#ifndef AMT_THREAD_HPP
#define AMT_THREAD_HPP

#include <cassert>
#include <concepts>
#include <cstddef>
#include <deque>
#include <mutex>
#include <type_traits>
#include <thread>
#include <condition_variable>
#include <atomic>
#include <functional>

namespace amt {

	// NOTE: Could implement lock-free queue.
	template <typename T>
	struct Queue {
		using base_type = std::deque<T>;
		using value_type = typename base_type::value_type;
		using pointer = typename base_type::pointer;
		using const_pointer = typename base_type::const_pointer;
		using reference = typename base_type::reference;
		using const_reference = typename base_type::const_reference;
		using iterator = typename base_type::iterator;
		using const_iterator = typename base_type::const_iterator;
		using reverse_iterator = typename base_type::reverse_iterator;
		using const_reverse_iterator = typename base_type::const_reverse_iterator;
		using difference_type = typename base_type::difference_type;
		using size_type = typename base_type::size_type;

		constexpr Queue() noexcept = default;
		constexpr Queue(Queue const&) noexcept = delete;
		constexpr Queue(Queue &&) noexcept = default;
		constexpr Queue& operator=(Queue const&) noexcept = delete;
		constexpr Queue& operator=(Queue &&) noexcept = default;
		constexpr ~Queue() noexcept = default;

		template <typename U>
			requires std::same_as<std::decay_t<U>, value_type>
		void push(U&& u) {
			std::lock_guard m(m_mutex);
			m_data.push_back(std::forward<U>(u));
		}

		template <typename... Args>
		void emplace(Args&&... args) {
			std::lock_guard m(m_mutex);
			m_data.emplace_back(std::forward<Args>(args)...);
		}

		std::optional<value_type> pop() {
			std::lock_guard m(m_mutex);
			if (empty_unsafe()) return std::nullopt;
			auto el = std::move(m_data.front());
			m_data.pop_front();
			return std::move(el);
		}

		auto size() const noexcept -> size_type {
			std::lock_guard m(m_mutex);
			return m_data.size();
		}
		auto empty() const noexcept -> bool {
			std::lock_guard m(m_mutex);
			return m_data.empty();
		}
		constexpr auto size_unsafe() const noexcept -> size_type { return m_data.size(); }
		constexpr auto empty_unsafe() const noexcept -> bool { return m_data.empty(); }

	private:
		base_type m_data;
		mutable std::mutex m_mutex;
	};

	template <typename Fn>
	struct ThreadPool;

	template <typename Fn>
	struct Worker {
		using parent_t = ThreadPool<Fn>*;
		using work_t = Fn;
		using size_type = std::size_t;
		constexpr Worker() noexcept = default;
		constexpr Worker(Worker const&) noexcept = default;
		constexpr Worker(Worker &&) noexcept = default;
		constexpr Worker& operator=(Worker const&) noexcept = default;
		constexpr Worker& operator=(Worker &&) noexcept = default;
		~Worker() {
			stop();
		}

		void start(parent_t pool, size_type id) {
			assert((m_running.load(std::memory_order::acquire) == false) && "Thread is already running");
			m_running.store(true);
			m_parent.store(pool);
			m_id = id;
			m_thread = std::thread([this]() {
				while (m_running.load(std::memory_order::relaxed)) {
					std::unique_lock lk(m_mutex);
					m_cv.wait(lk, [this] {
						return !m_queue.empty_unsafe() || !m_running.load(std::memory_order::relaxed);
					});
					auto item = pop_task();
					if (!item) {
						item = try_steal();
						if (!item) continue;
					}

					process_work(std::move(*item));
				}
			});
		}

		void process_work(work_t&& work) const noexcept {
			std::invoke(std::move(work));
			auto ptr = m_parent.load();
			if (ptr) ptr->task_completed();
		}

		void stop() {
			if (!m_running.load()) return;
			{
				std::lock_guard<std::mutex> lock(m_mutex);
				m_running.store(false);
			}
			m_cv.notify_all();
			m_thread.join();
			m_parent.store(nullptr);
		}

