aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorStefan Boberg <[email protected]>2026-03-18 11:19:10 +0100
committerGitHub Enterprise <[email protected]>2026-03-18 11:19:10 +0100
commiteba410c4168e23d7908827eb34b7cf0c58a5dc48 (patch)
tree3cda8e8f3f81941d3bb5b84a8155350c5bb2068c /src
parentbugfix release - v5.7.23 (#851) (diff)
downloadzen-eba410c4168e23d7908827eb34b7cf0c58a5dc48.tar.xz
zen-eba410c4168e23d7908827eb34b7cf0c58a5dc48.zip
Compute batching (#849)
### Compute Batch Submission - Consolidate duplicated action submission logic in `httpcomputeservice` into a single `HandleSubmitAction` supporting both single-action and batch (actions array) payloads - Group actions by queue in `RemoteHttpRunner` and submit as batches with configurable chunk size, falling back to individual submission on failure - Extract shared helpers: `MakeErrorResult`, `ValidateQueueForEnqueue`, `ActivateActionInQueue`, `RemoveActionFromActiveMaps` ### Retracted Action State - Add `Retracted` state to `RunnerAction` for retry-free rescheduling — an explicit request to pull an action back and reschedule it on a different runner without incrementing `RetryCount` - Implement idempotent `RetractAction()` on `RunnerAction` and `ComputeServiceSession` - Add `POST jobs/{lsn}/retract` and `queues/{queueref}/jobs/{lsn}/retract` HTTP endpoints - Add state machine documentation and per-state comments to `RunnerAction` ### Compute Race Fixes - Fix race in `HandleActionUpdates` where actions enqueued between session abandon and scheduler tick were never abandoned, causing `GetActionResult` to return 202 indefinitely - Fix queue `ActiveCount` race where `NotifyQueueActionComplete` was called after releasing `m_ResultsLock`, allowing callers to observe stale counters immediately after `GetActionResult` returned OK ### Logging Optimization and ANSI improvements - Improve `AnsiColorStdoutSink` write efficiency — single write call, dirty-flag flush, `RwLock` instead of `std::mutex` - Move ANSI color emission from sink into formatters via `Formatter::SetColorEnabled()`; remove `ColorRangeStart`/`End` from `LogMessage` - Extract color helpers (`AnsiColorForLevel`, `StripAnsiSgrSequences`) into `helpers.h` - Strip upstream ANSI SGR escapes in non-color output mode. This enables colour in log messages without polluting log files with ANSI control sequences - Move `RotatingFileSink`, `JsonFormatter`, and `FullFormatter` from header-only to pimpl with `.cpp` files ### CLI / Exec Refactoring - Extract `ExecSessionRunner` class from ~920-line `ExecUsingSession` into focused methods and a `ExecSessionConfig` struct - Replace monolithic `ExecCommand` with subcommand-based architecture (`http`, `inproc`, `beacon`, `dump`, `buildlog`) - Allow parent options to appear after subcommand name by parsing subcommand args permissively and forwarding unmatched tokens to the parent parser ### Testing Improvements - Fix `--test-suite` filter being ignored due to accumulation with default wildcard filter - Add test suite banners to test listener output - Made `function.session.abandon_pending` test more robust ### Startup / Reliability Fixes - Fix silent exit when a second zenserver instance detects a port conflict — use `ZEN_CONSOLE_*` for log calls that precede `InitializeLogging()` - Fix two potential SIGSEGV paths during early startup: guard `sentry_options_new()` returning nullptr, and throw on `ZenServerState::Register()` returning nullptr instead of dereferencing - Fail on unrecognized zenserver `--mode` instead of silently defaulting to store ### Other - Show host details (hostname, platform, CPU count, memory) when discovering new compute workers - Move frontend `html.zip` from source tree into build directory - Add format specifications for Compact Binary and Compressed Buffer wire formats - Add `WriteCompactBinaryObject` to zencore - Extended `ConsoleTui` with additional functionality - Add `--vscode` option to `xmake sln` for clangd / `compile_commands.json` support - Disable compute/horde/nomad in release builds (not yet production-ready) - Disable unintended `ASIO_HAS_IO_URING` enablement - Fix crashpad patch missing leading whitespace - Clean up code triggering gcc false positives
Diffstat (limited to 'src')
-rw-r--r--src/zen/cmds/exec_cmd.cpp1858
-rw-r--r--src/zen/cmds/exec_cmd.h141
-rw-r--r--src/zen/cmds/service_cmd.cpp25
-rw-r--r--src/zen/cmds/workspaces_cmd.cpp66
-rw-r--r--src/zen/zen.cpp131
-rw-r--r--src/zen/zen.h6
-rw-r--r--src/zencompute/CLAUDE.md39
-rw-r--r--src/zencompute/computeservice.cpp384
-rw-r--r--src/zencompute/httpcomputeservice.cpp760
-rw-r--r--src/zencompute/include/zencompute/computeservice.h69
-rw-r--r--src/zencompute/include/zencompute/httpcomputeservice.h13
-rw-r--r--src/zencompute/runners/functionrunner.cpp44
-rw-r--r--src/zencompute/runners/functionrunner.h53
-rw-r--r--src/zencompute/runners/localrunner.cpp10
-rw-r--r--src/zencompute/runners/localrunner.h2
-rw-r--r--src/zencompute/runners/remotehttprunner.cpp399
-rw-r--r--src/zencompute/runners/remotehttprunner.h23
-rw-r--r--src/zencore/compactbinaryfile.cpp11
-rw-r--r--src/zencore/include/zencore/compactbinaryfile.h1
-rw-r--r--src/zencore/include/zencore/logging/ansicolorsink.h3
-rw-r--r--src/zencore/include/zencore/logging/formatter.h6
-rw-r--r--src/zencore/include/zencore/logging/helpers.h77
-rw-r--r--src/zencore/include/zencore/logging/logmsg.h3
-rw-r--r--src/zencore/include/zencore/testing.h5
-rw-r--r--src/zencore/logging/ansicolorsink.cpp287
-rw-r--r--src/zencore/testing.cpp50
-rw-r--r--src/zenhttp/clients/httpwsclient.cpp30
-rw-r--r--src/zenhttp/include/zenhttp/httpwsclient.h10
-rw-r--r--src/zenhttp/include/zenhttp/websocket.h2
-rw-r--r--src/zenserver-test/compute-tests.cpp397
-rw-r--r--src/zenserver-test/zenserver-test.cpp4
-rw-r--r--src/zenserver/compute/computeserver.cpp16
-rw-r--r--src/zenserver/frontend/html/compute/compute.html6
-rw-r--r--src/zenserver/frontend/html/compute/hub.html7
-rw-r--r--src/zenserver/frontend/html/compute/orchestrator.html7
-rw-r--r--src/zenserver/frontend/html/util/sanitize.js9
-rw-r--r--src/zenserver/main.cpp5
-rw-r--r--src/zenserver/xmake.lua10
-rw-r--r--src/zenutil/consoletui.cpp65
-rw-r--r--src/zenutil/include/zenutil/consoletui.h28
-rw-r--r--src/zenutil/include/zenutil/logging/fullformatter.h197
-rw-r--r--src/zenutil/include/zenutil/logging/jsonformatter.h148
-rw-r--r--src/zenutil/include/zenutil/logging/rotatingfilesink.h230
-rw-r--r--src/zenutil/logging/fullformatter.cpp235
-rw-r--r--src/zenutil/logging/jsonformatter.cpp198
-rw-r--r--src/zenutil/logging/logging.cpp (renamed from src/zenutil/logging.cpp)6
-rw-r--r--src/zenutil/logging/rotatingfilesink.cpp249
47 files changed, 4064 insertions, 2261 deletions
diff --git a/src/zen/cmds/exec_cmd.cpp b/src/zen/cmds/exec_cmd.cpp
index cbc153e07..30e860a3f 100644
--- a/src/zen/cmds/exec_cmd.cpp
+++ b/src/zen/cmds/exec_cmd.cpp
@@ -18,6 +18,7 @@
#include <zencore/stream.h>
#include <zencore/string.h>
#include <zencore/system.h>
+#include <zencore/thread.h>
#include <zencore/timer.h>
#include <zenhttp/httpclient.h>
#include <zenhttp/packageformat.h>
@@ -41,255 +42,122 @@ struct hash<zen::IoHash> : public zen::IoHash::Hasher
namespace zen {
-ExecCommand::ExecCommand()
-{
- m_Options.add_options()("h,help", "Print help");
- m_Options.add_option("", "u", "hosturl", kHostUrlHelp, cxxopts::value(m_HostName), "<hosturl>");
- m_Options.add_option("", "", "log", "Action log directory", cxxopts::value(m_RecordingLogPath), "<path>");
- m_Options.add_option("", "p", "path", "Recording path (directory or .actionlog file)", cxxopts::value(m_RecordingPath), "<path>");
- m_Options.add_option("", "", "offset", "Recording replay start offset", cxxopts::value(m_Offset), "<offset>");
- m_Options.add_option("", "", "stride", "Recording replay stride", cxxopts::value(m_Stride), "<stride>");
- m_Options.add_option("", "", "limit", "Recording replay limit", cxxopts::value(m_Limit), "<limit>");
- m_Options.add_option("", "", "beacon", "Beacon path", cxxopts::value(m_BeaconPath), "<path>");
- m_Options.add_option("", "", "orch", "Orchestrator URL for worker discovery", cxxopts::value(m_OrchestratorUrl), "<url>");
- m_Options.add_option("",
- "",
- "mode",
- "Select execution mode (http,inproc,dump,direct,beacon,buildlog)",
- cxxopts::value(m_Mode)->default_value("http"),
- "<string>");
- m_Options
- .add_option("", "", "dump-actions", "Dump each action to console as it is dispatched", cxxopts::value(m_DumpActions), "<bool>");
- m_Options.add_option("", "o", "output", "Save action results to directory", cxxopts::value(m_OutputPath), "<path>");
- m_Options.add_option("", "", "binary", "Write output as binary packages instead of YAML", cxxopts::value(m_Binary), "<bool>");
- m_Options.add_option("", "", "quiet", "Quiet mode (less logging)", cxxopts::value(m_Quiet), "<bool>");
- m_Options.parse_positional("mode");
-}
-
-ExecCommand::~ExecCommand()
-{
-}
-
-void
-ExecCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
-{
- // Configure
-
- if (!ParseOptions(argc, argv))
- {
- return;
- }
-
- m_HostName = ResolveTargetHostSpec(m_HostName);
-
- if (m_RecordingPath.empty())
- {
- throw OptionParseException("replay path is required!", m_Options.help());
- }
-
- m_VerboseLogging = GlobalOptions.IsVerbose;
- m_QuietLogging = m_Quiet && !m_VerboseLogging;
-
- enum ExecMode
- {
- kHttp,
- kDirect,
- kInproc,
- kDump,
- kBeacon,
- kBuildLog
- } Mode;
-
- if (m_Mode == "http"sv)
- {
- Mode = kHttp;
- }
- else if (m_Mode == "direct"sv)
- {
- Mode = kDirect;
- }
- else if (m_Mode == "inproc"sv)
- {
- Mode = kInproc;
- }
- else if (m_Mode == "dump"sv)
- {
- Mode = kDump;
- }
- else if (m_Mode == "beacon"sv)
- {
- Mode = kBeacon;
- }
- else if (m_Mode == "buildlog"sv)
- {
- Mode = kBuildLog;
- }
- else
- {
- throw OptionParseException("invalid mode specified!", m_Options.help());
- }
+namespace {
- // Gather information from recording path
-
- std::unique_ptr<zen::compute::RecordingReader> Reader;
- std::unique_ptr<zen::compute::UeRecordingReader> UeReader;
-
- std::filesystem::path RecordingPath{m_RecordingPath};
-
- if (!std::filesystem::is_directory(RecordingPath))
- {
- throw OptionParseException("replay path should be a directory path!", m_Options.help());
- }
- else
+ static std::string EscapeHtml(std::string_view Input)
{
- if (std::filesystem::is_directory(RecordingPath / "cid"))
+ std::string Out;
+ Out.reserve(Input.size());
+ for (char C : Input)
{
- Reader = std::make_unique<zen::compute::RecordingReader>(RecordingPath);
- m_WorkerMap = Reader->ReadWorkers();
- m_ChunkResolver = Reader.get();
- m_RecordingReader = Reader.get();
- }
- else
- {
- UeReader = std::make_unique<zen::compute::UeRecordingReader>(RecordingPath);
- m_WorkerMap = UeReader->ReadWorkers();
- m_ChunkResolver = UeReader.get();
- m_RecordingReader = UeReader.get();
+ switch (C)
+ {
+ case '&':
+ Out += "&amp;";
+ break;
+ case '<':
+ Out += "&lt;";
+ break;
+ case '>':
+ Out += "&gt;";
+ break;
+ case '"':
+ Out += "&quot;";
+ break;
+ case '\'':
+ Out += "&#39;";
+ break;
+ default:
+ Out += C;
+ }
}
+ return Out;
}
- ZEN_CONSOLE("found {} workers, {} action items", m_WorkerMap.size(), m_RecordingReader->GetActionCount());
-
- for (auto& Kv : m_WorkerMap)
+ static std::string EscapeJson(std::string_view Input)
{
- CbObject WorkerDesc = Kv.second.GetObject();
- const IoHash& WorkerId = Kv.first;
-
- RegisterWorkerFunctionsFromDescription(WorkerDesc, WorkerId);
-
- if (m_VerboseLogging)
+ std::string Out;
+ Out.reserve(Input.size());
+ for (char C : Input)
{
- zen::ExtendableStringBuilder<1024> ObjStr;
-# if 0
- zen::CompactBinaryToJson(WorkerDesc, ObjStr);
- ZEN_CONSOLE("worker {}: {}", WorkerId, ObjStr);
-# else
- zen::CompactBinaryToYaml(WorkerDesc, ObjStr);
- ZEN_CONSOLE("worker {}:\n{}", WorkerId, ObjStr);
-# endif
+ switch (C)
+ {
+ case '"':
+ Out += "\\\"";
+ break;
+ case '\\':
+ Out += "\\\\";
+ break;
+ case '\n':
+ Out += "\\n";
+ break;
+ case '\r':
+ Out += "\\r";
+ break;
+ case '\t':
+ Out += "\\t";
+ break;
+ default:
+ if (static_cast<unsigned char>(C) < 0x20)
+ {
+ Out += fmt::format("\\u{:04x}", static_cast<unsigned>(static_cast<unsigned char>(C)));
+ }
+ else
+ {
+ Out += C;
+ }
+ }
}
+ return Out;
}
- if (m_VerboseLogging)
- {
- EmitFunctionList(m_FunctionList);
- }
-
- // Iterate over work items and dispatch or log them
-
- int ReturnValue = 0;
-
- Stopwatch ExecTimer;
-
- switch (Mode)
- {
- case kHttp:
- // Forward requests to HTTP function service
- ReturnValue = HttpExecute();
- break;
-
- case kDirect:
- // Not currently supported
- ReturnValue = LocalMessagingExecute();
- break;
-
- case kInproc:
- // Handle execution in-core (by spawning child processes)
- ReturnValue = InProcessExecute();
- break;
-
- case kDump:
- // Dump high level information about actions to console
- ReturnValue = DumpWorkItems();
- break;
-
- case kBeacon:
- ReturnValue = BeaconExecute();
- break;
-
- case kBuildLog:
- ReturnValue = BuildActionsLog();
- break;
-
- default:
- ZEN_ERROR("Unknown operating mode! No work submitted");
-
- ReturnValue = 1;
- }
+} // namespace
- ZEN_CONSOLE("complete - took {}", NiceTimeSpanMs(ExecTimer.GetElapsedTimeMs()));
-
- if (!ReturnValue)
- {
- ZEN_CONSOLE("all work items completed successfully");
- }
- else
- {
- ZEN_CONSOLE("some work items failed (code {})", ReturnValue);
- }
-}
+//////////////////////////////////////////////////////////////////////////
+// ExecSessionConfig — read-only configuration for a session run
-int
-ExecCommand::InProcessExecute()
+struct ExecSessionConfig
{
- ZEN_ASSERT(m_ChunkResolver);
- ChunkResolver& Resolver = *m_ChunkResolver;
+ zen::ChunkResolver& Resolver;
+ zen::compute::RecordingReaderBase& RecordingReader;
+ const std::unordered_map<zen::IoHash, zen::CbPackage>& WorkerMap;
+ std::vector<ExecFunctionDefinition>& FunctionList; // mutable for EmitFunctionListOnce
+ std::string_view OrchestratorUrl;
+ const std::filesystem::path& OutputPath;
+ int Offset = 0;
+ int Stride = 1;
+ int Limit = 0;
+ bool Verbose = false;
+ bool Quiet = false;
+ bool DumpActions = false;
+ bool Binary = false;
+};
- zen::compute::ComputeServiceSession ComputeSession(Resolver);
+//////////////////////////////////////////////////////////////////////////
+// ExecSessionRunner — owns per-run state, drives the session lifecycle
- std::filesystem::path TempPath = std::filesystem::absolute(".zen_temp");
- ComputeSession.AddLocalRunner(Resolver, TempPath);
+class ExecSessionRunner
+{
+public:
+ ExecSessionRunner(zen::compute::ComputeServiceSession& Session, const ExecSessionConfig& Config);
+ int Run();
- return ExecUsingSession(ComputeSession);
-}
+private:
+ // Types
-int
-ExecCommand::ExecUsingSession(zen::compute::ComputeServiceSession& ComputeSession)
-{
struct JobTracker
{
public:
- inline void Insert(int LsnField)
- {
- RwLock::ExclusiveLockScope _(Lock);
- PendingJobs.insert(LsnField);
- }
-
- inline bool IsEmpty() const
- {
- RwLock::SharedLockScope _(Lock);
- return PendingJobs.empty();
- }
-
- inline void Remove(int CompleteLsn)
- {
- RwLock::ExclusiveLockScope _(Lock);
- PendingJobs.erase(CompleteLsn);
- }
-
- inline size_t GetSize() const
- {
- RwLock::SharedLockScope _(Lock);
- return PendingJobs.size();
- }
+ void Insert(int LsnField);
+ bool IsEmpty() const;
+ void Remove(int CompleteLsn);
+ size_t GetSize() const;
private:
mutable RwLock Lock;
std::unordered_set<int> PendingJobs;
};
- JobTracker PendingJobs;
-
struct ActionSummaryEntry
{
int32_t Lsn = 0;
@@ -307,664 +175,471 @@ ExecCommand::ExecUsingSession(zen::compute::ComputeServiceSession& ComputeSessio
std::string ExecutionLocation;
};
- std::mutex SummaryLock;
- std::unordered_map<int32_t, ActionSummaryEntry> SummaryEntries;
+ // Methods
+
+ std::string RegisterOrchestratorClient();
+ void SendOrchestratorHeartbeat();
+ void CompleteOrchestratorClient();
+ void DrainCompletedJobs();
+ void SaveResultOutput(int32_t CompleteLsn, CbPackage& ResultPackage);
+ void SaveActionOutput(int32_t Lsn, int RecordingIndex, const IoHash& ActionId, const CbObject& ActionObject);
+ void WriteSummaryFiles();
+ void EmitFunctionListOnce();
+
+ // State
+
+ zen::compute::ComputeServiceSession& m_Session;
+ ExecSessionConfig m_Config;
+ JobTracker m_PendingJobs;
+ std::mutex m_SummaryLock;
+ std::unordered_map<int32_t, ActionSummaryEntry> m_SummaryEntries;
+ int m_QueueId = 0;
+ std::string m_OrchestratorClientId;
+ Stopwatch m_OrchestratorHeartbeatTimer;
+ bool m_FunctionListEmittedOnce = false;
+ std::atomic<int> m_IsDraining{0};
+};
- ComputeSession.WaitUntilReady();
+//////////////////////////////////////////////////////////////////////////
+// ExecSessionRunner::JobTracker
- // Register as a client with the orchestrator (best-effort)
+void
+ExecSessionRunner::JobTracker::Insert(int LsnField)
+{
+ RwLock::ExclusiveLockScope _(Lock);
+ PendingJobs.insert(LsnField);
+}
- std::string OrchestratorClientId;
+bool
+ExecSessionRunner::JobTracker::IsEmpty() const
+{
+ RwLock::SharedLockScope _(Lock);
+ return PendingJobs.empty();
+}
- if (!m_OrchestratorUrl.empty())
+void
+ExecSessionRunner::JobTracker::Remove(int CompleteLsn)
+{
+ RwLock::ExclusiveLockScope _(Lock);
+ PendingJobs.erase(CompleteLsn);
+}
+
+size_t
+ExecSessionRunner::JobTracker::GetSize() const
+{
+ RwLock::SharedLockScope _(Lock);
+ return PendingJobs.size();
+}
+
+//////////////////////////////////////////////////////////////////////////
+// ExecSessionRunner implementation
+
+ExecSessionRunner::ExecSessionRunner(zen::compute::ComputeServiceSession& Session, const ExecSessionConfig& Config)
+: m_Session(Session)
+, m_Config(Config)
+{
+}
+
+std::string
+ExecSessionRunner::RegisterOrchestratorClient()
+{
+ if (m_Config.OrchestratorUrl.empty())
{
- try
- {
- HttpClient OrchestratorClient(m_OrchestratorUrl);
+ return {};
+ }
- CbObjectWriter Ann;
- Ann << "session_id"sv << GetSessionId();
- Ann << "hostname"sv << std::string_view(GetMachineName());
+ try
+ {
+ HttpClient OrchestratorClient(m_Config.OrchestratorUrl);
- CbObjectWriter Meta;
- Meta << "source"sv
- << "zen-exec"sv;
- Ann << "metadata"sv << Meta.Save();
+ CbObjectWriter Ann;
+ Ann << "session_id"sv << GetSessionId();
+ Ann << "hostname"sv << std::string_view(GetMachineName());
- auto Resp = OrchestratorClient.Post("/orch/clients", Ann.Save());
- if (Resp.IsSuccess())
- {
- OrchestratorClientId = std::string(Resp.AsObject()["id"].AsString());
- ZEN_CONSOLE_INFO("registered with orchestrator as {}", OrchestratorClientId);
- }
- else
- {
- ZEN_WARN("failed to register with orchestrator (status {})", static_cast<int>(Resp.StatusCode));
- }
+ CbObjectWriter Meta;
+ Meta << "source"sv
+ << "zen-exec"sv;
+ Ann << "metadata"sv << Meta.Save();
+
+ auto Resp = OrchestratorClient.Post("/orch/clients", Ann.Save());
+ if (Resp.IsSuccess())
+ {
+ std::string ClientId = std::string(Resp.AsObject()["id"].AsString());
+ ZEN_CONSOLE_INFO("registered with orchestrator as {}", ClientId);
+ return ClientId;
}
- catch (const std::exception& Ex)
+ else
{
- ZEN_WARN("failed to register with orchestrator: {}", Ex.what());
+ ZEN_WARN("failed to register with orchestrator (status {})", static_cast<int>(Resp.StatusCode));
}
}
+ catch (const std::exception& Ex)
+ {
+ ZEN_WARN("failed to register with orchestrator: {}", Ex.what());
+ }
- Stopwatch OrchestratorHeartbeatTimer;
+ return {};
+}
- auto SendOrchestratorHeartbeat = [&] {
- if (OrchestratorClientId.empty() || OrchestratorHeartbeatTimer.GetElapsedTimeMs() < 30'000)
- {
- return;
- }
- OrchestratorHeartbeatTimer.Reset();
+void
+ExecSessionRunner::SendOrchestratorHeartbeat()
+{
+ if (m_OrchestratorClientId.empty() || m_OrchestratorHeartbeatTimer.GetElapsedTimeMs() < 30'000)
+ {
+ return;
+ }
+ m_OrchestratorHeartbeatTimer.Reset();
+ try
+ {
+ HttpClient OrchestratorClient(m_Config.OrchestratorUrl);
+ std::ignore = OrchestratorClient.Post(fmt::format("/orch/clients/{}/update", m_OrchestratorClientId));
+ }
+ catch (...)
+ {
+ }
+}
+
+void
+ExecSessionRunner::CompleteOrchestratorClient()
+{
+ if (!m_OrchestratorClientId.empty())
+ {
try
{
- HttpClient OrchestratorClient(m_OrchestratorUrl);
- std::ignore = OrchestratorClient.Post(fmt::format("/orch/clients/{}/update", OrchestratorClientId));
+ HttpClient OrchestratorClient(m_Config.OrchestratorUrl);
+ std::ignore = OrchestratorClient.Post(fmt::format("/orch/clients/{}/complete", m_OrchestratorClientId));
}
catch (...)
{
}
- };
-
- auto ClientCleanup = MakeGuard([&] {
- if (!OrchestratorClientId.empty())
- {
- try
- {
- HttpClient OrchestratorClient(m_OrchestratorUrl);
- std::ignore = OrchestratorClient.Post(fmt::format("/orch/clients/{}/complete", OrchestratorClientId));
- }
- catch (...)
- {
- }
- }
- });
-
- // Create a queue to group all actions from this exec session
-
- CbObjectWriter Metadata;
- Metadata << "source"sv
- << "zen-exec"sv;
-
- auto QueueResult = ComputeSession.CreateQueue("zen-exec", Metadata.Save());
- const int QueueId = QueueResult.QueueId;
- if (!QueueId)
- {
- ZEN_ERROR("failed to create compute queue");
- return 1;
}
+}
- auto QueueCleanup = MakeGuard([&] { ComputeSession.DeleteQueue(QueueId); });
-
- if (!m_OutputPath.empty())
+void
+ExecSessionRunner::DrainCompletedJobs()
+{
+ if (m_IsDraining.exchange(1))
{
- zen::CreateDirectories(m_OutputPath);
+ return;
}
- std::atomic<int> IsDraining{0};
+ auto _ = MakeGuard([&] { m_IsDraining.store(0, std::memory_order_release); });
- auto DrainCompletedJobs = [&] {
- if (IsDraining.exchange(1))
- {
- return;
- }
-
- auto _ = MakeGuard([&] { IsDraining.store(0, std::memory_order_release); });
+ CbObjectWriter Cbo;
+ m_Session.GetQueueCompleted(m_QueueId, Cbo);
- CbObjectWriter Cbo;
- ComputeSession.GetQueueCompleted(QueueId, Cbo);
-
- if (CbObject Completed = Cbo.Save())
+ if (CbObject Completed = Cbo.Save())
+ {
+ for (auto& It : Completed["completed"sv])
{
- for (auto& It : Completed["completed"sv])
- {
- int32_t CompleteLsn = It.AsInt32();
+ int32_t CompleteLsn = It.AsInt32();
- CbPackage ResultPackage;
- HttpResponseCode Response = ComputeSession.GetActionResult(CompleteLsn, /* out */ ResultPackage);
+ CbPackage ResultPackage;
+ HttpResponseCode Response = m_Session.GetActionResult(CompleteLsn, /* out */ ResultPackage);
- if (Response == HttpResponseCode::OK)
+ if (Response == HttpResponseCode::OK)
+ {
+ if (!m_Config.OutputPath.empty() && ResultPackage)
{
- if (!m_OutputPath.empty() && ResultPackage)
- {
- int OutputAttachments = 0;
- uint64_t OutputBytes = 0;
-
- if (!m_Binary)
- {
- // Write the root object as YAML
- ExtendableStringBuilder<4096> YamlStr;
- CompactBinaryToYaml(ResultPackage.GetObject(), YamlStr);
-
- std::string_view Yaml = YamlStr;
- zen::WriteFile(m_OutputPath / fmt::format("{}.result.yaml", CompleteLsn),
- IoBuffer(IoBuffer::Clone, Yaml.data(), Yaml.size()));
-
- // Write decompressed attachments
- auto Attachments = ResultPackage.GetAttachments();
-
- if (!Attachments.empty())
- {
- std::filesystem::path AttDir = m_OutputPath / fmt::format("{}.result.attachments", CompleteLsn);
- zen::CreateDirectories(AttDir);
-
- for (const CbAttachment& Att : Attachments)
- {
- ++OutputAttachments;
-
- IoHash AttHash = Att.GetHash();
-
- if (Att.IsCompressedBinary())
- {
- SharedBuffer Decompressed = Att.AsCompressedBinary().Decompress();
- OutputBytes += Decompressed.GetSize();
- zen::WriteFile(AttDir / AttHash.ToHexString(),
- IoBuffer(IoBuffer::Clone, Decompressed.GetData(), Decompressed.GetSize()));
- }
- else
- {
- SharedBuffer Binary = Att.AsBinary();
- OutputBytes += Binary.GetSize();
- zen::WriteFile(AttDir / AttHash.ToHexString(),
- IoBuffer(IoBuffer::Clone, Binary.GetData(), Binary.GetSize()));
- }
- }
- }
-
- if (!m_QuietLogging)
- {
- ZEN_CONSOLE("saved result: {}/{}.result.yaml ({} attachments)",
- m_OutputPath.string(),
- CompleteLsn,
- OutputAttachments);
- }
- }
- else
- {
- CompositeBuffer Serialized = FormatPackageMessageBuffer(ResultPackage);
- zen::WriteFile(m_OutputPath / fmt::format("{}.result.pkg", CompleteLsn), std::move(Serialized));
-
- for (const CbAttachment& Att : ResultPackage.GetAttachments())
- {
- ++OutputAttachments;
- OutputBytes += Att.AsBinary().GetSize();
- }
-
- if (!m_QuietLogging)
- {
- ZEN_CONSOLE("saved result: {}/{}.result.pkg", m_OutputPath.string(), CompleteLsn);
- }
- }
-
- std::lock_guard Lock(SummaryLock);
- if (auto It2 = SummaryEntries.find(CompleteLsn); It2 != SummaryEntries.end())
- {
- It2->second.OutputAttachments = OutputAttachments;
- It2->second.OutputBytes = OutputBytes;
- }
- }
+ SaveResultOutput(CompleteLsn, ResultPackage);
+ }
- PendingJobs.Remove(CompleteLsn);
+ m_PendingJobs.Remove(CompleteLsn);
- ZEN_CONSOLE("completed: LSN {} ({} still pending)", CompleteLsn, PendingJobs.GetSize());
- }
+ ZEN_CONSOLE("completed: LSN {} ({} still pending)", CompleteLsn, m_PendingJobs.GetSize());
}
}
- };
-
- // Describe workers
+ }
+}
- ZEN_CONSOLE("describing {} workers", m_WorkerMap.size());
+void
+ExecSessionRunner::SaveResultOutput(int32_t CompleteLsn, CbPackage& ResultPackage)
+{
+ int OutputAttachments = 0;
+ uint64_t OutputBytes = 0;
- for (auto Kv : m_WorkerMap)
+ if (!m_Config.Binary)
{
- CbPackage WorkerDesc = Kv.second;
+ // Write the root object as YAML
+ ExtendableStringBuilder<4096> YamlStr;
+ CompactBinaryToYaml(ResultPackage.GetObject(), YamlStr);
- ComputeSession.RegisterWorker(WorkerDesc);
- }
+ std::string_view Yaml = YamlStr;
+ zen::WriteFile(m_Config.OutputPath / fmt::format("{}.result.yaml", CompleteLsn),
+ IoBuffer(IoBuffer::Clone, Yaml.data(), Yaml.size()));
- // Then submit work items
+ // Write decompressed attachments
+ auto Attachments = ResultPackage.GetAttachments();
- int FailedWorkCounter = 0;
- size_t RemainingWorkItems = m_RecordingReader->GetActionCount();
- int SubmittedWorkItems = 0;
-
- ZEN_CONSOLE("submitting {} work items", RemainingWorkItems);
+ if (!Attachments.empty())
+ {
+ std::filesystem::path AttDir = m_Config.OutputPath / fmt::format("{}.result.attachments", CompleteLsn);
+ zen::CreateDirectories(AttDir);
- int OffsetCounter = m_Offset;
- int StrideCounter = m_Stride;
+ for (const CbAttachment& Att : Attachments)
+ {
+ ++OutputAttachments;
- auto ShouldSchedule = [&]() -> bool {
- if (m_Limit && SubmittedWorkItems >= m_Limit)
- {
- // Limit reached, ignore
+ IoHash AttHash = Att.GetHash();
- return false;
+ if (Att.IsCompressedBinary())
+ {
+ SharedBuffer Decompressed = Att.AsCompressedBinary().Decompress();
+ OutputBytes += Decompressed.GetSize();
+ zen::WriteFile(AttDir / AttHash.ToHexString(),
+ IoBuffer(IoBuffer::Clone, Decompressed.GetData(), Decompressed.GetSize()));
+ }
+ else
+ {
+ SharedBuffer Binary = Att.AsBinary();
+ OutputBytes += Binary.GetSize();
+ zen::WriteFile(AttDir / AttHash.ToHexString(), IoBuffer(IoBuffer::Clone, Binary.GetData(), Binary.GetSize()));
+ }
+ }
}
- if (OffsetCounter && OffsetCounter--)
+ if (!m_Config.Quiet)
{
- // Still in offset, ignore
-
- return false;
+ ZEN_CONSOLE("saved result: {}/{}.result.yaml ({} attachments)", m_Config.OutputPath.string(), CompleteLsn, OutputAttachments);
}
+ }
+ else
+ {
+ CompositeBuffer Serialized = FormatPackageMessageBuffer(ResultPackage);
+ zen::WriteFile(m_Config.OutputPath / fmt::format("{}.result.pkg", CompleteLsn), std::move(Serialized));
- if (--StrideCounter == 0)
+ for (const CbAttachment& Att : ResultPackage.GetAttachments())
{
- StrideCounter = m_Stride;
-
- return true;
+ ++OutputAttachments;
+ OutputBytes += Att.AsBinary().GetSize();
}
- return false;
- };
-
- int TargetParallelism = 8;
+ if (!m_Config.Quiet)
+ {
+ ZEN_CONSOLE("saved result: {}/{}.result.pkg", m_Config.OutputPath.string(), CompleteLsn);
+ }
+ }
- if (OffsetCounter || StrideCounter || m_Limit)
+ std::lock_guard Lock(m_SummaryLock);
+ if (auto It2 = m_SummaryEntries.find(CompleteLsn); It2 != m_SummaryEntries.end())
{
- TargetParallelism = 1;
+ It2->second.OutputAttachments = OutputAttachments;
+ It2->second.OutputBytes = OutputBytes;
}
+}
- std::atomic<int> RecordingIndex{0};
+void
+ExecSessionRunner::SaveActionOutput(int32_t Lsn, int RecordingIndex, const IoHash& ActionId, const CbObject& ActionObject)
+{
+ ActionSummaryEntry Entry;
+ Entry.Lsn = Lsn;
+ Entry.RecordingIndex = RecordingIndex;
+ Entry.ActionId = ActionId;
+ Entry.FunctionName = std::string(ActionObject["Function"sv].AsString());
- m_RecordingReader->IterateActions(
- [&](CbObject ActionObject, const IoHash& ActionId) {
- // Enqueue job
+ if (!m_Config.Binary)
+ {
+ // Write action object as YAML
+ ExtendableStringBuilder<4096> YamlStr;
+ CompactBinaryToYaml(ActionObject, YamlStr);
- const int CurrentRecordingIndex = RecordingIndex++;
+ std::string_view Yaml = YamlStr;
+ zen::WriteFile(m_Config.OutputPath / fmt::format("{}.action.yaml", Lsn), IoBuffer(IoBuffer::Clone, Yaml.data(), Yaml.size()));
- Stopwatch SubmitTimer;
+ // Write decompressed input attachments
+ std::filesystem::path AttDir = m_Config.OutputPath / fmt::format("{}.action.attachments", Lsn);
+ bool AttDirCreated = false;
- const int Priority = 0;
+ ActionObject.IterateAttachments([&](CbFieldView Field) {
+ IoHash AttachCid = Field.AsAttachment();
+ ++Entry.InputAttachments;
- if (ShouldSchedule())
+ if (IoBuffer ChunkData = m_Config.Resolver.FindChunkByCid(AttachCid))
{
- if (m_VerboseLogging)
- {
- int AttachmentCount = 0;
- uint64_t AttachmentBytes = 0;
- eastl::hash_set<IoHash> ReferencedChunks;
+ IoHash RawHash;
+ uint64_t RawSize = 0;
+ CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(ChunkData), RawHash, RawSize);
+ SharedBuffer Decompressed = Compressed.Decompress();
- ActionObject.IterateAttachments([&](CbFieldView Field) {
- IoHash AttachData = Field.AsAttachment();
+ Entry.InputBytes += Decompressed.GetSize();
- ReferencedChunks.insert(AttachData);
- ++AttachmentCount;
-
- if (IoBuffer ChunkData = m_ChunkResolver->FindChunkByCid(AttachData))
- {
- AttachmentBytes += ChunkData.GetSize();
- }
- });
-
- zen::ExtendableStringBuilder<1024> ObjStr;
- zen::CompactBinaryToJson(ActionObject, ObjStr);
- ZEN_CONSOLE("work item {} ({} attachments, {} bytes): {}",
- ActionId,
- AttachmentCount,
- NiceBytes(AttachmentBytes),
- ObjStr);
- }
-
- if (m_DumpActions)
+ if (!AttDirCreated)
{
- int AttachmentCount = 0;
- uint64_t AttachmentBytes = 0;
-
- ActionObject.IterateAttachments([&](CbFieldView Field) {
- IoHash AttachData = Field.AsAttachment();
-
- ++AttachmentCount;
-
- if (IoBuffer ChunkData = m_ChunkResolver->FindChunkByCid(AttachData))
- {
- AttachmentBytes += ChunkData.GetSize();
- }
- });
-
- zen::ExtendableStringBuilder<1024> ObjStr;
- zen::CompactBinaryToYaml(ActionObject, ObjStr);
- ZEN_CONSOLE("action {} ({} attachments, {}):\n{}", ActionId, AttachmentCount, NiceBytes(AttachmentBytes), ObjStr);
+ zen::CreateDirectories(AttDir);
+ AttDirCreated = true;
}
- if (zen::compute::ComputeServiceSession::EnqueueResult EnqueueResult =
- ComputeSession.EnqueueActionToQueue(QueueId, ActionObject, Priority))
- {
- const int32_t LsnField = EnqueueResult.Lsn;
-
- --RemainingWorkItems;
- ++SubmittedWorkItems;
-
- if (!m_QuietLogging)
- {
- ZEN_CONSOLE("submitted work item #{} - LSN {} - {}. {} remaining",
- SubmittedWorkItems,
- LsnField,
- NiceTimeSpanMs(SubmitTimer.GetElapsedTimeMs()),
- RemainingWorkItems);
- }
-
- if (!m_OutputPath.empty())
- {
- ActionSummaryEntry Entry;
- Entry.Lsn = LsnField;
- Entry.RecordingIndex = CurrentRecordingIndex;
- Entry.ActionId = ActionId;
- Entry.FunctionName = std::string(ActionObject["Function"sv].AsString());
-
- if (!m_Binary)
- {
- // Write action object as YAML
- ExtendableStringBuilder<4096> YamlStr;
- CompactBinaryToYaml(ActionObject, YamlStr);
-
- std::string_view Yaml = YamlStr;
- zen::WriteFile(m_OutputPath / fmt::format("{}.action.yaml", LsnField),
- IoBuffer(IoBuffer::Clone, Yaml.data(), Yaml.size()));
-
- // Write decompressed input attachments
- std::filesystem::path AttDir = m_OutputPath / fmt::format("{}.action.attachments", LsnField);
- bool AttDirCreated = false;
-
- ActionObject.IterateAttachments([&](CbFieldView Field) {
- IoHash AttachCid = Field.AsAttachment();
- ++Entry.InputAttachments;
-
- if (IoBuffer ChunkData = m_ChunkResolver->FindChunkByCid(AttachCid))
- {
- IoHash RawHash;
- uint64_t RawSize = 0;
- CompressedBuffer Compressed =
- CompressedBuffer::FromCompressed(SharedBuffer(ChunkData), RawHash, RawSize);
- SharedBuffer Decompressed = Compressed.Decompress();
-
- Entry.InputBytes += Decompressed.GetSize();
-
- if (!AttDirCreated)
- {
- zen::CreateDirectories(AttDir);
- AttDirCreated = true;
- }
-
- zen::WriteFile(AttDir / AttachCid.ToHexString(),
- IoBuffer(IoBuffer::Clone, Decompressed.GetData(), Decompressed.GetSize()));
- }
- });
-
- if (!m_QuietLogging)
- {
- ZEN_CONSOLE("saved action: {}/{}.action.yaml ({} attachments)",
- m_OutputPath.string(),
- LsnField,
- Entry.InputAttachments);
- }
- }
- else
- {
- // Build a CbPackage from the action and write as .pkg
- CbPackage ActionPackage;
- ActionPackage.SetObject(ActionObject);
-
- ActionObject.IterateAttachments([&](CbFieldView Field) {
- IoHash AttachCid = Field.AsAttachment();
- ++Entry.InputAttachments;
-
- if (IoBuffer ChunkData = m_ChunkResolver->FindChunkByCid(AttachCid))
- {
- IoHash RawHash;
- uint64_t RawSize = 0;
- CompressedBuffer Compressed =
- CompressedBuffer::FromCompressed(SharedBuffer(ChunkData), RawHash, RawSize);
-
- Entry.InputBytes += ChunkData.GetSize();
- ActionPackage.AddAttachment(CbAttachment(std::move(Compressed), RawHash));
- }
- });
-
- CompositeBuffer Serialized = FormatPackageMessageBuffer(ActionPackage);
- zen::WriteFile(m_OutputPath / fmt::format("{}.action.pkg", LsnField), std::move(Serialized));
-
- if (!m_QuietLogging)
- {
- ZEN_CONSOLE("saved action: {}/{}.action.pkg", m_OutputPath.string(), LsnField);
- }
- }
-
- std::lock_guard Lock(SummaryLock);
- SummaryEntries.emplace(LsnField, std::move(Entry));
- }
+ zen::WriteFile(AttDir / AttachCid.ToHexString(), IoBuffer(IoBuffer::Clone, Decompressed.GetData(), Decompressed.GetSize()));
+ }
+ });
- PendingJobs.Insert(LsnField);
- }
- else
- {
- if (!m_QuietLogging)
- {
- std::string_view FunctionName = ActionObject["Function"sv].AsString();
- const Guid FunctionVersion = ActionObject["FunctionVersion"sv].AsUuid();
- const Guid BuildSystemVersion = ActionObject["BuildSystemVersion"sv].AsUuid();
+ if (!m_Config.Quiet)
+ {
+ ZEN_CONSOLE("saved action: {}/{}.action.yaml ({} attachments)", m_Config.OutputPath.string(), Lsn, Entry.InputAttachments);
+ }
+ }
+ else
+ {
+ // Build a CbPackage from the action and write as .pkg
+ CbPackage ActionPackage;
+ ActionPackage.SetObject(ActionObject);
- ZEN_ERROR(
- "failed to resolve function for work with (Function:{},FunctionVersion:{},BuildSystemVersion:{}). Work "
- "descriptor "
- "at: 'file://{}'",
- std::string(FunctionName),
- FunctionVersion,
- BuildSystemVersion,
- "<null>");
+ ActionObject.IterateAttachments([&](CbFieldView Field) {
+ IoHash AttachCid = Field.AsAttachment();
+ ++Entry.InputAttachments;
- EmitFunctionListOnce(m_FunctionList);
- }
+ if (IoBuffer ChunkData = m_Config.Resolver.FindChunkByCid(AttachCid))
+ {
+ IoHash RawHash;
+ uint64_t RawSize = 0;
+ CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(ChunkData), RawHash, RawSize);
- ++FailedWorkCounter;
- }
+ Entry.InputBytes += ChunkData.GetSize();
+ ActionPackage.AddAttachment(CbAttachment(std::move(Compressed), RawHash));
}
+ });
- // Check for completed work
+ CompositeBuffer Serialized = FormatPackageMessageBuffer(ActionPackage);
+ zen::WriteFile(m_Config.OutputPath / fmt::format("{}.action.pkg", Lsn), std::move(Serialized));
- DrainCompletedJobs();
- SendOrchestratorHeartbeat();
- },
- TargetParallelism);
+ if (!m_Config.Quiet)
+ {
+ ZEN_CONSOLE("saved action: {}/{}.action.pkg", m_Config.OutputPath.string(), Lsn);
+ }
+ }
- // Wait until all pending work is complete
+ std::lock_guard Lock(m_SummaryLock);
+ m_SummaryEntries.emplace(Lsn, std::move(Entry));
+}
- while (!PendingJobs.IsEmpty())
+void
+ExecSessionRunner::WriteSummaryFiles()
+{
+ if (m_Config.OutputPath.empty() || m_SummaryEntries.empty())
{
- // TODO: improve this logic
- zen::Sleep(500);
-
- DrainCompletedJobs();
- SendOrchestratorHeartbeat();
+ return;
}
// Merge timing data from queue history into summary entries
- if (!SummaryEntries.empty())
+ // RunnerAction::State indices (can't include functionrunner.h from here)
+ constexpr int kStateNew = 0;
+ constexpr int kStatePending = 1;
+ constexpr int kStateRunning = 3;
+ constexpr int kStateCompleted = 4; // first terminal state
+ constexpr int kStateCount = 8;
+
+ for (const auto& HistEntry : m_Session.GetQueueHistory(m_QueueId, 0))
{
- // RunnerAction::State indices (can't include functionrunner.h from here)
- constexpr int kStateNew = 0;
- constexpr int kStatePending = 1;
- constexpr int kStateRunning = 3;
- constexpr int kStateCompleted = 4; // first terminal state
- constexpr int kStateCount = 8;
-
- for (const auto& HistEntry : ComputeSession.GetQueueHistory(QueueId, 0))
+ std::lock_guard Lock(m_SummaryLock);
+ if (auto It = m_SummaryEntries.find(HistEntry.Lsn); It != m_SummaryEntries.end())
{
- std::lock_guard Lock(SummaryLock);
- if (auto It = SummaryEntries.find(HistEntry.Lsn); It != SummaryEntries.end())
+ // Find terminal state timestamp (Completed, Failed, Abandoned, or Cancelled)
+ uint64_t EndTick = 0;
+ for (int S = kStateCompleted; S < kStateCount; ++S)
{
- // Find terminal state timestamp (Completed, Failed, Abandoned, or Cancelled)
- uint64_t EndTick = 0;
- for (int S = kStateCompleted; S < kStateCount; ++S)
+ if (HistEntry.Timestamps[S] != 0)
{
- if (HistEntry.Timestamps[S] != 0)
- {
- EndTick = HistEntry.Timestamps[S];
- break;
- }
+ EndTick = HistEntry.Timestamps[S];
+ break;
}
- uint64_t StartTick = HistEntry.Timestamps[kStateNew];
- if (EndTick > StartTick)
- {
- It->second.WallSeconds = float(double(EndTick - StartTick) / double(TimeSpan::TicksPerSecond));
- }
- It->second.CpuSeconds = HistEntry.CpuSeconds;
- It->second.SubmittedTicks = HistEntry.Timestamps[kStatePending];
- It->second.StartedTicks = HistEntry.Timestamps[kStateRunning];
- It->second.ExecutionLocation = HistEntry.ExecutionLocation;
}
+ uint64_t StartTick = HistEntry.Timestamps[kStateNew];
+ if (EndTick > StartTick)
+ {
+ It->second.WallSeconds = float(double(EndTick - StartTick) / double(TimeSpan::TicksPerSecond));
+ }
+ It->second.CpuSeconds = HistEntry.CpuSeconds;
+ It->second.SubmittedTicks = HistEntry.Timestamps[kStatePending];
+ It->second.StartedTicks = HistEntry.Timestamps[kStateRunning];
+ It->second.ExecutionLocation = HistEntry.ExecutionLocation;
}
}
- // Write summary file if output path is set
+ // Sort entries by recording index
- if (!m_OutputPath.empty() && !SummaryEntries.empty())
+ std::vector<ActionSummaryEntry> Sorted;
+ Sorted.reserve(m_SummaryEntries.size());
+ for (auto& [_, Entry] : m_SummaryEntries)
{
- std::vector<ActionSummaryEntry> Sorted;
- Sorted.reserve(SummaryEntries.size());
- for (auto& [_, Entry] : SummaryEntries)
- {
- Sorted.push_back(std::move(Entry));
- }
-
- std::sort(Sorted.begin(), Sorted.end(), [](const ActionSummaryEntry& A, const ActionSummaryEntry& B) {
- return A.RecordingIndex < B.RecordingIndex;
- });
+ Sorted.push_back(std::move(Entry));
+ }
- auto FormatTimestamp = [](uint64_t Ticks) -> std::string {
- if (Ticks == 0)
- {
- return "-";
- }
- return DateTime(Ticks).ToString("%H:%M:%S.%s");
- };
-
- ExtendableStringBuilder<4096> Summary;
- Summary.Append(fmt::format("{:<8} {:<8} {:<40} {:<40} {:>8} {:>12} {:>8} {:>12} {:>8} {:>8} {:>12} {:>12} {:<24}\n",
- "LSN",
- "Index",
- "ActionId",
- "Function",
- "InAtt",
- "InBytes",
- "OutAtt",
- "OutBytes",
- "Wall(s)",
- "CPU(s)",
- "Submitted",
- "Started",
- "Location"));
- Summary.Append(fmt::format("{:-<8} {:-<8} {:-<40} {:-<40} {:-<8} {:-<12} {:-<8} {:-<12} {:-<8} {:-<8} {:-<12} {:-<12} {:-<24}\n",
- "",
- "",
- "",
- "",
- "",
- "",
- "",
- "",
- "",
- "",
- "",
- "",
- ""));
+ std::sort(Sorted.begin(), Sorted.end(), [](const ActionSummaryEntry& A, const ActionSummaryEntry& B) {
+ return A.RecordingIndex < B.RecordingIndex;
+ });
- for (const ActionSummaryEntry& Entry : Sorted)
+ auto FormatTimestamp = [](uint64_t Ticks) -> std::string {
+ if (Ticks == 0)
{
- Summary.Append(fmt::format("{:<8} {:<8} {:<40} {:<40} {:>8} {:>12} {:>8} {:>12} {:>8.2f} {:>8.2f} {:>12} {:>12} {:<24}\n",
- Entry.Lsn,
- Entry.RecordingIndex,
- Entry.ActionId,
- Entry.FunctionName,
- Entry.InputAttachments,
- NiceBytes(Entry.InputBytes),
- Entry.OutputAttachments,
- NiceBytes(Entry.OutputBytes),
- Entry.WallSeconds,
- Entry.CpuSeconds,
- FormatTimestamp(Entry.SubmittedTicks),
- FormatTimestamp(Entry.StartedTicks),
- Entry.ExecutionLocation));
+ return "-";
}
+ return DateTime(Ticks).ToString("%H:%M:%S.%s");
+ };
- std::filesystem::path SummaryPath = m_OutputPath / "summary.txt";
- std::string_view SummaryStr = Summary;
- zen::WriteFile(SummaryPath, IoBuffer(IoBuffer::Clone, SummaryStr.data(), SummaryStr.size()));
+ // Write summary.txt
+
+ ExtendableStringBuilder<4096> Summary;
+ Summary.Append(fmt::format("{:<8} {:<8} {:<40} {:<40} {:>8} {:>12} {:>8} {:>12} {:>8} {:>8} {:>12} {:>12} {:<24}\n",
+ "LSN",
+ "Index",
+ "ActionId",
+ "Function",
+ "InAtt",
+ "InBytes",
+ "OutAtt",
+ "OutBytes",
+ "Wall(s)",
+ "CPU(s)",
+ "Submitted",
+ "Started",
+ "Location"));
+ Summary.Append(fmt::format("{:-<8} {:-<8} {:-<40} {:-<40} {:-<8} {:-<12} {:-<8} {:-<12} {:-<8} {:-<8} {:-<12} {:-<12} {:-<24}\n",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ ""));
+
+ for (const ActionSummaryEntry& Entry : Sorted)
+ {
+ Summary.Append(fmt::format("{:<8} {:<8} {:<40} {:<40} {:>8} {:>12} {:>8} {:>12} {:>8.2f} {:>8.2f} {:>12} {:>12} {:<24}\n",
+ Entry.Lsn,
+ Entry.RecordingIndex,
+ Entry.ActionId,
+ Entry.FunctionName,
+ Entry.InputAttachments,
+ NiceBytes(Entry.InputBytes),
+ Entry.OutputAttachments,
+ NiceBytes(Entry.OutputBytes),
+ Entry.WallSeconds,
+ Entry.CpuSeconds,
+ FormatTimestamp(Entry.SubmittedTicks),
+ FormatTimestamp(Entry.StartedTicks),
+ Entry.ExecutionLocation));
+ }
- ZEN_CONSOLE("wrote summary to {}", SummaryPath.string());
+ std::filesystem::path SummaryPath = m_Config.OutputPath / "summary.txt";
+ std::string_view SummaryStr = Summary;
+ zen::WriteFile(SummaryPath, IoBuffer(IoBuffer::Clone, SummaryStr.data(), SummaryStr.size()));
- if (!m_Binary)
- {
- auto EscapeHtml = [](std::string_view Input) -> std::string {
- std::string Out;
- Out.reserve(Input.size());
- for (char C : Input)
- {
- switch (C)
- {
- case '&':
- Out += "&amp;";
- break;
- case '<':
- Out += "&lt;";
- break;
- case '>':
- Out += "&gt;";
- break;
- case '"':
- Out += "&quot;";
- break;
- case '\'':
- Out += "&#39;";
- break;
- default:
- Out += C;
- }
- }
- return Out;
- };
+ ZEN_CONSOLE("wrote summary to {}", SummaryPath.string());
- auto EscapeJson = [](std::string_view Input) -> std::string {
- std::string Out;
- Out.reserve(Input.size());
- for (char C : Input)
- {
- switch (C)
- {
- case '"':
- Out += "\\\"";
- break;
- case '\\':
- Out += "\\\\";
- break;
- case '\n':
- Out += "\\n";
- break;
- case '\r':
- Out += "\\r";
- break;
- case '\t':
- Out += "\\t";
- break;
- default:
- if (static_cast<unsigned char>(C) < 0x20)
- {
- Out += fmt::format("\\u{:04x}", static_cast<unsigned>(static_cast<unsigned char>(C)));
- }
- else
- {
- Out += C;
- }
- }
- }
- return Out;
- };
+ // Write summary.html
- ExtendableStringBuilder<8192> Html;
+ if (!m_Config.Binary)
+ {
+ ExtendableStringBuilder<8192> Html;
- Html.Append(std::string_view(R"(<!DOCTYPE html>
+ Html.Append(std::string_view(R"(<!DOCTYPE html>
<html><head><meta charset="utf-8"><title>Exec Summary</title>
<style>
body{font-family:system-ui,sans-serif;margin:20px;background:#fafafa}
@@ -1007,51 +682,50 @@ a:hover{text-decoration:underline}
const DATA=[
)"));
- std::string_view ResultExt = ".result.yaml";
- std::string_view ActionExt = ".action.yaml";
+ std::string_view ResultExt = ".result.yaml";
+ std::string_view ActionExt = ".action.yaml";
- for (const ActionSummaryEntry& Entry : Sorted)
+ for (const ActionSummaryEntry& Entry : Sorted)
+ {
+ std::string SafeName = EscapeJson(EscapeHtml(Entry.FunctionName));
+ std::string ActionIdStr = Entry.ActionId.ToHexString();
+ std::string ActionLink;
+ if (!ActionExt.empty())
{
- std::string SafeName = EscapeJson(EscapeHtml(Entry.FunctionName));
- std::string ActionIdStr = Entry.ActionId.ToHexString();
- std::string ActionLink;
- if (!ActionExt.empty())
- {
- ActionLink = EscapeJson(fmt::format(" <a href=\"{}{}\">[action]</a>", Entry.Lsn, ActionExt));
- }
-
- // Indices: 0=lsn, 1=idx, 2=actionId, 3=fn, 4=inAtt, 5=inBytes, 6=outAtt, 7=outBytes,
- // 8=wall, 9=cpu, 10=niceBytesIn, 11=niceBytesOut, 12=actionLink,
- // 13=submittedTicks, 14=startedTicks, 15=submittedDisplay, 16=startedDisplay,
- // 17=location
- Html.Append(
- fmt::format("[{},{},\"{}\",\"{}\",{},{},{},{},{:.6f},{:.6f},\"{}\",\"{}\",\"{}\",{},{},\"{}\",\"{}\",\"{}\"],\n",
- Entry.Lsn,
- Entry.RecordingIndex,
- ActionIdStr,
- SafeName,
- Entry.InputAttachments,
- Entry.InputBytes,
- Entry.OutputAttachments,
- Entry.OutputBytes,
- Entry.WallSeconds,
- Entry.CpuSeconds,
- EscapeJson(NiceBytes(Entry.InputBytes)),
- EscapeJson(NiceBytes(Entry.OutputBytes)),
- ActionLink,
- Entry.SubmittedTicks,
- Entry.StartedTicks,
- FormatTimestamp(Entry.SubmittedTicks),
- FormatTimestamp(Entry.StartedTicks),
- EscapeJson(EscapeHtml(Entry.ExecutionLocation))));
+ ActionLink = EscapeJson(fmt::format(" <a href=\"{}{}\">[action]</a>", Entry.Lsn, ActionExt));
}
- Html.Append(fmt::format(R"(];
+ // Indices: 0=lsn, 1=idx, 2=actionId, 3=fn, 4=inAtt, 5=inBytes, 6=outAtt, 7=outBytes,
+ // 8=wall, 9=cpu, 10=niceBytesIn, 11=niceBytesOut, 12=actionLink,
+ // 13=submittedTicks, 14=startedTicks, 15=submittedDisplay, 16=startedDisplay,
+ // 17=location
+ Html.Append(fmt::format("[{},{},\"{}\",\"{}\",{},{},{},{},{:.6f},{:.6f},\"{}\",\"{}\",\"{}\",{},{},\"{}\",\"{}\",\"{}\"],\n",
+ Entry.Lsn,
+ Entry.RecordingIndex,
+ ActionIdStr,
+ SafeName,
+ Entry.InputAttachments,
+ Entry.InputBytes,
+ Entry.OutputAttachments,
+ Entry.OutputBytes,
+ Entry.WallSeconds,
+ Entry.CpuSeconds,
+ EscapeJson(NiceBytes(Entry.InputBytes)),
+ EscapeJson(NiceBytes(Entry.OutputBytes)),
+ ActionLink,
+ Entry.SubmittedTicks,
+ Entry.StartedTicks,
+ FormatTimestamp(Entry.SubmittedTicks),
+ FormatTimestamp(Entry.StartedTicks),
+ EscapeJson(EscapeHtml(Entry.ExecutionLocation))));
+ }
+
+ Html.Append(fmt::format(R"(];
const RESULT_EXT="{}";
)",
- ResultExt));
+ ResultExt));
- Html.Append(std::string_view(R"JS((function(){
+ Html.Append(std::string_view(R"JS((function(){
const ROW_H=33,BUF=20;
const container=document.getElementById("container");
const tbody=container.querySelector("tbody");
@@ -1158,14 +832,244 @@ document.getElementById("csvBtn").addEventListener("click",()=>{
</script></body></html>
)JS"));
- std::filesystem::path HtmlPath = m_OutputPath / "summary.html";
- std::string_view HtmlStr = Html;
- zen::WriteFile(HtmlPath, IoBuffer(IoBuffer::Clone, HtmlStr.data(), HtmlStr.size()));
+ std::filesystem::path HtmlPath = m_Config.OutputPath / "summary.html";
+ std::string_view HtmlStr = Html;
+ zen::WriteFile(HtmlPath, IoBuffer(IoBuffer::Clone, HtmlStr.data(), HtmlStr.size()));
+
+ ZEN_CONSOLE("wrote HTML summary to {}", HtmlPath.string());
+ }
+}
+
+void
+ExecSessionRunner::EmitFunctionListOnce()
+{
+ if (!m_FunctionListEmittedOnce)
+ {
+ ExecCommand::EmitFunctionList(m_Config.FunctionList);
+ m_FunctionListEmittedOnce = true;
+ }
+}
+
+int
+ExecSessionRunner::Run()
+{
+ m_Session.WaitUntilReady();
+
+ // Register as a client with the orchestrator (best-effort)
+
+ m_OrchestratorClientId = RegisterOrchestratorClient();
+
+ auto ClientCleanup = MakeGuard([&] { CompleteOrchestratorClient(); });
+
+ // Create a queue to group all actions from this exec session
+
+ CbObjectWriter Metadata;
+ Metadata << "source"sv
+ << "zen-exec"sv;
+
+ auto QueueResult = m_Session.CreateQueue("zen-exec", Metadata.Save());
+ const int QueueId = QueueResult.QueueId;
+ if (!QueueId)
+ {
+ ZEN_ERROR("failed to create compute queue");
+ return 1;
+ }
+
+ m_QueueId = QueueId;
+
+ auto QueueCleanup = MakeGuard([&] { m_Session.DeleteQueue(QueueId); });
+
+ if (!m_Config.OutputPath.empty())
+ {
+ zen::CreateDirectories(m_Config.OutputPath);
+ }
+
+ // Describe workers
+
+ ZEN_CONSOLE("describing {} workers", m_Config.WorkerMap.size());
+
+ for (auto Kv : m_Config.WorkerMap)
+ {
+ CbPackage WorkerDesc = Kv.second;
+
+ m_Session.RegisterWorker(WorkerDesc);
+ }
+
+ // Then submit work items
+
+ int FailedWorkCounter = 0;
+ size_t RemainingWorkItems = m_Config.RecordingReader.GetActionCount();
+ int SubmittedWorkItems = 0;
+
+ ZEN_CONSOLE("submitting {} work items", RemainingWorkItems);
+
+ int OffsetCounter = m_Config.Offset;
+ int StrideCounter = m_Config.Stride;
- ZEN_CONSOLE("wrote HTML summary to {}", HtmlPath.string());
+ auto ShouldSchedule = [&]() -> bool {
+ if (m_Config.Limit && SubmittedWorkItems >= m_Config.Limit)
+ {
+ // Limit reached, ignore
+
+ return false;
+ }
+
+ if (OffsetCounter && OffsetCounter--)
+ {
+ // Still in offset, ignore
+
+ return false;
}
+
+ if (--StrideCounter == 0)
+ {
+ StrideCounter = m_Config.Stride;
+
+ return true;
+ }
+
+ return false;
+ };
+
+ int TargetParallelism = 8;
+
+ if (OffsetCounter || StrideCounter || m_Config.Limit)
+ {
+ TargetParallelism = 1;
+ }
+
+ std::atomic<int> RecordingIndex{0};
+
+ m_Config.RecordingReader.IterateActions(
+ [&](CbObject ActionObject, const IoHash& ActionId) {
+ // Enqueue job
+
+ const int CurrentRecordingIndex = RecordingIndex++;
+
+ Stopwatch SubmitTimer;
+
+ const int Priority = 0;
+
+ if (ShouldSchedule())
+ {
+ if (m_Config.Verbose)
+ {
+ int AttachmentCount = 0;
+ uint64_t AttachmentBytes = 0;
+ eastl::hash_set<IoHash> ReferencedChunks;
+
+ ActionObject.IterateAttachments([&](CbFieldView Field) {
+ IoHash AttachData = Field.AsAttachment();
+
+ ReferencedChunks.insert(AttachData);
+ ++AttachmentCount;
+
+ if (IoBuffer ChunkData = m_Config.Resolver.FindChunkByCid(AttachData))
+ {
+ AttachmentBytes += ChunkData.GetSize();
+ }
+ });
+
+ zen::ExtendableStringBuilder<1024> ObjStr;
+ zen::CompactBinaryToJson(ActionObject, ObjStr);
+ ZEN_CONSOLE("work item {} ({} attachments, {} bytes): {}",
+ ActionId,
+ AttachmentCount,
+ NiceBytes(AttachmentBytes),
+ ObjStr);
+ }
+
+ if (m_Config.DumpActions)
+ {
+ int AttachmentCount = 0;
+ uint64_t AttachmentBytes = 0;
+
+ ActionObject.IterateAttachments([&](CbFieldView Field) {
+ IoHash AttachData = Field.AsAttachment();
+
+ ++AttachmentCount;
+
+ if (IoBuffer ChunkData = m_Config.Resolver.FindChunkByCid(AttachData))
+ {
+ AttachmentBytes += ChunkData.GetSize();
+ }
+ });
+
+ zen::ExtendableStringBuilder<1024> ObjStr;
+ zen::CompactBinaryToYaml(ActionObject, ObjStr);
+ ZEN_CONSOLE("action {} ({} attachments, {}):\n{}", ActionId, AttachmentCount, NiceBytes(AttachmentBytes), ObjStr);
+ }
+
+ if (zen::compute::ComputeServiceSession::EnqueueResult EnqueueResult =
+ m_Session.EnqueueActionToQueue(QueueId, ActionObject, Priority))
+ {
+ const int32_t LsnField = EnqueueResult.Lsn;
+
+ --RemainingWorkItems;
+ ++SubmittedWorkItems;
+
+ if (!m_Config.Quiet)
+ {
+ ZEN_CONSOLE("submitted work item #{} - LSN {} - {}. {} remaining",
+ SubmittedWorkItems,
+ LsnField,
+ NiceTimeSpanMs(SubmitTimer.GetElapsedTimeMs()),
+ RemainingWorkItems);
+ }
+
+ if (!m_Config.OutputPath.empty())
+ {
+ SaveActionOutput(LsnField, CurrentRecordingIndex, ActionId, ActionObject);
+ }
+
+ m_PendingJobs.Insert(LsnField);
+ }
+ else
+ {
+ if (!m_Config.Quiet)
+ {
+ std::string_view FunctionName = ActionObject["Function"sv].AsString();
+ const Guid FunctionVersion = ActionObject["FunctionVersion"sv].AsUuid();
+ const Guid BuildSystemVersion = ActionObject["BuildSystemVersion"sv].AsUuid();
+
+ ZEN_ERROR(
+ "failed to resolve function for work with (Function:{},FunctionVersion:{},BuildSystemVersion:{}). Work "
+ "descriptor "
+ "at: 'file://{}'",
+ std::string(FunctionName),
+ FunctionVersion,
+ BuildSystemVersion,
+ "<null>");
+
+ EmitFunctionListOnce();
+ }
+
+ ++FailedWorkCounter;
+ }
+ }
+
+ // Check for completed work
+
+ DrainCompletedJobs();
+ SendOrchestratorHeartbeat();
+ },
+ TargetParallelism);
+
+ // Wait until all pending work is complete
+
+ while (!m_PendingJobs.IsEmpty())
+ {
+ // TODO: improve this logic
+ zen::Sleep(500);
+
+ DrainCompletedJobs();
+ SendOrchestratorHeartbeat();
}
+ // Write summary files
+
+ WriteSummaryFiles();
+
if (FailedWorkCounter)
{
return 1;
@@ -1174,37 +1078,91 @@ document.getElementById("csvBtn").addEventListener("click",()=>{
return 0;
}
-int
-ExecCommand::LocalMessagingExecute()
+//////////////////////////////////////////////////////////////////////////
+// ExecHttpSubCmd
+
+ExecHttpSubCmd::ExecHttpSubCmd(ExecCommand& Parent) : ZenSubCmdBase("http", "Forward requests to HTTP compute service"), m_Parent(Parent)
{
- // Non-HTTP work submission path
+ SubOptions().add_option("", "u", "hosturl", ZenCmdBase::kHostUrlHelp, cxxopts::value(m_HostName), "<hosturl>");
+}
+
+void
+ExecHttpSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/)
+{
+ m_HostName = ZenCmdBase::ResolveTargetHostSpec(m_HostName);
- // To be reimplemented using final transport
+ ZEN_ASSERT(m_Parent.m_ChunkResolver);
+ ChunkResolver& Resolver = *m_Parent.m_ChunkResolver;
- return 0;
+ std::filesystem::path TempPath = std::filesystem::absolute(".zen_temp");
+
+ zen::compute::ComputeServiceSession ComputeSession(Resolver);
+ ComputeSession.AddRemoteRunner(Resolver, TempPath, m_HostName);
+
+ Stopwatch ExecTimer;
+ int ReturnValue = m_Parent.RunSession(ComputeSession);
+
+ ZEN_CONSOLE("complete - took {}", NiceTimeSpanMs(ExecTimer.GetElapsedTimeMs()));
+
+ if (!ReturnValue)
+ {
+ ZEN_CONSOLE("all work items completed successfully");
+ }
+ else
+ {
+ ZEN_CONSOLE("some work items failed (code {})", ReturnValue);
+ }
}
//////////////////////////////////////////////////////////////////////////
+// ExecInprocSubCmd
-int
-ExecCommand::HttpExecute()
+ExecInprocSubCmd::ExecInprocSubCmd(ExecCommand& Parent) : ZenSubCmdBase("inproc", "Handle execution in-process"), m_Parent(Parent)
{
- ZEN_ASSERT(m_ChunkResolver);
- ChunkResolver& Resolver = *m_ChunkResolver;
+}
- std::filesystem::path TempPath = std::filesystem::absolute(".zen_temp");
+void
+ExecInprocSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/)
+{
+ ZEN_ASSERT(m_Parent.m_ChunkResolver);
+ ChunkResolver& Resolver = *m_Parent.m_ChunkResolver;
zen::compute::ComputeServiceSession ComputeSession(Resolver);
- ComputeSession.AddRemoteRunner(Resolver, TempPath, m_HostName);
- return ExecUsingSession(ComputeSession);
+ std::filesystem::path TempPath = std::filesystem::absolute(".zen_temp");
+ ComputeSession.AddLocalRunner(Resolver, TempPath);
+
+ Stopwatch ExecTimer;
+ int ReturnValue = m_Parent.RunSession(ComputeSession);
+
+ ZEN_CONSOLE("complete - took {}", NiceTimeSpanMs(ExecTimer.GetElapsedTimeMs()));
+
+ if (!ReturnValue)
+ {
+ ZEN_CONSOLE("all work items completed successfully");
+ }
+ else
+ {
+ ZEN_CONSOLE("some work items failed (code {})", ReturnValue);
+ }
}
-int
-ExecCommand::BeaconExecute()
+//////////////////////////////////////////////////////////////////////////
+// ExecBeaconSubCmd
+
+ExecBeaconSubCmd::ExecBeaconSubCmd(ExecCommand& Parent)
+: ZenSubCmdBase("beacon", "Execute via beacon/orchestrator discovery")
+, m_Parent(Parent)
{
- ZEN_ASSERT(m_ChunkResolver);
- ChunkResolver& Resolver = *m_ChunkResolver;
+ SubOptions().add_option("", "", "orch", "Orchestrator URL for worker discovery", cxxopts::value(m_OrchestratorUrl), "<url>");
+ SubOptions().add_option("", "", "beacon", "Beacon path", cxxopts::value(m_BeaconPath), "<path>");
+}
+
+void
+ExecBeaconSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/)
+{
+ ZEN_ASSERT(m_Parent.m_ChunkResolver);
+ ChunkResolver& Resolver = *m_Parent.m_ChunkResolver;
std::filesystem::path TempPath = std::filesystem::absolute(".zen_temp");
zen::compute::ComputeServiceSession ComputeSession(Resolver);
@@ -1221,49 +1179,36 @@ ExecCommand::BeaconExecute()
ComputeSession.AddRemoteRunner(Resolver, TempPath, "http://localhost:8558");
}
- return ExecUsingSession(ComputeSession);
-}
-
-//////////////////////////////////////////////////////////////////////////
+ Stopwatch ExecTimer;
+ int ReturnValue = m_Parent.RunSession(ComputeSession, m_OrchestratorUrl);
-void
-ExecCommand::RegisterWorkerFunctionsFromDescription(const CbObject& WorkerDesc, const IoHash& WorkerId)
-{
- const Guid WorkerBuildSystemVersion = WorkerDesc["buildsystem_version"sv].AsUuid();
+ ZEN_CONSOLE("complete - took {}", NiceTimeSpanMs(ExecTimer.GetElapsedTimeMs()));
- for (auto& Item : WorkerDesc["functions"sv])
+ if (!ReturnValue)
{
- CbObjectView Function = Item.AsObjectView();
-
- std::string_view FunctionName = Function["name"sv].AsString();
- const Guid FunctionVersion = Function["version"sv].AsUuid();
-
- m_FunctionList.emplace_back(FunctionDefinition{.FunctionName = std::string{FunctionName},
- .FunctionVersion = FunctionVersion,
- .BuildSystemVersion = WorkerBuildSystemVersion,
- .WorkerId = WorkerId});
+ ZEN_CONSOLE("all work items completed successfully");
+ }
+ else
+ {
+ ZEN_CONSOLE("some work items failed (code {})", ReturnValue);
}
}
-void
-ExecCommand::EmitFunctionListOnce(const std::vector<FunctionDefinition>& FunctionList)
-{
- if (m_FunctionListEmittedOnce == false)
- {
- EmitFunctionList(FunctionList);
+//////////////////////////////////////////////////////////////////////////
+// ExecDumpSubCmd
- m_FunctionListEmittedOnce = true;
- }
+ExecDumpSubCmd::ExecDumpSubCmd(ExecCommand& Parent) : ZenSubCmdBase("dump", "Dump high level information about actions"), m_Parent(Parent)
+{
}
-int
-ExecCommand::DumpWorkItems()
+void
+ExecDumpSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/)
{
std::atomic<int> EmittedCount{0};
eastl::hash_map<IoHash, uint64_t> SeenAttachments; // Attachment CID -> count of references
- m_RecordingReader->IterateActions(
+ m_Parent.m_RecordingReader->IterateActions(
[&](CbObject ActionObject, const IoHash& ActionId) {
eastl::hash_map<IoHash, CompressedBuffer> Attachments;
@@ -1272,7 +1217,7 @@ ExecCommand::DumpWorkItems()
ActionObject.IterateAttachments([&](const zen::CbFieldView AttachmentField) {
const IoHash AttachmentCid = AttachmentField.GetValue().AsHash();
- IoBuffer AttachmentData = m_ChunkResolver->FindChunkByCid(AttachmentCid);
+ IoBuffer AttachmentData = m_Parent.m_ChunkResolver->FindChunkByCid(AttachmentCid);
IoHash RawHash;
uint64_t RawSize = 0;
CompressedBuffer CompressedData = CompressedBuffer::FromCompressed(SharedBuffer(AttachmentData), RawHash, RawSize);
@@ -1322,36 +1267,191 @@ ExecCommand::DumpWorkItems()
{
ZEN_CONSOLE("{} attachments with {} references", Cids.size(), RefCount);
}
+}
- return 0;
+//////////////////////////////////////////////////////////////////////////
+// ExecBuildlogSubCmd
+
+ExecBuildlogSubCmd::ExecBuildlogSubCmd(ExecCommand& Parent) : ZenSubCmdBase("buildlog", "Generate build actions log"), m_Parent(Parent)
+{
+}
+
+void
+ExecBuildlogSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/)
+{
+ ZEN_ASSERT(m_Parent.m_ChunkResolver);
+ ChunkResolver& Resolver = *m_Parent.m_ChunkResolver;
+
+ if (std::filesystem::exists(m_Parent.m_RecordingLogPath))
+ {
+ m_Parent.ThrowOptionError(fmt::format("recording log directory '{}' already exists!", m_Parent.m_RecordingLogPath));
+ }
+
+ ZEN_NOT_IMPLEMENTED("build log generation not implemented yet!");
+
+ std::filesystem::path TempPath = std::filesystem::absolute(".zen_temp");
+
+ zen::compute::ComputeServiceSession ComputeSession(Resolver);
+ ComputeSession.StartRecording(Resolver, m_Parent.m_RecordingLogPath);
+
+ Stopwatch ExecTimer;
+ int ReturnValue = m_Parent.RunSession(ComputeSession);
+
+ ZEN_CONSOLE("complete - took {}", NiceTimeSpanMs(ExecTimer.GetElapsedTimeMs()));
+
+ if (!ReturnValue)
+ {
+ ZEN_CONSOLE("all work items completed successfully");
+ }
+ else
+ {
+ ZEN_CONSOLE("some work items failed (code {})", ReturnValue);
+ }
}
//////////////////////////////////////////////////////////////////////////
+// ExecCommand
-int
-ExecCommand::BuildActionsLog()
+ExecCommand::ExecCommand()
{
- ZEN_ASSERT(m_ChunkResolver);
- ChunkResolver& Resolver = *m_ChunkResolver;
+ m_Options.add_options()("h,help", "Print help");
+ m_Options.add_option("replay", "p", "path", "Recording path (directory or .actionlog file)", cxxopts::value(m_RecordingPath), "<path>");
+ m_Options.add_option("replay", "", "log", "Action log directory", cxxopts::value(m_RecordingLogPath), "<path>");
+ m_Options.add_option("replay", "", "offset", "Recording replay start offset", cxxopts::value(m_Offset), "<offset>");
+ m_Options.add_option("replay", "", "stride", "Recording replay stride", cxxopts::value(m_Stride), "<stride>");
+ m_Options.add_option("replay", "", "limit", "Recording replay limit", cxxopts::value(m_Limit), "<limit>");
+ m_Options.add_option("", "", "quiet", "Quiet mode (less logging)", cxxopts::value(m_Quiet), "<bool>");
+ m_Options.add_option("output",
+ "",
+ "dump-actions",
+ "Dump each action to console as it is dispatched",
+ cxxopts::value(m_DumpActions),
+ "<bool>");
+ m_Options.add_option("output", "o", "output", "Save action results to directory", cxxopts::value(m_OutputPath), "<path>");
+ m_Options.add_option("output", "", "binary", "Write output as binary packages instead of YAML", cxxopts::value(m_Binary), "<bool>");
+ m_Options.add_option("__hidden__", "", "subcommand", "", cxxopts::value<std::string>(m_SubCommand)->default_value(""), "");
+ m_Options.parse_positional({"subcommand"});
+
+ AddSubCommand(m_HttpSubCmd);
+ AddSubCommand(m_InprocSubCmd);
+ AddSubCommand(m_BeaconSubCmd);
+ AddSubCommand(m_DumpSubCmd);
+ AddSubCommand(m_BuildlogSubCmd);
+}
+ExecCommand::~ExecCommand()
+{
+}
+
+bool
+ExecCommand::OnParentOptionsParsed(const ZenCliOptions& GlobalOptions)
+{
if (m_RecordingPath.empty())
{
- throw OptionParseException("need to specify recording path", m_Options.help());
+ ThrowOptionError("replay path is required!");
}
- if (std::filesystem::exists(m_RecordingLogPath))
+ m_VerboseLogging = GlobalOptions.IsVerbose;
+ m_QuietLogging = m_Quiet && !m_VerboseLogging;
+
+ // Gather information from recording path
+
+ std::filesystem::path RecordingPath{m_RecordingPath};
+
+ if (!std::filesystem::is_directory(RecordingPath))
{
- throw OptionParseException(fmt::format("recording log directory '{}' already exists!", m_RecordingLogPath), m_Options.help());
+ ThrowOptionError("replay path should be a directory path!");
+ }
+ else
+ {
+ if (std::filesystem::is_directory(RecordingPath / "cid"))
+ {
+ m_Reader = std::make_unique<zen::compute::RecordingReader>(RecordingPath);
+ m_WorkerMap = m_Reader->ReadWorkers();
+ m_ChunkResolver = m_Reader.get();
+ m_RecordingReader = m_Reader.get();
+ }
+ else
+ {
+ m_UeReader = std::make_unique<zen::compute::UeRecordingReader>(RecordingPath);
+ m_WorkerMap = m_UeReader->ReadWorkers();
+ m_ChunkResolver = m_UeReader.get();
+ m_RecordingReader = m_UeReader.get();
+ }
}
- ZEN_NOT_IMPLEMENTED("build log generation not implemented yet!");
+ ZEN_CONSOLE("found {} workers, {} action items", m_WorkerMap.size(), m_RecordingReader->GetActionCount());
- std::filesystem::path TempPath = std::filesystem::absolute(".zen_temp");
+ for (auto& Kv : m_WorkerMap)
+ {
+ CbObject WorkerDesc = Kv.second.GetObject();
+ const IoHash& WorkerId = Kv.first;
- zen::compute::ComputeServiceSession ComputeSession(Resolver);
- ComputeSession.StartRecording(Resolver, m_RecordingLogPath);
+ RegisterWorkerFunctionsFromDescription(WorkerDesc, WorkerId);
+
+ if (m_VerboseLogging)
+ {
+ zen::ExtendableStringBuilder<1024> ObjStr;
+# if 0
+ zen::CompactBinaryToJson(WorkerDesc, ObjStr);
+ ZEN_CONSOLE("worker {}: {}", WorkerId, ObjStr);
+# else
+ zen::CompactBinaryToYaml(WorkerDesc, ObjStr);
+ ZEN_CONSOLE("worker {}:\n{}", WorkerId, ObjStr);
+# endif
+ }
+ }
+
+ if (m_VerboseLogging)
+ {
+ EmitFunctionList(m_FunctionList);
+ }
+
+ return true;
+}
- return ExecUsingSession(ComputeSession);
+int
+ExecCommand::RunSession(zen::compute::ComputeServiceSession& ComputeSession, std::string_view OrchestratorUrl)
+{
+ ExecSessionConfig Config{
+ .Resolver = *m_ChunkResolver,
+ .RecordingReader = *m_RecordingReader,
+ .WorkerMap = m_WorkerMap,
+ .FunctionList = m_FunctionList,
+ .OrchestratorUrl = OrchestratorUrl,
+ .OutputPath = m_OutputPath,
+ .Offset = m_Offset,
+ .Stride = m_Stride,
+ .Limit = m_Limit,
+ .Verbose = m_VerboseLogging,
+ .Quiet = m_QuietLogging,
+ .DumpActions = m_DumpActions,
+ .Binary = m_Binary,
+ };
+
+ ExecSessionRunner Runner(ComputeSession, Config);
+ return Runner.Run();
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+void
+ExecCommand::RegisterWorkerFunctionsFromDescription(const CbObject& WorkerDesc, const IoHash& WorkerId)
+{
+ const Guid WorkerBuildSystemVersion = WorkerDesc["buildsystem_version"sv].AsUuid();
+
+ for (auto& Item : WorkerDesc["functions"sv])
+ {
+ CbObjectView Function = Item.AsObjectView();
+
+ std::string_view FunctionName = Function["name"sv].AsString();
+ const Guid FunctionVersion = Function["version"sv].AsUuid();
+
+ m_FunctionList.emplace_back(FunctionDefinition{.FunctionName = std::string{FunctionName},
+ .FunctionVersion = FunctionVersion,
+ .BuildSystemVersion = WorkerBuildSystemVersion,
+ .WorkerId = WorkerId});
+ }
}
void
diff --git a/src/zen/cmds/exec_cmd.h b/src/zen/cmds/exec_cmd.h
index 6311354c0..c55412780 100644
--- a/src/zen/cmds/exec_cmd.h
+++ b/src/zen/cmds/exec_cmd.h
@@ -11,6 +11,7 @@
#include <filesystem>
#include <functional>
+#include <memory>
#include <unordered_map>
namespace zen {
@@ -28,13 +29,79 @@ class ComputeServiceSession;
namespace zen {
+class ExecCommand;
+
+struct ExecFunctionDefinition
+{
+ std::string FunctionName;
+ zen::Guid FunctionVersion;
+ zen::Guid BuildSystemVersion;
+ zen::IoHash WorkerId;
+};
+
+//////////////////////////////////////////////////////////////////////////
+// Subcommands
+
+class ExecHttpSubCmd : public ZenSubCmdBase
+{
+public:
+ explicit ExecHttpSubCmd(ExecCommand& Parent);
+ void Run(const ZenCliOptions& GlobalOptions) override;
+
+private:
+ ExecCommand& m_Parent;
+ std::string m_HostName;
+};
+
+class ExecInprocSubCmd : public ZenSubCmdBase
+{
+public:
+ explicit ExecInprocSubCmd(ExecCommand& Parent);
+ void Run(const ZenCliOptions& GlobalOptions) override;
+
+private:
+ ExecCommand& m_Parent;
+};
+
+class ExecBeaconSubCmd : public ZenSubCmdBase
+{
+public:
+ explicit ExecBeaconSubCmd(ExecCommand& Parent);
+ void Run(const ZenCliOptions& GlobalOptions) override;
+
+private:
+ ExecCommand& m_Parent;
+ std::string m_OrchestratorUrl;
+ std::filesystem::path m_BeaconPath;
+};
+
+class ExecDumpSubCmd : public ZenSubCmdBase
+{
+public:
+ explicit ExecDumpSubCmd(ExecCommand& Parent);
+ void Run(const ZenCliOptions& GlobalOptions) override;
+
+private:
+ ExecCommand& m_Parent;
+};
+
+class ExecBuildlogSubCmd : public ZenSubCmdBase
+{
+public:
+ explicit ExecBuildlogSubCmd(ExecCommand& Parent);
+ void Run(const ZenCliOptions& GlobalOptions) override;
+
+private:
+ ExecCommand& m_Parent;
+};
+
/**
* Zen CLI command for executing functions from a recording
*
* Mostly for testing and debugging purposes
*/
-class ExecCommand : public ZenCmdBase
+class ExecCommand : public ZenCmdWithSubCommands
{
public:
ExecCommand();
@@ -43,57 +110,47 @@ public:
static constexpr char Name[] = "exec";
static constexpr char Description[] = "Execute functions from a recording";
- virtual void Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override;
- virtual cxxopts::Options& Options() override { return m_Options; }
+ cxxopts::Options& Options() override { return m_Options; }
+
+ // Shared state & helpers (public for subcommand access)
+
+ using FunctionDefinition = ExecFunctionDefinition;
+
+ static void EmitFunctionList(const std::vector<FunctionDefinition>& FunctionList);
+ void RegisterWorkerFunctionsFromDescription(const zen::CbObject& WorkerDesc, const zen::IoHash& WorkerId);
+
+ int RunSession(zen::compute::ComputeServiceSession& ComputeSession, std::string_view OrchestratorUrl = {});
+
+ std::unordered_map<zen::IoHash, zen::CbPackage> m_WorkerMap;
+ std::vector<FunctionDefinition> m_FunctionList;
+ zen::ChunkResolver* m_ChunkResolver = nullptr;
+ zen::compute::RecordingReaderBase* m_RecordingReader = nullptr;
+ bool m_VerboseLogging = false;
+ bool m_QuietLogging = false;
+ bool m_DumpActions = false;
+ std::filesystem::path m_OutputPath;
+ bool m_Binary = false;
+ std::filesystem::path m_RecordingLogPath;
private:
+ bool OnParentOptionsParsed(const ZenCliOptions& GlobalOptions) override;
+
cxxopts::Options m_Options{Name, Description};
- std::string m_HostName;
- std::string m_OrchestratorUrl;
- std::filesystem::path m_BeaconPath;
+ std::string m_SubCommand;
std::filesystem::path m_RecordingPath;
- std::filesystem::path m_RecordingLogPath;
int m_Offset = 0;
int m_Stride = 1;
int m_Limit = 0;
bool m_Quiet = false;
- std::string m_Mode{"http"};
- std::filesystem::path m_OutputPath;
- bool m_Binary = false;
-
- struct FunctionDefinition
- {
- std::string FunctionName;
- zen::Guid FunctionVersion;
- zen::Guid BuildSystemVersion;
- zen::IoHash WorkerId;
- };
-
- bool m_FunctionListEmittedOnce = false;
- void EmitFunctionListOnce(const std::vector<FunctionDefinition>& FunctionList);
- void EmitFunctionList(const std::vector<FunctionDefinition>& FunctionList);
-
- std::unordered_map<zen::IoHash, zen::CbPackage> m_WorkerMap;
- std::vector<FunctionDefinition> m_FunctionList;
- bool m_VerboseLogging = false;
- bool m_QuietLogging = false;
- bool m_DumpActions = false;
-
- zen::ChunkResolver* m_ChunkResolver = nullptr;
- zen::compute::RecordingReaderBase* m_RecordingReader = nullptr;
-
- void RegisterWorkerFunctionsFromDescription(const zen::CbObject& WorkerDesc, const zen::IoHash& WorkerId);
-
- int ExecUsingSession(zen::compute::ComputeServiceSession& ComputeSession);
- // Execution modes
+ std::unique_ptr<zen::compute::RecordingReader> m_Reader;
+ std::unique_ptr<zen::compute::UeRecordingReader> m_UeReader;
- int DumpWorkItems();
- int HttpExecute();
- int InProcessExecute();
- int LocalMessagingExecute();
- int BeaconExecute();
- int BuildActionsLog();
+ ExecHttpSubCmd m_HttpSubCmd{*this};
+ ExecInprocSubCmd m_InprocSubCmd{*this};
+ ExecBeaconSubCmd m_BeaconSubCmd{*this};
+ ExecDumpSubCmd m_DumpSubCmd{*this};
+ ExecBuildlogSubCmd m_BuildlogSubCmd{*this};
};
} // namespace zen
diff --git a/src/zen/cmds/service_cmd.cpp b/src/zen/cmds/service_cmd.cpp
index 3347f1afe..37baf5483 100644
--- a/src/zen/cmds/service_cmd.cpp
+++ b/src/zen/cmds/service_cmd.cpp
@@ -310,17 +310,34 @@ ServiceCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
std::vector<char*> SubCommandArguments;
cxxopts::Options* SubOption = nullptr;
int ParentCommandArgCount = GetSubCommand(m_Options, argc, argv, m_SubCommands, SubOption, SubCommandArguments);
- if (!ParseOptions(ParentCommandArgCount, argv))
+
+ if (SubOption == nullptr)
+ {
+ if (!ParseOptions(ParentCommandArgCount, argv))
+ {
+ return;
+ }
+ throw OptionParseException("'verb' option is required", m_Options.help());
+ }
+
+ // Parse subcommand permissively — forward unrecognised options to the parent parser.
+ std::vector<std::string> SubUnmatched;
+ if (!ParseOptionsPermissive(*SubOption, gsl::narrow<int>(SubCommandArguments.size()), SubCommandArguments.data(), SubUnmatched))
{
return;
}
- if (SubOption == nullptr)
+ // Build parent arg list: original parent args (without subcommand name) + forwarded unmatched.
+ std::vector<char*> ParentArgs;
+ ParentArgs.reserve(static_cast<size_t>(ParentCommandArgCount - 1) + SubUnmatched.size());
+ ParentArgs.push_back(argv[0]);
+ std::copy(argv + 1, argv + ParentCommandArgCount - 1, std::back_inserter(ParentArgs));
+ for (std::string& Arg : SubUnmatched)
{
- throw OptionParseException("'verb' option is required", m_Options.help());
+ ParentArgs.push_back(Arg.data());
}
- if (!ParseOptions(*SubOption, gsl::narrow<int>(SubCommandArguments.size()), SubCommandArguments.data()))
+ if (!ParseOptions(static_cast<int>(ParentArgs.size()), ParentArgs.data()))
{
return;
}
diff --git a/src/zen/cmds/workspaces_cmd.cpp b/src/zen/cmds/workspaces_cmd.cpp
index 220ef6a9e..9e49b464e 100644
--- a/src/zen/cmds/workspaces_cmd.cpp
+++ b/src/zen/cmds/workspaces_cmd.cpp
@@ -127,14 +127,36 @@ WorkspaceCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
std::vector<char*> SubCommandArguments;
cxxopts::Options* SubOption = nullptr;
int ParentCommandArgCount = GetSubCommand(m_Options, argc, argv, m_SubCommands, SubOption, SubCommandArguments);
- if (!ParseOptions(ParentCommandArgCount, argv))
+
+ if (SubOption == nullptr)
+ {
+ if (!ParseOptions(ParentCommandArgCount, argv))
+ {
+ return;
+ }
+ throw OptionParseException("'verb' option is required", m_Options.help());
+ }
+
+ // Parse subcommand permissively — forward unrecognised options to the parent parser.
+ std::vector<std::string> SubUnmatched;
+ if (!ParseOptionsPermissive(*SubOption, gsl::narrow<int>(SubCommandArguments.size()), SubCommandArguments.data(), SubUnmatched))
{
return;
}
- if (SubOption == nullptr)
+ // Build parent arg list: original parent args (without subcommand name) + forwarded unmatched.
+ std::vector<char*> ParentArgs;
+ ParentArgs.reserve(static_cast<size_t>(ParentCommandArgCount - 1) + SubUnmatched.size());
+ ParentArgs.push_back(argv[0]);
+ std::copy(argv + 1, argv + ParentCommandArgCount - 1, std::back_inserter(ParentArgs));
+ for (std::string& Arg : SubUnmatched)
{
- throw OptionParseException("'verb' option is required", m_Options.help());
+ ParentArgs.push_back(Arg.data());
+ }
+
+ if (!ParseOptions(static_cast<int>(ParentArgs.size()), ParentArgs.data()))
+ {
+ return;
}
m_HostName = ResolveTargetHostSpec(m_HostName);
@@ -150,11 +172,6 @@ WorkspaceCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
std::filesystem::path StatePath = m_SystemRootDir / "workspaces";
- if (!ParseOptions(*SubOption, gsl::narrow<int>(SubCommandArguments.size()), SubCommandArguments.data()))
- {
- return;
- }
-
if (SubOption == &m_CreateOptions)
{
if (m_Path.empty())
@@ -376,14 +393,36 @@ WorkspaceShareCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char**
std::vector<char*> SubCommandArguments;
cxxopts::Options* SubOption = nullptr;
int ParentCommandArgCount = GetSubCommand(m_Options, argc, argv, m_SubCommands, SubOption, SubCommandArguments);
- if (!ParseOptions(ParentCommandArgCount, argv))
+
+ if (SubOption == nullptr)
+ {
+ if (!ParseOptions(ParentCommandArgCount, argv))
+ {
+ return;
+ }
+ throw OptionParseException("'verb' option is required", m_Options.help());
+ }
+
+ // Parse subcommand permissively — forward unrecognised options to the parent parser.
+ std::vector<std::string> SubUnmatched;
+ if (!ParseOptionsPermissive(*SubOption, gsl::narrow<int>(SubCommandArguments.size()), SubCommandArguments.data(), SubUnmatched))
{
return;
}
- if (SubOption == nullptr)
+ // Build parent arg list: original parent args (without subcommand name) + forwarded unmatched.
+ std::vector<char*> ParentArgs;
+ ParentArgs.reserve(static_cast<size_t>(ParentCommandArgCount - 1) + SubUnmatched.size());
+ ParentArgs.push_back(argv[0]);
+ std::copy(argv + 1, argv + ParentCommandArgCount - 1, std::back_inserter(ParentArgs));
+ for (std::string& Arg : SubUnmatched)
{
- throw OptionParseException("'verb' option is required", m_Options.help());
+ ParentArgs.push_back(Arg.data());
+ }
+
+ if (!ParseOptions(static_cast<int>(ParentArgs.size()), ParentArgs.data()))
+ {
+ return;
}
m_HostName = ResolveTargetHostSpec(m_HostName);
@@ -403,11 +442,6 @@ WorkspaceShareCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char**
std::filesystem::path StatePath = m_SystemRootDir / "workspaces";
- if (!ParseOptions(*SubOption, gsl::narrow<int>(SubCommandArguments.size()), SubCommandArguments.data()))
- {
- return;
- }
-
if (SubOption == &m_CreateOptions)
{
if (m_WorkspaceRoot.empty())
diff --git a/src/zen/zen.cpp b/src/zen/zen.cpp
index 86154c291..849d7a075 100644
--- a/src/zen/zen.cpp
+++ b/src/zen/zen.cpp
@@ -115,6 +115,21 @@ ZenCmdBase::CommandCategory() const
return DefaultCategory;
}
+std::string
+ZenCmdBase::HelpText()
+{
+ std::vector<std::string> Groups = Options().groups();
+ Groups.erase(std::remove(Groups.begin(), Groups.end(), std::string("__hidden__")), Groups.end());
+ Options().set_width(TuiConsoleColumns(80));
+ return Options().help(Groups);
+}
+
+void
+ZenCmdBase::ThrowOptionError(std::string_view Message)
+{
+ throw OptionParseException(std::string(Message), HelpText());
+}
+
bool
ZenCmdBase::ParseOptions(int argc, char** argv)
{
@@ -166,6 +181,35 @@ ZenCmdBase::ParseOptions(cxxopts::Options& CmdOptions, int argc, char** argv)
return true;
}
+bool
+ZenCmdBase::ParseOptionsPermissive(cxxopts::Options& CmdOptions, int argc, char** argv, std::vector<std::string>& OutUnmatched)
+{
+ CmdOptions.set_width(TuiConsoleColumns(80));
+ CmdOptions.allow_unrecognised_options();
+
+ cxxopts::ParseResult Result;
+
+ try
+ {
+ Result = CmdOptions.parse(argc, argv);
+ }
+ catch (const std::exception& Ex)
+ {
+ throw zen::OptionParseException(Ex.what(), CmdOptions.help());
+ }
+
+ CmdOptions.show_positional_help();
+
+ if (Result.count("help"))
+ {
+ printf("%s\n", CmdOptions.help().c_str());
+ return false;
+ }
+
+ OutUnmatched = Result.unmatched();
+ return true;
+}
+
// Get the number of args including the sub command
// Build an array for sub command to parse
int
@@ -243,6 +287,23 @@ ZenCmdWithSubCommands::PrintHelp()
}
void
+ZenCmdWithSubCommands::PrintSubCommandHelp(cxxopts::Options& SubCmdOptions)
+{
+ // Show the subcommand's own options.
+ SubCmdOptions.set_width(TuiConsoleColumns(80));
+ printf("%s\n", SubCmdOptions.help().c_str());
+
+ // Show the parent command's options (excluding the hidden positional group).
+ std::vector<std::string> ParentGroups = Options().groups();
+ ParentGroups.erase(std::remove(ParentGroups.begin(), ParentGroups.end(), std::string("__hidden__")), ParentGroups.end());
+
+ Options().set_width(TuiConsoleColumns(80));
+ printf("%s options:\n%s\n", Options().program().c_str(), Options().help(ParentGroups).c_str());
+
+ printf("For global options run: zen --help\n");
+}
+
+void
ZenCmdWithSubCommands::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
{
std::vector<cxxopts::Options*> SubOptionPtrs;
@@ -270,42 +331,15 @@ ZenCmdWithSubCommands::Run(const ZenCliOptions& GlobalOptions, int argc, char**
}
}
- // Parse parent options. When a subcommand was matched we strip its name from
- // the arg list so the parent parser does not see it as an unmatched positional.
- if (MatchedSubOption != nullptr)
- {
- std::vector<char*> ParentArgs;
- ParentArgs.reserve(static_cast<size_t>(ParentArgc - 1));
- ParentArgs.push_back(argv[0]);
- std::copy(argv + 1, argv + ParentArgc - 1, std::back_inserter(ParentArgs));
- if (!ParseOptions(Options(), static_cast<int>(ParentArgs.size()), ParentArgs.data()))
- {
- return;
- }
- }
- else
+ if (MatchedSubOption == nullptr)
{
if (!ParseOptions(Options(), ParentArgc, argv))
{
return;
}
- }
-
- if (MatchedSubOption == nullptr)
- {
PrintHelp();
-
- ExtendableStringBuilder<128> VerbList;
- for (bool First = true; ZenSubCmdBase * SubCmd : m_SubCommands)
- {
- if (!First)
- {
- VerbList.Append(", ");
- }
- VerbList.Append(SubCmd->SubOptions().program());
- First = false;
- }
- throw OptionParseException(fmt::format("No subcommand specified. Available subcommands: {}", VerbList.ToView()), {});
+ fflush(stdout);
+ throw OptionParseException("No subcommand specified", {});
}
ZenSubCmdBase* MatchedSubCmd = nullptr;
@@ -319,9 +353,40 @@ ZenCmdWithSubCommands::Run(const ZenCliOptions& GlobalOptions, int argc, char**
}
ZEN_ASSERT(MatchedSubCmd != nullptr);
- // Parse subcommand args before OnParentOptionsParsed so --help on the subcommand
- // works without requiring parent options like --hosturl to be populated.
- if (!ParseOptions(*MatchedSubOption, gsl::narrow<int>(SubCommandArguments.size()), SubCommandArguments.data()))
+ // Intercept --help/-h in the subcommand args so we can show combined help
+ // (subcommand options + parent options) without requiring parent options to
+ // be populated.
+ for (size_t i = 1; i < SubCommandArguments.size(); ++i)
+ {
+ std::string_view Arg(SubCommandArguments[i]);
+ if (Arg == "--help" || Arg == "-h")
+ {
+ PrintSubCommandHelp(*MatchedSubOption);
+ return;
+ }
+ }
+
+ // Parse subcommand args permissively — unrecognised options are collected
+ // and forwarded to the parent parser so that parent options (e.g. --path)
+ // can appear after the subcommand name on the command line.
+ std::vector<std::string> SubUnmatched;
+ if (!ParseOptionsPermissive(*MatchedSubOption, gsl::narrow<int>(SubCommandArguments.size()), SubCommandArguments.data(), SubUnmatched))
+ {
+ return;
+ }
+
+ // Build parent arg list: original parent args (without subcommand name) + forwarded unmatched.
+ std::vector<std::string> UnmatchedStorage = std::move(SubUnmatched);
+ std::vector<char*> ParentArgs;
+ ParentArgs.reserve(static_cast<size_t>(ParentArgc - 1) + UnmatchedStorage.size());
+ ParentArgs.push_back(argv[0]);
+ std::copy(argv + 1, argv + ParentArgc - 1, std::back_inserter(ParentArgs));
+ for (std::string& Arg : UnmatchedStorage)
+ {
+ ParentArgs.push_back(Arg.data());
+ }
+
+ if (!ParseOptions(Options(), static_cast<int>(ParentArgs.size()), ParentArgs.data()))
{
return;
}
diff --git a/src/zen/zen.h b/src/zen/zen.h
index 05ce32d0a..97cc9af6f 100644
--- a/src/zen/zen.h
+++ b/src/zen/zen.h
@@ -60,6 +60,7 @@ public:
bool ParseOptions(int argc, char** argv);
static bool ParseOptions(cxxopts::Options& Options, int argc, char** argv);
+ static bool ParseOptionsPermissive(cxxopts::Options& Options, int argc, char** argv, std::vector<std::string>& OutUnmatched);
static int GetSubCommand(cxxopts::Options& Options,
int argc,
char** argv,
@@ -74,7 +75,9 @@ public:
static constexpr const char* kHostUrlHelp = "Host URL or unix:///path/to/socket";
- static void LogExecutableVersionAndPid();
+ std::string HelpText();
+ [[noreturn]] void ThrowOptionError(std::string_view Message);
+ static void LogExecutableVersionAndPid();
};
class StorageCommand : public ZenCmdBase
@@ -114,6 +117,7 @@ protected:
void AddSubCommand(ZenSubCmdBase& SubCmd);
virtual bool OnParentOptionsParsed(const ZenCliOptions& GlobalOptions);
void PrintHelp();
+ void PrintSubCommandHelp(cxxopts::Options& SubCmdOptions);
private:
std::vector<ZenSubCmdBase*> m_SubCommands;
diff --git a/src/zencompute/CLAUDE.md b/src/zencompute/CLAUDE.md
index f5188123f..a1a39fc3c 100644
--- a/src/zencompute/CLAUDE.md
+++ b/src/zencompute/CLAUDE.md
@@ -46,9 +46,12 @@ Public API entry point. Uses PIMPL (`struct Impl` in computeservice.cpp). Owns:
- Action maps: `m_PendingActions`, `m_RunningMap`, `m_ResultsMap`
- Queue map: `m_Queues` (QueueEntry objects)
- Action history ring: `m_ActionHistory` (bounded deque, default 1000)
+- WebSocket client (`m_OrchestratorWsClient`) subscribed to the orchestrator's `/orch/ws` push for instant worker discovery
**Session states:** Created → Ready → Draining → Paused → Abandoned → Sunset. Both Abandoned and Sunset can be jumped to from any earlier state. Abandoned is used for spot instance termination grace periods — on entry, all pending and running actions are immediately marked as `RunnerAction::State::Abandoned` and running processes are best-effort cancelled. Auto-retry is suppressed while the session is Abandoned. `IsHealthy()` returns false for Abandoned and Sunset.
+**Convenience helpers:** `Ready()`, `Abandon()`, `SetOrchestrator(Endpoint, BasePath)` are inline wrappers for common state transitions and orchestrator configuration.
+
### `RunnerAction` (runners/functionrunner.h)
Shared ref-counted struct representing one action through its lifecycle.
@@ -67,8 +70,11 @@ New → Pending → Submitting → Running → Completed
→ Failed
→ Abandoned
→ Cancelled
+ → Retracted
```
-`SetActionState()` rejects non-forward transitions. The one exception is `ResetActionStateToPending()`, which uses CAS to atomically transition from Failed/Abandoned back to Pending for rescheduling. It clears timestamps from Submitting onward, resets execution fields, increments `RetryCount`, and calls `PostUpdate()` to re-enter the scheduler pipeline.
+`SetActionState()` rejects non-forward transitions (Retracted has the highest ordinal so runner-side transitions cannot override it). `ResetActionStateToPending()` uses CAS to atomically transition from Failed/Abandoned back to Pending for rescheduling — it clears timestamps from Submitting onward, resets execution fields, increments `RetryCount`, and calls `PostUpdate()` to re-enter the scheduler pipeline.
+
+**Retracted state:** An explicit, instigator-initiated request to pull an action back and reschedule it on a different runner (e.g. capacity opened up elsewhere). Unlike Failed/Abandoned auto-retry, rescheduling from Retracted does not increment `RetryCount` since nothing went wrong. Retraction is idempotent and can target Pending, Submitting, or Running actions.
### `LocalProcessRunner` (runners/localrunner.h)
Base for all local execution. Platform runners subclass this and override:
@@ -90,10 +96,29 @@ Base for all local execution. Platform runners subclass this and override:
- macOS: `proc_pidinfo(PROC_PIDTASKINFO)` pti_total_user+system nanoseconds ÷ 1,000,000,000
### `FunctionRunner` / `RunnerGroup` (runners/functionrunner.h)
-Abstract base for runners. `RunnerGroup<T>` holds a vector of runners and load-balances across them using a round-robin atomic index. `BaseRunnerGroup::SubmitActions()` distributes a batch proportionally based on per-runner capacity.
+Abstract base for runners. `RunnerGroup<T>` holds a vector of runners and load-balances across them using a round-robin atomic index. `BaseRunnerGroup::SubmitActions()` distributes a batch proportionally based on per-runner capacity. `SubmitActions()` supports batch submission — actions are grouped and forwarded in chunks.
+
+### `RemoteHttpRunner` (runners/remotehttprunner.h)
+Submits actions to remote zenserver instances over HTTP. Key features:
+- **WebSocket completion notifications**: connects a WS client to `/compute/ws` on the remote. When a message arrives (action completed), the monitor thread wakes immediately instead of polling. Falls back to adaptive polling (200ms→50ms) when WS is unavailable.
+- **Batch submission**: groups actions by queue and submits in configurable chunks (`m_MaxBatchSize`, default 50), falling back to individual submission on failure.
+- **Queue cancellation**: `CancelRemoteQueue()` sends cancel requests to the remote.
+- **Graceful shutdown**: `Shutdown()` closes the WS client, cancels all remote queues, stops the monitor thread, and marks remaining actions as Failed.
### `HttpComputeService` (include/zencompute/httpcomputeservice.h)
-Wraps `ComputeServiceSession` as an HTTP service. All routes are registered in the constructor. Handles CbPackage attachment ingestion from `CidStore` before forwarding to the service.
+Wraps `ComputeServiceSession` as an HTTP service. All routes are registered in the constructor. Handles CbPackage attachment ingestion from `CidStore` before forwarding to the service. Supports both single-action and batch (actions array) payloads via a shared `HandleSubmitAction` helper.
+
+## Orchestrator Discovery
+
+`ComputeServiceSession` discovers remote workers via the orchestrator endpoint (`SetOrchestratorEndpoint()`). Two complementary mechanisms:
+
+1. **Polling** (`UpdateCoordinatorState`): `GET /orch/agents` on the scheduler thread, throttled to every 5s (500ms when no workers are known yet). Discovers new workers and removes stale/unreachable ones.
+
+2. **WebSocket push** (`OrchestratorWsHandler`): connects to `/orch/ws` on the orchestrator at setup time. When the orchestrator broadcasts a state change, the handler sets `m_OrchestratorQueryForced` and signals the scheduler event, bypassing the polling throttle. Falls back silently to polling if the WS connection fails.
+
+`NotifyOrchestratorChanged()` is the public API to trigger an immediate re-query — useful in tests and for external notification sources.
+
+Use `HttpToWsUrl(Endpoint, Path)` from `zenhttp/httpwsclient.h` to convert HTTP(S) endpoints to WebSocket URLs. This helper is shared across all WS client setup sites in the codebase.
## Action Lifecycle (End to End)
@@ -118,6 +143,8 @@ Actions that fail or are abandoned can be automatically retried or manually resc
**Retry limit:** Default of 3, overridable per-queue via the `max_retries` integer field in the queue's `Config` CbObject (set at `CreateQueue` time). Both automatic and manual paths respect this limit.
+**Retraction (API path):** `RetractAction(Lsn)` pulls a Pending/Submitting/Running action back for rescheduling on a different runner. The action transitions to Retracted, then `ResetActionStateToPending()` is called *without* incrementing `RetryCount`. Retraction is idempotent.
+
**Cancelled actions are never retried** — cancellation is an intentional user action, not a transient failure.
## Queue System
@@ -161,8 +188,9 @@ All routes registered in `HttpComputeService` constructor. Prefix is configured
| GET | `jobs/running` | In-flight actions with CPU metrics |
| GET | `jobs/completed` | Actions with results available |
| GET/POST/DELETE | `jobs/{lsn}` | GET: result; POST: reschedule failed action; DELETE: retire |
+| POST | `jobs/{lsn}/retract` | Retract a pending/running action for rescheduling (idempotent) |
| POST | `jobs/{worker}` | Submit action for specific worker |
-| POST | `jobs` | Submit action (worker resolved from descriptor) |
+| POST | `jobs` | Submit action (or batch via `actions` array) |
| GET | `workers` | List worker IDs |
| GET | `workers/all` | All workers with full descriptors |
| GET/POST | `workers/{worker}` | Get/register worker |
@@ -179,8 +207,9 @@ Queue ref is capture(1) in all `queues/{queueref}/...` routes.
| GET | `queues/{queueref}/completed` | Queue's completed results |
| GET | `queues/{queueref}/history` | Queue's action history |
| GET | `queues/{queueref}/running` | Queue's running actions |
-| POST | `queues/{queueref}/jobs` | Submit to queue |
+| POST | `queues/{queueref}/jobs` | Submit to queue (or batch via `actions` array) |
| GET/POST | `queues/{queueref}/jobs/{lsn}` | GET: result; POST: reschedule |
+| POST | `queues/{queueref}/jobs/{lsn}/retract` | Retract action for rescheduling |
| GET/POST | `queues/{queueref}/workers/...` | Worker endpoints (same as global) |
Worker handler logic is extracted into private helpers (`HandleWorkersGet`, `HandleWorkersAllGet`, `HandleWorkerRequest`) shared by top-level and queue-scoped routes.
diff --git a/src/zencompute/computeservice.cpp b/src/zencompute/computeservice.cpp
index 838d741b6..92901de64 100644
--- a/src/zencompute/computeservice.cpp
+++ b/src/zencompute/computeservice.cpp
@@ -33,6 +33,7 @@
# include <zenutil/workerpools.h>
# include <zentelemetry/stats.h>
# include <zenhttp/httpclient.h>
+# include <zenhttp/httpwsclient.h>
# include <set>
# include <deque>
@@ -42,6 +43,7 @@
# include <unordered_set>
ZEN_THIRD_PARTY_INCLUDES_START
+# include <EASTL/fixed_vector.h>
# include <EASTL/hash_set.h>
ZEN_THIRD_PARTY_INCLUDES_END
@@ -95,6 +97,14 @@ using SessionState = ComputeServiceSession::SessionState;
static_assert(ZEN_ARRAY_COUNT(ComputeServiceSession::ActionHistoryEntry::Timestamps) == static_cast<size_t>(RunnerAction::State::_Count));
+static ComputeServiceSession::EnqueueResult
+MakeErrorResult(std::string_view Error)
+{
+ CbObjectWriter Writer;
+ Writer << "error"sv << Error;
+ return {0, Writer.Save()};
+}
+
//////////////////////////////////////////////////////////////////////////
struct ComputeServiceSession::Impl
@@ -130,14 +140,40 @@ struct ComputeServiceSession::Impl
void SetOrchestratorEndpoint(std::string_view Endpoint);
void SetOrchestratorBasePath(std::filesystem::path BasePath);
+ void NotifyOrchestratorChanged();
std::string m_OrchestratorEndpoint;
std::filesystem::path m_OrchestratorBasePath;
Stopwatch m_OrchestratorQueryTimer;
+ std::atomic<bool> m_OrchestratorQueryForced{false};
std::unordered_set<std::string> m_KnownWorkerUris;
void UpdateCoordinatorState();
+ // WebSocket subscription to orchestrator push notifications
+ struct OrchestratorWsHandler : public IWsClientHandler
+ {
+ Impl& Owner;
+
+ explicit OrchestratorWsHandler(Impl& InOwner) : Owner(InOwner) {}
+
+ void OnWsOpen() override
+ {
+ ZEN_LOG_INFO(Owner.m_Log, "orchestrator WebSocket connected");
+ Owner.NotifyOrchestratorChanged();
+ }
+
+ void OnWsMessage(const WebSocketMessage&) override { Owner.NotifyOrchestratorChanged(); }
+
+ void OnWsClose(uint16_t Code, std::string_view Reason) override
+ {
+ ZEN_LOG_WARN(Owner.m_Log, "orchestrator WebSocket closed (code {}: {})", Code, Reason);
+ }
+ };
+
+ std::unique_ptr<OrchestratorWsHandler> m_OrchestratorWsHandler;
+ std::unique_ptr<HttpWsClient> m_OrchestratorWsClient;
+
// Worker registration and discovery
struct FunctionDefinition
@@ -157,6 +193,8 @@ struct ComputeServiceSession::Impl
std::atomic<int32_t> m_ActionsCounter = 0; // sequence number
metrics::Meter m_ArrivalRate;
+ std::atomic<IComputeCompletionObserver*> m_CompletionObserver{nullptr};
+
RwLock m_PendingLock;
std::map<int, Ref<RunnerAction>> m_PendingActions;
@@ -267,6 +305,8 @@ struct ComputeServiceSession::Impl
void DrainQueue(int QueueId);
ComputeServiceSession::EnqueueResult EnqueueActionToQueue(int QueueId, CbObject ActionObject, int Priority);
ComputeServiceSession::EnqueueResult EnqueueResolvedActionToQueue(int QueueId, WorkerDesc Worker, CbObject ActionObj, int Priority);
+ ComputeServiceSession::EnqueueResult ValidateQueueForEnqueue(int QueueId, Ref<QueueEntry>& OutQueue);
+ void ActivateActionInQueue(const Ref<QueueEntry>& Queue, int Lsn);
void GetQueueCompleted(int QueueId, CbWriter& Cbo);
void NotifyQueueActionComplete(int QueueId, int Lsn, RunnerAction::State ActionState);
void ExpireCompletedQueues();
@@ -292,11 +332,13 @@ struct ComputeServiceSession::Impl
void HandleActionUpdates();
void PostUpdate(RunnerAction* Action);
+ void RemoveActionFromActiveMaps(int ActionLsn);
static constexpr int kDefaultMaxRetries = 3;
int GetMaxRetriesForQueue(int QueueId);
ComputeServiceSession::RescheduleResult RescheduleAction(int ActionLsn);
+ ComputeServiceSession::RescheduleResult RetractAction(int ActionLsn);
ActionCounts GetActionCounts()
{
@@ -449,6 +491,28 @@ void
ComputeServiceSession::Impl::SetOrchestratorEndpoint(std::string_view Endpoint)
{
m_OrchestratorEndpoint = Endpoint;
+
+ // Subscribe to orchestrator WebSocket push so we discover worker changes
+ // immediately instead of waiting for the next polling cycle.
+ try
+ {
+ std::string WsUrl = HttpToWsUrl(Endpoint, "/orch/ws");
+
+ m_OrchestratorWsHandler = std::make_unique<OrchestratorWsHandler>(*this);
+
+ HttpWsClientSettings WsSettings;
+ WsSettings.LogCategory = "orch_disc_ws";
+ WsSettings.ConnectTimeout = std::chrono::milliseconds{3000};
+
+ m_OrchestratorWsClient = std::make_unique<HttpWsClient>(WsUrl, *m_OrchestratorWsHandler, WsSettings);
+ m_OrchestratorWsClient->Connect();
+ }
+ catch (const std::exception& Ex)
+ {
+ ZEN_WARN("failed to connect orchestrator WebSocket, falling back to polling: {}", Ex.what());
+ m_OrchestratorWsClient.reset();
+ m_OrchestratorWsHandler.reset();
+ }
}
void
@@ -458,6 +522,13 @@ ComputeServiceSession::Impl::SetOrchestratorBasePath(std::filesystem::path BaseP
}
void
+ComputeServiceSession::Impl::NotifyOrchestratorChanged()
+{
+ m_OrchestratorQueryForced.store(true, std::memory_order_relaxed);
+ m_SchedulingThreadEvent.Set();
+}
+
+void
ComputeServiceSession::Impl::UpdateCoordinatorState()
{
ZEN_TRACE_CPU("ComputeServiceSession::UpdateCoordinatorState");
@@ -467,10 +538,14 @@ ComputeServiceSession::Impl::UpdateCoordinatorState()
}
// Poll faster when we have no discovered workers yet so remote runners come online quickly
- const uint64_t PollIntervalMs = m_KnownWorkerUris.empty() ? 500 : 5000;
- if (m_OrchestratorQueryTimer.GetElapsedTimeMs() < PollIntervalMs)
+ const bool Forced = m_OrchestratorQueryForced.exchange(false, std::memory_order_relaxed);
+ if (!Forced)
{
- return;
+ const uint64_t PollIntervalMs = m_KnownWorkerUris.empty() ? 500 : 5000;
+ if (m_OrchestratorQueryTimer.GetElapsedTimeMs() < PollIntervalMs)
+ {
+ return;
+ }
}
m_OrchestratorQueryTimer.Reset();
@@ -520,7 +595,24 @@ ComputeServiceSession::Impl::UpdateCoordinatorState()
continue;
}
- ZEN_INFO("discovered new worker at {}", UriStr);
+ std::string_view Hostname = Worker["hostname"sv].AsString();
+ std::string_view Platform = Worker["platform"sv].AsString();
+ int Cpus = Worker["cpus"sv].AsInt32();
+ uint64_t MemTotal = Worker["memory_total"sv].AsUInt64();
+
+ if (!Hostname.empty())
+ {
+ ZEN_INFO("discovered new worker at {} ({}, {}, {} cpus, {:.1f} GB)",
+ UriStr,
+ Hostname,
+ Platform,
+ Cpus,
+ static_cast<double>(MemTotal) / (1024.0 * 1024.0 * 1024.0));
+ }
+ else
+ {
+ ZEN_INFO("discovered new worker at {}", UriStr);
+ }
m_KnownWorkerUris.insert(UriStr);
@@ -598,6 +690,15 @@ ComputeServiceSession::Impl::Shutdown()
{
RequestStateTransition(SessionState::Sunset);
+ // Close orchestrator WebSocket before stopping the scheduler thread
+ // to prevent callbacks into a shutting-down scheduler.
+ if (m_OrchestratorWsClient)
+ {
+ m_OrchestratorWsClient->Close();
+ m_OrchestratorWsClient.reset();
+ }
+ m_OrchestratorWsHandler.reset();
+
m_SchedulingThreadEnabled = false;
m_SchedulingThreadEvent.Set();
if (m_SchedulingThread.joinable())
@@ -720,8 +821,14 @@ ComputeServiceSession::Impl::RegisterWorker(CbPackage Worker)
// different descriptor. Thus we only need to call this the first time, when the
// worker is added
- m_LocalRunnerGroup.RegisterWorker(Worker);
- m_RemoteRunnerGroup.RegisterWorker(Worker);
+ if (!m_LocalRunnerGroup.RegisterWorker(Worker))
+ {
+ ZEN_WARN("failed to register worker {} on one or more local runners", WorkerId);
+ }
+ if (!m_RemoteRunnerGroup.RegisterWorker(Worker))
+ {
+ ZEN_WARN("failed to register worker {} on one or more remote runners", WorkerId);
+ }
if (m_Recorder)
{
@@ -767,7 +874,10 @@ ComputeServiceSession::Impl::SyncWorkersToRunner(FunctionRunner& Runner)
for (const CbPackage& Worker : Workers)
{
- Runner.RegisterWorker(Worker);
+ if (!Runner.RegisterWorker(Worker))
+ {
+ ZEN_WARN("failed to sync worker {} to runner", Worker.GetObjectHash());
+ }
}
}
@@ -868,9 +978,7 @@ ComputeServiceSession::Impl::EnqueueResolvedAction(int QueueId, WorkerDesc Worke
if (m_SessionState.load(std::memory_order_relaxed) != SessionState::Ready)
{
- CbObjectWriter Writer;
- Writer << "error"sv << fmt::format("session is not accepting actions (state: {})", ToString(m_SessionState.load()));
- return {0, Writer.Save()};
+ return MakeErrorResult(fmt::format("session is not accepting actions (state: {})", ToString(m_SessionState.load())));
}
const int ActionLsn = ++m_ActionsCounter;
@@ -1258,42 +1366,51 @@ ComputeServiceSession::Impl::DrainQueue(int QueueId)
}
ComputeServiceSession::EnqueueResult
-ComputeServiceSession::Impl::EnqueueActionToQueue(int QueueId, CbObject ActionObject, int Priority)
+ComputeServiceSession::Impl::ValidateQueueForEnqueue(int QueueId, Ref<QueueEntry>& OutQueue)
{
- Ref<QueueEntry> Queue = FindQueue(QueueId);
+ OutQueue = FindQueue(QueueId);
- if (!Queue)
+ if (!OutQueue)
{
- CbObjectWriter Writer;
- Writer << "error"sv
- << "queue not found"sv;
- return {0, Writer.Save()};
+ return MakeErrorResult("queue not found"sv);
}
- QueueState QState = Queue->State.load();
+ QueueState QState = OutQueue->State.load();
if (QState == QueueState::Cancelled)
{
- CbObjectWriter Writer;
- Writer << "error"sv
- << "queue is cancelled"sv;
- return {0, Writer.Save()};
+ return MakeErrorResult("queue is cancelled"sv);
}
if (QState == QueueState::Draining)
{
- CbObjectWriter Writer;
- Writer << "error"sv
- << "queue is draining"sv;
- return {0, Writer.Save()};
+ return MakeErrorResult("queue is draining"sv);
+ }
+
+ return {};
+}
+
+void
+ComputeServiceSession::Impl::ActivateActionInQueue(const Ref<QueueEntry>& Queue, int Lsn)
+{
+ Queue->m_Lock.WithExclusiveLock([&] { Queue->ActiveLsns.insert(Lsn); });
+ Queue->ActiveCount.fetch_add(1, std::memory_order_relaxed);
+ Queue->IdleSince.store(0, std::memory_order_relaxed);
+}
+
+ComputeServiceSession::EnqueueResult
+ComputeServiceSession::Impl::EnqueueActionToQueue(int QueueId, CbObject ActionObject, int Priority)
+{
+ Ref<QueueEntry> Queue;
+ if (EnqueueResult Error = ValidateQueueForEnqueue(QueueId, Queue); Error.ResponseMessage)
+ {
+ return Error;
}
EnqueueResult Result = EnqueueAction(QueueId, ActionObject, Priority);
if (Result.Lsn != 0)
{
- Queue->m_Lock.WithExclusiveLock([&] { Queue->ActiveLsns.insert(Result.Lsn); });
- Queue->ActiveCount.fetch_add(1, std::memory_order_relaxed);
- Queue->IdleSince.store(0, std::memory_order_relaxed);
+ ActivateActionInQueue(Queue, Result.Lsn);
}
return Result;
@@ -1302,40 +1419,17 @@ ComputeServiceSession::Impl::EnqueueActionToQueue(int QueueId, CbObject ActionOb
ComputeServiceSession::EnqueueResult
ComputeServiceSession::Impl::EnqueueResolvedActionToQueue(int QueueId, WorkerDesc Worker, CbObject ActionObj, int Priority)
{
- Ref<QueueEntry> Queue = FindQueue(QueueId);
-
- if (!Queue)
- {
- CbObjectWriter Writer;
- Writer << "error"sv
- << "queue not found"sv;
- return {0, Writer.Save()};
- }
-
- QueueState QState = Queue->State.load();
- if (QState == QueueState::Cancelled)
+ Ref<QueueEntry> Queue;
+ if (EnqueueResult Error = ValidateQueueForEnqueue(QueueId, Queue); Error.ResponseMessage)
{
- CbObjectWriter Writer;
- Writer << "error"sv
- << "queue is cancelled"sv;
- return {0, Writer.Save()};
- }
-
- if (QState == QueueState::Draining)
- {
- CbObjectWriter Writer;
- Writer << "error"sv
- << "queue is draining"sv;
- return {0, Writer.Save()};
+ return Error;
}
EnqueueResult Result = EnqueueResolvedAction(QueueId, Worker, ActionObj, Priority);
if (Result.Lsn != 0)
{
- Queue->m_Lock.WithExclusiveLock([&] { Queue->ActiveLsns.insert(Result.Lsn); });
- Queue->ActiveCount.fetch_add(1, std::memory_order_relaxed);
- Queue->IdleSince.store(0, std::memory_order_relaxed);
+ ActivateActionInQueue(Queue, Result.Lsn);
}
return Result;
@@ -1770,6 +1864,68 @@ ComputeServiceSession::Impl::RescheduleAction(int ActionLsn)
return {.Success = true, .RetryCount = NewRetryCount};
}
+ComputeServiceSession::RescheduleResult
+ComputeServiceSession::Impl::RetractAction(int ActionLsn)
+{
+ Ref<RunnerAction> Action;
+ bool WasRunning = false;
+
+ // Look for the action in pending or running maps
+ m_RunningLock.WithSharedLock([&] {
+ if (auto It = m_RunningMap.find(ActionLsn); It != m_RunningMap.end())
+ {
+ Action = It->second;
+ WasRunning = true;
+ }
+ });
+
+ if (!Action)
+ {
+ m_PendingLock.WithSharedLock([&] {
+ if (auto It = m_PendingActions.find(ActionLsn); It != m_PendingActions.end())
+ {
+ Action = It->second;
+ }
+ });
+ }
+
+ if (!Action)
+ {
+ return {.Success = false, .Error = "Action not found in pending or running maps"};
+ }
+
+ if (!Action->RetractAction())
+ {
+ return {.Success = false, .Error = "Action cannot be retracted from its current state"};
+ }
+
+ // If the action was running, send a cancellation signal to the runner
+ if (WasRunning)
+ {
+ m_LocalRunnerGroup.CancelAction(ActionLsn);
+ }
+
+ ZEN_INFO("action {} ({}) retract requested", Action->ActionId, ActionLsn);
+ return {.Success = true, .RetryCount = Action->RetryCount.load(std::memory_order_relaxed)};
+}
+
+void
+ComputeServiceSession::Impl::RemoveActionFromActiveMaps(int ActionLsn)
+{
+ m_RunningLock.WithExclusiveLock([&] {
+ m_PendingLock.WithExclusiveLock([&] {
+ if (auto FindIt = m_RunningMap.find(ActionLsn); FindIt == m_RunningMap.end())
+ {
+ m_PendingActions.erase(ActionLsn);
+ }
+ else
+ {
+ m_RunningMap.erase(FindIt);
+ }
+ });
+ });
+}
+
void
ComputeServiceSession::Impl::HandleActionUpdates()
{
@@ -1781,6 +1937,10 @@ ComputeServiceSession::Impl::HandleActionUpdates()
std::unordered_set<int> SeenLsn;
+ // Collect terminal action notifications for the completion observer.
+ // Inline capacity of 64 avoids heap allocation in the common case.
+ eastl::fixed_vector<IComputeCompletionObserver::CompletedActionNotification, 64> TerminalBatch;
+
// Process each action's latest state, deduplicating by LSN.
//
// This is safe because state transitions are monotonically increasing by enum
@@ -1798,7 +1958,23 @@ ComputeServiceSession::Impl::HandleActionUpdates()
{
// Newly enqueued — add to pending map for scheduling
case RunnerAction::State::Pending:
- m_PendingLock.WithExclusiveLock([&] { m_PendingActions.insert({ActionLsn, Action}); });
+ // Guard against a race where the session is abandoned between
+ // EnqueueAction (which calls PostUpdate) and this scheduler
+ // tick. AbandonAllActions() only scans m_PendingActions, so it
+ // misses actions still in m_UpdatedActions at the time the
+ // session transitions. Detect that here and immediately abandon
+ // rather than inserting into the pending map, where they would
+ // otherwise be stuck indefinitely.
+ if (m_SessionState.load(std::memory_order_relaxed) >= SessionState::Abandoned)
+ {
+ Action->SetActionState(RunnerAction::State::Abandoned);
+ // SetActionState calls PostUpdate; the Abandoned action
+ // will be processed as a terminal on the next scheduler pass.
+ }
+ else
+ {
+ m_PendingLock.WithExclusiveLock([&] { m_PendingActions.insert({ActionLsn, Action}); });
+ }
break;
// Async submission in progress — remains in pending map
@@ -1816,6 +1992,15 @@ ComputeServiceSession::Impl::HandleActionUpdates()
ZEN_DEBUG("action {} ({}) RUNNING", Action->ActionId, ActionLsn);
break;
+ // Retracted — pull back for rescheduling without counting against retry limit
+ case RunnerAction::State::Retracted:
+ {
+ RemoveActionFromActiveMaps(ActionLsn);
+ Action->ResetActionStateToPending();
+ ZEN_INFO("action {} ({}) retracted for rescheduling", Action->ActionId, ActionLsn);
+ break;
+ }
+
// Terminal states — move to results, record history, notify queue
case RunnerAction::State::Completed:
case RunnerAction::State::Failed:
@@ -1834,19 +2019,7 @@ ComputeServiceSession::Impl::HandleActionUpdates()
if (Action->RetryCount.load(std::memory_order_relaxed) < MaxRetries)
{
- // Remove from whichever active map the action is in before resetting
- m_RunningLock.WithExclusiveLock([&] {
- m_PendingLock.WithExclusiveLock([&] {
- if (auto FindIt = m_RunningMap.find(ActionLsn); FindIt == m_RunningMap.end())
- {
- m_PendingActions.erase(ActionLsn);
- }
- else
- {
- m_RunningMap.erase(FindIt);
- }
- });
- });
+ RemoveActionFromActiveMaps(ActionLsn);
// Reset triggers PostUpdate() which re-enters the action as Pending
Action->ResetActionStateToPending();
@@ -1861,19 +2034,14 @@ ComputeServiceSession::Impl::HandleActionUpdates()
}
}
- // Remove from whichever active map the action is in
- m_RunningLock.WithExclusiveLock([&] {
- m_PendingLock.WithExclusiveLock([&] {
- if (auto FindIt = m_RunningMap.find(ActionLsn); FindIt == m_RunningMap.end())
- {
- m_PendingActions.erase(ActionLsn);
- }
- else
- {
- m_RunningMap.erase(FindIt);
- }
- });
- });
+ RemoveActionFromActiveMaps(ActionLsn);
+
+ // Update queue counters BEFORE publishing the result into
+ // m_ResultsMap. GetActionResult erases from m_ResultsMap
+ // under m_ResultsLock, so if we updated counters after
+ // releasing that lock, a caller could observe ActiveCount
+ // still at 1 immediately after GetActionResult returned OK.
+ NotifyQueueActionComplete(Action->QueueId, ActionLsn, TerminalState);
m_ResultsLock.WithExclusiveLock([&] {
m_ResultsMap[ActionLsn] = Action;
@@ -1902,16 +2070,46 @@ ComputeServiceSession::Impl::HandleActionUpdates()
});
m_RetiredCount.fetch_add(1);
m_ResultRate.Mark(1);
+ {
+ using ObserverState = IComputeCompletionObserver::ActionState;
+ ObserverState NotifyState{};
+ switch (TerminalState)
+ {
+ case RunnerAction::State::Completed:
+ NotifyState = ObserverState::Completed;
+ break;
+ case RunnerAction::State::Failed:
+ NotifyState = ObserverState::Failed;
+ break;
+ case RunnerAction::State::Abandoned:
+ NotifyState = ObserverState::Abandoned;
+ break;
+ case RunnerAction::State::Cancelled:
+ NotifyState = ObserverState::Cancelled;
+ break;
+ default:
+ break;
+ }
+ TerminalBatch.push_back({ActionLsn, NotifyState});
+ }
ZEN_DEBUG("action {} ({}) RUNNING -> COMPLETED with {}",
Action->ActionId,
ActionLsn,
TerminalState == RunnerAction::State::Completed ? "SUCCESS" : "FAILURE");
- NotifyQueueActionComplete(Action->QueueId, ActionLsn, TerminalState);
break;
}
}
}
}
+
+ // Notify the completion observer, if any, about all terminal actions in this batch.
+ if (!TerminalBatch.empty())
+ {
+ if (IComputeCompletionObserver* Observer = m_CompletionObserver.load(std::memory_order_acquire))
+ {
+ Observer->OnActionsCompleted({TerminalBatch.data(), TerminalBatch.size()});
+ }
+ }
}
size_t
@@ -2014,6 +2212,12 @@ ComputeServiceSession::SetOrchestratorBasePath(std::filesystem::path BasePath)
}
void
+ComputeServiceSession::NotifyOrchestratorChanged()
+{
+ m_Impl->NotifyOrchestratorChanged();
+}
+
+void
ComputeServiceSession::StartRecording(ChunkResolver& InResolver, const std::filesystem::path& RecordingPath)
{
m_Impl->StartRecording(InResolver, RecordingPath);
@@ -2194,6 +2398,12 @@ ComputeServiceSession::RescheduleAction(int ActionLsn)
return m_Impl->RescheduleAction(ActionLsn);
}
+ComputeServiceSession::RescheduleResult
+ComputeServiceSession::RetractAction(int ActionLsn)
+{
+ return m_Impl->RetractAction(ActionLsn);
+}
+
std::vector<ComputeServiceSession::RunningActionInfo>
ComputeServiceSession::GetRunningActions()
{
@@ -2219,6 +2429,12 @@ ComputeServiceSession::GetCompleted(CbWriter& Cbo)
}
void
+ComputeServiceSession::SetCompletionObserver(IComputeCompletionObserver* Observer)
+{
+ m_Impl->m_CompletionObserver.store(Observer, std::memory_order_release);
+}
+
+void
ComputeServiceSession::PostUpdate(RunnerAction* Action)
{
m_Impl->PostUpdate(Action);
diff --git a/src/zencompute/httpcomputeservice.cpp b/src/zencompute/httpcomputeservice.cpp
index e82a40781..bdfd9d197 100644
--- a/src/zencompute/httpcomputeservice.cpp
+++ b/src/zencompute/httpcomputeservice.cpp
@@ -16,6 +16,7 @@
# include <zencore/iobuffer.h>
# include <zencore/iohash.h>
# include <zencore/logging.h>
+# include <zencore/string.h>
# include <zencore/system.h>
# include <zencore/thread.h>
# include <zencore/trace.h>
@@ -23,8 +24,10 @@
# include <zenstore/cidstore.h>
# include <zentelemetry/stats.h>
+# include <algorithm>
# include <span>
# include <unordered_map>
+# include <vector>
using namespace std::literals;
@@ -50,6 +53,11 @@ struct HttpComputeService::Impl
ComputeServiceSession m_ComputeService;
SystemMetricsTracker m_MetricsTracker;
+ // WebSocket connections (completion push)
+
+ RwLock m_WsConnectionsLock;
+ std::vector<Ref<WebSocketConnection>> m_WsConnections;
+
// Metrics
metrics::OperationTiming m_HttpRequests;
@@ -91,6 +99,12 @@ struct HttpComputeService::Impl
void HandleWorkersAllGet(HttpServerRequest& HttpReq);
void WriteQueueDescription(CbWriter& Cbo, int QueueId, const ComputeServiceSession::QueueStatus& Status);
void HandleWorkerRequest(HttpServerRequest& HttpReq, const IoHash& WorkerId);
+ void HandleSubmitAction(HttpServerRequest& HttpReq, int QueueId, int Priority, const WorkerDesc* Worker);
+
+ // WebSocket / observer
+ void OnWebSocketOpen(Ref<WebSocketConnection> Connection);
+ void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code);
+ void OnActionsCompleted(std::span<const IComputeCompletionObserver::CompletedActionNotification> Actions);
void RegisterRoutes();
@@ -110,6 +124,7 @@ struct HttpComputeService::Impl
m_ComputeService.WaitUntilReady();
m_StatsService.RegisterHandler("compute", *m_Self);
RegisterRoutes();
+ m_ComputeService.SetCompletionObserver(m_Self);
}
};
@@ -149,7 +164,7 @@ HttpComputeService::Impl::RegisterRoutes()
return HttpReq.WriteResponse(HttpResponseCode::Forbidden);
}
- bool Success = m_ComputeService.RequestStateTransition(ComputeServiceSession::SessionState::Abandoned);
+ bool Success = m_ComputeService.Abandon();
if (Success)
{
@@ -325,6 +340,29 @@ HttpComputeService::Impl::RegisterRoutes()
HttpVerb::kGet | HttpVerb::kPost);
m_Router.RegisterRoute(
+ "jobs/{lsn}/retract",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+ const int ActionLsn = ParseInt<int>(Req.GetCapture(1)).value_or(0);
+
+ auto Result = m_ComputeService.RetractAction(ActionLsn);
+
+ CbObjectWriter Cbo;
+ if (Result.Success)
+ {
+ Cbo << "success"sv << true;
+ Cbo << "lsn"sv << ActionLsn;
+ HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save());
+ }
+ else
+ {
+ Cbo << "error"sv << Result.Error;
+ HttpReq.WriteResponse(HttpResponseCode::Conflict, Cbo.Save());
+ }
+ },
+ HttpVerb::kPost);
+
+ m_Router.RegisterRoute(
"jobs/{worker}/{action}", // This route is inefficient, and is only here for backwards compatibility. The preferred path is the
// one which uses the scheduled action lsn for lookups
[this](HttpRouterRequest& Req) {
@@ -373,127 +411,7 @@ HttpComputeService::Impl::RegisterRoutes()
RequestPriority = ParseInt<int>(PriorityParam).value_or(-1);
}
- switch (HttpReq.RequestVerb())
- {
- case HttpVerb::kGet:
- // TODO: return status of all pending or executing jobs
- break;
-
- case HttpVerb::kPost:
- switch (HttpReq.RequestContentType())
- {
- case HttpContentType::kCbObject:
- {
- // This operation takes the proposed job spec and identifies which
- // chunks are not present on this server. This list is then returned in
- // the "need" list in the response
-
- IoBuffer Payload = HttpReq.ReadPayload();
- CbObject ActionObj = LoadCompactBinaryObject(Payload);
-
- std::vector<IoHash> NeedList;
-
- ActionObj.IterateAttachments([&](CbFieldView Field) {
- const IoHash FileHash = Field.AsHash();
-
- if (!m_CidStore.ContainsChunk(FileHash))
- {
- NeedList.push_back(FileHash);
- }
- });
-
- if (NeedList.empty())
- {
- // We already have everything, enqueue the action for execution
-
- if (ComputeServiceSession::EnqueueResult Result =
- m_ComputeService.EnqueueResolvedAction(Worker, ActionObj, RequestPriority))
- {
- ZEN_DEBUG("action {} accepted (lsn {})", ActionObj.GetHash(), Result.Lsn);
-
- HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage);
- }
-
- return;
- }
-
- CbObjectWriter Cbo;
- Cbo.BeginArray("need");
-
- for (const IoHash& Hash : NeedList)
- {
- Cbo << Hash;
- }
-
- Cbo.EndArray();
- CbObject Response = Cbo.Save();
-
- return HttpReq.WriteResponse(HttpResponseCode::NotFound, Response);
- }
- break;
-
- case HttpContentType::kCbPackage:
- {
- CbPackage Action = HttpReq.ReadPayloadPackage();
- CbObject ActionObj = Action.GetObject();
-
- std::span<const CbAttachment> Attachments = Action.GetAttachments();
-
- int AttachmentCount = 0;
- int NewAttachmentCount = 0;
- uint64_t TotalAttachmentBytes = 0;
- uint64_t TotalNewBytes = 0;
-
- for (const CbAttachment& Attachment : Attachments)
- {
- ZEN_ASSERT(Attachment.IsCompressedBinary());
-
- const IoHash DataHash = Attachment.GetHash();
- CompressedBuffer DataView = Attachment.AsCompressedBinary();
-
- ZEN_UNUSED(DataHash);
-
- const uint64_t CompressedSize = DataView.GetCompressedSize();
-
- TotalAttachmentBytes += CompressedSize;
- ++AttachmentCount;
-
- const CidStore::InsertResult InsertResult =
- m_CidStore.AddChunk(DataView.GetCompressed().Flatten().AsIoBuffer(), DataHash);
-
- if (InsertResult.New)
- {
- TotalNewBytes += CompressedSize;
- ++NewAttachmentCount;
- }
- }
-
- if (ComputeServiceSession::EnqueueResult Result =
- m_ComputeService.EnqueueResolvedAction(Worker, ActionObj, RequestPriority))
- {
- ZEN_DEBUG("accepted action {} (lsn {}): {} in {} attachments. {} new ({} attachments)",
- ActionObj.GetHash(),
- Result.Lsn,
- zen::NiceBytes(TotalAttachmentBytes),
- AttachmentCount,
- zen::NiceBytes(TotalNewBytes),
- NewAttachmentCount);
-
- HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage);
- }
-
- return;
- }
- break;
-
- default:
- break;
- }
- break;
-
- default:
- break;
- }
+ HandleSubmitAction(HttpReq, 0, RequestPriority, &Worker);
},
HttpVerb::kPost);
@@ -511,118 +429,7 @@ HttpComputeService::Impl::RegisterRoutes()
RequestPriority = ParseInt<int>(PriorityParam).value_or(-1);
}
- // Resolve worker
-
- //
-
- switch (HttpReq.RequestContentType())
- {
- case HttpContentType::kCbObject:
- {
- // This operation takes the proposed job spec and identifies which
- // chunks are not present on this server. This list is then returned in
- // the "need" list in the response
-
- IoBuffer Payload = HttpReq.ReadPayload();
- CbObject ActionObj = LoadCompactBinaryObject(Payload);
-
- std::vector<IoHash> NeedList;
-
- ActionObj.IterateAttachments([&](CbFieldView Field) {
- const IoHash FileHash = Field.AsHash();
-
- if (!m_CidStore.ContainsChunk(FileHash))
- {
- NeedList.push_back(FileHash);
- }
- });
-
- if (NeedList.empty())
- {
- // We already have everything, enqueue the action for execution
-
- if (ComputeServiceSession::EnqueueResult Result = m_ComputeService.EnqueueAction(ActionObj, RequestPriority))
- {
- ZEN_DEBUG("action accepted (lsn {})", Result.Lsn);
-
- return HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage);
- }
- else
- {
- // Could not resolve?
- return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage);
- }
- }
-
- CbObjectWriter Cbo;
- Cbo.BeginArray("need");
-
- for (const IoHash& Hash : NeedList)
- {
- Cbo << Hash;
- }
-
- Cbo.EndArray();
- CbObject Response = Cbo.Save();
-
- return HttpReq.WriteResponse(HttpResponseCode::NotFound, Response);
- }
-
- case HttpContentType::kCbPackage:
- {
- CbPackage Action = HttpReq.ReadPayloadPackage();
- CbObject ActionObj = Action.GetObject();
-
- std::span<const CbAttachment> Attachments = Action.GetAttachments();
-
- int AttachmentCount = 0;
- int NewAttachmentCount = 0;
- uint64_t TotalAttachmentBytes = 0;
- uint64_t TotalNewBytes = 0;
-
- for (const CbAttachment& Attachment : Attachments)
- {
- ZEN_ASSERT(Attachment.IsCompressedBinary());
-
- const IoHash DataHash = Attachment.GetHash();
- CompressedBuffer DataView = Attachment.AsCompressedBinary();
-
- ZEN_UNUSED(DataHash);
-
- const uint64_t CompressedSize = DataView.GetCompressedSize();
-
- TotalAttachmentBytes += CompressedSize;
- ++AttachmentCount;
-
- const CidStore::InsertResult InsertResult =
- m_CidStore.AddChunk(DataView.GetCompressed().Flatten().AsIoBuffer(), DataHash);
-
- if (InsertResult.New)
- {
- TotalNewBytes += CompressedSize;
- ++NewAttachmentCount;
- }
- }
-
- if (ComputeServiceSession::EnqueueResult Result = m_ComputeService.EnqueueAction(ActionObj, RequestPriority))
- {
- ZEN_DEBUG("accepted action (lsn {}): {} in {} attachments. {} new ({} attachments)",
- Result.Lsn,
- zen::NiceBytes(TotalAttachmentBytes),
- AttachmentCount,
- zen::NiceBytes(TotalNewBytes),
- NewAttachmentCount);
-
- HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage);
- }
- else
- {
- // Could not resolve?
- return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage);
- }
- }
- return;
- }
+ HandleSubmitAction(HttpReq, 0, RequestPriority, nullptr);
},
HttpVerb::kPost);
@@ -1090,72 +897,7 @@ HttpComputeService::Impl::RegisterRoutes()
RequestPriority = ParseInt<int>(PriorityParam).value_or(-1);
}
- switch (HttpReq.RequestContentType())
- {
- case HttpContentType::kCbObject:
- {
- IoBuffer Payload = HttpReq.ReadPayload();
- CbObject ActionObj = LoadCompactBinaryObject(Payload);
-
- std::vector<IoHash> NeedList;
-
- if (!CheckAttachments(ActionObj, NeedList))
- {
- CbObjectWriter Cbo;
- Cbo.BeginArray("need");
-
- for (const IoHash& Hash : NeedList)
- {
- Cbo << Hash;
- }
-
- Cbo.EndArray();
-
- return HttpReq.WriteResponse(HttpResponseCode::NotFound, Cbo.Save());
- }
-
- if (ComputeServiceSession::EnqueueResult Result =
- m_ComputeService.EnqueueResolvedActionToQueue(QueueId, Worker, ActionObj, RequestPriority))
- {
- ZEN_DEBUG("queue {}: action {} accepted (lsn {})", QueueId, ActionObj.GetHash(), Result.Lsn);
- return HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage);
- }
- else
- {
- return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage);
- }
- }
-
- case HttpContentType::kCbPackage:
- {
- CbPackage Action = HttpReq.ReadPayloadPackage();
- CbObject ActionObj = Action.GetObject();
-
- IngestStats Stats = IngestPackageAttachments(Action);
-
- if (ComputeServiceSession::EnqueueResult Result =
- m_ComputeService.EnqueueResolvedActionToQueue(QueueId, Worker, ActionObj, RequestPriority))
- {
- ZEN_DEBUG("queue {}: accepted action {} (lsn {}): {} in {} attachments. {} new ({} attachments)",
- QueueId,
- ActionObj.GetHash(),
- Result.Lsn,
- zen::NiceBytes(Stats.Bytes),
- Stats.Count,
- zen::NiceBytes(Stats.NewBytes),
- Stats.NewCount);
-
- return HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage);
- }
- else
- {
- return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage);
- }
- }
-
- default:
- break;
- }
+ HandleSubmitAction(HttpReq, QueueId, RequestPriority, &Worker);
},
HttpVerb::kPost);
@@ -1178,71 +920,7 @@ HttpComputeService::Impl::RegisterRoutes()
RequestPriority = ParseInt<int>(PriorityParam).value_or(-1);
}
- switch (HttpReq.RequestContentType())
- {
- case HttpContentType::kCbObject:
- {
- IoBuffer Payload = HttpReq.ReadPayload();
- CbObject ActionObj = LoadCompactBinaryObject(Payload);
-
- std::vector<IoHash> NeedList;
-
- if (!CheckAttachments(ActionObj, NeedList))
- {
- CbObjectWriter Cbo;
- Cbo.BeginArray("need");
-
- for (const IoHash& Hash : NeedList)
- {
- Cbo << Hash;
- }
-
- Cbo.EndArray();
-
- return HttpReq.WriteResponse(HttpResponseCode::NotFound, Cbo.Save());
- }
-
- if (ComputeServiceSession::EnqueueResult Result =
- m_ComputeService.EnqueueActionToQueue(QueueId, ActionObj, RequestPriority))
- {
- ZEN_DEBUG("queue {}: action accepted (lsn {})", QueueId, Result.Lsn);
- return HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage);
- }
- else
- {
- return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage);
- }
- }
-
- case HttpContentType::kCbPackage:
- {
- CbPackage Action = HttpReq.ReadPayloadPackage();
- CbObject ActionObj = Action.GetObject();
-
- IngestStats Stats = IngestPackageAttachments(Action);
-
- if (ComputeServiceSession::EnqueueResult Result =
- m_ComputeService.EnqueueActionToQueue(QueueId, ActionObj, RequestPriority))
- {
- ZEN_DEBUG("queue {}: accepted action (lsn {}): {} in {} attachments. {} new ({} attachments)",
- QueueId,
- Result.Lsn,
- zen::NiceBytes(Stats.Bytes),
- Stats.Count,
- zen::NiceBytes(Stats.NewBytes),
- Stats.NewCount);
-
- return HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage);
- }
- else
- {
- return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage);
- }
- }
-
- default:
- break;
- }
+ HandleSubmitAction(HttpReq, QueueId, RequestPriority, nullptr);
},
HttpVerb::kPost);
@@ -1306,6 +984,45 @@ HttpComputeService::Impl::RegisterRoutes()
}
},
HttpVerb::kGet | HttpVerb::kPost);
+
+ m_Router.RegisterRoute(
+ "queues/{queueref}/jobs/{lsn}/retract",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+ const int QueueId = ResolveQueueRef(HttpReq, Req.GetCapture(1));
+ const int ActionLsn = ParseInt<int>(Req.GetCapture(2)).value_or(0);
+
+ if (QueueId == 0)
+ {
+ return;
+ }
+
+ ZEN_UNUSED(QueueId);
+
+ auto Result = m_ComputeService.RetractAction(ActionLsn);
+
+ CbObjectWriter Cbo;
+ if (Result.Success)
+ {
+ Cbo << "success"sv << true;
+ Cbo << "lsn"sv << ActionLsn;
+ HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save());
+ }
+ else
+ {
+ Cbo << "error"sv << Result.Error;
+ HttpReq.WriteResponse(HttpResponseCode::Conflict, Cbo.Save());
+ }
+ },
+ HttpVerb::kPost);
+
+ // WebSocket upgrade endpoint — the handler logic lives in
+ // HttpComputeService::OnWebSocket* methods; this route merely
+ // satisfies the router so the upgrade request isn't rejected.
+ m_Router.RegisterRoute(
+ "ws",
+ [this](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK); },
+ HttpVerb::kGet);
}
//////////////////////////////////////////////////////////////////////////
@@ -1320,12 +1037,17 @@ HttpComputeService::HttpComputeService(CidStore& InCidStore,
HttpComputeService::~HttpComputeService()
{
+ m_Impl->m_ComputeService.SetCompletionObserver(nullptr);
m_Impl->m_StatsService.UnregisterHandler("compute", *this);
}
void
HttpComputeService::Shutdown()
{
+ // Null out observer before shutting down the compute session to prevent
+ // callbacks into a partially-torn-down service.
+ m_Impl->m_ComputeService.SetCompletionObserver(nullptr);
+ m_Impl->m_WsConnectionsLock.WithExclusiveLock([&] { m_Impl->m_WsConnections.clear(); });
m_Impl->m_ComputeService.Shutdown();
}
@@ -1492,6 +1214,184 @@ HttpComputeService::Impl::CheckAttachments(const CbObject& ActionObj, std::vecto
}
void
+HttpComputeService::Impl::HandleSubmitAction(HttpServerRequest& HttpReq, int QueueId, int Priority, const WorkerDesc* Worker)
+{
+ // QueueId > 0: queue-scoped enqueue; QueueId == 0: implicit queue (global routes)
+ auto Enqueue = [&](CbObject ActionObj) -> ComputeServiceSession::EnqueueResult {
+ if (QueueId > 0)
+ {
+ if (Worker)
+ {
+ return m_ComputeService.EnqueueResolvedActionToQueue(QueueId, *Worker, ActionObj, Priority);
+ }
+ return m_ComputeService.EnqueueActionToQueue(QueueId, ActionObj, Priority);
+ }
+ else
+ {
+ if (Worker)
+ {
+ return m_ComputeService.EnqueueResolvedAction(*Worker, ActionObj, Priority);
+ }
+ return m_ComputeService.EnqueueAction(ActionObj, Priority);
+ }
+ };
+
+ // Read payload upfront and handle attachments based on content type
+ CbObject Body;
+ IngestStats Stats = {};
+
+ switch (HttpReq.RequestContentType())
+ {
+ case HttpContentType::kCbObject:
+ {
+ IoBuffer Payload = HttpReq.ReadPayload();
+ Body = LoadCompactBinaryObject(Payload);
+ break;
+ }
+
+ case HttpContentType::kCbPackage:
+ {
+ CbPackage Package = HttpReq.ReadPayloadPackage();
+ Body = Package.GetObject();
+ Stats = IngestPackageAttachments(Package);
+ break;
+ }
+
+ default:
+ return HttpReq.WriteResponse(HttpResponseCode::BadRequest);
+ }
+
+ // Check for "actions" array to determine batch vs single-action path
+ CbArray Actions = Body.Find("actions"sv).AsArray();
+
+ if (Actions.Num() > 0)
+ {
+ // --- Batch path ---
+
+ // For CbObject payloads, check all attachments upfront before enqueuing anything
+ if (HttpReq.RequestContentType() == HttpContentType::kCbObject)
+ {
+ std::vector<IoHash> NeedList;
+
+ for (CbField ActionField : Actions)
+ {
+ CbObject ActionObj = ActionField.AsObject();
+ CheckAttachments(ActionObj, NeedList);
+ }
+
+ if (!NeedList.empty())
+ {
+ CbObjectWriter Cbo;
+ Cbo.BeginArray("need");
+
+ for (const IoHash& Hash : NeedList)
+ {
+ Cbo << Hash;
+ }
+
+ Cbo.EndArray();
+
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound, Cbo.Save());
+ }
+ }
+
+ // Enqueue all actions and collect results
+ CbObjectWriter Cbo;
+ int Accepted = 0;
+
+ Cbo.BeginArray("results");
+
+ for (CbField ActionField : Actions)
+ {
+ CbObject ActionObj = ActionField.AsObject();
+
+ ComputeServiceSession::EnqueueResult Result = Enqueue(ActionObj);
+
+ Cbo.BeginObject();
+
+ if (Result)
+ {
+ Cbo << "lsn"sv << Result.Lsn;
+ ++Accepted;
+ }
+ else
+ {
+ Cbo << "error"sv << Result.ResponseMessage;
+ }
+
+ Cbo.EndObject();
+ }
+
+ Cbo.EndArray();
+
+ if (Stats.Count > 0)
+ {
+ ZEN_DEBUG("queue {}: batch accepted {}/{} actions: {} in {} attachments. {} new ({} attachments)",
+ QueueId,
+ Accepted,
+ Actions.Num(),
+ zen::NiceBytes(Stats.Bytes),
+ Stats.Count,
+ zen::NiceBytes(Stats.NewBytes),
+ Stats.NewCount);
+ }
+ else
+ {
+ ZEN_DEBUG("queue {}: batch accepted {}/{} actions", QueueId, Accepted, Actions.Num());
+ }
+
+ return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save());
+ }
+
+ // --- Single-action path: Body is the action itself ---
+
+ if (HttpReq.RequestContentType() == HttpContentType::kCbObject)
+ {
+ std::vector<IoHash> NeedList;
+
+ if (!CheckAttachments(Body, NeedList))
+ {
+ CbObjectWriter Cbo;
+ Cbo.BeginArray("need");
+
+ for (const IoHash& Hash : NeedList)
+ {
+ Cbo << Hash;
+ }
+
+ Cbo.EndArray();
+
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound, Cbo.Save());
+ }
+ }
+
+ if (ComputeServiceSession::EnqueueResult Result = Enqueue(Body))
+ {
+ if (Stats.Count > 0)
+ {
+ ZEN_DEBUG("queue {}: accepted action {} (lsn {}): {} in {} attachments. {} new ({} attachments)",
+ QueueId,
+ Body.GetHash(),
+ Result.Lsn,
+ zen::NiceBytes(Stats.Bytes),
+ Stats.Count,
+ zen::NiceBytes(Stats.NewBytes),
+ Stats.NewCount);
+ }
+ else
+ {
+ ZEN_DEBUG("queue {}: action {} accepted (lsn {})", QueueId, Body.GetHash(), Result.Lsn);
+ }
+
+ return HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage);
+ }
+ else
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage);
+ }
+}
+
+void
HttpComputeService::Impl::HandleWorkersGet(HttpServerRequest& HttpReq)
{
CbObjectWriter Cbo;
@@ -1632,6 +1532,136 @@ HttpComputeService::Impl::HandleWorkerRequest(HttpServerRequest& HttpReq, const
}
//////////////////////////////////////////////////////////////////////////
+//
+// IWebSocketHandler
+//
+
+void
+HttpComputeService::OnWebSocketOpen(Ref<WebSocketConnection> Connection)
+{
+ m_Impl->OnWebSocketOpen(std::move(Connection));
+}
+
+void
+HttpComputeService::OnWebSocketMessage([[maybe_unused]] WebSocketConnection& Conn, [[maybe_unused]] const WebSocketMessage& Msg)
+{
+ // Clients are receive-only; ignore any inbound messages.
+}
+
+void
+HttpComputeService::OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, [[maybe_unused]] std::string_view Reason)
+{
+ m_Impl->OnWebSocketClose(Conn, Code);
+}
+
+void
+HttpComputeService::OnActionsCompleted(std::span<const CompletedActionNotification> Actions)
+{
+ m_Impl->OnActionsCompleted(Actions);
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Impl — WebSocket / observer
+//
+
+void
+HttpComputeService::Impl::OnWebSocketOpen(Ref<WebSocketConnection> Connection)
+{
+ ZEN_INFO("compute WebSocket client connected");
+ m_WsConnectionsLock.WithExclusiveLock([&] { m_WsConnections.push_back(std::move(Connection)); });
+}
+
+void
+HttpComputeService::Impl::OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code)
+{
+ ZEN_INFO("compute WebSocket client disconnected (code {})", Code);
+
+ m_WsConnectionsLock.WithExclusiveLock([&] {
+ auto It = std::remove_if(m_WsConnections.begin(), m_WsConnections.end(), [&Conn](const Ref<WebSocketConnection>& C) {
+ return C.Get() == &Conn;
+ });
+ m_WsConnections.erase(It, m_WsConnections.end());
+ });
+}
+
+void
+HttpComputeService::Impl::OnActionsCompleted(std::span<const IComputeCompletionObserver::CompletedActionNotification> Actions)
+{
+ using ActionState = IComputeCompletionObserver::ActionState;
+ using CompletedActionNotification = IComputeCompletionObserver::CompletedActionNotification;
+
+ // Snapshot connections under shared lock
+ eastl::fixed_vector<Ref<WebSocketConnection>, 16> Connections;
+ m_WsConnectionsLock.WithSharedLock([&] { Connections = {begin(m_WsConnections), end(m_WsConnections)}; });
+
+ if (Connections.empty())
+ {
+ return;
+ }
+
+ // Build CompactBinary notification grouped by state:
+ // {"Completed": [lsn, ...], "Failed": [lsn, ...], ...}
+ // Each state name becomes an array key containing the LSNs in that state.
+ CbObjectWriter Cbo;
+
+ // Sort by state so we can emit one array per state in a single pass.
+ // Copy into a local vector since the span is const.
+ eastl::fixed_vector<CompletedActionNotification, 16> Sorted(Actions.begin(), Actions.end());
+ std::sort(Sorted.begin(), Sorted.end(), [](const auto& A, const auto& B) { return A.State < B.State; });
+
+ ActionState CurrentState{};
+ bool ArrayOpen = false;
+
+ for (const CompletedActionNotification& Action : Sorted)
+ {
+ if (!ArrayOpen || Action.State != CurrentState)
+ {
+ if (ArrayOpen)
+ {
+ Cbo.EndArray();
+ }
+ CurrentState = Action.State;
+ Cbo.BeginArray(IComputeCompletionObserver::ActionStateToString(CurrentState));
+ ArrayOpen = true;
+ }
+ Cbo.AddInteger(Action.Lsn);
+ }
+
+ if (ArrayOpen)
+ {
+ Cbo.EndArray();
+ }
+
+ CbObject Msg = Cbo.Save();
+ MemoryView MsgView = Msg.GetView();
+
+ // Broadcast to all connected clients, prune closed ones
+ bool HadClosedConnections = false;
+ for (auto& Conn : Connections)
+ {
+ if (Conn->IsOpen())
+ {
+ Conn->SendBinary(MsgView);
+ }
+ else
+ {
+ HadClosedConnections = true;
+ }
+ }
+
+ if (HadClosedConnections)
+ {
+ m_WsConnectionsLock.WithExclusiveLock([&] {
+ auto It = std::remove_if(m_WsConnections.begin(), m_WsConnections.end(), [](const Ref<WebSocketConnection>& C) {
+ return !C->IsOpen();
+ });
+ m_WsConnections.erase(It, m_WsConnections.end());
+ });
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
void
httpcomputeservice_forcelink()
diff --git a/src/zencompute/include/zencompute/computeservice.h b/src/zencompute/include/zencompute/computeservice.h
index 65ec5f9ee..1ca78738a 100644
--- a/src/zencompute/include/zencompute/computeservice.h
+++ b/src/zencompute/include/zencompute/computeservice.h
@@ -13,6 +13,7 @@
# include <zenhttp/httpcommon.h>
# include <filesystem>
+# include <span>
namespace zen {
class ChunkResolver;
@@ -29,6 +30,53 @@ class RemoteHttpRunner;
struct RunnerAction;
struct SubmitResult;
+/**
+ * Observer interface for action completion notifications.
+ *
+ * Implementors receive a batch of notifications whenever actions reach a
+ * terminal state (Completed, Failed, Abandoned, Cancelled). The callback
+ * fires on the scheduler thread *after* the action result has been placed
+ * in m_ResultsMap, so GET /jobs/{lsn} will succeed by the time the client
+ * reacts to the notification.
+ */
+class IComputeCompletionObserver
+{
+public:
+ virtual ~IComputeCompletionObserver() = default;
+
+ enum class ActionState
+ {
+ Completed,
+ Failed,
+ Abandoned,
+ Cancelled,
+ };
+
+ struct CompletedActionNotification
+ {
+ int Lsn;
+ ActionState State;
+ };
+
+ static constexpr std::string_view ActionStateToString(ActionState State)
+ {
+ switch (State)
+ {
+ case ActionState::Completed:
+ return "Completed";
+ case ActionState::Failed:
+ return "Failed";
+ case ActionState::Abandoned:
+ return "Abandoned";
+ case ActionState::Cancelled:
+ return "Cancelled";
+ }
+ return "Unknown";
+ }
+
+ virtual void OnActionsCompleted(std::span<const CompletedActionNotification> Actions) = 0;
+};
+
struct WorkerDesc
{
CbPackage Descriptor;
@@ -91,11 +139,25 @@ public:
// Sunset can be reached from any non-Sunset state.
bool RequestStateTransition(SessionState NewState);
+ // Convenience helpers for common state transitions.
+ bool Ready() { return RequestStateTransition(SessionState::Ready); }
+ bool Abandon() { return RequestStateTransition(SessionState::Abandoned); }
+
// Orchestration
void SetOrchestratorEndpoint(std::string_view Endpoint);
void SetOrchestratorBasePath(std::filesystem::path BasePath);
+ void SetOrchestrator(std::string_view Endpoint, std::filesystem::path BasePath)
+ {
+ SetOrchestratorEndpoint(Endpoint);
+ SetOrchestratorBasePath(std::move(BasePath));
+ }
+
+ /// Immediately wake the scheduler to re-poll the orchestrator for worker changes.
+ /// Resets the polling throttle so the next scheduler tick calls UpdateCoordinatorState().
+ void NotifyOrchestratorChanged();
+
// Worker registration and discovery
void RegisterWorker(CbPackage Worker);
@@ -182,6 +244,7 @@ public:
};
[[nodiscard]] RescheduleResult RescheduleAction(int ActionLsn);
+ [[nodiscard]] RescheduleResult RetractAction(int ActionLsn);
void GetCompleted(CbWriter&);
@@ -215,7 +278,7 @@ public:
// sized to match RunnerAction::State::_Count but we can't use the enum here
// for dependency reasons, so just use a fixed size array and static assert in
// the implementation file
- uint64_t Timestamps[8] = {};
+ uint64_t Timestamps[9] = {};
};
[[nodiscard]] std::vector<ActionHistoryEntry> GetActionHistory(int Limit = 100);
@@ -235,6 +298,10 @@ public:
void EmitStats(CbObjectWriter& Cbo);
+ // Completion observer (used by HttpComputeService for WebSocket push)
+
+ void SetCompletionObserver(IComputeCompletionObserver* Observer);
+
// Recording
void StartRecording(ChunkResolver& InResolver, const std::filesystem::path& RecordingPath);
diff --git a/src/zencompute/include/zencompute/httpcomputeservice.h b/src/zencompute/include/zencompute/httpcomputeservice.h
index ee1cd2614..b58e73a0d 100644
--- a/src/zencompute/include/zencompute/httpcomputeservice.h
+++ b/src/zencompute/include/zencompute/httpcomputeservice.h
@@ -9,6 +9,7 @@
# include "zencompute/computeservice.h"
# include <zenhttp/httpserver.h>
+# include <zenhttp/websocket.h>
# include <filesystem>
# include <memory>
@@ -22,7 +23,7 @@ namespace zen::compute {
/**
* HTTP interface for compute service
*/
-class HttpComputeService : public HttpService, public IHttpStatsProvider
+class HttpComputeService : public HttpService, public IHttpStatsProvider, public IWebSocketHandler, public IComputeCompletionObserver
{
public:
HttpComputeService(CidStore& InCidStore,
@@ -42,6 +43,16 @@ public:
void HandleStatsRequest(HttpServerRequest& Request) override;
+ // IWebSocketHandler
+
+ void OnWebSocketOpen(Ref<WebSocketConnection> Connection) override;
+ void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) override;
+ void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason) override;
+
+ // IComputeCompletionObserver
+
+ void OnActionsCompleted(std::span<const CompletedActionNotification> Actions) override;
+
private:
struct Impl;
std::unique_ptr<Impl> m_Impl;
diff --git a/src/zencompute/runners/functionrunner.cpp b/src/zencompute/runners/functionrunner.cpp
index 768cdf1e1..4f116e7d8 100644
--- a/src/zencompute/runners/functionrunner.cpp
+++ b/src/zencompute/runners/functionrunner.cpp
@@ -215,15 +215,22 @@ BaseRunnerGroup::GetSubmittedActionCount()
return TotalCount;
}
-void
+bool
BaseRunnerGroup::RegisterWorker(CbPackage Worker)
{
RwLock::SharedLockScope _(m_RunnersLock);
+ bool AllSucceeded = true;
+
for (auto& Runner : m_Runners)
{
- Runner->RegisterWorker(Worker);
+ if (!Runner->RegisterWorker(Worker))
+ {
+ AllSucceeded = false;
+ }
}
+
+ return AllSucceeded;
}
void
@@ -276,12 +283,34 @@ RunnerAction::~RunnerAction()
}
bool
+RunnerAction::RetractAction()
+{
+ State CurrentState = m_ActionState.load();
+
+ do
+ {
+ // Only allow retraction from pre-terminal states (idempotent if already retracted)
+ if (CurrentState > State::Running)
+ {
+ return CurrentState == State::Retracted;
+ }
+
+ if (m_ActionState.compare_exchange_strong(CurrentState, State::Retracted))
+ {
+ this->Timestamps[static_cast<int>(State::Retracted)] = DateTime::Now().GetTicks();
+ m_OwnerSession->PostUpdate(this);
+ return true;
+ }
+ } while (true);
+}
+
+bool
RunnerAction::ResetActionStateToPending()
{
- // Only allow reset from Failed or Abandoned states
+ // Only allow reset from Failed, Abandoned, or Retracted states
State CurrentState = m_ActionState.load();
- if (CurrentState != State::Failed && CurrentState != State::Abandoned)
+ if (CurrentState != State::Failed && CurrentState != State::Abandoned && CurrentState != State::Retracted)
{
return false;
}
@@ -305,8 +334,11 @@ RunnerAction::ResetActionStateToPending()
CpuUsagePercent.store(-1.0f, std::memory_order_relaxed);
CpuSeconds.store(0.0f, std::memory_order_relaxed);
- // Increment retry count
- RetryCount.fetch_add(1, std::memory_order_relaxed);
+ // Increment retry count (skip for Retracted — nothing failed)
+ if (CurrentState != State::Retracted)
+ {
+ RetryCount.fetch_add(1, std::memory_order_relaxed);
+ }
// Re-enter the scheduler pipeline
m_OwnerSession->PostUpdate(this);
diff --git a/src/zencompute/runners/functionrunner.h b/src/zencompute/runners/functionrunner.h
index f67414dbb..56c3f3af0 100644
--- a/src/zencompute/runners/functionrunner.h
+++ b/src/zencompute/runners/functionrunner.h
@@ -29,8 +29,8 @@ public:
FunctionRunner(std::filesystem::path BasePath);
virtual ~FunctionRunner() = 0;
- virtual void Shutdown() = 0;
- virtual void RegisterWorker(const CbPackage& WorkerPackage) = 0;
+ virtual void Shutdown() = 0;
+ [[nodiscard]] virtual bool RegisterWorker(const CbPackage& WorkerPackage) = 0;
[[nodiscard]] virtual SubmitResult SubmitAction(Ref<RunnerAction> Action) = 0;
[[nodiscard]] virtual size_t GetSubmittedActionCount() = 0;
@@ -63,7 +63,7 @@ public:
SubmitResult SubmitAction(Ref<RunnerAction> Action);
std::vector<SubmitResult> SubmitActions(const std::vector<Ref<RunnerAction>>& Actions);
size_t GetSubmittedActionCount();
- void RegisterWorker(CbPackage Worker);
+ [[nodiscard]] bool RegisterWorker(CbPackage Worker);
void Shutdown();
bool CancelAction(int ActionLsn);
void CancelRemoteQueue(int QueueId);
@@ -114,6 +114,30 @@ struct RunnerGroup : public BaseRunnerGroup
/**
* This represents an action going through different stages of scheduling and execution.
+ *
+ * State machine
+ * =============
+ *
+ * Normal forward flow (enforced by SetActionState rejecting backward transitions):
+ *
+ * New -> Pending -> Submitting -> Running -> Completed
+ * -> Failed
+ * -> Abandoned
+ * -> Cancelled
+ *
+ * Rescheduling (via ResetActionStateToPending):
+ *
+ * Failed ---> Pending (increments RetryCount, subject to retry limit)
+ * Abandoned ---> Pending (increments RetryCount, subject to retry limit)
+ * Retracted ---> Pending (does NOT increment RetryCount)
+ *
+ * Retraction (via RetractAction, idempotent):
+ *
+ * Pending/Submitting/Running -> Retracted -> Pending (rescheduled)
+ *
+ * Retracted is placed after Cancelled in enum order so that once set,
+ * no runner-side transition (Completed/Failed) can override it via
+ * SetActionState's forward-only rule.
*/
struct RunnerAction : public RefCounted
{
@@ -137,16 +161,20 @@ struct RunnerAction : public RefCounted
enum class State
{
- New,
- Pending,
- Submitting,
- Running,
- Completed,
- Failed,
- Abandoned,
- Cancelled,
+ New, // Initial state at construction, before entering the scheduler
+ Pending, // Queued and waiting for a runner slot
+ Submitting, // Being handed off to a runner (async submission in progress)
+ Running, // Executing on a runner process
+ Completed, // Finished successfully with results available
+ Failed, // Execution failed (transient error, eligible for retry)
+ Abandoned, // Infrastructure termination (e.g. spot eviction, session abandon)
+ Cancelled, // Intentional user cancellation (never retried)
+ Retracted, // Pulled back for rescheduling on a different runner (no retry cost)
_Count
};
+ static_assert(State::Retracted > State::Completed && State::Retracted > State::Failed && State::Retracted > State::Abandoned &&
+ State::Retracted > State::Cancelled,
+ "Retracted must be the highest terminal ordinal so runner-side transitions cannot override it");
static const char* ToString(State _)
{
@@ -168,6 +196,8 @@ struct RunnerAction : public RefCounted
return "Abandoned";
case State::Cancelled:
return "Cancelled";
+ case State::Retracted:
+ return "Retracted";
default:
return "Unknown";
}
@@ -191,6 +221,7 @@ struct RunnerAction : public RefCounted
void SetActionState(State NewState);
bool IsSuccess() const { return ActionState() == State::Completed; }
+ bool RetractAction();
bool ResetActionStateToPending();
bool IsCompleted() const
{
diff --git a/src/zencompute/runners/localrunner.cpp b/src/zencompute/runners/localrunner.cpp
index 7aaefb06e..b61e0a46f 100644
--- a/src/zencompute/runners/localrunner.cpp
+++ b/src/zencompute/runners/localrunner.cpp
@@ -7,14 +7,16 @@
# include <zencore/compactbinary.h>
# include <zencore/compactbinarybuilder.h>
# include <zencore/compactbinarypackage.h>
+# include <zencore/compactbinaryfile.h>
# include <zencore/compress.h>
# include <zencore/except_fmt.h>
# include <zencore/filesystem.h>
# include <zencore/fmtutils.h>
# include <zencore/iobuffer.h>
# include <zencore/iohash.h>
-# include <zencore/system.h>
# include <zencore/scopeguard.h>
+# include <zencore/stream.h>
+# include <zencore/system.h>
# include <zencore/timer.h>
# include <zencore/trace.h>
# include <zenstore/cidstore.h>
@@ -152,7 +154,7 @@ LocalProcessRunner::CreateNewSandbox()
return Path;
}
-void
+bool
LocalProcessRunner::RegisterWorker(const CbPackage& WorkerPackage)
{
ZEN_TRACE_CPU("LocalProcessRunner::RegisterWorker");
@@ -173,6 +175,8 @@ LocalProcessRunner::RegisterWorker(const CbPackage& WorkerPackage)
ZEN_INFO("dumped worker '{}' to 'file://{}'", WorkerId, Path);
}
+
+ return true;
}
size_t
@@ -301,7 +305,7 @@ LocalProcessRunner::PrepareActionSubmission(Ref<RunnerAction> Action)
// Write out action
- zen::WriteFile(SandboxPath / "build.action", ActionObj.GetBuffer().AsIoBuffer());
+ WriteCompactBinaryObject(SandboxPath / "build.action", ActionObj);
// Manifest inputs in sandbox
diff --git a/src/zencompute/runners/localrunner.h b/src/zencompute/runners/localrunner.h
index 7493e980b..b8cff6826 100644
--- a/src/zencompute/runners/localrunner.h
+++ b/src/zencompute/runners/localrunner.h
@@ -51,7 +51,7 @@ public:
~LocalProcessRunner();
virtual void Shutdown() override;
- virtual void RegisterWorker(const CbPackage& WorkerPackage) override;
+ [[nodiscard]] virtual bool RegisterWorker(const CbPackage& WorkerPackage) override;
[[nodiscard]] virtual SubmitResult SubmitAction(Ref<RunnerAction> Action) override;
[[nodiscard]] virtual bool IsHealthy() override { return true; }
[[nodiscard]] virtual size_t GetSubmittedActionCount() override;
diff --git a/src/zencompute/runners/remotehttprunner.cpp b/src/zencompute/runners/remotehttprunner.cpp
index 672636d06..ce6a81173 100644
--- a/src/zencompute/runners/remotehttprunner.cpp
+++ b/src/zencompute/runners/remotehttprunner.cpp
@@ -42,6 +42,20 @@ RemoteHttpRunner::RemoteHttpRunner(ChunkResolver& InChunkResolver,
, m_Http(m_BaseUrl)
, m_InstanceId(Oid::NewOid())
{
+ // Attempt to connect a WebSocket for push-based completion notifications.
+ // If the remote doesn't support WS, OnWsClose fires and we fall back to polling.
+ {
+ std::string WsUrl = HttpToWsUrl(HostName, "/compute/ws");
+
+ HttpWsClientSettings WsSettings;
+ WsSettings.LogCategory = "http_exec_ws";
+ WsSettings.ConnectTimeout = std::chrono::milliseconds{3000};
+
+ IWsClientHandler& Handler = *this;
+ m_WsClient = std::make_unique<HttpWsClient>(WsUrl, Handler, WsSettings);
+ m_WsClient->Connect();
+ }
+
m_MonitorThread = std::thread{&RemoteHttpRunner::MonitorThreadFunction, this};
}
@@ -53,7 +67,29 @@ RemoteHttpRunner::~RemoteHttpRunner()
void
RemoteHttpRunner::Shutdown()
{
- // TODO: should cleanly drain/cancel pending work
+ m_AcceptNewActions = false;
+
+ // Close the WebSocket client first, so no more wakeup signals arrive.
+ if (m_WsClient)
+ {
+ m_WsClient->Close();
+ }
+
+ // Cancel all known remote queues so the remote side stops scheduling new
+ // work and cancels in-flight actions belonging to those queues.
+
+ {
+ std::vector<std::pair<int, Oid>> Queues;
+
+ m_QueueTokenLock.WithSharedLock([&] { Queues.assign(m_RemoteQueueTokens.begin(), m_RemoteQueueTokens.end()); });
+
+ for (const auto& [QueueId, Token] : Queues)
+ {
+ CancelRemoteQueue(QueueId);
+ }
+ }
+
+ // Stop the monitor thread so it no longer polls the remote.
m_MonitorThreadEnabled = false;
m_MonitorThreadEvent.Set();
@@ -61,9 +97,22 @@ RemoteHttpRunner::Shutdown()
{
m_MonitorThread.join();
}
+
+ // Drain the running map and mark all remaining actions as Failed so the
+ // scheduler can reschedule or finalize them.
+
+ std::unordered_map<int, HttpRunningAction> Remaining;
+
+ m_RunningLock.WithExclusiveLock([&] { Remaining.swap(m_RemoteRunningMap); });
+
+ for (auto& [RemoteLsn, HttpAction] : Remaining)
+ {
+ ZEN_DEBUG("shutdown: marking remote action LSN {} (local LSN {}) as Failed", RemoteLsn, HttpAction.Action->ActionLsn);
+ HttpAction.Action->SetActionState(RunnerAction::State::Failed);
+ }
}
-void
+bool
RemoteHttpRunner::RegisterWorker(const CbPackage& WorkerPackage)
{
ZEN_TRACE_CPU("RemoteHttpRunner::RegisterWorker");
@@ -125,15 +174,13 @@ RemoteHttpRunner::RegisterWorker(const CbPackage& WorkerPackage)
if (!IsHttpSuccessCode(PayloadResponse.StatusCode))
{
ZEN_ERROR("ERROR: unable to register payloads for worker {} at {}{}", WorkerId, m_Http.GetBaseUri(), WorkerUrl);
-
- // TODO: propagate error
+ return false;
}
}
else if (!IsHttpSuccessCode(DescResponse.StatusCode))
{
ZEN_ERROR("ERROR: unable to register worker {} at {}{}", WorkerId, m_Http.GetBaseUri(), WorkerUrl);
-
- // TODO: propagate error
+ return false;
}
else
{
@@ -152,14 +199,20 @@ RemoteHttpRunner::RegisterWorker(const CbPackage& WorkerPackage)
WorkerUrl,
(int)WorkerResponse.StatusCode,
ToString(WorkerResponse.StatusCode));
-
- // TODO: propagate error
+ return false;
}
+
+ return true;
}
size_t
RemoteHttpRunner::QueryCapacity()
{
+ if (!m_AcceptNewActions)
+ {
+ return 0;
+ }
+
// Estimate how much more work we're ready to accept
RwLock::SharedLockScope _{m_RunningLock};
@@ -191,24 +244,68 @@ RemoteHttpRunner::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions)
return Results;
}
- // For larger batches, submit HTTP requests in parallel via the shared worker pool
+ // Collect distinct QueueIds and ensure remote queues exist once per queue
- std::vector<std::future<SubmitResult>> Futures;
- Futures.reserve(Actions.size());
+ std::unordered_map<int, Oid> QueueTokens; // QueueId → remote token (0 stays as Zero)
for (const Ref<RunnerAction>& Action : Actions)
{
- std::packaged_task<SubmitResult()> Task([this, Action]() { return SubmitAction(Action); });
+ const int QueueId = Action->QueueId;
+ if (QueueId != 0 && QueueTokens.find(QueueId) == QueueTokens.end())
+ {
+ CbObject QueueMeta = Action->GetOwnerSession()->GetQueueMetadata(QueueId);
+ CbObject QueueConfig = Action->GetOwnerSession()->GetQueueConfig(QueueId);
+ QueueTokens[QueueId] = EnsureRemoteQueue(QueueId, QueueMeta, QueueConfig);
+ }
+ }
- Futures.push_back(m_WorkerPool.EnqueueTask(std::move(Task), WorkerThreadPool::EMode::EnableBacklog));
+ // Group actions by QueueId
+
+ struct QueueGroup
+ {
+ std::vector<Ref<RunnerAction>> Actions;
+ std::vector<size_t> OriginalIndices;
+ };
+
+ std::unordered_map<int, QueueGroup> Groups;
+
+ for (size_t i = 0; i < Actions.size(); ++i)
+ {
+ auto& Group = Groups[Actions[i]->QueueId];
+ Group.Actions.push_back(Actions[i]);
+ Group.OriginalIndices.push_back(i);
}
- std::vector<SubmitResult> Results;
- Results.reserve(Futures.size());
+ // Submit each group as a batch and map results back to original indices
- for (auto& Future : Futures)
+ std::vector<SubmitResult> Results(Actions.size());
+
+ for (auto& [QueueId, Group] : Groups)
{
- Results.push_back(Future.get());
+ std::string SubmitUrl = "/jobs";
+ if (QueueId != 0)
+ {
+ if (Oid Token = QueueTokens[QueueId]; Token != Oid::Zero)
+ {
+ SubmitUrl = fmt::format("/queues/{}/jobs", Token);
+ }
+ }
+
+ const size_t BatchLimit = size_t(m_MaxBatchSize);
+
+ for (size_t Offset = 0; Offset < Group.Actions.size(); Offset += BatchLimit)
+ {
+ size_t End = zen::Min(Offset + BatchLimit, Group.Actions.size());
+
+ std::vector<Ref<RunnerAction>> Chunk(Group.Actions.begin() + Offset, Group.Actions.begin() + End);
+
+ std::vector<SubmitResult> ChunkResults = SubmitActionBatch(SubmitUrl, Chunk);
+
+ for (size_t j = 0; j < ChunkResults.size(); ++j)
+ {
+ Results[Group.OriginalIndices[Offset + j]] = std::move(ChunkResults[j]);
+ }
+ }
}
return Results;
@@ -221,6 +318,11 @@ RemoteHttpRunner::SubmitAction(Ref<RunnerAction> Action)
// Verify whether we can accept more work
+ if (!m_AcceptNewActions)
+ {
+ return SubmitResult{.IsAccepted = false, .Reason = "runner is shutting down"};
+ }
+
{
RwLock::SharedLockScope _{m_RunningLock};
if (m_RemoteRunningMap.size() >= size_t(m_MaxRunningActions))
@@ -275,7 +377,7 @@ RemoteHttpRunner::SubmitAction(Ref<RunnerAction> Action)
m_Http.GetBaseUri(),
ActionId);
- RegisterWorker(Action->Worker.Descriptor);
+ (void)RegisterWorker(Action->Worker.Descriptor);
}
else
{
@@ -384,6 +486,194 @@ RemoteHttpRunner::SubmitAction(Ref<RunnerAction> Action)
return {};
}
+std::vector<SubmitResult>
+RemoteHttpRunner::SubmitActionBatch(const std::string& SubmitUrl, const std::vector<Ref<RunnerAction>>& Actions)
+{
+ ZEN_TRACE_CPU("RemoteHttpRunner::SubmitActionBatch");
+
+ if (!m_AcceptNewActions)
+ {
+ return std::vector<SubmitResult>(Actions.size(), SubmitResult{.IsAccepted = false, .Reason = "runner is shutting down"});
+ }
+
+ // Capacity check
+
+ {
+ RwLock::SharedLockScope _{m_RunningLock};
+ if (m_RemoteRunningMap.size() >= size_t(m_MaxRunningActions))
+ {
+ std::vector<SubmitResult> Results(Actions.size(), SubmitResult{.IsAccepted = false});
+ return Results;
+ }
+ }
+
+ // Per-action setup and build batch body
+
+ CbObjectWriter Body;
+ Body.BeginArray("actions"sv);
+
+ for (const Ref<RunnerAction>& Action : Actions)
+ {
+ Action->ExecutionLocation = m_HostName;
+ MaybeDumpAction(Action->ActionLsn, Action->ActionObj);
+ Body.AddObject(Action->ActionObj);
+ }
+
+ Body.EndArray();
+
+ // POST the batch
+
+ HttpClient::Response Response = m_Http.Post(SubmitUrl, Body.Save());
+
+ if (Response.StatusCode == HttpResponseCode::OK)
+ {
+ return ParseBatchResponse(Response, Actions);
+ }
+
+ if (Response.StatusCode == HttpResponseCode::NotFound)
+ {
+ // Server needs attachments — resolve them and retry with a CbPackage
+
+ CbObject NeedObj = Response.AsObject();
+
+ CbPackage Pkg;
+ Pkg.SetObject(Body.Save());
+
+ for (auto& Item : NeedObj["need"sv])
+ {
+ const IoHash NeedHash = Item.AsHash();
+
+ if (IoBuffer Chunk = m_ChunkResolver.FindChunkByCid(NeedHash))
+ {
+ uint64_t DataRawSize = 0;
+ IoHash DataRawHash;
+ CompressedBuffer Compressed =
+ CompressedBuffer::FromCompressed(SharedBuffer{Chunk}, /* out */ DataRawHash, /* out */ DataRawSize);
+
+ ZEN_ASSERT(DataRawHash == NeedHash);
+
+ Pkg.AddAttachment(CbAttachment(Compressed, NeedHash));
+ }
+ else
+ {
+ ZEN_WARN("batch submit: missing attachment {} — falling back to individual submit", NeedHash);
+ return FallbackToIndividualSubmit(Actions);
+ }
+ }
+
+ HttpClient::Response RetryResponse = m_Http.Post(SubmitUrl, Pkg);
+
+ if (RetryResponse.StatusCode == HttpResponseCode::OK)
+ {
+ return ParseBatchResponse(RetryResponse, Actions);
+ }
+
+ ZEN_WARN("batch submit retry failed with {} {} — falling back to individual submit",
+ (int)RetryResponse.StatusCode,
+ ToString(RetryResponse.StatusCode));
+ return FallbackToIndividualSubmit(Actions);
+ }
+
+ // Unexpected status or connection error — fall back to individual submission
+
+ if (Response)
+ {
+ ZEN_WARN("batch submit to {}{} returned {} {} — falling back to individual submit",
+ m_Http.GetBaseUri(),
+ SubmitUrl,
+ (int)Response.StatusCode,
+ ToString(Response.StatusCode));
+ }
+ else
+ {
+ ZEN_WARN("batch submit to {}{} failed — falling back to individual submit", m_Http.GetBaseUri(), SubmitUrl);
+ }
+
+ return FallbackToIndividualSubmit(Actions);
+}
+
+std::vector<SubmitResult>
+RemoteHttpRunner::ParseBatchResponse(const HttpClient::Response& Response, const std::vector<Ref<RunnerAction>>& Actions)
+{
+ std::vector<SubmitResult> Results;
+ Results.reserve(Actions.size());
+
+ CbObject ResponseObj = Response.AsObject();
+ CbArrayView ResultArray = ResponseObj["results"sv].AsArrayView();
+
+ size_t Index = 0;
+ for (CbFieldView Field : ResultArray)
+ {
+ if (Index >= Actions.size())
+ {
+ break;
+ }
+
+ CbObjectView Entry = Field.AsObjectView();
+ const int32_t LsnField = Entry["lsn"sv].AsInt32(0);
+
+ if (LsnField > 0)
+ {
+ HttpRunningAction NewAction;
+ NewAction.Action = Actions[Index];
+ NewAction.RemoteActionLsn = LsnField;
+
+ {
+ RwLock::ExclusiveLockScope _(m_RunningLock);
+ m_RemoteRunningMap[LsnField] = std::move(NewAction);
+ }
+
+ ZEN_DEBUG("batch: scheduled action {} with remote LSN {} (local LSN {})",
+ Actions[Index]->ActionObj.GetHash(),
+ LsnField,
+ Actions[Index]->ActionLsn);
+
+ Actions[Index]->SetActionState(RunnerAction::State::Running);
+
+ Results.push_back(SubmitResult{.IsAccepted = true});
+ }
+ else
+ {
+ std::string_view ErrorMsg = Entry["error"sv].AsString();
+ Results.push_back(SubmitResult{.IsAccepted = false, .Reason = std::string(ErrorMsg)});
+ }
+
+ ++Index;
+ }
+
+ // If the server returned fewer results than actions, mark the rest as not accepted
+ while (Results.size() < Actions.size())
+ {
+ Results.push_back(SubmitResult{.IsAccepted = false, .Reason = "no result from server"});
+ }
+
+ return Results;
+}
+
+std::vector<SubmitResult>
+RemoteHttpRunner::FallbackToIndividualSubmit(const std::vector<Ref<RunnerAction>>& Actions)
+{
+ std::vector<std::future<SubmitResult>> Futures;
+ Futures.reserve(Actions.size());
+
+ for (const Ref<RunnerAction>& Action : Actions)
+ {
+ std::packaged_task<SubmitResult()> Task([this, Action]() { return SubmitAction(Action); });
+
+ Futures.push_back(m_WorkerPool.EnqueueTask(std::move(Task), WorkerThreadPool::EMode::EnableBacklog));
+ }
+
+ std::vector<SubmitResult> Results;
+ Results.reserve(Futures.size());
+
+ for (auto& Future : Futures)
+ {
+ Results.push_back(Future.get());
+ }
+
+ return Results;
+}
+
Oid
RemoteHttpRunner::EnsureRemoteQueue(int QueueId, const CbObject& Metadata, const CbObject& Config)
{
@@ -481,6 +771,35 @@ RemoteHttpRunner::GetSubmittedActionCount()
return m_RemoteRunningMap.size();
}
+//////////////////////////////////////////////////////////////////////////
+//
+// IWsClientHandler
+//
+
+void
+RemoteHttpRunner::OnWsOpen()
+{
+ ZEN_INFO("WebSocket connected to {}", m_HostName);
+ m_WsConnected.store(true, std::memory_order_release);
+}
+
+void
+RemoteHttpRunner::OnWsMessage([[maybe_unused]] const WebSocketMessage& Msg)
+{
+ // The message content is a wakeup signal; no parsing needed.
+ // Signal the monitor thread to sweep completed actions immediately.
+ m_MonitorThreadEvent.Set();
+}
+
+void
+RemoteHttpRunner::OnWsClose([[maybe_unused]] uint16_t Code, [[maybe_unused]] std::string_view Reason)
+{
+ ZEN_WARN("WebSocket disconnected from {} (code {})", m_HostName, Code);
+ m_WsConnected.store(false, std::memory_order_release);
+}
+
+//////////////////////////////////////////////////////////////////////////
+
void
RemoteHttpRunner::MonitorThreadFunction()
{
@@ -489,28 +808,40 @@ RemoteHttpRunner::MonitorThreadFunction()
do
{
const int NormalWaitingTime = 200;
- int WaitTimeMs = NormalWaitingTime;
- auto WaitOnce = [&] { return m_MonitorThreadEvent.Wait(WaitTimeMs); };
- auto SweepOnce = [&] {
+ const int WsWaitingTime = 2000; // Safety-net interval when WS is connected
+
+ int WaitTimeMs = m_WsConnected.load(std::memory_order_relaxed) ? WsWaitingTime : NormalWaitingTime;
+ auto WaitOnce = [&] { return m_MonitorThreadEvent.Wait(WaitTimeMs); };
+ auto SweepOnce = [&] {
const size_t RetiredCount = SweepRunningActions();
- m_RunningLock.WithSharedLock([&] {
- if (m_RemoteRunningMap.size() > 16)
- {
- WaitTimeMs = NormalWaitingTime / 4;
- }
- else
- {
- if (RetiredCount)
+ if (m_WsConnected.load(std::memory_order_relaxed))
+ {
+ // WS connected: use long safety-net interval; the WS message
+ // will wake us immediately for the real work.
+ WaitTimeMs = WsWaitingTime;
+ }
+ else
+ {
+ // No WS: adaptive polling as before
+ m_RunningLock.WithSharedLock([&] {
+ if (m_RemoteRunningMap.size() > 16)
{
- WaitTimeMs = NormalWaitingTime / 2;
+ WaitTimeMs = NormalWaitingTime / 4;
}
else
{
- WaitTimeMs = NormalWaitingTime;
+ if (RetiredCount)
+ {
+ WaitTimeMs = NormalWaitingTime / 2;
+ }
+ else
+ {
+ WaitTimeMs = NormalWaitingTime;
+ }
}
- }
- });
+ });
+ }
};
while (!WaitOnce())
@@ -518,7 +849,7 @@ RemoteHttpRunner::MonitorThreadFunction()
SweepOnce();
}
- // Signal received - this may mean we should quit
+ // Signal received — may be a WS wakeup or a quit signal
SweepOnce();
} while (m_MonitorThreadEnabled);
diff --git a/src/zencompute/runners/remotehttprunner.h b/src/zencompute/runners/remotehttprunner.h
index 9119992a9..c17d0cf2a 100644
--- a/src/zencompute/runners/remotehttprunner.h
+++ b/src/zencompute/runners/remotehttprunner.h
@@ -14,9 +14,11 @@
# include <zencore/workthreadpool.h>
# include <zencore/zencore.h>
# include <zenhttp/httpclient.h>
+# include <zenhttp/httpwsclient.h>
# include <atomic>
# include <filesystem>
+# include <memory>
# include <thread>
# include <unordered_map>
@@ -32,7 +34,7 @@ namespace zen::compute {
*/
-class RemoteHttpRunner : public FunctionRunner
+class RemoteHttpRunner : public FunctionRunner, private IWsClientHandler
{
RemoteHttpRunner(RemoteHttpRunner&&) = delete;
RemoteHttpRunner& operator=(RemoteHttpRunner&&) = delete;
@@ -45,7 +47,7 @@ public:
~RemoteHttpRunner();
virtual void Shutdown() override;
- virtual void RegisterWorker(const CbPackage& WorkerPackage) override;
+ [[nodiscard]] virtual bool RegisterWorker(const CbPackage& WorkerPackage) override;
[[nodiscard]] virtual SubmitResult SubmitAction(Ref<RunnerAction> Action) override;
[[nodiscard]] virtual bool IsHealthy() override;
[[nodiscard]] virtual size_t GetSubmittedActionCount() override;
@@ -66,7 +68,9 @@ private:
std::string m_BaseUrl;
HttpClient m_Http;
- int32_t m_MaxRunningActions = 256; // arbitrary limit for testing
+ std::atomic<bool> m_AcceptNewActions{true};
+ int32_t m_MaxRunningActions = 256; // arbitrary limit for testing
+ int32_t m_MaxBatchSize = 50;
struct HttpRunningAction
{
@@ -92,7 +96,20 @@ private:
// creating remote queues. Generated once at construction and never changes.
Oid m_InstanceId;
+ // WebSocket completion notification client
+ std::unique_ptr<HttpWsClient> m_WsClient;
+ std::atomic<bool> m_WsConnected{false};
+
+ // IWsClientHandler
+ void OnWsOpen() override;
+ void OnWsMessage(const WebSocketMessage& Msg) override;
+ void OnWsClose(uint16_t Code, std::string_view Reason) override;
+
Oid EnsureRemoteQueue(int QueueId, const CbObject& Metadata, const CbObject& Config);
+
+ std::vector<SubmitResult> SubmitActionBatch(const std::string& SubmitUrl, const std::vector<Ref<RunnerAction>>& Actions);
+ std::vector<SubmitResult> ParseBatchResponse(const HttpClient::Response& Response, const std::vector<Ref<RunnerAction>>& Actions);
+ std::vector<SubmitResult> FallbackToIndividualSubmit(const std::vector<Ref<RunnerAction>>& Actions);
};
} // namespace zen::compute
diff --git a/src/zencore/compactbinaryfile.cpp b/src/zencore/compactbinaryfile.cpp
index ec2fc3cd5..9ddafbe15 100644
--- a/src/zencore/compactbinaryfile.cpp
+++ b/src/zencore/compactbinaryfile.cpp
@@ -30,4 +30,15 @@ LoadCompactBinaryObject(const std::filesystem::path& FilePath)
return {.Hash = IoHash::Zero};
}
+void
+WriteCompactBinaryObject(const std::filesystem::path& Path, const CbObject& Object)
+{
+ // We cannot use CbObject::GetView() here because it may not return a complete
+ // view since the type byte can be omitted in arrays.
+ CbWriter Writer;
+ Writer.AddObject(Object);
+ CbFieldIterator Fields = Writer.Save();
+ zen::WriteFile(Path, IoBufferBuilder::MakeFromMemory(Fields.GetRangeView()));
+}
+
} // namespace zen
diff --git a/src/zencore/include/zencore/compactbinaryfile.h b/src/zencore/include/zencore/compactbinaryfile.h
index 33f3e7bea..a06524549 100644
--- a/src/zencore/include/zencore/compactbinaryfile.h
+++ b/src/zencore/include/zencore/compactbinaryfile.h
@@ -15,5 +15,6 @@ struct CbObjectFromFile
};
CbObjectFromFile LoadCompactBinaryObject(const std::filesystem::path& FilePath);
+void WriteCompactBinaryObject(const std::filesystem::path& Path, const CbObject& Object);
} // namespace zen
diff --git a/src/zencore/include/zencore/logging/ansicolorsink.h b/src/zencore/include/zencore/logging/ansicolorsink.h
index 5060a8393..939c70d12 100644
--- a/src/zencore/include/zencore/logging/ansicolorsink.h
+++ b/src/zencore/include/zencore/logging/ansicolorsink.h
@@ -15,6 +15,9 @@ enum class ColorMode
Auto
};
+bool IsColorTerminal();
+bool ResolveColorMode(ColorMode Mode);
+
class AnsiColorStdoutSink : public Sink
{
public:
diff --git a/src/zencore/include/zencore/logging/formatter.h b/src/zencore/include/zencore/logging/formatter.h
index 11904d71d..e605b22b8 100644
--- a/src/zencore/include/zencore/logging/formatter.h
+++ b/src/zencore/include/zencore/logging/formatter.h
@@ -15,6 +15,12 @@ public:
virtual ~Formatter() = default;
virtual void Format(const LogMessage& Msg, MemoryBuffer& Dest) = 0;
virtual std::unique_ptr<Formatter> Clone() const = 0;
+
+ void SetColorEnabled(bool Enabled) { m_UseColor = Enabled; }
+ bool IsColorEnabled() const { return m_UseColor; }
+
+private:
+ bool m_UseColor = false;
};
} // namespace zen::logging
diff --git a/src/zencore/include/zencore/logging/helpers.h b/src/zencore/include/zencore/logging/helpers.h
index ce021e1a5..765aa59e3 100644
--- a/src/zencore/include/zencore/logging/helpers.h
+++ b/src/zencore/include/zencore/logging/helpers.h
@@ -119,4 +119,81 @@ LevelToShortString(LogLevel Level)
return ToStringView(Level);
}
+inline std::string_view
+AnsiColorForLevel(LogLevel Level)
+{
+ using namespace std::string_view_literals;
+ switch (Level)
+ {
+ case Trace:
+ return "\033[37m"sv; // white
+ case Debug:
+ return "\033[36m"sv; // cyan
+ case Info:
+ return "\033[32m"sv; // green
+ case Warn:
+ return "\033[33m\033[1m"sv; // bold yellow
+ case Err:
+ return "\033[31m\033[1m"sv; // bold red
+ case Critical:
+ return "\033[1m\033[41m"sv; // bold on red background
+ default:
+ return "\033[m"sv;
+ }
+}
+
+inline constexpr std::string_view kAnsiReset = "\033[m";
+
+inline void
+AppendAnsiColor(LogLevel Level, MemoryBuffer& Dest)
+{
+ std::string_view Color = AnsiColorForLevel(Level);
+ Dest.append(Color.data(), Color.data() + Color.size());
+}
+
+inline void
+AppendAnsiReset(MemoryBuffer& Dest)
+{
+ Dest.append(kAnsiReset.data(), kAnsiReset.data() + kAnsiReset.size());
+}
+
+// Strip ANSI SGR escape sequences (\033[...m) from the buffer in-place.
+// Only sequences terminated by 'm' are removed (colors, bold, underline, etc.).
+// Other CSI sequences (cursor movement, erase, etc.) are left intact.
+inline void
+StripAnsiSgrSequences(MemoryBuffer& Buf)
+{
+ const char* Src = Buf.data();
+ const char* End = Src + Buf.size();
+ char* Dst = Buf.data();
+
+ while (Src < End)
+ {
+ if (Src[0] == '\033' && (Src + 1) < End && Src[1] == '[')
+ {
+ const char* Seq = Src + 2;
+ while (Seq < End && *Seq != 'm')
+ {
+ ++Seq;
+ }
+ if (Seq < End)
+ {
+ ++Seq; // skip 'm'
+ }
+ Src = Seq;
+ }
+ else
+ {
+ if (Dst != Src)
+ {
+ *Dst = *Src;
+ }
+ ++Dst;
+ ++Src;
+ }
+ }
+
+ Buf.resize(static_cast<size_t>(Dst - Buf.data()));
+}
+
} // namespace zen::logging::helpers
diff --git a/src/zencore/include/zencore/logging/logmsg.h b/src/zencore/include/zencore/logging/logmsg.h
index 1d8b6b1b7..a1acb503b 100644
--- a/src/zencore/include/zencore/logging/logmsg.h
+++ b/src/zencore/include/zencore/logging/logmsg.h
@@ -40,9 +40,6 @@ struct LogMessage
void SetTime(LogClock::time_point InTime) { m_Time = InTime; }
void SetSource(const SourceLocation& InSource) { m_Source = InSource; }
- mutable size_t ColorRangeStart = 0;
- mutable size_t ColorRangeEnd = 0;
-
private:
static constexpr LogPoint s_DefaultPoints[LogLevelCount] = {
{{}, Trace, {}},
diff --git a/src/zencore/include/zencore/testing.h b/src/zencore/include/zencore/testing.h
index 8410216c4..01356fa00 100644
--- a/src/zencore/include/zencore/testing.h
+++ b/src/zencore/include/zencore/testing.h
@@ -43,9 +43,8 @@ public:
TestRunner();
~TestRunner();
- void SetDefaultSuiteFilter(const char* Pattern);
- int ApplyCommandLine(int Argc, char const* const* Argv);
- int Run();
+ int ApplyCommandLine(int Argc, char const* const* Argv, const char* DefaultSuiteFilter = nullptr);
+ int Run();
private:
struct Impl;
diff --git a/src/zencore/logging/ansicolorsink.cpp b/src/zencore/logging/ansicolorsink.cpp
index 540d22359..03aae068a 100644
--- a/src/zencore/logging/ansicolorsink.cpp
+++ b/src/zencore/logging/ansicolorsink.cpp
@@ -4,12 +4,14 @@
#include <zencore/logging/helpers.h>
#include <zencore/logging/messageonlyformatter.h>
+#include <zencore/thread.h>
+
#include <cstdio>
#include <cstdlib>
-#include <mutex>
#if defined(_WIN32)
# include <io.h>
+# include <zencore/windows.h>
# define ZEN_ISATTY _isatty
# define ZEN_FILENO _fileno
#else
@@ -62,188 +64,225 @@ public:
Dest.push_back(' ');
}
- // level (colored range)
+ // level
Dest.push_back('[');
- Msg.ColorRangeStart = Dest.size();
+ if (IsColorEnabled())
+ {
+ helpers::AppendAnsiColor(Msg.GetLevel(), Dest);
+ }
helpers::AppendStringView(helpers::LevelToShortString(Msg.GetLevel()), Dest);
- Msg.ColorRangeEnd = Dest.size();
+ if (IsColorEnabled())
+ {
+ helpers::AppendAnsiReset(Dest);
+ }
Dest.push_back(']');
Dest.push_back(' ');
- // message
- helpers::AppendStringView(Msg.GetPayload(), Dest);
+ // message (align continuation lines with the first line)
+ size_t AnsiBytes = IsColorEnabled() ? (helpers::AnsiColorForLevel(Msg.GetLevel()).size() + helpers::kAnsiReset.size()) : 0;
+ size_t LinePrefixCount = Dest.size() - AnsiBytes;
+
+ auto MsgPayload = Msg.GetPayload();
+ auto ItLineBegin = MsgPayload.begin();
+ auto ItMessageEnd = MsgPayload.end();
+ bool IsFirstLine = true;
+
+ auto ItLineEnd = ItLineBegin;
+
+ auto EmitLine = [&] {
+ if (IsFirstLine)
+ {
+ IsFirstLine = false;
+ }
+ else
+ {
+ for (size_t i = 0; i < LinePrefixCount; ++i)
+ {
+ Dest.push_back(' ');
+ }
+ }
+ helpers::AppendStringView(std::string_view(&*ItLineBegin, ItLineEnd - ItLineBegin), Dest);
+ };
+
+ while (ItLineEnd != ItMessageEnd)
+ {
+ if (*ItLineEnd++ == '\n')
+ {
+ EmitLine();
+ ItLineBegin = ItLineEnd;
+ }
+ }
+
+ if (ItLineBegin != ItMessageEnd)
+ {
+ EmitLine();
+ }
Dest.push_back('\n');
}
- std::unique_ptr<Formatter> Clone() const override { return std::make_unique<DefaultConsoleFormatter>(); }
+ std::unique_ptr<Formatter> Clone() const override
+ {
+ auto Copy = std::make_unique<DefaultConsoleFormatter>();
+ Copy->SetColorEnabled(IsColorEnabled());
+ return Copy;
+ }
private:
std::chrono::seconds m_LastLogSecs{0};
std::tm m_CachedLocalTm{};
};
-static constexpr std::string_view s_Reset = "\033[m";
-
-static std::string_view
-GetColorForLevel(LogLevel InLevel)
+bool
+IsColorTerminal()
{
- using namespace std::string_view_literals;
- switch (InLevel)
+ // If stdout is not a TTY, no color
+ if (ZEN_ISATTY(ZEN_FILENO(stdout)) == 0)
{
- case Trace:
- return "\033[37m"sv; // white
- case Debug:
- return "\033[36m"sv; // cyan
- case Info:
- return "\033[32m"sv; // green
- case Warn:
- return "\033[33m\033[1m"sv; // bold yellow
- case Err:
- return "\033[31m\033[1m"sv; // bold red
- case Critical:
- return "\033[1m\033[41m"sv; // bold on red background
- default:
- return s_Reset;
+ return false;
}
-}
-struct AnsiColorStdoutSink::Impl
-{
- explicit Impl(ColorMode Mode) : m_Formatter(std::make_unique<DefaultConsoleFormatter>()), m_UseColor(ResolveColorMode(Mode)) {}
+ // NO_COLOR convention (https://no-color.org/)
+ if (std::getenv("NO_COLOR") != nullptr)
+ {
+ return false;
+ }
- static bool IsColorTerminal()
+ // COLORTERM is set by terminals that support color (e.g. "truecolor", "24bit")
+ if (std::getenv("COLORTERM") != nullptr)
{
- // If stdout is not a TTY, no color
- if (ZEN_ISATTY(ZEN_FILENO(stdout)) == 0)
- {
- return false;
- }
+ return true;
+ }
- // NO_COLOR convention (https://no-color.org/)
- if (std::getenv("NO_COLOR") != nullptr)
+ // Check TERM for known color-capable values
+ const char* Term = std::getenv("TERM");
+ if (Term != nullptr)
+ {
+ std::string_view TermView(Term);
+ // "dumb" terminals do not support color
+ if (TermView == "dumb")
{
return false;
}
-
- // COLORTERM is set by terminals that support color (e.g. "truecolor", "24bit")
- if (std::getenv("COLORTERM") != nullptr)
+ // Match against known color-capable terminal types.
+ // TERM often includes suffixes like "-256color", so we use substring matching.
+ constexpr std::string_view ColorTerms[] = {
+ "alacritty",
+ "ansi",
+ "color",
+ "console",
+ "cygwin",
+ "gnome",
+ "konsole",
+ "kterm",
+ "linux",
+ "msys",
+ "putty",
+ "rxvt",
+ "screen",
+ "tmux",
+ "vt100",
+ "vt102",
+ "xterm",
+ };
+ for (std::string_view Candidate : ColorTerms)
{
- return true;
- }
-
- // Check TERM for known color-capable values
- const char* Term = std::getenv("TERM");
- if (Term != nullptr)
- {
- std::string_view TermView(Term);
- // "dumb" terminals do not support color
- if (TermView == "dumb")
+ if (TermView.find(Candidate) != std::string_view::npos)
{
- return false;
- }
- // Match against known color-capable terminal types.
- // TERM often includes suffixes like "-256color", so we use substring matching.
- constexpr std::string_view ColorTerms[] = {
- "alacritty",
- "ansi",
- "color",
- "console",
- "cygwin",
- "gnome",
- "konsole",
- "kterm",
- "linux",
- "msys",
- "putty",
- "rxvt",
- "screen",
- "tmux",
- "vt100",
- "vt102",
- "xterm",
- };
- for (std::string_view Candidate : ColorTerms)
- {
- if (TermView.find(Candidate) != std::string_view::npos)
- {
- return true;
- }
+ return true;
}
}
+ }
#if defined(_WIN32)
- // Windows console supports ANSI color by default in modern versions
- return true;
+ // Windows console supports ANSI color by default in modern versions
+ return true;
#else
- // Unknown terminal — be conservative
- return false;
+ // Unknown terminal — be conservative
+ return false;
#endif
- }
+}
- static bool ResolveColorMode(ColorMode Mode)
+bool
+ResolveColorMode(ColorMode Mode)
+{
+ switch (Mode)
{
- switch (Mode)
- {
- case ColorMode::On:
- return true;
- case ColorMode::Off:
- return false;
- case ColorMode::Auto:
- default:
- return IsColorTerminal();
- }
+ case ColorMode::On:
+ return true;
+ case ColorMode::Off:
+ return false;
+ case ColorMode::Auto:
+ default:
+ return IsColorTerminal();
}
+}
- void Log(const LogMessage& Msg)
+struct AnsiColorStdoutSink::Impl
+{
+ explicit Impl(ColorMode Mode) : m_Formatter(std::make_unique<DefaultConsoleFormatter>()), m_UseColor(ResolveColorMode(Mode))
{
- std::lock_guard<std::mutex> Lock(m_Mutex);
-
- MemoryBuffer Formatted;
- m_Formatter->Format(Msg, Formatted);
+ m_Formatter->SetColorEnabled(m_UseColor);
+ }
- if (m_UseColor && Msg.ColorRangeEnd > Msg.ColorRangeStart)
- {
- // Print pre-color range
- fwrite(Formatted.data(), 1, Msg.ColorRangeStart, m_File);
+ void WriteOutput(const MemoryBuffer& Buf)
+ {
+ RwLock::ExclusiveLockScope Lock(m_Lock);
- // Print color
- std::string_view Color = GetColorForLevel(Msg.GetLevel());
- fwrite(Color.data(), 1, Color.size(), m_File);
+#if defined(_WIN32)
+ DWORD Written;
+ WriteFile(m_Handle, Buf.data(), static_cast<DWORD>(Buf.size()), &Written, nullptr);
+#else
+ fwrite(Buf.data(), 1, Buf.size(), m_File);
+#endif
- // Print colored range
- fwrite(Formatted.data() + Msg.ColorRangeStart, 1, Msg.ColorRangeEnd - Msg.ColorRangeStart, m_File);
+ m_Dirty.store(false, std::memory_order_relaxed);
+ }
- // Reset color
- fwrite(s_Reset.data(), 1, s_Reset.size(), m_File);
+ void Log(const LogMessage& Msg)
+ {
+ MemoryBuffer Formatted;
+ m_Formatter->Format(Msg, Formatted);
- // Print remainder
- fwrite(Formatted.data() + Msg.ColorRangeEnd, 1, Formatted.size() - Msg.ColorRangeEnd, m_File);
- }
- else
+ if (!m_UseColor)
{
- fwrite(Formatted.data(), 1, Formatted.size(), m_File);
+ helpers::StripAnsiSgrSequences(Formatted);
}
- fflush(m_File);
+ WriteOutput(Formatted);
}
void Flush()
{
- std::lock_guard<std::mutex> Lock(m_Mutex);
+ if (!m_Dirty.load(std::memory_order_relaxed))
+ {
+ return;
+ }
+ RwLock::ExclusiveLockScope Lock(m_Lock);
+ m_Dirty.store(false, std::memory_order_relaxed);
+#if defined(_WIN32)
+ FlushFileBuffers(m_Handle);
+#else
fflush(m_File);
+#endif
}
void SetFormatter(std::unique_ptr<Formatter> InFormatter)
{
- std::lock_guard<std::mutex> Lock(m_Mutex);
+ RwLock::ExclusiveLockScope Lock(m_Lock);
+ InFormatter->SetColorEnabled(m_UseColor);
m_Formatter = std::move(InFormatter);
}
private:
- std::mutex m_Mutex;
+ RwLock m_Lock;
std::unique_ptr<Formatter> m_Formatter;
- FILE* m_File = stdout;
- bool m_UseColor = true;
+#if defined(_WIN32)
+ HANDLE m_Handle = GetStdHandle(STD_OUTPUT_HANDLE);
+#else
+ FILE* m_File = stdout;
+#endif
+ bool m_UseColor = true;
+ std::atomic<bool> m_Dirty = false;
};
AnsiColorStdoutSink::AnsiColorStdoutSink(ColorMode Mode) : m_Impl(std::make_unique<Impl>(Mode))
diff --git a/src/zencore/testing.cpp b/src/zencore/testing.cpp
index f5bc723b1..c6ee5ee6b 100644
--- a/src/zencore/testing.cpp
+++ b/src/zencore/testing.cpp
@@ -181,6 +181,15 @@ struct TestListener : public doctest::IReporter
void test_case_start(const doctest::TestCaseData& in) override
{
Current = &in;
+
+ if (in.m_test_suite && in.m_test_suite != CurrentSuite)
+ {
+ CurrentSuite = in.m_test_suite;
+ ZEN_CONSOLE("{}==============================================================================={}", ColorYellow, ColorNone);
+ ZEN_CONSOLE("{} TEST_SUITE: {}{}", ColorYellow, CurrentSuite, ColorNone);
+ ZEN_CONSOLE("{}==============================================================================={}", ColorYellow, ColorNone);
+ }
+
ZEN_CONSOLE("{}======== TEST_CASE: {:<50} ========{}", ColorYellow, Current->m_name, ColorNone);
}
@@ -217,8 +226,9 @@ struct TestListener : public doctest::IReporter
void test_case_skipped(const doctest::TestCaseData& /*in*/) override {}
- const doctest::TestCaseData* Current = nullptr;
- std::chrono::steady_clock::time_point RunStart = {};
+ const doctest::TestCaseData* Current = nullptr;
+ std::string_view CurrentSuite = {};
+ std::chrono::steady_clock::time_point RunStart = {};
struct FailedTestInfo
{
@@ -244,15 +254,29 @@ TestRunner::~TestRunner()
{
}
-void
-TestRunner::SetDefaultSuiteFilter(const char* Pattern)
-{
- m_Impl->Session.setOption("test-suite", Pattern);
-}
-
int
-TestRunner::ApplyCommandLine(int Argc, char const* const* Argv)
+TestRunner::ApplyCommandLine(int Argc, char const* const* Argv, const char* DefaultSuiteFilter)
{
+ // Apply the default suite filter only when the command line doesn't provide
+ // an explicit --test-suite / --ts override.
+ if (DefaultSuiteFilter)
+ {
+ bool HasExplicitSuiteFilter = false;
+ for (int i = 1; i < Argc; ++i)
+ {
+ std::string_view Arg = Argv[i];
+ if (Arg.starts_with("--test-suite=") || Arg.starts_with("--ts=") || Arg.starts_with("-test-suite=") || Arg.starts_with("-ts="))
+ {
+ HasExplicitSuiteFilter = true;
+ break;
+ }
+ }
+ if (!HasExplicitSuiteFilter)
+ {
+ m_Impl->Session.setOption("test-suite", DefaultSuiteFilter);
+ }
+ }
+
m_Impl->Session.applyCommandLine(Argc, Argv);
for (int i = 1; i < Argc; ++i)
@@ -316,6 +340,7 @@ RunTestMain(int Argc, char* Argv[], const char* ExecutableName, void (*ForceLink
TestRunner Runner;
// Derive default suite filter from ExecutableName: "zencore-test" -> "core.*"
+ std::string DefaultSuiteFilter;
if (ExecutableName)
{
std::string_view Name = ExecutableName;
@@ -329,13 +354,12 @@ RunTestMain(int Argc, char* Argv[], const char* ExecutableName, void (*ForceLink
}
if (!Name.empty())
{
- std::string Filter(Name);
- Filter += ".*";
- Runner.SetDefaultSuiteFilter(Filter.c_str());
+ DefaultSuiteFilter = Name;
+ DefaultSuiteFilter += ".*";
}
}
- Runner.ApplyCommandLine(Argc, Argv);
+ Runner.ApplyCommandLine(Argc, Argv, DefaultSuiteFilter.empty() ? nullptr : DefaultSuiteFilter.c_str());
return Runner.Run();
}
diff --git a/src/zenhttp/clients/httpwsclient.cpp b/src/zenhttp/clients/httpwsclient.cpp
index fbae9f5fe..770213738 100644
--- a/src/zenhttp/clients/httpwsclient.cpp
+++ b/src/zenhttp/clients/httpwsclient.cpp
@@ -638,4 +638,34 @@ HttpWsClient::IsOpen() const
return m_Impl->m_IsOpen.load(std::memory_order_relaxed);
}
+std::string
+HttpToWsUrl(std::string_view Endpoint, std::string_view Path)
+{
+ std::string_view Scheme = "ws://";
+ std::string_view Host = Endpoint;
+
+ if (Endpoint.starts_with("http://"))
+ {
+ Host = Endpoint.substr(7);
+ }
+ else if (Endpoint.starts_with("https://"))
+ {
+ Scheme = "wss://";
+ Host = Endpoint.substr(8);
+ }
+
+ // Strip trailing slash from host to avoid double-slash when Path starts with '/'
+ if (!Host.empty() && Host.back() == '/')
+ {
+ Host = Host.substr(0, Host.size() - 1);
+ }
+
+ std::string Url;
+ Url.reserve(Scheme.size() + Host.size() + Path.size());
+ Url.append(Scheme);
+ Url.append(Host);
+ Url.append(Path);
+ return Url;
+}
+
} // namespace zen
diff --git a/src/zenhttp/include/zenhttp/httpwsclient.h b/src/zenhttp/include/zenhttp/httpwsclient.h
index 2ca9b7ab1..9c3b909a2 100644
--- a/src/zenhttp/include/zenhttp/httpwsclient.h
+++ b/src/zenhttp/include/zenhttp/httpwsclient.h
@@ -80,4 +80,14 @@ private:
std::unique_ptr<Impl> m_Impl;
};
+/// Convert an HTTP(S) endpoint to a WebSocket URL by replacing the scheme
+/// and appending the given path. If the endpoint has no recognised scheme,
+/// it is treated as a plain host:port and gets the ws:// prefix.
+///
+/// Examples:
+/// HttpToWsUrl("http://host:8080", "/orch/ws") → "ws://host:8080/orch/ws"
+/// HttpToWsUrl("https://host", "/foo") → "wss://host/foo"
+/// HttpToWsUrl("host:8080", "/bar") → "ws://host:8080/bar"
+std::string HttpToWsUrl(std::string_view Endpoint, std::string_view Path);
+
} // namespace zen
diff --git a/src/zenhttp/include/zenhttp/websocket.h b/src/zenhttp/include/zenhttp/websocket.h
index bc3293282..710579faa 100644
--- a/src/zenhttp/include/zenhttp/websocket.h
+++ b/src/zenhttp/include/zenhttp/websocket.h
@@ -43,6 +43,8 @@ public:
virtual void SendBinary(std::span<const uint8_t> Data) = 0;
virtual void Close(uint16_t Code = 1000, std::string_view Reason = {}) = 0;
virtual bool IsOpen() const = 0;
+
+ void SendBinary(MemoryView Data) { SendBinary(std::span<const uint8_t>(static_cast<const uint8_t*>(Data.GetData()), Data.GetSize())); }
};
/**
diff --git a/src/zenserver-test/compute-tests.cpp b/src/zenserver-test/compute-tests.cpp
index c90ac5d8b..021052a3b 100644
--- a/src/zenserver-test/compute-tests.cpp
+++ b/src/zenserver-test/compute-tests.cpp
@@ -19,6 +19,7 @@
# include <zencore/timer.h>
# include <zenhttp/httpclient.h>
# include <zenhttp/httpserver.h>
+# include <zenhttp/websocket.h>
# include <zencompute/computeservice.h>
# include <zenstore/zenstore.h>
# include <zenutil/zenserverprocess.h>
@@ -291,7 +292,9 @@ GetRot13Output(const CbPackage& ResultPackage)
}
// Mock orchestrator HTTP service that serves GET /orch/agents with a controllable response.
-class MockOrchestratorService : public HttpService
+// Also implements IWebSocketHandler so the compute session's WS subscription receives
+// push notifications when the worker list changes.
+class MockOrchestratorService : public HttpService, public IWebSocketHandler
{
public:
MockOrchestratorService()
@@ -318,13 +321,48 @@ public:
void SetWorkerList(CbObject WorkerList)
{
- RwLock::ExclusiveLockScope Lock(m_Lock);
- m_WorkerList = std::move(WorkerList);
+ {
+ RwLock::ExclusiveLockScope Lock(m_Lock);
+ m_WorkerList = std::move(WorkerList);
+ }
+
+ // Broadcast a poke to all connected WebSocket clients so they
+ // immediately re-query the orchestrator instead of waiting for the poll.
+ std::vector<Ref<WebSocketConnection>> Snapshot;
+ m_WsLock.WithSharedLock([&] { Snapshot = m_WsConnections; });
+ for (auto& Conn : Snapshot)
+ {
+ if (Conn->IsOpen())
+ {
+ Conn->SendText("updated"sv);
+ }
+ }
+ }
+
+ // IWebSocketHandler
+ void OnWebSocketOpen(Ref<WebSocketConnection> Connection) override
+ {
+ m_WsLock.WithExclusiveLock([&] { m_WsConnections.push_back(std::move(Connection)); });
+ }
+
+ void OnWebSocketMessage(WebSocketConnection&, const WebSocketMessage&) override {}
+
+ void OnWebSocketClose(WebSocketConnection& Conn, uint16_t, std::string_view) override
+ {
+ m_WsLock.WithExclusiveLock([&] {
+ auto It = std::remove_if(m_WsConnections.begin(), m_WsConnections.end(), [&](const Ref<WebSocketConnection>& C) {
+ return C.Get() == &Conn;
+ });
+ m_WsConnections.erase(It, m_WsConnections.end());
+ });
}
private:
RwLock m_Lock;
CbObject m_WorkerList;
+
+ RwLock m_WsLock;
+ std::vector<Ref<WebSocketConnection>> m_WsConnections;
};
// Manages in-process ASIO HTTP server lifecycle for mock orchestrator.
@@ -1089,9 +1127,8 @@ TEST_CASE("function.remote.worker_sync_on_discovery")
InMemoryChunkResolver Resolver;
ScopedTemporaryDirectory SessionBaseDir;
zen::compute::ComputeServiceSession Session(Resolver);
- Session.SetOrchestratorEndpoint(MockOrch.GetEndpoint());
- Session.SetOrchestratorBasePath(SessionBaseDir.Path());
- Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Ready);
+ Session.SetOrchestrator(MockOrch.GetEndpoint(), SessionBaseDir.Path());
+ Session.Ready();
// Register worker on session (stored locally, no runners yet)
CbPackage WorkerPackage = BuildWorkerPackage(TestEnv, Resolver);
@@ -1100,8 +1137,9 @@ TEST_CASE("function.remote.worker_sync_on_discovery")
// Update mock orchestrator to advertise the real server
MockOrch.Service.SetWorkerList(BuildAgentListResponse({{"worker-1", ServerUri}}));
- // Wait for scheduler to discover the runner (~5s throttle + margin)
- Sleep(7'000);
+ // Trigger immediate orchestrator re-query and wait for runner setup
+ Session.NotifyOrchestratorChanged();
+ Sleep(2'000);
// Submit Rot13 action via session
CbObject ActionObj = BuildRot13ActionForSession("Hello World"sv, Resolver);
@@ -1153,15 +1191,15 @@ TEST_CASE("function.remote.late_runner_discovery")
InMemoryChunkResolver Resolver;
ScopedTemporaryDirectory SessionBaseDir;
zen::compute::ComputeServiceSession Session(Resolver);
- Session.SetOrchestratorEndpoint(MockOrch.GetEndpoint());
- Session.SetOrchestratorBasePath(SessionBaseDir.Path());
- Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Ready);
+ Session.SetOrchestrator(MockOrch.GetEndpoint(), SessionBaseDir.Path());
+ Session.Ready();
CbPackage WorkerPackage = BuildWorkerPackage(TestEnv, Resolver);
Session.RegisterWorker(WorkerPackage);
// Wait for W1 discovery
- Sleep(7'000);
+ Session.NotifyOrchestratorChanged();
+ Sleep(2'000);
// Baseline: submit Rot13 action and verify it completes on W1
{
@@ -1202,7 +1240,8 @@ TEST_CASE("function.remote.late_runner_discovery")
MockOrch.Service.SetWorkerList(BuildAgentListResponse({{"worker-1", ServerUri1}, {"worker-2", ServerUri2}}));
// Wait for W2 discovery
- Sleep(7'000);
+ Session.NotifyOrchestratorChanged();
+ Sleep(2'000);
// Verify W2 received the worker by querying its /compute/workers endpoint directly
{
@@ -1274,16 +1313,16 @@ TEST_CASE("function.remote.queue_association")
InMemoryChunkResolver Resolver;
ScopedTemporaryDirectory SessionBaseDir;
zen::compute::ComputeServiceSession Session(Resolver);
- Session.SetOrchestratorEndpoint(MockOrch.GetEndpoint());
- Session.SetOrchestratorBasePath(SessionBaseDir.Path());
- Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Ready);
+ Session.SetOrchestrator(MockOrch.GetEndpoint(), SessionBaseDir.Path());
+ Session.Ready();
// Register worker on session
CbPackage WorkerPackage = BuildWorkerPackage(TestEnv, Resolver);
Session.RegisterWorker(WorkerPackage);
// Wait for scheduler to discover the runner
- Sleep(7'000);
+ Session.NotifyOrchestratorChanged();
+ Sleep(2'000);
// Create a local queue and submit action to it
auto QueueResult = Session.CreateQueue();
@@ -1353,16 +1392,16 @@ TEST_CASE("function.remote.queue_cancel_propagation")
InMemoryChunkResolver Resolver;
ScopedTemporaryDirectory SessionBaseDir;
zen::compute::ComputeServiceSession Session(Resolver);
- Session.SetOrchestratorEndpoint(MockOrch.GetEndpoint());
- Session.SetOrchestratorBasePath(SessionBaseDir.Path());
- Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Ready);
+ Session.SetOrchestrator(MockOrch.GetEndpoint(), SessionBaseDir.Path());
+ Session.Ready();
// Register worker on session
CbPackage WorkerPackage = BuildWorkerPackage(TestEnv, Resolver);
Session.RegisterWorker(WorkerPackage);
// Wait for scheduler to discover the runner
- Sleep(7'000);
+ Session.NotifyOrchestratorChanged();
+ Sleep(2'000);
// Create a local queue and submit a long-running Sleep action
auto QueueResult = Session.CreateQueue();
@@ -1496,7 +1535,7 @@ TEST_CASE("function.session.abandon_pending")
InMemoryChunkResolver Resolver;
ScopedTemporaryDirectory SessionBaseDir;
zen::compute::ComputeServiceSession Session(Resolver);
- Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Ready);
+ Session.Ready();
CbPackage WorkerPackage = BuildWorkerPackage(TestEnv, Resolver);
Session.RegisterWorker(WorkerPackage);
@@ -1515,19 +1554,29 @@ TEST_CASE("function.session.abandon_pending")
REQUIRE_MESSAGE(Enqueue3, "Failed to enqueue action 3");
// Transition to Abandoned — should mark all pending actions as Abandoned
- bool Transitioned = Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Abandoned);
+ bool Transitioned = Session.Abandon();
CHECK_MESSAGE(Transitioned, "Failed to transition to Abandoned");
CHECK(Session.GetSessionState() == zen::compute::ComputeServiceSession::SessionState::Abandoned);
CHECK(!Session.IsHealthy());
- // Give the scheduler thread time to process the state changes
- Sleep(2'000);
-
- // All three actions should now be in the results map as abandoned
+ // Poll until the scheduler thread has processed all abandoned actions into
+ // the results map. The abandon is asynchronous: AbandonAllActions() sets
+ // action state and posts updates, but HandleActionUpdates() on the
+ // scheduler thread must run before results are queryable.
+ Stopwatch Timer;
for (int Lsn : {Enqueue1.Lsn, Enqueue2.Lsn, Enqueue3.Lsn})
{
CbPackage Result;
- HttpResponseCode Code = Session.GetActionResult(Lsn, Result);
+ HttpResponseCode Code = HttpResponseCode::Accepted;
+ while (Timer.GetElapsedTimeMs() < 10'000)
+ {
+ Code = Session.GetActionResult(Lsn, Result);
+ if (Code == HttpResponseCode::OK)
+ {
+ break;
+ }
+ Sleep(100);
+ }
CHECK_MESSAGE(Code == HttpResponseCode::OK, fmt::format("Expected action LSN {} to be in results (got {})", Lsn, int(Code)));
}
@@ -1561,15 +1610,15 @@ TEST_CASE("function.session.abandon_running")
InMemoryChunkResolver Resolver;
ScopedTemporaryDirectory SessionBaseDir;
zen::compute::ComputeServiceSession Session(Resolver);
- Session.SetOrchestratorEndpoint(MockOrch.GetEndpoint());
- Session.SetOrchestratorBasePath(SessionBaseDir.Path());
- Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Ready);
+ Session.SetOrchestrator(MockOrch.GetEndpoint(), SessionBaseDir.Path());
+ Session.Ready();
CbPackage WorkerPackage = BuildWorkerPackage(TestEnv, Resolver);
Session.RegisterWorker(WorkerPackage);
// Wait for scheduler to discover the runner
- Sleep(7'000);
+ Session.NotifyOrchestratorChanged();
+ Sleep(2'000);
// Create a queue and submit a long-running Sleep action
auto QueueResult = Session.CreateQueue();
@@ -1585,7 +1634,7 @@ TEST_CASE("function.session.abandon_running")
Sleep(2'000);
// Transition to Abandoned — should abandon the running action
- bool Transitioned = Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Abandoned);
+ bool Transitioned = Session.Abandon();
CHECK_MESSAGE(Transitioned, "Failed to transition to Abandoned");
CHECK(!Session.IsHealthy());
@@ -1631,16 +1680,16 @@ TEST_CASE("function.remote.abandon_propagation")
InMemoryChunkResolver Resolver;
ScopedTemporaryDirectory SessionBaseDir;
zen::compute::ComputeServiceSession Session(Resolver);
- Session.SetOrchestratorEndpoint(MockOrch.GetEndpoint());
- Session.SetOrchestratorBasePath(SessionBaseDir.Path());
- Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Ready);
+ Session.SetOrchestrator(MockOrch.GetEndpoint(), SessionBaseDir.Path());
+ Session.Ready();
// Register worker on session
CbPackage WorkerPackage = BuildWorkerPackage(TestEnv, Resolver);
Session.RegisterWorker(WorkerPackage);
// Wait for scheduler to discover the runner
- Sleep(7'000);
+ Session.NotifyOrchestratorChanged();
+ Sleep(2'000);
// Create a local queue and submit a long-running Sleep action
auto QueueResult = Session.CreateQueue();
@@ -1656,7 +1705,7 @@ TEST_CASE("function.remote.abandon_propagation")
Sleep(2'000);
// Transition to Abandoned — should abandon the running action and propagate
- bool Transitioned = Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Abandoned);
+ bool Transitioned = Session.Abandon();
CHECK_MESSAGE(Transitioned, "Failed to transition to Abandoned");
// Poll for the action to complete
@@ -1693,6 +1742,278 @@ TEST_CASE("function.remote.abandon_propagation")
Session.Shutdown();
}
+TEST_CASE("function.remote.shutdown_cancels_queues")
+{
+ // Verify that Session.Shutdown() cancels remote queues on the compute node.
+ ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer);
+ Instance.SetDataDir(TestEnv.CreateNewTestDir());
+ REQUIRE_MESSAGE(Instance.SpawnServerAndWaitUntilReady() != 0, Instance.GetLogOutput());
+
+ MockOrchestratorFixture MockOrch;
+ MockOrch.Service.SetWorkerList(BuildAgentListResponse({{"worker-1", Instance.GetBaseUri()}}));
+
+ InMemoryChunkResolver Resolver;
+ ScopedTemporaryDirectory SessionBaseDir;
+ zen::compute::ComputeServiceSession Session(Resolver);
+ Session.SetOrchestrator(MockOrch.GetEndpoint(), SessionBaseDir.Path());
+ Session.Ready();
+
+ CbPackage WorkerPackage = BuildWorkerPackage(TestEnv, Resolver);
+ Session.RegisterWorker(WorkerPackage);
+
+ Session.NotifyOrchestratorChanged();
+ Sleep(2'000);
+
+ // Create a queue and submit a long-running action so the remote queue is established
+ auto QueueResult = Session.CreateQueue();
+ REQUIRE_MESSAGE(QueueResult.QueueId != 0, "Failed to create local queue");
+ const int QueueId = QueueResult.QueueId;
+
+ CbObject ActionObj = BuildSleepActionForSession("data"sv, 30'000, Resolver);
+
+ auto EnqueueRes = Session.EnqueueActionToQueue(QueueId, ActionObj, 0);
+ REQUIRE_MESSAGE(EnqueueRes, "Sleep action enqueue to queue failed");
+
+ // Wait for the action to start running on the remote
+ Sleep(2'000);
+
+ // Verify the remote has a non-implicit queue before shutdown
+ HttpClient RemoteClient(Instance.GetBaseUri() + "/compute");
+ {
+ HttpClient::Response QueuesResp = RemoteClient.Get("/queues"sv);
+ REQUIRE_MESSAGE(QueuesResp, "Failed to list queues on remote server before shutdown");
+
+ bool RemoteQueueFound = false;
+ for (auto& Item : QueuesResp.AsObject()["queues"sv])
+ {
+ if (!Item.AsObjectView()["implicit"sv].AsBool())
+ {
+ RemoteQueueFound = true;
+ break;
+ }
+ }
+ REQUIRE_MESSAGE(RemoteQueueFound, "Expected remote queue to exist before shutdown");
+ }
+
+ // Shut down the session — this should cancel all remote queues
+ Session.Shutdown();
+
+ // Verify the remote queue is now cancelled
+ {
+ HttpClient::Response QueuesResp = RemoteClient.Get("/queues"sv);
+ REQUIRE_MESSAGE(QueuesResp, "Failed to list queues on remote server after shutdown");
+
+ bool RemoteQueueCancelled = false;
+ for (auto& Item : QueuesResp.AsObject()["queues"sv])
+ {
+ if (!Item.AsObjectView()["implicit"sv].AsBool())
+ {
+ RemoteQueueCancelled = std::string(Item.AsObjectView()["state"sv].AsString()) == "cancelled";
+ break;
+ }
+ }
+ CHECK_MESSAGE(RemoteQueueCancelled, "Expected remote queue to be cancelled after session shutdown");
+ }
+}
+
+TEST_CASE("function.remote.shutdown_rejects_new_work")
+{
+ // Verify that after Shutdown() the remote runner rejects new submissions.
+ ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer);
+ Instance.SetDataDir(TestEnv.CreateNewTestDir());
+ REQUIRE_MESSAGE(Instance.SpawnServerAndWaitUntilReady() != 0, Instance.GetLogOutput());
+
+ MockOrchestratorFixture MockOrch;
+ MockOrch.Service.SetWorkerList(BuildAgentListResponse({{"worker-1", Instance.GetBaseUri()}}));
+
+ InMemoryChunkResolver Resolver;
+ ScopedTemporaryDirectory SessionBaseDir;
+ zen::compute::ComputeServiceSession Session(Resolver);
+ Session.SetOrchestrator(MockOrch.GetEndpoint(), SessionBaseDir.Path());
+ Session.Ready();
+
+ CbPackage WorkerPackage = BuildWorkerPackage(TestEnv, Resolver);
+ Session.RegisterWorker(WorkerPackage);
+
+ // Wait for runner discovery
+ Session.NotifyOrchestratorChanged();
+ Sleep(2'000);
+
+ // Baseline: submit an action and verify it completes
+ {
+ CbObject ActionObj = BuildRot13ActionForSession("Hello World"sv, Resolver);
+
+ auto EnqueueRes = Session.EnqueueAction(ActionObj, 0);
+ REQUIRE_MESSAGE(EnqueueRes, "Baseline action enqueue failed");
+
+ CbPackage ResultPackage;
+ HttpResponseCode ResultCode = HttpResponseCode::Accepted;
+ Stopwatch Timer;
+
+ while (Timer.GetElapsedTimeMs() < 30'000)
+ {
+ ResultCode = Session.GetActionResult(EnqueueRes.Lsn, ResultPackage);
+ if (ResultCode == HttpResponseCode::OK)
+ {
+ break;
+ }
+ Sleep(200);
+ }
+
+ REQUIRE_MESSAGE(ResultCode == HttpResponseCode::OK,
+ fmt::format("Baseline action did not complete in time\nServer log:\n{}", Instance.GetLogOutput()));
+ CHECK_EQ(GetRot13Output(ResultPackage), "Uryyb Jbeyq"sv);
+ }
+
+ // Shut down — the remote runner should now reject new work
+ Session.Shutdown();
+
+ // Attempting to enqueue after shutdown should fail (session is in Sunset state)
+ CbObject ActionObj = BuildRot13ActionForSession("rejected"sv, Resolver);
+ auto Rejected = Session.EnqueueAction(ActionObj, 0);
+ CHECK_MESSAGE(!Rejected, "Expected action submission to be rejected after shutdown");
+}
+
+TEST_CASE("function.session.retract_pending")
+{
+ // Create a session with no runners so actions stay pending
+ InMemoryChunkResolver Resolver;
+ ScopedTemporaryDirectory SessionBaseDir;
+ zen::compute::ComputeServiceSession Session(Resolver);
+ Session.Ready();
+
+ CbPackage WorkerPackage = BuildWorkerPackage(TestEnv, Resolver);
+ Session.RegisterWorker(WorkerPackage);
+
+ auto QueueResult = Session.CreateQueue();
+ REQUIRE_MESSAGE(QueueResult.QueueId != 0, "Failed to create queue");
+
+ CbObject ActionObj = BuildRot13ActionForSession("retract-test"sv, Resolver);
+
+ auto Enqueued = Session.EnqueueActionToQueue(QueueResult.QueueId, ActionObj, 0);
+ REQUIRE_MESSAGE(Enqueued, "Failed to enqueue action");
+
+ // Let the scheduler process the pending action
+ Sleep(500);
+
+ // Retract the pending action
+ auto Result = Session.RetractAction(Enqueued.Lsn);
+ CHECK_MESSAGE(Result.Success, fmt::format("RetractAction failed: {}", Result.Error));
+ CHECK_EQ(Result.RetryCount, 0); // Retract should NOT increment retry count
+
+ // The action should be re-enqueued as pending (still no runners, so stays pending).
+ // Let the scheduler process the retracted action back to pending.
+ Sleep(500);
+
+ // Queue should still show 1 active (the action was rescheduled, not completed)
+ auto Status = Session.GetQueueStatus(QueueResult.QueueId);
+ CHECK_EQ(Status.ActiveCount, 1);
+ CHECK_EQ(Status.CompletedCount, 0);
+ CHECK_EQ(Status.AbandonedCount, 0);
+ CHECK_EQ(Status.CancelledCount, 0);
+
+ // The action result should NOT be in the results map (it's pending again)
+ CbPackage ResultPackage;
+ HttpResponseCode Code = Session.GetActionResult(Enqueued.Lsn, ResultPackage);
+ CHECK(Code != HttpResponseCode::OK);
+
+ Session.Shutdown();
+}
+
+TEST_CASE("function.session.retract_not_terminal")
+{
+ // Verify that a completed action cannot be retracted
+ InMemoryChunkResolver Resolver;
+ ScopedTemporaryDirectory SessionBaseDir;
+ zen::compute::ComputeServiceSession Session(Resolver);
+ Session.SetOrchestratorBasePath(SessionBaseDir.Path());
+ Session.AddLocalRunner(Resolver, SessionBaseDir.Path());
+ Session.Ready();
+
+ CbPackage WorkerPackage = BuildWorkerPackage(TestEnv, Resolver);
+ Session.RegisterWorker(WorkerPackage);
+
+ CbObject ActionObj = BuildRot13ActionForSession("retract-completed"sv, Resolver);
+
+ auto Enqueued = Session.EnqueueAction(ActionObj, 0);
+ REQUIRE_MESSAGE(Enqueued, "Failed to enqueue action");
+
+ // Wait for the action to complete
+ CbPackage ResultPackage;
+ HttpResponseCode Code = HttpResponseCode::Accepted;
+ Stopwatch Timer;
+
+ while (Timer.GetElapsedTimeMs() < 30'000)
+ {
+ Code = Session.GetActionResult(Enqueued.Lsn, ResultPackage);
+ if (Code == HttpResponseCode::OK)
+ {
+ break;
+ }
+ Sleep(200);
+ }
+
+ REQUIRE_MESSAGE(Code == HttpResponseCode::OK, "Action did not complete within timeout");
+
+ // Retract should fail — action already completed (no longer in pending/running maps)
+ auto RetractResult = Session.RetractAction(Enqueued.Lsn);
+ CHECK(!RetractResult.Success);
+
+ Session.Shutdown();
+}
+
+TEST_CASE("function.retract_http")
+{
+ // Use max-actions=1 with a long sleep to hold the slot, then submit a second
+ // action that will stay pending and can be retracted via the HTTP endpoint.
+ ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer);
+ Instance.SetDataDir(TestEnv.CreateNewTestDir());
+ const uint16_t Port = Instance.SpawnServerAndWaitUntilReady("--max-actions=1");
+ REQUIRE_MESSAGE(Port != 0, Instance.GetLogOutput());
+
+ const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port);
+ HttpClient Client(ComputeBaseUri);
+
+ const IoHash WorkerId = RegisterWorker(Client, TestEnv);
+
+ // Submit a long-running Sleep action to occupy the single execution slot
+ const std::string BlockerUrl = fmt::format("/jobs/{}", WorkerId.ToHexString());
+ HttpClient::Response BlockerResp = Client.Post(BlockerUrl, BuildSleepActionPackage("data"sv, 30'000));
+ REQUIRE_MESSAGE(BlockerResp, fmt::format("Blocker submission failed: status={}", int(BlockerResp.StatusCode)));
+
+ // Submit a second action — it will stay pending because the slot is occupied
+ HttpClient::Response SubmitResp = Client.Post(BlockerUrl, BuildRot13ActionPackage("Retract HTTP Test"sv));
+ REQUIRE_MESSAGE(SubmitResp, fmt::format("Job submission failed: status={}", int(SubmitResp.StatusCode)));
+
+ const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32();
+ REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN from job submission");
+
+ // Wait for the scheduler to process the pending action into m_PendingActions
+ Sleep(1'000);
+
+ // Retract the pending action via POST /jobs/{lsn}/retract
+ const std::string RetractUrl = fmt::format("/jobs/{}/retract", Lsn);
+ HttpClient::Response RetractResp = Client.Post(RetractUrl);
+ CHECK_MESSAGE(RetractResp.StatusCode == HttpResponseCode::OK,
+ fmt::format("Retract failed: status={}, body={}\nServer log:\n{}",
+ int(RetractResp.StatusCode),
+ RetractResp.ToText(),
+ Instance.GetLogOutput()));
+
+ if (RetractResp.StatusCode == HttpResponseCode::OK)
+ {
+ CbObject RespObj = RetractResp.AsObject();
+ CHECK(RespObj["success"sv].AsBool());
+ CHECK_EQ(RespObj["lsn"sv].AsInt32(), Lsn);
+ }
+
+ // A second retract should also succeed (action is back to pending)
+ Sleep(500);
+ HttpClient::Response RetractResp2 = Client.Post(RetractUrl);
+ CHECK_MESSAGE(RetractResp2.StatusCode == HttpResponseCode::OK,
+ fmt::format("Second retract failed: status={}, body={}", int(RetractResp2.StatusCode), RetractResp2.ToText()));
+}
+
TEST_SUITE_END();
} // namespace zen::tests::compute
diff --git a/src/zenserver-test/zenserver-test.cpp b/src/zenserver-test/zenserver-test.cpp
index e89812c1f..42632682b 100644
--- a/src/zenserver-test/zenserver-test.cpp
+++ b/src/zenserver-test/zenserver-test.cpp
@@ -15,6 +15,7 @@
# include <zencore/testutils.h>
# include <zencore/thread.h>
# include <zencore/timer.h>
+# include <zencore/trace.h>
# include <zenhttp/httpclient.h>
# include <zenhttp/packageformat.h>
# include <zenutil/config/commandlineoptions.h>
@@ -134,8 +135,7 @@ main(int argc, char** argv)
ZEN_INFO("Running tests...(base dir: '{}')", TestBaseDir);
zen::testing::TestRunner Runner;
- Runner.SetDefaultSuiteFilter("server.*");
- Runner.ApplyCommandLine(argc, argv);
+ Runner.ApplyCommandLine(argc, argv, "server.*");
return Runner.Run();
}
diff --git a/src/zenserver/compute/computeserver.cpp b/src/zenserver/compute/computeserver.cpp
index 724ef9ad2..d1875f41a 100644
--- a/src/zenserver/compute/computeserver.cpp
+++ b/src/zenserver/compute/computeserver.cpp
@@ -721,21 +721,7 @@ ZenComputeServer::InitializeOrchestratorWebSocket()
return;
}
- // Convert http://host:port → ws://host:port/orch/ws
- std::string WsUrl = m_CoordinatorEndpoint;
- if (WsUrl.starts_with("http://"))
- {
- WsUrl = "ws://" + WsUrl.substr(7);
- }
- else if (WsUrl.starts_with("https://"))
- {
- WsUrl = "wss://" + WsUrl.substr(8);
- }
- if (!WsUrl.empty() && WsUrl.back() != '/')
- {
- WsUrl += '/';
- }
- WsUrl += "orch/ws";
+ std::string WsUrl = HttpToWsUrl(m_CoordinatorEndpoint, "/orch/ws");
ZEN_INFO("establishing WebSocket link to orchestrator at {}", WsUrl);
diff --git a/src/zenserver/frontend/html/compute/compute.html b/src/zenserver/frontend/html/compute/compute.html
index 66c20175f..c07bbb692 100644
--- a/src/zenserver/frontend/html/compute/compute.html
+++ b/src/zenserver/frontend/html/compute/compute.html
@@ -6,6 +6,7 @@
<title>Zen Compute Dashboard</title>
<script src="https://cdn.jsdelivr.net/npm/[email protected]/dist/chart.umd.min.js"></script>
<link rel="stylesheet" type="text/css" href="../zen.css" />
+ <script src="../util/sanitize.js"></script>
<script src="../theme.js"></script>
<script src="../banner.js" defer></script>
<script src="../nav.js" defer></script>
@@ -456,11 +457,6 @@
});
// Helper functions
- function escapeHtml(text) {
- var div = document.createElement('div');
- div.textContent = text;
- return div.innerHTML;
- }
function formatBytes(bytes) {
if (bytes === 0) return '0 B';
diff --git a/src/zenserver/frontend/html/compute/hub.html b/src/zenserver/frontend/html/compute/hub.html
index 32e1b05db..620349a2b 100644
--- a/src/zenserver/frontend/html/compute/hub.html
+++ b/src/zenserver/frontend/html/compute/hub.html
@@ -4,6 +4,7 @@
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<link rel="stylesheet" type="text/css" href="../zen.css" />
+ <script src="../util/sanitize.js"></script>
<script src="../theme.js"></script>
<script src="../banner.js" defer></script>
<script src="../nav.js" defer></script>
@@ -62,12 +63,6 @@
const BASE_URL = window.location.origin;
const REFRESH_INTERVAL = 2000;
- function escapeHtml(text) {
- var div = document.createElement('div');
- div.textContent = text;
- return div.innerHTML;
- }
-
function showError(message) {
document.getElementById('error-container').innerHTML =
'<div class="error">Error: ' + escapeHtml(message) + '</div>';
diff --git a/src/zenserver/frontend/html/compute/orchestrator.html b/src/zenserver/frontend/html/compute/orchestrator.html
index a519dee18..d1a2bb015 100644
--- a/src/zenserver/frontend/html/compute/orchestrator.html
+++ b/src/zenserver/frontend/html/compute/orchestrator.html
@@ -4,6 +4,7 @@
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<link rel="stylesheet" type="text/css" href="../zen.css" />
+ <script src="../util/sanitize.js"></script>
<script src="../theme.js"></script>
<script src="../banner.js" defer></script>
<script src="../nav.js" defer></script>
@@ -128,12 +129,6 @@
const BASE_URL = window.location.origin;
const REFRESH_INTERVAL = 2000;
- function escapeHtml(text) {
- var div = document.createElement('div');
- div.textContent = text;
- return div.innerHTML;
- }
-
function showError(message) {
document.getElementById('error-container').innerHTML =
'<div class="error">Error: ' + escapeHtml(message) + '</div>';
diff --git a/src/zenserver/frontend/html/util/sanitize.js b/src/zenserver/frontend/html/util/sanitize.js
new file mode 100644
index 000000000..1b0f32e38
--- /dev/null
+++ b/src/zenserver/frontend/html/util/sanitize.js
@@ -0,0 +1,9 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+// Shared utility functions for compute dashboard pages.
+
+function escapeHtml(text) {
+ var div = document.createElement('div');
+ div.textContent = text;
+ return div.innerHTML;
+}
diff --git a/src/zenserver/main.cpp b/src/zenserver/main.cpp
index 26ae85ae1..9d786c209 100644
--- a/src/zenserver/main.cpp
+++ b/src/zenserver/main.cpp
@@ -319,6 +319,11 @@ main(int argc, char* argv[])
{
ServerMode = kTest;
}
+ else if (argv[1][0] != '-')
+ {
+ fprintf(stderr, "unknown mode '%s'. Available modes: hub, store, compute, proxy, test\n", argv[1]);
+ return 1;
+ }
}
switch (ServerMode)
diff --git a/src/zenserver/xmake.lua b/src/zenserver/xmake.lua
index fe279ebb2..b619c5548 100644
--- a/src/zenserver/xmake.lua
+++ b/src/zenserver/xmake.lua
@@ -19,7 +19,7 @@ target("zenserver")
add_headerfiles("**.h")
add_rules("utils.bin2c", {extensions = {".zip"}})
add_files("**.cpp")
- add_files("frontend/html.zip")
+ add_files("$(buildir)/frontend/html.zip")
add_files("zenserver.cpp", {unity_ignored = true })
if is_plat("linux") and not (get_config("toolchain") or ""):find("clang") then
@@ -84,7 +84,8 @@ target("zenserver")
on_load(function(target)
local html_dir = path.join(os.projectdir(), "src/zenserver/frontend/html")
- local zip_path = path.join(os.projectdir(), "src/zenserver/frontend/html.zip")
+ local zip_dir = path.join(os.projectdir(), get_config("buildir") or "build", "frontend")
+ local zip_path = path.join(zip_dir, "html.zip")
-- Check if zip needs regeneration
local need_update = not os.isfile(zip_path)
@@ -100,18 +101,19 @@ target("zenserver")
if need_update then
print("Regenerating frontend zip...")
+ os.mkdir(zip_dir)
os.tryrm(zip_path)
import("detect.tools.find_7z")
local cmd_7z = find_7z()
if cmd_7z then
- os.execv(cmd_7z, {"a", "-mx0", zip_path, path.join(html_dir, ".")})
+ os.execv(cmd_7z, {"a", "-mx0", "-bso0", zip_path, path.join(html_dir, ".")})
else
import("detect.tools.find_zip")
local zip_cmd = find_zip()
if zip_cmd then
local oldir = os.cd(html_dir)
- os.execv(zip_cmd, {"-r", "-0", zip_path, "."})
+ os.execv(zip_cmd, {"-r", "-0", "-q", zip_path, "."})
os.cd(oldir)
else
raise("Unable to find a suitable zip tool (need 7z or zip)")
diff --git a/src/zenutil/consoletui.cpp b/src/zenutil/consoletui.cpp
index 4410d463d..124132aed 100644
--- a/src/zenutil/consoletui.cpp
+++ b/src/zenutil/consoletui.cpp
@@ -480,4 +480,69 @@ TuiPollQuit()
#endif
}
+void
+TuiSetScrollRegion(uint32_t Top, uint32_t Bottom)
+{
+ printf("\033[%u;%ur", Top, Bottom);
+}
+
+void
+TuiResetScrollRegion()
+{
+ printf("\033[r");
+}
+
+void
+TuiMoveCursor(uint32_t Row, uint32_t Col)
+{
+ printf("\033[%u;%uH", Row, Col);
+}
+
+void
+TuiSaveCursor()
+{
+ printf(
+ "\033"
+ "7");
+}
+
+void
+TuiRestoreCursor()
+{
+ printf(
+ "\033"
+ "8");
+}
+
+void
+TuiEraseLine()
+{
+ printf("\033[2K");
+}
+
+void
+TuiWrite(std::string_view Text)
+{
+ fwrite(Text.data(), 1, Text.size(), stdout);
+}
+
+void
+TuiFlush()
+{
+ fflush(stdout);
+}
+
+void
+TuiShowCursor(bool Show)
+{
+ if (Show)
+ {
+ printf("\033[?25h");
+ }
+ else
+ {
+ printf("\033[?25l");
+ }
+}
+
} // namespace zen
diff --git a/src/zenutil/include/zenutil/consoletui.h b/src/zenutil/include/zenutil/consoletui.h
index 5f74fa82b..22737589b 100644
--- a/src/zenutil/include/zenutil/consoletui.h
+++ b/src/zenutil/include/zenutil/consoletui.h
@@ -57,4 +57,32 @@ uint32_t TuiConsoleRows(uint32_t Default = 40);
// Should only be called while in alternate screen mode.
bool TuiPollQuit();
+// Set the scrollable region of the terminal to rows [Top, Bottom] (1-based, inclusive).
+// Emits DECSTBM: ESC[<top>;<bottom>r
+void TuiSetScrollRegion(uint32_t Top, uint32_t Bottom);
+
+// Reset the scroll region to the full terminal. Emits ESC[r
+void TuiResetScrollRegion();
+
+// Move cursor to the given 1-based row and column. Emits ESC[<row>;<col>H
+void TuiMoveCursor(uint32_t Row, uint32_t Col);
+
+// Save cursor position. Emits ESC 7
+void TuiSaveCursor();
+
+// Restore cursor position. Emits ESC 8
+void TuiRestoreCursor();
+
+// Erase the entire current line. Emits ESC[2K
+void TuiEraseLine();
+
+// Write raw text to stdout without any formatting or newline.
+void TuiWrite(std::string_view Text);
+
+// Flush stdout.
+void TuiFlush();
+
+// Show or hide the cursor. Emits ESC[?25h or ESC[?25l
+void TuiShowCursor(bool Show);
+
} // namespace zen
diff --git a/src/zenutil/include/zenutil/logging/fullformatter.h b/src/zenutil/include/zenutil/logging/fullformatter.h
index 33cb94dae..0d026ed72 100644
--- a/src/zenutil/include/zenutil/logging/fullformatter.h
+++ b/src/zenutil/include/zenutil/logging/fullformatter.h
@@ -3,10 +3,8 @@
#pragma once
#include <zencore/logging/formatter.h>
-#include <zencore/logging/helpers.h>
-#include <zencore/memory/llm.h>
-#include <zencore/zencore.h>
+#include <memory>
#include <string_view>
namespace zen::logging {
@@ -14,195 +12,16 @@ namespace zen::logging {
class FullFormatter final : public Formatter
{
public:
- FullFormatter(std::string_view LogId, std::chrono::time_point<std::chrono::system_clock> Epoch)
- : m_Epoch(Epoch)
- , m_LogId(LogId)
- , m_LinePrefix(128, ' ')
- , m_UseFullDate(false)
- {
- }
+ FullFormatter(std::string_view LogId, std::chrono::time_point<std::chrono::system_clock> Epoch);
+ explicit FullFormatter(std::string_view LogId);
+ ~FullFormatter() override;
- FullFormatter(std::string_view LogId) : m_LogId(LogId), m_LinePrefix(128, ' '), m_UseFullDate(true) {}
-
- virtual std::unique_ptr<Formatter> Clone() const override
- {
- ZEN_MEMSCOPE(ELLMTag::Logging);
- if (m_UseFullDate)
- {
- return std::make_unique<FullFormatter>(m_LogId);
- }
- return std::make_unique<FullFormatter>(m_LogId, m_Epoch);
- }
-
- virtual void Format(const LogMessage& Msg, MemoryBuffer& OutBuffer) override
- {
- ZEN_MEMSCOPE(ELLMTag::Logging);
-
- // Note that the sink is responsible for ensuring there is only ever a
- // single caller in here
-
- using namespace std::literals;
-
- std::chrono::seconds TimestampSeconds;
-
- std::chrono::milliseconds Millis;
-
- if (m_UseFullDate)
- {
- TimestampSeconds = std::chrono::duration_cast<std::chrono::seconds>(Msg.GetTime().time_since_epoch());
- if (TimestampSeconds != m_LastLogSecs)
- {
- RwLock::ExclusiveLockScope _(m_TimestampLock);
- m_LastLogSecs = TimestampSeconds;
-
- m_CachedLocalTm = helpers::SafeLocaltime(LogClock::to_time_t(Msg.GetTime()));
- m_CachedDatetime.clear();
- m_CachedDatetime.push_back('[');
- helpers::Pad2(m_CachedLocalTm.tm_year % 100, m_CachedDatetime);
- m_CachedDatetime.push_back('-');
- helpers::Pad2(m_CachedLocalTm.tm_mon + 1, m_CachedDatetime);
- m_CachedDatetime.push_back('-');
- helpers::Pad2(m_CachedLocalTm.tm_mday, m_CachedDatetime);
- m_CachedDatetime.push_back(' ');
- helpers::Pad2(m_CachedLocalTm.tm_hour, m_CachedDatetime);
- m_CachedDatetime.push_back(':');
- helpers::Pad2(m_CachedLocalTm.tm_min, m_CachedDatetime);
- m_CachedDatetime.push_back(':');
- helpers::Pad2(m_CachedLocalTm.tm_sec, m_CachedDatetime);
- m_CachedDatetime.push_back('.');
- }
-
- Millis = helpers::TimeFraction<std::chrono::milliseconds>(Msg.GetTime());
- }
- else
- {
- auto ElapsedTime = Msg.GetTime() - m_Epoch;
- TimestampSeconds = std::chrono::duration_cast<std::chrono::seconds>(ElapsedTime);
-
- if (m_CacheTimestamp.load() != TimestampSeconds)
- {
- RwLock::ExclusiveLockScope _(m_TimestampLock);
-
- m_CacheTimestamp = TimestampSeconds;
-
- int Count = int(TimestampSeconds.count());
- const int LogSecs = Count % 60;
- Count /= 60;
- const int LogMins = Count % 60;
- Count /= 60;
- const int LogHours = Count;
-
- m_CachedDatetime.clear();
- m_CachedDatetime.push_back('[');
- helpers::Pad2(LogHours, m_CachedDatetime);
- m_CachedDatetime.push_back(':');
- helpers::Pad2(LogMins, m_CachedDatetime);
- m_CachedDatetime.push_back(':');
- helpers::Pad2(LogSecs, m_CachedDatetime);
- m_CachedDatetime.push_back('.');
- }
-
- Millis = std::chrono::duration_cast<std::chrono::milliseconds>(ElapsedTime - TimestampSeconds);
- }
-
- {
- RwLock::SharedLockScope _(m_TimestampLock);
- OutBuffer.append(m_CachedDatetime.begin(), m_CachedDatetime.end());
- }
-
- helpers::Pad3(static_cast<uint32_t>(Millis.count()), OutBuffer);
- OutBuffer.push_back(']');
- OutBuffer.push_back(' ');
-
- if (!m_LogId.empty())
- {
- OutBuffer.push_back('[');
- helpers::AppendStringView(m_LogId, OutBuffer);
- OutBuffer.push_back(']');
- OutBuffer.push_back(' ');
- }
-
- // append logger name if exists
- if (Msg.GetLoggerName().size() > 0)
- {
- OutBuffer.push_back('[');
- helpers::AppendStringView(Msg.GetLoggerName(), OutBuffer);
- OutBuffer.push_back(']');
- OutBuffer.push_back(' ');
- }
-
- OutBuffer.push_back('[');
- // wrap the level name with color
- Msg.ColorRangeStart = OutBuffer.size();
- helpers::AppendStringView(helpers::LevelToShortString(Msg.GetLevel()), OutBuffer);
- Msg.ColorRangeEnd = OutBuffer.size();
- OutBuffer.push_back(']');
- OutBuffer.push_back(' ');
-
- // add source location if present
- if (Msg.GetSource())
- {
- OutBuffer.push_back('[');
- const char* Filename = helpers::ShortFilename(Msg.GetSource().Filename);
- helpers::AppendStringView(Filename, OutBuffer);
- OutBuffer.push_back(':');
- helpers::AppendInt(Msg.GetSource().Line, OutBuffer);
- OutBuffer.push_back(']');
- OutBuffer.push_back(' ');
- }
-
- // Handle newlines in single log call by prefixing each additional line to make
- // subsequent lines align with the first line in the message
-
- const size_t LinePrefixCount = Min<size_t>(OutBuffer.size(), m_LinePrefix.size());
-
- auto MsgPayload = Msg.GetPayload();
- auto ItLineBegin = MsgPayload.begin();
- auto ItMessageEnd = MsgPayload.end();
- bool IsFirstline = true;
-
- {
- auto ItLineEnd = ItLineBegin;
-
- auto EmitLine = [&] {
- if (IsFirstline)
- {
- IsFirstline = false;
- }
- else
- {
- helpers::AppendStringView(std::string_view(m_LinePrefix.data(), LinePrefixCount), OutBuffer);
- }
- helpers::AppendStringView(std::string_view(&*ItLineBegin, ItLineEnd - ItLineBegin), OutBuffer);
- };
-
- while (ItLineEnd != ItMessageEnd)
- {
- if (*ItLineEnd++ == '\n')
- {
- EmitLine();
- ItLineBegin = ItLineEnd;
- }
- }
-
- if (ItLineBegin != ItMessageEnd)
- {
- EmitLine();
- helpers::AppendStringView("\n"sv, OutBuffer);
- }
- }
- }
+ std::unique_ptr<Formatter> Clone() const override;
+ void Format(const LogMessage& Msg, MemoryBuffer& OutBuffer) override;
private:
- std::chrono::time_point<std::chrono::system_clock> m_Epoch;
- std::tm m_CachedLocalTm;
- std::chrono::seconds m_LastLogSecs{std::chrono::seconds(87654321)};
- std::atomic<std::chrono::seconds> m_CacheTimestamp{std::chrono::seconds(87654321)};
- MemoryBuffer m_CachedDatetime;
- std::string m_LogId;
- std::string m_LinePrefix;
- bool m_UseFullDate = true;
- RwLock m_TimestampLock;
+ struct Impl;
+ std::unique_ptr<Impl> m_Impl;
};
} // namespace zen::logging
diff --git a/src/zenutil/include/zenutil/logging/jsonformatter.h b/src/zenutil/include/zenutil/logging/jsonformatter.h
index 216b1b5e5..fb3193f3e 100644
--- a/src/zenutil/include/zenutil/logging/jsonformatter.h
+++ b/src/zenutil/include/zenutil/logging/jsonformatter.h
@@ -3,158 +3,24 @@
#pragma once
#include <zencore/logging/formatter.h>
-#include <zencore/logging/helpers.h>
-#include <zencore/memory/llm.h>
-#include <zencore/zencore.h>
+#include <memory>
#include <string_view>
-#include <unordered_map>
namespace zen::logging {
-using namespace std::literals;
-
class JsonFormatter final : public Formatter
{
public:
- JsonFormatter(std::string_view LogId) : m_LogId(LogId) {}
-
- virtual std::unique_ptr<Formatter> Clone() const override { return std::make_unique<JsonFormatter>(m_LogId); }
-
- virtual void Format(const LogMessage& Msg, MemoryBuffer& Dest) override
- {
- ZEN_MEMSCOPE(ELLMTag::Logging);
-
- using std::chrono::duration_cast;
- using std::chrono::milliseconds;
- using std::chrono::seconds;
-
- auto Secs = std::chrono::duration_cast<seconds>(Msg.GetTime().time_since_epoch());
- if (Secs != m_LastLogSecs)
- {
- RwLock::ExclusiveLockScope _(m_TimestampLock);
- m_CachedTm = helpers::SafeLocaltime(LogClock::to_time_t(Msg.GetTime()));
- m_LastLogSecs = Secs;
-
- // cache the date/time part for the next second.
- m_CachedDatetime.clear();
-
- helpers::AppendInt(m_CachedTm.tm_year + 1900, m_CachedDatetime);
- m_CachedDatetime.push_back('-');
-
- helpers::Pad2(m_CachedTm.tm_mon + 1, m_CachedDatetime);
- m_CachedDatetime.push_back('-');
-
- helpers::Pad2(m_CachedTm.tm_mday, m_CachedDatetime);
- m_CachedDatetime.push_back(' ');
-
- helpers::Pad2(m_CachedTm.tm_hour, m_CachedDatetime);
- m_CachedDatetime.push_back(':');
-
- helpers::Pad2(m_CachedTm.tm_min, m_CachedDatetime);
- m_CachedDatetime.push_back(':');
-
- helpers::Pad2(m_CachedTm.tm_sec, m_CachedDatetime);
+ explicit JsonFormatter(std::string_view LogId);
+ ~JsonFormatter() override;
- m_CachedDatetime.push_back('.');
- }
- helpers::AppendStringView("{"sv, Dest);
- helpers::AppendStringView("\"time\": \""sv, Dest);
- {
- RwLock::SharedLockScope _(m_TimestampLock);
- Dest.append(m_CachedDatetime.begin(), m_CachedDatetime.end());
- }
- auto Millis = helpers::TimeFraction<milliseconds>(Msg.GetTime());
- helpers::Pad3(static_cast<uint32_t>(Millis.count()), Dest);
- helpers::AppendStringView("\", "sv, Dest);
-
- helpers::AppendStringView("\"status\": \""sv, Dest);
- helpers::AppendStringView(helpers::LevelToShortString(Msg.GetLevel()), Dest);
- helpers::AppendStringView("\", "sv, Dest);
-
- helpers::AppendStringView("\"source\": \""sv, Dest);
- helpers::AppendStringView("zenserver"sv, Dest);
- helpers::AppendStringView("\", "sv, Dest);
-
- helpers::AppendStringView("\"service\": \""sv, Dest);
- helpers::AppendStringView("zencache"sv, Dest);
- helpers::AppendStringView("\", "sv, Dest);
-
- if (!m_LogId.empty())
- {
- helpers::AppendStringView("\"id\": \""sv, Dest);
- helpers::AppendStringView(m_LogId, Dest);
- helpers::AppendStringView("\", "sv, Dest);
- }
-
- if (Msg.GetLoggerName().size() > 0)
- {
- helpers::AppendStringView("\"logger.name\": \""sv, Dest);
- helpers::AppendStringView(Msg.GetLoggerName(), Dest);
- helpers::AppendStringView("\", "sv, Dest);
- }
-
- if (Msg.GetThreadId() != 0)
- {
- helpers::AppendStringView("\"logger.thread_name\": \""sv, Dest);
- helpers::PadUint(Msg.GetThreadId(), 0, Dest);
- helpers::AppendStringView("\", "sv, Dest);
- }
-
- if (Msg.GetSource())
- {
- helpers::AppendStringView("\"file\": \""sv, Dest);
- WriteEscapedString(Dest, helpers::ShortFilename(Msg.GetSource().Filename));
- helpers::AppendStringView("\","sv, Dest);
-
- helpers::AppendStringView("\"line\": \""sv, Dest);
- helpers::AppendInt(Msg.GetSource().Line, Dest);
- helpers::AppendStringView("\","sv, Dest);
- }
-
- helpers::AppendStringView("\"message\": \""sv, Dest);
- WriteEscapedString(Dest, Msg.GetPayload());
- helpers::AppendStringView("\""sv, Dest);
-
- helpers::AppendStringView("}\n"sv, Dest);
- }
+ std::unique_ptr<Formatter> Clone() const override;
+ void Format(const LogMessage& Msg, MemoryBuffer& Dest) override;
private:
- static inline const std::unordered_map<char, std::string_view> s_SpecialCharacterMap{{'\b', "\\b"sv},
- {'\f', "\\f"sv},
- {'\n', "\\n"sv},
- {'\r', "\\r"sv},
- {'\t', "\\t"sv},
- {'"', "\\\""sv},
- {'\\', "\\\\"sv}};
-
- static void WriteEscapedString(MemoryBuffer& Dest, const std::string_view& Text)
- {
- const char* RangeStart = Text.data();
- const char* End = Text.data() + Text.size();
- for (const char* It = RangeStart; It != End; ++It)
- {
- if (auto SpecialIt = s_SpecialCharacterMap.find(*It); SpecialIt != s_SpecialCharacterMap.end())
- {
- if (RangeStart != It)
- {
- Dest.append(RangeStart, It);
- }
- helpers::AppendStringView(SpecialIt->second, Dest);
- RangeStart = It + 1;
- }
- }
- if (RangeStart != End)
- {
- Dest.append(RangeStart, End);
- }
- };
-
- std::tm m_CachedTm{0, 0, 0, 0, 0, 0, 0, 0, 0};
- std::chrono::seconds m_LastLogSecs{0};
- MemoryBuffer m_CachedDatetime;
- std::string m_LogId;
- RwLock m_TimestampLock;
+ struct Impl;
+ std::unique_ptr<Impl> m_Impl;
};
} // namespace zen::logging
diff --git a/src/zenutil/include/zenutil/logging/rotatingfilesink.h b/src/zenutil/include/zenutil/logging/rotatingfilesink.h
index cebc5b110..e0ff7eca1 100644
--- a/src/zenutil/include/zenutil/logging/rotatingfilesink.h
+++ b/src/zenutil/include/zenutil/logging/rotatingfilesink.h
@@ -2,14 +2,11 @@
#pragma once
-#include <zencore/basicfile.h>
-#include <zencore/logging/formatter.h>
-#include <zencore/logging/messageonlyformatter.h>
#include <zencore/logging/sink.h>
-#include <zencore/memory/llm.h>
-#include <atomic>
+#include <cstddef>
#include <filesystem>
+#include <memory>
namespace zen::logging {
@@ -19,230 +16,21 @@ namespace zen::logging {
class RotatingFileSink : public Sink
{
public:
- RotatingFileSink(const std::filesystem::path& BaseFilename, std::size_t MaxSize, std::size_t MaxFiles, bool RotateOnOpen = false)
- : m_BaseFilename(BaseFilename)
- , m_MaxSize(MaxSize)
- , m_MaxFiles(MaxFiles)
- , m_Formatter(std::make_unique<MessageOnlyFormatter>())
- {
- ZEN_MEMSCOPE(ELLMTag::Logging);
-
- std::error_code Ec;
- if (RotateOnOpen)
- {
- RwLock::ExclusiveLockScope RotateLock(m_Lock);
- Rotate(RotateLock, Ec);
- }
- else
- {
- m_CurrentFile.Open(m_BaseFilename, BasicFile::Mode::kWrite, Ec);
- if (!Ec)
- {
- m_CurrentSize = m_CurrentFile.FileSize(Ec);
- }
- if (!Ec)
- {
- if (m_CurrentSize > m_MaxSize)
- {
- RwLock::ExclusiveLockScope RotateLock(m_Lock);
- Rotate(RotateLock, Ec);
- }
- }
- }
-
- if (Ec)
- {
- throw std::system_error(Ec, fmt::format("Failed to open log file '{}'", m_BaseFilename.string()));
- }
- }
-
- virtual ~RotatingFileSink()
- {
- try
- {
- RwLock::ExclusiveLockScope RotateLock(m_Lock);
- m_CurrentFile.Close();
- }
- catch (const std::exception&)
- {
- }
- }
+ RotatingFileSink(const std::filesystem::path& BaseFilename, std::size_t MaxSize, std::size_t MaxFiles, bool RotateOnOpen = false);
+ ~RotatingFileSink() override;
RotatingFileSink(const RotatingFileSink&) = delete;
RotatingFileSink(RotatingFileSink&&) = delete;
-
RotatingFileSink& operator=(const RotatingFileSink&) = delete;
RotatingFileSink& operator=(RotatingFileSink&&) = delete;
- virtual void Log(const LogMessage& Msg) override
- {
- ZEN_MEMSCOPE(ELLMTag::Logging);
-
- try
- {
- MemoryBuffer Formatted;
- if (TrySinkIt(Msg, Formatted))
- {
- return;
- }
-
- // This intentionally has no limit on the number of retries, see
- // comment above.
- for (;;)
- {
- {
- RwLock::ExclusiveLockScope RotateLock(m_Lock);
- // Only rotate if no-one else has rotated before us
- if (m_CurrentSize > m_MaxSize || !m_CurrentFile.IsOpen())
- {
- std::error_code Ec;
- Rotate(RotateLock, Ec);
- if (Ec)
- {
- return;
- }
- }
- }
- if (TrySinkIt(Formatted))
- {
- return;
- }
- }
- }
- catch (const std::exception&)
- {
- // Silently eat errors
- }
- }
- virtual void Flush() override
- {
- if (!m_NeedFlush)
- {
- return;
- }
-
- ZEN_MEMSCOPE(ELLMTag::Logging);
-
- try
- {
- RwLock::SharedLockScope Lock(m_Lock);
- if (m_CurrentFile.IsOpen())
- {
- m_CurrentFile.Flush();
- }
- }
- catch (const std::exception&)
- {
- // Silently eat errors
- }
-
- m_NeedFlush = false;
- }
-
- virtual void SetFormatter(std::unique_ptr<Formatter> InFormatter) override
- {
- ZEN_MEMSCOPE(ELLMTag::Logging);
-
- try
- {
- RwLock::ExclusiveLockScope _(m_Lock);
- m_Formatter = std::move(InFormatter);
- }
- catch (const std::exception&)
- {
- // Silently eat errors
- }
- }
+ void Log(const LogMessage& Msg) override;
+ void Flush() override;
+ void SetFormatter(std::unique_ptr<Formatter> InFormatter) override;
private:
- void Rotate(RwLock::ExclusiveLockScope&, std::error_code& OutEc)
- {
- ZEN_MEMSCOPE(ELLMTag::Logging);
-
- m_CurrentFile.Close();
-
- OutEc = RotateFiles(m_BaseFilename, m_MaxFiles);
- if (OutEc)
- {
- return;
- }
-
- m_CurrentFile.Open(m_BaseFilename, BasicFile::Mode::kWrite, OutEc);
- if (OutEc)
- {
- return;
- }
-
- m_CurrentSize = m_CurrentFile.FileSize(OutEc);
- if (OutEc)
- {
- // FileSize failed but we have an open file — reset to 0
- // so we can at least attempt writes from the start
- m_CurrentSize = 0;
- OutEc.clear();
- }
- }
-
- bool TrySinkIt(const LogMessage& Msg, MemoryBuffer& OutFormatted)
- {
- ZEN_MEMSCOPE(ELLMTag::Logging);
-
- RwLock::SharedLockScope Lock(m_Lock);
- if (!m_CurrentFile.IsOpen())
- {
- return false;
- }
- m_Formatter->Format(Msg, OutFormatted);
- size_t AddSize = OutFormatted.size();
- size_t WritePos = m_CurrentSize.fetch_add(AddSize);
- if (WritePos + AddSize > m_MaxSize)
- {
- return false;
- }
- std::error_code Ec;
- m_CurrentFile.Write(OutFormatted.data(), OutFormatted.size(), WritePos, Ec);
- if (Ec)
- {
- return false;
- }
- m_NeedFlush = true;
- return true;
- }
-
- bool TrySinkIt(const MemoryBuffer& Formatted)
- {
- ZEN_MEMSCOPE(ELLMTag::Logging);
-
- RwLock::SharedLockScope Lock(m_Lock);
- if (!m_CurrentFile.IsOpen())
- {
- return false;
- }
- size_t AddSize = Formatted.size();
- size_t WritePos = m_CurrentSize.fetch_add(AddSize);
- if (WritePos + AddSize > m_MaxSize)
- {
- return false;
- }
-
- std::error_code Ec;
- m_CurrentFile.Write(Formatted.data(), Formatted.size(), WritePos, Ec);
- if (Ec)
- {
- return false;
- }
- m_NeedFlush = true;
- return true;
- }
-
- RwLock m_Lock;
- const std::filesystem::path m_BaseFilename;
- const std::size_t m_MaxSize;
- const std::size_t m_MaxFiles;
- std::unique_ptr<Formatter> m_Formatter;
- std::atomic_size_t m_CurrentSize;
- BasicFile m_CurrentFile;
- std::atomic<bool> m_NeedFlush = false;
+ struct Impl;
+ std::unique_ptr<Impl> m_Impl;
};
} // namespace zen::logging
diff --git a/src/zenutil/logging/fullformatter.cpp b/src/zenutil/logging/fullformatter.cpp
new file mode 100644
index 000000000..2a4840241
--- /dev/null
+++ b/src/zenutil/logging/fullformatter.cpp
@@ -0,0 +1,235 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zenutil/logging/fullformatter.h>
+
+#include <zencore/intmath.h>
+#include <zencore/logging/helpers.h>
+#include <zencore/logging/memorybuffer.h>
+#include <zencore/memory/llm.h>
+#include <zencore/thread.h>
+#include <zencore/zencore.h>
+
+#include <atomic>
+#include <chrono>
+#include <string>
+
+namespace zen::logging {
+
+struct FullFormatter::Impl
+{
+ Impl(std::string_view LogId, std::chrono::time_point<std::chrono::system_clock> Epoch)
+ : m_Epoch(Epoch)
+ , m_LogId(LogId)
+ , m_LinePrefix(128, ' ')
+ , m_UseFullDate(false)
+ {
+ }
+
+ explicit Impl(std::string_view LogId) : m_LogId(LogId), m_LinePrefix(128, ' '), m_UseFullDate(true) {}
+
+ std::chrono::time_point<std::chrono::system_clock> m_Epoch;
+ std::tm m_CachedLocalTm{};
+ std::chrono::seconds m_LastLogSecs{std::chrono::seconds(87654321)};
+ std::atomic<std::chrono::seconds> m_CacheTimestamp{std::chrono::seconds(87654321)};
+ MemoryBuffer m_CachedDatetime;
+ std::string m_LogId;
+ std::string m_LinePrefix;
+ bool m_UseFullDate = true;
+ RwLock m_TimestampLock;
+};
+
+FullFormatter::FullFormatter(std::string_view LogId, std::chrono::time_point<std::chrono::system_clock> Epoch)
+: m_Impl(std::make_unique<Impl>(LogId, Epoch))
+{
+}
+
+FullFormatter::FullFormatter(std::string_view LogId) : m_Impl(std::make_unique<Impl>(LogId))
+{
+}
+
+FullFormatter::~FullFormatter() = default;
+
+std::unique_ptr<Formatter>
+FullFormatter::Clone() const
+{
+ ZEN_MEMSCOPE(ELLMTag::Logging);
+ std::unique_ptr<FullFormatter> Copy;
+ if (m_Impl->m_UseFullDate)
+ {
+ Copy = std::make_unique<FullFormatter>(m_Impl->m_LogId);
+ }
+ else
+ {
+ Copy = std::make_unique<FullFormatter>(m_Impl->m_LogId, m_Impl->m_Epoch);
+ }
+ Copy->SetColorEnabled(IsColorEnabled());
+ return Copy;
+}
+
+void
+FullFormatter::Format(const LogMessage& Msg, MemoryBuffer& OutBuffer)
+{
+ ZEN_MEMSCOPE(ELLMTag::Logging);
+
+ // Note that the sink is responsible for ensuring there is only ever a
+ // single caller in here
+
+ using namespace std::literals;
+
+ std::chrono::seconds TimestampSeconds;
+
+ std::chrono::milliseconds Millis;
+
+ if (m_Impl->m_UseFullDate)
+ {
+ TimestampSeconds = std::chrono::duration_cast<std::chrono::seconds>(Msg.GetTime().time_since_epoch());
+ if (TimestampSeconds != m_Impl->m_LastLogSecs)
+ {
+ RwLock::ExclusiveLockScope _(m_Impl->m_TimestampLock);
+ m_Impl->m_LastLogSecs = TimestampSeconds;
+
+ m_Impl->m_CachedLocalTm = helpers::SafeLocaltime(LogClock::to_time_t(Msg.GetTime()));
+ m_Impl->m_CachedDatetime.clear();
+ m_Impl->m_CachedDatetime.push_back('[');
+ helpers::Pad2(m_Impl->m_CachedLocalTm.tm_year % 100, m_Impl->m_CachedDatetime);
+ m_Impl->m_CachedDatetime.push_back('-');
+ helpers::Pad2(m_Impl->m_CachedLocalTm.tm_mon + 1, m_Impl->m_CachedDatetime);
+ m_Impl->m_CachedDatetime.push_back('-');
+ helpers::Pad2(m_Impl->m_CachedLocalTm.tm_mday, m_Impl->m_CachedDatetime);
+ m_Impl->m_CachedDatetime.push_back(' ');
+ helpers::Pad2(m_Impl->m_CachedLocalTm.tm_hour, m_Impl->m_CachedDatetime);
+ m_Impl->m_CachedDatetime.push_back(':');
+ helpers::Pad2(m_Impl->m_CachedLocalTm.tm_min, m_Impl->m_CachedDatetime);
+ m_Impl->m_CachedDatetime.push_back(':');
+ helpers::Pad2(m_Impl->m_CachedLocalTm.tm_sec, m_Impl->m_CachedDatetime);
+ m_Impl->m_CachedDatetime.push_back('.');
+ }
+
+ Millis = helpers::TimeFraction<std::chrono::milliseconds>(Msg.GetTime());
+ }
+ else
+ {
+ auto ElapsedTime = Msg.GetTime() - m_Impl->m_Epoch;
+ TimestampSeconds = std::chrono::duration_cast<std::chrono::seconds>(ElapsedTime);
+
+ if (m_Impl->m_CacheTimestamp.load() != TimestampSeconds)
+ {
+ RwLock::ExclusiveLockScope _(m_Impl->m_TimestampLock);
+
+ m_Impl->m_CacheTimestamp = TimestampSeconds;
+
+ int Count = int(TimestampSeconds.count());
+ const int LogSecs = Count % 60;
+ Count /= 60;
+ const int LogMins = Count % 60;
+ Count /= 60;
+ const int LogHours = Count;
+
+ m_Impl->m_CachedDatetime.clear();
+ m_Impl->m_CachedDatetime.push_back('[');
+ helpers::Pad2(LogHours, m_Impl->m_CachedDatetime);
+ m_Impl->m_CachedDatetime.push_back(':');
+ helpers::Pad2(LogMins, m_Impl->m_CachedDatetime);
+ m_Impl->m_CachedDatetime.push_back(':');
+ helpers::Pad2(LogSecs, m_Impl->m_CachedDatetime);
+ m_Impl->m_CachedDatetime.push_back('.');
+ }
+
+ Millis = std::chrono::duration_cast<std::chrono::milliseconds>(ElapsedTime - TimestampSeconds);
+ }
+
+ {
+ RwLock::SharedLockScope _(m_Impl->m_TimestampLock);
+ OutBuffer.append(m_Impl->m_CachedDatetime.begin(), m_Impl->m_CachedDatetime.end());
+ }
+
+ helpers::Pad3(static_cast<uint32_t>(Millis.count()), OutBuffer);
+ OutBuffer.push_back(']');
+ OutBuffer.push_back(' ');
+
+ if (!m_Impl->m_LogId.empty())
+ {
+ OutBuffer.push_back('[');
+ helpers::AppendStringView(m_Impl->m_LogId, OutBuffer);
+ OutBuffer.push_back(']');
+ OutBuffer.push_back(' ');
+ }
+
+ // append logger name if exists
+ if (Msg.GetLoggerName().size() > 0)
+ {
+ OutBuffer.push_back('[');
+ helpers::AppendStringView(Msg.GetLoggerName(), OutBuffer);
+ OutBuffer.push_back(']');
+ OutBuffer.push_back(' ');
+ }
+
+ OutBuffer.push_back('[');
+ if (IsColorEnabled())
+ {
+ helpers::AppendAnsiColor(Msg.GetLevel(), OutBuffer);
+ }
+ helpers::AppendStringView(helpers::LevelToShortString(Msg.GetLevel()), OutBuffer);
+ if (IsColorEnabled())
+ {
+ helpers::AppendAnsiReset(OutBuffer);
+ }
+ OutBuffer.push_back(']');
+ OutBuffer.push_back(' ');
+
+ // add source location if present
+ if (Msg.GetSource())
+ {
+ OutBuffer.push_back('[');
+ const char* Filename = helpers::ShortFilename(Msg.GetSource().Filename);
+ helpers::AppendStringView(Filename, OutBuffer);
+ OutBuffer.push_back(':');
+ helpers::AppendInt(Msg.GetSource().Line, OutBuffer);
+ OutBuffer.push_back(']');
+ OutBuffer.push_back(' ');
+ }
+
+ // Handle newlines in single log call by prefixing each additional line to make
+ // subsequent lines align with the first line in the message
+
+ size_t AnsiBytes = IsColorEnabled() ? (helpers::AnsiColorForLevel(Msg.GetLevel()).size() + helpers::kAnsiReset.size()) : 0;
+ const size_t LinePrefixCount = Min<size_t>(OutBuffer.size() - AnsiBytes, m_Impl->m_LinePrefix.size());
+
+ auto MsgPayload = Msg.GetPayload();
+ auto ItLineBegin = MsgPayload.begin();
+ auto ItMessageEnd = MsgPayload.end();
+ bool IsFirstline = true;
+
+ {
+ auto ItLineEnd = ItLineBegin;
+
+ auto EmitLine = [&] {
+ if (IsFirstline)
+ {
+ IsFirstline = false;
+ }
+ else
+ {
+ helpers::AppendStringView(std::string_view(m_Impl->m_LinePrefix.data(), LinePrefixCount), OutBuffer);
+ }
+ helpers::AppendStringView(std::string_view(&*ItLineBegin, ItLineEnd - ItLineBegin), OutBuffer);
+ };
+
+ while (ItLineEnd != ItMessageEnd)
+ {
+ if (*ItLineEnd++ == '\n')
+ {
+ EmitLine();
+ ItLineBegin = ItLineEnd;
+ }
+ }
+
+ if (ItLineBegin != ItMessageEnd)
+ {
+ EmitLine();
+ helpers::AppendStringView("\n"sv, OutBuffer);
+ }
+ }
+}
+
+} // namespace zen::logging
diff --git a/src/zenutil/logging/jsonformatter.cpp b/src/zenutil/logging/jsonformatter.cpp
new file mode 100644
index 000000000..673a03c94
--- /dev/null
+++ b/src/zenutil/logging/jsonformatter.cpp
@@ -0,0 +1,198 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zenutil/logging/jsonformatter.h>
+
+#include <zencore/logging/helpers.h>
+#include <zencore/memory/llm.h>
+#include <zencore/thread.h>
+#include <zencore/zencore.h>
+
+#include <chrono>
+#include <string>
+#include <unordered_map>
+
+namespace zen::logging {
+
+using namespace std::literals;
+
+static void
+WriteEscapedString(MemoryBuffer& Dest, std::string_view Text)
+{
+ // Strip ANSI SGR sequences before escaping so they don't appear in JSON output
+ static const auto IsEscapeStart = [](char C) { return C == '\033'; };
+
+ const char* RangeStart = Text.data();
+ const char* End = Text.data() + Text.size();
+
+ static const std::unordered_map<char, std::string_view> SpecialCharacterMap{
+ {'\b', "\\b"sv},
+ {'\f', "\\f"sv},
+ {'\n', "\\n"sv},
+ {'\r', "\\r"sv},
+ {'\t', "\\t"sv},
+ {'"', "\\\""sv},
+ {'\\', "\\\\"sv},
+ };
+
+ for (const char* It = RangeStart; It != End; ++It)
+ {
+ // Skip ANSI SGR escape sequences (\033[...m)
+ if (*It == '\033' && (It + 1) < End && *(It + 1) == '[')
+ {
+ if (RangeStart != It)
+ {
+ Dest.append(RangeStart, It);
+ }
+ const char* Seq = It + 2;
+ while (Seq < End && *Seq != 'm')
+ {
+ ++Seq;
+ }
+ if (Seq < End)
+ {
+ ++Seq; // skip 'm'
+ }
+ It = Seq - 1; // -1 because the for loop will ++It
+ RangeStart = Seq;
+ continue;
+ }
+
+ if (auto SpecialIt = SpecialCharacterMap.find(*It); SpecialIt != SpecialCharacterMap.end())
+ {
+ if (RangeStart != It)
+ {
+ Dest.append(RangeStart, It);
+ }
+ helpers::AppendStringView(SpecialIt->second, Dest);
+ RangeStart = It + 1;
+ }
+ }
+ if (RangeStart != End)
+ {
+ Dest.append(RangeStart, End);
+ }
+}
+
+struct JsonFormatter::Impl
+{
+ explicit Impl(std::string_view LogId) : m_LogId(LogId) {}
+
+ std::tm m_CachedTm{0, 0, 0, 0, 0, 0, 0, 0, 0};
+ std::chrono::seconds m_LastLogSecs{0};
+ MemoryBuffer m_CachedDatetime;
+ std::string m_LogId;
+ RwLock m_TimestampLock;
+};
+
+JsonFormatter::JsonFormatter(std::string_view LogId) : m_Impl(std::make_unique<Impl>(LogId))
+{
+}
+
+JsonFormatter::~JsonFormatter() = default;
+
+std::unique_ptr<Formatter>
+JsonFormatter::Clone() const
+{
+ return std::make_unique<JsonFormatter>(m_Impl->m_LogId);
+}
+
+void
+JsonFormatter::Format(const LogMessage& Msg, MemoryBuffer& Dest)
+{
+ ZEN_MEMSCOPE(ELLMTag::Logging);
+
+ using std::chrono::duration_cast;
+ using std::chrono::milliseconds;
+ using std::chrono::seconds;
+
+ auto Secs = duration_cast<seconds>(Msg.GetTime().time_since_epoch());
+ if (Secs != m_Impl->m_LastLogSecs)
+ {
+ RwLock::ExclusiveLockScope _(m_Impl->m_TimestampLock);
+ m_Impl->m_CachedTm = helpers::SafeLocaltime(LogClock::to_time_t(Msg.GetTime()));
+ m_Impl->m_LastLogSecs = Secs;
+
+ // cache the date/time part for the next second.
+ m_Impl->m_CachedDatetime.clear();
+
+ helpers::AppendInt(m_Impl->m_CachedTm.tm_year + 1900, m_Impl->m_CachedDatetime);
+ m_Impl->m_CachedDatetime.push_back('-');
+
+ helpers::Pad2(m_Impl->m_CachedTm.tm_mon + 1, m_Impl->m_CachedDatetime);
+ m_Impl->m_CachedDatetime.push_back('-');
+
+ helpers::Pad2(m_Impl->m_CachedTm.tm_mday, m_Impl->m_CachedDatetime);
+ m_Impl->m_CachedDatetime.push_back(' ');
+
+ helpers::Pad2(m_Impl->m_CachedTm.tm_hour, m_Impl->m_CachedDatetime);
+ m_Impl->m_CachedDatetime.push_back(':');
+
+ helpers::Pad2(m_Impl->m_CachedTm.tm_min, m_Impl->m_CachedDatetime);
+ m_Impl->m_CachedDatetime.push_back(':');
+
+ helpers::Pad2(m_Impl->m_CachedTm.tm_sec, m_Impl->m_CachedDatetime);
+
+ m_Impl->m_CachedDatetime.push_back('.');
+ }
+ helpers::AppendStringView("{"sv, Dest);
+ helpers::AppendStringView("\"time\": \""sv, Dest);
+ {
+ RwLock::SharedLockScope _(m_Impl->m_TimestampLock);
+ Dest.append(m_Impl->m_CachedDatetime.begin(), m_Impl->m_CachedDatetime.end());
+ }
+ auto Millis = helpers::TimeFraction<milliseconds>(Msg.GetTime());
+ helpers::Pad3(static_cast<uint32_t>(Millis.count()), Dest);
+ helpers::AppendStringView("\", "sv, Dest);
+
+ helpers::AppendStringView("\"status\": \""sv, Dest);
+ helpers::AppendStringView(helpers::LevelToShortString(Msg.GetLevel()), Dest);
+ helpers::AppendStringView("\", "sv, Dest);
+
+ helpers::AppendStringView("\"source\": \""sv, Dest);
+ helpers::AppendStringView("zenserver"sv, Dest);
+ helpers::AppendStringView("\", "sv, Dest);
+
+ helpers::AppendStringView("\"service\": \""sv, Dest);
+ helpers::AppendStringView("zencache"sv, Dest);
+ helpers::AppendStringView("\", "sv, Dest);
+
+ if (!m_Impl->m_LogId.empty())
+ {
+ helpers::AppendStringView("\"id\": \""sv, Dest);
+ helpers::AppendStringView(m_Impl->m_LogId, Dest);
+ helpers::AppendStringView("\", "sv, Dest);
+ }
+
+ if (Msg.GetLoggerName().size() > 0)
+ {
+ helpers::AppendStringView("\"logger.name\": \""sv, Dest);
+ helpers::AppendStringView(Msg.GetLoggerName(), Dest);
+ helpers::AppendStringView("\", "sv, Dest);
+ }
+
+ if (Msg.GetThreadId() != 0)
+ {
+ helpers::AppendStringView("\"logger.thread_name\": \""sv, Dest);
+ helpers::PadUint(Msg.GetThreadId(), 0, Dest);
+ helpers::AppendStringView("\", "sv, Dest);
+ }
+
+ if (Msg.GetSource())
+ {
+ helpers::AppendStringView("\"file\": \""sv, Dest);
+ WriteEscapedString(Dest, helpers::ShortFilename(Msg.GetSource().Filename));
+ helpers::AppendStringView("\","sv, Dest);
+
+ helpers::AppendStringView("\"line\": \""sv, Dest);
+ helpers::AppendInt(Msg.GetSource().Line, Dest);
+ helpers::AppendStringView("\","sv, Dest);
+ }
+
+ helpers::AppendStringView("\"message\": \""sv, Dest);
+ WriteEscapedString(Dest, Msg.GetPayload());
+ helpers::AppendStringView("\""sv, Dest);
+
+ helpers::AppendStringView("}\n"sv, Dest);
+}
+
+} // namespace zen::logging
diff --git a/src/zenutil/logging.cpp b/src/zenutil/logging/logging.cpp
index 1258ca155..ea2448a42 100644
--- a/src/zenutil/logging.cpp
+++ b/src/zenutil/logging/logging.cpp
@@ -238,8 +238,10 @@ FinishInitializeLogging(const LoggingOptions& LogOptions)
const std::string StartLogTime = zen::DateTime::Now().ToIso8601();
- static constinit logging::LogPoint LogStartPoint{{}, logging::Info, "log starting at {}"};
- logging::Registry::Instance().ApplyAll([&](auto Logger) { Logger->Log(LogStartPoint, fmt::make_format_args(StartLogTime)); });
+ logging::Registry::Instance().ApplyAll([&](auto Logger) {
+ static constinit logging::LogPoint LogStartPoint{{}, logging::Info, "log starting at {}"};
+ Logger->Log(LogStartPoint, fmt::make_format_args(StartLogTime));
+ });
}
g_IsLoggingInitialized = true;
diff --git a/src/zenutil/logging/rotatingfilesink.cpp b/src/zenutil/logging/rotatingfilesink.cpp
new file mode 100644
index 000000000..23cf60d16
--- /dev/null
+++ b/src/zenutil/logging/rotatingfilesink.cpp
@@ -0,0 +1,249 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zenutil/logging/rotatingfilesink.h>
+
+#include <zencore/basicfile.h>
+#include <zencore/filesystem.h>
+#include <zencore/logging/helpers.h>
+#include <zencore/logging/messageonlyformatter.h>
+#include <zencore/memory/llm.h>
+#include <zencore/thread.h>
+
+#include <atomic>
+
+namespace zen::logging {
+
+struct RotatingFileSink::Impl
+{
+ Impl(const std::filesystem::path& BaseFilename, std::size_t MaxSize, std::size_t MaxFiles, bool RotateOnOpen)
+ : m_BaseFilename(BaseFilename)
+ , m_MaxSize(MaxSize)
+ , m_MaxFiles(MaxFiles)
+ , m_Formatter(std::make_unique<MessageOnlyFormatter>())
+ {
+ ZEN_MEMSCOPE(ELLMTag::Logging);
+
+ std::error_code Ec;
+ if (RotateOnOpen)
+ {
+ RwLock::ExclusiveLockScope RotateLock(m_Lock);
+ Rotate(RotateLock, Ec);
+ }
+ else
+ {
+ m_CurrentFile.Open(m_BaseFilename, BasicFile::Mode::kWrite, Ec);
+ if (!Ec)
+ {
+ m_CurrentSize = m_CurrentFile.FileSize(Ec);
+ }
+ if (!Ec)
+ {
+ if (m_CurrentSize > m_MaxSize)
+ {
+ RwLock::ExclusiveLockScope RotateLock(m_Lock);
+ Rotate(RotateLock, Ec);
+ }
+ }
+ }
+
+ if (Ec)
+ {
+ throw std::system_error(Ec, fmt::format("Failed to open log file '{}'", m_BaseFilename.string()));
+ }
+ }
+
+ ~Impl()
+ {
+ try
+ {
+ RwLock::ExclusiveLockScope RotateLock(m_Lock);
+ m_CurrentFile.Close();
+ }
+ catch (const std::exception&)
+ {
+ }
+ }
+
+ void Rotate(RwLock::ExclusiveLockScope&, std::error_code& OutEc)
+ {
+ ZEN_MEMSCOPE(ELLMTag::Logging);
+
+ m_CurrentFile.Close();
+
+ OutEc = RotateFiles(m_BaseFilename, m_MaxFiles);
+ if (OutEc)
+ {
+ return;
+ }
+
+ m_CurrentFile.Open(m_BaseFilename, BasicFile::Mode::kWrite, OutEc);
+ if (OutEc)
+ {
+ return;
+ }
+
+ m_CurrentSize = m_CurrentFile.FileSize(OutEc);
+ if (OutEc)
+ {
+ // FileSize failed but we have an open file — reset to 0
+ // so we can at least attempt writes from the start
+ m_CurrentSize = 0;
+ OutEc.clear();
+ }
+ }
+
+ bool TrySinkIt(const LogMessage& Msg, MemoryBuffer& OutFormatted)
+ {
+ ZEN_MEMSCOPE(ELLMTag::Logging);
+
+ RwLock::SharedLockScope Lock(m_Lock);
+ if (!m_CurrentFile.IsOpen())
+ {
+ return false;
+ }
+ m_Formatter->Format(Msg, OutFormatted);
+ helpers::StripAnsiSgrSequences(OutFormatted);
+ size_t AddSize = OutFormatted.size();
+ size_t WritePos = m_CurrentSize.fetch_add(AddSize);
+ if (WritePos + AddSize > m_MaxSize)
+ {
+ return false;
+ }
+ std::error_code Ec;
+ m_CurrentFile.Write(OutFormatted.data(), OutFormatted.size(), WritePos, Ec);
+ if (Ec)
+ {
+ return false;
+ }
+ m_NeedFlush = true;
+ return true;
+ }
+
+ bool TrySinkIt(const MemoryBuffer& Formatted)
+ {
+ ZEN_MEMSCOPE(ELLMTag::Logging);
+
+ RwLock::SharedLockScope Lock(m_Lock);
+ if (!m_CurrentFile.IsOpen())
+ {
+ return false;
+ }
+ size_t AddSize = Formatted.size();
+ size_t WritePos = m_CurrentSize.fetch_add(AddSize);
+ if (WritePos + AddSize > m_MaxSize)
+ {
+ return false;
+ }
+
+ std::error_code Ec;
+ m_CurrentFile.Write(Formatted.data(), Formatted.size(), WritePos, Ec);
+ if (Ec)
+ {
+ return false;
+ }
+ m_NeedFlush = true;
+ return true;
+ }
+
+ RwLock m_Lock;
+ const std::filesystem::path m_BaseFilename;
+ const std::size_t m_MaxSize;
+ const std::size_t m_MaxFiles;
+ std::unique_ptr<Formatter> m_Formatter;
+ std::atomic_size_t m_CurrentSize;
+ BasicFile m_CurrentFile;
+ std::atomic<bool> m_NeedFlush = false;
+};
+
+RotatingFileSink::RotatingFileSink(const std::filesystem::path& BaseFilename, std::size_t MaxSize, std::size_t MaxFiles, bool RotateOnOpen)
+: m_Impl(std::make_unique<Impl>(BaseFilename, MaxSize, MaxFiles, RotateOnOpen))
+{
+}
+
+RotatingFileSink::~RotatingFileSink() = default;
+
+void
+RotatingFileSink::Log(const LogMessage& Msg)
+{
+ ZEN_MEMSCOPE(ELLMTag::Logging);
+
+ try
+ {
+ MemoryBuffer Formatted;
+ if (m_Impl->TrySinkIt(Msg, Formatted))
+ {
+ return;
+ }
+
+ // This intentionally has no limit on the number of retries, see
+ // comment above.
+ for (;;)
+ {
+ {
+ RwLock::ExclusiveLockScope RotateLock(m_Impl->m_Lock);
+ // Only rotate if no-one else has rotated before us
+ if (m_Impl->m_CurrentSize > m_Impl->m_MaxSize || !m_Impl->m_CurrentFile.IsOpen())
+ {
+ std::error_code Ec;
+ m_Impl->Rotate(RotateLock, Ec);
+ if (Ec)
+ {
+ return;
+ }
+ }
+ }
+ if (m_Impl->TrySinkIt(Formatted))
+ {
+ return;
+ }
+ }
+ }
+ catch (const std::exception&)
+ {
+ // Silently eat errors
+ }
+}
+
+void
+RotatingFileSink::Flush()
+{
+ if (!m_Impl->m_NeedFlush)
+ {
+ return;
+ }
+
+ ZEN_MEMSCOPE(ELLMTag::Logging);
+
+ try
+ {
+ RwLock::SharedLockScope Lock(m_Impl->m_Lock);
+ if (m_Impl->m_CurrentFile.IsOpen())
+ {
+ m_Impl->m_CurrentFile.Flush();
+ }
+ }
+ catch (const std::exception&)
+ {
+ // Silently eat errors
+ }
+
+ m_Impl->m_NeedFlush = false;
+}
+
+void
+RotatingFileSink::SetFormatter(std::unique_ptr<Formatter> InFormatter)
+{
+ ZEN_MEMSCOPE(ELLMTag::Logging);
+
+ try
+ {
+ RwLock::ExclusiveLockScope _(m_Impl->m_Lock);
+ m_Impl->m_Formatter = std::move(InFormatter);
+ }
+ catch (const std::exception&)
+ {
+ // Silently eat errors
+ }
+}
+
+} // namespace zen::logging