// Copyright Epic Games, Inc. All Rights Reserved. #include #include #include #include #include #include #include #include #include #include #include ZEN_THIRD_PARTY_INCLUDES_START #include ZEN_THIRD_PARTY_INCLUDES_END #if ZEN_PLATFORM_WINDOWS # include #endif namespace zen { namespace detail { struct LambdaWork : IWork { LambdaWork(auto Work) : WorkFunction(Work) {} virtual void Execute() override { WorkFunction(); } std::function WorkFunction; }; } // namespace detail ////////////////////////////////////////////////////////////////////////// struct WorkerThreadPool::Impl { virtual ~Impl() = default; [[nodiscard]] virtual Ref ScheduleWork(Ref Work, WorkerThreadPool::EMode Mode) = 0; }; ////////////////////////////////////////////////////////////////////////// #if ZEN_PLATFORM_WINDOWS namespace { thread_local bool t_IsThreadNamed{false}; } struct WinTpImpl : WorkerThreadPool::Impl { const int m_MinThreadCount = 0; const int m_MaxThreadCount = 0; PTP_POOL m_ThreadPool = nullptr; PTP_CLEANUP_GROUP m_CleanupGroup = nullptr; TP_CALLBACK_ENVIRON m_CallbackEnvironment; PTP_WORK m_Work = nullptr; std::string m_WorkerThreadBaseName; std::atomic m_WorkerThreadCounter{0}; std::atomic m_FreeWorkerCount{0}; mutable RwLock m_QueueLock; std::deque> m_WorkQueue; WinTpImpl(int InMinThreadCount, int InMaxThreadCount, std::string_view WorkerThreadBaseName) : m_MinThreadCount(InMinThreadCount) , m_MaxThreadCount(InMaxThreadCount < InMinThreadCount ? InMinThreadCount : InMaxThreadCount) , m_WorkerThreadBaseName(WorkerThreadBaseName) , m_FreeWorkerCount(m_MinThreadCount) { // Thread pool setup m_ThreadPool = CreateThreadpool(NULL); if (m_ThreadPool == NULL) { ThrowLastError("CreateThreadpool failed"); } if (!SetThreadpoolThreadMinimum(m_ThreadPool, (DWORD)m_MinThreadCount)) { ThrowLastError("SetThreadpoolThreadMinimum failed"); } SetThreadpoolThreadMaximum(m_ThreadPool, (DWORD)m_MaxThreadCount); InitializeThreadpoolEnvironment(&m_CallbackEnvironment); m_CleanupGroup = CreateThreadpoolCleanupGroup(); if (m_CleanupGroup == NULL) { ThrowLastError("CreateThreadpoolCleanupGroup failed"); } SetThreadpoolCallbackPool(&m_CallbackEnvironment, m_ThreadPool); SetThreadpoolCallbackCleanupGroup(&m_CallbackEnvironment, m_CleanupGroup, NULL); m_Work = CreateThreadpoolWork(&WorkCallback, this, &m_CallbackEnvironment); if (m_Work == NULL) { ThrowLastError("CreateThreadpoolWork failed"); } } ~WinTpImpl() override { WaitForThreadpoolWorkCallbacks(m_Work, /* CancelPendingCallbacks */ TRUE); CloseThreadpoolWork(m_Work); CloseThreadpool(m_ThreadPool); } [[nodiscard]] Ref ScheduleWork(Ref Work, WorkerThreadPool::EMode Mode) override { if (Mode == WorkerThreadPool::EMode::DisableBacklog) { if (m_FreeWorkerCount <= 0) { return Work; } RwLock::ExclusiveLockScope _(m_QueueLock); const int QueuedCount = gsl::narrow(m_WorkQueue.size()); if (QueuedCount >= m_FreeWorkerCount) { return Work; } m_WorkQueue.push_back(std::move(Work)); } else { m_QueueLock.WithExclusiveLock([&] { m_WorkQueue.push_back(std::move(Work)); }); } SubmitThreadpoolWork(m_Work); return {}; } static VOID CALLBACK WorkCallback(_Inout_ PTP_CALLBACK_INSTANCE Instance, _Inout_opt_ PVOID Context, _Inout_ PTP_WORK Work) { ZEN_UNUSED(Instance, Work); WinTpImpl* ThisPtr = reinterpret_cast(Context); ThisPtr->DoWork(); } void DoWork() { m_FreeWorkerCount--; auto _ = MakeGuard([&]() { m_FreeWorkerCount++; }); if (!t_IsThreadNamed) { t_IsThreadNamed = true; const size_t ThreadIndex = ++m_WorkerThreadCounter; zen::ExtendableStringBuilder<128> ThreadName; ThreadName << m_WorkerThreadBaseName << "_" << ThreadIndex; SetCurrentThreadName(ThreadName); } Ref WorkFromQueue; { RwLock::ExclusiveLockScope __{m_QueueLock}; WorkFromQueue = std::move(m_WorkQueue.front()); m_WorkQueue.pop_front(); } try { ZEN_TRACE_CPU_FLUSH("AsyncWork"); WorkFromQueue->Execute(); WorkFromQueue = {}; } catch (const AssertException& Ex) { WorkFromQueue = {}; ZEN_WARN("Assert exception in worker thread: {}", Ex.FullDescription()); } catch (const std::exception& e) { WorkFromQueue = {}; ZEN_ERROR("Caught exception when executing worker synchronously: {}", e.what()); } } }; #endif ////////////////////////////////////////////////////////////////////////// struct WorkerThreadPool::ThreadStartInfo { int ThreadNumber; zen::Latch* Latch; }; struct ExplicitImpl : WorkerThreadPool::Impl { const int m_MinThreads; const int m_MaxThreads; std::atomic m_TotalThreads{0}; std::atomic m_ActiveCount{0}; void WorkerThreadFunction(WorkerThreadPool::ThreadStartInfo Info); void SpawnWorkerThread(); std::string m_WorkerThreadBaseName; RwLock m_ThreadListLock; std::vector m_WorkerThreads; BlockingQueue> m_WorkQueue; std::atomic m_FreeWorkerCount{0}; bool ScalingEnabled() const { return m_MinThreads != m_MaxThreads; } ExplicitImpl(int InMinThreadCount, int InMaxThreadCount, std::string_view WorkerThreadBaseName) : m_MinThreads(InMinThreadCount) , m_MaxThreads(InMaxThreadCount < InMinThreadCount ? InMinThreadCount : InMaxThreadCount) , m_WorkerThreadBaseName(WorkerThreadBaseName) , m_FreeWorkerCount(InMinThreadCount) { #if ZEN_WITH_TRACE trace::ThreadGroupBegin(m_WorkerThreadBaseName.c_str()); #endif zen::Latch WorkerLatch{m_MinThreads}; for (int i = 0; i < m_MinThreads; ++i) { m_TotalThreads.fetch_add(1, std::memory_order::relaxed); m_WorkerThreads.emplace_back(&ExplicitImpl::WorkerThreadFunction, this, WorkerThreadPool::ThreadStartInfo{i + 1, &WorkerLatch}); } WorkerLatch.Wait(); #if ZEN_WITH_TRACE trace::ThreadGroupEnd(); #endif } ~ExplicitImpl() override { m_WorkQueue.CompleteAdding(); RwLock::ExclusiveLockScope _(m_ThreadListLock); for (std::thread& Thread : m_WorkerThreads) { if (Thread.joinable()) { Thread.join(); } } m_WorkerThreads.clear(); } [[nodiscard]] Ref ScheduleWork(Ref Work, WorkerThreadPool::EMode Mode) override { if (Mode == WorkerThreadPool::EMode::DisableBacklog) { if (m_FreeWorkerCount <= 0) { return Work; } const int QueuedCount = gsl::narrow(m_WorkQueue.Size()); if (QueuedCount >= m_FreeWorkerCount) { return Work; } } m_WorkQueue.Enqueue(std::move(Work)); // Scale up: if all workers are busy and we haven't hit the max, spawn a new thread if (ScalingEnabled()) { const int Active = m_ActiveCount.load(std::memory_order::acquire); const int Total = m_TotalThreads.load(std::memory_order::acquire); if (Active >= Total && Total < m_MaxThreads) { int Expected = Total; if (m_TotalThreads.compare_exchange_strong(Expected, Total + 1, std::memory_order::acq_rel)) { ZEN_DEBUG("scaling up worker thread pool '{}', {} -> {} threads", m_WorkerThreadBaseName, Total, Total + 1); SpawnWorkerThread(); } } } return {}; } }; void ExplicitImpl::SpawnWorkerThread() { static std::atomic s_DynamicThreadIndex{0}; const int ThreadNumber = ++s_DynamicThreadIndex; RwLock::ExclusiveLockScope _(m_ThreadListLock); m_WorkerThreads.emplace_back(&ExplicitImpl::WorkerThreadFunction, this, WorkerThreadPool::ThreadStartInfo{ThreadNumber, nullptr}); } void ExplicitImpl::WorkerThreadFunction(WorkerThreadPool::ThreadStartInfo Info) { SetCurrentThreadName(fmt::format("{}_{}", m_WorkerThreadBaseName, Info.ThreadNumber)); if (Info.Latch) { Info.Latch->CountDown(); } static constexpr auto kIdleTimeout = std::chrono::seconds(15); do { Ref Work; bool Dequeued; if (ScalingEnabled()) { Dequeued = m_WorkQueue.WaitAndDequeueFor(Work, kIdleTimeout); } else { Dequeued = m_WorkQueue.WaitAndDequeue(Work); } if (Dequeued) { m_FreeWorkerCount--; m_ActiveCount.fetch_add(1, std::memory_order::acq_rel); auto _ = MakeGuard([&]() { m_ActiveCount.fetch_sub(1, std::memory_order::release); m_FreeWorkerCount++; }); try { ZEN_TRACE_CPU_FLUSH("AsyncWork"); Work->Execute(); Work = {}; } catch (const AssertException& Ex) { Work = {}; ZEN_WARN("Assert exception in worker thread: {}", Ex.FullDescription()); } catch (const std::exception& e) { Work = {}; ZEN_ERROR("Caught exception in worker thread: {}", e.what()); } } else if (ScalingEnabled()) { // Timed out - consider scaling down const int CurrentTotal = m_TotalThreads.load(std::memory_order::acquire); if (CurrentTotal > m_MinThreads) { int Expected = CurrentTotal; if (m_TotalThreads.compare_exchange_strong(Expected, CurrentTotal - 1, std::memory_order::acq_rel)) { ZEN_DEBUG("scaling down worker thread pool '{}' (idle timeout), {} threads remaining", m_WorkerThreadBaseName, CurrentTotal - 1); m_FreeWorkerCount--; return; // Thread exits } } // CAS failed or at min threads - continue waiting } else { // CompleteAdding was called - exit return; } } while (true); } ////////////////////////////////////////////////////////////////////////// WorkerThreadPool::WorkerThreadPool(int InThreadCount, bool UseExplicitThreads) : WorkerThreadPool(InThreadCount, "workerthread", UseExplicitThreads) { } WorkerThreadPool::WorkerThreadPool(int InThreadCount, std::string_view WorkerThreadBaseName, bool UseExplicitThreads) : WorkerThreadPool(InThreadCount, InThreadCount, WorkerThreadBaseName, UseExplicitThreads) { } WorkerThreadPool::WorkerThreadPool(int InMinThreadCount, int InMaxThreadCount, std::string_view WorkerThreadBaseName, bool UseExplicitThreads) { if (InMinThreadCount > 0) { #if ZEN_PLATFORM_WINDOWS if (!UseExplicitThreads) { m_Impl = std::make_unique(InMinThreadCount, InMaxThreadCount, WorkerThreadBaseName); } else #endif { ZEN_UNUSED(UseExplicitThreads); m_Impl = std::make_unique(InMinThreadCount, InMaxThreadCount, WorkerThreadBaseName); } } } WorkerThreadPool::~WorkerThreadPool() { m_Impl.reset(); } void WorkerThreadPool::ScheduleWork(Ref Work, EMode Mode) { if (m_Impl) { if (Work = m_Impl->ScheduleWork(std::move(Work), Mode); !Work) { return; } } try { ZEN_TRACE_CPU_FLUSH("SyncWork"); Work->Execute(); Work = {}; } catch (const AssertException& Ex) { Work = {}; ZEN_WARN("Assert exception in worker thread: {}", Ex.FullDescription()); } catch (const std::exception& e) { Work = {}; ZEN_ERROR("Caught exception when executing worker synchronously: {}", e.what()); } } void WorkerThreadPool::ScheduleWork(std::function&& Work, EMode Mode) { ScheduleWork(Ref(new detail::LambdaWork(std::move(Work))), Mode); } ////////////////////////////////////////////////////////////////////////// #if ZEN_WITH_TESTS void workthreadpool_forcelink() { } using namespace std::literals; TEST_CASE("threadpool.basic") { WorkerThreadPool Threadpool{1}; auto Future42 = Threadpool.EnqueueTask(std::packaged_task{[] { return 42; }}, WorkerThreadPool::EMode::EnableBacklog); auto Future99 = Threadpool.EnqueueTask(std::packaged_task{[] { return 99; }}, WorkerThreadPool::EMode::EnableBacklog); auto FutureThrow = Threadpool.EnqueueTask(std::packaged_task{[] { throw std::runtime_error("meep!"); }}, WorkerThreadPool::EMode::EnableBacklog); CHECK_EQ(Future42.get(), 42); CHECK_EQ(Future99.get(), 99); CHECK_THROWS(FutureThrow.get()); } #endif } // namespace zen