diff options
Diffstat (limited to 'src/zencore/workthreadpool.cpp')
| -rw-r--r-- | src/zencore/workthreadpool.cpp | 151 |
1 files changed, 97 insertions, 54 deletions
diff --git a/src/zencore/workthreadpool.cpp b/src/zencore/workthreadpool.cpp index 445fe939e..e241c0de8 100644 --- a/src/zencore/workthreadpool.cpp +++ b/src/zencore/workthreadpool.cpp @@ -5,6 +5,7 @@ #include <zencore/blockingqueue.h> #include <zencore/except.h> #include <zencore/logging.h> +#include <zencore/scopeguard.h> #include <zencore/string.h> #include <zencore/testing.h> #include <zencore/thread.h> @@ -13,6 +14,10 @@ #include <thread> #include <vector> +ZEN_THIRD_PARTY_INCLUDES_START +#include <gsl/gsl-lite.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + #define ZEN_USE_WINDOWS_THREADPOOL 1 #if ZEN_PLATFORM_WINDOWS && ZEN_USE_WINDOWS_THREADPOOL @@ -41,18 +46,23 @@ namespace { struct WorkerThreadPool::Impl { + const int m_ThreadCount = 0; PTP_POOL m_ThreadPool = nullptr; PTP_CLEANUP_GROUP m_CleanupGroup = nullptr; TP_CALLBACK_ENVIRON m_CallbackEnvironment; PTP_WORK m_Work = nullptr; - std::string m_WorkerThreadBaseName; - std::atomic<int> m_WorkerThreadCounter{0}; + std::string m_WorkerThreadBaseName; + std::atomic<size_t> m_WorkerThreadCounter{0}; + std::atomic<int> m_FreeWorkerCount{0}; - RwLock m_QueueLock; + mutable RwLock m_QueueLock; std::deque<Ref<IWork>> m_WorkQueue; - Impl(int InThreadCount, std::string_view WorkerThreadBaseName) : m_WorkerThreadBaseName(WorkerThreadBaseName) + Impl(int InThreadCount, std::string_view WorkerThreadBaseName) + : m_ThreadCount(InThreadCount) + , m_WorkerThreadBaseName(WorkerThreadBaseName) + , m_FreeWorkerCount(m_ThreadCount) { // Thread pool setup @@ -62,11 +72,11 @@ struct WorkerThreadPool::Impl ThrowLastError("CreateThreadpool failed"); } - if (!SetThreadpoolThreadMinimum(m_ThreadPool, InThreadCount)) + if (!SetThreadpoolThreadMinimum(m_ThreadPool, (DWORD)m_ThreadCount)) { ThrowLastError("SetThreadpoolThreadMinimum failed"); } - SetThreadpoolThreadMaximum(m_ThreadPool, InThreadCount * 2); + SetThreadpoolThreadMaximum(m_ThreadPool, (DWORD)m_ThreadCount); InitializeThreadpoolEnvironment(&m_CallbackEnvironment); @@ -93,12 +103,29 @@ struct WorkerThreadPool::Impl CloseThreadpool(m_ThreadPool); } - void ScheduleWork(Ref<IWork> Work) + [[nodiscard]] Ref<IWork> ScheduleWork(Ref<IWork> Work, WorkerThreadPool::EMode Mode) { - m_QueueLock.WithExclusiveLock([&] { m_WorkQueue.push_back(std::move(Work)); }); + if (Mode == WorkerThreadPool::EMode::DisableBacklog) + { + if (m_FreeWorkerCount <= 0) + { + return Work; + } + RwLock::ExclusiveLockScope _(m_QueueLock); + const int QueuedCount = gsl::narrow<int>(m_WorkQueue.size()); + if (QueuedCount >= m_FreeWorkerCount) + { + return Work; + } + m_WorkQueue.push_back(std::move(Work)); + } + else + { + m_QueueLock.WithExclusiveLock([&] { m_WorkQueue.push_back(std::move(Work)); }); + } SubmitThreadpoolWork(m_Work); + return {}; } - [[nodiscard]] size_t PendingWorkItemCount() const { return 0; } static VOID CALLBACK WorkCallback(_Inout_ PTP_CALLBACK_INSTANCE Instance, _Inout_opt_ PVOID Context, _Inout_ PTP_WORK Work) { @@ -109,10 +136,13 @@ struct WorkerThreadPool::Impl void DoWork() { + m_FreeWorkerCount--; + auto _ = MakeGuard([&]() { m_FreeWorkerCount++; }); + if (!t_IsThreadNamed) { t_IsThreadNamed = true; - const int ThreadIndex = ++m_WorkerThreadCounter; + const size_t ThreadIndex = ++m_WorkerThreadCounter; zen::ExtendableStringBuilder<128> ThreadName; ThreadName << m_WorkerThreadBaseName << "_" << ThreadIndex; SetCurrentThreadName(ThreadName); @@ -121,7 +151,7 @@ struct WorkerThreadPool::Impl Ref<IWork> WorkFromQueue; { - RwLock::ExclusiveLockScope _{m_QueueLock}; + RwLock::ExclusiveLockScope __{m_QueueLock}; WorkFromQueue = std::move(m_WorkQueue.front()); m_WorkQueue.pop_front(); } @@ -141,20 +171,25 @@ struct WorkerThreadPool::ThreadStartInfo struct WorkerThreadPool::Impl { + const int m_ThreadCount = 0; void WorkerThreadFunction(ThreadStartInfo Info); std::string m_WorkerThreadBaseName; std::vector<std::thread> m_WorkerThreads; BlockingQueue<Ref<IWork>> m_WorkQueue; + std::atomic<int> m_FreeWorkerCount{0}; - Impl(int InThreadCount, std::string_view WorkerThreadBaseName) : m_WorkerThreadBaseName(WorkerThreadBaseName) + Impl(int InThreadCount, std::string_view WorkerThreadBaseName) + : m_ThreadCount(InThreadCount) + , m_WorkerThreadBaseName(WorkerThreadBaseName) + , m_FreeWorkerCount(m_ThreadCount) { # if ZEN_WITH_TRACE trace::ThreadGroupBegin(m_WorkerThreadBaseName.c_str()); # endif - zen::Latch WorkerLatch{InThreadCount}; + zen::Latch WorkerLatch{m_ThreadCount}; - for (int i = 0; i < InThreadCount; ++i) + for (int i = 0; i < m_ThreadCount; ++i) { m_WorkerThreads.emplace_back(&Impl::WorkerThreadFunction, this, ThreadStartInfo{i + 1, &WorkerLatch}); } @@ -181,8 +216,23 @@ struct WorkerThreadPool::Impl m_WorkerThreads.clear(); } - void ScheduleWork(Ref<IWork> Work) { m_WorkQueue.Enqueue(std::move(Work)); } - [[nodiscard]] size_t PendingWorkItemCount() const { return m_WorkQueue.Size(); } + [[nodiscard]] Ref<IWork> ScheduleWork(Ref<IWork> Work, WorkerThreadPool::EMode Mode) + { + if (Mode == WorkerThreadPool::EMode::DisableBacklog) + { + if (m_FreeWorkerCount <= 0) + { + return Work; + } + const int QueuedCount = gsl::narrow<int>(m_WorkQueue.Size()); + if (QueuedCount >= m_FreeWorkerCount) + { + return Work; + } + } + m_WorkQueue.Enqueue(std::move(Work)); + return {}; + } }; void @@ -197,21 +247,23 @@ WorkerThreadPool::Impl::WorkerThreadFunction(ThreadStartInfo Info) Ref<IWork> Work; if (m_WorkQueue.WaitAndDequeue(Work)) { + m_FreeWorkerCount--; + auto _ = MakeGuard([&]() { m_FreeWorkerCount++; }); + try { ZEN_TRACE_CPU_FLUSH("AsyncWork"); Work->Execute(); + Work = {}; } catch (const AssertException& Ex) { - Work->m_Exception = std::current_exception(); - + Work = {}; ZEN_WARN("Assert exception in worker thread: {}", Ex.FullDescription()); } catch (const std::exception& e) { - Work->m_Exception = std::current_exception(); - + Work = {}; ZEN_WARN("Caught exception in worker thread: {}", e.what()); } } @@ -243,48 +295,38 @@ WorkerThreadPool::~WorkerThreadPool() } void -WorkerThreadPool::ScheduleWork(Ref<IWork> Work) +WorkerThreadPool::ScheduleWork(Ref<IWork> Work, EMode Mode) { if (m_Impl) { - m_Impl->ScheduleWork(std::move(Work)); - } - else - { - try + if (Work = m_Impl->ScheduleWork(std::move(Work), Mode); !Work) { - ZEN_TRACE_CPU_FLUSH("SyncWork"); - Work->Execute(); - } - catch (const AssertException& Ex) - { - Work->m_Exception = std::current_exception(); - - ZEN_WARN("Assert exception in worker thread: {}", Ex.FullDescription()); + return; } - catch (const std::exception& e) - { - Work->m_Exception = std::current_exception(); + } - ZEN_WARN("Caught exception when executing worker synchronously: {}", e.what()); - } + try + { + ZEN_TRACE_CPU_FLUSH("SyncWork"); + Work->Execute(); + Work = {}; + } + catch (const AssertException& Ex) + { + Work = {}; + ZEN_WARN("Assert exception in worker thread: {}", Ex.FullDescription()); + } + catch (const std::exception& e) + { + Work = {}; + ZEN_WARN("Caught exception when executing worker synchronously: {}", e.what()); } } void -WorkerThreadPool::ScheduleWork(std::function<void()>&& Work) -{ - ScheduleWork(Ref<IWork>(new detail::LambdaWork(std::move(Work)))); -} - -[[nodiscard]] size_t -WorkerThreadPool::PendingWorkItemCount() const +WorkerThreadPool::ScheduleWork(std::function<void()>&& Work, EMode Mode) { - if (m_Impl) - { - return m_Impl->PendingWorkItemCount(); - } - return 0; + ScheduleWork(Ref<IWork>(new detail::LambdaWork(std::move(Work))), Mode); } ////////////////////////////////////////////////////////////////////////// @@ -302,9 +344,10 @@ TEST_CASE("threadpool.basic") { WorkerThreadPool Threadpool{1}; - auto Future42 = Threadpool.EnqueueTask(std::packaged_task<int()>{[] { return 42; }}); - auto Future99 = Threadpool.EnqueueTask(std::packaged_task<int()>{[] { return 99; }}); - auto FutureThrow = Threadpool.EnqueueTask(std::packaged_task<void()>{[] { throw std::runtime_error("meep!"); }}); + auto Future42 = Threadpool.EnqueueTask(std::packaged_task<int()>{[] { return 42; }}, WorkerThreadPool::EMode::EnableBacklog); + auto Future99 = Threadpool.EnqueueTask(std::packaged_task<int()>{[] { return 99; }}, WorkerThreadPool::EMode::EnableBacklog); + auto FutureThrow = Threadpool.EnqueueTask(std::packaged_task<void()>{[] { throw std::runtime_error("meep!"); }}, + WorkerThreadPool::EMode::EnableBacklog); CHECK_EQ(Future42.get(), 42); CHECK_EQ(Future99.get(), 99); |