aboutsummaryrefslogtreecommitdiff
path: root/src/zencompute
diff options
context:
space:
mode:
authorStefan Boberg <[email protected]>2026-03-18 11:19:10 +0100
committerGitHub Enterprise <[email protected]>2026-03-18 11:19:10 +0100
commiteba410c4168e23d7908827eb34b7cf0c58a5dc48 (patch)
tree3cda8e8f3f81941d3bb5b84a8155350c5bb2068c /src/zencompute
parentbugfix release - v5.7.23 (#851) (diff)
downloadzen-eba410c4168e23d7908827eb34b7cf0c58a5dc48.tar.xz
zen-eba410c4168e23d7908827eb34b7cf0c58a5dc48.zip
Compute batching (#849)
### Compute Batch Submission - Consolidate duplicated action submission logic in `httpcomputeservice` into a single `HandleSubmitAction` supporting both single-action and batch (actions array) payloads - Group actions by queue in `RemoteHttpRunner` and submit as batches with configurable chunk size, falling back to individual submission on failure - Extract shared helpers: `MakeErrorResult`, `ValidateQueueForEnqueue`, `ActivateActionInQueue`, `RemoveActionFromActiveMaps` ### Retracted Action State - Add `Retracted` state to `RunnerAction` for retry-free rescheduling — an explicit request to pull an action back and reschedule it on a different runner without incrementing `RetryCount` - Implement idempotent `RetractAction()` on `RunnerAction` and `ComputeServiceSession` - Add `POST jobs/{lsn}/retract` and `queues/{queueref}/jobs/{lsn}/retract` HTTP endpoints - Add state machine documentation and per-state comments to `RunnerAction` ### Compute Race Fixes - Fix race in `HandleActionUpdates` where actions enqueued between session abandon and scheduler tick were never abandoned, causing `GetActionResult` to return 202 indefinitely - Fix queue `ActiveCount` race where `NotifyQueueActionComplete` was called after releasing `m_ResultsLock`, allowing callers to observe stale counters immediately after `GetActionResult` returned OK ### Logging Optimization and ANSI improvements - Improve `AnsiColorStdoutSink` write efficiency — single write call, dirty-flag flush, `RwLock` instead of `std::mutex` - Move ANSI color emission from sink into formatters via `Formatter::SetColorEnabled()`; remove `ColorRangeStart`/`End` from `LogMessage` - Extract color helpers (`AnsiColorForLevel`, `StripAnsiSgrSequences`) into `helpers.h` - Strip upstream ANSI SGR escapes in non-color output mode. This enables colour in log messages without polluting log files with ANSI control sequences - Move `RotatingFileSink`, `JsonFormatter`, and `FullFormatter` from header-only to pimpl with `.cpp` files ### CLI / Exec Refactoring - Extract `ExecSessionRunner` class from ~920-line `ExecUsingSession` into focused methods and a `ExecSessionConfig` struct - Replace monolithic `ExecCommand` with subcommand-based architecture (`http`, `inproc`, `beacon`, `dump`, `buildlog`) - Allow parent options to appear after subcommand name by parsing subcommand args permissively and forwarding unmatched tokens to the parent parser ### Testing Improvements - Fix `--test-suite` filter being ignored due to accumulation with default wildcard filter - Add test suite banners to test listener output - Made `function.session.abandon_pending` test more robust ### Startup / Reliability Fixes - Fix silent exit when a second zenserver instance detects a port conflict — use `ZEN_CONSOLE_*` for log calls that precede `InitializeLogging()` - Fix two potential SIGSEGV paths during early startup: guard `sentry_options_new()` returning nullptr, and throw on `ZenServerState::Register()` returning nullptr instead of dereferencing - Fail on unrecognized zenserver `--mode` instead of silently defaulting to store ### Other - Show host details (hostname, platform, CPU count, memory) when discovering new compute workers - Move frontend `html.zip` from source tree into build directory - Add format specifications for Compact Binary and Compressed Buffer wire formats - Add `WriteCompactBinaryObject` to zencore - Extended `ConsoleTui` with additional functionality - Add `--vscode` option to `xmake sln` for clangd / `compile_commands.json` support - Disable compute/horde/nomad in release builds (not yet production-ready) - Disable unintended `ASIO_HAS_IO_URING` enablement - Fix crashpad patch missing leading whitespace - Clean up code triggering gcc false positives
Diffstat (limited to 'src/zencompute')
-rw-r--r--src/zencompute/CLAUDE.md39
-rw-r--r--src/zencompute/computeservice.cpp384
-rw-r--r--src/zencompute/httpcomputeservice.cpp760
-rw-r--r--src/zencompute/include/zencompute/computeservice.h69
-rw-r--r--src/zencompute/include/zencompute/httpcomputeservice.h13
-rw-r--r--src/zencompute/runners/functionrunner.cpp44
-rw-r--r--src/zencompute/runners/functionrunner.h53
-rw-r--r--src/zencompute/runners/localrunner.cpp10
-rw-r--r--src/zencompute/runners/localrunner.h2
-rw-r--r--src/zencompute/runners/remotehttprunner.cpp399
-rw-r--r--src/zencompute/runners/remotehttprunner.h23
11 files changed, 1282 insertions, 514 deletions
diff --git a/src/zencompute/CLAUDE.md b/src/zencompute/CLAUDE.md
index f5188123f..a1a39fc3c 100644
--- a/src/zencompute/CLAUDE.md
+++ b/src/zencompute/CLAUDE.md
@@ -46,9 +46,12 @@ Public API entry point. Uses PIMPL (`struct Impl` in computeservice.cpp). Owns:
- Action maps: `m_PendingActions`, `m_RunningMap`, `m_ResultsMap`
- Queue map: `m_Queues` (QueueEntry objects)
- Action history ring: `m_ActionHistory` (bounded deque, default 1000)
+- WebSocket client (`m_OrchestratorWsClient`) subscribed to the orchestrator's `/orch/ws` push for instant worker discovery
**Session states:** Created → Ready → Draining → Paused → Abandoned → Sunset. Both Abandoned and Sunset can be jumped to from any earlier state. Abandoned is used for spot instance termination grace periods — on entry, all pending and running actions are immediately marked as `RunnerAction::State::Abandoned` and running processes are best-effort cancelled. Auto-retry is suppressed while the session is Abandoned. `IsHealthy()` returns false for Abandoned and Sunset.
+**Convenience helpers:** `Ready()`, `Abandon()`, `SetOrchestrator(Endpoint, BasePath)` are inline wrappers for common state transitions and orchestrator configuration.
+
### `RunnerAction` (runners/functionrunner.h)
Shared ref-counted struct representing one action through its lifecycle.
@@ -67,8 +70,11 @@ New → Pending → Submitting → Running → Completed
→ Failed
→ Abandoned
→ Cancelled
+ → Retracted
```
-`SetActionState()` rejects non-forward transitions. The one exception is `ResetActionStateToPending()`, which uses CAS to atomically transition from Failed/Abandoned back to Pending for rescheduling. It clears timestamps from Submitting onward, resets execution fields, increments `RetryCount`, and calls `PostUpdate()` to re-enter the scheduler pipeline.
+`SetActionState()` rejects non-forward transitions (Retracted has the highest ordinal so runner-side transitions cannot override it). `ResetActionStateToPending()` uses CAS to atomically transition from Failed/Abandoned back to Pending for rescheduling — it clears timestamps from Submitting onward, resets execution fields, increments `RetryCount`, and calls `PostUpdate()` to re-enter the scheduler pipeline.
+
+**Retracted state:** An explicit, instigator-initiated request to pull an action back and reschedule it on a different runner (e.g. capacity opened up elsewhere). Unlike Failed/Abandoned auto-retry, rescheduling from Retracted does not increment `RetryCount` since nothing went wrong. Retraction is idempotent and can target Pending, Submitting, or Running actions.
### `LocalProcessRunner` (runners/localrunner.h)
Base for all local execution. Platform runners subclass this and override:
@@ -90,10 +96,29 @@ Base for all local execution. Platform runners subclass this and override:
- macOS: `proc_pidinfo(PROC_PIDTASKINFO)` pti_total_user+system nanoseconds ÷ 1,000,000,000
### `FunctionRunner` / `RunnerGroup` (runners/functionrunner.h)
-Abstract base for runners. `RunnerGroup<T>` holds a vector of runners and load-balances across them using a round-robin atomic index. `BaseRunnerGroup::SubmitActions()` distributes a batch proportionally based on per-runner capacity.
+Abstract base for runners. `RunnerGroup<T>` holds a vector of runners and load-balances across them using a round-robin atomic index. `BaseRunnerGroup::SubmitActions()` distributes a batch proportionally based on per-runner capacity. `SubmitActions()` supports batch submission — actions are grouped and forwarded in chunks.
+
+### `RemoteHttpRunner` (runners/remotehttprunner.h)
+Submits actions to remote zenserver instances over HTTP. Key features:
+- **WebSocket completion notifications**: connects a WS client to `/compute/ws` on the remote. When a message arrives (action completed), the monitor thread wakes immediately instead of polling. Falls back to adaptive polling (200ms→50ms) when WS is unavailable.
+- **Batch submission**: groups actions by queue and submits in configurable chunks (`m_MaxBatchSize`, default 50), falling back to individual submission on failure.
+- **Queue cancellation**: `CancelRemoteQueue()` sends cancel requests to the remote.
+- **Graceful shutdown**: `Shutdown()` closes the WS client, cancels all remote queues, stops the monitor thread, and marks remaining actions as Failed.
### `HttpComputeService` (include/zencompute/httpcomputeservice.h)
-Wraps `ComputeServiceSession` as an HTTP service. All routes are registered in the constructor. Handles CbPackage attachment ingestion from `CidStore` before forwarding to the service.
+Wraps `ComputeServiceSession` as an HTTP service. All routes are registered in the constructor. Handles CbPackage attachment ingestion from `CidStore` before forwarding to the service. Supports both single-action and batch (actions array) payloads via a shared `HandleSubmitAction` helper.
+
+## Orchestrator Discovery
+
+`ComputeServiceSession` discovers remote workers via the orchestrator endpoint (`SetOrchestratorEndpoint()`). Two complementary mechanisms:
+
+1. **Polling** (`UpdateCoordinatorState`): `GET /orch/agents` on the scheduler thread, throttled to every 5s (500ms when no workers are known yet). Discovers new workers and removes stale/unreachable ones.
+
+2. **WebSocket push** (`OrchestratorWsHandler`): connects to `/orch/ws` on the orchestrator at setup time. When the orchestrator broadcasts a state change, the handler sets `m_OrchestratorQueryForced` and signals the scheduler event, bypassing the polling throttle. Falls back silently to polling if the WS connection fails.
+
+`NotifyOrchestratorChanged()` is the public API to trigger an immediate re-query — useful in tests and for external notification sources.
+
+Use `HttpToWsUrl(Endpoint, Path)` from `zenhttp/httpwsclient.h` to convert HTTP(S) endpoints to WebSocket URLs. This helper is shared across all WS client setup sites in the codebase.
## Action Lifecycle (End to End)
@@ -118,6 +143,8 @@ Actions that fail or are abandoned can be automatically retried or manually resc
**Retry limit:** Default of 3, overridable per-queue via the `max_retries` integer field in the queue's `Config` CbObject (set at `CreateQueue` time). Both automatic and manual paths respect this limit.
+**Retraction (API path):** `RetractAction(Lsn)` pulls a Pending/Submitting/Running action back for rescheduling on a different runner. The action transitions to Retracted, then `ResetActionStateToPending()` is called *without* incrementing `RetryCount`. Retraction is idempotent.
+
**Cancelled actions are never retried** — cancellation is an intentional user action, not a transient failure.
## Queue System
@@ -161,8 +188,9 @@ All routes registered in `HttpComputeService` constructor. Prefix is configured
| GET | `jobs/running` | In-flight actions with CPU metrics |
| GET | `jobs/completed` | Actions with results available |
| GET/POST/DELETE | `jobs/{lsn}` | GET: result; POST: reschedule failed action; DELETE: retire |
+| POST | `jobs/{lsn}/retract` | Retract a pending/running action for rescheduling (idempotent) |
| POST | `jobs/{worker}` | Submit action for specific worker |
-| POST | `jobs` | Submit action (worker resolved from descriptor) |
+| POST | `jobs` | Submit action (or batch via `actions` array) |
| GET | `workers` | List worker IDs |
| GET | `workers/all` | All workers with full descriptors |
| GET/POST | `workers/{worker}` | Get/register worker |
@@ -179,8 +207,9 @@ Queue ref is capture(1) in all `queues/{queueref}/...` routes.
| GET | `queues/{queueref}/completed` | Queue's completed results |
| GET | `queues/{queueref}/history` | Queue's action history |
| GET | `queues/{queueref}/running` | Queue's running actions |
-| POST | `queues/{queueref}/jobs` | Submit to queue |
+| POST | `queues/{queueref}/jobs` | Submit to queue (or batch via `actions` array) |
| GET/POST | `queues/{queueref}/jobs/{lsn}` | GET: result; POST: reschedule |
+| POST | `queues/{queueref}/jobs/{lsn}/retract` | Retract action for rescheduling |
| GET/POST | `queues/{queueref}/workers/...` | Worker endpoints (same as global) |
Worker handler logic is extracted into private helpers (`HandleWorkersGet`, `HandleWorkersAllGet`, `HandleWorkerRequest`) shared by top-level and queue-scoped routes.
diff --git a/src/zencompute/computeservice.cpp b/src/zencompute/computeservice.cpp
index 838d741b6..92901de64 100644
--- a/src/zencompute/computeservice.cpp
+++ b/src/zencompute/computeservice.cpp
@@ -33,6 +33,7 @@
# include <zenutil/workerpools.h>
# include <zentelemetry/stats.h>
# include <zenhttp/httpclient.h>
+# include <zenhttp/httpwsclient.h>
# include <set>
# include <deque>
@@ -42,6 +43,7 @@
# include <unordered_set>
ZEN_THIRD_PARTY_INCLUDES_START
+# include <EASTL/fixed_vector.h>
# include <EASTL/hash_set.h>
ZEN_THIRD_PARTY_INCLUDES_END
@@ -95,6 +97,14 @@ using SessionState = ComputeServiceSession::SessionState;
static_assert(ZEN_ARRAY_COUNT(ComputeServiceSession::ActionHistoryEntry::Timestamps) == static_cast<size_t>(RunnerAction::State::_Count));
+static ComputeServiceSession::EnqueueResult
+MakeErrorResult(std::string_view Error)
+{
+ CbObjectWriter Writer;
+ Writer << "error"sv << Error;
+ return {0, Writer.Save()};
+}
+
//////////////////////////////////////////////////////////////////////////
struct ComputeServiceSession::Impl
@@ -130,14 +140,40 @@ struct ComputeServiceSession::Impl
void SetOrchestratorEndpoint(std::string_view Endpoint);
void SetOrchestratorBasePath(std::filesystem::path BasePath);
+ void NotifyOrchestratorChanged();
std::string m_OrchestratorEndpoint;
std::filesystem::path m_OrchestratorBasePath;
Stopwatch m_OrchestratorQueryTimer;
+ std::atomic<bool> m_OrchestratorQueryForced{false};
std::unordered_set<std::string> m_KnownWorkerUris;
void UpdateCoordinatorState();
+ // WebSocket subscription to orchestrator push notifications
+ struct OrchestratorWsHandler : public IWsClientHandler
+ {
+ Impl& Owner;
+
+ explicit OrchestratorWsHandler(Impl& InOwner) : Owner(InOwner) {}
+
+ void OnWsOpen() override
+ {
+ ZEN_LOG_INFO(Owner.m_Log, "orchestrator WebSocket connected");
+ Owner.NotifyOrchestratorChanged();
+ }
+
+ void OnWsMessage(const WebSocketMessage&) override { Owner.NotifyOrchestratorChanged(); }
+
+ void OnWsClose(uint16_t Code, std::string_view Reason) override
+ {
+ ZEN_LOG_WARN(Owner.m_Log, "orchestrator WebSocket closed (code {}: {})", Code, Reason);
+ }
+ };
+
+ std::unique_ptr<OrchestratorWsHandler> m_OrchestratorWsHandler;
+ std::unique_ptr<HttpWsClient> m_OrchestratorWsClient;
+
// Worker registration and discovery
struct FunctionDefinition
@@ -157,6 +193,8 @@ struct ComputeServiceSession::Impl
std::atomic<int32_t> m_ActionsCounter = 0; // sequence number
metrics::Meter m_ArrivalRate;
+ std::atomic<IComputeCompletionObserver*> m_CompletionObserver{nullptr};
+
RwLock m_PendingLock;
std::map<int, Ref<RunnerAction>> m_PendingActions;
@@ -267,6 +305,8 @@ struct ComputeServiceSession::Impl
void DrainQueue(int QueueId);
ComputeServiceSession::EnqueueResult EnqueueActionToQueue(int QueueId, CbObject ActionObject, int Priority);
ComputeServiceSession::EnqueueResult EnqueueResolvedActionToQueue(int QueueId, WorkerDesc Worker, CbObject ActionObj, int Priority);
+ ComputeServiceSession::EnqueueResult ValidateQueueForEnqueue(int QueueId, Ref<QueueEntry>& OutQueue);
+ void ActivateActionInQueue(const Ref<QueueEntry>& Queue, int Lsn);
void GetQueueCompleted(int QueueId, CbWriter& Cbo);
void NotifyQueueActionComplete(int QueueId, int Lsn, RunnerAction::State ActionState);
void ExpireCompletedQueues();
@@ -292,11 +332,13 @@ struct ComputeServiceSession::Impl
void HandleActionUpdates();
void PostUpdate(RunnerAction* Action);
+ void RemoveActionFromActiveMaps(int ActionLsn);
static constexpr int kDefaultMaxRetries = 3;
int GetMaxRetriesForQueue(int QueueId);
ComputeServiceSession::RescheduleResult RescheduleAction(int ActionLsn);
+ ComputeServiceSession::RescheduleResult RetractAction(int ActionLsn);
ActionCounts GetActionCounts()
{
@@ -449,6 +491,28 @@ void
ComputeServiceSession::Impl::SetOrchestratorEndpoint(std::string_view Endpoint)
{
m_OrchestratorEndpoint = Endpoint;
+
+ // Subscribe to orchestrator WebSocket push so we discover worker changes
+ // immediately instead of waiting for the next polling cycle.
+ try
+ {
+ std::string WsUrl = HttpToWsUrl(Endpoint, "/orch/ws");
+
+ m_OrchestratorWsHandler = std::make_unique<OrchestratorWsHandler>(*this);
+
+ HttpWsClientSettings WsSettings;
+ WsSettings.LogCategory = "orch_disc_ws";
+ WsSettings.ConnectTimeout = std::chrono::milliseconds{3000};
+
+ m_OrchestratorWsClient = std::make_unique<HttpWsClient>(WsUrl, *m_OrchestratorWsHandler, WsSettings);
+ m_OrchestratorWsClient->Connect();
+ }
+ catch (const std::exception& Ex)
+ {
+ ZEN_WARN("failed to connect orchestrator WebSocket, falling back to polling: {}", Ex.what());
+ m_OrchestratorWsClient.reset();
+ m_OrchestratorWsHandler.reset();
+ }
}
void
@@ -458,6 +522,13 @@ ComputeServiceSession::Impl::SetOrchestratorBasePath(std::filesystem::path BaseP
}
void
+ComputeServiceSession::Impl::NotifyOrchestratorChanged()
+{
+ m_OrchestratorQueryForced.store(true, std::memory_order_relaxed);
+ m_SchedulingThreadEvent.Set();
+}
+
+void
ComputeServiceSession::Impl::UpdateCoordinatorState()
{
ZEN_TRACE_CPU("ComputeServiceSession::UpdateCoordinatorState");
@@ -467,10 +538,14 @@ ComputeServiceSession::Impl::UpdateCoordinatorState()
}
// Poll faster when we have no discovered workers yet so remote runners come online quickly
- const uint64_t PollIntervalMs = m_KnownWorkerUris.empty() ? 500 : 5000;
- if (m_OrchestratorQueryTimer.GetElapsedTimeMs() < PollIntervalMs)
+ const bool Forced = m_OrchestratorQueryForced.exchange(false, std::memory_order_relaxed);
+ if (!Forced)
{
- return;
+ const uint64_t PollIntervalMs = m_KnownWorkerUris.empty() ? 500 : 5000;
+ if (m_OrchestratorQueryTimer.GetElapsedTimeMs() < PollIntervalMs)
+ {
+ return;
+ }
}
m_OrchestratorQueryTimer.Reset();
@@ -520,7 +595,24 @@ ComputeServiceSession::Impl::UpdateCoordinatorState()
continue;
}
- ZEN_INFO("discovered new worker at {}", UriStr);
+ std::string_view Hostname = Worker["hostname"sv].AsString();
+ std::string_view Platform = Worker["platform"sv].AsString();
+ int Cpus = Worker["cpus"sv].AsInt32();
+ uint64_t MemTotal = Worker["memory_total"sv].AsUInt64();
+
+ if (!Hostname.empty())
+ {
+ ZEN_INFO("discovered new worker at {} ({}, {}, {} cpus, {:.1f} GB)",
+ UriStr,
+ Hostname,
+ Platform,
+ Cpus,
+ static_cast<double>(MemTotal) / (1024.0 * 1024.0 * 1024.0));
+ }
+ else
+ {
+ ZEN_INFO("discovered new worker at {}", UriStr);
+ }
m_KnownWorkerUris.insert(UriStr);
@@ -598,6 +690,15 @@ ComputeServiceSession::Impl::Shutdown()
{
RequestStateTransition(SessionState::Sunset);
+ // Close orchestrator WebSocket before stopping the scheduler thread
+ // to prevent callbacks into a shutting-down scheduler.
+ if (m_OrchestratorWsClient)
+ {
+ m_OrchestratorWsClient->Close();
+ m_OrchestratorWsClient.reset();
+ }
+ m_OrchestratorWsHandler.reset();
+
m_SchedulingThreadEnabled = false;
m_SchedulingThreadEvent.Set();
if (m_SchedulingThread.joinable())
@@ -720,8 +821,14 @@ ComputeServiceSession::Impl::RegisterWorker(CbPackage Worker)
// different descriptor. Thus we only need to call this the first time, when the
// worker is added
- m_LocalRunnerGroup.RegisterWorker(Worker);
- m_RemoteRunnerGroup.RegisterWorker(Worker);
+ if (!m_LocalRunnerGroup.RegisterWorker(Worker))
+ {
+ ZEN_WARN("failed to register worker {} on one or more local runners", WorkerId);
+ }
+ if (!m_RemoteRunnerGroup.RegisterWorker(Worker))
+ {
+ ZEN_WARN("failed to register worker {} on one or more remote runners", WorkerId);
+ }
if (m_Recorder)
{
@@ -767,7 +874,10 @@ ComputeServiceSession::Impl::SyncWorkersToRunner(FunctionRunner& Runner)
for (const CbPackage& Worker : Workers)
{
- Runner.RegisterWorker(Worker);
+ if (!Runner.RegisterWorker(Worker))
+ {
+ ZEN_WARN("failed to sync worker {} to runner", Worker.GetObjectHash());
+ }
}
}
@@ -868,9 +978,7 @@ ComputeServiceSession::Impl::EnqueueResolvedAction(int QueueId, WorkerDesc Worke
if (m_SessionState.load(std::memory_order_relaxed) != SessionState::Ready)
{
- CbObjectWriter Writer;
- Writer << "error"sv << fmt::format("session is not accepting actions (state: {})", ToString(m_SessionState.load()));
- return {0, Writer.Save()};
+ return MakeErrorResult(fmt::format("session is not accepting actions (state: {})", ToString(m_SessionState.load())));
}
const int ActionLsn = ++m_ActionsCounter;
@@ -1258,42 +1366,51 @@ ComputeServiceSession::Impl::DrainQueue(int QueueId)
}
ComputeServiceSession::EnqueueResult
-ComputeServiceSession::Impl::EnqueueActionToQueue(int QueueId, CbObject ActionObject, int Priority)
+ComputeServiceSession::Impl::ValidateQueueForEnqueue(int QueueId, Ref<QueueEntry>& OutQueue)
{
- Ref<QueueEntry> Queue = FindQueue(QueueId);
+ OutQueue = FindQueue(QueueId);
- if (!Queue)
+ if (!OutQueue)
{
- CbObjectWriter Writer;
- Writer << "error"sv
- << "queue not found"sv;
- return {0, Writer.Save()};
+ return MakeErrorResult("queue not found"sv);
}
- QueueState QState = Queue->State.load();
+ QueueState QState = OutQueue->State.load();
if (QState == QueueState::Cancelled)
{
- CbObjectWriter Writer;
- Writer << "error"sv
- << "queue is cancelled"sv;
- return {0, Writer.Save()};
+ return MakeErrorResult("queue is cancelled"sv);
}
if (QState == QueueState::Draining)
{
- CbObjectWriter Writer;
- Writer << "error"sv
- << "queue is draining"sv;
- return {0, Writer.Save()};
+ return MakeErrorResult("queue is draining"sv);
+ }
+
+ return {};
+}
+
+void
+ComputeServiceSession::Impl::ActivateActionInQueue(const Ref<QueueEntry>& Queue, int Lsn)
+{
+ Queue->m_Lock.WithExclusiveLock([&] { Queue->ActiveLsns.insert(Lsn); });
+ Queue->ActiveCount.fetch_add(1, std::memory_order_relaxed);
+ Queue->IdleSince.store(0, std::memory_order_relaxed);
+}
+
+ComputeServiceSession::EnqueueResult
+ComputeServiceSession::Impl::EnqueueActionToQueue(int QueueId, CbObject ActionObject, int Priority)
+{
+ Ref<QueueEntry> Queue;
+ if (EnqueueResult Error = ValidateQueueForEnqueue(QueueId, Queue); Error.ResponseMessage)
+ {
+ return Error;
}
EnqueueResult Result = EnqueueAction(QueueId, ActionObject, Priority);
if (Result.Lsn != 0)
{
- Queue->m_Lock.WithExclusiveLock([&] { Queue->ActiveLsns.insert(Result.Lsn); });
- Queue->ActiveCount.fetch_add(1, std::memory_order_relaxed);
- Queue->IdleSince.store(0, std::memory_order_relaxed);
+ ActivateActionInQueue(Queue, Result.Lsn);
}
return Result;
@@ -1302,40 +1419,17 @@ ComputeServiceSession::Impl::EnqueueActionToQueue(int QueueId, CbObject ActionOb
ComputeServiceSession::EnqueueResult
ComputeServiceSession::Impl::EnqueueResolvedActionToQueue(int QueueId, WorkerDesc Worker, CbObject ActionObj, int Priority)
{
- Ref<QueueEntry> Queue = FindQueue(QueueId);
-
- if (!Queue)
- {
- CbObjectWriter Writer;
- Writer << "error"sv
- << "queue not found"sv;
- return {0, Writer.Save()};
- }
-
- QueueState QState = Queue->State.load();
- if (QState == QueueState::Cancelled)
+ Ref<QueueEntry> Queue;
+ if (EnqueueResult Error = ValidateQueueForEnqueue(QueueId, Queue); Error.ResponseMessage)
{
- CbObjectWriter Writer;
- Writer << "error"sv
- << "queue is cancelled"sv;
- return {0, Writer.Save()};
- }
-
- if (QState == QueueState::Draining)
- {
- CbObjectWriter Writer;
- Writer << "error"sv
- << "queue is draining"sv;
- return {0, Writer.Save()};
+ return Error;
}
EnqueueResult Result = EnqueueResolvedAction(QueueId, Worker, ActionObj, Priority);
if (Result.Lsn != 0)
{
- Queue->m_Lock.WithExclusiveLock([&] { Queue->ActiveLsns.insert(Result.Lsn); });
- Queue->ActiveCount.fetch_add(1, std::memory_order_relaxed);
- Queue->IdleSince.store(0, std::memory_order_relaxed);
+ ActivateActionInQueue(Queue, Result.Lsn);
}
return Result;
@@ -1770,6 +1864,68 @@ ComputeServiceSession::Impl::RescheduleAction(int ActionLsn)
return {.Success = true, .RetryCount = NewRetryCount};
}
+ComputeServiceSession::RescheduleResult
+ComputeServiceSession::Impl::RetractAction(int ActionLsn)
+{
+ Ref<RunnerAction> Action;
+ bool WasRunning = false;
+
+ // Look for the action in pending or running maps
+ m_RunningLock.WithSharedLock([&] {
+ if (auto It = m_RunningMap.find(ActionLsn); It != m_RunningMap.end())
+ {
+ Action = It->second;
+ WasRunning = true;
+ }
+ });
+
+ if (!Action)
+ {
+ m_PendingLock.WithSharedLock([&] {
+ if (auto It = m_PendingActions.find(ActionLsn); It != m_PendingActions.end())
+ {
+ Action = It->second;
+ }
+ });
+ }
+
+ if (!Action)
+ {
+ return {.Success = false, .Error = "Action not found in pending or running maps"};
+ }
+
+ if (!Action->RetractAction())
+ {
+ return {.Success = false, .Error = "Action cannot be retracted from its current state"};
+ }
+
+ // If the action was running, send a cancellation signal to the runner
+ if (WasRunning)
+ {
+ m_LocalRunnerGroup.CancelAction(ActionLsn);
+ }
+
+ ZEN_INFO("action {} ({}) retract requested", Action->ActionId, ActionLsn);
+ return {.Success = true, .RetryCount = Action->RetryCount.load(std::memory_order_relaxed)};
+}
+
+void
+ComputeServiceSession::Impl::RemoveActionFromActiveMaps(int ActionLsn)
+{
+ m_RunningLock.WithExclusiveLock([&] {
+ m_PendingLock.WithExclusiveLock([&] {
+ if (auto FindIt = m_RunningMap.find(ActionLsn); FindIt == m_RunningMap.end())
+ {
+ m_PendingActions.erase(ActionLsn);
+ }
+ else
+ {
+ m_RunningMap.erase(FindIt);
+ }
+ });
+ });
+}
+
void
ComputeServiceSession::Impl::HandleActionUpdates()
{
@@ -1781,6 +1937,10 @@ ComputeServiceSession::Impl::HandleActionUpdates()
std::unordered_set<int> SeenLsn;
+ // Collect terminal action notifications for the completion observer.
+ // Inline capacity of 64 avoids heap allocation in the common case.
+ eastl::fixed_vector<IComputeCompletionObserver::CompletedActionNotification, 64> TerminalBatch;
+
// Process each action's latest state, deduplicating by LSN.
//
// This is safe because state transitions are monotonically increasing by enum
@@ -1798,7 +1958,23 @@ ComputeServiceSession::Impl::HandleActionUpdates()
{
// Newly enqueued — add to pending map for scheduling
case RunnerAction::State::Pending:
- m_PendingLock.WithExclusiveLock([&] { m_PendingActions.insert({ActionLsn, Action}); });
+ // Guard against a race where the session is abandoned between
+ // EnqueueAction (which calls PostUpdate) and this scheduler
+ // tick. AbandonAllActions() only scans m_PendingActions, so it
+ // misses actions still in m_UpdatedActions at the time the
+ // session transitions. Detect that here and immediately abandon
+ // rather than inserting into the pending map, where they would
+ // otherwise be stuck indefinitely.
+ if (m_SessionState.load(std::memory_order_relaxed) >= SessionState::Abandoned)
+ {
+ Action->SetActionState(RunnerAction::State::Abandoned);
+ // SetActionState calls PostUpdate; the Abandoned action
+ // will be processed as a terminal on the next scheduler pass.
+ }
+ else
+ {
+ m_PendingLock.WithExclusiveLock([&] { m_PendingActions.insert({ActionLsn, Action}); });
+ }
break;
// Async submission in progress — remains in pending map
@@ -1816,6 +1992,15 @@ ComputeServiceSession::Impl::HandleActionUpdates()
ZEN_DEBUG("action {} ({}) RUNNING", Action->ActionId, ActionLsn);
break;
+ // Retracted — pull back for rescheduling without counting against retry limit
+ case RunnerAction::State::Retracted:
+ {
+ RemoveActionFromActiveMaps(ActionLsn);
+ Action->ResetActionStateToPending();
+ ZEN_INFO("action {} ({}) retracted for rescheduling", Action->ActionId, ActionLsn);
+ break;
+ }
+
// Terminal states — move to results, record history, notify queue
case RunnerAction::State::Completed:
case RunnerAction::State::Failed:
@@ -1834,19 +2019,7 @@ ComputeServiceSession::Impl::HandleActionUpdates()
if (Action->RetryCount.load(std::memory_order_relaxed) < MaxRetries)
{
- // Remove from whichever active map the action is in before resetting
- m_RunningLock.WithExclusiveLock([&] {
- m_PendingLock.WithExclusiveLock([&] {
- if (auto FindIt = m_RunningMap.find(ActionLsn); FindIt == m_RunningMap.end())
- {
- m_PendingActions.erase(ActionLsn);
- }
- else
- {
- m_RunningMap.erase(FindIt);
- }
- });
- });
+ RemoveActionFromActiveMaps(ActionLsn);
// Reset triggers PostUpdate() which re-enters the action as Pending
Action->ResetActionStateToPending();
@@ -1861,19 +2034,14 @@ ComputeServiceSession::Impl::HandleActionUpdates()
}
}
- // Remove from whichever active map the action is in
- m_RunningLock.WithExclusiveLock([&] {
- m_PendingLock.WithExclusiveLock([&] {
- if (auto FindIt = m_RunningMap.find(ActionLsn); FindIt == m_RunningMap.end())
- {
- m_PendingActions.erase(ActionLsn);
- }
- else
- {
- m_RunningMap.erase(FindIt);
- }
- });
- });
+ RemoveActionFromActiveMaps(ActionLsn);
+
+ // Update queue counters BEFORE publishing the result into
+ // m_ResultsMap. GetActionResult erases from m_ResultsMap
+ // under m_ResultsLock, so if we updated counters after
+ // releasing that lock, a caller could observe ActiveCount
+ // still at 1 immediately after GetActionResult returned OK.
+ NotifyQueueActionComplete(Action->QueueId, ActionLsn, TerminalState);
m_ResultsLock.WithExclusiveLock([&] {
m_ResultsMap[ActionLsn] = Action;
@@ -1902,16 +2070,46 @@ ComputeServiceSession::Impl::HandleActionUpdates()
});
m_RetiredCount.fetch_add(1);
m_ResultRate.Mark(1);
+ {
+ using ObserverState = IComputeCompletionObserver::ActionState;
+ ObserverState NotifyState{};
+ switch (TerminalState)
+ {
+ case RunnerAction::State::Completed:
+ NotifyState = ObserverState::Completed;
+ break;
+ case RunnerAction::State::Failed:
+ NotifyState = ObserverState::Failed;
+ break;
+ case RunnerAction::State::Abandoned:
+ NotifyState = ObserverState::Abandoned;
+ break;
+ case RunnerAction::State::Cancelled:
+ NotifyState = ObserverState::Cancelled;
+ break;
+ default:
+ break;
+ }
+ TerminalBatch.push_back({ActionLsn, NotifyState});
+ }
ZEN_DEBUG("action {} ({}) RUNNING -> COMPLETED with {}",
Action->ActionId,
ActionLsn,
TerminalState == RunnerAction::State::Completed ? "SUCCESS" : "FAILURE");
- NotifyQueueActionComplete(Action->QueueId, ActionLsn, TerminalState);
break;
}
}
}
}
+
+ // Notify the completion observer, if any, about all terminal actions in this batch.
+ if (!TerminalBatch.empty())
+ {
+ if (IComputeCompletionObserver* Observer = m_CompletionObserver.load(std::memory_order_acquire))
+ {
+ Observer->OnActionsCompleted({TerminalBatch.data(), TerminalBatch.size()});
+ }
+ }
}
size_t
@@ -2014,6 +2212,12 @@ ComputeServiceSession::SetOrchestratorBasePath(std::filesystem::path BasePath)
}
void
+ComputeServiceSession::NotifyOrchestratorChanged()
+{
+ m_Impl->NotifyOrchestratorChanged();
+}
+
+void
ComputeServiceSession::StartRecording(ChunkResolver& InResolver, const std::filesystem::path& RecordingPath)
{
m_Impl->StartRecording(InResolver, RecordingPath);
@@ -2194,6 +2398,12 @@ ComputeServiceSession::RescheduleAction(int ActionLsn)
return m_Impl->RescheduleAction(ActionLsn);
}
+ComputeServiceSession::RescheduleResult
+ComputeServiceSession::RetractAction(int ActionLsn)
+{
+ return m_Impl->RetractAction(ActionLsn);
+}
+
std::vector<ComputeServiceSession::RunningActionInfo>
ComputeServiceSession::GetRunningActions()
{
@@ -2219,6 +2429,12 @@ ComputeServiceSession::GetCompleted(CbWriter& Cbo)
}
void
+ComputeServiceSession::SetCompletionObserver(IComputeCompletionObserver* Observer)
+{
+ m_Impl->m_CompletionObserver.store(Observer, std::memory_order_release);
+}
+
+void
ComputeServiceSession::PostUpdate(RunnerAction* Action)
{
m_Impl->PostUpdate(Action);
diff --git a/src/zencompute/httpcomputeservice.cpp b/src/zencompute/httpcomputeservice.cpp
index e82a40781..bdfd9d197 100644
--- a/src/zencompute/httpcomputeservice.cpp
+++ b/src/zencompute/httpcomputeservice.cpp
@@ -16,6 +16,7 @@
# include <zencore/iobuffer.h>
# include <zencore/iohash.h>
# include <zencore/logging.h>
+# include <zencore/string.h>
# include <zencore/system.h>
# include <zencore/thread.h>
# include <zencore/trace.h>
@@ -23,8 +24,10 @@
# include <zenstore/cidstore.h>
# include <zentelemetry/stats.h>
+# include <algorithm>
# include <span>
# include <unordered_map>
+# include <vector>
using namespace std::literals;
@@ -50,6 +53,11 @@ struct HttpComputeService::Impl
ComputeServiceSession m_ComputeService;
SystemMetricsTracker m_MetricsTracker;
+ // WebSocket connections (completion push)
+
+ RwLock m_WsConnectionsLock;
+ std::vector<Ref<WebSocketConnection>> m_WsConnections;
+
// Metrics
metrics::OperationTiming m_HttpRequests;
@@ -91,6 +99,12 @@ struct HttpComputeService::Impl
void HandleWorkersAllGet(HttpServerRequest& HttpReq);
void WriteQueueDescription(CbWriter& Cbo, int QueueId, const ComputeServiceSession::QueueStatus& Status);
void HandleWorkerRequest(HttpServerRequest& HttpReq, const IoHash& WorkerId);
+ void HandleSubmitAction(HttpServerRequest& HttpReq, int QueueId, int Priority, const WorkerDesc* Worker);
+
+ // WebSocket / observer
+ void OnWebSocketOpen(Ref<WebSocketConnection> Connection);
+ void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code);
+ void OnActionsCompleted(std::span<const IComputeCompletionObserver::CompletedActionNotification> Actions);
void RegisterRoutes();
@@ -110,6 +124,7 @@ struct HttpComputeService::Impl
m_ComputeService.WaitUntilReady();
m_StatsService.RegisterHandler("compute", *m_Self);
RegisterRoutes();
+ m_ComputeService.SetCompletionObserver(m_Self);
}
};
@@ -149,7 +164,7 @@ HttpComputeService::Impl::RegisterRoutes()
return HttpReq.WriteResponse(HttpResponseCode::Forbidden);
}
- bool Success = m_ComputeService.RequestStateTransition(ComputeServiceSession::SessionState::Abandoned);
+ bool Success = m_ComputeService.Abandon();
if (Success)
{
@@ -325,6 +340,29 @@ HttpComputeService::Impl::RegisterRoutes()
HttpVerb::kGet | HttpVerb::kPost);
m_Router.RegisterRoute(
+ "jobs/{lsn}/retract",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+ const int ActionLsn = ParseInt<int>(Req.GetCapture(1)).value_or(0);
+
+ auto Result = m_ComputeService.RetractAction(ActionLsn);
+
+ CbObjectWriter Cbo;
+ if (Result.Success)
+ {
+ Cbo << "success"sv << true;
+ Cbo << "lsn"sv << ActionLsn;
+ HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save());
+ }
+ else
+ {
+ Cbo << "error"sv << Result.Error;
+ HttpReq.WriteResponse(HttpResponseCode::Conflict, Cbo.Save());
+ }
+ },
+ HttpVerb::kPost);
+
+ m_Router.RegisterRoute(
"jobs/{worker}/{action}", // This route is inefficient, and is only here for backwards compatibility. The preferred path is the
// one which uses the scheduled action lsn for lookups
[this](HttpRouterRequest& Req) {
@@ -373,127 +411,7 @@ HttpComputeService::Impl::RegisterRoutes()
RequestPriority = ParseInt<int>(PriorityParam).value_or(-1);
}
- switch (HttpReq.RequestVerb())
- {
- case HttpVerb::kGet:
- // TODO: return status of all pending or executing jobs
- break;
-
- case HttpVerb::kPost:
- switch (HttpReq.RequestContentType())
- {
- case HttpContentType::kCbObject:
- {
- // This operation takes the proposed job spec and identifies which
- // chunks are not present on this server. This list is then returned in
- // the "need" list in the response
-
- IoBuffer Payload = HttpReq.ReadPayload();
- CbObject ActionObj = LoadCompactBinaryObject(Payload);
-
- std::vector<IoHash> NeedList;
-
- ActionObj.IterateAttachments([&](CbFieldView Field) {
- const IoHash FileHash = Field.AsHash();
-
- if (!m_CidStore.ContainsChunk(FileHash))
- {
- NeedList.push_back(FileHash);
- }
- });
-
- if (NeedList.empty())
- {
- // We already have everything, enqueue the action for execution
-
- if (ComputeServiceSession::EnqueueResult Result =
- m_ComputeService.EnqueueResolvedAction(Worker, ActionObj, RequestPriority))
- {
- ZEN_DEBUG("action {} accepted (lsn {})", ActionObj.GetHash(), Result.Lsn);
-
- HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage);
- }
-
- return;
- }
-
- CbObjectWriter Cbo;
- Cbo.BeginArray("need");
-
- for (const IoHash& Hash : NeedList)
- {
- Cbo << Hash;
- }
-
- Cbo.EndArray();
- CbObject Response = Cbo.Save();
-
- return HttpReq.WriteResponse(HttpResponseCode::NotFound, Response);
- }
- break;
-
- case HttpContentType::kCbPackage:
- {
- CbPackage Action = HttpReq.ReadPayloadPackage();
- CbObject ActionObj = Action.GetObject();
-
- std::span<const CbAttachment> Attachments = Action.GetAttachments();
-
- int AttachmentCount = 0;
- int NewAttachmentCount = 0;
- uint64_t TotalAttachmentBytes = 0;
- uint64_t TotalNewBytes = 0;
-
- for (const CbAttachment& Attachment : Attachments)
- {
- ZEN_ASSERT(Attachment.IsCompressedBinary());
-
- const IoHash DataHash = Attachment.GetHash();
- CompressedBuffer DataView = Attachment.AsCompressedBinary();
-
- ZEN_UNUSED(DataHash);
-
- const uint64_t CompressedSize = DataView.GetCompressedSize();
-
- TotalAttachmentBytes += CompressedSize;
- ++AttachmentCount;
-
- const CidStore::InsertResult InsertResult =
- m_CidStore.AddChunk(DataView.GetCompressed().Flatten().AsIoBuffer(), DataHash);
-
- if (InsertResult.New)
- {
- TotalNewBytes += CompressedSize;
- ++NewAttachmentCount;
- }
- }
-
- if (ComputeServiceSession::EnqueueResult Result =
- m_ComputeService.EnqueueResolvedAction(Worker, ActionObj, RequestPriority))
- {
- ZEN_DEBUG("accepted action {} (lsn {}): {} in {} attachments. {} new ({} attachments)",
- ActionObj.GetHash(),
- Result.Lsn,
- zen::NiceBytes(TotalAttachmentBytes),
- AttachmentCount,
- zen::NiceBytes(TotalNewBytes),
- NewAttachmentCount);
-
- HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage);
- }
-
- return;
- }
- break;
-
- default:
- break;
- }
- break;
-
- default:
- break;
- }
+ HandleSubmitAction(HttpReq, 0, RequestPriority, &Worker);
},
HttpVerb::kPost);
@@ -511,118 +429,7 @@ HttpComputeService::Impl::RegisterRoutes()
RequestPriority = ParseInt<int>(PriorityParam).value_or(-1);
}
- // Resolve worker
-
- //
-
- switch (HttpReq.RequestContentType())
- {
- case HttpContentType::kCbObject:
- {
- // This operation takes the proposed job spec and identifies which
- // chunks are not present on this server. This list is then returned in
- // the "need" list in the response
-
- IoBuffer Payload = HttpReq.ReadPayload();
- CbObject ActionObj = LoadCompactBinaryObject(Payload);
-
- std::vector<IoHash> NeedList;
-
- ActionObj.IterateAttachments([&](CbFieldView Field) {
- const IoHash FileHash = Field.AsHash();
-
- if (!m_CidStore.ContainsChunk(FileHash))
- {
- NeedList.push_back(FileHash);
- }
- });
-
- if (NeedList.empty())
- {
- // We already have everything, enqueue the action for execution
-
- if (ComputeServiceSession::EnqueueResult Result = m_ComputeService.EnqueueAction(ActionObj, RequestPriority))
- {
- ZEN_DEBUG("action accepted (lsn {})", Result.Lsn);
-
- return HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage);
- }
- else
- {
- // Could not resolve?
- return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage);
- }
- }
-
- CbObjectWriter Cbo;
- Cbo.BeginArray("need");
-
- for (const IoHash& Hash : NeedList)
- {
- Cbo << Hash;
- }
-
- Cbo.EndArray();
- CbObject Response = Cbo.Save();
-
- return HttpReq.WriteResponse(HttpResponseCode::NotFound, Response);
- }
-
- case HttpContentType::kCbPackage:
- {
- CbPackage Action = HttpReq.ReadPayloadPackage();
- CbObject ActionObj = Action.GetObject();
-
- std::span<const CbAttachment> Attachments = Action.GetAttachments();
-
- int AttachmentCount = 0;
- int NewAttachmentCount = 0;
- uint64_t TotalAttachmentBytes = 0;
- uint64_t TotalNewBytes = 0;
-
- for (const CbAttachment& Attachment : Attachments)
- {
- ZEN_ASSERT(Attachment.IsCompressedBinary());
-
- const IoHash DataHash = Attachment.GetHash();
- CompressedBuffer DataView = Attachment.AsCompressedBinary();
-
- ZEN_UNUSED(DataHash);
-
- const uint64_t CompressedSize = DataView.GetCompressedSize();
-
- TotalAttachmentBytes += CompressedSize;
- ++AttachmentCount;
-
- const CidStore::InsertResult InsertResult =
- m_CidStore.AddChunk(DataView.GetCompressed().Flatten().AsIoBuffer(), DataHash);
-
- if (InsertResult.New)
- {
- TotalNewBytes += CompressedSize;
- ++NewAttachmentCount;
- }
- }
-
- if (ComputeServiceSession::EnqueueResult Result = m_ComputeService.EnqueueAction(ActionObj, RequestPriority))
- {
- ZEN_DEBUG("accepted action (lsn {}): {} in {} attachments. {} new ({} attachments)",
- Result.Lsn,
- zen::NiceBytes(TotalAttachmentBytes),
- AttachmentCount,
- zen::NiceBytes(TotalNewBytes),
- NewAttachmentCount);
-
- HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage);
- }
- else
- {
- // Could not resolve?
- return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage);
- }
- }
- return;
- }
+ HandleSubmitAction(HttpReq, 0, RequestPriority, nullptr);
},
HttpVerb::kPost);
@@ -1090,72 +897,7 @@ HttpComputeService::Impl::RegisterRoutes()
RequestPriority = ParseInt<int>(PriorityParam).value_or(-1);
}
- switch (HttpReq.RequestContentType())
- {
- case HttpContentType::kCbObject:
- {
- IoBuffer Payload = HttpReq.ReadPayload();
- CbObject ActionObj = LoadCompactBinaryObject(Payload);
-
- std::vector<IoHash> NeedList;
-
- if (!CheckAttachments(ActionObj, NeedList))
- {
- CbObjectWriter Cbo;
- Cbo.BeginArray("need");
-
- for (const IoHash& Hash : NeedList)
- {
- Cbo << Hash;
- }
-
- Cbo.EndArray();
-
- return HttpReq.WriteResponse(HttpResponseCode::NotFound, Cbo.Save());
- }
-
- if (ComputeServiceSession::EnqueueResult Result =
- m_ComputeService.EnqueueResolvedActionToQueue(QueueId, Worker, ActionObj, RequestPriority))
- {
- ZEN_DEBUG("queue {}: action {} accepted (lsn {})", QueueId, ActionObj.GetHash(), Result.Lsn);
- return HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage);
- }
- else
- {
- return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage);
- }
- }
-
- case HttpContentType::kCbPackage:
- {
- CbPackage Action = HttpReq.ReadPayloadPackage();
- CbObject ActionObj = Action.GetObject();
-
- IngestStats Stats = IngestPackageAttachments(Action);
-
- if (ComputeServiceSession::EnqueueResult Result =
- m_ComputeService.EnqueueResolvedActionToQueue(QueueId, Worker, ActionObj, RequestPriority))
- {
- ZEN_DEBUG("queue {}: accepted action {} (lsn {}): {} in {} attachments. {} new ({} attachments)",
- QueueId,
- ActionObj.GetHash(),
- Result.Lsn,
- zen::NiceBytes(Stats.Bytes),
- Stats.Count,
- zen::NiceBytes(Stats.NewBytes),
- Stats.NewCount);
-
- return HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage);
- }
- else
- {
- return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage);
- }
- }
-
- default:
- break;
- }
+ HandleSubmitAction(HttpReq, QueueId, RequestPriority, &Worker);
},
HttpVerb::kPost);
@@ -1178,71 +920,7 @@ HttpComputeService::Impl::RegisterRoutes()
RequestPriority = ParseInt<int>(PriorityParam).value_or(-1);
}
- switch (HttpReq.RequestContentType())
- {
- case HttpContentType::kCbObject:
- {
- IoBuffer Payload = HttpReq.ReadPayload();
- CbObject ActionObj = LoadCompactBinaryObject(Payload);
-
- std::vector<IoHash> NeedList;
-
- if (!CheckAttachments(ActionObj, NeedList))
- {
- CbObjectWriter Cbo;
- Cbo.BeginArray("need");
-
- for (const IoHash& Hash : NeedList)
- {
- Cbo << Hash;
- }
-
- Cbo.EndArray();
-
- return HttpReq.WriteResponse(HttpResponseCode::NotFound, Cbo.Save());
- }
-
- if (ComputeServiceSession::EnqueueResult Result =
- m_ComputeService.EnqueueActionToQueue(QueueId, ActionObj, RequestPriority))
- {
- ZEN_DEBUG("queue {}: action accepted (lsn {})", QueueId, Result.Lsn);
- return HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage);
- }
- else
- {
- return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage);
- }
- }
-
- case HttpContentType::kCbPackage:
- {
- CbPackage Action = HttpReq.ReadPayloadPackage();
- CbObject ActionObj = Action.GetObject();
-
- IngestStats Stats = IngestPackageAttachments(Action);
-
- if (ComputeServiceSession::EnqueueResult Result =
- m_ComputeService.EnqueueActionToQueue(QueueId, ActionObj, RequestPriority))
- {
- ZEN_DEBUG("queue {}: accepted action (lsn {}): {} in {} attachments. {} new ({} attachments)",
- QueueId,
- Result.Lsn,
- zen::NiceBytes(Stats.Bytes),
- Stats.Count,
- zen::NiceBytes(Stats.NewBytes),
- Stats.NewCount);
-
- return HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage);
- }
- else
- {
- return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage);
- }
- }
-
- default:
- break;
- }
+ HandleSubmitAction(HttpReq, QueueId, RequestPriority, nullptr);
},
HttpVerb::kPost);
@@ -1306,6 +984,45 @@ HttpComputeService::Impl::RegisterRoutes()
}
},
HttpVerb::kGet | HttpVerb::kPost);
+
+ m_Router.RegisterRoute(
+ "queues/{queueref}/jobs/{lsn}/retract",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+ const int QueueId = ResolveQueueRef(HttpReq, Req.GetCapture(1));
+ const int ActionLsn = ParseInt<int>(Req.GetCapture(2)).value_or(0);
+
+ if (QueueId == 0)
+ {
+ return;
+ }
+
+ ZEN_UNUSED(QueueId);
+
+ auto Result = m_ComputeService.RetractAction(ActionLsn);
+
+ CbObjectWriter Cbo;
+ if (Result.Success)
+ {
+ Cbo << "success"sv << true;
+ Cbo << "lsn"sv << ActionLsn;
+ HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save());
+ }
+ else
+ {
+ Cbo << "error"sv << Result.Error;
+ HttpReq.WriteResponse(HttpResponseCode::Conflict, Cbo.Save());
+ }
+ },
+ HttpVerb::kPost);
+
+ // WebSocket upgrade endpoint — the handler logic lives in
+ // HttpComputeService::OnWebSocket* methods; this route merely
+ // satisfies the router so the upgrade request isn't rejected.
+ m_Router.RegisterRoute(
+ "ws",
+ [this](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK); },
+ HttpVerb::kGet);
}
//////////////////////////////////////////////////////////////////////////
@@ -1320,12 +1037,17 @@ HttpComputeService::HttpComputeService(CidStore& InCidStore,
HttpComputeService::~HttpComputeService()
{
+ m_Impl->m_ComputeService.SetCompletionObserver(nullptr);
m_Impl->m_StatsService.UnregisterHandler("compute", *this);
}
void
HttpComputeService::Shutdown()
{
+ // Null out observer before shutting down the compute session to prevent
+ // callbacks into a partially-torn-down service.
+ m_Impl->m_ComputeService.SetCompletionObserver(nullptr);
+ m_Impl->m_WsConnectionsLock.WithExclusiveLock([&] { m_Impl->m_WsConnections.clear(); });
m_Impl->m_ComputeService.Shutdown();
}
@@ -1492,6 +1214,184 @@ HttpComputeService::Impl::CheckAttachments(const CbObject& ActionObj, std::vecto
}
void
+HttpComputeService::Impl::HandleSubmitAction(HttpServerRequest& HttpReq, int QueueId, int Priority, const WorkerDesc* Worker)
+{
+ // QueueId > 0: queue-scoped enqueue; QueueId == 0: implicit queue (global routes)
+ auto Enqueue = [&](CbObject ActionObj) -> ComputeServiceSession::EnqueueResult {
+ if (QueueId > 0)
+ {
+ if (Worker)
+ {
+ return m_ComputeService.EnqueueResolvedActionToQueue(QueueId, *Worker, ActionObj, Priority);
+ }
+ return m_ComputeService.EnqueueActionToQueue(QueueId, ActionObj, Priority);
+ }
+ else
+ {
+ if (Worker)
+ {
+ return m_ComputeService.EnqueueResolvedAction(*Worker, ActionObj, Priority);
+ }
+ return m_ComputeService.EnqueueAction(ActionObj, Priority);
+ }
+ };
+
+ // Read payload upfront and handle attachments based on content type
+ CbObject Body;
+ IngestStats Stats = {};
+
+ switch (HttpReq.RequestContentType())
+ {
+ case HttpContentType::kCbObject:
+ {
+ IoBuffer Payload = HttpReq.ReadPayload();
+ Body = LoadCompactBinaryObject(Payload);
+ break;
+ }
+
+ case HttpContentType::kCbPackage:
+ {
+ CbPackage Package = HttpReq.ReadPayloadPackage();
+ Body = Package.GetObject();
+ Stats = IngestPackageAttachments(Package);
+ break;
+ }
+
+ default:
+ return HttpReq.WriteResponse(HttpResponseCode::BadRequest);
+ }
+
+ // Check for "actions" array to determine batch vs single-action path
+ CbArray Actions = Body.Find("actions"sv).AsArray();
+
+ if (Actions.Num() > 0)
+ {
+ // --- Batch path ---
+
+ // For CbObject payloads, check all attachments upfront before enqueuing anything
+ if (HttpReq.RequestContentType() == HttpContentType::kCbObject)
+ {
+ std::vector<IoHash> NeedList;
+
+ for (CbField ActionField : Actions)
+ {
+ CbObject ActionObj = ActionField.AsObject();
+ CheckAttachments(ActionObj, NeedList);
+ }
+
+ if (!NeedList.empty())
+ {
+ CbObjectWriter Cbo;
+ Cbo.BeginArray("need");
+
+ for (const IoHash& Hash : NeedList)
+ {
+ Cbo << Hash;
+ }
+
+ Cbo.EndArray();
+
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound, Cbo.Save());
+ }
+ }
+
+ // Enqueue all actions and collect results
+ CbObjectWriter Cbo;
+ int Accepted = 0;
+
+ Cbo.BeginArray("results");
+
+ for (CbField ActionField : Actions)
+ {
+ CbObject ActionObj = ActionField.AsObject();
+
+ ComputeServiceSession::EnqueueResult Result = Enqueue(ActionObj);
+
+ Cbo.BeginObject();
+
+ if (Result)
+ {
+ Cbo << "lsn"sv << Result.Lsn;
+ ++Accepted;
+ }
+ else
+ {
+ Cbo << "error"sv << Result.ResponseMessage;
+ }
+
+ Cbo.EndObject();
+ }
+
+ Cbo.EndArray();
+
+ if (Stats.Count > 0)
+ {
+ ZEN_DEBUG("queue {}: batch accepted {}/{} actions: {} in {} attachments. {} new ({} attachments)",
+ QueueId,
+ Accepted,
+ Actions.Num(),
+ zen::NiceBytes(Stats.Bytes),
+ Stats.Count,
+ zen::NiceBytes(Stats.NewBytes),
+ Stats.NewCount);
+ }
+ else
+ {
+ ZEN_DEBUG("queue {}: batch accepted {}/{} actions", QueueId, Accepted, Actions.Num());
+ }
+
+ return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save());
+ }
+
+ // --- Single-action path: Body is the action itself ---
+
+ if (HttpReq.RequestContentType() == HttpContentType::kCbObject)
+ {
+ std::vector<IoHash> NeedList;
+
+ if (!CheckAttachments(Body, NeedList))
+ {
+ CbObjectWriter Cbo;
+ Cbo.BeginArray("need");
+
+ for (const IoHash& Hash : NeedList)
+ {
+ Cbo << Hash;
+ }
+
+ Cbo.EndArray();
+
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound, Cbo.Save());
+ }
+ }
+
+ if (ComputeServiceSession::EnqueueResult Result = Enqueue(Body))
+ {
+ if (Stats.Count > 0)
+ {
+ ZEN_DEBUG("queue {}: accepted action {} (lsn {}): {} in {} attachments. {} new ({} attachments)",
+ QueueId,
+ Body.GetHash(),
+ Result.Lsn,
+ zen::NiceBytes(Stats.Bytes),
+ Stats.Count,
+ zen::NiceBytes(Stats.NewBytes),
+ Stats.NewCount);
+ }
+ else
+ {
+ ZEN_DEBUG("queue {}: action {} accepted (lsn {})", QueueId, Body.GetHash(), Result.Lsn);
+ }
+
+ return HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage);
+ }
+ else
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage);
+ }
+}
+
+void
HttpComputeService::Impl::HandleWorkersGet(HttpServerRequest& HttpReq)
{
CbObjectWriter Cbo;
@@ -1632,6 +1532,136 @@ HttpComputeService::Impl::HandleWorkerRequest(HttpServerRequest& HttpReq, const
}
//////////////////////////////////////////////////////////////////////////
+//
+// IWebSocketHandler
+//
+
+void
+HttpComputeService::OnWebSocketOpen(Ref<WebSocketConnection> Connection)
+{
+ m_Impl->OnWebSocketOpen(std::move(Connection));
+}
+
+void
+HttpComputeService::OnWebSocketMessage([[maybe_unused]] WebSocketConnection& Conn, [[maybe_unused]] const WebSocketMessage& Msg)
+{
+ // Clients are receive-only; ignore any inbound messages.
+}
+
+void
+HttpComputeService::OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, [[maybe_unused]] std::string_view Reason)
+{
+ m_Impl->OnWebSocketClose(Conn, Code);
+}
+
+void
+HttpComputeService::OnActionsCompleted(std::span<const CompletedActionNotification> Actions)
+{
+ m_Impl->OnActionsCompleted(Actions);
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Impl — WebSocket / observer
+//
+
+void
+HttpComputeService::Impl::OnWebSocketOpen(Ref<WebSocketConnection> Connection)
+{
+ ZEN_INFO("compute WebSocket client connected");
+ m_WsConnectionsLock.WithExclusiveLock([&] { m_WsConnections.push_back(std::move(Connection)); });
+}
+
+void
+HttpComputeService::Impl::OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code)
+{
+ ZEN_INFO("compute WebSocket client disconnected (code {})", Code);
+
+ m_WsConnectionsLock.WithExclusiveLock([&] {
+ auto It = std::remove_if(m_WsConnections.begin(), m_WsConnections.end(), [&Conn](const Ref<WebSocketConnection>& C) {
+ return C.Get() == &Conn;
+ });
+ m_WsConnections.erase(It, m_WsConnections.end());
+ });
+}
+
+void
+HttpComputeService::Impl::OnActionsCompleted(std::span<const IComputeCompletionObserver::CompletedActionNotification> Actions)
+{
+ using ActionState = IComputeCompletionObserver::ActionState;
+ using CompletedActionNotification = IComputeCompletionObserver::CompletedActionNotification;
+
+ // Snapshot connections under shared lock
+ eastl::fixed_vector<Ref<WebSocketConnection>, 16> Connections;
+ m_WsConnectionsLock.WithSharedLock([&] { Connections = {begin(m_WsConnections), end(m_WsConnections)}; });
+
+ if (Connections.empty())
+ {
+ return;
+ }
+
+ // Build CompactBinary notification grouped by state:
+ // {"Completed": [lsn, ...], "Failed": [lsn, ...], ...}
+ // Each state name becomes an array key containing the LSNs in that state.
+ CbObjectWriter Cbo;
+
+ // Sort by state so we can emit one array per state in a single pass.
+ // Copy into a local vector since the span is const.
+ eastl::fixed_vector<CompletedActionNotification, 16> Sorted(Actions.begin(), Actions.end());
+ std::sort(Sorted.begin(), Sorted.end(), [](const auto& A, const auto& B) { return A.State < B.State; });
+
+ ActionState CurrentState{};
+ bool ArrayOpen = false;
+
+ for (const CompletedActionNotification& Action : Sorted)
+ {
+ if (!ArrayOpen || Action.State != CurrentState)
+ {
+ if (ArrayOpen)
+ {
+ Cbo.EndArray();
+ }
+ CurrentState = Action.State;
+ Cbo.BeginArray(IComputeCompletionObserver::ActionStateToString(CurrentState));
+ ArrayOpen = true;
+ }
+ Cbo.AddInteger(Action.Lsn);
+ }
+
+ if (ArrayOpen)
+ {
+ Cbo.EndArray();
+ }
+
+ CbObject Msg = Cbo.Save();
+ MemoryView MsgView = Msg.GetView();
+
+ // Broadcast to all connected clients, prune closed ones
+ bool HadClosedConnections = false;
+ for (auto& Conn : Connections)
+ {
+ if (Conn->IsOpen())
+ {
+ Conn->SendBinary(MsgView);
+ }
+ else
+ {
+ HadClosedConnections = true;
+ }
+ }
+
+ if (HadClosedConnections)
+ {
+ m_WsConnectionsLock.WithExclusiveLock([&] {
+ auto It = std::remove_if(m_WsConnections.begin(), m_WsConnections.end(), [](const Ref<WebSocketConnection>& C) {
+ return !C->IsOpen();
+ });
+ m_WsConnections.erase(It, m_WsConnections.end());
+ });
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
void
httpcomputeservice_forcelink()
diff --git a/src/zencompute/include/zencompute/computeservice.h b/src/zencompute/include/zencompute/computeservice.h
index 65ec5f9ee..1ca78738a 100644
--- a/src/zencompute/include/zencompute/computeservice.h
+++ b/src/zencompute/include/zencompute/computeservice.h
@@ -13,6 +13,7 @@
# include <zenhttp/httpcommon.h>
# include <filesystem>
+# include <span>
namespace zen {
class ChunkResolver;
@@ -29,6 +30,53 @@ class RemoteHttpRunner;
struct RunnerAction;
struct SubmitResult;
+/**
+ * Observer interface for action completion notifications.
+ *
+ * Implementors receive a batch of notifications whenever actions reach a
+ * terminal state (Completed, Failed, Abandoned, Cancelled). The callback
+ * fires on the scheduler thread *after* the action result has been placed
+ * in m_ResultsMap, so GET /jobs/{lsn} will succeed by the time the client
+ * reacts to the notification.
+ */
+class IComputeCompletionObserver
+{
+public:
+ virtual ~IComputeCompletionObserver() = default;
+
+ enum class ActionState
+ {
+ Completed,
+ Failed,
+ Abandoned,
+ Cancelled,
+ };
+
+ struct CompletedActionNotification
+ {
+ int Lsn;
+ ActionState State;
+ };
+
+ static constexpr std::string_view ActionStateToString(ActionState State)
+ {
+ switch (State)
+ {
+ case ActionState::Completed:
+ return "Completed";
+ case ActionState::Failed:
+ return "Failed";
+ case ActionState::Abandoned:
+ return "Abandoned";
+ case ActionState::Cancelled:
+ return "Cancelled";
+ }
+ return "Unknown";
+ }
+
+ virtual void OnActionsCompleted(std::span<const CompletedActionNotification> Actions) = 0;
+};
+
struct WorkerDesc
{
CbPackage Descriptor;
@@ -91,11 +139,25 @@ public:
// Sunset can be reached from any non-Sunset state.
bool RequestStateTransition(SessionState NewState);
+ // Convenience helpers for common state transitions.
+ bool Ready() { return RequestStateTransition(SessionState::Ready); }
+ bool Abandon() { return RequestStateTransition(SessionState::Abandoned); }
+
// Orchestration
void SetOrchestratorEndpoint(std::string_view Endpoint);
void SetOrchestratorBasePath(std::filesystem::path BasePath);
+ void SetOrchestrator(std::string_view Endpoint, std::filesystem::path BasePath)
+ {
+ SetOrchestratorEndpoint(Endpoint);
+ SetOrchestratorBasePath(std::move(BasePath));
+ }
+
+ /// Immediately wake the scheduler to re-poll the orchestrator for worker changes.
+ /// Resets the polling throttle so the next scheduler tick calls UpdateCoordinatorState().
+ void NotifyOrchestratorChanged();
+
// Worker registration and discovery
void RegisterWorker(CbPackage Worker);
@@ -182,6 +244,7 @@ public:
};
[[nodiscard]] RescheduleResult RescheduleAction(int ActionLsn);
+ [[nodiscard]] RescheduleResult RetractAction(int ActionLsn);
void GetCompleted(CbWriter&);
@@ -215,7 +278,7 @@ public:
// sized to match RunnerAction::State::_Count but we can't use the enum here
// for dependency reasons, so just use a fixed size array and static assert in
// the implementation file
- uint64_t Timestamps[8] = {};
+ uint64_t Timestamps[9] = {};
};
[[nodiscard]] std::vector<ActionHistoryEntry> GetActionHistory(int Limit = 100);
@@ -235,6 +298,10 @@ public:
void EmitStats(CbObjectWriter& Cbo);
+ // Completion observer (used by HttpComputeService for WebSocket push)
+
+ void SetCompletionObserver(IComputeCompletionObserver* Observer);
+
// Recording
void StartRecording(ChunkResolver& InResolver, const std::filesystem::path& RecordingPath);
diff --git a/src/zencompute/include/zencompute/httpcomputeservice.h b/src/zencompute/include/zencompute/httpcomputeservice.h
index ee1cd2614..b58e73a0d 100644
--- a/src/zencompute/include/zencompute/httpcomputeservice.h
+++ b/src/zencompute/include/zencompute/httpcomputeservice.h
@@ -9,6 +9,7 @@
# include "zencompute/computeservice.h"
# include <zenhttp/httpserver.h>
+# include <zenhttp/websocket.h>
# include <filesystem>
# include <memory>
@@ -22,7 +23,7 @@ namespace zen::compute {
/**
* HTTP interface for compute service
*/
-class HttpComputeService : public HttpService, public IHttpStatsProvider
+class HttpComputeService : public HttpService, public IHttpStatsProvider, public IWebSocketHandler, public IComputeCompletionObserver
{
public:
HttpComputeService(CidStore& InCidStore,
@@ -42,6 +43,16 @@ public:
void HandleStatsRequest(HttpServerRequest& Request) override;
+ // IWebSocketHandler
+
+ void OnWebSocketOpen(Ref<WebSocketConnection> Connection) override;
+ void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) override;
+ void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason) override;
+
+ // IComputeCompletionObserver
+
+ void OnActionsCompleted(std::span<const CompletedActionNotification> Actions) override;
+
private:
struct Impl;
std::unique_ptr<Impl> m_Impl;
diff --git a/src/zencompute/runners/functionrunner.cpp b/src/zencompute/runners/functionrunner.cpp
index 768cdf1e1..4f116e7d8 100644
--- a/src/zencompute/runners/functionrunner.cpp
+++ b/src/zencompute/runners/functionrunner.cpp
@@ -215,15 +215,22 @@ BaseRunnerGroup::GetSubmittedActionCount()
return TotalCount;
}
-void
+bool
BaseRunnerGroup::RegisterWorker(CbPackage Worker)
{
RwLock::SharedLockScope _(m_RunnersLock);
+ bool AllSucceeded = true;
+
for (auto& Runner : m_Runners)
{
- Runner->RegisterWorker(Worker);
+ if (!Runner->RegisterWorker(Worker))
+ {
+ AllSucceeded = false;
+ }
}
+
+ return AllSucceeded;
}
void
@@ -276,12 +283,34 @@ RunnerAction::~RunnerAction()
}
bool
+RunnerAction::RetractAction()
+{
+ State CurrentState = m_ActionState.load();
+
+ do
+ {
+ // Only allow retraction from pre-terminal states (idempotent if already retracted)
+ if (CurrentState > State::Running)
+ {
+ return CurrentState == State::Retracted;
+ }
+
+ if (m_ActionState.compare_exchange_strong(CurrentState, State::Retracted))
+ {
+ this->Timestamps[static_cast<int>(State::Retracted)] = DateTime::Now().GetTicks();
+ m_OwnerSession->PostUpdate(this);
+ return true;
+ }
+ } while (true);
+}
+
+bool
RunnerAction::ResetActionStateToPending()
{
- // Only allow reset from Failed or Abandoned states
+ // Only allow reset from Failed, Abandoned, or Retracted states
State CurrentState = m_ActionState.load();
- if (CurrentState != State::Failed && CurrentState != State::Abandoned)
+ if (CurrentState != State::Failed && CurrentState != State::Abandoned && CurrentState != State::Retracted)
{
return false;
}
@@ -305,8 +334,11 @@ RunnerAction::ResetActionStateToPending()
CpuUsagePercent.store(-1.0f, std::memory_order_relaxed);
CpuSeconds.store(0.0f, std::memory_order_relaxed);
- // Increment retry count
- RetryCount.fetch_add(1, std::memory_order_relaxed);
+ // Increment retry count (skip for Retracted — nothing failed)
+ if (CurrentState != State::Retracted)
+ {
+ RetryCount.fetch_add(1, std::memory_order_relaxed);
+ }
// Re-enter the scheduler pipeline
m_OwnerSession->PostUpdate(this);
diff --git a/src/zencompute/runners/functionrunner.h b/src/zencompute/runners/functionrunner.h
index f67414dbb..56c3f3af0 100644
--- a/src/zencompute/runners/functionrunner.h
+++ b/src/zencompute/runners/functionrunner.h
@@ -29,8 +29,8 @@ public:
FunctionRunner(std::filesystem::path BasePath);
virtual ~FunctionRunner() = 0;
- virtual void Shutdown() = 0;
- virtual void RegisterWorker(const CbPackage& WorkerPackage) = 0;
+ virtual void Shutdown() = 0;
+ [[nodiscard]] virtual bool RegisterWorker(const CbPackage& WorkerPackage) = 0;
[[nodiscard]] virtual SubmitResult SubmitAction(Ref<RunnerAction> Action) = 0;
[[nodiscard]] virtual size_t GetSubmittedActionCount() = 0;
@@ -63,7 +63,7 @@ public:
SubmitResult SubmitAction(Ref<RunnerAction> Action);
std::vector<SubmitResult> SubmitActions(const std::vector<Ref<RunnerAction>>& Actions);
size_t GetSubmittedActionCount();
- void RegisterWorker(CbPackage Worker);
+ [[nodiscard]] bool RegisterWorker(CbPackage Worker);
void Shutdown();
bool CancelAction(int ActionLsn);
void CancelRemoteQueue(int QueueId);
@@ -114,6 +114,30 @@ struct RunnerGroup : public BaseRunnerGroup
/**
* This represents an action going through different stages of scheduling and execution.
+ *
+ * State machine
+ * =============
+ *
+ * Normal forward flow (enforced by SetActionState rejecting backward transitions):
+ *
+ * New -> Pending -> Submitting -> Running -> Completed
+ * -> Failed
+ * -> Abandoned
+ * -> Cancelled
+ *
+ * Rescheduling (via ResetActionStateToPending):
+ *
+ * Failed ---> Pending (increments RetryCount, subject to retry limit)
+ * Abandoned ---> Pending (increments RetryCount, subject to retry limit)
+ * Retracted ---> Pending (does NOT increment RetryCount)
+ *
+ * Retraction (via RetractAction, idempotent):
+ *
+ * Pending/Submitting/Running -> Retracted -> Pending (rescheduled)
+ *
+ * Retracted is placed after Cancelled in enum order so that once set,
+ * no runner-side transition (Completed/Failed) can override it via
+ * SetActionState's forward-only rule.
*/
struct RunnerAction : public RefCounted
{
@@ -137,16 +161,20 @@ struct RunnerAction : public RefCounted
enum class State
{
- New,
- Pending,
- Submitting,
- Running,
- Completed,
- Failed,
- Abandoned,
- Cancelled,
+ New, // Initial state at construction, before entering the scheduler
+ Pending, // Queued and waiting for a runner slot
+ Submitting, // Being handed off to a runner (async submission in progress)
+ Running, // Executing on a runner process
+ Completed, // Finished successfully with results available
+ Failed, // Execution failed (transient error, eligible for retry)
+ Abandoned, // Infrastructure termination (e.g. spot eviction, session abandon)
+ Cancelled, // Intentional user cancellation (never retried)
+ Retracted, // Pulled back for rescheduling on a different runner (no retry cost)
_Count
};
+ static_assert(State::Retracted > State::Completed && State::Retracted > State::Failed && State::Retracted > State::Abandoned &&
+ State::Retracted > State::Cancelled,
+ "Retracted must be the highest terminal ordinal so runner-side transitions cannot override it");
static const char* ToString(State _)
{
@@ -168,6 +196,8 @@ struct RunnerAction : public RefCounted
return "Abandoned";
case State::Cancelled:
return "Cancelled";
+ case State::Retracted:
+ return "Retracted";
default:
return "Unknown";
}
@@ -191,6 +221,7 @@ struct RunnerAction : public RefCounted
void SetActionState(State NewState);
bool IsSuccess() const { return ActionState() == State::Completed; }
+ bool RetractAction();
bool ResetActionStateToPending();
bool IsCompleted() const
{
diff --git a/src/zencompute/runners/localrunner.cpp b/src/zencompute/runners/localrunner.cpp
index 7aaefb06e..b61e0a46f 100644
--- a/src/zencompute/runners/localrunner.cpp
+++ b/src/zencompute/runners/localrunner.cpp
@@ -7,14 +7,16 @@
# include <zencore/compactbinary.h>
# include <zencore/compactbinarybuilder.h>
# include <zencore/compactbinarypackage.h>
+# include <zencore/compactbinaryfile.h>
# include <zencore/compress.h>
# include <zencore/except_fmt.h>
# include <zencore/filesystem.h>
# include <zencore/fmtutils.h>
# include <zencore/iobuffer.h>
# include <zencore/iohash.h>
-# include <zencore/system.h>
# include <zencore/scopeguard.h>
+# include <zencore/stream.h>
+# include <zencore/system.h>
# include <zencore/timer.h>
# include <zencore/trace.h>
# include <zenstore/cidstore.h>
@@ -152,7 +154,7 @@ LocalProcessRunner::CreateNewSandbox()
return Path;
}
-void
+bool
LocalProcessRunner::RegisterWorker(const CbPackage& WorkerPackage)
{
ZEN_TRACE_CPU("LocalProcessRunner::RegisterWorker");
@@ -173,6 +175,8 @@ LocalProcessRunner::RegisterWorker(const CbPackage& WorkerPackage)
ZEN_INFO("dumped worker '{}' to 'file://{}'", WorkerId, Path);
}
+
+ return true;
}
size_t
@@ -301,7 +305,7 @@ LocalProcessRunner::PrepareActionSubmission(Ref<RunnerAction> Action)
// Write out action
- zen::WriteFile(SandboxPath / "build.action", ActionObj.GetBuffer().AsIoBuffer());
+ WriteCompactBinaryObject(SandboxPath / "build.action", ActionObj);
// Manifest inputs in sandbox
diff --git a/src/zencompute/runners/localrunner.h b/src/zencompute/runners/localrunner.h
index 7493e980b..b8cff6826 100644
--- a/src/zencompute/runners/localrunner.h
+++ b/src/zencompute/runners/localrunner.h
@@ -51,7 +51,7 @@ public:
~LocalProcessRunner();
virtual void Shutdown() override;
- virtual void RegisterWorker(const CbPackage& WorkerPackage) override;
+ [[nodiscard]] virtual bool RegisterWorker(const CbPackage& WorkerPackage) override;
[[nodiscard]] virtual SubmitResult SubmitAction(Ref<RunnerAction> Action) override;
[[nodiscard]] virtual bool IsHealthy() override { return true; }
[[nodiscard]] virtual size_t GetSubmittedActionCount() override;
diff --git a/src/zencompute/runners/remotehttprunner.cpp b/src/zencompute/runners/remotehttprunner.cpp
index 672636d06..ce6a81173 100644
--- a/src/zencompute/runners/remotehttprunner.cpp
+++ b/src/zencompute/runners/remotehttprunner.cpp
@@ -42,6 +42,20 @@ RemoteHttpRunner::RemoteHttpRunner(ChunkResolver& InChunkResolver,
, m_Http(m_BaseUrl)
, m_InstanceId(Oid::NewOid())
{
+ // Attempt to connect a WebSocket for push-based completion notifications.
+ // If the remote doesn't support WS, OnWsClose fires and we fall back to polling.
+ {
+ std::string WsUrl = HttpToWsUrl(HostName, "/compute/ws");
+
+ HttpWsClientSettings WsSettings;
+ WsSettings.LogCategory = "http_exec_ws";
+ WsSettings.ConnectTimeout = std::chrono::milliseconds{3000};
+
+ IWsClientHandler& Handler = *this;
+ m_WsClient = std::make_unique<HttpWsClient>(WsUrl, Handler, WsSettings);
+ m_WsClient->Connect();
+ }
+
m_MonitorThread = std::thread{&RemoteHttpRunner::MonitorThreadFunction, this};
}
@@ -53,7 +67,29 @@ RemoteHttpRunner::~RemoteHttpRunner()
void
RemoteHttpRunner::Shutdown()
{
- // TODO: should cleanly drain/cancel pending work
+ m_AcceptNewActions = false;
+
+ // Close the WebSocket client first, so no more wakeup signals arrive.
+ if (m_WsClient)
+ {
+ m_WsClient->Close();
+ }
+
+ // Cancel all known remote queues so the remote side stops scheduling new
+ // work and cancels in-flight actions belonging to those queues.
+
+ {
+ std::vector<std::pair<int, Oid>> Queues;
+
+ m_QueueTokenLock.WithSharedLock([&] { Queues.assign(m_RemoteQueueTokens.begin(), m_RemoteQueueTokens.end()); });
+
+ for (const auto& [QueueId, Token] : Queues)
+ {
+ CancelRemoteQueue(QueueId);
+ }
+ }
+
+ // Stop the monitor thread so it no longer polls the remote.
m_MonitorThreadEnabled = false;
m_MonitorThreadEvent.Set();
@@ -61,9 +97,22 @@ RemoteHttpRunner::Shutdown()
{
m_MonitorThread.join();
}
+
+ // Drain the running map and mark all remaining actions as Failed so the
+ // scheduler can reschedule or finalize them.
+
+ std::unordered_map<int, HttpRunningAction> Remaining;
+
+ m_RunningLock.WithExclusiveLock([&] { Remaining.swap(m_RemoteRunningMap); });
+
+ for (auto& [RemoteLsn, HttpAction] : Remaining)
+ {
+ ZEN_DEBUG("shutdown: marking remote action LSN {} (local LSN {}) as Failed", RemoteLsn, HttpAction.Action->ActionLsn);
+ HttpAction.Action->SetActionState(RunnerAction::State::Failed);
+ }
}
-void
+bool
RemoteHttpRunner::RegisterWorker(const CbPackage& WorkerPackage)
{
ZEN_TRACE_CPU("RemoteHttpRunner::RegisterWorker");
@@ -125,15 +174,13 @@ RemoteHttpRunner::RegisterWorker(const CbPackage& WorkerPackage)
if (!IsHttpSuccessCode(PayloadResponse.StatusCode))
{
ZEN_ERROR("ERROR: unable to register payloads for worker {} at {}{}", WorkerId, m_Http.GetBaseUri(), WorkerUrl);
-
- // TODO: propagate error
+ return false;
}
}
else if (!IsHttpSuccessCode(DescResponse.StatusCode))
{
ZEN_ERROR("ERROR: unable to register worker {} at {}{}", WorkerId, m_Http.GetBaseUri(), WorkerUrl);
-
- // TODO: propagate error
+ return false;
}
else
{
@@ -152,14 +199,20 @@ RemoteHttpRunner::RegisterWorker(const CbPackage& WorkerPackage)
WorkerUrl,
(int)WorkerResponse.StatusCode,
ToString(WorkerResponse.StatusCode));
-
- // TODO: propagate error
+ return false;
}
+
+ return true;
}
size_t
RemoteHttpRunner::QueryCapacity()
{
+ if (!m_AcceptNewActions)
+ {
+ return 0;
+ }
+
// Estimate how much more work we're ready to accept
RwLock::SharedLockScope _{m_RunningLock};
@@ -191,24 +244,68 @@ RemoteHttpRunner::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions)
return Results;
}
- // For larger batches, submit HTTP requests in parallel via the shared worker pool
+ // Collect distinct QueueIds and ensure remote queues exist once per queue
- std::vector<std::future<SubmitResult>> Futures;
- Futures.reserve(Actions.size());
+ std::unordered_map<int, Oid> QueueTokens; // QueueId → remote token (0 stays as Zero)
for (const Ref<RunnerAction>& Action : Actions)
{
- std::packaged_task<SubmitResult()> Task([this, Action]() { return SubmitAction(Action); });
+ const int QueueId = Action->QueueId;
+ if (QueueId != 0 && QueueTokens.find(QueueId) == QueueTokens.end())
+ {
+ CbObject QueueMeta = Action->GetOwnerSession()->GetQueueMetadata(QueueId);
+ CbObject QueueConfig = Action->GetOwnerSession()->GetQueueConfig(QueueId);
+ QueueTokens[QueueId] = EnsureRemoteQueue(QueueId, QueueMeta, QueueConfig);
+ }
+ }
- Futures.push_back(m_WorkerPool.EnqueueTask(std::move(Task), WorkerThreadPool::EMode::EnableBacklog));
+ // Group actions by QueueId
+
+ struct QueueGroup
+ {
+ std::vector<Ref<RunnerAction>> Actions;
+ std::vector<size_t> OriginalIndices;
+ };
+
+ std::unordered_map<int, QueueGroup> Groups;
+
+ for (size_t i = 0; i < Actions.size(); ++i)
+ {
+ auto& Group = Groups[Actions[i]->QueueId];
+ Group.Actions.push_back(Actions[i]);
+ Group.OriginalIndices.push_back(i);
}
- std::vector<SubmitResult> Results;
- Results.reserve(Futures.size());
+ // Submit each group as a batch and map results back to original indices
- for (auto& Future : Futures)
+ std::vector<SubmitResult> Results(Actions.size());
+
+ for (auto& [QueueId, Group] : Groups)
{
- Results.push_back(Future.get());
+ std::string SubmitUrl = "/jobs";
+ if (QueueId != 0)
+ {
+ if (Oid Token = QueueTokens[QueueId]; Token != Oid::Zero)
+ {
+ SubmitUrl = fmt::format("/queues/{}/jobs", Token);
+ }
+ }
+
+ const size_t BatchLimit = size_t(m_MaxBatchSize);
+
+ for (size_t Offset = 0; Offset < Group.Actions.size(); Offset += BatchLimit)
+ {
+ size_t End = zen::Min(Offset + BatchLimit, Group.Actions.size());
+
+ std::vector<Ref<RunnerAction>> Chunk(Group.Actions.begin() + Offset, Group.Actions.begin() + End);
+
+ std::vector<SubmitResult> ChunkResults = SubmitActionBatch(SubmitUrl, Chunk);
+
+ for (size_t j = 0; j < ChunkResults.size(); ++j)
+ {
+ Results[Group.OriginalIndices[Offset + j]] = std::move(ChunkResults[j]);
+ }
+ }
}
return Results;
@@ -221,6 +318,11 @@ RemoteHttpRunner::SubmitAction(Ref<RunnerAction> Action)
// Verify whether we can accept more work
+ if (!m_AcceptNewActions)
+ {
+ return SubmitResult{.IsAccepted = false, .Reason = "runner is shutting down"};
+ }
+
{
RwLock::SharedLockScope _{m_RunningLock};
if (m_RemoteRunningMap.size() >= size_t(m_MaxRunningActions))
@@ -275,7 +377,7 @@ RemoteHttpRunner::SubmitAction(Ref<RunnerAction> Action)
m_Http.GetBaseUri(),
ActionId);
- RegisterWorker(Action->Worker.Descriptor);
+ (void)RegisterWorker(Action->Worker.Descriptor);
}
else
{
@@ -384,6 +486,194 @@ RemoteHttpRunner::SubmitAction(Ref<RunnerAction> Action)
return {};
}
+std::vector<SubmitResult>
+RemoteHttpRunner::SubmitActionBatch(const std::string& SubmitUrl, const std::vector<Ref<RunnerAction>>& Actions)
+{
+ ZEN_TRACE_CPU("RemoteHttpRunner::SubmitActionBatch");
+
+ if (!m_AcceptNewActions)
+ {
+ return std::vector<SubmitResult>(Actions.size(), SubmitResult{.IsAccepted = false, .Reason = "runner is shutting down"});
+ }
+
+ // Capacity check
+
+ {
+ RwLock::SharedLockScope _{m_RunningLock};
+ if (m_RemoteRunningMap.size() >= size_t(m_MaxRunningActions))
+ {
+ std::vector<SubmitResult> Results(Actions.size(), SubmitResult{.IsAccepted = false});
+ return Results;
+ }
+ }
+
+ // Per-action setup and build batch body
+
+ CbObjectWriter Body;
+ Body.BeginArray("actions"sv);
+
+ for (const Ref<RunnerAction>& Action : Actions)
+ {
+ Action->ExecutionLocation = m_HostName;
+ MaybeDumpAction(Action->ActionLsn, Action->ActionObj);
+ Body.AddObject(Action->ActionObj);
+ }
+
+ Body.EndArray();
+
+ // POST the batch
+
+ HttpClient::Response Response = m_Http.Post(SubmitUrl, Body.Save());
+
+ if (Response.StatusCode == HttpResponseCode::OK)
+ {
+ return ParseBatchResponse(Response, Actions);
+ }
+
+ if (Response.StatusCode == HttpResponseCode::NotFound)
+ {
+ // Server needs attachments — resolve them and retry with a CbPackage
+
+ CbObject NeedObj = Response.AsObject();
+
+ CbPackage Pkg;
+ Pkg.SetObject(Body.Save());
+
+ for (auto& Item : NeedObj["need"sv])
+ {
+ const IoHash NeedHash = Item.AsHash();
+
+ if (IoBuffer Chunk = m_ChunkResolver.FindChunkByCid(NeedHash))
+ {
+ uint64_t DataRawSize = 0;
+ IoHash DataRawHash;
+ CompressedBuffer Compressed =
+ CompressedBuffer::FromCompressed(SharedBuffer{Chunk}, /* out */ DataRawHash, /* out */ DataRawSize);
+
+ ZEN_ASSERT(DataRawHash == NeedHash);
+
+ Pkg.AddAttachment(CbAttachment(Compressed, NeedHash));
+ }
+ else
+ {
+ ZEN_WARN("batch submit: missing attachment {} — falling back to individual submit", NeedHash);
+ return FallbackToIndividualSubmit(Actions);
+ }
+ }
+
+ HttpClient::Response RetryResponse = m_Http.Post(SubmitUrl, Pkg);
+
+ if (RetryResponse.StatusCode == HttpResponseCode::OK)
+ {
+ return ParseBatchResponse(RetryResponse, Actions);
+ }
+
+ ZEN_WARN("batch submit retry failed with {} {} — falling back to individual submit",
+ (int)RetryResponse.StatusCode,
+ ToString(RetryResponse.StatusCode));
+ return FallbackToIndividualSubmit(Actions);
+ }
+
+ // Unexpected status or connection error — fall back to individual submission
+
+ if (Response)
+ {
+ ZEN_WARN("batch submit to {}{} returned {} {} — falling back to individual submit",
+ m_Http.GetBaseUri(),
+ SubmitUrl,
+ (int)Response.StatusCode,
+ ToString(Response.StatusCode));
+ }
+ else
+ {
+ ZEN_WARN("batch submit to {}{} failed — falling back to individual submit", m_Http.GetBaseUri(), SubmitUrl);
+ }
+
+ return FallbackToIndividualSubmit(Actions);
+}
+
+std::vector<SubmitResult>
+RemoteHttpRunner::ParseBatchResponse(const HttpClient::Response& Response, const std::vector<Ref<RunnerAction>>& Actions)
+{
+ std::vector<SubmitResult> Results;
+ Results.reserve(Actions.size());
+
+ CbObject ResponseObj = Response.AsObject();
+ CbArrayView ResultArray = ResponseObj["results"sv].AsArrayView();
+
+ size_t Index = 0;
+ for (CbFieldView Field : ResultArray)
+ {
+ if (Index >= Actions.size())
+ {
+ break;
+ }
+
+ CbObjectView Entry = Field.AsObjectView();
+ const int32_t LsnField = Entry["lsn"sv].AsInt32(0);
+
+ if (LsnField > 0)
+ {
+ HttpRunningAction NewAction;
+ NewAction.Action = Actions[Index];
+ NewAction.RemoteActionLsn = LsnField;
+
+ {
+ RwLock::ExclusiveLockScope _(m_RunningLock);
+ m_RemoteRunningMap[LsnField] = std::move(NewAction);
+ }
+
+ ZEN_DEBUG("batch: scheduled action {} with remote LSN {} (local LSN {})",
+ Actions[Index]->ActionObj.GetHash(),
+ LsnField,
+ Actions[Index]->ActionLsn);
+
+ Actions[Index]->SetActionState(RunnerAction::State::Running);
+
+ Results.push_back(SubmitResult{.IsAccepted = true});
+ }
+ else
+ {
+ std::string_view ErrorMsg = Entry["error"sv].AsString();
+ Results.push_back(SubmitResult{.IsAccepted = false, .Reason = std::string(ErrorMsg)});
+ }
+
+ ++Index;
+ }
+
+ // If the server returned fewer results than actions, mark the rest as not accepted
+ while (Results.size() < Actions.size())
+ {
+ Results.push_back(SubmitResult{.IsAccepted = false, .Reason = "no result from server"});
+ }
+
+ return Results;
+}
+
+std::vector<SubmitResult>
+RemoteHttpRunner::FallbackToIndividualSubmit(const std::vector<Ref<RunnerAction>>& Actions)
+{
+ std::vector<std::future<SubmitResult>> Futures;
+ Futures.reserve(Actions.size());
+
+ for (const Ref<RunnerAction>& Action : Actions)
+ {
+ std::packaged_task<SubmitResult()> Task([this, Action]() { return SubmitAction(Action); });
+
+ Futures.push_back(m_WorkerPool.EnqueueTask(std::move(Task), WorkerThreadPool::EMode::EnableBacklog));
+ }
+
+ std::vector<SubmitResult> Results;
+ Results.reserve(Futures.size());
+
+ for (auto& Future : Futures)
+ {
+ Results.push_back(Future.get());
+ }
+
+ return Results;
+}
+
Oid
RemoteHttpRunner::EnsureRemoteQueue(int QueueId, const CbObject& Metadata, const CbObject& Config)
{
@@ -481,6 +771,35 @@ RemoteHttpRunner::GetSubmittedActionCount()
return m_RemoteRunningMap.size();
}
+//////////////////////////////////////////////////////////////////////////
+//
+// IWsClientHandler
+//
+
+void
+RemoteHttpRunner::OnWsOpen()
+{
+ ZEN_INFO("WebSocket connected to {}", m_HostName);
+ m_WsConnected.store(true, std::memory_order_release);
+}
+
+void
+RemoteHttpRunner::OnWsMessage([[maybe_unused]] const WebSocketMessage& Msg)
+{
+ // The message content is a wakeup signal; no parsing needed.
+ // Signal the monitor thread to sweep completed actions immediately.
+ m_MonitorThreadEvent.Set();
+}
+
+void
+RemoteHttpRunner::OnWsClose([[maybe_unused]] uint16_t Code, [[maybe_unused]] std::string_view Reason)
+{
+ ZEN_WARN("WebSocket disconnected from {} (code {})", m_HostName, Code);
+ m_WsConnected.store(false, std::memory_order_release);
+}
+
+//////////////////////////////////////////////////////////////////////////
+
void
RemoteHttpRunner::MonitorThreadFunction()
{
@@ -489,28 +808,40 @@ RemoteHttpRunner::MonitorThreadFunction()
do
{
const int NormalWaitingTime = 200;
- int WaitTimeMs = NormalWaitingTime;
- auto WaitOnce = [&] { return m_MonitorThreadEvent.Wait(WaitTimeMs); };
- auto SweepOnce = [&] {
+ const int WsWaitingTime = 2000; // Safety-net interval when WS is connected
+
+ int WaitTimeMs = m_WsConnected.load(std::memory_order_relaxed) ? WsWaitingTime : NormalWaitingTime;
+ auto WaitOnce = [&] { return m_MonitorThreadEvent.Wait(WaitTimeMs); };
+ auto SweepOnce = [&] {
const size_t RetiredCount = SweepRunningActions();
- m_RunningLock.WithSharedLock([&] {
- if (m_RemoteRunningMap.size() > 16)
- {
- WaitTimeMs = NormalWaitingTime / 4;
- }
- else
- {
- if (RetiredCount)
+ if (m_WsConnected.load(std::memory_order_relaxed))
+ {
+ // WS connected: use long safety-net interval; the WS message
+ // will wake us immediately for the real work.
+ WaitTimeMs = WsWaitingTime;
+ }
+ else
+ {
+ // No WS: adaptive polling as before
+ m_RunningLock.WithSharedLock([&] {
+ if (m_RemoteRunningMap.size() > 16)
{
- WaitTimeMs = NormalWaitingTime / 2;
+ WaitTimeMs = NormalWaitingTime / 4;
}
else
{
- WaitTimeMs = NormalWaitingTime;
+ if (RetiredCount)
+ {
+ WaitTimeMs = NormalWaitingTime / 2;
+ }
+ else
+ {
+ WaitTimeMs = NormalWaitingTime;
+ }
}
- }
- });
+ });
+ }
};
while (!WaitOnce())
@@ -518,7 +849,7 @@ RemoteHttpRunner::MonitorThreadFunction()
SweepOnce();
}
- // Signal received - this may mean we should quit
+ // Signal received — may be a WS wakeup or a quit signal
SweepOnce();
} while (m_MonitorThreadEnabled);
diff --git a/src/zencompute/runners/remotehttprunner.h b/src/zencompute/runners/remotehttprunner.h
index 9119992a9..c17d0cf2a 100644
--- a/src/zencompute/runners/remotehttprunner.h
+++ b/src/zencompute/runners/remotehttprunner.h
@@ -14,9 +14,11 @@
# include <zencore/workthreadpool.h>
# include <zencore/zencore.h>
# include <zenhttp/httpclient.h>
+# include <zenhttp/httpwsclient.h>
# include <atomic>
# include <filesystem>
+# include <memory>
# include <thread>
# include <unordered_map>
@@ -32,7 +34,7 @@ namespace zen::compute {
*/
-class RemoteHttpRunner : public FunctionRunner
+class RemoteHttpRunner : public FunctionRunner, private IWsClientHandler
{
RemoteHttpRunner(RemoteHttpRunner&&) = delete;
RemoteHttpRunner& operator=(RemoteHttpRunner&&) = delete;
@@ -45,7 +47,7 @@ public:
~RemoteHttpRunner();
virtual void Shutdown() override;
- virtual void RegisterWorker(const CbPackage& WorkerPackage) override;
+ [[nodiscard]] virtual bool RegisterWorker(const CbPackage& WorkerPackage) override;
[[nodiscard]] virtual SubmitResult SubmitAction(Ref<RunnerAction> Action) override;
[[nodiscard]] virtual bool IsHealthy() override;
[[nodiscard]] virtual size_t GetSubmittedActionCount() override;
@@ -66,7 +68,9 @@ private:
std::string m_BaseUrl;
HttpClient m_Http;
- int32_t m_MaxRunningActions = 256; // arbitrary limit for testing
+ std::atomic<bool> m_AcceptNewActions{true};
+ int32_t m_MaxRunningActions = 256; // arbitrary limit for testing
+ int32_t m_MaxBatchSize = 50;
struct HttpRunningAction
{
@@ -92,7 +96,20 @@ private:
// creating remote queues. Generated once at construction and never changes.
Oid m_InstanceId;
+ // WebSocket completion notification client
+ std::unique_ptr<HttpWsClient> m_WsClient;
+ std::atomic<bool> m_WsConnected{false};
+
+ // IWsClientHandler
+ void OnWsOpen() override;
+ void OnWsMessage(const WebSocketMessage& Msg) override;
+ void OnWsClose(uint16_t Code, std::string_view Reason) override;
+
Oid EnsureRemoteQueue(int QueueId, const CbObject& Metadata, const CbObject& Config);
+
+ std::vector<SubmitResult> SubmitActionBatch(const std::string& SubmitUrl, const std::vector<Ref<RunnerAction>>& Actions);
+ std::vector<SubmitResult> ParseBatchResponse(const HttpClient::Response& Response, const std::vector<Ref<RunnerAction>>& Actions);
+ std::vector<SubmitResult> FallbackToIndividualSubmit(const std::vector<Ref<RunnerAction>>& Actions);
};
} // namespace zen::compute