You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
253 lines
7.1 KiB
253 lines
7.1 KiB
#ifndef AMT_THREAD_HPP
|
|
#define AMT_THREAD_HPP
|
|
|
|
#include <cassert>
|
|
#include <concepts>
|
|
#include <cstddef>
|
|
#include <deque>
|
|
#include <mutex>
|
|
#include <type_traits>
|
|
#include <thread>
|
|
#include <condition_variable>
|
|
#include <atomic>
|
|
#include <functional>
|
|
|
|
namespace amt {
|
|
|
|
// NOTE: Could implement lock-free queue.
|
|
template <typename T>
|
|
struct Queue {
|
|
using base_type = std::deque<T>;
|
|
using value_type = typename base_type::value_type;
|
|
using pointer = typename base_type::pointer;
|
|
using const_pointer = typename base_type::const_pointer;
|
|
using reference = typename base_type::reference;
|
|
using const_reference = typename base_type::const_reference;
|
|
using iterator = typename base_type::iterator;
|
|
using const_iterator = typename base_type::const_iterator;
|
|
using reverse_iterator = typename base_type::reverse_iterator;
|
|
using const_reverse_iterator = typename base_type::const_reverse_iterator;
|
|
using difference_type = typename base_type::difference_type;
|
|
using size_type = typename base_type::size_type;
|
|
|
|
constexpr Queue() noexcept = default;
|
|
constexpr Queue(Queue const&) noexcept = delete;
|
|
constexpr Queue(Queue &&) noexcept = default;
|
|
constexpr Queue& operator=(Queue const&) noexcept = delete;
|
|
constexpr Queue& operator=(Queue &&) noexcept = default;
|
|
constexpr ~Queue() noexcept = default;
|
|
|
|
template <typename U>
|
|
requires std::same_as<std::decay_t<U>, value_type>
|
|
void push(U&& u) {
|
|
std::lock_guard m(m_mutex);
|
|
m_data.push_back(std::forward<U>(u));
|
|
}
|
|
|
|
template <typename... Args>
|
|
void emplace(Args&&... args) {
|
|
std::lock_guard m(m_mutex);
|
|
m_data.emplace_back(std::forward<Args>(args)...);
|
|
}
|
|
|
|
std::optional<value_type> pop() {
|
|
std::lock_guard m(m_mutex);
|
|
if (empty_unsafe()) return std::nullopt;
|
|
auto el = std::move(m_data.front());
|
|
m_data.pop_front();
|
|
return std::move(el);
|
|
}
|
|
|
|
auto size() const noexcept -> size_type {
|
|
std::lock_guard m(m_mutex);
|
|
return m_data.size();
|
|
}
|
|
auto empty() const noexcept -> bool {
|
|
std::lock_guard m(m_mutex);
|
|
return m_data.empty();
|
|
}
|
|
constexpr auto size_unsafe() const noexcept -> size_type { return m_data.size(); }
|
|
constexpr auto empty_unsafe() const noexcept -> bool { return m_data.empty(); }
|
|
|
|
private:
|
|
base_type m_data;
|
|
mutable std::mutex m_mutex;
|
|
};
|
|
|
|
template <typename Fn>
|
|
struct ThreadPool;
|
|
|
|
template <typename Fn>
|
|
struct Worker {
|
|
using parent_t = ThreadPool<Fn>*;
|
|
using work_t = Fn;
|
|
using size_type = std::size_t;
|
|
constexpr Worker() noexcept = default;
|
|
constexpr Worker(Worker const&) noexcept = default;
|
|
constexpr Worker(Worker &&) noexcept = default;
|
|
constexpr Worker& operator=(Worker const&) noexcept = default;
|
|
constexpr Worker& operator=(Worker &&) noexcept = default;
|
|
~Worker() {
|
|
stop();
|
|
}
|
|
|
|
void start(parent_t pool, size_type id) {
|
|
assert((m_running.load(std::memory_order::acquire) == false) && "Thread is already running");
|
|
m_running.store(true);
|
|
m_parent.store(pool);
|
|
m_id = id;
|
|
m_thread = std::thread([this]() {
|
|
while (m_running.load(std::memory_order::relaxed)) {
|
|
std::unique_lock lk(m_mutex);
|
|
m_cv.wait(lk, [this] {
|
|
return !m_queue.empty_unsafe() || !m_running.load(std::memory_order::relaxed);
|
|
});
|
|
auto item = pop_task();
|
|
if (!item) {
|
|
item = try_steal();
|
|
if (!item) continue;
|
|
}
|
|
|
|
process_work(std::move(*item));
|
|
}
|
|
});
|
|
}
|
|
|
|
void process_work(work_t&& work) const noexcept {
|
|
std::invoke(std::move(work));
|
|
auto ptr = m_parent.load();
|
|
if (ptr) ptr->task_completed();
|
|
}
|
|
|
|
void stop() {
|
|
if (!m_running.load()) return;
|
|
{
|
|
std::lock_guard<std::mutex> lock(m_mutex);
|
|
m_running.store(false);
|
|
}
|
|
m_cv.notify_all();
|
|
m_thread.join();
|
|
m_parent.store(nullptr);
|
|
}
|
|
|
|
void add(work_t&& work) {
|
|
std::lock_guard<std::mutex> lock(m_mutex);
|
|
m_queue.push(std::move(work));
|
|
m_cv.notify_one();
|
|
}
|
|
|
|
std::optional<work_t> pop_task() noexcept {
|
|
return m_queue.pop();
|
|
}
|
|
|
|
|
|
std::optional<work_t> try_steal() noexcept {
|
|
auto ptr = m_parent.load();
|
|
if (ptr) return ptr->try_steal(m_id);
|
|
return {};
|
|
}
|
|
|
|
constexpr bool empty() const noexcept { return m_queue.empty_unsafe(); }
|
|
constexpr size_type size() const noexcept { return m_queue.size_unsafe(); }
|
|
constexpr size_type id() const noexcept { return m_id; }
|
|
constexpr bool running() const noexcept { return m_running.load(std::memory_order::relaxed); }
|
|
|
|
private:
|
|
Queue<work_t> m_queue{};
|
|
std::thread m_thread;
|
|
std::atomic<bool> m_running{false};
|
|
std::mutex m_mutex{};
|
|
std::condition_variable m_cv{};
|
|
std::atomic<parent_t> m_parent{nullptr};
|
|
size_type m_id;
|
|
};
|
|
|
|
template <typename Fn>
|
|
struct ThreadPool {
|
|
using worker_t = Worker<Fn>;
|
|
using work_t = typename worker_t::work_t;
|
|
using size_type = std::size_t;
|
|
|
|
constexpr ThreadPool(ThreadPool const&) noexcept = delete;
|
|
constexpr ThreadPool(ThreadPool &&) noexcept = default;
|
|
constexpr ThreadPool& operator=(ThreadPool const&) noexcept = delete;
|
|
constexpr ThreadPool& operator=(ThreadPool &&) noexcept = default;
|
|
~ThreadPool() {
|
|
stop();
|
|
}
|
|
|
|
ThreadPool(size_type n = std::thread::hardware_concurrency())
|
|
: m_workers(std::max(n, size_type{1}))
|
|
{
|
|
for (auto i = 0ul; i < m_workers.size(); ++i) {
|
|
m_workers[i].start(this, i);
|
|
}
|
|
}
|
|
|
|
void stop() {
|
|
for (auto& w: m_workers) w.stop();
|
|
}
|
|
|
|
void add(Fn&& work) {
|
|
m_active_tasks.fetch_add(1, std::memory_order::relaxed);
|
|
m_workers[m_last_added].add(std::move(work));
|
|
m_last_added = (m_last_added + 1) % m_workers.size();
|
|
}
|
|
|
|
std::optional<work_t> try_steal(size_type id) {
|
|
for (auto& w: m_workers) {
|
|
if (w.id() == id) continue;
|
|
auto item = w.pop_task();
|
|
if (item) return item;
|
|
}
|
|
return {};
|
|
}
|
|
|
|
void task_completed() {
|
|
if (m_active_tasks.fetch_sub(1, std::memory_order::release) == 1) {
|
|
m_wait_cv.notify_all();
|
|
}
|
|
}
|
|
|
|
void wait() {
|
|
std::unique_lock lock(m_wait_mutex);
|
|
m_wait_cv.wait(lock, [this] {
|
|
return m_active_tasks.load(std::memory_order::acquire) == 0;
|
|
});
|
|
}
|
|
|
|
|
|
private:
|
|
std::vector<worker_t> m_workers;
|
|
size_type m_last_added{};
|
|
std::mutex m_wait_mutex;
|
|
std::condition_variable m_wait_cv;
|
|
std::atomic<size_t> m_active_tasks{0};
|
|
};
|
|
|
|
using thread_pool_t = ThreadPool<std::function<void()>>;
|
|
|
|
// WARNING: Do not capture the stack variable if you're defering wait on pool.
|
|
// If you want to capture them, either capture them value or do "pool.wait()" at the end of the scope.
|
|
template <std::size_t Split, typename Fn>
|
|
requires (std::is_invocable_v<Fn, std::size_t>)
|
|
constexpr auto parallel_for(thread_pool_t& pool, std::size_t start, std::size_t end, Fn&& body) noexcept {
|
|
if (start >= end) return;
|
|
|
|
auto const size = (end - start);
|
|
auto const chunk_size = std::max(size_t{1}, (size + Split - 1) / Split);
|
|
auto const num_chunks = (size + chunk_size - 1) / chunk_size;
|
|
|
|
for (auto chunk = 0ul; chunk < num_chunks; ++chunk) {
|
|
auto const chunk_start = std::min(start + (chunk * chunk_size), end);
|
|
auto const chunk_end = std::min(chunk_start + (chunk_size), end);
|
|
pool.add([chunk_start, chunk_end, body] {
|
|
for (auto i = chunk_start; i < chunk_end; ++i) {
|
|
std::invoke(body, i);
|
|
}
|
|
});
|
|
}
|
|
}
|
|
} // nsmespace amt
|
|
|
|
#endif // AMT_THREAD_HPP
|
|
|