// Copyright Epic Games, Inc. All Rights Reserved. #include #include #include #include #include #include #include #include #include #include #include ZEN_THIRD_PARTY_INCLUDES_START #include ZEN_THIRD_PARTY_INCLUDES_END #define ZEN_USE_WINDOWS_THREADPOOL 1 #if ZEN_PLATFORM_WINDOWS && ZEN_USE_WINDOWS_THREADPOOL # include #endif namespace zen { namespace detail { struct LambdaWork : IWork { LambdaWork(auto Work) : WorkFunction(Work) {} virtual void Execute() override { WorkFunction(); } std::function WorkFunction; }; } // namespace detail ////////////////////////////////////////////////////////////////////////// #if ZEN_USE_WINDOWS_THREADPOOL && ZEN_PLATFORM_WINDOWS namespace { thread_local bool t_IsThreadNamed{false}; } struct WorkerThreadPool::Impl { const int m_ThreadCount = 0; PTP_POOL m_ThreadPool = nullptr; PTP_CLEANUP_GROUP m_CleanupGroup = nullptr; TP_CALLBACK_ENVIRON m_CallbackEnvironment; PTP_WORK m_Work = nullptr; std::string m_WorkerThreadBaseName; std::atomic m_WorkerThreadCounter{0}; std::atomic m_FreeWorkerCount{0}; mutable RwLock m_QueueLock; std::deque> m_WorkQueue; Impl(int InThreadCount, std::string_view WorkerThreadBaseName) : m_ThreadCount(InThreadCount) , m_WorkerThreadBaseName(WorkerThreadBaseName) , m_FreeWorkerCount(m_ThreadCount) { // Thread pool setup m_ThreadPool = CreateThreadpool(NULL); if (m_ThreadPool == NULL) { ThrowLastError("CreateThreadpool failed"); } if (!SetThreadpoolThreadMinimum(m_ThreadPool, (DWORD)m_ThreadCount)) { ThrowLastError("SetThreadpoolThreadMinimum failed"); } SetThreadpoolThreadMaximum(m_ThreadPool, (DWORD)m_ThreadCount); InitializeThreadpoolEnvironment(&m_CallbackEnvironment); m_CleanupGroup = CreateThreadpoolCleanupGroup(); if (m_CleanupGroup == NULL) { ThrowLastError("CreateThreadpoolCleanupGroup failed"); } SetThreadpoolCallbackPool(&m_CallbackEnvironment, m_ThreadPool); SetThreadpoolCallbackCleanupGroup(&m_CallbackEnvironment, m_CleanupGroup, NULL); m_Work = CreateThreadpoolWork(&WorkCallback, this, &m_CallbackEnvironment); if (m_Work == NULL) { ThrowLastError("CreateThreadpoolWork failed"); } } ~Impl() { WaitForThreadpoolWorkCallbacks(m_Work, /* CancelPendingCallbacks */ TRUE); CloseThreadpoolWork(m_Work); CloseThreadpool(m_ThreadPool); } [[nodiscard]] Ref ScheduleWork(Ref Work, WorkerThreadPool::EMode Mode) { if (Mode == WorkerThreadPool::EMode::DisableBacklog) { if (m_FreeWorkerCount <= 0) { return Work; } RwLock::ExclusiveLockScope _(m_QueueLock); const int QueuedCount = gsl::narrow(m_WorkQueue.size()); if (QueuedCount >= m_FreeWorkerCount) { return Work; } m_WorkQueue.push_back(std::move(Work)); } else { m_QueueLock.WithExclusiveLock([&] { m_WorkQueue.push_back(std::move(Work)); }); } SubmitThreadpoolWork(m_Work); return {}; } static VOID CALLBACK WorkCallback(_Inout_ PTP_CALLBACK_INSTANCE Instance, _Inout_opt_ PVOID Context, _Inout_ PTP_WORK Work) { ZEN_UNUSED(Instance, Work); Impl* ThisPtr = reinterpret_cast(Context); ThisPtr->DoWork(); } void DoWork() { m_FreeWorkerCount--; auto _ = MakeGuard([&]() { m_FreeWorkerCount++; }); if (!t_IsThreadNamed) { t_IsThreadNamed = true; const size_t ThreadIndex = ++m_WorkerThreadCounter; zen::ExtendableStringBuilder<128> ThreadName; ThreadName << m_WorkerThreadBaseName << "_" << ThreadIndex; SetCurrentThreadName(ThreadName); } Ref WorkFromQueue; { RwLock::ExclusiveLockScope __{m_QueueLock}; WorkFromQueue = std::move(m_WorkQueue.front()); m_WorkQueue.pop_front(); } try { ZEN_TRACE_CPU_FLUSH("AsyncWork"); WorkFromQueue->Execute(); WorkFromQueue = {}; } catch (const AssertException& Ex) { WorkFromQueue = {}; ZEN_WARN("Assert exception in worker thread: {}", Ex.FullDescription()); } catch (const std::exception& e) { WorkFromQueue = {}; ZEN_ERROR("Caught exception when executing worker synchronously: {}", e.what()); } } }; #else struct WorkerThreadPool::ThreadStartInfo { int ThreadNumber; zen::Latch* Latch; }; struct WorkerThreadPool::Impl { const int m_ThreadCount = 0; void WorkerThreadFunction(ThreadStartInfo Info); std::string m_WorkerThreadBaseName; std::vector m_WorkerThreads; BlockingQueue> m_WorkQueue; std::atomic m_FreeWorkerCount{0}; Impl(int InThreadCount, std::string_view WorkerThreadBaseName) : m_ThreadCount(InThreadCount) , m_WorkerThreadBaseName(WorkerThreadBaseName) , m_FreeWorkerCount(m_ThreadCount) { # if ZEN_WITH_TRACE trace::ThreadGroupBegin(m_WorkerThreadBaseName.c_str()); # endif zen::Latch WorkerLatch{m_ThreadCount}; for (int i = 0; i < m_ThreadCount; ++i) { m_WorkerThreads.emplace_back(&Impl::WorkerThreadFunction, this, ThreadStartInfo{i + 1, &WorkerLatch}); } WorkerLatch.Wait(); # if ZEN_WITH_TRACE trace::ThreadGroupEnd(); # endif } ~Impl() { m_WorkQueue.CompleteAdding(); for (std::thread& Thread : m_WorkerThreads) { if (Thread.joinable()) { Thread.join(); } } m_WorkerThreads.clear(); } [[nodiscard]] Ref ScheduleWork(Ref Work, WorkerThreadPool::EMode Mode) { if (Mode == WorkerThreadPool::EMode::DisableBacklog) { if (m_FreeWorkerCount <= 0) { return Work; } const int QueuedCount = gsl::narrow(m_WorkQueue.Size()); if (QueuedCount >= m_FreeWorkerCount) { return Work; } } m_WorkQueue.Enqueue(std::move(Work)); return {}; } }; void WorkerThreadPool::Impl::WorkerThreadFunction(ThreadStartInfo Info) { SetCurrentThreadName(fmt::format("{}_{}", m_WorkerThreadBaseName, Info.ThreadNumber)); Info.Latch->CountDown(); do { Ref Work; if (m_WorkQueue.WaitAndDequeue(Work)) { m_FreeWorkerCount--; auto _ = MakeGuard([&]() { m_FreeWorkerCount++; }); try { ZEN_TRACE_CPU_FLUSH("AsyncWork"); Work->Execute(); Work = {}; } catch (const AssertException& Ex) { Work = {}; ZEN_WARN("Assert exception in worker thread: {}", Ex.FullDescription()); } catch (const std::exception& e) { Work = {}; ZEN_ERROR("Caught exception in worker thread: {}", e.what()); } } else { return; } } while (true); } #endif ////////////////////////////////////////////////////////////////////////// WorkerThreadPool::WorkerThreadPool(int InThreadCount) : WorkerThreadPool(InThreadCount, "workerthread") { } WorkerThreadPool::WorkerThreadPool(int InThreadCount, std::string_view WorkerThreadBaseName) { if (InThreadCount > 0) { m_Impl = std::make_unique(InThreadCount, WorkerThreadBaseName); } } WorkerThreadPool::~WorkerThreadPool() { m_Impl.reset(); } void WorkerThreadPool::ScheduleWork(Ref Work, EMode Mode) { if (m_Impl) { if (Work = m_Impl->ScheduleWork(std::move(Work), Mode); !Work) { return; } } try { ZEN_TRACE_CPU_FLUSH("SyncWork"); Work->Execute(); Work = {}; } catch (const AssertException& Ex) { Work = {}; ZEN_WARN("Assert exception in worker thread: {}", Ex.FullDescription()); } catch (const std::exception& e) { Work = {}; ZEN_ERROR("Caught exception when executing worker synchronously: {}", e.what()); } } void WorkerThreadPool::ScheduleWork(std::function&& Work, EMode Mode) { ScheduleWork(Ref(new detail::LambdaWork(std::move(Work))), Mode); } ////////////////////////////////////////////////////////////////////////// #if ZEN_WITH_TESTS void workthreadpool_forcelink() { } using namespace std::literals; TEST_CASE("threadpool.basic") { WorkerThreadPool Threadpool{1}; auto Future42 = Threadpool.EnqueueTask(std::packaged_task{[] { return 42; }}, WorkerThreadPool::EMode::EnableBacklog); auto Future99 = Threadpool.EnqueueTask(std::packaged_task{[] { return 99; }}, WorkerThreadPool::EMode::EnableBacklog); auto FutureThrow = Threadpool.EnqueueTask(std::packaged_task{[] { throw std::runtime_error("meep!"); }}, WorkerThreadPool::EMode::EnableBacklog); CHECK_EQ(Future42.get(), 42); CHECK_EQ(Future99.get(), 99); CHECK_THROWS(FutureThrow.get()); } #endif } // namespace zen