diff options
| author | Stefan Boberg <[email protected]> | 2026-03-18 11:19:10 +0100 |
|---|---|---|
| committer | GitHub Enterprise <[email protected]> | 2026-03-18 11:19:10 +0100 |
| commit | eba410c4168e23d7908827eb34b7cf0c58a5dc48 (patch) | |
| tree | 3cda8e8f3f81941d3bb5b84a8155350c5bb2068c /src/zencompute | |
| parent | bugfix release - v5.7.23 (#851) (diff) | |
| download | zen-eba410c4168e23d7908827eb34b7cf0c58a5dc48.tar.xz zen-eba410c4168e23d7908827eb34b7cf0c58a5dc48.zip | |
Compute batching (#849)
### Compute Batch Submission
- Consolidate duplicated action submission logic in `httpcomputeservice` into a single `HandleSubmitAction` supporting both single-action and batch (actions array) payloads
- Group actions by queue in `RemoteHttpRunner` and submit as batches with configurable chunk size, falling back to individual submission on failure
- Extract shared helpers: `MakeErrorResult`, `ValidateQueueForEnqueue`, `ActivateActionInQueue`, `RemoveActionFromActiveMaps`
### Retracted Action State
- Add `Retracted` state to `RunnerAction` for retry-free rescheduling — an explicit request to pull an action back and reschedule it on a different runner without incrementing `RetryCount`
- Implement idempotent `RetractAction()` on `RunnerAction` and `ComputeServiceSession`
- Add `POST jobs/{lsn}/retract` and `queues/{queueref}/jobs/{lsn}/retract` HTTP endpoints
- Add state machine documentation and per-state comments to `RunnerAction`
### Compute Race Fixes
- Fix race in `HandleActionUpdates` where actions enqueued between session abandon and scheduler tick were never abandoned, causing `GetActionResult` to return 202 indefinitely
- Fix queue `ActiveCount` race where `NotifyQueueActionComplete` was called after releasing `m_ResultsLock`, allowing callers to observe stale counters immediately after `GetActionResult` returned OK
### Logging Optimization and ANSI improvements
- Improve `AnsiColorStdoutSink` write efficiency — single write call, dirty-flag flush, `RwLock` instead of `std::mutex`
- Move ANSI color emission from sink into formatters via `Formatter::SetColorEnabled()`; remove `ColorRangeStart`/`End` from `LogMessage`
- Extract color helpers (`AnsiColorForLevel`, `StripAnsiSgrSequences`) into `helpers.h`
- Strip upstream ANSI SGR escapes in non-color output mode. This enables colour in log messages without polluting log files with ANSI control sequences
- Move `RotatingFileSink`, `JsonFormatter`, and `FullFormatter` from header-only to pimpl with `.cpp` files
### CLI / Exec Refactoring
- Extract `ExecSessionRunner` class from ~920-line `ExecUsingSession` into focused methods and a `ExecSessionConfig` struct
- Replace monolithic `ExecCommand` with subcommand-based architecture (`http`, `inproc`, `beacon`, `dump`, `buildlog`)
- Allow parent options to appear after subcommand name by parsing subcommand args permissively and forwarding unmatched tokens to the parent parser
### Testing Improvements
- Fix `--test-suite` filter being ignored due to accumulation with default wildcard filter
- Add test suite banners to test listener output
- Made `function.session.abandon_pending` test more robust
### Startup / Reliability Fixes
- Fix silent exit when a second zenserver instance detects a port conflict — use `ZEN_CONSOLE_*` for log calls that precede `InitializeLogging()`
- Fix two potential SIGSEGV paths during early startup: guard `sentry_options_new()` returning nullptr, and throw on `ZenServerState::Register()` returning nullptr instead of dereferencing
- Fail on unrecognized zenserver `--mode` instead of silently defaulting to store
### Other
- Show host details (hostname, platform, CPU count, memory) when discovering new compute workers
- Move frontend `html.zip` from source tree into build directory
- Add format specifications for Compact Binary and Compressed Buffer wire formats
- Add `WriteCompactBinaryObject` to zencore
- Extended `ConsoleTui` with additional functionality
- Add `--vscode` option to `xmake sln` for clangd / `compile_commands.json` support
- Disable compute/horde/nomad in release builds (not yet production-ready)
- Disable unintended `ASIO_HAS_IO_URING` enablement
- Fix crashpad patch missing leading whitespace
- Clean up code triggering gcc false positives
Diffstat (limited to 'src/zencompute')
| -rw-r--r-- | src/zencompute/CLAUDE.md | 39 | ||||
| -rw-r--r-- | src/zencompute/computeservice.cpp | 384 | ||||
| -rw-r--r-- | src/zencompute/httpcomputeservice.cpp | 760 | ||||
| -rw-r--r-- | src/zencompute/include/zencompute/computeservice.h | 69 | ||||
| -rw-r--r-- | src/zencompute/include/zencompute/httpcomputeservice.h | 13 | ||||
| -rw-r--r-- | src/zencompute/runners/functionrunner.cpp | 44 | ||||
| -rw-r--r-- | src/zencompute/runners/functionrunner.h | 53 | ||||
| -rw-r--r-- | src/zencompute/runners/localrunner.cpp | 10 | ||||
| -rw-r--r-- | src/zencompute/runners/localrunner.h | 2 | ||||
| -rw-r--r-- | src/zencompute/runners/remotehttprunner.cpp | 399 | ||||
| -rw-r--r-- | src/zencompute/runners/remotehttprunner.h | 23 |
11 files changed, 1282 insertions, 514 deletions
diff --git a/src/zencompute/CLAUDE.md b/src/zencompute/CLAUDE.md index f5188123f..a1a39fc3c 100644 --- a/src/zencompute/CLAUDE.md +++ b/src/zencompute/CLAUDE.md @@ -46,9 +46,12 @@ Public API entry point. Uses PIMPL (`struct Impl` in computeservice.cpp). Owns: - Action maps: `m_PendingActions`, `m_RunningMap`, `m_ResultsMap` - Queue map: `m_Queues` (QueueEntry objects) - Action history ring: `m_ActionHistory` (bounded deque, default 1000) +- WebSocket client (`m_OrchestratorWsClient`) subscribed to the orchestrator's `/orch/ws` push for instant worker discovery **Session states:** Created → Ready → Draining → Paused → Abandoned → Sunset. Both Abandoned and Sunset can be jumped to from any earlier state. Abandoned is used for spot instance termination grace periods — on entry, all pending and running actions are immediately marked as `RunnerAction::State::Abandoned` and running processes are best-effort cancelled. Auto-retry is suppressed while the session is Abandoned. `IsHealthy()` returns false for Abandoned and Sunset. +**Convenience helpers:** `Ready()`, `Abandon()`, `SetOrchestrator(Endpoint, BasePath)` are inline wrappers for common state transitions and orchestrator configuration. + ### `RunnerAction` (runners/functionrunner.h) Shared ref-counted struct representing one action through its lifecycle. @@ -67,8 +70,11 @@ New → Pending → Submitting → Running → Completed → Failed → Abandoned → Cancelled + → Retracted ``` -`SetActionState()` rejects non-forward transitions. The one exception is `ResetActionStateToPending()`, which uses CAS to atomically transition from Failed/Abandoned back to Pending for rescheduling. It clears timestamps from Submitting onward, resets execution fields, increments `RetryCount`, and calls `PostUpdate()` to re-enter the scheduler pipeline. +`SetActionState()` rejects non-forward transitions (Retracted has the highest ordinal so runner-side transitions cannot override it). `ResetActionStateToPending()` uses CAS to atomically transition from Failed/Abandoned back to Pending for rescheduling — it clears timestamps from Submitting onward, resets execution fields, increments `RetryCount`, and calls `PostUpdate()` to re-enter the scheduler pipeline. + +**Retracted state:** An explicit, instigator-initiated request to pull an action back and reschedule it on a different runner (e.g. capacity opened up elsewhere). Unlike Failed/Abandoned auto-retry, rescheduling from Retracted does not increment `RetryCount` since nothing went wrong. Retraction is idempotent and can target Pending, Submitting, or Running actions. ### `LocalProcessRunner` (runners/localrunner.h) Base for all local execution. Platform runners subclass this and override: @@ -90,10 +96,29 @@ Base for all local execution. Platform runners subclass this and override: - macOS: `proc_pidinfo(PROC_PIDTASKINFO)` pti_total_user+system nanoseconds ÷ 1,000,000,000 ### `FunctionRunner` / `RunnerGroup` (runners/functionrunner.h) -Abstract base for runners. `RunnerGroup<T>` holds a vector of runners and load-balances across them using a round-robin atomic index. `BaseRunnerGroup::SubmitActions()` distributes a batch proportionally based on per-runner capacity. +Abstract base for runners. `RunnerGroup<T>` holds a vector of runners and load-balances across them using a round-robin atomic index. `BaseRunnerGroup::SubmitActions()` distributes a batch proportionally based on per-runner capacity. `SubmitActions()` supports batch submission — actions are grouped and forwarded in chunks. + +### `RemoteHttpRunner` (runners/remotehttprunner.h) +Submits actions to remote zenserver instances over HTTP. Key features: +- **WebSocket completion notifications**: connects a WS client to `/compute/ws` on the remote. When a message arrives (action completed), the monitor thread wakes immediately instead of polling. Falls back to adaptive polling (200ms→50ms) when WS is unavailable. +- **Batch submission**: groups actions by queue and submits in configurable chunks (`m_MaxBatchSize`, default 50), falling back to individual submission on failure. +- **Queue cancellation**: `CancelRemoteQueue()` sends cancel requests to the remote. +- **Graceful shutdown**: `Shutdown()` closes the WS client, cancels all remote queues, stops the monitor thread, and marks remaining actions as Failed. ### `HttpComputeService` (include/zencompute/httpcomputeservice.h) -Wraps `ComputeServiceSession` as an HTTP service. All routes are registered in the constructor. Handles CbPackage attachment ingestion from `CidStore` before forwarding to the service. +Wraps `ComputeServiceSession` as an HTTP service. All routes are registered in the constructor. Handles CbPackage attachment ingestion from `CidStore` before forwarding to the service. Supports both single-action and batch (actions array) payloads via a shared `HandleSubmitAction` helper. + +## Orchestrator Discovery + +`ComputeServiceSession` discovers remote workers via the orchestrator endpoint (`SetOrchestratorEndpoint()`). Two complementary mechanisms: + +1. **Polling** (`UpdateCoordinatorState`): `GET /orch/agents` on the scheduler thread, throttled to every 5s (500ms when no workers are known yet). Discovers new workers and removes stale/unreachable ones. + +2. **WebSocket push** (`OrchestratorWsHandler`): connects to `/orch/ws` on the orchestrator at setup time. When the orchestrator broadcasts a state change, the handler sets `m_OrchestratorQueryForced` and signals the scheduler event, bypassing the polling throttle. Falls back silently to polling if the WS connection fails. + +`NotifyOrchestratorChanged()` is the public API to trigger an immediate re-query — useful in tests and for external notification sources. + +Use `HttpToWsUrl(Endpoint, Path)` from `zenhttp/httpwsclient.h` to convert HTTP(S) endpoints to WebSocket URLs. This helper is shared across all WS client setup sites in the codebase. ## Action Lifecycle (End to End) @@ -118,6 +143,8 @@ Actions that fail or are abandoned can be automatically retried or manually resc **Retry limit:** Default of 3, overridable per-queue via the `max_retries` integer field in the queue's `Config` CbObject (set at `CreateQueue` time). Both automatic and manual paths respect this limit. +**Retraction (API path):** `RetractAction(Lsn)` pulls a Pending/Submitting/Running action back for rescheduling on a different runner. The action transitions to Retracted, then `ResetActionStateToPending()` is called *without* incrementing `RetryCount`. Retraction is idempotent. + **Cancelled actions are never retried** — cancellation is an intentional user action, not a transient failure. ## Queue System @@ -161,8 +188,9 @@ All routes registered in `HttpComputeService` constructor. Prefix is configured | GET | `jobs/running` | In-flight actions with CPU metrics | | GET | `jobs/completed` | Actions with results available | | GET/POST/DELETE | `jobs/{lsn}` | GET: result; POST: reschedule failed action; DELETE: retire | +| POST | `jobs/{lsn}/retract` | Retract a pending/running action for rescheduling (idempotent) | | POST | `jobs/{worker}` | Submit action for specific worker | -| POST | `jobs` | Submit action (worker resolved from descriptor) | +| POST | `jobs` | Submit action (or batch via `actions` array) | | GET | `workers` | List worker IDs | | GET | `workers/all` | All workers with full descriptors | | GET/POST | `workers/{worker}` | Get/register worker | @@ -179,8 +207,9 @@ Queue ref is capture(1) in all `queues/{queueref}/...` routes. | GET | `queues/{queueref}/completed` | Queue's completed results | | GET | `queues/{queueref}/history` | Queue's action history | | GET | `queues/{queueref}/running` | Queue's running actions | -| POST | `queues/{queueref}/jobs` | Submit to queue | +| POST | `queues/{queueref}/jobs` | Submit to queue (or batch via `actions` array) | | GET/POST | `queues/{queueref}/jobs/{lsn}` | GET: result; POST: reschedule | +| POST | `queues/{queueref}/jobs/{lsn}/retract` | Retract action for rescheduling | | GET/POST | `queues/{queueref}/workers/...` | Worker endpoints (same as global) | Worker handler logic is extracted into private helpers (`HandleWorkersGet`, `HandleWorkersAllGet`, `HandleWorkerRequest`) shared by top-level and queue-scoped routes. diff --git a/src/zencompute/computeservice.cpp b/src/zencompute/computeservice.cpp index 838d741b6..92901de64 100644 --- a/src/zencompute/computeservice.cpp +++ b/src/zencompute/computeservice.cpp @@ -33,6 +33,7 @@ # include <zenutil/workerpools.h> # include <zentelemetry/stats.h> # include <zenhttp/httpclient.h> +# include <zenhttp/httpwsclient.h> # include <set> # include <deque> @@ -42,6 +43,7 @@ # include <unordered_set> ZEN_THIRD_PARTY_INCLUDES_START +# include <EASTL/fixed_vector.h> # include <EASTL/hash_set.h> ZEN_THIRD_PARTY_INCLUDES_END @@ -95,6 +97,14 @@ using SessionState = ComputeServiceSession::SessionState; static_assert(ZEN_ARRAY_COUNT(ComputeServiceSession::ActionHistoryEntry::Timestamps) == static_cast<size_t>(RunnerAction::State::_Count)); +static ComputeServiceSession::EnqueueResult +MakeErrorResult(std::string_view Error) +{ + CbObjectWriter Writer; + Writer << "error"sv << Error; + return {0, Writer.Save()}; +} + ////////////////////////////////////////////////////////////////////////// struct ComputeServiceSession::Impl @@ -130,14 +140,40 @@ struct ComputeServiceSession::Impl void SetOrchestratorEndpoint(std::string_view Endpoint); void SetOrchestratorBasePath(std::filesystem::path BasePath); + void NotifyOrchestratorChanged(); std::string m_OrchestratorEndpoint; std::filesystem::path m_OrchestratorBasePath; Stopwatch m_OrchestratorQueryTimer; + std::atomic<bool> m_OrchestratorQueryForced{false}; std::unordered_set<std::string> m_KnownWorkerUris; void UpdateCoordinatorState(); + // WebSocket subscription to orchestrator push notifications + struct OrchestratorWsHandler : public IWsClientHandler + { + Impl& Owner; + + explicit OrchestratorWsHandler(Impl& InOwner) : Owner(InOwner) {} + + void OnWsOpen() override + { + ZEN_LOG_INFO(Owner.m_Log, "orchestrator WebSocket connected"); + Owner.NotifyOrchestratorChanged(); + } + + void OnWsMessage(const WebSocketMessage&) override { Owner.NotifyOrchestratorChanged(); } + + void OnWsClose(uint16_t Code, std::string_view Reason) override + { + ZEN_LOG_WARN(Owner.m_Log, "orchestrator WebSocket closed (code {}: {})", Code, Reason); + } + }; + + std::unique_ptr<OrchestratorWsHandler> m_OrchestratorWsHandler; + std::unique_ptr<HttpWsClient> m_OrchestratorWsClient; + // Worker registration and discovery struct FunctionDefinition @@ -157,6 +193,8 @@ struct ComputeServiceSession::Impl std::atomic<int32_t> m_ActionsCounter = 0; // sequence number metrics::Meter m_ArrivalRate; + std::atomic<IComputeCompletionObserver*> m_CompletionObserver{nullptr}; + RwLock m_PendingLock; std::map<int, Ref<RunnerAction>> m_PendingActions; @@ -267,6 +305,8 @@ struct ComputeServiceSession::Impl void DrainQueue(int QueueId); ComputeServiceSession::EnqueueResult EnqueueActionToQueue(int QueueId, CbObject ActionObject, int Priority); ComputeServiceSession::EnqueueResult EnqueueResolvedActionToQueue(int QueueId, WorkerDesc Worker, CbObject ActionObj, int Priority); + ComputeServiceSession::EnqueueResult ValidateQueueForEnqueue(int QueueId, Ref<QueueEntry>& OutQueue); + void ActivateActionInQueue(const Ref<QueueEntry>& Queue, int Lsn); void GetQueueCompleted(int QueueId, CbWriter& Cbo); void NotifyQueueActionComplete(int QueueId, int Lsn, RunnerAction::State ActionState); void ExpireCompletedQueues(); @@ -292,11 +332,13 @@ struct ComputeServiceSession::Impl void HandleActionUpdates(); void PostUpdate(RunnerAction* Action); + void RemoveActionFromActiveMaps(int ActionLsn); static constexpr int kDefaultMaxRetries = 3; int GetMaxRetriesForQueue(int QueueId); ComputeServiceSession::RescheduleResult RescheduleAction(int ActionLsn); + ComputeServiceSession::RescheduleResult RetractAction(int ActionLsn); ActionCounts GetActionCounts() { @@ -449,6 +491,28 @@ void ComputeServiceSession::Impl::SetOrchestratorEndpoint(std::string_view Endpoint) { m_OrchestratorEndpoint = Endpoint; + + // Subscribe to orchestrator WebSocket push so we discover worker changes + // immediately instead of waiting for the next polling cycle. + try + { + std::string WsUrl = HttpToWsUrl(Endpoint, "/orch/ws"); + + m_OrchestratorWsHandler = std::make_unique<OrchestratorWsHandler>(*this); + + HttpWsClientSettings WsSettings; + WsSettings.LogCategory = "orch_disc_ws"; + WsSettings.ConnectTimeout = std::chrono::milliseconds{3000}; + + m_OrchestratorWsClient = std::make_unique<HttpWsClient>(WsUrl, *m_OrchestratorWsHandler, WsSettings); + m_OrchestratorWsClient->Connect(); + } + catch (const std::exception& Ex) + { + ZEN_WARN("failed to connect orchestrator WebSocket, falling back to polling: {}", Ex.what()); + m_OrchestratorWsClient.reset(); + m_OrchestratorWsHandler.reset(); + } } void @@ -458,6 +522,13 @@ ComputeServiceSession::Impl::SetOrchestratorBasePath(std::filesystem::path BaseP } void +ComputeServiceSession::Impl::NotifyOrchestratorChanged() +{ + m_OrchestratorQueryForced.store(true, std::memory_order_relaxed); + m_SchedulingThreadEvent.Set(); +} + +void ComputeServiceSession::Impl::UpdateCoordinatorState() { ZEN_TRACE_CPU("ComputeServiceSession::UpdateCoordinatorState"); @@ -467,10 +538,14 @@ ComputeServiceSession::Impl::UpdateCoordinatorState() } // Poll faster when we have no discovered workers yet so remote runners come online quickly - const uint64_t PollIntervalMs = m_KnownWorkerUris.empty() ? 500 : 5000; - if (m_OrchestratorQueryTimer.GetElapsedTimeMs() < PollIntervalMs) + const bool Forced = m_OrchestratorQueryForced.exchange(false, std::memory_order_relaxed); + if (!Forced) { - return; + const uint64_t PollIntervalMs = m_KnownWorkerUris.empty() ? 500 : 5000; + if (m_OrchestratorQueryTimer.GetElapsedTimeMs() < PollIntervalMs) + { + return; + } } m_OrchestratorQueryTimer.Reset(); @@ -520,7 +595,24 @@ ComputeServiceSession::Impl::UpdateCoordinatorState() continue; } - ZEN_INFO("discovered new worker at {}", UriStr); + std::string_view Hostname = Worker["hostname"sv].AsString(); + std::string_view Platform = Worker["platform"sv].AsString(); + int Cpus = Worker["cpus"sv].AsInt32(); + uint64_t MemTotal = Worker["memory_total"sv].AsUInt64(); + + if (!Hostname.empty()) + { + ZEN_INFO("discovered new worker at {} ({}, {}, {} cpus, {:.1f} GB)", + UriStr, + Hostname, + Platform, + Cpus, + static_cast<double>(MemTotal) / (1024.0 * 1024.0 * 1024.0)); + } + else + { + ZEN_INFO("discovered new worker at {}", UriStr); + } m_KnownWorkerUris.insert(UriStr); @@ -598,6 +690,15 @@ ComputeServiceSession::Impl::Shutdown() { RequestStateTransition(SessionState::Sunset); + // Close orchestrator WebSocket before stopping the scheduler thread + // to prevent callbacks into a shutting-down scheduler. + if (m_OrchestratorWsClient) + { + m_OrchestratorWsClient->Close(); + m_OrchestratorWsClient.reset(); + } + m_OrchestratorWsHandler.reset(); + m_SchedulingThreadEnabled = false; m_SchedulingThreadEvent.Set(); if (m_SchedulingThread.joinable()) @@ -720,8 +821,14 @@ ComputeServiceSession::Impl::RegisterWorker(CbPackage Worker) // different descriptor. Thus we only need to call this the first time, when the // worker is added - m_LocalRunnerGroup.RegisterWorker(Worker); - m_RemoteRunnerGroup.RegisterWorker(Worker); + if (!m_LocalRunnerGroup.RegisterWorker(Worker)) + { + ZEN_WARN("failed to register worker {} on one or more local runners", WorkerId); + } + if (!m_RemoteRunnerGroup.RegisterWorker(Worker)) + { + ZEN_WARN("failed to register worker {} on one or more remote runners", WorkerId); + } if (m_Recorder) { @@ -767,7 +874,10 @@ ComputeServiceSession::Impl::SyncWorkersToRunner(FunctionRunner& Runner) for (const CbPackage& Worker : Workers) { - Runner.RegisterWorker(Worker); + if (!Runner.RegisterWorker(Worker)) + { + ZEN_WARN("failed to sync worker {} to runner", Worker.GetObjectHash()); + } } } @@ -868,9 +978,7 @@ ComputeServiceSession::Impl::EnqueueResolvedAction(int QueueId, WorkerDesc Worke if (m_SessionState.load(std::memory_order_relaxed) != SessionState::Ready) { - CbObjectWriter Writer; - Writer << "error"sv << fmt::format("session is not accepting actions (state: {})", ToString(m_SessionState.load())); - return {0, Writer.Save()}; + return MakeErrorResult(fmt::format("session is not accepting actions (state: {})", ToString(m_SessionState.load()))); } const int ActionLsn = ++m_ActionsCounter; @@ -1258,42 +1366,51 @@ ComputeServiceSession::Impl::DrainQueue(int QueueId) } ComputeServiceSession::EnqueueResult -ComputeServiceSession::Impl::EnqueueActionToQueue(int QueueId, CbObject ActionObject, int Priority) +ComputeServiceSession::Impl::ValidateQueueForEnqueue(int QueueId, Ref<QueueEntry>& OutQueue) { - Ref<QueueEntry> Queue = FindQueue(QueueId); + OutQueue = FindQueue(QueueId); - if (!Queue) + if (!OutQueue) { - CbObjectWriter Writer; - Writer << "error"sv - << "queue not found"sv; - return {0, Writer.Save()}; + return MakeErrorResult("queue not found"sv); } - QueueState QState = Queue->State.load(); + QueueState QState = OutQueue->State.load(); if (QState == QueueState::Cancelled) { - CbObjectWriter Writer; - Writer << "error"sv - << "queue is cancelled"sv; - return {0, Writer.Save()}; + return MakeErrorResult("queue is cancelled"sv); } if (QState == QueueState::Draining) { - CbObjectWriter Writer; - Writer << "error"sv - << "queue is draining"sv; - return {0, Writer.Save()}; + return MakeErrorResult("queue is draining"sv); + } + + return {}; +} + +void +ComputeServiceSession::Impl::ActivateActionInQueue(const Ref<QueueEntry>& Queue, int Lsn) +{ + Queue->m_Lock.WithExclusiveLock([&] { Queue->ActiveLsns.insert(Lsn); }); + Queue->ActiveCount.fetch_add(1, std::memory_order_relaxed); + Queue->IdleSince.store(0, std::memory_order_relaxed); +} + +ComputeServiceSession::EnqueueResult +ComputeServiceSession::Impl::EnqueueActionToQueue(int QueueId, CbObject ActionObject, int Priority) +{ + Ref<QueueEntry> Queue; + if (EnqueueResult Error = ValidateQueueForEnqueue(QueueId, Queue); Error.ResponseMessage) + { + return Error; } EnqueueResult Result = EnqueueAction(QueueId, ActionObject, Priority); if (Result.Lsn != 0) { - Queue->m_Lock.WithExclusiveLock([&] { Queue->ActiveLsns.insert(Result.Lsn); }); - Queue->ActiveCount.fetch_add(1, std::memory_order_relaxed); - Queue->IdleSince.store(0, std::memory_order_relaxed); + ActivateActionInQueue(Queue, Result.Lsn); } return Result; @@ -1302,40 +1419,17 @@ ComputeServiceSession::Impl::EnqueueActionToQueue(int QueueId, CbObject ActionOb ComputeServiceSession::EnqueueResult ComputeServiceSession::Impl::EnqueueResolvedActionToQueue(int QueueId, WorkerDesc Worker, CbObject ActionObj, int Priority) { - Ref<QueueEntry> Queue = FindQueue(QueueId); - - if (!Queue) - { - CbObjectWriter Writer; - Writer << "error"sv - << "queue not found"sv; - return {0, Writer.Save()}; - } - - QueueState QState = Queue->State.load(); - if (QState == QueueState::Cancelled) + Ref<QueueEntry> Queue; + if (EnqueueResult Error = ValidateQueueForEnqueue(QueueId, Queue); Error.ResponseMessage) { - CbObjectWriter Writer; - Writer << "error"sv - << "queue is cancelled"sv; - return {0, Writer.Save()}; - } - - if (QState == QueueState::Draining) - { - CbObjectWriter Writer; - Writer << "error"sv - << "queue is draining"sv; - return {0, Writer.Save()}; + return Error; } EnqueueResult Result = EnqueueResolvedAction(QueueId, Worker, ActionObj, Priority); if (Result.Lsn != 0) { - Queue->m_Lock.WithExclusiveLock([&] { Queue->ActiveLsns.insert(Result.Lsn); }); - Queue->ActiveCount.fetch_add(1, std::memory_order_relaxed); - Queue->IdleSince.store(0, std::memory_order_relaxed); + ActivateActionInQueue(Queue, Result.Lsn); } return Result; @@ -1770,6 +1864,68 @@ ComputeServiceSession::Impl::RescheduleAction(int ActionLsn) return {.Success = true, .RetryCount = NewRetryCount}; } +ComputeServiceSession::RescheduleResult +ComputeServiceSession::Impl::RetractAction(int ActionLsn) +{ + Ref<RunnerAction> Action; + bool WasRunning = false; + + // Look for the action in pending or running maps + m_RunningLock.WithSharedLock([&] { + if (auto It = m_RunningMap.find(ActionLsn); It != m_RunningMap.end()) + { + Action = It->second; + WasRunning = true; + } + }); + + if (!Action) + { + m_PendingLock.WithSharedLock([&] { + if (auto It = m_PendingActions.find(ActionLsn); It != m_PendingActions.end()) + { + Action = It->second; + } + }); + } + + if (!Action) + { + return {.Success = false, .Error = "Action not found in pending or running maps"}; + } + + if (!Action->RetractAction()) + { + return {.Success = false, .Error = "Action cannot be retracted from its current state"}; + } + + // If the action was running, send a cancellation signal to the runner + if (WasRunning) + { + m_LocalRunnerGroup.CancelAction(ActionLsn); + } + + ZEN_INFO("action {} ({}) retract requested", Action->ActionId, ActionLsn); + return {.Success = true, .RetryCount = Action->RetryCount.load(std::memory_order_relaxed)}; +} + +void +ComputeServiceSession::Impl::RemoveActionFromActiveMaps(int ActionLsn) +{ + m_RunningLock.WithExclusiveLock([&] { + m_PendingLock.WithExclusiveLock([&] { + if (auto FindIt = m_RunningMap.find(ActionLsn); FindIt == m_RunningMap.end()) + { + m_PendingActions.erase(ActionLsn); + } + else + { + m_RunningMap.erase(FindIt); + } + }); + }); +} + void ComputeServiceSession::Impl::HandleActionUpdates() { @@ -1781,6 +1937,10 @@ ComputeServiceSession::Impl::HandleActionUpdates() std::unordered_set<int> SeenLsn; + // Collect terminal action notifications for the completion observer. + // Inline capacity of 64 avoids heap allocation in the common case. + eastl::fixed_vector<IComputeCompletionObserver::CompletedActionNotification, 64> TerminalBatch; + // Process each action's latest state, deduplicating by LSN. // // This is safe because state transitions are monotonically increasing by enum @@ -1798,7 +1958,23 @@ ComputeServiceSession::Impl::HandleActionUpdates() { // Newly enqueued — add to pending map for scheduling case RunnerAction::State::Pending: - m_PendingLock.WithExclusiveLock([&] { m_PendingActions.insert({ActionLsn, Action}); }); + // Guard against a race where the session is abandoned between + // EnqueueAction (which calls PostUpdate) and this scheduler + // tick. AbandonAllActions() only scans m_PendingActions, so it + // misses actions still in m_UpdatedActions at the time the + // session transitions. Detect that here and immediately abandon + // rather than inserting into the pending map, where they would + // otherwise be stuck indefinitely. + if (m_SessionState.load(std::memory_order_relaxed) >= SessionState::Abandoned) + { + Action->SetActionState(RunnerAction::State::Abandoned); + // SetActionState calls PostUpdate; the Abandoned action + // will be processed as a terminal on the next scheduler pass. + } + else + { + m_PendingLock.WithExclusiveLock([&] { m_PendingActions.insert({ActionLsn, Action}); }); + } break; // Async submission in progress — remains in pending map @@ -1816,6 +1992,15 @@ ComputeServiceSession::Impl::HandleActionUpdates() ZEN_DEBUG("action {} ({}) RUNNING", Action->ActionId, ActionLsn); break; + // Retracted — pull back for rescheduling without counting against retry limit + case RunnerAction::State::Retracted: + { + RemoveActionFromActiveMaps(ActionLsn); + Action->ResetActionStateToPending(); + ZEN_INFO("action {} ({}) retracted for rescheduling", Action->ActionId, ActionLsn); + break; + } + // Terminal states — move to results, record history, notify queue case RunnerAction::State::Completed: case RunnerAction::State::Failed: @@ -1834,19 +2019,7 @@ ComputeServiceSession::Impl::HandleActionUpdates() if (Action->RetryCount.load(std::memory_order_relaxed) < MaxRetries) { - // Remove from whichever active map the action is in before resetting - m_RunningLock.WithExclusiveLock([&] { - m_PendingLock.WithExclusiveLock([&] { - if (auto FindIt = m_RunningMap.find(ActionLsn); FindIt == m_RunningMap.end()) - { - m_PendingActions.erase(ActionLsn); - } - else - { - m_RunningMap.erase(FindIt); - } - }); - }); + RemoveActionFromActiveMaps(ActionLsn); // Reset triggers PostUpdate() which re-enters the action as Pending Action->ResetActionStateToPending(); @@ -1861,19 +2034,14 @@ ComputeServiceSession::Impl::HandleActionUpdates() } } - // Remove from whichever active map the action is in - m_RunningLock.WithExclusiveLock([&] { - m_PendingLock.WithExclusiveLock([&] { - if (auto FindIt = m_RunningMap.find(ActionLsn); FindIt == m_RunningMap.end()) - { - m_PendingActions.erase(ActionLsn); - } - else - { - m_RunningMap.erase(FindIt); - } - }); - }); + RemoveActionFromActiveMaps(ActionLsn); + + // Update queue counters BEFORE publishing the result into + // m_ResultsMap. GetActionResult erases from m_ResultsMap + // under m_ResultsLock, so if we updated counters after + // releasing that lock, a caller could observe ActiveCount + // still at 1 immediately after GetActionResult returned OK. + NotifyQueueActionComplete(Action->QueueId, ActionLsn, TerminalState); m_ResultsLock.WithExclusiveLock([&] { m_ResultsMap[ActionLsn] = Action; @@ -1902,16 +2070,46 @@ ComputeServiceSession::Impl::HandleActionUpdates() }); m_RetiredCount.fetch_add(1); m_ResultRate.Mark(1); + { + using ObserverState = IComputeCompletionObserver::ActionState; + ObserverState NotifyState{}; + switch (TerminalState) + { + case RunnerAction::State::Completed: + NotifyState = ObserverState::Completed; + break; + case RunnerAction::State::Failed: + NotifyState = ObserverState::Failed; + break; + case RunnerAction::State::Abandoned: + NotifyState = ObserverState::Abandoned; + break; + case RunnerAction::State::Cancelled: + NotifyState = ObserverState::Cancelled; + break; + default: + break; + } + TerminalBatch.push_back({ActionLsn, NotifyState}); + } ZEN_DEBUG("action {} ({}) RUNNING -> COMPLETED with {}", Action->ActionId, ActionLsn, TerminalState == RunnerAction::State::Completed ? "SUCCESS" : "FAILURE"); - NotifyQueueActionComplete(Action->QueueId, ActionLsn, TerminalState); break; } } } } + + // Notify the completion observer, if any, about all terminal actions in this batch. + if (!TerminalBatch.empty()) + { + if (IComputeCompletionObserver* Observer = m_CompletionObserver.load(std::memory_order_acquire)) + { + Observer->OnActionsCompleted({TerminalBatch.data(), TerminalBatch.size()}); + } + } } size_t @@ -2014,6 +2212,12 @@ ComputeServiceSession::SetOrchestratorBasePath(std::filesystem::path BasePath) } void +ComputeServiceSession::NotifyOrchestratorChanged() +{ + m_Impl->NotifyOrchestratorChanged(); +} + +void ComputeServiceSession::StartRecording(ChunkResolver& InResolver, const std::filesystem::path& RecordingPath) { m_Impl->StartRecording(InResolver, RecordingPath); @@ -2194,6 +2398,12 @@ ComputeServiceSession::RescheduleAction(int ActionLsn) return m_Impl->RescheduleAction(ActionLsn); } +ComputeServiceSession::RescheduleResult +ComputeServiceSession::RetractAction(int ActionLsn) +{ + return m_Impl->RetractAction(ActionLsn); +} + std::vector<ComputeServiceSession::RunningActionInfo> ComputeServiceSession::GetRunningActions() { @@ -2219,6 +2429,12 @@ ComputeServiceSession::GetCompleted(CbWriter& Cbo) } void +ComputeServiceSession::SetCompletionObserver(IComputeCompletionObserver* Observer) +{ + m_Impl->m_CompletionObserver.store(Observer, std::memory_order_release); +} + +void ComputeServiceSession::PostUpdate(RunnerAction* Action) { m_Impl->PostUpdate(Action); diff --git a/src/zencompute/httpcomputeservice.cpp b/src/zencompute/httpcomputeservice.cpp index e82a40781..bdfd9d197 100644 --- a/src/zencompute/httpcomputeservice.cpp +++ b/src/zencompute/httpcomputeservice.cpp @@ -16,6 +16,7 @@ # include <zencore/iobuffer.h> # include <zencore/iohash.h> # include <zencore/logging.h> +# include <zencore/string.h> # include <zencore/system.h> # include <zencore/thread.h> # include <zencore/trace.h> @@ -23,8 +24,10 @@ # include <zenstore/cidstore.h> # include <zentelemetry/stats.h> +# include <algorithm> # include <span> # include <unordered_map> +# include <vector> using namespace std::literals; @@ -50,6 +53,11 @@ struct HttpComputeService::Impl ComputeServiceSession m_ComputeService; SystemMetricsTracker m_MetricsTracker; + // WebSocket connections (completion push) + + RwLock m_WsConnectionsLock; + std::vector<Ref<WebSocketConnection>> m_WsConnections; + // Metrics metrics::OperationTiming m_HttpRequests; @@ -91,6 +99,12 @@ struct HttpComputeService::Impl void HandleWorkersAllGet(HttpServerRequest& HttpReq); void WriteQueueDescription(CbWriter& Cbo, int QueueId, const ComputeServiceSession::QueueStatus& Status); void HandleWorkerRequest(HttpServerRequest& HttpReq, const IoHash& WorkerId); + void HandleSubmitAction(HttpServerRequest& HttpReq, int QueueId, int Priority, const WorkerDesc* Worker); + + // WebSocket / observer + void OnWebSocketOpen(Ref<WebSocketConnection> Connection); + void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code); + void OnActionsCompleted(std::span<const IComputeCompletionObserver::CompletedActionNotification> Actions); void RegisterRoutes(); @@ -110,6 +124,7 @@ struct HttpComputeService::Impl m_ComputeService.WaitUntilReady(); m_StatsService.RegisterHandler("compute", *m_Self); RegisterRoutes(); + m_ComputeService.SetCompletionObserver(m_Self); } }; @@ -149,7 +164,7 @@ HttpComputeService::Impl::RegisterRoutes() return HttpReq.WriteResponse(HttpResponseCode::Forbidden); } - bool Success = m_ComputeService.RequestStateTransition(ComputeServiceSession::SessionState::Abandoned); + bool Success = m_ComputeService.Abandon(); if (Success) { @@ -325,6 +340,29 @@ HttpComputeService::Impl::RegisterRoutes() HttpVerb::kGet | HttpVerb::kPost); m_Router.RegisterRoute( + "jobs/{lsn}/retract", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const int ActionLsn = ParseInt<int>(Req.GetCapture(1)).value_or(0); + + auto Result = m_ComputeService.RetractAction(ActionLsn); + + CbObjectWriter Cbo; + if (Result.Success) + { + Cbo << "success"sv << true; + Cbo << "lsn"sv << ActionLsn; + HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + } + else + { + Cbo << "error"sv << Result.Error; + HttpReq.WriteResponse(HttpResponseCode::Conflict, Cbo.Save()); + } + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( "jobs/{worker}/{action}", // This route is inefficient, and is only here for backwards compatibility. The preferred path is the // one which uses the scheduled action lsn for lookups [this](HttpRouterRequest& Req) { @@ -373,127 +411,7 @@ HttpComputeService::Impl::RegisterRoutes() RequestPriority = ParseInt<int>(PriorityParam).value_or(-1); } - switch (HttpReq.RequestVerb()) - { - case HttpVerb::kGet: - // TODO: return status of all pending or executing jobs - break; - - case HttpVerb::kPost: - switch (HttpReq.RequestContentType()) - { - case HttpContentType::kCbObject: - { - // This operation takes the proposed job spec and identifies which - // chunks are not present on this server. This list is then returned in - // the "need" list in the response - - IoBuffer Payload = HttpReq.ReadPayload(); - CbObject ActionObj = LoadCompactBinaryObject(Payload); - - std::vector<IoHash> NeedList; - - ActionObj.IterateAttachments([&](CbFieldView Field) { - const IoHash FileHash = Field.AsHash(); - - if (!m_CidStore.ContainsChunk(FileHash)) - { - NeedList.push_back(FileHash); - } - }); - - if (NeedList.empty()) - { - // We already have everything, enqueue the action for execution - - if (ComputeServiceSession::EnqueueResult Result = - m_ComputeService.EnqueueResolvedAction(Worker, ActionObj, RequestPriority)) - { - ZEN_DEBUG("action {} accepted (lsn {})", ActionObj.GetHash(), Result.Lsn); - - HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); - } - - return; - } - - CbObjectWriter Cbo; - Cbo.BeginArray("need"); - - for (const IoHash& Hash : NeedList) - { - Cbo << Hash; - } - - Cbo.EndArray(); - CbObject Response = Cbo.Save(); - - return HttpReq.WriteResponse(HttpResponseCode::NotFound, Response); - } - break; - - case HttpContentType::kCbPackage: - { - CbPackage Action = HttpReq.ReadPayloadPackage(); - CbObject ActionObj = Action.GetObject(); - - std::span<const CbAttachment> Attachments = Action.GetAttachments(); - - int AttachmentCount = 0; - int NewAttachmentCount = 0; - uint64_t TotalAttachmentBytes = 0; - uint64_t TotalNewBytes = 0; - - for (const CbAttachment& Attachment : Attachments) - { - ZEN_ASSERT(Attachment.IsCompressedBinary()); - - const IoHash DataHash = Attachment.GetHash(); - CompressedBuffer DataView = Attachment.AsCompressedBinary(); - - ZEN_UNUSED(DataHash); - - const uint64_t CompressedSize = DataView.GetCompressedSize(); - - TotalAttachmentBytes += CompressedSize; - ++AttachmentCount; - - const CidStore::InsertResult InsertResult = - m_CidStore.AddChunk(DataView.GetCompressed().Flatten().AsIoBuffer(), DataHash); - - if (InsertResult.New) - { - TotalNewBytes += CompressedSize; - ++NewAttachmentCount; - } - } - - if (ComputeServiceSession::EnqueueResult Result = - m_ComputeService.EnqueueResolvedAction(Worker, ActionObj, RequestPriority)) - { - ZEN_DEBUG("accepted action {} (lsn {}): {} in {} attachments. {} new ({} attachments)", - ActionObj.GetHash(), - Result.Lsn, - zen::NiceBytes(TotalAttachmentBytes), - AttachmentCount, - zen::NiceBytes(TotalNewBytes), - NewAttachmentCount); - - HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); - } - - return; - } - break; - - default: - break; - } - break; - - default: - break; - } + HandleSubmitAction(HttpReq, 0, RequestPriority, &Worker); }, HttpVerb::kPost); @@ -511,118 +429,7 @@ HttpComputeService::Impl::RegisterRoutes() RequestPriority = ParseInt<int>(PriorityParam).value_or(-1); } - // Resolve worker - - // - - switch (HttpReq.RequestContentType()) - { - case HttpContentType::kCbObject: - { - // This operation takes the proposed job spec and identifies which - // chunks are not present on this server. This list is then returned in - // the "need" list in the response - - IoBuffer Payload = HttpReq.ReadPayload(); - CbObject ActionObj = LoadCompactBinaryObject(Payload); - - std::vector<IoHash> NeedList; - - ActionObj.IterateAttachments([&](CbFieldView Field) { - const IoHash FileHash = Field.AsHash(); - - if (!m_CidStore.ContainsChunk(FileHash)) - { - NeedList.push_back(FileHash); - } - }); - - if (NeedList.empty()) - { - // We already have everything, enqueue the action for execution - - if (ComputeServiceSession::EnqueueResult Result = m_ComputeService.EnqueueAction(ActionObj, RequestPriority)) - { - ZEN_DEBUG("action accepted (lsn {})", Result.Lsn); - - return HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); - } - else - { - // Could not resolve? - return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage); - } - } - - CbObjectWriter Cbo; - Cbo.BeginArray("need"); - - for (const IoHash& Hash : NeedList) - { - Cbo << Hash; - } - - Cbo.EndArray(); - CbObject Response = Cbo.Save(); - - return HttpReq.WriteResponse(HttpResponseCode::NotFound, Response); - } - - case HttpContentType::kCbPackage: - { - CbPackage Action = HttpReq.ReadPayloadPackage(); - CbObject ActionObj = Action.GetObject(); - - std::span<const CbAttachment> Attachments = Action.GetAttachments(); - - int AttachmentCount = 0; - int NewAttachmentCount = 0; - uint64_t TotalAttachmentBytes = 0; - uint64_t TotalNewBytes = 0; - - for (const CbAttachment& Attachment : Attachments) - { - ZEN_ASSERT(Attachment.IsCompressedBinary()); - - const IoHash DataHash = Attachment.GetHash(); - CompressedBuffer DataView = Attachment.AsCompressedBinary(); - - ZEN_UNUSED(DataHash); - - const uint64_t CompressedSize = DataView.GetCompressedSize(); - - TotalAttachmentBytes += CompressedSize; - ++AttachmentCount; - - const CidStore::InsertResult InsertResult = - m_CidStore.AddChunk(DataView.GetCompressed().Flatten().AsIoBuffer(), DataHash); - - if (InsertResult.New) - { - TotalNewBytes += CompressedSize; - ++NewAttachmentCount; - } - } - - if (ComputeServiceSession::EnqueueResult Result = m_ComputeService.EnqueueAction(ActionObj, RequestPriority)) - { - ZEN_DEBUG("accepted action (lsn {}): {} in {} attachments. {} new ({} attachments)", - Result.Lsn, - zen::NiceBytes(TotalAttachmentBytes), - AttachmentCount, - zen::NiceBytes(TotalNewBytes), - NewAttachmentCount); - - HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); - } - else - { - // Could not resolve? - return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage); - } - } - return; - } + HandleSubmitAction(HttpReq, 0, RequestPriority, nullptr); }, HttpVerb::kPost); @@ -1090,72 +897,7 @@ HttpComputeService::Impl::RegisterRoutes() RequestPriority = ParseInt<int>(PriorityParam).value_or(-1); } - switch (HttpReq.RequestContentType()) - { - case HttpContentType::kCbObject: - { - IoBuffer Payload = HttpReq.ReadPayload(); - CbObject ActionObj = LoadCompactBinaryObject(Payload); - - std::vector<IoHash> NeedList; - - if (!CheckAttachments(ActionObj, NeedList)) - { - CbObjectWriter Cbo; - Cbo.BeginArray("need"); - - for (const IoHash& Hash : NeedList) - { - Cbo << Hash; - } - - Cbo.EndArray(); - - return HttpReq.WriteResponse(HttpResponseCode::NotFound, Cbo.Save()); - } - - if (ComputeServiceSession::EnqueueResult Result = - m_ComputeService.EnqueueResolvedActionToQueue(QueueId, Worker, ActionObj, RequestPriority)) - { - ZEN_DEBUG("queue {}: action {} accepted (lsn {})", QueueId, ActionObj.GetHash(), Result.Lsn); - return HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); - } - else - { - return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage); - } - } - - case HttpContentType::kCbPackage: - { - CbPackage Action = HttpReq.ReadPayloadPackage(); - CbObject ActionObj = Action.GetObject(); - - IngestStats Stats = IngestPackageAttachments(Action); - - if (ComputeServiceSession::EnqueueResult Result = - m_ComputeService.EnqueueResolvedActionToQueue(QueueId, Worker, ActionObj, RequestPriority)) - { - ZEN_DEBUG("queue {}: accepted action {} (lsn {}): {} in {} attachments. {} new ({} attachments)", - QueueId, - ActionObj.GetHash(), - Result.Lsn, - zen::NiceBytes(Stats.Bytes), - Stats.Count, - zen::NiceBytes(Stats.NewBytes), - Stats.NewCount); - - return HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); - } - else - { - return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage); - } - } - - default: - break; - } + HandleSubmitAction(HttpReq, QueueId, RequestPriority, &Worker); }, HttpVerb::kPost); @@ -1178,71 +920,7 @@ HttpComputeService::Impl::RegisterRoutes() RequestPriority = ParseInt<int>(PriorityParam).value_or(-1); } - switch (HttpReq.RequestContentType()) - { - case HttpContentType::kCbObject: - { - IoBuffer Payload = HttpReq.ReadPayload(); - CbObject ActionObj = LoadCompactBinaryObject(Payload); - - std::vector<IoHash> NeedList; - - if (!CheckAttachments(ActionObj, NeedList)) - { - CbObjectWriter Cbo; - Cbo.BeginArray("need"); - - for (const IoHash& Hash : NeedList) - { - Cbo << Hash; - } - - Cbo.EndArray(); - - return HttpReq.WriteResponse(HttpResponseCode::NotFound, Cbo.Save()); - } - - if (ComputeServiceSession::EnqueueResult Result = - m_ComputeService.EnqueueActionToQueue(QueueId, ActionObj, RequestPriority)) - { - ZEN_DEBUG("queue {}: action accepted (lsn {})", QueueId, Result.Lsn); - return HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); - } - else - { - return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage); - } - } - - case HttpContentType::kCbPackage: - { - CbPackage Action = HttpReq.ReadPayloadPackage(); - CbObject ActionObj = Action.GetObject(); - - IngestStats Stats = IngestPackageAttachments(Action); - - if (ComputeServiceSession::EnqueueResult Result = - m_ComputeService.EnqueueActionToQueue(QueueId, ActionObj, RequestPriority)) - { - ZEN_DEBUG("queue {}: accepted action (lsn {}): {} in {} attachments. {} new ({} attachments)", - QueueId, - Result.Lsn, - zen::NiceBytes(Stats.Bytes), - Stats.Count, - zen::NiceBytes(Stats.NewBytes), - Stats.NewCount); - - return HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); - } - else - { - return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage); - } - } - - default: - break; - } + HandleSubmitAction(HttpReq, QueueId, RequestPriority, nullptr); }, HttpVerb::kPost); @@ -1306,6 +984,45 @@ HttpComputeService::Impl::RegisterRoutes() } }, HttpVerb::kGet | HttpVerb::kPost); + + m_Router.RegisterRoute( + "queues/{queueref}/jobs/{lsn}/retract", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const int QueueId = ResolveQueueRef(HttpReq, Req.GetCapture(1)); + const int ActionLsn = ParseInt<int>(Req.GetCapture(2)).value_or(0); + + if (QueueId == 0) + { + return; + } + + ZEN_UNUSED(QueueId); + + auto Result = m_ComputeService.RetractAction(ActionLsn); + + CbObjectWriter Cbo; + if (Result.Success) + { + Cbo << "success"sv << true; + Cbo << "lsn"sv << ActionLsn; + HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + } + else + { + Cbo << "error"sv << Result.Error; + HttpReq.WriteResponse(HttpResponseCode::Conflict, Cbo.Save()); + } + }, + HttpVerb::kPost); + + // WebSocket upgrade endpoint — the handler logic lives in + // HttpComputeService::OnWebSocket* methods; this route merely + // satisfies the router so the upgrade request isn't rejected. + m_Router.RegisterRoute( + "ws", + [this](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK); }, + HttpVerb::kGet); } ////////////////////////////////////////////////////////////////////////// @@ -1320,12 +1037,17 @@ HttpComputeService::HttpComputeService(CidStore& InCidStore, HttpComputeService::~HttpComputeService() { + m_Impl->m_ComputeService.SetCompletionObserver(nullptr); m_Impl->m_StatsService.UnregisterHandler("compute", *this); } void HttpComputeService::Shutdown() { + // Null out observer before shutting down the compute session to prevent + // callbacks into a partially-torn-down service. + m_Impl->m_ComputeService.SetCompletionObserver(nullptr); + m_Impl->m_WsConnectionsLock.WithExclusiveLock([&] { m_Impl->m_WsConnections.clear(); }); m_Impl->m_ComputeService.Shutdown(); } @@ -1492,6 +1214,184 @@ HttpComputeService::Impl::CheckAttachments(const CbObject& ActionObj, std::vecto } void +HttpComputeService::Impl::HandleSubmitAction(HttpServerRequest& HttpReq, int QueueId, int Priority, const WorkerDesc* Worker) +{ + // QueueId > 0: queue-scoped enqueue; QueueId == 0: implicit queue (global routes) + auto Enqueue = [&](CbObject ActionObj) -> ComputeServiceSession::EnqueueResult { + if (QueueId > 0) + { + if (Worker) + { + return m_ComputeService.EnqueueResolvedActionToQueue(QueueId, *Worker, ActionObj, Priority); + } + return m_ComputeService.EnqueueActionToQueue(QueueId, ActionObj, Priority); + } + else + { + if (Worker) + { + return m_ComputeService.EnqueueResolvedAction(*Worker, ActionObj, Priority); + } + return m_ComputeService.EnqueueAction(ActionObj, Priority); + } + }; + + // Read payload upfront and handle attachments based on content type + CbObject Body; + IngestStats Stats = {}; + + switch (HttpReq.RequestContentType()) + { + case HttpContentType::kCbObject: + { + IoBuffer Payload = HttpReq.ReadPayload(); + Body = LoadCompactBinaryObject(Payload); + break; + } + + case HttpContentType::kCbPackage: + { + CbPackage Package = HttpReq.ReadPayloadPackage(); + Body = Package.GetObject(); + Stats = IngestPackageAttachments(Package); + break; + } + + default: + return HttpReq.WriteResponse(HttpResponseCode::BadRequest); + } + + // Check for "actions" array to determine batch vs single-action path + CbArray Actions = Body.Find("actions"sv).AsArray(); + + if (Actions.Num() > 0) + { + // --- Batch path --- + + // For CbObject payloads, check all attachments upfront before enqueuing anything + if (HttpReq.RequestContentType() == HttpContentType::kCbObject) + { + std::vector<IoHash> NeedList; + + for (CbField ActionField : Actions) + { + CbObject ActionObj = ActionField.AsObject(); + CheckAttachments(ActionObj, NeedList); + } + + if (!NeedList.empty()) + { + CbObjectWriter Cbo; + Cbo.BeginArray("need"); + + for (const IoHash& Hash : NeedList) + { + Cbo << Hash; + } + + Cbo.EndArray(); + + return HttpReq.WriteResponse(HttpResponseCode::NotFound, Cbo.Save()); + } + } + + // Enqueue all actions and collect results + CbObjectWriter Cbo; + int Accepted = 0; + + Cbo.BeginArray("results"); + + for (CbField ActionField : Actions) + { + CbObject ActionObj = ActionField.AsObject(); + + ComputeServiceSession::EnqueueResult Result = Enqueue(ActionObj); + + Cbo.BeginObject(); + + if (Result) + { + Cbo << "lsn"sv << Result.Lsn; + ++Accepted; + } + else + { + Cbo << "error"sv << Result.ResponseMessage; + } + + Cbo.EndObject(); + } + + Cbo.EndArray(); + + if (Stats.Count > 0) + { + ZEN_DEBUG("queue {}: batch accepted {}/{} actions: {} in {} attachments. {} new ({} attachments)", + QueueId, + Accepted, + Actions.Num(), + zen::NiceBytes(Stats.Bytes), + Stats.Count, + zen::NiceBytes(Stats.NewBytes), + Stats.NewCount); + } + else + { + ZEN_DEBUG("queue {}: batch accepted {}/{} actions", QueueId, Accepted, Actions.Num()); + } + + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + } + + // --- Single-action path: Body is the action itself --- + + if (HttpReq.RequestContentType() == HttpContentType::kCbObject) + { + std::vector<IoHash> NeedList; + + if (!CheckAttachments(Body, NeedList)) + { + CbObjectWriter Cbo; + Cbo.BeginArray("need"); + + for (const IoHash& Hash : NeedList) + { + Cbo << Hash; + } + + Cbo.EndArray(); + + return HttpReq.WriteResponse(HttpResponseCode::NotFound, Cbo.Save()); + } + } + + if (ComputeServiceSession::EnqueueResult Result = Enqueue(Body)) + { + if (Stats.Count > 0) + { + ZEN_DEBUG("queue {}: accepted action {} (lsn {}): {} in {} attachments. {} new ({} attachments)", + QueueId, + Body.GetHash(), + Result.Lsn, + zen::NiceBytes(Stats.Bytes), + Stats.Count, + zen::NiceBytes(Stats.NewBytes), + Stats.NewCount); + } + else + { + ZEN_DEBUG("queue {}: action {} accepted (lsn {})", QueueId, Body.GetHash(), Result.Lsn); + } + + return HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); + } + else + { + return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage); + } +} + +void HttpComputeService::Impl::HandleWorkersGet(HttpServerRequest& HttpReq) { CbObjectWriter Cbo; @@ -1632,6 +1532,136 @@ HttpComputeService::Impl::HandleWorkerRequest(HttpServerRequest& HttpReq, const } ////////////////////////////////////////////////////////////////////////// +// +// IWebSocketHandler +// + +void +HttpComputeService::OnWebSocketOpen(Ref<WebSocketConnection> Connection) +{ + m_Impl->OnWebSocketOpen(std::move(Connection)); +} + +void +HttpComputeService::OnWebSocketMessage([[maybe_unused]] WebSocketConnection& Conn, [[maybe_unused]] const WebSocketMessage& Msg) +{ + // Clients are receive-only; ignore any inbound messages. +} + +void +HttpComputeService::OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, [[maybe_unused]] std::string_view Reason) +{ + m_Impl->OnWebSocketClose(Conn, Code); +} + +void +HttpComputeService::OnActionsCompleted(std::span<const CompletedActionNotification> Actions) +{ + m_Impl->OnActionsCompleted(Actions); +} + +////////////////////////////////////////////////////////////////////////// +// +// Impl — WebSocket / observer +// + +void +HttpComputeService::Impl::OnWebSocketOpen(Ref<WebSocketConnection> Connection) +{ + ZEN_INFO("compute WebSocket client connected"); + m_WsConnectionsLock.WithExclusiveLock([&] { m_WsConnections.push_back(std::move(Connection)); }); +} + +void +HttpComputeService::Impl::OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code) +{ + ZEN_INFO("compute WebSocket client disconnected (code {})", Code); + + m_WsConnectionsLock.WithExclusiveLock([&] { + auto It = std::remove_if(m_WsConnections.begin(), m_WsConnections.end(), [&Conn](const Ref<WebSocketConnection>& C) { + return C.Get() == &Conn; + }); + m_WsConnections.erase(It, m_WsConnections.end()); + }); +} + +void +HttpComputeService::Impl::OnActionsCompleted(std::span<const IComputeCompletionObserver::CompletedActionNotification> Actions) +{ + using ActionState = IComputeCompletionObserver::ActionState; + using CompletedActionNotification = IComputeCompletionObserver::CompletedActionNotification; + + // Snapshot connections under shared lock + eastl::fixed_vector<Ref<WebSocketConnection>, 16> Connections; + m_WsConnectionsLock.WithSharedLock([&] { Connections = {begin(m_WsConnections), end(m_WsConnections)}; }); + + if (Connections.empty()) + { + return; + } + + // Build CompactBinary notification grouped by state: + // {"Completed": [lsn, ...], "Failed": [lsn, ...], ...} + // Each state name becomes an array key containing the LSNs in that state. + CbObjectWriter Cbo; + + // Sort by state so we can emit one array per state in a single pass. + // Copy into a local vector since the span is const. + eastl::fixed_vector<CompletedActionNotification, 16> Sorted(Actions.begin(), Actions.end()); + std::sort(Sorted.begin(), Sorted.end(), [](const auto& A, const auto& B) { return A.State < B.State; }); + + ActionState CurrentState{}; + bool ArrayOpen = false; + + for (const CompletedActionNotification& Action : Sorted) + { + if (!ArrayOpen || Action.State != CurrentState) + { + if (ArrayOpen) + { + Cbo.EndArray(); + } + CurrentState = Action.State; + Cbo.BeginArray(IComputeCompletionObserver::ActionStateToString(CurrentState)); + ArrayOpen = true; + } + Cbo.AddInteger(Action.Lsn); + } + + if (ArrayOpen) + { + Cbo.EndArray(); + } + + CbObject Msg = Cbo.Save(); + MemoryView MsgView = Msg.GetView(); + + // Broadcast to all connected clients, prune closed ones + bool HadClosedConnections = false; + for (auto& Conn : Connections) + { + if (Conn->IsOpen()) + { + Conn->SendBinary(MsgView); + } + else + { + HadClosedConnections = true; + } + } + + if (HadClosedConnections) + { + m_WsConnectionsLock.WithExclusiveLock([&] { + auto It = std::remove_if(m_WsConnections.begin(), m_WsConnections.end(), [](const Ref<WebSocketConnection>& C) { + return !C->IsOpen(); + }); + m_WsConnections.erase(It, m_WsConnections.end()); + }); + } +} + +////////////////////////////////////////////////////////////////////////// void httpcomputeservice_forcelink() diff --git a/src/zencompute/include/zencompute/computeservice.h b/src/zencompute/include/zencompute/computeservice.h index 65ec5f9ee..1ca78738a 100644 --- a/src/zencompute/include/zencompute/computeservice.h +++ b/src/zencompute/include/zencompute/computeservice.h @@ -13,6 +13,7 @@ # include <zenhttp/httpcommon.h> # include <filesystem> +# include <span> namespace zen { class ChunkResolver; @@ -29,6 +30,53 @@ class RemoteHttpRunner; struct RunnerAction; struct SubmitResult; +/** + * Observer interface for action completion notifications. + * + * Implementors receive a batch of notifications whenever actions reach a + * terminal state (Completed, Failed, Abandoned, Cancelled). The callback + * fires on the scheduler thread *after* the action result has been placed + * in m_ResultsMap, so GET /jobs/{lsn} will succeed by the time the client + * reacts to the notification. + */ +class IComputeCompletionObserver +{ +public: + virtual ~IComputeCompletionObserver() = default; + + enum class ActionState + { + Completed, + Failed, + Abandoned, + Cancelled, + }; + + struct CompletedActionNotification + { + int Lsn; + ActionState State; + }; + + static constexpr std::string_view ActionStateToString(ActionState State) + { + switch (State) + { + case ActionState::Completed: + return "Completed"; + case ActionState::Failed: + return "Failed"; + case ActionState::Abandoned: + return "Abandoned"; + case ActionState::Cancelled: + return "Cancelled"; + } + return "Unknown"; + } + + virtual void OnActionsCompleted(std::span<const CompletedActionNotification> Actions) = 0; +}; + struct WorkerDesc { CbPackage Descriptor; @@ -91,11 +139,25 @@ public: // Sunset can be reached from any non-Sunset state. bool RequestStateTransition(SessionState NewState); + // Convenience helpers for common state transitions. + bool Ready() { return RequestStateTransition(SessionState::Ready); } + bool Abandon() { return RequestStateTransition(SessionState::Abandoned); } + // Orchestration void SetOrchestratorEndpoint(std::string_view Endpoint); void SetOrchestratorBasePath(std::filesystem::path BasePath); + void SetOrchestrator(std::string_view Endpoint, std::filesystem::path BasePath) + { + SetOrchestratorEndpoint(Endpoint); + SetOrchestratorBasePath(std::move(BasePath)); + } + + /// Immediately wake the scheduler to re-poll the orchestrator for worker changes. + /// Resets the polling throttle so the next scheduler tick calls UpdateCoordinatorState(). + void NotifyOrchestratorChanged(); + // Worker registration and discovery void RegisterWorker(CbPackage Worker); @@ -182,6 +244,7 @@ public: }; [[nodiscard]] RescheduleResult RescheduleAction(int ActionLsn); + [[nodiscard]] RescheduleResult RetractAction(int ActionLsn); void GetCompleted(CbWriter&); @@ -215,7 +278,7 @@ public: // sized to match RunnerAction::State::_Count but we can't use the enum here // for dependency reasons, so just use a fixed size array and static assert in // the implementation file - uint64_t Timestamps[8] = {}; + uint64_t Timestamps[9] = {}; }; [[nodiscard]] std::vector<ActionHistoryEntry> GetActionHistory(int Limit = 100); @@ -235,6 +298,10 @@ public: void EmitStats(CbObjectWriter& Cbo); + // Completion observer (used by HttpComputeService for WebSocket push) + + void SetCompletionObserver(IComputeCompletionObserver* Observer); + // Recording void StartRecording(ChunkResolver& InResolver, const std::filesystem::path& RecordingPath); diff --git a/src/zencompute/include/zencompute/httpcomputeservice.h b/src/zencompute/include/zencompute/httpcomputeservice.h index ee1cd2614..b58e73a0d 100644 --- a/src/zencompute/include/zencompute/httpcomputeservice.h +++ b/src/zencompute/include/zencompute/httpcomputeservice.h @@ -9,6 +9,7 @@ # include "zencompute/computeservice.h" # include <zenhttp/httpserver.h> +# include <zenhttp/websocket.h> # include <filesystem> # include <memory> @@ -22,7 +23,7 @@ namespace zen::compute { /** * HTTP interface for compute service */ -class HttpComputeService : public HttpService, public IHttpStatsProvider +class HttpComputeService : public HttpService, public IHttpStatsProvider, public IWebSocketHandler, public IComputeCompletionObserver { public: HttpComputeService(CidStore& InCidStore, @@ -42,6 +43,16 @@ public: void HandleStatsRequest(HttpServerRequest& Request) override; + // IWebSocketHandler + + void OnWebSocketOpen(Ref<WebSocketConnection> Connection) override; + void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) override; + void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason) override; + + // IComputeCompletionObserver + + void OnActionsCompleted(std::span<const CompletedActionNotification> Actions) override; + private: struct Impl; std::unique_ptr<Impl> m_Impl; diff --git a/src/zencompute/runners/functionrunner.cpp b/src/zencompute/runners/functionrunner.cpp index 768cdf1e1..4f116e7d8 100644 --- a/src/zencompute/runners/functionrunner.cpp +++ b/src/zencompute/runners/functionrunner.cpp @@ -215,15 +215,22 @@ BaseRunnerGroup::GetSubmittedActionCount() return TotalCount; } -void +bool BaseRunnerGroup::RegisterWorker(CbPackage Worker) { RwLock::SharedLockScope _(m_RunnersLock); + bool AllSucceeded = true; + for (auto& Runner : m_Runners) { - Runner->RegisterWorker(Worker); + if (!Runner->RegisterWorker(Worker)) + { + AllSucceeded = false; + } } + + return AllSucceeded; } void @@ -276,12 +283,34 @@ RunnerAction::~RunnerAction() } bool +RunnerAction::RetractAction() +{ + State CurrentState = m_ActionState.load(); + + do + { + // Only allow retraction from pre-terminal states (idempotent if already retracted) + if (CurrentState > State::Running) + { + return CurrentState == State::Retracted; + } + + if (m_ActionState.compare_exchange_strong(CurrentState, State::Retracted)) + { + this->Timestamps[static_cast<int>(State::Retracted)] = DateTime::Now().GetTicks(); + m_OwnerSession->PostUpdate(this); + return true; + } + } while (true); +} + +bool RunnerAction::ResetActionStateToPending() { - // Only allow reset from Failed or Abandoned states + // Only allow reset from Failed, Abandoned, or Retracted states State CurrentState = m_ActionState.load(); - if (CurrentState != State::Failed && CurrentState != State::Abandoned) + if (CurrentState != State::Failed && CurrentState != State::Abandoned && CurrentState != State::Retracted) { return false; } @@ -305,8 +334,11 @@ RunnerAction::ResetActionStateToPending() CpuUsagePercent.store(-1.0f, std::memory_order_relaxed); CpuSeconds.store(0.0f, std::memory_order_relaxed); - // Increment retry count - RetryCount.fetch_add(1, std::memory_order_relaxed); + // Increment retry count (skip for Retracted — nothing failed) + if (CurrentState != State::Retracted) + { + RetryCount.fetch_add(1, std::memory_order_relaxed); + } // Re-enter the scheduler pipeline m_OwnerSession->PostUpdate(this); diff --git a/src/zencompute/runners/functionrunner.h b/src/zencompute/runners/functionrunner.h index f67414dbb..56c3f3af0 100644 --- a/src/zencompute/runners/functionrunner.h +++ b/src/zencompute/runners/functionrunner.h @@ -29,8 +29,8 @@ public: FunctionRunner(std::filesystem::path BasePath); virtual ~FunctionRunner() = 0; - virtual void Shutdown() = 0; - virtual void RegisterWorker(const CbPackage& WorkerPackage) = 0; + virtual void Shutdown() = 0; + [[nodiscard]] virtual bool RegisterWorker(const CbPackage& WorkerPackage) = 0; [[nodiscard]] virtual SubmitResult SubmitAction(Ref<RunnerAction> Action) = 0; [[nodiscard]] virtual size_t GetSubmittedActionCount() = 0; @@ -63,7 +63,7 @@ public: SubmitResult SubmitAction(Ref<RunnerAction> Action); std::vector<SubmitResult> SubmitActions(const std::vector<Ref<RunnerAction>>& Actions); size_t GetSubmittedActionCount(); - void RegisterWorker(CbPackage Worker); + [[nodiscard]] bool RegisterWorker(CbPackage Worker); void Shutdown(); bool CancelAction(int ActionLsn); void CancelRemoteQueue(int QueueId); @@ -114,6 +114,30 @@ struct RunnerGroup : public BaseRunnerGroup /** * This represents an action going through different stages of scheduling and execution. + * + * State machine + * ============= + * + * Normal forward flow (enforced by SetActionState rejecting backward transitions): + * + * New -> Pending -> Submitting -> Running -> Completed + * -> Failed + * -> Abandoned + * -> Cancelled + * + * Rescheduling (via ResetActionStateToPending): + * + * Failed ---> Pending (increments RetryCount, subject to retry limit) + * Abandoned ---> Pending (increments RetryCount, subject to retry limit) + * Retracted ---> Pending (does NOT increment RetryCount) + * + * Retraction (via RetractAction, idempotent): + * + * Pending/Submitting/Running -> Retracted -> Pending (rescheduled) + * + * Retracted is placed after Cancelled in enum order so that once set, + * no runner-side transition (Completed/Failed) can override it via + * SetActionState's forward-only rule. */ struct RunnerAction : public RefCounted { @@ -137,16 +161,20 @@ struct RunnerAction : public RefCounted enum class State { - New, - Pending, - Submitting, - Running, - Completed, - Failed, - Abandoned, - Cancelled, + New, // Initial state at construction, before entering the scheduler + Pending, // Queued and waiting for a runner slot + Submitting, // Being handed off to a runner (async submission in progress) + Running, // Executing on a runner process + Completed, // Finished successfully with results available + Failed, // Execution failed (transient error, eligible for retry) + Abandoned, // Infrastructure termination (e.g. spot eviction, session abandon) + Cancelled, // Intentional user cancellation (never retried) + Retracted, // Pulled back for rescheduling on a different runner (no retry cost) _Count }; + static_assert(State::Retracted > State::Completed && State::Retracted > State::Failed && State::Retracted > State::Abandoned && + State::Retracted > State::Cancelled, + "Retracted must be the highest terminal ordinal so runner-side transitions cannot override it"); static const char* ToString(State _) { @@ -168,6 +196,8 @@ struct RunnerAction : public RefCounted return "Abandoned"; case State::Cancelled: return "Cancelled"; + case State::Retracted: + return "Retracted"; default: return "Unknown"; } @@ -191,6 +221,7 @@ struct RunnerAction : public RefCounted void SetActionState(State NewState); bool IsSuccess() const { return ActionState() == State::Completed; } + bool RetractAction(); bool ResetActionStateToPending(); bool IsCompleted() const { diff --git a/src/zencompute/runners/localrunner.cpp b/src/zencompute/runners/localrunner.cpp index 7aaefb06e..b61e0a46f 100644 --- a/src/zencompute/runners/localrunner.cpp +++ b/src/zencompute/runners/localrunner.cpp @@ -7,14 +7,16 @@ # include <zencore/compactbinary.h> # include <zencore/compactbinarybuilder.h> # include <zencore/compactbinarypackage.h> +# include <zencore/compactbinaryfile.h> # include <zencore/compress.h> # include <zencore/except_fmt.h> # include <zencore/filesystem.h> # include <zencore/fmtutils.h> # include <zencore/iobuffer.h> # include <zencore/iohash.h> -# include <zencore/system.h> # include <zencore/scopeguard.h> +# include <zencore/stream.h> +# include <zencore/system.h> # include <zencore/timer.h> # include <zencore/trace.h> # include <zenstore/cidstore.h> @@ -152,7 +154,7 @@ LocalProcessRunner::CreateNewSandbox() return Path; } -void +bool LocalProcessRunner::RegisterWorker(const CbPackage& WorkerPackage) { ZEN_TRACE_CPU("LocalProcessRunner::RegisterWorker"); @@ -173,6 +175,8 @@ LocalProcessRunner::RegisterWorker(const CbPackage& WorkerPackage) ZEN_INFO("dumped worker '{}' to 'file://{}'", WorkerId, Path); } + + return true; } size_t @@ -301,7 +305,7 @@ LocalProcessRunner::PrepareActionSubmission(Ref<RunnerAction> Action) // Write out action - zen::WriteFile(SandboxPath / "build.action", ActionObj.GetBuffer().AsIoBuffer()); + WriteCompactBinaryObject(SandboxPath / "build.action", ActionObj); // Manifest inputs in sandbox diff --git a/src/zencompute/runners/localrunner.h b/src/zencompute/runners/localrunner.h index 7493e980b..b8cff6826 100644 --- a/src/zencompute/runners/localrunner.h +++ b/src/zencompute/runners/localrunner.h @@ -51,7 +51,7 @@ public: ~LocalProcessRunner(); virtual void Shutdown() override; - virtual void RegisterWorker(const CbPackage& WorkerPackage) override; + [[nodiscard]] virtual bool RegisterWorker(const CbPackage& WorkerPackage) override; [[nodiscard]] virtual SubmitResult SubmitAction(Ref<RunnerAction> Action) override; [[nodiscard]] virtual bool IsHealthy() override { return true; } [[nodiscard]] virtual size_t GetSubmittedActionCount() override; diff --git a/src/zencompute/runners/remotehttprunner.cpp b/src/zencompute/runners/remotehttprunner.cpp index 672636d06..ce6a81173 100644 --- a/src/zencompute/runners/remotehttprunner.cpp +++ b/src/zencompute/runners/remotehttprunner.cpp @@ -42,6 +42,20 @@ RemoteHttpRunner::RemoteHttpRunner(ChunkResolver& InChunkResolver, , m_Http(m_BaseUrl) , m_InstanceId(Oid::NewOid()) { + // Attempt to connect a WebSocket for push-based completion notifications. + // If the remote doesn't support WS, OnWsClose fires and we fall back to polling. + { + std::string WsUrl = HttpToWsUrl(HostName, "/compute/ws"); + + HttpWsClientSettings WsSettings; + WsSettings.LogCategory = "http_exec_ws"; + WsSettings.ConnectTimeout = std::chrono::milliseconds{3000}; + + IWsClientHandler& Handler = *this; + m_WsClient = std::make_unique<HttpWsClient>(WsUrl, Handler, WsSettings); + m_WsClient->Connect(); + } + m_MonitorThread = std::thread{&RemoteHttpRunner::MonitorThreadFunction, this}; } @@ -53,7 +67,29 @@ RemoteHttpRunner::~RemoteHttpRunner() void RemoteHttpRunner::Shutdown() { - // TODO: should cleanly drain/cancel pending work + m_AcceptNewActions = false; + + // Close the WebSocket client first, so no more wakeup signals arrive. + if (m_WsClient) + { + m_WsClient->Close(); + } + + // Cancel all known remote queues so the remote side stops scheduling new + // work and cancels in-flight actions belonging to those queues. + + { + std::vector<std::pair<int, Oid>> Queues; + + m_QueueTokenLock.WithSharedLock([&] { Queues.assign(m_RemoteQueueTokens.begin(), m_RemoteQueueTokens.end()); }); + + for (const auto& [QueueId, Token] : Queues) + { + CancelRemoteQueue(QueueId); + } + } + + // Stop the monitor thread so it no longer polls the remote. m_MonitorThreadEnabled = false; m_MonitorThreadEvent.Set(); @@ -61,9 +97,22 @@ RemoteHttpRunner::Shutdown() { m_MonitorThread.join(); } + + // Drain the running map and mark all remaining actions as Failed so the + // scheduler can reschedule or finalize them. + + std::unordered_map<int, HttpRunningAction> Remaining; + + m_RunningLock.WithExclusiveLock([&] { Remaining.swap(m_RemoteRunningMap); }); + + for (auto& [RemoteLsn, HttpAction] : Remaining) + { + ZEN_DEBUG("shutdown: marking remote action LSN {} (local LSN {}) as Failed", RemoteLsn, HttpAction.Action->ActionLsn); + HttpAction.Action->SetActionState(RunnerAction::State::Failed); + } } -void +bool RemoteHttpRunner::RegisterWorker(const CbPackage& WorkerPackage) { ZEN_TRACE_CPU("RemoteHttpRunner::RegisterWorker"); @@ -125,15 +174,13 @@ RemoteHttpRunner::RegisterWorker(const CbPackage& WorkerPackage) if (!IsHttpSuccessCode(PayloadResponse.StatusCode)) { ZEN_ERROR("ERROR: unable to register payloads for worker {} at {}{}", WorkerId, m_Http.GetBaseUri(), WorkerUrl); - - // TODO: propagate error + return false; } } else if (!IsHttpSuccessCode(DescResponse.StatusCode)) { ZEN_ERROR("ERROR: unable to register worker {} at {}{}", WorkerId, m_Http.GetBaseUri(), WorkerUrl); - - // TODO: propagate error + return false; } else { @@ -152,14 +199,20 @@ RemoteHttpRunner::RegisterWorker(const CbPackage& WorkerPackage) WorkerUrl, (int)WorkerResponse.StatusCode, ToString(WorkerResponse.StatusCode)); - - // TODO: propagate error + return false; } + + return true; } size_t RemoteHttpRunner::QueryCapacity() { + if (!m_AcceptNewActions) + { + return 0; + } + // Estimate how much more work we're ready to accept RwLock::SharedLockScope _{m_RunningLock}; @@ -191,24 +244,68 @@ RemoteHttpRunner::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions) return Results; } - // For larger batches, submit HTTP requests in parallel via the shared worker pool + // Collect distinct QueueIds and ensure remote queues exist once per queue - std::vector<std::future<SubmitResult>> Futures; - Futures.reserve(Actions.size()); + std::unordered_map<int, Oid> QueueTokens; // QueueId → remote token (0 stays as Zero) for (const Ref<RunnerAction>& Action : Actions) { - std::packaged_task<SubmitResult()> Task([this, Action]() { return SubmitAction(Action); }); + const int QueueId = Action->QueueId; + if (QueueId != 0 && QueueTokens.find(QueueId) == QueueTokens.end()) + { + CbObject QueueMeta = Action->GetOwnerSession()->GetQueueMetadata(QueueId); + CbObject QueueConfig = Action->GetOwnerSession()->GetQueueConfig(QueueId); + QueueTokens[QueueId] = EnsureRemoteQueue(QueueId, QueueMeta, QueueConfig); + } + } - Futures.push_back(m_WorkerPool.EnqueueTask(std::move(Task), WorkerThreadPool::EMode::EnableBacklog)); + // Group actions by QueueId + + struct QueueGroup + { + std::vector<Ref<RunnerAction>> Actions; + std::vector<size_t> OriginalIndices; + }; + + std::unordered_map<int, QueueGroup> Groups; + + for (size_t i = 0; i < Actions.size(); ++i) + { + auto& Group = Groups[Actions[i]->QueueId]; + Group.Actions.push_back(Actions[i]); + Group.OriginalIndices.push_back(i); } - std::vector<SubmitResult> Results; - Results.reserve(Futures.size()); + // Submit each group as a batch and map results back to original indices - for (auto& Future : Futures) + std::vector<SubmitResult> Results(Actions.size()); + + for (auto& [QueueId, Group] : Groups) { - Results.push_back(Future.get()); + std::string SubmitUrl = "/jobs"; + if (QueueId != 0) + { + if (Oid Token = QueueTokens[QueueId]; Token != Oid::Zero) + { + SubmitUrl = fmt::format("/queues/{}/jobs", Token); + } + } + + const size_t BatchLimit = size_t(m_MaxBatchSize); + + for (size_t Offset = 0; Offset < Group.Actions.size(); Offset += BatchLimit) + { + size_t End = zen::Min(Offset + BatchLimit, Group.Actions.size()); + + std::vector<Ref<RunnerAction>> Chunk(Group.Actions.begin() + Offset, Group.Actions.begin() + End); + + std::vector<SubmitResult> ChunkResults = SubmitActionBatch(SubmitUrl, Chunk); + + for (size_t j = 0; j < ChunkResults.size(); ++j) + { + Results[Group.OriginalIndices[Offset + j]] = std::move(ChunkResults[j]); + } + } } return Results; @@ -221,6 +318,11 @@ RemoteHttpRunner::SubmitAction(Ref<RunnerAction> Action) // Verify whether we can accept more work + if (!m_AcceptNewActions) + { + return SubmitResult{.IsAccepted = false, .Reason = "runner is shutting down"}; + } + { RwLock::SharedLockScope _{m_RunningLock}; if (m_RemoteRunningMap.size() >= size_t(m_MaxRunningActions)) @@ -275,7 +377,7 @@ RemoteHttpRunner::SubmitAction(Ref<RunnerAction> Action) m_Http.GetBaseUri(), ActionId); - RegisterWorker(Action->Worker.Descriptor); + (void)RegisterWorker(Action->Worker.Descriptor); } else { @@ -384,6 +486,194 @@ RemoteHttpRunner::SubmitAction(Ref<RunnerAction> Action) return {}; } +std::vector<SubmitResult> +RemoteHttpRunner::SubmitActionBatch(const std::string& SubmitUrl, const std::vector<Ref<RunnerAction>>& Actions) +{ + ZEN_TRACE_CPU("RemoteHttpRunner::SubmitActionBatch"); + + if (!m_AcceptNewActions) + { + return std::vector<SubmitResult>(Actions.size(), SubmitResult{.IsAccepted = false, .Reason = "runner is shutting down"}); + } + + // Capacity check + + { + RwLock::SharedLockScope _{m_RunningLock}; + if (m_RemoteRunningMap.size() >= size_t(m_MaxRunningActions)) + { + std::vector<SubmitResult> Results(Actions.size(), SubmitResult{.IsAccepted = false}); + return Results; + } + } + + // Per-action setup and build batch body + + CbObjectWriter Body; + Body.BeginArray("actions"sv); + + for (const Ref<RunnerAction>& Action : Actions) + { + Action->ExecutionLocation = m_HostName; + MaybeDumpAction(Action->ActionLsn, Action->ActionObj); + Body.AddObject(Action->ActionObj); + } + + Body.EndArray(); + + // POST the batch + + HttpClient::Response Response = m_Http.Post(SubmitUrl, Body.Save()); + + if (Response.StatusCode == HttpResponseCode::OK) + { + return ParseBatchResponse(Response, Actions); + } + + if (Response.StatusCode == HttpResponseCode::NotFound) + { + // Server needs attachments — resolve them and retry with a CbPackage + + CbObject NeedObj = Response.AsObject(); + + CbPackage Pkg; + Pkg.SetObject(Body.Save()); + + for (auto& Item : NeedObj["need"sv]) + { + const IoHash NeedHash = Item.AsHash(); + + if (IoBuffer Chunk = m_ChunkResolver.FindChunkByCid(NeedHash)) + { + uint64_t DataRawSize = 0; + IoHash DataRawHash; + CompressedBuffer Compressed = + CompressedBuffer::FromCompressed(SharedBuffer{Chunk}, /* out */ DataRawHash, /* out */ DataRawSize); + + ZEN_ASSERT(DataRawHash == NeedHash); + + Pkg.AddAttachment(CbAttachment(Compressed, NeedHash)); + } + else + { + ZEN_WARN("batch submit: missing attachment {} — falling back to individual submit", NeedHash); + return FallbackToIndividualSubmit(Actions); + } + } + + HttpClient::Response RetryResponse = m_Http.Post(SubmitUrl, Pkg); + + if (RetryResponse.StatusCode == HttpResponseCode::OK) + { + return ParseBatchResponse(RetryResponse, Actions); + } + + ZEN_WARN("batch submit retry failed with {} {} — falling back to individual submit", + (int)RetryResponse.StatusCode, + ToString(RetryResponse.StatusCode)); + return FallbackToIndividualSubmit(Actions); + } + + // Unexpected status or connection error — fall back to individual submission + + if (Response) + { + ZEN_WARN("batch submit to {}{} returned {} {} — falling back to individual submit", + m_Http.GetBaseUri(), + SubmitUrl, + (int)Response.StatusCode, + ToString(Response.StatusCode)); + } + else + { + ZEN_WARN("batch submit to {}{} failed — falling back to individual submit", m_Http.GetBaseUri(), SubmitUrl); + } + + return FallbackToIndividualSubmit(Actions); +} + +std::vector<SubmitResult> +RemoteHttpRunner::ParseBatchResponse(const HttpClient::Response& Response, const std::vector<Ref<RunnerAction>>& Actions) +{ + std::vector<SubmitResult> Results; + Results.reserve(Actions.size()); + + CbObject ResponseObj = Response.AsObject(); + CbArrayView ResultArray = ResponseObj["results"sv].AsArrayView(); + + size_t Index = 0; + for (CbFieldView Field : ResultArray) + { + if (Index >= Actions.size()) + { + break; + } + + CbObjectView Entry = Field.AsObjectView(); + const int32_t LsnField = Entry["lsn"sv].AsInt32(0); + + if (LsnField > 0) + { + HttpRunningAction NewAction; + NewAction.Action = Actions[Index]; + NewAction.RemoteActionLsn = LsnField; + + { + RwLock::ExclusiveLockScope _(m_RunningLock); + m_RemoteRunningMap[LsnField] = std::move(NewAction); + } + + ZEN_DEBUG("batch: scheduled action {} with remote LSN {} (local LSN {})", + Actions[Index]->ActionObj.GetHash(), + LsnField, + Actions[Index]->ActionLsn); + + Actions[Index]->SetActionState(RunnerAction::State::Running); + + Results.push_back(SubmitResult{.IsAccepted = true}); + } + else + { + std::string_view ErrorMsg = Entry["error"sv].AsString(); + Results.push_back(SubmitResult{.IsAccepted = false, .Reason = std::string(ErrorMsg)}); + } + + ++Index; + } + + // If the server returned fewer results than actions, mark the rest as not accepted + while (Results.size() < Actions.size()) + { + Results.push_back(SubmitResult{.IsAccepted = false, .Reason = "no result from server"}); + } + + return Results; +} + +std::vector<SubmitResult> +RemoteHttpRunner::FallbackToIndividualSubmit(const std::vector<Ref<RunnerAction>>& Actions) +{ + std::vector<std::future<SubmitResult>> Futures; + Futures.reserve(Actions.size()); + + for (const Ref<RunnerAction>& Action : Actions) + { + std::packaged_task<SubmitResult()> Task([this, Action]() { return SubmitAction(Action); }); + + Futures.push_back(m_WorkerPool.EnqueueTask(std::move(Task), WorkerThreadPool::EMode::EnableBacklog)); + } + + std::vector<SubmitResult> Results; + Results.reserve(Futures.size()); + + for (auto& Future : Futures) + { + Results.push_back(Future.get()); + } + + return Results; +} + Oid RemoteHttpRunner::EnsureRemoteQueue(int QueueId, const CbObject& Metadata, const CbObject& Config) { @@ -481,6 +771,35 @@ RemoteHttpRunner::GetSubmittedActionCount() return m_RemoteRunningMap.size(); } +////////////////////////////////////////////////////////////////////////// +// +// IWsClientHandler +// + +void +RemoteHttpRunner::OnWsOpen() +{ + ZEN_INFO("WebSocket connected to {}", m_HostName); + m_WsConnected.store(true, std::memory_order_release); +} + +void +RemoteHttpRunner::OnWsMessage([[maybe_unused]] const WebSocketMessage& Msg) +{ + // The message content is a wakeup signal; no parsing needed. + // Signal the monitor thread to sweep completed actions immediately. + m_MonitorThreadEvent.Set(); +} + +void +RemoteHttpRunner::OnWsClose([[maybe_unused]] uint16_t Code, [[maybe_unused]] std::string_view Reason) +{ + ZEN_WARN("WebSocket disconnected from {} (code {})", m_HostName, Code); + m_WsConnected.store(false, std::memory_order_release); +} + +////////////////////////////////////////////////////////////////////////// + void RemoteHttpRunner::MonitorThreadFunction() { @@ -489,28 +808,40 @@ RemoteHttpRunner::MonitorThreadFunction() do { const int NormalWaitingTime = 200; - int WaitTimeMs = NormalWaitingTime; - auto WaitOnce = [&] { return m_MonitorThreadEvent.Wait(WaitTimeMs); }; - auto SweepOnce = [&] { + const int WsWaitingTime = 2000; // Safety-net interval when WS is connected + + int WaitTimeMs = m_WsConnected.load(std::memory_order_relaxed) ? WsWaitingTime : NormalWaitingTime; + auto WaitOnce = [&] { return m_MonitorThreadEvent.Wait(WaitTimeMs); }; + auto SweepOnce = [&] { const size_t RetiredCount = SweepRunningActions(); - m_RunningLock.WithSharedLock([&] { - if (m_RemoteRunningMap.size() > 16) - { - WaitTimeMs = NormalWaitingTime / 4; - } - else - { - if (RetiredCount) + if (m_WsConnected.load(std::memory_order_relaxed)) + { + // WS connected: use long safety-net interval; the WS message + // will wake us immediately for the real work. + WaitTimeMs = WsWaitingTime; + } + else + { + // No WS: adaptive polling as before + m_RunningLock.WithSharedLock([&] { + if (m_RemoteRunningMap.size() > 16) { - WaitTimeMs = NormalWaitingTime / 2; + WaitTimeMs = NormalWaitingTime / 4; } else { - WaitTimeMs = NormalWaitingTime; + if (RetiredCount) + { + WaitTimeMs = NormalWaitingTime / 2; + } + else + { + WaitTimeMs = NormalWaitingTime; + } } - } - }); + }); + } }; while (!WaitOnce()) @@ -518,7 +849,7 @@ RemoteHttpRunner::MonitorThreadFunction() SweepOnce(); } - // Signal received - this may mean we should quit + // Signal received — may be a WS wakeup or a quit signal SweepOnce(); } while (m_MonitorThreadEnabled); diff --git a/src/zencompute/runners/remotehttprunner.h b/src/zencompute/runners/remotehttprunner.h index 9119992a9..c17d0cf2a 100644 --- a/src/zencompute/runners/remotehttprunner.h +++ b/src/zencompute/runners/remotehttprunner.h @@ -14,9 +14,11 @@ # include <zencore/workthreadpool.h> # include <zencore/zencore.h> # include <zenhttp/httpclient.h> +# include <zenhttp/httpwsclient.h> # include <atomic> # include <filesystem> +# include <memory> # include <thread> # include <unordered_map> @@ -32,7 +34,7 @@ namespace zen::compute { */ -class RemoteHttpRunner : public FunctionRunner +class RemoteHttpRunner : public FunctionRunner, private IWsClientHandler { RemoteHttpRunner(RemoteHttpRunner&&) = delete; RemoteHttpRunner& operator=(RemoteHttpRunner&&) = delete; @@ -45,7 +47,7 @@ public: ~RemoteHttpRunner(); virtual void Shutdown() override; - virtual void RegisterWorker(const CbPackage& WorkerPackage) override; + [[nodiscard]] virtual bool RegisterWorker(const CbPackage& WorkerPackage) override; [[nodiscard]] virtual SubmitResult SubmitAction(Ref<RunnerAction> Action) override; [[nodiscard]] virtual bool IsHealthy() override; [[nodiscard]] virtual size_t GetSubmittedActionCount() override; @@ -66,7 +68,9 @@ private: std::string m_BaseUrl; HttpClient m_Http; - int32_t m_MaxRunningActions = 256; // arbitrary limit for testing + std::atomic<bool> m_AcceptNewActions{true}; + int32_t m_MaxRunningActions = 256; // arbitrary limit for testing + int32_t m_MaxBatchSize = 50; struct HttpRunningAction { @@ -92,7 +96,20 @@ private: // creating remote queues. Generated once at construction and never changes. Oid m_InstanceId; + // WebSocket completion notification client + std::unique_ptr<HttpWsClient> m_WsClient; + std::atomic<bool> m_WsConnected{false}; + + // IWsClientHandler + void OnWsOpen() override; + void OnWsMessage(const WebSocketMessage& Msg) override; + void OnWsClose(uint16_t Code, std::string_view Reason) override; + Oid EnsureRemoteQueue(int QueueId, const CbObject& Metadata, const CbObject& Config); + + std::vector<SubmitResult> SubmitActionBatch(const std::string& SubmitUrl, const std::vector<Ref<RunnerAction>>& Actions); + std::vector<SubmitResult> ParseBatchResponse(const HttpClient::Response& Response, const std::vector<Ref<RunnerAction>>& Actions); + std::vector<SubmitResult> FallbackToIndividualSubmit(const std::vector<Ref<RunnerAction>>& Actions); }; } // namespace zen::compute |