diff options
Diffstat (limited to 'src/zencompute')
27 files changed, 1716 insertions, 506 deletions
diff --git a/src/zencompute/CLAUDE.md b/src/zencompute/CLAUDE.md index a1a39fc3c..bb574edc2 100644 --- a/src/zencompute/CLAUDE.md +++ b/src/zencompute/CLAUDE.md @@ -141,7 +141,7 @@ Actions that fail or are abandoned can be automatically retried or manually resc **Manual retry (API path):** `POST /compute/jobs/{lsn}` calls `RescheduleAction()`, which finds the action in `m_ResultsMap`, validates state (must be Failed or Abandoned), checks the retry limit, reverses queue counters (moving the LSN from `FinishedLsns` back to `ActiveLsns`), removes from results, and calls `ResetActionStateToPending()`. Returns 200 with `{lsn, retry_count}` on success, 409 Conflict with `{error}` on failure. -**Retry limit:** Default of 3, overridable per-queue via the `max_retries` integer field in the queue's `Config` CbObject (set at `CreateQueue` time). Both automatic and manual paths respect this limit. +**Retry limit:** Default of 3, overridable per-queue via the `max_retries` integer field in the queue's `Config` CbObject (set at `CreateQueue` time). Setting `max_retries=0` disables automatic retry entirely; omitting the field (or setting it to a negative value) uses the default of 3. Both automatic and manual paths respect this limit. **Retraction (API path):** `RetractAction(Lsn)` pulls a Pending/Submitting/Running action back for rescheduling on a different runner. The action transitions to Retracted, then `ResetActionStateToPending()` is called *without* incrementing `RetryCount`. Retraction is idempotent. @@ -156,7 +156,7 @@ Queues group actions from a single client session. A `QueueEntry` (internal) tra - `ActiveLsns` — for cancellation lookup (under `m_Lock`) - `FinishedLsns` — moved here when actions complete - `IdleSince` — used for 15-minute automatic expiry -- `Config` — CbObject set at creation; supports `max_retries` (int) to override the default retry limit +- `Config` — CbObject set at creation; supports `max_retries` (int, default 3) to override the default retry limit. `0` = no retries, negative or absent = use default **Queue state machine (`QueueState` enum):** ``` @@ -216,11 +216,14 @@ Worker handler logic is extracted into private helpers (`HandleWorkersGet`, `Han ## Concurrency Model -**Locking discipline:** When multiple locks must be held simultaneously, always acquire in this order to prevent deadlocks: -1. `m_ResultsLock` -2. `m_RunningLock` (comment in localrunner.h: "must be taken *after* m_ResultsLock") -3. `m_PendingLock` -4. `m_QueueLock` +**Locking discipline:** The three action maps (`m_PendingActions`, `m_RunningMap`, `m_ResultsMap`) are guarded by a single `m_ActionMapLock`. This eliminates lock-ordering concerns between maps and prevents actions from being temporarily absent from all maps during state transitions. Runner-level `m_RunningLock` in `LocalProcessRunner` / `RemoteHttpRunner` is a separate lock on a different class — unrelated to the session-level action map lock. + +**Lock ordering:** When acquiring multiple session-level locks, always acquire in this order to avoid deadlocks: +1. `m_ActionMapLock` (session action maps) +2. `QueueEntry::m_Lock` (per-queue state) +3. `m_ActionHistoryLock` (action history ring) + +Never acquire an earlier lock while holding a later one (e.g. never acquire `m_ActionMapLock` while holding `QueueEntry::m_Lock`). **Atomic fields** for counters and simple state: queue counts, `CpuUsagePercent`, `CpuSeconds`, `RetryCount`, `RunnerAction::m_ActionState`. diff --git a/src/zencompute/cloudmetadata.cpp b/src/zencompute/cloudmetadata.cpp index eb4c05f9f..f1df18e8e 100644 --- a/src/zencompute/cloudmetadata.cpp +++ b/src/zencompute/cloudmetadata.cpp @@ -183,7 +183,7 @@ CloudMetadata::TryDetectAWS() m_Info.AvailabilityZone = std::string(AzResponse.AsText()); } - // "spot" vs "on-demand" — determines whether the instance can be + // "spot" vs "on-demand" - determines whether the instance can be // reclaimed by AWS with a 2-minute warning HttpClient::Response LifecycleResponse = ImdsClient.Get("/latest/meta-data/instance-life-cycle", AuthHeaders); if (LifecycleResponse.IsSuccess()) @@ -273,7 +273,7 @@ CloudMetadata::TryDetectAzure() std::string Priority = Compute["priority"].string_value(); m_Info.IsSpot = (Priority == "Spot"); - // Check if part of a VMSS (Virtual Machine Scale Set) — indicates autoscaling + // Check if part of a VMSS (Virtual Machine Scale Set) - indicates autoscaling std::string VmssName = Compute["vmScaleSetName"].string_value(); m_Info.IsAutoscaling = !VmssName.empty(); @@ -609,7 +609,7 @@ namespace zen::compute { TEST_SUITE_BEGIN("compute.cloudmetadata"); // --------------------------------------------------------------------------- -// Test helper — spins up a local ASIO HTTP server hosting a MockImdsService +// Test helper - spins up a local ASIO HTTP server hosting a MockImdsService // --------------------------------------------------------------------------- struct TestImdsServer @@ -974,7 +974,7 @@ TEST_CASE("cloudmetadata.sentinel_files") SUBCASE("only failed providers get sentinels") { - // Switch to AWS — Azure and GCP never probed, so no sentinels for them + // Switch to AWS - Azure and GCP never probed, so no sentinels for them Imds.Mock.ActiveProvider = CloudProvider::AWS; auto Cloud = Imds.CreateCloud(); diff --git a/src/zencompute/computeservice.cpp b/src/zencompute/computeservice.cpp index 92901de64..7f354a51c 100644 --- a/src/zencompute/computeservice.cpp +++ b/src/zencompute/computeservice.cpp @@ -8,6 +8,8 @@ # include "recording/actionrecorder.h" # include "runners/localrunner.h" # include "runners/remotehttprunner.h" +# include "runners/managedrunner.h" +# include "pathvalidation.h" # if ZEN_PLATFORM_LINUX # include "runners/linuxrunner.h" # elif ZEN_PLATFORM_WINDOWS @@ -119,6 +121,8 @@ struct ComputeServiceSession::Impl , m_LocalSubmitPool(GetLargeWorkerPool(EWorkloadType::Burst)) , m_RemoteSubmitPool(GetLargeWorkerPool(EWorkloadType::Burst)) { + m_RemoteRunnerGroup.SetWorkerPool(&m_RemoteSubmitPool); + // Create a non-expiring, non-deletable implicit queue for legacy endpoints auto Result = CreateQueue("implicit"sv, {}, {}); m_ImplicitQueueId = Result.QueueId; @@ -195,13 +199,9 @@ struct ComputeServiceSession::Impl std::atomic<IComputeCompletionObserver*> m_CompletionObserver{nullptr}; - RwLock m_PendingLock; - std::map<int, Ref<RunnerAction>> m_PendingActions; - - RwLock m_RunningLock; + RwLock m_ActionMapLock; // Guards m_PendingActions, m_RunningMap, m_ResultsMap + std::map<int, Ref<RunnerAction>> m_PendingActions; std::unordered_map<int, Ref<RunnerAction>> m_RunningMap; - - RwLock m_ResultsLock; std::unordered_map<int, Ref<RunnerAction>> m_ResultsMap; metrics::Meter m_ResultRate; std::atomic<uint64_t> m_RetiredCount{0}; @@ -242,8 +242,9 @@ struct ComputeServiceSession::Impl // Recording - void StartRecording(ChunkResolver& InCidStore, const std::filesystem::path& RecordingPath); - void StopRecording(); + bool StartRecording(ChunkResolver& InCidStore, const std::filesystem::path& RecordingPath); + bool StopRecording(); + bool IsRecording() const; std::unique_ptr<ActionRecorder> m_Recorder; @@ -343,9 +344,12 @@ struct ComputeServiceSession::Impl ActionCounts GetActionCounts() { ActionCounts Counts; - Counts.Pending = (int)m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); }); - Counts.Running = (int)m_RunningLock.WithSharedLock([&] { return m_RunningMap.size(); }); - Counts.Completed = (int)m_ResultsLock.WithSharedLock([&] { return m_ResultsMap.size(); }) + (int)m_RetiredCount.load(); + m_ActionMapLock.WithSharedLock([&] { + Counts.Pending = (int)m_PendingActions.size(); + Counts.Running = (int)m_RunningMap.size(); + Counts.Completed = (int)m_ResultsMap.size(); + }); + Counts.Completed += (int)m_RetiredCount.load(); Counts.ActiveQueues = (int)m_QueueLock.WithSharedLock([&] { size_t Count = 0; for (const auto& [Id, Queue] : m_Queues) @@ -364,8 +368,10 @@ struct ComputeServiceSession::Impl { Cbo << "session_state"sv << ToString(m_SessionState.load(std::memory_order_relaxed)); m_WorkerLock.WithSharedLock([&] { Cbo << "worker_count"sv << m_WorkerMap.size(); }); - m_ResultsLock.WithSharedLock([&] { Cbo << "actions_complete"sv << m_ResultsMap.size(); }); - m_PendingLock.WithSharedLock([&] { Cbo << "actions_pending"sv << m_PendingActions.size(); }); + m_ActionMapLock.WithSharedLock([&] { + Cbo << "actions_complete"sv << m_ResultsMap.size(); + Cbo << "actions_pending"sv << m_PendingActions.size(); + }); Cbo << "actions_submitted"sv << GetSubmittedActionCount(); EmitSnapshot("actions_arrival"sv, m_ArrivalRate, Cbo); EmitSnapshot("actions_retired"sv, m_ResultRate, Cbo); @@ -443,34 +449,24 @@ ComputeServiceSession::Impl::RequestStateTransition(SessionState NewState) return true; } - // CAS failed, Current was updated — retry with the new value + // CAS failed, Current was updated - retry with the new value } } void ComputeServiceSession::Impl::AbandonAllActions() { - // Collect all pending actions and mark them as Abandoned + // Collect all pending and running actions under a single lock scope std::vector<Ref<RunnerAction>> PendingToAbandon; + std::vector<Ref<RunnerAction>> RunningToAbandon; - m_PendingLock.WithSharedLock([&] { + m_ActionMapLock.WithSharedLock([&] { PendingToAbandon.reserve(m_PendingActions.size()); for (auto& [Lsn, Action] : m_PendingActions) { PendingToAbandon.push_back(Action); } - }); - - for (auto& Action : PendingToAbandon) - { - Action->SetActionState(RunnerAction::State::Abandoned); - } - - // Collect all running actions and mark them as Abandoned, then - // best-effort cancel via the local runner group - std::vector<Ref<RunnerAction>> RunningToAbandon; - m_RunningLock.WithSharedLock([&] { RunningToAbandon.reserve(m_RunningMap.size()); for (auto& [Lsn, Action] : m_RunningMap) { @@ -478,6 +474,11 @@ ComputeServiceSession::Impl::AbandonAllActions() } }); + for (auto& Action : PendingToAbandon) + { + Action->SetActionState(RunnerAction::State::Abandoned); + } + for (auto& Action : RunningToAbandon) { Action->SetActionState(RunnerAction::State::Abandoned); @@ -617,6 +618,7 @@ ComputeServiceSession::Impl::UpdateCoordinatorState() m_KnownWorkerUris.insert(UriStr); auto* NewRunner = new RemoteHttpRunner(m_ChunkResolver, m_OrchestratorBasePath, UriStr, m_RemoteSubmitPool); + NewRunner->SetRemoteHostname(Hostname); SyncWorkersToRunner(*NewRunner); m_RemoteRunnerGroup.AddRunner(NewRunner); } @@ -718,31 +720,51 @@ ComputeServiceSession::Impl::ShutdownRunners() m_RemoteRunnerGroup.Shutdown(); } -void +bool ComputeServiceSession::Impl::StartRecording(ChunkResolver& InCidStore, const std::filesystem::path& RecordingPath) { + if (m_Recorder) + { + ZEN_WARN("recording is already active"); + return false; + } + ZEN_INFO("starting recording to '{}'", RecordingPath); m_Recorder = std::make_unique<ActionRecorder>(InCidStore, RecordingPath); ZEN_INFO("started recording to '{}'", RecordingPath); + return true; } -void +bool ComputeServiceSession::Impl::StopRecording() { + if (!m_Recorder) + { + ZEN_WARN("no recording is active"); + return false; + } + ZEN_INFO("stopping recording"); m_Recorder = nullptr; ZEN_INFO("stopped recording"); + return true; +} + +bool +ComputeServiceSession::Impl::IsRecording() const +{ + return m_Recorder != nullptr; } std::vector<ComputeServiceSession::RunningActionInfo> ComputeServiceSession::Impl::GetRunningActions() { std::vector<ComputeServiceSession::RunningActionInfo> Result; - m_RunningLock.WithSharedLock([&] { + m_ActionMapLock.WithSharedLock([&] { Result.reserve(m_RunningMap.size()); for (const auto& [Lsn, Action] : m_RunningMap) { @@ -810,6 +832,11 @@ void ComputeServiceSession::Impl::RegisterWorker(CbPackage Worker) { ZEN_TRACE_CPU("ComputeServiceSession::RegisterWorker"); + + // Validate all paths in the worker description upfront, before the worker is + // distributed to runners. This rejects malicious packages early at ingestion time. + ValidateWorkerDescriptionPaths(Worker.GetObject()); + RwLock::ExclusiveLockScope _(m_WorkerLock); const IoHash& WorkerId = Worker.GetObject().GetHash(); @@ -994,10 +1021,15 @@ ComputeServiceSession::Impl::EnqueueResolvedAction(int QueueId, WorkerDesc Worke Pending->ActionObj = ActionObj; Pending->Priority = RequestPriority; - // For now simply put action into pending state, so we can do batch scheduling + // Insert into the pending map immediately so the action is visible to + // FindActionResult/GetActionResult right away. SetActionState will call + // PostUpdate which adds the action to m_UpdatedActions and signals the + // scheduler, but the scheduler's HandleActionUpdates inserts with + // std::map::insert which is a no-op for existing keys. ZEN_DEBUG("action {} ({}) PENDING", Pending->ActionId, Pending->ActionLsn); + m_ActionMapLock.WithExclusiveLock([&] { m_PendingActions.insert({ActionLsn, Pending}); }); Pending->SetActionState(RunnerAction::State::Pending); if (m_Recorder) @@ -1043,11 +1075,7 @@ ComputeServiceSession::Impl::GetSubmittedActionCount() HttpResponseCode ComputeServiceSession::Impl::GetActionResult(int ActionLsn, CbPackage& OutResultPackage) { - // This lock is held for the duration of the function since we need to - // be sure that the action doesn't change state while we are checking the - // different data structures - - RwLock::ExclusiveLockScope _(m_ResultsLock); + RwLock::ExclusiveLockScope _(m_ActionMapLock); if (auto It = m_ResultsMap.find(ActionLsn); It != m_ResultsMap.end()) { @@ -1058,25 +1086,14 @@ ComputeServiceSession::Impl::GetActionResult(int ActionLsn, CbPackage& OutResult return HttpResponseCode::OK; } + if (m_PendingActions.find(ActionLsn) != m_PendingActions.end()) { - RwLock::SharedLockScope __(m_PendingLock); - - if (auto FindIt = m_PendingActions.find(ActionLsn); FindIt != m_PendingActions.end()) - { - return HttpResponseCode::Accepted; - } + return HttpResponseCode::Accepted; } - // Lock order is important here to avoid deadlocks, RwLock m_RunningLock must - // always be taken after m_ResultsLock if both are needed - + if (m_RunningMap.find(ActionLsn) != m_RunningMap.end()) { - RwLock::SharedLockScope __(m_RunningLock); - - if (m_RunningMap.find(ActionLsn) != m_RunningMap.end()) - { - return HttpResponseCode::Accepted; - } + return HttpResponseCode::Accepted; } return HttpResponseCode::NotFound; @@ -1085,11 +1102,7 @@ ComputeServiceSession::Impl::GetActionResult(int ActionLsn, CbPackage& OutResult HttpResponseCode ComputeServiceSession::Impl::FindActionResult(const IoHash& ActionId, CbPackage& OutResultPackage) { - // This lock is held for the duration of the function since we need to - // be sure that the action doesn't change state while we are checking the - // different data structures - - RwLock::ExclusiveLockScope _(m_ResultsLock); + RwLock::ExclusiveLockScope _(m_ActionMapLock); for (auto It = begin(m_ResultsMap), End = end(m_ResultsMap); It != End; ++It) { @@ -1103,30 +1116,19 @@ ComputeServiceSession::Impl::FindActionResult(const IoHash& ActionId, CbPackage& } } + for (const auto& [K, Pending] : m_PendingActions) { - RwLock::SharedLockScope __(m_PendingLock); - - for (const auto& [K, Pending] : m_PendingActions) + if (Pending->ActionId == ActionId) { - if (Pending->ActionId == ActionId) - { - return HttpResponseCode::Accepted; - } + return HttpResponseCode::Accepted; } } - // Lock order is important here to avoid deadlocks, RwLock m_RunningLock must - // always be taken after m_ResultsLock if both are needed - + for (const auto& [K, v] : m_RunningMap) { - RwLock::SharedLockScope __(m_RunningLock); - - for (const auto& [K, v] : m_RunningMap) + if (v->ActionId == ActionId) { - if (v->ActionId == ActionId) - { - return HttpResponseCode::Accepted; - } + return HttpResponseCode::Accepted; } } @@ -1144,12 +1146,16 @@ ComputeServiceSession::Impl::GetCompleted(CbWriter& Cbo) { Cbo.BeginArray("completed"); - m_ResultsLock.WithSharedLock([&] { + m_ActionMapLock.WithSharedLock([&] { for (auto& [Lsn, Action] : m_ResultsMap) { Cbo.BeginObject(); Cbo << "lsn"sv << Lsn; Cbo << "state"sv << RunnerAction::ToString(Action->ActionState()); + if (!Action->FailureReason.empty()) + { + Cbo << "reason"sv << Action->FailureReason; + } Cbo.EndObject(); } }); @@ -1275,20 +1281,14 @@ ComputeServiceSession::Impl::CancelQueue(int QueueId) std::vector<Ref<RunnerAction>> PendingActionsToCancel; std::vector<int> RunningLsnsToCancel; - m_PendingLock.WithSharedLock([&] { + m_ActionMapLock.WithSharedLock([&] { for (int Lsn : LsnsToCancel) { if (auto It = m_PendingActions.find(Lsn); It != m_PendingActions.end()) { PendingActionsToCancel.push_back(It->second); } - } - }); - - m_RunningLock.WithSharedLock([&] { - for (int Lsn : LsnsToCancel) - { - if (m_RunningMap.find(Lsn) != m_RunningMap.end()) + else if (m_RunningMap.find(Lsn) != m_RunningMap.end()) { RunningLsnsToCancel.push_back(Lsn); } @@ -1307,7 +1307,7 @@ ComputeServiceSession::Impl::CancelQueue(int QueueId) // transition from the runner is blocked (Cancelled > Failed in the enum). for (int Lsn : RunningLsnsToCancel) { - m_RunningLock.WithSharedLock([&] { + m_ActionMapLock.WithSharedLock([&] { if (auto It = m_RunningMap.find(Lsn); It != m_RunningMap.end()) { It->second->SetActionState(RunnerAction::State::Cancelled); @@ -1444,8 +1444,8 @@ ComputeServiceSession::Impl::GetQueueCompleted(int QueueId, CbWriter& Cbo) if (Queue) { - Queue->m_Lock.WithSharedLock([&] { - m_ResultsLock.WithSharedLock([&] { + m_ActionMapLock.WithSharedLock([&] { + Queue->m_Lock.WithSharedLock([&] { for (int Lsn : Queue->FinishedLsns) { if (m_ResultsMap.contains(Lsn)) @@ -1475,15 +1475,19 @@ ComputeServiceSession::Impl::NotifyQueueActionComplete(int QueueId, int Lsn, Run return; } + bool WasActive = false; Queue->m_Lock.WithExclusiveLock([&] { - Queue->ActiveLsns.erase(Lsn); + WasActive = Queue->ActiveLsns.erase(Lsn) > 0; Queue->FinishedLsns.insert(Lsn); }); - const int PreviousActive = Queue->ActiveCount.fetch_sub(1, std::memory_order_relaxed); - if (PreviousActive == 1) + if (WasActive) { - Queue->IdleSince.store(GetHifreqTimerValue(), std::memory_order_relaxed); + const int PreviousActive = Queue->ActiveCount.fetch_sub(1, std::memory_order_relaxed); + if (PreviousActive == 1) + { + Queue->IdleSince.store(GetHifreqTimerValue(), std::memory_order_relaxed); + } } switch (ActionState) @@ -1541,26 +1545,32 @@ ComputeServiceSession::Impl::SchedulePendingActions() { ZEN_TRACE_CPU("ComputeServiceSession::SchedulePendingActions"); int ScheduledCount = 0; - size_t RunningCount = m_RunningLock.WithSharedLock([&] { return m_RunningMap.size(); }); - size_t PendingCount = m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); }); - size_t ResultCount = m_ResultsLock.WithSharedLock([&] { return m_ResultsMap.size(); }); + size_t RunningCount = 0; + size_t PendingCount = 0; + size_t ResultCount = 0; + + m_ActionMapLock.WithSharedLock([&] { + RunningCount = m_RunningMap.size(); + PendingCount = m_PendingActions.size(); + ResultCount = m_ResultsMap.size(); + }); static Stopwatch DumpRunningTimer; auto _ = MakeGuard([&] { - ZEN_INFO("scheduled {} pending actions. {} running ({} retired), {} still pending, {} results", - ScheduledCount, - RunningCount, - m_RetiredCount.load(), - PendingCount, - ResultCount); + ZEN_DEBUG("scheduled {} pending actions. {} running ({} retired), {} still pending, {} results", + ScheduledCount, + RunningCount, + m_RetiredCount.load(), + PendingCount, + ResultCount); if (DumpRunningTimer.GetElapsedTimeMs() > 30000) { DumpRunningTimer.Reset(); std::set<int> RunningList; - m_RunningLock.WithSharedLock([&] { + m_ActionMapLock.WithSharedLock([&] { for (auto& [K, V] : m_RunningMap) { RunningList.insert(K); @@ -1602,13 +1612,13 @@ ComputeServiceSession::Impl::SchedulePendingActions() // Also note that the m_PendingActions list is not maintained // here, that's done periodically in SchedulePendingActions() - m_PendingLock.WithExclusiveLock([&] { - if (m_SessionState.load(std::memory_order_relaxed) >= SessionState::Paused) - { - return; - } + // Extract pending actions under a shared lock - we only need to read + // the map and take Ref copies. ActionState() is atomic so this is safe. + // Sorting and capacity trimming happen outside the lock to avoid + // blocking HTTP handlers on O(N log N) work with large pending queues. - if (m_PendingActions.empty()) + m_ActionMapLock.WithSharedLock([&] { + if (m_SessionState.load(std::memory_order_relaxed) >= SessionState::Paused) { return; } @@ -1628,6 +1638,7 @@ ComputeServiceSession::Impl::SchedulePendingActions() case RunnerAction::State::Completed: case RunnerAction::State::Failed: case RunnerAction::State::Abandoned: + case RunnerAction::State::Rejected: case RunnerAction::State::Cancelled: break; @@ -1638,30 +1649,30 @@ ComputeServiceSession::Impl::SchedulePendingActions() } } - // Sort by priority descending, then by LSN ascending (FIFO within same priority) - std::sort(ActionsToSchedule.begin(), ActionsToSchedule.end(), [](const Ref<RunnerAction>& A, const Ref<RunnerAction>& B) { - if (A->Priority != B->Priority) - { - return A->Priority > B->Priority; - } - return A->ActionLsn < B->ActionLsn; - }); + PendingCount = m_PendingActions.size(); + }); - if (ActionsToSchedule.size() > Capacity) + // Sort by priority descending, then by LSN ascending (FIFO within same priority) + std::sort(ActionsToSchedule.begin(), ActionsToSchedule.end(), [](const Ref<RunnerAction>& A, const Ref<RunnerAction>& B) { + if (A->Priority != B->Priority) { - ActionsToSchedule.resize(Capacity); + return A->Priority > B->Priority; } - - PendingCount = m_PendingActions.size(); + return A->ActionLsn < B->ActionLsn; }); + if (ActionsToSchedule.size() > Capacity) + { + ActionsToSchedule.resize(Capacity); + } + if (ActionsToSchedule.empty()) { _.Dismiss(); return; } - ZEN_INFO("attempting schedule of {} pending actions", ActionsToSchedule.size()); + ZEN_DEBUG("attempting schedule of {} pending actions", ActionsToSchedule.size()); Stopwatch SubmitTimer; std::vector<SubmitResult> SubmitResults = SubmitActions(ActionsToSchedule); @@ -1681,10 +1692,10 @@ ComputeServiceSession::Impl::SchedulePendingActions() } } - ZEN_INFO("scheduled {} pending actions in {} ({} rejected)", - ScheduledActionCount, - NiceTimeSpanMs(SubmitTimer.GetElapsedTimeMs()), - NotAcceptedCount); + ZEN_DEBUG("scheduled {} pending actions in {} ({} rejected)", + ScheduledActionCount, + NiceTimeSpanMs(SubmitTimer.GetElapsedTimeMs()), + NotAcceptedCount); ScheduledCount += ScheduledActionCount; PendingCount -= ScheduledActionCount; @@ -1701,7 +1712,7 @@ ComputeServiceSession::Impl::SchedulerThreadFunction() { int TimeoutMs = 500; - auto PendingCount = m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); }); + auto PendingCount = m_ActionMapLock.WithSharedLock([&] { return m_PendingActions.size(); }); if (PendingCount) { @@ -1720,22 +1731,22 @@ ComputeServiceSession::Impl::SchedulerThreadFunction() m_SchedulingThreadEvent.Reset(); } - ZEN_DEBUG("compute scheduler TICK (Pending: {} was {}, Running: {}, Results: {}) timeout: {}", - m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); }), - PendingCount, - m_RunningLock.WithSharedLock([&] { return m_RunningMap.size(); }), - m_ResultsLock.WithSharedLock([&] { return m_ResultsMap.size(); }), - TimeoutMs); + m_ActionMapLock.WithSharedLock([&] { + ZEN_DEBUG("compute scheduler TICK (Pending: {}, Running: {}, Results: {}) timeout: {}", + m_PendingActions.size(), + m_RunningMap.size(), + m_ResultsMap.size(), + TimeoutMs); + }); HandleActionUpdates(); - // Auto-transition Draining → Paused when all work is done + // Auto-transition Draining -> Paused when all work is done if (m_SessionState.load(std::memory_order_relaxed) == SessionState::Draining) { - size_t Pending = m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); }); - size_t Running = m_RunningLock.WithSharedLock([&] { return m_RunningMap.size(); }); + bool AllDrained = m_ActionMapLock.WithSharedLock([&] { return m_PendingActions.empty() && m_RunningMap.empty(); }); - if (Pending == 0 && Running == 0) + if (AllDrained) { SessionState Expected = SessionState::Draining; if (m_SessionState.compare_exchange_strong(Expected, SessionState::Paused, std::memory_order_acq_rel)) @@ -1776,9 +1787,9 @@ ComputeServiceSession::Impl::GetMaxRetriesForQueue(int QueueId) if (Config) { - int Value = Config["max_retries"].AsInt32(0); + int Value = Config["max_retries"].AsInt32(-1); - if (Value > 0) + if (Value >= 0) { return Value; } @@ -1797,7 +1808,7 @@ ComputeServiceSession::Impl::RescheduleAction(int ActionLsn) // Find, validate, and remove atomically under a single lock scope to prevent // concurrent RescheduleAction calls from double-removing the same action. - m_ResultsLock.WithExclusiveLock([&] { + m_ActionMapLock.WithExclusiveLock([&] { auto It = m_ResultsMap.find(ActionLsn); if (It == m_ResultsMap.end()) { @@ -1855,7 +1866,7 @@ ComputeServiceSession::Impl::RescheduleAction(int ActionLsn) } } - // Reset action state — this calls PostUpdate() internally + // Reset action state - this calls PostUpdate() internally Action->ResetActionStateToPending(); int NewRetryCount = Action->RetryCount.load(std::memory_order_relaxed); @@ -1871,26 +1882,20 @@ ComputeServiceSession::Impl::RetractAction(int ActionLsn) bool WasRunning = false; // Look for the action in pending or running maps - m_RunningLock.WithSharedLock([&] { + m_ActionMapLock.WithSharedLock([&] { if (auto It = m_RunningMap.find(ActionLsn); It != m_RunningMap.end()) { Action = It->second; WasRunning = true; } + else if (auto PIt = m_PendingActions.find(ActionLsn); PIt != m_PendingActions.end()) + { + Action = PIt->second; + } }); if (!Action) { - m_PendingLock.WithSharedLock([&] { - if (auto It = m_PendingActions.find(ActionLsn); It != m_PendingActions.end()) - { - Action = It->second; - } - }); - } - - if (!Action) - { return {.Success = false, .Error = "Action not found in pending or running maps"}; } @@ -1912,18 +1917,15 @@ ComputeServiceSession::Impl::RetractAction(int ActionLsn) void ComputeServiceSession::Impl::RemoveActionFromActiveMaps(int ActionLsn) { - m_RunningLock.WithExclusiveLock([&] { - m_PendingLock.WithExclusiveLock([&] { - if (auto FindIt = m_RunningMap.find(ActionLsn); FindIt == m_RunningMap.end()) - { - m_PendingActions.erase(ActionLsn); - } - else - { - m_RunningMap.erase(FindIt); - } - }); - }); + // Caller must hold m_ActionMapLock exclusively. + if (auto FindIt = m_RunningMap.find(ActionLsn); FindIt == m_RunningMap.end()) + { + m_PendingActions.erase(ActionLsn); + } + else + { + m_RunningMap.erase(FindIt); + } } void @@ -1946,7 +1948,7 @@ ComputeServiceSession::Impl::HandleActionUpdates() // This is safe because state transitions are monotonically increasing by enum // rank (Pending < Submitting < Running < Completed/Failed/Cancelled), so // SetActionState rejects any transition to a lower-ranked state. By the time - // we read ActionState() here, it reflects the highest state reached — making + // we read ActionState() here, it reflects the highest state reached - making // the first occurrence per LSN authoritative and duplicates redundant. for (Ref<RunnerAction>& Action : UpdatedActions) { @@ -1956,7 +1958,7 @@ ComputeServiceSession::Impl::HandleActionUpdates() { switch (Action->ActionState()) { - // Newly enqueued — add to pending map for scheduling + // Newly enqueued - add to pending map for scheduling case RunnerAction::State::Pending: // Guard against a race where the session is abandoned between // EnqueueAction (which calls PostUpdate) and this scheduler @@ -1973,35 +1975,44 @@ ComputeServiceSession::Impl::HandleActionUpdates() } else { - m_PendingLock.WithExclusiveLock([&] { m_PendingActions.insert({ActionLsn, Action}); }); + m_ActionMapLock.WithExclusiveLock([&] { m_PendingActions.insert({ActionLsn, Action}); }); } break; - // Async submission in progress — remains in pending map + // Async submission in progress - remains in pending map case RunnerAction::State::Submitting: break; - // Dispatched to a runner — move from pending to running + // Dispatched to a runner - move from pending to running case RunnerAction::State::Running: - m_RunningLock.WithExclusiveLock([&] { - m_PendingLock.WithExclusiveLock([&] { - m_RunningMap.insert({ActionLsn, Action}); - m_PendingActions.erase(ActionLsn); - }); + m_ActionMapLock.WithExclusiveLock([&] { + m_RunningMap.insert({ActionLsn, Action}); + m_PendingActions.erase(ActionLsn); }); ZEN_DEBUG("action {} ({}) RUNNING", Action->ActionId, ActionLsn); break; - // Retracted — pull back for rescheduling without counting against retry limit + // Retracted - pull back for rescheduling without counting against retry limit case RunnerAction::State::Retracted: { - RemoveActionFromActiveMaps(ActionLsn); + m_ActionMapLock.WithExclusiveLock([&] { + m_RunningMap.erase(ActionLsn); + m_PendingActions[ActionLsn] = Action; + }); Action->ResetActionStateToPending(); ZEN_INFO("action {} ({}) retracted for rescheduling", Action->ActionId, ActionLsn); break; } - // Terminal states — move to results, record history, notify queue + // Rejected - runner was at capacity, reschedule without retry cost + case RunnerAction::State::Rejected: + { + Action->ResetActionStateToPending(); + ZEN_DEBUG("action {} ({}) rescheduled after runner rejection", Action->ActionId, ActionLsn); + break; + } + + // Terminal states - move to results, record history, notify queue case RunnerAction::State::Completed: case RunnerAction::State::Failed: case RunnerAction::State::Abandoned: @@ -2010,7 +2021,7 @@ ComputeServiceSession::Impl::HandleActionUpdates() auto TerminalState = Action->ActionState(); // Automatic retry for Failed/Abandoned actions with retries remaining. - // Skip retries when the session itself is abandoned — those actions + // Skip retries when the session itself is abandoned - those actions // were intentionally abandoned and should not be rescheduled. if ((TerminalState == RunnerAction::State::Failed || TerminalState == RunnerAction::State::Abandoned) && m_SessionState.load(std::memory_order_relaxed) < SessionState::Abandoned) @@ -2019,7 +2030,10 @@ ComputeServiceSession::Impl::HandleActionUpdates() if (Action->RetryCount.load(std::memory_order_relaxed) < MaxRetries) { - RemoveActionFromActiveMaps(ActionLsn); + m_ActionMapLock.WithExclusiveLock([&] { + m_RunningMap.erase(ActionLsn); + m_PendingActions[ActionLsn] = Action; + }); // Reset triggers PostUpdate() which re-enters the action as Pending Action->ResetActionStateToPending(); @@ -2032,18 +2046,26 @@ ComputeServiceSession::Impl::HandleActionUpdates() MaxRetries); break; } + else + { + ZEN_WARN("action {} ({}) {} after {} retries, not rescheduling", + Action->ActionId, + ActionLsn, + RunnerAction::ToString(TerminalState), + Action->RetryCount.load(std::memory_order_relaxed)); + } } - RemoveActionFromActiveMaps(ActionLsn); + m_ActionMapLock.WithExclusiveLock([&] { + RemoveActionFromActiveMaps(ActionLsn); - // Update queue counters BEFORE publishing the result into - // m_ResultsMap. GetActionResult erases from m_ResultsMap - // under m_ResultsLock, so if we updated counters after - // releasing that lock, a caller could observe ActiveCount - // still at 1 immediately after GetActionResult returned OK. - NotifyQueueActionComplete(Action->QueueId, ActionLsn, TerminalState); + // Update queue counters BEFORE publishing the result into + // m_ResultsMap. GetActionResult erases from m_ResultsMap + // under m_ActionMapLock, so if we updated counters after + // releasing that lock, a caller could observe ActiveCount + // still at 1 immediately after GetActionResult returned OK. + NotifyQueueActionComplete(Action->QueueId, ActionLsn, TerminalState); - m_ResultsLock.WithExclusiveLock([&] { m_ResultsMap[ActionLsn] = Action; // Append to bounded action history ring @@ -2124,10 +2146,9 @@ ComputeServiceSession::Impl::SubmitActions(const std::vector<Ref<RunnerAction>>& ZEN_TRACE_CPU("ComputeServiceSession::SubmitActions"); std::vector<SubmitResult> Results(Actions.size()); - // First try submitting the batch to local runners in parallel + // First try submitting the batch to local runners std::vector<SubmitResult> LocalResults = m_LocalRunnerGroup.SubmitActions(Actions); - std::vector<size_t> RemoteIndices; std::vector<Ref<RunnerAction>> RemoteActions; for (size_t i = 0; i < Actions.size(); ++i) @@ -2138,20 +2159,40 @@ ComputeServiceSession::Impl::SubmitActions(const std::vector<Ref<RunnerAction>>& } else { - RemoteIndices.push_back(i); RemoteActions.push_back(Actions[i]); + Results[i] = SubmitResult{.IsAccepted = true, .Reason = "dispatched to remote"}; } } - // Submit remaining actions to remote runners in parallel + // Dispatch remaining actions to remote runners asynchronously. + // Mark actions as Submitting so the scheduler won't re-pick them. + // The remote runner will transition them to Running on success, or + // we mark them Failed on rejection so HandleActionUpdates retries. if (!RemoteActions.empty()) { - std::vector<SubmitResult> RemoteResults = m_RemoteRunnerGroup.SubmitActions(RemoteActions); - - for (size_t j = 0; j < RemoteIndices.size(); ++j) + for (const Ref<RunnerAction>& Action : RemoteActions) { - Results[RemoteIndices[j]] = std::move(RemoteResults[j]); + Action->SetActionState(RunnerAction::State::Submitting); } + + m_RemoteSubmitPool.ScheduleWork( + [this, RemoteActions = std::move(RemoteActions)]() { + std::vector<SubmitResult> RemoteResults = m_RemoteRunnerGroup.SubmitActions(RemoteActions); + + for (size_t j = 0; j < RemoteResults.size(); ++j) + { + if (!RemoteResults[j].IsAccepted) + { + ZEN_DEBUG("remote submission rejected for action {} ({}): {}", + RemoteActions[j]->ActionId, + RemoteActions[j]->ActionLsn, + RemoteResults[j].Reason); + + RemoteActions[j]->SetActionState(RunnerAction::State::Rejected); + } + } + }, + WorkerThreadPool::EMode::EnableBacklog); } return Results; @@ -2217,16 +2258,22 @@ ComputeServiceSession::NotifyOrchestratorChanged() m_Impl->NotifyOrchestratorChanged(); } -void +bool ComputeServiceSession::StartRecording(ChunkResolver& InResolver, const std::filesystem::path& RecordingPath) { - m_Impl->StartRecording(InResolver, RecordingPath); + return m_Impl->StartRecording(InResolver, RecordingPath); } -void +bool ComputeServiceSession::StopRecording() { - m_Impl->StopRecording(); + return m_Impl->StopRecording(); +} + +bool +ComputeServiceSession::IsRecording() const +{ + return m_Impl->IsRecording(); } ComputeServiceSession::ActionCounts @@ -2282,6 +2329,18 @@ ComputeServiceSession::AddLocalRunner(ChunkResolver& InChunkResolver, std::files } void +ComputeServiceSession::AddManagedLocalRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath, int32_t MaxConcurrentActions) +{ + ZEN_TRACE_CPU("ComputeServiceSession::AddManagedLocalRunner"); + + auto* NewRunner = + new ManagedProcessRunner(InChunkResolver, BasePath, m_Impl->m_DeferredDeleter, m_Impl->m_LocalSubmitPool, MaxConcurrentActions); + + m_Impl->SyncWorkersToRunner(*NewRunner); + m_Impl->m_LocalRunnerGroup.AddRunner(NewRunner); +} + +void ComputeServiceSession::AddRemoteRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath, std::string_view HostName) { ZEN_TRACE_CPU("ComputeServiceSession::AddRemoteRunner"); diff --git a/src/zencompute/httpcomputeservice.cpp b/src/zencompute/httpcomputeservice.cpp index bdfd9d197..5ab189d89 100644 --- a/src/zencompute/httpcomputeservice.cpp +++ b/src/zencompute/httpcomputeservice.cpp @@ -21,12 +21,14 @@ # include <zencore/thread.h> # include <zencore/trace.h> # include <zencore/uid.h> -# include <zenstore/cidstore.h> +# include <zenstore/hashkeyset.h> +# include <zenstore/zenstore.h> # include <zentelemetry/stats.h> # include <algorithm> # include <span> # include <unordered_map> +# include <utility> # include <vector> using namespace std::literals; @@ -45,7 +47,9 @@ auto OidMatcher = [](std::string_view Str) { return Str.size() == 24 && AsciiSe struct HttpComputeService::Impl { HttpComputeService* m_Self; - CidStore& m_CidStore; + ChunkStore& m_ActionStore; + ChunkStore& m_WorkerStore; + FallbackChunkResolver m_CombinedResolver; IHttpStatsService& m_StatsService; LoggerRef m_Log; std::filesystem::path m_BaseDir; @@ -58,6 +62,8 @@ struct HttpComputeService::Impl RwLock m_WsConnectionsLock; std::vector<Ref<WebSocketConnection>> m_WsConnections; + std::function<void()> m_ShutdownCallback; + // Metrics metrics::OperationTiming m_HttpRequests; @@ -72,13 +78,13 @@ struct HttpComputeService::Impl std::string ClientHostname; // empty if no hostname was provided }; - // Remote queue registry — all three maps share the same RemoteQueueInfo objects. + // Remote queue registry - all three maps share the same RemoteQueueInfo objects. // All maps are guarded by m_RemoteQueueLock. RwLock m_RemoteQueueLock; - std::unordered_map<Oid, Ref<RemoteQueueInfo>, Oid::Hasher> m_RemoteQueuesByToken; // Token → info - std::unordered_map<int, Ref<RemoteQueueInfo>> m_RemoteQueuesByQueueId; // QueueId → info - std::unordered_map<std::string, Ref<RemoteQueueInfo>> m_RemoteQueuesByTag; // idempotency key → info + std::unordered_map<Oid, Ref<RemoteQueueInfo>, Oid::Hasher> m_RemoteQueuesByToken; // Token -> info + std::unordered_map<int, Ref<RemoteQueueInfo>> m_RemoteQueuesByQueueId; // QueueId -> info + std::unordered_map<std::string, Ref<RemoteQueueInfo>> m_RemoteQueuesByTag; // idempotency key -> info LoggerRef Log() { return m_Log; } @@ -93,34 +99,38 @@ struct HttpComputeService::Impl uint64_t NewBytes = 0; }; - IngestStats IngestPackageAttachments(const CbPackage& Package); - bool CheckAttachments(const CbObject& ActionObj, std::vector<IoHash>& NeedList); - void HandleWorkersGet(HttpServerRequest& HttpReq); - void HandleWorkersAllGet(HttpServerRequest& HttpReq); - void WriteQueueDescription(CbWriter& Cbo, int QueueId, const ComputeServiceSession::QueueStatus& Status); - void HandleWorkerRequest(HttpServerRequest& HttpReq, const IoHash& WorkerId); - void HandleSubmitAction(HttpServerRequest& HttpReq, int QueueId, int Priority, const WorkerDesc* Worker); + bool IngestPackageAttachments(HttpServerRequest& HttpReq, const CbPackage& Package, IngestStats& OutStats); + bool CheckAttachments(const CbObject& ActionObj, std::vector<IoHash>& NeedList); + bool ValidateAttachmentHash(HttpServerRequest& HttpReq, const CbAttachment& Attachment); + void HandleWorkersGet(HttpServerRequest& HttpReq); + void HandleWorkersAllGet(HttpServerRequest& HttpReq); + void WriteQueueDescription(CbWriter& Cbo, int QueueId, const ComputeServiceSession::QueueStatus& Status); + void HandleWorkerRequest(HttpServerRequest& HttpReq, const IoHash& WorkerId); + void HandleSubmitAction(HttpServerRequest& HttpReq, int QueueId, int Priority, const WorkerDesc* Worker); // WebSocket / observer - void OnWebSocketOpen(Ref<WebSocketConnection> Connection); + void OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri); void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code); void OnActionsCompleted(std::span<const IComputeCompletionObserver::CompletedActionNotification> Actions); void RegisterRoutes(); - Impl(HttpComputeService* Self, - CidStore& InCidStore, - IHttpStatsService& StatsService, - const std::filesystem::path& BaseDir, - int32_t MaxConcurrentActions) + Impl(HttpComputeService* Self, + ChunkStore& InActionStore, + ChunkStore& InWorkerStore, + IHttpStatsService& StatsService, + std::filesystem::path BaseDir, + int32_t MaxConcurrentActions) : m_Self(Self) - , m_CidStore(InCidStore) + , m_ActionStore(InActionStore) + , m_WorkerStore(InWorkerStore) + , m_CombinedResolver(InActionStore, InWorkerStore) , m_StatsService(StatsService) , m_Log(logging::Get("compute")) - , m_BaseDir(BaseDir) - , m_ComputeService(InCidStore) + , m_BaseDir(std::move(BaseDir)) + , m_ComputeService(m_CombinedResolver) { - m_ComputeService.AddLocalRunner(InCidStore, m_BaseDir / "local", MaxConcurrentActions); + m_ComputeService.AddLocalRunner(m_CombinedResolver, m_BaseDir / "local", MaxConcurrentActions); m_ComputeService.WaitUntilReady(); m_StatsService.RegisterHandler("compute", *m_Self); RegisterRoutes(); @@ -182,6 +192,65 @@ HttpComputeService::Impl::RegisterRoutes() HttpVerb::kPost); m_Router.RegisterRoute( + "session/drain", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + if (m_ComputeService.RequestStateTransition(ComputeServiceSession::SessionState::Draining)) + { + CbObjectWriter Cbo; + Cbo << "state"sv << ToString(m_ComputeService.GetSessionState()); + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + } + + CbObjectWriter Cbo; + Cbo << "error"sv + << "Cannot transition to Draining from current state"sv; + HttpReq.WriteResponse(HttpResponseCode::Conflict, Cbo.Save()); + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "session/status", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + CbObjectWriter Cbo; + Cbo << "state"sv << ToString(m_ComputeService.GetSessionState()); + auto Counts = m_ComputeService.GetActionCounts(); + Cbo << "actions_pending"sv << Counts.Pending; + Cbo << "actions_running"sv << Counts.Running; + Cbo << "actions_completed"sv << Counts.Completed; + HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "session/sunset", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + if (m_ComputeService.RequestStateTransition(ComputeServiceSession::SessionState::Sunset)) + { + CbObjectWriter Cbo; + Cbo << "state"sv << ToString(m_ComputeService.GetSessionState()); + HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + + if (m_ShutdownCallback) + { + m_ShutdownCallback(); + } + return; + } + + CbObjectWriter Cbo; + Cbo << "error"sv + << "Cannot transition to Sunset from current state"sv; + HttpReq.WriteResponse(HttpResponseCode::Conflict, Cbo.Save()); + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( "workers", [this](HttpRouterRequest& Req) { HandleWorkersGet(Req.ServerRequest()); }, HttpVerb::kGet); @@ -373,7 +442,7 @@ HttpComputeService::Impl::RegisterRoutes() if (HttpResponseCode ResponseCode = m_ComputeService.FindActionResult(ActionId, /* out */ Output); ResponseCode != HttpResponseCode::OK) { - ZEN_TRACE("jobs/{}/{}: {}", Req.GetCapture(1), Req.GetCapture(2), ToString(ResponseCode)) + ZEN_DEBUG("jobs/{}/{}: {}", Req.GetCapture(1), Req.GetCapture(2), ToString(ResponseCode)) if (ResponseCode == HttpResponseCode::NotFound) { @@ -498,9 +567,19 @@ HttpComputeService::Impl::RegisterRoutes() return HttpReq.WriteResponse(HttpResponseCode::Forbidden); } - m_ComputeService.StartRecording(m_CidStore, m_BaseDir / "recording"); + std::filesystem::path RecordingPath = m_BaseDir / "recording"; + + if (!m_ComputeService.StartRecording(m_CombinedResolver, RecordingPath)) + { + CbObjectWriter Cbo; + Cbo << "error" + << "recording is already active"; + return HttpReq.WriteResponse(HttpResponseCode::Conflict, Cbo.Save()); + } - return HttpReq.WriteResponse(HttpResponseCode::OK); + CbObjectWriter Cbo; + Cbo << "path" << RecordingPath.string(); + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); }, HttpVerb::kPost); @@ -514,9 +593,19 @@ HttpComputeService::Impl::RegisterRoutes() return HttpReq.WriteResponse(HttpResponseCode::Forbidden); } - m_ComputeService.StopRecording(); + std::filesystem::path RecordingPath = m_BaseDir / "recording"; - return HttpReq.WriteResponse(HttpResponseCode::OK); + if (!m_ComputeService.StopRecording()) + { + CbObjectWriter Cbo; + Cbo << "error" + << "no recording is active"; + return HttpReq.WriteResponse(HttpResponseCode::Conflict, Cbo.Save()); + } + + CbObjectWriter Cbo; + Cbo << "path" << RecordingPath.string(); + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); }, HttpVerb::kPost); @@ -583,7 +672,7 @@ HttpComputeService::Impl::RegisterRoutes() }, HttpVerb::kGet | HttpVerb::kPost); - // Queue creation routes — these remain separate since local creates a plain queue + // Queue creation routes - these remain separate since local creates a plain queue // while remote additionally generates an OID token for external access. m_Router.RegisterRoute( @@ -637,7 +726,7 @@ HttpComputeService::Impl::RegisterRoutes() return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); } - // Queue has since expired — clean up stale entries and fall through to create a new one + // Queue has since expired - clean up stale entries and fall through to create a new one m_RemoteQueuesByToken.erase(Existing->Token); m_RemoteQueuesByQueueId.erase(Existing->QueueId); m_RemoteQueuesByTag.erase(It); @@ -666,7 +755,7 @@ HttpComputeService::Impl::RegisterRoutes() }, HttpVerb::kPost); - // Unified queue routes — {queueref} accepts both local integer IDs and remote OID tokens. + // Unified queue routes - {queueref} accepts both local integer IDs and remote OID tokens. // ResolveQueueRef() handles access control (local-only for integer IDs) and token resolution. m_Router.RegisterRoute( @@ -1016,7 +1105,7 @@ HttpComputeService::Impl::RegisterRoutes() }, HttpVerb::kPost); - // WebSocket upgrade endpoint — the handler logic lives in + // WebSocket upgrade endpoint - the handler logic lives in // HttpComputeService::OnWebSocket* methods; this route merely // satisfies the router so the upgrade request isn't rejected. m_Router.RegisterRoute( @@ -1027,11 +1116,12 @@ HttpComputeService::Impl::RegisterRoutes() ////////////////////////////////////////////////////////////////////////// -HttpComputeService::HttpComputeService(CidStore& InCidStore, +HttpComputeService::HttpComputeService(ChunkStore& InActionStore, + ChunkStore& InWorkerStore, IHttpStatsService& StatsService, const std::filesystem::path& BaseDir, int32_t MaxConcurrentActions) -: m_Impl(std::make_unique<Impl>(this, InCidStore, StatsService, BaseDir, MaxConcurrentActions)) +: m_Impl(std::make_unique<Impl>(this, InActionStore, InWorkerStore, StatsService, BaseDir, MaxConcurrentActions)) { } @@ -1057,6 +1147,12 @@ HttpComputeService::GetActionCounts() return m_Impl->m_ComputeService.GetActionCounts(); } +void +HttpComputeService::SetShutdownCallback(std::function<void()> Callback) +{ + m_Impl->m_ShutdownCallback = std::move(Callback); +} + const char* HttpComputeService::BaseUri() const { @@ -1145,7 +1241,7 @@ HttpComputeService::Impl::ResolveQueueRef(HttpServerRequest& HttpReq, std::strin { if (OidMatcher(Capture)) { - // Remote OID token — accessible from any client + // Remote OID token - accessible from any client const Oid Token = Oid::FromHexString(Capture); const int QueueId = ResolveQueueToken(Token); @@ -1157,7 +1253,7 @@ HttpComputeService::Impl::ResolveQueueRef(HttpServerRequest& HttpReq, std::strin return QueueId; } - // Local integer queue ID — restricted to local machine requests + // Local integer queue ID - restricted to local machine requests if (!HttpReq.IsLocalMachineRequest()) { HttpReq.WriteResponse(HttpResponseCode::Forbidden); @@ -1167,35 +1263,81 @@ HttpComputeService::Impl::ResolveQueueRef(HttpServerRequest& HttpReq, std::strin return ParseInt<int>(Capture).value_or(0); } -HttpComputeService::Impl::IngestStats -HttpComputeService::Impl::IngestPackageAttachments(const CbPackage& Package) +bool +HttpComputeService::Impl::ValidateAttachmentHash(HttpServerRequest& HttpReq, const CbAttachment& Attachment) { - IngestStats Stats; + const IoHash ClaimedHash = Attachment.GetHash(); + CompressedBuffer Buffer = Attachment.AsCompressedBinary(); + const IoHash HeaderHash = Buffer.DecodeRawHash(); + if (HeaderHash != ClaimedHash) + { + ZEN_WARN("attachment header hash mismatch: claimed {} but header contains {}", ClaimedHash, HeaderHash); + HttpReq.WriteResponse(HttpResponseCode::BadRequest); + return false; + } + + IoHashStream Hasher; + + bool DecompressOk = Buffer.DecompressToStream( + 0, + Buffer.DecodeRawSize(), + [&](uint64_t /*SourceOffset*/, uint64_t /*SourceSize*/, uint64_t /*Offset*/, const CompositeBuffer& Range) -> bool { + for (const SharedBuffer& Segment : Range.GetSegments()) + { + Hasher.Append(Segment.GetView()); + } + return true; + }); + + if (!DecompressOk) + { + ZEN_WARN("attachment {}: failed to decompress", ClaimedHash); + HttpReq.WriteResponse(HttpResponseCode::BadRequest); + return false; + } + + const IoHash ActualHash = Hasher.GetHash(); + + if (ActualHash != ClaimedHash) + { + ZEN_WARN("attachment hash mismatch: claimed {} but decompressed data hashes to {}", ClaimedHash, ActualHash); + HttpReq.WriteResponse(HttpResponseCode::BadRequest); + return false; + } + + return true; +} + +bool +HttpComputeService::Impl::IngestPackageAttachments(HttpServerRequest& HttpReq, const CbPackage& Package, IngestStats& OutStats) +{ for (const CbAttachment& Attachment : Package.GetAttachments()) { ZEN_ASSERT(Attachment.IsCompressedBinary()); - const IoHash DataHash = Attachment.GetHash(); - CompressedBuffer DataView = Attachment.AsCompressedBinary(); - - ZEN_UNUSED(DataHash); + if (!ValidateAttachmentHash(HttpReq, Attachment)) + { + return false; + } - const uint64_t CompressedSize = DataView.GetCompressedSize(); + const IoHash DataHash = Attachment.GetHash(); + CompressedBuffer DataView = Attachment.AsCompressedBinary(); + const uint64_t CompressedSize = DataView.GetCompressedSize(); - Stats.Bytes += CompressedSize; - ++Stats.Count; + OutStats.Bytes += CompressedSize; + ++OutStats.Count; - const CidStore::InsertResult InsertResult = m_CidStore.AddChunk(DataView.GetCompressed().Flatten().AsIoBuffer(), DataHash); + const ChunkStore::InsertResult InsertResult = m_ActionStore.AddChunk(DataView.GetCompressed().Flatten().AsIoBuffer(), DataHash); if (InsertResult.New) { - Stats.NewBytes += CompressedSize; - ++Stats.NewCount; + OutStats.NewBytes += CompressedSize; + ++OutStats.NewCount; } } - return Stats; + return true; } bool @@ -1204,7 +1346,7 @@ HttpComputeService::Impl::CheckAttachments(const CbObject& ActionObj, std::vecto ActionObj.IterateAttachments([&](CbFieldView Field) { const IoHash FileHash = Field.AsHash(); - if (!m_CidStore.ContainsChunk(FileHash)) + if (!m_ActionStore.ContainsChunk(FileHash)) { NeedList.push_back(FileHash); } @@ -1253,7 +1395,10 @@ HttpComputeService::Impl::HandleSubmitAction(HttpServerRequest& HttpReq, int Que { CbPackage Package = HttpReq.ReadPayloadPackage(); Body = Package.GetObject(); - Stats = IngestPackageAttachments(Package); + if (!IngestPackageAttachments(HttpReq, Package, Stats)) + { + return; // validation failed, response already written + } break; } @@ -1268,8 +1413,7 @@ HttpComputeService::Impl::HandleSubmitAction(HttpServerRequest& HttpReq, int Que { // --- Batch path --- - // For CbObject payloads, check all attachments upfront before enqueuing anything - if (HttpReq.RequestContentType() == HttpContentType::kCbObject) + // Verify all action attachment references exist in the store { std::vector<IoHash> NeedList; @@ -1345,7 +1489,6 @@ HttpComputeService::Impl::HandleSubmitAction(HttpServerRequest& HttpReq, int Que // --- Single-action path: Body is the action itself --- - if (HttpReq.RequestContentType() == HttpContentType::kCbObject) { std::vector<IoHash> NeedList; @@ -1453,7 +1596,7 @@ HttpComputeService::Impl::HandleWorkerRequest(HttpServerRequest& HttpReq, const CbPackage WorkerPackage; WorkerPackage.SetObject(WorkerSpec); - m_CidStore.FilterChunks(ChunkSet); + m_WorkerStore.FilterChunks(ChunkSet); if (ChunkSet.IsEmpty()) { @@ -1491,15 +1634,19 @@ HttpComputeService::Impl::HandleWorkerRequest(HttpServerRequest& HttpReq, const { ZEN_ASSERT(Attachment.IsCompressedBinary()); + if (!ValidateAttachmentHash(HttpReq, Attachment)) + { + return; + } + const IoHash DataHash = Attachment.GetHash(); CompressedBuffer Buffer = Attachment.AsCompressedBinary(); - ZEN_UNUSED(DataHash); TotalAttachmentBytes += Buffer.GetCompressedSize(); ++AttachmentCount; - const CidStore::InsertResult InsertResult = - m_CidStore.AddChunk(Buffer.GetCompressed().Flatten().AsIoBuffer(), DataHash); + const ChunkStore::InsertResult InsertResult = + m_WorkerStore.AddChunk(Buffer.GetCompressed().Flatten().AsIoBuffer(), DataHash); if (InsertResult.New) { @@ -1537,9 +1684,9 @@ HttpComputeService::Impl::HandleWorkerRequest(HttpServerRequest& HttpReq, const // void -HttpComputeService::OnWebSocketOpen(Ref<WebSocketConnection> Connection) +HttpComputeService::OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri) { - m_Impl->OnWebSocketOpen(std::move(Connection)); + m_Impl->OnWebSocketOpen(std::move(Connection), RelativeUri); } void @@ -1562,12 +1709,13 @@ HttpComputeService::OnActionsCompleted(std::span<const CompletedActionNotificati ////////////////////////////////////////////////////////////////////////// // -// Impl — WebSocket / observer +// Impl - WebSocket / observer // void -HttpComputeService::Impl::OnWebSocketOpen(Ref<WebSocketConnection> Connection) +HttpComputeService::Impl::OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri) { + ZEN_UNUSED(RelativeUri); ZEN_INFO("compute WebSocket client connected"); m_WsConnectionsLock.WithExclusiveLock([&] { m_WsConnections.push_back(std::move(Connection)); }); } diff --git a/src/zencompute/httporchestrator.cpp b/src/zencompute/httporchestrator.cpp index 6cbe01e04..56eadcd57 100644 --- a/src/zencompute/httporchestrator.cpp +++ b/src/zencompute/httporchestrator.cpp @@ -7,6 +7,7 @@ # include <zencompute/orchestratorservice.h> # include <zencore/compactbinarybuilder.h> # include <zencore/logging.h> +# include <zencore/session.h> # include <zencore/string.h> # include <zencore/system.h> @@ -77,10 +78,47 @@ ParseWorkerAnnouncement(const CbObjectView& Data, OrchestratorService::WorkerAnn return Ann.Id; } +static OrchestratorService::WorkerAnnotator +MakeWorkerAnnotator(IProvisionerStateProvider* Prov) +{ + if (!Prov) + { + return {}; + } + return [Prov](std::string_view WorkerId, CbObjectWriter& Cbo) { + AgentProvisioningStatus Status = Prov->GetAgentStatus(WorkerId); + if (Status != AgentProvisioningStatus::Unknown) + { + const char* StatusStr = (Status == AgentProvisioningStatus::Draining) ? "draining" : "active"; + Cbo << "provisioner_status" << std::string_view(StatusStr); + } + }; +} + +bool +HttpOrchestratorService::ValidateCoordinatorSession(const CbObjectView& Data, std::string_view WorkerId) +{ + std::string_view SessionStr = Data["coordinator_session"].AsString(""); + if (SessionStr.empty()) + { + return true; // backwards compatibility: accept announcements without a session + } + Oid Session = Oid::TryFromHexString(SessionStr); + if (Session == m_SessionId) + { + return true; + } + ZEN_WARN("rejecting stale announcement from '{}' (session {} != {})", WorkerId, SessionStr, m_SessionId.ToString()); + return false; +} + HttpOrchestratorService::HttpOrchestratorService(std::filesystem::path DataDir, bool EnableWorkerWebSocket) : m_Service(std::make_unique<OrchestratorService>(std::move(DataDir), EnableWorkerWebSocket)) , m_Hostname(GetMachineName()) { + m_SessionId = zen::GetSessionId(); + ZEN_INFO("orchestrator session id: {}", m_SessionId.ToString()); + m_Router.AddMatcher("workerid", [](std::string_view Segment) { return IsValidWorkerId(Segment); }); m_Router.AddMatcher("clientid", [](std::string_view Segment) { return IsValidWorkerId(Segment); }); @@ -95,13 +133,17 @@ HttpOrchestratorService::HttpOrchestratorService(std::filesystem::path DataDir, [this](HttpRouterRequest& Req) { CbObjectWriter Cbo; Cbo << "hostname" << std::string_view(m_Hostname); + Cbo << "session_id" << m_SessionId.ToString(); Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Cbo.Save()); }, HttpVerb::kGet); m_Router.RegisterRoute( "provision", - [this](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK, m_Service->GetWorkerList()); }, + [this](HttpRouterRequest& Req) { + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, + m_Service->GetWorkerList(MakeWorkerAnnotator(m_Provisioner.load(std::memory_order_acquire)))); + }, HttpVerb::kPost); m_Router.RegisterRoute( @@ -122,6 +164,11 @@ HttpOrchestratorService::HttpOrchestratorService(std::filesystem::path DataDir, "characters and uri must start with http:// or https://"); } + if (!ValidateCoordinatorSession(Data, WorkerId)) + { + return HttpReq.WriteResponse(HttpResponseCode::Conflict, HttpContentType::kText, "Stale coordinator session"); + } + m_Service->AnnounceWorker(Ann); HttpReq.WriteResponse(HttpResponseCode::OK); @@ -135,7 +182,10 @@ HttpOrchestratorService::HttpOrchestratorService(std::filesystem::path DataDir, m_Router.RegisterRoute( "agents", - [this](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK, m_Service->GetWorkerList()); }, + [this](HttpRouterRequest& Req) { + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, + m_Service->GetWorkerList(MakeWorkerAnnotator(m_Provisioner.load(std::memory_order_acquire)))); + }, HttpVerb::kGet); m_Router.RegisterRoute( @@ -241,6 +291,59 @@ HttpOrchestratorService::HttpOrchestratorService(std::filesystem::path DataDir, }, HttpVerb::kGet); + // Provisioner endpoints + + m_Router.RegisterRoute( + "provisioner/status", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + CbObjectWriter Cbo; + if (IProvisionerStateProvider* Prov = m_Provisioner.load(std::memory_order_acquire)) + { + Cbo << "name" << Prov->GetName(); + Cbo << "target_cores" << Prov->GetTargetCoreCount(); + Cbo << "estimated_cores" << Prov->GetEstimatedCoreCount(); + Cbo << "active_cores" << Prov->GetActiveCoreCount(); + Cbo << "agents" << Prov->GetAgentCount(); + Cbo << "agents_draining" << Prov->GetDrainingAgentCount(); + } + HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "provisioner/target", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + CbObject Data = HttpReq.ReadPayloadObject(); + int32_t Cores = Data["target_cores"].AsInt32(-1); + + ZEN_INFO("provisioner/target: received request (target_cores={}, payload_valid={})", Cores, Data ? true : false); + + if (Cores < 0) + { + ZEN_WARN("provisioner/target: bad request (target_cores={})", Cores); + return HttpReq.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "Missing or invalid target_cores field"); + } + + IProvisionerStateProvider* Prov = m_Provisioner.load(std::memory_order_acquire); + if (!Prov) + { + ZEN_WARN("provisioner/target: no provisioner configured"); + return HttpReq.WriteResponse(HttpResponseCode::NotFound, HttpContentType::kText, "No provisioner configured"); + } + + ZEN_INFO("provisioner/target: setting target to {} cores", Cores); + Prov->SetTargetCoreCount(static_cast<uint32_t>(Cores)); + + CbObjectWriter Cbo; + Cbo << "target_cores" << Prov->GetTargetCoreCount(); + HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + }, + HttpVerb::kPost); + // Client tracking endpoints m_Router.RegisterRoute( @@ -375,7 +478,7 @@ HttpOrchestratorService::Shutdown() m_PushThread.join(); } - // Clean up worker WebSocket connections — collect IDs under lock, then + // Clean up worker WebSocket connections - collect IDs under lock, then // notify the service outside the lock to avoid lock-order inversions. std::vector<std::string> WorkerIds; m_WorkerWsLock.WithExclusiveLock([&] { @@ -411,6 +514,13 @@ HttpOrchestratorService::HandleRequest(HttpServerRequest& Request) } } +void +HttpOrchestratorService::SetProvisionerStateProvider(IProvisionerStateProvider* Provider) +{ + m_Provisioner.store(Provider, std::memory_order_release); + m_Service->SetProvisionerStateProvider(Provider); +} + ////////////////////////////////////////////////////////////////////////// // // IWebSocketHandler @@ -418,8 +528,9 @@ HttpOrchestratorService::HandleRequest(HttpServerRequest& Request) # if ZEN_WITH_WEBSOCKETS void -HttpOrchestratorService::OnWebSocketOpen(Ref<WebSocketConnection> Connection) +HttpOrchestratorService::OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri) { + ZEN_UNUSED(RelativeUri); if (!m_PushEnabled.load()) { return; @@ -471,7 +582,7 @@ std::string HttpOrchestratorService::HandleWorkerWebSocketMessage(const WebSocketMessage& Msg) { // Workers send CbObject in native binary format over the WebSocket to - // avoid the lossy CbObject↔JSON round-trip. + // avoid the lossy CbObject<->JSON round-trip. CbObject Data = CbObject::MakeView(Msg.Payload.GetData()); if (!Data) { @@ -487,6 +598,11 @@ HttpOrchestratorService::HandleWorkerWebSocketMessage(const WebSocketMessage& Ms return {}; } + if (!ValidateCoordinatorSession(Data, WorkerId)) + { + return {}; + } + m_Service->AnnounceWorker(Ann); return std::string(WorkerId); } @@ -562,7 +678,7 @@ HttpOrchestratorService::PushThreadFunction() } // Build combined JSON with worker list, provisioning history, clients, and client history - CbObject WorkerList = m_Service->GetWorkerList(); + CbObject WorkerList = m_Service->GetWorkerList(MakeWorkerAnnotator(m_Provisioner.load(std::memory_order_acquire))); CbObject History = m_Service->GetProvisioningHistory(50); CbObject ClientList = m_Service->GetClientList(); CbObject ClientHistory = m_Service->GetClientHistory(50); @@ -614,6 +730,20 @@ HttpOrchestratorService::PushThreadFunction() JsonBuilder.Append(ClientHistoryJsonView.substr(1, ClientHistoryJsonView.size() - 2)); } + // Emit provisioner stats if available + if (IProvisionerStateProvider* Prov = m_Provisioner.load(std::memory_order_acquire)) + { + JsonBuilder.Append( + fmt::format(",\"provisioner\":{{\"name\":\"{}\",\"target_cores\":{},\"estimated_cores\":{}" + ",\"active_cores\":{},\"agents\":{},\"agents_draining\":{}}}", + Prov->GetName(), + Prov->GetTargetCoreCount(), + Prov->GetEstimatedCoreCount(), + Prov->GetActiveCoreCount(), + Prov->GetAgentCount(), + Prov->GetDrainingAgentCount())); + } + JsonBuilder.Append("}"); std::string_view Json = JsonBuilder.ToView(); diff --git a/src/zencompute/include/zencompute/cloudmetadata.h b/src/zencompute/include/zencompute/cloudmetadata.h index 3b9642ac3..280d794e7 100644 --- a/src/zencompute/include/zencompute/cloudmetadata.h +++ b/src/zencompute/include/zencompute/cloudmetadata.h @@ -64,7 +64,7 @@ public: explicit CloudMetadata(std::filesystem::path DataDir); /** Synchronously probes cloud providers at the given IMDS endpoint. - * Intended for testing — allows redirecting all IMDS queries to a local + * Intended for testing - allows redirecting all IMDS queries to a local * mock HTTP server instead of the real 169.254.169.254 endpoint. */ CloudMetadata(std::filesystem::path DataDir, std::string ImdsEndpoint); diff --git a/src/zencompute/include/zencompute/computeservice.h b/src/zencompute/include/zencompute/computeservice.h index 1ca78738a..97de4321a 100644 --- a/src/zencompute/include/zencompute/computeservice.h +++ b/src/zencompute/include/zencompute/computeservice.h @@ -167,6 +167,7 @@ public: // Action runners void AddLocalRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath, int32_t MaxConcurrentActions = 0); + void AddManagedLocalRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath, int32_t MaxConcurrentActions = 0); void AddRemoteRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath, std::string_view HostName); // Action submission @@ -278,7 +279,7 @@ public: // sized to match RunnerAction::State::_Count but we can't use the enum here // for dependency reasons, so just use a fixed size array and static assert in // the implementation file - uint64_t Timestamps[9] = {}; + uint64_t Timestamps[10] = {}; }; [[nodiscard]] std::vector<ActionHistoryEntry> GetActionHistory(int Limit = 100); @@ -304,8 +305,9 @@ public: // Recording - void StartRecording(ChunkResolver& InResolver, const std::filesystem::path& RecordingPath); - void StopRecording(); + bool StartRecording(ChunkResolver& InResolver, const std::filesystem::path& RecordingPath); + bool StopRecording(); + bool IsRecording() const; private: void PostUpdate(RunnerAction* Action); diff --git a/src/zencompute/include/zencompute/httpcomputeservice.h b/src/zencompute/include/zencompute/httpcomputeservice.h index b58e73a0d..32f54f293 100644 --- a/src/zencompute/include/zencompute/httpcomputeservice.h +++ b/src/zencompute/include/zencompute/httpcomputeservice.h @@ -15,7 +15,7 @@ # include <memory> namespace zen { -class CidStore; +class ChunkStore; } namespace zen::compute { @@ -26,7 +26,8 @@ namespace zen::compute { class HttpComputeService : public HttpService, public IHttpStatsProvider, public IWebSocketHandler, public IComputeCompletionObserver { public: - HttpComputeService(CidStore& InCidStore, + HttpComputeService(ChunkStore& InActionStore, + ChunkStore& InWorkerStore, IHttpStatsService& StatsService, const std::filesystem::path& BaseDir, int32_t MaxConcurrentActions = 0); @@ -34,6 +35,10 @@ public: void Shutdown(); + /** Set a callback to be invoked when the session/sunset endpoint is hit. + * Typically wired to HttpServer::RequestExit() to shut down the process. */ + void SetShutdownCallback(std::function<void()> Callback); + [[nodiscard]] ComputeServiceSession::ActionCounts GetActionCounts(); const char* BaseUri() const override; @@ -45,7 +50,7 @@ public: // IWebSocketHandler - void OnWebSocketOpen(Ref<WebSocketConnection> Connection) override; + void OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri) override; void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) override; void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason) override; diff --git a/src/zencompute/include/zencompute/httporchestrator.h b/src/zencompute/include/zencompute/httporchestrator.h index da5c5dfc3..4e4f5f0f8 100644 --- a/src/zencompute/include/zencompute/httporchestrator.h +++ b/src/zencompute/include/zencompute/httporchestrator.h @@ -2,10 +2,12 @@ #pragma once +#include <zencompute/provisionerstate.h> #include <zencompute/zencompute.h> #include <zencore/logging.h> #include <zencore/thread.h> +#include <zencore/uid.h> #include <zenhttp/httpserver.h> #include <zenhttp/websocket.h> @@ -65,12 +67,22 @@ public: */ void Shutdown(); + /** Return the session ID generated at construction time. Provisioners + * pass this to spawned workers so the orchestrator can reject stale + * announcements from previous sessions. */ + Oid GetSessionId() const { return m_SessionId; } + + /** Register a provisioner whose target core count can be read and changed + * via the orchestrator HTTP API and dashboard. Caller retains ownership; + * the provider must outlive this service. */ + void SetProvisionerStateProvider(IProvisionerStateProvider* Provider); + virtual const char* BaseUri() const override; virtual void HandleRequest(HttpServerRequest& Request) override; // IWebSocketHandler #if ZEN_WITH_WEBSOCKETS - void OnWebSocketOpen(Ref<WebSocketConnection> Connection) override; + void OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri) override; void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) override; void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason) override; #endif @@ -81,6 +93,11 @@ private: std::unique_ptr<OrchestratorService> m_Service; std::string m_Hostname; + Oid m_SessionId; + bool ValidateCoordinatorSession(const CbObjectView& Data, std::string_view WorkerId); + + std::atomic<IProvisionerStateProvider*> m_Provisioner{nullptr}; + // WebSocket push #if ZEN_WITH_WEBSOCKETS @@ -91,9 +108,9 @@ private: Event m_PushEvent; void PushThreadFunction(); - // Worker WebSocket connections (worker→orchestrator persistent links) + // Worker WebSocket connections (worker->orchestrator persistent links) RwLock m_WorkerWsLock; - std::unordered_map<WebSocketConnection*, std::string> m_WorkerWsMap; // connection ptr → worker ID + std::unordered_map<WebSocketConnection*, std::string> m_WorkerWsMap; // connection ptr -> worker ID std::string HandleWorkerWebSocketMessage(const WebSocketMessage& Msg); #endif }; diff --git a/src/zencompute/include/zencompute/mockimds.h b/src/zencompute/include/zencompute/mockimds.h index 704306913..6074240b9 100644 --- a/src/zencompute/include/zencompute/mockimds.h +++ b/src/zencompute/include/zencompute/mockimds.h @@ -1,5 +1,5 @@ // Copyright Epic Games, Inc. All Rights Reserved. -// Moved to zenutil — this header is kept for backward compatibility. +// Moved to zenutil - this header is kept for backward compatibility. #pragma once diff --git a/src/zencompute/include/zencompute/orchestratorservice.h b/src/zencompute/include/zencompute/orchestratorservice.h index 071e902b3..2c49e22df 100644 --- a/src/zencompute/include/zencompute/orchestratorservice.h +++ b/src/zencompute/include/zencompute/orchestratorservice.h @@ -6,7 +6,10 @@ #if ZEN_WITH_COMPUTE_SERVICES +# include <zencompute/provisionerstate.h> # include <zencore/compactbinary.h> +# include <zencore/compactbinarybuilder.h> +# include <zencore/logbase.h> # include <zencore/thread.h> # include <zencore/timer.h> # include <zencore/uid.h> @@ -88,9 +91,16 @@ public: std::string Hostname; }; - CbObject GetWorkerList(); + /** Per-worker callback invoked during GetWorkerList serialization. + * The callback receives the worker ID and a CbObjectWriter positioned + * inside the worker's object, allowing the caller to append extra fields. */ + using WorkerAnnotator = std::function<void(std::string_view WorkerId, CbObjectWriter& Cbo)>; + + CbObject GetWorkerList(const WorkerAnnotator& Annotate = {}); void AnnounceWorker(const WorkerAnnouncement& Announcement); + void SetProvisionerStateProvider(IProvisionerStateProvider* Provider); + bool IsWorkerWebSocketEnabled() const; void SetWorkerWebSocketConnected(std::string_view WorkerId, bool Connected); @@ -164,7 +174,12 @@ private: void RecordClientEvent(ClientEvent::Type Type, std::string_view ClientId, std::string_view Hostname); - bool m_EnableWorkerWebSocket = false; + LoggerRef Log() { return m_Log; } + + LoggerRef m_Log{"compute.orchestrator"}; + bool m_EnableWorkerWebSocket = false; + + std::atomic<IProvisionerStateProvider*> m_Provisioner{nullptr}; std::thread m_ProbeThread; std::atomic<bool> m_ProbeThreadEnabled{true}; diff --git a/src/zencompute/include/zencompute/provisionerstate.h b/src/zencompute/include/zencompute/provisionerstate.h new file mode 100644 index 000000000..e9af8a635 --- /dev/null +++ b/src/zencompute/include/zencompute/provisionerstate.h @@ -0,0 +1,38 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <cstdint> +#include <string_view> + +namespace zen::compute { + +/** Per-agent provisioning status as seen by the provisioner. */ +enum class AgentProvisioningStatus +{ + Unknown, ///< Not known to the provisioner + Active, ///< Running and allocated + Draining, ///< Being gracefully deprovisioned +}; + +/** Abstract interface for querying and controlling a provisioner from the HTTP layer. + * This decouples the orchestrator service from specific provisioner implementations. */ +class IProvisionerStateProvider +{ +public: + virtual ~IProvisionerStateProvider() = default; + + virtual std::string_view GetName() const = 0; ///< e.g. "horde", "nomad" + virtual uint32_t GetTargetCoreCount() const = 0; + virtual uint32_t GetEstimatedCoreCount() const = 0; + virtual uint32_t GetActiveCoreCount() const = 0; + virtual uint32_t GetAgentCount() const = 0; + virtual uint32_t GetDrainingAgentCount() const { return 0; } + virtual void SetTargetCoreCount(uint32_t Count) = 0; + + /** Return the provisioning status for a worker by its orchestrator ID + * (e.g. "horde-{LeaseId}"). Returns Unknown if the ID is not recognized. */ + virtual AgentProvisioningStatus GetAgentStatus(std::string_view /*WorkerId*/) const { return AgentProvisioningStatus::Unknown; } +}; + +} // namespace zen::compute diff --git a/src/zencompute/orchestratorservice.cpp b/src/zencompute/orchestratorservice.cpp index 9ea695305..68199ab3c 100644 --- a/src/zencompute/orchestratorservice.cpp +++ b/src/zencompute/orchestratorservice.cpp @@ -31,7 +31,7 @@ OrchestratorService::~OrchestratorService() } CbObject -OrchestratorService::GetWorkerList() +OrchestratorService::GetWorkerList(const WorkerAnnotator& Annotate) { ZEN_TRACE_CPU("OrchestratorService::GetWorkerList"); CbObjectWriter Cbo; @@ -71,6 +71,10 @@ OrchestratorService::GetWorkerList() Cbo << "ws_connected" << true; } Cbo << "dt" << Worker.LastSeen.GetElapsedTimeMs(); + if (Annotate) + { + Annotate(WorkerId, Cbo); + } Cbo.EndObject(); } }); @@ -144,6 +148,12 @@ OrchestratorService::AnnounceWorker(const WorkerAnnouncement& Ann) } } +void +OrchestratorService::SetProvisionerStateProvider(IProvisionerStateProvider* Provider) +{ + m_Provisioner.store(Provider, std::memory_order_release); +} + bool OrchestratorService::IsWorkerWebSocketEnabled() const { @@ -170,11 +180,11 @@ OrchestratorService::SetWorkerWebSocketConnected(std::string_view WorkerId, bool if (Connected) { - ZEN_INFO("worker {} WebSocket connected — marking reachable", WorkerId); + ZEN_INFO("worker {} WebSocket connected - marking reachable", WorkerId); } else { - ZEN_WARN("worker {} WebSocket disconnected — marking unreachable", WorkerId); + ZEN_WARN("worker {} WebSocket disconnected - marking unreachable", WorkerId); } }); @@ -607,6 +617,14 @@ OrchestratorService::ProbeThreadFunction() continue; } + // Check if the provisioner knows this worker is draining - if so, + // unreachability is expected and should not be logged as a warning. + bool IsDraining = false; + if (IProvisionerStateProvider* Prov = m_Provisioner.load(std::memory_order_acquire)) + { + IsDraining = Prov->GetAgentStatus(Snap.Id) == AgentProvisioningStatus::Draining; + } + ReachableState NewState = ReachableState::Unreachable; try @@ -621,7 +639,10 @@ OrchestratorService::ProbeThreadFunction() } catch (const std::exception& Ex) { - ZEN_WARN("probe failed for worker {} ({}): {}", Snap.Id, Snap.Uri, Ex.what()); + if (!IsDraining) + { + ZEN_WARN("probe failed for worker {} ({}): {}", Snap.Id, Snap.Uri, Ex.what()); + } } ReachableState PrevState = ReachableState::Unknown; @@ -646,6 +667,10 @@ OrchestratorService::ProbeThreadFunction() { ZEN_INFO("worker {} ({}) is now reachable", Snap.Id, Snap.Uri); } + else if (IsDraining) + { + ZEN_INFO("worker {} ({}) shut down (draining)", Snap.Id, Snap.Uri); + } else if (PrevState == ReachableState::Reachable) { ZEN_WARN("worker {} ({}) is no longer reachable", Snap.Id, Snap.Uri); diff --git a/src/zencompute/pathvalidation.h b/src/zencompute/pathvalidation.h new file mode 100644 index 000000000..d50ad4a2a --- /dev/null +++ b/src/zencompute/pathvalidation.h @@ -0,0 +1,118 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/compactbinary.h> +#include <zencore/except_fmt.h> +#include <zencore/string.h> + +#include <filesystem> +#include <string_view> + +namespace zen::compute { + +// Validate that a single path component contains only characters that are valid +// file/directory names on all supported platforms. Uses Windows rules as the most +// restrictive superset, since packages may be built on one platform and consumed +// on another. +inline void +ValidatePathComponent(std::string_view Component, std::string_view FullPath) +{ + // Reject control characters (0x00-0x1F) and characters forbidden on Windows + for (char Ch : Component) + { + if (static_cast<unsigned char>(Ch) < 0x20 || Ch == '<' || Ch == '>' || Ch == ':' || Ch == '"' || Ch == '|' || Ch == '?' || + Ch == '*') + { + throw zen::invalid_argument("invalid character in path component '{}' of '{}'", Component, FullPath); + } + } + + // Reject empty components and trailing dots or spaces (silently stripped on Windows, leading to confusion) + if (Component.empty() || Component.back() == '.' || Component.back() == ' ') + { + throw zen::invalid_argument("path component '{}' of '{}' has trailing dot or space", Component, FullPath); + } + + // Reject Windows reserved device names (CON, PRN, AUX, NUL, COM1-9, LPT1-9) + // These are reserved with or without an extension (e.g. "CON.txt" is still reserved). + std::string_view Stem = Component.substr(0, Component.find('.')); + + static constexpr std::string_view ReservedNames[] = { + "CON", "PRN", "AUX", "NUL", "COM1", "COM2", "COM3", "COM4", "COM5", "COM6", "COM7", + "COM8", "COM9", "LPT1", "LPT2", "LPT3", "LPT4", "LPT5", "LPT6", "LPT7", "LPT8", "LPT9", + }; + + for (std::string_view Reserved : ReservedNames) + { + if (zen::StrCaseCompare(Stem, Reserved) == 0) + { + throw zen::invalid_argument("path component '{}' of '{}' uses reserved device name '{}'", Component, FullPath, Reserved); + } + } +} + +// Validate that a path extracted from a package is a safe relative path. +// Rejects absolute paths, ".." components, and invalid platform filenames. +inline void +ValidateSandboxRelativePath(std::string_view Name) +{ + if (Name.empty()) + { + throw zen::invalid_argument("path traversal detected: empty path name"); + } + + std::filesystem::path Parsed(Name); + + if (Parsed.is_absolute()) + { + throw zen::invalid_argument("path traversal detected: '{}' is an absolute path", Name); + } + + for (const auto& Component : Parsed) + { + std::string ComponentStr = Component.string(); + + if (ComponentStr == "..") + { + throw zen::invalid_argument("path traversal detected: '{}' contains '..' component", Name); + } + + // Skip "." (current directory) - harmless in relative paths + if (ComponentStr != ".") + { + ValidatePathComponent(ComponentStr, Name); + } + } +} + +// Validate all path entries in a worker description CbObject. +// Checks path, executables[].name, dirs[], and files[].name fields. +// Throws an exception if any invalid paths are found. +inline void +ValidateWorkerDescriptionPaths(const CbObject& WorkerDescription) +{ + using namespace std::literals; + + if (auto PathField = WorkerDescription["path"sv]; PathField.HasValue()) + { + ValidateSandboxRelativePath(PathField.AsString()); + } + + for (auto& It : WorkerDescription["executables"sv]) + { + ValidateSandboxRelativePath(It.AsObjectView()["name"sv].AsString()); + } + + for (auto& It : WorkerDescription["dirs"sv]) + { + ValidateSandboxRelativePath(It.AsString()); + } + + for (auto& It : WorkerDescription["files"sv]) + { + ValidateSandboxRelativePath(It.AsObjectView()["name"sv].AsString()); + } +} + +} // namespace zen::compute diff --git a/src/zencompute/runners/functionrunner.cpp b/src/zencompute/runners/functionrunner.cpp index 4f116e7d8..34bf065b4 100644 --- a/src/zencompute/runners/functionrunner.cpp +++ b/src/zencompute/runners/functionrunner.cpp @@ -6,9 +6,15 @@ # include <zencore/compactbinary.h> # include <zencore/filesystem.h> +# include <zencore/fmtutils.h> +# include <zencore/logging.h> +# include <zencore/string.h> +# include <zencore/timer.h> # include <zencore/trace.h> +# include <zencore/workthreadpool.h> # include <fmt/format.h> +# include <future> # include <vector> namespace zen::compute { @@ -118,23 +124,34 @@ std::vector<SubmitResult> BaseRunnerGroup::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions) { ZEN_TRACE_CPU("BaseRunnerGroup::SubmitActions"); - RwLock::SharedLockScope _(m_RunnersLock); - const int RunnerCount = gsl::narrow<int>(m_Runners.size()); + // Snapshot runners and query capacity under the lock, then release + // before submitting - HTTP submissions to remote runners can take + // hundreds of milliseconds and we must not hold m_RunnersLock during I/O. - if (RunnerCount == 0) - { - return std::vector<SubmitResult>(Actions.size(), SubmitResult{.IsAccepted = false, .Reason = "No runners available"}); - } + std::vector<Ref<FunctionRunner>> Runners; + std::vector<size_t> Capacities; + std::vector<std::vector<Ref<RunnerAction>>> PerRunnerActions; + size_t TotalCapacity = 0; - // Query capacity per runner and compute total - std::vector<size_t> Capacities(RunnerCount); - size_t TotalCapacity = 0; + m_RunnersLock.WithSharedLock([&] { + const int RunnerCount = gsl::narrow<int>(m_Runners.size()); + Runners.assign(m_Runners.begin(), m_Runners.end()); + Capacities.resize(RunnerCount); + PerRunnerActions.resize(RunnerCount); - for (int i = 0; i < RunnerCount; ++i) + for (int i = 0; i < RunnerCount; ++i) + { + Capacities[i] = Runners[i]->QueryCapacity(); + TotalCapacity += Capacities[i]; + } + }); + + const int RunnerCount = gsl::narrow<int>(Runners.size()); + + if (RunnerCount == 0) { - Capacities[i] = m_Runners[i]->QueryCapacity(); - TotalCapacity += Capacities[i]; + return std::vector<SubmitResult>(Actions.size(), SubmitResult{.IsAccepted = false, .Reason = "No runners available"}); } if (TotalCapacity == 0) @@ -143,9 +160,8 @@ BaseRunnerGroup::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions) } // Distribute actions across runners proportionally to their available capacity - std::vector<std::vector<Ref<RunnerAction>>> PerRunnerActions(RunnerCount); - std::vector<size_t> ActionRunnerIndex(Actions.size()); - size_t ActionIdx = 0; + std::vector<size_t> ActionRunnerIndex(Actions.size()); + size_t ActionIdx = 0; for (int i = 0; i < RunnerCount; ++i) { @@ -164,8 +180,9 @@ BaseRunnerGroup::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions) } } - // Assign any remaining actions to runners with capacity (round-robin) - for (int i = 0; ActionIdx < Actions.size(); i = (i + 1) % RunnerCount) + // Assign any remaining actions to runners with capacity (round-robin). + // Cap at TotalCapacity to avoid spinning when there are more actions than runners can accept. + for (int i = 0; ActionIdx < Actions.size() && ActionIdx < TotalCapacity; i = (i + 1) % RunnerCount) { if (Capacities[i] > PerRunnerActions[i].size()) { @@ -175,22 +192,83 @@ BaseRunnerGroup::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions) } } - // Submit batches per runner + // Submit batches per runner - in parallel when a worker pool is available + std::vector<std::vector<SubmitResult>> PerRunnerResults(RunnerCount); + int ActiveRunnerCount = 0; for (int i = 0; i < RunnerCount; ++i) { if (!PerRunnerActions[i].empty()) { - PerRunnerResults[i] = m_Runners[i]->SubmitActions(PerRunnerActions[i]); + ++ActiveRunnerCount; + } + } + + static constexpr uint64_t SubmitWarnThresholdMs = 500; + + auto SubmitToRunner = [&](int RunnerIndex) { + auto& Runner = Runners[RunnerIndex]; + Runner->m_LastSubmitStats.Reset(); + + Stopwatch Timer; + + PerRunnerResults[RunnerIndex] = Runner->SubmitActions(PerRunnerActions[RunnerIndex]); + + uint64_t ElapsedMs = Timer.GetElapsedTimeMs(); + if (ElapsedMs >= SubmitWarnThresholdMs) + { + size_t Attachments = Runner->m_LastSubmitStats.TotalAttachments.load(std::memory_order_relaxed); + uint64_t AttachmentBytes = Runner->m_LastSubmitStats.TotalAttachmentBytes.load(std::memory_order_relaxed); + + ZEN_WARN("submit of {} actions ({} attachments, {}) to '{}' took {}ms", + PerRunnerActions[RunnerIndex].size(), + Attachments, + NiceBytes(AttachmentBytes), + Runner->GetDisplayName(), + ElapsedMs); + } + }; + + if (m_WorkerPool && ActiveRunnerCount > 1) + { + std::vector<std::future<void>> Futures(RunnerCount); + + for (int i = 0; i < RunnerCount; ++i) + { + if (!PerRunnerActions[i].empty()) + { + std::packaged_task<void()> Task([&SubmitToRunner, i]() { SubmitToRunner(i); }); + + Futures[i] = m_WorkerPool->EnqueueTask(std::move(Task), WorkerThreadPool::EMode::EnableBacklog); + } + } + + for (int i = 0; i < RunnerCount; ++i) + { + if (Futures[i].valid()) + { + Futures[i].get(); + } + } + } + else + { + for (int i = 0; i < RunnerCount; ++i) + { + if (!PerRunnerActions[i].empty()) + { + SubmitToRunner(i); + } } } - // Reassemble results in original action order - std::vector<SubmitResult> Results(Actions.size()); + // Reassemble results in original action order. + // Actions beyond ActionIdx were not assigned to any runner (insufficient capacity). + std::vector<SubmitResult> Results(Actions.size(), SubmitResult{.IsAccepted = false, .Reason = "No capacity"}); std::vector<size_t> PerRunnerIdx(RunnerCount, 0); - for (size_t i = 0; i < Actions.size(); ++i) + for (size_t i = 0; i < ActionIdx; ++i) { size_t RunnerIdx = ActionRunnerIndex[i]; size_t Idx = PerRunnerIdx[RunnerIdx]++; @@ -307,10 +385,11 @@ RunnerAction::RetractAction() bool RunnerAction::ResetActionStateToPending() { - // Only allow reset from Failed, Abandoned, or Retracted states + // Only allow reset from Failed, Abandoned, Rejected, or Retracted states State CurrentState = m_ActionState.load(); - if (CurrentState != State::Failed && CurrentState != State::Abandoned && CurrentState != State::Retracted) + if (CurrentState != State::Failed && CurrentState != State::Abandoned && CurrentState != State::Rejected && + CurrentState != State::Retracted) { return false; } @@ -331,11 +410,12 @@ RunnerAction::ResetActionStateToPending() // Clear execution fields ExecutionLocation.clear(); + FailureReason.clear(); CpuUsagePercent.store(-1.0f, std::memory_order_relaxed); CpuSeconds.store(0.0f, std::memory_order_relaxed); - // Increment retry count (skip for Retracted — nothing failed) - if (CurrentState != State::Retracted) + // Increment retry count (skip for Retracted/Rejected - nothing failed) + if (CurrentState != State::Retracted && CurrentState != State::Rejected) { RetryCount.fetch_add(1, std::memory_order_relaxed); } diff --git a/src/zencompute/runners/functionrunner.h b/src/zencompute/runners/functionrunner.h index 56c3f3af0..371a60b7a 100644 --- a/src/zencompute/runners/functionrunner.h +++ b/src/zencompute/runners/functionrunner.h @@ -10,6 +10,10 @@ # include <filesystem> # include <vector> +namespace zen { +class WorkerThreadPool; +} + namespace zen::compute { struct SubmitResult @@ -37,6 +41,22 @@ public: [[nodiscard]] virtual bool IsHealthy() = 0; [[nodiscard]] virtual size_t QueryCapacity(); [[nodiscard]] virtual std::vector<SubmitResult> SubmitActions(const std::vector<Ref<RunnerAction>>& Actions); + [[nodiscard]] virtual std::string_view GetDisplayName() const { return "local"; } + + // Accumulated stats from the most recent SubmitActions call. + // Reset before each call, populated by the runner implementation. + struct SubmitStats + { + std::atomic<size_t> TotalAttachments{0}; + std::atomic<uint64_t> TotalAttachmentBytes{0}; + + void Reset() + { + TotalAttachments.store(0, std::memory_order_relaxed); + TotalAttachmentBytes.store(0, std::memory_order_relaxed); + } + }; + SubmitStats m_LastSubmitStats; // Best-effort cancellation of a specific in-flight action. Returns true if the // cancellation signal was successfully sent. The action will transition to Cancelled @@ -68,6 +88,8 @@ public: bool CancelAction(int ActionLsn); void CancelRemoteQueue(int QueueId); + void SetWorkerPool(WorkerThreadPool* Pool) { m_WorkerPool = Pool; } + size_t GetRunnerCount() { return m_RunnersLock.WithSharedLock([this] { return m_Runners.size(); }); @@ -79,6 +101,7 @@ protected: RwLock m_RunnersLock; std::vector<Ref<FunctionRunner>> m_Runners; std::atomic<int> m_NextSubmitIndex{0}; + WorkerThreadPool* m_WorkerPool = nullptr; }; /** Typed RunnerGroup that adds type-safe runner addition and predicate-based removal. @@ -151,6 +174,7 @@ struct RunnerAction : public RefCounted CbObject ActionObj; int Priority = 0; std::string ExecutionLocation; // "local" or remote hostname + std::string FailureReason; // human-readable reason when action fails (empty on success) // CPU usage and total CPU time of the running process, sampled periodically by the local runner. // CpuUsagePercent: -1.0 means not yet sampled; >=0.0 is the most recent reading as a percentage. @@ -168,6 +192,7 @@ struct RunnerAction : public RefCounted Completed, // Finished successfully with results available Failed, // Execution failed (transient error, eligible for retry) Abandoned, // Infrastructure termination (e.g. spot eviction, session abandon) + Rejected, // Runner declined (e.g. at capacity) - rescheduled without retry cost Cancelled, // Intentional user cancellation (never retried) Retracted, // Pulled back for rescheduling on a different runner (no retry cost) _Count @@ -194,6 +219,8 @@ struct RunnerAction : public RefCounted return "Failed"; case State::Abandoned: return "Abandoned"; + case State::Rejected: + return "Rejected"; case State::Cancelled: return "Cancelled"; case State::Retracted: diff --git a/src/zencompute/runners/linuxrunner.cpp b/src/zencompute/runners/linuxrunner.cpp index e79a6c90f..be4274823 100644 --- a/src/zencompute/runners/linuxrunner.cpp +++ b/src/zencompute/runners/linuxrunner.cpp @@ -195,7 +195,7 @@ namespace { WriteErrorAndExit(ErrorPipeFd, "bind mount /lib failed", errno); } - // /lib64 (optional — not all distros have it) + // /lib64 (optional - not all distros have it) { struct stat St; if (stat("/lib64", &St) == 0 && S_ISDIR(St.st_mode)) @@ -208,7 +208,7 @@ namespace { } } - // /etc (required — for resolv.conf, ld.so.cache, etc.) + // /etc (required - for resolv.conf, ld.so.cache, etc.) if (MkdirIfNeeded(BuildPath("etc"), 0755) != 0) { WriteErrorAndExit(ErrorPipeFd, "mkdir sandbox/etc failed", errno); @@ -218,7 +218,7 @@ namespace { WriteErrorAndExit(ErrorPipeFd, "bind mount /etc failed", errno); } - // /worker — bind-mount worker directory (contains the executable) + // /worker - bind-mount worker directory (contains the executable) if (MkdirIfNeeded(BuildPath("worker"), 0755) != 0) { WriteErrorAndExit(ErrorPipeFd, "mkdir sandbox/worker failed", errno); @@ -331,6 +331,8 @@ LinuxProcessRunner::LinuxProcessRunner(ChunkResolver& Resolver, { ZEN_INFO("namespace sandboxing enabled for child processes"); } + + StartMonitorThread(); } SubmitResult @@ -428,11 +430,12 @@ LinuxProcessRunner::SubmitAction(Ref<RunnerAction> Action) if (ChildPid == 0) { - // Child process + // Child process - lower priority so workers don't starve the main server + nice(5); if (m_Sandboxed) { - // Close read end of error pipe — child only writes + // Close read end of error pipe - child only writes close(ErrorPipe[0]); SetupNamespaceSandbox(SandboxPathStr.c_str(), CurrentUid, CurrentGid, WorkerPathStr.c_str(), ErrorPipe[1]); @@ -459,7 +462,7 @@ LinuxProcessRunner::SubmitAction(Ref<RunnerAction> Action) if (m_Sandboxed) { - // Close write end of error pipe — parent only reads + // Close write end of error pipe - parent only reads close(ErrorPipe[1]); // Read from error pipe. If execve succeeded, pipe was closed by O_CLOEXEC @@ -479,7 +482,8 @@ LinuxProcessRunner::SubmitAction(Ref<RunnerAction> Action) // Clean up the sandbox in the background m_DeferredDeleter.Enqueue(Action->ActionLsn, std::move(Prepared->SandboxPath)); - ZEN_ERROR("Sandbox setup failed for action {}: {}", Action->ActionLsn, ErrBuf); + Action->FailureReason = fmt::format("sandbox setup failed: {}", ErrBuf); + ZEN_ERROR("action {} ({}): {}", Action->ActionId, Action->ActionLsn, Action->FailureReason); Action->SetActionState(RunnerAction::State::Failed); return SubmitResult{.IsAccepted = false}; @@ -675,7 +679,7 @@ ReadProcStatCpuTicks(pid_t Pid) Buf[Len] = '\0'; - // Skip past "pid (name) " — find last ')' to handle names containing spaces or parens + // Skip past "pid (name) " - find last ')' to handle names containing spaces or parens const char* P = strrchr(Buf, ')'); if (!P) { @@ -705,7 +709,7 @@ LinuxProcessRunner::SampleProcessCpu(RunningAction& Running) if (CurrentOsTicks == 0) { - // Process gone or /proc entry unreadable — record timestamp without updating usage + // Process gone or /proc entry unreadable - record timestamp without updating usage Running.LastCpuSampleTicks = NowTicks; Running.LastCpuOsTicks = 0; return; diff --git a/src/zencompute/runners/localrunner.cpp b/src/zencompute/runners/localrunner.cpp index b61e0a46f..259965e23 100644 --- a/src/zencompute/runners/localrunner.cpp +++ b/src/zencompute/runners/localrunner.cpp @@ -4,6 +4,8 @@ #if ZEN_WITH_COMPUTE_SERVICES +# include "pathvalidation.h" + # include <zencore/compactbinary.h> # include <zencore/compactbinarybuilder.h> # include <zencore/compactbinarypackage.h> @@ -104,8 +106,6 @@ LocalProcessRunner::LocalProcessRunner(ChunkResolver& Resolver, ZEN_INFO("Cleanup complete"); } - m_MonitorThread = std::thread{&LocalProcessRunner::MonitorThreadFunction, this}; - # if ZEN_PLATFORM_WINDOWS // Suppress any error dialogs caused by missing dependencies UINT OldMode = ::SetErrorMode(0); @@ -337,7 +337,7 @@ LocalProcessRunner::PrepareActionSubmission(Ref<RunnerAction> Action) SubmitResult LocalProcessRunner::SubmitAction(Ref<RunnerAction> Action) { - // Base class is not directly usable — platform subclasses override this + // Base class is not directly usable - platform subclasses override this ZEN_UNUSED(Action); return SubmitResult{.IsAccepted = false}; } @@ -357,14 +357,21 @@ LocalProcessRunner::ManifestWorker(const WorkerDesc& Worker) std::filesystem::path WorkerDir = m_WorkerPath / fmt::format("runner_{}", Worker.WorkerId); - if (!std::filesystem::exists(WorkerDir)) + // worker.zcb is written as the last step of ManifestWorker, so its presence + // indicates a complete manifest. If the directory exists but the marker is + // missing, a previous manifest was interrupted and we need to start over. + bool NeedsManifest = !std::filesystem::exists(WorkerDir / "worker.zcb"); + + if (NeedsManifest) { _.ReleaseNow(); RwLock::ExclusiveLockScope $(m_WorkerLock); - if (!std::filesystem::exists(WorkerDir)) + if (!std::filesystem::exists(WorkerDir / "worker.zcb")) { + std::error_code Ec; + std::filesystem::remove_all(WorkerDir, Ec); ManifestWorker(Worker.Descriptor, WorkerDir, [](const IoHash&, CompressedBuffer&) {}); } } @@ -382,6 +389,8 @@ LocalProcessRunner::DecompressAttachmentToFile(const CbPackage& FromP const IoHash ChunkHash = FileEntry["hash"sv].AsHash(); const uint64_t Size = FileEntry["size"sv].AsUInt64(); + ValidateSandboxRelativePath(Name); + CompressedBuffer Compressed; if (const CbAttachment* Attachment = FromPackage.FindAttachment(ChunkHash)) @@ -457,7 +466,8 @@ LocalProcessRunner::ManifestWorker(const CbPackage& WorkerPackage, for (auto& It : WorkerDescription["dirs"sv]) { - std::string_view Name = It.AsString(); + std::string_view Name = It.AsString(); + ValidateSandboxRelativePath(Name); std::filesystem::path DirPath{SandboxPath / std::filesystem::path(Name).make_preferred()}; // Validate dir path stays within sandbox @@ -482,6 +492,8 @@ LocalProcessRunner::ManifestWorker(const CbPackage& WorkerPackage, } WriteFile(SandboxPath / "worker.zcb", WorkerDescription.GetBuffer().AsIoBuffer()); + + ZEN_INFO("manifested worker '{}' in '{}'", WorkerPackage.GetObjectHash(), SandboxPath); } CbPackage @@ -540,6 +552,12 @@ LocalProcessRunner::GatherActionOutputs(std::filesystem::path SandboxPath) } void +LocalProcessRunner::StartMonitorThread() +{ + m_MonitorThread = std::thread{&LocalProcessRunner::MonitorThreadFunction, this}; +} + +void LocalProcessRunner::MonitorThreadFunction() { SetCurrentThreadName("LocalProcessRunner_Monitor"); @@ -602,7 +620,7 @@ LocalProcessRunner::MonitorThreadFunction() void LocalProcessRunner::CancelRunningActions() { - // Base class is not directly usable — platform subclasses override this + // Base class is not directly usable - platform subclasses override this } void @@ -662,9 +680,15 @@ LocalProcessRunner::ProcessCompletedActions(std::vector<Ref<RunningAction>>& Com } catch (std::exception& Ex) { - ZEN_ERROR("Encountered failure while gathering outputs for action lsn {}, '{}'", ActionLsn, Ex.what()); + Running->Action->FailureReason = fmt::format("exception gathering outputs: {}", Ex.what()); + ZEN_ERROR("action {} ({}) failed: {}", Running->Action->ActionId, ActionLsn, Running->Action->FailureReason); } } + else + { + Running->Action->FailureReason = fmt::format("process exited with code {}", Running->ExitCode); + ZEN_WARN("action {} ({}) failed: {}", Running->Action->ActionId, ActionLsn, Running->Action->FailureReason); + } // Failed - clean up the sandbox in the background. diff --git a/src/zencompute/runners/localrunner.h b/src/zencompute/runners/localrunner.h index b8cff6826..d6589db43 100644 --- a/src/zencompute/runners/localrunner.h +++ b/src/zencompute/runners/localrunner.h @@ -67,6 +67,7 @@ protected: { Ref<RunnerAction> Action; void* ProcessHandle = nullptr; + int Pid = 0; int ExitCode = 0; std::filesystem::path SandboxPath; @@ -83,8 +84,6 @@ protected: std::filesystem::path m_SandboxPath; int32_t m_MaxRunningActions = 64; // arbitrary limit for testing - // if used in conjuction with m_ResultsLock, this lock must be taken *after* - // m_ResultsLock to avoid deadlocks RwLock m_RunningLock; std::unordered_map<int, Ref<RunningAction>> m_RunningMap; @@ -95,6 +94,7 @@ protected: std::thread m_MonitorThread; std::atomic<bool> m_MonitorThreadEnabled{true}; Event m_MonitorThreadEvent; + void StartMonitorThread(); void MonitorThreadFunction(); virtual void SweepRunningActions(); virtual void CancelRunningActions(); diff --git a/src/zencompute/runners/macrunner.cpp b/src/zencompute/runners/macrunner.cpp index 5cec90699..ab24d4672 100644 --- a/src/zencompute/runners/macrunner.cpp +++ b/src/zencompute/runners/macrunner.cpp @@ -130,6 +130,8 @@ MacProcessRunner::MacProcessRunner(ChunkResolver& Resolver, { ZEN_INFO("Seatbelt sandboxing enabled for child processes"); } + + StartMonitorThread(); } SubmitResult @@ -209,18 +211,19 @@ MacProcessRunner::SubmitAction(Ref<RunnerAction> Action) if (ChildPid == 0) { - // Child process + // Child process - lower priority so workers don't starve the main server + nice(5); if (m_Sandboxed) { - // Close read end of error pipe — child only writes + // Close read end of error pipe - child only writes close(ErrorPipe[0]); // Apply Seatbelt sandbox profile char* ErrorBuf = nullptr; if (sandbox_init(SandboxProfile.c_str(), 0, &ErrorBuf) != 0) { - // sandbox_init failed — write error to pipe and exit + // sandbox_init failed - write error to pipe and exit if (ErrorBuf) { WriteErrorAndExit(ErrorPipe[1], ErrorBuf, 0); @@ -259,7 +262,7 @@ MacProcessRunner::SubmitAction(Ref<RunnerAction> Action) if (m_Sandboxed) { - // Close write end of error pipe — parent only reads + // Close write end of error pipe - parent only reads close(ErrorPipe[1]); // Read from error pipe. If execve succeeded, pipe was closed by O_CLOEXEC @@ -279,7 +282,8 @@ MacProcessRunner::SubmitAction(Ref<RunnerAction> Action) // Clean up the sandbox in the background m_DeferredDeleter.Enqueue(Action->ActionLsn, std::move(Prepared->SandboxPath)); - ZEN_ERROR("Sandbox setup failed for action {}: {}", Action->ActionLsn, ErrBuf); + Action->FailureReason = fmt::format("sandbox setup failed: {}", ErrBuf); + ZEN_ERROR("action {} ({}): {}", Action->ActionId, Action->ActionLsn, Action->FailureReason); Action->SetActionState(RunnerAction::State::Failed); return SubmitResult{.IsAccepted = false}; @@ -467,7 +471,7 @@ MacProcessRunner::SampleProcessCpu(RunningAction& Running) const uint64_t CurrentOsTicks = Info.pti_total_user + Info.pti_total_system; const uint64_t NowTicks = GetHifreqTimerValue(); - // Cumulative CPU seconds (absolute, available from first sample): ns → seconds + // Cumulative CPU seconds (absolute, available from first sample): ns -> seconds Running.Action->CpuSeconds.store(static_cast<float>(static_cast<double>(CurrentOsTicks) / 1'000'000'000.0), std::memory_order_relaxed); if (Running.LastCpuSampleTicks != 0 && Running.LastCpuOsTicks != 0) @@ -476,7 +480,7 @@ MacProcessRunner::SampleProcessCpu(RunningAction& Running) if (ElapsedMs > 0) { const uint64_t DeltaOsTicks = CurrentOsTicks - Running.LastCpuOsTicks; - // ns → ms: divide by 1,000,000; then as percent of elapsed ms + // ns -> ms: divide by 1,000,000; then as percent of elapsed ms const float CpuPct = static_cast<float>(static_cast<double>(DeltaOsTicks) / 1'000'000.0 / ElapsedMs * 100.0); Running.Action->CpuUsagePercent.store(CpuPct, std::memory_order_relaxed); } diff --git a/src/zencompute/runners/managedrunner.cpp b/src/zencompute/runners/managedrunner.cpp new file mode 100644 index 000000000..a4f586852 --- /dev/null +++ b/src/zencompute/runners/managedrunner.cpp @@ -0,0 +1,279 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "managedrunner.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include <zencore/compactbinary.h> +# include <zencore/compactbinarypackage.h> +# include <zencore/except_fmt.h> +# include <zencore/filesystem.h> +# include <zencore/fmtutils.h> +# include <zencore/scopeguard.h> +# include <zencore/timer.h> +# include <zencore/trace.h> + +ZEN_THIRD_PARTY_INCLUDES_START +# include <asio/io_context.hpp> +# include <asio/executor_work_guard.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen::compute { + +using namespace std::literals; + +ManagedProcessRunner::ManagedProcessRunner(ChunkResolver& Resolver, + const std::filesystem::path& BaseDir, + DeferredDirectoryDeleter& Deleter, + WorkerThreadPool& WorkerPool, + int32_t MaxConcurrentActions) +: LocalProcessRunner(Resolver, BaseDir, Deleter, WorkerPool, MaxConcurrentActions) +, m_IoContext(std::make_unique<asio::io_context>()) +, m_SubprocessManager(std::make_unique<SubprocessManager>(*m_IoContext)) +{ + m_ProcessGroup = m_SubprocessManager->CreateGroup("compute-workers"); + + // Run the io_context on a small thread pool so that exit callbacks and + // metrics sampling are dispatched without blocking each other. + for (int i = 0; i < kIoThreadCount; ++i) + { + m_IoThreads.emplace_back([this, i] { + SetCurrentThreadName(fmt::format("mrunner_{}", i)); + + // work_guard keeps run() alive even when there is no pending work yet + auto WorkGuard = asio::make_work_guard(*m_IoContext); + + m_IoContext->run(); + }); + } +} + +ManagedProcessRunner::~ManagedProcessRunner() +{ + try + { + Shutdown(); + } + catch (std::exception& Ex) + { + ZEN_WARN("exception during managed process runner shutdown: {}", Ex.what()); + } +} + +void +ManagedProcessRunner::Shutdown() +{ + ZEN_TRACE_CPU("ManagedProcessRunner::Shutdown"); + m_AcceptNewActions = false; + + CancelRunningActions(); + + // Tear down the SubprocessManager before stopping the io_context so that + // any in-flight callbacks are drained cleanly. + if (m_SubprocessManager) + { + m_SubprocessManager->DestroyGroup("compute-workers"); + m_ProcessGroup = nullptr; + m_SubprocessManager.reset(); + } + + if (m_IoContext) + { + m_IoContext->stop(); + } + + for (std::thread& Thread : m_IoThreads) + { + if (Thread.joinable()) + { + Thread.join(); + } + } + m_IoThreads.clear(); +} + +SubmitResult +ManagedProcessRunner::SubmitAction(Ref<RunnerAction> Action) +{ + ZEN_TRACE_CPU("ManagedProcessRunner::SubmitAction"); + std::optional<PreparedAction> Prepared = PrepareActionSubmission(Action); + + if (!Prepared) + { + return SubmitResult{.IsAccepted = false}; + } + + CbObject WorkerDescription = Prepared->WorkerPackage.GetObject(); + + // Parse environment variables from worker descriptor ("KEY=VALUE" strings) + // into the key-value pairs expected by CreateProcOptions. + std::vector<std::pair<std::string, std::string>> EnvPairs; + for (auto& It : WorkerDescription["environment"sv]) + { + std::string_view Str = It.AsString(); + size_t Eq = Str.find('='); + if (Eq != std::string_view::npos) + { + EnvPairs.emplace_back(std::string(Str.substr(0, Eq)), std::string(Str.substr(Eq + 1))); + } + } + + // Build command line + std::string_view ExecPath = WorkerDescription["path"sv].AsString(); + std::filesystem::path ExePath = Prepared->WorkerPath / std::filesystem::path(ExecPath).make_preferred(); + + std::string CommandLine = fmt::format("\"{}\" -Build=build.action"sv, ExePath.string()); + + ZEN_DEBUG("Executing (managed): '{}' (sandbox='{}')", CommandLine, Prepared->SandboxPath); + + CreateProcOptions Options; + Options.WorkingDirectory = &Prepared->SandboxPath; + Options.Flags = CreateProcOptions::Flag_NoConsole | CreateProcOptions::Flag_BelowNormalPriority; + Options.Environment = std::move(EnvPairs); + + const int32_t ActionLsn = Prepared->ActionLsn; + + ManagedProcess* Proc = nullptr; + + try + { + Proc = m_ProcessGroup->Spawn(ExePath, CommandLine, Options, [this, ActionLsn](ManagedProcess& /*Process*/, int ExitCode) { + OnProcessExit(ActionLsn, ExitCode); + }); + } + catch (std::exception& Ex) + { + ZEN_ERROR("Failed to spawn process for action LSN {}: {}", ActionLsn, Ex.what()); + m_DeferredDeleter.Enqueue(ActionLsn, std::move(Prepared->SandboxPath)); + return SubmitResult{.IsAccepted = false}; + } + + { + Ref<RunningAction> NewAction{new RunningAction()}; + NewAction->Action = Action; + NewAction->ProcessHandle = static_cast<void*>(Proc); + NewAction->Pid = Proc->Pid(); + NewAction->SandboxPath = std::move(Prepared->SandboxPath); + + RwLock::ExclusiveLockScope _(m_RunningLock); + m_RunningMap[ActionLsn] = std::move(NewAction); + } + + Action->SetActionState(RunnerAction::State::Running); + + ZEN_DEBUG("Managed runner: action LSN {} -> PID {}", ActionLsn, Proc->Pid()); + + return SubmitResult{.IsAccepted = true}; +} + +void +ManagedProcessRunner::OnProcessExit(int ActionLsn, int ExitCode) +{ + ZEN_TRACE_CPU("ManagedProcessRunner::OnProcessExit"); + + Ref<RunningAction> Running; + + m_RunningLock.WithExclusiveLock([&] { + auto It = m_RunningMap.find(ActionLsn); + if (It != m_RunningMap.end()) + { + Running = std::move(It->second); + m_RunningMap.erase(It); + } + }); + + if (!Running) + { + return; + } + + ZEN_DEBUG("Managed runner: action LSN {} + PID {} exited with code " ZEN_BRIGHT_WHITE("{}"), ActionLsn, Running->Pid, ExitCode); + + Running->ExitCode = ExitCode; + + // Capture final CPU metrics from the managed process before it is removed. + auto* Proc = static_cast<ManagedProcess*>(Running->ProcessHandle); + if (Proc) + { + ProcessMetrics Metrics = Proc->GetLatestMetrics(); + float CpuMs = static_cast<float>(Metrics.UserTimeMs + Metrics.KernelTimeMs); + Running->Action->CpuSeconds.store(CpuMs / 1000.0f, std::memory_order_relaxed); + + float CpuPct = Proc->GetCpuUsagePercent(); + if (CpuPct >= 0.0f) + { + Running->Action->CpuUsagePercent.store(CpuPct, std::memory_order_relaxed); + } + } + + Running->ProcessHandle = nullptr; + + std::vector<Ref<RunningAction>> CompletedActions; + CompletedActions.push_back(std::move(Running)); + ProcessCompletedActions(CompletedActions); +} + +void +ManagedProcessRunner::CancelRunningActions() +{ + ZEN_TRACE_CPU("ManagedProcessRunner::CancelRunningActions"); + + std::unordered_map<int, Ref<RunningAction>> RunningMap; + m_RunningLock.WithExclusiveLock([&] { std::swap(RunningMap, m_RunningMap); }); + + if (RunningMap.empty()) + { + return; + } + + ZEN_INFO("cancelling {} running actions via process group", RunningMap.size()); + + Stopwatch Timer; + + // Kill all processes in the group atomically (TerminateJobObject on Windows, + // SIGTERM+SIGKILL on POSIX). + if (m_ProcessGroup) + { + m_ProcessGroup->KillAll(); + } + + for (auto& [Lsn, Running] : RunningMap) + { + m_DeferredDeleter.Enqueue(Running->Action->ActionLsn, std::move(Running->SandboxPath)); + Running->Action->SetActionState(RunnerAction::State::Failed); + } + + ZEN_INFO("DONE - cancelled {} running processes (took {})", RunningMap.size(), NiceTimeSpanMs(Timer.GetElapsedTimeMs())); +} + +bool +ManagedProcessRunner::CancelAction(int ActionLsn) +{ + ZEN_TRACE_CPU("ManagedProcessRunner::CancelAction"); + + ManagedProcess* Proc = nullptr; + + m_RunningLock.WithSharedLock([&] { + auto It = m_RunningMap.find(ActionLsn); + if (It != m_RunningMap.end() && It->second->ProcessHandle != nullptr) + { + Proc = static_cast<ManagedProcess*>(It->second->ProcessHandle); + } + }); + + if (!Proc) + { + return false; + } + + // Terminate the process. The exit callback will handle the rest + // (remove from running map, gather outputs or mark failed). + Proc->Terminate(222); + + ZEN_DEBUG("CancelAction: initiated cancellation of LSN {}", ActionLsn); + return true; +} + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/runners/managedrunner.h b/src/zencompute/runners/managedrunner.h new file mode 100644 index 000000000..21a44d43c --- /dev/null +++ b/src/zencompute/runners/managedrunner.h @@ -0,0 +1,64 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "localrunner.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include <zenutil/process/subprocessmanager.h> + +# include <memory> + +namespace asio { +class io_context; +} + +namespace zen::compute { + +/** Cross-platform process runner backed by SubprocessManager. + + Subclasses LocalProcessRunner, reusing sandbox management, worker manifesting, + input/output handling, and shared action preparation. Replaces the polling-based + monitor thread with async exit callbacks driven by SubprocessManager, and + delegates CPU/memory metrics sampling to the manager's built-in round-robin + sampler. + + A ProcessGroup (backed by a JobObject on Windows, process group on POSIX) is + used for bulk cancellation on shutdown. + + This runner does not perform any platform-specific sandboxing (AppContainer, + namespaces, Seatbelt). It is intended as a simpler, cross-platform alternative + to the platform-specific runners for non-sandboxed workloads. + */ +class ManagedProcessRunner : public LocalProcessRunner +{ +public: + ManagedProcessRunner(ChunkResolver& Resolver, + const std::filesystem::path& BaseDir, + DeferredDirectoryDeleter& Deleter, + WorkerThreadPool& WorkerPool, + int32_t MaxConcurrentActions = 0); + ~ManagedProcessRunner(); + + void Shutdown() override; + [[nodiscard]] SubmitResult SubmitAction(Ref<RunnerAction> Action) override; + void CancelRunningActions() override; + bool CancelAction(int ActionLsn) override; + [[nodiscard]] bool IsHealthy() override { return true; } + +private: + static constexpr int kIoThreadCount = 4; + + // Exit callback posted on an io_context thread. + void OnProcessExit(int ActionLsn, int ExitCode); + + std::unique_ptr<asio::io_context> m_IoContext; + std::unique_ptr<SubprocessManager> m_SubprocessManager; + ProcessGroup* m_ProcessGroup = nullptr; + std::vector<std::thread> m_IoThreads; +}; + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/runners/remotehttprunner.cpp b/src/zencompute/runners/remotehttprunner.cpp index ce6a81173..08f381b7f 100644 --- a/src/zencompute/runners/remotehttprunner.cpp +++ b/src/zencompute/runners/remotehttprunner.cpp @@ -20,6 +20,7 @@ # include <zenstore/cidstore.h> # include <span> +# include <unordered_set> ////////////////////////////////////////////////////////////////////////// @@ -38,6 +39,7 @@ RemoteHttpRunner::RemoteHttpRunner(ChunkResolver& InChunkResolver, , m_ChunkResolver{InChunkResolver} , m_WorkerPool{InWorkerPool} , m_HostName{HostName} +, m_DisplayName{HostName} , m_BaseUrl{fmt::format("{}/compute", HostName)} , m_Http(m_BaseUrl) , m_InstanceId(Oid::NewOid()) @@ -59,6 +61,15 @@ RemoteHttpRunner::RemoteHttpRunner(ChunkResolver& InChunkResolver, m_MonitorThread = std::thread{&RemoteHttpRunner::MonitorThreadFunction, this}; } +void +RemoteHttpRunner::SetRemoteHostname(std::string_view Hostname) +{ + if (!Hostname.empty()) + { + m_DisplayName = fmt::format("{} ({})", m_HostName, Hostname); + } +} + RemoteHttpRunner::~RemoteHttpRunner() { Shutdown(); @@ -108,6 +119,7 @@ RemoteHttpRunner::Shutdown() for (auto& [RemoteLsn, HttpAction] : Remaining) { ZEN_DEBUG("shutdown: marking remote action LSN {} (local LSN {}) as Failed", RemoteLsn, HttpAction.Action->ActionLsn); + HttpAction.Action->FailureReason = "remote runner shutdown"; HttpAction.Action->SetActionState(RunnerAction::State::Failed); } } @@ -213,11 +225,13 @@ RemoteHttpRunner::QueryCapacity() return 0; } - // Estimate how much more work we're ready to accept + // Estimate how much more work we're ready to accept. + // Include actions currently being submitted over HTTP so we don't + // keep queueing new submissions while previous ones are still in flight. RwLock::SharedLockScope _{m_RunningLock}; - size_t RunningCount = m_RemoteRunningMap.size(); + size_t RunningCount = m_RemoteRunningMap.size() + m_InFlightSubmissions.load(std::memory_order_relaxed); if (RunningCount >= size_t(m_MaxRunningActions)) { @@ -232,6 +246,9 @@ RemoteHttpRunner::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions) { ZEN_TRACE_CPU("RemoteHttpRunner::SubmitActions"); + m_InFlightSubmissions.fetch_add(Actions.size(), std::memory_order_relaxed); + auto InFlightGuard = MakeGuard([&] { m_InFlightSubmissions.fetch_sub(Actions.size(), std::memory_order_relaxed); }); + if (Actions.size() <= 1) { std::vector<SubmitResult> Results; @@ -246,7 +263,7 @@ RemoteHttpRunner::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions) // Collect distinct QueueIds and ensure remote queues exist once per queue - std::unordered_map<int, Oid> QueueTokens; // QueueId → remote token (0 stays as Zero) + std::unordered_map<int, Oid> QueueTokens; // QueueId -> remote token (0 stays as Zero) for (const Ref<RunnerAction>& Action : Actions) { @@ -359,108 +376,141 @@ RemoteHttpRunner::SubmitAction(Ref<RunnerAction> Action) } } - // Enqueue job. If the remote returns FailedDependency (424), it means it - // cannot resolve the worker/function — re-register the worker and retry once. + // Submit the action to the remote. In eager-attach mode we build a + // CbPackage with all referenced attachments upfront to avoid the 404 + // round-trip. In the default mode we POST the bare object first and + // only upload missing attachments if the remote requests them. + // + // In both modes, FailedDependency (424) triggers a worker re-register + // and a single retry. CbObject Result; HttpClient::Response WorkResponse; HttpResponseCode WorkResponseCode{}; - for (int Attempt = 0; Attempt < 2; ++Attempt) - { - WorkResponse = m_Http.Post(SubmitUrl, ActionObj); - WorkResponseCode = WorkResponse.StatusCode; - - if (WorkResponseCode == HttpResponseCode::FailedDependency && Attempt == 0) - { - ZEN_WARN("remote {} returned FailedDependency for action {} — re-registering worker and retrying", - m_Http.GetBaseUri(), - ActionId); - - (void)RegisterWorker(Action->Worker.Descriptor); - } - else - { - break; - } - } - - if (WorkResponseCode == HttpResponseCode::OK) - { - Result = WorkResponse.AsObject(); - } - else if (WorkResponseCode == HttpResponseCode::NotFound) + if (m_EagerAttach) { - // Not all attachments are present - - // Build response package including all required attachments - CbPackage Pkg; Pkg.SetObject(ActionObj); - CbObject Response = WorkResponse.AsObject(); + ActionObj.IterateAttachments([&](CbFieldView Field) { + const IoHash AttachHash = Field.AsHash(); - for (auto& Item : Response["need"sv]) - { - const IoHash NeedHash = Item.AsHash(); - - if (IoBuffer Chunk = m_ChunkResolver.FindChunkByCid(NeedHash)) + if (IoBuffer Chunk = m_ChunkResolver.FindChunkByCid(AttachHash)) { uint64_t DataRawSize = 0; IoHash DataRawHash; CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer{Chunk}, /* out */ DataRawHash, /* out */ DataRawSize); - ZEN_ASSERT(DataRawHash == NeedHash); + Pkg.AddAttachment(CbAttachment(Compressed, AttachHash)); + m_LastSubmitStats.TotalAttachments.fetch_add(1, std::memory_order_relaxed); + m_LastSubmitStats.TotalAttachmentBytes.fetch_add(Chunk.GetSize(), std::memory_order_relaxed); + } + }); + + for (int Attempt = 0; Attempt < 2; ++Attempt) + { + WorkResponse = m_Http.Post(SubmitUrl, Pkg); + WorkResponseCode = WorkResponse.StatusCode; + + if (WorkResponseCode == HttpResponseCode::FailedDependency && Attempt == 0) + { + ZEN_WARN("remote {} returned FailedDependency for action {} - re-registering worker and retrying", + m_Http.GetBaseUri(), + ActionId); - Pkg.AddAttachment(CbAttachment(Compressed, NeedHash)); + (void)RegisterWorker(Action->Worker.Descriptor); } else { - // No such attachment - - return {.IsAccepted = false, .Reason = fmt::format("missing attachment {}", NeedHash)}; + break; } } + } + else + { + for (int Attempt = 0; Attempt < 2; ++Attempt) + { + WorkResponse = m_Http.Post(SubmitUrl, ActionObj); + WorkResponseCode = WorkResponse.StatusCode; - // Post resulting package + if (WorkResponseCode == HttpResponseCode::FailedDependency && Attempt == 0) + { + ZEN_WARN("remote {} returned FailedDependency for action {} - re-registering worker and retrying", + m_Http.GetBaseUri(), + ActionId); - HttpClient::Response PayloadResponse = m_Http.Post(SubmitUrl, Pkg); + (void)RegisterWorker(Action->Worker.Descriptor); + } + else + { + break; + } + } - if (!PayloadResponse) + if (WorkResponseCode == HttpResponseCode::NotFound) { - ZEN_WARN("unable to register payloads for action {} at {}{}", ActionId, m_Http.GetBaseUri(), SubmitUrl); + // Remote needs attachments - resolve them and retry with a CbPackage - // TODO: include more information about the failure in the response + CbPackage Pkg; + Pkg.SetObject(ActionObj); - return {.IsAccepted = false, .Reason = "HTTP request failed"}; - } - else if (PayloadResponse.StatusCode == HttpResponseCode::OK) - { - Result = PayloadResponse.AsObject(); - } - else - { - // Unexpected response - - const int ResponseStatusCode = (int)PayloadResponse.StatusCode; - - ZEN_WARN("unable to register payloads for action {} at {}{} (error: {} {})", - ActionId, - m_Http.GetBaseUri(), - SubmitUrl, - ResponseStatusCode, - ToString(ResponseStatusCode)); - - return {.IsAccepted = false, - .Reason = fmt::format("unexpected response code {} {} from {}{}", - ResponseStatusCode, - ToString(ResponseStatusCode), - m_Http.GetBaseUri(), - SubmitUrl)}; + CbObject Response = WorkResponse.AsObject(); + + for (auto& Item : Response["need"sv]) + { + const IoHash NeedHash = Item.AsHash(); + + if (IoBuffer Chunk = m_ChunkResolver.FindChunkByCid(NeedHash)) + { + uint64_t DataRawSize = 0; + IoHash DataRawHash; + CompressedBuffer Compressed = + CompressedBuffer::FromCompressed(SharedBuffer{Chunk}, /* out */ DataRawHash, /* out */ DataRawSize); + + ZEN_ASSERT(DataRawHash == NeedHash); + + Pkg.AddAttachment(CbAttachment(Compressed, NeedHash)); + m_LastSubmitStats.TotalAttachments.fetch_add(1, std::memory_order_relaxed); + m_LastSubmitStats.TotalAttachmentBytes.fetch_add(Chunk.GetSize(), std::memory_order_relaxed); + } + else + { + return {.IsAccepted = false, .Reason = fmt::format("missing attachment {}", NeedHash)}; + } + } + + HttpClient::Response PayloadResponse = m_Http.Post(SubmitUrl, Pkg); + + if (!PayloadResponse) + { + ZEN_WARN("unable to register payloads for action {} at {}{}", ActionId, m_Http.GetBaseUri(), SubmitUrl); + return {.IsAccepted = false, .Reason = "HTTP request failed"}; + } + + WorkResponse = std::move(PayloadResponse); + WorkResponseCode = WorkResponse.StatusCode; } } + if (WorkResponseCode == HttpResponseCode::OK) + { + Result = WorkResponse.AsObject(); + } + else if (!WorkResponse) + { + ZEN_WARN("submit of action {} to {}{} failed", ActionId, m_Http.GetBaseUri(), SubmitUrl); + return {.IsAccepted = false, .Reason = "HTTP request failed"}; + } + else if (!IsHttpSuccessCode(WorkResponseCode)) + { + const int Code = static_cast<int>(WorkResponseCode); + ZEN_WARN("submit of action {} to {}{} returned {} {}", ActionId, m_Http.GetBaseUri(), SubmitUrl, Code, ToString(Code)); + return {.IsAccepted = false, + .Reason = fmt::format("unexpected response code {} {} from {}{}", Code, ToString(Code), m_Http.GetBaseUri(), SubmitUrl)}; + } + if (Result) { if (const int32_t LsnField = Result["lsn"].AsInt32(0)) @@ -512,83 +562,111 @@ RemoteHttpRunner::SubmitActionBatch(const std::string& SubmitUrl, const std::vec CbObjectWriter Body; Body.BeginArray("actions"sv); + std::unordered_set<IoHash, IoHash::Hasher> AttachmentsSeen; + for (const Ref<RunnerAction>& Action : Actions) { Action->ExecutionLocation = m_HostName; MaybeDumpAction(Action->ActionLsn, Action->ActionObj); Body.AddObject(Action->ActionObj); + + if (m_EagerAttach) + { + Action->ActionObj.IterateAttachments([&](CbFieldView Field) { AttachmentsSeen.insert(Field.AsHash()); }); + } } Body.EndArray(); - // POST the batch - - HttpClient::Response Response = m_Http.Post(SubmitUrl, Body.Save()); - - if (Response.StatusCode == HttpResponseCode::OK) - { - return ParseBatchResponse(Response, Actions); - } + // In eager-attach mode, build a CbPackage with all referenced attachments + // so the remote can accept in a single round-trip. Otherwise POST a bare + // CbObject and handle the 404 need-list flow. - if (Response.StatusCode == HttpResponseCode::NotFound) + if (m_EagerAttach) { - // Server needs attachments — resolve them and retry with a CbPackage - - CbObject NeedObj = Response.AsObject(); - CbPackage Pkg; Pkg.SetObject(Body.Save()); - for (auto& Item : NeedObj["need"sv]) + for (const IoHash& AttachHash : AttachmentsSeen) { - const IoHash NeedHash = Item.AsHash(); - - if (IoBuffer Chunk = m_ChunkResolver.FindChunkByCid(NeedHash)) + if (IoBuffer Chunk = m_ChunkResolver.FindChunkByCid(AttachHash)) { uint64_t DataRawSize = 0; IoHash DataRawHash; CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer{Chunk}, /* out */ DataRawHash, /* out */ DataRawSize); - ZEN_ASSERT(DataRawHash == NeedHash); - - Pkg.AddAttachment(CbAttachment(Compressed, NeedHash)); - } - else - { - ZEN_WARN("batch submit: missing attachment {} — falling back to individual submit", NeedHash); - return FallbackToIndividualSubmit(Actions); + Pkg.AddAttachment(CbAttachment(Compressed, AttachHash)); + m_LastSubmitStats.TotalAttachments.fetch_add(1, std::memory_order_relaxed); + m_LastSubmitStats.TotalAttachmentBytes.fetch_add(Chunk.GetSize(), std::memory_order_relaxed); } } - HttpClient::Response RetryResponse = m_Http.Post(SubmitUrl, Pkg); + HttpClient::Response Response = m_Http.Post(SubmitUrl, Pkg); - if (RetryResponse.StatusCode == HttpResponseCode::OK) + if (Response.StatusCode == HttpResponseCode::OK) { - return ParseBatchResponse(RetryResponse, Actions); + return ParseBatchResponse(Response, Actions); } - - ZEN_WARN("batch submit retry failed with {} {} — falling back to individual submit", - (int)RetryResponse.StatusCode, - ToString(RetryResponse.StatusCode)); - return FallbackToIndividualSubmit(Actions); - } - - // Unexpected status or connection error — fall back to individual submission - - if (Response) - { - ZEN_WARN("batch submit to {}{} returned {} {} — falling back to individual submit", - m_Http.GetBaseUri(), - SubmitUrl, - (int)Response.StatusCode, - ToString(Response.StatusCode)); } else { - ZEN_WARN("batch submit to {}{} failed — falling back to individual submit", m_Http.GetBaseUri(), SubmitUrl); + HttpClient::Response Response = m_Http.Post(SubmitUrl, Body.Save()); + + if (Response.StatusCode == HttpResponseCode::OK) + { + return ParseBatchResponse(Response, Actions); + } + + if (Response.StatusCode == HttpResponseCode::NotFound) + { + CbObject NeedObj = Response.AsObject(); + + CbPackage Pkg; + Pkg.SetObject(Body.Save()); + + for (auto& Item : NeedObj["need"sv]) + { + const IoHash NeedHash = Item.AsHash(); + + if (IoBuffer Chunk = m_ChunkResolver.FindChunkByCid(NeedHash)) + { + uint64_t DataRawSize = 0; + IoHash DataRawHash; + CompressedBuffer Compressed = + CompressedBuffer::FromCompressed(SharedBuffer{Chunk}, /* out */ DataRawHash, /* out */ DataRawSize); + + ZEN_ASSERT(DataRawHash == NeedHash); + + Pkg.AddAttachment(CbAttachment(Compressed, NeedHash)); + m_LastSubmitStats.TotalAttachments.fetch_add(1, std::memory_order_relaxed); + m_LastSubmitStats.TotalAttachmentBytes.fetch_add(Chunk.GetSize(), std::memory_order_relaxed); + } + else + { + ZEN_WARN("batch submit: missing attachment {} - falling back to individual submit", NeedHash); + return FallbackToIndividualSubmit(Actions); + } + } + + HttpClient::Response RetryResponse = m_Http.Post(SubmitUrl, Pkg); + + if (RetryResponse.StatusCode == HttpResponseCode::OK) + { + return ParseBatchResponse(RetryResponse, Actions); + } + + ZEN_WARN("batch submit retry failed with {} {} - falling back to individual submit", + (int)RetryResponse.StatusCode, + ToString(RetryResponse.StatusCode)); + return FallbackToIndividualSubmit(Actions); + } } + // Unexpected status or connection error - fall back to individual submission + + ZEN_WARN("batch submit to {}{} failed - falling back to individual submit", m_Http.GetBaseUri(), SubmitUrl); + return FallbackToIndividualSubmit(Actions); } @@ -849,7 +927,7 @@ RemoteHttpRunner::MonitorThreadFunction() SweepOnce(); } - // Signal received — may be a WS wakeup or a quit signal + // Signal received - may be a WS wakeup or a quit signal SweepOnce(); } while (m_MonitorThreadEnabled); @@ -869,9 +947,10 @@ RemoteHttpRunner::SweepRunningActions() { for (auto& FieldIt : Completed["completed"sv]) { - CbObjectView EntryObj = FieldIt.AsObjectView(); - const int32_t CompleteLsn = EntryObj["lsn"sv].AsInt32(); - std::string_view StateName = EntryObj["state"sv].AsString(); + CbObjectView EntryObj = FieldIt.AsObjectView(); + const int32_t CompleteLsn = EntryObj["lsn"sv].AsInt32(); + std::string_view StateName = EntryObj["state"sv].AsString(); + std::string_view FailureReason = EntryObj["reason"sv].AsString(); RunnerAction::State RemoteState = RunnerAction::FromString(StateName); @@ -884,6 +963,7 @@ RemoteHttpRunner::SweepRunningActions() { HttpRunningAction CompletedAction = std::move(CompleteIt->second); CompletedAction.RemoteState = RemoteState; + CompletedAction.FailureReason = std::string(FailureReason); if (RemoteState == RunnerAction::State::Completed && ResponseJob) { @@ -927,16 +1007,44 @@ RemoteHttpRunner::SweepRunningActions() { const int ActionLsn = HttpAction.Action->ActionLsn; - ZEN_DEBUG("action {} LSN {} (remote LSN {}) -> {}", - HttpAction.Action->ActionId, - ActionLsn, - HttpAction.RemoteActionLsn, - RunnerAction::ToString(HttpAction.RemoteState)); - if (HttpAction.RemoteState == RunnerAction::State::Completed) { + ZEN_DEBUG("action {} LSN {} (remote LSN {}) completed on {}", + HttpAction.Action->ActionId, + ActionLsn, + HttpAction.RemoteActionLsn, + m_HostName); HttpAction.Action->SetResult(std::move(HttpAction.ActionResults)); } + else if (HttpAction.RemoteState == RunnerAction::State::Failed || HttpAction.RemoteState == RunnerAction::State::Abandoned) + { + HttpAction.Action->FailureReason = HttpAction.FailureReason; + if (HttpAction.FailureReason.empty()) + { + ZEN_WARN("action {} ({}) {} on remote {}", + HttpAction.Action->ActionId, + ActionLsn, + RunnerAction::ToString(HttpAction.RemoteState), + m_HostName); + } + else + { + ZEN_WARN("action {} ({}) {} on remote {}: {}", + HttpAction.Action->ActionId, + ActionLsn, + RunnerAction::ToString(HttpAction.RemoteState), + m_HostName, + HttpAction.FailureReason); + } + } + else + { + ZEN_DEBUG("action {} LSN {} (remote LSN {}) -> {}", + HttpAction.Action->ActionId, + ActionLsn, + HttpAction.RemoteActionLsn, + RunnerAction::ToString(HttpAction.RemoteState)); + } HttpAction.Action->SetActionState(HttpAction.RemoteState); } diff --git a/src/zencompute/runners/remotehttprunner.h b/src/zencompute/runners/remotehttprunner.h index c17d0cf2a..521bf2f82 100644 --- a/src/zencompute/runners/remotehttprunner.h +++ b/src/zencompute/runners/remotehttprunner.h @@ -54,8 +54,10 @@ public: [[nodiscard]] virtual size_t QueryCapacity() override; [[nodiscard]] virtual std::vector<SubmitResult> SubmitActions(const std::vector<Ref<RunnerAction>>& Actions) override; virtual void CancelRemoteQueue(int QueueId) override; + [[nodiscard]] virtual std::string_view GetDisplayName() const override { return m_DisplayName; } std::string_view GetHostName() const { return m_HostName; } + void SetRemoteHostname(std::string_view Hostname); protected: LoggerRef Log() { return m_Log; } @@ -65,12 +67,15 @@ private: ChunkResolver& m_ChunkResolver; WorkerThreadPool& m_WorkerPool; std::string m_HostName; + std::string m_DisplayName; std::string m_BaseUrl; HttpClient m_Http; - std::atomic<bool> m_AcceptNewActions{true}; - int32_t m_MaxRunningActions = 256; // arbitrary limit for testing - int32_t m_MaxBatchSize = 50; + std::atomic<bool> m_AcceptNewActions{true}; + int32_t m_MaxRunningActions = 256; // arbitrary limit for testing + int32_t m_MaxBatchSize = 50; + bool m_EagerAttach = true; ///< Send attachments with every submit instead of the two-step 404 retry + std::atomic<size_t> m_InFlightSubmissions{0}; // actions currently being submitted over HTTP struct HttpRunningAction { @@ -78,6 +83,7 @@ private: int RemoteActionLsn = 0; // Remote LSN RunnerAction::State RemoteState = RunnerAction::State::Failed; CbPackage ActionResults; + std::string FailureReason; }; RwLock m_RunningLock; @@ -90,7 +96,7 @@ private: size_t SweepRunningActions(); RwLock m_QueueTokenLock; - std::unordered_map<int, Oid> m_RemoteQueueTokens; // local QueueId → remote queue token + std::unordered_map<int, Oid> m_RemoteQueueTokens; // local QueueId -> remote queue token // Stable identity for this runner instance, used as part of the idempotency key when // creating remote queues. Generated once at construction and never changes. diff --git a/src/zencompute/runners/windowsrunner.cpp b/src/zencompute/runners/windowsrunner.cpp index e9a1ae8b6..c6b3e82ea 100644 --- a/src/zencompute/runners/windowsrunner.cpp +++ b/src/zencompute/runners/windowsrunner.cpp @@ -21,6 +21,12 @@ ZEN_THIRD_PARTY_INCLUDES_START # include <sddl.h> ZEN_THIRD_PARTY_INCLUDES_END +// JOB_OBJECT_UILIMIT_ERRORMODE is defined in winuser.h which may be +// excluded by WIN32_LEAN_AND_MEAN. +# if !defined(JOB_OBJECT_UILIMIT_ERRORMODE) +# define JOB_OBJECT_UILIMIT_ERRORMODE 0x00000400 +# endif + namespace zen::compute { using namespace std::literals; @@ -34,38 +40,67 @@ WindowsProcessRunner::WindowsProcessRunner(ChunkResolver& Resolver, : LocalProcessRunner(Resolver, BaseDir, Deleter, WorkerPool, MaxConcurrentActions) , m_Sandboxed(Sandboxed) { - if (!m_Sandboxed) + // Create a job object shared by all child processes. Restricting the + // error-mode UI prevents crash dialogs (WER / Dr. Watson) from + // blocking the monitor thread when a worker process terminates + // abnormally. + m_JobObject = CreateJobObjectW(nullptr, nullptr); + if (m_JobObject) { - return; + JOBOBJECT_EXTENDED_LIMIT_INFORMATION ExtLimits{}; + ExtLimits.BasicLimitInformation.LimitFlags = + JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE | JOB_OBJECT_LIMIT_DIE_ON_UNHANDLED_EXCEPTION | JOB_OBJECT_LIMIT_PRIORITY_CLASS; + ExtLimits.BasicLimitInformation.PriorityClass = BELOW_NORMAL_PRIORITY_CLASS; + SetInformationJobObject(m_JobObject, JobObjectExtendedLimitInformation, &ExtLimits, sizeof(ExtLimits)); + + JOBOBJECT_BASIC_UI_RESTRICTIONS UiRestrictions{}; + UiRestrictions.UIRestrictionsClass = JOB_OBJECT_UILIMIT_ERRORMODE; + SetInformationJobObject(m_JobObject, JobObjectBasicUIRestrictions, &UiRestrictions, sizeof(UiRestrictions)); + + // Set error mode on this process so children inherit it. The + // UILIMIT_ERRORMODE restriction above prevents them from clearing + // SEM_NOGPFAULTERRORBOX. + SetErrorMode(SEM_FAILCRITICALERRORS | SEM_NOGPFAULTERRORBOX); } - // Build a unique profile name per process to avoid collisions - m_AppContainerName = L"zenserver-sandbox-" + std::to_wstring(GetCurrentProcessId()); + if (m_Sandboxed) + { + // Build a unique profile name per process to avoid collisions + m_AppContainerName = L"zenserver-sandbox-" + std::to_wstring(GetCurrentProcessId()); - // Clean up any stale profile from a previous crash - DeleteAppContainerProfile(m_AppContainerName.c_str()); + // Clean up any stale profile from a previous crash + DeleteAppContainerProfile(m_AppContainerName.c_str()); - PSID Sid = nullptr; + PSID Sid = nullptr; - HRESULT Hr = CreateAppContainerProfile(m_AppContainerName.c_str(), - m_AppContainerName.c_str(), // display name - m_AppContainerName.c_str(), // description - nullptr, // no capabilities - 0, // capability count - &Sid); + HRESULT Hr = CreateAppContainerProfile(m_AppContainerName.c_str(), + m_AppContainerName.c_str(), // display name + m_AppContainerName.c_str(), // description + nullptr, // no capabilities + 0, // capability count + &Sid); - if (FAILED(Hr)) - { - throw zen::runtime_error("CreateAppContainerProfile failed: HRESULT 0x{:08X}", static_cast<uint32_t>(Hr)); - } + if (FAILED(Hr)) + { + throw zen::runtime_error("CreateAppContainerProfile failed: HRESULT 0x{:08X}", static_cast<uint32_t>(Hr)); + } - m_AppContainerSid = Sid; + m_AppContainerSid = Sid; + + ZEN_INFO("AppContainer sandboxing enabled for child processes (profile={})", WideToUtf8(m_AppContainerName)); + } - ZEN_INFO("AppContainer sandboxing enabled for child processes (profile={})", WideToUtf8(m_AppContainerName)); + StartMonitorThread(); } WindowsProcessRunner::~WindowsProcessRunner() { + if (m_JobObject) + { + CloseHandle(m_JobObject); + m_JobObject = nullptr; + } + if (m_AppContainerSid) { FreeSid(m_AppContainerSid); @@ -172,9 +207,9 @@ WindowsProcessRunner::SubmitAction(Ref<RunnerAction> Action) LPSECURITY_ATTRIBUTES lpProcessAttributes = nullptr; LPSECURITY_ATTRIBUTES lpThreadAttributes = nullptr; BOOL bInheritHandles = FALSE; - DWORD dwCreationFlags = 0; + DWORD dwCreationFlags = CREATE_SUSPENDED | DETACHED_PROCESS; - ZEN_DEBUG("Executing: {} (sandboxed={})", WideToUtf8(CommandLine.c_str()), m_Sandboxed); + ZEN_DEBUG("{}: '{}' (sandbox='{}')", m_Sandboxed ? "Sandboxing" : "Executing", WideToUtf8(CommandLine.c_str()), Prepared->SandboxPath); CommandLine.EnsureNulTerminated(); @@ -260,14 +295,21 @@ WindowsProcessRunner::SubmitAction(Ref<RunnerAction> Action) } } - CloseHandle(ProcessInformation.hThread); + if (m_JobObject) + { + AssignProcessToJobObject(m_JobObject, ProcessInformation.hProcess); + } - Ref<RunningAction> NewAction{new RunningAction()}; - NewAction->Action = Action; - NewAction->ProcessHandle = ProcessInformation.hProcess; - NewAction->SandboxPath = std::move(Prepared->SandboxPath); + ResumeThread(ProcessInformation.hThread); + CloseHandle(ProcessInformation.hThread); { + Ref<RunningAction> NewAction{new RunningAction()}; + NewAction->Action = Action; + NewAction->ProcessHandle = ProcessInformation.hProcess; + NewAction->Pid = ProcessInformation.dwProcessId; + NewAction->SandboxPath = std::move(Prepared->SandboxPath); + RwLock::ExclusiveLockScope _(m_RunningLock); m_RunningMap[Prepared->ActionLsn] = std::move(NewAction); @@ -275,6 +317,8 @@ WindowsProcessRunner::SubmitAction(Ref<RunnerAction> Action) Action->SetActionState(RunnerAction::State::Running); + ZEN_DEBUG("Local runner: action LSN {} -> PID {}", Action->ActionLsn, ProcessInformation.dwProcessId); + return SubmitResult{.IsAccepted = true}; } @@ -294,6 +338,11 @@ WindowsProcessRunner::SweepRunningActions() if (IsSuccess && ExitCode != STILL_ACTIVE) { + ZEN_DEBUG("Local runner: action LSN {} + PID {} exited with code " ZEN_BRIGHT_WHITE("{}"), + Running->Action->ActionLsn, + Running->Pid, + ExitCode); + CloseHandle(Running->ProcessHandle); Running->ProcessHandle = INVALID_HANDLE_VALUE; Running->ExitCode = ExitCode; @@ -436,7 +485,7 @@ WindowsProcessRunner::SampleProcessCpu(RunningAction& Running) const uint64_t CurrentOsTicks = FtToU64(KernelTime) + FtToU64(UserTime); const uint64_t NowTicks = GetHifreqTimerValue(); - // Cumulative CPU seconds (absolute, available from first sample): 100ns → seconds + // Cumulative CPU seconds (absolute, available from first sample): 100ns -> seconds Running.Action->CpuSeconds.store(static_cast<float>(static_cast<double>(CurrentOsTicks) / 10'000'000.0), std::memory_order_relaxed); if (Running.LastCpuSampleTicks != 0 && Running.LastCpuOsTicks != 0) @@ -445,7 +494,7 @@ WindowsProcessRunner::SampleProcessCpu(RunningAction& Running) if (ElapsedMs > 0) { const uint64_t DeltaOsTicks = CurrentOsTicks - Running.LastCpuOsTicks; - // 100ns → ms: divide by 10000; then as percent of elapsed ms + // 100ns -> ms: divide by 10000; then as percent of elapsed ms const float CpuPct = static_cast<float>(static_cast<double>(DeltaOsTicks) / 10000.0 / ElapsedMs * 100.0); Running.Action->CpuUsagePercent.store(CpuPct, std::memory_order_relaxed); } diff --git a/src/zencompute/runners/windowsrunner.h b/src/zencompute/runners/windowsrunner.h index 9f2385cc4..adeaf02fc 100644 --- a/src/zencompute/runners/windowsrunner.h +++ b/src/zencompute/runners/windowsrunner.h @@ -46,6 +46,7 @@ private: bool m_Sandboxed = false; PSID m_AppContainerSid = nullptr; std::wstring m_AppContainerName; + HANDLE m_JobObject = nullptr; }; } // namespace zen::compute diff --git a/src/zencompute/runners/winerunner.cpp b/src/zencompute/runners/winerunner.cpp index 506bec73b..29ab93663 100644 --- a/src/zencompute/runners/winerunner.cpp +++ b/src/zencompute/runners/winerunner.cpp @@ -36,6 +36,8 @@ WineProcessRunner::WineProcessRunner(ChunkResolver& Resolver, sigemptyset(&Action.sa_mask); Action.sa_handler = SIG_DFL; sigaction(SIGCHLD, &Action, nullptr); + + StartMonitorThread(); } SubmitResult @@ -94,7 +96,9 @@ WineProcessRunner::SubmitAction(Ref<RunnerAction> Action) if (ChildPid == 0) { - // Child process + // Child process - lower priority so workers don't starve the main server + nice(5); + if (chdir(SandboxPathStr.c_str()) != 0) { _exit(127); |