#ifndef AMT_THREAD_HPP #define AMT_THREAD_HPP #include #include #include #include #include #include #include #include #include #include namespace amt { // NOTE: Could implement lock-free queue. template struct Queue { using base_type = std::deque; using value_type = typename base_type::value_type; using pointer = typename base_type::pointer; using const_pointer = typename base_type::const_pointer; using reference = typename base_type::reference; using const_reference = typename base_type::const_reference; using iterator = typename base_type::iterator; using const_iterator = typename base_type::const_iterator; using reverse_iterator = typename base_type::reverse_iterator; using const_reverse_iterator = typename base_type::const_reverse_iterator; using difference_type = typename base_type::difference_type; using size_type = typename base_type::size_type; constexpr Queue() noexcept = default; constexpr Queue(Queue const&) noexcept = delete; constexpr Queue(Queue &&) noexcept = default; constexpr Queue& operator=(Queue const&) noexcept = delete; constexpr Queue& operator=(Queue &&) noexcept = default; constexpr ~Queue() noexcept = default; template requires std::same_as, value_type> void push(U&& u) { std::lock_guard m(m_mutex); m_data.push_back(std::forward(u)); } template void emplace(Args&&... args) { std::lock_guard m(m_mutex); m_data.emplace_back(std::forward(args)...); } std::optional pop() { std::lock_guard m(m_mutex); if (empty_unsafe()) return std::nullopt; auto el = std::move(m_data.front()); m_data.pop_front(); return std::move(el); } auto size() const noexcept -> size_type { std::lock_guard m(m_mutex); return m_data.size(); } auto empty() const noexcept -> bool { std::lock_guard m(m_mutex); return m_data.empty(); } constexpr auto size_unsafe() const noexcept -> size_type { return m_data.size(); } constexpr auto empty_unsafe() const noexcept -> bool { return m_data.empty(); } private: base_type m_data; mutable std::mutex m_mutex; }; template struct ThreadPool; template struct Worker { using parent_t = ThreadPool*; using work_t = Fn; using size_type = std::size_t; constexpr Worker() noexcept = default; constexpr Worker(Worker const&) noexcept = default; constexpr Worker(Worker &&) noexcept = default; constexpr Worker& operator=(Worker const&) noexcept = default; constexpr Worker& operator=(Worker &&) noexcept = default; ~Worker() { stop(); } void start(parent_t pool, size_type id) { assert((m_running.load(std::memory_order::acquire) == false) && "Thread is already running"); m_running.store(true); m_parent.store(pool); m_id = id; m_thread = std::thread([this]() { while (m_running.load(std::memory_order::relaxed)) { std::unique_lock lk(m_mutex); m_cv.wait(lk, [this] { return !m_queue.empty_unsafe() || !m_running.load(std::memory_order::relaxed); }); auto item = pop_task(); if (!item) { item = try_steal(); if (!item) continue; } process_work(std::move(*item)); } }); } void process_work(work_t&& work) const noexcept { std::invoke(std::move(work)); auto ptr = m_parent.load(); if (ptr) ptr->task_completed(); } void stop() { if (!m_running.load()) return; { std::lock_guard lock(m_mutex); m_running.store(false); } m_cv.notify_all(); m_thread.join(); m_parent.store(nullptr); } void add(work_t&& work) { std::lock_guard lock(m_mutex); m_queue.push(std::move(work)); m_cv.notify_one(); } std::optional pop_task() noexcept { return m_queue.pop(); } std::optional try_steal() noexcept { auto ptr = m_parent.load(); if (ptr) return ptr->try_steal(m_id); return {}; } constexpr bool empty() const noexcept { return m_queue.empty_unsafe(); } constexpr size_type size() const noexcept { return m_queue.size_unsafe(); } constexpr size_type id() const noexcept { return m_id; } constexpr bool running() const noexcept { return m_running.load(std::memory_order::relaxed); } private: Queue m_queue{}; std::thread m_thread; std::atomic m_running{false}; std::mutex m_mutex{}; std::condition_variable m_cv{}; std::atomic m_parent{nullptr}; size_type m_id; }; template struct ThreadPool { using worker_t = Worker; using work_t = typename worker_t::work_t; using size_type = std::size_t; constexpr ThreadPool(ThreadPool const&) noexcept = delete; constexpr ThreadPool(ThreadPool &&) noexcept = default; constexpr ThreadPool& operator=(ThreadPool const&) noexcept = delete; constexpr ThreadPool& operator=(ThreadPool &&) noexcept = default; ~ThreadPool() { stop(); } ThreadPool(size_type n = std::thread::hardware_concurrency()) : m_workers(std::max(n, size_type{1})) { for (auto i = 0ul; i < m_workers.size(); ++i) { m_workers[i].start(this, i); } } void stop() { for (auto& w: m_workers) w.stop(); } void add(Fn&& work) { m_active_tasks.fetch_add(1, std::memory_order::relaxed); m_workers[m_last_added].add(std::move(work)); m_last_added = (m_last_added + 1) % m_workers.size(); } std::optional try_steal(size_type id) { for (auto& w: m_workers) { if (w.id() == id) continue; auto item = w.pop_task(); if (item) return item; } return {}; } void task_completed() { if (m_active_tasks.fetch_sub(1, std::memory_order::release) == 1) { m_wait_cv.notify_all(); } } void wait() { std::unique_lock lock(m_wait_mutex); m_wait_cv.wait(lock, [this] { return m_active_tasks.load(std::memory_order::acquire) == 0; }); } private: std::vector m_workers; size_type m_last_added{}; std::mutex m_wait_mutex; std::condition_variable m_wait_cv; std::atomic m_active_tasks{0}; }; using thread_pool_t = ThreadPool>; // WARNING: Do not capture the stack variable if you're defering wait on pool. // If you want to capture them, either capture them value or do "pool.wait()" at the end of the scope. template requires (std::is_invocable_v) constexpr auto parallel_for(thread_pool_t& pool, std::size_t start, std::size_t end, Fn&& body) noexcept { if (start >= end) return; auto const size = (end - start); auto const chunk_size = std::max(size_t{1}, (size + Split - 1) / Split); auto const num_chunks = (size + chunk_size - 1) / chunk_size; for (auto chunk = 0ul; chunk < num_chunks; ++chunk) { auto const chunk_start = std::min(start + (chunk * chunk_size), end); auto const chunk_end = std::min(chunk_start + (chunk_size), end); pool.add([chunk_start, chunk_end, body] { for (auto i = chunk_start; i < chunk_end; ++i) { std::invoke(body, i); } }); } } } // nsmespace amt #endif // AMT_THREAD_HPP