diff options
Diffstat (limited to 'src/zencore')
| -rw-r--r-- | src/zencore/include/zencore/blockingqueue.h | 27 | ||||
| -rw-r--r-- | src/zencore/include/zencore/workthreadpool.h | 8 | ||||
| -rw-r--r-- | src/zencore/workthreadpool.cpp | 230 |
3 files changed, 216 insertions, 49 deletions
diff --git a/src/zencore/include/zencore/blockingqueue.h b/src/zencore/include/zencore/blockingqueue.h index b6c93e937..6ac43b1ac 100644 --- a/src/zencore/include/zencore/blockingqueue.h +++ b/src/zencore/include/zencore/blockingqueue.h @@ -5,6 +5,7 @@ #include <zencore/zencore.h> // For ZEN_ASSERT #include <atomic> +#include <chrono> #include <condition_variable> #include <deque> #include <mutex> @@ -50,6 +51,32 @@ public: return true; } + // Returns: true if item dequeued, false on timeout or completion + template<typename Rep, typename Period> + bool WaitAndDequeueFor(T& Item, std::chrono::duration<Rep, Period> Timeout) + { + std::unique_lock Lock(m_Lock); + if (m_Queue.empty()) + { + if (m_CompleteAdding) + { + return false; + } + if (!m_NewItemSignal.wait_for(Lock, Timeout, [this]() { return !m_Queue.empty() || m_CompleteAdding; })) + { + return false; // Timed out + } + if (m_Queue.empty()) + { + ZEN_ASSERT(m_CompleteAdding); + return false; + } + } + Item = std::move(m_Queue.front()); + m_Queue.pop_front(); + return true; + } + void CompleteAdding() { std::unique_lock Lock(m_Lock); diff --git a/src/zencore/include/zencore/workthreadpool.h b/src/zencore/include/zencore/workthreadpool.h index 4c38dd651..cb0b8f491 100644 --- a/src/zencore/include/zencore/workthreadpool.h +++ b/src/zencore/include/zencore/workthreadpool.h @@ -27,8 +27,9 @@ private: class WorkerThreadPool { public: - explicit WorkerThreadPool(int InThreadCount); - WorkerThreadPool(int InThreadCount, std::string_view WorkerThreadBaseName); + explicit WorkerThreadPool(int InThreadCount, bool UseExplicitThreads = true); + WorkerThreadPool(int InThreadCount, std::string_view WorkerThreadBaseName, bool UseExplicitThreads = true); + WorkerThreadPool(int InMinThreadCount, int InMaxThreadCount, std::string_view WorkerThreadBaseName, bool UseExplicitThreads = true); ~WorkerThreadPool(); // Decides what to do if there are no free workers in the pool when the work is submitted @@ -48,6 +49,9 @@ private: struct Impl; struct ThreadStartInfo; + friend struct WinTpImpl; + friend struct ExplicitImpl; + std::unique_ptr<Impl> m_Impl; }; diff --git a/src/zencore/workthreadpool.cpp b/src/zencore/workthreadpool.cpp index 1cb338c66..b179527d7 100644 --- a/src/zencore/workthreadpool.cpp +++ b/src/zencore/workthreadpool.cpp @@ -11,6 +11,7 @@ #include <zencore/thread.h> #include <zencore/trace.h> +#include <algorithm> #include <thread> #include <vector> @@ -18,9 +19,7 @@ ZEN_THIRD_PARTY_INCLUDES_START #include <gsl/gsl-lite.hpp> ZEN_THIRD_PARTY_INCLUDES_END -#define ZEN_USE_WINDOWS_THREADPOOL 1 - -#if ZEN_PLATFORM_WINDOWS && ZEN_USE_WINDOWS_THREADPOOL +#if ZEN_PLATFORM_WINDOWS # include <zencore/windows.h> #endif @@ -38,17 +37,26 @@ namespace detail { ////////////////////////////////////////////////////////////////////////// -#if ZEN_USE_WINDOWS_THREADPOOL && ZEN_PLATFORM_WINDOWS +struct WorkerThreadPool::Impl +{ + virtual ~Impl() = default; + [[nodiscard]] virtual Ref<IWork> ScheduleWork(Ref<IWork> Work, WorkerThreadPool::EMode Mode) = 0; +}; + +////////////////////////////////////////////////////////////////////////// + +#if ZEN_PLATFORM_WINDOWS namespace { thread_local bool t_IsThreadNamed{false}; } -struct WorkerThreadPool::Impl +struct WinTpImpl : WorkerThreadPool::Impl { - const int m_ThreadCount = 0; - PTP_POOL m_ThreadPool = nullptr; - PTP_CLEANUP_GROUP m_CleanupGroup = nullptr; + const int m_MinThreadCount = 0; + const int m_MaxThreadCount = 0; + PTP_POOL m_ThreadPool = nullptr; + PTP_CLEANUP_GROUP m_CleanupGroup = nullptr; TP_CALLBACK_ENVIRON m_CallbackEnvironment; PTP_WORK m_Work = nullptr; @@ -59,10 +67,11 @@ struct WorkerThreadPool::Impl mutable RwLock m_QueueLock; std::deque<Ref<IWork>> m_WorkQueue; - Impl(int InThreadCount, std::string_view WorkerThreadBaseName) - : m_ThreadCount(InThreadCount) + WinTpImpl(int InMinThreadCount, int InMaxThreadCount, std::string_view WorkerThreadBaseName) + : m_MinThreadCount(InMinThreadCount) + , m_MaxThreadCount(InMaxThreadCount < InMinThreadCount ? InMinThreadCount : InMaxThreadCount) , m_WorkerThreadBaseName(WorkerThreadBaseName) - , m_FreeWorkerCount(m_ThreadCount) + , m_FreeWorkerCount(m_MinThreadCount) { // Thread pool setup @@ -72,11 +81,11 @@ struct WorkerThreadPool::Impl ThrowLastError("CreateThreadpool failed"); } - if (!SetThreadpoolThreadMinimum(m_ThreadPool, (DWORD)m_ThreadCount)) + if (!SetThreadpoolThreadMinimum(m_ThreadPool, (DWORD)m_MinThreadCount)) { ThrowLastError("SetThreadpoolThreadMinimum failed"); } - SetThreadpoolThreadMaximum(m_ThreadPool, (DWORD)m_ThreadCount); + SetThreadpoolThreadMaximum(m_ThreadPool, (DWORD)m_MaxThreadCount); InitializeThreadpoolEnvironment(&m_CallbackEnvironment); @@ -96,14 +105,14 @@ struct WorkerThreadPool::Impl } } - ~Impl() + ~WinTpImpl() override { WaitForThreadpoolWorkCallbacks(m_Work, /* CancelPendingCallbacks */ TRUE); CloseThreadpoolWork(m_Work); CloseThreadpool(m_ThreadPool); } - [[nodiscard]] Ref<IWork> ScheduleWork(Ref<IWork> Work, WorkerThreadPool::EMode Mode) + [[nodiscard]] Ref<IWork> ScheduleWork(Ref<IWork> Work, WorkerThreadPool::EMode Mode) override { if (Mode == WorkerThreadPool::EMode::DisableBacklog) { @@ -130,7 +139,7 @@ struct WorkerThreadPool::Impl static VOID CALLBACK WorkCallback(_Inout_ PTP_CALLBACK_INSTANCE Instance, _Inout_opt_ PVOID Context, _Inout_ PTP_WORK Work) { ZEN_UNUSED(Instance, Work); - Impl* ThisPtr = reinterpret_cast<Impl*>(Context); + WinTpImpl* ThisPtr = reinterpret_cast<WinTpImpl*>(Context); ThisPtr->DoWork(); } @@ -175,7 +184,9 @@ struct WorkerThreadPool::Impl } }; -#else +#endif + +////////////////////////////////////////////////////////////////////////// struct WorkerThreadPool::ThreadStartInfo { @@ -183,42 +194,54 @@ struct WorkerThreadPool::ThreadStartInfo zen::Latch* Latch; }; -struct WorkerThreadPool::Impl +struct ExplicitImpl : WorkerThreadPool::Impl { - const int m_ThreadCount = 0; - void WorkerThreadFunction(ThreadStartInfo Info); - std::string m_WorkerThreadBaseName; - std::vector<std::thread> m_WorkerThreads; - BlockingQueue<Ref<IWork>> m_WorkQueue; - std::atomic<int> m_FreeWorkerCount{0}; - - Impl(int InThreadCount, std::string_view WorkerThreadBaseName) - : m_ThreadCount(InThreadCount) + const int m_MinThreads; + const int m_MaxThreads; + std::atomic<int> m_TotalThreads{0}; + std::atomic<int> m_ActiveCount{0}; + void WorkerThreadFunction(WorkerThreadPool::ThreadStartInfo Info); + void SpawnWorkerThread(); + void PruneExitedThreads(); + std::string m_WorkerThreadBaseName; + RwLock m_ThreadListLock; + std::vector<std::thread> m_WorkerThreads; + std::vector<std::thread::id> m_ExitedThreadIds; + BlockingQueue<Ref<IWork>> m_WorkQueue; + std::atomic<int> m_FreeWorkerCount{0}; + + bool ScalingEnabled() const { return m_MinThreads != m_MaxThreads; } + + ExplicitImpl(int InMinThreadCount, int InMaxThreadCount, std::string_view WorkerThreadBaseName) + : m_MinThreads(InMinThreadCount) + , m_MaxThreads(InMaxThreadCount < InMinThreadCount ? InMinThreadCount : InMaxThreadCount) , m_WorkerThreadBaseName(WorkerThreadBaseName) - , m_FreeWorkerCount(m_ThreadCount) + , m_FreeWorkerCount(InMinThreadCount) { -# if ZEN_WITH_TRACE +#if ZEN_WITH_TRACE trace::ThreadGroupBegin(m_WorkerThreadBaseName.c_str()); -# endif +#endif - zen::Latch WorkerLatch{m_ThreadCount}; + zen::Latch WorkerLatch{m_MinThreads}; - for (int i = 0; i < m_ThreadCount; ++i) + for (int i = 0; i < m_MinThreads; ++i) { - m_WorkerThreads.emplace_back(&Impl::WorkerThreadFunction, this, ThreadStartInfo{i + 1, &WorkerLatch}); + m_TotalThreads.fetch_add(1, std::memory_order::relaxed); + m_WorkerThreads.emplace_back(&ExplicitImpl::WorkerThreadFunction, this, WorkerThreadPool::ThreadStartInfo{i + 1, &WorkerLatch}); } WorkerLatch.Wait(); -# if ZEN_WITH_TRACE +#if ZEN_WITH_TRACE trace::ThreadGroupEnd(); -# endif +#endif } - ~Impl() + ~ExplicitImpl() override { m_WorkQueue.CompleteAdding(); + RwLock::ExclusiveLockScope _(m_ThreadListLock); for (std::thread& Thread : m_WorkerThreads) { if (Thread.joinable()) @@ -230,7 +253,7 @@ struct WorkerThreadPool::Impl m_WorkerThreads.clear(); } - [[nodiscard]] Ref<IWork> ScheduleWork(Ref<IWork> Work, WorkerThreadPool::EMode Mode) + [[nodiscard]] Ref<IWork> ScheduleWork(Ref<IWork> Work, WorkerThreadPool::EMode Mode) override { if (Mode == WorkerThreadPool::EMode::DisableBacklog) { @@ -245,24 +268,97 @@ struct WorkerThreadPool::Impl } } m_WorkQueue.Enqueue(std::move(Work)); + + // Scale up: if all workers are busy and we haven't hit the max, spawn a new thread + if (ScalingEnabled()) + { + const int Active = m_ActiveCount.load(std::memory_order::acquire); + const int Total = m_TotalThreads.load(std::memory_order::acquire); + if (Active >= Total && Total < m_MaxThreads) + { + int Expected = Total; + if (m_TotalThreads.compare_exchange_strong(Expected, Total + 1, std::memory_order::acq_rel)) + { + ZEN_DEBUG("scaling up worker thread pool '{}', {} -> {} threads", m_WorkerThreadBaseName, Total, Total + 1); + SpawnWorkerThread(); + } + } + } + return {}; } }; void -WorkerThreadPool::Impl::WorkerThreadFunction(ThreadStartInfo Info) +ExplicitImpl::PruneExitedThreads() +{ + // Must be called under m_ThreadListLock + if (m_ExitedThreadIds.empty()) + { + return; + } + + for (auto It = m_WorkerThreads.begin(); It != m_WorkerThreads.end();) + { + auto IdIt = std::find(m_ExitedThreadIds.begin(), m_ExitedThreadIds.end(), It->get_id()); + if (IdIt != m_ExitedThreadIds.end()) + { + It->join(); + It = m_WorkerThreads.erase(It); + m_ExitedThreadIds.erase(IdIt); + } + else + { + ++It; + } + } +} + +void +ExplicitImpl::SpawnWorkerThread() +{ + static std::atomic<int> s_DynamicThreadIndex{0}; + const int ThreadNumber = ++s_DynamicThreadIndex; + + RwLock::ExclusiveLockScope _(m_ThreadListLock); + PruneExitedThreads(); + m_WorkerThreads.emplace_back(&ExplicitImpl::WorkerThreadFunction, this, WorkerThreadPool::ThreadStartInfo{ThreadNumber, nullptr}); +} + +void +ExplicitImpl::WorkerThreadFunction(WorkerThreadPool::ThreadStartInfo Info) { SetCurrentThreadName(fmt::format("{}_{}", m_WorkerThreadBaseName, Info.ThreadNumber)); - Info.Latch->CountDown(); + if (Info.Latch) + { + Info.Latch->CountDown(); + } + + static constexpr auto kIdleTimeout = std::chrono::seconds(15); do { Ref<IWork> Work; - if (m_WorkQueue.WaitAndDequeue(Work)) + + bool Dequeued; + if (ScalingEnabled()) + { + Dequeued = m_WorkQueue.WaitAndDequeueFor(Work, kIdleTimeout); + } + else + { + Dequeued = m_WorkQueue.WaitAndDequeue(Work); + } + + if (Dequeued) { m_FreeWorkerCount--; - auto _ = MakeGuard([&]() { m_FreeWorkerCount++; }); + m_ActiveCount.fetch_add(1, std::memory_order::acq_rel); + auto _ = MakeGuard([&]() { + m_ActiveCount.fetch_sub(1, std::memory_order::release); + m_FreeWorkerCount++; + }); try { @@ -281,25 +377,65 @@ WorkerThreadPool::Impl::WorkerThreadFunction(ThreadStartInfo Info) ZEN_ERROR("Caught exception in worker thread: {}", e.what()); } } + else if (ScalingEnabled()) + { + // Timed out - consider scaling down + const int CurrentTotal = m_TotalThreads.load(std::memory_order::acquire); + if (CurrentTotal > m_MinThreads) + { + int Expected = CurrentTotal; + if (m_TotalThreads.compare_exchange_strong(Expected, CurrentTotal - 1, std::memory_order::acq_rel)) + { + ZEN_DEBUG("scaling down worker thread pool '{}' (idle timeout), {} threads remaining", + m_WorkerThreadBaseName, + CurrentTotal - 1); + m_FreeWorkerCount--; + { + RwLock::ExclusiveLockScope _(m_ThreadListLock); + m_ExitedThreadIds.push_back(std::this_thread::get_id()); + } + return; // Thread exits + } + } + // CAS failed or at min threads - continue waiting + } else { + // CompleteAdding was called - exit return; } } while (true); } -#endif - ////////////////////////////////////////////////////////////////////////// -WorkerThreadPool::WorkerThreadPool(int InThreadCount) : WorkerThreadPool(InThreadCount, "workerthread") +WorkerThreadPool::WorkerThreadPool(int InThreadCount, bool UseExplicitThreads) +: WorkerThreadPool(InThreadCount, "workerthread", UseExplicitThreads) +{ +} + +WorkerThreadPool::WorkerThreadPool(int InThreadCount, std::string_view WorkerThreadBaseName, bool UseExplicitThreads) +: WorkerThreadPool(InThreadCount, InThreadCount, WorkerThreadBaseName, UseExplicitThreads) { } -WorkerThreadPool::WorkerThreadPool(int InThreadCount, std::string_view WorkerThreadBaseName) +WorkerThreadPool::WorkerThreadPool(int InMinThreadCount, + int InMaxThreadCount, + std::string_view WorkerThreadBaseName, + bool UseExplicitThreads) { - if (InThreadCount > 0) + if (InMinThreadCount > 0) { - m_Impl = std::make_unique<Impl>(InThreadCount, WorkerThreadBaseName); +#if ZEN_PLATFORM_WINDOWS + if (!UseExplicitThreads) + { + m_Impl = std::make_unique<WinTpImpl>(InMinThreadCount, InMaxThreadCount, WorkerThreadBaseName); + } + else +#endif + { + ZEN_UNUSED(UseExplicitThreads); + m_Impl = std::make_unique<ExplicitImpl>(InMinThreadCount, InMaxThreadCount, WorkerThreadBaseName); + } } } |