diff options
Diffstat (limited to 'src/zenhttp/servers/iothreadpool.cpp')
| -rw-r--r-- | src/zenhttp/servers/iothreadpool.cpp | 243 |
1 files changed, 240 insertions, 3 deletions
diff --git a/src/zenhttp/servers/iothreadpool.cpp b/src/zenhttp/servers/iothreadpool.cpp index e941606e2..d180f17f8 100644 --- a/src/zenhttp/servers/iothreadpool.cpp +++ b/src/zenhttp/servers/iothreadpool.cpp @@ -3,12 +3,35 @@ #include "iothreadpool.h" #include <zencore/except.h> +#include <zencore/logging.h> +#include <zencore/thread.h> #if ZEN_PLATFORM_WINDOWS +# include <thread> + namespace zen { -WinIoThreadPool::WinIoThreadPool(int InThreadCount, int InMaxThreadCount) +////////////////////////////////////////////////////////////////////////// +// Factory + +std::unique_ptr<WinIoThreadPool> +WinIoThreadPool::Create(bool UseExplicitThreads, int MinThreads, int MaxThreads) +{ + if (UseExplicitThreads) + { + return std::make_unique<ExplicitIoThreadPool>(MinThreads, MaxThreads); + } + else + { + return std::make_unique<WinTpIoThreadPool>(MinThreads, MaxThreads); + } +} + +////////////////////////////////////////////////////////////////////////// +// WinTpIoThreadPool - Windows Thread Pool implementation + +WinTpIoThreadPool::WinTpIoThreadPool(int InThreadCount, int InMaxThreadCount) { ZEN_ASSERT(InThreadCount); @@ -31,7 +54,7 @@ WinIoThreadPool::WinIoThreadPool(int InThreadCount, int InMaxThreadCount) SetThreadpoolCallbackCleanupGroup(&m_CallbackEnvironment, m_CleanupGroup, NULL); } -WinIoThreadPool::~WinIoThreadPool() +WinTpIoThreadPool::~WinTpIoThreadPool() { // this will wait for all callbacks to complete and tear down the `CreateThreadpoolIo` // object and release all related objects @@ -42,7 +65,7 @@ WinIoThreadPool::~WinIoThreadPool() } void -WinIoThreadPool::CreateIocp(HANDLE IoHandle, PTP_WIN32_IO_CALLBACK Callback, void* Context, std::error_code& ErrorCode) +WinTpIoThreadPool::CreateIocp(HANDLE IoHandle, PTP_WIN32_IO_CALLBACK Callback, void* Context, std::error_code& ErrorCode) { ZEN_ASSERT(!m_ThreadPoolIo); @@ -54,6 +77,220 @@ WinIoThreadPool::CreateIocp(HANDLE IoHandle, PTP_WIN32_IO_CALLBACK Callback, voi } } +void +WinTpIoThreadPool::StartIo() +{ + StartThreadpoolIo(m_ThreadPoolIo); +} + +void +WinTpIoThreadPool::CancelIo() +{ + CancelThreadpoolIo(m_ThreadPoolIo); +} + +////////////////////////////////////////////////////////////////////////// +// ExplicitIoThreadPool - Raw IOCP + std::thread with load-based scaling + +static LoggerRef +ExplicitIoPoolLog() +{ + static LoggerRef s_Log = logging::Get("iopool"); + return s_Log; +} + +ExplicitIoThreadPool::ExplicitIoThreadPool(int InMinThreadCount, int InMaxThreadCount) +: m_MinThreads(InMinThreadCount) +, m_MaxThreads(InMaxThreadCount) +{ + ZEN_ASSERT(InMinThreadCount > 0); + + if (m_MaxThreads < m_MinThreads) + { + m_MaxThreads = m_MinThreads; + } + + m_Iocp = CreateIoCompletionPort(INVALID_HANDLE_VALUE, NULL, 0, 0); + + if (!m_Iocp) + { + ZEN_LOG_ERROR(ExplicitIoPoolLog(), "failed to create I/O completion port: {}", GetLastError()); + } +} + +ExplicitIoThreadPool::~ExplicitIoThreadPool() +{ + m_ShuttingDown.store(true, std::memory_order::release); + + // Post poison-pill completions to wake all threads + const int ThreadCount = m_TotalThreads.load(std::memory_order::acquire); + for (int i = 0; i < ThreadCount; ++i) + { + PostQueuedCompletionStatus(m_Iocp, 0, 0, nullptr); + } + + // Join all threads + { + RwLock::ExclusiveLockScope _(m_ThreadListLock); + for (auto& Thread : m_Threads) + { + if (Thread.joinable()) + { + Thread.join(); + } + } + m_Threads.clear(); + } + + if (m_Iocp) + { + CloseHandle(m_Iocp); + m_Iocp = nullptr; + } +} + +void +ExplicitIoThreadPool::CreateIocp(HANDLE IoHandle, PTP_WIN32_IO_CALLBACK Callback, void* Context, std::error_code& ErrorCode) +{ + ZEN_ASSERT(m_Iocp); + ZEN_ASSERT(!m_Callback); + + m_Callback = Callback; + m_Context = Context; + + // Associate the I/O handle with our completion port + HANDLE Result = CreateIoCompletionPort(IoHandle, m_Iocp, /* CompletionKey */ 0, 0); + + if (!Result) + { + ErrorCode = MakeErrorCodeFromLastError(); + return; + } + + // Now spawn the initial worker threads + for (int i = 0; i < m_MinThreads; ++i) + { + SpawnWorkerThread(); + } +} + +void +ExplicitIoThreadPool::StartIo() +{ + // No-op for raw IOCP - completions are posted automatically +} + +void +ExplicitIoThreadPool::CancelIo() +{ + // No-op for raw IOCP - completions are posted automatically +} + +void +ExplicitIoThreadPool::SpawnWorkerThread() +{ + RwLock::ExclusiveLockScope _(m_ThreadListLock); + + ++m_TotalThreads; + m_Threads.emplace_back([this] { WorkerThreadMain(); }); +} + +void +ExplicitIoThreadPool::WorkerThreadMain() +{ + static std::atomic<int> s_ThreadIndex{0}; + const int ThreadIndex = ++s_ThreadIndex; + ExtendableStringBuilder<16> ThreadName; + ThreadName << "xpio_" << ThreadIndex; + SetCurrentThreadName(ThreadName); + + static constexpr DWORD kIdleTimeoutMs = 15000; + + while (!m_ShuttingDown.load(std::memory_order::acquire)) + { + DWORD BytesTransferred = 0; + ULONG_PTR CompletionKey = 0; + OVERLAPPED* pOverlapped = nullptr; + + BOOL Success = GetQueuedCompletionStatus(m_Iocp, &BytesTransferred, &CompletionKey, &pOverlapped, kIdleTimeoutMs); + + if (m_ShuttingDown.load(std::memory_order::acquire)) + { + break; + } + + if (!Success && !pOverlapped) + { + DWORD Error = GetLastError(); + + if (Error == WAIT_TIMEOUT) + { + // Timeout - consider scaling down + const int CurrentTotal = m_TotalThreads.load(std::memory_order::acquire); + if (CurrentTotal > m_MinThreads) + { + // Try to claim this thread for exit by decrementing the count. + // Use CAS to avoid thundering herd of exits. + int Expected = CurrentTotal; + if (m_TotalThreads.compare_exchange_strong(Expected, CurrentTotal - 1, std::memory_order::acq_rel)) + { + ZEN_LOG_DEBUG(ExplicitIoPoolLog(), + "scaling down I/O thread (idle timeout), {} threads remaining", + CurrentTotal - 1); + return; // Thread exits + } + } + continue; + } + + // Some other error with no overlapped - unexpected + ZEN_LOG_WARN(ExplicitIoPoolLog(), "GetQueuedCompletionStatus failed with error {}", Error); + continue; + } + + if (!pOverlapped) + { + // Poison pill (null overlapped) - shutdown signal + break; + } + + // Got a real completion - determine the I/O result + ULONG IoResult = NO_ERROR; + if (!Success) + { + IoResult = GetLastError(); + } + + // Track active threads for scale-up decisions + const int ActiveBefore = m_ActiveCount.fetch_add(1, std::memory_order::acq_rel); + const int TotalNow = m_TotalThreads.load(std::memory_order::acquire); + + // Scale up: if all threads are now busy and we haven't hit the max, spawn another + if ((ActiveBefore + 1) >= TotalNow && TotalNow < m_MaxThreads) + { + // Use CAS to ensure only one thread triggers the scale-up + int Expected = TotalNow; + if (m_TotalThreads.compare_exchange_strong(Expected, TotalNow + 1, std::memory_order::acq_rel)) + { + ZEN_LOG_DEBUG(ExplicitIoPoolLog(), "scaling up I/O thread pool, {} -> {} threads", TotalNow, TotalNow + 1); + + // Spawn outside the hot path - but we need the thread list lock + // We already incremented m_TotalThreads, so do the actual spawn + { + RwLock::ExclusiveLockScope _(m_ThreadListLock); + m_Threads.emplace_back([this] { WorkerThreadMain(); }); + } + } + } + + // Invoke the callback with the same signature as PTP_WIN32_IO_CALLBACK + // Parameters: Instance, Context, Overlapped, IoResult, NumberOfBytesTransferred, Io + m_Callback(nullptr, m_Context, pOverlapped, IoResult, BytesTransferred, nullptr); + + m_ActiveCount.fetch_sub(1, std::memory_order::release); + } +} + } // namespace zen #endif |