aboutsummaryrefslogtreecommitdiff
path: root/src/zencompute
diff options
context:
space:
mode:
Diffstat (limited to 'src/zencompute')
-rw-r--r--src/zencompute/CLAUDE.md17
-rw-r--r--src/zencompute/cloudmetadata.cpp8
-rw-r--r--src/zencompute/computeservice.cpp449
-rw-r--r--src/zencompute/httpcomputeservice.cpp274
-rw-r--r--src/zencompute/httporchestrator.cpp142
-rw-r--r--src/zencompute/include/zencompute/cloudmetadata.h2
-rw-r--r--src/zencompute/include/zencompute/computeservice.h8
-rw-r--r--src/zencompute/include/zencompute/httpcomputeservice.h11
-rw-r--r--src/zencompute/include/zencompute/httporchestrator.h23
-rw-r--r--src/zencompute/include/zencompute/mockimds.h2
-rw-r--r--src/zencompute/include/zencompute/orchestratorservice.h19
-rw-r--r--src/zencompute/include/zencompute/provisionerstate.h38
-rw-r--r--src/zencompute/orchestratorservice.cpp33
-rw-r--r--src/zencompute/pathvalidation.h118
-rw-r--r--src/zencompute/runners/functionrunner.cpp132
-rw-r--r--src/zencompute/runners/functionrunner.h27
-rw-r--r--src/zencompute/runners/linuxrunner.cpp22
-rw-r--r--src/zencompute/runners/localrunner.cpp40
-rw-r--r--src/zencompute/runners/localrunner.h4
-rw-r--r--src/zencompute/runners/macrunner.cpp18
-rw-r--r--src/zencompute/runners/managedrunner.cpp279
-rw-r--r--src/zencompute/runners/managedrunner.h64
-rw-r--r--src/zencompute/runners/remotehttprunner.cpp366
-rw-r--r--src/zencompute/runners/remotehttprunner.h14
-rw-r--r--src/zencompute/runners/windowsrunner.cpp105
-rw-r--r--src/zencompute/runners/windowsrunner.h1
-rw-r--r--src/zencompute/runners/winerunner.cpp6
27 files changed, 1716 insertions, 506 deletions
diff --git a/src/zencompute/CLAUDE.md b/src/zencompute/CLAUDE.md
index a1a39fc3c..bb574edc2 100644
--- a/src/zencompute/CLAUDE.md
+++ b/src/zencompute/CLAUDE.md
@@ -141,7 +141,7 @@ Actions that fail or are abandoned can be automatically retried or manually resc
**Manual retry (API path):** `POST /compute/jobs/{lsn}` calls `RescheduleAction()`, which finds the action in `m_ResultsMap`, validates state (must be Failed or Abandoned), checks the retry limit, reverses queue counters (moving the LSN from `FinishedLsns` back to `ActiveLsns`), removes from results, and calls `ResetActionStateToPending()`. Returns 200 with `{lsn, retry_count}` on success, 409 Conflict with `{error}` on failure.
-**Retry limit:** Default of 3, overridable per-queue via the `max_retries` integer field in the queue's `Config` CbObject (set at `CreateQueue` time). Both automatic and manual paths respect this limit.
+**Retry limit:** Default of 3, overridable per-queue via the `max_retries` integer field in the queue's `Config` CbObject (set at `CreateQueue` time). Setting `max_retries=0` disables automatic retry entirely; omitting the field (or setting it to a negative value) uses the default of 3. Both automatic and manual paths respect this limit.
**Retraction (API path):** `RetractAction(Lsn)` pulls a Pending/Submitting/Running action back for rescheduling on a different runner. The action transitions to Retracted, then `ResetActionStateToPending()` is called *without* incrementing `RetryCount`. Retraction is idempotent.
@@ -156,7 +156,7 @@ Queues group actions from a single client session. A `QueueEntry` (internal) tra
- `ActiveLsns` — for cancellation lookup (under `m_Lock`)
- `FinishedLsns` — moved here when actions complete
- `IdleSince` — used for 15-minute automatic expiry
-- `Config` — CbObject set at creation; supports `max_retries` (int) to override the default retry limit
+- `Config` — CbObject set at creation; supports `max_retries` (int, default 3) to override the default retry limit. `0` = no retries, negative or absent = use default
**Queue state machine (`QueueState` enum):**
```
@@ -216,11 +216,14 @@ Worker handler logic is extracted into private helpers (`HandleWorkersGet`, `Han
## Concurrency Model
-**Locking discipline:** When multiple locks must be held simultaneously, always acquire in this order to prevent deadlocks:
-1. `m_ResultsLock`
-2. `m_RunningLock` (comment in localrunner.h: "must be taken *after* m_ResultsLock")
-3. `m_PendingLock`
-4. `m_QueueLock`
+**Locking discipline:** The three action maps (`m_PendingActions`, `m_RunningMap`, `m_ResultsMap`) are guarded by a single `m_ActionMapLock`. This eliminates lock-ordering concerns between maps and prevents actions from being temporarily absent from all maps during state transitions. Runner-level `m_RunningLock` in `LocalProcessRunner` / `RemoteHttpRunner` is a separate lock on a different class — unrelated to the session-level action map lock.
+
+**Lock ordering:** When acquiring multiple session-level locks, always acquire in this order to avoid deadlocks:
+1. `m_ActionMapLock` (session action maps)
+2. `QueueEntry::m_Lock` (per-queue state)
+3. `m_ActionHistoryLock` (action history ring)
+
+Never acquire an earlier lock while holding a later one (e.g. never acquire `m_ActionMapLock` while holding `QueueEntry::m_Lock`).
**Atomic fields** for counters and simple state: queue counts, `CpuUsagePercent`, `CpuSeconds`, `RetryCount`, `RunnerAction::m_ActionState`.
diff --git a/src/zencompute/cloudmetadata.cpp b/src/zencompute/cloudmetadata.cpp
index eb4c05f9f..f1df18e8e 100644
--- a/src/zencompute/cloudmetadata.cpp
+++ b/src/zencompute/cloudmetadata.cpp
@@ -183,7 +183,7 @@ CloudMetadata::TryDetectAWS()
m_Info.AvailabilityZone = std::string(AzResponse.AsText());
}
- // "spot" vs "on-demand" — determines whether the instance can be
+ // "spot" vs "on-demand" - determines whether the instance can be
// reclaimed by AWS with a 2-minute warning
HttpClient::Response LifecycleResponse = ImdsClient.Get("/latest/meta-data/instance-life-cycle", AuthHeaders);
if (LifecycleResponse.IsSuccess())
@@ -273,7 +273,7 @@ CloudMetadata::TryDetectAzure()
std::string Priority = Compute["priority"].string_value();
m_Info.IsSpot = (Priority == "Spot");
- // Check if part of a VMSS (Virtual Machine Scale Set) — indicates autoscaling
+ // Check if part of a VMSS (Virtual Machine Scale Set) - indicates autoscaling
std::string VmssName = Compute["vmScaleSetName"].string_value();
m_Info.IsAutoscaling = !VmssName.empty();
@@ -609,7 +609,7 @@ namespace zen::compute {
TEST_SUITE_BEGIN("compute.cloudmetadata");
// ---------------------------------------------------------------------------
-// Test helper — spins up a local ASIO HTTP server hosting a MockImdsService
+// Test helper - spins up a local ASIO HTTP server hosting a MockImdsService
// ---------------------------------------------------------------------------
struct TestImdsServer
@@ -974,7 +974,7 @@ TEST_CASE("cloudmetadata.sentinel_files")
SUBCASE("only failed providers get sentinels")
{
- // Switch to AWS — Azure and GCP never probed, so no sentinels for them
+ // Switch to AWS - Azure and GCP never probed, so no sentinels for them
Imds.Mock.ActiveProvider = CloudProvider::AWS;
auto Cloud = Imds.CreateCloud();
diff --git a/src/zencompute/computeservice.cpp b/src/zencompute/computeservice.cpp
index 92901de64..7f354a51c 100644
--- a/src/zencompute/computeservice.cpp
+++ b/src/zencompute/computeservice.cpp
@@ -8,6 +8,8 @@
# include "recording/actionrecorder.h"
# include "runners/localrunner.h"
# include "runners/remotehttprunner.h"
+# include "runners/managedrunner.h"
+# include "pathvalidation.h"
# if ZEN_PLATFORM_LINUX
# include "runners/linuxrunner.h"
# elif ZEN_PLATFORM_WINDOWS
@@ -119,6 +121,8 @@ struct ComputeServiceSession::Impl
, m_LocalSubmitPool(GetLargeWorkerPool(EWorkloadType::Burst))
, m_RemoteSubmitPool(GetLargeWorkerPool(EWorkloadType::Burst))
{
+ m_RemoteRunnerGroup.SetWorkerPool(&m_RemoteSubmitPool);
+
// Create a non-expiring, non-deletable implicit queue for legacy endpoints
auto Result = CreateQueue("implicit"sv, {}, {});
m_ImplicitQueueId = Result.QueueId;
@@ -195,13 +199,9 @@ struct ComputeServiceSession::Impl
std::atomic<IComputeCompletionObserver*> m_CompletionObserver{nullptr};
- RwLock m_PendingLock;
- std::map<int, Ref<RunnerAction>> m_PendingActions;
-
- RwLock m_RunningLock;
+ RwLock m_ActionMapLock; // Guards m_PendingActions, m_RunningMap, m_ResultsMap
+ std::map<int, Ref<RunnerAction>> m_PendingActions;
std::unordered_map<int, Ref<RunnerAction>> m_RunningMap;
-
- RwLock m_ResultsLock;
std::unordered_map<int, Ref<RunnerAction>> m_ResultsMap;
metrics::Meter m_ResultRate;
std::atomic<uint64_t> m_RetiredCount{0};
@@ -242,8 +242,9 @@ struct ComputeServiceSession::Impl
// Recording
- void StartRecording(ChunkResolver& InCidStore, const std::filesystem::path& RecordingPath);
- void StopRecording();
+ bool StartRecording(ChunkResolver& InCidStore, const std::filesystem::path& RecordingPath);
+ bool StopRecording();
+ bool IsRecording() const;
std::unique_ptr<ActionRecorder> m_Recorder;
@@ -343,9 +344,12 @@ struct ComputeServiceSession::Impl
ActionCounts GetActionCounts()
{
ActionCounts Counts;
- Counts.Pending = (int)m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); });
- Counts.Running = (int)m_RunningLock.WithSharedLock([&] { return m_RunningMap.size(); });
- Counts.Completed = (int)m_ResultsLock.WithSharedLock([&] { return m_ResultsMap.size(); }) + (int)m_RetiredCount.load();
+ m_ActionMapLock.WithSharedLock([&] {
+ Counts.Pending = (int)m_PendingActions.size();
+ Counts.Running = (int)m_RunningMap.size();
+ Counts.Completed = (int)m_ResultsMap.size();
+ });
+ Counts.Completed += (int)m_RetiredCount.load();
Counts.ActiveQueues = (int)m_QueueLock.WithSharedLock([&] {
size_t Count = 0;
for (const auto& [Id, Queue] : m_Queues)
@@ -364,8 +368,10 @@ struct ComputeServiceSession::Impl
{
Cbo << "session_state"sv << ToString(m_SessionState.load(std::memory_order_relaxed));
m_WorkerLock.WithSharedLock([&] { Cbo << "worker_count"sv << m_WorkerMap.size(); });
- m_ResultsLock.WithSharedLock([&] { Cbo << "actions_complete"sv << m_ResultsMap.size(); });
- m_PendingLock.WithSharedLock([&] { Cbo << "actions_pending"sv << m_PendingActions.size(); });
+ m_ActionMapLock.WithSharedLock([&] {
+ Cbo << "actions_complete"sv << m_ResultsMap.size();
+ Cbo << "actions_pending"sv << m_PendingActions.size();
+ });
Cbo << "actions_submitted"sv << GetSubmittedActionCount();
EmitSnapshot("actions_arrival"sv, m_ArrivalRate, Cbo);
EmitSnapshot("actions_retired"sv, m_ResultRate, Cbo);
@@ -443,34 +449,24 @@ ComputeServiceSession::Impl::RequestStateTransition(SessionState NewState)
return true;
}
- // CAS failed, Current was updated — retry with the new value
+ // CAS failed, Current was updated - retry with the new value
}
}
void
ComputeServiceSession::Impl::AbandonAllActions()
{
- // Collect all pending actions and mark them as Abandoned
+ // Collect all pending and running actions under a single lock scope
std::vector<Ref<RunnerAction>> PendingToAbandon;
+ std::vector<Ref<RunnerAction>> RunningToAbandon;
- m_PendingLock.WithSharedLock([&] {
+ m_ActionMapLock.WithSharedLock([&] {
PendingToAbandon.reserve(m_PendingActions.size());
for (auto& [Lsn, Action] : m_PendingActions)
{
PendingToAbandon.push_back(Action);
}
- });
-
- for (auto& Action : PendingToAbandon)
- {
- Action->SetActionState(RunnerAction::State::Abandoned);
- }
-
- // Collect all running actions and mark them as Abandoned, then
- // best-effort cancel via the local runner group
- std::vector<Ref<RunnerAction>> RunningToAbandon;
- m_RunningLock.WithSharedLock([&] {
RunningToAbandon.reserve(m_RunningMap.size());
for (auto& [Lsn, Action] : m_RunningMap)
{
@@ -478,6 +474,11 @@ ComputeServiceSession::Impl::AbandonAllActions()
}
});
+ for (auto& Action : PendingToAbandon)
+ {
+ Action->SetActionState(RunnerAction::State::Abandoned);
+ }
+
for (auto& Action : RunningToAbandon)
{
Action->SetActionState(RunnerAction::State::Abandoned);
@@ -617,6 +618,7 @@ ComputeServiceSession::Impl::UpdateCoordinatorState()
m_KnownWorkerUris.insert(UriStr);
auto* NewRunner = new RemoteHttpRunner(m_ChunkResolver, m_OrchestratorBasePath, UriStr, m_RemoteSubmitPool);
+ NewRunner->SetRemoteHostname(Hostname);
SyncWorkersToRunner(*NewRunner);
m_RemoteRunnerGroup.AddRunner(NewRunner);
}
@@ -718,31 +720,51 @@ ComputeServiceSession::Impl::ShutdownRunners()
m_RemoteRunnerGroup.Shutdown();
}
-void
+bool
ComputeServiceSession::Impl::StartRecording(ChunkResolver& InCidStore, const std::filesystem::path& RecordingPath)
{
+ if (m_Recorder)
+ {
+ ZEN_WARN("recording is already active");
+ return false;
+ }
+
ZEN_INFO("starting recording to '{}'", RecordingPath);
m_Recorder = std::make_unique<ActionRecorder>(InCidStore, RecordingPath);
ZEN_INFO("started recording to '{}'", RecordingPath);
+ return true;
}
-void
+bool
ComputeServiceSession::Impl::StopRecording()
{
+ if (!m_Recorder)
+ {
+ ZEN_WARN("no recording is active");
+ return false;
+ }
+
ZEN_INFO("stopping recording");
m_Recorder = nullptr;
ZEN_INFO("stopped recording");
+ return true;
+}
+
+bool
+ComputeServiceSession::Impl::IsRecording() const
+{
+ return m_Recorder != nullptr;
}
std::vector<ComputeServiceSession::RunningActionInfo>
ComputeServiceSession::Impl::GetRunningActions()
{
std::vector<ComputeServiceSession::RunningActionInfo> Result;
- m_RunningLock.WithSharedLock([&] {
+ m_ActionMapLock.WithSharedLock([&] {
Result.reserve(m_RunningMap.size());
for (const auto& [Lsn, Action] : m_RunningMap)
{
@@ -810,6 +832,11 @@ void
ComputeServiceSession::Impl::RegisterWorker(CbPackage Worker)
{
ZEN_TRACE_CPU("ComputeServiceSession::RegisterWorker");
+
+ // Validate all paths in the worker description upfront, before the worker is
+ // distributed to runners. This rejects malicious packages early at ingestion time.
+ ValidateWorkerDescriptionPaths(Worker.GetObject());
+
RwLock::ExclusiveLockScope _(m_WorkerLock);
const IoHash& WorkerId = Worker.GetObject().GetHash();
@@ -994,10 +1021,15 @@ ComputeServiceSession::Impl::EnqueueResolvedAction(int QueueId, WorkerDesc Worke
Pending->ActionObj = ActionObj;
Pending->Priority = RequestPriority;
- // For now simply put action into pending state, so we can do batch scheduling
+ // Insert into the pending map immediately so the action is visible to
+ // FindActionResult/GetActionResult right away. SetActionState will call
+ // PostUpdate which adds the action to m_UpdatedActions and signals the
+ // scheduler, but the scheduler's HandleActionUpdates inserts with
+ // std::map::insert which is a no-op for existing keys.
ZEN_DEBUG("action {} ({}) PENDING", Pending->ActionId, Pending->ActionLsn);
+ m_ActionMapLock.WithExclusiveLock([&] { m_PendingActions.insert({ActionLsn, Pending}); });
Pending->SetActionState(RunnerAction::State::Pending);
if (m_Recorder)
@@ -1043,11 +1075,7 @@ ComputeServiceSession::Impl::GetSubmittedActionCount()
HttpResponseCode
ComputeServiceSession::Impl::GetActionResult(int ActionLsn, CbPackage& OutResultPackage)
{
- // This lock is held for the duration of the function since we need to
- // be sure that the action doesn't change state while we are checking the
- // different data structures
-
- RwLock::ExclusiveLockScope _(m_ResultsLock);
+ RwLock::ExclusiveLockScope _(m_ActionMapLock);
if (auto It = m_ResultsMap.find(ActionLsn); It != m_ResultsMap.end())
{
@@ -1058,25 +1086,14 @@ ComputeServiceSession::Impl::GetActionResult(int ActionLsn, CbPackage& OutResult
return HttpResponseCode::OK;
}
+ if (m_PendingActions.find(ActionLsn) != m_PendingActions.end())
{
- RwLock::SharedLockScope __(m_PendingLock);
-
- if (auto FindIt = m_PendingActions.find(ActionLsn); FindIt != m_PendingActions.end())
- {
- return HttpResponseCode::Accepted;
- }
+ return HttpResponseCode::Accepted;
}
- // Lock order is important here to avoid deadlocks, RwLock m_RunningLock must
- // always be taken after m_ResultsLock if both are needed
-
+ if (m_RunningMap.find(ActionLsn) != m_RunningMap.end())
{
- RwLock::SharedLockScope __(m_RunningLock);
-
- if (m_RunningMap.find(ActionLsn) != m_RunningMap.end())
- {
- return HttpResponseCode::Accepted;
- }
+ return HttpResponseCode::Accepted;
}
return HttpResponseCode::NotFound;
@@ -1085,11 +1102,7 @@ ComputeServiceSession::Impl::GetActionResult(int ActionLsn, CbPackage& OutResult
HttpResponseCode
ComputeServiceSession::Impl::FindActionResult(const IoHash& ActionId, CbPackage& OutResultPackage)
{
- // This lock is held for the duration of the function since we need to
- // be sure that the action doesn't change state while we are checking the
- // different data structures
-
- RwLock::ExclusiveLockScope _(m_ResultsLock);
+ RwLock::ExclusiveLockScope _(m_ActionMapLock);
for (auto It = begin(m_ResultsMap), End = end(m_ResultsMap); It != End; ++It)
{
@@ -1103,30 +1116,19 @@ ComputeServiceSession::Impl::FindActionResult(const IoHash& ActionId, CbPackage&
}
}
+ for (const auto& [K, Pending] : m_PendingActions)
{
- RwLock::SharedLockScope __(m_PendingLock);
-
- for (const auto& [K, Pending] : m_PendingActions)
+ if (Pending->ActionId == ActionId)
{
- if (Pending->ActionId == ActionId)
- {
- return HttpResponseCode::Accepted;
- }
+ return HttpResponseCode::Accepted;
}
}
- // Lock order is important here to avoid deadlocks, RwLock m_RunningLock must
- // always be taken after m_ResultsLock if both are needed
-
+ for (const auto& [K, v] : m_RunningMap)
{
- RwLock::SharedLockScope __(m_RunningLock);
-
- for (const auto& [K, v] : m_RunningMap)
+ if (v->ActionId == ActionId)
{
- if (v->ActionId == ActionId)
- {
- return HttpResponseCode::Accepted;
- }
+ return HttpResponseCode::Accepted;
}
}
@@ -1144,12 +1146,16 @@ ComputeServiceSession::Impl::GetCompleted(CbWriter& Cbo)
{
Cbo.BeginArray("completed");
- m_ResultsLock.WithSharedLock([&] {
+ m_ActionMapLock.WithSharedLock([&] {
for (auto& [Lsn, Action] : m_ResultsMap)
{
Cbo.BeginObject();
Cbo << "lsn"sv << Lsn;
Cbo << "state"sv << RunnerAction::ToString(Action->ActionState());
+ if (!Action->FailureReason.empty())
+ {
+ Cbo << "reason"sv << Action->FailureReason;
+ }
Cbo.EndObject();
}
});
@@ -1275,20 +1281,14 @@ ComputeServiceSession::Impl::CancelQueue(int QueueId)
std::vector<Ref<RunnerAction>> PendingActionsToCancel;
std::vector<int> RunningLsnsToCancel;
- m_PendingLock.WithSharedLock([&] {
+ m_ActionMapLock.WithSharedLock([&] {
for (int Lsn : LsnsToCancel)
{
if (auto It = m_PendingActions.find(Lsn); It != m_PendingActions.end())
{
PendingActionsToCancel.push_back(It->second);
}
- }
- });
-
- m_RunningLock.WithSharedLock([&] {
- for (int Lsn : LsnsToCancel)
- {
- if (m_RunningMap.find(Lsn) != m_RunningMap.end())
+ else if (m_RunningMap.find(Lsn) != m_RunningMap.end())
{
RunningLsnsToCancel.push_back(Lsn);
}
@@ -1307,7 +1307,7 @@ ComputeServiceSession::Impl::CancelQueue(int QueueId)
// transition from the runner is blocked (Cancelled > Failed in the enum).
for (int Lsn : RunningLsnsToCancel)
{
- m_RunningLock.WithSharedLock([&] {
+ m_ActionMapLock.WithSharedLock([&] {
if (auto It = m_RunningMap.find(Lsn); It != m_RunningMap.end())
{
It->second->SetActionState(RunnerAction::State::Cancelled);
@@ -1444,8 +1444,8 @@ ComputeServiceSession::Impl::GetQueueCompleted(int QueueId, CbWriter& Cbo)
if (Queue)
{
- Queue->m_Lock.WithSharedLock([&] {
- m_ResultsLock.WithSharedLock([&] {
+ m_ActionMapLock.WithSharedLock([&] {
+ Queue->m_Lock.WithSharedLock([&] {
for (int Lsn : Queue->FinishedLsns)
{
if (m_ResultsMap.contains(Lsn))
@@ -1475,15 +1475,19 @@ ComputeServiceSession::Impl::NotifyQueueActionComplete(int QueueId, int Lsn, Run
return;
}
+ bool WasActive = false;
Queue->m_Lock.WithExclusiveLock([&] {
- Queue->ActiveLsns.erase(Lsn);
+ WasActive = Queue->ActiveLsns.erase(Lsn) > 0;
Queue->FinishedLsns.insert(Lsn);
});
- const int PreviousActive = Queue->ActiveCount.fetch_sub(1, std::memory_order_relaxed);
- if (PreviousActive == 1)
+ if (WasActive)
{
- Queue->IdleSince.store(GetHifreqTimerValue(), std::memory_order_relaxed);
+ const int PreviousActive = Queue->ActiveCount.fetch_sub(1, std::memory_order_relaxed);
+ if (PreviousActive == 1)
+ {
+ Queue->IdleSince.store(GetHifreqTimerValue(), std::memory_order_relaxed);
+ }
}
switch (ActionState)
@@ -1541,26 +1545,32 @@ ComputeServiceSession::Impl::SchedulePendingActions()
{
ZEN_TRACE_CPU("ComputeServiceSession::SchedulePendingActions");
int ScheduledCount = 0;
- size_t RunningCount = m_RunningLock.WithSharedLock([&] { return m_RunningMap.size(); });
- size_t PendingCount = m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); });
- size_t ResultCount = m_ResultsLock.WithSharedLock([&] { return m_ResultsMap.size(); });
+ size_t RunningCount = 0;
+ size_t PendingCount = 0;
+ size_t ResultCount = 0;
+
+ m_ActionMapLock.WithSharedLock([&] {
+ RunningCount = m_RunningMap.size();
+ PendingCount = m_PendingActions.size();
+ ResultCount = m_ResultsMap.size();
+ });
static Stopwatch DumpRunningTimer;
auto _ = MakeGuard([&] {
- ZEN_INFO("scheduled {} pending actions. {} running ({} retired), {} still pending, {} results",
- ScheduledCount,
- RunningCount,
- m_RetiredCount.load(),
- PendingCount,
- ResultCount);
+ ZEN_DEBUG("scheduled {} pending actions. {} running ({} retired), {} still pending, {} results",
+ ScheduledCount,
+ RunningCount,
+ m_RetiredCount.load(),
+ PendingCount,
+ ResultCount);
if (DumpRunningTimer.GetElapsedTimeMs() > 30000)
{
DumpRunningTimer.Reset();
std::set<int> RunningList;
- m_RunningLock.WithSharedLock([&] {
+ m_ActionMapLock.WithSharedLock([&] {
for (auto& [K, V] : m_RunningMap)
{
RunningList.insert(K);
@@ -1602,13 +1612,13 @@ ComputeServiceSession::Impl::SchedulePendingActions()
// Also note that the m_PendingActions list is not maintained
// here, that's done periodically in SchedulePendingActions()
- m_PendingLock.WithExclusiveLock([&] {
- if (m_SessionState.load(std::memory_order_relaxed) >= SessionState::Paused)
- {
- return;
- }
+ // Extract pending actions under a shared lock - we only need to read
+ // the map and take Ref copies. ActionState() is atomic so this is safe.
+ // Sorting and capacity trimming happen outside the lock to avoid
+ // blocking HTTP handlers on O(N log N) work with large pending queues.
- if (m_PendingActions.empty())
+ m_ActionMapLock.WithSharedLock([&] {
+ if (m_SessionState.load(std::memory_order_relaxed) >= SessionState::Paused)
{
return;
}
@@ -1628,6 +1638,7 @@ ComputeServiceSession::Impl::SchedulePendingActions()
case RunnerAction::State::Completed:
case RunnerAction::State::Failed:
case RunnerAction::State::Abandoned:
+ case RunnerAction::State::Rejected:
case RunnerAction::State::Cancelled:
break;
@@ -1638,30 +1649,30 @@ ComputeServiceSession::Impl::SchedulePendingActions()
}
}
- // Sort by priority descending, then by LSN ascending (FIFO within same priority)
- std::sort(ActionsToSchedule.begin(), ActionsToSchedule.end(), [](const Ref<RunnerAction>& A, const Ref<RunnerAction>& B) {
- if (A->Priority != B->Priority)
- {
- return A->Priority > B->Priority;
- }
- return A->ActionLsn < B->ActionLsn;
- });
+ PendingCount = m_PendingActions.size();
+ });
- if (ActionsToSchedule.size() > Capacity)
+ // Sort by priority descending, then by LSN ascending (FIFO within same priority)
+ std::sort(ActionsToSchedule.begin(), ActionsToSchedule.end(), [](const Ref<RunnerAction>& A, const Ref<RunnerAction>& B) {
+ if (A->Priority != B->Priority)
{
- ActionsToSchedule.resize(Capacity);
+ return A->Priority > B->Priority;
}
-
- PendingCount = m_PendingActions.size();
+ return A->ActionLsn < B->ActionLsn;
});
+ if (ActionsToSchedule.size() > Capacity)
+ {
+ ActionsToSchedule.resize(Capacity);
+ }
+
if (ActionsToSchedule.empty())
{
_.Dismiss();
return;
}
- ZEN_INFO("attempting schedule of {} pending actions", ActionsToSchedule.size());
+ ZEN_DEBUG("attempting schedule of {} pending actions", ActionsToSchedule.size());
Stopwatch SubmitTimer;
std::vector<SubmitResult> SubmitResults = SubmitActions(ActionsToSchedule);
@@ -1681,10 +1692,10 @@ ComputeServiceSession::Impl::SchedulePendingActions()
}
}
- ZEN_INFO("scheduled {} pending actions in {} ({} rejected)",
- ScheduledActionCount,
- NiceTimeSpanMs(SubmitTimer.GetElapsedTimeMs()),
- NotAcceptedCount);
+ ZEN_DEBUG("scheduled {} pending actions in {} ({} rejected)",
+ ScheduledActionCount,
+ NiceTimeSpanMs(SubmitTimer.GetElapsedTimeMs()),
+ NotAcceptedCount);
ScheduledCount += ScheduledActionCount;
PendingCount -= ScheduledActionCount;
@@ -1701,7 +1712,7 @@ ComputeServiceSession::Impl::SchedulerThreadFunction()
{
int TimeoutMs = 500;
- auto PendingCount = m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); });
+ auto PendingCount = m_ActionMapLock.WithSharedLock([&] { return m_PendingActions.size(); });
if (PendingCount)
{
@@ -1720,22 +1731,22 @@ ComputeServiceSession::Impl::SchedulerThreadFunction()
m_SchedulingThreadEvent.Reset();
}
- ZEN_DEBUG("compute scheduler TICK (Pending: {} was {}, Running: {}, Results: {}) timeout: {}",
- m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); }),
- PendingCount,
- m_RunningLock.WithSharedLock([&] { return m_RunningMap.size(); }),
- m_ResultsLock.WithSharedLock([&] { return m_ResultsMap.size(); }),
- TimeoutMs);
+ m_ActionMapLock.WithSharedLock([&] {
+ ZEN_DEBUG("compute scheduler TICK (Pending: {}, Running: {}, Results: {}) timeout: {}",
+ m_PendingActions.size(),
+ m_RunningMap.size(),
+ m_ResultsMap.size(),
+ TimeoutMs);
+ });
HandleActionUpdates();
- // Auto-transition Draining → Paused when all work is done
+ // Auto-transition Draining -> Paused when all work is done
if (m_SessionState.load(std::memory_order_relaxed) == SessionState::Draining)
{
- size_t Pending = m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); });
- size_t Running = m_RunningLock.WithSharedLock([&] { return m_RunningMap.size(); });
+ bool AllDrained = m_ActionMapLock.WithSharedLock([&] { return m_PendingActions.empty() && m_RunningMap.empty(); });
- if (Pending == 0 && Running == 0)
+ if (AllDrained)
{
SessionState Expected = SessionState::Draining;
if (m_SessionState.compare_exchange_strong(Expected, SessionState::Paused, std::memory_order_acq_rel))
@@ -1776,9 +1787,9 @@ ComputeServiceSession::Impl::GetMaxRetriesForQueue(int QueueId)
if (Config)
{
- int Value = Config["max_retries"].AsInt32(0);
+ int Value = Config["max_retries"].AsInt32(-1);
- if (Value > 0)
+ if (Value >= 0)
{
return Value;
}
@@ -1797,7 +1808,7 @@ ComputeServiceSession::Impl::RescheduleAction(int ActionLsn)
// Find, validate, and remove atomically under a single lock scope to prevent
// concurrent RescheduleAction calls from double-removing the same action.
- m_ResultsLock.WithExclusiveLock([&] {
+ m_ActionMapLock.WithExclusiveLock([&] {
auto It = m_ResultsMap.find(ActionLsn);
if (It == m_ResultsMap.end())
{
@@ -1855,7 +1866,7 @@ ComputeServiceSession::Impl::RescheduleAction(int ActionLsn)
}
}
- // Reset action state — this calls PostUpdate() internally
+ // Reset action state - this calls PostUpdate() internally
Action->ResetActionStateToPending();
int NewRetryCount = Action->RetryCount.load(std::memory_order_relaxed);
@@ -1871,26 +1882,20 @@ ComputeServiceSession::Impl::RetractAction(int ActionLsn)
bool WasRunning = false;
// Look for the action in pending or running maps
- m_RunningLock.WithSharedLock([&] {
+ m_ActionMapLock.WithSharedLock([&] {
if (auto It = m_RunningMap.find(ActionLsn); It != m_RunningMap.end())
{
Action = It->second;
WasRunning = true;
}
+ else if (auto PIt = m_PendingActions.find(ActionLsn); PIt != m_PendingActions.end())
+ {
+ Action = PIt->second;
+ }
});
if (!Action)
{
- m_PendingLock.WithSharedLock([&] {
- if (auto It = m_PendingActions.find(ActionLsn); It != m_PendingActions.end())
- {
- Action = It->second;
- }
- });
- }
-
- if (!Action)
- {
return {.Success = false, .Error = "Action not found in pending or running maps"};
}
@@ -1912,18 +1917,15 @@ ComputeServiceSession::Impl::RetractAction(int ActionLsn)
void
ComputeServiceSession::Impl::RemoveActionFromActiveMaps(int ActionLsn)
{
- m_RunningLock.WithExclusiveLock([&] {
- m_PendingLock.WithExclusiveLock([&] {
- if (auto FindIt = m_RunningMap.find(ActionLsn); FindIt == m_RunningMap.end())
- {
- m_PendingActions.erase(ActionLsn);
- }
- else
- {
- m_RunningMap.erase(FindIt);
- }
- });
- });
+ // Caller must hold m_ActionMapLock exclusively.
+ if (auto FindIt = m_RunningMap.find(ActionLsn); FindIt == m_RunningMap.end())
+ {
+ m_PendingActions.erase(ActionLsn);
+ }
+ else
+ {
+ m_RunningMap.erase(FindIt);
+ }
}
void
@@ -1946,7 +1948,7 @@ ComputeServiceSession::Impl::HandleActionUpdates()
// This is safe because state transitions are monotonically increasing by enum
// rank (Pending < Submitting < Running < Completed/Failed/Cancelled), so
// SetActionState rejects any transition to a lower-ranked state. By the time
- // we read ActionState() here, it reflects the highest state reached — making
+ // we read ActionState() here, it reflects the highest state reached - making
// the first occurrence per LSN authoritative and duplicates redundant.
for (Ref<RunnerAction>& Action : UpdatedActions)
{
@@ -1956,7 +1958,7 @@ ComputeServiceSession::Impl::HandleActionUpdates()
{
switch (Action->ActionState())
{
- // Newly enqueued — add to pending map for scheduling
+ // Newly enqueued - add to pending map for scheduling
case RunnerAction::State::Pending:
// Guard against a race where the session is abandoned between
// EnqueueAction (which calls PostUpdate) and this scheduler
@@ -1973,35 +1975,44 @@ ComputeServiceSession::Impl::HandleActionUpdates()
}
else
{
- m_PendingLock.WithExclusiveLock([&] { m_PendingActions.insert({ActionLsn, Action}); });
+ m_ActionMapLock.WithExclusiveLock([&] { m_PendingActions.insert({ActionLsn, Action}); });
}
break;
- // Async submission in progress — remains in pending map
+ // Async submission in progress - remains in pending map
case RunnerAction::State::Submitting:
break;
- // Dispatched to a runner — move from pending to running
+ // Dispatched to a runner - move from pending to running
case RunnerAction::State::Running:
- m_RunningLock.WithExclusiveLock([&] {
- m_PendingLock.WithExclusiveLock([&] {
- m_RunningMap.insert({ActionLsn, Action});
- m_PendingActions.erase(ActionLsn);
- });
+ m_ActionMapLock.WithExclusiveLock([&] {
+ m_RunningMap.insert({ActionLsn, Action});
+ m_PendingActions.erase(ActionLsn);
});
ZEN_DEBUG("action {} ({}) RUNNING", Action->ActionId, ActionLsn);
break;
- // Retracted — pull back for rescheduling without counting against retry limit
+ // Retracted - pull back for rescheduling without counting against retry limit
case RunnerAction::State::Retracted:
{
- RemoveActionFromActiveMaps(ActionLsn);
+ m_ActionMapLock.WithExclusiveLock([&] {
+ m_RunningMap.erase(ActionLsn);
+ m_PendingActions[ActionLsn] = Action;
+ });
Action->ResetActionStateToPending();
ZEN_INFO("action {} ({}) retracted for rescheduling", Action->ActionId, ActionLsn);
break;
}
- // Terminal states — move to results, record history, notify queue
+ // Rejected - runner was at capacity, reschedule without retry cost
+ case RunnerAction::State::Rejected:
+ {
+ Action->ResetActionStateToPending();
+ ZEN_DEBUG("action {} ({}) rescheduled after runner rejection", Action->ActionId, ActionLsn);
+ break;
+ }
+
+ // Terminal states - move to results, record history, notify queue
case RunnerAction::State::Completed:
case RunnerAction::State::Failed:
case RunnerAction::State::Abandoned:
@@ -2010,7 +2021,7 @@ ComputeServiceSession::Impl::HandleActionUpdates()
auto TerminalState = Action->ActionState();
// Automatic retry for Failed/Abandoned actions with retries remaining.
- // Skip retries when the session itself is abandoned — those actions
+ // Skip retries when the session itself is abandoned - those actions
// were intentionally abandoned and should not be rescheduled.
if ((TerminalState == RunnerAction::State::Failed || TerminalState == RunnerAction::State::Abandoned) &&
m_SessionState.load(std::memory_order_relaxed) < SessionState::Abandoned)
@@ -2019,7 +2030,10 @@ ComputeServiceSession::Impl::HandleActionUpdates()
if (Action->RetryCount.load(std::memory_order_relaxed) < MaxRetries)
{
- RemoveActionFromActiveMaps(ActionLsn);
+ m_ActionMapLock.WithExclusiveLock([&] {
+ m_RunningMap.erase(ActionLsn);
+ m_PendingActions[ActionLsn] = Action;
+ });
// Reset triggers PostUpdate() which re-enters the action as Pending
Action->ResetActionStateToPending();
@@ -2032,18 +2046,26 @@ ComputeServiceSession::Impl::HandleActionUpdates()
MaxRetries);
break;
}
+ else
+ {
+ ZEN_WARN("action {} ({}) {} after {} retries, not rescheduling",
+ Action->ActionId,
+ ActionLsn,
+ RunnerAction::ToString(TerminalState),
+ Action->RetryCount.load(std::memory_order_relaxed));
+ }
}
- RemoveActionFromActiveMaps(ActionLsn);
+ m_ActionMapLock.WithExclusiveLock([&] {
+ RemoveActionFromActiveMaps(ActionLsn);
- // Update queue counters BEFORE publishing the result into
- // m_ResultsMap. GetActionResult erases from m_ResultsMap
- // under m_ResultsLock, so if we updated counters after
- // releasing that lock, a caller could observe ActiveCount
- // still at 1 immediately after GetActionResult returned OK.
- NotifyQueueActionComplete(Action->QueueId, ActionLsn, TerminalState);
+ // Update queue counters BEFORE publishing the result into
+ // m_ResultsMap. GetActionResult erases from m_ResultsMap
+ // under m_ActionMapLock, so if we updated counters after
+ // releasing that lock, a caller could observe ActiveCount
+ // still at 1 immediately after GetActionResult returned OK.
+ NotifyQueueActionComplete(Action->QueueId, ActionLsn, TerminalState);
- m_ResultsLock.WithExclusiveLock([&] {
m_ResultsMap[ActionLsn] = Action;
// Append to bounded action history ring
@@ -2124,10 +2146,9 @@ ComputeServiceSession::Impl::SubmitActions(const std::vector<Ref<RunnerAction>>&
ZEN_TRACE_CPU("ComputeServiceSession::SubmitActions");
std::vector<SubmitResult> Results(Actions.size());
- // First try submitting the batch to local runners in parallel
+ // First try submitting the batch to local runners
std::vector<SubmitResult> LocalResults = m_LocalRunnerGroup.SubmitActions(Actions);
- std::vector<size_t> RemoteIndices;
std::vector<Ref<RunnerAction>> RemoteActions;
for (size_t i = 0; i < Actions.size(); ++i)
@@ -2138,20 +2159,40 @@ ComputeServiceSession::Impl::SubmitActions(const std::vector<Ref<RunnerAction>>&
}
else
{
- RemoteIndices.push_back(i);
RemoteActions.push_back(Actions[i]);
+ Results[i] = SubmitResult{.IsAccepted = true, .Reason = "dispatched to remote"};
}
}
- // Submit remaining actions to remote runners in parallel
+ // Dispatch remaining actions to remote runners asynchronously.
+ // Mark actions as Submitting so the scheduler won't re-pick them.
+ // The remote runner will transition them to Running on success, or
+ // we mark them Failed on rejection so HandleActionUpdates retries.
if (!RemoteActions.empty())
{
- std::vector<SubmitResult> RemoteResults = m_RemoteRunnerGroup.SubmitActions(RemoteActions);
-
- for (size_t j = 0; j < RemoteIndices.size(); ++j)
+ for (const Ref<RunnerAction>& Action : RemoteActions)
{
- Results[RemoteIndices[j]] = std::move(RemoteResults[j]);
+ Action->SetActionState(RunnerAction::State::Submitting);
}
+
+ m_RemoteSubmitPool.ScheduleWork(
+ [this, RemoteActions = std::move(RemoteActions)]() {
+ std::vector<SubmitResult> RemoteResults = m_RemoteRunnerGroup.SubmitActions(RemoteActions);
+
+ for (size_t j = 0; j < RemoteResults.size(); ++j)
+ {
+ if (!RemoteResults[j].IsAccepted)
+ {
+ ZEN_DEBUG("remote submission rejected for action {} ({}): {}",
+ RemoteActions[j]->ActionId,
+ RemoteActions[j]->ActionLsn,
+ RemoteResults[j].Reason);
+
+ RemoteActions[j]->SetActionState(RunnerAction::State::Rejected);
+ }
+ }
+ },
+ WorkerThreadPool::EMode::EnableBacklog);
}
return Results;
@@ -2217,16 +2258,22 @@ ComputeServiceSession::NotifyOrchestratorChanged()
m_Impl->NotifyOrchestratorChanged();
}
-void
+bool
ComputeServiceSession::StartRecording(ChunkResolver& InResolver, const std::filesystem::path& RecordingPath)
{
- m_Impl->StartRecording(InResolver, RecordingPath);
+ return m_Impl->StartRecording(InResolver, RecordingPath);
}
-void
+bool
ComputeServiceSession::StopRecording()
{
- m_Impl->StopRecording();
+ return m_Impl->StopRecording();
+}
+
+bool
+ComputeServiceSession::IsRecording() const
+{
+ return m_Impl->IsRecording();
}
ComputeServiceSession::ActionCounts
@@ -2282,6 +2329,18 @@ ComputeServiceSession::AddLocalRunner(ChunkResolver& InChunkResolver, std::files
}
void
+ComputeServiceSession::AddManagedLocalRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath, int32_t MaxConcurrentActions)
+{
+ ZEN_TRACE_CPU("ComputeServiceSession::AddManagedLocalRunner");
+
+ auto* NewRunner =
+ new ManagedProcessRunner(InChunkResolver, BasePath, m_Impl->m_DeferredDeleter, m_Impl->m_LocalSubmitPool, MaxConcurrentActions);
+
+ m_Impl->SyncWorkersToRunner(*NewRunner);
+ m_Impl->m_LocalRunnerGroup.AddRunner(NewRunner);
+}
+
+void
ComputeServiceSession::AddRemoteRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath, std::string_view HostName)
{
ZEN_TRACE_CPU("ComputeServiceSession::AddRemoteRunner");
diff --git a/src/zencompute/httpcomputeservice.cpp b/src/zencompute/httpcomputeservice.cpp
index bdfd9d197..5ab189d89 100644
--- a/src/zencompute/httpcomputeservice.cpp
+++ b/src/zencompute/httpcomputeservice.cpp
@@ -21,12 +21,14 @@
# include <zencore/thread.h>
# include <zencore/trace.h>
# include <zencore/uid.h>
-# include <zenstore/cidstore.h>
+# include <zenstore/hashkeyset.h>
+# include <zenstore/zenstore.h>
# include <zentelemetry/stats.h>
# include <algorithm>
# include <span>
# include <unordered_map>
+# include <utility>
# include <vector>
using namespace std::literals;
@@ -45,7 +47,9 @@ auto OidMatcher = [](std::string_view Str) { return Str.size() == 24 && AsciiSe
struct HttpComputeService::Impl
{
HttpComputeService* m_Self;
- CidStore& m_CidStore;
+ ChunkStore& m_ActionStore;
+ ChunkStore& m_WorkerStore;
+ FallbackChunkResolver m_CombinedResolver;
IHttpStatsService& m_StatsService;
LoggerRef m_Log;
std::filesystem::path m_BaseDir;
@@ -58,6 +62,8 @@ struct HttpComputeService::Impl
RwLock m_WsConnectionsLock;
std::vector<Ref<WebSocketConnection>> m_WsConnections;
+ std::function<void()> m_ShutdownCallback;
+
// Metrics
metrics::OperationTiming m_HttpRequests;
@@ -72,13 +78,13 @@ struct HttpComputeService::Impl
std::string ClientHostname; // empty if no hostname was provided
};
- // Remote queue registry — all three maps share the same RemoteQueueInfo objects.
+ // Remote queue registry - all three maps share the same RemoteQueueInfo objects.
// All maps are guarded by m_RemoteQueueLock.
RwLock m_RemoteQueueLock;
- std::unordered_map<Oid, Ref<RemoteQueueInfo>, Oid::Hasher> m_RemoteQueuesByToken; // Token → info
- std::unordered_map<int, Ref<RemoteQueueInfo>> m_RemoteQueuesByQueueId; // QueueId → info
- std::unordered_map<std::string, Ref<RemoteQueueInfo>> m_RemoteQueuesByTag; // idempotency key → info
+ std::unordered_map<Oid, Ref<RemoteQueueInfo>, Oid::Hasher> m_RemoteQueuesByToken; // Token -> info
+ std::unordered_map<int, Ref<RemoteQueueInfo>> m_RemoteQueuesByQueueId; // QueueId -> info
+ std::unordered_map<std::string, Ref<RemoteQueueInfo>> m_RemoteQueuesByTag; // idempotency key -> info
LoggerRef Log() { return m_Log; }
@@ -93,34 +99,38 @@ struct HttpComputeService::Impl
uint64_t NewBytes = 0;
};
- IngestStats IngestPackageAttachments(const CbPackage& Package);
- bool CheckAttachments(const CbObject& ActionObj, std::vector<IoHash>& NeedList);
- void HandleWorkersGet(HttpServerRequest& HttpReq);
- void HandleWorkersAllGet(HttpServerRequest& HttpReq);
- void WriteQueueDescription(CbWriter& Cbo, int QueueId, const ComputeServiceSession::QueueStatus& Status);
- void HandleWorkerRequest(HttpServerRequest& HttpReq, const IoHash& WorkerId);
- void HandleSubmitAction(HttpServerRequest& HttpReq, int QueueId, int Priority, const WorkerDesc* Worker);
+ bool IngestPackageAttachments(HttpServerRequest& HttpReq, const CbPackage& Package, IngestStats& OutStats);
+ bool CheckAttachments(const CbObject& ActionObj, std::vector<IoHash>& NeedList);
+ bool ValidateAttachmentHash(HttpServerRequest& HttpReq, const CbAttachment& Attachment);
+ void HandleWorkersGet(HttpServerRequest& HttpReq);
+ void HandleWorkersAllGet(HttpServerRequest& HttpReq);
+ void WriteQueueDescription(CbWriter& Cbo, int QueueId, const ComputeServiceSession::QueueStatus& Status);
+ void HandleWorkerRequest(HttpServerRequest& HttpReq, const IoHash& WorkerId);
+ void HandleSubmitAction(HttpServerRequest& HttpReq, int QueueId, int Priority, const WorkerDesc* Worker);
// WebSocket / observer
- void OnWebSocketOpen(Ref<WebSocketConnection> Connection);
+ void OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri);
void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code);
void OnActionsCompleted(std::span<const IComputeCompletionObserver::CompletedActionNotification> Actions);
void RegisterRoutes();
- Impl(HttpComputeService* Self,
- CidStore& InCidStore,
- IHttpStatsService& StatsService,
- const std::filesystem::path& BaseDir,
- int32_t MaxConcurrentActions)
+ Impl(HttpComputeService* Self,
+ ChunkStore& InActionStore,
+ ChunkStore& InWorkerStore,
+ IHttpStatsService& StatsService,
+ std::filesystem::path BaseDir,
+ int32_t MaxConcurrentActions)
: m_Self(Self)
- , m_CidStore(InCidStore)
+ , m_ActionStore(InActionStore)
+ , m_WorkerStore(InWorkerStore)
+ , m_CombinedResolver(InActionStore, InWorkerStore)
, m_StatsService(StatsService)
, m_Log(logging::Get("compute"))
- , m_BaseDir(BaseDir)
- , m_ComputeService(InCidStore)
+ , m_BaseDir(std::move(BaseDir))
+ , m_ComputeService(m_CombinedResolver)
{
- m_ComputeService.AddLocalRunner(InCidStore, m_BaseDir / "local", MaxConcurrentActions);
+ m_ComputeService.AddLocalRunner(m_CombinedResolver, m_BaseDir / "local", MaxConcurrentActions);
m_ComputeService.WaitUntilReady();
m_StatsService.RegisterHandler("compute", *m_Self);
RegisterRoutes();
@@ -182,6 +192,65 @@ HttpComputeService::Impl::RegisterRoutes()
HttpVerb::kPost);
m_Router.RegisterRoute(
+ "session/drain",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+
+ if (m_ComputeService.RequestStateTransition(ComputeServiceSession::SessionState::Draining))
+ {
+ CbObjectWriter Cbo;
+ Cbo << "state"sv << ToString(m_ComputeService.GetSessionState());
+ return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save());
+ }
+
+ CbObjectWriter Cbo;
+ Cbo << "error"sv
+ << "Cannot transition to Draining from current state"sv;
+ HttpReq.WriteResponse(HttpResponseCode::Conflict, Cbo.Save());
+ },
+ HttpVerb::kPost);
+
+ m_Router.RegisterRoute(
+ "session/status",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+
+ CbObjectWriter Cbo;
+ Cbo << "state"sv << ToString(m_ComputeService.GetSessionState());
+ auto Counts = m_ComputeService.GetActionCounts();
+ Cbo << "actions_pending"sv << Counts.Pending;
+ Cbo << "actions_running"sv << Counts.Running;
+ Cbo << "actions_completed"sv << Counts.Completed;
+ HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save());
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "session/sunset",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+
+ if (m_ComputeService.RequestStateTransition(ComputeServiceSession::SessionState::Sunset))
+ {
+ CbObjectWriter Cbo;
+ Cbo << "state"sv << ToString(m_ComputeService.GetSessionState());
+ HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save());
+
+ if (m_ShutdownCallback)
+ {
+ m_ShutdownCallback();
+ }
+ return;
+ }
+
+ CbObjectWriter Cbo;
+ Cbo << "error"sv
+ << "Cannot transition to Sunset from current state"sv;
+ HttpReq.WriteResponse(HttpResponseCode::Conflict, Cbo.Save());
+ },
+ HttpVerb::kPost);
+
+ m_Router.RegisterRoute(
"workers",
[this](HttpRouterRequest& Req) { HandleWorkersGet(Req.ServerRequest()); },
HttpVerb::kGet);
@@ -373,7 +442,7 @@ HttpComputeService::Impl::RegisterRoutes()
if (HttpResponseCode ResponseCode = m_ComputeService.FindActionResult(ActionId, /* out */ Output);
ResponseCode != HttpResponseCode::OK)
{
- ZEN_TRACE("jobs/{}/{}: {}", Req.GetCapture(1), Req.GetCapture(2), ToString(ResponseCode))
+ ZEN_DEBUG("jobs/{}/{}: {}", Req.GetCapture(1), Req.GetCapture(2), ToString(ResponseCode))
if (ResponseCode == HttpResponseCode::NotFound)
{
@@ -498,9 +567,19 @@ HttpComputeService::Impl::RegisterRoutes()
return HttpReq.WriteResponse(HttpResponseCode::Forbidden);
}
- m_ComputeService.StartRecording(m_CidStore, m_BaseDir / "recording");
+ std::filesystem::path RecordingPath = m_BaseDir / "recording";
+
+ if (!m_ComputeService.StartRecording(m_CombinedResolver, RecordingPath))
+ {
+ CbObjectWriter Cbo;
+ Cbo << "error"
+ << "recording is already active";
+ return HttpReq.WriteResponse(HttpResponseCode::Conflict, Cbo.Save());
+ }
- return HttpReq.WriteResponse(HttpResponseCode::OK);
+ CbObjectWriter Cbo;
+ Cbo << "path" << RecordingPath.string();
+ return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save());
},
HttpVerb::kPost);
@@ -514,9 +593,19 @@ HttpComputeService::Impl::RegisterRoutes()
return HttpReq.WriteResponse(HttpResponseCode::Forbidden);
}
- m_ComputeService.StopRecording();
+ std::filesystem::path RecordingPath = m_BaseDir / "recording";
- return HttpReq.WriteResponse(HttpResponseCode::OK);
+ if (!m_ComputeService.StopRecording())
+ {
+ CbObjectWriter Cbo;
+ Cbo << "error"
+ << "no recording is active";
+ return HttpReq.WriteResponse(HttpResponseCode::Conflict, Cbo.Save());
+ }
+
+ CbObjectWriter Cbo;
+ Cbo << "path" << RecordingPath.string();
+ return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save());
},
HttpVerb::kPost);
@@ -583,7 +672,7 @@ HttpComputeService::Impl::RegisterRoutes()
},
HttpVerb::kGet | HttpVerb::kPost);
- // Queue creation routes — these remain separate since local creates a plain queue
+ // Queue creation routes - these remain separate since local creates a plain queue
// while remote additionally generates an OID token for external access.
m_Router.RegisterRoute(
@@ -637,7 +726,7 @@ HttpComputeService::Impl::RegisterRoutes()
return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save());
}
- // Queue has since expired — clean up stale entries and fall through to create a new one
+ // Queue has since expired - clean up stale entries and fall through to create a new one
m_RemoteQueuesByToken.erase(Existing->Token);
m_RemoteQueuesByQueueId.erase(Existing->QueueId);
m_RemoteQueuesByTag.erase(It);
@@ -666,7 +755,7 @@ HttpComputeService::Impl::RegisterRoutes()
},
HttpVerb::kPost);
- // Unified queue routes — {queueref} accepts both local integer IDs and remote OID tokens.
+ // Unified queue routes - {queueref} accepts both local integer IDs and remote OID tokens.
// ResolveQueueRef() handles access control (local-only for integer IDs) and token resolution.
m_Router.RegisterRoute(
@@ -1016,7 +1105,7 @@ HttpComputeService::Impl::RegisterRoutes()
},
HttpVerb::kPost);
- // WebSocket upgrade endpoint — the handler logic lives in
+ // WebSocket upgrade endpoint - the handler logic lives in
// HttpComputeService::OnWebSocket* methods; this route merely
// satisfies the router so the upgrade request isn't rejected.
m_Router.RegisterRoute(
@@ -1027,11 +1116,12 @@ HttpComputeService::Impl::RegisterRoutes()
//////////////////////////////////////////////////////////////////////////
-HttpComputeService::HttpComputeService(CidStore& InCidStore,
+HttpComputeService::HttpComputeService(ChunkStore& InActionStore,
+ ChunkStore& InWorkerStore,
IHttpStatsService& StatsService,
const std::filesystem::path& BaseDir,
int32_t MaxConcurrentActions)
-: m_Impl(std::make_unique<Impl>(this, InCidStore, StatsService, BaseDir, MaxConcurrentActions))
+: m_Impl(std::make_unique<Impl>(this, InActionStore, InWorkerStore, StatsService, BaseDir, MaxConcurrentActions))
{
}
@@ -1057,6 +1147,12 @@ HttpComputeService::GetActionCounts()
return m_Impl->m_ComputeService.GetActionCounts();
}
+void
+HttpComputeService::SetShutdownCallback(std::function<void()> Callback)
+{
+ m_Impl->m_ShutdownCallback = std::move(Callback);
+}
+
const char*
HttpComputeService::BaseUri() const
{
@@ -1145,7 +1241,7 @@ HttpComputeService::Impl::ResolveQueueRef(HttpServerRequest& HttpReq, std::strin
{
if (OidMatcher(Capture))
{
- // Remote OID token — accessible from any client
+ // Remote OID token - accessible from any client
const Oid Token = Oid::FromHexString(Capture);
const int QueueId = ResolveQueueToken(Token);
@@ -1157,7 +1253,7 @@ HttpComputeService::Impl::ResolveQueueRef(HttpServerRequest& HttpReq, std::strin
return QueueId;
}
- // Local integer queue ID — restricted to local machine requests
+ // Local integer queue ID - restricted to local machine requests
if (!HttpReq.IsLocalMachineRequest())
{
HttpReq.WriteResponse(HttpResponseCode::Forbidden);
@@ -1167,35 +1263,81 @@ HttpComputeService::Impl::ResolveQueueRef(HttpServerRequest& HttpReq, std::strin
return ParseInt<int>(Capture).value_or(0);
}
-HttpComputeService::Impl::IngestStats
-HttpComputeService::Impl::IngestPackageAttachments(const CbPackage& Package)
+bool
+HttpComputeService::Impl::ValidateAttachmentHash(HttpServerRequest& HttpReq, const CbAttachment& Attachment)
{
- IngestStats Stats;
+ const IoHash ClaimedHash = Attachment.GetHash();
+ CompressedBuffer Buffer = Attachment.AsCompressedBinary();
+ const IoHash HeaderHash = Buffer.DecodeRawHash();
+ if (HeaderHash != ClaimedHash)
+ {
+ ZEN_WARN("attachment header hash mismatch: claimed {} but header contains {}", ClaimedHash, HeaderHash);
+ HttpReq.WriteResponse(HttpResponseCode::BadRequest);
+ return false;
+ }
+
+ IoHashStream Hasher;
+
+ bool DecompressOk = Buffer.DecompressToStream(
+ 0,
+ Buffer.DecodeRawSize(),
+ [&](uint64_t /*SourceOffset*/, uint64_t /*SourceSize*/, uint64_t /*Offset*/, const CompositeBuffer& Range) -> bool {
+ for (const SharedBuffer& Segment : Range.GetSegments())
+ {
+ Hasher.Append(Segment.GetView());
+ }
+ return true;
+ });
+
+ if (!DecompressOk)
+ {
+ ZEN_WARN("attachment {}: failed to decompress", ClaimedHash);
+ HttpReq.WriteResponse(HttpResponseCode::BadRequest);
+ return false;
+ }
+
+ const IoHash ActualHash = Hasher.GetHash();
+
+ if (ActualHash != ClaimedHash)
+ {
+ ZEN_WARN("attachment hash mismatch: claimed {} but decompressed data hashes to {}", ClaimedHash, ActualHash);
+ HttpReq.WriteResponse(HttpResponseCode::BadRequest);
+ return false;
+ }
+
+ return true;
+}
+
+bool
+HttpComputeService::Impl::IngestPackageAttachments(HttpServerRequest& HttpReq, const CbPackage& Package, IngestStats& OutStats)
+{
for (const CbAttachment& Attachment : Package.GetAttachments())
{
ZEN_ASSERT(Attachment.IsCompressedBinary());
- const IoHash DataHash = Attachment.GetHash();
- CompressedBuffer DataView = Attachment.AsCompressedBinary();
-
- ZEN_UNUSED(DataHash);
+ if (!ValidateAttachmentHash(HttpReq, Attachment))
+ {
+ return false;
+ }
- const uint64_t CompressedSize = DataView.GetCompressedSize();
+ const IoHash DataHash = Attachment.GetHash();
+ CompressedBuffer DataView = Attachment.AsCompressedBinary();
+ const uint64_t CompressedSize = DataView.GetCompressedSize();
- Stats.Bytes += CompressedSize;
- ++Stats.Count;
+ OutStats.Bytes += CompressedSize;
+ ++OutStats.Count;
- const CidStore::InsertResult InsertResult = m_CidStore.AddChunk(DataView.GetCompressed().Flatten().AsIoBuffer(), DataHash);
+ const ChunkStore::InsertResult InsertResult = m_ActionStore.AddChunk(DataView.GetCompressed().Flatten().AsIoBuffer(), DataHash);
if (InsertResult.New)
{
- Stats.NewBytes += CompressedSize;
- ++Stats.NewCount;
+ OutStats.NewBytes += CompressedSize;
+ ++OutStats.NewCount;
}
}
- return Stats;
+ return true;
}
bool
@@ -1204,7 +1346,7 @@ HttpComputeService::Impl::CheckAttachments(const CbObject& ActionObj, std::vecto
ActionObj.IterateAttachments([&](CbFieldView Field) {
const IoHash FileHash = Field.AsHash();
- if (!m_CidStore.ContainsChunk(FileHash))
+ if (!m_ActionStore.ContainsChunk(FileHash))
{
NeedList.push_back(FileHash);
}
@@ -1253,7 +1395,10 @@ HttpComputeService::Impl::HandleSubmitAction(HttpServerRequest& HttpReq, int Que
{
CbPackage Package = HttpReq.ReadPayloadPackage();
Body = Package.GetObject();
- Stats = IngestPackageAttachments(Package);
+ if (!IngestPackageAttachments(HttpReq, Package, Stats))
+ {
+ return; // validation failed, response already written
+ }
break;
}
@@ -1268,8 +1413,7 @@ HttpComputeService::Impl::HandleSubmitAction(HttpServerRequest& HttpReq, int Que
{
// --- Batch path ---
- // For CbObject payloads, check all attachments upfront before enqueuing anything
- if (HttpReq.RequestContentType() == HttpContentType::kCbObject)
+ // Verify all action attachment references exist in the store
{
std::vector<IoHash> NeedList;
@@ -1345,7 +1489,6 @@ HttpComputeService::Impl::HandleSubmitAction(HttpServerRequest& HttpReq, int Que
// --- Single-action path: Body is the action itself ---
- if (HttpReq.RequestContentType() == HttpContentType::kCbObject)
{
std::vector<IoHash> NeedList;
@@ -1453,7 +1596,7 @@ HttpComputeService::Impl::HandleWorkerRequest(HttpServerRequest& HttpReq, const
CbPackage WorkerPackage;
WorkerPackage.SetObject(WorkerSpec);
- m_CidStore.FilterChunks(ChunkSet);
+ m_WorkerStore.FilterChunks(ChunkSet);
if (ChunkSet.IsEmpty())
{
@@ -1491,15 +1634,19 @@ HttpComputeService::Impl::HandleWorkerRequest(HttpServerRequest& HttpReq, const
{
ZEN_ASSERT(Attachment.IsCompressedBinary());
+ if (!ValidateAttachmentHash(HttpReq, Attachment))
+ {
+ return;
+ }
+
const IoHash DataHash = Attachment.GetHash();
CompressedBuffer Buffer = Attachment.AsCompressedBinary();
- ZEN_UNUSED(DataHash);
TotalAttachmentBytes += Buffer.GetCompressedSize();
++AttachmentCount;
- const CidStore::InsertResult InsertResult =
- m_CidStore.AddChunk(Buffer.GetCompressed().Flatten().AsIoBuffer(), DataHash);
+ const ChunkStore::InsertResult InsertResult =
+ m_WorkerStore.AddChunk(Buffer.GetCompressed().Flatten().AsIoBuffer(), DataHash);
if (InsertResult.New)
{
@@ -1537,9 +1684,9 @@ HttpComputeService::Impl::HandleWorkerRequest(HttpServerRequest& HttpReq, const
//
void
-HttpComputeService::OnWebSocketOpen(Ref<WebSocketConnection> Connection)
+HttpComputeService::OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri)
{
- m_Impl->OnWebSocketOpen(std::move(Connection));
+ m_Impl->OnWebSocketOpen(std::move(Connection), RelativeUri);
}
void
@@ -1562,12 +1709,13 @@ HttpComputeService::OnActionsCompleted(std::span<const CompletedActionNotificati
//////////////////////////////////////////////////////////////////////////
//
-// Impl — WebSocket / observer
+// Impl - WebSocket / observer
//
void
-HttpComputeService::Impl::OnWebSocketOpen(Ref<WebSocketConnection> Connection)
+HttpComputeService::Impl::OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri)
{
+ ZEN_UNUSED(RelativeUri);
ZEN_INFO("compute WebSocket client connected");
m_WsConnectionsLock.WithExclusiveLock([&] { m_WsConnections.push_back(std::move(Connection)); });
}
diff --git a/src/zencompute/httporchestrator.cpp b/src/zencompute/httporchestrator.cpp
index 6cbe01e04..56eadcd57 100644
--- a/src/zencompute/httporchestrator.cpp
+++ b/src/zencompute/httporchestrator.cpp
@@ -7,6 +7,7 @@
# include <zencompute/orchestratorservice.h>
# include <zencore/compactbinarybuilder.h>
# include <zencore/logging.h>
+# include <zencore/session.h>
# include <zencore/string.h>
# include <zencore/system.h>
@@ -77,10 +78,47 @@ ParseWorkerAnnouncement(const CbObjectView& Data, OrchestratorService::WorkerAnn
return Ann.Id;
}
+static OrchestratorService::WorkerAnnotator
+MakeWorkerAnnotator(IProvisionerStateProvider* Prov)
+{
+ if (!Prov)
+ {
+ return {};
+ }
+ return [Prov](std::string_view WorkerId, CbObjectWriter& Cbo) {
+ AgentProvisioningStatus Status = Prov->GetAgentStatus(WorkerId);
+ if (Status != AgentProvisioningStatus::Unknown)
+ {
+ const char* StatusStr = (Status == AgentProvisioningStatus::Draining) ? "draining" : "active";
+ Cbo << "provisioner_status" << std::string_view(StatusStr);
+ }
+ };
+}
+
+bool
+HttpOrchestratorService::ValidateCoordinatorSession(const CbObjectView& Data, std::string_view WorkerId)
+{
+ std::string_view SessionStr = Data["coordinator_session"].AsString("");
+ if (SessionStr.empty())
+ {
+ return true; // backwards compatibility: accept announcements without a session
+ }
+ Oid Session = Oid::TryFromHexString(SessionStr);
+ if (Session == m_SessionId)
+ {
+ return true;
+ }
+ ZEN_WARN("rejecting stale announcement from '{}' (session {} != {})", WorkerId, SessionStr, m_SessionId.ToString());
+ return false;
+}
+
HttpOrchestratorService::HttpOrchestratorService(std::filesystem::path DataDir, bool EnableWorkerWebSocket)
: m_Service(std::make_unique<OrchestratorService>(std::move(DataDir), EnableWorkerWebSocket))
, m_Hostname(GetMachineName())
{
+ m_SessionId = zen::GetSessionId();
+ ZEN_INFO("orchestrator session id: {}", m_SessionId.ToString());
+
m_Router.AddMatcher("workerid", [](std::string_view Segment) { return IsValidWorkerId(Segment); });
m_Router.AddMatcher("clientid", [](std::string_view Segment) { return IsValidWorkerId(Segment); });
@@ -95,13 +133,17 @@ HttpOrchestratorService::HttpOrchestratorService(std::filesystem::path DataDir,
[this](HttpRouterRequest& Req) {
CbObjectWriter Cbo;
Cbo << "hostname" << std::string_view(m_Hostname);
+ Cbo << "session_id" << m_SessionId.ToString();
Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Cbo.Save());
},
HttpVerb::kGet);
m_Router.RegisterRoute(
"provision",
- [this](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK, m_Service->GetWorkerList()); },
+ [this](HttpRouterRequest& Req) {
+ Req.ServerRequest().WriteResponse(HttpResponseCode::OK,
+ m_Service->GetWorkerList(MakeWorkerAnnotator(m_Provisioner.load(std::memory_order_acquire))));
+ },
HttpVerb::kPost);
m_Router.RegisterRoute(
@@ -122,6 +164,11 @@ HttpOrchestratorService::HttpOrchestratorService(std::filesystem::path DataDir,
"characters and uri must start with http:// or https://");
}
+ if (!ValidateCoordinatorSession(Data, WorkerId))
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::Conflict, HttpContentType::kText, "Stale coordinator session");
+ }
+
m_Service->AnnounceWorker(Ann);
HttpReq.WriteResponse(HttpResponseCode::OK);
@@ -135,7 +182,10 @@ HttpOrchestratorService::HttpOrchestratorService(std::filesystem::path DataDir,
m_Router.RegisterRoute(
"agents",
- [this](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK, m_Service->GetWorkerList()); },
+ [this](HttpRouterRequest& Req) {
+ Req.ServerRequest().WriteResponse(HttpResponseCode::OK,
+ m_Service->GetWorkerList(MakeWorkerAnnotator(m_Provisioner.load(std::memory_order_acquire))));
+ },
HttpVerb::kGet);
m_Router.RegisterRoute(
@@ -241,6 +291,59 @@ HttpOrchestratorService::HttpOrchestratorService(std::filesystem::path DataDir,
},
HttpVerb::kGet);
+ // Provisioner endpoints
+
+ m_Router.RegisterRoute(
+ "provisioner/status",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+
+ CbObjectWriter Cbo;
+ if (IProvisionerStateProvider* Prov = m_Provisioner.load(std::memory_order_acquire))
+ {
+ Cbo << "name" << Prov->GetName();
+ Cbo << "target_cores" << Prov->GetTargetCoreCount();
+ Cbo << "estimated_cores" << Prov->GetEstimatedCoreCount();
+ Cbo << "active_cores" << Prov->GetActiveCoreCount();
+ Cbo << "agents" << Prov->GetAgentCount();
+ Cbo << "agents_draining" << Prov->GetDrainingAgentCount();
+ }
+ HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save());
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "provisioner/target",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+
+ CbObject Data = HttpReq.ReadPayloadObject();
+ int32_t Cores = Data["target_cores"].AsInt32(-1);
+
+ ZEN_INFO("provisioner/target: received request (target_cores={}, payload_valid={})", Cores, Data ? true : false);
+
+ if (Cores < 0)
+ {
+ ZEN_WARN("provisioner/target: bad request (target_cores={})", Cores);
+ return HttpReq.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "Missing or invalid target_cores field");
+ }
+
+ IProvisionerStateProvider* Prov = m_Provisioner.load(std::memory_order_acquire);
+ if (!Prov)
+ {
+ ZEN_WARN("provisioner/target: no provisioner configured");
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound, HttpContentType::kText, "No provisioner configured");
+ }
+
+ ZEN_INFO("provisioner/target: setting target to {} cores", Cores);
+ Prov->SetTargetCoreCount(static_cast<uint32_t>(Cores));
+
+ CbObjectWriter Cbo;
+ Cbo << "target_cores" << Prov->GetTargetCoreCount();
+ HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save());
+ },
+ HttpVerb::kPost);
+
// Client tracking endpoints
m_Router.RegisterRoute(
@@ -375,7 +478,7 @@ HttpOrchestratorService::Shutdown()
m_PushThread.join();
}
- // Clean up worker WebSocket connections — collect IDs under lock, then
+ // Clean up worker WebSocket connections - collect IDs under lock, then
// notify the service outside the lock to avoid lock-order inversions.
std::vector<std::string> WorkerIds;
m_WorkerWsLock.WithExclusiveLock([&] {
@@ -411,6 +514,13 @@ HttpOrchestratorService::HandleRequest(HttpServerRequest& Request)
}
}
+void
+HttpOrchestratorService::SetProvisionerStateProvider(IProvisionerStateProvider* Provider)
+{
+ m_Provisioner.store(Provider, std::memory_order_release);
+ m_Service->SetProvisionerStateProvider(Provider);
+}
+
//////////////////////////////////////////////////////////////////////////
//
// IWebSocketHandler
@@ -418,8 +528,9 @@ HttpOrchestratorService::HandleRequest(HttpServerRequest& Request)
# if ZEN_WITH_WEBSOCKETS
void
-HttpOrchestratorService::OnWebSocketOpen(Ref<WebSocketConnection> Connection)
+HttpOrchestratorService::OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri)
{
+ ZEN_UNUSED(RelativeUri);
if (!m_PushEnabled.load())
{
return;
@@ -471,7 +582,7 @@ std::string
HttpOrchestratorService::HandleWorkerWebSocketMessage(const WebSocketMessage& Msg)
{
// Workers send CbObject in native binary format over the WebSocket to
- // avoid the lossy CbObject↔JSON round-trip.
+ // avoid the lossy CbObject<->JSON round-trip.
CbObject Data = CbObject::MakeView(Msg.Payload.GetData());
if (!Data)
{
@@ -487,6 +598,11 @@ HttpOrchestratorService::HandleWorkerWebSocketMessage(const WebSocketMessage& Ms
return {};
}
+ if (!ValidateCoordinatorSession(Data, WorkerId))
+ {
+ return {};
+ }
+
m_Service->AnnounceWorker(Ann);
return std::string(WorkerId);
}
@@ -562,7 +678,7 @@ HttpOrchestratorService::PushThreadFunction()
}
// Build combined JSON with worker list, provisioning history, clients, and client history
- CbObject WorkerList = m_Service->GetWorkerList();
+ CbObject WorkerList = m_Service->GetWorkerList(MakeWorkerAnnotator(m_Provisioner.load(std::memory_order_acquire)));
CbObject History = m_Service->GetProvisioningHistory(50);
CbObject ClientList = m_Service->GetClientList();
CbObject ClientHistory = m_Service->GetClientHistory(50);
@@ -614,6 +730,20 @@ HttpOrchestratorService::PushThreadFunction()
JsonBuilder.Append(ClientHistoryJsonView.substr(1, ClientHistoryJsonView.size() - 2));
}
+ // Emit provisioner stats if available
+ if (IProvisionerStateProvider* Prov = m_Provisioner.load(std::memory_order_acquire))
+ {
+ JsonBuilder.Append(
+ fmt::format(",\"provisioner\":{{\"name\":\"{}\",\"target_cores\":{},\"estimated_cores\":{}"
+ ",\"active_cores\":{},\"agents\":{},\"agents_draining\":{}}}",
+ Prov->GetName(),
+ Prov->GetTargetCoreCount(),
+ Prov->GetEstimatedCoreCount(),
+ Prov->GetActiveCoreCount(),
+ Prov->GetAgentCount(),
+ Prov->GetDrainingAgentCount()));
+ }
+
JsonBuilder.Append("}");
std::string_view Json = JsonBuilder.ToView();
diff --git a/src/zencompute/include/zencompute/cloudmetadata.h b/src/zencompute/include/zencompute/cloudmetadata.h
index 3b9642ac3..280d794e7 100644
--- a/src/zencompute/include/zencompute/cloudmetadata.h
+++ b/src/zencompute/include/zencompute/cloudmetadata.h
@@ -64,7 +64,7 @@ public:
explicit CloudMetadata(std::filesystem::path DataDir);
/** Synchronously probes cloud providers at the given IMDS endpoint.
- * Intended for testing — allows redirecting all IMDS queries to a local
+ * Intended for testing - allows redirecting all IMDS queries to a local
* mock HTTP server instead of the real 169.254.169.254 endpoint.
*/
CloudMetadata(std::filesystem::path DataDir, std::string ImdsEndpoint);
diff --git a/src/zencompute/include/zencompute/computeservice.h b/src/zencompute/include/zencompute/computeservice.h
index 1ca78738a..97de4321a 100644
--- a/src/zencompute/include/zencompute/computeservice.h
+++ b/src/zencompute/include/zencompute/computeservice.h
@@ -167,6 +167,7 @@ public:
// Action runners
void AddLocalRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath, int32_t MaxConcurrentActions = 0);
+ void AddManagedLocalRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath, int32_t MaxConcurrentActions = 0);
void AddRemoteRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath, std::string_view HostName);
// Action submission
@@ -278,7 +279,7 @@ public:
// sized to match RunnerAction::State::_Count but we can't use the enum here
// for dependency reasons, so just use a fixed size array and static assert in
// the implementation file
- uint64_t Timestamps[9] = {};
+ uint64_t Timestamps[10] = {};
};
[[nodiscard]] std::vector<ActionHistoryEntry> GetActionHistory(int Limit = 100);
@@ -304,8 +305,9 @@ public:
// Recording
- void StartRecording(ChunkResolver& InResolver, const std::filesystem::path& RecordingPath);
- void StopRecording();
+ bool StartRecording(ChunkResolver& InResolver, const std::filesystem::path& RecordingPath);
+ bool StopRecording();
+ bool IsRecording() const;
private:
void PostUpdate(RunnerAction* Action);
diff --git a/src/zencompute/include/zencompute/httpcomputeservice.h b/src/zencompute/include/zencompute/httpcomputeservice.h
index b58e73a0d..32f54f293 100644
--- a/src/zencompute/include/zencompute/httpcomputeservice.h
+++ b/src/zencompute/include/zencompute/httpcomputeservice.h
@@ -15,7 +15,7 @@
# include <memory>
namespace zen {
-class CidStore;
+class ChunkStore;
}
namespace zen::compute {
@@ -26,7 +26,8 @@ namespace zen::compute {
class HttpComputeService : public HttpService, public IHttpStatsProvider, public IWebSocketHandler, public IComputeCompletionObserver
{
public:
- HttpComputeService(CidStore& InCidStore,
+ HttpComputeService(ChunkStore& InActionStore,
+ ChunkStore& InWorkerStore,
IHttpStatsService& StatsService,
const std::filesystem::path& BaseDir,
int32_t MaxConcurrentActions = 0);
@@ -34,6 +35,10 @@ public:
void Shutdown();
+ /** Set a callback to be invoked when the session/sunset endpoint is hit.
+ * Typically wired to HttpServer::RequestExit() to shut down the process. */
+ void SetShutdownCallback(std::function<void()> Callback);
+
[[nodiscard]] ComputeServiceSession::ActionCounts GetActionCounts();
const char* BaseUri() const override;
@@ -45,7 +50,7 @@ public:
// IWebSocketHandler
- void OnWebSocketOpen(Ref<WebSocketConnection> Connection) override;
+ void OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri) override;
void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) override;
void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason) override;
diff --git a/src/zencompute/include/zencompute/httporchestrator.h b/src/zencompute/include/zencompute/httporchestrator.h
index da5c5dfc3..4e4f5f0f8 100644
--- a/src/zencompute/include/zencompute/httporchestrator.h
+++ b/src/zencompute/include/zencompute/httporchestrator.h
@@ -2,10 +2,12 @@
#pragma once
+#include <zencompute/provisionerstate.h>
#include <zencompute/zencompute.h>
#include <zencore/logging.h>
#include <zencore/thread.h>
+#include <zencore/uid.h>
#include <zenhttp/httpserver.h>
#include <zenhttp/websocket.h>
@@ -65,12 +67,22 @@ public:
*/
void Shutdown();
+ /** Return the session ID generated at construction time. Provisioners
+ * pass this to spawned workers so the orchestrator can reject stale
+ * announcements from previous sessions. */
+ Oid GetSessionId() const { return m_SessionId; }
+
+ /** Register a provisioner whose target core count can be read and changed
+ * via the orchestrator HTTP API and dashboard. Caller retains ownership;
+ * the provider must outlive this service. */
+ void SetProvisionerStateProvider(IProvisionerStateProvider* Provider);
+
virtual const char* BaseUri() const override;
virtual void HandleRequest(HttpServerRequest& Request) override;
// IWebSocketHandler
#if ZEN_WITH_WEBSOCKETS
- void OnWebSocketOpen(Ref<WebSocketConnection> Connection) override;
+ void OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri) override;
void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) override;
void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason) override;
#endif
@@ -81,6 +93,11 @@ private:
std::unique_ptr<OrchestratorService> m_Service;
std::string m_Hostname;
+ Oid m_SessionId;
+ bool ValidateCoordinatorSession(const CbObjectView& Data, std::string_view WorkerId);
+
+ std::atomic<IProvisionerStateProvider*> m_Provisioner{nullptr};
+
// WebSocket push
#if ZEN_WITH_WEBSOCKETS
@@ -91,9 +108,9 @@ private:
Event m_PushEvent;
void PushThreadFunction();
- // Worker WebSocket connections (worker→orchestrator persistent links)
+ // Worker WebSocket connections (worker->orchestrator persistent links)
RwLock m_WorkerWsLock;
- std::unordered_map<WebSocketConnection*, std::string> m_WorkerWsMap; // connection ptr → worker ID
+ std::unordered_map<WebSocketConnection*, std::string> m_WorkerWsMap; // connection ptr -> worker ID
std::string HandleWorkerWebSocketMessage(const WebSocketMessage& Msg);
#endif
};
diff --git a/src/zencompute/include/zencompute/mockimds.h b/src/zencompute/include/zencompute/mockimds.h
index 704306913..6074240b9 100644
--- a/src/zencompute/include/zencompute/mockimds.h
+++ b/src/zencompute/include/zencompute/mockimds.h
@@ -1,5 +1,5 @@
// Copyright Epic Games, Inc. All Rights Reserved.
-// Moved to zenutil — this header is kept for backward compatibility.
+// Moved to zenutil - this header is kept for backward compatibility.
#pragma once
diff --git a/src/zencompute/include/zencompute/orchestratorservice.h b/src/zencompute/include/zencompute/orchestratorservice.h
index 071e902b3..2c49e22df 100644
--- a/src/zencompute/include/zencompute/orchestratorservice.h
+++ b/src/zencompute/include/zencompute/orchestratorservice.h
@@ -6,7 +6,10 @@
#if ZEN_WITH_COMPUTE_SERVICES
+# include <zencompute/provisionerstate.h>
# include <zencore/compactbinary.h>
+# include <zencore/compactbinarybuilder.h>
+# include <zencore/logbase.h>
# include <zencore/thread.h>
# include <zencore/timer.h>
# include <zencore/uid.h>
@@ -88,9 +91,16 @@ public:
std::string Hostname;
};
- CbObject GetWorkerList();
+ /** Per-worker callback invoked during GetWorkerList serialization.
+ * The callback receives the worker ID and a CbObjectWriter positioned
+ * inside the worker's object, allowing the caller to append extra fields. */
+ using WorkerAnnotator = std::function<void(std::string_view WorkerId, CbObjectWriter& Cbo)>;
+
+ CbObject GetWorkerList(const WorkerAnnotator& Annotate = {});
void AnnounceWorker(const WorkerAnnouncement& Announcement);
+ void SetProvisionerStateProvider(IProvisionerStateProvider* Provider);
+
bool IsWorkerWebSocketEnabled() const;
void SetWorkerWebSocketConnected(std::string_view WorkerId, bool Connected);
@@ -164,7 +174,12 @@ private:
void RecordClientEvent(ClientEvent::Type Type, std::string_view ClientId, std::string_view Hostname);
- bool m_EnableWorkerWebSocket = false;
+ LoggerRef Log() { return m_Log; }
+
+ LoggerRef m_Log{"compute.orchestrator"};
+ bool m_EnableWorkerWebSocket = false;
+
+ std::atomic<IProvisionerStateProvider*> m_Provisioner{nullptr};
std::thread m_ProbeThread;
std::atomic<bool> m_ProbeThreadEnabled{true};
diff --git a/src/zencompute/include/zencompute/provisionerstate.h b/src/zencompute/include/zencompute/provisionerstate.h
new file mode 100644
index 000000000..e9af8a635
--- /dev/null
+++ b/src/zencompute/include/zencompute/provisionerstate.h
@@ -0,0 +1,38 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <cstdint>
+#include <string_view>
+
+namespace zen::compute {
+
+/** Per-agent provisioning status as seen by the provisioner. */
+enum class AgentProvisioningStatus
+{
+ Unknown, ///< Not known to the provisioner
+ Active, ///< Running and allocated
+ Draining, ///< Being gracefully deprovisioned
+};
+
+/** Abstract interface for querying and controlling a provisioner from the HTTP layer.
+ * This decouples the orchestrator service from specific provisioner implementations. */
+class IProvisionerStateProvider
+{
+public:
+ virtual ~IProvisionerStateProvider() = default;
+
+ virtual std::string_view GetName() const = 0; ///< e.g. "horde", "nomad"
+ virtual uint32_t GetTargetCoreCount() const = 0;
+ virtual uint32_t GetEstimatedCoreCount() const = 0;
+ virtual uint32_t GetActiveCoreCount() const = 0;
+ virtual uint32_t GetAgentCount() const = 0;
+ virtual uint32_t GetDrainingAgentCount() const { return 0; }
+ virtual void SetTargetCoreCount(uint32_t Count) = 0;
+
+ /** Return the provisioning status for a worker by its orchestrator ID
+ * (e.g. "horde-{LeaseId}"). Returns Unknown if the ID is not recognized. */
+ virtual AgentProvisioningStatus GetAgentStatus(std::string_view /*WorkerId*/) const { return AgentProvisioningStatus::Unknown; }
+};
+
+} // namespace zen::compute
diff --git a/src/zencompute/orchestratorservice.cpp b/src/zencompute/orchestratorservice.cpp
index 9ea695305..68199ab3c 100644
--- a/src/zencompute/orchestratorservice.cpp
+++ b/src/zencompute/orchestratorservice.cpp
@@ -31,7 +31,7 @@ OrchestratorService::~OrchestratorService()
}
CbObject
-OrchestratorService::GetWorkerList()
+OrchestratorService::GetWorkerList(const WorkerAnnotator& Annotate)
{
ZEN_TRACE_CPU("OrchestratorService::GetWorkerList");
CbObjectWriter Cbo;
@@ -71,6 +71,10 @@ OrchestratorService::GetWorkerList()
Cbo << "ws_connected" << true;
}
Cbo << "dt" << Worker.LastSeen.GetElapsedTimeMs();
+ if (Annotate)
+ {
+ Annotate(WorkerId, Cbo);
+ }
Cbo.EndObject();
}
});
@@ -144,6 +148,12 @@ OrchestratorService::AnnounceWorker(const WorkerAnnouncement& Ann)
}
}
+void
+OrchestratorService::SetProvisionerStateProvider(IProvisionerStateProvider* Provider)
+{
+ m_Provisioner.store(Provider, std::memory_order_release);
+}
+
bool
OrchestratorService::IsWorkerWebSocketEnabled() const
{
@@ -170,11 +180,11 @@ OrchestratorService::SetWorkerWebSocketConnected(std::string_view WorkerId, bool
if (Connected)
{
- ZEN_INFO("worker {} WebSocket connected — marking reachable", WorkerId);
+ ZEN_INFO("worker {} WebSocket connected - marking reachable", WorkerId);
}
else
{
- ZEN_WARN("worker {} WebSocket disconnected — marking unreachable", WorkerId);
+ ZEN_WARN("worker {} WebSocket disconnected - marking unreachable", WorkerId);
}
});
@@ -607,6 +617,14 @@ OrchestratorService::ProbeThreadFunction()
continue;
}
+ // Check if the provisioner knows this worker is draining - if so,
+ // unreachability is expected and should not be logged as a warning.
+ bool IsDraining = false;
+ if (IProvisionerStateProvider* Prov = m_Provisioner.load(std::memory_order_acquire))
+ {
+ IsDraining = Prov->GetAgentStatus(Snap.Id) == AgentProvisioningStatus::Draining;
+ }
+
ReachableState NewState = ReachableState::Unreachable;
try
@@ -621,7 +639,10 @@ OrchestratorService::ProbeThreadFunction()
}
catch (const std::exception& Ex)
{
- ZEN_WARN("probe failed for worker {} ({}): {}", Snap.Id, Snap.Uri, Ex.what());
+ if (!IsDraining)
+ {
+ ZEN_WARN("probe failed for worker {} ({}): {}", Snap.Id, Snap.Uri, Ex.what());
+ }
}
ReachableState PrevState = ReachableState::Unknown;
@@ -646,6 +667,10 @@ OrchestratorService::ProbeThreadFunction()
{
ZEN_INFO("worker {} ({}) is now reachable", Snap.Id, Snap.Uri);
}
+ else if (IsDraining)
+ {
+ ZEN_INFO("worker {} ({}) shut down (draining)", Snap.Id, Snap.Uri);
+ }
else if (PrevState == ReachableState::Reachable)
{
ZEN_WARN("worker {} ({}) is no longer reachable", Snap.Id, Snap.Uri);
diff --git a/src/zencompute/pathvalidation.h b/src/zencompute/pathvalidation.h
new file mode 100644
index 000000000..d50ad4a2a
--- /dev/null
+++ b/src/zencompute/pathvalidation.h
@@ -0,0 +1,118 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/compactbinary.h>
+#include <zencore/except_fmt.h>
+#include <zencore/string.h>
+
+#include <filesystem>
+#include <string_view>
+
+namespace zen::compute {
+
+// Validate that a single path component contains only characters that are valid
+// file/directory names on all supported platforms. Uses Windows rules as the most
+// restrictive superset, since packages may be built on one platform and consumed
+// on another.
+inline void
+ValidatePathComponent(std::string_view Component, std::string_view FullPath)
+{
+ // Reject control characters (0x00-0x1F) and characters forbidden on Windows
+ for (char Ch : Component)
+ {
+ if (static_cast<unsigned char>(Ch) < 0x20 || Ch == '<' || Ch == '>' || Ch == ':' || Ch == '"' || Ch == '|' || Ch == '?' ||
+ Ch == '*')
+ {
+ throw zen::invalid_argument("invalid character in path component '{}' of '{}'", Component, FullPath);
+ }
+ }
+
+ // Reject empty components and trailing dots or spaces (silently stripped on Windows, leading to confusion)
+ if (Component.empty() || Component.back() == '.' || Component.back() == ' ')
+ {
+ throw zen::invalid_argument("path component '{}' of '{}' has trailing dot or space", Component, FullPath);
+ }
+
+ // Reject Windows reserved device names (CON, PRN, AUX, NUL, COM1-9, LPT1-9)
+ // These are reserved with or without an extension (e.g. "CON.txt" is still reserved).
+ std::string_view Stem = Component.substr(0, Component.find('.'));
+
+ static constexpr std::string_view ReservedNames[] = {
+ "CON", "PRN", "AUX", "NUL", "COM1", "COM2", "COM3", "COM4", "COM5", "COM6", "COM7",
+ "COM8", "COM9", "LPT1", "LPT2", "LPT3", "LPT4", "LPT5", "LPT6", "LPT7", "LPT8", "LPT9",
+ };
+
+ for (std::string_view Reserved : ReservedNames)
+ {
+ if (zen::StrCaseCompare(Stem, Reserved) == 0)
+ {
+ throw zen::invalid_argument("path component '{}' of '{}' uses reserved device name '{}'", Component, FullPath, Reserved);
+ }
+ }
+}
+
+// Validate that a path extracted from a package is a safe relative path.
+// Rejects absolute paths, ".." components, and invalid platform filenames.
+inline void
+ValidateSandboxRelativePath(std::string_view Name)
+{
+ if (Name.empty())
+ {
+ throw zen::invalid_argument("path traversal detected: empty path name");
+ }
+
+ std::filesystem::path Parsed(Name);
+
+ if (Parsed.is_absolute())
+ {
+ throw zen::invalid_argument("path traversal detected: '{}' is an absolute path", Name);
+ }
+
+ for (const auto& Component : Parsed)
+ {
+ std::string ComponentStr = Component.string();
+
+ if (ComponentStr == "..")
+ {
+ throw zen::invalid_argument("path traversal detected: '{}' contains '..' component", Name);
+ }
+
+ // Skip "." (current directory) - harmless in relative paths
+ if (ComponentStr != ".")
+ {
+ ValidatePathComponent(ComponentStr, Name);
+ }
+ }
+}
+
+// Validate all path entries in a worker description CbObject.
+// Checks path, executables[].name, dirs[], and files[].name fields.
+// Throws an exception if any invalid paths are found.
+inline void
+ValidateWorkerDescriptionPaths(const CbObject& WorkerDescription)
+{
+ using namespace std::literals;
+
+ if (auto PathField = WorkerDescription["path"sv]; PathField.HasValue())
+ {
+ ValidateSandboxRelativePath(PathField.AsString());
+ }
+
+ for (auto& It : WorkerDescription["executables"sv])
+ {
+ ValidateSandboxRelativePath(It.AsObjectView()["name"sv].AsString());
+ }
+
+ for (auto& It : WorkerDescription["dirs"sv])
+ {
+ ValidateSandboxRelativePath(It.AsString());
+ }
+
+ for (auto& It : WorkerDescription["files"sv])
+ {
+ ValidateSandboxRelativePath(It.AsObjectView()["name"sv].AsString());
+ }
+}
+
+} // namespace zen::compute
diff --git a/src/zencompute/runners/functionrunner.cpp b/src/zencompute/runners/functionrunner.cpp
index 4f116e7d8..34bf065b4 100644
--- a/src/zencompute/runners/functionrunner.cpp
+++ b/src/zencompute/runners/functionrunner.cpp
@@ -6,9 +6,15 @@
# include <zencore/compactbinary.h>
# include <zencore/filesystem.h>
+# include <zencore/fmtutils.h>
+# include <zencore/logging.h>
+# include <zencore/string.h>
+# include <zencore/timer.h>
# include <zencore/trace.h>
+# include <zencore/workthreadpool.h>
# include <fmt/format.h>
+# include <future>
# include <vector>
namespace zen::compute {
@@ -118,23 +124,34 @@ std::vector<SubmitResult>
BaseRunnerGroup::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions)
{
ZEN_TRACE_CPU("BaseRunnerGroup::SubmitActions");
- RwLock::SharedLockScope _(m_RunnersLock);
- const int RunnerCount = gsl::narrow<int>(m_Runners.size());
+ // Snapshot runners and query capacity under the lock, then release
+ // before submitting - HTTP submissions to remote runners can take
+ // hundreds of milliseconds and we must not hold m_RunnersLock during I/O.
- if (RunnerCount == 0)
- {
- return std::vector<SubmitResult>(Actions.size(), SubmitResult{.IsAccepted = false, .Reason = "No runners available"});
- }
+ std::vector<Ref<FunctionRunner>> Runners;
+ std::vector<size_t> Capacities;
+ std::vector<std::vector<Ref<RunnerAction>>> PerRunnerActions;
+ size_t TotalCapacity = 0;
- // Query capacity per runner and compute total
- std::vector<size_t> Capacities(RunnerCount);
- size_t TotalCapacity = 0;
+ m_RunnersLock.WithSharedLock([&] {
+ const int RunnerCount = gsl::narrow<int>(m_Runners.size());
+ Runners.assign(m_Runners.begin(), m_Runners.end());
+ Capacities.resize(RunnerCount);
+ PerRunnerActions.resize(RunnerCount);
- for (int i = 0; i < RunnerCount; ++i)
+ for (int i = 0; i < RunnerCount; ++i)
+ {
+ Capacities[i] = Runners[i]->QueryCapacity();
+ TotalCapacity += Capacities[i];
+ }
+ });
+
+ const int RunnerCount = gsl::narrow<int>(Runners.size());
+
+ if (RunnerCount == 0)
{
- Capacities[i] = m_Runners[i]->QueryCapacity();
- TotalCapacity += Capacities[i];
+ return std::vector<SubmitResult>(Actions.size(), SubmitResult{.IsAccepted = false, .Reason = "No runners available"});
}
if (TotalCapacity == 0)
@@ -143,9 +160,8 @@ BaseRunnerGroup::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions)
}
// Distribute actions across runners proportionally to their available capacity
- std::vector<std::vector<Ref<RunnerAction>>> PerRunnerActions(RunnerCount);
- std::vector<size_t> ActionRunnerIndex(Actions.size());
- size_t ActionIdx = 0;
+ std::vector<size_t> ActionRunnerIndex(Actions.size());
+ size_t ActionIdx = 0;
for (int i = 0; i < RunnerCount; ++i)
{
@@ -164,8 +180,9 @@ BaseRunnerGroup::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions)
}
}
- // Assign any remaining actions to runners with capacity (round-robin)
- for (int i = 0; ActionIdx < Actions.size(); i = (i + 1) % RunnerCount)
+ // Assign any remaining actions to runners with capacity (round-robin).
+ // Cap at TotalCapacity to avoid spinning when there are more actions than runners can accept.
+ for (int i = 0; ActionIdx < Actions.size() && ActionIdx < TotalCapacity; i = (i + 1) % RunnerCount)
{
if (Capacities[i] > PerRunnerActions[i].size())
{
@@ -175,22 +192,83 @@ BaseRunnerGroup::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions)
}
}
- // Submit batches per runner
+ // Submit batches per runner - in parallel when a worker pool is available
+
std::vector<std::vector<SubmitResult>> PerRunnerResults(RunnerCount);
+ int ActiveRunnerCount = 0;
for (int i = 0; i < RunnerCount; ++i)
{
if (!PerRunnerActions[i].empty())
{
- PerRunnerResults[i] = m_Runners[i]->SubmitActions(PerRunnerActions[i]);
+ ++ActiveRunnerCount;
+ }
+ }
+
+ static constexpr uint64_t SubmitWarnThresholdMs = 500;
+
+ auto SubmitToRunner = [&](int RunnerIndex) {
+ auto& Runner = Runners[RunnerIndex];
+ Runner->m_LastSubmitStats.Reset();
+
+ Stopwatch Timer;
+
+ PerRunnerResults[RunnerIndex] = Runner->SubmitActions(PerRunnerActions[RunnerIndex]);
+
+ uint64_t ElapsedMs = Timer.GetElapsedTimeMs();
+ if (ElapsedMs >= SubmitWarnThresholdMs)
+ {
+ size_t Attachments = Runner->m_LastSubmitStats.TotalAttachments.load(std::memory_order_relaxed);
+ uint64_t AttachmentBytes = Runner->m_LastSubmitStats.TotalAttachmentBytes.load(std::memory_order_relaxed);
+
+ ZEN_WARN("submit of {} actions ({} attachments, {}) to '{}' took {}ms",
+ PerRunnerActions[RunnerIndex].size(),
+ Attachments,
+ NiceBytes(AttachmentBytes),
+ Runner->GetDisplayName(),
+ ElapsedMs);
+ }
+ };
+
+ if (m_WorkerPool && ActiveRunnerCount > 1)
+ {
+ std::vector<std::future<void>> Futures(RunnerCount);
+
+ for (int i = 0; i < RunnerCount; ++i)
+ {
+ if (!PerRunnerActions[i].empty())
+ {
+ std::packaged_task<void()> Task([&SubmitToRunner, i]() { SubmitToRunner(i); });
+
+ Futures[i] = m_WorkerPool->EnqueueTask(std::move(Task), WorkerThreadPool::EMode::EnableBacklog);
+ }
+ }
+
+ for (int i = 0; i < RunnerCount; ++i)
+ {
+ if (Futures[i].valid())
+ {
+ Futures[i].get();
+ }
+ }
+ }
+ else
+ {
+ for (int i = 0; i < RunnerCount; ++i)
+ {
+ if (!PerRunnerActions[i].empty())
+ {
+ SubmitToRunner(i);
+ }
}
}
- // Reassemble results in original action order
- std::vector<SubmitResult> Results(Actions.size());
+ // Reassemble results in original action order.
+ // Actions beyond ActionIdx were not assigned to any runner (insufficient capacity).
+ std::vector<SubmitResult> Results(Actions.size(), SubmitResult{.IsAccepted = false, .Reason = "No capacity"});
std::vector<size_t> PerRunnerIdx(RunnerCount, 0);
- for (size_t i = 0; i < Actions.size(); ++i)
+ for (size_t i = 0; i < ActionIdx; ++i)
{
size_t RunnerIdx = ActionRunnerIndex[i];
size_t Idx = PerRunnerIdx[RunnerIdx]++;
@@ -307,10 +385,11 @@ RunnerAction::RetractAction()
bool
RunnerAction::ResetActionStateToPending()
{
- // Only allow reset from Failed, Abandoned, or Retracted states
+ // Only allow reset from Failed, Abandoned, Rejected, or Retracted states
State CurrentState = m_ActionState.load();
- if (CurrentState != State::Failed && CurrentState != State::Abandoned && CurrentState != State::Retracted)
+ if (CurrentState != State::Failed && CurrentState != State::Abandoned && CurrentState != State::Rejected &&
+ CurrentState != State::Retracted)
{
return false;
}
@@ -331,11 +410,12 @@ RunnerAction::ResetActionStateToPending()
// Clear execution fields
ExecutionLocation.clear();
+ FailureReason.clear();
CpuUsagePercent.store(-1.0f, std::memory_order_relaxed);
CpuSeconds.store(0.0f, std::memory_order_relaxed);
- // Increment retry count (skip for Retracted — nothing failed)
- if (CurrentState != State::Retracted)
+ // Increment retry count (skip for Retracted/Rejected - nothing failed)
+ if (CurrentState != State::Retracted && CurrentState != State::Rejected)
{
RetryCount.fetch_add(1, std::memory_order_relaxed);
}
diff --git a/src/zencompute/runners/functionrunner.h b/src/zencompute/runners/functionrunner.h
index 56c3f3af0..371a60b7a 100644
--- a/src/zencompute/runners/functionrunner.h
+++ b/src/zencompute/runners/functionrunner.h
@@ -10,6 +10,10 @@
# include <filesystem>
# include <vector>
+namespace zen {
+class WorkerThreadPool;
+}
+
namespace zen::compute {
struct SubmitResult
@@ -37,6 +41,22 @@ public:
[[nodiscard]] virtual bool IsHealthy() = 0;
[[nodiscard]] virtual size_t QueryCapacity();
[[nodiscard]] virtual std::vector<SubmitResult> SubmitActions(const std::vector<Ref<RunnerAction>>& Actions);
+ [[nodiscard]] virtual std::string_view GetDisplayName() const { return "local"; }
+
+ // Accumulated stats from the most recent SubmitActions call.
+ // Reset before each call, populated by the runner implementation.
+ struct SubmitStats
+ {
+ std::atomic<size_t> TotalAttachments{0};
+ std::atomic<uint64_t> TotalAttachmentBytes{0};
+
+ void Reset()
+ {
+ TotalAttachments.store(0, std::memory_order_relaxed);
+ TotalAttachmentBytes.store(0, std::memory_order_relaxed);
+ }
+ };
+ SubmitStats m_LastSubmitStats;
// Best-effort cancellation of a specific in-flight action. Returns true if the
// cancellation signal was successfully sent. The action will transition to Cancelled
@@ -68,6 +88,8 @@ public:
bool CancelAction(int ActionLsn);
void CancelRemoteQueue(int QueueId);
+ void SetWorkerPool(WorkerThreadPool* Pool) { m_WorkerPool = Pool; }
+
size_t GetRunnerCount()
{
return m_RunnersLock.WithSharedLock([this] { return m_Runners.size(); });
@@ -79,6 +101,7 @@ protected:
RwLock m_RunnersLock;
std::vector<Ref<FunctionRunner>> m_Runners;
std::atomic<int> m_NextSubmitIndex{0};
+ WorkerThreadPool* m_WorkerPool = nullptr;
};
/** Typed RunnerGroup that adds type-safe runner addition and predicate-based removal.
@@ -151,6 +174,7 @@ struct RunnerAction : public RefCounted
CbObject ActionObj;
int Priority = 0;
std::string ExecutionLocation; // "local" or remote hostname
+ std::string FailureReason; // human-readable reason when action fails (empty on success)
// CPU usage and total CPU time of the running process, sampled periodically by the local runner.
// CpuUsagePercent: -1.0 means not yet sampled; >=0.0 is the most recent reading as a percentage.
@@ -168,6 +192,7 @@ struct RunnerAction : public RefCounted
Completed, // Finished successfully with results available
Failed, // Execution failed (transient error, eligible for retry)
Abandoned, // Infrastructure termination (e.g. spot eviction, session abandon)
+ Rejected, // Runner declined (e.g. at capacity) - rescheduled without retry cost
Cancelled, // Intentional user cancellation (never retried)
Retracted, // Pulled back for rescheduling on a different runner (no retry cost)
_Count
@@ -194,6 +219,8 @@ struct RunnerAction : public RefCounted
return "Failed";
case State::Abandoned:
return "Abandoned";
+ case State::Rejected:
+ return "Rejected";
case State::Cancelled:
return "Cancelled";
case State::Retracted:
diff --git a/src/zencompute/runners/linuxrunner.cpp b/src/zencompute/runners/linuxrunner.cpp
index e79a6c90f..be4274823 100644
--- a/src/zencompute/runners/linuxrunner.cpp
+++ b/src/zencompute/runners/linuxrunner.cpp
@@ -195,7 +195,7 @@ namespace {
WriteErrorAndExit(ErrorPipeFd, "bind mount /lib failed", errno);
}
- // /lib64 (optional — not all distros have it)
+ // /lib64 (optional - not all distros have it)
{
struct stat St;
if (stat("/lib64", &St) == 0 && S_ISDIR(St.st_mode))
@@ -208,7 +208,7 @@ namespace {
}
}
- // /etc (required — for resolv.conf, ld.so.cache, etc.)
+ // /etc (required - for resolv.conf, ld.so.cache, etc.)
if (MkdirIfNeeded(BuildPath("etc"), 0755) != 0)
{
WriteErrorAndExit(ErrorPipeFd, "mkdir sandbox/etc failed", errno);
@@ -218,7 +218,7 @@ namespace {
WriteErrorAndExit(ErrorPipeFd, "bind mount /etc failed", errno);
}
- // /worker — bind-mount worker directory (contains the executable)
+ // /worker - bind-mount worker directory (contains the executable)
if (MkdirIfNeeded(BuildPath("worker"), 0755) != 0)
{
WriteErrorAndExit(ErrorPipeFd, "mkdir sandbox/worker failed", errno);
@@ -331,6 +331,8 @@ LinuxProcessRunner::LinuxProcessRunner(ChunkResolver& Resolver,
{
ZEN_INFO("namespace sandboxing enabled for child processes");
}
+
+ StartMonitorThread();
}
SubmitResult
@@ -428,11 +430,12 @@ LinuxProcessRunner::SubmitAction(Ref<RunnerAction> Action)
if (ChildPid == 0)
{
- // Child process
+ // Child process - lower priority so workers don't starve the main server
+ nice(5);
if (m_Sandboxed)
{
- // Close read end of error pipe — child only writes
+ // Close read end of error pipe - child only writes
close(ErrorPipe[0]);
SetupNamespaceSandbox(SandboxPathStr.c_str(), CurrentUid, CurrentGid, WorkerPathStr.c_str(), ErrorPipe[1]);
@@ -459,7 +462,7 @@ LinuxProcessRunner::SubmitAction(Ref<RunnerAction> Action)
if (m_Sandboxed)
{
- // Close write end of error pipe — parent only reads
+ // Close write end of error pipe - parent only reads
close(ErrorPipe[1]);
// Read from error pipe. If execve succeeded, pipe was closed by O_CLOEXEC
@@ -479,7 +482,8 @@ LinuxProcessRunner::SubmitAction(Ref<RunnerAction> Action)
// Clean up the sandbox in the background
m_DeferredDeleter.Enqueue(Action->ActionLsn, std::move(Prepared->SandboxPath));
- ZEN_ERROR("Sandbox setup failed for action {}: {}", Action->ActionLsn, ErrBuf);
+ Action->FailureReason = fmt::format("sandbox setup failed: {}", ErrBuf);
+ ZEN_ERROR("action {} ({}): {}", Action->ActionId, Action->ActionLsn, Action->FailureReason);
Action->SetActionState(RunnerAction::State::Failed);
return SubmitResult{.IsAccepted = false};
@@ -675,7 +679,7 @@ ReadProcStatCpuTicks(pid_t Pid)
Buf[Len] = '\0';
- // Skip past "pid (name) " — find last ')' to handle names containing spaces or parens
+ // Skip past "pid (name) " - find last ')' to handle names containing spaces or parens
const char* P = strrchr(Buf, ')');
if (!P)
{
@@ -705,7 +709,7 @@ LinuxProcessRunner::SampleProcessCpu(RunningAction& Running)
if (CurrentOsTicks == 0)
{
- // Process gone or /proc entry unreadable — record timestamp without updating usage
+ // Process gone or /proc entry unreadable - record timestamp without updating usage
Running.LastCpuSampleTicks = NowTicks;
Running.LastCpuOsTicks = 0;
return;
diff --git a/src/zencompute/runners/localrunner.cpp b/src/zencompute/runners/localrunner.cpp
index b61e0a46f..259965e23 100644
--- a/src/zencompute/runners/localrunner.cpp
+++ b/src/zencompute/runners/localrunner.cpp
@@ -4,6 +4,8 @@
#if ZEN_WITH_COMPUTE_SERVICES
+# include "pathvalidation.h"
+
# include <zencore/compactbinary.h>
# include <zencore/compactbinarybuilder.h>
# include <zencore/compactbinarypackage.h>
@@ -104,8 +106,6 @@ LocalProcessRunner::LocalProcessRunner(ChunkResolver& Resolver,
ZEN_INFO("Cleanup complete");
}
- m_MonitorThread = std::thread{&LocalProcessRunner::MonitorThreadFunction, this};
-
# if ZEN_PLATFORM_WINDOWS
// Suppress any error dialogs caused by missing dependencies
UINT OldMode = ::SetErrorMode(0);
@@ -337,7 +337,7 @@ LocalProcessRunner::PrepareActionSubmission(Ref<RunnerAction> Action)
SubmitResult
LocalProcessRunner::SubmitAction(Ref<RunnerAction> Action)
{
- // Base class is not directly usable — platform subclasses override this
+ // Base class is not directly usable - platform subclasses override this
ZEN_UNUSED(Action);
return SubmitResult{.IsAccepted = false};
}
@@ -357,14 +357,21 @@ LocalProcessRunner::ManifestWorker(const WorkerDesc& Worker)
std::filesystem::path WorkerDir = m_WorkerPath / fmt::format("runner_{}", Worker.WorkerId);
- if (!std::filesystem::exists(WorkerDir))
+ // worker.zcb is written as the last step of ManifestWorker, so its presence
+ // indicates a complete manifest. If the directory exists but the marker is
+ // missing, a previous manifest was interrupted and we need to start over.
+ bool NeedsManifest = !std::filesystem::exists(WorkerDir / "worker.zcb");
+
+ if (NeedsManifest)
{
_.ReleaseNow();
RwLock::ExclusiveLockScope $(m_WorkerLock);
- if (!std::filesystem::exists(WorkerDir))
+ if (!std::filesystem::exists(WorkerDir / "worker.zcb"))
{
+ std::error_code Ec;
+ std::filesystem::remove_all(WorkerDir, Ec);
ManifestWorker(Worker.Descriptor, WorkerDir, [](const IoHash&, CompressedBuffer&) {});
}
}
@@ -382,6 +389,8 @@ LocalProcessRunner::DecompressAttachmentToFile(const CbPackage& FromP
const IoHash ChunkHash = FileEntry["hash"sv].AsHash();
const uint64_t Size = FileEntry["size"sv].AsUInt64();
+ ValidateSandboxRelativePath(Name);
+
CompressedBuffer Compressed;
if (const CbAttachment* Attachment = FromPackage.FindAttachment(ChunkHash))
@@ -457,7 +466,8 @@ LocalProcessRunner::ManifestWorker(const CbPackage& WorkerPackage,
for (auto& It : WorkerDescription["dirs"sv])
{
- std::string_view Name = It.AsString();
+ std::string_view Name = It.AsString();
+ ValidateSandboxRelativePath(Name);
std::filesystem::path DirPath{SandboxPath / std::filesystem::path(Name).make_preferred()};
// Validate dir path stays within sandbox
@@ -482,6 +492,8 @@ LocalProcessRunner::ManifestWorker(const CbPackage& WorkerPackage,
}
WriteFile(SandboxPath / "worker.zcb", WorkerDescription.GetBuffer().AsIoBuffer());
+
+ ZEN_INFO("manifested worker '{}' in '{}'", WorkerPackage.GetObjectHash(), SandboxPath);
}
CbPackage
@@ -540,6 +552,12 @@ LocalProcessRunner::GatherActionOutputs(std::filesystem::path SandboxPath)
}
void
+LocalProcessRunner::StartMonitorThread()
+{
+ m_MonitorThread = std::thread{&LocalProcessRunner::MonitorThreadFunction, this};
+}
+
+void
LocalProcessRunner::MonitorThreadFunction()
{
SetCurrentThreadName("LocalProcessRunner_Monitor");
@@ -602,7 +620,7 @@ LocalProcessRunner::MonitorThreadFunction()
void
LocalProcessRunner::CancelRunningActions()
{
- // Base class is not directly usable — platform subclasses override this
+ // Base class is not directly usable - platform subclasses override this
}
void
@@ -662,9 +680,15 @@ LocalProcessRunner::ProcessCompletedActions(std::vector<Ref<RunningAction>>& Com
}
catch (std::exception& Ex)
{
- ZEN_ERROR("Encountered failure while gathering outputs for action lsn {}, '{}'", ActionLsn, Ex.what());
+ Running->Action->FailureReason = fmt::format("exception gathering outputs: {}", Ex.what());
+ ZEN_ERROR("action {} ({}) failed: {}", Running->Action->ActionId, ActionLsn, Running->Action->FailureReason);
}
}
+ else
+ {
+ Running->Action->FailureReason = fmt::format("process exited with code {}", Running->ExitCode);
+ ZEN_WARN("action {} ({}) failed: {}", Running->Action->ActionId, ActionLsn, Running->Action->FailureReason);
+ }
// Failed - clean up the sandbox in the background.
diff --git a/src/zencompute/runners/localrunner.h b/src/zencompute/runners/localrunner.h
index b8cff6826..d6589db43 100644
--- a/src/zencompute/runners/localrunner.h
+++ b/src/zencompute/runners/localrunner.h
@@ -67,6 +67,7 @@ protected:
{
Ref<RunnerAction> Action;
void* ProcessHandle = nullptr;
+ int Pid = 0;
int ExitCode = 0;
std::filesystem::path SandboxPath;
@@ -83,8 +84,6 @@ protected:
std::filesystem::path m_SandboxPath;
int32_t m_MaxRunningActions = 64; // arbitrary limit for testing
- // if used in conjuction with m_ResultsLock, this lock must be taken *after*
- // m_ResultsLock to avoid deadlocks
RwLock m_RunningLock;
std::unordered_map<int, Ref<RunningAction>> m_RunningMap;
@@ -95,6 +94,7 @@ protected:
std::thread m_MonitorThread;
std::atomic<bool> m_MonitorThreadEnabled{true};
Event m_MonitorThreadEvent;
+ void StartMonitorThread();
void MonitorThreadFunction();
virtual void SweepRunningActions();
virtual void CancelRunningActions();
diff --git a/src/zencompute/runners/macrunner.cpp b/src/zencompute/runners/macrunner.cpp
index 5cec90699..ab24d4672 100644
--- a/src/zencompute/runners/macrunner.cpp
+++ b/src/zencompute/runners/macrunner.cpp
@@ -130,6 +130,8 @@ MacProcessRunner::MacProcessRunner(ChunkResolver& Resolver,
{
ZEN_INFO("Seatbelt sandboxing enabled for child processes");
}
+
+ StartMonitorThread();
}
SubmitResult
@@ -209,18 +211,19 @@ MacProcessRunner::SubmitAction(Ref<RunnerAction> Action)
if (ChildPid == 0)
{
- // Child process
+ // Child process - lower priority so workers don't starve the main server
+ nice(5);
if (m_Sandboxed)
{
- // Close read end of error pipe — child only writes
+ // Close read end of error pipe - child only writes
close(ErrorPipe[0]);
// Apply Seatbelt sandbox profile
char* ErrorBuf = nullptr;
if (sandbox_init(SandboxProfile.c_str(), 0, &ErrorBuf) != 0)
{
- // sandbox_init failed — write error to pipe and exit
+ // sandbox_init failed - write error to pipe and exit
if (ErrorBuf)
{
WriteErrorAndExit(ErrorPipe[1], ErrorBuf, 0);
@@ -259,7 +262,7 @@ MacProcessRunner::SubmitAction(Ref<RunnerAction> Action)
if (m_Sandboxed)
{
- // Close write end of error pipe — parent only reads
+ // Close write end of error pipe - parent only reads
close(ErrorPipe[1]);
// Read from error pipe. If execve succeeded, pipe was closed by O_CLOEXEC
@@ -279,7 +282,8 @@ MacProcessRunner::SubmitAction(Ref<RunnerAction> Action)
// Clean up the sandbox in the background
m_DeferredDeleter.Enqueue(Action->ActionLsn, std::move(Prepared->SandboxPath));
- ZEN_ERROR("Sandbox setup failed for action {}: {}", Action->ActionLsn, ErrBuf);
+ Action->FailureReason = fmt::format("sandbox setup failed: {}", ErrBuf);
+ ZEN_ERROR("action {} ({}): {}", Action->ActionId, Action->ActionLsn, Action->FailureReason);
Action->SetActionState(RunnerAction::State::Failed);
return SubmitResult{.IsAccepted = false};
@@ -467,7 +471,7 @@ MacProcessRunner::SampleProcessCpu(RunningAction& Running)
const uint64_t CurrentOsTicks = Info.pti_total_user + Info.pti_total_system;
const uint64_t NowTicks = GetHifreqTimerValue();
- // Cumulative CPU seconds (absolute, available from first sample): ns → seconds
+ // Cumulative CPU seconds (absolute, available from first sample): ns -> seconds
Running.Action->CpuSeconds.store(static_cast<float>(static_cast<double>(CurrentOsTicks) / 1'000'000'000.0), std::memory_order_relaxed);
if (Running.LastCpuSampleTicks != 0 && Running.LastCpuOsTicks != 0)
@@ -476,7 +480,7 @@ MacProcessRunner::SampleProcessCpu(RunningAction& Running)
if (ElapsedMs > 0)
{
const uint64_t DeltaOsTicks = CurrentOsTicks - Running.LastCpuOsTicks;
- // ns → ms: divide by 1,000,000; then as percent of elapsed ms
+ // ns -> ms: divide by 1,000,000; then as percent of elapsed ms
const float CpuPct = static_cast<float>(static_cast<double>(DeltaOsTicks) / 1'000'000.0 / ElapsedMs * 100.0);
Running.Action->CpuUsagePercent.store(CpuPct, std::memory_order_relaxed);
}
diff --git a/src/zencompute/runners/managedrunner.cpp b/src/zencompute/runners/managedrunner.cpp
new file mode 100644
index 000000000..a4f586852
--- /dev/null
+++ b/src/zencompute/runners/managedrunner.cpp
@@ -0,0 +1,279 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "managedrunner.h"
+
+#if ZEN_WITH_COMPUTE_SERVICES
+
+# include <zencore/compactbinary.h>
+# include <zencore/compactbinarypackage.h>
+# include <zencore/except_fmt.h>
+# include <zencore/filesystem.h>
+# include <zencore/fmtutils.h>
+# include <zencore/scopeguard.h>
+# include <zencore/timer.h>
+# include <zencore/trace.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+# include <asio/io_context.hpp>
+# include <asio/executor_work_guard.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+namespace zen::compute {
+
+using namespace std::literals;
+
+ManagedProcessRunner::ManagedProcessRunner(ChunkResolver& Resolver,
+ const std::filesystem::path& BaseDir,
+ DeferredDirectoryDeleter& Deleter,
+ WorkerThreadPool& WorkerPool,
+ int32_t MaxConcurrentActions)
+: LocalProcessRunner(Resolver, BaseDir, Deleter, WorkerPool, MaxConcurrentActions)
+, m_IoContext(std::make_unique<asio::io_context>())
+, m_SubprocessManager(std::make_unique<SubprocessManager>(*m_IoContext))
+{
+ m_ProcessGroup = m_SubprocessManager->CreateGroup("compute-workers");
+
+ // Run the io_context on a small thread pool so that exit callbacks and
+ // metrics sampling are dispatched without blocking each other.
+ for (int i = 0; i < kIoThreadCount; ++i)
+ {
+ m_IoThreads.emplace_back([this, i] {
+ SetCurrentThreadName(fmt::format("mrunner_{}", i));
+
+ // work_guard keeps run() alive even when there is no pending work yet
+ auto WorkGuard = asio::make_work_guard(*m_IoContext);
+
+ m_IoContext->run();
+ });
+ }
+}
+
+ManagedProcessRunner::~ManagedProcessRunner()
+{
+ try
+ {
+ Shutdown();
+ }
+ catch (std::exception& Ex)
+ {
+ ZEN_WARN("exception during managed process runner shutdown: {}", Ex.what());
+ }
+}
+
+void
+ManagedProcessRunner::Shutdown()
+{
+ ZEN_TRACE_CPU("ManagedProcessRunner::Shutdown");
+ m_AcceptNewActions = false;
+
+ CancelRunningActions();
+
+ // Tear down the SubprocessManager before stopping the io_context so that
+ // any in-flight callbacks are drained cleanly.
+ if (m_SubprocessManager)
+ {
+ m_SubprocessManager->DestroyGroup("compute-workers");
+ m_ProcessGroup = nullptr;
+ m_SubprocessManager.reset();
+ }
+
+ if (m_IoContext)
+ {
+ m_IoContext->stop();
+ }
+
+ for (std::thread& Thread : m_IoThreads)
+ {
+ if (Thread.joinable())
+ {
+ Thread.join();
+ }
+ }
+ m_IoThreads.clear();
+}
+
+SubmitResult
+ManagedProcessRunner::SubmitAction(Ref<RunnerAction> Action)
+{
+ ZEN_TRACE_CPU("ManagedProcessRunner::SubmitAction");
+ std::optional<PreparedAction> Prepared = PrepareActionSubmission(Action);
+
+ if (!Prepared)
+ {
+ return SubmitResult{.IsAccepted = false};
+ }
+
+ CbObject WorkerDescription = Prepared->WorkerPackage.GetObject();
+
+ // Parse environment variables from worker descriptor ("KEY=VALUE" strings)
+ // into the key-value pairs expected by CreateProcOptions.
+ std::vector<std::pair<std::string, std::string>> EnvPairs;
+ for (auto& It : WorkerDescription["environment"sv])
+ {
+ std::string_view Str = It.AsString();
+ size_t Eq = Str.find('=');
+ if (Eq != std::string_view::npos)
+ {
+ EnvPairs.emplace_back(std::string(Str.substr(0, Eq)), std::string(Str.substr(Eq + 1)));
+ }
+ }
+
+ // Build command line
+ std::string_view ExecPath = WorkerDescription["path"sv].AsString();
+ std::filesystem::path ExePath = Prepared->WorkerPath / std::filesystem::path(ExecPath).make_preferred();
+
+ std::string CommandLine = fmt::format("\"{}\" -Build=build.action"sv, ExePath.string());
+
+ ZEN_DEBUG("Executing (managed): '{}' (sandbox='{}')", CommandLine, Prepared->SandboxPath);
+
+ CreateProcOptions Options;
+ Options.WorkingDirectory = &Prepared->SandboxPath;
+ Options.Flags = CreateProcOptions::Flag_NoConsole | CreateProcOptions::Flag_BelowNormalPriority;
+ Options.Environment = std::move(EnvPairs);
+
+ const int32_t ActionLsn = Prepared->ActionLsn;
+
+ ManagedProcess* Proc = nullptr;
+
+ try
+ {
+ Proc = m_ProcessGroup->Spawn(ExePath, CommandLine, Options, [this, ActionLsn](ManagedProcess& /*Process*/, int ExitCode) {
+ OnProcessExit(ActionLsn, ExitCode);
+ });
+ }
+ catch (std::exception& Ex)
+ {
+ ZEN_ERROR("Failed to spawn process for action LSN {}: {}", ActionLsn, Ex.what());
+ m_DeferredDeleter.Enqueue(ActionLsn, std::move(Prepared->SandboxPath));
+ return SubmitResult{.IsAccepted = false};
+ }
+
+ {
+ Ref<RunningAction> NewAction{new RunningAction()};
+ NewAction->Action = Action;
+ NewAction->ProcessHandle = static_cast<void*>(Proc);
+ NewAction->Pid = Proc->Pid();
+ NewAction->SandboxPath = std::move(Prepared->SandboxPath);
+
+ RwLock::ExclusiveLockScope _(m_RunningLock);
+ m_RunningMap[ActionLsn] = std::move(NewAction);
+ }
+
+ Action->SetActionState(RunnerAction::State::Running);
+
+ ZEN_DEBUG("Managed runner: action LSN {} -> PID {}", ActionLsn, Proc->Pid());
+
+ return SubmitResult{.IsAccepted = true};
+}
+
+void
+ManagedProcessRunner::OnProcessExit(int ActionLsn, int ExitCode)
+{
+ ZEN_TRACE_CPU("ManagedProcessRunner::OnProcessExit");
+
+ Ref<RunningAction> Running;
+
+ m_RunningLock.WithExclusiveLock([&] {
+ auto It = m_RunningMap.find(ActionLsn);
+ if (It != m_RunningMap.end())
+ {
+ Running = std::move(It->second);
+ m_RunningMap.erase(It);
+ }
+ });
+
+ if (!Running)
+ {
+ return;
+ }
+
+ ZEN_DEBUG("Managed runner: action LSN {} + PID {} exited with code " ZEN_BRIGHT_WHITE("{}"), ActionLsn, Running->Pid, ExitCode);
+
+ Running->ExitCode = ExitCode;
+
+ // Capture final CPU metrics from the managed process before it is removed.
+ auto* Proc = static_cast<ManagedProcess*>(Running->ProcessHandle);
+ if (Proc)
+ {
+ ProcessMetrics Metrics = Proc->GetLatestMetrics();
+ float CpuMs = static_cast<float>(Metrics.UserTimeMs + Metrics.KernelTimeMs);
+ Running->Action->CpuSeconds.store(CpuMs / 1000.0f, std::memory_order_relaxed);
+
+ float CpuPct = Proc->GetCpuUsagePercent();
+ if (CpuPct >= 0.0f)
+ {
+ Running->Action->CpuUsagePercent.store(CpuPct, std::memory_order_relaxed);
+ }
+ }
+
+ Running->ProcessHandle = nullptr;
+
+ std::vector<Ref<RunningAction>> CompletedActions;
+ CompletedActions.push_back(std::move(Running));
+ ProcessCompletedActions(CompletedActions);
+}
+
+void
+ManagedProcessRunner::CancelRunningActions()
+{
+ ZEN_TRACE_CPU("ManagedProcessRunner::CancelRunningActions");
+
+ std::unordered_map<int, Ref<RunningAction>> RunningMap;
+ m_RunningLock.WithExclusiveLock([&] { std::swap(RunningMap, m_RunningMap); });
+
+ if (RunningMap.empty())
+ {
+ return;
+ }
+
+ ZEN_INFO("cancelling {} running actions via process group", RunningMap.size());
+
+ Stopwatch Timer;
+
+ // Kill all processes in the group atomically (TerminateJobObject on Windows,
+ // SIGTERM+SIGKILL on POSIX).
+ if (m_ProcessGroup)
+ {
+ m_ProcessGroup->KillAll();
+ }
+
+ for (auto& [Lsn, Running] : RunningMap)
+ {
+ m_DeferredDeleter.Enqueue(Running->Action->ActionLsn, std::move(Running->SandboxPath));
+ Running->Action->SetActionState(RunnerAction::State::Failed);
+ }
+
+ ZEN_INFO("DONE - cancelled {} running processes (took {})", RunningMap.size(), NiceTimeSpanMs(Timer.GetElapsedTimeMs()));
+}
+
+bool
+ManagedProcessRunner::CancelAction(int ActionLsn)
+{
+ ZEN_TRACE_CPU("ManagedProcessRunner::CancelAction");
+
+ ManagedProcess* Proc = nullptr;
+
+ m_RunningLock.WithSharedLock([&] {
+ auto It = m_RunningMap.find(ActionLsn);
+ if (It != m_RunningMap.end() && It->second->ProcessHandle != nullptr)
+ {
+ Proc = static_cast<ManagedProcess*>(It->second->ProcessHandle);
+ }
+ });
+
+ if (!Proc)
+ {
+ return false;
+ }
+
+ // Terminate the process. The exit callback will handle the rest
+ // (remove from running map, gather outputs or mark failed).
+ Proc->Terminate(222);
+
+ ZEN_DEBUG("CancelAction: initiated cancellation of LSN {}", ActionLsn);
+ return true;
+}
+
+} // namespace zen::compute
+
+#endif
diff --git a/src/zencompute/runners/managedrunner.h b/src/zencompute/runners/managedrunner.h
new file mode 100644
index 000000000..21a44d43c
--- /dev/null
+++ b/src/zencompute/runners/managedrunner.h
@@ -0,0 +1,64 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "localrunner.h"
+
+#if ZEN_WITH_COMPUTE_SERVICES
+
+# include <zenutil/process/subprocessmanager.h>
+
+# include <memory>
+
+namespace asio {
+class io_context;
+}
+
+namespace zen::compute {
+
+/** Cross-platform process runner backed by SubprocessManager.
+
+ Subclasses LocalProcessRunner, reusing sandbox management, worker manifesting,
+ input/output handling, and shared action preparation. Replaces the polling-based
+ monitor thread with async exit callbacks driven by SubprocessManager, and
+ delegates CPU/memory metrics sampling to the manager's built-in round-robin
+ sampler.
+
+ A ProcessGroup (backed by a JobObject on Windows, process group on POSIX) is
+ used for bulk cancellation on shutdown.
+
+ This runner does not perform any platform-specific sandboxing (AppContainer,
+ namespaces, Seatbelt). It is intended as a simpler, cross-platform alternative
+ to the platform-specific runners for non-sandboxed workloads.
+ */
+class ManagedProcessRunner : public LocalProcessRunner
+{
+public:
+ ManagedProcessRunner(ChunkResolver& Resolver,
+ const std::filesystem::path& BaseDir,
+ DeferredDirectoryDeleter& Deleter,
+ WorkerThreadPool& WorkerPool,
+ int32_t MaxConcurrentActions = 0);
+ ~ManagedProcessRunner();
+
+ void Shutdown() override;
+ [[nodiscard]] SubmitResult SubmitAction(Ref<RunnerAction> Action) override;
+ void CancelRunningActions() override;
+ bool CancelAction(int ActionLsn) override;
+ [[nodiscard]] bool IsHealthy() override { return true; }
+
+private:
+ static constexpr int kIoThreadCount = 4;
+
+ // Exit callback posted on an io_context thread.
+ void OnProcessExit(int ActionLsn, int ExitCode);
+
+ std::unique_ptr<asio::io_context> m_IoContext;
+ std::unique_ptr<SubprocessManager> m_SubprocessManager;
+ ProcessGroup* m_ProcessGroup = nullptr;
+ std::vector<std::thread> m_IoThreads;
+};
+
+} // namespace zen::compute
+
+#endif
diff --git a/src/zencompute/runners/remotehttprunner.cpp b/src/zencompute/runners/remotehttprunner.cpp
index ce6a81173..08f381b7f 100644
--- a/src/zencompute/runners/remotehttprunner.cpp
+++ b/src/zencompute/runners/remotehttprunner.cpp
@@ -20,6 +20,7 @@
# include <zenstore/cidstore.h>
# include <span>
+# include <unordered_set>
//////////////////////////////////////////////////////////////////////////
@@ -38,6 +39,7 @@ RemoteHttpRunner::RemoteHttpRunner(ChunkResolver& InChunkResolver,
, m_ChunkResolver{InChunkResolver}
, m_WorkerPool{InWorkerPool}
, m_HostName{HostName}
+, m_DisplayName{HostName}
, m_BaseUrl{fmt::format("{}/compute", HostName)}
, m_Http(m_BaseUrl)
, m_InstanceId(Oid::NewOid())
@@ -59,6 +61,15 @@ RemoteHttpRunner::RemoteHttpRunner(ChunkResolver& InChunkResolver,
m_MonitorThread = std::thread{&RemoteHttpRunner::MonitorThreadFunction, this};
}
+void
+RemoteHttpRunner::SetRemoteHostname(std::string_view Hostname)
+{
+ if (!Hostname.empty())
+ {
+ m_DisplayName = fmt::format("{} ({})", m_HostName, Hostname);
+ }
+}
+
RemoteHttpRunner::~RemoteHttpRunner()
{
Shutdown();
@@ -108,6 +119,7 @@ RemoteHttpRunner::Shutdown()
for (auto& [RemoteLsn, HttpAction] : Remaining)
{
ZEN_DEBUG("shutdown: marking remote action LSN {} (local LSN {}) as Failed", RemoteLsn, HttpAction.Action->ActionLsn);
+ HttpAction.Action->FailureReason = "remote runner shutdown";
HttpAction.Action->SetActionState(RunnerAction::State::Failed);
}
}
@@ -213,11 +225,13 @@ RemoteHttpRunner::QueryCapacity()
return 0;
}
- // Estimate how much more work we're ready to accept
+ // Estimate how much more work we're ready to accept.
+ // Include actions currently being submitted over HTTP so we don't
+ // keep queueing new submissions while previous ones are still in flight.
RwLock::SharedLockScope _{m_RunningLock};
- size_t RunningCount = m_RemoteRunningMap.size();
+ size_t RunningCount = m_RemoteRunningMap.size() + m_InFlightSubmissions.load(std::memory_order_relaxed);
if (RunningCount >= size_t(m_MaxRunningActions))
{
@@ -232,6 +246,9 @@ RemoteHttpRunner::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions)
{
ZEN_TRACE_CPU("RemoteHttpRunner::SubmitActions");
+ m_InFlightSubmissions.fetch_add(Actions.size(), std::memory_order_relaxed);
+ auto InFlightGuard = MakeGuard([&] { m_InFlightSubmissions.fetch_sub(Actions.size(), std::memory_order_relaxed); });
+
if (Actions.size() <= 1)
{
std::vector<SubmitResult> Results;
@@ -246,7 +263,7 @@ RemoteHttpRunner::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions)
// Collect distinct QueueIds and ensure remote queues exist once per queue
- std::unordered_map<int, Oid> QueueTokens; // QueueId → remote token (0 stays as Zero)
+ std::unordered_map<int, Oid> QueueTokens; // QueueId -> remote token (0 stays as Zero)
for (const Ref<RunnerAction>& Action : Actions)
{
@@ -359,108 +376,141 @@ RemoteHttpRunner::SubmitAction(Ref<RunnerAction> Action)
}
}
- // Enqueue job. If the remote returns FailedDependency (424), it means it
- // cannot resolve the worker/function — re-register the worker and retry once.
+ // Submit the action to the remote. In eager-attach mode we build a
+ // CbPackage with all referenced attachments upfront to avoid the 404
+ // round-trip. In the default mode we POST the bare object first and
+ // only upload missing attachments if the remote requests them.
+ //
+ // In both modes, FailedDependency (424) triggers a worker re-register
+ // and a single retry.
CbObject Result;
HttpClient::Response WorkResponse;
HttpResponseCode WorkResponseCode{};
- for (int Attempt = 0; Attempt < 2; ++Attempt)
- {
- WorkResponse = m_Http.Post(SubmitUrl, ActionObj);
- WorkResponseCode = WorkResponse.StatusCode;
-
- if (WorkResponseCode == HttpResponseCode::FailedDependency && Attempt == 0)
- {
- ZEN_WARN("remote {} returned FailedDependency for action {} — re-registering worker and retrying",
- m_Http.GetBaseUri(),
- ActionId);
-
- (void)RegisterWorker(Action->Worker.Descriptor);
- }
- else
- {
- break;
- }
- }
-
- if (WorkResponseCode == HttpResponseCode::OK)
- {
- Result = WorkResponse.AsObject();
- }
- else if (WorkResponseCode == HttpResponseCode::NotFound)
+ if (m_EagerAttach)
{
- // Not all attachments are present
-
- // Build response package including all required attachments
-
CbPackage Pkg;
Pkg.SetObject(ActionObj);
- CbObject Response = WorkResponse.AsObject();
+ ActionObj.IterateAttachments([&](CbFieldView Field) {
+ const IoHash AttachHash = Field.AsHash();
- for (auto& Item : Response["need"sv])
- {
- const IoHash NeedHash = Item.AsHash();
-
- if (IoBuffer Chunk = m_ChunkResolver.FindChunkByCid(NeedHash))
+ if (IoBuffer Chunk = m_ChunkResolver.FindChunkByCid(AttachHash))
{
uint64_t DataRawSize = 0;
IoHash DataRawHash;
CompressedBuffer Compressed =
CompressedBuffer::FromCompressed(SharedBuffer{Chunk}, /* out */ DataRawHash, /* out */ DataRawSize);
- ZEN_ASSERT(DataRawHash == NeedHash);
+ Pkg.AddAttachment(CbAttachment(Compressed, AttachHash));
+ m_LastSubmitStats.TotalAttachments.fetch_add(1, std::memory_order_relaxed);
+ m_LastSubmitStats.TotalAttachmentBytes.fetch_add(Chunk.GetSize(), std::memory_order_relaxed);
+ }
+ });
+
+ for (int Attempt = 0; Attempt < 2; ++Attempt)
+ {
+ WorkResponse = m_Http.Post(SubmitUrl, Pkg);
+ WorkResponseCode = WorkResponse.StatusCode;
+
+ if (WorkResponseCode == HttpResponseCode::FailedDependency && Attempt == 0)
+ {
+ ZEN_WARN("remote {} returned FailedDependency for action {} - re-registering worker and retrying",
+ m_Http.GetBaseUri(),
+ ActionId);
- Pkg.AddAttachment(CbAttachment(Compressed, NeedHash));
+ (void)RegisterWorker(Action->Worker.Descriptor);
}
else
{
- // No such attachment
-
- return {.IsAccepted = false, .Reason = fmt::format("missing attachment {}", NeedHash)};
+ break;
}
}
+ }
+ else
+ {
+ for (int Attempt = 0; Attempt < 2; ++Attempt)
+ {
+ WorkResponse = m_Http.Post(SubmitUrl, ActionObj);
+ WorkResponseCode = WorkResponse.StatusCode;
- // Post resulting package
+ if (WorkResponseCode == HttpResponseCode::FailedDependency && Attempt == 0)
+ {
+ ZEN_WARN("remote {} returned FailedDependency for action {} - re-registering worker and retrying",
+ m_Http.GetBaseUri(),
+ ActionId);
- HttpClient::Response PayloadResponse = m_Http.Post(SubmitUrl, Pkg);
+ (void)RegisterWorker(Action->Worker.Descriptor);
+ }
+ else
+ {
+ break;
+ }
+ }
- if (!PayloadResponse)
+ if (WorkResponseCode == HttpResponseCode::NotFound)
{
- ZEN_WARN("unable to register payloads for action {} at {}{}", ActionId, m_Http.GetBaseUri(), SubmitUrl);
+ // Remote needs attachments - resolve them and retry with a CbPackage
- // TODO: include more information about the failure in the response
+ CbPackage Pkg;
+ Pkg.SetObject(ActionObj);
- return {.IsAccepted = false, .Reason = "HTTP request failed"};
- }
- else if (PayloadResponse.StatusCode == HttpResponseCode::OK)
- {
- Result = PayloadResponse.AsObject();
- }
- else
- {
- // Unexpected response
-
- const int ResponseStatusCode = (int)PayloadResponse.StatusCode;
-
- ZEN_WARN("unable to register payloads for action {} at {}{} (error: {} {})",
- ActionId,
- m_Http.GetBaseUri(),
- SubmitUrl,
- ResponseStatusCode,
- ToString(ResponseStatusCode));
-
- return {.IsAccepted = false,
- .Reason = fmt::format("unexpected response code {} {} from {}{}",
- ResponseStatusCode,
- ToString(ResponseStatusCode),
- m_Http.GetBaseUri(),
- SubmitUrl)};
+ CbObject Response = WorkResponse.AsObject();
+
+ for (auto& Item : Response["need"sv])
+ {
+ const IoHash NeedHash = Item.AsHash();
+
+ if (IoBuffer Chunk = m_ChunkResolver.FindChunkByCid(NeedHash))
+ {
+ uint64_t DataRawSize = 0;
+ IoHash DataRawHash;
+ CompressedBuffer Compressed =
+ CompressedBuffer::FromCompressed(SharedBuffer{Chunk}, /* out */ DataRawHash, /* out */ DataRawSize);
+
+ ZEN_ASSERT(DataRawHash == NeedHash);
+
+ Pkg.AddAttachment(CbAttachment(Compressed, NeedHash));
+ m_LastSubmitStats.TotalAttachments.fetch_add(1, std::memory_order_relaxed);
+ m_LastSubmitStats.TotalAttachmentBytes.fetch_add(Chunk.GetSize(), std::memory_order_relaxed);
+ }
+ else
+ {
+ return {.IsAccepted = false, .Reason = fmt::format("missing attachment {}", NeedHash)};
+ }
+ }
+
+ HttpClient::Response PayloadResponse = m_Http.Post(SubmitUrl, Pkg);
+
+ if (!PayloadResponse)
+ {
+ ZEN_WARN("unable to register payloads for action {} at {}{}", ActionId, m_Http.GetBaseUri(), SubmitUrl);
+ return {.IsAccepted = false, .Reason = "HTTP request failed"};
+ }
+
+ WorkResponse = std::move(PayloadResponse);
+ WorkResponseCode = WorkResponse.StatusCode;
}
}
+ if (WorkResponseCode == HttpResponseCode::OK)
+ {
+ Result = WorkResponse.AsObject();
+ }
+ else if (!WorkResponse)
+ {
+ ZEN_WARN("submit of action {} to {}{} failed", ActionId, m_Http.GetBaseUri(), SubmitUrl);
+ return {.IsAccepted = false, .Reason = "HTTP request failed"};
+ }
+ else if (!IsHttpSuccessCode(WorkResponseCode))
+ {
+ const int Code = static_cast<int>(WorkResponseCode);
+ ZEN_WARN("submit of action {} to {}{} returned {} {}", ActionId, m_Http.GetBaseUri(), SubmitUrl, Code, ToString(Code));
+ return {.IsAccepted = false,
+ .Reason = fmt::format("unexpected response code {} {} from {}{}", Code, ToString(Code), m_Http.GetBaseUri(), SubmitUrl)};
+ }
+
if (Result)
{
if (const int32_t LsnField = Result["lsn"].AsInt32(0))
@@ -512,83 +562,111 @@ RemoteHttpRunner::SubmitActionBatch(const std::string& SubmitUrl, const std::vec
CbObjectWriter Body;
Body.BeginArray("actions"sv);
+ std::unordered_set<IoHash, IoHash::Hasher> AttachmentsSeen;
+
for (const Ref<RunnerAction>& Action : Actions)
{
Action->ExecutionLocation = m_HostName;
MaybeDumpAction(Action->ActionLsn, Action->ActionObj);
Body.AddObject(Action->ActionObj);
+
+ if (m_EagerAttach)
+ {
+ Action->ActionObj.IterateAttachments([&](CbFieldView Field) { AttachmentsSeen.insert(Field.AsHash()); });
+ }
}
Body.EndArray();
- // POST the batch
-
- HttpClient::Response Response = m_Http.Post(SubmitUrl, Body.Save());
-
- if (Response.StatusCode == HttpResponseCode::OK)
- {
- return ParseBatchResponse(Response, Actions);
- }
+ // In eager-attach mode, build a CbPackage with all referenced attachments
+ // so the remote can accept in a single round-trip. Otherwise POST a bare
+ // CbObject and handle the 404 need-list flow.
- if (Response.StatusCode == HttpResponseCode::NotFound)
+ if (m_EagerAttach)
{
- // Server needs attachments — resolve them and retry with a CbPackage
-
- CbObject NeedObj = Response.AsObject();
-
CbPackage Pkg;
Pkg.SetObject(Body.Save());
- for (auto& Item : NeedObj["need"sv])
+ for (const IoHash& AttachHash : AttachmentsSeen)
{
- const IoHash NeedHash = Item.AsHash();
-
- if (IoBuffer Chunk = m_ChunkResolver.FindChunkByCid(NeedHash))
+ if (IoBuffer Chunk = m_ChunkResolver.FindChunkByCid(AttachHash))
{
uint64_t DataRawSize = 0;
IoHash DataRawHash;
CompressedBuffer Compressed =
CompressedBuffer::FromCompressed(SharedBuffer{Chunk}, /* out */ DataRawHash, /* out */ DataRawSize);
- ZEN_ASSERT(DataRawHash == NeedHash);
-
- Pkg.AddAttachment(CbAttachment(Compressed, NeedHash));
- }
- else
- {
- ZEN_WARN("batch submit: missing attachment {} — falling back to individual submit", NeedHash);
- return FallbackToIndividualSubmit(Actions);
+ Pkg.AddAttachment(CbAttachment(Compressed, AttachHash));
+ m_LastSubmitStats.TotalAttachments.fetch_add(1, std::memory_order_relaxed);
+ m_LastSubmitStats.TotalAttachmentBytes.fetch_add(Chunk.GetSize(), std::memory_order_relaxed);
}
}
- HttpClient::Response RetryResponse = m_Http.Post(SubmitUrl, Pkg);
+ HttpClient::Response Response = m_Http.Post(SubmitUrl, Pkg);
- if (RetryResponse.StatusCode == HttpResponseCode::OK)
+ if (Response.StatusCode == HttpResponseCode::OK)
{
- return ParseBatchResponse(RetryResponse, Actions);
+ return ParseBatchResponse(Response, Actions);
}
-
- ZEN_WARN("batch submit retry failed with {} {} — falling back to individual submit",
- (int)RetryResponse.StatusCode,
- ToString(RetryResponse.StatusCode));
- return FallbackToIndividualSubmit(Actions);
- }
-
- // Unexpected status or connection error — fall back to individual submission
-
- if (Response)
- {
- ZEN_WARN("batch submit to {}{} returned {} {} — falling back to individual submit",
- m_Http.GetBaseUri(),
- SubmitUrl,
- (int)Response.StatusCode,
- ToString(Response.StatusCode));
}
else
{
- ZEN_WARN("batch submit to {}{} failed — falling back to individual submit", m_Http.GetBaseUri(), SubmitUrl);
+ HttpClient::Response Response = m_Http.Post(SubmitUrl, Body.Save());
+
+ if (Response.StatusCode == HttpResponseCode::OK)
+ {
+ return ParseBatchResponse(Response, Actions);
+ }
+
+ if (Response.StatusCode == HttpResponseCode::NotFound)
+ {
+ CbObject NeedObj = Response.AsObject();
+
+ CbPackage Pkg;
+ Pkg.SetObject(Body.Save());
+
+ for (auto& Item : NeedObj["need"sv])
+ {
+ const IoHash NeedHash = Item.AsHash();
+
+ if (IoBuffer Chunk = m_ChunkResolver.FindChunkByCid(NeedHash))
+ {
+ uint64_t DataRawSize = 0;
+ IoHash DataRawHash;
+ CompressedBuffer Compressed =
+ CompressedBuffer::FromCompressed(SharedBuffer{Chunk}, /* out */ DataRawHash, /* out */ DataRawSize);
+
+ ZEN_ASSERT(DataRawHash == NeedHash);
+
+ Pkg.AddAttachment(CbAttachment(Compressed, NeedHash));
+ m_LastSubmitStats.TotalAttachments.fetch_add(1, std::memory_order_relaxed);
+ m_LastSubmitStats.TotalAttachmentBytes.fetch_add(Chunk.GetSize(), std::memory_order_relaxed);
+ }
+ else
+ {
+ ZEN_WARN("batch submit: missing attachment {} - falling back to individual submit", NeedHash);
+ return FallbackToIndividualSubmit(Actions);
+ }
+ }
+
+ HttpClient::Response RetryResponse = m_Http.Post(SubmitUrl, Pkg);
+
+ if (RetryResponse.StatusCode == HttpResponseCode::OK)
+ {
+ return ParseBatchResponse(RetryResponse, Actions);
+ }
+
+ ZEN_WARN("batch submit retry failed with {} {} - falling back to individual submit",
+ (int)RetryResponse.StatusCode,
+ ToString(RetryResponse.StatusCode));
+ return FallbackToIndividualSubmit(Actions);
+ }
}
+ // Unexpected status or connection error - fall back to individual submission
+
+ ZEN_WARN("batch submit to {}{} failed - falling back to individual submit", m_Http.GetBaseUri(), SubmitUrl);
+
return FallbackToIndividualSubmit(Actions);
}
@@ -849,7 +927,7 @@ RemoteHttpRunner::MonitorThreadFunction()
SweepOnce();
}
- // Signal received — may be a WS wakeup or a quit signal
+ // Signal received - may be a WS wakeup or a quit signal
SweepOnce();
} while (m_MonitorThreadEnabled);
@@ -869,9 +947,10 @@ RemoteHttpRunner::SweepRunningActions()
{
for (auto& FieldIt : Completed["completed"sv])
{
- CbObjectView EntryObj = FieldIt.AsObjectView();
- const int32_t CompleteLsn = EntryObj["lsn"sv].AsInt32();
- std::string_view StateName = EntryObj["state"sv].AsString();
+ CbObjectView EntryObj = FieldIt.AsObjectView();
+ const int32_t CompleteLsn = EntryObj["lsn"sv].AsInt32();
+ std::string_view StateName = EntryObj["state"sv].AsString();
+ std::string_view FailureReason = EntryObj["reason"sv].AsString();
RunnerAction::State RemoteState = RunnerAction::FromString(StateName);
@@ -884,6 +963,7 @@ RemoteHttpRunner::SweepRunningActions()
{
HttpRunningAction CompletedAction = std::move(CompleteIt->second);
CompletedAction.RemoteState = RemoteState;
+ CompletedAction.FailureReason = std::string(FailureReason);
if (RemoteState == RunnerAction::State::Completed && ResponseJob)
{
@@ -927,16 +1007,44 @@ RemoteHttpRunner::SweepRunningActions()
{
const int ActionLsn = HttpAction.Action->ActionLsn;
- ZEN_DEBUG("action {} LSN {} (remote LSN {}) -> {}",
- HttpAction.Action->ActionId,
- ActionLsn,
- HttpAction.RemoteActionLsn,
- RunnerAction::ToString(HttpAction.RemoteState));
-
if (HttpAction.RemoteState == RunnerAction::State::Completed)
{
+ ZEN_DEBUG("action {} LSN {} (remote LSN {}) completed on {}",
+ HttpAction.Action->ActionId,
+ ActionLsn,
+ HttpAction.RemoteActionLsn,
+ m_HostName);
HttpAction.Action->SetResult(std::move(HttpAction.ActionResults));
}
+ else if (HttpAction.RemoteState == RunnerAction::State::Failed || HttpAction.RemoteState == RunnerAction::State::Abandoned)
+ {
+ HttpAction.Action->FailureReason = HttpAction.FailureReason;
+ if (HttpAction.FailureReason.empty())
+ {
+ ZEN_WARN("action {} ({}) {} on remote {}",
+ HttpAction.Action->ActionId,
+ ActionLsn,
+ RunnerAction::ToString(HttpAction.RemoteState),
+ m_HostName);
+ }
+ else
+ {
+ ZEN_WARN("action {} ({}) {} on remote {}: {}",
+ HttpAction.Action->ActionId,
+ ActionLsn,
+ RunnerAction::ToString(HttpAction.RemoteState),
+ m_HostName,
+ HttpAction.FailureReason);
+ }
+ }
+ else
+ {
+ ZEN_DEBUG("action {} LSN {} (remote LSN {}) -> {}",
+ HttpAction.Action->ActionId,
+ ActionLsn,
+ HttpAction.RemoteActionLsn,
+ RunnerAction::ToString(HttpAction.RemoteState));
+ }
HttpAction.Action->SetActionState(HttpAction.RemoteState);
}
diff --git a/src/zencompute/runners/remotehttprunner.h b/src/zencompute/runners/remotehttprunner.h
index c17d0cf2a..521bf2f82 100644
--- a/src/zencompute/runners/remotehttprunner.h
+++ b/src/zencompute/runners/remotehttprunner.h
@@ -54,8 +54,10 @@ public:
[[nodiscard]] virtual size_t QueryCapacity() override;
[[nodiscard]] virtual std::vector<SubmitResult> SubmitActions(const std::vector<Ref<RunnerAction>>& Actions) override;
virtual void CancelRemoteQueue(int QueueId) override;
+ [[nodiscard]] virtual std::string_view GetDisplayName() const override { return m_DisplayName; }
std::string_view GetHostName() const { return m_HostName; }
+ void SetRemoteHostname(std::string_view Hostname);
protected:
LoggerRef Log() { return m_Log; }
@@ -65,12 +67,15 @@ private:
ChunkResolver& m_ChunkResolver;
WorkerThreadPool& m_WorkerPool;
std::string m_HostName;
+ std::string m_DisplayName;
std::string m_BaseUrl;
HttpClient m_Http;
- std::atomic<bool> m_AcceptNewActions{true};
- int32_t m_MaxRunningActions = 256; // arbitrary limit for testing
- int32_t m_MaxBatchSize = 50;
+ std::atomic<bool> m_AcceptNewActions{true};
+ int32_t m_MaxRunningActions = 256; // arbitrary limit for testing
+ int32_t m_MaxBatchSize = 50;
+ bool m_EagerAttach = true; ///< Send attachments with every submit instead of the two-step 404 retry
+ std::atomic<size_t> m_InFlightSubmissions{0}; // actions currently being submitted over HTTP
struct HttpRunningAction
{
@@ -78,6 +83,7 @@ private:
int RemoteActionLsn = 0; // Remote LSN
RunnerAction::State RemoteState = RunnerAction::State::Failed;
CbPackage ActionResults;
+ std::string FailureReason;
};
RwLock m_RunningLock;
@@ -90,7 +96,7 @@ private:
size_t SweepRunningActions();
RwLock m_QueueTokenLock;
- std::unordered_map<int, Oid> m_RemoteQueueTokens; // local QueueId → remote queue token
+ std::unordered_map<int, Oid> m_RemoteQueueTokens; // local QueueId -> remote queue token
// Stable identity for this runner instance, used as part of the idempotency key when
// creating remote queues. Generated once at construction and never changes.
diff --git a/src/zencompute/runners/windowsrunner.cpp b/src/zencompute/runners/windowsrunner.cpp
index e9a1ae8b6..c6b3e82ea 100644
--- a/src/zencompute/runners/windowsrunner.cpp
+++ b/src/zencompute/runners/windowsrunner.cpp
@@ -21,6 +21,12 @@ ZEN_THIRD_PARTY_INCLUDES_START
# include <sddl.h>
ZEN_THIRD_PARTY_INCLUDES_END
+// JOB_OBJECT_UILIMIT_ERRORMODE is defined in winuser.h which may be
+// excluded by WIN32_LEAN_AND_MEAN.
+# if !defined(JOB_OBJECT_UILIMIT_ERRORMODE)
+# define JOB_OBJECT_UILIMIT_ERRORMODE 0x00000400
+# endif
+
namespace zen::compute {
using namespace std::literals;
@@ -34,38 +40,67 @@ WindowsProcessRunner::WindowsProcessRunner(ChunkResolver& Resolver,
: LocalProcessRunner(Resolver, BaseDir, Deleter, WorkerPool, MaxConcurrentActions)
, m_Sandboxed(Sandboxed)
{
- if (!m_Sandboxed)
+ // Create a job object shared by all child processes. Restricting the
+ // error-mode UI prevents crash dialogs (WER / Dr. Watson) from
+ // blocking the monitor thread when a worker process terminates
+ // abnormally.
+ m_JobObject = CreateJobObjectW(nullptr, nullptr);
+ if (m_JobObject)
{
- return;
+ JOBOBJECT_EXTENDED_LIMIT_INFORMATION ExtLimits{};
+ ExtLimits.BasicLimitInformation.LimitFlags =
+ JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE | JOB_OBJECT_LIMIT_DIE_ON_UNHANDLED_EXCEPTION | JOB_OBJECT_LIMIT_PRIORITY_CLASS;
+ ExtLimits.BasicLimitInformation.PriorityClass = BELOW_NORMAL_PRIORITY_CLASS;
+ SetInformationJobObject(m_JobObject, JobObjectExtendedLimitInformation, &ExtLimits, sizeof(ExtLimits));
+
+ JOBOBJECT_BASIC_UI_RESTRICTIONS UiRestrictions{};
+ UiRestrictions.UIRestrictionsClass = JOB_OBJECT_UILIMIT_ERRORMODE;
+ SetInformationJobObject(m_JobObject, JobObjectBasicUIRestrictions, &UiRestrictions, sizeof(UiRestrictions));
+
+ // Set error mode on this process so children inherit it. The
+ // UILIMIT_ERRORMODE restriction above prevents them from clearing
+ // SEM_NOGPFAULTERRORBOX.
+ SetErrorMode(SEM_FAILCRITICALERRORS | SEM_NOGPFAULTERRORBOX);
}
- // Build a unique profile name per process to avoid collisions
- m_AppContainerName = L"zenserver-sandbox-" + std::to_wstring(GetCurrentProcessId());
+ if (m_Sandboxed)
+ {
+ // Build a unique profile name per process to avoid collisions
+ m_AppContainerName = L"zenserver-sandbox-" + std::to_wstring(GetCurrentProcessId());
- // Clean up any stale profile from a previous crash
- DeleteAppContainerProfile(m_AppContainerName.c_str());
+ // Clean up any stale profile from a previous crash
+ DeleteAppContainerProfile(m_AppContainerName.c_str());
- PSID Sid = nullptr;
+ PSID Sid = nullptr;
- HRESULT Hr = CreateAppContainerProfile(m_AppContainerName.c_str(),
- m_AppContainerName.c_str(), // display name
- m_AppContainerName.c_str(), // description
- nullptr, // no capabilities
- 0, // capability count
- &Sid);
+ HRESULT Hr = CreateAppContainerProfile(m_AppContainerName.c_str(),
+ m_AppContainerName.c_str(), // display name
+ m_AppContainerName.c_str(), // description
+ nullptr, // no capabilities
+ 0, // capability count
+ &Sid);
- if (FAILED(Hr))
- {
- throw zen::runtime_error("CreateAppContainerProfile failed: HRESULT 0x{:08X}", static_cast<uint32_t>(Hr));
- }
+ if (FAILED(Hr))
+ {
+ throw zen::runtime_error("CreateAppContainerProfile failed: HRESULT 0x{:08X}", static_cast<uint32_t>(Hr));
+ }
- m_AppContainerSid = Sid;
+ m_AppContainerSid = Sid;
+
+ ZEN_INFO("AppContainer sandboxing enabled for child processes (profile={})", WideToUtf8(m_AppContainerName));
+ }
- ZEN_INFO("AppContainer sandboxing enabled for child processes (profile={})", WideToUtf8(m_AppContainerName));
+ StartMonitorThread();
}
WindowsProcessRunner::~WindowsProcessRunner()
{
+ if (m_JobObject)
+ {
+ CloseHandle(m_JobObject);
+ m_JobObject = nullptr;
+ }
+
if (m_AppContainerSid)
{
FreeSid(m_AppContainerSid);
@@ -172,9 +207,9 @@ WindowsProcessRunner::SubmitAction(Ref<RunnerAction> Action)
LPSECURITY_ATTRIBUTES lpProcessAttributes = nullptr;
LPSECURITY_ATTRIBUTES lpThreadAttributes = nullptr;
BOOL bInheritHandles = FALSE;
- DWORD dwCreationFlags = 0;
+ DWORD dwCreationFlags = CREATE_SUSPENDED | DETACHED_PROCESS;
- ZEN_DEBUG("Executing: {} (sandboxed={})", WideToUtf8(CommandLine.c_str()), m_Sandboxed);
+ ZEN_DEBUG("{}: '{}' (sandbox='{}')", m_Sandboxed ? "Sandboxing" : "Executing", WideToUtf8(CommandLine.c_str()), Prepared->SandboxPath);
CommandLine.EnsureNulTerminated();
@@ -260,14 +295,21 @@ WindowsProcessRunner::SubmitAction(Ref<RunnerAction> Action)
}
}
- CloseHandle(ProcessInformation.hThread);
+ if (m_JobObject)
+ {
+ AssignProcessToJobObject(m_JobObject, ProcessInformation.hProcess);
+ }
- Ref<RunningAction> NewAction{new RunningAction()};
- NewAction->Action = Action;
- NewAction->ProcessHandle = ProcessInformation.hProcess;
- NewAction->SandboxPath = std::move(Prepared->SandboxPath);
+ ResumeThread(ProcessInformation.hThread);
+ CloseHandle(ProcessInformation.hThread);
{
+ Ref<RunningAction> NewAction{new RunningAction()};
+ NewAction->Action = Action;
+ NewAction->ProcessHandle = ProcessInformation.hProcess;
+ NewAction->Pid = ProcessInformation.dwProcessId;
+ NewAction->SandboxPath = std::move(Prepared->SandboxPath);
+
RwLock::ExclusiveLockScope _(m_RunningLock);
m_RunningMap[Prepared->ActionLsn] = std::move(NewAction);
@@ -275,6 +317,8 @@ WindowsProcessRunner::SubmitAction(Ref<RunnerAction> Action)
Action->SetActionState(RunnerAction::State::Running);
+ ZEN_DEBUG("Local runner: action LSN {} -> PID {}", Action->ActionLsn, ProcessInformation.dwProcessId);
+
return SubmitResult{.IsAccepted = true};
}
@@ -294,6 +338,11 @@ WindowsProcessRunner::SweepRunningActions()
if (IsSuccess && ExitCode != STILL_ACTIVE)
{
+ ZEN_DEBUG("Local runner: action LSN {} + PID {} exited with code " ZEN_BRIGHT_WHITE("{}"),
+ Running->Action->ActionLsn,
+ Running->Pid,
+ ExitCode);
+
CloseHandle(Running->ProcessHandle);
Running->ProcessHandle = INVALID_HANDLE_VALUE;
Running->ExitCode = ExitCode;
@@ -436,7 +485,7 @@ WindowsProcessRunner::SampleProcessCpu(RunningAction& Running)
const uint64_t CurrentOsTicks = FtToU64(KernelTime) + FtToU64(UserTime);
const uint64_t NowTicks = GetHifreqTimerValue();
- // Cumulative CPU seconds (absolute, available from first sample): 100ns → seconds
+ // Cumulative CPU seconds (absolute, available from first sample): 100ns -> seconds
Running.Action->CpuSeconds.store(static_cast<float>(static_cast<double>(CurrentOsTicks) / 10'000'000.0), std::memory_order_relaxed);
if (Running.LastCpuSampleTicks != 0 && Running.LastCpuOsTicks != 0)
@@ -445,7 +494,7 @@ WindowsProcessRunner::SampleProcessCpu(RunningAction& Running)
if (ElapsedMs > 0)
{
const uint64_t DeltaOsTicks = CurrentOsTicks - Running.LastCpuOsTicks;
- // 100ns → ms: divide by 10000; then as percent of elapsed ms
+ // 100ns -> ms: divide by 10000; then as percent of elapsed ms
const float CpuPct = static_cast<float>(static_cast<double>(DeltaOsTicks) / 10000.0 / ElapsedMs * 100.0);
Running.Action->CpuUsagePercent.store(CpuPct, std::memory_order_relaxed);
}
diff --git a/src/zencompute/runners/windowsrunner.h b/src/zencompute/runners/windowsrunner.h
index 9f2385cc4..adeaf02fc 100644
--- a/src/zencompute/runners/windowsrunner.h
+++ b/src/zencompute/runners/windowsrunner.h
@@ -46,6 +46,7 @@ private:
bool m_Sandboxed = false;
PSID m_AppContainerSid = nullptr;
std::wstring m_AppContainerName;
+ HANDLE m_JobObject = nullptr;
};
} // namespace zen::compute
diff --git a/src/zencompute/runners/winerunner.cpp b/src/zencompute/runners/winerunner.cpp
index 506bec73b..29ab93663 100644
--- a/src/zencompute/runners/winerunner.cpp
+++ b/src/zencompute/runners/winerunner.cpp
@@ -36,6 +36,8 @@ WineProcessRunner::WineProcessRunner(ChunkResolver& Resolver,
sigemptyset(&Action.sa_mask);
Action.sa_handler = SIG_DFL;
sigaction(SIGCHLD, &Action, nullptr);
+
+ StartMonitorThread();
}
SubmitResult
@@ -94,7 +96,9 @@ WineProcessRunner::SubmitAction(Ref<RunnerAction> Action)
if (ChildPid == 0)
{
- // Child process
+ // Child process - lower priority so workers don't starve the main server
+ nice(5);
+
if (chdir(SandboxPathStr.c_str()) != 0)
{
_exit(127);