// Copyright Epic Games, Inc. All Rights Reserved. #include "iothreadpool.h" #include #include #include #if ZEN_PLATFORM_WINDOWS # include namespace zen { ////////////////////////////////////////////////////////////////////////// // Factory std::unique_ptr WinIoThreadPool::Create(bool UseExplicitThreads, int MinThreads, int MaxThreads) { if (UseExplicitThreads) { return std::make_unique(MinThreads, MaxThreads); } else { return std::make_unique(MinThreads, MaxThreads); } } ////////////////////////////////////////////////////////////////////////// // WinTpIoThreadPool - Windows Thread Pool implementation WinTpIoThreadPool::WinTpIoThreadPool(int InThreadCount, int InMaxThreadCount) { ZEN_ASSERT(InThreadCount); if (InMaxThreadCount < InThreadCount) { InMaxThreadCount = InThreadCount; } m_ThreadPool = CreateThreadpool(NULL); SetThreadpoolThreadMinimum(m_ThreadPool, InThreadCount); SetThreadpoolThreadMaximum(m_ThreadPool, InMaxThreadCount); InitializeThreadpoolEnvironment(&m_CallbackEnvironment); m_CleanupGroup = CreateThreadpoolCleanupGroup(); SetThreadpoolCallbackPool(&m_CallbackEnvironment, m_ThreadPool); SetThreadpoolCallbackCleanupGroup(&m_CallbackEnvironment, m_CleanupGroup, NULL); } WinTpIoThreadPool::~WinTpIoThreadPool() { // this will wait for all callbacks to complete and tear down the `CreateThreadpoolIo` // object and release all related objects CloseThreadpoolCleanupGroupMembers(m_CleanupGroup, /* cancel pending callbacks */ TRUE, nullptr); CloseThreadpoolCleanupGroup(m_CleanupGroup); CloseThreadpool(m_ThreadPool); DestroyThreadpoolEnvironment(&m_CallbackEnvironment); } void WinTpIoThreadPool::CreateIocp(HANDLE IoHandle, PTP_WIN32_IO_CALLBACK Callback, void* Context, std::error_code& ErrorCode) { ZEN_ASSERT(!m_ThreadPoolIo); m_ThreadPoolIo = CreateThreadpoolIo(IoHandle, Callback, Context, &m_CallbackEnvironment); if (!m_ThreadPoolIo) { ErrorCode = MakeErrorCodeFromLastError(); } } void WinTpIoThreadPool::StartIo() { StartThreadpoolIo(m_ThreadPoolIo); } void WinTpIoThreadPool::CancelIo() { CancelThreadpoolIo(m_ThreadPoolIo); } ////////////////////////////////////////////////////////////////////////// // ExplicitIoThreadPool - Raw IOCP + std::thread with load-based scaling static LoggerRef ExplicitIoPoolLog() { static LoggerRef s_Log = logging::Get("iopool"); return s_Log; } ExplicitIoThreadPool::ExplicitIoThreadPool(int InMinThreadCount, int InMaxThreadCount) : m_MinThreads(InMinThreadCount) , m_MaxThreads(InMaxThreadCount) { ZEN_ASSERT(InMinThreadCount > 0); if (m_MaxThreads < m_MinThreads) { m_MaxThreads = m_MinThreads; } m_Iocp = CreateIoCompletionPort(INVALID_HANDLE_VALUE, NULL, 0, 0); if (!m_Iocp) { ZEN_LOG_ERROR(ExplicitIoPoolLog(), "failed to create I/O completion port: {}", GetLastError()); } } ExplicitIoThreadPool::~ExplicitIoThreadPool() { m_ShuttingDown.store(true, std::memory_order::release); // Post poison-pill completions to wake all threads const int ThreadCount = m_TotalThreads.load(std::memory_order::acquire); for (int i = 0; i < ThreadCount; ++i) { PostQueuedCompletionStatus(m_Iocp, 0, 0, nullptr); } // Join all threads { RwLock::ExclusiveLockScope _(m_ThreadListLock); for (auto& Thread : m_Threads) { if (Thread.joinable()) { Thread.join(); } } m_Threads.clear(); } if (m_Iocp) { CloseHandle(m_Iocp); m_Iocp = nullptr; } } void ExplicitIoThreadPool::CreateIocp(HANDLE IoHandle, PTP_WIN32_IO_CALLBACK Callback, void* Context, std::error_code& ErrorCode) { ZEN_ASSERT(m_Iocp); ZEN_ASSERT(!m_Callback); m_Callback = Callback; m_Context = Context; // Associate the I/O handle with our completion port HANDLE Result = CreateIoCompletionPort(IoHandle, m_Iocp, /* CompletionKey */ 0, 0); if (!Result) { ErrorCode = MakeErrorCodeFromLastError(); return; } // Now spawn the initial worker threads for (int i = 0; i < m_MinThreads; ++i) { SpawnWorkerThread(); } } void ExplicitIoThreadPool::StartIo() { // No-op for raw IOCP - completions are posted automatically } void ExplicitIoThreadPool::CancelIo() { // No-op for raw IOCP - completions are posted automatically } void ExplicitIoThreadPool::SpawnWorkerThread() { RwLock::ExclusiveLockScope _(m_ThreadListLock); ++m_TotalThreads; m_Threads.emplace_back([this] { WorkerThreadMain(); }); } void ExplicitIoThreadPool::WorkerThreadMain() { static std::atomic s_ThreadIndex{0}; const int ThreadIndex = ++s_ThreadIndex; ExtendableStringBuilder<16> ThreadName; ThreadName << "xpio_" << ThreadIndex; SetCurrentThreadName(ThreadName); static constexpr DWORD kIdleTimeoutMs = 15000; while (!m_ShuttingDown.load(std::memory_order::acquire)) { DWORD BytesTransferred = 0; ULONG_PTR CompletionKey = 0; OVERLAPPED* pOverlapped = nullptr; BOOL Success = GetQueuedCompletionStatus(m_Iocp, &BytesTransferred, &CompletionKey, &pOverlapped, kIdleTimeoutMs); if (m_ShuttingDown.load(std::memory_order::acquire)) { break; } if (!Success && !pOverlapped) { DWORD Error = GetLastError(); if (Error == WAIT_TIMEOUT) { // Timeout - consider scaling down const int CurrentTotal = m_TotalThreads.load(std::memory_order::acquire); if (CurrentTotal > m_MinThreads) { // Try to claim this thread for exit by decrementing the count. // Use CAS to avoid thundering herd of exits. int Expected = CurrentTotal; if (m_TotalThreads.compare_exchange_strong(Expected, CurrentTotal - 1, std::memory_order::acq_rel)) { ZEN_LOG_DEBUG(ExplicitIoPoolLog(), "scaling down I/O thread (idle timeout), {} threads remaining", CurrentTotal - 1); return; // Thread exits } } continue; } // Some other error with no overlapped - unexpected ZEN_LOG_WARN(ExplicitIoPoolLog(), "GetQueuedCompletionStatus failed with error {}", Error); continue; } if (!pOverlapped) { // Poison pill (null overlapped) - shutdown signal break; } // Got a real completion - determine the I/O result ULONG IoResult = NO_ERROR; if (!Success) { IoResult = GetLastError(); } // Track active threads for scale-up decisions const int ActiveBefore = m_ActiveCount.fetch_add(1, std::memory_order::acq_rel); const int TotalNow = m_TotalThreads.load(std::memory_order::acquire); // Scale up: if all threads are now busy and we haven't hit the max, spawn another if ((ActiveBefore + 1) >= TotalNow && TotalNow < m_MaxThreads) { // Use CAS to ensure only one thread triggers the scale-up int Expected = TotalNow; if (m_TotalThreads.compare_exchange_strong(Expected, TotalNow + 1, std::memory_order::acq_rel)) { ZEN_LOG_DEBUG(ExplicitIoPoolLog(), "scaling up I/O thread pool, {} -> {} threads", TotalNow, TotalNow + 1); // Spawn outside the hot path - but we need the thread list lock // We already incremented m_TotalThreads, so do the actual spawn { RwLock::ExclusiveLockScope _(m_ThreadListLock); m_Threads.emplace_back([this] { WorkerThreadMain(); }); } } } // Invoke the callback with the same signature as PTP_WIN32_IO_CALLBACK // Parameters: Instance, Context, Overlapped, IoResult, NumberOfBytesTransferred, Io m_Callback(nullptr, m_Context, pOverlapped, IoResult, BytesTransferred, nullptr); m_ActiveCount.fetch_sub(1, std::memory_order::release); } } } // namespace zen #endif