aboutsummaryrefslogtreecommitdiff
path: root/src/zencore/workthreadpool.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/zencore/workthreadpool.cpp')
-rw-r--r--src/zencore/workthreadpool.cpp126
1 files changed, 106 insertions, 20 deletions
diff --git a/src/zencore/workthreadpool.cpp b/src/zencore/workthreadpool.cpp
index df24d4185..a43ce7115 100644
--- a/src/zencore/workthreadpool.cpp
+++ b/src/zencore/workthreadpool.cpp
@@ -52,9 +52,10 @@ namespace {
struct WinTpImpl : WorkerThreadPool::Impl
{
- const int m_ThreadCount = 0;
- PTP_POOL m_ThreadPool = nullptr;
- PTP_CLEANUP_GROUP m_CleanupGroup = nullptr;
+ const int m_MinThreadCount = 0;
+ const int m_MaxThreadCount = 0;
+ PTP_POOL m_ThreadPool = nullptr;
+ PTP_CLEANUP_GROUP m_CleanupGroup = nullptr;
TP_CALLBACK_ENVIRON m_CallbackEnvironment;
PTP_WORK m_Work = nullptr;
@@ -65,10 +66,11 @@ struct WinTpImpl : WorkerThreadPool::Impl
mutable RwLock m_QueueLock;
std::deque<Ref<IWork>> m_WorkQueue;
- WinTpImpl(int InThreadCount, std::string_view WorkerThreadBaseName)
- : m_ThreadCount(InThreadCount)
+ WinTpImpl(int InMinThreadCount, int InMaxThreadCount, std::string_view WorkerThreadBaseName)
+ : m_MinThreadCount(InMinThreadCount)
+ , m_MaxThreadCount(InMaxThreadCount < InMinThreadCount ? InMinThreadCount : InMaxThreadCount)
, m_WorkerThreadBaseName(WorkerThreadBaseName)
- , m_FreeWorkerCount(m_ThreadCount)
+ , m_FreeWorkerCount(m_MinThreadCount)
{
// Thread pool setup
@@ -78,11 +80,11 @@ struct WinTpImpl : WorkerThreadPool::Impl
ThrowLastError("CreateThreadpool failed");
}
- if (!SetThreadpoolThreadMinimum(m_ThreadPool, (DWORD)m_ThreadCount))
+ if (!SetThreadpoolThreadMinimum(m_ThreadPool, (DWORD)m_MinThreadCount))
{
ThrowLastError("SetThreadpoolThreadMinimum failed");
}
- SetThreadpoolThreadMaximum(m_ThreadPool, (DWORD)m_ThreadCount);
+ SetThreadpoolThreadMaximum(m_ThreadPool, (DWORD)m_MaxThreadCount);
InitializeThreadpoolEnvironment(&m_CallbackEnvironment);
@@ -193,26 +195,35 @@ struct WorkerThreadPool::ThreadStartInfo
struct ExplicitImpl : WorkerThreadPool::Impl
{
- const int m_ThreadCount = 0;
+ const int m_MinThreads;
+ const int m_MaxThreads;
+ std::atomic<int> m_TotalThreads{0};
+ std::atomic<int> m_ActiveCount{0};
void WorkerThreadFunction(WorkerThreadPool::ThreadStartInfo Info);
+ void SpawnWorkerThread();
std::string m_WorkerThreadBaseName;
+ RwLock m_ThreadListLock;
std::vector<std::thread> m_WorkerThreads;
BlockingQueue<Ref<IWork>> m_WorkQueue;
std::atomic<int> m_FreeWorkerCount{0};
- ExplicitImpl(int InThreadCount, std::string_view WorkerThreadBaseName)
- : m_ThreadCount(InThreadCount)
+ bool ScalingEnabled() const { return m_MinThreads != m_MaxThreads; }
+
+ ExplicitImpl(int InMinThreadCount, int InMaxThreadCount, std::string_view WorkerThreadBaseName)
+ : m_MinThreads(InMinThreadCount)
+ , m_MaxThreads(InMaxThreadCount < InMinThreadCount ? InMinThreadCount : InMaxThreadCount)
, m_WorkerThreadBaseName(WorkerThreadBaseName)
- , m_FreeWorkerCount(m_ThreadCount)
+ , m_FreeWorkerCount(InMinThreadCount)
{
#if ZEN_WITH_TRACE
trace::ThreadGroupBegin(m_WorkerThreadBaseName.c_str());
#endif
- zen::Latch WorkerLatch{m_ThreadCount};
+ zen::Latch WorkerLatch{m_MinThreads};
- for (int i = 0; i < m_ThreadCount; ++i)
+ for (int i = 0; i < m_MinThreads; ++i)
{
+ m_TotalThreads.fetch_add(1, std::memory_order::relaxed);
m_WorkerThreads.emplace_back(&ExplicitImpl::WorkerThreadFunction, this, WorkerThreadPool::ThreadStartInfo{i + 1, &WorkerLatch});
}
@@ -227,6 +238,7 @@ struct ExplicitImpl : WorkerThreadPool::Impl
{
m_WorkQueue.CompleteAdding();
+ RwLock::ExclusiveLockScope _(m_ThreadListLock);
for (std::thread& Thread : m_WorkerThreads)
{
if (Thread.joinable())
@@ -253,24 +265,71 @@ struct ExplicitImpl : WorkerThreadPool::Impl
}
}
m_WorkQueue.Enqueue(std::move(Work));
+
+ // Scale up: if all workers are busy and we haven't hit the max, spawn a new thread
+ if (ScalingEnabled())
+ {
+ const int Active = m_ActiveCount.load(std::memory_order::acquire);
+ const int Total = m_TotalThreads.load(std::memory_order::acquire);
+ if (Active >= Total && Total < m_MaxThreads)
+ {
+ int Expected = Total;
+ if (m_TotalThreads.compare_exchange_strong(Expected, Total + 1, std::memory_order::acq_rel))
+ {
+ ZEN_DEBUG("scaling up worker thread pool '{}', {} -> {} threads", m_WorkerThreadBaseName, Total, Total + 1);
+ SpawnWorkerThread();
+ }
+ }
+ }
+
return {};
}
};
void
+ExplicitImpl::SpawnWorkerThread()
+{
+ static std::atomic<int> s_DynamicThreadIndex{0};
+ const int ThreadNumber = ++s_DynamicThreadIndex;
+
+ RwLock::ExclusiveLockScope _(m_ThreadListLock);
+ m_WorkerThreads.emplace_back(&ExplicitImpl::WorkerThreadFunction, this, WorkerThreadPool::ThreadStartInfo{ThreadNumber, nullptr});
+}
+
+void
ExplicitImpl::WorkerThreadFunction(WorkerThreadPool::ThreadStartInfo Info)
{
SetCurrentThreadName(fmt::format("{}_{}", m_WorkerThreadBaseName, Info.ThreadNumber));
- Info.Latch->CountDown();
+ if (Info.Latch)
+ {
+ Info.Latch->CountDown();
+ }
+
+ static constexpr auto kIdleTimeout = std::chrono::seconds(15);
do
{
Ref<IWork> Work;
- if (m_WorkQueue.WaitAndDequeue(Work))
+
+ bool Dequeued;
+ if (ScalingEnabled())
+ {
+ Dequeued = m_WorkQueue.WaitAndDequeueFor(Work, kIdleTimeout);
+ }
+ else
+ {
+ Dequeued = m_WorkQueue.WaitAndDequeue(Work);
+ }
+
+ if (Dequeued)
{
m_FreeWorkerCount--;
- auto _ = MakeGuard([&]() { m_FreeWorkerCount++; });
+ m_ActiveCount.fetch_add(1, std::memory_order::acq_rel);
+ auto _ = MakeGuard([&]() {
+ m_ActiveCount.fetch_sub(1, std::memory_order::release);
+ m_FreeWorkerCount++;
+ });
try
{
@@ -289,8 +348,27 @@ ExplicitImpl::WorkerThreadFunction(WorkerThreadPool::ThreadStartInfo Info)
ZEN_ERROR("Caught exception in worker thread: {}", e.what());
}
}
+ else if (ScalingEnabled())
+ {
+ // Timed out - consider scaling down
+ const int CurrentTotal = m_TotalThreads.load(std::memory_order::acquire);
+ if (CurrentTotal > m_MinThreads)
+ {
+ int Expected = CurrentTotal;
+ if (m_TotalThreads.compare_exchange_strong(Expected, CurrentTotal - 1, std::memory_order::acq_rel))
+ {
+ ZEN_DEBUG("scaling down worker thread pool '{}' (idle timeout), {} threads remaining",
+ m_WorkerThreadBaseName,
+ CurrentTotal - 1);
+ m_FreeWorkerCount--;
+ return; // Thread exits
+ }
+ }
+ // CAS failed or at min threads - continue waiting
+ }
else
{
+ // CompleteAdding was called - exit
return;
}
} while (true);
@@ -303,19 +381,27 @@ WorkerThreadPool::WorkerThreadPool(int InThreadCount, bool UseExplicitThreads)
}
WorkerThreadPool::WorkerThreadPool(int InThreadCount, std::string_view WorkerThreadBaseName, bool UseExplicitThreads)
+: WorkerThreadPool(InThreadCount, InThreadCount, WorkerThreadBaseName, UseExplicitThreads)
+{
+}
+
+WorkerThreadPool::WorkerThreadPool(int InMinThreadCount,
+ int InMaxThreadCount,
+ std::string_view WorkerThreadBaseName,
+ bool UseExplicitThreads)
{
- if (InThreadCount > 0)
+ if (InMinThreadCount > 0)
{
#if ZEN_PLATFORM_WINDOWS
if (!UseExplicitThreads)
{
- m_Impl = std::make_unique<WinTpImpl>(InThreadCount, WorkerThreadBaseName);
+ m_Impl = std::make_unique<WinTpImpl>(InMinThreadCount, InMaxThreadCount, WorkerThreadBaseName);
}
else
#endif
{
ZEN_UNUSED(UseExplicitThreads);
- m_Impl = std::make_unique<ExplicitImpl>(InThreadCount, WorkerThreadBaseName);
+ m_Impl = std::make_unique<ExplicitImpl>(InMinThreadCount, InMaxThreadCount, WorkerThreadBaseName);
}
}
}