diff options
| author | Stefan Boberg <[email protected]> | 2026-03-18 11:19:10 +0100 |
|---|---|---|
| committer | GitHub Enterprise <[email protected]> | 2026-03-18 11:19:10 +0100 |
| commit | eba410c4168e23d7908827eb34b7cf0c58a5dc48 (patch) | |
| tree | 3cda8e8f3f81941d3bb5b84a8155350c5bb2068c /src | |
| parent | bugfix release - v5.7.23 (#851) (diff) | |
| download | zen-eba410c4168e23d7908827eb34b7cf0c58a5dc48.tar.xz zen-eba410c4168e23d7908827eb34b7cf0c58a5dc48.zip | |
Compute batching (#849)
### Compute Batch Submission
- Consolidate duplicated action submission logic in `httpcomputeservice` into a single `HandleSubmitAction` supporting both single-action and batch (actions array) payloads
- Group actions by queue in `RemoteHttpRunner` and submit as batches with configurable chunk size, falling back to individual submission on failure
- Extract shared helpers: `MakeErrorResult`, `ValidateQueueForEnqueue`, `ActivateActionInQueue`, `RemoveActionFromActiveMaps`
### Retracted Action State
- Add `Retracted` state to `RunnerAction` for retry-free rescheduling — an explicit request to pull an action back and reschedule it on a different runner without incrementing `RetryCount`
- Implement idempotent `RetractAction()` on `RunnerAction` and `ComputeServiceSession`
- Add `POST jobs/{lsn}/retract` and `queues/{queueref}/jobs/{lsn}/retract` HTTP endpoints
- Add state machine documentation and per-state comments to `RunnerAction`
### Compute Race Fixes
- Fix race in `HandleActionUpdates` where actions enqueued between session abandon and scheduler tick were never abandoned, causing `GetActionResult` to return 202 indefinitely
- Fix queue `ActiveCount` race where `NotifyQueueActionComplete` was called after releasing `m_ResultsLock`, allowing callers to observe stale counters immediately after `GetActionResult` returned OK
### Logging Optimization and ANSI improvements
- Improve `AnsiColorStdoutSink` write efficiency — single write call, dirty-flag flush, `RwLock` instead of `std::mutex`
- Move ANSI color emission from sink into formatters via `Formatter::SetColorEnabled()`; remove `ColorRangeStart`/`End` from `LogMessage`
- Extract color helpers (`AnsiColorForLevel`, `StripAnsiSgrSequences`) into `helpers.h`
- Strip upstream ANSI SGR escapes in non-color output mode. This enables colour in log messages without polluting log files with ANSI control sequences
- Move `RotatingFileSink`, `JsonFormatter`, and `FullFormatter` from header-only to pimpl with `.cpp` files
### CLI / Exec Refactoring
- Extract `ExecSessionRunner` class from ~920-line `ExecUsingSession` into focused methods and a `ExecSessionConfig` struct
- Replace monolithic `ExecCommand` with subcommand-based architecture (`http`, `inproc`, `beacon`, `dump`, `buildlog`)
- Allow parent options to appear after subcommand name by parsing subcommand args permissively and forwarding unmatched tokens to the parent parser
### Testing Improvements
- Fix `--test-suite` filter being ignored due to accumulation with default wildcard filter
- Add test suite banners to test listener output
- Made `function.session.abandon_pending` test more robust
### Startup / Reliability Fixes
- Fix silent exit when a second zenserver instance detects a port conflict — use `ZEN_CONSOLE_*` for log calls that precede `InitializeLogging()`
- Fix two potential SIGSEGV paths during early startup: guard `sentry_options_new()` returning nullptr, and throw on `ZenServerState::Register()` returning nullptr instead of dereferencing
- Fail on unrecognized zenserver `--mode` instead of silently defaulting to store
### Other
- Show host details (hostname, platform, CPU count, memory) when discovering new compute workers
- Move frontend `html.zip` from source tree into build directory
- Add format specifications for Compact Binary and Compressed Buffer wire formats
- Add `WriteCompactBinaryObject` to zencore
- Extended `ConsoleTui` with additional functionality
- Add `--vscode` option to `xmake sln` for clangd / `compile_commands.json` support
- Disable compute/horde/nomad in release builds (not yet production-ready)
- Disable unintended `ASIO_HAS_IO_URING` enablement
- Fix crashpad patch missing leading whitespace
- Clean up code triggering gcc false positives
Diffstat (limited to 'src')
47 files changed, 4064 insertions, 2261 deletions
diff --git a/src/zen/cmds/exec_cmd.cpp b/src/zen/cmds/exec_cmd.cpp index cbc153e07..30e860a3f 100644 --- a/src/zen/cmds/exec_cmd.cpp +++ b/src/zen/cmds/exec_cmd.cpp @@ -18,6 +18,7 @@ #include <zencore/stream.h> #include <zencore/string.h> #include <zencore/system.h> +#include <zencore/thread.h> #include <zencore/timer.h> #include <zenhttp/httpclient.h> #include <zenhttp/packageformat.h> @@ -41,255 +42,122 @@ struct hash<zen::IoHash> : public zen::IoHash::Hasher namespace zen { -ExecCommand::ExecCommand() -{ - m_Options.add_options()("h,help", "Print help"); - m_Options.add_option("", "u", "hosturl", kHostUrlHelp, cxxopts::value(m_HostName), "<hosturl>"); - m_Options.add_option("", "", "log", "Action log directory", cxxopts::value(m_RecordingLogPath), "<path>"); - m_Options.add_option("", "p", "path", "Recording path (directory or .actionlog file)", cxxopts::value(m_RecordingPath), "<path>"); - m_Options.add_option("", "", "offset", "Recording replay start offset", cxxopts::value(m_Offset), "<offset>"); - m_Options.add_option("", "", "stride", "Recording replay stride", cxxopts::value(m_Stride), "<stride>"); - m_Options.add_option("", "", "limit", "Recording replay limit", cxxopts::value(m_Limit), "<limit>"); - m_Options.add_option("", "", "beacon", "Beacon path", cxxopts::value(m_BeaconPath), "<path>"); - m_Options.add_option("", "", "orch", "Orchestrator URL for worker discovery", cxxopts::value(m_OrchestratorUrl), "<url>"); - m_Options.add_option("", - "", - "mode", - "Select execution mode (http,inproc,dump,direct,beacon,buildlog)", - cxxopts::value(m_Mode)->default_value("http"), - "<string>"); - m_Options - .add_option("", "", "dump-actions", "Dump each action to console as it is dispatched", cxxopts::value(m_DumpActions), "<bool>"); - m_Options.add_option("", "o", "output", "Save action results to directory", cxxopts::value(m_OutputPath), "<path>"); - m_Options.add_option("", "", "binary", "Write output as binary packages instead of YAML", cxxopts::value(m_Binary), "<bool>"); - m_Options.add_option("", "", "quiet", "Quiet mode (less logging)", cxxopts::value(m_Quiet), "<bool>"); - m_Options.parse_positional("mode"); -} - -ExecCommand::~ExecCommand() -{ -} - -void -ExecCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) -{ - // Configure - - if (!ParseOptions(argc, argv)) - { - return; - } - - m_HostName = ResolveTargetHostSpec(m_HostName); - - if (m_RecordingPath.empty()) - { - throw OptionParseException("replay path is required!", m_Options.help()); - } - - m_VerboseLogging = GlobalOptions.IsVerbose; - m_QuietLogging = m_Quiet && !m_VerboseLogging; - - enum ExecMode - { - kHttp, - kDirect, - kInproc, - kDump, - kBeacon, - kBuildLog - } Mode; - - if (m_Mode == "http"sv) - { - Mode = kHttp; - } - else if (m_Mode == "direct"sv) - { - Mode = kDirect; - } - else if (m_Mode == "inproc"sv) - { - Mode = kInproc; - } - else if (m_Mode == "dump"sv) - { - Mode = kDump; - } - else if (m_Mode == "beacon"sv) - { - Mode = kBeacon; - } - else if (m_Mode == "buildlog"sv) - { - Mode = kBuildLog; - } - else - { - throw OptionParseException("invalid mode specified!", m_Options.help()); - } +namespace { - // Gather information from recording path - - std::unique_ptr<zen::compute::RecordingReader> Reader; - std::unique_ptr<zen::compute::UeRecordingReader> UeReader; - - std::filesystem::path RecordingPath{m_RecordingPath}; - - if (!std::filesystem::is_directory(RecordingPath)) - { - throw OptionParseException("replay path should be a directory path!", m_Options.help()); - } - else + static std::string EscapeHtml(std::string_view Input) { - if (std::filesystem::is_directory(RecordingPath / "cid")) + std::string Out; + Out.reserve(Input.size()); + for (char C : Input) { - Reader = std::make_unique<zen::compute::RecordingReader>(RecordingPath); - m_WorkerMap = Reader->ReadWorkers(); - m_ChunkResolver = Reader.get(); - m_RecordingReader = Reader.get(); - } - else - { - UeReader = std::make_unique<zen::compute::UeRecordingReader>(RecordingPath); - m_WorkerMap = UeReader->ReadWorkers(); - m_ChunkResolver = UeReader.get(); - m_RecordingReader = UeReader.get(); + switch (C) + { + case '&': + Out += "&"; + break; + case '<': + Out += "<"; + break; + case '>': + Out += ">"; + break; + case '"': + Out += """; + break; + case '\'': + Out += "'"; + break; + default: + Out += C; + } } + return Out; } - ZEN_CONSOLE("found {} workers, {} action items", m_WorkerMap.size(), m_RecordingReader->GetActionCount()); - - for (auto& Kv : m_WorkerMap) + static std::string EscapeJson(std::string_view Input) { - CbObject WorkerDesc = Kv.second.GetObject(); - const IoHash& WorkerId = Kv.first; - - RegisterWorkerFunctionsFromDescription(WorkerDesc, WorkerId); - - if (m_VerboseLogging) + std::string Out; + Out.reserve(Input.size()); + for (char C : Input) { - zen::ExtendableStringBuilder<1024> ObjStr; -# if 0 - zen::CompactBinaryToJson(WorkerDesc, ObjStr); - ZEN_CONSOLE("worker {}: {}", WorkerId, ObjStr); -# else - zen::CompactBinaryToYaml(WorkerDesc, ObjStr); - ZEN_CONSOLE("worker {}:\n{}", WorkerId, ObjStr); -# endif + switch (C) + { + case '"': + Out += "\\\""; + break; + case '\\': + Out += "\\\\"; + break; + case '\n': + Out += "\\n"; + break; + case '\r': + Out += "\\r"; + break; + case '\t': + Out += "\\t"; + break; + default: + if (static_cast<unsigned char>(C) < 0x20) + { + Out += fmt::format("\\u{:04x}", static_cast<unsigned>(static_cast<unsigned char>(C))); + } + else + { + Out += C; + } + } } + return Out; } - if (m_VerboseLogging) - { - EmitFunctionList(m_FunctionList); - } - - // Iterate over work items and dispatch or log them - - int ReturnValue = 0; - - Stopwatch ExecTimer; - - switch (Mode) - { - case kHttp: - // Forward requests to HTTP function service - ReturnValue = HttpExecute(); - break; - - case kDirect: - // Not currently supported - ReturnValue = LocalMessagingExecute(); - break; - - case kInproc: - // Handle execution in-core (by spawning child processes) - ReturnValue = InProcessExecute(); - break; - - case kDump: - // Dump high level information about actions to console - ReturnValue = DumpWorkItems(); - break; - - case kBeacon: - ReturnValue = BeaconExecute(); - break; - - case kBuildLog: - ReturnValue = BuildActionsLog(); - break; - - default: - ZEN_ERROR("Unknown operating mode! No work submitted"); - - ReturnValue = 1; - } +} // namespace - ZEN_CONSOLE("complete - took {}", NiceTimeSpanMs(ExecTimer.GetElapsedTimeMs())); - - if (!ReturnValue) - { - ZEN_CONSOLE("all work items completed successfully"); - } - else - { - ZEN_CONSOLE("some work items failed (code {})", ReturnValue); - } -} +////////////////////////////////////////////////////////////////////////// +// ExecSessionConfig — read-only configuration for a session run -int -ExecCommand::InProcessExecute() +struct ExecSessionConfig { - ZEN_ASSERT(m_ChunkResolver); - ChunkResolver& Resolver = *m_ChunkResolver; + zen::ChunkResolver& Resolver; + zen::compute::RecordingReaderBase& RecordingReader; + const std::unordered_map<zen::IoHash, zen::CbPackage>& WorkerMap; + std::vector<ExecFunctionDefinition>& FunctionList; // mutable for EmitFunctionListOnce + std::string_view OrchestratorUrl; + const std::filesystem::path& OutputPath; + int Offset = 0; + int Stride = 1; + int Limit = 0; + bool Verbose = false; + bool Quiet = false; + bool DumpActions = false; + bool Binary = false; +}; - zen::compute::ComputeServiceSession ComputeSession(Resolver); +////////////////////////////////////////////////////////////////////////// +// ExecSessionRunner — owns per-run state, drives the session lifecycle - std::filesystem::path TempPath = std::filesystem::absolute(".zen_temp"); - ComputeSession.AddLocalRunner(Resolver, TempPath); +class ExecSessionRunner +{ +public: + ExecSessionRunner(zen::compute::ComputeServiceSession& Session, const ExecSessionConfig& Config); + int Run(); - return ExecUsingSession(ComputeSession); -} +private: + // Types -int -ExecCommand::ExecUsingSession(zen::compute::ComputeServiceSession& ComputeSession) -{ struct JobTracker { public: - inline void Insert(int LsnField) - { - RwLock::ExclusiveLockScope _(Lock); - PendingJobs.insert(LsnField); - } - - inline bool IsEmpty() const - { - RwLock::SharedLockScope _(Lock); - return PendingJobs.empty(); - } - - inline void Remove(int CompleteLsn) - { - RwLock::ExclusiveLockScope _(Lock); - PendingJobs.erase(CompleteLsn); - } - - inline size_t GetSize() const - { - RwLock::SharedLockScope _(Lock); - return PendingJobs.size(); - } + void Insert(int LsnField); + bool IsEmpty() const; + void Remove(int CompleteLsn); + size_t GetSize() const; private: mutable RwLock Lock; std::unordered_set<int> PendingJobs; }; - JobTracker PendingJobs; - struct ActionSummaryEntry { int32_t Lsn = 0; @@ -307,664 +175,471 @@ ExecCommand::ExecUsingSession(zen::compute::ComputeServiceSession& ComputeSessio std::string ExecutionLocation; }; - std::mutex SummaryLock; - std::unordered_map<int32_t, ActionSummaryEntry> SummaryEntries; + // Methods + + std::string RegisterOrchestratorClient(); + void SendOrchestratorHeartbeat(); + void CompleteOrchestratorClient(); + void DrainCompletedJobs(); + void SaveResultOutput(int32_t CompleteLsn, CbPackage& ResultPackage); + void SaveActionOutput(int32_t Lsn, int RecordingIndex, const IoHash& ActionId, const CbObject& ActionObject); + void WriteSummaryFiles(); + void EmitFunctionListOnce(); + + // State + + zen::compute::ComputeServiceSession& m_Session; + ExecSessionConfig m_Config; + JobTracker m_PendingJobs; + std::mutex m_SummaryLock; + std::unordered_map<int32_t, ActionSummaryEntry> m_SummaryEntries; + int m_QueueId = 0; + std::string m_OrchestratorClientId; + Stopwatch m_OrchestratorHeartbeatTimer; + bool m_FunctionListEmittedOnce = false; + std::atomic<int> m_IsDraining{0}; +}; - ComputeSession.WaitUntilReady(); +////////////////////////////////////////////////////////////////////////// +// ExecSessionRunner::JobTracker - // Register as a client with the orchestrator (best-effort) +void +ExecSessionRunner::JobTracker::Insert(int LsnField) +{ + RwLock::ExclusiveLockScope _(Lock); + PendingJobs.insert(LsnField); +} - std::string OrchestratorClientId; +bool +ExecSessionRunner::JobTracker::IsEmpty() const +{ + RwLock::SharedLockScope _(Lock); + return PendingJobs.empty(); +} - if (!m_OrchestratorUrl.empty()) +void +ExecSessionRunner::JobTracker::Remove(int CompleteLsn) +{ + RwLock::ExclusiveLockScope _(Lock); + PendingJobs.erase(CompleteLsn); +} + +size_t +ExecSessionRunner::JobTracker::GetSize() const +{ + RwLock::SharedLockScope _(Lock); + return PendingJobs.size(); +} + +////////////////////////////////////////////////////////////////////////// +// ExecSessionRunner implementation + +ExecSessionRunner::ExecSessionRunner(zen::compute::ComputeServiceSession& Session, const ExecSessionConfig& Config) +: m_Session(Session) +, m_Config(Config) +{ +} + +std::string +ExecSessionRunner::RegisterOrchestratorClient() +{ + if (m_Config.OrchestratorUrl.empty()) { - try - { - HttpClient OrchestratorClient(m_OrchestratorUrl); + return {}; + } - CbObjectWriter Ann; - Ann << "session_id"sv << GetSessionId(); - Ann << "hostname"sv << std::string_view(GetMachineName()); + try + { + HttpClient OrchestratorClient(m_Config.OrchestratorUrl); - CbObjectWriter Meta; - Meta << "source"sv - << "zen-exec"sv; - Ann << "metadata"sv << Meta.Save(); + CbObjectWriter Ann; + Ann << "session_id"sv << GetSessionId(); + Ann << "hostname"sv << std::string_view(GetMachineName()); - auto Resp = OrchestratorClient.Post("/orch/clients", Ann.Save()); - if (Resp.IsSuccess()) - { - OrchestratorClientId = std::string(Resp.AsObject()["id"].AsString()); - ZEN_CONSOLE_INFO("registered with orchestrator as {}", OrchestratorClientId); - } - else - { - ZEN_WARN("failed to register with orchestrator (status {})", static_cast<int>(Resp.StatusCode)); - } + CbObjectWriter Meta; + Meta << "source"sv + << "zen-exec"sv; + Ann << "metadata"sv << Meta.Save(); + + auto Resp = OrchestratorClient.Post("/orch/clients", Ann.Save()); + if (Resp.IsSuccess()) + { + std::string ClientId = std::string(Resp.AsObject()["id"].AsString()); + ZEN_CONSOLE_INFO("registered with orchestrator as {}", ClientId); + return ClientId; } - catch (const std::exception& Ex) + else { - ZEN_WARN("failed to register with orchestrator: {}", Ex.what()); + ZEN_WARN("failed to register with orchestrator (status {})", static_cast<int>(Resp.StatusCode)); } } + catch (const std::exception& Ex) + { + ZEN_WARN("failed to register with orchestrator: {}", Ex.what()); + } - Stopwatch OrchestratorHeartbeatTimer; + return {}; +} - auto SendOrchestratorHeartbeat = [&] { - if (OrchestratorClientId.empty() || OrchestratorHeartbeatTimer.GetElapsedTimeMs() < 30'000) - { - return; - } - OrchestratorHeartbeatTimer.Reset(); +void +ExecSessionRunner::SendOrchestratorHeartbeat() +{ + if (m_OrchestratorClientId.empty() || m_OrchestratorHeartbeatTimer.GetElapsedTimeMs() < 30'000) + { + return; + } + m_OrchestratorHeartbeatTimer.Reset(); + try + { + HttpClient OrchestratorClient(m_Config.OrchestratorUrl); + std::ignore = OrchestratorClient.Post(fmt::format("/orch/clients/{}/update", m_OrchestratorClientId)); + } + catch (...) + { + } +} + +void +ExecSessionRunner::CompleteOrchestratorClient() +{ + if (!m_OrchestratorClientId.empty()) + { try { - HttpClient OrchestratorClient(m_OrchestratorUrl); - std::ignore = OrchestratorClient.Post(fmt::format("/orch/clients/{}/update", OrchestratorClientId)); + HttpClient OrchestratorClient(m_Config.OrchestratorUrl); + std::ignore = OrchestratorClient.Post(fmt::format("/orch/clients/{}/complete", m_OrchestratorClientId)); } catch (...) { } - }; - - auto ClientCleanup = MakeGuard([&] { - if (!OrchestratorClientId.empty()) - { - try - { - HttpClient OrchestratorClient(m_OrchestratorUrl); - std::ignore = OrchestratorClient.Post(fmt::format("/orch/clients/{}/complete", OrchestratorClientId)); - } - catch (...) - { - } - } - }); - - // Create a queue to group all actions from this exec session - - CbObjectWriter Metadata; - Metadata << "source"sv - << "zen-exec"sv; - - auto QueueResult = ComputeSession.CreateQueue("zen-exec", Metadata.Save()); - const int QueueId = QueueResult.QueueId; - if (!QueueId) - { - ZEN_ERROR("failed to create compute queue"); - return 1; } +} - auto QueueCleanup = MakeGuard([&] { ComputeSession.DeleteQueue(QueueId); }); - - if (!m_OutputPath.empty()) +void +ExecSessionRunner::DrainCompletedJobs() +{ + if (m_IsDraining.exchange(1)) { - zen::CreateDirectories(m_OutputPath); + return; } - std::atomic<int> IsDraining{0}; + auto _ = MakeGuard([&] { m_IsDraining.store(0, std::memory_order_release); }); - auto DrainCompletedJobs = [&] { - if (IsDraining.exchange(1)) - { - return; - } - - auto _ = MakeGuard([&] { IsDraining.store(0, std::memory_order_release); }); + CbObjectWriter Cbo; + m_Session.GetQueueCompleted(m_QueueId, Cbo); - CbObjectWriter Cbo; - ComputeSession.GetQueueCompleted(QueueId, Cbo); - - if (CbObject Completed = Cbo.Save()) + if (CbObject Completed = Cbo.Save()) + { + for (auto& It : Completed["completed"sv]) { - for (auto& It : Completed["completed"sv]) - { - int32_t CompleteLsn = It.AsInt32(); + int32_t CompleteLsn = It.AsInt32(); - CbPackage ResultPackage; - HttpResponseCode Response = ComputeSession.GetActionResult(CompleteLsn, /* out */ ResultPackage); + CbPackage ResultPackage; + HttpResponseCode Response = m_Session.GetActionResult(CompleteLsn, /* out */ ResultPackage); - if (Response == HttpResponseCode::OK) + if (Response == HttpResponseCode::OK) + { + if (!m_Config.OutputPath.empty() && ResultPackage) { - if (!m_OutputPath.empty() && ResultPackage) - { - int OutputAttachments = 0; - uint64_t OutputBytes = 0; - - if (!m_Binary) - { - // Write the root object as YAML - ExtendableStringBuilder<4096> YamlStr; - CompactBinaryToYaml(ResultPackage.GetObject(), YamlStr); - - std::string_view Yaml = YamlStr; - zen::WriteFile(m_OutputPath / fmt::format("{}.result.yaml", CompleteLsn), - IoBuffer(IoBuffer::Clone, Yaml.data(), Yaml.size())); - - // Write decompressed attachments - auto Attachments = ResultPackage.GetAttachments(); - - if (!Attachments.empty()) - { - std::filesystem::path AttDir = m_OutputPath / fmt::format("{}.result.attachments", CompleteLsn); - zen::CreateDirectories(AttDir); - - for (const CbAttachment& Att : Attachments) - { - ++OutputAttachments; - - IoHash AttHash = Att.GetHash(); - - if (Att.IsCompressedBinary()) - { - SharedBuffer Decompressed = Att.AsCompressedBinary().Decompress(); - OutputBytes += Decompressed.GetSize(); - zen::WriteFile(AttDir / AttHash.ToHexString(), - IoBuffer(IoBuffer::Clone, Decompressed.GetData(), Decompressed.GetSize())); - } - else - { - SharedBuffer Binary = Att.AsBinary(); - OutputBytes += Binary.GetSize(); - zen::WriteFile(AttDir / AttHash.ToHexString(), - IoBuffer(IoBuffer::Clone, Binary.GetData(), Binary.GetSize())); - } - } - } - - if (!m_QuietLogging) - { - ZEN_CONSOLE("saved result: {}/{}.result.yaml ({} attachments)", - m_OutputPath.string(), - CompleteLsn, - OutputAttachments); - } - } - else - { - CompositeBuffer Serialized = FormatPackageMessageBuffer(ResultPackage); - zen::WriteFile(m_OutputPath / fmt::format("{}.result.pkg", CompleteLsn), std::move(Serialized)); - - for (const CbAttachment& Att : ResultPackage.GetAttachments()) - { - ++OutputAttachments; - OutputBytes += Att.AsBinary().GetSize(); - } - - if (!m_QuietLogging) - { - ZEN_CONSOLE("saved result: {}/{}.result.pkg", m_OutputPath.string(), CompleteLsn); - } - } - - std::lock_guard Lock(SummaryLock); - if (auto It2 = SummaryEntries.find(CompleteLsn); It2 != SummaryEntries.end()) - { - It2->second.OutputAttachments = OutputAttachments; - It2->second.OutputBytes = OutputBytes; - } - } + SaveResultOutput(CompleteLsn, ResultPackage); + } - PendingJobs.Remove(CompleteLsn); + m_PendingJobs.Remove(CompleteLsn); - ZEN_CONSOLE("completed: LSN {} ({} still pending)", CompleteLsn, PendingJobs.GetSize()); - } + ZEN_CONSOLE("completed: LSN {} ({} still pending)", CompleteLsn, m_PendingJobs.GetSize()); } } - }; - - // Describe workers + } +} - ZEN_CONSOLE("describing {} workers", m_WorkerMap.size()); +void +ExecSessionRunner::SaveResultOutput(int32_t CompleteLsn, CbPackage& ResultPackage) +{ + int OutputAttachments = 0; + uint64_t OutputBytes = 0; - for (auto Kv : m_WorkerMap) + if (!m_Config.Binary) { - CbPackage WorkerDesc = Kv.second; + // Write the root object as YAML + ExtendableStringBuilder<4096> YamlStr; + CompactBinaryToYaml(ResultPackage.GetObject(), YamlStr); - ComputeSession.RegisterWorker(WorkerDesc); - } + std::string_view Yaml = YamlStr; + zen::WriteFile(m_Config.OutputPath / fmt::format("{}.result.yaml", CompleteLsn), + IoBuffer(IoBuffer::Clone, Yaml.data(), Yaml.size())); - // Then submit work items + // Write decompressed attachments + auto Attachments = ResultPackage.GetAttachments(); - int FailedWorkCounter = 0; - size_t RemainingWorkItems = m_RecordingReader->GetActionCount(); - int SubmittedWorkItems = 0; - - ZEN_CONSOLE("submitting {} work items", RemainingWorkItems); + if (!Attachments.empty()) + { + std::filesystem::path AttDir = m_Config.OutputPath / fmt::format("{}.result.attachments", CompleteLsn); + zen::CreateDirectories(AttDir); - int OffsetCounter = m_Offset; - int StrideCounter = m_Stride; + for (const CbAttachment& Att : Attachments) + { + ++OutputAttachments; - auto ShouldSchedule = [&]() -> bool { - if (m_Limit && SubmittedWorkItems >= m_Limit) - { - // Limit reached, ignore + IoHash AttHash = Att.GetHash(); - return false; + if (Att.IsCompressedBinary()) + { + SharedBuffer Decompressed = Att.AsCompressedBinary().Decompress(); + OutputBytes += Decompressed.GetSize(); + zen::WriteFile(AttDir / AttHash.ToHexString(), + IoBuffer(IoBuffer::Clone, Decompressed.GetData(), Decompressed.GetSize())); + } + else + { + SharedBuffer Binary = Att.AsBinary(); + OutputBytes += Binary.GetSize(); + zen::WriteFile(AttDir / AttHash.ToHexString(), IoBuffer(IoBuffer::Clone, Binary.GetData(), Binary.GetSize())); + } + } } - if (OffsetCounter && OffsetCounter--) + if (!m_Config.Quiet) { - // Still in offset, ignore - - return false; + ZEN_CONSOLE("saved result: {}/{}.result.yaml ({} attachments)", m_Config.OutputPath.string(), CompleteLsn, OutputAttachments); } + } + else + { + CompositeBuffer Serialized = FormatPackageMessageBuffer(ResultPackage); + zen::WriteFile(m_Config.OutputPath / fmt::format("{}.result.pkg", CompleteLsn), std::move(Serialized)); - if (--StrideCounter == 0) + for (const CbAttachment& Att : ResultPackage.GetAttachments()) { - StrideCounter = m_Stride; - - return true; + ++OutputAttachments; + OutputBytes += Att.AsBinary().GetSize(); } - return false; - }; - - int TargetParallelism = 8; + if (!m_Config.Quiet) + { + ZEN_CONSOLE("saved result: {}/{}.result.pkg", m_Config.OutputPath.string(), CompleteLsn); + } + } - if (OffsetCounter || StrideCounter || m_Limit) + std::lock_guard Lock(m_SummaryLock); + if (auto It2 = m_SummaryEntries.find(CompleteLsn); It2 != m_SummaryEntries.end()) { - TargetParallelism = 1; + It2->second.OutputAttachments = OutputAttachments; + It2->second.OutputBytes = OutputBytes; } +} - std::atomic<int> RecordingIndex{0}; +void +ExecSessionRunner::SaveActionOutput(int32_t Lsn, int RecordingIndex, const IoHash& ActionId, const CbObject& ActionObject) +{ + ActionSummaryEntry Entry; + Entry.Lsn = Lsn; + Entry.RecordingIndex = RecordingIndex; + Entry.ActionId = ActionId; + Entry.FunctionName = std::string(ActionObject["Function"sv].AsString()); - m_RecordingReader->IterateActions( - [&](CbObject ActionObject, const IoHash& ActionId) { - // Enqueue job + if (!m_Config.Binary) + { + // Write action object as YAML + ExtendableStringBuilder<4096> YamlStr; + CompactBinaryToYaml(ActionObject, YamlStr); - const int CurrentRecordingIndex = RecordingIndex++; + std::string_view Yaml = YamlStr; + zen::WriteFile(m_Config.OutputPath / fmt::format("{}.action.yaml", Lsn), IoBuffer(IoBuffer::Clone, Yaml.data(), Yaml.size())); - Stopwatch SubmitTimer; + // Write decompressed input attachments + std::filesystem::path AttDir = m_Config.OutputPath / fmt::format("{}.action.attachments", Lsn); + bool AttDirCreated = false; - const int Priority = 0; + ActionObject.IterateAttachments([&](CbFieldView Field) { + IoHash AttachCid = Field.AsAttachment(); + ++Entry.InputAttachments; - if (ShouldSchedule()) + if (IoBuffer ChunkData = m_Config.Resolver.FindChunkByCid(AttachCid)) { - if (m_VerboseLogging) - { - int AttachmentCount = 0; - uint64_t AttachmentBytes = 0; - eastl::hash_set<IoHash> ReferencedChunks; + IoHash RawHash; + uint64_t RawSize = 0; + CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(ChunkData), RawHash, RawSize); + SharedBuffer Decompressed = Compressed.Decompress(); - ActionObject.IterateAttachments([&](CbFieldView Field) { - IoHash AttachData = Field.AsAttachment(); + Entry.InputBytes += Decompressed.GetSize(); - ReferencedChunks.insert(AttachData); - ++AttachmentCount; - - if (IoBuffer ChunkData = m_ChunkResolver->FindChunkByCid(AttachData)) - { - AttachmentBytes += ChunkData.GetSize(); - } - }); - - zen::ExtendableStringBuilder<1024> ObjStr; - zen::CompactBinaryToJson(ActionObject, ObjStr); - ZEN_CONSOLE("work item {} ({} attachments, {} bytes): {}", - ActionId, - AttachmentCount, - NiceBytes(AttachmentBytes), - ObjStr); - } - - if (m_DumpActions) + if (!AttDirCreated) { - int AttachmentCount = 0; - uint64_t AttachmentBytes = 0; - - ActionObject.IterateAttachments([&](CbFieldView Field) { - IoHash AttachData = Field.AsAttachment(); - - ++AttachmentCount; - - if (IoBuffer ChunkData = m_ChunkResolver->FindChunkByCid(AttachData)) - { - AttachmentBytes += ChunkData.GetSize(); - } - }); - - zen::ExtendableStringBuilder<1024> ObjStr; - zen::CompactBinaryToYaml(ActionObject, ObjStr); - ZEN_CONSOLE("action {} ({} attachments, {}):\n{}", ActionId, AttachmentCount, NiceBytes(AttachmentBytes), ObjStr); + zen::CreateDirectories(AttDir); + AttDirCreated = true; } - if (zen::compute::ComputeServiceSession::EnqueueResult EnqueueResult = - ComputeSession.EnqueueActionToQueue(QueueId, ActionObject, Priority)) - { - const int32_t LsnField = EnqueueResult.Lsn; - - --RemainingWorkItems; - ++SubmittedWorkItems; - - if (!m_QuietLogging) - { - ZEN_CONSOLE("submitted work item #{} - LSN {} - {}. {} remaining", - SubmittedWorkItems, - LsnField, - NiceTimeSpanMs(SubmitTimer.GetElapsedTimeMs()), - RemainingWorkItems); - } - - if (!m_OutputPath.empty()) - { - ActionSummaryEntry Entry; - Entry.Lsn = LsnField; - Entry.RecordingIndex = CurrentRecordingIndex; - Entry.ActionId = ActionId; - Entry.FunctionName = std::string(ActionObject["Function"sv].AsString()); - - if (!m_Binary) - { - // Write action object as YAML - ExtendableStringBuilder<4096> YamlStr; - CompactBinaryToYaml(ActionObject, YamlStr); - - std::string_view Yaml = YamlStr; - zen::WriteFile(m_OutputPath / fmt::format("{}.action.yaml", LsnField), - IoBuffer(IoBuffer::Clone, Yaml.data(), Yaml.size())); - - // Write decompressed input attachments - std::filesystem::path AttDir = m_OutputPath / fmt::format("{}.action.attachments", LsnField); - bool AttDirCreated = false; - - ActionObject.IterateAttachments([&](CbFieldView Field) { - IoHash AttachCid = Field.AsAttachment(); - ++Entry.InputAttachments; - - if (IoBuffer ChunkData = m_ChunkResolver->FindChunkByCid(AttachCid)) - { - IoHash RawHash; - uint64_t RawSize = 0; - CompressedBuffer Compressed = - CompressedBuffer::FromCompressed(SharedBuffer(ChunkData), RawHash, RawSize); - SharedBuffer Decompressed = Compressed.Decompress(); - - Entry.InputBytes += Decompressed.GetSize(); - - if (!AttDirCreated) - { - zen::CreateDirectories(AttDir); - AttDirCreated = true; - } - - zen::WriteFile(AttDir / AttachCid.ToHexString(), - IoBuffer(IoBuffer::Clone, Decompressed.GetData(), Decompressed.GetSize())); - } - }); - - if (!m_QuietLogging) - { - ZEN_CONSOLE("saved action: {}/{}.action.yaml ({} attachments)", - m_OutputPath.string(), - LsnField, - Entry.InputAttachments); - } - } - else - { - // Build a CbPackage from the action and write as .pkg - CbPackage ActionPackage; - ActionPackage.SetObject(ActionObject); - - ActionObject.IterateAttachments([&](CbFieldView Field) { - IoHash AttachCid = Field.AsAttachment(); - ++Entry.InputAttachments; - - if (IoBuffer ChunkData = m_ChunkResolver->FindChunkByCid(AttachCid)) - { - IoHash RawHash; - uint64_t RawSize = 0; - CompressedBuffer Compressed = - CompressedBuffer::FromCompressed(SharedBuffer(ChunkData), RawHash, RawSize); - - Entry.InputBytes += ChunkData.GetSize(); - ActionPackage.AddAttachment(CbAttachment(std::move(Compressed), RawHash)); - } - }); - - CompositeBuffer Serialized = FormatPackageMessageBuffer(ActionPackage); - zen::WriteFile(m_OutputPath / fmt::format("{}.action.pkg", LsnField), std::move(Serialized)); - - if (!m_QuietLogging) - { - ZEN_CONSOLE("saved action: {}/{}.action.pkg", m_OutputPath.string(), LsnField); - } - } - - std::lock_guard Lock(SummaryLock); - SummaryEntries.emplace(LsnField, std::move(Entry)); - } + zen::WriteFile(AttDir / AttachCid.ToHexString(), IoBuffer(IoBuffer::Clone, Decompressed.GetData(), Decompressed.GetSize())); + } + }); - PendingJobs.Insert(LsnField); - } - else - { - if (!m_QuietLogging) - { - std::string_view FunctionName = ActionObject["Function"sv].AsString(); - const Guid FunctionVersion = ActionObject["FunctionVersion"sv].AsUuid(); - const Guid BuildSystemVersion = ActionObject["BuildSystemVersion"sv].AsUuid(); + if (!m_Config.Quiet) + { + ZEN_CONSOLE("saved action: {}/{}.action.yaml ({} attachments)", m_Config.OutputPath.string(), Lsn, Entry.InputAttachments); + } + } + else + { + // Build a CbPackage from the action and write as .pkg + CbPackage ActionPackage; + ActionPackage.SetObject(ActionObject); - ZEN_ERROR( - "failed to resolve function for work with (Function:{},FunctionVersion:{},BuildSystemVersion:{}). Work " - "descriptor " - "at: 'file://{}'", - std::string(FunctionName), - FunctionVersion, - BuildSystemVersion, - "<null>"); + ActionObject.IterateAttachments([&](CbFieldView Field) { + IoHash AttachCid = Field.AsAttachment(); + ++Entry.InputAttachments; - EmitFunctionListOnce(m_FunctionList); - } + if (IoBuffer ChunkData = m_Config.Resolver.FindChunkByCid(AttachCid)) + { + IoHash RawHash; + uint64_t RawSize = 0; + CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(ChunkData), RawHash, RawSize); - ++FailedWorkCounter; - } + Entry.InputBytes += ChunkData.GetSize(); + ActionPackage.AddAttachment(CbAttachment(std::move(Compressed), RawHash)); } + }); - // Check for completed work + CompositeBuffer Serialized = FormatPackageMessageBuffer(ActionPackage); + zen::WriteFile(m_Config.OutputPath / fmt::format("{}.action.pkg", Lsn), std::move(Serialized)); - DrainCompletedJobs(); - SendOrchestratorHeartbeat(); - }, - TargetParallelism); + if (!m_Config.Quiet) + { + ZEN_CONSOLE("saved action: {}/{}.action.pkg", m_Config.OutputPath.string(), Lsn); + } + } - // Wait until all pending work is complete + std::lock_guard Lock(m_SummaryLock); + m_SummaryEntries.emplace(Lsn, std::move(Entry)); +} - while (!PendingJobs.IsEmpty()) +void +ExecSessionRunner::WriteSummaryFiles() +{ + if (m_Config.OutputPath.empty() || m_SummaryEntries.empty()) { - // TODO: improve this logic - zen::Sleep(500); - - DrainCompletedJobs(); - SendOrchestratorHeartbeat(); + return; } // Merge timing data from queue history into summary entries - if (!SummaryEntries.empty()) + // RunnerAction::State indices (can't include functionrunner.h from here) + constexpr int kStateNew = 0; + constexpr int kStatePending = 1; + constexpr int kStateRunning = 3; + constexpr int kStateCompleted = 4; // first terminal state + constexpr int kStateCount = 8; + + for (const auto& HistEntry : m_Session.GetQueueHistory(m_QueueId, 0)) { - // RunnerAction::State indices (can't include functionrunner.h from here) - constexpr int kStateNew = 0; - constexpr int kStatePending = 1; - constexpr int kStateRunning = 3; - constexpr int kStateCompleted = 4; // first terminal state - constexpr int kStateCount = 8; - - for (const auto& HistEntry : ComputeSession.GetQueueHistory(QueueId, 0)) + std::lock_guard Lock(m_SummaryLock); + if (auto It = m_SummaryEntries.find(HistEntry.Lsn); It != m_SummaryEntries.end()) { - std::lock_guard Lock(SummaryLock); - if (auto It = SummaryEntries.find(HistEntry.Lsn); It != SummaryEntries.end()) + // Find terminal state timestamp (Completed, Failed, Abandoned, or Cancelled) + uint64_t EndTick = 0; + for (int S = kStateCompleted; S < kStateCount; ++S) { - // Find terminal state timestamp (Completed, Failed, Abandoned, or Cancelled) - uint64_t EndTick = 0; - for (int S = kStateCompleted; S < kStateCount; ++S) + if (HistEntry.Timestamps[S] != 0) { - if (HistEntry.Timestamps[S] != 0) - { - EndTick = HistEntry.Timestamps[S]; - break; - } + EndTick = HistEntry.Timestamps[S]; + break; } - uint64_t StartTick = HistEntry.Timestamps[kStateNew]; - if (EndTick > StartTick) - { - It->second.WallSeconds = float(double(EndTick - StartTick) / double(TimeSpan::TicksPerSecond)); - } - It->second.CpuSeconds = HistEntry.CpuSeconds; - It->second.SubmittedTicks = HistEntry.Timestamps[kStatePending]; - It->second.StartedTicks = HistEntry.Timestamps[kStateRunning]; - It->second.ExecutionLocation = HistEntry.ExecutionLocation; } + uint64_t StartTick = HistEntry.Timestamps[kStateNew]; + if (EndTick > StartTick) + { + It->second.WallSeconds = float(double(EndTick - StartTick) / double(TimeSpan::TicksPerSecond)); + } + It->second.CpuSeconds = HistEntry.CpuSeconds; + It->second.SubmittedTicks = HistEntry.Timestamps[kStatePending]; + It->second.StartedTicks = HistEntry.Timestamps[kStateRunning]; + It->second.ExecutionLocation = HistEntry.ExecutionLocation; } } - // Write summary file if output path is set + // Sort entries by recording index - if (!m_OutputPath.empty() && !SummaryEntries.empty()) + std::vector<ActionSummaryEntry> Sorted; + Sorted.reserve(m_SummaryEntries.size()); + for (auto& [_, Entry] : m_SummaryEntries) { - std::vector<ActionSummaryEntry> Sorted; - Sorted.reserve(SummaryEntries.size()); - for (auto& [_, Entry] : SummaryEntries) - { - Sorted.push_back(std::move(Entry)); - } - - std::sort(Sorted.begin(), Sorted.end(), [](const ActionSummaryEntry& A, const ActionSummaryEntry& B) { - return A.RecordingIndex < B.RecordingIndex; - }); + Sorted.push_back(std::move(Entry)); + } - auto FormatTimestamp = [](uint64_t Ticks) -> std::string { - if (Ticks == 0) - { - return "-"; - } - return DateTime(Ticks).ToString("%H:%M:%S.%s"); - }; - - ExtendableStringBuilder<4096> Summary; - Summary.Append(fmt::format("{:<8} {:<8} {:<40} {:<40} {:>8} {:>12} {:>8} {:>12} {:>8} {:>8} {:>12} {:>12} {:<24}\n", - "LSN", - "Index", - "ActionId", - "Function", - "InAtt", - "InBytes", - "OutAtt", - "OutBytes", - "Wall(s)", - "CPU(s)", - "Submitted", - "Started", - "Location")); - Summary.Append(fmt::format("{:-<8} {:-<8} {:-<40} {:-<40} {:-<8} {:-<12} {:-<8} {:-<12} {:-<8} {:-<8} {:-<12} {:-<12} {:-<24}\n", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "")); + std::sort(Sorted.begin(), Sorted.end(), [](const ActionSummaryEntry& A, const ActionSummaryEntry& B) { + return A.RecordingIndex < B.RecordingIndex; + }); - for (const ActionSummaryEntry& Entry : Sorted) + auto FormatTimestamp = [](uint64_t Ticks) -> std::string { + if (Ticks == 0) { - Summary.Append(fmt::format("{:<8} {:<8} {:<40} {:<40} {:>8} {:>12} {:>8} {:>12} {:>8.2f} {:>8.2f} {:>12} {:>12} {:<24}\n", - Entry.Lsn, - Entry.RecordingIndex, - Entry.ActionId, - Entry.FunctionName, - Entry.InputAttachments, - NiceBytes(Entry.InputBytes), - Entry.OutputAttachments, - NiceBytes(Entry.OutputBytes), - Entry.WallSeconds, - Entry.CpuSeconds, - FormatTimestamp(Entry.SubmittedTicks), - FormatTimestamp(Entry.StartedTicks), - Entry.ExecutionLocation)); + return "-"; } + return DateTime(Ticks).ToString("%H:%M:%S.%s"); + }; - std::filesystem::path SummaryPath = m_OutputPath / "summary.txt"; - std::string_view SummaryStr = Summary; - zen::WriteFile(SummaryPath, IoBuffer(IoBuffer::Clone, SummaryStr.data(), SummaryStr.size())); + // Write summary.txt + + ExtendableStringBuilder<4096> Summary; + Summary.Append(fmt::format("{:<8} {:<8} {:<40} {:<40} {:>8} {:>12} {:>8} {:>12} {:>8} {:>8} {:>12} {:>12} {:<24}\n", + "LSN", + "Index", + "ActionId", + "Function", + "InAtt", + "InBytes", + "OutAtt", + "OutBytes", + "Wall(s)", + "CPU(s)", + "Submitted", + "Started", + "Location")); + Summary.Append(fmt::format("{:-<8} {:-<8} {:-<40} {:-<40} {:-<8} {:-<12} {:-<8} {:-<12} {:-<8} {:-<8} {:-<12} {:-<12} {:-<24}\n", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "")); + + for (const ActionSummaryEntry& Entry : Sorted) + { + Summary.Append(fmt::format("{:<8} {:<8} {:<40} {:<40} {:>8} {:>12} {:>8} {:>12} {:>8.2f} {:>8.2f} {:>12} {:>12} {:<24}\n", + Entry.Lsn, + Entry.RecordingIndex, + Entry.ActionId, + Entry.FunctionName, + Entry.InputAttachments, + NiceBytes(Entry.InputBytes), + Entry.OutputAttachments, + NiceBytes(Entry.OutputBytes), + Entry.WallSeconds, + Entry.CpuSeconds, + FormatTimestamp(Entry.SubmittedTicks), + FormatTimestamp(Entry.StartedTicks), + Entry.ExecutionLocation)); + } - ZEN_CONSOLE("wrote summary to {}", SummaryPath.string()); + std::filesystem::path SummaryPath = m_Config.OutputPath / "summary.txt"; + std::string_view SummaryStr = Summary; + zen::WriteFile(SummaryPath, IoBuffer(IoBuffer::Clone, SummaryStr.data(), SummaryStr.size())); - if (!m_Binary) - { - auto EscapeHtml = [](std::string_view Input) -> std::string { - std::string Out; - Out.reserve(Input.size()); - for (char C : Input) - { - switch (C) - { - case '&': - Out += "&"; - break; - case '<': - Out += "<"; - break; - case '>': - Out += ">"; - break; - case '"': - Out += """; - break; - case '\'': - Out += "'"; - break; - default: - Out += C; - } - } - return Out; - }; + ZEN_CONSOLE("wrote summary to {}", SummaryPath.string()); - auto EscapeJson = [](std::string_view Input) -> std::string { - std::string Out; - Out.reserve(Input.size()); - for (char C : Input) - { - switch (C) - { - case '"': - Out += "\\\""; - break; - case '\\': - Out += "\\\\"; - break; - case '\n': - Out += "\\n"; - break; - case '\r': - Out += "\\r"; - break; - case '\t': - Out += "\\t"; - break; - default: - if (static_cast<unsigned char>(C) < 0x20) - { - Out += fmt::format("\\u{:04x}", static_cast<unsigned>(static_cast<unsigned char>(C))); - } - else - { - Out += C; - } - } - } - return Out; - }; + // Write summary.html - ExtendableStringBuilder<8192> Html; + if (!m_Config.Binary) + { + ExtendableStringBuilder<8192> Html; - Html.Append(std::string_view(R"(<!DOCTYPE html> + Html.Append(std::string_view(R"(<!DOCTYPE html> <html><head><meta charset="utf-8"><title>Exec Summary</title> <style> body{font-family:system-ui,sans-serif;margin:20px;background:#fafafa} @@ -1007,51 +682,50 @@ a:hover{text-decoration:underline} const DATA=[ )")); - std::string_view ResultExt = ".result.yaml"; - std::string_view ActionExt = ".action.yaml"; + std::string_view ResultExt = ".result.yaml"; + std::string_view ActionExt = ".action.yaml"; - for (const ActionSummaryEntry& Entry : Sorted) + for (const ActionSummaryEntry& Entry : Sorted) + { + std::string SafeName = EscapeJson(EscapeHtml(Entry.FunctionName)); + std::string ActionIdStr = Entry.ActionId.ToHexString(); + std::string ActionLink; + if (!ActionExt.empty()) { - std::string SafeName = EscapeJson(EscapeHtml(Entry.FunctionName)); - std::string ActionIdStr = Entry.ActionId.ToHexString(); - std::string ActionLink; - if (!ActionExt.empty()) - { - ActionLink = EscapeJson(fmt::format(" <a href=\"{}{}\">[action]</a>", Entry.Lsn, ActionExt)); - } - - // Indices: 0=lsn, 1=idx, 2=actionId, 3=fn, 4=inAtt, 5=inBytes, 6=outAtt, 7=outBytes, - // 8=wall, 9=cpu, 10=niceBytesIn, 11=niceBytesOut, 12=actionLink, - // 13=submittedTicks, 14=startedTicks, 15=submittedDisplay, 16=startedDisplay, - // 17=location - Html.Append( - fmt::format("[{},{},\"{}\",\"{}\",{},{},{},{},{:.6f},{:.6f},\"{}\",\"{}\",\"{}\",{},{},\"{}\",\"{}\",\"{}\"],\n", - Entry.Lsn, - Entry.RecordingIndex, - ActionIdStr, - SafeName, - Entry.InputAttachments, - Entry.InputBytes, - Entry.OutputAttachments, - Entry.OutputBytes, - Entry.WallSeconds, - Entry.CpuSeconds, - EscapeJson(NiceBytes(Entry.InputBytes)), - EscapeJson(NiceBytes(Entry.OutputBytes)), - ActionLink, - Entry.SubmittedTicks, - Entry.StartedTicks, - FormatTimestamp(Entry.SubmittedTicks), - FormatTimestamp(Entry.StartedTicks), - EscapeJson(EscapeHtml(Entry.ExecutionLocation)))); + ActionLink = EscapeJson(fmt::format(" <a href=\"{}{}\">[action]</a>", Entry.Lsn, ActionExt)); } - Html.Append(fmt::format(R"(]; + // Indices: 0=lsn, 1=idx, 2=actionId, 3=fn, 4=inAtt, 5=inBytes, 6=outAtt, 7=outBytes, + // 8=wall, 9=cpu, 10=niceBytesIn, 11=niceBytesOut, 12=actionLink, + // 13=submittedTicks, 14=startedTicks, 15=submittedDisplay, 16=startedDisplay, + // 17=location + Html.Append(fmt::format("[{},{},\"{}\",\"{}\",{},{},{},{},{:.6f},{:.6f},\"{}\",\"{}\",\"{}\",{},{},\"{}\",\"{}\",\"{}\"],\n", + Entry.Lsn, + Entry.RecordingIndex, + ActionIdStr, + SafeName, + Entry.InputAttachments, + Entry.InputBytes, + Entry.OutputAttachments, + Entry.OutputBytes, + Entry.WallSeconds, + Entry.CpuSeconds, + EscapeJson(NiceBytes(Entry.InputBytes)), + EscapeJson(NiceBytes(Entry.OutputBytes)), + ActionLink, + Entry.SubmittedTicks, + Entry.StartedTicks, + FormatTimestamp(Entry.SubmittedTicks), + FormatTimestamp(Entry.StartedTicks), + EscapeJson(EscapeHtml(Entry.ExecutionLocation)))); + } + + Html.Append(fmt::format(R"(]; const RESULT_EXT="{}"; )", - ResultExt)); + ResultExt)); - Html.Append(std::string_view(R"JS((function(){ + Html.Append(std::string_view(R"JS((function(){ const ROW_H=33,BUF=20; const container=document.getElementById("container"); const tbody=container.querySelector("tbody"); @@ -1158,14 +832,244 @@ document.getElementById("csvBtn").addEventListener("click",()=>{ </script></body></html> )JS")); - std::filesystem::path HtmlPath = m_OutputPath / "summary.html"; - std::string_view HtmlStr = Html; - zen::WriteFile(HtmlPath, IoBuffer(IoBuffer::Clone, HtmlStr.data(), HtmlStr.size())); + std::filesystem::path HtmlPath = m_Config.OutputPath / "summary.html"; + std::string_view HtmlStr = Html; + zen::WriteFile(HtmlPath, IoBuffer(IoBuffer::Clone, HtmlStr.data(), HtmlStr.size())); + + ZEN_CONSOLE("wrote HTML summary to {}", HtmlPath.string()); + } +} + +void +ExecSessionRunner::EmitFunctionListOnce() +{ + if (!m_FunctionListEmittedOnce) + { + ExecCommand::EmitFunctionList(m_Config.FunctionList); + m_FunctionListEmittedOnce = true; + } +} + +int +ExecSessionRunner::Run() +{ + m_Session.WaitUntilReady(); + + // Register as a client with the orchestrator (best-effort) + + m_OrchestratorClientId = RegisterOrchestratorClient(); + + auto ClientCleanup = MakeGuard([&] { CompleteOrchestratorClient(); }); + + // Create a queue to group all actions from this exec session + + CbObjectWriter Metadata; + Metadata << "source"sv + << "zen-exec"sv; + + auto QueueResult = m_Session.CreateQueue("zen-exec", Metadata.Save()); + const int QueueId = QueueResult.QueueId; + if (!QueueId) + { + ZEN_ERROR("failed to create compute queue"); + return 1; + } + + m_QueueId = QueueId; + + auto QueueCleanup = MakeGuard([&] { m_Session.DeleteQueue(QueueId); }); + + if (!m_Config.OutputPath.empty()) + { + zen::CreateDirectories(m_Config.OutputPath); + } + + // Describe workers + + ZEN_CONSOLE("describing {} workers", m_Config.WorkerMap.size()); + + for (auto Kv : m_Config.WorkerMap) + { + CbPackage WorkerDesc = Kv.second; + + m_Session.RegisterWorker(WorkerDesc); + } + + // Then submit work items + + int FailedWorkCounter = 0; + size_t RemainingWorkItems = m_Config.RecordingReader.GetActionCount(); + int SubmittedWorkItems = 0; + + ZEN_CONSOLE("submitting {} work items", RemainingWorkItems); + + int OffsetCounter = m_Config.Offset; + int StrideCounter = m_Config.Stride; - ZEN_CONSOLE("wrote HTML summary to {}", HtmlPath.string()); + auto ShouldSchedule = [&]() -> bool { + if (m_Config.Limit && SubmittedWorkItems >= m_Config.Limit) + { + // Limit reached, ignore + + return false; + } + + if (OffsetCounter && OffsetCounter--) + { + // Still in offset, ignore + + return false; } + + if (--StrideCounter == 0) + { + StrideCounter = m_Config.Stride; + + return true; + } + + return false; + }; + + int TargetParallelism = 8; + + if (OffsetCounter || StrideCounter || m_Config.Limit) + { + TargetParallelism = 1; + } + + std::atomic<int> RecordingIndex{0}; + + m_Config.RecordingReader.IterateActions( + [&](CbObject ActionObject, const IoHash& ActionId) { + // Enqueue job + + const int CurrentRecordingIndex = RecordingIndex++; + + Stopwatch SubmitTimer; + + const int Priority = 0; + + if (ShouldSchedule()) + { + if (m_Config.Verbose) + { + int AttachmentCount = 0; + uint64_t AttachmentBytes = 0; + eastl::hash_set<IoHash> ReferencedChunks; + + ActionObject.IterateAttachments([&](CbFieldView Field) { + IoHash AttachData = Field.AsAttachment(); + + ReferencedChunks.insert(AttachData); + ++AttachmentCount; + + if (IoBuffer ChunkData = m_Config.Resolver.FindChunkByCid(AttachData)) + { + AttachmentBytes += ChunkData.GetSize(); + } + }); + + zen::ExtendableStringBuilder<1024> ObjStr; + zen::CompactBinaryToJson(ActionObject, ObjStr); + ZEN_CONSOLE("work item {} ({} attachments, {} bytes): {}", + ActionId, + AttachmentCount, + NiceBytes(AttachmentBytes), + ObjStr); + } + + if (m_Config.DumpActions) + { + int AttachmentCount = 0; + uint64_t AttachmentBytes = 0; + + ActionObject.IterateAttachments([&](CbFieldView Field) { + IoHash AttachData = Field.AsAttachment(); + + ++AttachmentCount; + + if (IoBuffer ChunkData = m_Config.Resolver.FindChunkByCid(AttachData)) + { + AttachmentBytes += ChunkData.GetSize(); + } + }); + + zen::ExtendableStringBuilder<1024> ObjStr; + zen::CompactBinaryToYaml(ActionObject, ObjStr); + ZEN_CONSOLE("action {} ({} attachments, {}):\n{}", ActionId, AttachmentCount, NiceBytes(AttachmentBytes), ObjStr); + } + + if (zen::compute::ComputeServiceSession::EnqueueResult EnqueueResult = + m_Session.EnqueueActionToQueue(QueueId, ActionObject, Priority)) + { + const int32_t LsnField = EnqueueResult.Lsn; + + --RemainingWorkItems; + ++SubmittedWorkItems; + + if (!m_Config.Quiet) + { + ZEN_CONSOLE("submitted work item #{} - LSN {} - {}. {} remaining", + SubmittedWorkItems, + LsnField, + NiceTimeSpanMs(SubmitTimer.GetElapsedTimeMs()), + RemainingWorkItems); + } + + if (!m_Config.OutputPath.empty()) + { + SaveActionOutput(LsnField, CurrentRecordingIndex, ActionId, ActionObject); + } + + m_PendingJobs.Insert(LsnField); + } + else + { + if (!m_Config.Quiet) + { + std::string_view FunctionName = ActionObject["Function"sv].AsString(); + const Guid FunctionVersion = ActionObject["FunctionVersion"sv].AsUuid(); + const Guid BuildSystemVersion = ActionObject["BuildSystemVersion"sv].AsUuid(); + + ZEN_ERROR( + "failed to resolve function for work with (Function:{},FunctionVersion:{},BuildSystemVersion:{}). Work " + "descriptor " + "at: 'file://{}'", + std::string(FunctionName), + FunctionVersion, + BuildSystemVersion, + "<null>"); + + EmitFunctionListOnce(); + } + + ++FailedWorkCounter; + } + } + + // Check for completed work + + DrainCompletedJobs(); + SendOrchestratorHeartbeat(); + }, + TargetParallelism); + + // Wait until all pending work is complete + + while (!m_PendingJobs.IsEmpty()) + { + // TODO: improve this logic + zen::Sleep(500); + + DrainCompletedJobs(); + SendOrchestratorHeartbeat(); } + // Write summary files + + WriteSummaryFiles(); + if (FailedWorkCounter) { return 1; @@ -1174,37 +1078,91 @@ document.getElementById("csvBtn").addEventListener("click",()=>{ return 0; } -int -ExecCommand::LocalMessagingExecute() +////////////////////////////////////////////////////////////////////////// +// ExecHttpSubCmd + +ExecHttpSubCmd::ExecHttpSubCmd(ExecCommand& Parent) : ZenSubCmdBase("http", "Forward requests to HTTP compute service"), m_Parent(Parent) { - // Non-HTTP work submission path + SubOptions().add_option("", "u", "hosturl", ZenCmdBase::kHostUrlHelp, cxxopts::value(m_HostName), "<hosturl>"); +} + +void +ExecHttpSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/) +{ + m_HostName = ZenCmdBase::ResolveTargetHostSpec(m_HostName); - // To be reimplemented using final transport + ZEN_ASSERT(m_Parent.m_ChunkResolver); + ChunkResolver& Resolver = *m_Parent.m_ChunkResolver; - return 0; + std::filesystem::path TempPath = std::filesystem::absolute(".zen_temp"); + + zen::compute::ComputeServiceSession ComputeSession(Resolver); + ComputeSession.AddRemoteRunner(Resolver, TempPath, m_HostName); + + Stopwatch ExecTimer; + int ReturnValue = m_Parent.RunSession(ComputeSession); + + ZEN_CONSOLE("complete - took {}", NiceTimeSpanMs(ExecTimer.GetElapsedTimeMs())); + + if (!ReturnValue) + { + ZEN_CONSOLE("all work items completed successfully"); + } + else + { + ZEN_CONSOLE("some work items failed (code {})", ReturnValue); + } } ////////////////////////////////////////////////////////////////////////// +// ExecInprocSubCmd -int -ExecCommand::HttpExecute() +ExecInprocSubCmd::ExecInprocSubCmd(ExecCommand& Parent) : ZenSubCmdBase("inproc", "Handle execution in-process"), m_Parent(Parent) { - ZEN_ASSERT(m_ChunkResolver); - ChunkResolver& Resolver = *m_ChunkResolver; +} - std::filesystem::path TempPath = std::filesystem::absolute(".zen_temp"); +void +ExecInprocSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/) +{ + ZEN_ASSERT(m_Parent.m_ChunkResolver); + ChunkResolver& Resolver = *m_Parent.m_ChunkResolver; zen::compute::ComputeServiceSession ComputeSession(Resolver); - ComputeSession.AddRemoteRunner(Resolver, TempPath, m_HostName); - return ExecUsingSession(ComputeSession); + std::filesystem::path TempPath = std::filesystem::absolute(".zen_temp"); + ComputeSession.AddLocalRunner(Resolver, TempPath); + + Stopwatch ExecTimer; + int ReturnValue = m_Parent.RunSession(ComputeSession); + + ZEN_CONSOLE("complete - took {}", NiceTimeSpanMs(ExecTimer.GetElapsedTimeMs())); + + if (!ReturnValue) + { + ZEN_CONSOLE("all work items completed successfully"); + } + else + { + ZEN_CONSOLE("some work items failed (code {})", ReturnValue); + } } -int -ExecCommand::BeaconExecute() +////////////////////////////////////////////////////////////////////////// +// ExecBeaconSubCmd + +ExecBeaconSubCmd::ExecBeaconSubCmd(ExecCommand& Parent) +: ZenSubCmdBase("beacon", "Execute via beacon/orchestrator discovery") +, m_Parent(Parent) { - ZEN_ASSERT(m_ChunkResolver); - ChunkResolver& Resolver = *m_ChunkResolver; + SubOptions().add_option("", "", "orch", "Orchestrator URL for worker discovery", cxxopts::value(m_OrchestratorUrl), "<url>"); + SubOptions().add_option("", "", "beacon", "Beacon path", cxxopts::value(m_BeaconPath), "<path>"); +} + +void +ExecBeaconSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/) +{ + ZEN_ASSERT(m_Parent.m_ChunkResolver); + ChunkResolver& Resolver = *m_Parent.m_ChunkResolver; std::filesystem::path TempPath = std::filesystem::absolute(".zen_temp"); zen::compute::ComputeServiceSession ComputeSession(Resolver); @@ -1221,49 +1179,36 @@ ExecCommand::BeaconExecute() ComputeSession.AddRemoteRunner(Resolver, TempPath, "http://localhost:8558"); } - return ExecUsingSession(ComputeSession); -} - -////////////////////////////////////////////////////////////////////////// + Stopwatch ExecTimer; + int ReturnValue = m_Parent.RunSession(ComputeSession, m_OrchestratorUrl); -void -ExecCommand::RegisterWorkerFunctionsFromDescription(const CbObject& WorkerDesc, const IoHash& WorkerId) -{ - const Guid WorkerBuildSystemVersion = WorkerDesc["buildsystem_version"sv].AsUuid(); + ZEN_CONSOLE("complete - took {}", NiceTimeSpanMs(ExecTimer.GetElapsedTimeMs())); - for (auto& Item : WorkerDesc["functions"sv]) + if (!ReturnValue) { - CbObjectView Function = Item.AsObjectView(); - - std::string_view FunctionName = Function["name"sv].AsString(); - const Guid FunctionVersion = Function["version"sv].AsUuid(); - - m_FunctionList.emplace_back(FunctionDefinition{.FunctionName = std::string{FunctionName}, - .FunctionVersion = FunctionVersion, - .BuildSystemVersion = WorkerBuildSystemVersion, - .WorkerId = WorkerId}); + ZEN_CONSOLE("all work items completed successfully"); + } + else + { + ZEN_CONSOLE("some work items failed (code {})", ReturnValue); } } -void -ExecCommand::EmitFunctionListOnce(const std::vector<FunctionDefinition>& FunctionList) -{ - if (m_FunctionListEmittedOnce == false) - { - EmitFunctionList(FunctionList); +////////////////////////////////////////////////////////////////////////// +// ExecDumpSubCmd - m_FunctionListEmittedOnce = true; - } +ExecDumpSubCmd::ExecDumpSubCmd(ExecCommand& Parent) : ZenSubCmdBase("dump", "Dump high level information about actions"), m_Parent(Parent) +{ } -int -ExecCommand::DumpWorkItems() +void +ExecDumpSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/) { std::atomic<int> EmittedCount{0}; eastl::hash_map<IoHash, uint64_t> SeenAttachments; // Attachment CID -> count of references - m_RecordingReader->IterateActions( + m_Parent.m_RecordingReader->IterateActions( [&](CbObject ActionObject, const IoHash& ActionId) { eastl::hash_map<IoHash, CompressedBuffer> Attachments; @@ -1272,7 +1217,7 @@ ExecCommand::DumpWorkItems() ActionObject.IterateAttachments([&](const zen::CbFieldView AttachmentField) { const IoHash AttachmentCid = AttachmentField.GetValue().AsHash(); - IoBuffer AttachmentData = m_ChunkResolver->FindChunkByCid(AttachmentCid); + IoBuffer AttachmentData = m_Parent.m_ChunkResolver->FindChunkByCid(AttachmentCid); IoHash RawHash; uint64_t RawSize = 0; CompressedBuffer CompressedData = CompressedBuffer::FromCompressed(SharedBuffer(AttachmentData), RawHash, RawSize); @@ -1322,36 +1267,191 @@ ExecCommand::DumpWorkItems() { ZEN_CONSOLE("{} attachments with {} references", Cids.size(), RefCount); } +} - return 0; +////////////////////////////////////////////////////////////////////////// +// ExecBuildlogSubCmd + +ExecBuildlogSubCmd::ExecBuildlogSubCmd(ExecCommand& Parent) : ZenSubCmdBase("buildlog", "Generate build actions log"), m_Parent(Parent) +{ +} + +void +ExecBuildlogSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/) +{ + ZEN_ASSERT(m_Parent.m_ChunkResolver); + ChunkResolver& Resolver = *m_Parent.m_ChunkResolver; + + if (std::filesystem::exists(m_Parent.m_RecordingLogPath)) + { + m_Parent.ThrowOptionError(fmt::format("recording log directory '{}' already exists!", m_Parent.m_RecordingLogPath)); + } + + ZEN_NOT_IMPLEMENTED("build log generation not implemented yet!"); + + std::filesystem::path TempPath = std::filesystem::absolute(".zen_temp"); + + zen::compute::ComputeServiceSession ComputeSession(Resolver); + ComputeSession.StartRecording(Resolver, m_Parent.m_RecordingLogPath); + + Stopwatch ExecTimer; + int ReturnValue = m_Parent.RunSession(ComputeSession); + + ZEN_CONSOLE("complete - took {}", NiceTimeSpanMs(ExecTimer.GetElapsedTimeMs())); + + if (!ReturnValue) + { + ZEN_CONSOLE("all work items completed successfully"); + } + else + { + ZEN_CONSOLE("some work items failed (code {})", ReturnValue); + } } ////////////////////////////////////////////////////////////////////////// +// ExecCommand -int -ExecCommand::BuildActionsLog() +ExecCommand::ExecCommand() { - ZEN_ASSERT(m_ChunkResolver); - ChunkResolver& Resolver = *m_ChunkResolver; + m_Options.add_options()("h,help", "Print help"); + m_Options.add_option("replay", "p", "path", "Recording path (directory or .actionlog file)", cxxopts::value(m_RecordingPath), "<path>"); + m_Options.add_option("replay", "", "log", "Action log directory", cxxopts::value(m_RecordingLogPath), "<path>"); + m_Options.add_option("replay", "", "offset", "Recording replay start offset", cxxopts::value(m_Offset), "<offset>"); + m_Options.add_option("replay", "", "stride", "Recording replay stride", cxxopts::value(m_Stride), "<stride>"); + m_Options.add_option("replay", "", "limit", "Recording replay limit", cxxopts::value(m_Limit), "<limit>"); + m_Options.add_option("", "", "quiet", "Quiet mode (less logging)", cxxopts::value(m_Quiet), "<bool>"); + m_Options.add_option("output", + "", + "dump-actions", + "Dump each action to console as it is dispatched", + cxxopts::value(m_DumpActions), + "<bool>"); + m_Options.add_option("output", "o", "output", "Save action results to directory", cxxopts::value(m_OutputPath), "<path>"); + m_Options.add_option("output", "", "binary", "Write output as binary packages instead of YAML", cxxopts::value(m_Binary), "<bool>"); + m_Options.add_option("__hidden__", "", "subcommand", "", cxxopts::value<std::string>(m_SubCommand)->default_value(""), ""); + m_Options.parse_positional({"subcommand"}); + + AddSubCommand(m_HttpSubCmd); + AddSubCommand(m_InprocSubCmd); + AddSubCommand(m_BeaconSubCmd); + AddSubCommand(m_DumpSubCmd); + AddSubCommand(m_BuildlogSubCmd); +} +ExecCommand::~ExecCommand() +{ +} + +bool +ExecCommand::OnParentOptionsParsed(const ZenCliOptions& GlobalOptions) +{ if (m_RecordingPath.empty()) { - throw OptionParseException("need to specify recording path", m_Options.help()); + ThrowOptionError("replay path is required!"); } - if (std::filesystem::exists(m_RecordingLogPath)) + m_VerboseLogging = GlobalOptions.IsVerbose; + m_QuietLogging = m_Quiet && !m_VerboseLogging; + + // Gather information from recording path + + std::filesystem::path RecordingPath{m_RecordingPath}; + + if (!std::filesystem::is_directory(RecordingPath)) { - throw OptionParseException(fmt::format("recording log directory '{}' already exists!", m_RecordingLogPath), m_Options.help()); + ThrowOptionError("replay path should be a directory path!"); + } + else + { + if (std::filesystem::is_directory(RecordingPath / "cid")) + { + m_Reader = std::make_unique<zen::compute::RecordingReader>(RecordingPath); + m_WorkerMap = m_Reader->ReadWorkers(); + m_ChunkResolver = m_Reader.get(); + m_RecordingReader = m_Reader.get(); + } + else + { + m_UeReader = std::make_unique<zen::compute::UeRecordingReader>(RecordingPath); + m_WorkerMap = m_UeReader->ReadWorkers(); + m_ChunkResolver = m_UeReader.get(); + m_RecordingReader = m_UeReader.get(); + } } - ZEN_NOT_IMPLEMENTED("build log generation not implemented yet!"); + ZEN_CONSOLE("found {} workers, {} action items", m_WorkerMap.size(), m_RecordingReader->GetActionCount()); - std::filesystem::path TempPath = std::filesystem::absolute(".zen_temp"); + for (auto& Kv : m_WorkerMap) + { + CbObject WorkerDesc = Kv.second.GetObject(); + const IoHash& WorkerId = Kv.first; - zen::compute::ComputeServiceSession ComputeSession(Resolver); - ComputeSession.StartRecording(Resolver, m_RecordingLogPath); + RegisterWorkerFunctionsFromDescription(WorkerDesc, WorkerId); + + if (m_VerboseLogging) + { + zen::ExtendableStringBuilder<1024> ObjStr; +# if 0 + zen::CompactBinaryToJson(WorkerDesc, ObjStr); + ZEN_CONSOLE("worker {}: {}", WorkerId, ObjStr); +# else + zen::CompactBinaryToYaml(WorkerDesc, ObjStr); + ZEN_CONSOLE("worker {}:\n{}", WorkerId, ObjStr); +# endif + } + } + + if (m_VerboseLogging) + { + EmitFunctionList(m_FunctionList); + } + + return true; +} - return ExecUsingSession(ComputeSession); +int +ExecCommand::RunSession(zen::compute::ComputeServiceSession& ComputeSession, std::string_view OrchestratorUrl) +{ + ExecSessionConfig Config{ + .Resolver = *m_ChunkResolver, + .RecordingReader = *m_RecordingReader, + .WorkerMap = m_WorkerMap, + .FunctionList = m_FunctionList, + .OrchestratorUrl = OrchestratorUrl, + .OutputPath = m_OutputPath, + .Offset = m_Offset, + .Stride = m_Stride, + .Limit = m_Limit, + .Verbose = m_VerboseLogging, + .Quiet = m_QuietLogging, + .DumpActions = m_DumpActions, + .Binary = m_Binary, + }; + + ExecSessionRunner Runner(ComputeSession, Config); + return Runner.Run(); +} + +////////////////////////////////////////////////////////////////////////// + +void +ExecCommand::RegisterWorkerFunctionsFromDescription(const CbObject& WorkerDesc, const IoHash& WorkerId) +{ + const Guid WorkerBuildSystemVersion = WorkerDesc["buildsystem_version"sv].AsUuid(); + + for (auto& Item : WorkerDesc["functions"sv]) + { + CbObjectView Function = Item.AsObjectView(); + + std::string_view FunctionName = Function["name"sv].AsString(); + const Guid FunctionVersion = Function["version"sv].AsUuid(); + + m_FunctionList.emplace_back(FunctionDefinition{.FunctionName = std::string{FunctionName}, + .FunctionVersion = FunctionVersion, + .BuildSystemVersion = WorkerBuildSystemVersion, + .WorkerId = WorkerId}); + } } void diff --git a/src/zen/cmds/exec_cmd.h b/src/zen/cmds/exec_cmd.h index 6311354c0..c55412780 100644 --- a/src/zen/cmds/exec_cmd.h +++ b/src/zen/cmds/exec_cmd.h @@ -11,6 +11,7 @@ #include <filesystem> #include <functional> +#include <memory> #include <unordered_map> namespace zen { @@ -28,13 +29,79 @@ class ComputeServiceSession; namespace zen { +class ExecCommand; + +struct ExecFunctionDefinition +{ + std::string FunctionName; + zen::Guid FunctionVersion; + zen::Guid BuildSystemVersion; + zen::IoHash WorkerId; +}; + +////////////////////////////////////////////////////////////////////////// +// Subcommands + +class ExecHttpSubCmd : public ZenSubCmdBase +{ +public: + explicit ExecHttpSubCmd(ExecCommand& Parent); + void Run(const ZenCliOptions& GlobalOptions) override; + +private: + ExecCommand& m_Parent; + std::string m_HostName; +}; + +class ExecInprocSubCmd : public ZenSubCmdBase +{ +public: + explicit ExecInprocSubCmd(ExecCommand& Parent); + void Run(const ZenCliOptions& GlobalOptions) override; + +private: + ExecCommand& m_Parent; +}; + +class ExecBeaconSubCmd : public ZenSubCmdBase +{ +public: + explicit ExecBeaconSubCmd(ExecCommand& Parent); + void Run(const ZenCliOptions& GlobalOptions) override; + +private: + ExecCommand& m_Parent; + std::string m_OrchestratorUrl; + std::filesystem::path m_BeaconPath; +}; + +class ExecDumpSubCmd : public ZenSubCmdBase +{ +public: + explicit ExecDumpSubCmd(ExecCommand& Parent); + void Run(const ZenCliOptions& GlobalOptions) override; + +private: + ExecCommand& m_Parent; +}; + +class ExecBuildlogSubCmd : public ZenSubCmdBase +{ +public: + explicit ExecBuildlogSubCmd(ExecCommand& Parent); + void Run(const ZenCliOptions& GlobalOptions) override; + +private: + ExecCommand& m_Parent; +}; + /** * Zen CLI command for executing functions from a recording * * Mostly for testing and debugging purposes */ -class ExecCommand : public ZenCmdBase +class ExecCommand : public ZenCmdWithSubCommands { public: ExecCommand(); @@ -43,57 +110,47 @@ public: static constexpr char Name[] = "exec"; static constexpr char Description[] = "Execute functions from a recording"; - virtual void Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override; - virtual cxxopts::Options& Options() override { return m_Options; } + cxxopts::Options& Options() override { return m_Options; } + + // Shared state & helpers (public for subcommand access) + + using FunctionDefinition = ExecFunctionDefinition; + + static void EmitFunctionList(const std::vector<FunctionDefinition>& FunctionList); + void RegisterWorkerFunctionsFromDescription(const zen::CbObject& WorkerDesc, const zen::IoHash& WorkerId); + + int RunSession(zen::compute::ComputeServiceSession& ComputeSession, std::string_view OrchestratorUrl = {}); + + std::unordered_map<zen::IoHash, zen::CbPackage> m_WorkerMap; + std::vector<FunctionDefinition> m_FunctionList; + zen::ChunkResolver* m_ChunkResolver = nullptr; + zen::compute::RecordingReaderBase* m_RecordingReader = nullptr; + bool m_VerboseLogging = false; + bool m_QuietLogging = false; + bool m_DumpActions = false; + std::filesystem::path m_OutputPath; + bool m_Binary = false; + std::filesystem::path m_RecordingLogPath; private: + bool OnParentOptionsParsed(const ZenCliOptions& GlobalOptions) override; + cxxopts::Options m_Options{Name, Description}; - std::string m_HostName; - std::string m_OrchestratorUrl; - std::filesystem::path m_BeaconPath; + std::string m_SubCommand; std::filesystem::path m_RecordingPath; - std::filesystem::path m_RecordingLogPath; int m_Offset = 0; int m_Stride = 1; int m_Limit = 0; bool m_Quiet = false; - std::string m_Mode{"http"}; - std::filesystem::path m_OutputPath; - bool m_Binary = false; - - struct FunctionDefinition - { - std::string FunctionName; - zen::Guid FunctionVersion; - zen::Guid BuildSystemVersion; - zen::IoHash WorkerId; - }; - - bool m_FunctionListEmittedOnce = false; - void EmitFunctionListOnce(const std::vector<FunctionDefinition>& FunctionList); - void EmitFunctionList(const std::vector<FunctionDefinition>& FunctionList); - - std::unordered_map<zen::IoHash, zen::CbPackage> m_WorkerMap; - std::vector<FunctionDefinition> m_FunctionList; - bool m_VerboseLogging = false; - bool m_QuietLogging = false; - bool m_DumpActions = false; - - zen::ChunkResolver* m_ChunkResolver = nullptr; - zen::compute::RecordingReaderBase* m_RecordingReader = nullptr; - - void RegisterWorkerFunctionsFromDescription(const zen::CbObject& WorkerDesc, const zen::IoHash& WorkerId); - - int ExecUsingSession(zen::compute::ComputeServiceSession& ComputeSession); - // Execution modes + std::unique_ptr<zen::compute::RecordingReader> m_Reader; + std::unique_ptr<zen::compute::UeRecordingReader> m_UeReader; - int DumpWorkItems(); - int HttpExecute(); - int InProcessExecute(); - int LocalMessagingExecute(); - int BeaconExecute(); - int BuildActionsLog(); + ExecHttpSubCmd m_HttpSubCmd{*this}; + ExecInprocSubCmd m_InprocSubCmd{*this}; + ExecBeaconSubCmd m_BeaconSubCmd{*this}; + ExecDumpSubCmd m_DumpSubCmd{*this}; + ExecBuildlogSubCmd m_BuildlogSubCmd{*this}; }; } // namespace zen diff --git a/src/zen/cmds/service_cmd.cpp b/src/zen/cmds/service_cmd.cpp index 3347f1afe..37baf5483 100644 --- a/src/zen/cmds/service_cmd.cpp +++ b/src/zen/cmds/service_cmd.cpp @@ -310,17 +310,34 @@ ServiceCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) std::vector<char*> SubCommandArguments; cxxopts::Options* SubOption = nullptr; int ParentCommandArgCount = GetSubCommand(m_Options, argc, argv, m_SubCommands, SubOption, SubCommandArguments); - if (!ParseOptions(ParentCommandArgCount, argv)) + + if (SubOption == nullptr) + { + if (!ParseOptions(ParentCommandArgCount, argv)) + { + return; + } + throw OptionParseException("'verb' option is required", m_Options.help()); + } + + // Parse subcommand permissively — forward unrecognised options to the parent parser. + std::vector<std::string> SubUnmatched; + if (!ParseOptionsPermissive(*SubOption, gsl::narrow<int>(SubCommandArguments.size()), SubCommandArguments.data(), SubUnmatched)) { return; } - if (SubOption == nullptr) + // Build parent arg list: original parent args (without subcommand name) + forwarded unmatched. + std::vector<char*> ParentArgs; + ParentArgs.reserve(static_cast<size_t>(ParentCommandArgCount - 1) + SubUnmatched.size()); + ParentArgs.push_back(argv[0]); + std::copy(argv + 1, argv + ParentCommandArgCount - 1, std::back_inserter(ParentArgs)); + for (std::string& Arg : SubUnmatched) { - throw OptionParseException("'verb' option is required", m_Options.help()); + ParentArgs.push_back(Arg.data()); } - if (!ParseOptions(*SubOption, gsl::narrow<int>(SubCommandArguments.size()), SubCommandArguments.data())) + if (!ParseOptions(static_cast<int>(ParentArgs.size()), ParentArgs.data())) { return; } diff --git a/src/zen/cmds/workspaces_cmd.cpp b/src/zen/cmds/workspaces_cmd.cpp index 220ef6a9e..9e49b464e 100644 --- a/src/zen/cmds/workspaces_cmd.cpp +++ b/src/zen/cmds/workspaces_cmd.cpp @@ -127,14 +127,36 @@ WorkspaceCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) std::vector<char*> SubCommandArguments; cxxopts::Options* SubOption = nullptr; int ParentCommandArgCount = GetSubCommand(m_Options, argc, argv, m_SubCommands, SubOption, SubCommandArguments); - if (!ParseOptions(ParentCommandArgCount, argv)) + + if (SubOption == nullptr) + { + if (!ParseOptions(ParentCommandArgCount, argv)) + { + return; + } + throw OptionParseException("'verb' option is required", m_Options.help()); + } + + // Parse subcommand permissively — forward unrecognised options to the parent parser. + std::vector<std::string> SubUnmatched; + if (!ParseOptionsPermissive(*SubOption, gsl::narrow<int>(SubCommandArguments.size()), SubCommandArguments.data(), SubUnmatched)) { return; } - if (SubOption == nullptr) + // Build parent arg list: original parent args (without subcommand name) + forwarded unmatched. + std::vector<char*> ParentArgs; + ParentArgs.reserve(static_cast<size_t>(ParentCommandArgCount - 1) + SubUnmatched.size()); + ParentArgs.push_back(argv[0]); + std::copy(argv + 1, argv + ParentCommandArgCount - 1, std::back_inserter(ParentArgs)); + for (std::string& Arg : SubUnmatched) { - throw OptionParseException("'verb' option is required", m_Options.help()); + ParentArgs.push_back(Arg.data()); + } + + if (!ParseOptions(static_cast<int>(ParentArgs.size()), ParentArgs.data())) + { + return; } m_HostName = ResolveTargetHostSpec(m_HostName); @@ -150,11 +172,6 @@ WorkspaceCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) std::filesystem::path StatePath = m_SystemRootDir / "workspaces"; - if (!ParseOptions(*SubOption, gsl::narrow<int>(SubCommandArguments.size()), SubCommandArguments.data())) - { - return; - } - if (SubOption == &m_CreateOptions) { if (m_Path.empty()) @@ -376,14 +393,36 @@ WorkspaceShareCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** std::vector<char*> SubCommandArguments; cxxopts::Options* SubOption = nullptr; int ParentCommandArgCount = GetSubCommand(m_Options, argc, argv, m_SubCommands, SubOption, SubCommandArguments); - if (!ParseOptions(ParentCommandArgCount, argv)) + + if (SubOption == nullptr) + { + if (!ParseOptions(ParentCommandArgCount, argv)) + { + return; + } + throw OptionParseException("'verb' option is required", m_Options.help()); + } + + // Parse subcommand permissively — forward unrecognised options to the parent parser. + std::vector<std::string> SubUnmatched; + if (!ParseOptionsPermissive(*SubOption, gsl::narrow<int>(SubCommandArguments.size()), SubCommandArguments.data(), SubUnmatched)) { return; } - if (SubOption == nullptr) + // Build parent arg list: original parent args (without subcommand name) + forwarded unmatched. + std::vector<char*> ParentArgs; + ParentArgs.reserve(static_cast<size_t>(ParentCommandArgCount - 1) + SubUnmatched.size()); + ParentArgs.push_back(argv[0]); + std::copy(argv + 1, argv + ParentCommandArgCount - 1, std::back_inserter(ParentArgs)); + for (std::string& Arg : SubUnmatched) { - throw OptionParseException("'verb' option is required", m_Options.help()); + ParentArgs.push_back(Arg.data()); + } + + if (!ParseOptions(static_cast<int>(ParentArgs.size()), ParentArgs.data())) + { + return; } m_HostName = ResolveTargetHostSpec(m_HostName); @@ -403,11 +442,6 @@ WorkspaceShareCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** std::filesystem::path StatePath = m_SystemRootDir / "workspaces"; - if (!ParseOptions(*SubOption, gsl::narrow<int>(SubCommandArguments.size()), SubCommandArguments.data())) - { - return; - } - if (SubOption == &m_CreateOptions) { if (m_WorkspaceRoot.empty()) diff --git a/src/zen/zen.cpp b/src/zen/zen.cpp index 86154c291..849d7a075 100644 --- a/src/zen/zen.cpp +++ b/src/zen/zen.cpp @@ -115,6 +115,21 @@ ZenCmdBase::CommandCategory() const return DefaultCategory; } +std::string +ZenCmdBase::HelpText() +{ + std::vector<std::string> Groups = Options().groups(); + Groups.erase(std::remove(Groups.begin(), Groups.end(), std::string("__hidden__")), Groups.end()); + Options().set_width(TuiConsoleColumns(80)); + return Options().help(Groups); +} + +void +ZenCmdBase::ThrowOptionError(std::string_view Message) +{ + throw OptionParseException(std::string(Message), HelpText()); +} + bool ZenCmdBase::ParseOptions(int argc, char** argv) { @@ -166,6 +181,35 @@ ZenCmdBase::ParseOptions(cxxopts::Options& CmdOptions, int argc, char** argv) return true; } +bool +ZenCmdBase::ParseOptionsPermissive(cxxopts::Options& CmdOptions, int argc, char** argv, std::vector<std::string>& OutUnmatched) +{ + CmdOptions.set_width(TuiConsoleColumns(80)); + CmdOptions.allow_unrecognised_options(); + + cxxopts::ParseResult Result; + + try + { + Result = CmdOptions.parse(argc, argv); + } + catch (const std::exception& Ex) + { + throw zen::OptionParseException(Ex.what(), CmdOptions.help()); + } + + CmdOptions.show_positional_help(); + + if (Result.count("help")) + { + printf("%s\n", CmdOptions.help().c_str()); + return false; + } + + OutUnmatched = Result.unmatched(); + return true; +} + // Get the number of args including the sub command // Build an array for sub command to parse int @@ -243,6 +287,23 @@ ZenCmdWithSubCommands::PrintHelp() } void +ZenCmdWithSubCommands::PrintSubCommandHelp(cxxopts::Options& SubCmdOptions) +{ + // Show the subcommand's own options. + SubCmdOptions.set_width(TuiConsoleColumns(80)); + printf("%s\n", SubCmdOptions.help().c_str()); + + // Show the parent command's options (excluding the hidden positional group). + std::vector<std::string> ParentGroups = Options().groups(); + ParentGroups.erase(std::remove(ParentGroups.begin(), ParentGroups.end(), std::string("__hidden__")), ParentGroups.end()); + + Options().set_width(TuiConsoleColumns(80)); + printf("%s options:\n%s\n", Options().program().c_str(), Options().help(ParentGroups).c_str()); + + printf("For global options run: zen --help\n"); +} + +void ZenCmdWithSubCommands::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) { std::vector<cxxopts::Options*> SubOptionPtrs; @@ -270,42 +331,15 @@ ZenCmdWithSubCommands::Run(const ZenCliOptions& GlobalOptions, int argc, char** } } - // Parse parent options. When a subcommand was matched we strip its name from - // the arg list so the parent parser does not see it as an unmatched positional. - if (MatchedSubOption != nullptr) - { - std::vector<char*> ParentArgs; - ParentArgs.reserve(static_cast<size_t>(ParentArgc - 1)); - ParentArgs.push_back(argv[0]); - std::copy(argv + 1, argv + ParentArgc - 1, std::back_inserter(ParentArgs)); - if (!ParseOptions(Options(), static_cast<int>(ParentArgs.size()), ParentArgs.data())) - { - return; - } - } - else + if (MatchedSubOption == nullptr) { if (!ParseOptions(Options(), ParentArgc, argv)) { return; } - } - - if (MatchedSubOption == nullptr) - { PrintHelp(); - - ExtendableStringBuilder<128> VerbList; - for (bool First = true; ZenSubCmdBase * SubCmd : m_SubCommands) - { - if (!First) - { - VerbList.Append(", "); - } - VerbList.Append(SubCmd->SubOptions().program()); - First = false; - } - throw OptionParseException(fmt::format("No subcommand specified. Available subcommands: {}", VerbList.ToView()), {}); + fflush(stdout); + throw OptionParseException("No subcommand specified", {}); } ZenSubCmdBase* MatchedSubCmd = nullptr; @@ -319,9 +353,40 @@ ZenCmdWithSubCommands::Run(const ZenCliOptions& GlobalOptions, int argc, char** } ZEN_ASSERT(MatchedSubCmd != nullptr); - // Parse subcommand args before OnParentOptionsParsed so --help on the subcommand - // works without requiring parent options like --hosturl to be populated. - if (!ParseOptions(*MatchedSubOption, gsl::narrow<int>(SubCommandArguments.size()), SubCommandArguments.data())) + // Intercept --help/-h in the subcommand args so we can show combined help + // (subcommand options + parent options) without requiring parent options to + // be populated. + for (size_t i = 1; i < SubCommandArguments.size(); ++i) + { + std::string_view Arg(SubCommandArguments[i]); + if (Arg == "--help" || Arg == "-h") + { + PrintSubCommandHelp(*MatchedSubOption); + return; + } + } + + // Parse subcommand args permissively — unrecognised options are collected + // and forwarded to the parent parser so that parent options (e.g. --path) + // can appear after the subcommand name on the command line. + std::vector<std::string> SubUnmatched; + if (!ParseOptionsPermissive(*MatchedSubOption, gsl::narrow<int>(SubCommandArguments.size()), SubCommandArguments.data(), SubUnmatched)) + { + return; + } + + // Build parent arg list: original parent args (without subcommand name) + forwarded unmatched. + std::vector<std::string> UnmatchedStorage = std::move(SubUnmatched); + std::vector<char*> ParentArgs; + ParentArgs.reserve(static_cast<size_t>(ParentArgc - 1) + UnmatchedStorage.size()); + ParentArgs.push_back(argv[0]); + std::copy(argv + 1, argv + ParentArgc - 1, std::back_inserter(ParentArgs)); + for (std::string& Arg : UnmatchedStorage) + { + ParentArgs.push_back(Arg.data()); + } + + if (!ParseOptions(Options(), static_cast<int>(ParentArgs.size()), ParentArgs.data())) { return; } diff --git a/src/zen/zen.h b/src/zen/zen.h index 05ce32d0a..97cc9af6f 100644 --- a/src/zen/zen.h +++ b/src/zen/zen.h @@ -60,6 +60,7 @@ public: bool ParseOptions(int argc, char** argv); static bool ParseOptions(cxxopts::Options& Options, int argc, char** argv); + static bool ParseOptionsPermissive(cxxopts::Options& Options, int argc, char** argv, std::vector<std::string>& OutUnmatched); static int GetSubCommand(cxxopts::Options& Options, int argc, char** argv, @@ -74,7 +75,9 @@ public: static constexpr const char* kHostUrlHelp = "Host URL or unix:///path/to/socket"; - static void LogExecutableVersionAndPid(); + std::string HelpText(); + [[noreturn]] void ThrowOptionError(std::string_view Message); + static void LogExecutableVersionAndPid(); }; class StorageCommand : public ZenCmdBase @@ -114,6 +117,7 @@ protected: void AddSubCommand(ZenSubCmdBase& SubCmd); virtual bool OnParentOptionsParsed(const ZenCliOptions& GlobalOptions); void PrintHelp(); + void PrintSubCommandHelp(cxxopts::Options& SubCmdOptions); private: std::vector<ZenSubCmdBase*> m_SubCommands; diff --git a/src/zencompute/CLAUDE.md b/src/zencompute/CLAUDE.md index f5188123f..a1a39fc3c 100644 --- a/src/zencompute/CLAUDE.md +++ b/src/zencompute/CLAUDE.md @@ -46,9 +46,12 @@ Public API entry point. Uses PIMPL (`struct Impl` in computeservice.cpp). Owns: - Action maps: `m_PendingActions`, `m_RunningMap`, `m_ResultsMap` - Queue map: `m_Queues` (QueueEntry objects) - Action history ring: `m_ActionHistory` (bounded deque, default 1000) +- WebSocket client (`m_OrchestratorWsClient`) subscribed to the orchestrator's `/orch/ws` push for instant worker discovery **Session states:** Created → Ready → Draining → Paused → Abandoned → Sunset. Both Abandoned and Sunset can be jumped to from any earlier state. Abandoned is used for spot instance termination grace periods — on entry, all pending and running actions are immediately marked as `RunnerAction::State::Abandoned` and running processes are best-effort cancelled. Auto-retry is suppressed while the session is Abandoned. `IsHealthy()` returns false for Abandoned and Sunset. +**Convenience helpers:** `Ready()`, `Abandon()`, `SetOrchestrator(Endpoint, BasePath)` are inline wrappers for common state transitions and orchestrator configuration. + ### `RunnerAction` (runners/functionrunner.h) Shared ref-counted struct representing one action through its lifecycle. @@ -67,8 +70,11 @@ New → Pending → Submitting → Running → Completed → Failed → Abandoned → Cancelled + → Retracted ``` -`SetActionState()` rejects non-forward transitions. The one exception is `ResetActionStateToPending()`, which uses CAS to atomically transition from Failed/Abandoned back to Pending for rescheduling. It clears timestamps from Submitting onward, resets execution fields, increments `RetryCount`, and calls `PostUpdate()` to re-enter the scheduler pipeline. +`SetActionState()` rejects non-forward transitions (Retracted has the highest ordinal so runner-side transitions cannot override it). `ResetActionStateToPending()` uses CAS to atomically transition from Failed/Abandoned back to Pending for rescheduling — it clears timestamps from Submitting onward, resets execution fields, increments `RetryCount`, and calls `PostUpdate()` to re-enter the scheduler pipeline. + +**Retracted state:** An explicit, instigator-initiated request to pull an action back and reschedule it on a different runner (e.g. capacity opened up elsewhere). Unlike Failed/Abandoned auto-retry, rescheduling from Retracted does not increment `RetryCount` since nothing went wrong. Retraction is idempotent and can target Pending, Submitting, or Running actions. ### `LocalProcessRunner` (runners/localrunner.h) Base for all local execution. Platform runners subclass this and override: @@ -90,10 +96,29 @@ Base for all local execution. Platform runners subclass this and override: - macOS: `proc_pidinfo(PROC_PIDTASKINFO)` pti_total_user+system nanoseconds ÷ 1,000,000,000 ### `FunctionRunner` / `RunnerGroup` (runners/functionrunner.h) -Abstract base for runners. `RunnerGroup<T>` holds a vector of runners and load-balances across them using a round-robin atomic index. `BaseRunnerGroup::SubmitActions()` distributes a batch proportionally based on per-runner capacity. +Abstract base for runners. `RunnerGroup<T>` holds a vector of runners and load-balances across them using a round-robin atomic index. `BaseRunnerGroup::SubmitActions()` distributes a batch proportionally based on per-runner capacity. `SubmitActions()` supports batch submission — actions are grouped and forwarded in chunks. + +### `RemoteHttpRunner` (runners/remotehttprunner.h) +Submits actions to remote zenserver instances over HTTP. Key features: +- **WebSocket completion notifications**: connects a WS client to `/compute/ws` on the remote. When a message arrives (action completed), the monitor thread wakes immediately instead of polling. Falls back to adaptive polling (200ms→50ms) when WS is unavailable. +- **Batch submission**: groups actions by queue and submits in configurable chunks (`m_MaxBatchSize`, default 50), falling back to individual submission on failure. +- **Queue cancellation**: `CancelRemoteQueue()` sends cancel requests to the remote. +- **Graceful shutdown**: `Shutdown()` closes the WS client, cancels all remote queues, stops the monitor thread, and marks remaining actions as Failed. ### `HttpComputeService` (include/zencompute/httpcomputeservice.h) -Wraps `ComputeServiceSession` as an HTTP service. All routes are registered in the constructor. Handles CbPackage attachment ingestion from `CidStore` before forwarding to the service. +Wraps `ComputeServiceSession` as an HTTP service. All routes are registered in the constructor. Handles CbPackage attachment ingestion from `CidStore` before forwarding to the service. Supports both single-action and batch (actions array) payloads via a shared `HandleSubmitAction` helper. + +## Orchestrator Discovery + +`ComputeServiceSession` discovers remote workers via the orchestrator endpoint (`SetOrchestratorEndpoint()`). Two complementary mechanisms: + +1. **Polling** (`UpdateCoordinatorState`): `GET /orch/agents` on the scheduler thread, throttled to every 5s (500ms when no workers are known yet). Discovers new workers and removes stale/unreachable ones. + +2. **WebSocket push** (`OrchestratorWsHandler`): connects to `/orch/ws` on the orchestrator at setup time. When the orchestrator broadcasts a state change, the handler sets `m_OrchestratorQueryForced` and signals the scheduler event, bypassing the polling throttle. Falls back silently to polling if the WS connection fails. + +`NotifyOrchestratorChanged()` is the public API to trigger an immediate re-query — useful in tests and for external notification sources. + +Use `HttpToWsUrl(Endpoint, Path)` from `zenhttp/httpwsclient.h` to convert HTTP(S) endpoints to WebSocket URLs. This helper is shared across all WS client setup sites in the codebase. ## Action Lifecycle (End to End) @@ -118,6 +143,8 @@ Actions that fail or are abandoned can be automatically retried or manually resc **Retry limit:** Default of 3, overridable per-queue via the `max_retries` integer field in the queue's `Config` CbObject (set at `CreateQueue` time). Both automatic and manual paths respect this limit. +**Retraction (API path):** `RetractAction(Lsn)` pulls a Pending/Submitting/Running action back for rescheduling on a different runner. The action transitions to Retracted, then `ResetActionStateToPending()` is called *without* incrementing `RetryCount`. Retraction is idempotent. + **Cancelled actions are never retried** — cancellation is an intentional user action, not a transient failure. ## Queue System @@ -161,8 +188,9 @@ All routes registered in `HttpComputeService` constructor. Prefix is configured | GET | `jobs/running` | In-flight actions with CPU metrics | | GET | `jobs/completed` | Actions with results available | | GET/POST/DELETE | `jobs/{lsn}` | GET: result; POST: reschedule failed action; DELETE: retire | +| POST | `jobs/{lsn}/retract` | Retract a pending/running action for rescheduling (idempotent) | | POST | `jobs/{worker}` | Submit action for specific worker | -| POST | `jobs` | Submit action (worker resolved from descriptor) | +| POST | `jobs` | Submit action (or batch via `actions` array) | | GET | `workers` | List worker IDs | | GET | `workers/all` | All workers with full descriptors | | GET/POST | `workers/{worker}` | Get/register worker | @@ -179,8 +207,9 @@ Queue ref is capture(1) in all `queues/{queueref}/...` routes. | GET | `queues/{queueref}/completed` | Queue's completed results | | GET | `queues/{queueref}/history` | Queue's action history | | GET | `queues/{queueref}/running` | Queue's running actions | -| POST | `queues/{queueref}/jobs` | Submit to queue | +| POST | `queues/{queueref}/jobs` | Submit to queue (or batch via `actions` array) | | GET/POST | `queues/{queueref}/jobs/{lsn}` | GET: result; POST: reschedule | +| POST | `queues/{queueref}/jobs/{lsn}/retract` | Retract action for rescheduling | | GET/POST | `queues/{queueref}/workers/...` | Worker endpoints (same as global) | Worker handler logic is extracted into private helpers (`HandleWorkersGet`, `HandleWorkersAllGet`, `HandleWorkerRequest`) shared by top-level and queue-scoped routes. diff --git a/src/zencompute/computeservice.cpp b/src/zencompute/computeservice.cpp index 838d741b6..92901de64 100644 --- a/src/zencompute/computeservice.cpp +++ b/src/zencompute/computeservice.cpp @@ -33,6 +33,7 @@ # include <zenutil/workerpools.h> # include <zentelemetry/stats.h> # include <zenhttp/httpclient.h> +# include <zenhttp/httpwsclient.h> # include <set> # include <deque> @@ -42,6 +43,7 @@ # include <unordered_set> ZEN_THIRD_PARTY_INCLUDES_START +# include <EASTL/fixed_vector.h> # include <EASTL/hash_set.h> ZEN_THIRD_PARTY_INCLUDES_END @@ -95,6 +97,14 @@ using SessionState = ComputeServiceSession::SessionState; static_assert(ZEN_ARRAY_COUNT(ComputeServiceSession::ActionHistoryEntry::Timestamps) == static_cast<size_t>(RunnerAction::State::_Count)); +static ComputeServiceSession::EnqueueResult +MakeErrorResult(std::string_view Error) +{ + CbObjectWriter Writer; + Writer << "error"sv << Error; + return {0, Writer.Save()}; +} + ////////////////////////////////////////////////////////////////////////// struct ComputeServiceSession::Impl @@ -130,14 +140,40 @@ struct ComputeServiceSession::Impl void SetOrchestratorEndpoint(std::string_view Endpoint); void SetOrchestratorBasePath(std::filesystem::path BasePath); + void NotifyOrchestratorChanged(); std::string m_OrchestratorEndpoint; std::filesystem::path m_OrchestratorBasePath; Stopwatch m_OrchestratorQueryTimer; + std::atomic<bool> m_OrchestratorQueryForced{false}; std::unordered_set<std::string> m_KnownWorkerUris; void UpdateCoordinatorState(); + // WebSocket subscription to orchestrator push notifications + struct OrchestratorWsHandler : public IWsClientHandler + { + Impl& Owner; + + explicit OrchestratorWsHandler(Impl& InOwner) : Owner(InOwner) {} + + void OnWsOpen() override + { + ZEN_LOG_INFO(Owner.m_Log, "orchestrator WebSocket connected"); + Owner.NotifyOrchestratorChanged(); + } + + void OnWsMessage(const WebSocketMessage&) override { Owner.NotifyOrchestratorChanged(); } + + void OnWsClose(uint16_t Code, std::string_view Reason) override + { + ZEN_LOG_WARN(Owner.m_Log, "orchestrator WebSocket closed (code {}: {})", Code, Reason); + } + }; + + std::unique_ptr<OrchestratorWsHandler> m_OrchestratorWsHandler; + std::unique_ptr<HttpWsClient> m_OrchestratorWsClient; + // Worker registration and discovery struct FunctionDefinition @@ -157,6 +193,8 @@ struct ComputeServiceSession::Impl std::atomic<int32_t> m_ActionsCounter = 0; // sequence number metrics::Meter m_ArrivalRate; + std::atomic<IComputeCompletionObserver*> m_CompletionObserver{nullptr}; + RwLock m_PendingLock; std::map<int, Ref<RunnerAction>> m_PendingActions; @@ -267,6 +305,8 @@ struct ComputeServiceSession::Impl void DrainQueue(int QueueId); ComputeServiceSession::EnqueueResult EnqueueActionToQueue(int QueueId, CbObject ActionObject, int Priority); ComputeServiceSession::EnqueueResult EnqueueResolvedActionToQueue(int QueueId, WorkerDesc Worker, CbObject ActionObj, int Priority); + ComputeServiceSession::EnqueueResult ValidateQueueForEnqueue(int QueueId, Ref<QueueEntry>& OutQueue); + void ActivateActionInQueue(const Ref<QueueEntry>& Queue, int Lsn); void GetQueueCompleted(int QueueId, CbWriter& Cbo); void NotifyQueueActionComplete(int QueueId, int Lsn, RunnerAction::State ActionState); void ExpireCompletedQueues(); @@ -292,11 +332,13 @@ struct ComputeServiceSession::Impl void HandleActionUpdates(); void PostUpdate(RunnerAction* Action); + void RemoveActionFromActiveMaps(int ActionLsn); static constexpr int kDefaultMaxRetries = 3; int GetMaxRetriesForQueue(int QueueId); ComputeServiceSession::RescheduleResult RescheduleAction(int ActionLsn); + ComputeServiceSession::RescheduleResult RetractAction(int ActionLsn); ActionCounts GetActionCounts() { @@ -449,6 +491,28 @@ void ComputeServiceSession::Impl::SetOrchestratorEndpoint(std::string_view Endpoint) { m_OrchestratorEndpoint = Endpoint; + + // Subscribe to orchestrator WebSocket push so we discover worker changes + // immediately instead of waiting for the next polling cycle. + try + { + std::string WsUrl = HttpToWsUrl(Endpoint, "/orch/ws"); + + m_OrchestratorWsHandler = std::make_unique<OrchestratorWsHandler>(*this); + + HttpWsClientSettings WsSettings; + WsSettings.LogCategory = "orch_disc_ws"; + WsSettings.ConnectTimeout = std::chrono::milliseconds{3000}; + + m_OrchestratorWsClient = std::make_unique<HttpWsClient>(WsUrl, *m_OrchestratorWsHandler, WsSettings); + m_OrchestratorWsClient->Connect(); + } + catch (const std::exception& Ex) + { + ZEN_WARN("failed to connect orchestrator WebSocket, falling back to polling: {}", Ex.what()); + m_OrchestratorWsClient.reset(); + m_OrchestratorWsHandler.reset(); + } } void @@ -458,6 +522,13 @@ ComputeServiceSession::Impl::SetOrchestratorBasePath(std::filesystem::path BaseP } void +ComputeServiceSession::Impl::NotifyOrchestratorChanged() +{ + m_OrchestratorQueryForced.store(true, std::memory_order_relaxed); + m_SchedulingThreadEvent.Set(); +} + +void ComputeServiceSession::Impl::UpdateCoordinatorState() { ZEN_TRACE_CPU("ComputeServiceSession::UpdateCoordinatorState"); @@ -467,10 +538,14 @@ ComputeServiceSession::Impl::UpdateCoordinatorState() } // Poll faster when we have no discovered workers yet so remote runners come online quickly - const uint64_t PollIntervalMs = m_KnownWorkerUris.empty() ? 500 : 5000; - if (m_OrchestratorQueryTimer.GetElapsedTimeMs() < PollIntervalMs) + const bool Forced = m_OrchestratorQueryForced.exchange(false, std::memory_order_relaxed); + if (!Forced) { - return; + const uint64_t PollIntervalMs = m_KnownWorkerUris.empty() ? 500 : 5000; + if (m_OrchestratorQueryTimer.GetElapsedTimeMs() < PollIntervalMs) + { + return; + } } m_OrchestratorQueryTimer.Reset(); @@ -520,7 +595,24 @@ ComputeServiceSession::Impl::UpdateCoordinatorState() continue; } - ZEN_INFO("discovered new worker at {}", UriStr); + std::string_view Hostname = Worker["hostname"sv].AsString(); + std::string_view Platform = Worker["platform"sv].AsString(); + int Cpus = Worker["cpus"sv].AsInt32(); + uint64_t MemTotal = Worker["memory_total"sv].AsUInt64(); + + if (!Hostname.empty()) + { + ZEN_INFO("discovered new worker at {} ({}, {}, {} cpus, {:.1f} GB)", + UriStr, + Hostname, + Platform, + Cpus, + static_cast<double>(MemTotal) / (1024.0 * 1024.0 * 1024.0)); + } + else + { + ZEN_INFO("discovered new worker at {}", UriStr); + } m_KnownWorkerUris.insert(UriStr); @@ -598,6 +690,15 @@ ComputeServiceSession::Impl::Shutdown() { RequestStateTransition(SessionState::Sunset); + // Close orchestrator WebSocket before stopping the scheduler thread + // to prevent callbacks into a shutting-down scheduler. + if (m_OrchestratorWsClient) + { + m_OrchestratorWsClient->Close(); + m_OrchestratorWsClient.reset(); + } + m_OrchestratorWsHandler.reset(); + m_SchedulingThreadEnabled = false; m_SchedulingThreadEvent.Set(); if (m_SchedulingThread.joinable()) @@ -720,8 +821,14 @@ ComputeServiceSession::Impl::RegisterWorker(CbPackage Worker) // different descriptor. Thus we only need to call this the first time, when the // worker is added - m_LocalRunnerGroup.RegisterWorker(Worker); - m_RemoteRunnerGroup.RegisterWorker(Worker); + if (!m_LocalRunnerGroup.RegisterWorker(Worker)) + { + ZEN_WARN("failed to register worker {} on one or more local runners", WorkerId); + } + if (!m_RemoteRunnerGroup.RegisterWorker(Worker)) + { + ZEN_WARN("failed to register worker {} on one or more remote runners", WorkerId); + } if (m_Recorder) { @@ -767,7 +874,10 @@ ComputeServiceSession::Impl::SyncWorkersToRunner(FunctionRunner& Runner) for (const CbPackage& Worker : Workers) { - Runner.RegisterWorker(Worker); + if (!Runner.RegisterWorker(Worker)) + { + ZEN_WARN("failed to sync worker {} to runner", Worker.GetObjectHash()); + } } } @@ -868,9 +978,7 @@ ComputeServiceSession::Impl::EnqueueResolvedAction(int QueueId, WorkerDesc Worke if (m_SessionState.load(std::memory_order_relaxed) != SessionState::Ready) { - CbObjectWriter Writer; - Writer << "error"sv << fmt::format("session is not accepting actions (state: {})", ToString(m_SessionState.load())); - return {0, Writer.Save()}; + return MakeErrorResult(fmt::format("session is not accepting actions (state: {})", ToString(m_SessionState.load()))); } const int ActionLsn = ++m_ActionsCounter; @@ -1258,42 +1366,51 @@ ComputeServiceSession::Impl::DrainQueue(int QueueId) } ComputeServiceSession::EnqueueResult -ComputeServiceSession::Impl::EnqueueActionToQueue(int QueueId, CbObject ActionObject, int Priority) +ComputeServiceSession::Impl::ValidateQueueForEnqueue(int QueueId, Ref<QueueEntry>& OutQueue) { - Ref<QueueEntry> Queue = FindQueue(QueueId); + OutQueue = FindQueue(QueueId); - if (!Queue) + if (!OutQueue) { - CbObjectWriter Writer; - Writer << "error"sv - << "queue not found"sv; - return {0, Writer.Save()}; + return MakeErrorResult("queue not found"sv); } - QueueState QState = Queue->State.load(); + QueueState QState = OutQueue->State.load(); if (QState == QueueState::Cancelled) { - CbObjectWriter Writer; - Writer << "error"sv - << "queue is cancelled"sv; - return {0, Writer.Save()}; + return MakeErrorResult("queue is cancelled"sv); } if (QState == QueueState::Draining) { - CbObjectWriter Writer; - Writer << "error"sv - << "queue is draining"sv; - return {0, Writer.Save()}; + return MakeErrorResult("queue is draining"sv); + } + + return {}; +} + +void +ComputeServiceSession::Impl::ActivateActionInQueue(const Ref<QueueEntry>& Queue, int Lsn) +{ + Queue->m_Lock.WithExclusiveLock([&] { Queue->ActiveLsns.insert(Lsn); }); + Queue->ActiveCount.fetch_add(1, std::memory_order_relaxed); + Queue->IdleSince.store(0, std::memory_order_relaxed); +} + +ComputeServiceSession::EnqueueResult +ComputeServiceSession::Impl::EnqueueActionToQueue(int QueueId, CbObject ActionObject, int Priority) +{ + Ref<QueueEntry> Queue; + if (EnqueueResult Error = ValidateQueueForEnqueue(QueueId, Queue); Error.ResponseMessage) + { + return Error; } EnqueueResult Result = EnqueueAction(QueueId, ActionObject, Priority); if (Result.Lsn != 0) { - Queue->m_Lock.WithExclusiveLock([&] { Queue->ActiveLsns.insert(Result.Lsn); }); - Queue->ActiveCount.fetch_add(1, std::memory_order_relaxed); - Queue->IdleSince.store(0, std::memory_order_relaxed); + ActivateActionInQueue(Queue, Result.Lsn); } return Result; @@ -1302,40 +1419,17 @@ ComputeServiceSession::Impl::EnqueueActionToQueue(int QueueId, CbObject ActionOb ComputeServiceSession::EnqueueResult ComputeServiceSession::Impl::EnqueueResolvedActionToQueue(int QueueId, WorkerDesc Worker, CbObject ActionObj, int Priority) { - Ref<QueueEntry> Queue = FindQueue(QueueId); - - if (!Queue) - { - CbObjectWriter Writer; - Writer << "error"sv - << "queue not found"sv; - return {0, Writer.Save()}; - } - - QueueState QState = Queue->State.load(); - if (QState == QueueState::Cancelled) + Ref<QueueEntry> Queue; + if (EnqueueResult Error = ValidateQueueForEnqueue(QueueId, Queue); Error.ResponseMessage) { - CbObjectWriter Writer; - Writer << "error"sv - << "queue is cancelled"sv; - return {0, Writer.Save()}; - } - - if (QState == QueueState::Draining) - { - CbObjectWriter Writer; - Writer << "error"sv - << "queue is draining"sv; - return {0, Writer.Save()}; + return Error; } EnqueueResult Result = EnqueueResolvedAction(QueueId, Worker, ActionObj, Priority); if (Result.Lsn != 0) { - Queue->m_Lock.WithExclusiveLock([&] { Queue->ActiveLsns.insert(Result.Lsn); }); - Queue->ActiveCount.fetch_add(1, std::memory_order_relaxed); - Queue->IdleSince.store(0, std::memory_order_relaxed); + ActivateActionInQueue(Queue, Result.Lsn); } return Result; @@ -1770,6 +1864,68 @@ ComputeServiceSession::Impl::RescheduleAction(int ActionLsn) return {.Success = true, .RetryCount = NewRetryCount}; } +ComputeServiceSession::RescheduleResult +ComputeServiceSession::Impl::RetractAction(int ActionLsn) +{ + Ref<RunnerAction> Action; + bool WasRunning = false; + + // Look for the action in pending or running maps + m_RunningLock.WithSharedLock([&] { + if (auto It = m_RunningMap.find(ActionLsn); It != m_RunningMap.end()) + { + Action = It->second; + WasRunning = true; + } + }); + + if (!Action) + { + m_PendingLock.WithSharedLock([&] { + if (auto It = m_PendingActions.find(ActionLsn); It != m_PendingActions.end()) + { + Action = It->second; + } + }); + } + + if (!Action) + { + return {.Success = false, .Error = "Action not found in pending or running maps"}; + } + + if (!Action->RetractAction()) + { + return {.Success = false, .Error = "Action cannot be retracted from its current state"}; + } + + // If the action was running, send a cancellation signal to the runner + if (WasRunning) + { + m_LocalRunnerGroup.CancelAction(ActionLsn); + } + + ZEN_INFO("action {} ({}) retract requested", Action->ActionId, ActionLsn); + return {.Success = true, .RetryCount = Action->RetryCount.load(std::memory_order_relaxed)}; +} + +void +ComputeServiceSession::Impl::RemoveActionFromActiveMaps(int ActionLsn) +{ + m_RunningLock.WithExclusiveLock([&] { + m_PendingLock.WithExclusiveLock([&] { + if (auto FindIt = m_RunningMap.find(ActionLsn); FindIt == m_RunningMap.end()) + { + m_PendingActions.erase(ActionLsn); + } + else + { + m_RunningMap.erase(FindIt); + } + }); + }); +} + void ComputeServiceSession::Impl::HandleActionUpdates() { @@ -1781,6 +1937,10 @@ ComputeServiceSession::Impl::HandleActionUpdates() std::unordered_set<int> SeenLsn; + // Collect terminal action notifications for the completion observer. + // Inline capacity of 64 avoids heap allocation in the common case. + eastl::fixed_vector<IComputeCompletionObserver::CompletedActionNotification, 64> TerminalBatch; + // Process each action's latest state, deduplicating by LSN. // // This is safe because state transitions are monotonically increasing by enum @@ -1798,7 +1958,23 @@ ComputeServiceSession::Impl::HandleActionUpdates() { // Newly enqueued — add to pending map for scheduling case RunnerAction::State::Pending: - m_PendingLock.WithExclusiveLock([&] { m_PendingActions.insert({ActionLsn, Action}); }); + // Guard against a race where the session is abandoned between + // EnqueueAction (which calls PostUpdate) and this scheduler + // tick. AbandonAllActions() only scans m_PendingActions, so it + // misses actions still in m_UpdatedActions at the time the + // session transitions. Detect that here and immediately abandon + // rather than inserting into the pending map, where they would + // otherwise be stuck indefinitely. + if (m_SessionState.load(std::memory_order_relaxed) >= SessionState::Abandoned) + { + Action->SetActionState(RunnerAction::State::Abandoned); + // SetActionState calls PostUpdate; the Abandoned action + // will be processed as a terminal on the next scheduler pass. + } + else + { + m_PendingLock.WithExclusiveLock([&] { m_PendingActions.insert({ActionLsn, Action}); }); + } break; // Async submission in progress — remains in pending map @@ -1816,6 +1992,15 @@ ComputeServiceSession::Impl::HandleActionUpdates() ZEN_DEBUG("action {} ({}) RUNNING", Action->ActionId, ActionLsn); break; + // Retracted — pull back for rescheduling without counting against retry limit + case RunnerAction::State::Retracted: + { + RemoveActionFromActiveMaps(ActionLsn); + Action->ResetActionStateToPending(); + ZEN_INFO("action {} ({}) retracted for rescheduling", Action->ActionId, ActionLsn); + break; + } + // Terminal states — move to results, record history, notify queue case RunnerAction::State::Completed: case RunnerAction::State::Failed: @@ -1834,19 +2019,7 @@ ComputeServiceSession::Impl::HandleActionUpdates() if (Action->RetryCount.load(std::memory_order_relaxed) < MaxRetries) { - // Remove from whichever active map the action is in before resetting - m_RunningLock.WithExclusiveLock([&] { - m_PendingLock.WithExclusiveLock([&] { - if (auto FindIt = m_RunningMap.find(ActionLsn); FindIt == m_RunningMap.end()) - { - m_PendingActions.erase(ActionLsn); - } - else - { - m_RunningMap.erase(FindIt); - } - }); - }); + RemoveActionFromActiveMaps(ActionLsn); // Reset triggers PostUpdate() which re-enters the action as Pending Action->ResetActionStateToPending(); @@ -1861,19 +2034,14 @@ ComputeServiceSession::Impl::HandleActionUpdates() } } - // Remove from whichever active map the action is in - m_RunningLock.WithExclusiveLock([&] { - m_PendingLock.WithExclusiveLock([&] { - if (auto FindIt = m_RunningMap.find(ActionLsn); FindIt == m_RunningMap.end()) - { - m_PendingActions.erase(ActionLsn); - } - else - { - m_RunningMap.erase(FindIt); - } - }); - }); + RemoveActionFromActiveMaps(ActionLsn); + + // Update queue counters BEFORE publishing the result into + // m_ResultsMap. GetActionResult erases from m_ResultsMap + // under m_ResultsLock, so if we updated counters after + // releasing that lock, a caller could observe ActiveCount + // still at 1 immediately after GetActionResult returned OK. + NotifyQueueActionComplete(Action->QueueId, ActionLsn, TerminalState); m_ResultsLock.WithExclusiveLock([&] { m_ResultsMap[ActionLsn] = Action; @@ -1902,16 +2070,46 @@ ComputeServiceSession::Impl::HandleActionUpdates() }); m_RetiredCount.fetch_add(1); m_ResultRate.Mark(1); + { + using ObserverState = IComputeCompletionObserver::ActionState; + ObserverState NotifyState{}; + switch (TerminalState) + { + case RunnerAction::State::Completed: + NotifyState = ObserverState::Completed; + break; + case RunnerAction::State::Failed: + NotifyState = ObserverState::Failed; + break; + case RunnerAction::State::Abandoned: + NotifyState = ObserverState::Abandoned; + break; + case RunnerAction::State::Cancelled: + NotifyState = ObserverState::Cancelled; + break; + default: + break; + } + TerminalBatch.push_back({ActionLsn, NotifyState}); + } ZEN_DEBUG("action {} ({}) RUNNING -> COMPLETED with {}", Action->ActionId, ActionLsn, TerminalState == RunnerAction::State::Completed ? "SUCCESS" : "FAILURE"); - NotifyQueueActionComplete(Action->QueueId, ActionLsn, TerminalState); break; } } } } + + // Notify the completion observer, if any, about all terminal actions in this batch. + if (!TerminalBatch.empty()) + { + if (IComputeCompletionObserver* Observer = m_CompletionObserver.load(std::memory_order_acquire)) + { + Observer->OnActionsCompleted({TerminalBatch.data(), TerminalBatch.size()}); + } + } } size_t @@ -2014,6 +2212,12 @@ ComputeServiceSession::SetOrchestratorBasePath(std::filesystem::path BasePath) } void +ComputeServiceSession::NotifyOrchestratorChanged() +{ + m_Impl->NotifyOrchestratorChanged(); +} + +void ComputeServiceSession::StartRecording(ChunkResolver& InResolver, const std::filesystem::path& RecordingPath) { m_Impl->StartRecording(InResolver, RecordingPath); @@ -2194,6 +2398,12 @@ ComputeServiceSession::RescheduleAction(int ActionLsn) return m_Impl->RescheduleAction(ActionLsn); } +ComputeServiceSession::RescheduleResult +ComputeServiceSession::RetractAction(int ActionLsn) +{ + return m_Impl->RetractAction(ActionLsn); +} + std::vector<ComputeServiceSession::RunningActionInfo> ComputeServiceSession::GetRunningActions() { @@ -2219,6 +2429,12 @@ ComputeServiceSession::GetCompleted(CbWriter& Cbo) } void +ComputeServiceSession::SetCompletionObserver(IComputeCompletionObserver* Observer) +{ + m_Impl->m_CompletionObserver.store(Observer, std::memory_order_release); +} + +void ComputeServiceSession::PostUpdate(RunnerAction* Action) { m_Impl->PostUpdate(Action); diff --git a/src/zencompute/httpcomputeservice.cpp b/src/zencompute/httpcomputeservice.cpp index e82a40781..bdfd9d197 100644 --- a/src/zencompute/httpcomputeservice.cpp +++ b/src/zencompute/httpcomputeservice.cpp @@ -16,6 +16,7 @@ # include <zencore/iobuffer.h> # include <zencore/iohash.h> # include <zencore/logging.h> +# include <zencore/string.h> # include <zencore/system.h> # include <zencore/thread.h> # include <zencore/trace.h> @@ -23,8 +24,10 @@ # include <zenstore/cidstore.h> # include <zentelemetry/stats.h> +# include <algorithm> # include <span> # include <unordered_map> +# include <vector> using namespace std::literals; @@ -50,6 +53,11 @@ struct HttpComputeService::Impl ComputeServiceSession m_ComputeService; SystemMetricsTracker m_MetricsTracker; + // WebSocket connections (completion push) + + RwLock m_WsConnectionsLock; + std::vector<Ref<WebSocketConnection>> m_WsConnections; + // Metrics metrics::OperationTiming m_HttpRequests; @@ -91,6 +99,12 @@ struct HttpComputeService::Impl void HandleWorkersAllGet(HttpServerRequest& HttpReq); void WriteQueueDescription(CbWriter& Cbo, int QueueId, const ComputeServiceSession::QueueStatus& Status); void HandleWorkerRequest(HttpServerRequest& HttpReq, const IoHash& WorkerId); + void HandleSubmitAction(HttpServerRequest& HttpReq, int QueueId, int Priority, const WorkerDesc* Worker); + + // WebSocket / observer + void OnWebSocketOpen(Ref<WebSocketConnection> Connection); + void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code); + void OnActionsCompleted(std::span<const IComputeCompletionObserver::CompletedActionNotification> Actions); void RegisterRoutes(); @@ -110,6 +124,7 @@ struct HttpComputeService::Impl m_ComputeService.WaitUntilReady(); m_StatsService.RegisterHandler("compute", *m_Self); RegisterRoutes(); + m_ComputeService.SetCompletionObserver(m_Self); } }; @@ -149,7 +164,7 @@ HttpComputeService::Impl::RegisterRoutes() return HttpReq.WriteResponse(HttpResponseCode::Forbidden); } - bool Success = m_ComputeService.RequestStateTransition(ComputeServiceSession::SessionState::Abandoned); + bool Success = m_ComputeService.Abandon(); if (Success) { @@ -325,6 +340,29 @@ HttpComputeService::Impl::RegisterRoutes() HttpVerb::kGet | HttpVerb::kPost); m_Router.RegisterRoute( + "jobs/{lsn}/retract", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const int ActionLsn = ParseInt<int>(Req.GetCapture(1)).value_or(0); + + auto Result = m_ComputeService.RetractAction(ActionLsn); + + CbObjectWriter Cbo; + if (Result.Success) + { + Cbo << "success"sv << true; + Cbo << "lsn"sv << ActionLsn; + HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + } + else + { + Cbo << "error"sv << Result.Error; + HttpReq.WriteResponse(HttpResponseCode::Conflict, Cbo.Save()); + } + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( "jobs/{worker}/{action}", // This route is inefficient, and is only here for backwards compatibility. The preferred path is the // one which uses the scheduled action lsn for lookups [this](HttpRouterRequest& Req) { @@ -373,127 +411,7 @@ HttpComputeService::Impl::RegisterRoutes() RequestPriority = ParseInt<int>(PriorityParam).value_or(-1); } - switch (HttpReq.RequestVerb()) - { - case HttpVerb::kGet: - // TODO: return status of all pending or executing jobs - break; - - case HttpVerb::kPost: - switch (HttpReq.RequestContentType()) - { - case HttpContentType::kCbObject: - { - // This operation takes the proposed job spec and identifies which - // chunks are not present on this server. This list is then returned in - // the "need" list in the response - - IoBuffer Payload = HttpReq.ReadPayload(); - CbObject ActionObj = LoadCompactBinaryObject(Payload); - - std::vector<IoHash> NeedList; - - ActionObj.IterateAttachments([&](CbFieldView Field) { - const IoHash FileHash = Field.AsHash(); - - if (!m_CidStore.ContainsChunk(FileHash)) - { - NeedList.push_back(FileHash); - } - }); - - if (NeedList.empty()) - { - // We already have everything, enqueue the action for execution - - if (ComputeServiceSession::EnqueueResult Result = - m_ComputeService.EnqueueResolvedAction(Worker, ActionObj, RequestPriority)) - { - ZEN_DEBUG("action {} accepted (lsn {})", ActionObj.GetHash(), Result.Lsn); - - HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); - } - - return; - } - - CbObjectWriter Cbo; - Cbo.BeginArray("need"); - - for (const IoHash& Hash : NeedList) - { - Cbo << Hash; - } - - Cbo.EndArray(); - CbObject Response = Cbo.Save(); - - return HttpReq.WriteResponse(HttpResponseCode::NotFound, Response); - } - break; - - case HttpContentType::kCbPackage: - { - CbPackage Action = HttpReq.ReadPayloadPackage(); - CbObject ActionObj = Action.GetObject(); - - std::span<const CbAttachment> Attachments = Action.GetAttachments(); - - int AttachmentCount = 0; - int NewAttachmentCount = 0; - uint64_t TotalAttachmentBytes = 0; - uint64_t TotalNewBytes = 0; - - for (const CbAttachment& Attachment : Attachments) - { - ZEN_ASSERT(Attachment.IsCompressedBinary()); - - const IoHash DataHash = Attachment.GetHash(); - CompressedBuffer DataView = Attachment.AsCompressedBinary(); - - ZEN_UNUSED(DataHash); - - const uint64_t CompressedSize = DataView.GetCompressedSize(); - - TotalAttachmentBytes += CompressedSize; - ++AttachmentCount; - - const CidStore::InsertResult InsertResult = - m_CidStore.AddChunk(DataView.GetCompressed().Flatten().AsIoBuffer(), DataHash); - - if (InsertResult.New) - { - TotalNewBytes += CompressedSize; - ++NewAttachmentCount; - } - } - - if (ComputeServiceSession::EnqueueResult Result = - m_ComputeService.EnqueueResolvedAction(Worker, ActionObj, RequestPriority)) - { - ZEN_DEBUG("accepted action {} (lsn {}): {} in {} attachments. {} new ({} attachments)", - ActionObj.GetHash(), - Result.Lsn, - zen::NiceBytes(TotalAttachmentBytes), - AttachmentCount, - zen::NiceBytes(TotalNewBytes), - NewAttachmentCount); - - HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); - } - - return; - } - break; - - default: - break; - } - break; - - default: - break; - } + HandleSubmitAction(HttpReq, 0, RequestPriority, &Worker); }, HttpVerb::kPost); @@ -511,118 +429,7 @@ HttpComputeService::Impl::RegisterRoutes() RequestPriority = ParseInt<int>(PriorityParam).value_or(-1); } - // Resolve worker - - // - - switch (HttpReq.RequestContentType()) - { - case HttpContentType::kCbObject: - { - // This operation takes the proposed job spec and identifies which - // chunks are not present on this server. This list is then returned in - // the "need" list in the response - - IoBuffer Payload = HttpReq.ReadPayload(); - CbObject ActionObj = LoadCompactBinaryObject(Payload); - - std::vector<IoHash> NeedList; - - ActionObj.IterateAttachments([&](CbFieldView Field) { - const IoHash FileHash = Field.AsHash(); - - if (!m_CidStore.ContainsChunk(FileHash)) - { - NeedList.push_back(FileHash); - } - }); - - if (NeedList.empty()) - { - // We already have everything, enqueue the action for execution - - if (ComputeServiceSession::EnqueueResult Result = m_ComputeService.EnqueueAction(ActionObj, RequestPriority)) - { - ZEN_DEBUG("action accepted (lsn {})", Result.Lsn); - - return HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); - } - else - { - // Could not resolve? - return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage); - } - } - - CbObjectWriter Cbo; - Cbo.BeginArray("need"); - - for (const IoHash& Hash : NeedList) - { - Cbo << Hash; - } - - Cbo.EndArray(); - CbObject Response = Cbo.Save(); - - return HttpReq.WriteResponse(HttpResponseCode::NotFound, Response); - } - - case HttpContentType::kCbPackage: - { - CbPackage Action = HttpReq.ReadPayloadPackage(); - CbObject ActionObj = Action.GetObject(); - - std::span<const CbAttachment> Attachments = Action.GetAttachments(); - - int AttachmentCount = 0; - int NewAttachmentCount = 0; - uint64_t TotalAttachmentBytes = 0; - uint64_t TotalNewBytes = 0; - - for (const CbAttachment& Attachment : Attachments) - { - ZEN_ASSERT(Attachment.IsCompressedBinary()); - - const IoHash DataHash = Attachment.GetHash(); - CompressedBuffer DataView = Attachment.AsCompressedBinary(); - - ZEN_UNUSED(DataHash); - - const uint64_t CompressedSize = DataView.GetCompressedSize(); - - TotalAttachmentBytes += CompressedSize; - ++AttachmentCount; - - const CidStore::InsertResult InsertResult = - m_CidStore.AddChunk(DataView.GetCompressed().Flatten().AsIoBuffer(), DataHash); - - if (InsertResult.New) - { - TotalNewBytes += CompressedSize; - ++NewAttachmentCount; - } - } - - if (ComputeServiceSession::EnqueueResult Result = m_ComputeService.EnqueueAction(ActionObj, RequestPriority)) - { - ZEN_DEBUG("accepted action (lsn {}): {} in {} attachments. {} new ({} attachments)", - Result.Lsn, - zen::NiceBytes(TotalAttachmentBytes), - AttachmentCount, - zen::NiceBytes(TotalNewBytes), - NewAttachmentCount); - - HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); - } - else - { - // Could not resolve? - return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage); - } - } - return; - } + HandleSubmitAction(HttpReq, 0, RequestPriority, nullptr); }, HttpVerb::kPost); @@ -1090,72 +897,7 @@ HttpComputeService::Impl::RegisterRoutes() RequestPriority = ParseInt<int>(PriorityParam).value_or(-1); } - switch (HttpReq.RequestContentType()) - { - case HttpContentType::kCbObject: - { - IoBuffer Payload = HttpReq.ReadPayload(); - CbObject ActionObj = LoadCompactBinaryObject(Payload); - - std::vector<IoHash> NeedList; - - if (!CheckAttachments(ActionObj, NeedList)) - { - CbObjectWriter Cbo; - Cbo.BeginArray("need"); - - for (const IoHash& Hash : NeedList) - { - Cbo << Hash; - } - - Cbo.EndArray(); - - return HttpReq.WriteResponse(HttpResponseCode::NotFound, Cbo.Save()); - } - - if (ComputeServiceSession::EnqueueResult Result = - m_ComputeService.EnqueueResolvedActionToQueue(QueueId, Worker, ActionObj, RequestPriority)) - { - ZEN_DEBUG("queue {}: action {} accepted (lsn {})", QueueId, ActionObj.GetHash(), Result.Lsn); - return HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); - } - else - { - return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage); - } - } - - case HttpContentType::kCbPackage: - { - CbPackage Action = HttpReq.ReadPayloadPackage(); - CbObject ActionObj = Action.GetObject(); - - IngestStats Stats = IngestPackageAttachments(Action); - - if (ComputeServiceSession::EnqueueResult Result = - m_ComputeService.EnqueueResolvedActionToQueue(QueueId, Worker, ActionObj, RequestPriority)) - { - ZEN_DEBUG("queue {}: accepted action {} (lsn {}): {} in {} attachments. {} new ({} attachments)", - QueueId, - ActionObj.GetHash(), - Result.Lsn, - zen::NiceBytes(Stats.Bytes), - Stats.Count, - zen::NiceBytes(Stats.NewBytes), - Stats.NewCount); - - return HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); - } - else - { - return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage); - } - } - - default: - break; - } + HandleSubmitAction(HttpReq, QueueId, RequestPriority, &Worker); }, HttpVerb::kPost); @@ -1178,71 +920,7 @@ HttpComputeService::Impl::RegisterRoutes() RequestPriority = ParseInt<int>(PriorityParam).value_or(-1); } - switch (HttpReq.RequestContentType()) - { - case HttpContentType::kCbObject: - { - IoBuffer Payload = HttpReq.ReadPayload(); - CbObject ActionObj = LoadCompactBinaryObject(Payload); - - std::vector<IoHash> NeedList; - - if (!CheckAttachments(ActionObj, NeedList)) - { - CbObjectWriter Cbo; - Cbo.BeginArray("need"); - - for (const IoHash& Hash : NeedList) - { - Cbo << Hash; - } - - Cbo.EndArray(); - - return HttpReq.WriteResponse(HttpResponseCode::NotFound, Cbo.Save()); - } - - if (ComputeServiceSession::EnqueueResult Result = - m_ComputeService.EnqueueActionToQueue(QueueId, ActionObj, RequestPriority)) - { - ZEN_DEBUG("queue {}: action accepted (lsn {})", QueueId, Result.Lsn); - return HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); - } - else - { - return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage); - } - } - - case HttpContentType::kCbPackage: - { - CbPackage Action = HttpReq.ReadPayloadPackage(); - CbObject ActionObj = Action.GetObject(); - - IngestStats Stats = IngestPackageAttachments(Action); - - if (ComputeServiceSession::EnqueueResult Result = - m_ComputeService.EnqueueActionToQueue(QueueId, ActionObj, RequestPriority)) - { - ZEN_DEBUG("queue {}: accepted action (lsn {}): {} in {} attachments. {} new ({} attachments)", - QueueId, - Result.Lsn, - zen::NiceBytes(Stats.Bytes), - Stats.Count, - zen::NiceBytes(Stats.NewBytes), - Stats.NewCount); - - return HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); - } - else - { - return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage); - } - } - - default: - break; - } + HandleSubmitAction(HttpReq, QueueId, RequestPriority, nullptr); }, HttpVerb::kPost); @@ -1306,6 +984,45 @@ HttpComputeService::Impl::RegisterRoutes() } }, HttpVerb::kGet | HttpVerb::kPost); + + m_Router.RegisterRoute( + "queues/{queueref}/jobs/{lsn}/retract", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const int QueueId = ResolveQueueRef(HttpReq, Req.GetCapture(1)); + const int ActionLsn = ParseInt<int>(Req.GetCapture(2)).value_or(0); + + if (QueueId == 0) + { + return; + } + + ZEN_UNUSED(QueueId); + + auto Result = m_ComputeService.RetractAction(ActionLsn); + + CbObjectWriter Cbo; + if (Result.Success) + { + Cbo << "success"sv << true; + Cbo << "lsn"sv << ActionLsn; + HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + } + else + { + Cbo << "error"sv << Result.Error; + HttpReq.WriteResponse(HttpResponseCode::Conflict, Cbo.Save()); + } + }, + HttpVerb::kPost); + + // WebSocket upgrade endpoint — the handler logic lives in + // HttpComputeService::OnWebSocket* methods; this route merely + // satisfies the router so the upgrade request isn't rejected. + m_Router.RegisterRoute( + "ws", + [this](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK); }, + HttpVerb::kGet); } ////////////////////////////////////////////////////////////////////////// @@ -1320,12 +1037,17 @@ HttpComputeService::HttpComputeService(CidStore& InCidStore, HttpComputeService::~HttpComputeService() { + m_Impl->m_ComputeService.SetCompletionObserver(nullptr); m_Impl->m_StatsService.UnregisterHandler("compute", *this); } void HttpComputeService::Shutdown() { + // Null out observer before shutting down the compute session to prevent + // callbacks into a partially-torn-down service. + m_Impl->m_ComputeService.SetCompletionObserver(nullptr); + m_Impl->m_WsConnectionsLock.WithExclusiveLock([&] { m_Impl->m_WsConnections.clear(); }); m_Impl->m_ComputeService.Shutdown(); } @@ -1492,6 +1214,184 @@ HttpComputeService::Impl::CheckAttachments(const CbObject& ActionObj, std::vecto } void +HttpComputeService::Impl::HandleSubmitAction(HttpServerRequest& HttpReq, int QueueId, int Priority, const WorkerDesc* Worker) +{ + // QueueId > 0: queue-scoped enqueue; QueueId == 0: implicit queue (global routes) + auto Enqueue = [&](CbObject ActionObj) -> ComputeServiceSession::EnqueueResult { + if (QueueId > 0) + { + if (Worker) + { + return m_ComputeService.EnqueueResolvedActionToQueue(QueueId, *Worker, ActionObj, Priority); + } + return m_ComputeService.EnqueueActionToQueue(QueueId, ActionObj, Priority); + } + else + { + if (Worker) + { + return m_ComputeService.EnqueueResolvedAction(*Worker, ActionObj, Priority); + } + return m_ComputeService.EnqueueAction(ActionObj, Priority); + } + }; + + // Read payload upfront and handle attachments based on content type + CbObject Body; + IngestStats Stats = {}; + + switch (HttpReq.RequestContentType()) + { + case HttpContentType::kCbObject: + { + IoBuffer Payload = HttpReq.ReadPayload(); + Body = LoadCompactBinaryObject(Payload); + break; + } + + case HttpContentType::kCbPackage: + { + CbPackage Package = HttpReq.ReadPayloadPackage(); + Body = Package.GetObject(); + Stats = IngestPackageAttachments(Package); + break; + } + + default: + return HttpReq.WriteResponse(HttpResponseCode::BadRequest); + } + + // Check for "actions" array to determine batch vs single-action path + CbArray Actions = Body.Find("actions"sv).AsArray(); + + if (Actions.Num() > 0) + { + // --- Batch path --- + + // For CbObject payloads, check all attachments upfront before enqueuing anything + if (HttpReq.RequestContentType() == HttpContentType::kCbObject) + { + std::vector<IoHash> NeedList; + + for (CbField ActionField : Actions) + { + CbObject ActionObj = ActionField.AsObject(); + CheckAttachments(ActionObj, NeedList); + } + + if (!NeedList.empty()) + { + CbObjectWriter Cbo; + Cbo.BeginArray("need"); + + for (const IoHash& Hash : NeedList) + { + Cbo << Hash; + } + + Cbo.EndArray(); + + return HttpReq.WriteResponse(HttpResponseCode::NotFound, Cbo.Save()); + } + } + + // Enqueue all actions and collect results + CbObjectWriter Cbo; + int Accepted = 0; + + Cbo.BeginArray("results"); + + for (CbField ActionField : Actions) + { + CbObject ActionObj = ActionField.AsObject(); + + ComputeServiceSession::EnqueueResult Result = Enqueue(ActionObj); + + Cbo.BeginObject(); + + if (Result) + { + Cbo << "lsn"sv << Result.Lsn; + ++Accepted; + } + else + { + Cbo << "error"sv << Result.ResponseMessage; + } + + Cbo.EndObject(); + } + + Cbo.EndArray(); + + if (Stats.Count > 0) + { + ZEN_DEBUG("queue {}: batch accepted {}/{} actions: {} in {} attachments. {} new ({} attachments)", + QueueId, + Accepted, + Actions.Num(), + zen::NiceBytes(Stats.Bytes), + Stats.Count, + zen::NiceBytes(Stats.NewBytes), + Stats.NewCount); + } + else + { + ZEN_DEBUG("queue {}: batch accepted {}/{} actions", QueueId, Accepted, Actions.Num()); + } + + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + } + + // --- Single-action path: Body is the action itself --- + + if (HttpReq.RequestContentType() == HttpContentType::kCbObject) + { + std::vector<IoHash> NeedList; + + if (!CheckAttachments(Body, NeedList)) + { + CbObjectWriter Cbo; + Cbo.BeginArray("need"); + + for (const IoHash& Hash : NeedList) + { + Cbo << Hash; + } + + Cbo.EndArray(); + + return HttpReq.WriteResponse(HttpResponseCode::NotFound, Cbo.Save()); + } + } + + if (ComputeServiceSession::EnqueueResult Result = Enqueue(Body)) + { + if (Stats.Count > 0) + { + ZEN_DEBUG("queue {}: accepted action {} (lsn {}): {} in {} attachments. {} new ({} attachments)", + QueueId, + Body.GetHash(), + Result.Lsn, + zen::NiceBytes(Stats.Bytes), + Stats.Count, + zen::NiceBytes(Stats.NewBytes), + Stats.NewCount); + } + else + { + ZEN_DEBUG("queue {}: action {} accepted (lsn {})", QueueId, Body.GetHash(), Result.Lsn); + } + + return HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); + } + else + { + return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage); + } +} + +void HttpComputeService::Impl::HandleWorkersGet(HttpServerRequest& HttpReq) { CbObjectWriter Cbo; @@ -1632,6 +1532,136 @@ HttpComputeService::Impl::HandleWorkerRequest(HttpServerRequest& HttpReq, const } ////////////////////////////////////////////////////////////////////////// +// +// IWebSocketHandler +// + +void +HttpComputeService::OnWebSocketOpen(Ref<WebSocketConnection> Connection) +{ + m_Impl->OnWebSocketOpen(std::move(Connection)); +} + +void +HttpComputeService::OnWebSocketMessage([[maybe_unused]] WebSocketConnection& Conn, [[maybe_unused]] const WebSocketMessage& Msg) +{ + // Clients are receive-only; ignore any inbound messages. +} + +void +HttpComputeService::OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, [[maybe_unused]] std::string_view Reason) +{ + m_Impl->OnWebSocketClose(Conn, Code); +} + +void +HttpComputeService::OnActionsCompleted(std::span<const CompletedActionNotification> Actions) +{ + m_Impl->OnActionsCompleted(Actions); +} + +////////////////////////////////////////////////////////////////////////// +// +// Impl — WebSocket / observer +// + +void +HttpComputeService::Impl::OnWebSocketOpen(Ref<WebSocketConnection> Connection) +{ + ZEN_INFO("compute WebSocket client connected"); + m_WsConnectionsLock.WithExclusiveLock([&] { m_WsConnections.push_back(std::move(Connection)); }); +} + +void +HttpComputeService::Impl::OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code) +{ + ZEN_INFO("compute WebSocket client disconnected (code {})", Code); + + m_WsConnectionsLock.WithExclusiveLock([&] { + auto It = std::remove_if(m_WsConnections.begin(), m_WsConnections.end(), [&Conn](const Ref<WebSocketConnection>& C) { + return C.Get() == &Conn; + }); + m_WsConnections.erase(It, m_WsConnections.end()); + }); +} + +void +HttpComputeService::Impl::OnActionsCompleted(std::span<const IComputeCompletionObserver::CompletedActionNotification> Actions) +{ + using ActionState = IComputeCompletionObserver::ActionState; + using CompletedActionNotification = IComputeCompletionObserver::CompletedActionNotification; + + // Snapshot connections under shared lock + eastl::fixed_vector<Ref<WebSocketConnection>, 16> Connections; + m_WsConnectionsLock.WithSharedLock([&] { Connections = {begin(m_WsConnections), end(m_WsConnections)}; }); + + if (Connections.empty()) + { + return; + } + + // Build CompactBinary notification grouped by state: + // {"Completed": [lsn, ...], "Failed": [lsn, ...], ...} + // Each state name becomes an array key containing the LSNs in that state. + CbObjectWriter Cbo; + + // Sort by state so we can emit one array per state in a single pass. + // Copy into a local vector since the span is const. + eastl::fixed_vector<CompletedActionNotification, 16> Sorted(Actions.begin(), Actions.end()); + std::sort(Sorted.begin(), Sorted.end(), [](const auto& A, const auto& B) { return A.State < B.State; }); + + ActionState CurrentState{}; + bool ArrayOpen = false; + + for (const CompletedActionNotification& Action : Sorted) + { + if (!ArrayOpen || Action.State != CurrentState) + { + if (ArrayOpen) + { + Cbo.EndArray(); + } + CurrentState = Action.State; + Cbo.BeginArray(IComputeCompletionObserver::ActionStateToString(CurrentState)); + ArrayOpen = true; + } + Cbo.AddInteger(Action.Lsn); + } + + if (ArrayOpen) + { + Cbo.EndArray(); + } + + CbObject Msg = Cbo.Save(); + MemoryView MsgView = Msg.GetView(); + + // Broadcast to all connected clients, prune closed ones + bool HadClosedConnections = false; + for (auto& Conn : Connections) + { + if (Conn->IsOpen()) + { + Conn->SendBinary(MsgView); + } + else + { + HadClosedConnections = true; + } + } + + if (HadClosedConnections) + { + m_WsConnectionsLock.WithExclusiveLock([&] { + auto It = std::remove_if(m_WsConnections.begin(), m_WsConnections.end(), [](const Ref<WebSocketConnection>& C) { + return !C->IsOpen(); + }); + m_WsConnections.erase(It, m_WsConnections.end()); + }); + } +} + +////////////////////////////////////////////////////////////////////////// void httpcomputeservice_forcelink() diff --git a/src/zencompute/include/zencompute/computeservice.h b/src/zencompute/include/zencompute/computeservice.h index 65ec5f9ee..1ca78738a 100644 --- a/src/zencompute/include/zencompute/computeservice.h +++ b/src/zencompute/include/zencompute/computeservice.h @@ -13,6 +13,7 @@ # include <zenhttp/httpcommon.h> # include <filesystem> +# include <span> namespace zen { class ChunkResolver; @@ -29,6 +30,53 @@ class RemoteHttpRunner; struct RunnerAction; struct SubmitResult; +/** + * Observer interface for action completion notifications. + * + * Implementors receive a batch of notifications whenever actions reach a + * terminal state (Completed, Failed, Abandoned, Cancelled). The callback + * fires on the scheduler thread *after* the action result has been placed + * in m_ResultsMap, so GET /jobs/{lsn} will succeed by the time the client + * reacts to the notification. + */ +class IComputeCompletionObserver +{ +public: + virtual ~IComputeCompletionObserver() = default; + + enum class ActionState + { + Completed, + Failed, + Abandoned, + Cancelled, + }; + + struct CompletedActionNotification + { + int Lsn; + ActionState State; + }; + + static constexpr std::string_view ActionStateToString(ActionState State) + { + switch (State) + { + case ActionState::Completed: + return "Completed"; + case ActionState::Failed: + return "Failed"; + case ActionState::Abandoned: + return "Abandoned"; + case ActionState::Cancelled: + return "Cancelled"; + } + return "Unknown"; + } + + virtual void OnActionsCompleted(std::span<const CompletedActionNotification> Actions) = 0; +}; + struct WorkerDesc { CbPackage Descriptor; @@ -91,11 +139,25 @@ public: // Sunset can be reached from any non-Sunset state. bool RequestStateTransition(SessionState NewState); + // Convenience helpers for common state transitions. + bool Ready() { return RequestStateTransition(SessionState::Ready); } + bool Abandon() { return RequestStateTransition(SessionState::Abandoned); } + // Orchestration void SetOrchestratorEndpoint(std::string_view Endpoint); void SetOrchestratorBasePath(std::filesystem::path BasePath); + void SetOrchestrator(std::string_view Endpoint, std::filesystem::path BasePath) + { + SetOrchestratorEndpoint(Endpoint); + SetOrchestratorBasePath(std::move(BasePath)); + } + + /// Immediately wake the scheduler to re-poll the orchestrator for worker changes. + /// Resets the polling throttle so the next scheduler tick calls UpdateCoordinatorState(). + void NotifyOrchestratorChanged(); + // Worker registration and discovery void RegisterWorker(CbPackage Worker); @@ -182,6 +244,7 @@ public: }; [[nodiscard]] RescheduleResult RescheduleAction(int ActionLsn); + [[nodiscard]] RescheduleResult RetractAction(int ActionLsn); void GetCompleted(CbWriter&); @@ -215,7 +278,7 @@ public: // sized to match RunnerAction::State::_Count but we can't use the enum here // for dependency reasons, so just use a fixed size array and static assert in // the implementation file - uint64_t Timestamps[8] = {}; + uint64_t Timestamps[9] = {}; }; [[nodiscard]] std::vector<ActionHistoryEntry> GetActionHistory(int Limit = 100); @@ -235,6 +298,10 @@ public: void EmitStats(CbObjectWriter& Cbo); + // Completion observer (used by HttpComputeService for WebSocket push) + + void SetCompletionObserver(IComputeCompletionObserver* Observer); + // Recording void StartRecording(ChunkResolver& InResolver, const std::filesystem::path& RecordingPath); diff --git a/src/zencompute/include/zencompute/httpcomputeservice.h b/src/zencompute/include/zencompute/httpcomputeservice.h index ee1cd2614..b58e73a0d 100644 --- a/src/zencompute/include/zencompute/httpcomputeservice.h +++ b/src/zencompute/include/zencompute/httpcomputeservice.h @@ -9,6 +9,7 @@ # include "zencompute/computeservice.h" # include <zenhttp/httpserver.h> +# include <zenhttp/websocket.h> # include <filesystem> # include <memory> @@ -22,7 +23,7 @@ namespace zen::compute { /** * HTTP interface for compute service */ -class HttpComputeService : public HttpService, public IHttpStatsProvider +class HttpComputeService : public HttpService, public IHttpStatsProvider, public IWebSocketHandler, public IComputeCompletionObserver { public: HttpComputeService(CidStore& InCidStore, @@ -42,6 +43,16 @@ public: void HandleStatsRequest(HttpServerRequest& Request) override; + // IWebSocketHandler + + void OnWebSocketOpen(Ref<WebSocketConnection> Connection) override; + void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) override; + void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason) override; + + // IComputeCompletionObserver + + void OnActionsCompleted(std::span<const CompletedActionNotification> Actions) override; + private: struct Impl; std::unique_ptr<Impl> m_Impl; diff --git a/src/zencompute/runners/functionrunner.cpp b/src/zencompute/runners/functionrunner.cpp index 768cdf1e1..4f116e7d8 100644 --- a/src/zencompute/runners/functionrunner.cpp +++ b/src/zencompute/runners/functionrunner.cpp @@ -215,15 +215,22 @@ BaseRunnerGroup::GetSubmittedActionCount() return TotalCount; } -void +bool BaseRunnerGroup::RegisterWorker(CbPackage Worker) { RwLock::SharedLockScope _(m_RunnersLock); + bool AllSucceeded = true; + for (auto& Runner : m_Runners) { - Runner->RegisterWorker(Worker); + if (!Runner->RegisterWorker(Worker)) + { + AllSucceeded = false; + } } + + return AllSucceeded; } void @@ -276,12 +283,34 @@ RunnerAction::~RunnerAction() } bool +RunnerAction::RetractAction() +{ + State CurrentState = m_ActionState.load(); + + do + { + // Only allow retraction from pre-terminal states (idempotent if already retracted) + if (CurrentState > State::Running) + { + return CurrentState == State::Retracted; + } + + if (m_ActionState.compare_exchange_strong(CurrentState, State::Retracted)) + { + this->Timestamps[static_cast<int>(State::Retracted)] = DateTime::Now().GetTicks(); + m_OwnerSession->PostUpdate(this); + return true; + } + } while (true); +} + +bool RunnerAction::ResetActionStateToPending() { - // Only allow reset from Failed or Abandoned states + // Only allow reset from Failed, Abandoned, or Retracted states State CurrentState = m_ActionState.load(); - if (CurrentState != State::Failed && CurrentState != State::Abandoned) + if (CurrentState != State::Failed && CurrentState != State::Abandoned && CurrentState != State::Retracted) { return false; } @@ -305,8 +334,11 @@ RunnerAction::ResetActionStateToPending() CpuUsagePercent.store(-1.0f, std::memory_order_relaxed); CpuSeconds.store(0.0f, std::memory_order_relaxed); - // Increment retry count - RetryCount.fetch_add(1, std::memory_order_relaxed); + // Increment retry count (skip for Retracted — nothing failed) + if (CurrentState != State::Retracted) + { + RetryCount.fetch_add(1, std::memory_order_relaxed); + } // Re-enter the scheduler pipeline m_OwnerSession->PostUpdate(this); diff --git a/src/zencompute/runners/functionrunner.h b/src/zencompute/runners/functionrunner.h index f67414dbb..56c3f3af0 100644 --- a/src/zencompute/runners/functionrunner.h +++ b/src/zencompute/runners/functionrunner.h @@ -29,8 +29,8 @@ public: FunctionRunner(std::filesystem::path BasePath); virtual ~FunctionRunner() = 0; - virtual void Shutdown() = 0; - virtual void RegisterWorker(const CbPackage& WorkerPackage) = 0; + virtual void Shutdown() = 0; + [[nodiscard]] virtual bool RegisterWorker(const CbPackage& WorkerPackage) = 0; [[nodiscard]] virtual SubmitResult SubmitAction(Ref<RunnerAction> Action) = 0; [[nodiscard]] virtual size_t GetSubmittedActionCount() = 0; @@ -63,7 +63,7 @@ public: SubmitResult SubmitAction(Ref<RunnerAction> Action); std::vector<SubmitResult> SubmitActions(const std::vector<Ref<RunnerAction>>& Actions); size_t GetSubmittedActionCount(); - void RegisterWorker(CbPackage Worker); + [[nodiscard]] bool RegisterWorker(CbPackage Worker); void Shutdown(); bool CancelAction(int ActionLsn); void CancelRemoteQueue(int QueueId); @@ -114,6 +114,30 @@ struct RunnerGroup : public BaseRunnerGroup /** * This represents an action going through different stages of scheduling and execution. + * + * State machine + * ============= + * + * Normal forward flow (enforced by SetActionState rejecting backward transitions): + * + * New -> Pending -> Submitting -> Running -> Completed + * -> Failed + * -> Abandoned + * -> Cancelled + * + * Rescheduling (via ResetActionStateToPending): + * + * Failed ---> Pending (increments RetryCount, subject to retry limit) + * Abandoned ---> Pending (increments RetryCount, subject to retry limit) + * Retracted ---> Pending (does NOT increment RetryCount) + * + * Retraction (via RetractAction, idempotent): + * + * Pending/Submitting/Running -> Retracted -> Pending (rescheduled) + * + * Retracted is placed after Cancelled in enum order so that once set, + * no runner-side transition (Completed/Failed) can override it via + * SetActionState's forward-only rule. */ struct RunnerAction : public RefCounted { @@ -137,16 +161,20 @@ struct RunnerAction : public RefCounted enum class State { - New, - Pending, - Submitting, - Running, - Completed, - Failed, - Abandoned, - Cancelled, + New, // Initial state at construction, before entering the scheduler + Pending, // Queued and waiting for a runner slot + Submitting, // Being handed off to a runner (async submission in progress) + Running, // Executing on a runner process + Completed, // Finished successfully with results available + Failed, // Execution failed (transient error, eligible for retry) + Abandoned, // Infrastructure termination (e.g. spot eviction, session abandon) + Cancelled, // Intentional user cancellation (never retried) + Retracted, // Pulled back for rescheduling on a different runner (no retry cost) _Count }; + static_assert(State::Retracted > State::Completed && State::Retracted > State::Failed && State::Retracted > State::Abandoned && + State::Retracted > State::Cancelled, + "Retracted must be the highest terminal ordinal so runner-side transitions cannot override it"); static const char* ToString(State _) { @@ -168,6 +196,8 @@ struct RunnerAction : public RefCounted return "Abandoned"; case State::Cancelled: return "Cancelled"; + case State::Retracted: + return "Retracted"; default: return "Unknown"; } @@ -191,6 +221,7 @@ struct RunnerAction : public RefCounted void SetActionState(State NewState); bool IsSuccess() const { return ActionState() == State::Completed; } + bool RetractAction(); bool ResetActionStateToPending(); bool IsCompleted() const { diff --git a/src/zencompute/runners/localrunner.cpp b/src/zencompute/runners/localrunner.cpp index 7aaefb06e..b61e0a46f 100644 --- a/src/zencompute/runners/localrunner.cpp +++ b/src/zencompute/runners/localrunner.cpp @@ -7,14 +7,16 @@ # include <zencore/compactbinary.h> # include <zencore/compactbinarybuilder.h> # include <zencore/compactbinarypackage.h> +# include <zencore/compactbinaryfile.h> # include <zencore/compress.h> # include <zencore/except_fmt.h> # include <zencore/filesystem.h> # include <zencore/fmtutils.h> # include <zencore/iobuffer.h> # include <zencore/iohash.h> -# include <zencore/system.h> # include <zencore/scopeguard.h> +# include <zencore/stream.h> +# include <zencore/system.h> # include <zencore/timer.h> # include <zencore/trace.h> # include <zenstore/cidstore.h> @@ -152,7 +154,7 @@ LocalProcessRunner::CreateNewSandbox() return Path; } -void +bool LocalProcessRunner::RegisterWorker(const CbPackage& WorkerPackage) { ZEN_TRACE_CPU("LocalProcessRunner::RegisterWorker"); @@ -173,6 +175,8 @@ LocalProcessRunner::RegisterWorker(const CbPackage& WorkerPackage) ZEN_INFO("dumped worker '{}' to 'file://{}'", WorkerId, Path); } + + return true; } size_t @@ -301,7 +305,7 @@ LocalProcessRunner::PrepareActionSubmission(Ref<RunnerAction> Action) // Write out action - zen::WriteFile(SandboxPath / "build.action", ActionObj.GetBuffer().AsIoBuffer()); + WriteCompactBinaryObject(SandboxPath / "build.action", ActionObj); // Manifest inputs in sandbox diff --git a/src/zencompute/runners/localrunner.h b/src/zencompute/runners/localrunner.h index 7493e980b..b8cff6826 100644 --- a/src/zencompute/runners/localrunner.h +++ b/src/zencompute/runners/localrunner.h @@ -51,7 +51,7 @@ public: ~LocalProcessRunner(); virtual void Shutdown() override; - virtual void RegisterWorker(const CbPackage& WorkerPackage) override; + [[nodiscard]] virtual bool RegisterWorker(const CbPackage& WorkerPackage) override; [[nodiscard]] virtual SubmitResult SubmitAction(Ref<RunnerAction> Action) override; [[nodiscard]] virtual bool IsHealthy() override { return true; } [[nodiscard]] virtual size_t GetSubmittedActionCount() override; diff --git a/src/zencompute/runners/remotehttprunner.cpp b/src/zencompute/runners/remotehttprunner.cpp index 672636d06..ce6a81173 100644 --- a/src/zencompute/runners/remotehttprunner.cpp +++ b/src/zencompute/runners/remotehttprunner.cpp @@ -42,6 +42,20 @@ RemoteHttpRunner::RemoteHttpRunner(ChunkResolver& InChunkResolver, , m_Http(m_BaseUrl) , m_InstanceId(Oid::NewOid()) { + // Attempt to connect a WebSocket for push-based completion notifications. + // If the remote doesn't support WS, OnWsClose fires and we fall back to polling. + { + std::string WsUrl = HttpToWsUrl(HostName, "/compute/ws"); + + HttpWsClientSettings WsSettings; + WsSettings.LogCategory = "http_exec_ws"; + WsSettings.ConnectTimeout = std::chrono::milliseconds{3000}; + + IWsClientHandler& Handler = *this; + m_WsClient = std::make_unique<HttpWsClient>(WsUrl, Handler, WsSettings); + m_WsClient->Connect(); + } + m_MonitorThread = std::thread{&RemoteHttpRunner::MonitorThreadFunction, this}; } @@ -53,7 +67,29 @@ RemoteHttpRunner::~RemoteHttpRunner() void RemoteHttpRunner::Shutdown() { - // TODO: should cleanly drain/cancel pending work + m_AcceptNewActions = false; + + // Close the WebSocket client first, so no more wakeup signals arrive. + if (m_WsClient) + { + m_WsClient->Close(); + } + + // Cancel all known remote queues so the remote side stops scheduling new + // work and cancels in-flight actions belonging to those queues. + + { + std::vector<std::pair<int, Oid>> Queues; + + m_QueueTokenLock.WithSharedLock([&] { Queues.assign(m_RemoteQueueTokens.begin(), m_RemoteQueueTokens.end()); }); + + for (const auto& [QueueId, Token] : Queues) + { + CancelRemoteQueue(QueueId); + } + } + + // Stop the monitor thread so it no longer polls the remote. m_MonitorThreadEnabled = false; m_MonitorThreadEvent.Set(); @@ -61,9 +97,22 @@ RemoteHttpRunner::Shutdown() { m_MonitorThread.join(); } + + // Drain the running map and mark all remaining actions as Failed so the + // scheduler can reschedule or finalize them. + + std::unordered_map<int, HttpRunningAction> Remaining; + + m_RunningLock.WithExclusiveLock([&] { Remaining.swap(m_RemoteRunningMap); }); + + for (auto& [RemoteLsn, HttpAction] : Remaining) + { + ZEN_DEBUG("shutdown: marking remote action LSN {} (local LSN {}) as Failed", RemoteLsn, HttpAction.Action->ActionLsn); + HttpAction.Action->SetActionState(RunnerAction::State::Failed); + } } -void +bool RemoteHttpRunner::RegisterWorker(const CbPackage& WorkerPackage) { ZEN_TRACE_CPU("RemoteHttpRunner::RegisterWorker"); @@ -125,15 +174,13 @@ RemoteHttpRunner::RegisterWorker(const CbPackage& WorkerPackage) if (!IsHttpSuccessCode(PayloadResponse.StatusCode)) { ZEN_ERROR("ERROR: unable to register payloads for worker {} at {}{}", WorkerId, m_Http.GetBaseUri(), WorkerUrl); - - // TODO: propagate error + return false; } } else if (!IsHttpSuccessCode(DescResponse.StatusCode)) { ZEN_ERROR("ERROR: unable to register worker {} at {}{}", WorkerId, m_Http.GetBaseUri(), WorkerUrl); - - // TODO: propagate error + return false; } else { @@ -152,14 +199,20 @@ RemoteHttpRunner::RegisterWorker(const CbPackage& WorkerPackage) WorkerUrl, (int)WorkerResponse.StatusCode, ToString(WorkerResponse.StatusCode)); - - // TODO: propagate error + return false; } + + return true; } size_t RemoteHttpRunner::QueryCapacity() { + if (!m_AcceptNewActions) + { + return 0; + } + // Estimate how much more work we're ready to accept RwLock::SharedLockScope _{m_RunningLock}; @@ -191,24 +244,68 @@ RemoteHttpRunner::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions) return Results; } - // For larger batches, submit HTTP requests in parallel via the shared worker pool + // Collect distinct QueueIds and ensure remote queues exist once per queue - std::vector<std::future<SubmitResult>> Futures; - Futures.reserve(Actions.size()); + std::unordered_map<int, Oid> QueueTokens; // QueueId → remote token (0 stays as Zero) for (const Ref<RunnerAction>& Action : Actions) { - std::packaged_task<SubmitResult()> Task([this, Action]() { return SubmitAction(Action); }); + const int QueueId = Action->QueueId; + if (QueueId != 0 && QueueTokens.find(QueueId) == QueueTokens.end()) + { + CbObject QueueMeta = Action->GetOwnerSession()->GetQueueMetadata(QueueId); + CbObject QueueConfig = Action->GetOwnerSession()->GetQueueConfig(QueueId); + QueueTokens[QueueId] = EnsureRemoteQueue(QueueId, QueueMeta, QueueConfig); + } + } - Futures.push_back(m_WorkerPool.EnqueueTask(std::move(Task), WorkerThreadPool::EMode::EnableBacklog)); + // Group actions by QueueId + + struct QueueGroup + { + std::vector<Ref<RunnerAction>> Actions; + std::vector<size_t> OriginalIndices; + }; + + std::unordered_map<int, QueueGroup> Groups; + + for (size_t i = 0; i < Actions.size(); ++i) + { + auto& Group = Groups[Actions[i]->QueueId]; + Group.Actions.push_back(Actions[i]); + Group.OriginalIndices.push_back(i); } - std::vector<SubmitResult> Results; - Results.reserve(Futures.size()); + // Submit each group as a batch and map results back to original indices - for (auto& Future : Futures) + std::vector<SubmitResult> Results(Actions.size()); + + for (auto& [QueueId, Group] : Groups) { - Results.push_back(Future.get()); + std::string SubmitUrl = "/jobs"; + if (QueueId != 0) + { + if (Oid Token = QueueTokens[QueueId]; Token != Oid::Zero) + { + SubmitUrl = fmt::format("/queues/{}/jobs", Token); + } + } + + const size_t BatchLimit = size_t(m_MaxBatchSize); + + for (size_t Offset = 0; Offset < Group.Actions.size(); Offset += BatchLimit) + { + size_t End = zen::Min(Offset + BatchLimit, Group.Actions.size()); + + std::vector<Ref<RunnerAction>> Chunk(Group.Actions.begin() + Offset, Group.Actions.begin() + End); + + std::vector<SubmitResult> ChunkResults = SubmitActionBatch(SubmitUrl, Chunk); + + for (size_t j = 0; j < ChunkResults.size(); ++j) + { + Results[Group.OriginalIndices[Offset + j]] = std::move(ChunkResults[j]); + } + } } return Results; @@ -221,6 +318,11 @@ RemoteHttpRunner::SubmitAction(Ref<RunnerAction> Action) // Verify whether we can accept more work + if (!m_AcceptNewActions) + { + return SubmitResult{.IsAccepted = false, .Reason = "runner is shutting down"}; + } + { RwLock::SharedLockScope _{m_RunningLock}; if (m_RemoteRunningMap.size() >= size_t(m_MaxRunningActions)) @@ -275,7 +377,7 @@ RemoteHttpRunner::SubmitAction(Ref<RunnerAction> Action) m_Http.GetBaseUri(), ActionId); - RegisterWorker(Action->Worker.Descriptor); + (void)RegisterWorker(Action->Worker.Descriptor); } else { @@ -384,6 +486,194 @@ RemoteHttpRunner::SubmitAction(Ref<RunnerAction> Action) return {}; } +std::vector<SubmitResult> +RemoteHttpRunner::SubmitActionBatch(const std::string& SubmitUrl, const std::vector<Ref<RunnerAction>>& Actions) +{ + ZEN_TRACE_CPU("RemoteHttpRunner::SubmitActionBatch"); + + if (!m_AcceptNewActions) + { + return std::vector<SubmitResult>(Actions.size(), SubmitResult{.IsAccepted = false, .Reason = "runner is shutting down"}); + } + + // Capacity check + + { + RwLock::SharedLockScope _{m_RunningLock}; + if (m_RemoteRunningMap.size() >= size_t(m_MaxRunningActions)) + { + std::vector<SubmitResult> Results(Actions.size(), SubmitResult{.IsAccepted = false}); + return Results; + } + } + + // Per-action setup and build batch body + + CbObjectWriter Body; + Body.BeginArray("actions"sv); + + for (const Ref<RunnerAction>& Action : Actions) + { + Action->ExecutionLocation = m_HostName; + MaybeDumpAction(Action->ActionLsn, Action->ActionObj); + Body.AddObject(Action->ActionObj); + } + + Body.EndArray(); + + // POST the batch + + HttpClient::Response Response = m_Http.Post(SubmitUrl, Body.Save()); + + if (Response.StatusCode == HttpResponseCode::OK) + { + return ParseBatchResponse(Response, Actions); + } + + if (Response.StatusCode == HttpResponseCode::NotFound) + { + // Server needs attachments — resolve them and retry with a CbPackage + + CbObject NeedObj = Response.AsObject(); + + CbPackage Pkg; + Pkg.SetObject(Body.Save()); + + for (auto& Item : NeedObj["need"sv]) + { + const IoHash NeedHash = Item.AsHash(); + + if (IoBuffer Chunk = m_ChunkResolver.FindChunkByCid(NeedHash)) + { + uint64_t DataRawSize = 0; + IoHash DataRawHash; + CompressedBuffer Compressed = + CompressedBuffer::FromCompressed(SharedBuffer{Chunk}, /* out */ DataRawHash, /* out */ DataRawSize); + + ZEN_ASSERT(DataRawHash == NeedHash); + + Pkg.AddAttachment(CbAttachment(Compressed, NeedHash)); + } + else + { + ZEN_WARN("batch submit: missing attachment {} — falling back to individual submit", NeedHash); + return FallbackToIndividualSubmit(Actions); + } + } + + HttpClient::Response RetryResponse = m_Http.Post(SubmitUrl, Pkg); + + if (RetryResponse.StatusCode == HttpResponseCode::OK) + { + return ParseBatchResponse(RetryResponse, Actions); + } + + ZEN_WARN("batch submit retry failed with {} {} — falling back to individual submit", + (int)RetryResponse.StatusCode, + ToString(RetryResponse.StatusCode)); + return FallbackToIndividualSubmit(Actions); + } + + // Unexpected status or connection error — fall back to individual submission + + if (Response) + { + ZEN_WARN("batch submit to {}{} returned {} {} — falling back to individual submit", + m_Http.GetBaseUri(), + SubmitUrl, + (int)Response.StatusCode, + ToString(Response.StatusCode)); + } + else + { + ZEN_WARN("batch submit to {}{} failed — falling back to individual submit", m_Http.GetBaseUri(), SubmitUrl); + } + + return FallbackToIndividualSubmit(Actions); +} + +std::vector<SubmitResult> +RemoteHttpRunner::ParseBatchResponse(const HttpClient::Response& Response, const std::vector<Ref<RunnerAction>>& Actions) +{ + std::vector<SubmitResult> Results; + Results.reserve(Actions.size()); + + CbObject ResponseObj = Response.AsObject(); + CbArrayView ResultArray = ResponseObj["results"sv].AsArrayView(); + + size_t Index = 0; + for (CbFieldView Field : ResultArray) + { + if (Index >= Actions.size()) + { + break; + } + + CbObjectView Entry = Field.AsObjectView(); + const int32_t LsnField = Entry["lsn"sv].AsInt32(0); + + if (LsnField > 0) + { + HttpRunningAction NewAction; + NewAction.Action = Actions[Index]; + NewAction.RemoteActionLsn = LsnField; + + { + RwLock::ExclusiveLockScope _(m_RunningLock); + m_RemoteRunningMap[LsnField] = std::move(NewAction); + } + + ZEN_DEBUG("batch: scheduled action {} with remote LSN {} (local LSN {})", + Actions[Index]->ActionObj.GetHash(), + LsnField, + Actions[Index]->ActionLsn); + + Actions[Index]->SetActionState(RunnerAction::State::Running); + + Results.push_back(SubmitResult{.IsAccepted = true}); + } + else + { + std::string_view ErrorMsg = Entry["error"sv].AsString(); + Results.push_back(SubmitResult{.IsAccepted = false, .Reason = std::string(ErrorMsg)}); + } + + ++Index; + } + + // If the server returned fewer results than actions, mark the rest as not accepted + while (Results.size() < Actions.size()) + { + Results.push_back(SubmitResult{.IsAccepted = false, .Reason = "no result from server"}); + } + + return Results; +} + +std::vector<SubmitResult> +RemoteHttpRunner::FallbackToIndividualSubmit(const std::vector<Ref<RunnerAction>>& Actions) +{ + std::vector<std::future<SubmitResult>> Futures; + Futures.reserve(Actions.size()); + + for (const Ref<RunnerAction>& Action : Actions) + { + std::packaged_task<SubmitResult()> Task([this, Action]() { return SubmitAction(Action); }); + + Futures.push_back(m_WorkerPool.EnqueueTask(std::move(Task), WorkerThreadPool::EMode::EnableBacklog)); + } + + std::vector<SubmitResult> Results; + Results.reserve(Futures.size()); + + for (auto& Future : Futures) + { + Results.push_back(Future.get()); + } + + return Results; +} + Oid RemoteHttpRunner::EnsureRemoteQueue(int QueueId, const CbObject& Metadata, const CbObject& Config) { @@ -481,6 +771,35 @@ RemoteHttpRunner::GetSubmittedActionCount() return m_RemoteRunningMap.size(); } +////////////////////////////////////////////////////////////////////////// +// +// IWsClientHandler +// + +void +RemoteHttpRunner::OnWsOpen() +{ + ZEN_INFO("WebSocket connected to {}", m_HostName); + m_WsConnected.store(true, std::memory_order_release); +} + +void +RemoteHttpRunner::OnWsMessage([[maybe_unused]] const WebSocketMessage& Msg) +{ + // The message content is a wakeup signal; no parsing needed. + // Signal the monitor thread to sweep completed actions immediately. + m_MonitorThreadEvent.Set(); +} + +void +RemoteHttpRunner::OnWsClose([[maybe_unused]] uint16_t Code, [[maybe_unused]] std::string_view Reason) +{ + ZEN_WARN("WebSocket disconnected from {} (code {})", m_HostName, Code); + m_WsConnected.store(false, std::memory_order_release); +} + +////////////////////////////////////////////////////////////////////////// + void RemoteHttpRunner::MonitorThreadFunction() { @@ -489,28 +808,40 @@ RemoteHttpRunner::MonitorThreadFunction() do { const int NormalWaitingTime = 200; - int WaitTimeMs = NormalWaitingTime; - auto WaitOnce = [&] { return m_MonitorThreadEvent.Wait(WaitTimeMs); }; - auto SweepOnce = [&] { + const int WsWaitingTime = 2000; // Safety-net interval when WS is connected + + int WaitTimeMs = m_WsConnected.load(std::memory_order_relaxed) ? WsWaitingTime : NormalWaitingTime; + auto WaitOnce = [&] { return m_MonitorThreadEvent.Wait(WaitTimeMs); }; + auto SweepOnce = [&] { const size_t RetiredCount = SweepRunningActions(); - m_RunningLock.WithSharedLock([&] { - if (m_RemoteRunningMap.size() > 16) - { - WaitTimeMs = NormalWaitingTime / 4; - } - else - { - if (RetiredCount) + if (m_WsConnected.load(std::memory_order_relaxed)) + { + // WS connected: use long safety-net interval; the WS message + // will wake us immediately for the real work. + WaitTimeMs = WsWaitingTime; + } + else + { + // No WS: adaptive polling as before + m_RunningLock.WithSharedLock([&] { + if (m_RemoteRunningMap.size() > 16) { - WaitTimeMs = NormalWaitingTime / 2; + WaitTimeMs = NormalWaitingTime / 4; } else { - WaitTimeMs = NormalWaitingTime; + if (RetiredCount) + { + WaitTimeMs = NormalWaitingTime / 2; + } + else + { + WaitTimeMs = NormalWaitingTime; + } } - } - }); + }); + } }; while (!WaitOnce()) @@ -518,7 +849,7 @@ RemoteHttpRunner::MonitorThreadFunction() SweepOnce(); } - // Signal received - this may mean we should quit + // Signal received — may be a WS wakeup or a quit signal SweepOnce(); } while (m_MonitorThreadEnabled); diff --git a/src/zencompute/runners/remotehttprunner.h b/src/zencompute/runners/remotehttprunner.h index 9119992a9..c17d0cf2a 100644 --- a/src/zencompute/runners/remotehttprunner.h +++ b/src/zencompute/runners/remotehttprunner.h @@ -14,9 +14,11 @@ # include <zencore/workthreadpool.h> # include <zencore/zencore.h> # include <zenhttp/httpclient.h> +# include <zenhttp/httpwsclient.h> # include <atomic> # include <filesystem> +# include <memory> # include <thread> # include <unordered_map> @@ -32,7 +34,7 @@ namespace zen::compute { */ -class RemoteHttpRunner : public FunctionRunner +class RemoteHttpRunner : public FunctionRunner, private IWsClientHandler { RemoteHttpRunner(RemoteHttpRunner&&) = delete; RemoteHttpRunner& operator=(RemoteHttpRunner&&) = delete; @@ -45,7 +47,7 @@ public: ~RemoteHttpRunner(); virtual void Shutdown() override; - virtual void RegisterWorker(const CbPackage& WorkerPackage) override; + [[nodiscard]] virtual bool RegisterWorker(const CbPackage& WorkerPackage) override; [[nodiscard]] virtual SubmitResult SubmitAction(Ref<RunnerAction> Action) override; [[nodiscard]] virtual bool IsHealthy() override; [[nodiscard]] virtual size_t GetSubmittedActionCount() override; @@ -66,7 +68,9 @@ private: std::string m_BaseUrl; HttpClient m_Http; - int32_t m_MaxRunningActions = 256; // arbitrary limit for testing + std::atomic<bool> m_AcceptNewActions{true}; + int32_t m_MaxRunningActions = 256; // arbitrary limit for testing + int32_t m_MaxBatchSize = 50; struct HttpRunningAction { @@ -92,7 +96,20 @@ private: // creating remote queues. Generated once at construction and never changes. Oid m_InstanceId; + // WebSocket completion notification client + std::unique_ptr<HttpWsClient> m_WsClient; + std::atomic<bool> m_WsConnected{false}; + + // IWsClientHandler + void OnWsOpen() override; + void OnWsMessage(const WebSocketMessage& Msg) override; + void OnWsClose(uint16_t Code, std::string_view Reason) override; + Oid EnsureRemoteQueue(int QueueId, const CbObject& Metadata, const CbObject& Config); + + std::vector<SubmitResult> SubmitActionBatch(const std::string& SubmitUrl, const std::vector<Ref<RunnerAction>>& Actions); + std::vector<SubmitResult> ParseBatchResponse(const HttpClient::Response& Response, const std::vector<Ref<RunnerAction>>& Actions); + std::vector<SubmitResult> FallbackToIndividualSubmit(const std::vector<Ref<RunnerAction>>& Actions); }; } // namespace zen::compute diff --git a/src/zencore/compactbinaryfile.cpp b/src/zencore/compactbinaryfile.cpp index ec2fc3cd5..9ddafbe15 100644 --- a/src/zencore/compactbinaryfile.cpp +++ b/src/zencore/compactbinaryfile.cpp @@ -30,4 +30,15 @@ LoadCompactBinaryObject(const std::filesystem::path& FilePath) return {.Hash = IoHash::Zero}; } +void +WriteCompactBinaryObject(const std::filesystem::path& Path, const CbObject& Object) +{ + // We cannot use CbObject::GetView() here because it may not return a complete + // view since the type byte can be omitted in arrays. + CbWriter Writer; + Writer.AddObject(Object); + CbFieldIterator Fields = Writer.Save(); + zen::WriteFile(Path, IoBufferBuilder::MakeFromMemory(Fields.GetRangeView())); +} + } // namespace zen diff --git a/src/zencore/include/zencore/compactbinaryfile.h b/src/zencore/include/zencore/compactbinaryfile.h index 33f3e7bea..a06524549 100644 --- a/src/zencore/include/zencore/compactbinaryfile.h +++ b/src/zencore/include/zencore/compactbinaryfile.h @@ -15,5 +15,6 @@ struct CbObjectFromFile }; CbObjectFromFile LoadCompactBinaryObject(const std::filesystem::path& FilePath); +void WriteCompactBinaryObject(const std::filesystem::path& Path, const CbObject& Object); } // namespace zen diff --git a/src/zencore/include/zencore/logging/ansicolorsink.h b/src/zencore/include/zencore/logging/ansicolorsink.h index 5060a8393..939c70d12 100644 --- a/src/zencore/include/zencore/logging/ansicolorsink.h +++ b/src/zencore/include/zencore/logging/ansicolorsink.h @@ -15,6 +15,9 @@ enum class ColorMode Auto }; +bool IsColorTerminal(); +bool ResolveColorMode(ColorMode Mode); + class AnsiColorStdoutSink : public Sink { public: diff --git a/src/zencore/include/zencore/logging/formatter.h b/src/zencore/include/zencore/logging/formatter.h index 11904d71d..e605b22b8 100644 --- a/src/zencore/include/zencore/logging/formatter.h +++ b/src/zencore/include/zencore/logging/formatter.h @@ -15,6 +15,12 @@ public: virtual ~Formatter() = default; virtual void Format(const LogMessage& Msg, MemoryBuffer& Dest) = 0; virtual std::unique_ptr<Formatter> Clone() const = 0; + + void SetColorEnabled(bool Enabled) { m_UseColor = Enabled; } + bool IsColorEnabled() const { return m_UseColor; } + +private: + bool m_UseColor = false; }; } // namespace zen::logging diff --git a/src/zencore/include/zencore/logging/helpers.h b/src/zencore/include/zencore/logging/helpers.h index ce021e1a5..765aa59e3 100644 --- a/src/zencore/include/zencore/logging/helpers.h +++ b/src/zencore/include/zencore/logging/helpers.h @@ -119,4 +119,81 @@ LevelToShortString(LogLevel Level) return ToStringView(Level); } +inline std::string_view +AnsiColorForLevel(LogLevel Level) +{ + using namespace std::string_view_literals; + switch (Level) + { + case Trace: + return "\033[37m"sv; // white + case Debug: + return "\033[36m"sv; // cyan + case Info: + return "\033[32m"sv; // green + case Warn: + return "\033[33m\033[1m"sv; // bold yellow + case Err: + return "\033[31m\033[1m"sv; // bold red + case Critical: + return "\033[1m\033[41m"sv; // bold on red background + default: + return "\033[m"sv; + } +} + +inline constexpr std::string_view kAnsiReset = "\033[m"; + +inline void +AppendAnsiColor(LogLevel Level, MemoryBuffer& Dest) +{ + std::string_view Color = AnsiColorForLevel(Level); + Dest.append(Color.data(), Color.data() + Color.size()); +} + +inline void +AppendAnsiReset(MemoryBuffer& Dest) +{ + Dest.append(kAnsiReset.data(), kAnsiReset.data() + kAnsiReset.size()); +} + +// Strip ANSI SGR escape sequences (\033[...m) from the buffer in-place. +// Only sequences terminated by 'm' are removed (colors, bold, underline, etc.). +// Other CSI sequences (cursor movement, erase, etc.) are left intact. +inline void +StripAnsiSgrSequences(MemoryBuffer& Buf) +{ + const char* Src = Buf.data(); + const char* End = Src + Buf.size(); + char* Dst = Buf.data(); + + while (Src < End) + { + if (Src[0] == '\033' && (Src + 1) < End && Src[1] == '[') + { + const char* Seq = Src + 2; + while (Seq < End && *Seq != 'm') + { + ++Seq; + } + if (Seq < End) + { + ++Seq; // skip 'm' + } + Src = Seq; + } + else + { + if (Dst != Src) + { + *Dst = *Src; + } + ++Dst; + ++Src; + } + } + + Buf.resize(static_cast<size_t>(Dst - Buf.data())); +} + } // namespace zen::logging::helpers diff --git a/src/zencore/include/zencore/logging/logmsg.h b/src/zencore/include/zencore/logging/logmsg.h index 1d8b6b1b7..a1acb503b 100644 --- a/src/zencore/include/zencore/logging/logmsg.h +++ b/src/zencore/include/zencore/logging/logmsg.h @@ -40,9 +40,6 @@ struct LogMessage void SetTime(LogClock::time_point InTime) { m_Time = InTime; } void SetSource(const SourceLocation& InSource) { m_Source = InSource; } - mutable size_t ColorRangeStart = 0; - mutable size_t ColorRangeEnd = 0; - private: static constexpr LogPoint s_DefaultPoints[LogLevelCount] = { {{}, Trace, {}}, diff --git a/src/zencore/include/zencore/testing.h b/src/zencore/include/zencore/testing.h index 8410216c4..01356fa00 100644 --- a/src/zencore/include/zencore/testing.h +++ b/src/zencore/include/zencore/testing.h @@ -43,9 +43,8 @@ public: TestRunner(); ~TestRunner(); - void SetDefaultSuiteFilter(const char* Pattern); - int ApplyCommandLine(int Argc, char const* const* Argv); - int Run(); + int ApplyCommandLine(int Argc, char const* const* Argv, const char* DefaultSuiteFilter = nullptr); + int Run(); private: struct Impl; diff --git a/src/zencore/logging/ansicolorsink.cpp b/src/zencore/logging/ansicolorsink.cpp index 540d22359..03aae068a 100644 --- a/src/zencore/logging/ansicolorsink.cpp +++ b/src/zencore/logging/ansicolorsink.cpp @@ -4,12 +4,14 @@ #include <zencore/logging/helpers.h> #include <zencore/logging/messageonlyformatter.h> +#include <zencore/thread.h> + #include <cstdio> #include <cstdlib> -#include <mutex> #if defined(_WIN32) # include <io.h> +# include <zencore/windows.h> # define ZEN_ISATTY _isatty # define ZEN_FILENO _fileno #else @@ -62,188 +64,225 @@ public: Dest.push_back(' '); } - // level (colored range) + // level Dest.push_back('['); - Msg.ColorRangeStart = Dest.size(); + if (IsColorEnabled()) + { + helpers::AppendAnsiColor(Msg.GetLevel(), Dest); + } helpers::AppendStringView(helpers::LevelToShortString(Msg.GetLevel()), Dest); - Msg.ColorRangeEnd = Dest.size(); + if (IsColorEnabled()) + { + helpers::AppendAnsiReset(Dest); + } Dest.push_back(']'); Dest.push_back(' '); - // message - helpers::AppendStringView(Msg.GetPayload(), Dest); + // message (align continuation lines with the first line) + size_t AnsiBytes = IsColorEnabled() ? (helpers::AnsiColorForLevel(Msg.GetLevel()).size() + helpers::kAnsiReset.size()) : 0; + size_t LinePrefixCount = Dest.size() - AnsiBytes; + + auto MsgPayload = Msg.GetPayload(); + auto ItLineBegin = MsgPayload.begin(); + auto ItMessageEnd = MsgPayload.end(); + bool IsFirstLine = true; + + auto ItLineEnd = ItLineBegin; + + auto EmitLine = [&] { + if (IsFirstLine) + { + IsFirstLine = false; + } + else + { + for (size_t i = 0; i < LinePrefixCount; ++i) + { + Dest.push_back(' '); + } + } + helpers::AppendStringView(std::string_view(&*ItLineBegin, ItLineEnd - ItLineBegin), Dest); + }; + + while (ItLineEnd != ItMessageEnd) + { + if (*ItLineEnd++ == '\n') + { + EmitLine(); + ItLineBegin = ItLineEnd; + } + } + + if (ItLineBegin != ItMessageEnd) + { + EmitLine(); + } Dest.push_back('\n'); } - std::unique_ptr<Formatter> Clone() const override { return std::make_unique<DefaultConsoleFormatter>(); } + std::unique_ptr<Formatter> Clone() const override + { + auto Copy = std::make_unique<DefaultConsoleFormatter>(); + Copy->SetColorEnabled(IsColorEnabled()); + return Copy; + } private: std::chrono::seconds m_LastLogSecs{0}; std::tm m_CachedLocalTm{}; }; -static constexpr std::string_view s_Reset = "\033[m"; - -static std::string_view -GetColorForLevel(LogLevel InLevel) +bool +IsColorTerminal() { - using namespace std::string_view_literals; - switch (InLevel) + // If stdout is not a TTY, no color + if (ZEN_ISATTY(ZEN_FILENO(stdout)) == 0) { - case Trace: - return "\033[37m"sv; // white - case Debug: - return "\033[36m"sv; // cyan - case Info: - return "\033[32m"sv; // green - case Warn: - return "\033[33m\033[1m"sv; // bold yellow - case Err: - return "\033[31m\033[1m"sv; // bold red - case Critical: - return "\033[1m\033[41m"sv; // bold on red background - default: - return s_Reset; + return false; } -} -struct AnsiColorStdoutSink::Impl -{ - explicit Impl(ColorMode Mode) : m_Formatter(std::make_unique<DefaultConsoleFormatter>()), m_UseColor(ResolveColorMode(Mode)) {} + // NO_COLOR convention (https://no-color.org/) + if (std::getenv("NO_COLOR") != nullptr) + { + return false; + } - static bool IsColorTerminal() + // COLORTERM is set by terminals that support color (e.g. "truecolor", "24bit") + if (std::getenv("COLORTERM") != nullptr) { - // If stdout is not a TTY, no color - if (ZEN_ISATTY(ZEN_FILENO(stdout)) == 0) - { - return false; - } + return true; + } - // NO_COLOR convention (https://no-color.org/) - if (std::getenv("NO_COLOR") != nullptr) + // Check TERM for known color-capable values + const char* Term = std::getenv("TERM"); + if (Term != nullptr) + { + std::string_view TermView(Term); + // "dumb" terminals do not support color + if (TermView == "dumb") { return false; } - - // COLORTERM is set by terminals that support color (e.g. "truecolor", "24bit") - if (std::getenv("COLORTERM") != nullptr) + // Match against known color-capable terminal types. + // TERM often includes suffixes like "-256color", so we use substring matching. + constexpr std::string_view ColorTerms[] = { + "alacritty", + "ansi", + "color", + "console", + "cygwin", + "gnome", + "konsole", + "kterm", + "linux", + "msys", + "putty", + "rxvt", + "screen", + "tmux", + "vt100", + "vt102", + "xterm", + }; + for (std::string_view Candidate : ColorTerms) { - return true; - } - - // Check TERM for known color-capable values - const char* Term = std::getenv("TERM"); - if (Term != nullptr) - { - std::string_view TermView(Term); - // "dumb" terminals do not support color - if (TermView == "dumb") + if (TermView.find(Candidate) != std::string_view::npos) { - return false; - } - // Match against known color-capable terminal types. - // TERM often includes suffixes like "-256color", so we use substring matching. - constexpr std::string_view ColorTerms[] = { - "alacritty", - "ansi", - "color", - "console", - "cygwin", - "gnome", - "konsole", - "kterm", - "linux", - "msys", - "putty", - "rxvt", - "screen", - "tmux", - "vt100", - "vt102", - "xterm", - }; - for (std::string_view Candidate : ColorTerms) - { - if (TermView.find(Candidate) != std::string_view::npos) - { - return true; - } + return true; } } + } #if defined(_WIN32) - // Windows console supports ANSI color by default in modern versions - return true; + // Windows console supports ANSI color by default in modern versions + return true; #else - // Unknown terminal — be conservative - return false; + // Unknown terminal — be conservative + return false; #endif - } +} - static bool ResolveColorMode(ColorMode Mode) +bool +ResolveColorMode(ColorMode Mode) +{ + switch (Mode) { - switch (Mode) - { - case ColorMode::On: - return true; - case ColorMode::Off: - return false; - case ColorMode::Auto: - default: - return IsColorTerminal(); - } + case ColorMode::On: + return true; + case ColorMode::Off: + return false; + case ColorMode::Auto: + default: + return IsColorTerminal(); } +} - void Log(const LogMessage& Msg) +struct AnsiColorStdoutSink::Impl +{ + explicit Impl(ColorMode Mode) : m_Formatter(std::make_unique<DefaultConsoleFormatter>()), m_UseColor(ResolveColorMode(Mode)) { - std::lock_guard<std::mutex> Lock(m_Mutex); - - MemoryBuffer Formatted; - m_Formatter->Format(Msg, Formatted); + m_Formatter->SetColorEnabled(m_UseColor); + } - if (m_UseColor && Msg.ColorRangeEnd > Msg.ColorRangeStart) - { - // Print pre-color range - fwrite(Formatted.data(), 1, Msg.ColorRangeStart, m_File); + void WriteOutput(const MemoryBuffer& Buf) + { + RwLock::ExclusiveLockScope Lock(m_Lock); - // Print color - std::string_view Color = GetColorForLevel(Msg.GetLevel()); - fwrite(Color.data(), 1, Color.size(), m_File); +#if defined(_WIN32) + DWORD Written; + WriteFile(m_Handle, Buf.data(), static_cast<DWORD>(Buf.size()), &Written, nullptr); +#else + fwrite(Buf.data(), 1, Buf.size(), m_File); +#endif - // Print colored range - fwrite(Formatted.data() + Msg.ColorRangeStart, 1, Msg.ColorRangeEnd - Msg.ColorRangeStart, m_File); + m_Dirty.store(false, std::memory_order_relaxed); + } - // Reset color - fwrite(s_Reset.data(), 1, s_Reset.size(), m_File); + void Log(const LogMessage& Msg) + { + MemoryBuffer Formatted; + m_Formatter->Format(Msg, Formatted); - // Print remainder - fwrite(Formatted.data() + Msg.ColorRangeEnd, 1, Formatted.size() - Msg.ColorRangeEnd, m_File); - } - else + if (!m_UseColor) { - fwrite(Formatted.data(), 1, Formatted.size(), m_File); + helpers::StripAnsiSgrSequences(Formatted); } - fflush(m_File); + WriteOutput(Formatted); } void Flush() { - std::lock_guard<std::mutex> Lock(m_Mutex); + if (!m_Dirty.load(std::memory_order_relaxed)) + { + return; + } + RwLock::ExclusiveLockScope Lock(m_Lock); + m_Dirty.store(false, std::memory_order_relaxed); +#if defined(_WIN32) + FlushFileBuffers(m_Handle); +#else fflush(m_File); +#endif } void SetFormatter(std::unique_ptr<Formatter> InFormatter) { - std::lock_guard<std::mutex> Lock(m_Mutex); + RwLock::ExclusiveLockScope Lock(m_Lock); + InFormatter->SetColorEnabled(m_UseColor); m_Formatter = std::move(InFormatter); } private: - std::mutex m_Mutex; + RwLock m_Lock; std::unique_ptr<Formatter> m_Formatter; - FILE* m_File = stdout; - bool m_UseColor = true; +#if defined(_WIN32) + HANDLE m_Handle = GetStdHandle(STD_OUTPUT_HANDLE); +#else + FILE* m_File = stdout; +#endif + bool m_UseColor = true; + std::atomic<bool> m_Dirty = false; }; AnsiColorStdoutSink::AnsiColorStdoutSink(ColorMode Mode) : m_Impl(std::make_unique<Impl>(Mode)) diff --git a/src/zencore/testing.cpp b/src/zencore/testing.cpp index f5bc723b1..c6ee5ee6b 100644 --- a/src/zencore/testing.cpp +++ b/src/zencore/testing.cpp @@ -181,6 +181,15 @@ struct TestListener : public doctest::IReporter void test_case_start(const doctest::TestCaseData& in) override { Current = ∈ + + if (in.m_test_suite && in.m_test_suite != CurrentSuite) + { + CurrentSuite = in.m_test_suite; + ZEN_CONSOLE("{}==============================================================================={}", ColorYellow, ColorNone); + ZEN_CONSOLE("{} TEST_SUITE: {}{}", ColorYellow, CurrentSuite, ColorNone); + ZEN_CONSOLE("{}==============================================================================={}", ColorYellow, ColorNone); + } + ZEN_CONSOLE("{}======== TEST_CASE: {:<50} ========{}", ColorYellow, Current->m_name, ColorNone); } @@ -217,8 +226,9 @@ struct TestListener : public doctest::IReporter void test_case_skipped(const doctest::TestCaseData& /*in*/) override {} - const doctest::TestCaseData* Current = nullptr; - std::chrono::steady_clock::time_point RunStart = {}; + const doctest::TestCaseData* Current = nullptr; + std::string_view CurrentSuite = {}; + std::chrono::steady_clock::time_point RunStart = {}; struct FailedTestInfo { @@ -244,15 +254,29 @@ TestRunner::~TestRunner() { } -void -TestRunner::SetDefaultSuiteFilter(const char* Pattern) -{ - m_Impl->Session.setOption("test-suite", Pattern); -} - int -TestRunner::ApplyCommandLine(int Argc, char const* const* Argv) +TestRunner::ApplyCommandLine(int Argc, char const* const* Argv, const char* DefaultSuiteFilter) { + // Apply the default suite filter only when the command line doesn't provide + // an explicit --test-suite / --ts override. + if (DefaultSuiteFilter) + { + bool HasExplicitSuiteFilter = false; + for (int i = 1; i < Argc; ++i) + { + std::string_view Arg = Argv[i]; + if (Arg.starts_with("--test-suite=") || Arg.starts_with("--ts=") || Arg.starts_with("-test-suite=") || Arg.starts_with("-ts=")) + { + HasExplicitSuiteFilter = true; + break; + } + } + if (!HasExplicitSuiteFilter) + { + m_Impl->Session.setOption("test-suite", DefaultSuiteFilter); + } + } + m_Impl->Session.applyCommandLine(Argc, Argv); for (int i = 1; i < Argc; ++i) @@ -316,6 +340,7 @@ RunTestMain(int Argc, char* Argv[], const char* ExecutableName, void (*ForceLink TestRunner Runner; // Derive default suite filter from ExecutableName: "zencore-test" -> "core.*" + std::string DefaultSuiteFilter; if (ExecutableName) { std::string_view Name = ExecutableName; @@ -329,13 +354,12 @@ RunTestMain(int Argc, char* Argv[], const char* ExecutableName, void (*ForceLink } if (!Name.empty()) { - std::string Filter(Name); - Filter += ".*"; - Runner.SetDefaultSuiteFilter(Filter.c_str()); + DefaultSuiteFilter = Name; + DefaultSuiteFilter += ".*"; } } - Runner.ApplyCommandLine(Argc, Argv); + Runner.ApplyCommandLine(Argc, Argv, DefaultSuiteFilter.empty() ? nullptr : DefaultSuiteFilter.c_str()); return Runner.Run(); } diff --git a/src/zenhttp/clients/httpwsclient.cpp b/src/zenhttp/clients/httpwsclient.cpp index fbae9f5fe..770213738 100644 --- a/src/zenhttp/clients/httpwsclient.cpp +++ b/src/zenhttp/clients/httpwsclient.cpp @@ -638,4 +638,34 @@ HttpWsClient::IsOpen() const return m_Impl->m_IsOpen.load(std::memory_order_relaxed); } +std::string +HttpToWsUrl(std::string_view Endpoint, std::string_view Path) +{ + std::string_view Scheme = "ws://"; + std::string_view Host = Endpoint; + + if (Endpoint.starts_with("http://")) + { + Host = Endpoint.substr(7); + } + else if (Endpoint.starts_with("https://")) + { + Scheme = "wss://"; + Host = Endpoint.substr(8); + } + + // Strip trailing slash from host to avoid double-slash when Path starts with '/' + if (!Host.empty() && Host.back() == '/') + { + Host = Host.substr(0, Host.size() - 1); + } + + std::string Url; + Url.reserve(Scheme.size() + Host.size() + Path.size()); + Url.append(Scheme); + Url.append(Host); + Url.append(Path); + return Url; +} + } // namespace zen diff --git a/src/zenhttp/include/zenhttp/httpwsclient.h b/src/zenhttp/include/zenhttp/httpwsclient.h index 2ca9b7ab1..9c3b909a2 100644 --- a/src/zenhttp/include/zenhttp/httpwsclient.h +++ b/src/zenhttp/include/zenhttp/httpwsclient.h @@ -80,4 +80,14 @@ private: std::unique_ptr<Impl> m_Impl; }; +/// Convert an HTTP(S) endpoint to a WebSocket URL by replacing the scheme +/// and appending the given path. If the endpoint has no recognised scheme, +/// it is treated as a plain host:port and gets the ws:// prefix. +/// +/// Examples: +/// HttpToWsUrl("http://host:8080", "/orch/ws") → "ws://host:8080/orch/ws" +/// HttpToWsUrl("https://host", "/foo") → "wss://host/foo" +/// HttpToWsUrl("host:8080", "/bar") → "ws://host:8080/bar" +std::string HttpToWsUrl(std::string_view Endpoint, std::string_view Path); + } // namespace zen diff --git a/src/zenhttp/include/zenhttp/websocket.h b/src/zenhttp/include/zenhttp/websocket.h index bc3293282..710579faa 100644 --- a/src/zenhttp/include/zenhttp/websocket.h +++ b/src/zenhttp/include/zenhttp/websocket.h @@ -43,6 +43,8 @@ public: virtual void SendBinary(std::span<const uint8_t> Data) = 0; virtual void Close(uint16_t Code = 1000, std::string_view Reason = {}) = 0; virtual bool IsOpen() const = 0; + + void SendBinary(MemoryView Data) { SendBinary(std::span<const uint8_t>(static_cast<const uint8_t*>(Data.GetData()), Data.GetSize())); } }; /** diff --git a/src/zenserver-test/compute-tests.cpp b/src/zenserver-test/compute-tests.cpp index c90ac5d8b..021052a3b 100644 --- a/src/zenserver-test/compute-tests.cpp +++ b/src/zenserver-test/compute-tests.cpp @@ -19,6 +19,7 @@ # include <zencore/timer.h> # include <zenhttp/httpclient.h> # include <zenhttp/httpserver.h> +# include <zenhttp/websocket.h> # include <zencompute/computeservice.h> # include <zenstore/zenstore.h> # include <zenutil/zenserverprocess.h> @@ -291,7 +292,9 @@ GetRot13Output(const CbPackage& ResultPackage) } // Mock orchestrator HTTP service that serves GET /orch/agents with a controllable response. -class MockOrchestratorService : public HttpService +// Also implements IWebSocketHandler so the compute session's WS subscription receives +// push notifications when the worker list changes. +class MockOrchestratorService : public HttpService, public IWebSocketHandler { public: MockOrchestratorService() @@ -318,13 +321,48 @@ public: void SetWorkerList(CbObject WorkerList) { - RwLock::ExclusiveLockScope Lock(m_Lock); - m_WorkerList = std::move(WorkerList); + { + RwLock::ExclusiveLockScope Lock(m_Lock); + m_WorkerList = std::move(WorkerList); + } + + // Broadcast a poke to all connected WebSocket clients so they + // immediately re-query the orchestrator instead of waiting for the poll. + std::vector<Ref<WebSocketConnection>> Snapshot; + m_WsLock.WithSharedLock([&] { Snapshot = m_WsConnections; }); + for (auto& Conn : Snapshot) + { + if (Conn->IsOpen()) + { + Conn->SendText("updated"sv); + } + } + } + + // IWebSocketHandler + void OnWebSocketOpen(Ref<WebSocketConnection> Connection) override + { + m_WsLock.WithExclusiveLock([&] { m_WsConnections.push_back(std::move(Connection)); }); + } + + void OnWebSocketMessage(WebSocketConnection&, const WebSocketMessage&) override {} + + void OnWebSocketClose(WebSocketConnection& Conn, uint16_t, std::string_view) override + { + m_WsLock.WithExclusiveLock([&] { + auto It = std::remove_if(m_WsConnections.begin(), m_WsConnections.end(), [&](const Ref<WebSocketConnection>& C) { + return C.Get() == &Conn; + }); + m_WsConnections.erase(It, m_WsConnections.end()); + }); } private: RwLock m_Lock; CbObject m_WorkerList; + + RwLock m_WsLock; + std::vector<Ref<WebSocketConnection>> m_WsConnections; }; // Manages in-process ASIO HTTP server lifecycle for mock orchestrator. @@ -1089,9 +1127,8 @@ TEST_CASE("function.remote.worker_sync_on_discovery") InMemoryChunkResolver Resolver; ScopedTemporaryDirectory SessionBaseDir; zen::compute::ComputeServiceSession Session(Resolver); - Session.SetOrchestratorEndpoint(MockOrch.GetEndpoint()); - Session.SetOrchestratorBasePath(SessionBaseDir.Path()); - Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Ready); + Session.SetOrchestrator(MockOrch.GetEndpoint(), SessionBaseDir.Path()); + Session.Ready(); // Register worker on session (stored locally, no runners yet) CbPackage WorkerPackage = BuildWorkerPackage(TestEnv, Resolver); @@ -1100,8 +1137,9 @@ TEST_CASE("function.remote.worker_sync_on_discovery") // Update mock orchestrator to advertise the real server MockOrch.Service.SetWorkerList(BuildAgentListResponse({{"worker-1", ServerUri}})); - // Wait for scheduler to discover the runner (~5s throttle + margin) - Sleep(7'000); + // Trigger immediate orchestrator re-query and wait for runner setup + Session.NotifyOrchestratorChanged(); + Sleep(2'000); // Submit Rot13 action via session CbObject ActionObj = BuildRot13ActionForSession("Hello World"sv, Resolver); @@ -1153,15 +1191,15 @@ TEST_CASE("function.remote.late_runner_discovery") InMemoryChunkResolver Resolver; ScopedTemporaryDirectory SessionBaseDir; zen::compute::ComputeServiceSession Session(Resolver); - Session.SetOrchestratorEndpoint(MockOrch.GetEndpoint()); - Session.SetOrchestratorBasePath(SessionBaseDir.Path()); - Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Ready); + Session.SetOrchestrator(MockOrch.GetEndpoint(), SessionBaseDir.Path()); + Session.Ready(); CbPackage WorkerPackage = BuildWorkerPackage(TestEnv, Resolver); Session.RegisterWorker(WorkerPackage); // Wait for W1 discovery - Sleep(7'000); + Session.NotifyOrchestratorChanged(); + Sleep(2'000); // Baseline: submit Rot13 action and verify it completes on W1 { @@ -1202,7 +1240,8 @@ TEST_CASE("function.remote.late_runner_discovery") MockOrch.Service.SetWorkerList(BuildAgentListResponse({{"worker-1", ServerUri1}, {"worker-2", ServerUri2}})); // Wait for W2 discovery - Sleep(7'000); + Session.NotifyOrchestratorChanged(); + Sleep(2'000); // Verify W2 received the worker by querying its /compute/workers endpoint directly { @@ -1274,16 +1313,16 @@ TEST_CASE("function.remote.queue_association") InMemoryChunkResolver Resolver; ScopedTemporaryDirectory SessionBaseDir; zen::compute::ComputeServiceSession Session(Resolver); - Session.SetOrchestratorEndpoint(MockOrch.GetEndpoint()); - Session.SetOrchestratorBasePath(SessionBaseDir.Path()); - Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Ready); + Session.SetOrchestrator(MockOrch.GetEndpoint(), SessionBaseDir.Path()); + Session.Ready(); // Register worker on session CbPackage WorkerPackage = BuildWorkerPackage(TestEnv, Resolver); Session.RegisterWorker(WorkerPackage); // Wait for scheduler to discover the runner - Sleep(7'000); + Session.NotifyOrchestratorChanged(); + Sleep(2'000); // Create a local queue and submit action to it auto QueueResult = Session.CreateQueue(); @@ -1353,16 +1392,16 @@ TEST_CASE("function.remote.queue_cancel_propagation") InMemoryChunkResolver Resolver; ScopedTemporaryDirectory SessionBaseDir; zen::compute::ComputeServiceSession Session(Resolver); - Session.SetOrchestratorEndpoint(MockOrch.GetEndpoint()); - Session.SetOrchestratorBasePath(SessionBaseDir.Path()); - Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Ready); + Session.SetOrchestrator(MockOrch.GetEndpoint(), SessionBaseDir.Path()); + Session.Ready(); // Register worker on session CbPackage WorkerPackage = BuildWorkerPackage(TestEnv, Resolver); Session.RegisterWorker(WorkerPackage); // Wait for scheduler to discover the runner - Sleep(7'000); + Session.NotifyOrchestratorChanged(); + Sleep(2'000); // Create a local queue and submit a long-running Sleep action auto QueueResult = Session.CreateQueue(); @@ -1496,7 +1535,7 @@ TEST_CASE("function.session.abandon_pending") InMemoryChunkResolver Resolver; ScopedTemporaryDirectory SessionBaseDir; zen::compute::ComputeServiceSession Session(Resolver); - Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Ready); + Session.Ready(); CbPackage WorkerPackage = BuildWorkerPackage(TestEnv, Resolver); Session.RegisterWorker(WorkerPackage); @@ -1515,19 +1554,29 @@ TEST_CASE("function.session.abandon_pending") REQUIRE_MESSAGE(Enqueue3, "Failed to enqueue action 3"); // Transition to Abandoned — should mark all pending actions as Abandoned - bool Transitioned = Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Abandoned); + bool Transitioned = Session.Abandon(); CHECK_MESSAGE(Transitioned, "Failed to transition to Abandoned"); CHECK(Session.GetSessionState() == zen::compute::ComputeServiceSession::SessionState::Abandoned); CHECK(!Session.IsHealthy()); - // Give the scheduler thread time to process the state changes - Sleep(2'000); - - // All three actions should now be in the results map as abandoned + // Poll until the scheduler thread has processed all abandoned actions into + // the results map. The abandon is asynchronous: AbandonAllActions() sets + // action state and posts updates, but HandleActionUpdates() on the + // scheduler thread must run before results are queryable. + Stopwatch Timer; for (int Lsn : {Enqueue1.Lsn, Enqueue2.Lsn, Enqueue3.Lsn}) { CbPackage Result; - HttpResponseCode Code = Session.GetActionResult(Lsn, Result); + HttpResponseCode Code = HttpResponseCode::Accepted; + while (Timer.GetElapsedTimeMs() < 10'000) + { + Code = Session.GetActionResult(Lsn, Result); + if (Code == HttpResponseCode::OK) + { + break; + } + Sleep(100); + } CHECK_MESSAGE(Code == HttpResponseCode::OK, fmt::format("Expected action LSN {} to be in results (got {})", Lsn, int(Code))); } @@ -1561,15 +1610,15 @@ TEST_CASE("function.session.abandon_running") InMemoryChunkResolver Resolver; ScopedTemporaryDirectory SessionBaseDir; zen::compute::ComputeServiceSession Session(Resolver); - Session.SetOrchestratorEndpoint(MockOrch.GetEndpoint()); - Session.SetOrchestratorBasePath(SessionBaseDir.Path()); - Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Ready); + Session.SetOrchestrator(MockOrch.GetEndpoint(), SessionBaseDir.Path()); + Session.Ready(); CbPackage WorkerPackage = BuildWorkerPackage(TestEnv, Resolver); Session.RegisterWorker(WorkerPackage); // Wait for scheduler to discover the runner - Sleep(7'000); + Session.NotifyOrchestratorChanged(); + Sleep(2'000); // Create a queue and submit a long-running Sleep action auto QueueResult = Session.CreateQueue(); @@ -1585,7 +1634,7 @@ TEST_CASE("function.session.abandon_running") Sleep(2'000); // Transition to Abandoned — should abandon the running action - bool Transitioned = Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Abandoned); + bool Transitioned = Session.Abandon(); CHECK_MESSAGE(Transitioned, "Failed to transition to Abandoned"); CHECK(!Session.IsHealthy()); @@ -1631,16 +1680,16 @@ TEST_CASE("function.remote.abandon_propagation") InMemoryChunkResolver Resolver; ScopedTemporaryDirectory SessionBaseDir; zen::compute::ComputeServiceSession Session(Resolver); - Session.SetOrchestratorEndpoint(MockOrch.GetEndpoint()); - Session.SetOrchestratorBasePath(SessionBaseDir.Path()); - Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Ready); + Session.SetOrchestrator(MockOrch.GetEndpoint(), SessionBaseDir.Path()); + Session.Ready(); // Register worker on session CbPackage WorkerPackage = BuildWorkerPackage(TestEnv, Resolver); Session.RegisterWorker(WorkerPackage); // Wait for scheduler to discover the runner - Sleep(7'000); + Session.NotifyOrchestratorChanged(); + Sleep(2'000); // Create a local queue and submit a long-running Sleep action auto QueueResult = Session.CreateQueue(); @@ -1656,7 +1705,7 @@ TEST_CASE("function.remote.abandon_propagation") Sleep(2'000); // Transition to Abandoned — should abandon the running action and propagate - bool Transitioned = Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Abandoned); + bool Transitioned = Session.Abandon(); CHECK_MESSAGE(Transitioned, "Failed to transition to Abandoned"); // Poll for the action to complete @@ -1693,6 +1742,278 @@ TEST_CASE("function.remote.abandon_propagation") Session.Shutdown(); } +TEST_CASE("function.remote.shutdown_cancels_queues") +{ + // Verify that Session.Shutdown() cancels remote queues on the compute node. + ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + REQUIRE_MESSAGE(Instance.SpawnServerAndWaitUntilReady() != 0, Instance.GetLogOutput()); + + MockOrchestratorFixture MockOrch; + MockOrch.Service.SetWorkerList(BuildAgentListResponse({{"worker-1", Instance.GetBaseUri()}})); + + InMemoryChunkResolver Resolver; + ScopedTemporaryDirectory SessionBaseDir; + zen::compute::ComputeServiceSession Session(Resolver); + Session.SetOrchestrator(MockOrch.GetEndpoint(), SessionBaseDir.Path()); + Session.Ready(); + + CbPackage WorkerPackage = BuildWorkerPackage(TestEnv, Resolver); + Session.RegisterWorker(WorkerPackage); + + Session.NotifyOrchestratorChanged(); + Sleep(2'000); + + // Create a queue and submit a long-running action so the remote queue is established + auto QueueResult = Session.CreateQueue(); + REQUIRE_MESSAGE(QueueResult.QueueId != 0, "Failed to create local queue"); + const int QueueId = QueueResult.QueueId; + + CbObject ActionObj = BuildSleepActionForSession("data"sv, 30'000, Resolver); + + auto EnqueueRes = Session.EnqueueActionToQueue(QueueId, ActionObj, 0); + REQUIRE_MESSAGE(EnqueueRes, "Sleep action enqueue to queue failed"); + + // Wait for the action to start running on the remote + Sleep(2'000); + + // Verify the remote has a non-implicit queue before shutdown + HttpClient RemoteClient(Instance.GetBaseUri() + "/compute"); + { + HttpClient::Response QueuesResp = RemoteClient.Get("/queues"sv); + REQUIRE_MESSAGE(QueuesResp, "Failed to list queues on remote server before shutdown"); + + bool RemoteQueueFound = false; + for (auto& Item : QueuesResp.AsObject()["queues"sv]) + { + if (!Item.AsObjectView()["implicit"sv].AsBool()) + { + RemoteQueueFound = true; + break; + } + } + REQUIRE_MESSAGE(RemoteQueueFound, "Expected remote queue to exist before shutdown"); + } + + // Shut down the session — this should cancel all remote queues + Session.Shutdown(); + + // Verify the remote queue is now cancelled + { + HttpClient::Response QueuesResp = RemoteClient.Get("/queues"sv); + REQUIRE_MESSAGE(QueuesResp, "Failed to list queues on remote server after shutdown"); + + bool RemoteQueueCancelled = false; + for (auto& Item : QueuesResp.AsObject()["queues"sv]) + { + if (!Item.AsObjectView()["implicit"sv].AsBool()) + { + RemoteQueueCancelled = std::string(Item.AsObjectView()["state"sv].AsString()) == "cancelled"; + break; + } + } + CHECK_MESSAGE(RemoteQueueCancelled, "Expected remote queue to be cancelled after session shutdown"); + } +} + +TEST_CASE("function.remote.shutdown_rejects_new_work") +{ + // Verify that after Shutdown() the remote runner rejects new submissions. + ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + REQUIRE_MESSAGE(Instance.SpawnServerAndWaitUntilReady() != 0, Instance.GetLogOutput()); + + MockOrchestratorFixture MockOrch; + MockOrch.Service.SetWorkerList(BuildAgentListResponse({{"worker-1", Instance.GetBaseUri()}})); + + InMemoryChunkResolver Resolver; + ScopedTemporaryDirectory SessionBaseDir; + zen::compute::ComputeServiceSession Session(Resolver); + Session.SetOrchestrator(MockOrch.GetEndpoint(), SessionBaseDir.Path()); + Session.Ready(); + + CbPackage WorkerPackage = BuildWorkerPackage(TestEnv, Resolver); + Session.RegisterWorker(WorkerPackage); + + // Wait for runner discovery + Session.NotifyOrchestratorChanged(); + Sleep(2'000); + + // Baseline: submit an action and verify it completes + { + CbObject ActionObj = BuildRot13ActionForSession("Hello World"sv, Resolver); + + auto EnqueueRes = Session.EnqueueAction(ActionObj, 0); + REQUIRE_MESSAGE(EnqueueRes, "Baseline action enqueue failed"); + + CbPackage ResultPackage; + HttpResponseCode ResultCode = HttpResponseCode::Accepted; + Stopwatch Timer; + + while (Timer.GetElapsedTimeMs() < 30'000) + { + ResultCode = Session.GetActionResult(EnqueueRes.Lsn, ResultPackage); + if (ResultCode == HttpResponseCode::OK) + { + break; + } + Sleep(200); + } + + REQUIRE_MESSAGE(ResultCode == HttpResponseCode::OK, + fmt::format("Baseline action did not complete in time\nServer log:\n{}", Instance.GetLogOutput())); + CHECK_EQ(GetRot13Output(ResultPackage), "Uryyb Jbeyq"sv); + } + + // Shut down — the remote runner should now reject new work + Session.Shutdown(); + + // Attempting to enqueue after shutdown should fail (session is in Sunset state) + CbObject ActionObj = BuildRot13ActionForSession("rejected"sv, Resolver); + auto Rejected = Session.EnqueueAction(ActionObj, 0); + CHECK_MESSAGE(!Rejected, "Expected action submission to be rejected after shutdown"); +} + +TEST_CASE("function.session.retract_pending") +{ + // Create a session with no runners so actions stay pending + InMemoryChunkResolver Resolver; + ScopedTemporaryDirectory SessionBaseDir; + zen::compute::ComputeServiceSession Session(Resolver); + Session.Ready(); + + CbPackage WorkerPackage = BuildWorkerPackage(TestEnv, Resolver); + Session.RegisterWorker(WorkerPackage); + + auto QueueResult = Session.CreateQueue(); + REQUIRE_MESSAGE(QueueResult.QueueId != 0, "Failed to create queue"); + + CbObject ActionObj = BuildRot13ActionForSession("retract-test"sv, Resolver); + + auto Enqueued = Session.EnqueueActionToQueue(QueueResult.QueueId, ActionObj, 0); + REQUIRE_MESSAGE(Enqueued, "Failed to enqueue action"); + + // Let the scheduler process the pending action + Sleep(500); + + // Retract the pending action + auto Result = Session.RetractAction(Enqueued.Lsn); + CHECK_MESSAGE(Result.Success, fmt::format("RetractAction failed: {}", Result.Error)); + CHECK_EQ(Result.RetryCount, 0); // Retract should NOT increment retry count + + // The action should be re-enqueued as pending (still no runners, so stays pending). + // Let the scheduler process the retracted action back to pending. + Sleep(500); + + // Queue should still show 1 active (the action was rescheduled, not completed) + auto Status = Session.GetQueueStatus(QueueResult.QueueId); + CHECK_EQ(Status.ActiveCount, 1); + CHECK_EQ(Status.CompletedCount, 0); + CHECK_EQ(Status.AbandonedCount, 0); + CHECK_EQ(Status.CancelledCount, 0); + + // The action result should NOT be in the results map (it's pending again) + CbPackage ResultPackage; + HttpResponseCode Code = Session.GetActionResult(Enqueued.Lsn, ResultPackage); + CHECK(Code != HttpResponseCode::OK); + + Session.Shutdown(); +} + +TEST_CASE("function.session.retract_not_terminal") +{ + // Verify that a completed action cannot be retracted + InMemoryChunkResolver Resolver; + ScopedTemporaryDirectory SessionBaseDir; + zen::compute::ComputeServiceSession Session(Resolver); + Session.SetOrchestratorBasePath(SessionBaseDir.Path()); + Session.AddLocalRunner(Resolver, SessionBaseDir.Path()); + Session.Ready(); + + CbPackage WorkerPackage = BuildWorkerPackage(TestEnv, Resolver); + Session.RegisterWorker(WorkerPackage); + + CbObject ActionObj = BuildRot13ActionForSession("retract-completed"sv, Resolver); + + auto Enqueued = Session.EnqueueAction(ActionObj, 0); + REQUIRE_MESSAGE(Enqueued, "Failed to enqueue action"); + + // Wait for the action to complete + CbPackage ResultPackage; + HttpResponseCode Code = HttpResponseCode::Accepted; + Stopwatch Timer; + + while (Timer.GetElapsedTimeMs() < 30'000) + { + Code = Session.GetActionResult(Enqueued.Lsn, ResultPackage); + if (Code == HttpResponseCode::OK) + { + break; + } + Sleep(200); + } + + REQUIRE_MESSAGE(Code == HttpResponseCode::OK, "Action did not complete within timeout"); + + // Retract should fail — action already completed (no longer in pending/running maps) + auto RetractResult = Session.RetractAction(Enqueued.Lsn); + CHECK(!RetractResult.Success); + + Session.Shutdown(); +} + +TEST_CASE("function.retract_http") +{ + // Use max-actions=1 with a long sleep to hold the slot, then submit a second + // action that will stay pending and can be retracted via the HTTP endpoint. + ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady("--max-actions=1"); + REQUIRE_MESSAGE(Port != 0, Instance.GetLogOutput()); + + const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port); + HttpClient Client(ComputeBaseUri); + + const IoHash WorkerId = RegisterWorker(Client, TestEnv); + + // Submit a long-running Sleep action to occupy the single execution slot + const std::string BlockerUrl = fmt::format("/jobs/{}", WorkerId.ToHexString()); + HttpClient::Response BlockerResp = Client.Post(BlockerUrl, BuildSleepActionPackage("data"sv, 30'000)); + REQUIRE_MESSAGE(BlockerResp, fmt::format("Blocker submission failed: status={}", int(BlockerResp.StatusCode))); + + // Submit a second action — it will stay pending because the slot is occupied + HttpClient::Response SubmitResp = Client.Post(BlockerUrl, BuildRot13ActionPackage("Retract HTTP Test"sv)); + REQUIRE_MESSAGE(SubmitResp, fmt::format("Job submission failed: status={}", int(SubmitResp.StatusCode))); + + const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32(); + REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN from job submission"); + + // Wait for the scheduler to process the pending action into m_PendingActions + Sleep(1'000); + + // Retract the pending action via POST /jobs/{lsn}/retract + const std::string RetractUrl = fmt::format("/jobs/{}/retract", Lsn); + HttpClient::Response RetractResp = Client.Post(RetractUrl); + CHECK_MESSAGE(RetractResp.StatusCode == HttpResponseCode::OK, + fmt::format("Retract failed: status={}, body={}\nServer log:\n{}", + int(RetractResp.StatusCode), + RetractResp.ToText(), + Instance.GetLogOutput())); + + if (RetractResp.StatusCode == HttpResponseCode::OK) + { + CbObject RespObj = RetractResp.AsObject(); + CHECK(RespObj["success"sv].AsBool()); + CHECK_EQ(RespObj["lsn"sv].AsInt32(), Lsn); + } + + // A second retract should also succeed (action is back to pending) + Sleep(500); + HttpClient::Response RetractResp2 = Client.Post(RetractUrl); + CHECK_MESSAGE(RetractResp2.StatusCode == HttpResponseCode::OK, + fmt::format("Second retract failed: status={}, body={}", int(RetractResp2.StatusCode), RetractResp2.ToText())); +} + TEST_SUITE_END(); } // namespace zen::tests::compute diff --git a/src/zenserver-test/zenserver-test.cpp b/src/zenserver-test/zenserver-test.cpp index e89812c1f..42632682b 100644 --- a/src/zenserver-test/zenserver-test.cpp +++ b/src/zenserver-test/zenserver-test.cpp @@ -15,6 +15,7 @@ # include <zencore/testutils.h> # include <zencore/thread.h> # include <zencore/timer.h> +# include <zencore/trace.h> # include <zenhttp/httpclient.h> # include <zenhttp/packageformat.h> # include <zenutil/config/commandlineoptions.h> @@ -134,8 +135,7 @@ main(int argc, char** argv) ZEN_INFO("Running tests...(base dir: '{}')", TestBaseDir); zen::testing::TestRunner Runner; - Runner.SetDefaultSuiteFilter("server.*"); - Runner.ApplyCommandLine(argc, argv); + Runner.ApplyCommandLine(argc, argv, "server.*"); return Runner.Run(); } diff --git a/src/zenserver/compute/computeserver.cpp b/src/zenserver/compute/computeserver.cpp index 724ef9ad2..d1875f41a 100644 --- a/src/zenserver/compute/computeserver.cpp +++ b/src/zenserver/compute/computeserver.cpp @@ -721,21 +721,7 @@ ZenComputeServer::InitializeOrchestratorWebSocket() return; } - // Convert http://host:port → ws://host:port/orch/ws - std::string WsUrl = m_CoordinatorEndpoint; - if (WsUrl.starts_with("http://")) - { - WsUrl = "ws://" + WsUrl.substr(7); - } - else if (WsUrl.starts_with("https://")) - { - WsUrl = "wss://" + WsUrl.substr(8); - } - if (!WsUrl.empty() && WsUrl.back() != '/') - { - WsUrl += '/'; - } - WsUrl += "orch/ws"; + std::string WsUrl = HttpToWsUrl(m_CoordinatorEndpoint, "/orch/ws"); ZEN_INFO("establishing WebSocket link to orchestrator at {}", WsUrl); diff --git a/src/zenserver/frontend/html/compute/compute.html b/src/zenserver/frontend/html/compute/compute.html index 66c20175f..c07bbb692 100644 --- a/src/zenserver/frontend/html/compute/compute.html +++ b/src/zenserver/frontend/html/compute/compute.html @@ -6,6 +6,7 @@ <title>Zen Compute Dashboard</title> <script src="https://cdn.jsdelivr.net/npm/[email protected]/dist/chart.umd.min.js"></script> <link rel="stylesheet" type="text/css" href="../zen.css" /> + <script src="../util/sanitize.js"></script> <script src="../theme.js"></script> <script src="../banner.js" defer></script> <script src="../nav.js" defer></script> @@ -456,11 +457,6 @@ }); // Helper functions - function escapeHtml(text) { - var div = document.createElement('div'); - div.textContent = text; - return div.innerHTML; - } function formatBytes(bytes) { if (bytes === 0) return '0 B'; diff --git a/src/zenserver/frontend/html/compute/hub.html b/src/zenserver/frontend/html/compute/hub.html index 32e1b05db..620349a2b 100644 --- a/src/zenserver/frontend/html/compute/hub.html +++ b/src/zenserver/frontend/html/compute/hub.html @@ -4,6 +4,7 @@ <meta charset="UTF-8"> <meta name="viewport" content="width=device-width, initial-scale=1.0"> <link rel="stylesheet" type="text/css" href="../zen.css" /> + <script src="../util/sanitize.js"></script> <script src="../theme.js"></script> <script src="../banner.js" defer></script> <script src="../nav.js" defer></script> @@ -62,12 +63,6 @@ const BASE_URL = window.location.origin; const REFRESH_INTERVAL = 2000; - function escapeHtml(text) { - var div = document.createElement('div'); - div.textContent = text; - return div.innerHTML; - } - function showError(message) { document.getElementById('error-container').innerHTML = '<div class="error">Error: ' + escapeHtml(message) + '</div>'; diff --git a/src/zenserver/frontend/html/compute/orchestrator.html b/src/zenserver/frontend/html/compute/orchestrator.html index a519dee18..d1a2bb015 100644 --- a/src/zenserver/frontend/html/compute/orchestrator.html +++ b/src/zenserver/frontend/html/compute/orchestrator.html @@ -4,6 +4,7 @@ <meta charset="UTF-8"> <meta name="viewport" content="width=device-width, initial-scale=1.0"> <link rel="stylesheet" type="text/css" href="../zen.css" /> + <script src="../util/sanitize.js"></script> <script src="../theme.js"></script> <script src="../banner.js" defer></script> <script src="../nav.js" defer></script> @@ -128,12 +129,6 @@ const BASE_URL = window.location.origin; const REFRESH_INTERVAL = 2000; - function escapeHtml(text) { - var div = document.createElement('div'); - div.textContent = text; - return div.innerHTML; - } - function showError(message) { document.getElementById('error-container').innerHTML = '<div class="error">Error: ' + escapeHtml(message) + '</div>'; diff --git a/src/zenserver/frontend/html/util/sanitize.js b/src/zenserver/frontend/html/util/sanitize.js new file mode 100644 index 000000000..1b0f32e38 --- /dev/null +++ b/src/zenserver/frontend/html/util/sanitize.js @@ -0,0 +1,9 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +// Shared utility functions for compute dashboard pages. + +function escapeHtml(text) { + var div = document.createElement('div'); + div.textContent = text; + return div.innerHTML; +} diff --git a/src/zenserver/main.cpp b/src/zenserver/main.cpp index 26ae85ae1..9d786c209 100644 --- a/src/zenserver/main.cpp +++ b/src/zenserver/main.cpp @@ -319,6 +319,11 @@ main(int argc, char* argv[]) { ServerMode = kTest; } + else if (argv[1][0] != '-') + { + fprintf(stderr, "unknown mode '%s'. Available modes: hub, store, compute, proxy, test\n", argv[1]); + return 1; + } } switch (ServerMode) diff --git a/src/zenserver/xmake.lua b/src/zenserver/xmake.lua index fe279ebb2..b619c5548 100644 --- a/src/zenserver/xmake.lua +++ b/src/zenserver/xmake.lua @@ -19,7 +19,7 @@ target("zenserver") add_headerfiles("**.h") add_rules("utils.bin2c", {extensions = {".zip"}}) add_files("**.cpp") - add_files("frontend/html.zip") + add_files("$(buildir)/frontend/html.zip") add_files("zenserver.cpp", {unity_ignored = true }) if is_plat("linux") and not (get_config("toolchain") or ""):find("clang") then @@ -84,7 +84,8 @@ target("zenserver") on_load(function(target) local html_dir = path.join(os.projectdir(), "src/zenserver/frontend/html") - local zip_path = path.join(os.projectdir(), "src/zenserver/frontend/html.zip") + local zip_dir = path.join(os.projectdir(), get_config("buildir") or "build", "frontend") + local zip_path = path.join(zip_dir, "html.zip") -- Check if zip needs regeneration local need_update = not os.isfile(zip_path) @@ -100,18 +101,19 @@ target("zenserver") if need_update then print("Regenerating frontend zip...") + os.mkdir(zip_dir) os.tryrm(zip_path) import("detect.tools.find_7z") local cmd_7z = find_7z() if cmd_7z then - os.execv(cmd_7z, {"a", "-mx0", zip_path, path.join(html_dir, ".")}) + os.execv(cmd_7z, {"a", "-mx0", "-bso0", zip_path, path.join(html_dir, ".")}) else import("detect.tools.find_zip") local zip_cmd = find_zip() if zip_cmd then local oldir = os.cd(html_dir) - os.execv(zip_cmd, {"-r", "-0", zip_path, "."}) + os.execv(zip_cmd, {"-r", "-0", "-q", zip_path, "."}) os.cd(oldir) else raise("Unable to find a suitable zip tool (need 7z or zip)") diff --git a/src/zenutil/consoletui.cpp b/src/zenutil/consoletui.cpp index 4410d463d..124132aed 100644 --- a/src/zenutil/consoletui.cpp +++ b/src/zenutil/consoletui.cpp @@ -480,4 +480,69 @@ TuiPollQuit() #endif } +void +TuiSetScrollRegion(uint32_t Top, uint32_t Bottom) +{ + printf("\033[%u;%ur", Top, Bottom); +} + +void +TuiResetScrollRegion() +{ + printf("\033[r"); +} + +void +TuiMoveCursor(uint32_t Row, uint32_t Col) +{ + printf("\033[%u;%uH", Row, Col); +} + +void +TuiSaveCursor() +{ + printf( + "\033" + "7"); +} + +void +TuiRestoreCursor() +{ + printf( + "\033" + "8"); +} + +void +TuiEraseLine() +{ + printf("\033[2K"); +} + +void +TuiWrite(std::string_view Text) +{ + fwrite(Text.data(), 1, Text.size(), stdout); +} + +void +TuiFlush() +{ + fflush(stdout); +} + +void +TuiShowCursor(bool Show) +{ + if (Show) + { + printf("\033[?25h"); + } + else + { + printf("\033[?25l"); + } +} + } // namespace zen diff --git a/src/zenutil/include/zenutil/consoletui.h b/src/zenutil/include/zenutil/consoletui.h index 5f74fa82b..22737589b 100644 --- a/src/zenutil/include/zenutil/consoletui.h +++ b/src/zenutil/include/zenutil/consoletui.h @@ -57,4 +57,32 @@ uint32_t TuiConsoleRows(uint32_t Default = 40); // Should only be called while in alternate screen mode. bool TuiPollQuit(); +// Set the scrollable region of the terminal to rows [Top, Bottom] (1-based, inclusive). +// Emits DECSTBM: ESC[<top>;<bottom>r +void TuiSetScrollRegion(uint32_t Top, uint32_t Bottom); + +// Reset the scroll region to the full terminal. Emits ESC[r +void TuiResetScrollRegion(); + +// Move cursor to the given 1-based row and column. Emits ESC[<row>;<col>H +void TuiMoveCursor(uint32_t Row, uint32_t Col); + +// Save cursor position. Emits ESC 7 +void TuiSaveCursor(); + +// Restore cursor position. Emits ESC 8 +void TuiRestoreCursor(); + +// Erase the entire current line. Emits ESC[2K +void TuiEraseLine(); + +// Write raw text to stdout without any formatting or newline. +void TuiWrite(std::string_view Text); + +// Flush stdout. +void TuiFlush(); + +// Show or hide the cursor. Emits ESC[?25h or ESC[?25l +void TuiShowCursor(bool Show); + } // namespace zen diff --git a/src/zenutil/include/zenutil/logging/fullformatter.h b/src/zenutil/include/zenutil/logging/fullformatter.h index 33cb94dae..0d026ed72 100644 --- a/src/zenutil/include/zenutil/logging/fullformatter.h +++ b/src/zenutil/include/zenutil/logging/fullformatter.h @@ -3,10 +3,8 @@ #pragma once #include <zencore/logging/formatter.h> -#include <zencore/logging/helpers.h> -#include <zencore/memory/llm.h> -#include <zencore/zencore.h> +#include <memory> #include <string_view> namespace zen::logging { @@ -14,195 +12,16 @@ namespace zen::logging { class FullFormatter final : public Formatter { public: - FullFormatter(std::string_view LogId, std::chrono::time_point<std::chrono::system_clock> Epoch) - : m_Epoch(Epoch) - , m_LogId(LogId) - , m_LinePrefix(128, ' ') - , m_UseFullDate(false) - { - } + FullFormatter(std::string_view LogId, std::chrono::time_point<std::chrono::system_clock> Epoch); + explicit FullFormatter(std::string_view LogId); + ~FullFormatter() override; - FullFormatter(std::string_view LogId) : m_LogId(LogId), m_LinePrefix(128, ' '), m_UseFullDate(true) {} - - virtual std::unique_ptr<Formatter> Clone() const override - { - ZEN_MEMSCOPE(ELLMTag::Logging); - if (m_UseFullDate) - { - return std::make_unique<FullFormatter>(m_LogId); - } - return std::make_unique<FullFormatter>(m_LogId, m_Epoch); - } - - virtual void Format(const LogMessage& Msg, MemoryBuffer& OutBuffer) override - { - ZEN_MEMSCOPE(ELLMTag::Logging); - - // Note that the sink is responsible for ensuring there is only ever a - // single caller in here - - using namespace std::literals; - - std::chrono::seconds TimestampSeconds; - - std::chrono::milliseconds Millis; - - if (m_UseFullDate) - { - TimestampSeconds = std::chrono::duration_cast<std::chrono::seconds>(Msg.GetTime().time_since_epoch()); - if (TimestampSeconds != m_LastLogSecs) - { - RwLock::ExclusiveLockScope _(m_TimestampLock); - m_LastLogSecs = TimestampSeconds; - - m_CachedLocalTm = helpers::SafeLocaltime(LogClock::to_time_t(Msg.GetTime())); - m_CachedDatetime.clear(); - m_CachedDatetime.push_back('['); - helpers::Pad2(m_CachedLocalTm.tm_year % 100, m_CachedDatetime); - m_CachedDatetime.push_back('-'); - helpers::Pad2(m_CachedLocalTm.tm_mon + 1, m_CachedDatetime); - m_CachedDatetime.push_back('-'); - helpers::Pad2(m_CachedLocalTm.tm_mday, m_CachedDatetime); - m_CachedDatetime.push_back(' '); - helpers::Pad2(m_CachedLocalTm.tm_hour, m_CachedDatetime); - m_CachedDatetime.push_back(':'); - helpers::Pad2(m_CachedLocalTm.tm_min, m_CachedDatetime); - m_CachedDatetime.push_back(':'); - helpers::Pad2(m_CachedLocalTm.tm_sec, m_CachedDatetime); - m_CachedDatetime.push_back('.'); - } - - Millis = helpers::TimeFraction<std::chrono::milliseconds>(Msg.GetTime()); - } - else - { - auto ElapsedTime = Msg.GetTime() - m_Epoch; - TimestampSeconds = std::chrono::duration_cast<std::chrono::seconds>(ElapsedTime); - - if (m_CacheTimestamp.load() != TimestampSeconds) - { - RwLock::ExclusiveLockScope _(m_TimestampLock); - - m_CacheTimestamp = TimestampSeconds; - - int Count = int(TimestampSeconds.count()); - const int LogSecs = Count % 60; - Count /= 60; - const int LogMins = Count % 60; - Count /= 60; - const int LogHours = Count; - - m_CachedDatetime.clear(); - m_CachedDatetime.push_back('['); - helpers::Pad2(LogHours, m_CachedDatetime); - m_CachedDatetime.push_back(':'); - helpers::Pad2(LogMins, m_CachedDatetime); - m_CachedDatetime.push_back(':'); - helpers::Pad2(LogSecs, m_CachedDatetime); - m_CachedDatetime.push_back('.'); - } - - Millis = std::chrono::duration_cast<std::chrono::milliseconds>(ElapsedTime - TimestampSeconds); - } - - { - RwLock::SharedLockScope _(m_TimestampLock); - OutBuffer.append(m_CachedDatetime.begin(), m_CachedDatetime.end()); - } - - helpers::Pad3(static_cast<uint32_t>(Millis.count()), OutBuffer); - OutBuffer.push_back(']'); - OutBuffer.push_back(' '); - - if (!m_LogId.empty()) - { - OutBuffer.push_back('['); - helpers::AppendStringView(m_LogId, OutBuffer); - OutBuffer.push_back(']'); - OutBuffer.push_back(' '); - } - - // append logger name if exists - if (Msg.GetLoggerName().size() > 0) - { - OutBuffer.push_back('['); - helpers::AppendStringView(Msg.GetLoggerName(), OutBuffer); - OutBuffer.push_back(']'); - OutBuffer.push_back(' '); - } - - OutBuffer.push_back('['); - // wrap the level name with color - Msg.ColorRangeStart = OutBuffer.size(); - helpers::AppendStringView(helpers::LevelToShortString(Msg.GetLevel()), OutBuffer); - Msg.ColorRangeEnd = OutBuffer.size(); - OutBuffer.push_back(']'); - OutBuffer.push_back(' '); - - // add source location if present - if (Msg.GetSource()) - { - OutBuffer.push_back('['); - const char* Filename = helpers::ShortFilename(Msg.GetSource().Filename); - helpers::AppendStringView(Filename, OutBuffer); - OutBuffer.push_back(':'); - helpers::AppendInt(Msg.GetSource().Line, OutBuffer); - OutBuffer.push_back(']'); - OutBuffer.push_back(' '); - } - - // Handle newlines in single log call by prefixing each additional line to make - // subsequent lines align with the first line in the message - - const size_t LinePrefixCount = Min<size_t>(OutBuffer.size(), m_LinePrefix.size()); - - auto MsgPayload = Msg.GetPayload(); - auto ItLineBegin = MsgPayload.begin(); - auto ItMessageEnd = MsgPayload.end(); - bool IsFirstline = true; - - { - auto ItLineEnd = ItLineBegin; - - auto EmitLine = [&] { - if (IsFirstline) - { - IsFirstline = false; - } - else - { - helpers::AppendStringView(std::string_view(m_LinePrefix.data(), LinePrefixCount), OutBuffer); - } - helpers::AppendStringView(std::string_view(&*ItLineBegin, ItLineEnd - ItLineBegin), OutBuffer); - }; - - while (ItLineEnd != ItMessageEnd) - { - if (*ItLineEnd++ == '\n') - { - EmitLine(); - ItLineBegin = ItLineEnd; - } - } - - if (ItLineBegin != ItMessageEnd) - { - EmitLine(); - helpers::AppendStringView("\n"sv, OutBuffer); - } - } - } + std::unique_ptr<Formatter> Clone() const override; + void Format(const LogMessage& Msg, MemoryBuffer& OutBuffer) override; private: - std::chrono::time_point<std::chrono::system_clock> m_Epoch; - std::tm m_CachedLocalTm; - std::chrono::seconds m_LastLogSecs{std::chrono::seconds(87654321)}; - std::atomic<std::chrono::seconds> m_CacheTimestamp{std::chrono::seconds(87654321)}; - MemoryBuffer m_CachedDatetime; - std::string m_LogId; - std::string m_LinePrefix; - bool m_UseFullDate = true; - RwLock m_TimestampLock; + struct Impl; + std::unique_ptr<Impl> m_Impl; }; } // namespace zen::logging diff --git a/src/zenutil/include/zenutil/logging/jsonformatter.h b/src/zenutil/include/zenutil/logging/jsonformatter.h index 216b1b5e5..fb3193f3e 100644 --- a/src/zenutil/include/zenutil/logging/jsonformatter.h +++ b/src/zenutil/include/zenutil/logging/jsonformatter.h @@ -3,158 +3,24 @@ #pragma once #include <zencore/logging/formatter.h> -#include <zencore/logging/helpers.h> -#include <zencore/memory/llm.h> -#include <zencore/zencore.h> +#include <memory> #include <string_view> -#include <unordered_map> namespace zen::logging { -using namespace std::literals; - class JsonFormatter final : public Formatter { public: - JsonFormatter(std::string_view LogId) : m_LogId(LogId) {} - - virtual std::unique_ptr<Formatter> Clone() const override { return std::make_unique<JsonFormatter>(m_LogId); } - - virtual void Format(const LogMessage& Msg, MemoryBuffer& Dest) override - { - ZEN_MEMSCOPE(ELLMTag::Logging); - - using std::chrono::duration_cast; - using std::chrono::milliseconds; - using std::chrono::seconds; - - auto Secs = std::chrono::duration_cast<seconds>(Msg.GetTime().time_since_epoch()); - if (Secs != m_LastLogSecs) - { - RwLock::ExclusiveLockScope _(m_TimestampLock); - m_CachedTm = helpers::SafeLocaltime(LogClock::to_time_t(Msg.GetTime())); - m_LastLogSecs = Secs; - - // cache the date/time part for the next second. - m_CachedDatetime.clear(); - - helpers::AppendInt(m_CachedTm.tm_year + 1900, m_CachedDatetime); - m_CachedDatetime.push_back('-'); - - helpers::Pad2(m_CachedTm.tm_mon + 1, m_CachedDatetime); - m_CachedDatetime.push_back('-'); - - helpers::Pad2(m_CachedTm.tm_mday, m_CachedDatetime); - m_CachedDatetime.push_back(' '); - - helpers::Pad2(m_CachedTm.tm_hour, m_CachedDatetime); - m_CachedDatetime.push_back(':'); - - helpers::Pad2(m_CachedTm.tm_min, m_CachedDatetime); - m_CachedDatetime.push_back(':'); - - helpers::Pad2(m_CachedTm.tm_sec, m_CachedDatetime); + explicit JsonFormatter(std::string_view LogId); + ~JsonFormatter() override; - m_CachedDatetime.push_back('.'); - } - helpers::AppendStringView("{"sv, Dest); - helpers::AppendStringView("\"time\": \""sv, Dest); - { - RwLock::SharedLockScope _(m_TimestampLock); - Dest.append(m_CachedDatetime.begin(), m_CachedDatetime.end()); - } - auto Millis = helpers::TimeFraction<milliseconds>(Msg.GetTime()); - helpers::Pad3(static_cast<uint32_t>(Millis.count()), Dest); - helpers::AppendStringView("\", "sv, Dest); - - helpers::AppendStringView("\"status\": \""sv, Dest); - helpers::AppendStringView(helpers::LevelToShortString(Msg.GetLevel()), Dest); - helpers::AppendStringView("\", "sv, Dest); - - helpers::AppendStringView("\"source\": \""sv, Dest); - helpers::AppendStringView("zenserver"sv, Dest); - helpers::AppendStringView("\", "sv, Dest); - - helpers::AppendStringView("\"service\": \""sv, Dest); - helpers::AppendStringView("zencache"sv, Dest); - helpers::AppendStringView("\", "sv, Dest); - - if (!m_LogId.empty()) - { - helpers::AppendStringView("\"id\": \""sv, Dest); - helpers::AppendStringView(m_LogId, Dest); - helpers::AppendStringView("\", "sv, Dest); - } - - if (Msg.GetLoggerName().size() > 0) - { - helpers::AppendStringView("\"logger.name\": \""sv, Dest); - helpers::AppendStringView(Msg.GetLoggerName(), Dest); - helpers::AppendStringView("\", "sv, Dest); - } - - if (Msg.GetThreadId() != 0) - { - helpers::AppendStringView("\"logger.thread_name\": \""sv, Dest); - helpers::PadUint(Msg.GetThreadId(), 0, Dest); - helpers::AppendStringView("\", "sv, Dest); - } - - if (Msg.GetSource()) - { - helpers::AppendStringView("\"file\": \""sv, Dest); - WriteEscapedString(Dest, helpers::ShortFilename(Msg.GetSource().Filename)); - helpers::AppendStringView("\","sv, Dest); - - helpers::AppendStringView("\"line\": \""sv, Dest); - helpers::AppendInt(Msg.GetSource().Line, Dest); - helpers::AppendStringView("\","sv, Dest); - } - - helpers::AppendStringView("\"message\": \""sv, Dest); - WriteEscapedString(Dest, Msg.GetPayload()); - helpers::AppendStringView("\""sv, Dest); - - helpers::AppendStringView("}\n"sv, Dest); - } + std::unique_ptr<Formatter> Clone() const override; + void Format(const LogMessage& Msg, MemoryBuffer& Dest) override; private: - static inline const std::unordered_map<char, std::string_view> s_SpecialCharacterMap{{'\b', "\\b"sv}, - {'\f', "\\f"sv}, - {'\n', "\\n"sv}, - {'\r', "\\r"sv}, - {'\t', "\\t"sv}, - {'"', "\\\""sv}, - {'\\', "\\\\"sv}}; - - static void WriteEscapedString(MemoryBuffer& Dest, const std::string_view& Text) - { - const char* RangeStart = Text.data(); - const char* End = Text.data() + Text.size(); - for (const char* It = RangeStart; It != End; ++It) - { - if (auto SpecialIt = s_SpecialCharacterMap.find(*It); SpecialIt != s_SpecialCharacterMap.end()) - { - if (RangeStart != It) - { - Dest.append(RangeStart, It); - } - helpers::AppendStringView(SpecialIt->second, Dest); - RangeStart = It + 1; - } - } - if (RangeStart != End) - { - Dest.append(RangeStart, End); - } - }; - - std::tm m_CachedTm{0, 0, 0, 0, 0, 0, 0, 0, 0}; - std::chrono::seconds m_LastLogSecs{0}; - MemoryBuffer m_CachedDatetime; - std::string m_LogId; - RwLock m_TimestampLock; + struct Impl; + std::unique_ptr<Impl> m_Impl; }; } // namespace zen::logging diff --git a/src/zenutil/include/zenutil/logging/rotatingfilesink.h b/src/zenutil/include/zenutil/logging/rotatingfilesink.h index cebc5b110..e0ff7eca1 100644 --- a/src/zenutil/include/zenutil/logging/rotatingfilesink.h +++ b/src/zenutil/include/zenutil/logging/rotatingfilesink.h @@ -2,14 +2,11 @@ #pragma once -#include <zencore/basicfile.h> -#include <zencore/logging/formatter.h> -#include <zencore/logging/messageonlyformatter.h> #include <zencore/logging/sink.h> -#include <zencore/memory/llm.h> -#include <atomic> +#include <cstddef> #include <filesystem> +#include <memory> namespace zen::logging { @@ -19,230 +16,21 @@ namespace zen::logging { class RotatingFileSink : public Sink { public: - RotatingFileSink(const std::filesystem::path& BaseFilename, std::size_t MaxSize, std::size_t MaxFiles, bool RotateOnOpen = false) - : m_BaseFilename(BaseFilename) - , m_MaxSize(MaxSize) - , m_MaxFiles(MaxFiles) - , m_Formatter(std::make_unique<MessageOnlyFormatter>()) - { - ZEN_MEMSCOPE(ELLMTag::Logging); - - std::error_code Ec; - if (RotateOnOpen) - { - RwLock::ExclusiveLockScope RotateLock(m_Lock); - Rotate(RotateLock, Ec); - } - else - { - m_CurrentFile.Open(m_BaseFilename, BasicFile::Mode::kWrite, Ec); - if (!Ec) - { - m_CurrentSize = m_CurrentFile.FileSize(Ec); - } - if (!Ec) - { - if (m_CurrentSize > m_MaxSize) - { - RwLock::ExclusiveLockScope RotateLock(m_Lock); - Rotate(RotateLock, Ec); - } - } - } - - if (Ec) - { - throw std::system_error(Ec, fmt::format("Failed to open log file '{}'", m_BaseFilename.string())); - } - } - - virtual ~RotatingFileSink() - { - try - { - RwLock::ExclusiveLockScope RotateLock(m_Lock); - m_CurrentFile.Close(); - } - catch (const std::exception&) - { - } - } + RotatingFileSink(const std::filesystem::path& BaseFilename, std::size_t MaxSize, std::size_t MaxFiles, bool RotateOnOpen = false); + ~RotatingFileSink() override; RotatingFileSink(const RotatingFileSink&) = delete; RotatingFileSink(RotatingFileSink&&) = delete; - RotatingFileSink& operator=(const RotatingFileSink&) = delete; RotatingFileSink& operator=(RotatingFileSink&&) = delete; - virtual void Log(const LogMessage& Msg) override - { - ZEN_MEMSCOPE(ELLMTag::Logging); - - try - { - MemoryBuffer Formatted; - if (TrySinkIt(Msg, Formatted)) - { - return; - } - - // This intentionally has no limit on the number of retries, see - // comment above. - for (;;) - { - { - RwLock::ExclusiveLockScope RotateLock(m_Lock); - // Only rotate if no-one else has rotated before us - if (m_CurrentSize > m_MaxSize || !m_CurrentFile.IsOpen()) - { - std::error_code Ec; - Rotate(RotateLock, Ec); - if (Ec) - { - return; - } - } - } - if (TrySinkIt(Formatted)) - { - return; - } - } - } - catch (const std::exception&) - { - // Silently eat errors - } - } - virtual void Flush() override - { - if (!m_NeedFlush) - { - return; - } - - ZEN_MEMSCOPE(ELLMTag::Logging); - - try - { - RwLock::SharedLockScope Lock(m_Lock); - if (m_CurrentFile.IsOpen()) - { - m_CurrentFile.Flush(); - } - } - catch (const std::exception&) - { - // Silently eat errors - } - - m_NeedFlush = false; - } - - virtual void SetFormatter(std::unique_ptr<Formatter> InFormatter) override - { - ZEN_MEMSCOPE(ELLMTag::Logging); - - try - { - RwLock::ExclusiveLockScope _(m_Lock); - m_Formatter = std::move(InFormatter); - } - catch (const std::exception&) - { - // Silently eat errors - } - } + void Log(const LogMessage& Msg) override; + void Flush() override; + void SetFormatter(std::unique_ptr<Formatter> InFormatter) override; private: - void Rotate(RwLock::ExclusiveLockScope&, std::error_code& OutEc) - { - ZEN_MEMSCOPE(ELLMTag::Logging); - - m_CurrentFile.Close(); - - OutEc = RotateFiles(m_BaseFilename, m_MaxFiles); - if (OutEc) - { - return; - } - - m_CurrentFile.Open(m_BaseFilename, BasicFile::Mode::kWrite, OutEc); - if (OutEc) - { - return; - } - - m_CurrentSize = m_CurrentFile.FileSize(OutEc); - if (OutEc) - { - // FileSize failed but we have an open file — reset to 0 - // so we can at least attempt writes from the start - m_CurrentSize = 0; - OutEc.clear(); - } - } - - bool TrySinkIt(const LogMessage& Msg, MemoryBuffer& OutFormatted) - { - ZEN_MEMSCOPE(ELLMTag::Logging); - - RwLock::SharedLockScope Lock(m_Lock); - if (!m_CurrentFile.IsOpen()) - { - return false; - } - m_Formatter->Format(Msg, OutFormatted); - size_t AddSize = OutFormatted.size(); - size_t WritePos = m_CurrentSize.fetch_add(AddSize); - if (WritePos + AddSize > m_MaxSize) - { - return false; - } - std::error_code Ec; - m_CurrentFile.Write(OutFormatted.data(), OutFormatted.size(), WritePos, Ec); - if (Ec) - { - return false; - } - m_NeedFlush = true; - return true; - } - - bool TrySinkIt(const MemoryBuffer& Formatted) - { - ZEN_MEMSCOPE(ELLMTag::Logging); - - RwLock::SharedLockScope Lock(m_Lock); - if (!m_CurrentFile.IsOpen()) - { - return false; - } - size_t AddSize = Formatted.size(); - size_t WritePos = m_CurrentSize.fetch_add(AddSize); - if (WritePos + AddSize > m_MaxSize) - { - return false; - } - - std::error_code Ec; - m_CurrentFile.Write(Formatted.data(), Formatted.size(), WritePos, Ec); - if (Ec) - { - return false; - } - m_NeedFlush = true; - return true; - } - - RwLock m_Lock; - const std::filesystem::path m_BaseFilename; - const std::size_t m_MaxSize; - const std::size_t m_MaxFiles; - std::unique_ptr<Formatter> m_Formatter; - std::atomic_size_t m_CurrentSize; - BasicFile m_CurrentFile; - std::atomic<bool> m_NeedFlush = false; + struct Impl; + std::unique_ptr<Impl> m_Impl; }; } // namespace zen::logging diff --git a/src/zenutil/logging/fullformatter.cpp b/src/zenutil/logging/fullformatter.cpp new file mode 100644 index 000000000..2a4840241 --- /dev/null +++ b/src/zenutil/logging/fullformatter.cpp @@ -0,0 +1,235 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenutil/logging/fullformatter.h> + +#include <zencore/intmath.h> +#include <zencore/logging/helpers.h> +#include <zencore/logging/memorybuffer.h> +#include <zencore/memory/llm.h> +#include <zencore/thread.h> +#include <zencore/zencore.h> + +#include <atomic> +#include <chrono> +#include <string> + +namespace zen::logging { + +struct FullFormatter::Impl +{ + Impl(std::string_view LogId, std::chrono::time_point<std::chrono::system_clock> Epoch) + : m_Epoch(Epoch) + , m_LogId(LogId) + , m_LinePrefix(128, ' ') + , m_UseFullDate(false) + { + } + + explicit Impl(std::string_view LogId) : m_LogId(LogId), m_LinePrefix(128, ' '), m_UseFullDate(true) {} + + std::chrono::time_point<std::chrono::system_clock> m_Epoch; + std::tm m_CachedLocalTm{}; + std::chrono::seconds m_LastLogSecs{std::chrono::seconds(87654321)}; + std::atomic<std::chrono::seconds> m_CacheTimestamp{std::chrono::seconds(87654321)}; + MemoryBuffer m_CachedDatetime; + std::string m_LogId; + std::string m_LinePrefix; + bool m_UseFullDate = true; + RwLock m_TimestampLock; +}; + +FullFormatter::FullFormatter(std::string_view LogId, std::chrono::time_point<std::chrono::system_clock> Epoch) +: m_Impl(std::make_unique<Impl>(LogId, Epoch)) +{ +} + +FullFormatter::FullFormatter(std::string_view LogId) : m_Impl(std::make_unique<Impl>(LogId)) +{ +} + +FullFormatter::~FullFormatter() = default; + +std::unique_ptr<Formatter> +FullFormatter::Clone() const +{ + ZEN_MEMSCOPE(ELLMTag::Logging); + std::unique_ptr<FullFormatter> Copy; + if (m_Impl->m_UseFullDate) + { + Copy = std::make_unique<FullFormatter>(m_Impl->m_LogId); + } + else + { + Copy = std::make_unique<FullFormatter>(m_Impl->m_LogId, m_Impl->m_Epoch); + } + Copy->SetColorEnabled(IsColorEnabled()); + return Copy; +} + +void +FullFormatter::Format(const LogMessage& Msg, MemoryBuffer& OutBuffer) +{ + ZEN_MEMSCOPE(ELLMTag::Logging); + + // Note that the sink is responsible for ensuring there is only ever a + // single caller in here + + using namespace std::literals; + + std::chrono::seconds TimestampSeconds; + + std::chrono::milliseconds Millis; + + if (m_Impl->m_UseFullDate) + { + TimestampSeconds = std::chrono::duration_cast<std::chrono::seconds>(Msg.GetTime().time_since_epoch()); + if (TimestampSeconds != m_Impl->m_LastLogSecs) + { + RwLock::ExclusiveLockScope _(m_Impl->m_TimestampLock); + m_Impl->m_LastLogSecs = TimestampSeconds; + + m_Impl->m_CachedLocalTm = helpers::SafeLocaltime(LogClock::to_time_t(Msg.GetTime())); + m_Impl->m_CachedDatetime.clear(); + m_Impl->m_CachedDatetime.push_back('['); + helpers::Pad2(m_Impl->m_CachedLocalTm.tm_year % 100, m_Impl->m_CachedDatetime); + m_Impl->m_CachedDatetime.push_back('-'); + helpers::Pad2(m_Impl->m_CachedLocalTm.tm_mon + 1, m_Impl->m_CachedDatetime); + m_Impl->m_CachedDatetime.push_back('-'); + helpers::Pad2(m_Impl->m_CachedLocalTm.tm_mday, m_Impl->m_CachedDatetime); + m_Impl->m_CachedDatetime.push_back(' '); + helpers::Pad2(m_Impl->m_CachedLocalTm.tm_hour, m_Impl->m_CachedDatetime); + m_Impl->m_CachedDatetime.push_back(':'); + helpers::Pad2(m_Impl->m_CachedLocalTm.tm_min, m_Impl->m_CachedDatetime); + m_Impl->m_CachedDatetime.push_back(':'); + helpers::Pad2(m_Impl->m_CachedLocalTm.tm_sec, m_Impl->m_CachedDatetime); + m_Impl->m_CachedDatetime.push_back('.'); + } + + Millis = helpers::TimeFraction<std::chrono::milliseconds>(Msg.GetTime()); + } + else + { + auto ElapsedTime = Msg.GetTime() - m_Impl->m_Epoch; + TimestampSeconds = std::chrono::duration_cast<std::chrono::seconds>(ElapsedTime); + + if (m_Impl->m_CacheTimestamp.load() != TimestampSeconds) + { + RwLock::ExclusiveLockScope _(m_Impl->m_TimestampLock); + + m_Impl->m_CacheTimestamp = TimestampSeconds; + + int Count = int(TimestampSeconds.count()); + const int LogSecs = Count % 60; + Count /= 60; + const int LogMins = Count % 60; + Count /= 60; + const int LogHours = Count; + + m_Impl->m_CachedDatetime.clear(); + m_Impl->m_CachedDatetime.push_back('['); + helpers::Pad2(LogHours, m_Impl->m_CachedDatetime); + m_Impl->m_CachedDatetime.push_back(':'); + helpers::Pad2(LogMins, m_Impl->m_CachedDatetime); + m_Impl->m_CachedDatetime.push_back(':'); + helpers::Pad2(LogSecs, m_Impl->m_CachedDatetime); + m_Impl->m_CachedDatetime.push_back('.'); + } + + Millis = std::chrono::duration_cast<std::chrono::milliseconds>(ElapsedTime - TimestampSeconds); + } + + { + RwLock::SharedLockScope _(m_Impl->m_TimestampLock); + OutBuffer.append(m_Impl->m_CachedDatetime.begin(), m_Impl->m_CachedDatetime.end()); + } + + helpers::Pad3(static_cast<uint32_t>(Millis.count()), OutBuffer); + OutBuffer.push_back(']'); + OutBuffer.push_back(' '); + + if (!m_Impl->m_LogId.empty()) + { + OutBuffer.push_back('['); + helpers::AppendStringView(m_Impl->m_LogId, OutBuffer); + OutBuffer.push_back(']'); + OutBuffer.push_back(' '); + } + + // append logger name if exists + if (Msg.GetLoggerName().size() > 0) + { + OutBuffer.push_back('['); + helpers::AppendStringView(Msg.GetLoggerName(), OutBuffer); + OutBuffer.push_back(']'); + OutBuffer.push_back(' '); + } + + OutBuffer.push_back('['); + if (IsColorEnabled()) + { + helpers::AppendAnsiColor(Msg.GetLevel(), OutBuffer); + } + helpers::AppendStringView(helpers::LevelToShortString(Msg.GetLevel()), OutBuffer); + if (IsColorEnabled()) + { + helpers::AppendAnsiReset(OutBuffer); + } + OutBuffer.push_back(']'); + OutBuffer.push_back(' '); + + // add source location if present + if (Msg.GetSource()) + { + OutBuffer.push_back('['); + const char* Filename = helpers::ShortFilename(Msg.GetSource().Filename); + helpers::AppendStringView(Filename, OutBuffer); + OutBuffer.push_back(':'); + helpers::AppendInt(Msg.GetSource().Line, OutBuffer); + OutBuffer.push_back(']'); + OutBuffer.push_back(' '); + } + + // Handle newlines in single log call by prefixing each additional line to make + // subsequent lines align with the first line in the message + + size_t AnsiBytes = IsColorEnabled() ? (helpers::AnsiColorForLevel(Msg.GetLevel()).size() + helpers::kAnsiReset.size()) : 0; + const size_t LinePrefixCount = Min<size_t>(OutBuffer.size() - AnsiBytes, m_Impl->m_LinePrefix.size()); + + auto MsgPayload = Msg.GetPayload(); + auto ItLineBegin = MsgPayload.begin(); + auto ItMessageEnd = MsgPayload.end(); + bool IsFirstline = true; + + { + auto ItLineEnd = ItLineBegin; + + auto EmitLine = [&] { + if (IsFirstline) + { + IsFirstline = false; + } + else + { + helpers::AppendStringView(std::string_view(m_Impl->m_LinePrefix.data(), LinePrefixCount), OutBuffer); + } + helpers::AppendStringView(std::string_view(&*ItLineBegin, ItLineEnd - ItLineBegin), OutBuffer); + }; + + while (ItLineEnd != ItMessageEnd) + { + if (*ItLineEnd++ == '\n') + { + EmitLine(); + ItLineBegin = ItLineEnd; + } + } + + if (ItLineBegin != ItMessageEnd) + { + EmitLine(); + helpers::AppendStringView("\n"sv, OutBuffer); + } + } +} + +} // namespace zen::logging diff --git a/src/zenutil/logging/jsonformatter.cpp b/src/zenutil/logging/jsonformatter.cpp new file mode 100644 index 000000000..673a03c94 --- /dev/null +++ b/src/zenutil/logging/jsonformatter.cpp @@ -0,0 +1,198 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenutil/logging/jsonformatter.h> + +#include <zencore/logging/helpers.h> +#include <zencore/memory/llm.h> +#include <zencore/thread.h> +#include <zencore/zencore.h> + +#include <chrono> +#include <string> +#include <unordered_map> + +namespace zen::logging { + +using namespace std::literals; + +static void +WriteEscapedString(MemoryBuffer& Dest, std::string_view Text) +{ + // Strip ANSI SGR sequences before escaping so they don't appear in JSON output + static const auto IsEscapeStart = [](char C) { return C == '\033'; }; + + const char* RangeStart = Text.data(); + const char* End = Text.data() + Text.size(); + + static const std::unordered_map<char, std::string_view> SpecialCharacterMap{ + {'\b', "\\b"sv}, + {'\f', "\\f"sv}, + {'\n', "\\n"sv}, + {'\r', "\\r"sv}, + {'\t', "\\t"sv}, + {'"', "\\\""sv}, + {'\\', "\\\\"sv}, + }; + + for (const char* It = RangeStart; It != End; ++It) + { + // Skip ANSI SGR escape sequences (\033[...m) + if (*It == '\033' && (It + 1) < End && *(It + 1) == '[') + { + if (RangeStart != It) + { + Dest.append(RangeStart, It); + } + const char* Seq = It + 2; + while (Seq < End && *Seq != 'm') + { + ++Seq; + } + if (Seq < End) + { + ++Seq; // skip 'm' + } + It = Seq - 1; // -1 because the for loop will ++It + RangeStart = Seq; + continue; + } + + if (auto SpecialIt = SpecialCharacterMap.find(*It); SpecialIt != SpecialCharacterMap.end()) + { + if (RangeStart != It) + { + Dest.append(RangeStart, It); + } + helpers::AppendStringView(SpecialIt->second, Dest); + RangeStart = It + 1; + } + } + if (RangeStart != End) + { + Dest.append(RangeStart, End); + } +} + +struct JsonFormatter::Impl +{ + explicit Impl(std::string_view LogId) : m_LogId(LogId) {} + + std::tm m_CachedTm{0, 0, 0, 0, 0, 0, 0, 0, 0}; + std::chrono::seconds m_LastLogSecs{0}; + MemoryBuffer m_CachedDatetime; + std::string m_LogId; + RwLock m_TimestampLock; +}; + +JsonFormatter::JsonFormatter(std::string_view LogId) : m_Impl(std::make_unique<Impl>(LogId)) +{ +} + +JsonFormatter::~JsonFormatter() = default; + +std::unique_ptr<Formatter> +JsonFormatter::Clone() const +{ + return std::make_unique<JsonFormatter>(m_Impl->m_LogId); +} + +void +JsonFormatter::Format(const LogMessage& Msg, MemoryBuffer& Dest) +{ + ZEN_MEMSCOPE(ELLMTag::Logging); + + using std::chrono::duration_cast; + using std::chrono::milliseconds; + using std::chrono::seconds; + + auto Secs = duration_cast<seconds>(Msg.GetTime().time_since_epoch()); + if (Secs != m_Impl->m_LastLogSecs) + { + RwLock::ExclusiveLockScope _(m_Impl->m_TimestampLock); + m_Impl->m_CachedTm = helpers::SafeLocaltime(LogClock::to_time_t(Msg.GetTime())); + m_Impl->m_LastLogSecs = Secs; + + // cache the date/time part for the next second. + m_Impl->m_CachedDatetime.clear(); + + helpers::AppendInt(m_Impl->m_CachedTm.tm_year + 1900, m_Impl->m_CachedDatetime); + m_Impl->m_CachedDatetime.push_back('-'); + + helpers::Pad2(m_Impl->m_CachedTm.tm_mon + 1, m_Impl->m_CachedDatetime); + m_Impl->m_CachedDatetime.push_back('-'); + + helpers::Pad2(m_Impl->m_CachedTm.tm_mday, m_Impl->m_CachedDatetime); + m_Impl->m_CachedDatetime.push_back(' '); + + helpers::Pad2(m_Impl->m_CachedTm.tm_hour, m_Impl->m_CachedDatetime); + m_Impl->m_CachedDatetime.push_back(':'); + + helpers::Pad2(m_Impl->m_CachedTm.tm_min, m_Impl->m_CachedDatetime); + m_Impl->m_CachedDatetime.push_back(':'); + + helpers::Pad2(m_Impl->m_CachedTm.tm_sec, m_Impl->m_CachedDatetime); + + m_Impl->m_CachedDatetime.push_back('.'); + } + helpers::AppendStringView("{"sv, Dest); + helpers::AppendStringView("\"time\": \""sv, Dest); + { + RwLock::SharedLockScope _(m_Impl->m_TimestampLock); + Dest.append(m_Impl->m_CachedDatetime.begin(), m_Impl->m_CachedDatetime.end()); + } + auto Millis = helpers::TimeFraction<milliseconds>(Msg.GetTime()); + helpers::Pad3(static_cast<uint32_t>(Millis.count()), Dest); + helpers::AppendStringView("\", "sv, Dest); + + helpers::AppendStringView("\"status\": \""sv, Dest); + helpers::AppendStringView(helpers::LevelToShortString(Msg.GetLevel()), Dest); + helpers::AppendStringView("\", "sv, Dest); + + helpers::AppendStringView("\"source\": \""sv, Dest); + helpers::AppendStringView("zenserver"sv, Dest); + helpers::AppendStringView("\", "sv, Dest); + + helpers::AppendStringView("\"service\": \""sv, Dest); + helpers::AppendStringView("zencache"sv, Dest); + helpers::AppendStringView("\", "sv, Dest); + + if (!m_Impl->m_LogId.empty()) + { + helpers::AppendStringView("\"id\": \""sv, Dest); + helpers::AppendStringView(m_Impl->m_LogId, Dest); + helpers::AppendStringView("\", "sv, Dest); + } + + if (Msg.GetLoggerName().size() > 0) + { + helpers::AppendStringView("\"logger.name\": \""sv, Dest); + helpers::AppendStringView(Msg.GetLoggerName(), Dest); + helpers::AppendStringView("\", "sv, Dest); + } + + if (Msg.GetThreadId() != 0) + { + helpers::AppendStringView("\"logger.thread_name\": \""sv, Dest); + helpers::PadUint(Msg.GetThreadId(), 0, Dest); + helpers::AppendStringView("\", "sv, Dest); + } + + if (Msg.GetSource()) + { + helpers::AppendStringView("\"file\": \""sv, Dest); + WriteEscapedString(Dest, helpers::ShortFilename(Msg.GetSource().Filename)); + helpers::AppendStringView("\","sv, Dest); + + helpers::AppendStringView("\"line\": \""sv, Dest); + helpers::AppendInt(Msg.GetSource().Line, Dest); + helpers::AppendStringView("\","sv, Dest); + } + + helpers::AppendStringView("\"message\": \""sv, Dest); + WriteEscapedString(Dest, Msg.GetPayload()); + helpers::AppendStringView("\""sv, Dest); + + helpers::AppendStringView("}\n"sv, Dest); +} + +} // namespace zen::logging diff --git a/src/zenutil/logging.cpp b/src/zenutil/logging/logging.cpp index 1258ca155..ea2448a42 100644 --- a/src/zenutil/logging.cpp +++ b/src/zenutil/logging/logging.cpp @@ -238,8 +238,10 @@ FinishInitializeLogging(const LoggingOptions& LogOptions) const std::string StartLogTime = zen::DateTime::Now().ToIso8601(); - static constinit logging::LogPoint LogStartPoint{{}, logging::Info, "log starting at {}"}; - logging::Registry::Instance().ApplyAll([&](auto Logger) { Logger->Log(LogStartPoint, fmt::make_format_args(StartLogTime)); }); + logging::Registry::Instance().ApplyAll([&](auto Logger) { + static constinit logging::LogPoint LogStartPoint{{}, logging::Info, "log starting at {}"}; + Logger->Log(LogStartPoint, fmt::make_format_args(StartLogTime)); + }); } g_IsLoggingInitialized = true; diff --git a/src/zenutil/logging/rotatingfilesink.cpp b/src/zenutil/logging/rotatingfilesink.cpp new file mode 100644 index 000000000..23cf60d16 --- /dev/null +++ b/src/zenutil/logging/rotatingfilesink.cpp @@ -0,0 +1,249 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenutil/logging/rotatingfilesink.h> + +#include <zencore/basicfile.h> +#include <zencore/filesystem.h> +#include <zencore/logging/helpers.h> +#include <zencore/logging/messageonlyformatter.h> +#include <zencore/memory/llm.h> +#include <zencore/thread.h> + +#include <atomic> + +namespace zen::logging { + +struct RotatingFileSink::Impl +{ + Impl(const std::filesystem::path& BaseFilename, std::size_t MaxSize, std::size_t MaxFiles, bool RotateOnOpen) + : m_BaseFilename(BaseFilename) + , m_MaxSize(MaxSize) + , m_MaxFiles(MaxFiles) + , m_Formatter(std::make_unique<MessageOnlyFormatter>()) + { + ZEN_MEMSCOPE(ELLMTag::Logging); + + std::error_code Ec; + if (RotateOnOpen) + { + RwLock::ExclusiveLockScope RotateLock(m_Lock); + Rotate(RotateLock, Ec); + } + else + { + m_CurrentFile.Open(m_BaseFilename, BasicFile::Mode::kWrite, Ec); + if (!Ec) + { + m_CurrentSize = m_CurrentFile.FileSize(Ec); + } + if (!Ec) + { + if (m_CurrentSize > m_MaxSize) + { + RwLock::ExclusiveLockScope RotateLock(m_Lock); + Rotate(RotateLock, Ec); + } + } + } + + if (Ec) + { + throw std::system_error(Ec, fmt::format("Failed to open log file '{}'", m_BaseFilename.string())); + } + } + + ~Impl() + { + try + { + RwLock::ExclusiveLockScope RotateLock(m_Lock); + m_CurrentFile.Close(); + } + catch (const std::exception&) + { + } + } + + void Rotate(RwLock::ExclusiveLockScope&, std::error_code& OutEc) + { + ZEN_MEMSCOPE(ELLMTag::Logging); + + m_CurrentFile.Close(); + + OutEc = RotateFiles(m_BaseFilename, m_MaxFiles); + if (OutEc) + { + return; + } + + m_CurrentFile.Open(m_BaseFilename, BasicFile::Mode::kWrite, OutEc); + if (OutEc) + { + return; + } + + m_CurrentSize = m_CurrentFile.FileSize(OutEc); + if (OutEc) + { + // FileSize failed but we have an open file — reset to 0 + // so we can at least attempt writes from the start + m_CurrentSize = 0; + OutEc.clear(); + } + } + + bool TrySinkIt(const LogMessage& Msg, MemoryBuffer& OutFormatted) + { + ZEN_MEMSCOPE(ELLMTag::Logging); + + RwLock::SharedLockScope Lock(m_Lock); + if (!m_CurrentFile.IsOpen()) + { + return false; + } + m_Formatter->Format(Msg, OutFormatted); + helpers::StripAnsiSgrSequences(OutFormatted); + size_t AddSize = OutFormatted.size(); + size_t WritePos = m_CurrentSize.fetch_add(AddSize); + if (WritePos + AddSize > m_MaxSize) + { + return false; + } + std::error_code Ec; + m_CurrentFile.Write(OutFormatted.data(), OutFormatted.size(), WritePos, Ec); + if (Ec) + { + return false; + } + m_NeedFlush = true; + return true; + } + + bool TrySinkIt(const MemoryBuffer& Formatted) + { + ZEN_MEMSCOPE(ELLMTag::Logging); + + RwLock::SharedLockScope Lock(m_Lock); + if (!m_CurrentFile.IsOpen()) + { + return false; + } + size_t AddSize = Formatted.size(); + size_t WritePos = m_CurrentSize.fetch_add(AddSize); + if (WritePos + AddSize > m_MaxSize) + { + return false; + } + + std::error_code Ec; + m_CurrentFile.Write(Formatted.data(), Formatted.size(), WritePos, Ec); + if (Ec) + { + return false; + } + m_NeedFlush = true; + return true; + } + + RwLock m_Lock; + const std::filesystem::path m_BaseFilename; + const std::size_t m_MaxSize; + const std::size_t m_MaxFiles; + std::unique_ptr<Formatter> m_Formatter; + std::atomic_size_t m_CurrentSize; + BasicFile m_CurrentFile; + std::atomic<bool> m_NeedFlush = false; +}; + +RotatingFileSink::RotatingFileSink(const std::filesystem::path& BaseFilename, std::size_t MaxSize, std::size_t MaxFiles, bool RotateOnOpen) +: m_Impl(std::make_unique<Impl>(BaseFilename, MaxSize, MaxFiles, RotateOnOpen)) +{ +} + +RotatingFileSink::~RotatingFileSink() = default; + +void +RotatingFileSink::Log(const LogMessage& Msg) +{ + ZEN_MEMSCOPE(ELLMTag::Logging); + + try + { + MemoryBuffer Formatted; + if (m_Impl->TrySinkIt(Msg, Formatted)) + { + return; + } + + // This intentionally has no limit on the number of retries, see + // comment above. + for (;;) + { + { + RwLock::ExclusiveLockScope RotateLock(m_Impl->m_Lock); + // Only rotate if no-one else has rotated before us + if (m_Impl->m_CurrentSize > m_Impl->m_MaxSize || !m_Impl->m_CurrentFile.IsOpen()) + { + std::error_code Ec; + m_Impl->Rotate(RotateLock, Ec); + if (Ec) + { + return; + } + } + } + if (m_Impl->TrySinkIt(Formatted)) + { + return; + } + } + } + catch (const std::exception&) + { + // Silently eat errors + } +} + +void +RotatingFileSink::Flush() +{ + if (!m_Impl->m_NeedFlush) + { + return; + } + + ZEN_MEMSCOPE(ELLMTag::Logging); + + try + { + RwLock::SharedLockScope Lock(m_Impl->m_Lock); + if (m_Impl->m_CurrentFile.IsOpen()) + { + m_Impl->m_CurrentFile.Flush(); + } + } + catch (const std::exception&) + { + // Silently eat errors + } + + m_Impl->m_NeedFlush = false; +} + +void +RotatingFileSink::SetFormatter(std::unique_ptr<Formatter> InFormatter) +{ + ZEN_MEMSCOPE(ELLMTag::Logging); + + try + { + RwLock::ExclusiveLockScope _(m_Impl->m_Lock); + m_Impl->m_Formatter = std::move(InFormatter); + } + catch (const std::exception&) + { + // Silently eat errors + } +} + +} // namespace zen::logging |