// Copyright Epic Games, Inc. All Rights Reserved. #include #include #include #include #include #include #include #include #include #define ZEN_USE_WINDOWS_THREADPOOL 1 #if ZEN_PLATFORM_WINDOWS && ZEN_USE_WINDOWS_THREADPOOL # include #endif namespace zen { namespace detail { struct LambdaWork : IWork { LambdaWork(auto Work) : WorkFunction(Work) {} virtual void Execute() override { WorkFunction(); } std::function WorkFunction; }; } // namespace detail ////////////////////////////////////////////////////////////////////////// #if ZEN_USE_WINDOWS_THREADPOOL && ZEN_PLATFORM_WINDOWS namespace { thread_local bool t_IsThreadNamed{false}; } struct WorkerThreadPool::Impl { PTP_POOL m_ThreadPool = nullptr; PTP_CLEANUP_GROUP m_CleanupGroup = nullptr; TP_CALLBACK_ENVIRON m_CallbackEnvironment; PTP_WORK m_Work = nullptr; std::string m_WorkerThreadBaseName; std::atomic m_WorkerThreadCounter{0}; RwLock m_QueueLock; std::deque> m_WorkQueue; Impl(int InThreadCount, std::string_view WorkerThreadBaseName) : m_WorkerThreadBaseName(WorkerThreadBaseName) { // Thread pool setup m_ThreadPool = CreateThreadpool(NULL); SetThreadpoolThreadMinimum(m_ThreadPool, InThreadCount); SetThreadpoolThreadMaximum(m_ThreadPool, InThreadCount * 2); InitializeThreadpoolEnvironment(&m_CallbackEnvironment); m_CleanupGroup = CreateThreadpoolCleanupGroup(); SetThreadpoolCallbackPool(&m_CallbackEnvironment, m_ThreadPool); SetThreadpoolCallbackCleanupGroup(&m_CallbackEnvironment, m_CleanupGroup, NULL); m_Work = CreateThreadpoolWork(&WorkCallback, this, &m_CallbackEnvironment); } ~Impl() { WaitForThreadpoolWorkCallbacks(m_Work, /* CancelPendingCallbacks */ TRUE); CloseThreadpoolWork(m_Work); CloseThreadpool(m_ThreadPool); } void ScheduleWork(Ref Work) { m_QueueLock.WithExclusiveLock([&] { m_WorkQueue.push_back(std::move(Work)); }); SubmitThreadpoolWork(m_Work); } [[nodiscard]] size_t PendingWorkItemCount() const { return 0; } static VOID CALLBACK WorkCallback(_Inout_ PTP_CALLBACK_INSTANCE Instance, _Inout_opt_ PVOID Context, _Inout_ PTP_WORK Work) { ZEN_UNUSED(Instance, Work); Impl* ThisPtr = reinterpret_cast(Context); ThisPtr->DoWork(); } void DoWork() { if (!t_IsThreadNamed) { t_IsThreadNamed = true; const int ThreadIndex = ++m_WorkerThreadCounter; zen::ExtendableStringBuilder<128> ThreadName; ThreadName << m_WorkerThreadBaseName << "_" << ThreadIndex; SetCurrentThreadName(ThreadName); } Ref WorkFromQueue; { RwLock::ExclusiveLockScope _{m_QueueLock}; WorkFromQueue = std::move(m_WorkQueue.front()); m_WorkQueue.pop_front(); } ZEN_TRACE_CPU_FLUSH("AsyncWork"); WorkFromQueue->Execute(); } }; #else struct WorkerThreadPool::ThreadStartInfo { int ThreadNumber; zen::Latch* Latch; }; struct WorkerThreadPool::Impl { void WorkerThreadFunction(ThreadStartInfo Info); std::string m_WorkerThreadBaseName; std::vector m_WorkerThreads; BlockingQueue> m_WorkQueue; Impl(int InThreadCount, std::string_view WorkerThreadBaseName) : m_WorkerThreadBaseName(WorkerThreadBaseName) { # if ZEN_WITH_TRACE trace::ThreadGroupBegin(m_WorkerThreadBaseName.c_str()); # endif zen::Latch WorkerLatch{InThreadCount}; for (int i = 0; i < InThreadCount; ++i) { m_WorkerThreads.emplace_back(&Impl::WorkerThreadFunction, this, ThreadStartInfo{i + 1, &WorkerLatch}); } WorkerLatch.Wait(); # if ZEN_WITH_TRACE trace::ThreadGroupEnd(); # endif } ~Impl() { m_WorkQueue.CompleteAdding(); for (std::thread& Thread : m_WorkerThreads) { if (Thread.joinable()) { Thread.join(); } } m_WorkerThreads.clear(); } void ScheduleWork(Ref Work) { m_WorkQueue.Enqueue(std::move(Work)); } [[nodiscard]] size_t PendingWorkItemCount() const { return m_WorkQueue.Size(); } }; void WorkerThreadPool::Impl::WorkerThreadFunction(ThreadStartInfo Info) { SetCurrentThreadName(fmt::format("{}_{}", m_WorkerThreadBaseName, Info.ThreadNumber)); Info.Latch->CountDown(); do { Ref Work; if (m_WorkQueue.WaitAndDequeue(Work)) { try { ZEN_TRACE_CPU_FLUSH("AsyncWork"); Work->Execute(); } catch (const AssertException& Ex) { Work->m_Exception = std::current_exception(); ZEN_WARN("Assert exception in worker thread: {}", Ex.FullDescription()); } catch (const std::exception& e) { Work->m_Exception = std::current_exception(); ZEN_WARN("Caught exception in worker thread: {}", e.what()); } } else { return; } } while (true); } #endif ////////////////////////////////////////////////////////////////////////// WorkerThreadPool::WorkerThreadPool(int InThreadCount) : WorkerThreadPool(InThreadCount, "workerthread") { } WorkerThreadPool::WorkerThreadPool(int InThreadCount, std::string_view WorkerThreadBaseName) { if (InThreadCount > 0) { m_Impl = std::make_unique(InThreadCount, WorkerThreadBaseName); } } WorkerThreadPool::~WorkerThreadPool() { m_Impl.reset(); } void WorkerThreadPool::ScheduleWork(Ref Work) { if (m_Impl) { m_Impl->ScheduleWork(std::move(Work)); } else { try { ZEN_TRACE_CPU_FLUSH("SyncWork"); Work->Execute(); } catch (const AssertException& Ex) { Work->m_Exception = std::current_exception(); ZEN_WARN("Assert exception in worker thread: {}", Ex.FullDescription()); } catch (const std::exception& e) { Work->m_Exception = std::current_exception(); ZEN_WARN("Caught exception when executing worker synchronously: {}", e.what()); } } } void WorkerThreadPool::ScheduleWork(std::function&& Work) { ScheduleWork(Ref(new detail::LambdaWork(Work))); } [[nodiscard]] size_t WorkerThreadPool::PendingWorkItemCount() const { if (m_Impl) { return m_Impl->PendingWorkItemCount(); } return 0; } ////////////////////////////////////////////////////////////////////////// #if ZEN_WITH_TESTS void workthreadpool_forcelink() { } using namespace std::literals; TEST_CASE("threadpool.basic") { WorkerThreadPool Threadpool{1}; auto Future42 = Threadpool.EnqueueTask(std::packaged_task{[] { return 42; }}); auto Future99 = Threadpool.EnqueueTask(std::packaged_task{[] { return 99; }}); auto FutureThrow = Threadpool.EnqueueTask(std::packaged_task{[] { throw std::runtime_error("meep!"); }}); CHECK_EQ(Future42.get(), 42); CHECK_EQ(Future99.get(), 99); CHECK_THROWS(FutureThrow.get()); } #endif } // namespace zen