		void add(work_t&& work) {
			std::lock_guard<std::mutex> lock(m_mutex);
			m_queue.push(std::move(work));
			m_cv.notify_one();
		}

		std::optional<work_t> pop_task() noexcept {
			return m_queue.pop();
		}


		std::optional<work_t> try_steal() noexcept {
			auto ptr = m_parent.load();
			if (ptr) return ptr->try_steal(m_id);
			return {};
		}

		constexpr bool empty() const noexcept { return m_queue.empty_unsafe(); }
		constexpr size_type size() const noexcept { return m_queue.size_unsafe(); }
		constexpr size_type id() const noexcept { return m_id; }
		constexpr bool running() const noexcept { return m_running.load(std::memory_order::relaxed); }

	private:
		Queue<work_t> m_queue{};
		std::thread m_thread;
		std::atomic<bool> m_running{false};
		std::mutex m_mutex{};
		std::condition_variable m_cv{};
		std::atomic<parent_t> m_parent{nullptr};
		size_type m_id;
	};

	template <typename Fn>
	struct ThreadPool {
		using worker_t = Worker<Fn>;
		using work_t = typename worker_t::work_t;
		using size_type = std::size_t;

		constexpr ThreadPool(ThreadPool const&) noexcept = delete;
		constexpr ThreadPool(ThreadPool &&) noexcept = default;
		constexpr ThreadPool& operator=(ThreadPool const&) noexcept = delete;
		constexpr ThreadPool& operator=(ThreadPool &&) noexcept = default;
		~ThreadPool() {
			stop();
		}

		ThreadPool(size_type n = std::thread::hardware_concurrency())
			: m_workers(std::max(n, size_type{1}))
		{
			for (auto i = 0ul; i < m_workers.size(); ++i) {
				m_workers[i].start(this, i);
			}
		}

		void stop() {
			for (auto& w: m_workers) w.stop();
		}

		void add(Fn&& work) {
			m_active_tasks.fetch_add(1, std::memory_order::relaxed);
			m_workers[m_last_added].add(std::move(work));
			m_last_added = (m_last_added + 1) % m_workers.size();
		}

		std::optional<work_t> try_steal(size_type id) {
			for (auto& w: m_workers) {
				if (w.id() == id) continue;
				auto item = w.pop_task();
				if (item) return item;
			}
			return {};
		}

		void task_completed() {
			if (m_active_tasks.fetch_sub(1, std::memory_order::release) == 1) {
				m_wait_cv.notify_all();
			}
		}

		void wait() {
			std::unique_lock lock(m_wait_mutex);
			m_wait_cv.wait(lock, [this] {
				return m_active_tasks.load(std::memory_order::acquire) == 0;
			});
		}


	private:
		std::vector<worker_t> m_workers;
		size_type m_last_added{};
		std::mutex m_wait_mutex;
		std::condition_variable m_wait_cv;
		std::atomic<size_t> m_active_tasks{0};
	};

	using thread_pool_t = ThreadPool<std::function<void()>>;

	// WARNING: Do not capture the stack variable if you're defering wait on pool.
	// If you want to capture them, either capture them value or do "pool.wait()" at the end of the scope.
	template <std::size_t Split, typename Fn>
		requires (std::is_invocable_v<Fn, std::size_t>)
	constexpr auto parallel_for(thread_pool_t& pool, std::size_t start, std::size_t end, Fn&& body) noexcept {
		if (start >= end) return;

		auto const size = (end - start);
		auto const chunk_size = std::max(size_t{1}, (size + Split - 1) / Split);
		auto const num_chunks = (size + chunk_size - 1) / chunk_size;

		for (auto chunk = 0ul; chunk < num_chunks; ++chunk) {
			auto const chunk_start = std::min(start + (chunk * chunk_size), end);
			auto const chunk_end = std::min(chunk_start + (chunk_size), end);
			pool.add([chunk_start, chunk_end, body] {
				for (auto i = chunk_start; i < chunk_end; ++i) {
					std::invoke(body, i);
				}
			});
		}
	}
} // nsmespace amt

#endif // AMT_THREAD_HPP