diff options
| author | Stefan Boberg <[email protected]> | 2026-03-04 14:13:46 +0100 |
|---|---|---|
| committer | GitHub Enterprise <[email protected]> | 2026-03-04 14:13:46 +0100 |
| commit | 0763d09a81e5a1d3df11763a7ec75e7860c9510a (patch) | |
| tree | 074575ba6ea259044a179eab0bb396d37268fb09 | |
| parent | native xmake toolchain definition for UE-clang (#805) (diff) | |
| download | zen-0763d09a81e5a1d3df11763a7ec75e7860c9510a.tar.xz zen-0763d09a81e5a1d3df11763a7ec75e7860c9510a.zip | |
compute orchestration (#763)
- Added local process runners for Linux/Wine, Mac with some sandboxing support
- Horde & Nomad provisioning for development and testing
- Client session queues with lifecycle management (active/draining/cancelled), automatic retry with configurable limits, and manual reschedule API
- Improved web UI for orchestrator, compute, and hub dashboards with WebSocket push updates
- Some security hardening
- Improved scalability and `zen exec` command
Still experimental - compute support is disabled by default
121 files changed, 24372 insertions, 2901 deletions
diff --git a/docs/compute.md b/docs/compute.md index 417622f94..df8a22870 100644 --- a/docs/compute.md +++ b/docs/compute.md @@ -122,31 +122,82 @@ functions: version: '83027356-2cf7-41ca-aba5-c81ab0ff2129' ``` -## API (WIP not final) +## API -The compute interfaces are currently exposed on the `/apply` endpoint but this -will be subject to change as we adapt the interfaces during development. The LSN +The compute interfaces are exposed on the `/compute` endpoint. The LSN APIs below are intended to replace the action ID oriented APIs. The POST APIs typically involve a two-step dance where a descriptor is POSTed and -the service responds with a list of `needs` chunks (identified via `IoHash`) which -it does not have yet. The client can then follow up with a POST of a Compact Binary +the service responds with a list of `needs` chunks (identified via `IoHash`) which +it does not have yet. The client can then follow up with a POST of a Compact Binary Package containing the descriptor along with the needed chunks. -`/apply/ready` - health check endpoint returns HTTP 200 OK or HTTP 503 +`/compute/ready` - health check endpoint returns HTTP 200 OK or HTTP 503 -`/apply/sysinfo` - system information endpoint +`/compute/sysinfo` - system information endpoint -`/apply/record/start`, `/apply/record/stop` - start/stop action recording +`/compute/record/start`, `/compute/record/stop` - start/stop action recording -`/apply/workers/{worker}` - GET/POST worker descriptors and payloads +`/compute/workers/{worker}` - GET/POST worker descriptors and payloads -`/apply/jobs/completed` - GET list of completed actions +`/compute/jobs/completed` - GET list of completed actions -`/apply/jobs/{lsn}` - GET completed action results from LSN, POST action cancellation by LSN, priority changes by LSN +`/compute/jobs/{lsn}` - GET completed action results from LSN, POST action cancellation by LSN, priority changes by LSN -`/apply/jobs/{worker}/{action}` - GET completed action (job) results by action ID +`/compute/jobs/{worker}/{action}` - GET completed action (job) results by action ID -`/apply/jobs/{worker}` - GET pending/running jobs for worker, POST requests to schedule action as a job +`/compute/jobs/{worker}` - GET pending/running jobs for worker, POST requests to schedule action as a job -`/apply/jobs` - POST request to schedule action as a job +`/compute/jobs` - POST request to schedule action as a job + +### Queues + +Queues provide a way to logically group actions submitted by a client session. This enables +per-session cancellation and completion polling without affecting actions submitted by other +sessions. + +#### Local access (integer ID routes) + +These routes use sequential integer queue IDs and are restricted to local (loopback) +connections only. Remote requests receive HTTP 403 Forbidden. + +`/compute/queues` - POST to create a new queue. Returns a `queue_id` which is used to +reference the queue in subsequent requests. + +`/compute/queues/{queue}` - GET queue status (active, completed, failed, and cancelled +action counts, plus `is_complete` flag indicating all actions have finished). DELETE to +cancel all pending and running actions in the queue. + +`/compute/queues/{queue}/completed` - GET list of completed action LSNs for this queue +whose results have not yet been retired. A queue-scoped alternative to `/compute/jobs/completed`. + +`/compute/queues/{queue}/jobs` - POST to submit an action to a queue with automatic worker +resolution. Accepts an optional `priority` query parameter. + +`/compute/queues/{queue}/jobs/{worker}` - POST to submit an action to a queue targeting a +specific worker. Accepts an optional `priority` query parameter. + +`/compute/queues/{queue}/jobs/{lsn}` - GET action result by LSN, scoped to the queue + +#### Remote access (OID token routes) + +These routes use cryptographically generated 24-character hex tokens (OIDs) instead of +integer queue IDs. Tokens are unguessable and safe to use over the network. The token +mapping lives entirely in the HTTP service layer; the underlying compute service only +knows about integer queue IDs. + +`/compute/queues/remote` - POST to create a new queue with token-based access. Returns +`queue_token` (24-char hex string) and `queue_id` (integer, for internal visibility). + +`/compute/queues/{oidtoken}` - GET queue status or DELETE to cancel, same semantics as +the integer ID variant but using the OID token for identification. + +`/compute/queues/{oidtoken}/completed` - GET list of completed action LSNs for this queue. + +`/compute/queues/{oidtoken}/jobs` - POST to submit an action to a queue with automatic +worker resolution. + +`/compute/queues/{oidtoken}/jobs/{worker}` - POST to submit an action targeting a specific +worker. + +`/compute/queues/{oidtoken}/jobs/{lsn}` - GET action result by LSN, scoped to the queue diff --git a/repo/packages/n/nomad/xmake.lua b/repo/packages/n/nomad/xmake.lua new file mode 100644 index 000000000..85ea10985 --- /dev/null +++ b/repo/packages/n/nomad/xmake.lua @@ -0,0 +1,37 @@ +-- this package only provides the nomad binary, to be used for testing nomad provisioning + +package("nomad") + set_homepage("https://www.nomadproject.io/") + set_description("Nomad is a workload orchestrator that deploys and manages containers and non-containerized applications.") + + if is_plat("windows") then + add_urls("https://releases.hashicorp.com/nomad/$(version)/nomad_$(version)_windows_amd64.zip") + add_versions("1.9.7", "419e417d33f94888e176f2cccf1a101a16fc283bf721021f2a11f1b74570db97") + elseif is_plat("linux") then + add_urls("https://releases.hashicorp.com/nomad/$(version)/nomad_$(version)_linux_amd64.zip") + add_versions("1.9.7", "e9c7337893eceb549557ef9ad341b3ae64f5f43e29ff1fb167b70cfd16748d2d") + elseif is_plat("macosx") then + if is_arch("arm64") then + add_urls("https://releases.hashicorp.com/nomad/$(version)/nomad_$(version)_darwin_arm64.zip") + add_versions("1.9.7", "90f87dffb3669a842a8428899088f3a0ec5a0d204e5278dbb0c1ac16ab295935") + else + add_urls("https://releases.hashicorp.com/nomad/$(version)/nomad_$(version)_darwin_amd64.zip") + add_versions("1.9.7", "8f5befe1e11ef5664c0c212053aa3fc3e095e52a86e90c1315d7580f19ad7997") + end + end + + on_install(function (package) + if is_plat("windows") then + os.cp("nomad.exe", package:installdir("bin")) + else + os.cp("nomad", package:installdir("bin")) + end + end) + + on_test(function (package) + if is_plat("windows") then + os.run("%s version", package:installdir("bin", "nomad.exe")) + elseif is_plat("linux") then + os.run("%s version", package:installdir("bin", "nomad")) + end + end) diff --git a/src/zen/cmds/exec_cmd.cpp b/src/zen/cmds/exec_cmd.cpp index 407f42ee3..42c7119e7 100644 --- a/src/zen/cmds/exec_cmd.cpp +++ b/src/zen/cmds/exec_cmd.cpp @@ -2,7 +2,7 @@ #include "exec_cmd.h" -#include <zencompute/functionservice.h> +#include <zencompute/computeservice.h> #include <zencompute/recordingreader.h> #include <zencore/compactbinary.h> #include <zencore/compactbinarybuilder.h> @@ -14,9 +14,13 @@ #include <zencore/fmtutils.h> #include <zencore/logging.h> #include <zencore/scopeguard.h> +#include <zencore/session.h> #include <zencore/stream.h> #include <zencore/string.h> +#include <zencore/system.h> #include <zencore/timer.h> +#include <zenhttp/httpclient.h> +#include <zenhttp/packageformat.h> #include <EASTL/hash_map.h> #include <EASTL/hash_set.h> @@ -47,12 +51,17 @@ ExecCommand::ExecCommand() m_Options.add_option("", "", "stride", "Recording replay stride", cxxopts::value(m_Stride), "<stride>"); m_Options.add_option("", "", "limit", "Recording replay limit", cxxopts::value(m_Limit), "<limit>"); m_Options.add_option("", "", "beacon", "Beacon path", cxxopts::value(m_BeaconPath), "<path>"); + m_Options.add_option("", "", "orch", "Orchestrator URL for worker discovery", cxxopts::value(m_OrchestratorUrl), "<url>"); m_Options.add_option("", "", "mode", "Select execution mode (http,inproc,dump,direct,beacon,buildlog)", cxxopts::value(m_Mode)->default_value("http"), "<string>"); + m_Options + .add_option("", "", "dump-actions", "Dump each action to console as it is dispatched", cxxopts::value(m_DumpActions), "<bool>"); + m_Options.add_option("", "o", "output", "Save action results to directory", cxxopts::value(m_OutputPath), "<path>"); + m_Options.add_option("", "", "binary", "Write output as binary packages instead of YAML", cxxopts::value(m_Binary), "<bool>"); m_Options.add_option("", "", "quiet", "Quiet mode (less logging)", cxxopts::value(m_Quiet), "<bool>"); m_Options.parse_positional("mode"); } @@ -236,16 +245,16 @@ ExecCommand::InProcessExecute() ZEN_ASSERT(m_ChunkResolver); ChunkResolver& Resolver = *m_ChunkResolver; - zen::compute::FunctionServiceSession FunctionSession(Resolver); + zen::compute::ComputeServiceSession ComputeSession(Resolver); std::filesystem::path TempPath = std::filesystem::absolute(".zen_temp"); - FunctionSession.AddLocalRunner(Resolver, TempPath); + ComputeSession.AddLocalRunner(Resolver, TempPath); - return ExecUsingSession(FunctionSession); + return ExecUsingSession(ComputeSession); } int -ExecCommand::ExecUsingSession(zen::compute::FunctionServiceSession& FunctionSession) +ExecCommand::ExecUsingSession(zen::compute::ComputeServiceSession& ComputeSession) { struct JobTracker { @@ -281,6 +290,117 @@ ExecCommand::ExecUsingSession(zen::compute::FunctionServiceSession& FunctionSess JobTracker PendingJobs; + struct ActionSummaryEntry + { + int32_t Lsn = 0; + int RecordingIndex = 0; + IoHash ActionId; + std::string FunctionName; + int InputAttachments = 0; + uint64_t InputBytes = 0; + int OutputAttachments = 0; + uint64_t OutputBytes = 0; + float WallSeconds = 0.0f; + float CpuSeconds = 0.0f; + uint64_t SubmittedTicks = 0; + uint64_t StartedTicks = 0; + std::string ExecutionLocation; + }; + + std::mutex SummaryLock; + std::unordered_map<int32_t, ActionSummaryEntry> SummaryEntries; + + ComputeSession.WaitUntilReady(); + + // Register as a client with the orchestrator (best-effort) + + std::string OrchestratorClientId; + + if (!m_OrchestratorUrl.empty()) + { + try + { + HttpClient OrchestratorClient(m_OrchestratorUrl); + + CbObjectWriter Ann; + Ann << "session_id"sv << GetSessionId(); + Ann << "hostname"sv << std::string_view(GetMachineName()); + + CbObjectWriter Meta; + Meta << "source"sv + << "zen-exec"sv; + Ann << "metadata"sv << Meta.Save(); + + auto Resp = OrchestratorClient.Post("/orch/clients", Ann.Save()); + if (Resp.IsSuccess()) + { + OrchestratorClientId = std::string(Resp.AsObject()["id"].AsString()); + ZEN_CONSOLE_INFO("registered with orchestrator as {}", OrchestratorClientId); + } + else + { + ZEN_WARN("failed to register with orchestrator (status {})", static_cast<int>(Resp.StatusCode)); + } + } + catch (const std::exception& Ex) + { + ZEN_WARN("failed to register with orchestrator: {}", Ex.what()); + } + } + + Stopwatch OrchestratorHeartbeatTimer; + + auto SendOrchestratorHeartbeat = [&] { + if (OrchestratorClientId.empty() || OrchestratorHeartbeatTimer.GetElapsedTimeMs() < 30'000) + { + return; + } + OrchestratorHeartbeatTimer.Reset(); + try + { + HttpClient OrchestratorClient(m_OrchestratorUrl); + std::ignore = OrchestratorClient.Post(fmt::format("/orch/clients/{}/update", OrchestratorClientId)); + } + catch (...) + { + } + }; + + auto ClientCleanup = MakeGuard([&] { + if (!OrchestratorClientId.empty()) + { + try + { + HttpClient OrchestratorClient(m_OrchestratorUrl); + std::ignore = OrchestratorClient.Post(fmt::format("/orch/clients/{}/complete", OrchestratorClientId)); + } + catch (...) + { + } + } + }); + + // Create a queue to group all actions from this exec session + + CbObjectWriter Metadata; + Metadata << "source"sv + << "zen-exec"sv; + + auto QueueResult = ComputeSession.CreateQueue("zen-exec", Metadata.Save()); + const int QueueId = QueueResult.QueueId; + if (!QueueId) + { + ZEN_ERROR("failed to create compute queue"); + return 1; + } + + auto QueueCleanup = MakeGuard([&] { ComputeSession.DeleteQueue(QueueId); }); + + if (!m_OutputPath.empty()) + { + zen::CreateDirectories(m_OutputPath); + } + std::atomic<int> IsDraining{0}; auto DrainCompletedJobs = [&] { @@ -292,7 +412,7 @@ ExecCommand::ExecUsingSession(zen::compute::FunctionServiceSession& FunctionSess auto _ = MakeGuard([&] { IsDraining.store(0, std::memory_order_release); }); CbObjectWriter Cbo; - FunctionSession.GetCompleted(Cbo); + ComputeSession.GetQueueCompleted(QueueId, Cbo); if (CbObject Completed = Cbo.Save()) { @@ -301,10 +421,89 @@ ExecCommand::ExecUsingSession(zen::compute::FunctionServiceSession& FunctionSess int32_t CompleteLsn = It.AsInt32(); CbPackage ResultPackage; - HttpResponseCode Response = FunctionSession.GetActionResult(CompleteLsn, /* out */ ResultPackage); + HttpResponseCode Response = ComputeSession.GetActionResult(CompleteLsn, /* out */ ResultPackage); if (Response == HttpResponseCode::OK) { + if (!m_OutputPath.empty() && ResultPackage) + { + int OutputAttachments = 0; + uint64_t OutputBytes = 0; + + if (!m_Binary) + { + // Write the root object as YAML + ExtendableStringBuilder<4096> YamlStr; + CompactBinaryToYaml(ResultPackage.GetObject(), YamlStr); + + std::string_view Yaml = YamlStr; + zen::WriteFile(m_OutputPath / fmt::format("{}.result.yaml", CompleteLsn), + IoBuffer(IoBuffer::Clone, Yaml.data(), Yaml.size())); + + // Write decompressed attachments + auto Attachments = ResultPackage.GetAttachments(); + + if (!Attachments.empty()) + { + std::filesystem::path AttDir = m_OutputPath / fmt::format("{}.result.attachments", CompleteLsn); + zen::CreateDirectories(AttDir); + + for (const CbAttachment& Att : Attachments) + { + ++OutputAttachments; + + IoHash AttHash = Att.GetHash(); + + if (Att.IsCompressedBinary()) + { + SharedBuffer Decompressed = Att.AsCompressedBinary().Decompress(); + OutputBytes += Decompressed.GetSize(); + zen::WriteFile(AttDir / AttHash.ToHexString(), + IoBuffer(IoBuffer::Clone, Decompressed.GetData(), Decompressed.GetSize())); + } + else + { + SharedBuffer Binary = Att.AsBinary(); + OutputBytes += Binary.GetSize(); + zen::WriteFile(AttDir / AttHash.ToHexString(), + IoBuffer(IoBuffer::Clone, Binary.GetData(), Binary.GetSize())); + } + } + } + + if (!m_QuietLogging) + { + ZEN_CONSOLE("saved result: {}/{}.result.yaml ({} attachments)", + m_OutputPath.string(), + CompleteLsn, + OutputAttachments); + } + } + else + { + CompositeBuffer Serialized = FormatPackageMessageBuffer(ResultPackage); + zen::WriteFile(m_OutputPath / fmt::format("{}.result.pkg", CompleteLsn), std::move(Serialized)); + + for (const CbAttachment& Att : ResultPackage.GetAttachments()) + { + ++OutputAttachments; + OutputBytes += Att.AsBinary().GetSize(); + } + + if (!m_QuietLogging) + { + ZEN_CONSOLE("saved result: {}/{}.result.pkg", m_OutputPath.string(), CompleteLsn); + } + } + + std::lock_guard Lock(SummaryLock); + if (auto It2 = SummaryEntries.find(CompleteLsn); It2 != SummaryEntries.end()) + { + It2->second.OutputAttachments = OutputAttachments; + It2->second.OutputBytes = OutputBytes; + } + } + PendingJobs.Remove(CompleteLsn); ZEN_CONSOLE("completed: LSN {} ({} still pending)", CompleteLsn, PendingJobs.GetSize()); @@ -321,7 +520,7 @@ ExecCommand::ExecUsingSession(zen::compute::FunctionServiceSession& FunctionSess { CbPackage WorkerDesc = Kv.second; - FunctionSession.RegisterWorker(WorkerDesc); + ComputeSession.RegisterWorker(WorkerDesc); } // Then submit work items @@ -367,10 +566,14 @@ ExecCommand::ExecUsingSession(zen::compute::FunctionServiceSession& FunctionSess TargetParallelism = 1; } + std::atomic<int> RecordingIndex{0}; + m_RecordingReader->IterateActions( [&](CbObject ActionObject, const IoHash& ActionId) { // Enqueue job + const int CurrentRecordingIndex = RecordingIndex++; + Stopwatch SubmitTimer; const int Priority = 0; @@ -404,8 +607,29 @@ ExecCommand::ExecUsingSession(zen::compute::FunctionServiceSession& FunctionSess ObjStr); } - if (zen::compute::FunctionServiceSession::EnqueueResult EnqueueResult = - FunctionSession.EnqueueAction(ActionObject, Priority)) + if (m_DumpActions) + { + int AttachmentCount = 0; + uint64_t AttachmentBytes = 0; + + ActionObject.IterateAttachments([&](CbFieldView Field) { + IoHash AttachData = Field.AsAttachment(); + + ++AttachmentCount; + + if (IoBuffer ChunkData = m_ChunkResolver->FindChunkByCid(AttachData)) + { + AttachmentBytes += ChunkData.GetSize(); + } + }); + + zen::ExtendableStringBuilder<1024> ObjStr; + zen::CompactBinaryToYaml(ActionObject, ObjStr); + ZEN_CONSOLE("action {} ({} attachments, {}):\n{}", ActionId, AttachmentCount, NiceBytes(AttachmentBytes), ObjStr); + } + + if (zen::compute::ComputeServiceSession::EnqueueResult EnqueueResult = + ComputeSession.EnqueueActionToQueue(QueueId, ActionObject, Priority)) { const int32_t LsnField = EnqueueResult.Lsn; @@ -421,6 +645,96 @@ ExecCommand::ExecUsingSession(zen::compute::FunctionServiceSession& FunctionSess RemainingWorkItems); } + if (!m_OutputPath.empty()) + { + ActionSummaryEntry Entry; + Entry.Lsn = LsnField; + Entry.RecordingIndex = CurrentRecordingIndex; + Entry.ActionId = ActionId; + Entry.FunctionName = std::string(ActionObject["Function"sv].AsString()); + + if (!m_Binary) + { + // Write action object as YAML + ExtendableStringBuilder<4096> YamlStr; + CompactBinaryToYaml(ActionObject, YamlStr); + + std::string_view Yaml = YamlStr; + zen::WriteFile(m_OutputPath / fmt::format("{}.action.yaml", LsnField), + IoBuffer(IoBuffer::Clone, Yaml.data(), Yaml.size())); + + // Write decompressed input attachments + std::filesystem::path AttDir = m_OutputPath / fmt::format("{}.action.attachments", LsnField); + bool AttDirCreated = false; + + ActionObject.IterateAttachments([&](CbFieldView Field) { + IoHash AttachCid = Field.AsAttachment(); + ++Entry.InputAttachments; + + if (IoBuffer ChunkData = m_ChunkResolver->FindChunkByCid(AttachCid)) + { + IoHash RawHash; + uint64_t RawSize = 0; + CompressedBuffer Compressed = + CompressedBuffer::FromCompressed(SharedBuffer(ChunkData), RawHash, RawSize); + SharedBuffer Decompressed = Compressed.Decompress(); + + Entry.InputBytes += Decompressed.GetSize(); + + if (!AttDirCreated) + { + zen::CreateDirectories(AttDir); + AttDirCreated = true; + } + + zen::WriteFile(AttDir / AttachCid.ToHexString(), + IoBuffer(IoBuffer::Clone, Decompressed.GetData(), Decompressed.GetSize())); + } + }); + + if (!m_QuietLogging) + { + ZEN_CONSOLE("saved action: {}/{}.action.yaml ({} attachments)", + m_OutputPath.string(), + LsnField, + Entry.InputAttachments); + } + } + else + { + // Build a CbPackage from the action and write as .pkg + CbPackage ActionPackage; + ActionPackage.SetObject(ActionObject); + + ActionObject.IterateAttachments([&](CbFieldView Field) { + IoHash AttachCid = Field.AsAttachment(); + ++Entry.InputAttachments; + + if (IoBuffer ChunkData = m_ChunkResolver->FindChunkByCid(AttachCid)) + { + IoHash RawHash; + uint64_t RawSize = 0; + CompressedBuffer Compressed = + CompressedBuffer::FromCompressed(SharedBuffer(ChunkData), RawHash, RawSize); + + Entry.InputBytes += ChunkData.GetSize(); + ActionPackage.AddAttachment(CbAttachment(std::move(Compressed), RawHash)); + } + }); + + CompositeBuffer Serialized = FormatPackageMessageBuffer(ActionPackage); + zen::WriteFile(m_OutputPath / fmt::format("{}.action.pkg", LsnField), std::move(Serialized)); + + if (!m_QuietLogging) + { + ZEN_CONSOLE("saved action: {}/{}.action.pkg", m_OutputPath.string(), LsnField); + } + } + + std::lock_guard Lock(SummaryLock); + SummaryEntries.emplace(LsnField, std::move(Entry)); + } + PendingJobs.Insert(LsnField); } else @@ -450,6 +764,7 @@ ExecCommand::ExecUsingSession(zen::compute::FunctionServiceSession& FunctionSess // Check for completed work DrainCompletedJobs(); + SendOrchestratorHeartbeat(); }, TargetParallelism); @@ -461,6 +776,394 @@ ExecCommand::ExecUsingSession(zen::compute::FunctionServiceSession& FunctionSess zen::Sleep(500); DrainCompletedJobs(); + SendOrchestratorHeartbeat(); + } + + // Merge timing data from queue history into summary entries + + if (!SummaryEntries.empty()) + { + // RunnerAction::State indices (can't include functionrunner.h from here) + constexpr int kStateNew = 0; + constexpr int kStatePending = 1; + constexpr int kStateRunning = 3; + constexpr int kStateCompleted = 4; // first terminal state + constexpr int kStateCount = 8; + + for (const auto& HistEntry : ComputeSession.GetQueueHistory(QueueId, 0)) + { + std::lock_guard Lock(SummaryLock); + if (auto It = SummaryEntries.find(HistEntry.Lsn); It != SummaryEntries.end()) + { + // Find terminal state timestamp (Completed, Failed, Abandoned, or Cancelled) + uint64_t EndTick = 0; + for (int S = kStateCompleted; S < kStateCount; ++S) + { + if (HistEntry.Timestamps[S] != 0) + { + EndTick = HistEntry.Timestamps[S]; + break; + } + } + uint64_t StartTick = HistEntry.Timestamps[kStateNew]; + if (EndTick > StartTick) + { + It->second.WallSeconds = float(double(EndTick - StartTick) / double(TimeSpan::TicksPerSecond)); + } + It->second.CpuSeconds = HistEntry.CpuSeconds; + It->second.SubmittedTicks = HistEntry.Timestamps[kStatePending]; + It->second.StartedTicks = HistEntry.Timestamps[kStateRunning]; + It->second.ExecutionLocation = HistEntry.ExecutionLocation; + } + } + } + + // Write summary file if output path is set + + if (!m_OutputPath.empty() && !SummaryEntries.empty()) + { + std::vector<ActionSummaryEntry> Sorted; + Sorted.reserve(SummaryEntries.size()); + for (auto& [_, Entry] : SummaryEntries) + { + Sorted.push_back(std::move(Entry)); + } + + std::sort(Sorted.begin(), Sorted.end(), [](const ActionSummaryEntry& A, const ActionSummaryEntry& B) { + return A.RecordingIndex < B.RecordingIndex; + }); + + auto FormatTimestamp = [](uint64_t Ticks) -> std::string { + if (Ticks == 0) + { + return "-"; + } + return DateTime(Ticks).ToString("%H:%M:%S.%s"); + }; + + ExtendableStringBuilder<4096> Summary; + Summary.Append(fmt::format("{:<8} {:<8} {:<40} {:<40} {:>8} {:>12} {:>8} {:>12} {:>8} {:>8} {:>12} {:>12} {:<24}\n", + "LSN", + "Index", + "ActionId", + "Function", + "InAtt", + "InBytes", + "OutAtt", + "OutBytes", + "Wall(s)", + "CPU(s)", + "Submitted", + "Started", + "Location")); + Summary.Append(fmt::format("{:-<8} {:-<8} {:-<40} {:-<40} {:-<8} {:-<12} {:-<8} {:-<12} {:-<8} {:-<8} {:-<12} {:-<12} {:-<24}\n", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "")); + + for (const ActionSummaryEntry& Entry : Sorted) + { + Summary.Append(fmt::format("{:<8} {:<8} {:<40} {:<40} {:>8} {:>12} {:>8} {:>12} {:>8.2f} {:>8.2f} {:>12} {:>12} {:<24}\n", + Entry.Lsn, + Entry.RecordingIndex, + Entry.ActionId, + Entry.FunctionName, + Entry.InputAttachments, + NiceBytes(Entry.InputBytes), + Entry.OutputAttachments, + NiceBytes(Entry.OutputBytes), + Entry.WallSeconds, + Entry.CpuSeconds, + FormatTimestamp(Entry.SubmittedTicks), + FormatTimestamp(Entry.StartedTicks), + Entry.ExecutionLocation)); + } + + std::filesystem::path SummaryPath = m_OutputPath / "summary.txt"; + std::string_view SummaryStr = Summary; + zen::WriteFile(SummaryPath, IoBuffer(IoBuffer::Clone, SummaryStr.data(), SummaryStr.size())); + + ZEN_CONSOLE("wrote summary to {}", SummaryPath.string()); + + if (!m_Binary) + { + auto EscapeHtml = [](std::string_view Input) -> std::string { + std::string Out; + Out.reserve(Input.size()); + for (char C : Input) + { + switch (C) + { + case '&': + Out += "&"; + break; + case '<': + Out += "<"; + break; + case '>': + Out += ">"; + break; + case '"': + Out += """; + break; + case '\'': + Out += "'"; + break; + default: + Out += C; + } + } + return Out; + }; + + auto EscapeJson = [](std::string_view Input) -> std::string { + std::string Out; + Out.reserve(Input.size()); + for (char C : Input) + { + switch (C) + { + case '"': + Out += "\\\""; + break; + case '\\': + Out += "\\\\"; + break; + case '\n': + Out += "\\n"; + break; + case '\r': + Out += "\\r"; + break; + case '\t': + Out += "\\t"; + break; + default: + if (static_cast<unsigned char>(C) < 0x20) + { + Out += fmt::format("\\u{:04x}", static_cast<unsigned>(static_cast<unsigned char>(C))); + } + else + { + Out += C; + } + } + } + return Out; + }; + + ExtendableStringBuilder<8192> Html; + + Html.Append(std::string_view(R"(<!DOCTYPE html> +<html><head><meta charset="utf-8"><title>Exec Summary</title> +<style> +body{font-family:system-ui,sans-serif;margin:20px;background:#fafafa} +#container{overflow-y:auto;height:calc(100vh - 120px)} +table{border-collapse:collapse;width:100%} +th,td{border:1px solid #ddd;padding:6px 10px;text-align:left;white-space:nowrap} +th{background:#f0f0f0;cursor:pointer;user-select:none;position:sticky;top:0;z-index:1} +th:hover{background:#e0e0e0} +th .arrow{font-size:0.7em;margin-left:4px} +tr:hover{background:#e8f0fe} +input{padding:6px 10px;margin-bottom:12px;width:300px;border:1px solid #ccc;border-radius:4px} +button{padding:6px 14px;margin-left:8px;margin-bottom:12px;border:1px solid #ccc;border-radius:4px;background:#f0f0f0;cursor:pointer} +button:hover{background:#e0e0e0} +a{color:#1a73e8;text-decoration:none} +a:hover{text-decoration:underline} +.num{text-align:right} +</style></head><body> +<h2>Exec Summary</h2> +<input type="text" id="filter" placeholder="Filter by function name..."><button id="csvBtn">Export CSV</button> +<div id="container"> +<table><thead><tr> +<th data-col="0">LSN <span class="arrow"></span></th> +<th data-col="1">Index <span class="arrow"></span></th> +<th data-col="2">Action ID <span class="arrow"></span></th> +<th data-col="3">Function <span class="arrow"></span></th> +<th data-col="4">In Attachments <span class="arrow"></span></th> +<th data-col="5">In Bytes <span class="arrow"></span></th> +<th data-col="6">Out Attachments <span class="arrow"></span></th> +<th data-col="7">Out Bytes <span class="arrow"></span></th> +<th data-col="8">Wall(s) <span class="arrow"></span></th> +<th data-col="9">CPU(s) <span class="arrow"></span></th> +<th data-col="10">Submitted <span class="arrow"></span></th> +<th data-col="11">Started <span class="arrow"></span></th> +<th data-col="12">Location <span class="arrow"></span></th> +</tr></thead><tbody> +<tr id="spacerTop"><td colspan="13"></td></tr> +<tr id="spacerBot"><td colspan="13"></td></tr> +</tbody></table></div> +<script> +const DATA=[ +)")); + + std::string_view ResultExt = ".result.yaml"; + std::string_view ActionExt = ".action.yaml"; + + for (const ActionSummaryEntry& Entry : Sorted) + { + std::string SafeName = EscapeJson(EscapeHtml(Entry.FunctionName)); + std::string ActionIdStr = Entry.ActionId.ToHexString(); + std::string ActionLink; + if (!ActionExt.empty()) + { + ActionLink = EscapeJson(fmt::format(" <a href=\"{}{}\">[action]</a>", Entry.Lsn, ActionExt)); + } + + // Indices: 0=lsn, 1=idx, 2=actionId, 3=fn, 4=inAtt, 5=inBytes, 6=outAtt, 7=outBytes, + // 8=wall, 9=cpu, 10=niceBytesIn, 11=niceBytesOut, 12=actionLink, + // 13=submittedTicks, 14=startedTicks, 15=submittedDisplay, 16=startedDisplay, + // 17=location + Html.Append( + fmt::format("[{},{},\"{}\",\"{}\",{},{},{},{},{:.6f},{:.6f},\"{}\",\"{}\",\"{}\",{},{},\"{}\",\"{}\",\"{}\"],\n", + Entry.Lsn, + Entry.RecordingIndex, + ActionIdStr, + SafeName, + Entry.InputAttachments, + Entry.InputBytes, + Entry.OutputAttachments, + Entry.OutputBytes, + Entry.WallSeconds, + Entry.CpuSeconds, + EscapeJson(NiceBytes(Entry.InputBytes)), + EscapeJson(NiceBytes(Entry.OutputBytes)), + ActionLink, + Entry.SubmittedTicks, + Entry.StartedTicks, + FormatTimestamp(Entry.SubmittedTicks), + FormatTimestamp(Entry.StartedTicks), + EscapeJson(EscapeHtml(Entry.ExecutionLocation)))); + } + + Html.Append(fmt::format(R"(]; +const RESULT_EXT="{}"; +)", + ResultExt)); + + Html.Append(std::string_view(R"JS((function(){ +const ROW_H=33,BUF=20; +const container=document.getElementById("container"); +const tbody=container.querySelector("tbody"); +const headers=container.querySelectorAll("th"); +const filterInput=document.getElementById("filter"); +const spacerTop=document.getElementById("spacerTop"); +const spacerBot=document.getElementById("spacerBot"); +let view=[...DATA.keys()]; +let sortCol=-1,sortAsc=true; +const COLS=[ + {f:0,t:"n"},{f:1,t:"n"},{f:2,t:"s"},{f:3,t:"s"}, + {f:4,t:"n"},{f:5,t:"n"},{f:6,t:"n"},{f:7,t:"n"}, + {f:8,t:"n"},{f:9,t:"n"},{f:13,t:"n"},{f:14,t:"n"},{f:17,t:"s"} +]; +function rowHtml(i){ + const d=DATA[view[i]]; + const bg=i%2?' style="background:#f9f9f9"':''; + return '<tr'+bg+'>'+ + '<td class="num"><a href="'+d[0]+RESULT_EXT+'">'+d[0]+'</a></td>'+ + '<td class="num">'+d[1]+'</td>'+ + '<td><code>'+d[2]+'</code></td>'+ + '<td>'+d[3]+d[12]+'</td>'+ + '<td class="num">'+d[4]+'</td>'+ + '<td class="num">'+d[10]+'</td>'+ + '<td class="num">'+d[6]+'</td>'+ + '<td class="num">'+d[11]+'</td>'+ + '<td class="num">'+d[8].toFixed(2)+'</td>'+ + '<td class="num">'+d[9].toFixed(2)+'</td>'+ + '<td class="num">'+d[15]+'</td>'+ + '<td class="num">'+d[16]+'</td>'+ + '<td>'+d[17]+'</td></tr>'; +} +let lastFirst=-1,lastLast=-1; +function render(){ + const scrollTop=container.scrollTop; + const viewH=container.clientHeight; + let first=Math.floor(scrollTop/ROW_H)-BUF; + let last=Math.ceil((scrollTop+viewH)/ROW_H)+BUF; + if(first<0) first=0; + if(last>=view.length) last=view.length-1; + if(first===lastFirst&&last===lastLast) return; + lastFirst=first;lastLast=last; + const rows=[]; + for(let i=first;i<=last;i++) rows.push(rowHtml(i)); + spacerTop.style.height=(first*ROW_H)+'px'; + spacerBot.style.height=((view.length-1-last)*ROW_H)+'px'; + const mid=rows.join(''); + const topTr='<tr id="spacerTop"><td colspan="13" style="border:0;padding:0;height:'+spacerTop.style.height+'"></td></tr>'; + const botTr='<tr id="spacerBot"><td colspan="13" style="border:0;padding:0;height:'+spacerBot.style.height+'"></td></tr>'; + tbody.innerHTML=topTr+mid+botTr; +} +function applySort(){ + if(sortCol<0) return; + const c=COLS[sortCol]; + view.sort((a,b)=>{ + const va=DATA[a][c.f],vb=DATA[b][c.f]; + if(c.t==="n") return sortAsc?va-vb:vb-va; + return sortAsc?(va<vb?-1:va>vb?1:0):(va>vb?-1:va<vb?1:0); + }); +} +function rebuild(){ + const q=filterInput.value.toLowerCase(); + view=[]; + for(let i=0;i<DATA.length;i++){ + if(!q||DATA[i][3].toLowerCase().includes(q)) view.push(i); + } + applySort(); + lastFirst=lastLast=-1; + render(); +} +headers.forEach(th=>{ + th.addEventListener("click",()=>{ + const col=parseInt(th.dataset.col); + if(sortCol===col){sortAsc=!sortAsc}else{sortCol=col;sortAsc=true} + headers.forEach(h=>h.querySelector(".arrow").textContent=""); + th.querySelector(".arrow").textContent=sortAsc?"\u25B2":"\u25BC"; + applySort(); + lastFirst=lastLast=-1; + render(); + }); +}); +filterInput.addEventListener("input",()=>rebuild()); +let ticking=false; +container.addEventListener("scroll",()=>{ + if(!ticking){ticking=true;requestAnimationFrame(()=>{render();ticking=false})} +}); +rebuild(); +document.getElementById("csvBtn").addEventListener("click",()=>{ + const H=["LSN","Index","Action ID","Function","In Attachments","In Bytes","Out Attachments","Out Bytes","Wall(s)","CPU(s)","Submitted","Started","Location"]; + const esc=v=>{const s=String(v);return s.includes(',')||s.includes('"')||s.includes('\n')?'"'+s.replace(/"/g,'""')+'"':s}; + const rows=[H.join(",")]; + for(let i=0;i<view.length;i++){ + const d=DATA[view[i]]; + rows.push([d[0],d[1],d[2],d[3],d[4],d[5],d[6],d[7],d[8],d[9],d[15],d[16],d[17]].map(esc).join(",")); + } + const blob=new Blob([rows.join("\n")],{type:"text/csv"}); + const a=document.createElement("a"); + a.href=URL.createObjectURL(blob); + a.download="summary.csv"; + a.click(); + URL.revokeObjectURL(a.href); +}); +})(); +</script></body></html> +)JS")); + + std::filesystem::path HtmlPath = m_OutputPath / "summary.html"; + std::string_view HtmlStr = Html; + zen::WriteFile(HtmlPath, IoBuffer(IoBuffer::Clone, HtmlStr.data(), HtmlStr.size())); + + ZEN_CONSOLE("wrote HTML summary to {}", HtmlPath.string()); + } } if (FailedWorkCounter) @@ -491,10 +1194,10 @@ ExecCommand::HttpExecute() std::filesystem::path TempPath = std::filesystem::absolute(".zen_temp"); - zen::compute::FunctionServiceSession FunctionSession(Resolver); - FunctionSession.AddRemoteRunner(Resolver, TempPath, m_HostName); + zen::compute::ComputeServiceSession ComputeSession(Resolver); + ComputeSession.AddRemoteRunner(Resolver, TempPath, m_HostName); - return ExecUsingSession(FunctionSession); + return ExecUsingSession(ComputeSession); } int @@ -504,11 +1207,21 @@ ExecCommand::BeaconExecute() ChunkResolver& Resolver = *m_ChunkResolver; std::filesystem::path TempPath = std::filesystem::absolute(".zen_temp"); - zen::compute::FunctionServiceSession FunctionSession(Resolver); - FunctionSession.AddRemoteRunner(Resolver, TempPath, "http://localhost:8558"); - // FunctionSession.AddRemoteRunner(Resolver, TempPath, "http://10.99.9.246:8558"); + zen::compute::ComputeServiceSession ComputeSession(Resolver); + + if (!m_OrchestratorUrl.empty()) + { + ZEN_CONSOLE_INFO("using orchestrator at {}", m_OrchestratorUrl); + ComputeSession.SetOrchestratorEndpoint(m_OrchestratorUrl); + ComputeSession.SetOrchestratorBasePath(TempPath); + } + else + { + ZEN_CONSOLE_INFO("note: using hard-coded local worker path"); + ComputeSession.AddRemoteRunner(Resolver, TempPath, "http://localhost:8558"); + } - return ExecUsingSession(FunctionSession); + return ExecUsingSession(ComputeSession); } ////////////////////////////////////////////////////////////////////////// @@ -635,10 +1348,10 @@ ExecCommand::BuildActionsLog() std::filesystem::path TempPath = std::filesystem::absolute(".zen_temp"); - zen::compute::FunctionServiceSession FunctionSession(Resolver); - FunctionSession.StartRecording(Resolver, m_RecordingLogPath); + zen::compute::ComputeServiceSession ComputeSession(Resolver); + ComputeSession.StartRecording(Resolver, m_RecordingLogPath); - return ExecUsingSession(FunctionSession); + return ExecUsingSession(ComputeSession); } void diff --git a/src/zen/cmds/exec_cmd.h b/src/zen/cmds/exec_cmd.h index 43d092144..6311354c0 100644 --- a/src/zen/cmds/exec_cmd.h +++ b/src/zen/cmds/exec_cmd.h @@ -23,7 +23,7 @@ class ChunkResolver; #if ZEN_WITH_COMPUTE_SERVICES namespace zen::compute { -class FunctionServiceSession; +class ComputeServiceSession; } namespace zen { @@ -49,6 +49,7 @@ public: private: cxxopts::Options m_Options{Name, Description}; std::string m_HostName; + std::string m_OrchestratorUrl; std::filesystem::path m_BeaconPath; std::filesystem::path m_RecordingPath; std::filesystem::path m_RecordingLogPath; @@ -57,6 +58,8 @@ private: int m_Limit = 0; bool m_Quiet = false; std::string m_Mode{"http"}; + std::filesystem::path m_OutputPath; + bool m_Binary = false; struct FunctionDefinition { @@ -74,13 +77,14 @@ private: std::vector<FunctionDefinition> m_FunctionList; bool m_VerboseLogging = false; bool m_QuietLogging = false; + bool m_DumpActions = false; zen::ChunkResolver* m_ChunkResolver = nullptr; zen::compute::RecordingReaderBase* m_RecordingReader = nullptr; void RegisterWorkerFunctionsFromDescription(const zen::CbObject& WorkerDesc, const zen::IoHash& WorkerId); - int ExecUsingSession(zen::compute::FunctionServiceSession& FunctionSession); + int ExecUsingSession(zen::compute::ComputeServiceSession& ComputeSession); // Execution modes diff --git a/src/zencompute/CLAUDE.md b/src/zencompute/CLAUDE.md new file mode 100644 index 000000000..f5188123f --- /dev/null +++ b/src/zencompute/CLAUDE.md @@ -0,0 +1,232 @@ +# zencompute Module + +Lambda-style compute function service. Accepts execution requests from HTTP clients, schedules them across local and remote runners, and tracks results. + +## Directory Structure + +``` +src/zencompute/ +├── include/zencompute/ # Public headers +│ ├── computeservice.h # ComputeServiceSession public API +│ ├── httpcomputeservice.h # HTTP service wrapper +│ ├── orchestratorservice.h # Worker registry and orchestration +│ ├── httporchestrator.h # HTTP orchestrator with WebSocket push +│ ├── recordingreader.h # Recording/replay reader API +│ ├── cloudmetadata.h # Cloud provider detection (AWS/Azure/GCP) +│ └── mockimds.h # Test helper for cloud metadata +├── runners/ # Execution backends +│ ├── functionrunner.h/.cpp # Abstract base + BaseRunnerGroup/RunnerGroup +│ ├── localrunner.h/.cpp # LocalProcessRunner (sandbox, monitoring, CPU sampling) +│ ├── windowsrunner.h/.cpp # Windows AppContainer sandboxing + CreateProcessW +│ ├── linuxrunner.h/.cpp # Linux user/mount/network namespace isolation +│ ├── macrunner.h/.cpp # macOS Seatbelt sandboxing +│ ├── winerunner.h/.cpp # Wine runner for Windows executables on Linux +│ ├── remotehttprunner.h/.cpp # Remote HTTP submission to other zenserver instances +│ └── deferreddeleter.h/.cpp # Background deletion of sandbox directories +├── recording/ # Recording/replay subsystem +│ ├── actionrecorder.h/.cpp # Write actions+attachments to disk +│ └── recordingreader.cpp # Read recordings back +├── timeline/ +│ └── workertimeline.h/.cpp # Per-worker action lifecycle event tracking +├── testing/ +│ └── mockimds.cpp # Mock IMDS for cloud metadata tests +├── computeservice.cpp # ComputeServiceSession::Impl (~1700 lines) +├── httpcomputeservice.cpp # HTTP route registration and handlers (~900 lines) +├── httporchestrator.cpp # Orchestrator HTTP API + WebSocket push +├── orchestratorservice.cpp # Worker registry, health probing +└── cloudmetadata.cpp # IMDS probing, termination monitoring +``` + +## Key Classes + +### `ComputeServiceSession` (computeservice.h) +Public API entry point. Uses PIMPL (`struct Impl` in computeservice.cpp). Owns: +- Two `RunnerGroup`s: `m_LocalRunnerGroup`, `m_RemoteRunnerGroup` +- Scheduler thread that drains `m_UpdatedActions` and drives state transitions +- Action maps: `m_PendingActions`, `m_RunningMap`, `m_ResultsMap` +- Queue map: `m_Queues` (QueueEntry objects) +- Action history ring: `m_ActionHistory` (bounded deque, default 1000) + +**Session states:** Created → Ready → Draining → Paused → Abandoned → Sunset. Both Abandoned and Sunset can be jumped to from any earlier state. Abandoned is used for spot instance termination grace periods — on entry, all pending and running actions are immediately marked as `RunnerAction::State::Abandoned` and running processes are best-effort cancelled. Auto-retry is suppressed while the session is Abandoned. `IsHealthy()` returns false for Abandoned and Sunset. + +### `RunnerAction` (runners/functionrunner.h) +Shared ref-counted struct representing one action through its lifecycle. + +**Key fields:** +- `ActionLsn` — global unique sequence number +- `QueueId` — 0 for implicit/unqueued actions +- `Worker` — descriptor + content hash +- `ActionObj` — CbObject with the action spec +- `CpuUsagePercent` / `CpuSeconds` — atomics updated by monitor thread +- `RetryCount` — atomic int tracking how many times the action has been rescheduled +- `Timestamps[State::_Count]` — timestamp of each state transition + +**State machine (forward-only under normal flow, atomic):** +``` +New → Pending → Submitting → Running → Completed + → Failed + → Abandoned + → Cancelled +``` +`SetActionState()` rejects non-forward transitions. The one exception is `ResetActionStateToPending()`, which uses CAS to atomically transition from Failed/Abandoned back to Pending for rescheduling. It clears timestamps from Submitting onward, resets execution fields, increments `RetryCount`, and calls `PostUpdate()` to re-enter the scheduler pipeline. + +### `LocalProcessRunner` (runners/localrunner.h) +Base for all local execution. Platform runners subclass this and override: +- `SubmitAction()` — fork/exec the worker process +- `SweepRunningActions()` — poll for process exit (waitpid / WaitForSingleObject) +- `CancelRunningActions()` — signal all processes during shutdown +- `SampleProcessCpu(RunningAction&)` — read platform CPU usage (no-op default) + +**Infrastructure owned by LocalProcessRunner:** +- Monitor thread — calls `SweepRunningActions()` then `SampleRunningProcessCpu()` in a loop +- `m_RunningMap` — `RwLock`-guarded map of `Lsn → RunningAction` +- `DeferredDirectoryDeleter` — sandbox directories are queued for async deletion +- `PrepareActionSubmission()` — shared preamble (capacity check, sandbox creation, worker manifesting, input decompression) +- `ProcessCompletedActions()` — shared post-processing (gather outputs, set state, enqueue deletion) + +**CPU sampling:** `SampleRunningProcessCpu()` iterates `m_RunningMap` under shared lock, calls `SampleProcessCpu()` per entry, throttled to every 5 seconds per action. Platform implementations: +- Linux: `/proc/{pid}/stat` utime+stime jiffies ÷ `_SC_CLK_TCK` +- Windows: `GetProcessTimes()` in 100ns intervals ÷ 10,000,000 +- macOS: `proc_pidinfo(PROC_PIDTASKINFO)` pti_total_user+system nanoseconds ÷ 1,000,000,000 + +### `FunctionRunner` / `RunnerGroup` (runners/functionrunner.h) +Abstract base for runners. `RunnerGroup<T>` holds a vector of runners and load-balances across them using a round-robin atomic index. `BaseRunnerGroup::SubmitActions()` distributes a batch proportionally based on per-runner capacity. + +### `HttpComputeService` (include/zencompute/httpcomputeservice.h) +Wraps `ComputeServiceSession` as an HTTP service. All routes are registered in the constructor. Handles CbPackage attachment ingestion from `CidStore` before forwarding to the service. + +## Action Lifecycle (End to End) + +1. **HTTP POST** → `HttpComputeService` ingests attachments, calls `EnqueueAction()` +2. **Enqueue** → creates `RunnerAction` (New → Pending), calls `PostUpdate()` +3. **PostUpdate** → appends to `m_UpdatedActions`, signals scheduler thread event +4. **Scheduler thread** → drains `m_UpdatedActions`, drives pending actions to runners +5. **Runner `SubmitAction()`** → Pending → Submitting (on runner's worker pool thread) +6. **Process launch** → Submitting → Running, added to `m_RunningMap` +7. **Monitor thread `SweepRunningActions()`** → detects exit, gathers outputs +8. **`ProcessCompletedActions()`** → Running → Completed/Failed/Abandoned, `PostUpdate()` +9. **Scheduler thread `HandleActionUpdates()`** — for Failed/Abandoned actions, checks retry limit; if retries remain, calls `ResetActionStateToPending()` which loops back to step 3. Otherwise moves to `m_ResultsMap`, records history, notifies queue. +10. **Client `GET /jobs/{lsn}`** → returns result from `m_ResultsMap`, schedules retirement + +### Action Rescheduling + +Actions that fail or are abandoned can be automatically retried or manually rescheduled via the API. + +**Automatic retry (scheduler path):** In `HandleActionUpdates()`, when a Failed or Abandoned state is detected, the scheduler checks `RetryCount < GetMaxRetriesForQueue(QueueId)`. If retries remain, the action is removed from active maps and `ResetActionStateToPending()` is called, which re-enters it into the scheduler pipeline. The action keeps its original LSN so clients can continue polling with the same identifier. + +**Manual retry (API path):** `POST /compute/jobs/{lsn}` calls `RescheduleAction()`, which finds the action in `m_ResultsMap`, validates state (must be Failed or Abandoned), checks the retry limit, reverses queue counters (moving the LSN from `FinishedLsns` back to `ActiveLsns`), removes from results, and calls `ResetActionStateToPending()`. Returns 200 with `{lsn, retry_count}` on success, 409 Conflict with `{error}` on failure. + +**Retry limit:** Default of 3, overridable per-queue via the `max_retries` integer field in the queue's `Config` CbObject (set at `CreateQueue` time). Both automatic and manual paths respect this limit. + +**Cancelled actions are never retried** — cancellation is an intentional user action, not a transient failure. + +## Queue System + +Queues group actions from a single client session. A `QueueEntry` (internal) tracks: +- `State` — `std::atomic<QueueState>` lifecycle state (Active → Draining → Cancelled) +- `ActiveCount` — pending + running actions (atomic) +- `CompletedCount / FailedCount / AbandonedCount / CancelledCount` (atomics) +- `ActiveLsns` — for cancellation lookup (under `m_Lock`) +- `FinishedLsns` — moved here when actions complete +- `IdleSince` — used for 15-minute automatic expiry +- `Config` — CbObject set at creation; supports `max_retries` (int) to override the default retry limit + +**Queue state machine (`QueueState` enum):** +``` +Active → Draining → Cancelled + \ ↑ + ─────────────────────/ +``` +- **Active** — accepts new work, schedules pending work, finishes running work (initial state) +- **Draining** — rejects new work, finishes existing work (one-way via CAS from Active; cannot override Cancelled) +- **Cancelled** — rejects new work, actively cancels in-flight work (reachable from Active or Draining) + +Key operations: +- `CreateQueue(Tag)` → returns `QueueId` +- `EnqueueActionToQueue(QueueId, ...)` → action's `QueueId` field is set at creation +- `CancelQueue(QueueId)` → marks all active LSNs for cancellation +- `DrainQueue(QueueId)` → stops accepting new submissions; existing work finishes naturally (irreversible) +- `GetQueueCompleted(QueueId)` → CbWriter output of finished results +- Queue references in HTTP routes accept either a decimal ID or an Oid token (24-hex), resolved by `ResolveQueueRef()` + +## HTTP API + +All routes registered in `HttpComputeService` constructor. Prefix is configured externally (typically `/compute`). + +### Global endpoints +| Method | Path | Description | +|--------|------|-------------| +| POST | `abandon` | Transition session to Abandoned state (409 if invalid) | +| GET | `jobs/history` | Action history (last N, with timestamps per state) | +| GET | `jobs/running` | In-flight actions with CPU metrics | +| GET | `jobs/completed` | Actions with results available | +| GET/POST/DELETE | `jobs/{lsn}` | GET: result; POST: reschedule failed action; DELETE: retire | +| POST | `jobs/{worker}` | Submit action for specific worker | +| POST | `jobs` | Submit action (worker resolved from descriptor) | +| GET | `workers` | List worker IDs | +| GET | `workers/all` | All workers with full descriptors | +| GET/POST | `workers/{worker}` | Get/register worker | + +### Queue-scoped endpoints +Queue ref is capture(1) in all `queues/{queueref}/...` routes. + +| Method | Path | Description | +|--------|------|-------------| +| GET | `queues` | List queue IDs | +| POST | `queues` | Create queue | +| GET/DELETE | `queues/{queueref}` | Status / delete | +| POST | `queues/{queueref}/drain` | Drain queue (irreversible; rejects new submissions) | +| GET | `queues/{queueref}/completed` | Queue's completed results | +| GET | `queues/{queueref}/history` | Queue's action history | +| GET | `queues/{queueref}/running` | Queue's running actions | +| POST | `queues/{queueref}/jobs` | Submit to queue | +| GET/POST | `queues/{queueref}/jobs/{lsn}` | GET: result; POST: reschedule | +| GET/POST | `queues/{queueref}/workers/...` | Worker endpoints (same as global) | + +Worker handler logic is extracted into private helpers (`HandleWorkersGet`, `HandleWorkersAllGet`, `HandleWorkerRequest`) shared by top-level and queue-scoped routes. + +## Concurrency Model + +**Locking discipline:** When multiple locks must be held simultaneously, always acquire in this order to prevent deadlocks: +1. `m_ResultsLock` +2. `m_RunningLock` (comment in localrunner.h: "must be taken *after* m_ResultsLock") +3. `m_PendingLock` +4. `m_QueueLock` + +**Atomic fields** for counters and simple state: queue counts, `CpuUsagePercent`, `CpuSeconds`, `RetryCount`, `RunnerAction::m_ActionState`. + +**Update decoupling:** Runners call `PostUpdate(RunnerAction*)` rather than directly mutating service state. The scheduler thread batches and deduplicates updates. + +**Thread ownership:** +- Scheduler thread — drives state transitions, owns `m_PendingActions` +- Monitor thread (per runner) — polls process completion, owns `m_RunningMap` via shared lock +- Worker pool threads — async submission, brief `SubmitAction()` calls +- HTTP threads — read-only access to results, queue status + +## Sandbox Layout + +Each action gets a unique numbered directory under `m_SandboxPath`: +``` +scratch/{counter}/ + worker/ ← worker binaries (or bind-mounted on Linux) + inputs/ ← decompressed action inputs + outputs/ ← written by worker process +``` + +On Linux with sandboxing enabled, the process runs in a pivot-rooted namespace with `/usr`, `/lib`, `/etc`, `/worker` bind-mounted read-only and a tmpfs `/dev`. + +## Adding a New HTTP Endpoint + +1. Register the route in the `HttpComputeService` constructor in `httpcomputeservice.cpp` +2. If the handler is shared between top-level and a `queues/{queueref}/...` variant, extract it as a private helper method declared in `httpcomputeservice.h` +3. Queue-scoped routes validate the queue ref with `ResolveQueueRef(HttpReq, Req.GetCapture(1))` which writes an error response and returns 0 on failure +4. Use `CbObjectWriter` for response bodies; emit via `HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save())` +5. Conditional fields (e.g., optional CPU metrics): emit inside `if (value > 0.0f)` / `if (value >= 0.0f)` guards to omit absent values rather than emitting sentinel values + +## Adding a New Runner Platform + +1. Subclass `LocalProcessRunner`, add `h`/`cpp` files in `runners/` +2. Override `SubmitAction()`, `SweepRunningActions()`, `CancelRunningActions()`, and optionally `CancelAction(int)` and `SampleProcessCpu(RunningAction&)` +3. `SampleProcessCpu()` must update both `Running.Action->CpuSeconds` (unconditionally from the absolute OS value) and `Running.Action->CpuUsagePercent` (delta-based, only after second sample) +4. `ProcessHandle` convention: store pid as `reinterpret_cast<void*>(static_cast<intptr_t>(pid))` for consistency with the base class +5. Register in `ComputeServiceSession::AddLocalRunner()` in `computeservice.cpp` diff --git a/src/zencompute/cloudmetadata.cpp b/src/zencompute/cloudmetadata.cpp new file mode 100644 index 000000000..b3b3210d9 --- /dev/null +++ b/src/zencompute/cloudmetadata.cpp @@ -0,0 +1,1010 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencompute/cloudmetadata.h> + +#include <zencore/basicfile.h> +#include <zencore/filesystem.h> +#include <zencore/string.h> +#include <zencore/trace.h> +#include <zenhttp/httpclient.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <json11.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen::compute { + +// All major cloud providers expose instance metadata at this link-local address. +// It is only routable from within a cloud VM; on bare-metal the TCP connect will +// fail, which is how we distinguish cloud from non-cloud environments. +static constexpr std::string_view kImdsEndpoint = "http://169.254.169.254"; + +// Short connect timeout so that detection on non-cloud machines is fast. The IMDS +// is a local service on the hypervisor so 200ms is generous for actual cloud VMs. +static constexpr auto kImdsTimeout = std::chrono::milliseconds{200}; + +std::string_view +ToString(CloudProvider Provider) +{ + switch (Provider) + { + case CloudProvider::AWS: + return "AWS"; + case CloudProvider::Azure: + return "Azure"; + case CloudProvider::GCP: + return "GCP"; + default: + return "None"; + } +} + +CloudMetadata::CloudMetadata(std::filesystem::path DataDir) : CloudMetadata(std::move(DataDir), std::string(kImdsEndpoint)) +{ +} + +CloudMetadata::CloudMetadata(std::filesystem::path DataDir, std::string ImdsEndpoint) +: m_Log(logging::Get("cloud")) +, m_DataDir(std::move(DataDir)) +, m_ImdsEndpoint(std::move(ImdsEndpoint)) +{ + ZEN_TRACE_CPU("CloudMetadata::CloudMetadata"); + + std::error_code Ec; + std::filesystem::create_directories(m_DataDir, Ec); + + DetectProvider(); + + if (m_Info.Provider != CloudProvider::None) + { + StartTerminationMonitor(); + } +} + +CloudMetadata::~CloudMetadata() +{ + ZEN_TRACE_CPU("CloudMetadata::~CloudMetadata"); + m_MonitorEnabled = false; + m_MonitorEvent.Set(); + if (m_MonitorThread.joinable()) + { + m_MonitorThread.join(); + } +} + +CloudProvider +CloudMetadata::GetProvider() const +{ + return m_InfoLock.WithSharedLock([&] { return m_Info.Provider; }); +} + +CloudInstanceInfo +CloudMetadata::GetInstanceInfo() const +{ + return m_InfoLock.WithSharedLock([&] { return m_Info; }); +} + +bool +CloudMetadata::IsTerminationPending() const +{ + return m_TerminationPending.load(std::memory_order_relaxed); +} + +std::string +CloudMetadata::GetTerminationReason() const +{ + return m_ReasonLock.WithSharedLock([&] { return m_TerminationReason; }); +} + +void +CloudMetadata::Describe(CbWriter& Writer) const +{ + ZEN_TRACE_CPU("CloudMetadata::Describe"); + CloudInstanceInfo Info = GetInstanceInfo(); + + if (Info.Provider == CloudProvider::None) + { + return; + } + + Writer.BeginObject("cloud"); + Writer << "provider" << ToString(Info.Provider); + Writer << "instance_id" << Info.InstanceId; + Writer << "availability_zone" << Info.AvailabilityZone; + Writer << "is_spot" << Info.IsSpot; + Writer << "is_autoscaling" << Info.IsAutoscaling; + Writer << "termination_pending" << IsTerminationPending(); + + if (IsTerminationPending()) + { + Writer << "termination_reason" << GetTerminationReason(); + } + + Writer.EndObject(); +} + +void +CloudMetadata::DetectProvider() +{ + ZEN_TRACE_CPU("CloudMetadata::DetectProvider"); + + if (TryDetectAWS()) + { + return; + } + + if (TryDetectAzure()) + { + return; + } + + if (TryDetectGCP()) + { + return; + } + + ZEN_DEBUG("no cloud provider detected"); +} + +// AWS detection uses IMDSv2 which requires a session token obtained via PUT before +// any GET requests are allowed. This is more secure than IMDSv1 (which allowed +// unauthenticated GETs) and is the default on modern EC2 instances. The token has +// a 300-second TTL and is reused for termination polling. +bool +CloudMetadata::TryDetectAWS() +{ + ZEN_TRACE_CPU("CloudMetadata::TryDetectAWS"); + + std::filesystem::path SentinelPath = m_DataDir / ".isNotAWS"; + + if (HasSentinelFile(SentinelPath)) + { + ZEN_DEBUG("skipping AWS detection - negative cache hit"); + return false; + } + + ZEN_DEBUG("probing AWS IMDS"); + + try + { + HttpClient ImdsClient(m_ImdsEndpoint, + {.LogCategory = "cloud-aws", .ConnectTimeout = kImdsTimeout, .Timeout = std::chrono::milliseconds{1000}}); + + // IMDSv2: acquire session token. The TTL header is mandatory; we request + // 300s which is sufficient for the detection phase. The token is also + // stored in m_AwsToken for reuse by the termination polling thread. + HttpClient::KeyValueMap TokenHeaders(std::pair<std::string_view, std::string_view>{"X-aws-ec2-metadata-token-ttl-seconds", "300"}); + HttpClient::Response TokenResponse = ImdsClient.Put("/latest/api/token", IoBuffer{}, TokenHeaders); + + if (!TokenResponse.IsSuccess()) + { + ZEN_DEBUG("AWS IMDS token request failed ({}), not on AWS", static_cast<int>(TokenResponse.StatusCode)); + WriteSentinelFile(SentinelPath); + return false; + } + + m_AwsToken = std::string(TokenResponse.AsText()); + + HttpClient::KeyValueMap AuthHeaders(std::pair<std::string_view, std::string_view>{"X-aws-ec2-metadata-token", m_AwsToken}); + + HttpClient::Response IdResponse = ImdsClient.Get("/latest/meta-data/instance-id", AuthHeaders); + if (IdResponse.IsSuccess()) + { + m_Info.InstanceId = std::string(IdResponse.AsText()); + } + + HttpClient::Response AzResponse = ImdsClient.Get("/latest/meta-data/placement/availability-zone", AuthHeaders); + if (AzResponse.IsSuccess()) + { + m_Info.AvailabilityZone = std::string(AzResponse.AsText()); + } + + // "spot" vs "on-demand" — determines whether the instance can be + // reclaimed by AWS with a 2-minute warning + HttpClient::Response LifecycleResponse = ImdsClient.Get("/latest/meta-data/instance-life-cycle", AuthHeaders); + if (LifecycleResponse.IsSuccess()) + { + m_Info.IsSpot = (LifecycleResponse.AsText() == "spot"); + } + + // This endpoint only exists on instances managed by an Auto Scaling + // Group. A successful response (regardless of value) means autoscaling. + HttpClient::Response AutoscaleResponse = ImdsClient.Get("/latest/meta-data/autoscaling/target-lifecycle-state", AuthHeaders); + if (AutoscaleResponse.IsSuccess()) + { + m_Info.IsAutoscaling = true; + } + + m_Info.Provider = CloudProvider::AWS; + + ZEN_INFO("detected AWS instance: id={}, az={}, spot={}, autoscaling={}", + m_Info.InstanceId, + m_Info.AvailabilityZone, + m_Info.IsSpot, + m_Info.IsAutoscaling); + + return true; + } + catch (const std::exception& Ex) + { + ZEN_DEBUG("AWS IMDS probe failed: {}", Ex.what()); + WriteSentinelFile(SentinelPath); + return false; + } +} + +// Azure IMDS returns a single JSON document for the entire instance metadata, +// unlike AWS and GCP which use separate plain-text endpoints per field. The +// "Metadata: true" header is required; requests without it are rejected. +// The api-version parameter is mandatory and pins the response schema. +bool +CloudMetadata::TryDetectAzure() +{ + ZEN_TRACE_CPU("CloudMetadata::TryDetectAzure"); + + std::filesystem::path SentinelPath = m_DataDir / ".isNotAzure"; + + if (HasSentinelFile(SentinelPath)) + { + ZEN_DEBUG("skipping Azure detection - negative cache hit"); + return false; + } + + ZEN_DEBUG("probing Azure IMDS"); + + try + { + HttpClient ImdsClient(m_ImdsEndpoint, + {.LogCategory = "cloud-azure", .ConnectTimeout = kImdsTimeout, .Timeout = std::chrono::milliseconds{1000}}); + + HttpClient::KeyValueMap MetadataHeaders({ + std::pair<std::string_view, std::string_view>{"Metadata", "true"}, + }); + + HttpClient::Response InstanceResponse = ImdsClient.Get("/metadata/instance?api-version=2021-02-01", MetadataHeaders); + + if (!InstanceResponse.IsSuccess()) + { + ZEN_DEBUG("Azure IMDS request failed ({}), not on Azure", static_cast<int>(InstanceResponse.StatusCode)); + WriteSentinelFile(SentinelPath); + return false; + } + + std::string JsonError; + const json11::Json Json = json11::Json::parse(std::string(InstanceResponse.AsText()), JsonError); + + if (!JsonError.empty()) + { + ZEN_DEBUG("Azure IMDS returned invalid JSON: {}", JsonError); + WriteSentinelFile(SentinelPath); + return false; + } + + const json11::Json& Compute = Json["compute"]; + + m_Info.InstanceId = Compute["vmId"].string_value(); + m_Info.AvailabilityZone = Compute["location"].string_value(); + + // Azure spot VMs have priority "Spot"; regular VMs have "Regular" + std::string Priority = Compute["priority"].string_value(); + m_Info.IsSpot = (Priority == "Spot"); + + // Check if part of a VMSS (Virtual Machine Scale Set) — indicates autoscaling + std::string VmssName = Compute["vmScaleSetName"].string_value(); + m_Info.IsAutoscaling = !VmssName.empty(); + + m_Info.Provider = CloudProvider::Azure; + + ZEN_INFO("detected Azure instance: id={}, location={}, spot={}, vmss={}", + m_Info.InstanceId, + m_Info.AvailabilityZone, + m_Info.IsSpot, + m_Info.IsAutoscaling); + + return true; + } + catch (const std::exception& Ex) + { + ZEN_DEBUG("Azure IMDS probe failed: {}", Ex.what()); + WriteSentinelFile(SentinelPath); + return false; + } +} + +// GCP requires the "Metadata-Flavor: Google" header on all IMDS requests. +// Unlike AWS, there is no session token; the header itself is the auth mechanism +// (it prevents SSRF attacks since browsers won't send custom headers to the +// metadata endpoint). Each metadata field is fetched from a separate URL. +bool +CloudMetadata::TryDetectGCP() +{ + ZEN_TRACE_CPU("CloudMetadata::TryDetectGCP"); + + std::filesystem::path SentinelPath = m_DataDir / ".isNotGCP"; + + if (HasSentinelFile(SentinelPath)) + { + ZEN_DEBUG("skipping GCP detection - negative cache hit"); + return false; + } + + ZEN_DEBUG("probing GCP metadata service"); + + try + { + HttpClient ImdsClient(m_ImdsEndpoint, + {.LogCategory = "cloud-gcp", .ConnectTimeout = kImdsTimeout, .Timeout = std::chrono::milliseconds{1000}}); + + HttpClient::KeyValueMap MetadataHeaders(std::pair<std::string_view, std::string_view>{"Metadata-Flavor", "Google"}); + + // Fetch instance ID + HttpClient::Response IdResponse = ImdsClient.Get("/computeMetadata/v1/instance/id", MetadataHeaders); + + if (!IdResponse.IsSuccess()) + { + ZEN_DEBUG("GCP metadata request failed ({}), not on GCP", static_cast<int>(IdResponse.StatusCode)); + WriteSentinelFile(SentinelPath); + return false; + } + + m_Info.InstanceId = std::string(IdResponse.AsText()); + + // GCP returns the fully-qualified zone path "projects/<num>/zones/<zone>". + // Strip the prefix to get just the zone name (e.g. "us-central1-a"). + HttpClient::Response ZoneResponse = ImdsClient.Get("/computeMetadata/v1/instance/zone", MetadataHeaders); + if (ZoneResponse.IsSuccess()) + { + std::string_view Zone = ZoneResponse.AsText(); + if (auto Pos = Zone.rfind('/'); Pos != std::string_view::npos) + { + Zone = Zone.substr(Pos + 1); + } + m_Info.AvailabilityZone = std::string(Zone); + } + + // Check for preemptible/spot (scheduling/preemptible returns "TRUE" or "FALSE") + HttpClient::Response PreemptibleResponse = ImdsClient.Get("/computeMetadata/v1/instance/scheduling/preemptible", MetadataHeaders); + if (PreemptibleResponse.IsSuccess()) + { + m_Info.IsSpot = (PreemptibleResponse.AsText() == "TRUE"); + } + + // Check for maintenance event + HttpClient::Response MaintenanceResponse = ImdsClient.Get("/computeMetadata/v1/instance/maintenance-event", MetadataHeaders); + if (MaintenanceResponse.IsSuccess()) + { + std::string_view Event = MaintenanceResponse.AsText(); + if (!Event.empty() && Event != "NONE") + { + m_TerminationPending = true; + m_ReasonLock.WithExclusiveLock([&] { m_TerminationReason = fmt::format("GCP maintenance event: {}", Event); }); + } + } + + m_Info.Provider = CloudProvider::GCP; + + ZEN_INFO("detected GCP instance: id={}, az={}, spot={}", m_Info.InstanceId, m_Info.AvailabilityZone, m_Info.IsSpot); + + return true; + } + catch (const std::exception& Ex) + { + ZEN_DEBUG("GCP metadata probe failed: {}", Ex.what()); + WriteSentinelFile(SentinelPath); + return false; + } +} + +// Sentinel files are empty marker files whose mere existence signals that a +// previous detection attempt for a given provider failed. This avoids paying +// the connect-timeout cost on every startup for providers that are known to +// be absent. The files persist across process restarts; delete them manually +// (or remove the DataDir) to force re-detection. +void +CloudMetadata::WriteSentinelFile(const std::filesystem::path& Path) +{ + try + { + BasicFile File; + File.Open(Path, BasicFile::Mode::kTruncate); + } + catch (const std::exception& Ex) + { + ZEN_WARN("failed to write sentinel file '{}': {}", Path.string(), Ex.what()); + } +} + +bool +CloudMetadata::HasSentinelFile(const std::filesystem::path& Path) const +{ + return zen::IsFile(Path); +} + +void +CloudMetadata::ClearSentinelFiles() +{ + std::error_code Ec; + std::filesystem::remove(m_DataDir / ".isNotAWS", Ec); + std::filesystem::remove(m_DataDir / ".isNotAzure", Ec); + std::filesystem::remove(m_DataDir / ".isNotGCP", Ec); +} + +void +CloudMetadata::StartTerminationMonitor() +{ + ZEN_INFO("starting cloud termination monitor for {} instance {}", ToString(m_Info.Provider), m_Info.InstanceId); + + m_MonitorThread = std::thread{&CloudMetadata::TerminationMonitorThread, this}; +} + +void +CloudMetadata::TerminationMonitorThread() +{ + SetCurrentThreadName("cloud_term_mon"); + + // Poll every 5 seconds. The Event is used as an interruptible sleep so + // that the destructor can wake us up immediately for a clean shutdown. + while (m_MonitorEnabled) + { + m_MonitorEvent.Wait(5000); + m_MonitorEvent.Reset(); + + if (!m_MonitorEnabled) + { + return; + } + + PollTermination(); + } +} + +void +CloudMetadata::PollTermination() +{ + try + { + CloudProvider Provider = m_InfoLock.WithSharedLock([&] { return m_Info.Provider; }); + + if (Provider == CloudProvider::AWS) + { + PollAWSTermination(); + } + else if (Provider == CloudProvider::Azure) + { + PollAzureTermination(); + } + else if (Provider == CloudProvider::GCP) + { + PollGCPTermination(); + } + } + catch (const std::exception& Ex) + { + ZEN_DEBUG("termination poll error: {}", Ex.what()); + } +} + +// AWS termination signals: +// - /spot/instance-action: returns 200 with a JSON body ~2 minutes before +// a spot instance is reclaimed. Returns 404 when no action is pending. +// - /autoscaling/target-lifecycle-state: returns the ASG lifecycle state. +// "InService" is normal; anything else (e.g. "Terminated:Wait") means +// the instance is being cycled out. +void +CloudMetadata::PollAWSTermination() +{ + ZEN_TRACE_CPU("CloudMetadata::PollAWSTermination"); + + HttpClient ImdsClient(m_ImdsEndpoint, + {.LogCategory = "cloud-aws", .ConnectTimeout = kImdsTimeout, .Timeout = std::chrono::milliseconds{2000}}); + + HttpClient::KeyValueMap AuthHeaders(std::pair<std::string_view, std::string_view>{"X-aws-ec2-metadata-token", m_AwsToken}); + + HttpClient::Response SpotResponse = ImdsClient.Get("/latest/meta-data/spot/instance-action", AuthHeaders); + if (SpotResponse.IsSuccess()) + { + if (!m_TerminationPending.exchange(true)) + { + m_ReasonLock.WithExclusiveLock([&] { m_TerminationReason = fmt::format("AWS spot interruption: {}", SpotResponse.AsText()); }); + ZEN_WARN("AWS spot interruption detected: {}", SpotResponse.AsText()); + } + return; + } + + HttpClient::Response AutoscaleResponse = ImdsClient.Get("/latest/meta-data/autoscaling/target-lifecycle-state", AuthHeaders); + if (AutoscaleResponse.IsSuccess()) + { + std::string_view State = AutoscaleResponse.AsText(); + if (State.find("InService") == std::string_view::npos) + { + if (!m_TerminationPending.exchange(true)) + { + m_ReasonLock.WithExclusiveLock([&] { m_TerminationReason = fmt::format("AWS autoscaling lifecycle: {}", State); }); + ZEN_WARN("AWS autoscaling termination detected: {}", State); + } + } + } +} + +// Azure Scheduled Events API returns a JSON array of upcoming platform events. +// We care about "Preempt" (spot eviction), "Terminate", and "Reboot" events. +// Other event types like "Freeze" (live migration) are non-destructive and +// ignored. The Events array is empty when nothing is pending. +void +CloudMetadata::PollAzureTermination() +{ + ZEN_TRACE_CPU("CloudMetadata::PollAzureTermination"); + + HttpClient ImdsClient(m_ImdsEndpoint, + {.LogCategory = "cloud-azure", .ConnectTimeout = kImdsTimeout, .Timeout = std::chrono::milliseconds{2000}}); + + HttpClient::KeyValueMap MetadataHeaders({ + std::pair<std::string_view, std::string_view>{"Metadata", "true"}, + }); + + HttpClient::Response EventsResponse = ImdsClient.Get("/metadata/scheduledevents?api-version=2020-07-01", MetadataHeaders); + + if (!EventsResponse.IsSuccess()) + { + return; + } + + std::string JsonError; + const json11::Json Json = json11::Json::parse(std::string(EventsResponse.AsText()), JsonError); + + if (!JsonError.empty()) + { + return; + } + + const json11::Json::array& Events = Json["Events"].array_items(); + for (const auto& Evt : Events) + { + std::string EventType = Evt["EventType"].string_value(); + if (EventType == "Preempt" || EventType == "Terminate" || EventType == "Reboot") + { + if (!m_TerminationPending.exchange(true)) + { + std::string EventStatus = Evt["EventStatus"].string_value(); + m_ReasonLock.WithExclusiveLock( + [&] { m_TerminationReason = fmt::format("Azure scheduled event: {} ({})", EventType, EventStatus); }); + ZEN_WARN("Azure termination event detected: {} ({})", EventType, EventStatus); + } + return; + } + } +} + +// GCP maintenance-event returns "NONE" when nothing is pending, and a +// descriptive string like "TERMINATE_ON_HOST_MAINTENANCE" when the VM is +// about to be live-migrated or terminated. Preemptible/spot VMs get a +// 30-second warning before termination. +void +CloudMetadata::PollGCPTermination() +{ + ZEN_TRACE_CPU("CloudMetadata::PollGCPTermination"); + + HttpClient ImdsClient(m_ImdsEndpoint, + {.LogCategory = "cloud-gcp", .ConnectTimeout = kImdsTimeout, .Timeout = std::chrono::milliseconds{2000}}); + + HttpClient::KeyValueMap MetadataHeaders(std::pair<std::string_view, std::string_view>{"Metadata-Flavor", "Google"}); + + HttpClient::Response MaintenanceResponse = ImdsClient.Get("/computeMetadata/v1/instance/maintenance-event", MetadataHeaders); + if (MaintenanceResponse.IsSuccess()) + { + std::string_view Event = MaintenanceResponse.AsText(); + if (!Event.empty() && Event != "NONE") + { + if (!m_TerminationPending.exchange(true)) + { + m_ReasonLock.WithExclusiveLock([&] { m_TerminationReason = fmt::format("GCP maintenance event: {}", Event); }); + ZEN_WARN("GCP maintenance event detected: {}", Event); + } + } + } +} + +} // namespace zen::compute + +////////////////////////////////////////////////////////////////////////// + +#if ZEN_WITH_TESTS + +# include <zencompute/mockimds.h> + +# include <zencore/filesystem.h> +# include <zencore/testing.h> +# include <zencore/testutils.h> +# include <zenhttp/httpserver.h> + +# include <memory> +# include <thread> + +namespace zen::compute { + +// --------------------------------------------------------------------------- +// Test helper — spins up a local ASIO HTTP server hosting a MockImdsService +// --------------------------------------------------------------------------- + +struct TestImdsServer +{ + MockImdsService Mock; + + void Start() + { + m_TmpDir.emplace(); + m_Server = CreateHttpServer(HttpServerConfig{.ServerClass = "asio"}); + m_Port = m_Server->Initialize(7575, m_TmpDir->Path() / "http"); + REQUIRE(m_Port != -1); + m_Server->RegisterService(Mock); + m_ServerThread = std::thread([this]() { m_Server->Run(false); }); + } + + std::string Endpoint() const { return fmt::format("http://127.0.0.1:{}", m_Port); } + + std::filesystem::path DataDir() const { return m_TmpDir->Path() / "cloud"; } + + std::unique_ptr<CloudMetadata> CreateCloud() { return std::make_unique<CloudMetadata>(DataDir(), Endpoint()); } + + ~TestImdsServer() + { + if (m_Server) + { + m_Server->RequestExit(); + } + if (m_ServerThread.joinable()) + { + m_ServerThread.join(); + } + if (m_Server) + { + m_Server->Close(); + } + } + +private: + std::optional<ScopedTemporaryDirectory> m_TmpDir; + Ref<HttpServer> m_Server; + std::thread m_ServerThread; + int m_Port = -1; +}; + +// --------------------------------------------------------------------------- +// AWS +// --------------------------------------------------------------------------- + +TEST_CASE("cloudmetadata.aws") +{ + TestImdsServer Imds; + Imds.Mock.ActiveProvider = CloudProvider::AWS; + + SUBCASE("detection basics") + { + Imds.Mock.Aws.InstanceId = "i-abc123"; + Imds.Mock.Aws.AvailabilityZone = "us-west-2b"; + Imds.Mock.Aws.LifeCycle = "on-demand"; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + + CHECK(Cloud->GetProvider() == CloudProvider::AWS); + + CloudInstanceInfo Info = Cloud->GetInstanceInfo(); + CHECK(Info.InstanceId == "i-abc123"); + CHECK(Info.AvailabilityZone == "us-west-2b"); + CHECK(Info.IsSpot == false); + CHECK(Info.IsAutoscaling == false); + CHECK(Cloud->IsTerminationPending() == false); + } + + SUBCASE("spot instance") + { + Imds.Mock.Aws.LifeCycle = "spot"; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + CloudInstanceInfo Info = Cloud->GetInstanceInfo(); + CHECK(Info.IsSpot == true); + } + + SUBCASE("autoscaling instance") + { + Imds.Mock.Aws.AutoscalingState = "InService"; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + CloudInstanceInfo Info = Cloud->GetInstanceInfo(); + CHECK(Info.IsAutoscaling == true); + } + + SUBCASE("spot termination") + { + Imds.Mock.Aws.LifeCycle = "spot"; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + CHECK(Cloud->IsTerminationPending() == false); + + // Simulate a spot interruption notice appearing + Imds.Mock.Aws.SpotAction = R"({"action":"terminate","time":"2025-01-01T00:00:00Z"})"; + Cloud->PollTermination(); + + CHECK(Cloud->IsTerminationPending() == true); + CHECK(Cloud->GetTerminationReason().find("spot interruption") != std::string::npos); + } + + SUBCASE("autoscaling termination") + { + Imds.Mock.Aws.AutoscalingState = "InService"; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + CHECK(Cloud->IsTerminationPending() == false); + + // Simulate ASG cycling the instance out + Imds.Mock.Aws.AutoscalingState = "Terminated:Wait"; + Cloud->PollTermination(); + + CHECK(Cloud->IsTerminationPending() == true); + CHECK(Cloud->GetTerminationReason().find("autoscaling") != std::string::npos); + } + + SUBCASE("no termination when InService") + { + Imds.Mock.Aws.AutoscalingState = "InService"; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + Cloud->PollTermination(); + + CHECK(Cloud->IsTerminationPending() == false); + } +} + +// --------------------------------------------------------------------------- +// Azure +// --------------------------------------------------------------------------- + +TEST_CASE("cloudmetadata.azure") +{ + TestImdsServer Imds; + Imds.Mock.ActiveProvider = CloudProvider::Azure; + + SUBCASE("detection basics") + { + Imds.Mock.Azure.VmId = "vm-test-1234"; + Imds.Mock.Azure.Location = "westeurope"; + Imds.Mock.Azure.Priority = "Regular"; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + + CHECK(Cloud->GetProvider() == CloudProvider::Azure); + + CloudInstanceInfo Info = Cloud->GetInstanceInfo(); + CHECK(Info.InstanceId == "vm-test-1234"); + CHECK(Info.AvailabilityZone == "westeurope"); + CHECK(Info.IsSpot == false); + CHECK(Info.IsAutoscaling == false); + CHECK(Cloud->IsTerminationPending() == false); + } + + SUBCASE("spot instance") + { + Imds.Mock.Azure.Priority = "Spot"; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + CloudInstanceInfo Info = Cloud->GetInstanceInfo(); + CHECK(Info.IsSpot == true); + } + + SUBCASE("vmss instance") + { + Imds.Mock.Azure.VmScaleSetName = "my-vmss"; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + CloudInstanceInfo Info = Cloud->GetInstanceInfo(); + CHECK(Info.IsAutoscaling == true); + } + + SUBCASE("preempt termination") + { + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + CHECK(Cloud->IsTerminationPending() == false); + + Imds.Mock.Azure.ScheduledEventType = "Preempt"; + Imds.Mock.Azure.ScheduledEventStatus = "Scheduled"; + Cloud->PollTermination(); + + CHECK(Cloud->IsTerminationPending() == true); + CHECK(Cloud->GetTerminationReason().find("Preempt") != std::string::npos); + } + + SUBCASE("terminate event") + { + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + CHECK(Cloud->IsTerminationPending() == false); + + Imds.Mock.Azure.ScheduledEventType = "Terminate"; + Cloud->PollTermination(); + + CHECK(Cloud->IsTerminationPending() == true); + CHECK(Cloud->GetTerminationReason().find("Terminate") != std::string::npos); + } + + SUBCASE("no termination when events empty") + { + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + Cloud->PollTermination(); + + CHECK(Cloud->IsTerminationPending() == false); + } +} + +// --------------------------------------------------------------------------- +// GCP +// --------------------------------------------------------------------------- + +TEST_CASE("cloudmetadata.gcp") +{ + TestImdsServer Imds; + Imds.Mock.ActiveProvider = CloudProvider::GCP; + + SUBCASE("detection basics") + { + Imds.Mock.Gcp.InstanceId = "9876543210"; + Imds.Mock.Gcp.Zone = "projects/123/zones/europe-west1-b"; + Imds.Mock.Gcp.Preemptible = "FALSE"; + Imds.Mock.Gcp.MaintenanceEvent = "NONE"; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + + CHECK(Cloud->GetProvider() == CloudProvider::GCP); + + CloudInstanceInfo Info = Cloud->GetInstanceInfo(); + CHECK(Info.InstanceId == "9876543210"); + CHECK(Info.AvailabilityZone == "europe-west1-b"); // zone prefix stripped + CHECK(Info.IsSpot == false); + CHECK(Cloud->IsTerminationPending() == false); + } + + SUBCASE("preemptible instance") + { + Imds.Mock.Gcp.Preemptible = "TRUE"; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + CloudInstanceInfo Info = Cloud->GetInstanceInfo(); + CHECK(Info.IsSpot == true); + } + + SUBCASE("maintenance event during detection") + { + Imds.Mock.Gcp.MaintenanceEvent = "TERMINATE_ON_HOST_MAINTENANCE"; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + + // GCP sets termination pending immediately during detection if a + // maintenance event is active + CHECK(Cloud->IsTerminationPending() == true); + CHECK(Cloud->GetTerminationReason().find("maintenance") != std::string::npos); + } + + SUBCASE("maintenance event during polling") + { + Imds.Mock.Gcp.MaintenanceEvent = "NONE"; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + CHECK(Cloud->IsTerminationPending() == false); + + Imds.Mock.Gcp.MaintenanceEvent = "TERMINATE_ON_HOST_MAINTENANCE"; + Cloud->PollTermination(); + + CHECK(Cloud->IsTerminationPending() == true); + CHECK(Cloud->GetTerminationReason().find("maintenance") != std::string::npos); + } + + SUBCASE("no termination when NONE") + { + Imds.Mock.Gcp.MaintenanceEvent = "NONE"; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + Cloud->PollTermination(); + + CHECK(Cloud->IsTerminationPending() == false); + } +} + +// --------------------------------------------------------------------------- +// No provider +// --------------------------------------------------------------------------- + +TEST_CASE("cloudmetadata.no_provider") +{ + TestImdsServer Imds; + Imds.Mock.ActiveProvider = CloudProvider::None; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + + CHECK(Cloud->GetProvider() == CloudProvider::None); + + CloudInstanceInfo Info = Cloud->GetInstanceInfo(); + CHECK(Info.InstanceId.empty()); + CHECK(Info.AvailabilityZone.empty()); + CHECK(Info.IsSpot == false); + CHECK(Info.IsAutoscaling == false); + CHECK(Cloud->IsTerminationPending() == false); +} + +// --------------------------------------------------------------------------- +// Sentinel file management +// --------------------------------------------------------------------------- + +TEST_CASE("cloudmetadata.sentinel_files") +{ + TestImdsServer Imds; + Imds.Mock.ActiveProvider = CloudProvider::None; + Imds.Start(); + + auto DataDir = Imds.DataDir(); + + SUBCASE("sentinels are written on failed detection") + { + auto Cloud = Imds.CreateCloud(); + + CHECK(Cloud->GetProvider() == CloudProvider::None); + CHECK(zen::IsFile(DataDir / ".isNotAWS")); + CHECK(zen::IsFile(DataDir / ".isNotAzure")); + CHECK(zen::IsFile(DataDir / ".isNotGCP")); + } + + SUBCASE("ClearSentinelFiles removes sentinels") + { + auto Cloud = Imds.CreateCloud(); + + CHECK(zen::IsFile(DataDir / ".isNotAWS")); + CHECK(zen::IsFile(DataDir / ".isNotAzure")); + CHECK(zen::IsFile(DataDir / ".isNotGCP")); + + Cloud->ClearSentinelFiles(); + + CHECK_FALSE(zen::IsFile(DataDir / ".isNotAWS")); + CHECK_FALSE(zen::IsFile(DataDir / ".isNotAzure")); + CHECK_FALSE(zen::IsFile(DataDir / ".isNotGCP")); + } + + SUBCASE("only failed providers get sentinels") + { + // Switch to AWS — Azure and GCP never probed, so no sentinels for them + Imds.Mock.ActiveProvider = CloudProvider::AWS; + + auto Cloud = Imds.CreateCloud(); + + CHECK(Cloud->GetProvider() == CloudProvider::AWS); + CHECK_FALSE(zen::IsFile(DataDir / ".isNotAWS")); + CHECK_FALSE(zen::IsFile(DataDir / ".isNotAzure")); + CHECK_FALSE(zen::IsFile(DataDir / ".isNotGCP")); + } +} + +void +cloudmetadata_forcelink() +{ +} + +} // namespace zen::compute + +#endif // ZEN_WITH_TESTS diff --git a/src/zencompute/computeservice.cpp b/src/zencompute/computeservice.cpp new file mode 100644 index 000000000..838d741b6 --- /dev/null +++ b/src/zencompute/computeservice.cpp @@ -0,0 +1,2236 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "zencompute/computeservice.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include "runners/functionrunner.h" +# include "recording/actionrecorder.h" +# include "runners/localrunner.h" +# include "runners/remotehttprunner.h" +# if ZEN_PLATFORM_LINUX +# include "runners/linuxrunner.h" +# elif ZEN_PLATFORM_WINDOWS +# include "runners/windowsrunner.h" +# elif ZEN_PLATFORM_MAC +# include "runners/macrunner.h" +# endif + +# include <zencompute/recordingreader.h> +# include <zencore/compactbinary.h> +# include <zencore/compactbinarybuilder.h> +# include <zencore/compactbinarypackage.h> +# include <zencore/compress.h> +# include <zencore/except.h> +# include <zencore/filesystem.h> +# include <zencore/fmtutils.h> +# include <zencore/iobuffer.h> +# include <zencore/iohash.h> +# include <zencore/logging.h> +# include <zencore/scopeguard.h> +# include <zencore/trace.h> +# include <zencore/workthreadpool.h> +# include <zenutil/workerpools.h> +# include <zentelemetry/stats.h> +# include <zenhttp/httpclient.h> + +# include <set> +# include <deque> +# include <map> +# include <thread> +# include <unordered_map> +# include <unordered_set> + +ZEN_THIRD_PARTY_INCLUDES_START +# include <EASTL/hash_set.h> +ZEN_THIRD_PARTY_INCLUDES_END + +using namespace std::literals; + +namespace zen { + +const char* +ToString(compute::ComputeServiceSession::SessionState State) +{ + using enum compute::ComputeServiceSession::SessionState; + switch (State) + { + case Created: + return "Created"; + case Ready: + return "Ready"; + case Draining: + return "Draining"; + case Paused: + return "Paused"; + case Abandoned: + return "Abandoned"; + case Sunset: + return "Sunset"; + } + return "Unknown"; +} + +const char* +ToString(compute::ComputeServiceSession::QueueState State) +{ + using enum compute::ComputeServiceSession::QueueState; + switch (State) + { + case Active: + return "active"; + case Draining: + return "draining"; + case Cancelled: + return "cancelled"; + } + return "unknown"; +} + +} // namespace zen + +namespace zen::compute { + +using SessionState = ComputeServiceSession::SessionState; + +static_assert(ZEN_ARRAY_COUNT(ComputeServiceSession::ActionHistoryEntry::Timestamps) == static_cast<size_t>(RunnerAction::State::_Count)); + +////////////////////////////////////////////////////////////////////////// + +struct ComputeServiceSession::Impl +{ + ComputeServiceSession* m_ComputeServiceSession; + ChunkResolver& m_ChunkResolver; + LoggerRef m_Log{logging::Get("compute")}; + + Impl(ComputeServiceSession* InComputeServiceSession, ChunkResolver& InChunkResolver) + : m_ComputeServiceSession(InComputeServiceSession) + , m_ChunkResolver(InChunkResolver) + , m_LocalSubmitPool(GetLargeWorkerPool(EWorkloadType::Burst)) + , m_RemoteSubmitPool(GetLargeWorkerPool(EWorkloadType::Burst)) + { + // Create a non-expiring, non-deletable implicit queue for legacy endpoints + auto Result = CreateQueue("implicit"sv, {}, {}); + m_ImplicitQueueId = Result.QueueId; + m_QueueLock.WithSharedLock([&] { m_Queues[m_ImplicitQueueId]->Implicit = true; }); + + m_SchedulingThread = std::thread{&Impl::SchedulerThreadFunction, this}; + } + + void WaitUntilReady(); + void Shutdown(); + bool IsHealthy(); + + bool RequestStateTransition(SessionState NewState); + void AbandonAllActions(); + + LoggerRef Log() { return m_Log; } + + // Orchestration + + void SetOrchestratorEndpoint(std::string_view Endpoint); + void SetOrchestratorBasePath(std::filesystem::path BasePath); + + std::string m_OrchestratorEndpoint; + std::filesystem::path m_OrchestratorBasePath; + Stopwatch m_OrchestratorQueryTimer; + std::unordered_set<std::string> m_KnownWorkerUris; + + void UpdateCoordinatorState(); + + // Worker registration and discovery + + struct FunctionDefinition + { + std::string FunctionName; + Guid FunctionVersion; + Guid BuildSystemVersion; + IoHash WorkerId; + }; + + void RegisterWorker(CbPackage Worker); + WorkerDesc GetWorkerDescriptor(const IoHash& WorkerId); + + // Action scheduling and tracking + + std::atomic<SessionState> m_SessionState{SessionState::Created}; + std::atomic<int32_t> m_ActionsCounter = 0; // sequence number + metrics::Meter m_ArrivalRate; + + RwLock m_PendingLock; + std::map<int, Ref<RunnerAction>> m_PendingActions; + + RwLock m_RunningLock; + std::unordered_map<int, Ref<RunnerAction>> m_RunningMap; + + RwLock m_ResultsLock; + std::unordered_map<int, Ref<RunnerAction>> m_ResultsMap; + metrics::Meter m_ResultRate; + std::atomic<uint64_t> m_RetiredCount{0}; + + EnqueueResult EnqueueAction(int QueueId, CbObject ActionObject, int Priority); + EnqueueResult EnqueueResolvedAction(int QueueId, WorkerDesc Worker, CbObject ActionObj, int RequestPriority); + + void GetCompleted(CbWriter& Cbo); + + HttpResponseCode GetActionResult(int ActionLsn, CbPackage& OutResultPackage); + HttpResponseCode FindActionResult(const IoHash& ActionId, CbPackage& ResultPackage); + void RetireActionResult(int ActionLsn); + + std::thread m_SchedulingThread; + std::atomic<bool> m_SchedulingThreadEnabled{true}; + Event m_SchedulingThreadEvent; + + void SchedulerThreadFunction(); + void SchedulePendingActions(); + + // Workers + + RwLock m_WorkerLock; + std::unordered_map<IoHash, CbPackage> m_WorkerMap; + std::vector<FunctionDefinition> m_FunctionList; + std::vector<IoHash> GetKnownWorkerIds(); + void SyncWorkersToRunner(FunctionRunner& Runner); + + // Runners + + DeferredDirectoryDeleter m_DeferredDeleter; + WorkerThreadPool& m_LocalSubmitPool; + WorkerThreadPool& m_RemoteSubmitPool; + RunnerGroup<LocalProcessRunner> m_LocalRunnerGroup; + RunnerGroup<RemoteHttpRunner> m_RemoteRunnerGroup; + + void ShutdownRunners(); + + // Recording + + void StartRecording(ChunkResolver& InCidStore, const std::filesystem::path& RecordingPath); + void StopRecording(); + + std::unique_ptr<ActionRecorder> m_Recorder; + + // History tracking + + RwLock m_ActionHistoryLock; + std::deque<ComputeServiceSession::ActionHistoryEntry> m_ActionHistory; + size_t m_HistoryLimit = 1000; + + // Queue tracking + + using QueueState = ComputeServiceSession::QueueState; + + struct QueueEntry : RefCounted + { + int QueueId; + bool Implicit{false}; + std::atomic<QueueState> State{QueueState::Active}; + std::atomic<int> ActiveCount{0}; // pending + running + std::atomic<int> CompletedCount{0}; // successfully completed + std::atomic<int> FailedCount{0}; // failed + std::atomic<int> AbandonedCount{0}; // abandoned + std::atomic<int> CancelledCount{0}; // cancelled + std::atomic<uint64_t> IdleSince{0}; // hifreq tick when queue became idle; 0 = has active work + + RwLock m_Lock; + std::unordered_set<int> ActiveLsns; // for cancellation lookup + std::unordered_set<int> FinishedLsns; // completed/failed/cancelled LSNs + + std::string Tag; + CbObject Metadata; + CbObject Config; + }; + + int m_ImplicitQueueId{0}; + std::atomic<int> m_QueueCounter{0}; + RwLock m_QueueLock; + std::unordered_map<int, Ref<QueueEntry>> m_Queues; + + Ref<QueueEntry> FindQueue(int QueueId) + { + Ref<QueueEntry> Queue; + m_QueueLock.WithSharedLock([&] { + if (auto It = m_Queues.find(QueueId); It != m_Queues.end()) + { + Queue = It->second; + } + }); + return Queue; + } + + ComputeServiceSession::CreateQueueResult CreateQueue(std::string_view Tag, CbObject Metadata, CbObject Config); + std::vector<int> GetQueueIds(); + ComputeServiceSession::QueueStatus GetQueueStatus(int QueueId); + CbObject GetQueueMetadata(int QueueId); + CbObject GetQueueConfig(int QueueId); + void CancelQueue(int QueueId); + void DeleteQueue(int QueueId); + void DrainQueue(int QueueId); + ComputeServiceSession::EnqueueResult EnqueueActionToQueue(int QueueId, CbObject ActionObject, int Priority); + ComputeServiceSession::EnqueueResult EnqueueResolvedActionToQueue(int QueueId, WorkerDesc Worker, CbObject ActionObj, int Priority); + void GetQueueCompleted(int QueueId, CbWriter& Cbo); + void NotifyQueueActionComplete(int QueueId, int Lsn, RunnerAction::State ActionState); + void ExpireCompletedQueues(); + + Stopwatch m_QueueExpiryTimer; + + std::vector<ComputeServiceSession::RunningActionInfo> GetRunningActions(); + std::vector<ComputeServiceSession::ActionHistoryEntry> GetActionHistory(int Limit); + std::vector<ComputeServiceSession::ActionHistoryEntry> GetQueueHistory(int QueueId, int Limit); + + // Action submission + + [[nodiscard]] size_t QueryCapacity(); + + [[nodiscard]] SubmitResult SubmitAction(Ref<RunnerAction> Action); + [[nodiscard]] std::vector<SubmitResult> SubmitActions(const std::vector<Ref<RunnerAction>>& Actions); + [[nodiscard]] size_t GetSubmittedActionCount(); + + // Updates + + RwLock m_UpdatedActionsLock; + std::vector<Ref<RunnerAction>> m_UpdatedActions; + + void HandleActionUpdates(); + void PostUpdate(RunnerAction* Action); + + static constexpr int kDefaultMaxRetries = 3; + int GetMaxRetriesForQueue(int QueueId); + + ComputeServiceSession::RescheduleResult RescheduleAction(int ActionLsn); + + ActionCounts GetActionCounts() + { + ActionCounts Counts; + Counts.Pending = (int)m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); }); + Counts.Running = (int)m_RunningLock.WithSharedLock([&] { return m_RunningMap.size(); }); + Counts.Completed = (int)m_ResultsLock.WithSharedLock([&] { return m_ResultsMap.size(); }) + (int)m_RetiredCount.load(); + Counts.ActiveQueues = (int)m_QueueLock.WithSharedLock([&] { + size_t Count = 0; + for (const auto& [Id, Queue] : m_Queues) + { + if (!Queue->Implicit) + { + ++Count; + } + } + return Count; + }); + return Counts; + } + + void EmitStats(CbObjectWriter& Cbo) + { + Cbo << "session_state"sv << ToString(m_SessionState.load(std::memory_order_relaxed)); + m_WorkerLock.WithSharedLock([&] { Cbo << "worker_count"sv << m_WorkerMap.size(); }); + m_ResultsLock.WithSharedLock([&] { Cbo << "actions_complete"sv << m_ResultsMap.size(); }); + m_PendingLock.WithSharedLock([&] { Cbo << "actions_pending"sv << m_PendingActions.size(); }); + Cbo << "actions_submitted"sv << GetSubmittedActionCount(); + EmitSnapshot("actions_arrival"sv, m_ArrivalRate, Cbo); + EmitSnapshot("actions_retired"sv, m_ResultRate, Cbo); + } +}; + +bool +ComputeServiceSession::Impl::IsHealthy() +{ + return m_SessionState.load(std::memory_order_relaxed) < SessionState::Abandoned; +} + +bool +ComputeServiceSession::Impl::RequestStateTransition(SessionState NewState) +{ + SessionState Current = m_SessionState.load(std::memory_order_relaxed); + + for (;;) + { + if (Current == NewState) + { + return true; + } + + // Validate the transition + bool Valid = false; + + switch (Current) + { + case SessionState::Created: + Valid = (NewState == SessionState::Ready); + break; + case SessionState::Ready: + Valid = (NewState == SessionState::Draining); + break; + case SessionState::Draining: + Valid = (NewState == SessionState::Ready || NewState == SessionState::Paused); + break; + case SessionState::Paused: + Valid = (NewState == SessionState::Ready || NewState == SessionState::Sunset); + break; + case SessionState::Abandoned: + Valid = (NewState == SessionState::Sunset); + break; + case SessionState::Sunset: + Valid = false; + break; + } + + // Allow jumping directly to Abandoned or Sunset from any non-terminal state + if (NewState == SessionState::Abandoned && Current < SessionState::Abandoned) + { + Valid = true; + } + if (NewState == SessionState::Sunset && Current != SessionState::Sunset) + { + Valid = true; + } + + if (!Valid) + { + ZEN_WARN("invalid session state transition {} -> {}", ToString(Current), ToString(NewState)); + return false; + } + + if (m_SessionState.compare_exchange_strong(Current, NewState, std::memory_order_acq_rel)) + { + ZEN_INFO("session state: {} -> {}", ToString(Current), ToString(NewState)); + + if (NewState == SessionState::Abandoned) + { + AbandonAllActions(); + } + + return true; + } + + // CAS failed, Current was updated — retry with the new value + } +} + +void +ComputeServiceSession::Impl::AbandonAllActions() +{ + // Collect all pending actions and mark them as Abandoned + std::vector<Ref<RunnerAction>> PendingToAbandon; + + m_PendingLock.WithSharedLock([&] { + PendingToAbandon.reserve(m_PendingActions.size()); + for (auto& [Lsn, Action] : m_PendingActions) + { + PendingToAbandon.push_back(Action); + } + }); + + for (auto& Action : PendingToAbandon) + { + Action->SetActionState(RunnerAction::State::Abandoned); + } + + // Collect all running actions and mark them as Abandoned, then + // best-effort cancel via the local runner group + std::vector<Ref<RunnerAction>> RunningToAbandon; + + m_RunningLock.WithSharedLock([&] { + RunningToAbandon.reserve(m_RunningMap.size()); + for (auto& [Lsn, Action] : m_RunningMap) + { + RunningToAbandon.push_back(Action); + } + }); + + for (auto& Action : RunningToAbandon) + { + Action->SetActionState(RunnerAction::State::Abandoned); + m_LocalRunnerGroup.CancelAction(Action->ActionLsn); + } + + ZEN_INFO("abandoned all actions: {} pending, {} running", PendingToAbandon.size(), RunningToAbandon.size()); +} + +void +ComputeServiceSession::Impl::SetOrchestratorEndpoint(std::string_view Endpoint) +{ + m_OrchestratorEndpoint = Endpoint; +} + +void +ComputeServiceSession::Impl::SetOrchestratorBasePath(std::filesystem::path BasePath) +{ + m_OrchestratorBasePath = std::move(BasePath); +} + +void +ComputeServiceSession::Impl::UpdateCoordinatorState() +{ + ZEN_TRACE_CPU("ComputeServiceSession::UpdateCoordinatorState"); + if (m_OrchestratorEndpoint.empty()) + { + return; + } + + // Poll faster when we have no discovered workers yet so remote runners come online quickly + const uint64_t PollIntervalMs = m_KnownWorkerUris.empty() ? 500 : 5000; + if (m_OrchestratorQueryTimer.GetElapsedTimeMs() < PollIntervalMs) + { + return; + } + + m_OrchestratorQueryTimer.Reset(); + + try + { + HttpClient Client(m_OrchestratorEndpoint); + + HttpClient::Response Response = Client.Get("/orch/agents"); + + if (!Response.IsSuccess()) + { + ZEN_WARN("orchestrator query failed with status {}", static_cast<int>(Response.StatusCode)); + return; + } + + CbObject WorkerList = Response.AsObject(); + + std::unordered_set<std::string> ValidWorkerUris; + + for (auto& Item : WorkerList["workers"sv]) + { + CbObjectView Worker = Item.AsObjectView(); + + uint64_t Dt = Worker["dt"sv].AsUInt64(); + bool Reachable = Worker["reachable"sv].AsBool(); + std::string_view Uri = Worker["uri"sv].AsString(); + + // Skip stale workers (not seen in over 30 seconds) + if (Dt > 30000) + { + continue; + } + + // Skip workers that are not confirmed reachable + if (!Reachable) + { + continue; + } + + std::string UriStr{Uri}; + ValidWorkerUris.insert(UriStr); + + // Skip workers we already know about + if (m_KnownWorkerUris.contains(UriStr)) + { + continue; + } + + ZEN_INFO("discovered new worker at {}", UriStr); + + m_KnownWorkerUris.insert(UriStr); + + auto* NewRunner = new RemoteHttpRunner(m_ChunkResolver, m_OrchestratorBasePath, UriStr, m_RemoteSubmitPool); + SyncWorkersToRunner(*NewRunner); + m_RemoteRunnerGroup.AddRunner(NewRunner); + } + + // Remove workers that are no longer valid (stale or unreachable) + for (auto It = m_KnownWorkerUris.begin(); It != m_KnownWorkerUris.end();) + { + if (!ValidWorkerUris.contains(*It)) + { + const std::string& ExpiredUri = *It; + ZEN_INFO("removing expired worker at {}", ExpiredUri); + + m_RemoteRunnerGroup.RemoveRunnerIf([&](const RemoteHttpRunner& Runner) { return Runner.GetHostName() == ExpiredUri; }); + + It = m_KnownWorkerUris.erase(It); + } + else + { + ++It; + } + } + } + catch (const HttpClientError& Ex) + { + ZEN_WARN("orchestrator query error: {}", Ex.what()); + } + catch (const std::exception& Ex) + { + ZEN_WARN("orchestrator query unexpected error: {}", Ex.what()); + } +} + +void +ComputeServiceSession::Impl::WaitUntilReady() +{ + if (m_RemoteRunnerGroup.GetRunnerCount() || !m_OrchestratorEndpoint.empty()) + { + ZEN_INFO("waiting for remote runners..."); + + constexpr int MaxWaitSeconds = 120; + + for (int Elapsed = 0; Elapsed < MaxWaitSeconds; Elapsed++) + { + if (!m_SchedulingThreadEnabled.load(std::memory_order_relaxed)) + { + ZEN_WARN("shutdown requested while waiting for remote runners"); + return; + } + + const size_t Capacity = m_RemoteRunnerGroup.QueryCapacity(); + + if (Capacity > 0) + { + ZEN_INFO("found {} remote runners (capacity: {})", m_RemoteRunnerGroup.GetRunnerCount(), Capacity); + break; + } + + zen::Sleep(1000); + } + } + else + { + ZEN_ASSERT(m_LocalRunnerGroup.GetRunnerCount(), "no runners available"); + } + + RequestStateTransition(SessionState::Ready); +} + +void +ComputeServiceSession::Impl::Shutdown() +{ + RequestStateTransition(SessionState::Sunset); + + m_SchedulingThreadEnabled = false; + m_SchedulingThreadEvent.Set(); + if (m_SchedulingThread.joinable()) + { + m_SchedulingThread.join(); + } + + ShutdownRunners(); + + m_DeferredDeleter.Shutdown(); +} + +void +ComputeServiceSession::Impl::ShutdownRunners() +{ + m_LocalRunnerGroup.Shutdown(); + m_RemoteRunnerGroup.Shutdown(); +} + +void +ComputeServiceSession::Impl::StartRecording(ChunkResolver& InCidStore, const std::filesystem::path& RecordingPath) +{ + ZEN_INFO("starting recording to '{}'", RecordingPath); + + m_Recorder = std::make_unique<ActionRecorder>(InCidStore, RecordingPath); + + ZEN_INFO("started recording to '{}'", RecordingPath); +} + +void +ComputeServiceSession::Impl::StopRecording() +{ + ZEN_INFO("stopping recording"); + + m_Recorder = nullptr; + + ZEN_INFO("stopped recording"); +} + +std::vector<ComputeServiceSession::RunningActionInfo> +ComputeServiceSession::Impl::GetRunningActions() +{ + std::vector<ComputeServiceSession::RunningActionInfo> Result; + m_RunningLock.WithSharedLock([&] { + Result.reserve(m_RunningMap.size()); + for (const auto& [Lsn, Action] : m_RunningMap) + { + Result.push_back({.Lsn = Lsn, + .QueueId = Action->QueueId, + .ActionId = Action->ActionId, + .CpuUsagePercent = Action->CpuUsagePercent.load(std::memory_order_relaxed), + .CpuSeconds = Action->CpuSeconds.load(std::memory_order_relaxed)}); + } + }); + return Result; +} + +std::vector<ComputeServiceSession::ActionHistoryEntry> +ComputeServiceSession::Impl::GetActionHistory(int Limit) +{ + RwLock::SharedLockScope _(m_ActionHistoryLock); + + if (Limit > 0 && static_cast<size_t>(Limit) < m_ActionHistory.size()) + { + return std::vector<ActionHistoryEntry>(m_ActionHistory.end() - Limit, m_ActionHistory.end()); + } + + return std::vector<ActionHistoryEntry>(m_ActionHistory.begin(), m_ActionHistory.end()); +} + +std::vector<ComputeServiceSession::ActionHistoryEntry> +ComputeServiceSession::Impl::GetQueueHistory(int QueueId, int Limit) +{ + // Resolve the queue and snapshot its finished LSN set + Ref<QueueEntry> Queue = FindQueue(QueueId); + + if (!Queue) + { + return {}; + } + + std::unordered_set<int> FinishedLsns; + + Queue->m_Lock.WithSharedLock([&] { FinishedLsns = Queue->FinishedLsns; }); + + // Filter the global history to entries belonging to this queue. + // m_ActionHistory is ordered oldest-first, so the filtered result keeps the same ordering. + std::vector<ActionHistoryEntry> Result; + + m_ActionHistoryLock.WithSharedLock([&] { + for (const auto& Entry : m_ActionHistory) + { + if (FinishedLsns.contains(Entry.Lsn)) + { + Result.push_back(Entry); + } + } + }); + + if (Limit > 0 && static_cast<size_t>(Limit) < Result.size()) + { + Result.erase(Result.begin(), Result.end() - Limit); + } + + return Result; +} + +void +ComputeServiceSession::Impl::RegisterWorker(CbPackage Worker) +{ + ZEN_TRACE_CPU("ComputeServiceSession::RegisterWorker"); + RwLock::ExclusiveLockScope _(m_WorkerLock); + + const IoHash& WorkerId = Worker.GetObject().GetHash(); + + if (m_WorkerMap.insert_or_assign(WorkerId, Worker).second) + { + // Note that since the convention currently is that WorkerId is equal to the hash + // of the worker descriptor there is no chance that we get a second write with a + // different descriptor. Thus we only need to call this the first time, when the + // worker is added + + m_LocalRunnerGroup.RegisterWorker(Worker); + m_RemoteRunnerGroup.RegisterWorker(Worker); + + if (m_Recorder) + { + m_Recorder->RegisterWorker(Worker); + } + + CbObject WorkerObj = Worker.GetObject(); + + // Populate worker database + + const Guid WorkerBuildSystemVersion = WorkerObj["buildsystem_version"sv].AsUuid(); + + for (auto& Item : WorkerObj["functions"sv]) + { + CbObjectView Function = Item.AsObjectView(); + + std::string_view FunctionName = Function["name"sv].AsString(); + const Guid FunctionVersion = Function["version"sv].AsUuid(); + + m_FunctionList.emplace_back(FunctionDefinition{.FunctionName = std::string{FunctionName}, + .FunctionVersion = FunctionVersion, + .BuildSystemVersion = WorkerBuildSystemVersion, + .WorkerId = WorkerId}); + } + } +} + +void +ComputeServiceSession::Impl::SyncWorkersToRunner(FunctionRunner& Runner) +{ + ZEN_TRACE_CPU("SyncWorkersToRunner"); + + std::vector<CbPackage> Workers; + + { + RwLock::SharedLockScope _(m_WorkerLock); + Workers.reserve(m_WorkerMap.size()); + for (const auto& [Id, Pkg] : m_WorkerMap) + { + Workers.push_back(Pkg); + } + } + + for (const CbPackage& Worker : Workers) + { + Runner.RegisterWorker(Worker); + } +} + +WorkerDesc +ComputeServiceSession::Impl::GetWorkerDescriptor(const IoHash& WorkerId) +{ + RwLock::SharedLockScope _(m_WorkerLock); + + if (auto It = m_WorkerMap.find(WorkerId); It != m_WorkerMap.end()) + { + const CbPackage& Desc = It->second; + return {Desc, WorkerId}; + } + + return {}; +} + +std::vector<IoHash> +ComputeServiceSession::Impl::GetKnownWorkerIds() +{ + std::vector<IoHash> WorkerIds; + + m_WorkerLock.WithSharedLock([&] { + WorkerIds.reserve(m_WorkerMap.size()); + for (const auto& [WorkerId, _] : m_WorkerMap) + { + WorkerIds.push_back(WorkerId); + } + }); + + return WorkerIds; +} + +ComputeServiceSession::EnqueueResult +ComputeServiceSession::Impl::EnqueueAction(int QueueId, CbObject ActionObject, int Priority) +{ + ZEN_TRACE_CPU("ComputeServiceSession::EnqueueAction"); + + // Resolve function to worker + + IoHash WorkerId{IoHash::Zero}; + CbPackage WorkerPackage; + + std::string_view FunctionName = ActionObject["Function"sv].AsString(); + const Guid FunctionVersion = ActionObject["FunctionVersion"sv].AsUuid(); + const Guid BuildSystemVersion = ActionObject["BuildSystemVersion"sv].AsUuid(); + + m_WorkerLock.WithSharedLock([&] { + for (const FunctionDefinition& FuncDef : m_FunctionList) + { + if (FuncDef.FunctionName == FunctionName && FuncDef.FunctionVersion == FunctionVersion && + FuncDef.BuildSystemVersion == BuildSystemVersion) + { + WorkerId = FuncDef.WorkerId; + + break; + } + } + + if (WorkerId != IoHash::Zero) + { + if (auto It = m_WorkerMap.find(WorkerId); It != m_WorkerMap.end()) + { + WorkerPackage = It->second; + } + } + }); + + if (WorkerId == IoHash::Zero) + { + CbObjectWriter Writer; + + Writer << "Function"sv << FunctionName << "FunctionVersion"sv << FunctionVersion << "BuildSystemVersion" << BuildSystemVersion; + Writer << "error" + << "no worker matches the action specification"; + + return {0, Writer.Save()}; + } + + if (WorkerPackage) + { + return EnqueueResolvedAction(QueueId, WorkerDesc{WorkerPackage, WorkerId}, ActionObject, Priority); + } + + CbObjectWriter Writer; + + Writer << "Function"sv << FunctionName << "FunctionVersion"sv << FunctionVersion << "BuildSystemVersion" << BuildSystemVersion; + Writer << "error" + << "no worker found despite match"; + + return {0, Writer.Save()}; +} + +ComputeServiceSession::EnqueueResult +ComputeServiceSession::Impl::EnqueueResolvedAction(int QueueId, WorkerDesc Worker, CbObject ActionObj, int RequestPriority) +{ + ZEN_TRACE_CPU("ComputeServiceSession::EnqueueResolvedAction"); + + if (m_SessionState.load(std::memory_order_relaxed) != SessionState::Ready) + { + CbObjectWriter Writer; + Writer << "error"sv << fmt::format("session is not accepting actions (state: {})", ToString(m_SessionState.load())); + return {0, Writer.Save()}; + } + + const int ActionLsn = ++m_ActionsCounter; + + m_ArrivalRate.Mark(); + + Ref<RunnerAction> Pending{new RunnerAction(m_ComputeServiceSession)}; + + Pending->ActionLsn = ActionLsn; + Pending->QueueId = QueueId; + Pending->Worker = Worker; + Pending->ActionId = ActionObj.GetHash(); + Pending->ActionObj = ActionObj; + Pending->Priority = RequestPriority; + + // For now simply put action into pending state, so we can do batch scheduling + + ZEN_DEBUG("action {} ({}) PENDING", Pending->ActionId, Pending->ActionLsn); + + Pending->SetActionState(RunnerAction::State::Pending); + + if (m_Recorder) + { + m_Recorder->RecordAction(Pending); + } + + CbObjectWriter Writer; + Writer << "lsn" << Pending->ActionLsn; + Writer << "worker" << Pending->Worker.WorkerId; + Writer << "action" << Pending->ActionId; + + return {Pending->ActionLsn, Writer.Save()}; +} + +SubmitResult +ComputeServiceSession::Impl::SubmitAction(Ref<RunnerAction> Action) +{ + // Loosely round-robin scheduling of actions across runners. + // + // It's not entirely clear what this means given that submits + // can come in across multiple threads, but it's probably better + // than always starting with the first runner. + // + // Longer term we should track the state of the individual + // runners and make decisions accordingly. + + SubmitResult Result = m_LocalRunnerGroup.SubmitAction(Action); + if (Result.IsAccepted) + { + return Result; + } + + return m_RemoteRunnerGroup.SubmitAction(Action); +} + +size_t +ComputeServiceSession::Impl::GetSubmittedActionCount() +{ + return m_LocalRunnerGroup.GetSubmittedActionCount() + m_RemoteRunnerGroup.GetSubmittedActionCount(); +} + +HttpResponseCode +ComputeServiceSession::Impl::GetActionResult(int ActionLsn, CbPackage& OutResultPackage) +{ + // This lock is held for the duration of the function since we need to + // be sure that the action doesn't change state while we are checking the + // different data structures + + RwLock::ExclusiveLockScope _(m_ResultsLock); + + if (auto It = m_ResultsMap.find(ActionLsn); It != m_ResultsMap.end()) + { + OutResultPackage = std::move(It->second->GetResult()); + + m_ResultsMap.erase(It); + + return HttpResponseCode::OK; + } + + { + RwLock::SharedLockScope __(m_PendingLock); + + if (auto FindIt = m_PendingActions.find(ActionLsn); FindIt != m_PendingActions.end()) + { + return HttpResponseCode::Accepted; + } + } + + // Lock order is important here to avoid deadlocks, RwLock m_RunningLock must + // always be taken after m_ResultsLock if both are needed + + { + RwLock::SharedLockScope __(m_RunningLock); + + if (m_RunningMap.find(ActionLsn) != m_RunningMap.end()) + { + return HttpResponseCode::Accepted; + } + } + + return HttpResponseCode::NotFound; +} + +HttpResponseCode +ComputeServiceSession::Impl::FindActionResult(const IoHash& ActionId, CbPackage& OutResultPackage) +{ + // This lock is held for the duration of the function since we need to + // be sure that the action doesn't change state while we are checking the + // different data structures + + RwLock::ExclusiveLockScope _(m_ResultsLock); + + for (auto It = begin(m_ResultsMap), End = end(m_ResultsMap); It != End; ++It) + { + if (It->second->ActionId == ActionId) + { + OutResultPackage = std::move(It->second->GetResult()); + + m_ResultsMap.erase(It); + + return HttpResponseCode::OK; + } + } + + { + RwLock::SharedLockScope __(m_PendingLock); + + for (const auto& [K, Pending] : m_PendingActions) + { + if (Pending->ActionId == ActionId) + { + return HttpResponseCode::Accepted; + } + } + } + + // Lock order is important here to avoid deadlocks, RwLock m_RunningLock must + // always be taken after m_ResultsLock if both are needed + + { + RwLock::SharedLockScope __(m_RunningLock); + + for (const auto& [K, v] : m_RunningMap) + { + if (v->ActionId == ActionId) + { + return HttpResponseCode::Accepted; + } + } + } + + return HttpResponseCode::NotFound; +} + +void +ComputeServiceSession::Impl::RetireActionResult(int ActionLsn) +{ + m_DeferredDeleter.MarkReady(ActionLsn); +} + +void +ComputeServiceSession::Impl::GetCompleted(CbWriter& Cbo) +{ + Cbo.BeginArray("completed"); + + m_ResultsLock.WithSharedLock([&] { + for (auto& [Lsn, Action] : m_ResultsMap) + { + Cbo.BeginObject(); + Cbo << "lsn"sv << Lsn; + Cbo << "state"sv << RunnerAction::ToString(Action->ActionState()); + Cbo.EndObject(); + } + }); + + Cbo.EndArray(); +} + +////////////////////////////////////////////////////////////////////////// +// Queue management + +ComputeServiceSession::CreateQueueResult +ComputeServiceSession::Impl::CreateQueue(std::string_view Tag, CbObject Metadata, CbObject Config) +{ + const int QueueId = ++m_QueueCounter; + + Ref<QueueEntry> Queue{new QueueEntry()}; + Queue->QueueId = QueueId; + Queue->Tag = Tag; + Queue->Metadata = std::move(Metadata); + Queue->Config = std::move(Config); + Queue->IdleSince = GetHifreqTimerValue(); + + m_QueueLock.WithExclusiveLock([&] { m_Queues[QueueId] = Queue; }); + + ZEN_DEBUG("created queue {}", QueueId); + + return {.QueueId = QueueId}; +} + +std::vector<int> +ComputeServiceSession::Impl::GetQueueIds() +{ + std::vector<int> Ids; + + m_QueueLock.WithSharedLock([&] { + Ids.reserve(m_Queues.size()); + for (const auto& [Id, Queue] : m_Queues) + { + if (!Queue->Implicit) + { + Ids.push_back(Id); + } + } + }); + + return Ids; +} + +ComputeServiceSession::QueueStatus +ComputeServiceSession::Impl::GetQueueStatus(int QueueId) +{ + Ref<QueueEntry> Queue = FindQueue(QueueId); + + if (!Queue) + { + return {}; + } + + const int Active = Queue->ActiveCount.load(std::memory_order_relaxed); + const int Completed = Queue->CompletedCount.load(std::memory_order_relaxed); + const int Failed = Queue->FailedCount.load(std::memory_order_relaxed); + const int AbandonedN = Queue->AbandonedCount.load(std::memory_order_relaxed); + const int CancelledN = Queue->CancelledCount.load(std::memory_order_relaxed); + const QueueState QState = Queue->State.load(); + + return { + .IsValid = true, + .QueueId = QueueId, + .ActiveCount = Active, + .CompletedCount = Completed, + .FailedCount = Failed, + .AbandonedCount = AbandonedN, + .CancelledCount = CancelledN, + .State = QState, + .IsComplete = (Active == 0), + }; +} + +CbObject +ComputeServiceSession::Impl::GetQueueMetadata(int QueueId) +{ + Ref<QueueEntry> Queue = FindQueue(QueueId); + + if (!Queue) + { + return {}; + } + + return Queue->Metadata; +} + +CbObject +ComputeServiceSession::Impl::GetQueueConfig(int QueueId) +{ + Ref<QueueEntry> Queue = FindQueue(QueueId); + + if (!Queue) + { + return {}; + } + + return Queue->Config; +} + +void +ComputeServiceSession::Impl::CancelQueue(int QueueId) +{ + Ref<QueueEntry> Queue = FindQueue(QueueId); + + if (!Queue || Queue->Implicit) + { + return; + } + + Queue->State.store(QueueState::Cancelled); + + // Collect active LSNs snapshot for cancellation + std::vector<int> LsnsToCancel; + + Queue->m_Lock.WithSharedLock([&] { LsnsToCancel.assign(Queue->ActiveLsns.begin(), Queue->ActiveLsns.end()); }); + + // Identify which LSNs are still pending (not yet dispatched to a runner) + std::vector<Ref<RunnerAction>> PendingActionsToCancel; + std::vector<int> RunningLsnsToCancel; + + m_PendingLock.WithSharedLock([&] { + for (int Lsn : LsnsToCancel) + { + if (auto It = m_PendingActions.find(Lsn); It != m_PendingActions.end()) + { + PendingActionsToCancel.push_back(It->second); + } + } + }); + + m_RunningLock.WithSharedLock([&] { + for (int Lsn : LsnsToCancel) + { + if (m_RunningMap.find(Lsn) != m_RunningMap.end()) + { + RunningLsnsToCancel.push_back(Lsn); + } + } + }); + + // Cancel pending actions by marking them as Cancelled; they will flow through + // HandleActionUpdates and eventually be removed from the pending map. + for (auto& Action : PendingActionsToCancel) + { + Action->SetActionState(RunnerAction::State::Cancelled); + } + + // Best-effort cancellation of running actions via the local runner group. + // Also set the action state to Cancelled directly so a subsequent Failed + // transition from the runner is blocked (Cancelled > Failed in the enum). + for (int Lsn : RunningLsnsToCancel) + { + m_RunningLock.WithSharedLock([&] { + if (auto It = m_RunningMap.find(Lsn); It != m_RunningMap.end()) + { + It->second->SetActionState(RunnerAction::State::Cancelled); + } + }); + m_LocalRunnerGroup.CancelAction(Lsn); + } + + m_RemoteRunnerGroup.CancelRemoteQueue(QueueId); + + ZEN_INFO("cancelled queue {}: {} pending, {} running actions cancelled", + QueueId, + PendingActionsToCancel.size(), + RunningLsnsToCancel.size()); + + // Wake up the scheduler to process the cancelled actions + m_SchedulingThreadEvent.Set(); +} + +void +ComputeServiceSession::Impl::DeleteQueue(int QueueId) +{ + // Never delete the implicit queue + { + Ref<QueueEntry> Queue = FindQueue(QueueId); + if (Queue && Queue->Implicit) + { + return; + } + } + + // Cancel any active work first + CancelQueue(QueueId); + + m_QueueLock.WithExclusiveLock([&] { + if (auto It = m_Queues.find(QueueId); It != m_Queues.end()) + { + m_Queues.erase(It); + } + }); +} + +void +ComputeServiceSession::Impl::DrainQueue(int QueueId) +{ + Ref<QueueEntry> Queue = FindQueue(QueueId); + + if (!Queue || Queue->Implicit) + { + return; + } + + QueueState Expected = QueueState::Active; + Queue->State.compare_exchange_strong(Expected, QueueState::Draining); + ZEN_INFO("draining queue {}", QueueId); +} + +ComputeServiceSession::EnqueueResult +ComputeServiceSession::Impl::EnqueueActionToQueue(int QueueId, CbObject ActionObject, int Priority) +{ + Ref<QueueEntry> Queue = FindQueue(QueueId); + + if (!Queue) + { + CbObjectWriter Writer; + Writer << "error"sv + << "queue not found"sv; + return {0, Writer.Save()}; + } + + QueueState QState = Queue->State.load(); + if (QState == QueueState::Cancelled) + { + CbObjectWriter Writer; + Writer << "error"sv + << "queue is cancelled"sv; + return {0, Writer.Save()}; + } + + if (QState == QueueState::Draining) + { + CbObjectWriter Writer; + Writer << "error"sv + << "queue is draining"sv; + return {0, Writer.Save()}; + } + + EnqueueResult Result = EnqueueAction(QueueId, ActionObject, Priority); + + if (Result.Lsn != 0) + { + Queue->m_Lock.WithExclusiveLock([&] { Queue->ActiveLsns.insert(Result.Lsn); }); + Queue->ActiveCount.fetch_add(1, std::memory_order_relaxed); + Queue->IdleSince.store(0, std::memory_order_relaxed); + } + + return Result; +} + +ComputeServiceSession::EnqueueResult +ComputeServiceSession::Impl::EnqueueResolvedActionToQueue(int QueueId, WorkerDesc Worker, CbObject ActionObj, int Priority) +{ + Ref<QueueEntry> Queue = FindQueue(QueueId); + + if (!Queue) + { + CbObjectWriter Writer; + Writer << "error"sv + << "queue not found"sv; + return {0, Writer.Save()}; + } + + QueueState QState = Queue->State.load(); + if (QState == QueueState::Cancelled) + { + CbObjectWriter Writer; + Writer << "error"sv + << "queue is cancelled"sv; + return {0, Writer.Save()}; + } + + if (QState == QueueState::Draining) + { + CbObjectWriter Writer; + Writer << "error"sv + << "queue is draining"sv; + return {0, Writer.Save()}; + } + + EnqueueResult Result = EnqueueResolvedAction(QueueId, Worker, ActionObj, Priority); + + if (Result.Lsn != 0) + { + Queue->m_Lock.WithExclusiveLock([&] { Queue->ActiveLsns.insert(Result.Lsn); }); + Queue->ActiveCount.fetch_add(1, std::memory_order_relaxed); + Queue->IdleSince.store(0, std::memory_order_relaxed); + } + + return Result; +} + +void +ComputeServiceSession::Impl::GetQueueCompleted(int QueueId, CbWriter& Cbo) +{ + Ref<QueueEntry> Queue = FindQueue(QueueId); + + Cbo.BeginArray("completed"); + + if (Queue) + { + Queue->m_Lock.WithSharedLock([&] { + m_ResultsLock.WithSharedLock([&] { + for (int Lsn : Queue->FinishedLsns) + { + if (m_ResultsMap.contains(Lsn)) + { + Cbo << Lsn; + } + } + }); + }); + } + + Cbo.EndArray(); +} + +void +ComputeServiceSession::Impl::NotifyQueueActionComplete(int QueueId, int Lsn, RunnerAction::State ActionState) +{ + if (QueueId == 0) + { + return; + } + + Ref<QueueEntry> Queue = FindQueue(QueueId); + + if (!Queue) + { + return; + } + + Queue->m_Lock.WithExclusiveLock([&] { + Queue->ActiveLsns.erase(Lsn); + Queue->FinishedLsns.insert(Lsn); + }); + + const int PreviousActive = Queue->ActiveCount.fetch_sub(1, std::memory_order_relaxed); + if (PreviousActive == 1) + { + Queue->IdleSince.store(GetHifreqTimerValue(), std::memory_order_relaxed); + } + + switch (ActionState) + { + case RunnerAction::State::Completed: + Queue->CompletedCount.fetch_add(1, std::memory_order_relaxed); + break; + case RunnerAction::State::Abandoned: + Queue->AbandonedCount.fetch_add(1, std::memory_order_relaxed); + break; + case RunnerAction::State::Cancelled: + Queue->CancelledCount.fetch_add(1, std::memory_order_relaxed); + break; + default: + Queue->FailedCount.fetch_add(1, std::memory_order_relaxed); + break; + } +} + +void +ComputeServiceSession::Impl::ExpireCompletedQueues() +{ + static constexpr uint64_t ExpiryTimeMs = 15 * 60 * 1000; + + std::vector<int> ExpiredQueueIds; + + m_QueueLock.WithSharedLock([&] { + for (const auto& [Id, Queue] : m_Queues) + { + if (Queue->Implicit) + { + continue; + } + const uint64_t Idle = Queue->IdleSince.load(std::memory_order_relaxed); + if (Idle != 0 && Queue->ActiveCount.load(std::memory_order_relaxed) == 0) + { + const uint64_t ElapsedMs = Stopwatch::GetElapsedTimeMs(GetHifreqTimerValue() - Idle); + if (ElapsedMs >= ExpiryTimeMs) + { + ExpiredQueueIds.push_back(Id); + } + } + } + }); + + for (int QueueId : ExpiredQueueIds) + { + ZEN_INFO("expiring idle queue {}", QueueId); + DeleteQueue(QueueId); + } +} + +void +ComputeServiceSession::Impl::SchedulePendingActions() +{ + ZEN_TRACE_CPU("ComputeServiceSession::SchedulePendingActions"); + int ScheduledCount = 0; + size_t RunningCount = m_RunningLock.WithSharedLock([&] { return m_RunningMap.size(); }); + size_t PendingCount = m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); }); + size_t ResultCount = m_ResultsLock.WithSharedLock([&] { return m_ResultsMap.size(); }); + + static Stopwatch DumpRunningTimer; + + auto _ = MakeGuard([&] { + ZEN_INFO("scheduled {} pending actions. {} running ({} retired), {} still pending, {} results", + ScheduledCount, + RunningCount, + m_RetiredCount.load(), + PendingCount, + ResultCount); + + if (DumpRunningTimer.GetElapsedTimeMs() > 30000) + { + DumpRunningTimer.Reset(); + + std::set<int> RunningList; + m_RunningLock.WithSharedLock([&] { + for (auto& [K, V] : m_RunningMap) + { + RunningList.insert(K); + } + }); + + ExtendableStringBuilder<1024> RunningString; + for (int i : RunningList) + { + if (RunningString.Size()) + { + RunningString << ", "; + } + + RunningString.Append(IntNum(i)); + } + + ZEN_INFO("running: {}", RunningString); + } + }); + + size_t Capacity = QueryCapacity(); + + if (!Capacity) + { + _.Dismiss(); + + return; + } + + std::vector<Ref<RunnerAction>> ActionsToSchedule; + + // Pull actions to schedule from the pending queue, we will + // try to submit these to the runner outside of the lock. Note + // that because of how the state transitions work it's not + // actually the case that all of these actions will still be + // pending by the time we try to submit them, but that's fine. + // + // Also note that the m_PendingActions list is not maintained + // here, that's done periodically in SchedulePendingActions() + + m_PendingLock.WithExclusiveLock([&] { + if (m_SessionState.load(std::memory_order_relaxed) >= SessionState::Paused) + { + return; + } + + if (m_PendingActions.empty()) + { + return; + } + + for (auto& [Lsn, Pending] : m_PendingActions) + { + switch (Pending->ActionState()) + { + case RunnerAction::State::Pending: + ActionsToSchedule.push_back(Pending); + break; + + case RunnerAction::State::Submitting: + break; // already claimed by async submission + + case RunnerAction::State::Running: + case RunnerAction::State::Completed: + case RunnerAction::State::Failed: + case RunnerAction::State::Abandoned: + case RunnerAction::State::Cancelled: + break; + + default: + case RunnerAction::State::New: + ZEN_WARN("unexpected state {} for pending action {}", static_cast<int>(Pending->ActionState()), Pending->ActionLsn); + break; + } + } + + // Sort by priority descending, then by LSN ascending (FIFO within same priority) + std::sort(ActionsToSchedule.begin(), ActionsToSchedule.end(), [](const Ref<RunnerAction>& A, const Ref<RunnerAction>& B) { + if (A->Priority != B->Priority) + { + return A->Priority > B->Priority; + } + return A->ActionLsn < B->ActionLsn; + }); + + if (ActionsToSchedule.size() > Capacity) + { + ActionsToSchedule.resize(Capacity); + } + + PendingCount = m_PendingActions.size(); + }); + + if (ActionsToSchedule.empty()) + { + _.Dismiss(); + return; + } + + ZEN_INFO("attempting schedule of {} pending actions", ActionsToSchedule.size()); + + Stopwatch SubmitTimer; + std::vector<SubmitResult> SubmitResults = SubmitActions(ActionsToSchedule); + + int NotAcceptedCount = 0; + int ScheduledActionCount = 0; + + for (const SubmitResult& SubResult : SubmitResults) + { + if (SubResult.IsAccepted) + { + ++ScheduledActionCount; + } + else + { + ++NotAcceptedCount; + } + } + + ZEN_INFO("scheduled {} pending actions in {} ({} rejected)", + ScheduledActionCount, + NiceTimeSpanMs(SubmitTimer.GetElapsedTimeMs()), + NotAcceptedCount); + + ScheduledCount += ScheduledActionCount; + PendingCount -= ScheduledActionCount; +} + +void +ComputeServiceSession::Impl::SchedulerThreadFunction() +{ + SetCurrentThreadName("Function_Scheduler"); + + auto _ = MakeGuard([&] { ZEN_INFO("scheduler thread exiting"); }); + + do + { + int TimeoutMs = 500; + + auto PendingCount = m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); }); + + if (PendingCount) + { + TimeoutMs = 100; + } + + const bool WasSignaled = m_SchedulingThreadEvent.Wait(TimeoutMs); + + if (m_SchedulingThreadEnabled == false) + { + return; + } + + if (WasSignaled) + { + m_SchedulingThreadEvent.Reset(); + } + + ZEN_DEBUG("compute scheduler TICK (Pending: {} was {}, Running: {}, Results: {}) timeout: {}", + m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); }), + PendingCount, + m_RunningLock.WithSharedLock([&] { return m_RunningMap.size(); }), + m_ResultsLock.WithSharedLock([&] { return m_ResultsMap.size(); }), + TimeoutMs); + + HandleActionUpdates(); + + // Auto-transition Draining → Paused when all work is done + if (m_SessionState.load(std::memory_order_relaxed) == SessionState::Draining) + { + size_t Pending = m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); }); + size_t Running = m_RunningLock.WithSharedLock([&] { return m_RunningMap.size(); }); + + if (Pending == 0 && Running == 0) + { + SessionState Expected = SessionState::Draining; + if (m_SessionState.compare_exchange_strong(Expected, SessionState::Paused, std::memory_order_acq_rel)) + { + ZEN_INFO("session state: Draining -> Paused (all work completed)"); + } + } + } + + UpdateCoordinatorState(); + SchedulePendingActions(); + + static constexpr uint64_t QueueExpirySweepIntervalMs = 30000; + if (m_QueueExpiryTimer.GetElapsedTimeMs() >= QueueExpirySweepIntervalMs) + { + m_QueueExpiryTimer.Reset(); + ExpireCompletedQueues(); + } + } while (m_SchedulingThreadEnabled); +} + +void +ComputeServiceSession::Impl::PostUpdate(RunnerAction* Action) +{ + m_UpdatedActionsLock.WithExclusiveLock([&] { m_UpdatedActions.emplace_back(Action); }); + m_SchedulingThreadEvent.Set(); +} + +int +ComputeServiceSession::Impl::GetMaxRetriesForQueue(int QueueId) +{ + if (QueueId == 0) + { + return kDefaultMaxRetries; + } + + CbObject Config = GetQueueConfig(QueueId); + + if (Config) + { + int Value = Config["max_retries"].AsInt32(0); + + if (Value > 0) + { + return Value; + } + } + + return kDefaultMaxRetries; +} + +ComputeServiceSession::RescheduleResult +ComputeServiceSession::Impl::RescheduleAction(int ActionLsn) +{ + Ref<RunnerAction> Action; + RunnerAction::State State; + RescheduleResult ValidationError; + bool Removed = false; + + // Find, validate, and remove atomically under a single lock scope to prevent + // concurrent RescheduleAction calls from double-removing the same action. + m_ResultsLock.WithExclusiveLock([&] { + auto It = m_ResultsMap.find(ActionLsn); + if (It == m_ResultsMap.end()) + { + ValidationError = {.Success = false, .Error = "Action not found in results"}; + return; + } + + Action = It->second; + State = Action->ActionState(); + + if (State != RunnerAction::State::Failed && State != RunnerAction::State::Abandoned) + { + ValidationError = {.Success = false, .Error = "Action is not in a failed or abandoned state"}; + return; + } + + int MaxRetries = GetMaxRetriesForQueue(Action->QueueId); + if (Action->RetryCount.load(std::memory_order_relaxed) >= MaxRetries) + { + ValidationError = {.Success = false, .Error = "Retry limit reached"}; + return; + } + + m_ResultsMap.erase(It); + Removed = true; + }); + + if (!Removed) + { + return ValidationError; + } + + if (Action->QueueId != 0) + { + Ref<QueueEntry> Queue = FindQueue(Action->QueueId); + + if (Queue) + { + Queue->m_Lock.WithExclusiveLock([&] { + Queue->FinishedLsns.erase(ActionLsn); + Queue->ActiveLsns.insert(ActionLsn); + }); + + Queue->ActiveCount.fetch_add(1, std::memory_order_relaxed); + Queue->IdleSince.store(0, std::memory_order_relaxed); + + if (State == RunnerAction::State::Failed) + { + Queue->FailedCount.fetch_sub(1, std::memory_order_relaxed); + } + else + { + Queue->AbandonedCount.fetch_sub(1, std::memory_order_relaxed); + } + } + } + + // Reset action state — this calls PostUpdate() internally + Action->ResetActionStateToPending(); + + int NewRetryCount = Action->RetryCount.load(std::memory_order_relaxed); + ZEN_INFO("action {} ({}) manually rescheduled (retry {})", Action->ActionId, ActionLsn, NewRetryCount); + + return {.Success = true, .RetryCount = NewRetryCount}; +} + +void +ComputeServiceSession::Impl::HandleActionUpdates() +{ + ZEN_TRACE_CPU("ComputeServiceSession::HandleActionUpdates"); + + // Drain the update queue atomically + std::vector<Ref<RunnerAction>> UpdatedActions; + m_UpdatedActionsLock.WithExclusiveLock([&] { std::swap(UpdatedActions, m_UpdatedActions); }); + + std::unordered_set<int> SeenLsn; + + // Process each action's latest state, deduplicating by LSN. + // + // This is safe because state transitions are monotonically increasing by enum + // rank (Pending < Submitting < Running < Completed/Failed/Cancelled), so + // SetActionState rejects any transition to a lower-ranked state. By the time + // we read ActionState() here, it reflects the highest state reached — making + // the first occurrence per LSN authoritative and duplicates redundant. + for (Ref<RunnerAction>& Action : UpdatedActions) + { + const int ActionLsn = Action->ActionLsn; + + if (auto [It, Inserted] = SeenLsn.insert(ActionLsn); Inserted) + { + switch (Action->ActionState()) + { + // Newly enqueued — add to pending map for scheduling + case RunnerAction::State::Pending: + m_PendingLock.WithExclusiveLock([&] { m_PendingActions.insert({ActionLsn, Action}); }); + break; + + // Async submission in progress — remains in pending map + case RunnerAction::State::Submitting: + break; + + // Dispatched to a runner — move from pending to running + case RunnerAction::State::Running: + m_RunningLock.WithExclusiveLock([&] { + m_PendingLock.WithExclusiveLock([&] { + m_RunningMap.insert({ActionLsn, Action}); + m_PendingActions.erase(ActionLsn); + }); + }); + ZEN_DEBUG("action {} ({}) RUNNING", Action->ActionId, ActionLsn); + break; + + // Terminal states — move to results, record history, notify queue + case RunnerAction::State::Completed: + case RunnerAction::State::Failed: + case RunnerAction::State::Abandoned: + case RunnerAction::State::Cancelled: + { + auto TerminalState = Action->ActionState(); + + // Automatic retry for Failed/Abandoned actions with retries remaining. + // Skip retries when the session itself is abandoned — those actions + // were intentionally abandoned and should not be rescheduled. + if ((TerminalState == RunnerAction::State::Failed || TerminalState == RunnerAction::State::Abandoned) && + m_SessionState.load(std::memory_order_relaxed) < SessionState::Abandoned) + { + int MaxRetries = GetMaxRetriesForQueue(Action->QueueId); + + if (Action->RetryCount.load(std::memory_order_relaxed) < MaxRetries) + { + // Remove from whichever active map the action is in before resetting + m_RunningLock.WithExclusiveLock([&] { + m_PendingLock.WithExclusiveLock([&] { + if (auto FindIt = m_RunningMap.find(ActionLsn); FindIt == m_RunningMap.end()) + { + m_PendingActions.erase(ActionLsn); + } + else + { + m_RunningMap.erase(FindIt); + } + }); + }); + + // Reset triggers PostUpdate() which re-enters the action as Pending + Action->ResetActionStateToPending(); + int NewRetryCount = Action->RetryCount.load(std::memory_order_relaxed); + + ZEN_INFO("action {} ({}) auto-rescheduled (retry {}/{})", + Action->ActionId, + ActionLsn, + NewRetryCount, + MaxRetries); + break; + } + } + + // Remove from whichever active map the action is in + m_RunningLock.WithExclusiveLock([&] { + m_PendingLock.WithExclusiveLock([&] { + if (auto FindIt = m_RunningMap.find(ActionLsn); FindIt == m_RunningMap.end()) + { + m_PendingActions.erase(ActionLsn); + } + else + { + m_RunningMap.erase(FindIt); + } + }); + }); + + m_ResultsLock.WithExclusiveLock([&] { + m_ResultsMap[ActionLsn] = Action; + + // Append to bounded action history ring + m_ActionHistoryLock.WithExclusiveLock([&] { + ActionHistoryEntry Entry{.Lsn = ActionLsn, + .QueueId = Action->QueueId, + .ActionId = Action->ActionId, + .WorkerId = Action->Worker.WorkerId, + .ActionDescriptor = Action->ActionObj, + .ExecutionLocation = std::move(Action->ExecutionLocation), + .Succeeded = TerminalState == RunnerAction::State::Completed, + .CpuSeconds = Action->CpuSeconds.load(std::memory_order_relaxed), + .RetryCount = Action->RetryCount.load(std::memory_order_relaxed)}; + + std::copy(std::begin(Action->Timestamps), std::end(Action->Timestamps), std::begin(Entry.Timestamps)); + + m_ActionHistory.push_back(std::move(Entry)); + + if (m_ActionHistory.size() > m_HistoryLimit) + { + m_ActionHistory.pop_front(); + } + }); + }); + m_RetiredCount.fetch_add(1); + m_ResultRate.Mark(1); + ZEN_DEBUG("action {} ({}) RUNNING -> COMPLETED with {}", + Action->ActionId, + ActionLsn, + TerminalState == RunnerAction::State::Completed ? "SUCCESS" : "FAILURE"); + NotifyQueueActionComplete(Action->QueueId, ActionLsn, TerminalState); + break; + } + } + } + } +} + +size_t +ComputeServiceSession::Impl::QueryCapacity() +{ + return m_LocalRunnerGroup.QueryCapacity() + m_RemoteRunnerGroup.QueryCapacity(); +} + +std::vector<SubmitResult> +ComputeServiceSession::Impl::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions) +{ + ZEN_TRACE_CPU("ComputeServiceSession::SubmitActions"); + std::vector<SubmitResult> Results(Actions.size()); + + // First try submitting the batch to local runners in parallel + + std::vector<SubmitResult> LocalResults = m_LocalRunnerGroup.SubmitActions(Actions); + std::vector<size_t> RemoteIndices; + std::vector<Ref<RunnerAction>> RemoteActions; + + for (size_t i = 0; i < Actions.size(); ++i) + { + if (LocalResults[i].IsAccepted) + { + Results[i] = std::move(LocalResults[i]); + } + else + { + RemoteIndices.push_back(i); + RemoteActions.push_back(Actions[i]); + } + } + + // Submit remaining actions to remote runners in parallel + if (!RemoteActions.empty()) + { + std::vector<SubmitResult> RemoteResults = m_RemoteRunnerGroup.SubmitActions(RemoteActions); + + for (size_t j = 0; j < RemoteIndices.size(); ++j) + { + Results[RemoteIndices[j]] = std::move(RemoteResults[j]); + } + } + + return Results; +} + +////////////////////////////////////////////////////////////////////////// + +ComputeServiceSession::ComputeServiceSession(ChunkResolver& InChunkResolver) +{ + m_Impl = std::make_unique<Impl>(this, InChunkResolver); +} + +ComputeServiceSession::~ComputeServiceSession() +{ + Shutdown(); +} + +bool +ComputeServiceSession::IsHealthy() +{ + return m_Impl->IsHealthy(); +} + +void +ComputeServiceSession::WaitUntilReady() +{ + m_Impl->WaitUntilReady(); +} + +void +ComputeServiceSession::Shutdown() +{ + m_Impl->Shutdown(); +} + +ComputeServiceSession::SessionState +ComputeServiceSession::GetSessionState() const +{ + return m_Impl->m_SessionState.load(std::memory_order_relaxed); +} + +bool +ComputeServiceSession::RequestStateTransition(SessionState NewState) +{ + return m_Impl->RequestStateTransition(NewState); +} + +void +ComputeServiceSession::SetOrchestratorEndpoint(std::string_view Endpoint) +{ + m_Impl->SetOrchestratorEndpoint(Endpoint); +} + +void +ComputeServiceSession::SetOrchestratorBasePath(std::filesystem::path BasePath) +{ + m_Impl->SetOrchestratorBasePath(std::move(BasePath)); +} + +void +ComputeServiceSession::StartRecording(ChunkResolver& InResolver, const std::filesystem::path& RecordingPath) +{ + m_Impl->StartRecording(InResolver, RecordingPath); +} + +void +ComputeServiceSession::StopRecording() +{ + m_Impl->StopRecording(); +} + +ComputeServiceSession::ActionCounts +ComputeServiceSession::GetActionCounts() +{ + return m_Impl->GetActionCounts(); +} + +void +ComputeServiceSession::EmitStats(CbObjectWriter& Cbo) +{ + m_Impl->EmitStats(Cbo); +} + +std::vector<IoHash> +ComputeServiceSession::GetKnownWorkerIds() +{ + return m_Impl->GetKnownWorkerIds(); +} + +WorkerDesc +ComputeServiceSession::GetWorkerDescriptor(const IoHash& WorkerId) +{ + return m_Impl->GetWorkerDescriptor(WorkerId); +} + +void +ComputeServiceSession::AddLocalRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath, int32_t MaxConcurrentActions) +{ + ZEN_TRACE_CPU("ComputeServiceSession::AddLocalRunner"); + +# if ZEN_PLATFORM_LINUX + auto* NewRunner = new LinuxProcessRunner(InChunkResolver, + BasePath, + m_Impl->m_DeferredDeleter, + m_Impl->m_LocalSubmitPool, + false, + MaxConcurrentActions); +# elif ZEN_PLATFORM_WINDOWS + auto* NewRunner = new WindowsProcessRunner(InChunkResolver, + BasePath, + m_Impl->m_DeferredDeleter, + m_Impl->m_LocalSubmitPool, + false, + MaxConcurrentActions); +# elif ZEN_PLATFORM_MAC + auto* NewRunner = + new MacProcessRunner(InChunkResolver, BasePath, m_Impl->m_DeferredDeleter, m_Impl->m_LocalSubmitPool, false, MaxConcurrentActions); +# endif + + m_Impl->SyncWorkersToRunner(*NewRunner); + m_Impl->m_LocalRunnerGroup.AddRunner(NewRunner); +} + +void +ComputeServiceSession::AddRemoteRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath, std::string_view HostName) +{ + ZEN_TRACE_CPU("ComputeServiceSession::AddRemoteRunner"); + + auto* NewRunner = new RemoteHttpRunner(InChunkResolver, BasePath, HostName, m_Impl->m_RemoteSubmitPool); + m_Impl->SyncWorkersToRunner(*NewRunner); + m_Impl->m_RemoteRunnerGroup.AddRunner(NewRunner); +} + +ComputeServiceSession::EnqueueResult +ComputeServiceSession::EnqueueAction(CbObject ActionObject, int Priority) +{ + return m_Impl->EnqueueActionToQueue(m_Impl->m_ImplicitQueueId, ActionObject, Priority); +} + +ComputeServiceSession::EnqueueResult +ComputeServiceSession::EnqueueResolvedAction(WorkerDesc Worker, CbObject ActionObj, int RequestPriority) +{ + return m_Impl->EnqueueResolvedActionToQueue(m_Impl->m_ImplicitQueueId, Worker, ActionObj, RequestPriority); +} +ComputeServiceSession::CreateQueueResult +ComputeServiceSession::CreateQueue(std::string_view Tag, CbObject Metadata, CbObject Config) +{ + return m_Impl->CreateQueue(Tag, std::move(Metadata), std::move(Config)); +} + +CbObject +ComputeServiceSession::GetQueueMetadata(int QueueId) +{ + return m_Impl->GetQueueMetadata(QueueId); +} + +CbObject +ComputeServiceSession::GetQueueConfig(int QueueId) +{ + return m_Impl->GetQueueConfig(QueueId); +} + +std::vector<int> +ComputeServiceSession::GetQueueIds() +{ + return m_Impl->GetQueueIds(); +} + +ComputeServiceSession::QueueStatus +ComputeServiceSession::GetQueueStatus(int QueueId) +{ + return m_Impl->GetQueueStatus(QueueId); +} + +void +ComputeServiceSession::CancelQueue(int QueueId) +{ + m_Impl->CancelQueue(QueueId); +} + +void +ComputeServiceSession::DrainQueue(int QueueId) +{ + m_Impl->DrainQueue(QueueId); +} + +void +ComputeServiceSession::DeleteQueue(int QueueId) +{ + m_Impl->DeleteQueue(QueueId); +} + +void +ComputeServiceSession::GetQueueCompleted(int QueueId, CbWriter& Cbo) +{ + m_Impl->GetQueueCompleted(QueueId, Cbo); +} + +ComputeServiceSession::EnqueueResult +ComputeServiceSession::EnqueueActionToQueue(int QueueId, CbObject ActionObject, int Priority) +{ + return m_Impl->EnqueueActionToQueue(QueueId, ActionObject, Priority); +} + +ComputeServiceSession::EnqueueResult +ComputeServiceSession::EnqueueResolvedActionToQueue(int QueueId, WorkerDesc Worker, CbObject ActionObj, int RequestPriority) +{ + return m_Impl->EnqueueResolvedActionToQueue(QueueId, Worker, ActionObj, RequestPriority); +} + +void +ComputeServiceSession::RegisterWorker(CbPackage Worker) +{ + m_Impl->RegisterWorker(Worker); +} + +HttpResponseCode +ComputeServiceSession::GetActionResult(int ActionLsn, CbPackage& OutResultPackage) +{ + return m_Impl->GetActionResult(ActionLsn, OutResultPackage); +} + +HttpResponseCode +ComputeServiceSession::FindActionResult(const IoHash& ActionId, CbPackage& OutResultPackage) +{ + return m_Impl->FindActionResult(ActionId, OutResultPackage); +} + +void +ComputeServiceSession::RetireActionResult(int ActionLsn) +{ + m_Impl->RetireActionResult(ActionLsn); +} + +ComputeServiceSession::RescheduleResult +ComputeServiceSession::RescheduleAction(int ActionLsn) +{ + return m_Impl->RescheduleAction(ActionLsn); +} + +std::vector<ComputeServiceSession::RunningActionInfo> +ComputeServiceSession::GetRunningActions() +{ + return m_Impl->GetRunningActions(); +} + +std::vector<ComputeServiceSession::ActionHistoryEntry> +ComputeServiceSession::GetActionHistory(int Limit) +{ + return m_Impl->GetActionHistory(Limit); +} + +std::vector<ComputeServiceSession::ActionHistoryEntry> +ComputeServiceSession::GetQueueHistory(int QueueId, int Limit) +{ + return m_Impl->GetQueueHistory(QueueId, Limit); +} + +void +ComputeServiceSession::GetCompleted(CbWriter& Cbo) +{ + m_Impl->GetCompleted(Cbo); +} + +void +ComputeServiceSession::PostUpdate(RunnerAction* Action) +{ + m_Impl->PostUpdate(Action); +} + +////////////////////////////////////////////////////////////////////////// + +void +computeservice_forcelink() +{ +} + +} // namespace zen::compute + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/functionrunner.cpp b/src/zencompute/functionrunner.cpp deleted file mode 100644 index 8e7c12b2b..000000000 --- a/src/zencompute/functionrunner.cpp +++ /dev/null @@ -1,112 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#include "functionrunner.h" - -#if ZEN_WITH_COMPUTE_SERVICES - -# include <zencore/compactbinary.h> -# include <zencore/filesystem.h> - -# include <fmt/format.h> -# include <vector> - -namespace zen::compute { - -FunctionRunner::FunctionRunner(std::filesystem::path BasePath) : m_ActionsPath(BasePath / "actions") -{ -} - -FunctionRunner::~FunctionRunner() = default; - -size_t -FunctionRunner::QueryCapacity() -{ - return 1; -} - -std::vector<SubmitResult> -FunctionRunner::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions) -{ - std::vector<SubmitResult> Results; - Results.reserve(Actions.size()); - - for (const Ref<RunnerAction>& Action : Actions) - { - Results.push_back(SubmitAction(Action)); - } - - return Results; -} - -void -FunctionRunner::MaybeDumpAction(int ActionLsn, const CbObject& ActionObject) -{ - if (m_DumpActions) - { - std::string UniqueId = fmt::format("{}.ddb", ActionLsn); - std::filesystem::path Path = m_ActionsPath / UniqueId; - - zen::WriteFile(Path, IoBuffer(ActionObject.GetBuffer().AsIoBuffer())); - } -} - -////////////////////////////////////////////////////////////////////////// - -RunnerAction::RunnerAction(FunctionServiceSession* OwnerSession) : m_OwnerSession(OwnerSession) -{ - this->Timestamps[static_cast<int>(State::New)] = DateTime::Now().GetTicks(); -} - -RunnerAction::~RunnerAction() -{ -} - -void -RunnerAction::SetActionState(State NewState) -{ - ZEN_ASSERT(NewState < State::_Count); - this->Timestamps[static_cast<int>(NewState)] = DateTime::Now().GetTicks(); - - do - { - if (State CurrentState = m_ActionState.load(); CurrentState == NewState) - { - // No state change - return; - } - else - { - if (NewState <= CurrentState) - { - // Cannot transition to an earlier or same state - return; - } - - if (m_ActionState.compare_exchange_strong(CurrentState, NewState)) - { - // Successful state change - - m_OwnerSession->PostUpdate(this); - - return; - } - } - } while (true); -} - -void -RunnerAction::SetResult(CbPackage&& Result) -{ - m_Result = std::move(Result); -} - -CbPackage& -RunnerAction::GetResult() -{ - ZEN_ASSERT(IsCompleted()); - return m_Result; -} - -} // namespace zen::compute - -#endif // ZEN_WITH_COMPUTE_SERVICES
\ No newline at end of file diff --git a/src/zencompute/functionrunner.h b/src/zencompute/functionrunner.h deleted file mode 100644 index 6fd0d84cc..000000000 --- a/src/zencompute/functionrunner.h +++ /dev/null @@ -1,207 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#pragma once - -#include <zencompute/functionservice.h> - -#if ZEN_WITH_COMPUTE_SERVICES - -# include <filesystem> -# include <vector> - -namespace zen::compute { - -struct SubmitResult -{ - bool IsAccepted = false; - std::string Reason; -}; - -/** Base interface for classes implementing a remote execution "runner" - */ -class FunctionRunner : public RefCounted -{ - FunctionRunner(FunctionRunner&&) = delete; - FunctionRunner& operator=(FunctionRunner&&) = delete; - -public: - FunctionRunner(std::filesystem::path BasePath); - virtual ~FunctionRunner() = 0; - - virtual void Shutdown() = 0; - virtual void RegisterWorker(const CbPackage& WorkerPackage) = 0; - - [[nodiscard]] virtual SubmitResult SubmitAction(Ref<RunnerAction> Action) = 0; - [[nodiscard]] virtual size_t GetSubmittedActionCount() = 0; - [[nodiscard]] virtual bool IsHealthy() = 0; - [[nodiscard]] virtual size_t QueryCapacity(); - [[nodiscard]] virtual std::vector<SubmitResult> SubmitActions(const std::vector<Ref<RunnerAction>>& Actions); - -protected: - std::filesystem::path m_ActionsPath; - bool m_DumpActions = false; - void MaybeDumpAction(int ActionLsn, const CbObject& ActionObject); -}; - -template<typename RunnerType> -struct RunnerGroup -{ - void AddRunner(RunnerType* Runner) - { - m_RunnersLock.WithExclusiveLock([&] { m_Runners.emplace_back(Runner); }); - } - size_t QueryCapacity() - { - size_t TotalCapacity = 0; - m_RunnersLock.WithSharedLock([&] { - for (const auto& Runner : m_Runners) - { - TotalCapacity += Runner->QueryCapacity(); - } - }); - return TotalCapacity; - } - - SubmitResult SubmitAction(Ref<RunnerAction> Action) - { - RwLock::SharedLockScope _(m_RunnersLock); - - const int InitialIndex = m_NextSubmitIndex.load(std::memory_order_acquire); - int Index = InitialIndex; - const int RunnerCount = gsl::narrow<int>(m_Runners.size()); - - if (RunnerCount == 0) - { - return {.IsAccepted = false, .Reason = "No runners available"}; - } - - do - { - while (Index >= RunnerCount) - { - Index -= RunnerCount; - } - - auto& Runner = m_Runners[Index++]; - - SubmitResult Result = Runner->SubmitAction(Action); - - if (Result.IsAccepted == true) - { - m_NextSubmitIndex = Index % RunnerCount; - - return Result; - } - - while (Index >= RunnerCount) - { - Index -= RunnerCount; - } - } while (Index != InitialIndex); - - return {.IsAccepted = false}; - } - - size_t GetSubmittedActionCount() - { - RwLock::SharedLockScope _(m_RunnersLock); - - size_t TotalCount = 0; - - for (const auto& Runner : m_Runners) - { - TotalCount += Runner->GetSubmittedActionCount(); - } - - return TotalCount; - } - - void RegisterWorker(CbPackage Worker) - { - RwLock::SharedLockScope _(m_RunnersLock); - - for (auto& Runner : m_Runners) - { - Runner->RegisterWorker(Worker); - } - } - - void Shutdown() - { - RwLock::SharedLockScope _(m_RunnersLock); - - for (auto& Runner : m_Runners) - { - Runner->Shutdown(); - } - } - -private: - RwLock m_RunnersLock; - std::vector<Ref<RunnerType>> m_Runners; - std::atomic<int> m_NextSubmitIndex{0}; -}; - -/** - * This represents an action going through different stages of scheduling and execution. - */ -struct RunnerAction : public RefCounted -{ - explicit RunnerAction(FunctionServiceSession* OwnerSession); - ~RunnerAction(); - - int ActionLsn = 0; - WorkerDesc Worker; - IoHash ActionId; - CbObject ActionObj; - int Priority = 0; - - enum class State - { - New, - Pending, - Running, - Completed, - Failed, - _Count - }; - - static const char* ToString(State _) - { - switch (_) - { - case State::New: - return "New"; - case State::Pending: - return "Pending"; - case State::Running: - return "Running"; - case State::Completed: - return "Completed"; - case State::Failed: - return "Failed"; - default: - return "Unknown"; - } - } - - uint64_t Timestamps[static_cast<int>(State::_Count)] = {}; - - State ActionState() const { return m_ActionState; } - void SetActionState(State NewState); - - bool IsSuccess() const { return ActionState() == State::Completed; } - bool IsCompleted() const { return ActionState() == State::Completed || ActionState() == State::Failed; } - - void SetResult(CbPackage&& Result); - CbPackage& GetResult(); - -private: - std::atomic<State> m_ActionState = State::New; - FunctionServiceSession* m_OwnerSession = nullptr; - CbPackage m_Result; -}; - -} // namespace zen::compute - -#endif // ZEN_WITH_COMPUTE_SERVICES
\ No newline at end of file diff --git a/src/zencompute/functionservice.cpp b/src/zencompute/functionservice.cpp deleted file mode 100644 index 0698449e9..000000000 --- a/src/zencompute/functionservice.cpp +++ /dev/null @@ -1,957 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#include "zencompute/functionservice.h" - -#if ZEN_WITH_COMPUTE_SERVICES - -# include "functionrunner.h" -# include "actionrecorder.h" -# include "localrunner.h" -# include "remotehttprunner.h" - -# include <zencompute/recordingreader.h> -# include <zencore/compactbinary.h> -# include <zencore/compactbinarybuilder.h> -# include <zencore/compactbinarypackage.h> -# include <zencore/compress.h> -# include <zencore/except.h> -# include <zencore/filesystem.h> -# include <zencore/fmtutils.h> -# include <zencore/iobuffer.h> -# include <zencore/iohash.h> -# include <zencore/logging.h> -# include <zencore/scopeguard.h> -# include <zentelemetry/stats.h> - -# include <set> -# include <deque> -# include <map> -# include <thread> -# include <unordered_map> - -ZEN_THIRD_PARTY_INCLUDES_START -# include <EASTL/hash_set.h> -ZEN_THIRD_PARTY_INCLUDES_END - -using namespace std::literals; - -namespace zen::compute { - -////////////////////////////////////////////////////////////////////////// - -struct FunctionServiceSession::Impl -{ - FunctionServiceSession* m_FunctionServiceSession; - ChunkResolver& m_ChunkResolver; - LoggerRef m_Log{logging::Get("apply")}; - - Impl(FunctionServiceSession* InFunctionServiceSession, ChunkResolver& InChunkResolver) - : m_FunctionServiceSession(InFunctionServiceSession) - , m_ChunkResolver(InChunkResolver) - { - m_SchedulingThread = std::thread{&Impl::MonitorThreadFunction, this}; - } - - void Shutdown(); - bool IsHealthy(); - - LoggerRef Log() { return m_Log; } - - std::atomic_bool m_AcceptActions = true; - - struct FunctionDefinition - { - std::string FunctionName; - Guid FunctionVersion; - Guid BuildSystemVersion; - IoHash WorkerId; - }; - - void EmitStats(CbObjectWriter& Cbo) - { - m_WorkerLock.WithSharedLock([&] { Cbo << "worker_count"sv << m_WorkerMap.size(); }); - m_ResultsLock.WithSharedLock([&] { Cbo << "actions_complete"sv << m_ResultsMap.size(); }); - m_PendingLock.WithSharedLock([&] { Cbo << "actions_pending"sv << m_PendingActions.size(); }); - Cbo << "actions_submitted"sv << GetSubmittedActionCount(); - EmitSnapshot("actions_retired"sv, m_ResultRate, Cbo); - } - - void RegisterWorker(CbPackage Worker); - WorkerDesc GetWorkerDescriptor(const IoHash& WorkerId); - - std::atomic<int32_t> m_ActionsCounter = 0; // sequence number - - RwLock m_PendingLock; - std::map<int, Ref<RunnerAction>> m_PendingActions; - - RwLock m_RunningLock; - std::unordered_map<int, Ref<RunnerAction>> m_RunningMap; - - RwLock m_ResultsLock; - std::unordered_map<int, Ref<RunnerAction>> m_ResultsMap; - metrics::Meter m_ResultRate; - std::atomic<uint64_t> m_RetiredCount{0}; - - HttpResponseCode GetActionResult(int ActionLsn, CbPackage& OutResultPackage); - HttpResponseCode FindActionResult(const IoHash& ActionId, CbPackage& ResultPackage); - - std::atomic<bool> m_ShutdownRequested{false}; - - std::thread m_SchedulingThread; - std::atomic<bool> m_SchedulingThreadEnabled{true}; - Event m_SchedulingThreadEvent; - - void MonitorThreadFunction(); - void SchedulePendingActions(); - - // Workers - - RwLock m_WorkerLock; - std::unordered_map<IoHash, CbPackage> m_WorkerMap; - std::vector<FunctionDefinition> m_FunctionList; - std::vector<IoHash> GetKnownWorkerIds(); - - // Runners - - RunnerGroup<LocalProcessRunner> m_LocalRunnerGroup; - RunnerGroup<RemoteHttpRunner> m_RemoteRunnerGroup; - - EnqueueResult EnqueueAction(CbObject ActionObject, int Priority); - EnqueueResult EnqueueResolvedAction(WorkerDesc Worker, CbObject ActionObj, int RequestPriority); - - void GetCompleted(CbWriter& Cbo); - - // Recording - - void StartRecording(ChunkResolver& InCidStore, const std::filesystem::path& RecordingPath); - void StopRecording(); - - std::unique_ptr<ActionRecorder> m_Recorder; - - // History tracking - - RwLock m_ActionHistoryLock; - std::deque<FunctionServiceSession::ActionHistoryEntry> m_ActionHistory; - size_t m_HistoryLimit = 1000; - - std::vector<FunctionServiceSession::ActionHistoryEntry> GetActionHistory(int Limit); - - // - - [[nodiscard]] size_t QueryCapacity(); - - [[nodiscard]] SubmitResult SubmitAction(Ref<RunnerAction> Action); - [[nodiscard]] std::vector<SubmitResult> SubmitActions(const std::vector<Ref<RunnerAction>>& Actions); - [[nodiscard]] size_t GetSubmittedActionCount(); - - // Updates - - RwLock m_UpdatedActionsLock; - std::vector<Ref<RunnerAction>> m_UpdatedActions; - - void HandleActionUpdates(); - void PostUpdate(RunnerAction* Action); - - void ShutdownRunners(); -}; - -bool -FunctionServiceSession::Impl::IsHealthy() -{ - return true; -} - -void -FunctionServiceSession::Impl::Shutdown() -{ - m_AcceptActions = false; - m_ShutdownRequested = true; - - m_SchedulingThreadEnabled = false; - m_SchedulingThreadEvent.Set(); - if (m_SchedulingThread.joinable()) - { - m_SchedulingThread.join(); - } - - ShutdownRunners(); -} - -void -FunctionServiceSession::Impl::ShutdownRunners() -{ - m_LocalRunnerGroup.Shutdown(); - m_RemoteRunnerGroup.Shutdown(); -} - -void -FunctionServiceSession::Impl::StartRecording(ChunkResolver& InCidStore, const std::filesystem::path& RecordingPath) -{ - ZEN_INFO("starting recording to '{}'", RecordingPath); - - m_Recorder = std::make_unique<ActionRecorder>(InCidStore, RecordingPath); - - ZEN_INFO("started recording to '{}'", RecordingPath); -} - -void -FunctionServiceSession::Impl::StopRecording() -{ - ZEN_INFO("stopping recording"); - - m_Recorder = nullptr; - - ZEN_INFO("stopped recording"); -} - -std::vector<FunctionServiceSession::ActionHistoryEntry> -FunctionServiceSession::Impl::GetActionHistory(int Limit) -{ - RwLock::SharedLockScope _(m_ActionHistoryLock); - - if (Limit > 0 && static_cast<size_t>(Limit) < m_ActionHistory.size()) - { - return std::vector<ActionHistoryEntry>(m_ActionHistory.end() - Limit, m_ActionHistory.end()); - } - - return std::vector<ActionHistoryEntry>(m_ActionHistory.begin(), m_ActionHistory.end()); -} - -void -FunctionServiceSession::Impl::RegisterWorker(CbPackage Worker) -{ - RwLock::ExclusiveLockScope _(m_WorkerLock); - - const IoHash& WorkerId = Worker.GetObject().GetHash(); - - if (m_WorkerMap.insert_or_assign(WorkerId, Worker).second) - { - // Note that since the convention currently is that WorkerId is equal to the hash - // of the worker descriptor there is no chance that we get a second write with a - // different descriptor. Thus we only need to call this the first time, when the - // worker is added - - m_LocalRunnerGroup.RegisterWorker(Worker); - m_RemoteRunnerGroup.RegisterWorker(Worker); - - if (m_Recorder) - { - m_Recorder->RegisterWorker(Worker); - } - - CbObject WorkerObj = Worker.GetObject(); - - // Populate worker database - - const Guid WorkerBuildSystemVersion = WorkerObj["buildsystem_version"sv].AsUuid(); - - for (auto& Item : WorkerObj["functions"sv]) - { - CbObjectView Function = Item.AsObjectView(); - - std::string_view FunctionName = Function["name"sv].AsString(); - const Guid FunctionVersion = Function["version"sv].AsUuid(); - - m_FunctionList.emplace_back(FunctionDefinition{.FunctionName = std::string{FunctionName}, - .FunctionVersion = FunctionVersion, - .BuildSystemVersion = WorkerBuildSystemVersion, - .WorkerId = WorkerId}); - } - } -} - -WorkerDesc -FunctionServiceSession::Impl::GetWorkerDescriptor(const IoHash& WorkerId) -{ - RwLock::SharedLockScope _(m_WorkerLock); - - if (auto It = m_WorkerMap.find(WorkerId); It != m_WorkerMap.end()) - { - const CbPackage& Desc = It->second; - return {Desc, WorkerId}; - } - - return {}; -} - -std::vector<IoHash> -FunctionServiceSession::Impl::GetKnownWorkerIds() -{ - std::vector<IoHash> WorkerIds; - WorkerIds.reserve(m_WorkerMap.size()); - - m_WorkerLock.WithSharedLock([&] { - for (const auto& [WorkerId, _] : m_WorkerMap) - { - WorkerIds.push_back(WorkerId); - } - }); - - return WorkerIds; -} - -FunctionServiceSession::EnqueueResult -FunctionServiceSession::Impl::EnqueueAction(CbObject ActionObject, int Priority) -{ - // Resolve function to worker - - IoHash WorkerId{IoHash::Zero}; - - std::string_view FunctionName = ActionObject["Function"sv].AsString(); - const Guid FunctionVersion = ActionObject["FunctionVersion"sv].AsUuid(); - const Guid BuildSystemVersion = ActionObject["BuildSystemVersion"sv].AsUuid(); - - for (const FunctionDefinition& FuncDef : m_FunctionList) - { - if (FuncDef.FunctionName == FunctionName && FuncDef.FunctionVersion == FunctionVersion && - FuncDef.BuildSystemVersion == BuildSystemVersion) - { - WorkerId = FuncDef.WorkerId; - - break; - } - } - - if (WorkerId == IoHash::Zero) - { - CbObjectWriter Writer; - - Writer << "Function"sv << FunctionName << "FunctionVersion"sv << FunctionVersion << "BuildSystemVersion" << BuildSystemVersion; - Writer << "error" - << "no worker matches the action specification"; - - return {0, Writer.Save()}; - } - - if (auto It = m_WorkerMap.find(WorkerId); It != m_WorkerMap.end()) - { - CbPackage WorkerPackage = It->second; - - return EnqueueResolvedAction(WorkerDesc{WorkerPackage, WorkerId}, ActionObject, Priority); - } - - CbObjectWriter Writer; - - Writer << "Function"sv << FunctionName << "FunctionVersion"sv << FunctionVersion << "BuildSystemVersion" << BuildSystemVersion; - Writer << "error" - << "no worker found despite match"; - - return {0, Writer.Save()}; -} - -FunctionServiceSession::EnqueueResult -FunctionServiceSession::Impl::EnqueueResolvedAction(WorkerDesc Worker, CbObject ActionObj, int RequestPriority) -{ - const int ActionLsn = ++m_ActionsCounter; - - Ref<RunnerAction> Pending{new RunnerAction(m_FunctionServiceSession)}; - - Pending->ActionLsn = ActionLsn; - Pending->Worker = Worker; - Pending->ActionId = ActionObj.GetHash(); - Pending->ActionObj = ActionObj; - Pending->Priority = RequestPriority; - - SubmitResult SubResult = SubmitAction(Pending); - - if (SubResult.IsAccepted) - { - // Great, the job is being taken care of by the runner - ZEN_DEBUG("direct schedule LSN {}", Pending->ActionLsn); - } - else - { - ZEN_DEBUG("action {} ({}) PENDING", Pending->ActionId, Pending->ActionLsn); - - Pending->SetActionState(RunnerAction::State::Pending); - } - - if (m_Recorder) - { - m_Recorder->RecordAction(Pending); - } - - CbObjectWriter Writer; - Writer << "lsn" << Pending->ActionLsn; - Writer << "worker" << Pending->Worker.WorkerId; - Writer << "action" << Pending->ActionId; - - return {Pending->ActionLsn, Writer.Save()}; -} - -SubmitResult -FunctionServiceSession::Impl::SubmitAction(Ref<RunnerAction> Action) -{ - // Loosely round-robin scheduling of actions across runners. - // - // It's not entirely clear what this means given that submits - // can come in across multiple threads, but it's probably better - // than always starting with the first runner. - // - // Longer term we should track the state of the individual - // runners and make decisions accordingly. - - SubmitResult Result = m_LocalRunnerGroup.SubmitAction(Action); - if (Result.IsAccepted) - { - return Result; - } - - return m_RemoteRunnerGroup.SubmitAction(Action); -} - -size_t -FunctionServiceSession::Impl::GetSubmittedActionCount() -{ - return m_LocalRunnerGroup.GetSubmittedActionCount() + m_RemoteRunnerGroup.GetSubmittedActionCount(); -} - -HttpResponseCode -FunctionServiceSession::Impl::GetActionResult(int ActionLsn, CbPackage& OutResultPackage) -{ - // This lock is held for the duration of the function since we need to - // be sure that the action doesn't change state while we are checking the - // different data structures - - RwLock::ExclusiveLockScope _(m_ResultsLock); - - if (auto It = m_ResultsMap.find(ActionLsn); It != m_ResultsMap.end()) - { - OutResultPackage = std::move(It->second->GetResult()); - - m_ResultsMap.erase(It); - - return HttpResponseCode::OK; - } - - { - RwLock::SharedLockScope __(m_PendingLock); - - if (auto FindIt = m_PendingActions.find(ActionLsn); FindIt != m_PendingActions.end()) - { - return HttpResponseCode::Accepted; - } - } - - // Lock order is important here to avoid deadlocks, RwLock m_RunningLock must - // always be taken after m_ResultsLock if both are needed - - { - RwLock::SharedLockScope __(m_RunningLock); - - if (m_RunningMap.find(ActionLsn) != m_RunningMap.end()) - { - return HttpResponseCode::Accepted; - } - } - - return HttpResponseCode::NotFound; -} - -HttpResponseCode -FunctionServiceSession::Impl::FindActionResult(const IoHash& ActionId, CbPackage& OutResultPackage) -{ - // This lock is held for the duration of the function since we need to - // be sure that the action doesn't change state while we are checking the - // different data structures - - RwLock::ExclusiveLockScope _(m_ResultsLock); - - for (auto It = begin(m_ResultsMap), End = end(m_ResultsMap); It != End; ++It) - { - if (It->second->ActionId == ActionId) - { - OutResultPackage = std::move(It->second->GetResult()); - - m_ResultsMap.erase(It); - - return HttpResponseCode::OK; - } - } - - { - RwLock::SharedLockScope __(m_PendingLock); - - for (const auto& [K, Pending] : m_PendingActions) - { - if (Pending->ActionId == ActionId) - { - return HttpResponseCode::Accepted; - } - } - } - - // Lock order is important here to avoid deadlocks, RwLock m_RunningLock must - // always be taken after m_ResultsLock if both are needed - - { - RwLock::SharedLockScope __(m_RunningLock); - - for (const auto& [K, v] : m_RunningMap) - { - if (v->ActionId == ActionId) - { - return HttpResponseCode::Accepted; - } - } - } - - return HttpResponseCode::NotFound; -} - -void -FunctionServiceSession::Impl::GetCompleted(CbWriter& Cbo) -{ - Cbo.BeginArray("completed"); - - m_ResultsLock.WithSharedLock([&] { - for (auto& Kv : m_ResultsMap) - { - Cbo << Kv.first; - } - }); - - Cbo.EndArray(); -} - -# define ZEN_BATCH_SCHEDULER 1 - -void -FunctionServiceSession::Impl::SchedulePendingActions() -{ - int ScheduledCount = 0; - size_t RunningCount = m_RunningLock.WithSharedLock([&] { return m_RunningMap.size(); }); - size_t PendingCount = m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); }); - size_t ResultCount = m_ResultsLock.WithSharedLock([&] { return m_ResultsMap.size(); }); - - static Stopwatch DumpRunningTimer; - - auto _ = MakeGuard([&] { - ZEN_INFO("scheduled {} pending actions. {} running ({} retired), {} still pending, {} results", - ScheduledCount, - RunningCount, - m_RetiredCount.load(), - PendingCount, - ResultCount); - - if (DumpRunningTimer.GetElapsedTimeMs() > 30000) - { - DumpRunningTimer.Reset(); - - std::set<int> RunningList; - m_RunningLock.WithSharedLock([&] { - for (auto& [K, V] : m_RunningMap) - { - RunningList.insert(K); - } - }); - - ExtendableStringBuilder<1024> RunningString; - for (int i : RunningList) - { - if (RunningString.Size()) - { - RunningString << ", "; - } - - RunningString.Append(IntNum(i)); - } - - ZEN_INFO("running: {}", RunningString); - } - }); - -# if ZEN_BATCH_SCHEDULER - size_t Capacity = QueryCapacity(); - - if (!Capacity) - { - _.Dismiss(); - - return; - } - - std::vector<Ref<RunnerAction>> ActionsToSchedule; - - // Pull actions to schedule from the pending queue, we will try to submit these to the runner outside of the lock - - m_PendingLock.WithExclusiveLock([&] { - if (m_ShutdownRequested) - { - return; - } - - if (m_PendingActions.empty()) - { - return; - } - - size_t NumActionsToSchedule = std::min(Capacity, m_PendingActions.size()); - - auto PendingIt = m_PendingActions.begin(); - const auto PendingEnd = m_PendingActions.end(); - - while (NumActionsToSchedule && PendingIt != PendingEnd) - { - const Ref<RunnerAction>& Pending = PendingIt->second; - - switch (Pending->ActionState()) - { - case RunnerAction::State::Pending: - ActionsToSchedule.push_back(Pending); - break; - - case RunnerAction::State::Running: - case RunnerAction::State::Completed: - case RunnerAction::State::Failed: - break; - - default: - case RunnerAction::State::New: - ZEN_WARN("unexpected state {} for pending action {}", static_cast<int>(Pending->ActionState()), Pending->ActionLsn); - break; - } - - ++PendingIt; - --NumActionsToSchedule; - } - - PendingCount = m_PendingActions.size(); - }); - - if (ActionsToSchedule.empty()) - { - _.Dismiss(); - return; - } - - ZEN_INFO("attempting schedule of {} pending actions", ActionsToSchedule.size()); - - auto SubmitResults = SubmitActions(ActionsToSchedule); - - // Move successfully scheduled actions to the running map and remove - // from pending queue. It's actually possible that by the time we get - // to this stage some of the actions may have already completed, so - // they should not always be added to the running map - - eastl::hash_set<int> ScheduledActions; - - for (size_t i = 0; i < ActionsToSchedule.size(); ++i) - { - const Ref<RunnerAction>& Pending = ActionsToSchedule[i]; - const SubmitResult& SubResult = SubmitResults[i]; - - if (SubResult.IsAccepted) - { - ScheduledActions.insert(Pending->ActionLsn); - } - } - - ScheduledCount += (int)ActionsToSchedule.size(); - -# else - m_PendingLock.WithExclusiveLock([&] { - while (!m_PendingActions.empty()) - { - if (m_ShutdownRequested) - { - return; - } - - // Here it would be good if we could decide to pop immediately to avoid - // holding the lock while creating processes etc - const Ref<RunnerAction>& Pending = m_PendingActions.begin()->second; - FunctionRunner::SubmitResult SubResult = SubmitAction(Pending); - - if (SubResult.IsAccepted) - { - // Great, the job is being taken care of by the runner - - ZEN_DEBUG("action {} ({}) PENDING -> RUNNING", Pending->ActionId, Pending->ActionLsn); - - m_RunningLock.WithExclusiveLock([&] { - m_RunningMap.insert({Pending->ActionLsn, Pending}); - - RunningCount = m_RunningMap.size(); - }); - - m_PendingActions.pop_front(); - - PendingCount = m_PendingActions.size(); - ++ScheduledCount; - } - else - { - // Runner could not accept the job, leave it on the pending queue - - return; - } - } - }); -# endif -} - -void -FunctionServiceSession::Impl::MonitorThreadFunction() -{ - SetCurrentThreadName("FunctionServiceSession_Monitor"); - - auto _ = MakeGuard([&] { ZEN_INFO("monitor thread exiting"); }); - - do - { - int TimeoutMs = 1000; - - if (m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); })) - { - TimeoutMs = 100; - } - - const bool Timedout = m_SchedulingThreadEvent.Wait(TimeoutMs); - - if (m_SchedulingThreadEnabled == false) - { - return; - } - - HandleActionUpdates(); - - // Schedule pending actions - - SchedulePendingActions(); - - if (!Timedout) - { - m_SchedulingThreadEvent.Reset(); - } - } while (m_SchedulingThreadEnabled); -} - -void -FunctionServiceSession::Impl::PostUpdate(RunnerAction* Action) -{ - m_UpdatedActionsLock.WithExclusiveLock([&] { m_UpdatedActions.emplace_back(Action); }); -} - -void -FunctionServiceSession::Impl::HandleActionUpdates() -{ - std::vector<Ref<RunnerAction>> UpdatedActions; - - m_UpdatedActionsLock.WithExclusiveLock([&] { std::swap(UpdatedActions, m_UpdatedActions); }); - - std::unordered_set<int> SeenLsn; - std::unordered_set<int> RunningLsn; - - for (Ref<RunnerAction>& Action : UpdatedActions) - { - const int ActionLsn = Action->ActionLsn; - - if (auto [It, Inserted] = SeenLsn.insert(ActionLsn); Inserted) - { - switch (Action->ActionState()) - { - case RunnerAction::State::Pending: - m_PendingLock.WithExclusiveLock([&] { m_PendingActions.insert({ActionLsn, Action}); }); - break; - - case RunnerAction::State::Running: - m_PendingLock.WithExclusiveLock([&] { - m_RunningLock.WithExclusiveLock([&] { - m_RunningMap.insert({ActionLsn, Action}); - m_PendingActions.erase(ActionLsn); - }); - }); - ZEN_DEBUG("action {} ({}) RUNNING", Action->ActionId, ActionLsn); - break; - - case RunnerAction::State::Completed: - case RunnerAction::State::Failed: - m_ResultsLock.WithExclusiveLock([&] { - m_ResultsMap[ActionLsn] = Action; - - m_PendingLock.WithExclusiveLock([&] { - m_RunningLock.WithExclusiveLock([&] { - if (auto FindIt = m_RunningMap.find(ActionLsn); FindIt == m_RunningMap.end()) - { - m_PendingActions.erase(ActionLsn); - } - else - { - m_RunningMap.erase(FindIt); - } - }); - }); - - m_ActionHistoryLock.WithExclusiveLock([&] { - ActionHistoryEntry Entry{.Lsn = ActionLsn, - .ActionId = Action->ActionId, - .WorkerId = Action->Worker.WorkerId, - .ActionDescriptor = Action->ActionObj, - .Succeeded = Action->ActionState() == RunnerAction::State::Completed}; - - std::copy(std::begin(Action->Timestamps), std::end(Action->Timestamps), std::begin(Entry.Timestamps)); - - m_ActionHistory.push_back(std::move(Entry)); - - if (m_ActionHistory.size() > m_HistoryLimit) - { - m_ActionHistory.pop_front(); - } - }); - }); - m_RetiredCount.fetch_add(1); - m_ResultRate.Mark(1); - ZEN_DEBUG("action {} ({}) RUNNING -> COMPLETED with {}", - Action->ActionId, - ActionLsn, - Action->ActionState() == RunnerAction::State::Completed ? "SUCCESS" : "FAILURE"); - break; - } - } - } -} - -size_t -FunctionServiceSession::Impl::QueryCapacity() -{ - return m_LocalRunnerGroup.QueryCapacity() + m_RemoteRunnerGroup.QueryCapacity(); -} - -std::vector<SubmitResult> -FunctionServiceSession::Impl::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions) -{ - std::vector<SubmitResult> Results; - - for (const Ref<RunnerAction>& Action : Actions) - { - Results.push_back(SubmitAction(Action)); - } - - return Results; -} - -////////////////////////////////////////////////////////////////////////// - -FunctionServiceSession::FunctionServiceSession(ChunkResolver& InChunkResolver) -{ - m_Impl = std::make_unique<Impl>(this, InChunkResolver); -} - -FunctionServiceSession::~FunctionServiceSession() -{ - Shutdown(); -} - -bool -FunctionServiceSession::IsHealthy() -{ - return m_Impl->IsHealthy(); -} - -void -FunctionServiceSession::Shutdown() -{ - m_Impl->Shutdown(); -} - -void -FunctionServiceSession::StartRecording(ChunkResolver& InResolver, const std::filesystem::path& RecordingPath) -{ - m_Impl->StartRecording(InResolver, RecordingPath); -} - -void -FunctionServiceSession::StopRecording() -{ - m_Impl->StopRecording(); -} - -void -FunctionServiceSession::EmitStats(CbObjectWriter& Cbo) -{ - m_Impl->EmitStats(Cbo); -} - -std::vector<IoHash> -FunctionServiceSession::GetKnownWorkerIds() -{ - return m_Impl->GetKnownWorkerIds(); -} - -WorkerDesc -FunctionServiceSession::GetWorkerDescriptor(const IoHash& WorkerId) -{ - return m_Impl->GetWorkerDescriptor(WorkerId); -} - -void -FunctionServiceSession::AddLocalRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath) -{ - m_Impl->m_LocalRunnerGroup.AddRunner(new LocalProcessRunner(InChunkResolver, BasePath)); -} - -void -FunctionServiceSession::AddRemoteRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath, std::string_view HostName) -{ - m_Impl->m_RemoteRunnerGroup.AddRunner(new RemoteHttpRunner(InChunkResolver, BasePath, HostName)); -} - -FunctionServiceSession::EnqueueResult -FunctionServiceSession::EnqueueAction(CbObject ActionObject, int Priority) -{ - return m_Impl->EnqueueAction(ActionObject, Priority); -} - -FunctionServiceSession::EnqueueResult -FunctionServiceSession::EnqueueResolvedAction(WorkerDesc Worker, CbObject ActionObj, int RequestPriority) -{ - return m_Impl->EnqueueResolvedAction(Worker, ActionObj, RequestPriority); -} - -void -FunctionServiceSession::RegisterWorker(CbPackage Worker) -{ - m_Impl->RegisterWorker(Worker); -} - -HttpResponseCode -FunctionServiceSession::GetActionResult(int ActionLsn, CbPackage& OutResultPackage) -{ - return m_Impl->GetActionResult(ActionLsn, OutResultPackage); -} - -HttpResponseCode -FunctionServiceSession::FindActionResult(const IoHash& ActionId, CbPackage& OutResultPackage) -{ - return m_Impl->FindActionResult(ActionId, OutResultPackage); -} - -std::vector<FunctionServiceSession::ActionHistoryEntry> -FunctionServiceSession::GetActionHistory(int Limit) -{ - return m_Impl->GetActionHistory(Limit); -} - -void -FunctionServiceSession::GetCompleted(CbWriter& Cbo) -{ - m_Impl->GetCompleted(Cbo); -} - -void -FunctionServiceSession::PostUpdate(RunnerAction* Action) -{ - m_Impl->PostUpdate(Action); -} - -////////////////////////////////////////////////////////////////////////// - -void -function_forcelink() -{ -} - -} // namespace zen::compute - -#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/httpcomputeservice.cpp b/src/zencompute/httpcomputeservice.cpp new file mode 100644 index 000000000..e82a40781 --- /dev/null +++ b/src/zencompute/httpcomputeservice.cpp @@ -0,0 +1,1643 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "zencompute/httpcomputeservice.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include "runners/functionrunner.h" + +# include <zencore/compactbinary.h> +# include <zencore/compactbinarybuilder.h> +# include <zencore/compactbinarypackage.h> +# include <zencore/compress.h> +# include <zencore/except.h> +# include <zencore/filesystem.h> +# include <zencore/fmtutils.h> +# include <zencore/iobuffer.h> +# include <zencore/iohash.h> +# include <zencore/logging.h> +# include <zencore/system.h> +# include <zencore/thread.h> +# include <zencore/trace.h> +# include <zencore/uid.h> +# include <zenstore/cidstore.h> +# include <zentelemetry/stats.h> + +# include <span> +# include <unordered_map> + +using namespace std::literals; + +namespace zen::compute { + +constinit AsciiSet g_DecimalSet("0123456789"); +constinit AsciiSet g_HexSet("0123456789abcdefABCDEF"); + +auto DecimalMatcher = [](std::string_view Str) { return AsciiSet::HasOnly(Str, g_DecimalSet); }; +auto IoHashMatcher = [](std::string_view Str) { return Str.size() == 40 && AsciiSet::HasOnly(Str, g_HexSet); }; +auto OidMatcher = [](std::string_view Str) { return Str.size() == 24 && AsciiSet::HasOnly(Str, g_HexSet); }; + +////////////////////////////////////////////////////////////////////////// + +struct HttpComputeService::Impl +{ + HttpComputeService* m_Self; + CidStore& m_CidStore; + IHttpStatsService& m_StatsService; + LoggerRef m_Log; + std::filesystem::path m_BaseDir; + HttpRequestRouter m_Router; + ComputeServiceSession m_ComputeService; + SystemMetricsTracker m_MetricsTracker; + + // Metrics + + metrics::OperationTiming m_HttpRequests; + + // Per-remote-queue metadata, shared across all lookup maps below. + + struct RemoteQueueInfo : RefCounted + { + int QueueId = 0; + Oid Token; + std::string IdempotencyKey; // empty if no idempotency key was provided + std::string ClientHostname; // empty if no hostname was provided + }; + + // Remote queue registry — all three maps share the same RemoteQueueInfo objects. + // All maps are guarded by m_RemoteQueueLock. + + RwLock m_RemoteQueueLock; + std::unordered_map<Oid, Ref<RemoteQueueInfo>, Oid::Hasher> m_RemoteQueuesByToken; // Token → info + std::unordered_map<int, Ref<RemoteQueueInfo>> m_RemoteQueuesByQueueId; // QueueId → info + std::unordered_map<std::string, Ref<RemoteQueueInfo>> m_RemoteQueuesByTag; // idempotency key → info + + LoggerRef Log() { return m_Log; } + + int ResolveQueueToken(const Oid& Token); + int ResolveQueueRef(HttpServerRequest& HttpReq, std::string_view Capture); + + struct IngestStats + { + int Count = 0; + int NewCount = 0; + uint64_t Bytes = 0; + uint64_t NewBytes = 0; + }; + + IngestStats IngestPackageAttachments(const CbPackage& Package); + bool CheckAttachments(const CbObject& ActionObj, std::vector<IoHash>& NeedList); + void HandleWorkersGet(HttpServerRequest& HttpReq); + void HandleWorkersAllGet(HttpServerRequest& HttpReq); + void WriteQueueDescription(CbWriter& Cbo, int QueueId, const ComputeServiceSession::QueueStatus& Status); + void HandleWorkerRequest(HttpServerRequest& HttpReq, const IoHash& WorkerId); + + void RegisterRoutes(); + + Impl(HttpComputeService* Self, + CidStore& InCidStore, + IHttpStatsService& StatsService, + const std::filesystem::path& BaseDir, + int32_t MaxConcurrentActions) + : m_Self(Self) + , m_CidStore(InCidStore) + , m_StatsService(StatsService) + , m_Log(logging::Get("compute")) + , m_BaseDir(BaseDir) + , m_ComputeService(InCidStore) + { + m_ComputeService.AddLocalRunner(InCidStore, m_BaseDir / "local", MaxConcurrentActions); + m_ComputeService.WaitUntilReady(); + m_StatsService.RegisterHandler("compute", *m_Self); + RegisterRoutes(); + } +}; + +////////////////////////////////////////////////////////////////////////// + +void +HttpComputeService::Impl::RegisterRoutes() +{ + m_Router.AddMatcher("lsn", DecimalMatcher); + m_Router.AddMatcher("worker", IoHashMatcher); + m_Router.AddMatcher("action", IoHashMatcher); + m_Router.AddMatcher("queue", DecimalMatcher); + m_Router.AddMatcher("oidtoken", OidMatcher); + m_Router.AddMatcher("queueref", [](std::string_view Str) { return DecimalMatcher(Str) || OidMatcher(Str); }); + + m_Router.RegisterRoute( + "ready", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + if (m_ComputeService.IsHealthy()) + { + return HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "ok"); + } + + return HttpReq.WriteResponse(HttpResponseCode::ServiceUnavailable); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "abandon", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + if (!HttpReq.IsLocalMachineRequest()) + { + return HttpReq.WriteResponse(HttpResponseCode::Forbidden); + } + + bool Success = m_ComputeService.RequestStateTransition(ComputeServiceSession::SessionState::Abandoned); + + if (Success) + { + CbObjectWriter Cbo; + Cbo << "state"sv + << "Abandoned"sv; + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + } + + CbObjectWriter Cbo; + Cbo << "error"sv + << "Cannot transition to Abandoned from current state"sv; + return HttpReq.WriteResponse(HttpResponseCode::Conflict, Cbo.Save()); + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "workers", + [this](HttpRouterRequest& Req) { HandleWorkersGet(Req.ServerRequest()); }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "workers/{worker}", + [this](HttpRouterRequest& Req) { HandleWorkerRequest(Req.ServerRequest(), IoHash::FromHexString(Req.GetCapture(1))); }, + HttpVerb::kGet | HttpVerb::kPost); + + m_Router.RegisterRoute( + "jobs/completed", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + CbObjectWriter Cbo; + m_ComputeService.GetCompleted(Cbo); + + ExtendedSystemMetrics Sm = ApplyReportingOverrides(m_MetricsTracker.Query()); + Cbo.BeginObject("metrics"); + Describe(Sm, Cbo); + Cbo.EndObject(); + + HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "jobs/history", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const auto QueryParams = HttpReq.GetQueryParams(); + + int QueryLimit = 50; + + if (auto LimitParam = QueryParams.GetValue("limit"); LimitParam.empty() == false) + { + QueryLimit = ParseInt<int>(LimitParam).value_or(50); + } + + CbObjectWriter Cbo; + Cbo.BeginArray("history"); + for (const auto& Entry : m_ComputeService.GetActionHistory(QueryLimit)) + { + Cbo.BeginObject(); + Cbo << "lsn"sv << Entry.Lsn; + Cbo << "queueId"sv << Entry.QueueId; + Cbo << "actionId"sv << Entry.ActionId; + Cbo << "workerId"sv << Entry.WorkerId; + Cbo << "succeeded"sv << Entry.Succeeded; + Cbo << "actionDescriptor"sv << Entry.ActionDescriptor; + if (Entry.CpuSeconds > 0.0f) + { + Cbo.AddFloat("cpuSeconds"sv, Entry.CpuSeconds); + } + if (Entry.RetryCount > 0) + { + Cbo << "retry_count"sv << Entry.RetryCount; + } + + for (const auto& Timestamp : Entry.Timestamps) + { + Cbo.AddInteger( + fmt::format("time_{}"sv, RunnerAction::ToString(static_cast<RunnerAction::State>(&Timestamp - Entry.Timestamps))), + Timestamp); + } + Cbo.EndObject(); + } + Cbo.EndArray(); + + HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "jobs/running", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + auto Running = m_ComputeService.GetRunningActions(); + CbObjectWriter Cbo; + Cbo.BeginArray("running"); + for (const auto& Info : Running) + { + Cbo.BeginObject(); + Cbo << "lsn"sv << Info.Lsn; + Cbo << "queueId"sv << Info.QueueId; + Cbo << "actionId"sv << Info.ActionId; + if (Info.CpuUsagePercent >= 0.0f) + { + Cbo.AddFloat("cpuUsage"sv, Info.CpuUsagePercent); + } + if (Info.CpuSeconds > 0.0f) + { + Cbo.AddFloat("cpuSeconds"sv, Info.CpuSeconds); + } + Cbo.EndObject(); + } + Cbo.EndArray(); + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "jobs/{lsn}", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const int ActionLsn = ParseInt<int>(Req.GetCapture(1)).value_or(0); + + switch (HttpReq.RequestVerb()) + { + case HttpVerb::kGet: + { + CbPackage Output; + HttpResponseCode ResponseCode = m_ComputeService.GetActionResult(ActionLsn, Output); + + if (ResponseCode == HttpResponseCode::OK) + { + HttpReq.WriteResponse(HttpResponseCode::OK, Output); + } + else + { + HttpReq.WriteResponse(ResponseCode); + } + + // Once we've initiated the response we can mark the result + // as retired, allowing the service to free any associated + // resources. Note that there still needs to be a delay + // to allow the transmission to complete, it would be better + // if we could issue this once the response is fully sent... + m_ComputeService.RetireActionResult(ActionLsn); + } + break; + + case HttpVerb::kPost: + { + auto Result = m_ComputeService.RescheduleAction(ActionLsn); + + CbObjectWriter Cbo; + if (Result.Success) + { + Cbo << "lsn"sv << ActionLsn; + Cbo << "retry_count"sv << Result.RetryCount; + HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + } + else + { + Cbo << "error"sv << Result.Error; + HttpReq.WriteResponse(HttpResponseCode::Conflict, Cbo.Save()); + } + } + break; + + default: + break; + } + }, + HttpVerb::kGet | HttpVerb::kPost); + + m_Router.RegisterRoute( + "jobs/{worker}/{action}", // This route is inefficient, and is only here for backwards compatibility. The preferred path is the + // one which uses the scheduled action lsn for lookups + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const IoHash ActionId = IoHash::FromHexString(Req.GetCapture(2)); + + CbPackage Output; + if (HttpResponseCode ResponseCode = m_ComputeService.FindActionResult(ActionId, /* out */ Output); + ResponseCode != HttpResponseCode::OK) + { + ZEN_TRACE("jobs/{}/{}: {}", Req.GetCapture(1), Req.GetCapture(2), ToString(ResponseCode)) + + if (ResponseCode == HttpResponseCode::NotFound) + { + return HttpReq.WriteResponse(ResponseCode); + } + + return HttpReq.WriteResponse(ResponseCode); + } + + ZEN_DEBUG("jobs/{}/{}: OK", Req.GetCapture(1), Req.GetCapture(2)) + + return HttpReq.WriteResponse(HttpResponseCode::OK, Output); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "jobs/{worker}", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const IoHash WorkerId = IoHash::FromHexString(Req.GetCapture(1)); + + WorkerDesc Worker = m_ComputeService.GetWorkerDescriptor(WorkerId); + + if (!Worker) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + + const auto QueryParams = Req.ServerRequest().GetQueryParams(); + + int RequestPriority = -1; + + if (auto PriorityParam = QueryParams.GetValue("priority"); PriorityParam.empty() == false) + { + RequestPriority = ParseInt<int>(PriorityParam).value_or(-1); + } + + switch (HttpReq.RequestVerb()) + { + case HttpVerb::kGet: + // TODO: return status of all pending or executing jobs + break; + + case HttpVerb::kPost: + switch (HttpReq.RequestContentType()) + { + case HttpContentType::kCbObject: + { + // This operation takes the proposed job spec and identifies which + // chunks are not present on this server. This list is then returned in + // the "need" list in the response + + IoBuffer Payload = HttpReq.ReadPayload(); + CbObject ActionObj = LoadCompactBinaryObject(Payload); + + std::vector<IoHash> NeedList; + + ActionObj.IterateAttachments([&](CbFieldView Field) { + const IoHash FileHash = Field.AsHash(); + + if (!m_CidStore.ContainsChunk(FileHash)) + { + NeedList.push_back(FileHash); + } + }); + + if (NeedList.empty()) + { + // We already have everything, enqueue the action for execution + + if (ComputeServiceSession::EnqueueResult Result = + m_ComputeService.EnqueueResolvedAction(Worker, ActionObj, RequestPriority)) + { + ZEN_DEBUG("action {} accepted (lsn {})", ActionObj.GetHash(), Result.Lsn); + + HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); + } + + return; + } + + CbObjectWriter Cbo; + Cbo.BeginArray("need"); + + for (const IoHash& Hash : NeedList) + { + Cbo << Hash; + } + + Cbo.EndArray(); + CbObject Response = Cbo.Save(); + + return HttpReq.WriteResponse(HttpResponseCode::NotFound, Response); + } + break; + + case HttpContentType::kCbPackage: + { + CbPackage Action = HttpReq.ReadPayloadPackage(); + CbObject ActionObj = Action.GetObject(); + + std::span<const CbAttachment> Attachments = Action.GetAttachments(); + + int AttachmentCount = 0; + int NewAttachmentCount = 0; + uint64_t TotalAttachmentBytes = 0; + uint64_t TotalNewBytes = 0; + + for (const CbAttachment& Attachment : Attachments) + { + ZEN_ASSERT(Attachment.IsCompressedBinary()); + + const IoHash DataHash = Attachment.GetHash(); + CompressedBuffer DataView = Attachment.AsCompressedBinary(); + + ZEN_UNUSED(DataHash); + + const uint64_t CompressedSize = DataView.GetCompressedSize(); + + TotalAttachmentBytes += CompressedSize; + ++AttachmentCount; + + const CidStore::InsertResult InsertResult = + m_CidStore.AddChunk(DataView.GetCompressed().Flatten().AsIoBuffer(), DataHash); + + if (InsertResult.New) + { + TotalNewBytes += CompressedSize; + ++NewAttachmentCount; + } + } + + if (ComputeServiceSession::EnqueueResult Result = + m_ComputeService.EnqueueResolvedAction(Worker, ActionObj, RequestPriority)) + { + ZEN_DEBUG("accepted action {} (lsn {}): {} in {} attachments. {} new ({} attachments)", + ActionObj.GetHash(), + Result.Lsn, + zen::NiceBytes(TotalAttachmentBytes), + AttachmentCount, + zen::NiceBytes(TotalNewBytes), + NewAttachmentCount); + + HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); + } + + return; + } + break; + + default: + break; + } + break; + + default: + break; + } + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "jobs", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + const auto QueryParams = HttpReq.GetQueryParams(); + + int RequestPriority = -1; + + if (auto PriorityParam = QueryParams.GetValue("priority"); PriorityParam.empty() == false) + { + RequestPriority = ParseInt<int>(PriorityParam).value_or(-1); + } + + // Resolve worker + + // + + switch (HttpReq.RequestContentType()) + { + case HttpContentType::kCbObject: + { + // This operation takes the proposed job spec and identifies which + // chunks are not present on this server. This list is then returned in + // the "need" list in the response + + IoBuffer Payload = HttpReq.ReadPayload(); + CbObject ActionObj = LoadCompactBinaryObject(Payload); + + std::vector<IoHash> NeedList; + + ActionObj.IterateAttachments([&](CbFieldView Field) { + const IoHash FileHash = Field.AsHash(); + + if (!m_CidStore.ContainsChunk(FileHash)) + { + NeedList.push_back(FileHash); + } + }); + + if (NeedList.empty()) + { + // We already have everything, enqueue the action for execution + + if (ComputeServiceSession::EnqueueResult Result = m_ComputeService.EnqueueAction(ActionObj, RequestPriority)) + { + ZEN_DEBUG("action accepted (lsn {})", Result.Lsn); + + return HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); + } + else + { + // Could not resolve? + return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage); + } + } + + CbObjectWriter Cbo; + Cbo.BeginArray("need"); + + for (const IoHash& Hash : NeedList) + { + Cbo << Hash; + } + + Cbo.EndArray(); + CbObject Response = Cbo.Save(); + + return HttpReq.WriteResponse(HttpResponseCode::NotFound, Response); + } + + case HttpContentType::kCbPackage: + { + CbPackage Action = HttpReq.ReadPayloadPackage(); + CbObject ActionObj = Action.GetObject(); + + std::span<const CbAttachment> Attachments = Action.GetAttachments(); + + int AttachmentCount = 0; + int NewAttachmentCount = 0; + uint64_t TotalAttachmentBytes = 0; + uint64_t TotalNewBytes = 0; + + for (const CbAttachment& Attachment : Attachments) + { + ZEN_ASSERT(Attachment.IsCompressedBinary()); + + const IoHash DataHash = Attachment.GetHash(); + CompressedBuffer DataView = Attachment.AsCompressedBinary(); + + ZEN_UNUSED(DataHash); + + const uint64_t CompressedSize = DataView.GetCompressedSize(); + + TotalAttachmentBytes += CompressedSize; + ++AttachmentCount; + + const CidStore::InsertResult InsertResult = + m_CidStore.AddChunk(DataView.GetCompressed().Flatten().AsIoBuffer(), DataHash); + + if (InsertResult.New) + { + TotalNewBytes += CompressedSize; + ++NewAttachmentCount; + } + } + + if (ComputeServiceSession::EnqueueResult Result = m_ComputeService.EnqueueAction(ActionObj, RequestPriority)) + { + ZEN_DEBUG("accepted action (lsn {}): {} in {} attachments. {} new ({} attachments)", + Result.Lsn, + zen::NiceBytes(TotalAttachmentBytes), + AttachmentCount, + zen::NiceBytes(TotalNewBytes), + NewAttachmentCount); + + HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); + } + else + { + // Could not resolve? + return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage); + } + } + return; + } + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "workers/all", + [this](HttpRouterRequest& Req) { HandleWorkersAllGet(Req.ServerRequest()); }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "queues/{queueref}/workers", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + if (ResolveQueueRef(HttpReq, Req.GetCapture(1)) == 0) + return; + HandleWorkersGet(HttpReq); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "queues/{queueref}/workers/all", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + if (ResolveQueueRef(HttpReq, Req.GetCapture(1)) == 0) + return; + HandleWorkersAllGet(HttpReq); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "queues/{queueref}/workers/{worker}", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + if (ResolveQueueRef(HttpReq, Req.GetCapture(1)) == 0) + return; + HandleWorkerRequest(HttpReq, IoHash::FromHexString(Req.GetCapture(2))); + }, + HttpVerb::kGet | HttpVerb::kPost); + + m_Router.RegisterRoute( + "sysinfo", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + ExtendedSystemMetrics Sm = ApplyReportingOverrides(m_MetricsTracker.Query()); + + CbObjectWriter Cbo; + Describe(Sm, Cbo); + + Cbo << "cpu_usage" << Sm.CpuUsagePercent; + Cbo << "memory_total" << Sm.SystemMemoryMiB * 1024 * 1024; + Cbo << "memory_used" << (Sm.SystemMemoryMiB - Sm.AvailSystemMemoryMiB) * 1024 * 1024; + Cbo << "disk_used" << 100 * 1024; + Cbo << "disk_total" << 100 * 1024 * 1024; + + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "record/start", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + if (!HttpReq.IsLocalMachineRequest()) + { + return HttpReq.WriteResponse(HttpResponseCode::Forbidden); + } + + m_ComputeService.StartRecording(m_CidStore, m_BaseDir / "recording"); + + return HttpReq.WriteResponse(HttpResponseCode::OK); + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "record/stop", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + if (!HttpReq.IsLocalMachineRequest()) + { + return HttpReq.WriteResponse(HttpResponseCode::Forbidden); + } + + m_ComputeService.StopRecording(); + + return HttpReq.WriteResponse(HttpResponseCode::OK); + }, + HttpVerb::kPost); + + // Local-only queue listing and creation + + m_Router.RegisterRoute( + "queues", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + if (!HttpReq.IsLocalMachineRequest()) + { + return HttpReq.WriteResponse(HttpResponseCode::Forbidden); + } + + switch (HttpReq.RequestVerb()) + { + case HttpVerb::kGet: + { + CbObjectWriter Cbo; + Cbo.BeginArray("queues"sv); + + for (const int QueueId : m_ComputeService.GetQueueIds()) + { + ComputeServiceSession::QueueStatus Status = m_ComputeService.GetQueueStatus(QueueId); + + if (!Status.IsValid) + { + continue; + } + + Cbo.BeginObject(); + WriteQueueDescription(Cbo, QueueId, Status); + Cbo.EndObject(); + } + + Cbo.EndArray(); + + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + } + + case HttpVerb::kPost: + { + CbObject Metadata; + CbObject Config; + if (const CbObject Body = HttpReq.ReadPayloadObject()) + { + Metadata = Body.Find("metadata"sv).AsObject(); + Config = Body.Find("config"sv).AsObject(); + } + + ComputeServiceSession::CreateQueueResult Result = + m_ComputeService.CreateQueue({}, std::move(Metadata), std::move(Config)); + + CbObjectWriter Cbo; + Cbo << "queue_id"sv << Result.QueueId; + + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + } + + default: + break; + } + }, + HttpVerb::kGet | HttpVerb::kPost); + + // Queue creation routes — these remain separate since local creates a plain queue + // while remote additionally generates an OID token for external access. + + m_Router.RegisterRoute( + "queues/remote", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + // Extract optional fields from the request body. + // idempotency_key: when present, we return the existing remote queue token for this + // key rather than creating a new queue, making the endpoint safe to call concurrently. + // hostname: human-readable origin context stored alongside the queue for diagnostics. + // metadata: arbitrary CbObject metadata propagated from the originating queue. + // config: arbitrary CbObject config propagated from the originating queue. + std::string IdempotencyKey; + std::string ClientHostname; + CbObject Metadata; + CbObject Config; + if (const CbObject Body = HttpReq.ReadPayloadObject()) + { + IdempotencyKey = std::string(Body["idempotency_key"sv].AsString()); + ClientHostname = std::string(Body["hostname"sv].AsString()); + Metadata = Body.Find("metadata"sv).AsObject(); + Config = Body.Find("config"sv).AsObject(); + } + + // Stamp the forwarding node's hostname into the metadata so that the + // remote side knows which node originated the queue. + if (!ClientHostname.empty()) + { + CbObjectWriter MetaWriter; + for (auto Field : Metadata) + { + MetaWriter.AddField(Field.GetName(), Field); + } + MetaWriter << "via"sv << ClientHostname; + Metadata = MetaWriter.Save(); + } + + RwLock::ExclusiveLockScope _(m_RemoteQueueLock); + + if (!IdempotencyKey.empty()) + { + if (auto It = m_RemoteQueuesByTag.find(IdempotencyKey); It != m_RemoteQueuesByTag.end()) + { + Ref<RemoteQueueInfo> Existing = It->second; + if (m_ComputeService.GetQueueStatus(Existing->QueueId).IsValid) + { + CbObjectWriter Cbo; + Cbo << "queue_token"sv << Existing->Token.ToString(); + Cbo << "queue_id"sv << Existing->QueueId; + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + } + + // Queue has since expired — clean up stale entries and fall through to create a new one + m_RemoteQueuesByToken.erase(Existing->Token); + m_RemoteQueuesByQueueId.erase(Existing->QueueId); + m_RemoteQueuesByTag.erase(It); + } + } + + ComputeServiceSession::CreateQueueResult Result = m_ComputeService.CreateQueue({}, std::move(Metadata), std::move(Config)); + Ref<RemoteQueueInfo> InfoRef(new RemoteQueueInfo()); + InfoRef->QueueId = Result.QueueId; + InfoRef->Token = Oid::NewOid(); + InfoRef->IdempotencyKey = std::move(IdempotencyKey); + InfoRef->ClientHostname = std::move(ClientHostname); + + m_RemoteQueuesByToken[InfoRef->Token] = InfoRef; + m_RemoteQueuesByQueueId[InfoRef->QueueId] = InfoRef; + if (!InfoRef->IdempotencyKey.empty()) + { + m_RemoteQueuesByTag[InfoRef->IdempotencyKey] = InfoRef; + } + + CbObjectWriter Cbo; + Cbo << "queue_token"sv << InfoRef->Token.ToString(); + Cbo << "queue_id"sv << InfoRef->QueueId; + + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + }, + HttpVerb::kPost); + + // Unified queue routes — {queueref} accepts both local integer IDs and remote OID tokens. + // ResolveQueueRef() handles access control (local-only for integer IDs) and token resolution. + + m_Router.RegisterRoute( + "queues/{queueref}", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const int QueueId = ResolveQueueRef(HttpReq, Req.GetCapture(1)); + + if (QueueId == 0) + { + return; + } + + switch (HttpReq.RequestVerb()) + { + case HttpVerb::kGet: + { + ComputeServiceSession::QueueStatus Status = m_ComputeService.GetQueueStatus(QueueId); + + if (!Status.IsValid) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + + CbObjectWriter Cbo; + WriteQueueDescription(Cbo, QueueId, Status); + + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + } + + case HttpVerb::kDelete: + { + ComputeServiceSession::QueueStatus Status = m_ComputeService.GetQueueStatus(QueueId); + + if (!Status.IsValid) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + + m_ComputeService.CancelQueue(QueueId); + + return HttpReq.WriteResponse(HttpResponseCode::NoContent); + } + + default: + break; + } + }, + HttpVerb::kGet | HttpVerb::kDelete); + + m_Router.RegisterRoute( + "queues/{queueref}/drain", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const int QueueId = ResolveQueueRef(HttpReq, Req.GetCapture(1)); + + if (QueueId == 0) + { + return; + } + + ComputeServiceSession::QueueStatus Status = m_ComputeService.GetQueueStatus(QueueId); + + if (!Status.IsValid) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + + m_ComputeService.DrainQueue(QueueId); + + // Return updated queue status + Status = m_ComputeService.GetQueueStatus(QueueId); + + CbObjectWriter Cbo; + WriteQueueDescription(Cbo, QueueId, Status); + + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "queues/{queueref}/completed", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const int QueueId = ResolveQueueRef(HttpReq, Req.GetCapture(1)); + + if (QueueId == 0) + { + return; + } + + ComputeServiceSession::QueueStatus Status = m_ComputeService.GetQueueStatus(QueueId); + + if (!Status.IsValid) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + + CbObjectWriter Cbo; + m_ComputeService.GetQueueCompleted(QueueId, Cbo); + + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "queues/{queueref}/history", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const int QueueId = ResolveQueueRef(HttpReq, Req.GetCapture(1)); + + if (QueueId == 0) + { + return; + } + + ComputeServiceSession::QueueStatus Status = m_ComputeService.GetQueueStatus(QueueId); + + if (!Status.IsValid) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + + const auto QueryParams = HttpReq.GetQueryParams(); + + int QueryLimit = 50; + + if (auto LimitParam = QueryParams.GetValue("limit"); LimitParam.empty() == false) + { + QueryLimit = ParseInt<int>(LimitParam).value_or(50); + } + + CbObjectWriter Cbo; + Cbo.BeginArray("history"); + for (const auto& Entry : m_ComputeService.GetQueueHistory(QueueId, QueryLimit)) + { + Cbo.BeginObject(); + Cbo << "lsn"sv << Entry.Lsn; + Cbo << "queueId"sv << Entry.QueueId; + Cbo << "actionId"sv << Entry.ActionId; + Cbo << "workerId"sv << Entry.WorkerId; + Cbo << "succeeded"sv << Entry.Succeeded; + if (Entry.CpuSeconds > 0.0f) + { + Cbo.AddFloat("cpuSeconds"sv, Entry.CpuSeconds); + } + if (Entry.RetryCount > 0) + { + Cbo << "retry_count"sv << Entry.RetryCount; + } + + for (const auto& Timestamp : Entry.Timestamps) + { + Cbo.AddInteger( + fmt::format("time_{}"sv, RunnerAction::ToString(static_cast<RunnerAction::State>(&Timestamp - Entry.Timestamps))), + Timestamp); + } + Cbo.EndObject(); + } + Cbo.EndArray(); + + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "queues/{queueref}/running", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const int QueueId = ResolveQueueRef(HttpReq, Req.GetCapture(1)); + if (QueueId == 0) + { + return; + } + // Filter global running list to this queue + auto AllRunning = m_ComputeService.GetRunningActions(); + std::vector<ComputeServiceSession::RunningActionInfo> Running; + for (auto& Info : AllRunning) + if (Info.QueueId == QueueId) + Running.push_back(Info); + CbObjectWriter Cbo; + Cbo.BeginArray("running"); + for (const auto& Info : Running) + { + Cbo.BeginObject(); + Cbo << "lsn"sv << Info.Lsn; + Cbo << "queueId"sv << Info.QueueId; + Cbo << "actionId"sv << Info.ActionId; + if (Info.CpuUsagePercent >= 0.0f) + { + Cbo.AddFloat("cpuUsage"sv, Info.CpuUsagePercent); + } + if (Info.CpuSeconds > 0.0f) + { + Cbo.AddFloat("cpuSeconds"sv, Info.CpuSeconds); + } + Cbo.EndObject(); + } + Cbo.EndArray(); + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "queues/{queueref}/jobs/{worker}", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const int QueueId = ResolveQueueRef(HttpReq, Req.GetCapture(1)); + + if (QueueId == 0) + { + return; + } + + const IoHash WorkerId = IoHash::FromHexString(Req.GetCapture(2)); + WorkerDesc Worker = m_ComputeService.GetWorkerDescriptor(WorkerId); + + if (!Worker) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + + const auto QueryParams = Req.ServerRequest().GetQueryParams(); + int RequestPriority = -1; + + if (auto PriorityParam = QueryParams.GetValue("priority"); PriorityParam.empty() == false) + { + RequestPriority = ParseInt<int>(PriorityParam).value_or(-1); + } + + switch (HttpReq.RequestContentType()) + { + case HttpContentType::kCbObject: + { + IoBuffer Payload = HttpReq.ReadPayload(); + CbObject ActionObj = LoadCompactBinaryObject(Payload); + + std::vector<IoHash> NeedList; + + if (!CheckAttachments(ActionObj, NeedList)) + { + CbObjectWriter Cbo; + Cbo.BeginArray("need"); + + for (const IoHash& Hash : NeedList) + { + Cbo << Hash; + } + + Cbo.EndArray(); + + return HttpReq.WriteResponse(HttpResponseCode::NotFound, Cbo.Save()); + } + + if (ComputeServiceSession::EnqueueResult Result = + m_ComputeService.EnqueueResolvedActionToQueue(QueueId, Worker, ActionObj, RequestPriority)) + { + ZEN_DEBUG("queue {}: action {} accepted (lsn {})", QueueId, ActionObj.GetHash(), Result.Lsn); + return HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); + } + else + { + return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage); + } + } + + case HttpContentType::kCbPackage: + { + CbPackage Action = HttpReq.ReadPayloadPackage(); + CbObject ActionObj = Action.GetObject(); + + IngestStats Stats = IngestPackageAttachments(Action); + + if (ComputeServiceSession::EnqueueResult Result = + m_ComputeService.EnqueueResolvedActionToQueue(QueueId, Worker, ActionObj, RequestPriority)) + { + ZEN_DEBUG("queue {}: accepted action {} (lsn {}): {} in {} attachments. {} new ({} attachments)", + QueueId, + ActionObj.GetHash(), + Result.Lsn, + zen::NiceBytes(Stats.Bytes), + Stats.Count, + zen::NiceBytes(Stats.NewBytes), + Stats.NewCount); + + return HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); + } + else + { + return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage); + } + } + + default: + break; + } + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "queues/{queueref}/jobs", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const int QueueId = ResolveQueueRef(HttpReq, Req.GetCapture(1)); + + if (QueueId == 0) + { + return; + } + + const auto QueryParams = Req.ServerRequest().GetQueryParams(); + int RequestPriority = -1; + + if (auto PriorityParam = QueryParams.GetValue("priority"); PriorityParam.empty() == false) + { + RequestPriority = ParseInt<int>(PriorityParam).value_or(-1); + } + + switch (HttpReq.RequestContentType()) + { + case HttpContentType::kCbObject: + { + IoBuffer Payload = HttpReq.ReadPayload(); + CbObject ActionObj = LoadCompactBinaryObject(Payload); + + std::vector<IoHash> NeedList; + + if (!CheckAttachments(ActionObj, NeedList)) + { + CbObjectWriter Cbo; + Cbo.BeginArray("need"); + + for (const IoHash& Hash : NeedList) + { + Cbo << Hash; + } + + Cbo.EndArray(); + + return HttpReq.WriteResponse(HttpResponseCode::NotFound, Cbo.Save()); + } + + if (ComputeServiceSession::EnqueueResult Result = + m_ComputeService.EnqueueActionToQueue(QueueId, ActionObj, RequestPriority)) + { + ZEN_DEBUG("queue {}: action accepted (lsn {})", QueueId, Result.Lsn); + return HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); + } + else + { + return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage); + } + } + + case HttpContentType::kCbPackage: + { + CbPackage Action = HttpReq.ReadPayloadPackage(); + CbObject ActionObj = Action.GetObject(); + + IngestStats Stats = IngestPackageAttachments(Action); + + if (ComputeServiceSession::EnqueueResult Result = + m_ComputeService.EnqueueActionToQueue(QueueId, ActionObj, RequestPriority)) + { + ZEN_DEBUG("queue {}: accepted action (lsn {}): {} in {} attachments. {} new ({} attachments)", + QueueId, + Result.Lsn, + zen::NiceBytes(Stats.Bytes), + Stats.Count, + zen::NiceBytes(Stats.NewBytes), + Stats.NewCount); + + return HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); + } + else + { + return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage); + } + } + + default: + break; + } + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "queues/{queueref}/jobs/{lsn}", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const int QueueId = ResolveQueueRef(HttpReq, Req.GetCapture(1)); + const int ActionLsn = ParseInt<int>(Req.GetCapture(2)).value_or(0); + + if (QueueId == 0) + { + return; + } + + switch (HttpReq.RequestVerb()) + { + case HttpVerb::kGet: + { + ZEN_UNUSED(QueueId); + + CbPackage Output; + HttpResponseCode ResponseCode = m_ComputeService.GetActionResult(ActionLsn, Output); + + if (ResponseCode == HttpResponseCode::OK) + { + HttpReq.WriteResponse(HttpResponseCode::OK, Output); + } + else + { + HttpReq.WriteResponse(ResponseCode); + } + + m_ComputeService.RetireActionResult(ActionLsn); + } + break; + + case HttpVerb::kPost: + { + ZEN_UNUSED(QueueId); + + auto Result = m_ComputeService.RescheduleAction(ActionLsn); + + CbObjectWriter Cbo; + if (Result.Success) + { + Cbo << "lsn"sv << ActionLsn; + Cbo << "retry_count"sv << Result.RetryCount; + HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + } + else + { + Cbo << "error"sv << Result.Error; + HttpReq.WriteResponse(HttpResponseCode::Conflict, Cbo.Save()); + } + } + break; + + default: + break; + } + }, + HttpVerb::kGet | HttpVerb::kPost); +} + +////////////////////////////////////////////////////////////////////////// + +HttpComputeService::HttpComputeService(CidStore& InCidStore, + IHttpStatsService& StatsService, + const std::filesystem::path& BaseDir, + int32_t MaxConcurrentActions) +: m_Impl(std::make_unique<Impl>(this, InCidStore, StatsService, BaseDir, MaxConcurrentActions)) +{ +} + +HttpComputeService::~HttpComputeService() +{ + m_Impl->m_StatsService.UnregisterHandler("compute", *this); +} + +void +HttpComputeService::Shutdown() +{ + m_Impl->m_ComputeService.Shutdown(); +} + +ComputeServiceSession::ActionCounts +HttpComputeService::GetActionCounts() +{ + return m_Impl->m_ComputeService.GetActionCounts(); +} + +const char* +HttpComputeService::BaseUri() const +{ + return "/compute/"; +} + +void +HttpComputeService::HandleRequest(HttpServerRequest& Request) +{ + ZEN_TRACE_CPU("HttpComputeService::HandleRequest"); + metrics::OperationTiming::Scope $(m_Impl->m_HttpRequests); + + if (m_Impl->m_Router.HandleRequest(Request) == false) + { + ZEN_WARN("No route found for {0}", Request.RelativeUri()); + } +} + +void +HttpComputeService::HandleStatsRequest(HttpServerRequest& Request) +{ + CbObjectWriter Cbo; + m_Impl->m_ComputeService.EmitStats(Cbo); + + Request.WriteResponse(HttpResponseCode::OK, Cbo.Save()); +} + +////////////////////////////////////////////////////////////////////////// + +void +HttpComputeService::Impl::WriteQueueDescription(CbWriter& Cbo, int QueueId, const ComputeServiceSession::QueueStatus& Status) +{ + Cbo << "queue_id"sv << Status.QueueId; + Cbo << "active_count"sv << Status.ActiveCount; + Cbo << "completed_count"sv << Status.CompletedCount; + Cbo << "failed_count"sv << Status.FailedCount; + Cbo << "abandoned_count"sv << Status.AbandonedCount; + Cbo << "cancelled_count"sv << Status.CancelledCount; + Cbo << "state"sv << ToString(Status.State); + Cbo << "cancelled"sv << (Status.State == ComputeServiceSession::QueueState::Cancelled); + Cbo << "draining"sv << (Status.State == ComputeServiceSession::QueueState::Draining); + Cbo << "is_complete"sv << Status.IsComplete; + + if (CbObject Meta = m_ComputeService.GetQueueMetadata(QueueId)) + { + Cbo << "metadata"sv << Meta; + } + + if (CbObject Cfg = m_ComputeService.GetQueueConfig(QueueId)) + { + Cbo << "config"sv << Cfg; + } + + { + RwLock::SharedLockScope $(m_RemoteQueueLock); + if (auto It = m_RemoteQueuesByQueueId.find(QueueId); It != m_RemoteQueuesByQueueId.end()) + { + Cbo << "queue_token"sv << It->second->Token.ToString(); + if (!It->second->ClientHostname.empty()) + { + Cbo << "hostname"sv << It->second->ClientHostname; + } + } + } +} + +////////////////////////////////////////////////////////////////////////// + +int +HttpComputeService::Impl::ResolveQueueToken(const Oid& Token) +{ + RwLock::SharedLockScope $(m_RemoteQueueLock); + + auto It = m_RemoteQueuesByToken.find(Token); + + if (It != m_RemoteQueuesByToken.end()) + { + return It->second->QueueId; + } + + return 0; +} + +int +HttpComputeService::Impl::ResolveQueueRef(HttpServerRequest& HttpReq, std::string_view Capture) +{ + if (OidMatcher(Capture)) + { + // Remote OID token — accessible from any client + const Oid Token = Oid::FromHexString(Capture); + const int QueueId = ResolveQueueToken(Token); + + if (QueueId == 0) + { + HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + + return QueueId; + } + + // Local integer queue ID — restricted to local machine requests + if (!HttpReq.IsLocalMachineRequest()) + { + HttpReq.WriteResponse(HttpResponseCode::Forbidden); + return 0; + } + + return ParseInt<int>(Capture).value_or(0); +} + +HttpComputeService::Impl::IngestStats +HttpComputeService::Impl::IngestPackageAttachments(const CbPackage& Package) +{ + IngestStats Stats; + + for (const CbAttachment& Attachment : Package.GetAttachments()) + { + ZEN_ASSERT(Attachment.IsCompressedBinary()); + + const IoHash DataHash = Attachment.GetHash(); + CompressedBuffer DataView = Attachment.AsCompressedBinary(); + + ZEN_UNUSED(DataHash); + + const uint64_t CompressedSize = DataView.GetCompressedSize(); + + Stats.Bytes += CompressedSize; + ++Stats.Count; + + const CidStore::InsertResult InsertResult = m_CidStore.AddChunk(DataView.GetCompressed().Flatten().AsIoBuffer(), DataHash); + + if (InsertResult.New) + { + Stats.NewBytes += CompressedSize; + ++Stats.NewCount; + } + } + + return Stats; +} + +bool +HttpComputeService::Impl::CheckAttachments(const CbObject& ActionObj, std::vector<IoHash>& NeedList) +{ + ActionObj.IterateAttachments([&](CbFieldView Field) { + const IoHash FileHash = Field.AsHash(); + + if (!m_CidStore.ContainsChunk(FileHash)) + { + NeedList.push_back(FileHash); + } + }); + + return NeedList.empty(); +} + +void +HttpComputeService::Impl::HandleWorkersGet(HttpServerRequest& HttpReq) +{ + CbObjectWriter Cbo; + Cbo.BeginArray("workers"sv); + for (const IoHash& WorkerId : m_ComputeService.GetKnownWorkerIds()) + { + Cbo << WorkerId; + } + Cbo.EndArray(); + HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); +} + +void +HttpComputeService::Impl::HandleWorkersAllGet(HttpServerRequest& HttpReq) +{ + std::vector<IoHash> WorkerIds = m_ComputeService.GetKnownWorkerIds(); + + CbObjectWriter Cbo; + Cbo.BeginArray("workers"); + + for (const IoHash& WorkerId : WorkerIds) + { + Cbo.BeginObject(); + Cbo << "id" << WorkerId; + Cbo << "descriptor" << m_ComputeService.GetWorkerDescriptor(WorkerId).Descriptor.GetObject(); + Cbo.EndObject(); + } + + Cbo.EndArray(); + HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); +} + +void +HttpComputeService::Impl::HandleWorkerRequest(HttpServerRequest& HttpReq, const IoHash& WorkerId) +{ + switch (HttpReq.RequestVerb()) + { + case HttpVerb::kGet: + if (WorkerDesc Desc = m_ComputeService.GetWorkerDescriptor(WorkerId)) + { + return HttpReq.WriteResponse(HttpResponseCode::OK, Desc.Descriptor.GetObject()); + } + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + + case HttpVerb::kPost: + { + switch (HttpReq.RequestContentType()) + { + case HttpContentType::kCbObject: + { + CbObject WorkerSpec = HttpReq.ReadPayloadObject(); + + HashKeySet ChunkSet; + WorkerSpec.IterateAttachments([&](CbFieldView Field) { + const IoHash Hash = Field.AsHash(); + ChunkSet.AddHashToSet(Hash); + }); + + CbPackage WorkerPackage; + WorkerPackage.SetObject(WorkerSpec); + + m_CidStore.FilterChunks(ChunkSet); + + if (ChunkSet.IsEmpty()) + { + ZEN_DEBUG("worker {}: all attachments already available", WorkerId); + m_ComputeService.RegisterWorker(WorkerPackage); + return HttpReq.WriteResponse(HttpResponseCode::NoContent); + } + + CbObjectWriter ResponseWriter; + ResponseWriter.BeginArray("need"); + ChunkSet.IterateHashes([&](const IoHash& Hash) { + ZEN_DEBUG("worker {}: need chunk {}", WorkerId, Hash); + ResponseWriter.AddHash(Hash); + }); + ResponseWriter.EndArray(); + + ZEN_DEBUG("worker {}: need {} attachments", WorkerId, ChunkSet.GetSize()); + return HttpReq.WriteResponse(HttpResponseCode::NotFound, ResponseWriter.Save()); + } + break; + + case HttpContentType::kCbPackage: + { + CbPackage WorkerSpecPackage = HttpReq.ReadPayloadPackage(); + CbObject WorkerSpec = WorkerSpecPackage.GetObject(); + + std::span<const CbAttachment> Attachments = WorkerSpecPackage.GetAttachments(); + + int AttachmentCount = 0; + int NewAttachmentCount = 0; + uint64_t TotalAttachmentBytes = 0; + uint64_t TotalNewBytes = 0; + + for (const CbAttachment& Attachment : Attachments) + { + ZEN_ASSERT(Attachment.IsCompressedBinary()); + + const IoHash DataHash = Attachment.GetHash(); + CompressedBuffer Buffer = Attachment.AsCompressedBinary(); + + ZEN_UNUSED(DataHash); + TotalAttachmentBytes += Buffer.GetCompressedSize(); + ++AttachmentCount; + + const CidStore::InsertResult InsertResult = + m_CidStore.AddChunk(Buffer.GetCompressed().Flatten().AsIoBuffer(), DataHash); + + if (InsertResult.New) + { + TotalNewBytes += Buffer.GetCompressedSize(); + ++NewAttachmentCount; + } + } + + ZEN_DEBUG("worker {}: {} in {} attachments, {} in {} new attachments", + WorkerId, + zen::NiceBytes(TotalAttachmentBytes), + AttachmentCount, + zen::NiceBytes(TotalNewBytes), + NewAttachmentCount); + + m_ComputeService.RegisterWorker(WorkerSpecPackage); + return HttpReq.WriteResponse(HttpResponseCode::NoContent); + } + break; + + default: + break; + } + } + break; + + default: + break; + } +} + +////////////////////////////////////////////////////////////////////////// + +void +httpcomputeservice_forcelink() +{ +} + +} // namespace zen::compute + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/httpfunctionservice.cpp b/src/zencompute/httpfunctionservice.cpp deleted file mode 100644 index 09a9684a7..000000000 --- a/src/zencompute/httpfunctionservice.cpp +++ /dev/null @@ -1,709 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#include "zencompute/httpfunctionservice.h" - -#if ZEN_WITH_COMPUTE_SERVICES - -# include "functionrunner.h" - -# include <zencore/compactbinary.h> -# include <zencore/compactbinarybuilder.h> -# include <zencore/compactbinarypackage.h> -# include <zencore/compress.h> -# include <zencore/except.h> -# include <zencore/filesystem.h> -# include <zencore/fmtutils.h> -# include <zencore/iobuffer.h> -# include <zencore/iohash.h> -# include <zencore/system.h> -# include <zenstore/cidstore.h> - -# include <span> - -using namespace std::literals; - -namespace zen::compute { - -constinit AsciiSet g_DecimalSet("0123456789"); -auto DecimalMatcher = [](std::string_view Str) { return AsciiSet::HasOnly(Str, g_DecimalSet); }; - -constinit AsciiSet g_HexSet("0123456789abcdefABCDEF"); -auto IoHashMatcher = [](std::string_view Str) { return Str.size() == 40 && AsciiSet::HasOnly(Str, g_HexSet); }; - -HttpFunctionService::HttpFunctionService(CidStore& InCidStore, - IHttpStatsService& StatsService, - [[maybe_unused]] const std::filesystem::path& BaseDir) -: m_CidStore(InCidStore) -, m_StatsService(StatsService) -, m_Log(logging::Get("apply")) -, m_BaseDir(BaseDir) -, m_FunctionService(InCidStore) -{ - m_FunctionService.AddLocalRunner(InCidStore, m_BaseDir / "local"); - - m_StatsService.RegisterHandler("apply", *this); - - m_Router.AddMatcher("lsn", DecimalMatcher); - m_Router.AddMatcher("worker", IoHashMatcher); - m_Router.AddMatcher("action", IoHashMatcher); - - m_Router.RegisterRoute( - "ready", - [this](HttpRouterRequest& Req) { - HttpServerRequest& HttpReq = Req.ServerRequest(); - - if (m_FunctionService.IsHealthy()) - { - return HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "ok"); - } - - return HttpReq.WriteResponse(HttpResponseCode::ServiceUnavailable); - }, - HttpVerb::kGet); - - m_Router.RegisterRoute( - "workers", - [this](HttpRouterRequest& Req) { - HttpServerRequest& HttpReq = Req.ServerRequest(); - - CbObjectWriter Cbo; - Cbo.BeginArray("workers"sv); - for (const IoHash& WorkerId : m_FunctionService.GetKnownWorkerIds()) - { - Cbo << WorkerId; - } - Cbo.EndArray(); - - return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); - }, - HttpVerb::kGet); - - m_Router.RegisterRoute( - "workers/{worker}", - [this](HttpRouterRequest& Req) { - HttpServerRequest& HttpReq = Req.ServerRequest(); - - const IoHash WorkerId = IoHash::FromHexString(Req.GetCapture(1)); - - switch (HttpReq.RequestVerb()) - { - case HttpVerb::kGet: - if (WorkerDesc Desc = m_FunctionService.GetWorkerDescriptor(WorkerId)) - { - return HttpReq.WriteResponse(HttpResponseCode::OK, Desc.Descriptor.GetObject()); - } - return HttpReq.WriteResponse(HttpResponseCode::NotFound); - - case HttpVerb::kPost: - { - switch (HttpReq.RequestContentType()) - { - case HttpContentType::kCbObject: - { - CbObject WorkerSpec = HttpReq.ReadPayloadObject(); - - // Determine which pieces are missing and need to be transmitted - - HashKeySet ChunkSet; - - WorkerSpec.IterateAttachments([&](CbFieldView Field) { - const IoHash Hash = Field.AsHash(); - ChunkSet.AddHashToSet(Hash); - }); - - CbPackage WorkerPackage; - WorkerPackage.SetObject(WorkerSpec); - - m_CidStore.FilterChunks(ChunkSet); - - if (ChunkSet.IsEmpty()) - { - ZEN_DEBUG("worker {}: all attachments already available", WorkerId); - m_FunctionService.RegisterWorker(WorkerPackage); - return HttpReq.WriteResponse(HttpResponseCode::NoContent); - } - - CbObjectWriter ResponseWriter; - ResponseWriter.BeginArray("need"); - - ChunkSet.IterateHashes([&](const IoHash& Hash) { - ZEN_DEBUG("worker {}: need chunk {}", WorkerId, Hash); - ResponseWriter.AddHash(Hash); - }); - - ResponseWriter.EndArray(); - - ZEN_DEBUG("worker {}: need {} attachments", WorkerId, ChunkSet.GetSize()); - - return HttpReq.WriteResponse(HttpResponseCode::NotFound, ResponseWriter.Save()); - } - break; - - case HttpContentType::kCbPackage: - { - CbPackage WorkerSpecPackage = HttpReq.ReadPayloadPackage(); - CbObject WorkerSpec = WorkerSpecPackage.GetObject(); - - std::span<const CbAttachment> Attachments = WorkerSpecPackage.GetAttachments(); - - int AttachmentCount = 0; - int NewAttachmentCount = 0; - uint64_t TotalAttachmentBytes = 0; - uint64_t TotalNewBytes = 0; - - for (const CbAttachment& Attachment : Attachments) - { - ZEN_ASSERT(Attachment.IsCompressedBinary()); - - const IoHash DataHash = Attachment.GetHash(); - CompressedBuffer Buffer = Attachment.AsCompressedBinary(); - - ZEN_UNUSED(DataHash); - TotalAttachmentBytes += Buffer.GetCompressedSize(); - ++AttachmentCount; - - const CidStore::InsertResult InsertResult = - m_CidStore.AddChunk(Buffer.GetCompressed().Flatten().AsIoBuffer(), DataHash); - - if (InsertResult.New) - { - TotalNewBytes += Buffer.GetCompressedSize(); - ++NewAttachmentCount; - } - } - - ZEN_DEBUG("worker {}: {} in {} attachments, {} in {} new attachments", - WorkerId, - zen::NiceBytes(TotalAttachmentBytes), - AttachmentCount, - zen::NiceBytes(TotalNewBytes), - NewAttachmentCount); - - m_FunctionService.RegisterWorker(WorkerSpecPackage); - - return HttpReq.WriteResponse(HttpResponseCode::NoContent); - } - break; - - default: - break; - } - } - break; - - default: - break; - } - }, - HttpVerb::kGet | HttpVerb::kPost); - - m_Router.RegisterRoute( - "jobs/completed", - [this](HttpRouterRequest& Req) { - HttpServerRequest& HttpReq = Req.ServerRequest(); - - CbObjectWriter Cbo; - m_FunctionService.GetCompleted(Cbo); - - SystemMetrics Sm = GetSystemMetricsForReporting(); - Cbo.BeginObject("metrics"); - Describe(Sm, Cbo); - Cbo.EndObject(); - - HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); - }, - HttpVerb::kGet); - - m_Router.RegisterRoute( - "jobs/history", - [this](HttpRouterRequest& Req) { - HttpServerRequest& HttpReq = Req.ServerRequest(); - const auto QueryParams = HttpReq.GetQueryParams(); - - int QueryLimit = 50; - - if (auto LimitParam = QueryParams.GetValue("limit"); LimitParam.empty() == false) - { - QueryLimit = ParseInt<int>(LimitParam).value_or(50); - } - - CbObjectWriter Cbo; - Cbo.BeginArray("history"); - for (const auto& Entry : m_FunctionService.GetActionHistory(QueryLimit)) - { - Cbo.BeginObject(); - Cbo << "lsn"sv << Entry.Lsn; - Cbo << "actionId"sv << Entry.ActionId; - Cbo << "workerId"sv << Entry.WorkerId; - Cbo << "succeeded"sv << Entry.Succeeded; - Cbo << "actionDescriptor"sv << Entry.ActionDescriptor; - - for (const auto& Timestamp : Entry.Timestamps) - { - Cbo.AddInteger( - fmt::format("time_{}"sv, RunnerAction::ToString(static_cast<RunnerAction::State>(&Timestamp - Entry.Timestamps))), - Timestamp); - } - Cbo.EndObject(); - } - Cbo.EndArray(); - - HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); - }, - HttpVerb::kGet); - - m_Router.RegisterRoute( - "jobs/{lsn}", - [this](HttpRouterRequest& Req) { - HttpServerRequest& HttpReq = Req.ServerRequest(); - const int ActionLsn = std::stoi(std::string{Req.GetCapture(1)}); - - switch (HttpReq.RequestVerb()) - { - case HttpVerb::kGet: - { - CbPackage Output; - HttpResponseCode ResponseCode = m_FunctionService.GetActionResult(ActionLsn, Output); - - if (ResponseCode == HttpResponseCode::OK) - { - return HttpReq.WriteResponse(HttpResponseCode::OK, Output); - } - - return HttpReq.WriteResponse(ResponseCode); - } - break; - - case HttpVerb::kPost: - { - // Add support for cancellation, priority changes - } - break; - - default: - break; - } - }, - HttpVerb::kGet | HttpVerb::kPost); - - m_Router.RegisterRoute( - "jobs/{worker}/{action}", // This route is inefficient, and is only here for backwards compatibility. The preferred path is the - // one which uses the scheduled action lsn for lookups - [this](HttpRouterRequest& Req) { - HttpServerRequest& HttpReq = Req.ServerRequest(); - const IoHash ActionId = IoHash::FromHexString(Req.GetCapture(2)); - - CbPackage Output; - if (HttpResponseCode ResponseCode = m_FunctionService.FindActionResult(ActionId, /* out */ Output); - ResponseCode != HttpResponseCode::OK) - { - ZEN_TRACE("jobs/{}/{}: {}", Req.GetCapture(1), Req.GetCapture(2), ToString(ResponseCode)) - - if (ResponseCode == HttpResponseCode::NotFound) - { - return HttpReq.WriteResponse(ResponseCode); - } - - return HttpReq.WriteResponse(ResponseCode); - } - - ZEN_DEBUG("jobs/{}/{}: OK", Req.GetCapture(1), Req.GetCapture(2)) - - return HttpReq.WriteResponse(HttpResponseCode::OK, Output); - }, - HttpVerb::kGet); - - m_Router.RegisterRoute( - "jobs/{worker}", - [this](HttpRouterRequest& Req) { - HttpServerRequest& HttpReq = Req.ServerRequest(); - const IoHash WorkerId = IoHash::FromHexString(Req.GetCapture(1)); - - WorkerDesc Worker = m_FunctionService.GetWorkerDescriptor(WorkerId); - - if (!Worker) - { - return HttpReq.WriteResponse(HttpResponseCode::NotFound); - } - - const auto QueryParams = Req.ServerRequest().GetQueryParams(); - - int RequestPriority = -1; - - if (auto PriorityParam = QueryParams.GetValue("priority"); PriorityParam.empty() == false) - { - RequestPriority = ParseInt<int>(PriorityParam).value_or(-1); - } - - switch (HttpReq.RequestVerb()) - { - case HttpVerb::kGet: - // TODO: return status of all pending or executing jobs - break; - - case HttpVerb::kPost: - switch (HttpReq.RequestContentType()) - { - case HttpContentType::kCbObject: - { - // This operation takes the proposed job spec and identifies which - // chunks are not present on this server. This list is then returned in - // the "need" list in the response - - IoBuffer Payload = HttpReq.ReadPayload(); - CbObject ActionObj = LoadCompactBinaryObject(Payload); - - std::vector<IoHash> NeedList; - - ActionObj.IterateAttachments([&](CbFieldView Field) { - const IoHash FileHash = Field.AsHash(); - - if (!m_CidStore.ContainsChunk(FileHash)) - { - NeedList.push_back(FileHash); - } - }); - - if (NeedList.empty()) - { - // We already have everything, enqueue the action for execution - - if (FunctionServiceSession::EnqueueResult Result = - m_FunctionService.EnqueueResolvedAction(Worker, ActionObj, RequestPriority)) - { - ZEN_DEBUG("action {} accepted (lsn {})", ActionObj.GetHash(), Result.Lsn); - - HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); - } - - return; - } - - CbObjectWriter Cbo; - Cbo.BeginArray("need"); - - for (const IoHash& Hash : NeedList) - { - Cbo << Hash; - } - - Cbo.EndArray(); - CbObject Response = Cbo.Save(); - - return HttpReq.WriteResponse(HttpResponseCode::NotFound, Response); - } - break; - - case HttpContentType::kCbPackage: - { - CbPackage Action = HttpReq.ReadPayloadPackage(); - CbObject ActionObj = Action.GetObject(); - - std::span<const CbAttachment> Attachments = Action.GetAttachments(); - - int AttachmentCount = 0; - int NewAttachmentCount = 0; - uint64_t TotalAttachmentBytes = 0; - uint64_t TotalNewBytes = 0; - - for (const CbAttachment& Attachment : Attachments) - { - ZEN_ASSERT(Attachment.IsCompressedBinary()); - - const IoHash DataHash = Attachment.GetHash(); - CompressedBuffer DataView = Attachment.AsCompressedBinary(); - - ZEN_UNUSED(DataHash); - - const uint64_t CompressedSize = DataView.GetCompressedSize(); - - TotalAttachmentBytes += CompressedSize; - ++AttachmentCount; - - const CidStore::InsertResult InsertResult = - m_CidStore.AddChunk(DataView.GetCompressed().Flatten().AsIoBuffer(), DataHash); - - if (InsertResult.New) - { - TotalNewBytes += CompressedSize; - ++NewAttachmentCount; - } - } - - if (FunctionServiceSession::EnqueueResult Result = - m_FunctionService.EnqueueResolvedAction(Worker, ActionObj, RequestPriority)) - { - ZEN_DEBUG("accepted action {} (lsn {}): {} in {} attachments. {} new ({} attachments)", - ActionObj.GetHash(), - Result.Lsn, - zen::NiceBytes(TotalAttachmentBytes), - AttachmentCount, - zen::NiceBytes(TotalNewBytes), - NewAttachmentCount); - - HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); - } - - return; - } - break; - - default: - break; - } - break; - - default: - break; - } - }, - HttpVerb::kPost); - - m_Router.RegisterRoute( - "jobs", - [this](HttpRouterRequest& Req) { - HttpServerRequest& HttpReq = Req.ServerRequest(); - - const auto QueryParams = HttpReq.GetQueryParams(); - - int RequestPriority = -1; - - if (auto PriorityParam = QueryParams.GetValue("priority"); PriorityParam.empty() == false) - { - RequestPriority = ParseInt<int>(PriorityParam).value_or(-1); - } - - // Resolve worker - - // - - switch (HttpReq.RequestContentType()) - { - case HttpContentType::kCbObject: - { - // This operation takes the proposed job spec and identifies which - // chunks are not present on this server. This list is then returned in - // the "need" list in the response - - IoBuffer Payload = HttpReq.ReadPayload(); - CbObject ActionObj = LoadCompactBinaryObject(Payload); - - std::vector<IoHash> NeedList; - - ActionObj.IterateAttachments([&](CbFieldView Field) { - const IoHash FileHash = Field.AsHash(); - - if (!m_CidStore.ContainsChunk(FileHash)) - { - NeedList.push_back(FileHash); - } - }); - - if (NeedList.empty()) - { - // We already have everything, enqueue the action for execution - - if (FunctionServiceSession::EnqueueResult Result = m_FunctionService.EnqueueAction(ActionObj, RequestPriority)) - { - ZEN_DEBUG("action accepted (lsn {})", Result.Lsn); - - return HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); - } - else - { - // Could not resolve? - return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage); - } - } - - CbObjectWriter Cbo; - Cbo.BeginArray("need"); - - for (const IoHash& Hash : NeedList) - { - Cbo << Hash; - } - - Cbo.EndArray(); - CbObject Response = Cbo.Save(); - - return HttpReq.WriteResponse(HttpResponseCode::NotFound, Response); - } - - case HttpContentType::kCbPackage: - { - CbPackage Action = HttpReq.ReadPayloadPackage(); - CbObject ActionObj = Action.GetObject(); - - std::span<const CbAttachment> Attachments = Action.GetAttachments(); - - int AttachmentCount = 0; - int NewAttachmentCount = 0; - uint64_t TotalAttachmentBytes = 0; - uint64_t TotalNewBytes = 0; - - for (const CbAttachment& Attachment : Attachments) - { - ZEN_ASSERT(Attachment.IsCompressedBinary()); - - const IoHash DataHash = Attachment.GetHash(); - CompressedBuffer DataView = Attachment.AsCompressedBinary(); - - ZEN_UNUSED(DataHash); - - const uint64_t CompressedSize = DataView.GetCompressedSize(); - - TotalAttachmentBytes += CompressedSize; - ++AttachmentCount; - - const CidStore::InsertResult InsertResult = - m_CidStore.AddChunk(DataView.GetCompressed().Flatten().AsIoBuffer(), DataHash); - - if (InsertResult.New) - { - TotalNewBytes += CompressedSize; - ++NewAttachmentCount; - } - } - - if (FunctionServiceSession::EnqueueResult Result = m_FunctionService.EnqueueAction(ActionObj, RequestPriority)) - { - ZEN_DEBUG("accepted action (lsn {}): {} in {} attachments. {} new ({} attachments)", - Result.Lsn, - zen::NiceBytes(TotalAttachmentBytes), - AttachmentCount, - zen::NiceBytes(TotalNewBytes), - NewAttachmentCount); - - HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); - } - else - { - // Could not resolve? - return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage); - } - } - return; - } - }, - HttpVerb::kPost); - - m_Router.RegisterRoute( - "workers/all", - [this](HttpRouterRequest& Req) { - HttpServerRequest& HttpReq = Req.ServerRequest(); - - std::vector<IoHash> WorkerIds = m_FunctionService.GetKnownWorkerIds(); - - CbObjectWriter Cbo; - Cbo.BeginArray("workers"); - - for (const IoHash& WorkerId : WorkerIds) - { - Cbo.BeginObject(); - - Cbo << "id" << WorkerId; - - const auto& Descriptor = m_FunctionService.GetWorkerDescriptor(WorkerId); - - Cbo << "descriptor" << Descriptor.Descriptor.GetObject(); - - Cbo.EndObject(); - } - - Cbo.EndArray(); - - HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); - }, - HttpVerb::kGet); - - m_Router.RegisterRoute( - "sysinfo", - [this](HttpRouterRequest& Req) { - HttpServerRequest& HttpReq = Req.ServerRequest(); - - SystemMetrics Sm = GetSystemMetricsForReporting(); - - CbObjectWriter Cbo; - Describe(Sm, Cbo); - - Cbo << "cpu_usage" << Sm.CpuUsagePercent; - Cbo << "memory_total" << Sm.SystemMemoryMiB * 1024 * 1024; - Cbo << "memory_used" << (Sm.SystemMemoryMiB - Sm.AvailSystemMemoryMiB) * 1024 * 1024; - Cbo << "disk_used" << 100 * 1024; - Cbo << "disk_total" << 100 * 1024 * 1024; - - return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); - }, - HttpVerb::kGet); - - m_Router.RegisterRoute( - "record/start", - [this](HttpRouterRequest& Req) { - HttpServerRequest& HttpReq = Req.ServerRequest(); - - m_FunctionService.StartRecording(m_CidStore, m_BaseDir / "recording"); - - return HttpReq.WriteResponse(HttpResponseCode::OK); - }, - HttpVerb::kPost); - - m_Router.RegisterRoute( - "record/stop", - [this](HttpRouterRequest& Req) { - HttpServerRequest& HttpReq = Req.ServerRequest(); - - m_FunctionService.StopRecording(); - - return HttpReq.WriteResponse(HttpResponseCode::OK); - }, - HttpVerb::kPost); -} - -HttpFunctionService::~HttpFunctionService() -{ - m_StatsService.UnregisterHandler("apply", *this); -} - -void -HttpFunctionService::Shutdown() -{ - m_FunctionService.Shutdown(); -} - -const char* -HttpFunctionService::BaseUri() const -{ - return "/apply/"; -} - -void -HttpFunctionService::HandleRequest(HttpServerRequest& Request) -{ - metrics::OperationTiming::Scope $(m_HttpRequests); - - if (m_Router.HandleRequest(Request) == false) - { - ZEN_WARN("No route found for {0}", Request.RelativeUri()); - } -} - -void -HttpFunctionService::HandleStatsRequest(HttpServerRequest& Request) -{ - CbObjectWriter Cbo; - m_FunctionService.EmitStats(Cbo); - - Request.WriteResponse(HttpResponseCode::OK, Cbo.Save()); -} - -////////////////////////////////////////////////////////////////////////// - -void -httpfunction_forcelink() -{ -} - -} // namespace zen::compute - -#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/httporchestrator.cpp b/src/zencompute/httporchestrator.cpp index 39e7e60d7..6cbe01e04 100644 --- a/src/zencompute/httporchestrator.cpp +++ b/src/zencompute/httporchestrator.cpp @@ -2,65 +2,398 @@ #include "zencompute/httporchestrator.h" -#include <zencore/compactbinarybuilder.h> -#include <zencore/logging.h> +#if ZEN_WITH_COMPUTE_SERVICES + +# include <zencompute/orchestratorservice.h> +# include <zencore/compactbinarybuilder.h> +# include <zencore/logging.h> +# include <zencore/string.h> +# include <zencore/system.h> namespace zen::compute { -HttpOrchestratorService::HttpOrchestratorService() : m_Log(logging::Get("orch")) +// Worker IDs must be 3-64 characters and can only contain letters, numbers, underscores, and dashes +static bool +IsValidWorkerId(std::string_view Id) +{ + if (Id.size() < 3 || Id.size() > 64) + { + return false; + } + for (char c : Id) + { + if ((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '_' || c == '-') + { + continue; + } + return false; + } + return true; +} + +// Shared announce payload parser used by both the HTTP POST route and the +// WebSocket message handler. Returns the worker ID on success (empty on +// validation failure). The returned WorkerAnnouncement has string_view +// fields that reference the supplied CbObjectView, so the CbObject must +// outlive the returned announcement. +static std::string_view +ParseWorkerAnnouncement(const CbObjectView& Data, OrchestratorService::WorkerAnnouncement& Ann) { + Ann.Id = Data["id"].AsString(""); + Ann.Uri = Data["uri"].AsString(""); + + if (!IsValidWorkerId(Ann.Id)) + { + return {}; + } + + if (!Ann.Uri.starts_with("http://") && !Ann.Uri.starts_with("https://")) + { + return {}; + } + + Ann.Hostname = Data["hostname"].AsString(""); + Ann.Platform = Data["platform"].AsString(""); + Ann.CpuUsagePercent = Data["cpu_usage"].AsFloat(0.0f); + Ann.MemoryTotalBytes = Data["memory_total"].AsUInt64(0); + Ann.MemoryUsedBytes = Data["memory_used"].AsUInt64(0); + Ann.BytesReceived = Data["bytes_received"].AsUInt64(0); + Ann.BytesSent = Data["bytes_sent"].AsUInt64(0); + Ann.ActionsPending = Data["actions_pending"].AsInt32(0); + Ann.ActionsRunning = Data["actions_running"].AsInt32(0); + Ann.ActionsCompleted = Data["actions_completed"].AsInt32(0); + Ann.ActiveQueues = Data["active_queues"].AsInt32(0); + Ann.Provisioner = Data["provisioner"].AsString(""); + + if (auto Metrics = Data["metrics"].AsObjectView()) + { + Ann.Cpus = Metrics["lp_count"].AsInt32(0); + if (Ann.Cpus <= 0) + { + Ann.Cpus = 1; + } + } + + return Ann.Id; +} + +HttpOrchestratorService::HttpOrchestratorService(std::filesystem::path DataDir, bool EnableWorkerWebSocket) +: m_Service(std::make_unique<OrchestratorService>(std::move(DataDir), EnableWorkerWebSocket)) +, m_Hostname(GetMachineName()) +{ + m_Router.AddMatcher("workerid", [](std::string_view Segment) { return IsValidWorkerId(Segment); }); + m_Router.AddMatcher("clientid", [](std::string_view Segment) { return IsValidWorkerId(Segment); }); + + // dummy endpoint for websocket clients + m_Router.RegisterRoute( + "ws", + [this](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK); }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "status", + [this](HttpRouterRequest& Req) { + CbObjectWriter Cbo; + Cbo << "hostname" << std::string_view(m_Hostname); + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Cbo.Save()); + }, + HttpVerb::kGet); + m_Router.RegisterRoute( "provision", + [this](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK, m_Service->GetWorkerList()); }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "announce", [this](HttpRouterRequest& Req) { HttpServerRequest& HttpReq = Req.ServerRequest(); - CbObjectWriter Cbo; - Cbo.BeginArray("workers"); + CbObject Data = HttpReq.ReadPayloadObject(); + + OrchestratorService::WorkerAnnouncement Ann; + std::string_view WorkerId = ParseWorkerAnnouncement(Data, Ann); - m_KnownWorkersLock.WithSharedLock([&] { - for (const auto& [WorkerId, Worker] : m_KnownWorkers) + if (WorkerId.empty()) + { + return HttpReq.WriteResponse(HttpResponseCode::BadRequest, + HttpContentType::kText, + "Invalid worker announcement: id must be 3-64 alphanumeric/underscore/dash " + "characters and uri must start with http:// or https://"); + } + + m_Service->AnnounceWorker(Ann); + + HttpReq.WriteResponse(HttpResponseCode::OK); + +# if ZEN_WITH_WEBSOCKETS + // Notify push thread that state may have changed + m_PushEvent.Set(); +# endif + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "agents", + [this](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK, m_Service->GetWorkerList()); }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "history", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + auto Params = HttpReq.GetQueryParams(); + + int Limit = 100; + auto LimitStr = Params.GetValue("limit"); + if (!LimitStr.empty()) + { + Limit = std::atoi(std::string(LimitStr).c_str()); + } + + HttpReq.WriteResponse(HttpResponseCode::OK, m_Service->GetProvisioningHistory(Limit)); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "timeline/{workerid}", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + std::string_view WorkerId = Req.GetCapture(1); + auto Params = HttpReq.GetQueryParams(); + + auto FromStr = Params.GetValue("from"); + auto ToStr = Params.GetValue("to"); + auto LimitStr = Params.GetValue("limit"); + + std::optional<DateTime> From; + std::optional<DateTime> To; + + if (!FromStr.empty()) + { + auto Val = zen::ParseInt<uint64_t>(FromStr); + if (!Val) { - Cbo.BeginObject(); - Cbo << "uri" << Worker.BaseUri; - Cbo << "dt" << Worker.LastSeen.GetElapsedTimeMs(); - Cbo.EndObject(); + return HttpReq.WriteResponse(HttpResponseCode::BadRequest); } - }); + From = DateTime(*Val); + } + + if (!ToStr.empty()) + { + auto Val = zen::ParseInt<uint64_t>(ToStr); + if (!Val) + { + return HttpReq.WriteResponse(HttpResponseCode::BadRequest); + } + To = DateTime(*Val); + } + + int Limit = !LimitStr.empty() ? zen::ParseInt<int>(LimitStr).value_or(0) : 0; - Cbo.EndArray(); + CbObject Result = m_Service->GetWorkerTimeline(WorkerId, From, To, Limit); - HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + if (!Result) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + + HttpReq.WriteResponse(HttpResponseCode::OK, std::move(Result)); }, - HttpVerb::kPost); + HttpVerb::kGet); m_Router.RegisterRoute( - "announce", + "timeline", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + auto Params = HttpReq.GetQueryParams(); + + auto FromStr = Params.GetValue("from"); + auto ToStr = Params.GetValue("to"); + + DateTime From = DateTime(0); + DateTime To = DateTime::Now(); + + if (!FromStr.empty()) + { + auto Val = zen::ParseInt<uint64_t>(FromStr); + if (!Val) + { + return HttpReq.WriteResponse(HttpResponseCode::BadRequest); + } + From = DateTime(*Val); + } + + if (!ToStr.empty()) + { + auto Val = zen::ParseInt<uint64_t>(ToStr); + if (!Val) + { + return HttpReq.WriteResponse(HttpResponseCode::BadRequest); + } + To = DateTime(*Val); + } + + CbObject Result = m_Service->GetAllTimelines(From, To); + + HttpReq.WriteResponse(HttpResponseCode::OK, std::move(Result)); + }, + HttpVerb::kGet); + + // Client tracking endpoints + + m_Router.RegisterRoute( + "clients", [this](HttpRouterRequest& Req) { HttpServerRequest& HttpReq = Req.ServerRequest(); CbObject Data = HttpReq.ReadPayloadObject(); - std::string_view WorkerId = Data["id"].AsString(""); - std::string_view WorkerUri = Data["uri"].AsString(""); + OrchestratorService::ClientAnnouncement Ann; + Ann.SessionId = Data["session_id"].AsObjectId(Oid::Zero); + Ann.Hostname = Data["hostname"].AsString(""); + Ann.Address = HttpReq.GetRemoteAddress(); - if (WorkerId.empty() || WorkerUri.empty()) + auto MetadataView = Data["metadata"].AsObjectView(); + if (MetadataView) { - return HttpReq.WriteResponse(HttpResponseCode::BadRequest); + Ann.Metadata = CbObject::Clone(MetadataView); } - m_KnownWorkersLock.WithExclusiveLock([&] { - auto& Worker = m_KnownWorkers[std::string(WorkerId)]; - Worker.BaseUri = WorkerUri; - Worker.LastSeen.Reset(); - }); + std::string ClientId = m_Service->AnnounceClient(Ann); - HttpReq.WriteResponse(HttpResponseCode::OK); + CbObjectWriter ResponseObj; + ResponseObj << "id" << std::string_view(ClientId); + HttpReq.WriteResponse(HttpResponseCode::OK, ResponseObj.Save()); + +# if ZEN_WITH_WEBSOCKETS + m_PushEvent.Set(); +# endif }, HttpVerb::kPost); + + m_Router.RegisterRoute( + "clients/{clientid}/update", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + std::string_view ClientId = Req.GetCapture(1); + + CbObject MetadataObj; + CbObject Data = HttpReq.ReadPayloadObject(); + if (Data) + { + auto MetadataView = Data["metadata"].AsObjectView(); + if (MetadataView) + { + MetadataObj = CbObject::Clone(MetadataView); + } + } + + if (m_Service->UpdateClient(ClientId, std::move(MetadataObj))) + { + HttpReq.WriteResponse(HttpResponseCode::OK); + } + else + { + HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "clients/{clientid}/complete", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + std::string_view ClientId = Req.GetCapture(1); + + if (m_Service->CompleteClient(ClientId)) + { + HttpReq.WriteResponse(HttpResponseCode::OK); + } + else + { + HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + +# if ZEN_WITH_WEBSOCKETS + m_PushEvent.Set(); +# endif + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "clients", + [this](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK, m_Service->GetClientList()); }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "clients/history", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + auto Params = HttpReq.GetQueryParams(); + + int Limit = 100; + auto LimitStr = Params.GetValue("limit"); + if (!LimitStr.empty()) + { + Limit = std::atoi(std::string(LimitStr).c_str()); + } + + HttpReq.WriteResponse(HttpResponseCode::OK, m_Service->GetClientHistory(Limit)); + }, + HttpVerb::kGet); + +# if ZEN_WITH_WEBSOCKETS + + // Start the WebSocket push thread + m_PushEnabled.store(true); + m_PushThread = std::thread([this] { PushThreadFunction(); }); +# endif } HttpOrchestratorService::~HttpOrchestratorService() { + Shutdown(); +} + +void +HttpOrchestratorService::Shutdown() +{ +# if ZEN_WITH_WEBSOCKETS + if (!m_PushEnabled.exchange(false)) + { + return; + } + + // Stop the push thread first, before touching connections. This ensures + // the push thread is no longer reading m_WsConnections or calling into + // m_Service when we start tearing things down. + m_PushEvent.Set(); + if (m_PushThread.joinable()) + { + m_PushThread.join(); + } + + // Clean up worker WebSocket connections — collect IDs under lock, then + // notify the service outside the lock to avoid lock-order inversions. + std::vector<std::string> WorkerIds; + m_WorkerWsLock.WithExclusiveLock([&] { + WorkerIds.reserve(m_WorkerWsMap.size()); + for (const auto& [Conn, Id] : m_WorkerWsMap) + { + WorkerIds.push_back(Id); + } + m_WorkerWsMap.clear(); + }); + for (const auto& Id : WorkerIds) + { + m_Service->SetWorkerWebSocketConnected(Id, false); + } + + // Now that the push thread is gone, release all dashboard connections. + m_WsConnectionsLock.WithExclusiveLock([&] { m_WsConnections.clear(); }); +# endif } const char* @@ -78,4 +411,240 @@ HttpOrchestratorService::HandleRequest(HttpServerRequest& Request) } } +////////////////////////////////////////////////////////////////////////// +// +// IWebSocketHandler +// + +# if ZEN_WITH_WEBSOCKETS +void +HttpOrchestratorService::OnWebSocketOpen(Ref<WebSocketConnection> Connection) +{ + if (!m_PushEnabled.load()) + { + return; + } + + ZEN_INFO("WebSocket client connected"); + + m_WsConnectionsLock.WithExclusiveLock([&] { m_WsConnections.push_back(std::move(Connection)); }); + + // Wake push thread to send initial state immediately + m_PushEvent.Set(); +} + +void +HttpOrchestratorService::OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) +{ + // Only handle binary messages from workers when the feature is enabled. + if (!m_Service->IsWorkerWebSocketEnabled() || Msg.Opcode != WebSocketOpcode::kBinary) + { + return; + } + + std::string WorkerId = HandleWorkerWebSocketMessage(Msg); + if (WorkerId.empty()) + { + return; + } + + // Check if this is a new worker WebSocket connection + bool IsNewWorkerWs = false; + m_WorkerWsLock.WithExclusiveLock([&] { + auto It = m_WorkerWsMap.find(&Conn); + if (It == m_WorkerWsMap.end()) + { + m_WorkerWsMap[&Conn] = WorkerId; + IsNewWorkerWs = true; + } + }); + + if (IsNewWorkerWs) + { + m_Service->SetWorkerWebSocketConnected(WorkerId, true); + } + + m_PushEvent.Set(); +} + +std::string +HttpOrchestratorService::HandleWorkerWebSocketMessage(const WebSocketMessage& Msg) +{ + // Workers send CbObject in native binary format over the WebSocket to + // avoid the lossy CbObject↔JSON round-trip. + CbObject Data = CbObject::MakeView(Msg.Payload.GetData()); + if (!Data) + { + ZEN_WARN("worker WebSocket message is not a valid CbObject"); + return {}; + } + + OrchestratorService::WorkerAnnouncement Ann; + std::string_view WorkerId = ParseWorkerAnnouncement(Data, Ann); + if (WorkerId.empty()) + { + ZEN_WARN("invalid worker announcement via WebSocket"); + return {}; + } + + m_Service->AnnounceWorker(Ann); + return std::string(WorkerId); +} + +void +HttpOrchestratorService::OnWebSocketClose(WebSocketConnection& Conn, + [[maybe_unused]] uint16_t Code, + [[maybe_unused]] std::string_view Reason) +{ + ZEN_INFO("WebSocket client disconnected (code {})", Code); + + // Check if this was a worker WebSocket connection; collect the ID under + // the worker lock, then notify the service outside the lock. + std::string DisconnectedWorkerId; + m_WorkerWsLock.WithExclusiveLock([&] { + auto It = m_WorkerWsMap.find(&Conn); + if (It != m_WorkerWsMap.end()) + { + DisconnectedWorkerId = std::move(It->second); + m_WorkerWsMap.erase(It); + } + }); + + if (!DisconnectedWorkerId.empty()) + { + m_Service->SetWorkerWebSocketConnected(DisconnectedWorkerId, false); + m_PushEvent.Set(); + } + + if (!m_PushEnabled.load()) + { + return; + } + + // Remove from dashboard connections + m_WsConnectionsLock.WithExclusiveLock([&] { + auto It = std::remove_if(m_WsConnections.begin(), m_WsConnections.end(), [&Conn](const Ref<WebSocketConnection>& C) { + return C.Get() == &Conn; + }); + m_WsConnections.erase(It, m_WsConnections.end()); + }); +} +# endif + +////////////////////////////////////////////////////////////////////////// +// +// Push thread +// + +# if ZEN_WITH_WEBSOCKETS +void +HttpOrchestratorService::PushThreadFunction() +{ + SetCurrentThreadName("orch_ws_push"); + + while (m_PushEnabled.load()) + { + m_PushEvent.Wait(2000); + m_PushEvent.Reset(); + + if (!m_PushEnabled.load()) + { + break; + } + + // Snapshot current connections + std::vector<Ref<WebSocketConnection>> Connections; + m_WsConnectionsLock.WithSharedLock([&] { Connections = m_WsConnections; }); + + if (Connections.empty()) + { + continue; + } + + // Build combined JSON with worker list, provisioning history, clients, and client history + CbObject WorkerList = m_Service->GetWorkerList(); + CbObject History = m_Service->GetProvisioningHistory(50); + CbObject ClientList = m_Service->GetClientList(); + CbObject ClientHistory = m_Service->GetClientHistory(50); + + ExtendableStringBuilder<4096> JsonBuilder; + JsonBuilder.Append("{"); + JsonBuilder.Append(fmt::format("\"hostname\":\"{}\",", m_Hostname)); + + // Emit workers array from worker list + ExtendableStringBuilder<2048> WorkerJson; + WorkerList.ToJson(WorkerJson); + std::string_view WorkerJsonView = WorkerJson.ToView(); + // Strip outer braces: {"workers":[...]} -> "workers":[...] + if (WorkerJsonView.size() >= 2) + { + JsonBuilder.Append(WorkerJsonView.substr(1, WorkerJsonView.size() - 2)); + } + + JsonBuilder.Append(","); + + // Emit events array from history + ExtendableStringBuilder<2048> HistoryJson; + History.ToJson(HistoryJson); + std::string_view HistoryJsonView = HistoryJson.ToView(); + if (HistoryJsonView.size() >= 2) + { + JsonBuilder.Append(HistoryJsonView.substr(1, HistoryJsonView.size() - 2)); + } + + JsonBuilder.Append(","); + + // Emit clients array from client list + ExtendableStringBuilder<2048> ClientJson; + ClientList.ToJson(ClientJson); + std::string_view ClientJsonView = ClientJson.ToView(); + if (ClientJsonView.size() >= 2) + { + JsonBuilder.Append(ClientJsonView.substr(1, ClientJsonView.size() - 2)); + } + + JsonBuilder.Append(","); + + // Emit client_events array from client history + ExtendableStringBuilder<2048> ClientHistoryJson; + ClientHistory.ToJson(ClientHistoryJson); + std::string_view ClientHistoryJsonView = ClientHistoryJson.ToView(); + if (ClientHistoryJsonView.size() >= 2) + { + JsonBuilder.Append(ClientHistoryJsonView.substr(1, ClientHistoryJsonView.size() - 2)); + } + + JsonBuilder.Append("}"); + std::string_view Json = JsonBuilder.ToView(); + + // Broadcast to all connected clients, prune closed ones + bool HadClosedConnections = false; + + for (auto& Conn : Connections) + { + if (Conn->IsOpen()) + { + Conn->SendText(Json); + } + else + { + HadClosedConnections = true; + } + } + + if (HadClosedConnections) + { + m_WsConnectionsLock.WithExclusiveLock([&] { + auto It = std::remove_if(m_WsConnections.begin(), m_WsConnections.end(), [](const Ref<WebSocketConnection>& C) { + return !C->IsOpen(); + }); + m_WsConnections.erase(It, m_WsConnections.end()); + }); + } + } +} +# endif + } // namespace zen::compute + +#endif diff --git a/src/zencompute/include/zencompute/cloudmetadata.h b/src/zencompute/include/zencompute/cloudmetadata.h new file mode 100644 index 000000000..a5bc5a34d --- /dev/null +++ b/src/zencompute/include/zencompute/cloudmetadata.h @@ -0,0 +1,151 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/compactbinarybuilder.h> +#include <zencore/logging.h> +#include <zencore/thread.h> + +#include <atomic> +#include <filesystem> +#include <string> +#include <thread> + +namespace zen::compute { + +enum class CloudProvider +{ + None, + AWS, + Azure, + GCP +}; + +std::string_view ToString(CloudProvider Provider); + +/** Snapshot of detected cloud instance properties. */ +struct CloudInstanceInfo +{ + CloudProvider Provider = CloudProvider::None; + std::string InstanceId; + std::string AvailabilityZone; + bool IsSpot = false; + bool IsAutoscaling = false; +}; + +/** + * Detects whether the process is running on a cloud VM (AWS, Azure, or GCP) + * and monitors for impending termination signals. + * + * Detection works by querying the Instance Metadata Service (IMDS) at the + * well-known link-local address 169.254.169.254, which is only routable from + * within a cloud VM. Each provider is probed in sequence (AWS -> Azure -> GCP); + * the first successful response wins. + * + * To avoid a ~200ms connect timeout penalty on every startup when running on + * bare-metal or non-cloud machines, failed probes write sentinel files + * (e.g. ".isNotAWS") to DataDir. Subsequent startups skip providers that have + * a sentinel present. Delete the sentinel files to force re-detection. + * + * When a provider is detected, a background thread polls for termination + * signals every 5 seconds (spot interruption, autoscaling lifecycle changes, + * scheduled maintenance). The termination state is exposed as an atomic bool + * so the compute server can include it in coordinator announcements and react + * to imminent shutdown. + * + * Thread safety: GetInstanceInfo() and GetTerminationReason() acquire a + * shared RwLock; the background monitor thread acquires the exclusive lock + * only when writing the termination reason (a one-time transition). The + * termination-pending flag itself is a lock-free atomic. + * + * Usage: + * auto Cloud = std::make_unique<CloudMetadata>(DataDir / "cloud"); + * if (Cloud->IsTerminationPending()) { ... } + * Cloud->Describe(AnnounceBody); // writes "cloud" sub-object into CB + */ +class CloudMetadata +{ +public: + /** Synchronously probes cloud providers and starts the termination monitor + * if a provider is detected. Creates DataDir if it does not exist. + */ + explicit CloudMetadata(std::filesystem::path DataDir); + + /** Synchronously probes cloud providers at the given IMDS endpoint. + * Intended for testing — allows redirecting all IMDS queries to a local + * mock HTTP server instead of the real 169.254.169.254 endpoint. + */ + CloudMetadata(std::filesystem::path DataDir, std::string ImdsEndpoint); + + /** Stops the termination monitor thread and joins it. */ + ~CloudMetadata(); + + CloudMetadata(const CloudMetadata&) = delete; + CloudMetadata& operator=(const CloudMetadata&) = delete; + + CloudProvider GetProvider() const; + CloudInstanceInfo GetInstanceInfo() const; + bool IsTerminationPending() const; + std::string GetTerminationReason() const; + + /** Writes a "cloud" sub-object into the compact binary writer if a provider + * was detected. No-op when running on bare metal. + */ + void Describe(CbWriter& Writer) const; + + /** Executes a single termination-poll cycle for the detected provider. + * Public so tests can drive poll cycles synchronously without relying on + * the background thread's 5-second timer. + */ + void PollTermination(); + + /** Removes the negative-cache sentinel files (.isNotAWS, .isNotAzure, + * .isNotGCP) from DataDir so subsequent detection probes are not skipped. + * Primarily intended for tests that need to reset state between sub-cases. + */ + void ClearSentinelFiles(); + +private: + /** Tries each provider in order, stops on first successful detection. */ + void DetectProvider(); + bool TryDetectAWS(); + bool TryDetectAzure(); + bool TryDetectGCP(); + + void WriteSentinelFile(const std::filesystem::path& Path); + bool HasSentinelFile(const std::filesystem::path& Path) const; + + void StartTerminationMonitor(); + void TerminationMonitorThread(); + void PollAWSTermination(); + void PollAzureTermination(); + void PollGCPTermination(); + + LoggerRef Log() { return m_Log; } + + LoggerRef m_Log; + std::filesystem::path m_DataDir; + std::string m_ImdsEndpoint; + + mutable RwLock m_InfoLock; + CloudInstanceInfo m_Info; + + std::atomic<bool> m_TerminationPending{false}; + + mutable RwLock m_ReasonLock; + std::string m_TerminationReason; + + // IMDSv2 session token, acquired during AWS detection and reused for + // subsequent termination polling. Has a 300s TTL on the AWS side; if it + // expires mid-run the poll requests will get 401s which we treat as + // non-terminal (the monitor simply retries next cycle). + std::string m_AwsToken; + + std::thread m_MonitorThread; + std::atomic<bool> m_MonitorEnabled{true}; + Event m_MonitorEvent; +}; + +void cloudmetadata_forcelink(); // internal + +} // namespace zen::compute diff --git a/src/zencompute/include/zencompute/computeservice.h b/src/zencompute/include/zencompute/computeservice.h new file mode 100644 index 000000000..65ec5f9ee --- /dev/null +++ b/src/zencompute/include/zencompute/computeservice.h @@ -0,0 +1,262 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencompute/zencompute.h> + +#if ZEN_WITH_COMPUTE_SERVICES + +# include <zencore/compactbinary.h> +# include <zencore/compactbinarypackage.h> +# include <zencore/iohash.h> +# include <zenstore/zenstore.h> +# include <zenhttp/httpcommon.h> + +# include <filesystem> + +namespace zen { +class ChunkResolver; +class CbObjectWriter; +} // namespace zen + +namespace zen::compute { + +class ActionRecorder; +class ComputeServiceSession; +class IActionResultHandler; +class LocalProcessRunner; +class RemoteHttpRunner; +struct RunnerAction; +struct SubmitResult; + +struct WorkerDesc +{ + CbPackage Descriptor; + IoHash WorkerId{IoHash::Zero}; + + inline operator bool() const { return WorkerId != IoHash::Zero; } +}; + +/** + * Lambda style compute function service + * + * The responsibility of this class is to accept function execution requests, and + * schedule them using one or more FunctionRunner instances. It will basically always + * accept requests, queueing them if necessary, and then hand them off to runners + * as they become available. + * + * This is typically fronted by an API service that handles communication with clients. + */ +class ComputeServiceSession final +{ +public: + /** + * Session lifecycle state machine. + * + * Forward transitions: Created -> Ready -> Draining -> Paused -> Sunset + * Backward transitions: Draining -> Ready, Paused -> Ready + * Automatic transition: Draining -> Paused (when pending + running reaches 0) + * Jump transitions: any non-terminal -> Abandoned, any non-terminal -> Sunset + * Terminal states: Abandoned (only Sunset out), Sunset (no transitions out) + * + * | State | Accept new actions | Schedule pending | Finish running | + * |-----------|-------------------|-----------------|----------------| + * | Created | No | No | N/A | + * | Ready | Yes | Yes | Yes | + * | Draining | No | Yes | Yes | + * | Paused | No | No | No | + * | Abandoned | No | No | No (all abandoned) | + * | Sunset | No | No | No | + */ + enum class SessionState + { + Created, // Initial state before WaitUntilReady completes + Ready, // Normal operating state; accepts and schedules work + Draining, // Stops accepting new work; finishes existing; auto-transitions to Paused when empty + Paused, // Idle; no work accepted or scheduled; can resume to Ready + Abandoned, // Spot termination grace period; all actions abandoned; only Sunset out + Sunset // Terminal; triggers full shutdown + }; + + ComputeServiceSession(ChunkResolver& InChunkResolver); + ~ComputeServiceSession(); + + void WaitUntilReady(); + void Shutdown(); + bool IsHealthy(); + + SessionState GetSessionState() const; + + // Request a state transition. Returns false if the transition is invalid. + // Sunset can be reached from any non-Sunset state. + bool RequestStateTransition(SessionState NewState); + + // Orchestration + + void SetOrchestratorEndpoint(std::string_view Endpoint); + void SetOrchestratorBasePath(std::filesystem::path BasePath); + + // Worker registration and discovery + + void RegisterWorker(CbPackage Worker); + [[nodiscard]] WorkerDesc GetWorkerDescriptor(const IoHash& WorkerId); + [[nodiscard]] std::vector<IoHash> GetKnownWorkerIds(); + + // Action runners + + void AddLocalRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath, int32_t MaxConcurrentActions = 0); + void AddRemoteRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath, std::string_view HostName); + + // Action submission + + struct EnqueueResult + { + int Lsn; + CbObject ResponseMessage; + + inline operator bool() const { return Lsn != 0; } + }; + + [[nodiscard]] EnqueueResult EnqueueResolvedAction(WorkerDesc Worker, CbObject ActionObj, int Priority); + [[nodiscard]] EnqueueResult EnqueueAction(CbObject ActionObject, int Priority); + + // Queue management + // + // Queues group actions submitted by a single client session. They allow + // cancelling or polling completion of all actions in the group. + + struct CreateQueueResult + { + int QueueId = 0; // 0 if creation failed + }; + + enum class QueueState + { + Active, + Draining, + Cancelled, + }; + + struct QueueStatus + { + bool IsValid = false; + int QueueId = 0; + int ActiveCount = 0; // pending + running (not yet completed) + int CompletedCount = 0; // successfully completed + int FailedCount = 0; // failed + int AbandonedCount = 0; // abandoned + int CancelledCount = 0; // cancelled + QueueState State = QueueState::Active; + bool IsComplete = false; // ActiveCount == 0 + }; + + [[nodiscard]] CreateQueueResult CreateQueue(std::string_view Tag = {}, CbObject Metadata = {}, CbObject Config = {}); + [[nodiscard]] std::vector<int> GetQueueIds(); + [[nodiscard]] QueueStatus GetQueueStatus(int QueueId); + [[nodiscard]] CbObject GetQueueMetadata(int QueueId); + [[nodiscard]] CbObject GetQueueConfig(int QueueId); + void CancelQueue(int QueueId); + void DrainQueue(int QueueId); + void DeleteQueue(int QueueId); + void GetQueueCompleted(int QueueId, CbWriter& Cbo); + + // Queue-scoped action submission. Actions submitted via these methods are + // tracked under the given queue in addition to the global LSN-based tracking. + + [[nodiscard]] EnqueueResult EnqueueActionToQueue(int QueueId, CbObject ActionObject, int Priority); + [[nodiscard]] EnqueueResult EnqueueResolvedActionToQueue(int QueueId, WorkerDesc Worker, CbObject ActionObj, int Priority); + + // Completed action tracking + + [[nodiscard]] HttpResponseCode GetActionResult(int ActionLsn, CbPackage& OutResultPackage); + [[nodiscard]] HttpResponseCode FindActionResult(const IoHash& ActionId, CbPackage& ResultPackage); + void RetireActionResult(int ActionLsn); + + // Action rescheduling + + struct RescheduleResult + { + bool Success = false; + std::string Error; + int RetryCount = 0; + }; + + [[nodiscard]] RescheduleResult RescheduleAction(int ActionLsn); + + void GetCompleted(CbWriter&); + + // Running action tracking + + struct RunningActionInfo + { + int Lsn; + int QueueId; + IoHash ActionId; + float CpuUsagePercent; // -1.0 if not yet sampled + float CpuSeconds; // 0.0 if not yet sampled + }; + + [[nodiscard]] std::vector<RunningActionInfo> GetRunningActions(); + + // Action history tracking (note that this is separate from completed action tracking, and + // will include actions which have been retired and no longer have their results available) + + struct ActionHistoryEntry + { + int Lsn; + int QueueId = 0; + IoHash ActionId; + IoHash WorkerId; + CbObject ActionDescriptor; + std::string ExecutionLocation; + bool Succeeded; + float CpuSeconds = 0.0f; // total CPU time at completion; 0.0 if not sampled + int RetryCount = 0; // number of times this action was rescheduled + // sized to match RunnerAction::State::_Count but we can't use the enum here + // for dependency reasons, so just use a fixed size array and static assert in + // the implementation file + uint64_t Timestamps[8] = {}; + }; + + [[nodiscard]] std::vector<ActionHistoryEntry> GetActionHistory(int Limit = 100); + [[nodiscard]] std::vector<ActionHistoryEntry> GetQueueHistory(int QueueId, int Limit = 100); + + // Stats reporting + + struct ActionCounts + { + int Pending = 0; + int Running = 0; + int Completed = 0; + int ActiveQueues = 0; + }; + + [[nodiscard]] ActionCounts GetActionCounts(); + + void EmitStats(CbObjectWriter& Cbo); + + // Recording + + void StartRecording(ChunkResolver& InResolver, const std::filesystem::path& RecordingPath); + void StopRecording(); + +private: + void PostUpdate(RunnerAction* Action); + + friend class FunctionRunner; + friend struct RunnerAction; + + struct Impl; + std::unique_ptr<Impl> m_Impl; +}; + +void computeservice_forcelink(); + +} // namespace zen::compute + +namespace zen { +const char* ToString(compute::ComputeServiceSession::SessionState State); +const char* ToString(compute::ComputeServiceSession::QueueState State); +} // namespace zen + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/include/zencompute/functionservice.h b/src/zencompute/include/zencompute/functionservice.h deleted file mode 100644 index 1deb99fd5..000000000 --- a/src/zencompute/include/zencompute/functionservice.h +++ /dev/null @@ -1,132 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#pragma once - -#include <zencore/zencore.h> - -#if !defined(ZEN_WITH_COMPUTE_SERVICES) -# define ZEN_WITH_COMPUTE_SERVICES 1 -#endif - -#if ZEN_WITH_COMPUTE_SERVICES - -# include <zencore/compactbinary.h> -# include <zencore/compactbinarypackage.h> -# include <zencore/iohash.h> -# include <zenstore/zenstore.h> -# include <zenhttp/httpcommon.h> - -# include <filesystem> - -namespace zen { -class ChunkResolver; -class CbObjectWriter; -} // namespace zen - -namespace zen::compute { - -class ActionRecorder; -class FunctionServiceSession; -class IActionResultHandler; -class LocalProcessRunner; -class RemoteHttpRunner; -struct RunnerAction; -struct SubmitResult; - -struct WorkerDesc -{ - CbPackage Descriptor; - IoHash WorkerId{IoHash::Zero}; - - inline operator bool() const { return WorkerId != IoHash::Zero; } -}; - -/** - * Lambda style compute function service - * - * The responsibility of this class is to accept function execution requests, and - * schedule them using one or more FunctionRunner instances. It will basically always - * accept requests, queueing them if necessary, and then hand them off to runners - * as they become available. - * - * This is typically fronted by an API service that handles communication with clients. - */ -class FunctionServiceSession final -{ -public: - FunctionServiceSession(ChunkResolver& InChunkResolver); - ~FunctionServiceSession(); - - void Shutdown(); - bool IsHealthy(); - - // Worker registration and discovery - - void RegisterWorker(CbPackage Worker); - [[nodiscard]] WorkerDesc GetWorkerDescriptor(const IoHash& WorkerId); - [[nodiscard]] std::vector<IoHash> GetKnownWorkerIds(); - - // Action runners - - void AddLocalRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath); - void AddRemoteRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath, std::string_view HostName); - - // Action submission - - struct EnqueueResult - { - int Lsn; - CbObject ResponseMessage; - - inline operator bool() const { return Lsn != 0; } - }; - - [[nodiscard]] EnqueueResult EnqueueResolvedAction(WorkerDesc Worker, CbObject ActionObj, int Priority); - [[nodiscard]] EnqueueResult EnqueueAction(CbObject ActionObject, int Priority); - - // Completed action tracking - - [[nodiscard]] HttpResponseCode GetActionResult(int ActionLsn, CbPackage& OutResultPackage); - [[nodiscard]] HttpResponseCode FindActionResult(const IoHash& ActionId, CbPackage& ResultPackage); - - void GetCompleted(CbWriter&); - - // Action history tracking (note that this is separate from completed action tracking, and - // will include actions which have been retired and no longer have their results available) - - struct ActionHistoryEntry - { - int Lsn; - IoHash ActionId; - IoHash WorkerId; - CbObject ActionDescriptor; - bool Succeeded; - uint64_t Timestamps[5] = {}; - }; - - [[nodiscard]] std::vector<ActionHistoryEntry> GetActionHistory(int Limit = 100); - - // Stats reporting - - void EmitStats(CbObjectWriter& Cbo); - - // Recording - - void StartRecording(ChunkResolver& InResolver, const std::filesystem::path& RecordingPath); - void StopRecording(); - -private: - void PostUpdate(RunnerAction* Action); - - friend class FunctionRunner; - friend struct RunnerAction; - - struct Impl; - std::unique_ptr<Impl> m_Impl; -}; - -void function_forcelink(); - -} // namespace zen::compute - -#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/include/zencompute/httpcomputeservice.h b/src/zencompute/include/zencompute/httpcomputeservice.h new file mode 100644 index 000000000..ee1cd2614 --- /dev/null +++ b/src/zencompute/include/zencompute/httpcomputeservice.h @@ -0,0 +1,54 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencompute/zencompute.h> + +#if ZEN_WITH_COMPUTE_SERVICES + +# include "zencompute/computeservice.h" + +# include <zenhttp/httpserver.h> + +# include <filesystem> +# include <memory> + +namespace zen { +class CidStore; +} + +namespace zen::compute { + +/** + * HTTP interface for compute service + */ +class HttpComputeService : public HttpService, public IHttpStatsProvider +{ +public: + HttpComputeService(CidStore& InCidStore, + IHttpStatsService& StatsService, + const std::filesystem::path& BaseDir, + int32_t MaxConcurrentActions = 0); + ~HttpComputeService(); + + void Shutdown(); + + [[nodiscard]] ComputeServiceSession::ActionCounts GetActionCounts(); + + const char* BaseUri() const override; + void HandleRequest(HttpServerRequest& Request) override; + + // IHttpStatsProvider + + void HandleStatsRequest(HttpServerRequest& Request) override; + +private: + struct Impl; + std::unique_ptr<Impl> m_Impl; +}; + +void httpcomputeservice_forcelink(); + +} // namespace zen::compute + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/include/zencompute/httpfunctionservice.h b/src/zencompute/include/zencompute/httpfunctionservice.h deleted file mode 100644 index 6e2344ae6..000000000 --- a/src/zencompute/include/zencompute/httpfunctionservice.h +++ /dev/null @@ -1,73 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#pragma once - -#include <zencore/zencore.h> - -#if !defined(ZEN_WITH_COMPUTE_SERVICES) -# define ZEN_WITH_COMPUTE_SERVICES 1 -#endif - -#if ZEN_WITH_COMPUTE_SERVICES - -# include "zencompute/functionservice.h" - -# include <zencore/compactbinary.h> -# include <zencore/compactbinarypackage.h> -# include <zencore/iohash.h> -# include <zencore/logging.h> -# include <zentelemetry/stats.h> -# include <zenhttp/httpserver.h> - -# include <deque> -# include <filesystem> -# include <unordered_map> - -namespace zen { -class CidStore; -} - -namespace zen::compute { - -class HttpFunctionService; -class FunctionService; - -/** - * HTTP interface for compute function service - */ -class HttpFunctionService : public HttpService, public IHttpStatsProvider -{ -public: - HttpFunctionService(CidStore& InCidStore, IHttpStatsService& StatsService, const std::filesystem::path& BaseDir); - ~HttpFunctionService(); - - void Shutdown(); - - virtual const char* BaseUri() const override; - virtual void HandleRequest(HttpServerRequest& Request) override; - - // IHttpStatsProvider - - virtual void HandleStatsRequest(HttpServerRequest& Request) override; - -protected: - CidStore& m_CidStore; - IHttpStatsService& m_StatsService; - LoggerRef Log() { return m_Log; } - -private: - LoggerRef m_Log; - std::filesystem ::path m_BaseDir; - HttpRequestRouter m_Router; - FunctionServiceSession m_FunctionService; - - // Metrics - - metrics::OperationTiming m_HttpRequests; -}; - -void httpfunction_forcelink(); - -} // namespace zen::compute - -#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/include/zencompute/httporchestrator.h b/src/zencompute/include/zencompute/httporchestrator.h index 168c6d7fe..da5c5dfc3 100644 --- a/src/zencompute/include/zencompute/httporchestrator.h +++ b/src/zencompute/include/zencompute/httporchestrator.h @@ -2,43 +2,100 @@ #pragma once +#include <zencompute/zencompute.h> + #include <zencore/logging.h> #include <zencore/thread.h> -#include <zencore/timer.h> #include <zenhttp/httpserver.h> +#include <zenhttp/websocket.h> +#include <atomic> +#include <filesystem> +#include <memory> +#include <string> +#include <thread> #include <unordered_map> +#include <vector> + +#define ZEN_WITH_WEBSOCKETS 1 namespace zen::compute { +class OrchestratorService; + +// Experimental helper, to see if we can get rid of some error-prone +// boilerplate when declaring loggers as class members. + +class LoggerHelper +{ +public: + LoggerHelper(std::string_view Logger) : m_Log(logging::Get(Logger)) {} + + LoggerRef operator()() { return m_Log; } + +private: + LoggerRef m_Log; +}; + /** - * Mock orchestrator service, for testing dynamic provisioning + * Orchestrator HTTP service with WebSocket push support + * + * Normal HTTP requests are routed through the HttpRequestRouter as before. + * WebSocket clients connecting to /orch/ws receive periodic state broadcasts + * from a dedicated push thread, eliminating the need for polling. */ class HttpOrchestratorService : public HttpService +#if ZEN_WITH_WEBSOCKETS +, + public IWebSocketHandler +#endif { public: - HttpOrchestratorService(); + explicit HttpOrchestratorService(std::filesystem::path DataDir, bool EnableWorkerWebSocket = false); ~HttpOrchestratorService(); HttpOrchestratorService(const HttpOrchestratorService&) = delete; HttpOrchestratorService& operator=(const HttpOrchestratorService&) = delete; + /** + * Gracefully shut down the WebSocket push thread and release connections. + * Must be called while the ASIO io_context is still alive. The destructor + * also calls this, so it is safe (but not ideal) to omit the explicit call. + */ + void Shutdown(); + virtual const char* BaseUri() const override; virtual void HandleRequest(HttpServerRequest& Request) override; + // IWebSocketHandler +#if ZEN_WITH_WEBSOCKETS + void OnWebSocketOpen(Ref<WebSocketConnection> Connection) override; + void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) override; + void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason) override; +#endif + private: - HttpRequestRouter m_Router; - LoggerRef m_Log; + HttpRequestRouter m_Router; + LoggerHelper Log{"orch"}; + std::unique_ptr<OrchestratorService> m_Service; + std::string m_Hostname; + + // WebSocket push - struct KnownWorker - { - std::string_view BaseUri; - Stopwatch LastSeen; - }; +#if ZEN_WITH_WEBSOCKETS + RwLock m_WsConnectionsLock; + std::vector<Ref<WebSocketConnection>> m_WsConnections; + std::thread m_PushThread; + std::atomic<bool> m_PushEnabled{false}; + Event m_PushEvent; + void PushThreadFunction(); - RwLock m_KnownWorkersLock; - std::unordered_map<std::string, KnownWorker> m_KnownWorkers; + // Worker WebSocket connections (worker→orchestrator persistent links) + RwLock m_WorkerWsLock; + std::unordered_map<WebSocketConnection*, std::string> m_WorkerWsMap; // connection ptr → worker ID + std::string HandleWorkerWebSocketMessage(const WebSocketMessage& Msg); +#endif }; } // namespace zen::compute diff --git a/src/zencompute/include/zencompute/mockimds.h b/src/zencompute/include/zencompute/mockimds.h new file mode 100644 index 000000000..521722e63 --- /dev/null +++ b/src/zencompute/include/zencompute/mockimds.h @@ -0,0 +1,102 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencompute/cloudmetadata.h> +#include <zenhttp/httpserver.h> + +#include <string> + +#if ZEN_WITH_TESTS + +namespace zen::compute { + +/** + * Mock IMDS (Instance Metadata Service) for testing CloudMetadata. + * + * Implements an HttpService that responds to the same URL paths as the real + * cloud provider metadata endpoints (AWS IMDSv2, Azure IMDS, GCP metadata). + * Tests configure which provider is "active" and set the desired response + * values, then pass the mock server's address as the ImdsEndpoint to the + * CloudMetadata constructor. + * + * When a request arrives for a provider that is not the ActiveProvider, the + * mock returns 404, causing CloudMetadata to write a sentinel file and move + * on to the next provider — exactly like a failed probe on bare metal. + * + * All config fields are public and can be mutated between poll cycles to + * simulate state changes (e.g. a spot interruption appearing mid-run). + * + * Usage: + * MockImdsService Mock; + * Mock.ActiveProvider = CloudProvider::AWS; + * Mock.Aws.InstanceId = "i-test"; + * // ... stand up ASIO server, register Mock, create CloudMetadata with endpoint + */ +class MockImdsService : public HttpService +{ +public: + /** AWS IMDSv2 response configuration. */ + struct AwsConfig + { + std::string Token = "mock-aws-token-v2"; + std::string InstanceId = "i-0123456789abcdef0"; + std::string AvailabilityZone = "us-east-1a"; + std::string LifeCycle = "on-demand"; // "spot" or "on-demand" + + // Empty string → endpoint returns 404 (instance not in an ASG). + // Non-empty → returned as the response body. "InService" means healthy; + // anything else (e.g. "Terminated:Wait") triggers termination detection. + std::string AutoscalingState; + + // Empty string → endpoint returns 404 (no spot interruption). + // Non-empty → returned as the response body, signalling a spot reclaim. + std::string SpotAction; + }; + + /** Azure IMDS response configuration. */ + struct AzureConfig + { + std::string VmId = "vm-12345678-1234-1234-1234-123456789abc"; + std::string Location = "eastus"; + std::string Priority = "Regular"; // "Spot" or "Regular" + + // Empty → instance is not in a VM Scale Set (no autoscaling). + std::string VmScaleSetName; + + // Empty → no scheduled events. Set to "Preempt", "Terminate", or + // "Reboot" to simulate a termination-class event. + std::string ScheduledEventType; + std::string ScheduledEventStatus = "Scheduled"; + }; + + /** GCP metadata response configuration. */ + struct GcpConfig + { + std::string InstanceId = "1234567890123456789"; + std::string Zone = "projects/123456/zones/us-central1-a"; + std::string Preemptible = "FALSE"; // "TRUE" or "FALSE" + std::string MaintenanceEvent = "NONE"; // "NONE" or event description + }; + + /** Which provider's endpoints respond successfully. + * Requests targeting other providers receive 404. + */ + CloudProvider ActiveProvider = CloudProvider::None; + + AwsConfig Aws; + AzureConfig Azure; + GcpConfig Gcp; + + const char* BaseUri() const override; + void HandleRequest(HttpServerRequest& Request) override; + +private: + void HandleAwsRequest(HttpServerRequest& Request); + void HandleAzureRequest(HttpServerRequest& Request); + void HandleGcpRequest(HttpServerRequest& Request); +}; + +} // namespace zen::compute + +#endif // ZEN_WITH_TESTS diff --git a/src/zencompute/include/zencompute/orchestratorservice.h b/src/zencompute/include/zencompute/orchestratorservice.h new file mode 100644 index 000000000..071e902b3 --- /dev/null +++ b/src/zencompute/include/zencompute/orchestratorservice.h @@ -0,0 +1,177 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencompute/zencompute.h> + +#if ZEN_WITH_COMPUTE_SERVICES + +# include <zencore/compactbinary.h> +# include <zencore/thread.h> +# include <zencore/timer.h> +# include <zencore/uid.h> + +# include <deque> +# include <optional> +# include <filesystem> +# include <memory> +# include <string> +# include <string_view> +# include <thread> +# include <unordered_map> + +namespace zen::compute { + +class WorkerTimelineStore; + +class OrchestratorService +{ +public: + explicit OrchestratorService(std::filesystem::path DataDir, bool EnableWorkerWebSocket = false); + ~OrchestratorService(); + + OrchestratorService(const OrchestratorService&) = delete; + OrchestratorService& operator=(const OrchestratorService&) = delete; + + struct WorkerAnnouncement + { + std::string_view Id; + std::string_view Uri; + std::string_view Hostname; + std::string_view Platform; // e.g. "windows", "wine", "linux", "macos" + int Cpus = 0; + float CpuUsagePercent = 0.0f; + uint64_t MemoryTotalBytes = 0; + uint64_t MemoryUsedBytes = 0; + uint64_t BytesReceived = 0; + uint64_t BytesSent = 0; + int ActionsPending = 0; + int ActionsRunning = 0; + int ActionsCompleted = 0; + int ActiveQueues = 0; + std::string_view Provisioner; // e.g. "horde", "nomad", or empty + }; + + struct ProvisioningEvent + { + enum class Type + { + Joined, + Left, + Returned + }; + Type EventType; + DateTime Timestamp; + std::string WorkerId; + std::string Hostname; + }; + + struct ClientAnnouncement + { + Oid SessionId; + std::string_view Hostname; + std::string_view Address; + CbObject Metadata; + }; + + struct ClientEvent + { + enum class Type + { + Connected, + Disconnected, + Updated + }; + Type EventType; + DateTime Timestamp; + std::string ClientId; + std::string Hostname; + }; + + CbObject GetWorkerList(); + void AnnounceWorker(const WorkerAnnouncement& Announcement); + + bool IsWorkerWebSocketEnabled() const; + void SetWorkerWebSocketConnected(std::string_view WorkerId, bool Connected); + + CbObject GetProvisioningHistory(int Limit = 100); + + CbObject GetWorkerTimeline(std::string_view WorkerId, std::optional<DateTime> From, std::optional<DateTime> To, int Limit); + + CbObject GetAllTimelines(DateTime From, DateTime To); + + std::string AnnounceClient(const ClientAnnouncement& Announcement); + bool UpdateClient(std::string_view ClientId, CbObject Metadata = {}); + bool CompleteClient(std::string_view ClientId); + CbObject GetClientList(); + CbObject GetClientHistory(int Limit = 100); + +private: + enum class ReachableState + { + Unknown, + Reachable, + Unreachable, + }; + + struct KnownWorker + { + std::string BaseUri; + Stopwatch LastSeen; + std::string Hostname; + std::string Platform; + int Cpus = 0; + float CpuUsagePercent = 0.0f; + uint64_t MemoryTotalBytes = 0; + uint64_t MemoryUsedBytes = 0; + uint64_t BytesReceived = 0; + uint64_t BytesSent = 0; + int ActionsPending = 0; + int ActionsRunning = 0; + int ActionsCompleted = 0; + int ActiveQueues = 0; + std::string Provisioner; + ReachableState Reachable = ReachableState::Unknown; + bool WsConnected = false; + Stopwatch LastProbed; + }; + + RwLock m_KnownWorkersLock; + std::unordered_map<std::string, KnownWorker> m_KnownWorkers; + std::unique_ptr<WorkerTimelineStore> m_TimelineStore; + + RwLock m_ProvisioningLogLock; + std::deque<ProvisioningEvent> m_ProvisioningLog; + static constexpr size_t kMaxProvisioningEvents = 1000; + + void RecordProvisioningEvent(ProvisioningEvent::Type Type, std::string_view WorkerId, std::string_view Hostname); + + struct KnownClient + { + Oid SessionId; + std::string Hostname; + std::string Address; + Stopwatch LastSeen; + CbObject Metadata; + }; + + RwLock m_KnownClientsLock; + std::unordered_map<std::string, KnownClient> m_KnownClients; + + RwLock m_ClientLogLock; + std::deque<ClientEvent> m_ClientLog; + static constexpr size_t kMaxClientEvents = 1000; + + void RecordClientEvent(ClientEvent::Type Type, std::string_view ClientId, std::string_view Hostname); + + bool m_EnableWorkerWebSocket = false; + + std::thread m_ProbeThread; + std::atomic<bool> m_ProbeThreadEnabled{true}; + Event m_ProbeThreadEvent; + void ProbeThreadFunction(); +}; + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/include/zencompute/recordingreader.h b/src/zencompute/include/zencompute/recordingreader.h index bf1aff125..3f233fae0 100644 --- a/src/zencompute/include/zencompute/recordingreader.h +++ b/src/zencompute/include/zencompute/recordingreader.h @@ -2,7 +2,9 @@ #pragma once -#include <zencompute/functionservice.h> +#include <zencompute/zencompute.h> + +#include <zencompute/computeservice.h> #include <zencompute/zencompute.h> #include <zencore/basicfile.h> #include <zencore/compactbinarybuilder.h> diff --git a/src/zencompute/include/zencompute/zencompute.h b/src/zencompute/include/zencompute/zencompute.h index 6dc32eeea..00be4d4a0 100644 --- a/src/zencompute/include/zencompute/zencompute.h +++ b/src/zencompute/include/zencompute/zencompute.h @@ -4,6 +4,10 @@ #include <zencore/zencore.h> +#if !defined(ZEN_WITH_COMPUTE_SERVICES) +# define ZEN_WITH_COMPUTE_SERVICES 1 +#endif + namespace zen { void zencompute_forcelinktests(); diff --git a/src/zencompute/orchestratorservice.cpp b/src/zencompute/orchestratorservice.cpp new file mode 100644 index 000000000..9ea695305 --- /dev/null +++ b/src/zencompute/orchestratorservice.cpp @@ -0,0 +1,710 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencompute/orchestratorservice.h> + +#if ZEN_WITH_COMPUTE_SERVICES + +# include <zencore/compactbinarybuilder.h> +# include <zencore/logging.h> +# include <zencore/trace.h> +# include <zenhttp/httpclient.h> + +# include "timeline/workertimeline.h" + +namespace zen::compute { + +OrchestratorService::OrchestratorService(std::filesystem::path DataDir, bool EnableWorkerWebSocket) +: m_TimelineStore(std::make_unique<WorkerTimelineStore>(DataDir / "timelines")) +, m_EnableWorkerWebSocket(EnableWorkerWebSocket) +{ + m_ProbeThread = std::thread{&OrchestratorService::ProbeThreadFunction, this}; +} + +OrchestratorService::~OrchestratorService() +{ + m_ProbeThreadEnabled = false; + m_ProbeThreadEvent.Set(); + if (m_ProbeThread.joinable()) + { + m_ProbeThread.join(); + } +} + +CbObject +OrchestratorService::GetWorkerList() +{ + ZEN_TRACE_CPU("OrchestratorService::GetWorkerList"); + CbObjectWriter Cbo; + Cbo.BeginArray("workers"); + + m_KnownWorkersLock.WithSharedLock([&] { + for (const auto& [WorkerId, Worker] : m_KnownWorkers) + { + Cbo.BeginObject(); + Cbo << "id" << WorkerId; + Cbo << "uri" << Worker.BaseUri; + Cbo << "hostname" << Worker.Hostname; + if (!Worker.Platform.empty()) + { + Cbo << "platform" << std::string_view(Worker.Platform); + } + Cbo << "cpus" << Worker.Cpus; + Cbo << "cpu_usage" << Worker.CpuUsagePercent; + Cbo << "memory_total" << Worker.MemoryTotalBytes; + Cbo << "memory_used" << Worker.MemoryUsedBytes; + Cbo << "bytes_received" << Worker.BytesReceived; + Cbo << "bytes_sent" << Worker.BytesSent; + Cbo << "actions_pending" << Worker.ActionsPending; + Cbo << "actions_running" << Worker.ActionsRunning; + Cbo << "actions_completed" << Worker.ActionsCompleted; + Cbo << "active_queues" << Worker.ActiveQueues; + if (!Worker.Provisioner.empty()) + { + Cbo << "provisioner" << std::string_view(Worker.Provisioner); + } + if (Worker.Reachable != ReachableState::Unknown) + { + Cbo << "reachable" << (Worker.Reachable == ReachableState::Reachable); + } + if (Worker.WsConnected) + { + Cbo << "ws_connected" << true; + } + Cbo << "dt" << Worker.LastSeen.GetElapsedTimeMs(); + Cbo.EndObject(); + } + }); + + Cbo.EndArray(); + return Cbo.Save(); +} + +void +OrchestratorService::AnnounceWorker(const WorkerAnnouncement& Ann) +{ + ZEN_TRACE_CPU("OrchestratorService::AnnounceWorker"); + + bool IsNew = false; + std::string EvictedId; + std::string EvictedHostname; + + m_KnownWorkersLock.WithExclusiveLock([&] { + IsNew = (m_KnownWorkers.find(std::string(Ann.Id)) == m_KnownWorkers.end()); + + // If a different worker ID already maps to the same URI, the old entry + // is stale (e.g. a previous Horde lease on the same machine). Remove it + // so the dashboard doesn't show duplicates. + if (IsNew) + { + for (auto It = m_KnownWorkers.begin(); It != m_KnownWorkers.end(); ++It) + { + if (It->second.BaseUri == Ann.Uri && It->first != Ann.Id) + { + EvictedId = It->first; + EvictedHostname = It->second.Hostname; + m_KnownWorkers.erase(It); + break; + } + } + } + + auto& Worker = m_KnownWorkers[std::string(Ann.Id)]; + Worker.BaseUri = Ann.Uri; + Worker.Hostname = Ann.Hostname; + if (!Ann.Platform.empty()) + { + Worker.Platform = Ann.Platform; + } + Worker.Cpus = Ann.Cpus; + Worker.CpuUsagePercent = Ann.CpuUsagePercent; + Worker.MemoryTotalBytes = Ann.MemoryTotalBytes; + Worker.MemoryUsedBytes = Ann.MemoryUsedBytes; + Worker.BytesReceived = Ann.BytesReceived; + Worker.BytesSent = Ann.BytesSent; + Worker.ActionsPending = Ann.ActionsPending; + Worker.ActionsRunning = Ann.ActionsRunning; + Worker.ActionsCompleted = Ann.ActionsCompleted; + Worker.ActiveQueues = Ann.ActiveQueues; + if (!Ann.Provisioner.empty()) + { + Worker.Provisioner = Ann.Provisioner; + } + Worker.LastSeen.Reset(); + }); + + if (!EvictedId.empty()) + { + ZEN_INFO("worker {} superseded by {} (same endpoint)", EvictedId, Ann.Id); + RecordProvisioningEvent(ProvisioningEvent::Type::Left, EvictedId, EvictedHostname); + } + + if (IsNew) + { + RecordProvisioningEvent(ProvisioningEvent::Type::Joined, Ann.Id, Ann.Hostname); + } +} + +bool +OrchestratorService::IsWorkerWebSocketEnabled() const +{ + return m_EnableWorkerWebSocket; +} + +void +OrchestratorService::SetWorkerWebSocketConnected(std::string_view WorkerId, bool Connected) +{ + ReachableState PrevState = ReachableState::Unknown; + std::string WorkerHostname; + + m_KnownWorkersLock.WithExclusiveLock([&] { + auto It = m_KnownWorkers.find(std::string(WorkerId)); + if (It == m_KnownWorkers.end()) + { + return; + } + + PrevState = It->second.Reachable; + WorkerHostname = It->second.Hostname; + It->second.WsConnected = Connected; + It->second.Reachable = Connected ? ReachableState::Reachable : ReachableState::Unreachable; + + if (Connected) + { + ZEN_INFO("worker {} WebSocket connected — marking reachable", WorkerId); + } + else + { + ZEN_WARN("worker {} WebSocket disconnected — marking unreachable", WorkerId); + } + }); + + // Record provisioning events for state transitions outside the lock + if (Connected && PrevState == ReachableState::Unreachable) + { + RecordProvisioningEvent(ProvisioningEvent::Type::Returned, WorkerId, WorkerHostname); + } + else if (!Connected && PrevState == ReachableState::Reachable) + { + RecordProvisioningEvent(ProvisioningEvent::Type::Left, WorkerId, WorkerHostname); + } +} + +CbObject +OrchestratorService::GetWorkerTimeline(std::string_view WorkerId, std::optional<DateTime> From, std::optional<DateTime> To, int Limit) +{ + ZEN_TRACE_CPU("OrchestratorService::GetWorkerTimeline"); + + Ref<WorkerTimeline> Timeline = m_TimelineStore->Find(WorkerId); + if (!Timeline) + { + return {}; + } + + std::vector<WorkerTimeline::Event> Events; + + if (From || To) + { + DateTime StartTime = From.value_or(DateTime(0)); + DateTime EndTime = To.value_or(DateTime::Now()); + Events = Timeline->QueryTimeline(StartTime, EndTime); + } + else if (Limit > 0) + { + Events = Timeline->QueryRecent(Limit); + } + else + { + Events = Timeline->QueryRecent(); + } + + WorkerTimeline::TimeRange Range = Timeline->GetTimeRange(); + + CbObjectWriter Cbo; + Cbo << "worker_id" << WorkerId; + Cbo << "event_count" << static_cast<int32_t>(Timeline->GetEventCount()); + + if (Range) + { + Cbo.AddDateTime("time_first", Range.First); + Cbo.AddDateTime("time_last", Range.Last); + } + + Cbo.BeginArray("events"); + for (const auto& Evt : Events) + { + Cbo.BeginObject(); + Cbo << "type" << WorkerTimeline::ToString(Evt.Type); + Cbo.AddDateTime("ts", Evt.Timestamp); + + if (Evt.ActionLsn != 0) + { + Cbo << "lsn" << Evt.ActionLsn; + Cbo << "action_id" << Evt.ActionId; + } + + if (Evt.Type == WorkerTimeline::EventType::ActionStateChanged) + { + Cbo << "prev_state" << RunnerAction::ToString(Evt.PreviousState); + Cbo << "state" << RunnerAction::ToString(Evt.ActionState); + } + + if (!Evt.Reason.empty()) + { + Cbo << "reason" << std::string_view(Evt.Reason); + } + + Cbo.EndObject(); + } + Cbo.EndArray(); + + return Cbo.Save(); +} + +CbObject +OrchestratorService::GetAllTimelines(DateTime From, DateTime To) +{ + ZEN_TRACE_CPU("OrchestratorService::GetAllTimelines"); + + DateTime StartTime = From; + DateTime EndTime = To; + + auto AllInfo = m_TimelineStore->GetAllWorkerInfo(); + + CbObjectWriter Cbo; + Cbo.AddDateTime("from", StartTime); + Cbo.AddDateTime("to", EndTime); + + Cbo.BeginArray("workers"); + for (const auto& Info : AllInfo) + { + if (!Info.Range || Info.Range.Last < StartTime || Info.Range.First > EndTime) + { + continue; + } + + Cbo.BeginObject(); + Cbo << "worker_id" << Info.WorkerId; + Cbo.AddDateTime("time_first", Info.Range.First); + Cbo.AddDateTime("time_last", Info.Range.Last); + Cbo.EndObject(); + } + Cbo.EndArray(); + + return Cbo.Save(); +} + +void +OrchestratorService::RecordProvisioningEvent(ProvisioningEvent::Type Type, std::string_view WorkerId, std::string_view Hostname) +{ + ProvisioningEvent Evt{ + .EventType = Type, + .Timestamp = DateTime::Now(), + .WorkerId = std::string(WorkerId), + .Hostname = std::string(Hostname), + }; + + m_ProvisioningLogLock.WithExclusiveLock([&] { + m_ProvisioningLog.push_back(std::move(Evt)); + while (m_ProvisioningLog.size() > kMaxProvisioningEvents) + { + m_ProvisioningLog.pop_front(); + } + }); +} + +CbObject +OrchestratorService::GetProvisioningHistory(int Limit) +{ + ZEN_TRACE_CPU("OrchestratorService::GetProvisioningHistory"); + + if (Limit <= 0) + { + Limit = 100; + } + + CbObjectWriter Cbo; + Cbo.BeginArray("events"); + + m_ProvisioningLogLock.WithSharedLock([&] { + // Return last N events, newest first + int Count = 0; + for (auto It = m_ProvisioningLog.rbegin(); It != m_ProvisioningLog.rend() && Count < Limit; ++It, ++Count) + { + const auto& Evt = *It; + Cbo.BeginObject(); + + switch (Evt.EventType) + { + case ProvisioningEvent::Type::Joined: + Cbo << "type" + << "joined"; + break; + case ProvisioningEvent::Type::Left: + Cbo << "type" + << "left"; + break; + case ProvisioningEvent::Type::Returned: + Cbo << "type" + << "returned"; + break; + } + + Cbo.AddDateTime("ts", Evt.Timestamp); + Cbo << "worker_id" << std::string_view(Evt.WorkerId); + Cbo << "hostname" << std::string_view(Evt.Hostname); + Cbo.EndObject(); + } + }); + + Cbo.EndArray(); + return Cbo.Save(); +} + +std::string +OrchestratorService::AnnounceClient(const ClientAnnouncement& Ann) +{ + ZEN_TRACE_CPU("OrchestratorService::AnnounceClient"); + + std::string ClientId = fmt::format("client-{}", Oid::NewOid().ToString()); + + bool IsNew = false; + + m_KnownClientsLock.WithExclusiveLock([&] { + auto It = m_KnownClients.find(ClientId); + IsNew = (It == m_KnownClients.end()); + + auto& Client = m_KnownClients[ClientId]; + Client.SessionId = Ann.SessionId; + Client.Hostname = Ann.Hostname; + if (!Ann.Address.empty()) + { + Client.Address = Ann.Address; + } + if (Ann.Metadata) + { + Client.Metadata = Ann.Metadata; + } + Client.LastSeen.Reset(); + }); + + if (IsNew) + { + RecordClientEvent(ClientEvent::Type::Connected, ClientId, Ann.Hostname); + } + else + { + RecordClientEvent(ClientEvent::Type::Updated, ClientId, Ann.Hostname); + } + + return ClientId; +} + +bool +OrchestratorService::UpdateClient(std::string_view ClientId, CbObject Metadata) +{ + ZEN_TRACE_CPU("OrchestratorService::UpdateClient"); + + bool Found = false; + + m_KnownClientsLock.WithExclusiveLock([&] { + auto It = m_KnownClients.find(std::string(ClientId)); + if (It != m_KnownClients.end()) + { + Found = true; + if (Metadata) + { + It->second.Metadata = std::move(Metadata); + } + It->second.LastSeen.Reset(); + } + }); + + return Found; +} + +bool +OrchestratorService::CompleteClient(std::string_view ClientId) +{ + ZEN_TRACE_CPU("OrchestratorService::CompleteClient"); + + std::string Hostname; + bool Found = false; + + m_KnownClientsLock.WithExclusiveLock([&] { + auto It = m_KnownClients.find(std::string(ClientId)); + if (It != m_KnownClients.end()) + { + Found = true; + Hostname = It->second.Hostname; + m_KnownClients.erase(It); + } + }); + + if (Found) + { + RecordClientEvent(ClientEvent::Type::Disconnected, ClientId, Hostname); + } + + return Found; +} + +CbObject +OrchestratorService::GetClientList() +{ + ZEN_TRACE_CPU("OrchestratorService::GetClientList"); + CbObjectWriter Cbo; + Cbo.BeginArray("clients"); + + m_KnownClientsLock.WithSharedLock([&] { + for (const auto& [ClientId, Client] : m_KnownClients) + { + Cbo.BeginObject(); + Cbo << "id" << ClientId; + if (Client.SessionId) + { + Cbo << "session_id" << Client.SessionId; + } + Cbo << "hostname" << std::string_view(Client.Hostname); + if (!Client.Address.empty()) + { + Cbo << "address" << std::string_view(Client.Address); + } + Cbo << "dt" << Client.LastSeen.GetElapsedTimeMs(); + if (Client.Metadata) + { + Cbo << "metadata" << Client.Metadata; + } + Cbo.EndObject(); + } + }); + + Cbo.EndArray(); + return Cbo.Save(); +} + +CbObject +OrchestratorService::GetClientHistory(int Limit) +{ + ZEN_TRACE_CPU("OrchestratorService::GetClientHistory"); + + if (Limit <= 0) + { + Limit = 100; + } + + CbObjectWriter Cbo; + Cbo.BeginArray("client_events"); + + m_ClientLogLock.WithSharedLock([&] { + int Count = 0; + for (auto It = m_ClientLog.rbegin(); It != m_ClientLog.rend() && Count < Limit; ++It, ++Count) + { + const auto& Evt = *It; + Cbo.BeginObject(); + + switch (Evt.EventType) + { + case ClientEvent::Type::Connected: + Cbo << "type" + << "connected"; + break; + case ClientEvent::Type::Disconnected: + Cbo << "type" + << "disconnected"; + break; + case ClientEvent::Type::Updated: + Cbo << "type" + << "updated"; + break; + } + + Cbo.AddDateTime("ts", Evt.Timestamp); + Cbo << "client_id" << std::string_view(Evt.ClientId); + Cbo << "hostname" << std::string_view(Evt.Hostname); + Cbo.EndObject(); + } + }); + + Cbo.EndArray(); + return Cbo.Save(); +} + +void +OrchestratorService::RecordClientEvent(ClientEvent::Type Type, std::string_view ClientId, std::string_view Hostname) +{ + ClientEvent Evt{ + .EventType = Type, + .Timestamp = DateTime::Now(), + .ClientId = std::string(ClientId), + .Hostname = std::string(Hostname), + }; + + m_ClientLogLock.WithExclusiveLock([&] { + m_ClientLog.push_back(std::move(Evt)); + while (m_ClientLog.size() > kMaxClientEvents) + { + m_ClientLog.pop_front(); + } + }); +} + +void +OrchestratorService::ProbeThreadFunction() +{ + ZEN_TRACE_CPU("OrchestratorService::ProbeThreadFunction"); + SetCurrentThreadName("orch_probe"); + + bool IsFirstProbe = true; + + do + { + if (!IsFirstProbe) + { + m_ProbeThreadEvent.Wait(5'000); + m_ProbeThreadEvent.Reset(); + } + else + { + IsFirstProbe = false; + } + + if (m_ProbeThreadEnabled == false) + { + return; + } + + m_ProbeThreadEvent.Reset(); + + // Snapshot worker IDs and URIs under shared lock + struct WorkerSnapshot + { + std::string Id; + std::string Uri; + bool WsConnected = false; + }; + std::vector<WorkerSnapshot> Snapshots; + + m_KnownWorkersLock.WithSharedLock([&] { + Snapshots.reserve(m_KnownWorkers.size()); + for (const auto& [WorkerId, Worker] : m_KnownWorkers) + { + Snapshots.push_back({WorkerId, Worker.BaseUri, Worker.WsConnected}); + } + }); + + // Probe each worker outside the lock + for (const auto& Snap : Snapshots) + { + if (m_ProbeThreadEnabled == false) + { + return; + } + + // Workers with an active WebSocket connection are known-reachable; + // skip the HTTP health probe for them. + if (Snap.WsConnected) + { + continue; + } + + ReachableState NewState = ReachableState::Unreachable; + + try + { + HttpClient Client(Snap.Uri, + {.ConnectTimeout = std::chrono::milliseconds{3000}, .Timeout = std::chrono::milliseconds{5000}}); + HttpClient::Response Response = Client.Get("/health/"); + if (Response.IsSuccess()) + { + NewState = ReachableState::Reachable; + } + } + catch (const std::exception& Ex) + { + ZEN_WARN("probe failed for worker {} ({}): {}", Snap.Id, Snap.Uri, Ex.what()); + } + + ReachableState PrevState = ReachableState::Unknown; + std::string WorkerHostname; + + m_KnownWorkersLock.WithExclusiveLock([&] { + auto It = m_KnownWorkers.find(Snap.Id); + if (It != m_KnownWorkers.end()) + { + PrevState = It->second.Reachable; + WorkerHostname = It->second.Hostname; + It->second.Reachable = NewState; + It->second.LastProbed.Reset(); + + if (PrevState != NewState) + { + if (NewState == ReachableState::Reachable && PrevState == ReachableState::Unreachable) + { + ZEN_INFO("worker {} ({}) is reachable again", Snap.Id, Snap.Uri); + } + else if (NewState == ReachableState::Reachable) + { + ZEN_INFO("worker {} ({}) is now reachable", Snap.Id, Snap.Uri); + } + else if (PrevState == ReachableState::Reachable) + { + ZEN_WARN("worker {} ({}) is no longer reachable", Snap.Id, Snap.Uri); + } + else + { + ZEN_WARN("worker {} ({}) is not reachable", Snap.Id, Snap.Uri); + } + } + } + }); + + // Record provisioning events for state transitions outside the lock + if (PrevState != NewState) + { + if (NewState == ReachableState::Unreachable && PrevState == ReachableState::Reachable) + { + RecordProvisioningEvent(ProvisioningEvent::Type::Left, Snap.Id, WorkerHostname); + } + else if (NewState == ReachableState::Reachable && PrevState == ReachableState::Unreachable) + { + RecordProvisioningEvent(ProvisioningEvent::Type::Returned, Snap.Id, WorkerHostname); + } + } + } + + // Sweep expired clients (5-minute timeout) + static constexpr int64_t kClientTimeoutMs = 5 * 60 * 1000; + + struct ExpiredClient + { + std::string Id; + std::string Hostname; + }; + std::vector<ExpiredClient> ExpiredClients; + + m_KnownClientsLock.WithExclusiveLock([&] { + for (auto It = m_KnownClients.begin(); It != m_KnownClients.end();) + { + if (It->second.LastSeen.GetElapsedTimeMs() > kClientTimeoutMs) + { + ExpiredClients.push_back({It->first, It->second.Hostname}); + It = m_KnownClients.erase(It); + } + else + { + ++It; + } + } + }); + + for (const auto& Expired : ExpiredClients) + { + ZEN_INFO("client {} timed out (no announcement for >5 minutes)", Expired.Id); + RecordClientEvent(ClientEvent::Type::Disconnected, Expired.Id, Expired.Hostname); + } + } while (m_ProbeThreadEnabled); +} + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/actionrecorder.cpp b/src/zencompute/recording/actionrecorder.cpp index 04c4b5141..90141ca55 100644 --- a/src/zencompute/actionrecorder.cpp +++ b/src/zencompute/recording/actionrecorder.cpp @@ -2,7 +2,7 @@ #include "actionrecorder.h" -#include "functionrunner.h" +#include "../runners/functionrunner.h" #include <zencore/compactbinary.h> #include <zencore/compactbinaryfile.h> diff --git a/src/zencompute/actionrecorder.h b/src/zencompute/recording/actionrecorder.h index 9cc2b44a2..2827b6ac7 100644 --- a/src/zencompute/actionrecorder.h +++ b/src/zencompute/recording/actionrecorder.h @@ -2,7 +2,7 @@ #pragma once -#include <zencompute/functionservice.h> +#include <zencompute/computeservice.h> #include <zencompute/zencompute.h> #include <zencore/basicfile.h> #include <zencore/compactbinarybuilder.h> diff --git a/src/zencompute/recordingreader.cpp b/src/zencompute/recording/recordingreader.cpp index 1c1a119cf..1c1a119cf 100644 --- a/src/zencompute/recordingreader.cpp +++ b/src/zencompute/recording/recordingreader.cpp diff --git a/src/zencompute/runners/deferreddeleter.cpp b/src/zencompute/runners/deferreddeleter.cpp new file mode 100644 index 000000000..00977d9fa --- /dev/null +++ b/src/zencompute/runners/deferreddeleter.cpp @@ -0,0 +1,336 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "deferreddeleter.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include <zencore/filesystem.h> +# include <zencore/fmtutils.h> +# include <zencore/logging.h> +# include <zencore/thread.h> + +# include <algorithm> +# include <chrono> + +namespace zen::compute { + +using namespace std::chrono_literals; + +using Clock = std::chrono::steady_clock; + +// Default deferral: how long to wait before attempting deletion. +// This gives memory-mapped file handles time to close naturally. +static constexpr auto DeferralPeriod = 60s; + +// Shortened deferral after MarkReady(): the client has collected results +// so handles should be released soon, but we still wait briefly. +static constexpr auto ReadyGracePeriod = 5s; + +// Interval between retry attempts for directories that failed deletion. +static constexpr auto RetryInterval = 5s; + +static constexpr int MaxRetries = 10; + +DeferredDirectoryDeleter::DeferredDirectoryDeleter() : m_Thread(&DeferredDirectoryDeleter::ThreadFunction, this) +{ +} + +DeferredDirectoryDeleter::~DeferredDirectoryDeleter() +{ + Shutdown(); +} + +void +DeferredDirectoryDeleter::Enqueue(int ActionLsn, std::filesystem::path Path) +{ + { + std::lock_guard Lock(m_Mutex); + m_Queue.push_back({ActionLsn, std::move(Path)}); + } + m_Cv.notify_one(); +} + +void +DeferredDirectoryDeleter::MarkReady(int ActionLsn) +{ + { + std::lock_guard Lock(m_Mutex); + m_ReadyLsns.push_back(ActionLsn); + } + m_Cv.notify_one(); +} + +void +DeferredDirectoryDeleter::Shutdown() +{ + { + std::lock_guard Lock(m_Mutex); + m_Done = true; + } + m_Cv.notify_one(); + + if (m_Thread.joinable()) + { + m_Thread.join(); + } +} + +void +DeferredDirectoryDeleter::ThreadFunction() +{ + SetCurrentThreadName("ZenDirCleanup"); + + struct PendingEntry + { + int ActionLsn; + std::filesystem::path Path; + Clock::time_point ReadyTime; + int Attempts = 0; + }; + + std::vector<PendingEntry> PendingList; + + auto TryDelete = [](PendingEntry& Entry) -> bool { + std::error_code Ec; + std::filesystem::remove_all(Entry.Path, Ec); + return !Ec; + }; + + for (;;) + { + bool Shutting = false; + + // Drain the incoming queue and process MarkReady signals + + { + std::unique_lock Lock(m_Mutex); + + if (m_Queue.empty() && m_ReadyLsns.empty() && !m_Done) + { + if (PendingList.empty()) + { + m_Cv.wait(Lock, [this] { return !m_Queue.empty() || !m_ReadyLsns.empty() || m_Done; }); + } + else + { + auto NextReady = PendingList.front().ReadyTime; + for (const auto& Entry : PendingList) + { + if (Entry.ReadyTime < NextReady) + { + NextReady = Entry.ReadyTime; + } + } + + m_Cv.wait_until(Lock, NextReady, [this] { return !m_Queue.empty() || !m_ReadyLsns.empty() || m_Done; }); + } + } + + // Move new items into PendingList with the full deferral deadline + auto Now = Clock::now(); + for (auto& Entry : m_Queue) + { + PendingList.push_back({Entry.ActionLsn, std::move(Entry.Path), Now + DeferralPeriod, 0}); + } + m_Queue.clear(); + + // Apply MarkReady: shorten ReadyTime for matching entries + for (int Lsn : m_ReadyLsns) + { + for (auto& Entry : PendingList) + { + if (Entry.ActionLsn == Lsn) + { + auto NewReady = Now + ReadyGracePeriod; + if (NewReady < Entry.ReadyTime) + { + Entry.ReadyTime = NewReady; + } + } + } + } + m_ReadyLsns.clear(); + + Shutting = m_Done; + } + + // Process items whose deferral period has elapsed (or all items on shutdown) + + auto Now = Clock::now(); + + for (size_t i = 0; i < PendingList.size();) + { + auto& Entry = PendingList[i]; + + if (!Shutting && Now < Entry.ReadyTime) + { + ++i; + continue; + } + + if (TryDelete(Entry)) + { + if (Entry.Attempts > 0) + { + ZEN_INFO("Retry succeeded for directory '{}'", Entry.Path); + } + + PendingList[i] = std::move(PendingList.back()); + PendingList.pop_back(); + } + else + { + ++Entry.Attempts; + + if (Entry.Attempts >= MaxRetries) + { + ZEN_WARN("Giving up on deleting '{}' after {} attempts", Entry.Path, Entry.Attempts); + PendingList[i] = std::move(PendingList.back()); + PendingList.pop_back(); + } + else + { + ZEN_WARN("Unable to delete directory '{}' (attempt {}), will retry", Entry.Path, Entry.Attempts); + Entry.ReadyTime = Now + RetryInterval; + ++i; + } + } + } + + // Exit once shutdown is requested and nothing remains + + if (Shutting && PendingList.empty()) + { + return; + } + } +} + +} // namespace zen::compute + +#endif + +#if ZEN_WITH_TESTS + +# include <zencore/testing.h> + +namespace zen::compute { + +void +deferreddeleter_forcelink() +{ +} + +} // namespace zen::compute + +#endif + +#if ZEN_WITH_TESTS && ZEN_WITH_COMPUTE_SERVICES + +# include <zencore/testutils.h> + +namespace zen::compute { + +TEST_CASE("DeferredDirectoryDeleter.DeletesSingleDirectory") +{ + ScopedTemporaryDirectory TempDir; + std::filesystem::path DirToDelete = TempDir.Path() / "subdir"; + CreateDirectories(DirToDelete / "nested"); + + CHECK(std::filesystem::exists(DirToDelete)); + + { + DeferredDirectoryDeleter Deleter; + Deleter.Enqueue(1, DirToDelete); + } + + CHECK(!std::filesystem::exists(DirToDelete)); +} + +TEST_CASE("DeferredDirectoryDeleter.DeletesMultipleDirectories") +{ + ScopedTemporaryDirectory TempDir; + + constexpr int NumDirs = 10; + std::vector<std::filesystem::path> Dirs; + + for (int i = 0; i < NumDirs; ++i) + { + auto Dir = TempDir.Path() / std::to_string(i); + CreateDirectories(Dir / "child"); + Dirs.push_back(std::move(Dir)); + } + + { + DeferredDirectoryDeleter Deleter; + for (int i = 0; i < NumDirs; ++i) + { + CHECK(std::filesystem::exists(Dirs[i])); + Deleter.Enqueue(100 + i, Dirs[i]); + } + } + + for (const auto& Dir : Dirs) + { + CHECK(!std::filesystem::exists(Dir)); + } +} + +TEST_CASE("DeferredDirectoryDeleter.ShutdownIsIdempotent") +{ + ScopedTemporaryDirectory TempDir; + std::filesystem::path Dir = TempDir.Path() / "idempotent"; + CreateDirectories(Dir); + + DeferredDirectoryDeleter Deleter; + Deleter.Enqueue(42, Dir); + Deleter.Shutdown(); + Deleter.Shutdown(); + + CHECK(!std::filesystem::exists(Dir)); +} + +TEST_CASE("DeferredDirectoryDeleter.HandlesNonExistentPath") +{ + ScopedTemporaryDirectory TempDir; + std::filesystem::path NoSuchDir = TempDir.Path() / "does_not_exist"; + + { + DeferredDirectoryDeleter Deleter; + Deleter.Enqueue(99, NoSuchDir); + } +} + +TEST_CASE("DeferredDirectoryDeleter.ExplicitShutdownBeforeDestruction") +{ + ScopedTemporaryDirectory TempDir; + std::filesystem::path Dir = TempDir.Path() / "explicit"; + CreateDirectories(Dir / "inner"); + + DeferredDirectoryDeleter Deleter; + Deleter.Enqueue(7, Dir); + Deleter.Shutdown(); + + CHECK(!std::filesystem::exists(Dir)); +} + +TEST_CASE("DeferredDirectoryDeleter.MarkReadyShortensDeferral") +{ + ScopedTemporaryDirectory TempDir; + std::filesystem::path Dir = TempDir.Path() / "markready"; + CreateDirectories(Dir / "child"); + + DeferredDirectoryDeleter Deleter; + Deleter.Enqueue(50, Dir); + + // Without MarkReady the full deferral (60s) would apply. + // MarkReady shortens it to 5s, and shutdown bypasses even that. + Deleter.MarkReady(50); + Deleter.Shutdown(); + + CHECK(!std::filesystem::exists(Dir)); +} + +} // namespace zen::compute + +#endif // ZEN_WITH_TESTS && ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/runners/deferreddeleter.h b/src/zencompute/runners/deferreddeleter.h new file mode 100644 index 000000000..9b010aa0f --- /dev/null +++ b/src/zencompute/runners/deferreddeleter.h @@ -0,0 +1,68 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "zencompute/computeservice.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include <condition_variable> +# include <deque> +# include <filesystem> +# include <mutex> +# include <thread> +# include <vector> + +namespace zen::compute { + +/// Deletes directories on a background thread to avoid blocking callers. +/// Useful when DeleteDirectories may stall (e.g. Wine's deferred-unlink semantics). +/// +/// Enqueued directories wait for a deferral period before deletion, giving +/// file handles time to close. Call MarkReady() with the ActionLsn to shorten +/// the wait to a brief grace period (e.g. once a client has collected results). +/// On shutdown, all pending directories are deleted immediately. +class DeferredDirectoryDeleter +{ + DeferredDirectoryDeleter(const DeferredDirectoryDeleter&) = delete; + DeferredDirectoryDeleter& operator=(const DeferredDirectoryDeleter&) = delete; + +public: + DeferredDirectoryDeleter(); + ~DeferredDirectoryDeleter(); + + /// Enqueue a directory for deferred deletion, associated with an action LSN. + void Enqueue(int ActionLsn, std::filesystem::path Path); + + /// Signal that the action result has been consumed and the directory + /// can be deleted after a short grace period instead of the full deferral. + void MarkReady(int ActionLsn); + + /// Drain the queue and join the background thread. Idempotent. + void Shutdown(); + +private: + struct QueueEntry + { + int ActionLsn; + std::filesystem::path Path; + }; + + std::mutex m_Mutex; + std::condition_variable m_Cv; + std::deque<QueueEntry> m_Queue; + std::vector<int> m_ReadyLsns; + bool m_Done = false; + std::thread m_Thread; + void ThreadFunction(); +}; + +} // namespace zen::compute + +#endif + +#if ZEN_WITH_TESTS +namespace zen::compute { +void deferreddeleter_forcelink(); // internal +} // namespace zen::compute +#endif diff --git a/src/zencompute/runners/functionrunner.cpp b/src/zencompute/runners/functionrunner.cpp new file mode 100644 index 000000000..768cdf1e1 --- /dev/null +++ b/src/zencompute/runners/functionrunner.cpp @@ -0,0 +1,365 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "functionrunner.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include <zencore/compactbinary.h> +# include <zencore/filesystem.h> +# include <zencore/trace.h> + +# include <fmt/format.h> +# include <vector> + +namespace zen::compute { + +FunctionRunner::FunctionRunner(std::filesystem::path BasePath) : m_ActionsPath(BasePath / "actions") +{ +} + +FunctionRunner::~FunctionRunner() = default; + +size_t +FunctionRunner::QueryCapacity() +{ + return 1; +} + +std::vector<SubmitResult> +FunctionRunner::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions) +{ + std::vector<SubmitResult> Results; + Results.reserve(Actions.size()); + + for (const Ref<RunnerAction>& Action : Actions) + { + Results.push_back(SubmitAction(Action)); + } + + return Results; +} + +void +FunctionRunner::MaybeDumpAction(int ActionLsn, const CbObject& ActionObject) +{ + if (m_DumpActions) + { + std::string UniqueId = fmt::format("{}.ddb", ActionLsn); + std::filesystem::path Path = m_ActionsPath / UniqueId; + + zen::WriteFile(Path, IoBuffer(ActionObject.GetBuffer().AsIoBuffer())); + } +} + +////////////////////////////////////////////////////////////////////////// + +void +BaseRunnerGroup::AddRunnerInternal(FunctionRunner* Runner) +{ + m_RunnersLock.WithExclusiveLock([&] { m_Runners.emplace_back(Runner); }); +} + +size_t +BaseRunnerGroup::QueryCapacity() +{ + size_t TotalCapacity = 0; + m_RunnersLock.WithSharedLock([&] { + for (const auto& Runner : m_Runners) + { + TotalCapacity += Runner->QueryCapacity(); + } + }); + return TotalCapacity; +} + +SubmitResult +BaseRunnerGroup::SubmitAction(Ref<RunnerAction> Action) +{ + ZEN_TRACE_CPU("BaseRunnerGroup::SubmitAction"); + RwLock::SharedLockScope _(m_RunnersLock); + + const int InitialIndex = m_NextSubmitIndex.load(std::memory_order_acquire); + int Index = InitialIndex; + const int RunnerCount = gsl::narrow<int>(m_Runners.size()); + + if (RunnerCount == 0) + { + return {.IsAccepted = false, .Reason = "No runners available"}; + } + + do + { + while (Index >= RunnerCount) + { + Index -= RunnerCount; + } + + auto& Runner = m_Runners[Index++]; + + SubmitResult Result = Runner->SubmitAction(Action); + + if (Result.IsAccepted == true) + { + m_NextSubmitIndex = Index % RunnerCount; + + return Result; + } + + while (Index >= RunnerCount) + { + Index -= RunnerCount; + } + } while (Index != InitialIndex); + + return {.IsAccepted = false}; +} + +std::vector<SubmitResult> +BaseRunnerGroup::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions) +{ + ZEN_TRACE_CPU("BaseRunnerGroup::SubmitActions"); + RwLock::SharedLockScope _(m_RunnersLock); + + const int RunnerCount = gsl::narrow<int>(m_Runners.size()); + + if (RunnerCount == 0) + { + return std::vector<SubmitResult>(Actions.size(), SubmitResult{.IsAccepted = false, .Reason = "No runners available"}); + } + + // Query capacity per runner and compute total + std::vector<size_t> Capacities(RunnerCount); + size_t TotalCapacity = 0; + + for (int i = 0; i < RunnerCount; ++i) + { + Capacities[i] = m_Runners[i]->QueryCapacity(); + TotalCapacity += Capacities[i]; + } + + if (TotalCapacity == 0) + { + return std::vector<SubmitResult>(Actions.size(), SubmitResult{.IsAccepted = false, .Reason = "No capacity"}); + } + + // Distribute actions across runners proportionally to their available capacity + std::vector<std::vector<Ref<RunnerAction>>> PerRunnerActions(RunnerCount); + std::vector<size_t> ActionRunnerIndex(Actions.size()); + size_t ActionIdx = 0; + + for (int i = 0; i < RunnerCount; ++i) + { + if (Capacities[i] == 0) + { + continue; + } + + size_t Share = (Actions.size() * Capacities[i] + TotalCapacity - 1) / TotalCapacity; + Share = std::min(Share, Capacities[i]); + + for (size_t j = 0; j < Share && ActionIdx < Actions.size(); ++j, ++ActionIdx) + { + PerRunnerActions[i].push_back(Actions[ActionIdx]); + ActionRunnerIndex[ActionIdx] = i; + } + } + + // Assign any remaining actions to runners with capacity (round-robin) + for (int i = 0; ActionIdx < Actions.size(); i = (i + 1) % RunnerCount) + { + if (Capacities[i] > PerRunnerActions[i].size()) + { + PerRunnerActions[i].push_back(Actions[ActionIdx]); + ActionRunnerIndex[ActionIdx] = i; + ++ActionIdx; + } + } + + // Submit batches per runner + std::vector<std::vector<SubmitResult>> PerRunnerResults(RunnerCount); + + for (int i = 0; i < RunnerCount; ++i) + { + if (!PerRunnerActions[i].empty()) + { + PerRunnerResults[i] = m_Runners[i]->SubmitActions(PerRunnerActions[i]); + } + } + + // Reassemble results in original action order + std::vector<SubmitResult> Results(Actions.size()); + std::vector<size_t> PerRunnerIdx(RunnerCount, 0); + + for (size_t i = 0; i < Actions.size(); ++i) + { + size_t RunnerIdx = ActionRunnerIndex[i]; + size_t Idx = PerRunnerIdx[RunnerIdx]++; + Results[i] = std::move(PerRunnerResults[RunnerIdx][Idx]); + } + + return Results; +} + +size_t +BaseRunnerGroup::GetSubmittedActionCount() +{ + RwLock::SharedLockScope _(m_RunnersLock); + + size_t TotalCount = 0; + + for (const auto& Runner : m_Runners) + { + TotalCount += Runner->GetSubmittedActionCount(); + } + + return TotalCount; +} + +void +BaseRunnerGroup::RegisterWorker(CbPackage Worker) +{ + RwLock::SharedLockScope _(m_RunnersLock); + + for (auto& Runner : m_Runners) + { + Runner->RegisterWorker(Worker); + } +} + +void +BaseRunnerGroup::Shutdown() +{ + RwLock::SharedLockScope _(m_RunnersLock); + + for (auto& Runner : m_Runners) + { + Runner->Shutdown(); + } +} + +bool +BaseRunnerGroup::CancelAction(int ActionLsn) +{ + RwLock::SharedLockScope _(m_RunnersLock); + + for (auto& Runner : m_Runners) + { + if (Runner->CancelAction(ActionLsn)) + { + return true; + } + } + + return false; +} + +void +BaseRunnerGroup::CancelRemoteQueue(int QueueId) +{ + RwLock::SharedLockScope _(m_RunnersLock); + + for (auto& Runner : m_Runners) + { + Runner->CancelRemoteQueue(QueueId); + } +} + +////////////////////////////////////////////////////////////////////////// + +RunnerAction::RunnerAction(ComputeServiceSession* OwnerSession) : m_OwnerSession(OwnerSession) +{ + this->Timestamps[static_cast<int>(State::New)] = DateTime::Now().GetTicks(); +} + +RunnerAction::~RunnerAction() +{ +} + +bool +RunnerAction::ResetActionStateToPending() +{ + // Only allow reset from Failed or Abandoned states + State CurrentState = m_ActionState.load(); + + if (CurrentState != State::Failed && CurrentState != State::Abandoned) + { + return false; + } + + if (!m_ActionState.compare_exchange_strong(CurrentState, State::Pending)) + { + return false; + } + + // Clear timestamps from Submitting through _Count + for (int i = static_cast<int>(State::Submitting); i < static_cast<int>(State::_Count); ++i) + { + this->Timestamps[i] = 0; + } + + // Record new Pending timestamp + this->Timestamps[static_cast<int>(State::Pending)] = DateTime::Now().GetTicks(); + + // Clear execution fields + ExecutionLocation.clear(); + CpuUsagePercent.store(-1.0f, std::memory_order_relaxed); + CpuSeconds.store(0.0f, std::memory_order_relaxed); + + // Increment retry count + RetryCount.fetch_add(1, std::memory_order_relaxed); + + // Re-enter the scheduler pipeline + m_OwnerSession->PostUpdate(this); + + return true; +} + +void +RunnerAction::SetActionState(State NewState) +{ + ZEN_ASSERT(NewState < State::_Count); + this->Timestamps[static_cast<int>(NewState)] = DateTime::Now().GetTicks(); + + do + { + if (State CurrentState = m_ActionState.load(); CurrentState == NewState) + { + // No state change + return; + } + else + { + if (NewState <= CurrentState) + { + // Cannot transition to an earlier or same state + return; + } + + if (m_ActionState.compare_exchange_strong(CurrentState, NewState)) + { + // Successful state change + + m_OwnerSession->PostUpdate(this); + + return; + } + } + } while (true); +} + +void +RunnerAction::SetResult(CbPackage&& Result) +{ + m_Result = std::move(Result); +} + +CbPackage& +RunnerAction::GetResult() +{ + ZEN_ASSERT(IsCompleted()); + return m_Result; +} + +} // namespace zen::compute + +#endif // ZEN_WITH_COMPUTE_SERVICES
\ No newline at end of file diff --git a/src/zencompute/runners/functionrunner.h b/src/zencompute/runners/functionrunner.h new file mode 100644 index 000000000..f67414dbb --- /dev/null +++ b/src/zencompute/runners/functionrunner.h @@ -0,0 +1,214 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencompute/computeservice.h> + +#if ZEN_WITH_COMPUTE_SERVICES + +# include <atomic> +# include <filesystem> +# include <vector> + +namespace zen::compute { + +struct SubmitResult +{ + bool IsAccepted = false; + std::string Reason; +}; + +/** Base interface for classes implementing a remote execution "runner" + */ +class FunctionRunner : public RefCounted +{ + FunctionRunner(FunctionRunner&&) = delete; + FunctionRunner& operator=(FunctionRunner&&) = delete; + +public: + FunctionRunner(std::filesystem::path BasePath); + virtual ~FunctionRunner() = 0; + + virtual void Shutdown() = 0; + virtual void RegisterWorker(const CbPackage& WorkerPackage) = 0; + + [[nodiscard]] virtual SubmitResult SubmitAction(Ref<RunnerAction> Action) = 0; + [[nodiscard]] virtual size_t GetSubmittedActionCount() = 0; + [[nodiscard]] virtual bool IsHealthy() = 0; + [[nodiscard]] virtual size_t QueryCapacity(); + [[nodiscard]] virtual std::vector<SubmitResult> SubmitActions(const std::vector<Ref<RunnerAction>>& Actions); + + // Best-effort cancellation of a specific in-flight action. Returns true if the + // cancellation signal was successfully sent. The action will transition to Cancelled + // asynchronously once the platform-level termination completes. + virtual bool CancelAction(int /*ActionLsn*/) { return false; } + + // Cancel the remote queue corresponding to the given local QueueId. + // Only meaningful for remote runners; local runners ignore this. + virtual void CancelRemoteQueue(int /*QueueId*/) {} + +protected: + std::filesystem::path m_ActionsPath; + bool m_DumpActions = false; + void MaybeDumpAction(int ActionLsn, const CbObject& ActionObject); +}; + +/** Base class for RunnerGroup that operates on generic FunctionRunner references. + * All scheduling, capacity, and lifecycle logic lives here. + */ +class BaseRunnerGroup +{ +public: + size_t QueryCapacity(); + SubmitResult SubmitAction(Ref<RunnerAction> Action); + std::vector<SubmitResult> SubmitActions(const std::vector<Ref<RunnerAction>>& Actions); + size_t GetSubmittedActionCount(); + void RegisterWorker(CbPackage Worker); + void Shutdown(); + bool CancelAction(int ActionLsn); + void CancelRemoteQueue(int QueueId); + + size_t GetRunnerCount() + { + return m_RunnersLock.WithSharedLock([this] { return m_Runners.size(); }); + } + +protected: + void AddRunnerInternal(FunctionRunner* Runner); + + RwLock m_RunnersLock; + std::vector<Ref<FunctionRunner>> m_Runners; + std::atomic<int> m_NextSubmitIndex{0}; +}; + +/** Typed RunnerGroup that adds type-safe runner addition and predicate-based removal. + */ +template<typename RunnerType> +struct RunnerGroup : public BaseRunnerGroup +{ + void AddRunner(RunnerType* Runner) { AddRunnerInternal(Runner); } + + template<typename Predicate> + size_t RemoveRunnerIf(Predicate&& Pred) + { + size_t RemovedCount = 0; + m_RunnersLock.WithExclusiveLock([&] { + auto It = m_Runners.begin(); + while (It != m_Runners.end()) + { + if (Pred(static_cast<RunnerType&>(**It))) + { + (*It)->Shutdown(); + It = m_Runners.erase(It); + ++RemovedCount; + } + else + { + ++It; + } + } + }); + return RemovedCount; + } +}; + +/** + * This represents an action going through different stages of scheduling and execution. + */ +struct RunnerAction : public RefCounted +{ + explicit RunnerAction(ComputeServiceSession* OwnerSession); + ~RunnerAction(); + + int ActionLsn = 0; + int QueueId = 0; + WorkerDesc Worker; + IoHash ActionId; + CbObject ActionObj; + int Priority = 0; + std::string ExecutionLocation; // "local" or remote hostname + + // CPU usage and total CPU time of the running process, sampled periodically by the local runner. + // CpuUsagePercent: -1.0 means not yet sampled; >=0.0 is the most recent reading as a percentage. + // CpuSeconds: total CPU time (user+system) consumed since process start, in seconds. 0.0 if not yet sampled. + std::atomic<float> CpuUsagePercent{-1.0f}; + std::atomic<float> CpuSeconds{0.0f}; + std::atomic<int> RetryCount{0}; + + enum class State + { + New, + Pending, + Submitting, + Running, + Completed, + Failed, + Abandoned, + Cancelled, + _Count + }; + + static const char* ToString(State _) + { + switch (_) + { + case State::New: + return "New"; + case State::Pending: + return "Pending"; + case State::Submitting: + return "Submitting"; + case State::Running: + return "Running"; + case State::Completed: + return "Completed"; + case State::Failed: + return "Failed"; + case State::Abandoned: + return "Abandoned"; + case State::Cancelled: + return "Cancelled"; + default: + return "Unknown"; + } + } + + static State FromString(std::string_view Name, State Default = State::Failed) + { + for (int i = 0; i < static_cast<int>(State::_Count); ++i) + { + if (Name == ToString(static_cast<State>(i))) + { + return static_cast<State>(i); + } + } + return Default; + } + + uint64_t Timestamps[static_cast<int>(State::_Count)] = {}; + + State ActionState() const { return m_ActionState; } + void SetActionState(State NewState); + + bool IsSuccess() const { return ActionState() == State::Completed; } + bool ResetActionStateToPending(); + bool IsCompleted() const + { + return ActionState() == State::Completed || ActionState() == State::Failed || ActionState() == State::Abandoned || + ActionState() == State::Cancelled; + } + + void SetResult(CbPackage&& Result); + CbPackage& GetResult(); + + ComputeServiceSession* GetOwnerSession() const { return m_OwnerSession; } + +private: + std::atomic<State> m_ActionState = State::New; + ComputeServiceSession* m_OwnerSession = nullptr; + CbPackage m_Result; +}; + +} // namespace zen::compute + +#endif // ZEN_WITH_COMPUTE_SERVICES
\ No newline at end of file diff --git a/src/zencompute/runners/linuxrunner.cpp b/src/zencompute/runners/linuxrunner.cpp new file mode 100644 index 000000000..e79a6c90f --- /dev/null +++ b/src/zencompute/runners/linuxrunner.cpp @@ -0,0 +1,734 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "linuxrunner.h" + +#if ZEN_WITH_COMPUTE_SERVICES && ZEN_PLATFORM_LINUX + +# include <zencore/compactbinary.h> +# include <zencore/compactbinarypackage.h> +# include <zencore/except.h> +# include <zencore/except_fmt.h> +# include <zencore/filesystem.h> +# include <zencore/fmtutils.h> +# include <zencore/timer.h> +# include <zencore/trace.h> + +# include <fcntl.h> +# include <sched.h> +# include <signal.h> +# include <sys/mount.h> +# include <sys/stat.h> +# include <sys/syscall.h> +# include <sys/wait.h> +# include <unistd.h> + +namespace zen::compute { + +using namespace std::literals; + +namespace { + + // All helper functions in this namespace are async-signal-safe (safe to call + // between fork() and execve()). They use only raw syscalls and avoid any + // heap allocation, stdio, or other non-AS-safe operations. + + void WriteToFd(int Fd, const char* Buf, size_t Len) + { + while (Len > 0) + { + ssize_t Written = write(Fd, Buf, Len); + if (Written <= 0) + { + break; + } + Buf += Written; + Len -= static_cast<size_t>(Written); + } + } + + [[noreturn]] void WriteErrorAndExit(int ErrorPipeFd, const char* Msg, int Errno) + { + // Write the message prefix + size_t MsgLen = 0; + for (const char* P = Msg; *P; ++P) + { + ++MsgLen; + } + WriteToFd(ErrorPipeFd, Msg, MsgLen); + + // Append ": " and the errno string if non-zero + if (Errno != 0) + { + WriteToFd(ErrorPipeFd, ": ", 2); + const char* ErrStr = strerror(Errno); + size_t ErrLen = 0; + for (const char* P = ErrStr; *P; ++P) + { + ++ErrLen; + } + WriteToFd(ErrorPipeFd, ErrStr, ErrLen); + } + + _exit(127); + } + + int MkdirIfNeeded(const char* Path, mode_t Mode) + { + if (mkdir(Path, Mode) != 0 && errno != EEXIST) + { + return -1; + } + return 0; + } + + int BindMountReadOnly(const char* Src, const char* Dst) + { + if (mount(Src, Dst, nullptr, MS_BIND | MS_REC, nullptr) != 0) + { + return -1; + } + + // Remount read-only + if (mount(nullptr, Dst, nullptr, MS_REMOUNT | MS_BIND | MS_RDONLY | MS_REC, nullptr) != 0) + { + return -1; + } + + return 0; + } + + // Set up namespace-based sandbox isolation in the child process. + // This is called after fork(), before execve(). All operations must be + // async-signal-safe. + // + // The sandbox layout after pivot_root: + // / -> the sandbox directory (tmpfs-like, was SandboxPath) + // /usr -> bind-mount of host /usr (read-only) + // /lib -> bind-mount of host /lib (read-only) + // /lib64 -> bind-mount of host /lib64 (read-only, optional) + // /etc -> bind-mount of host /etc (read-only) + // /worker -> bind-mount of worker directory (read-only) + // /proc -> proc filesystem + // /dev -> tmpfs with null, zero, urandom + void SetupNamespaceSandbox(const char* SandboxPath, uid_t Uid, gid_t Gid, const char* WorkerPath, int ErrorPipeFd) + { + // 1. Unshare user, mount, and network namespaces + if (unshare(CLONE_NEWUSER | CLONE_NEWNS | CLONE_NEWNET) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "unshare() failed", errno); + } + + // 2. Write UID/GID mappings + // Must deny setgroups first (required by kernel for unprivileged user namespaces) + { + int Fd = open("/proc/self/setgroups", O_WRONLY); + if (Fd >= 0) + { + WriteToFd(Fd, "deny", 4); + close(Fd); + } + // setgroups file may not exist on older kernels; not fatal + } + + { + // uid_map: map our UID to 0 inside the namespace + char Buf[64]; + int Len = snprintf(Buf, sizeof(Buf), "0 %u 1\n", static_cast<unsigned>(Uid)); + + int Fd = open("/proc/self/uid_map", O_WRONLY); + if (Fd < 0) + { + WriteErrorAndExit(ErrorPipeFd, "open uid_map failed", errno); + } + WriteToFd(Fd, Buf, static_cast<size_t>(Len)); + close(Fd); + } + + { + // gid_map: map our GID to 0 inside the namespace + char Buf[64]; + int Len = snprintf(Buf, sizeof(Buf), "0 %u 1\n", static_cast<unsigned>(Gid)); + + int Fd = open("/proc/self/gid_map", O_WRONLY); + if (Fd < 0) + { + WriteErrorAndExit(ErrorPipeFd, "open gid_map failed", errno); + } + WriteToFd(Fd, Buf, static_cast<size_t>(Len)); + close(Fd); + } + + // 3. Privatize the entire mount tree so our mounts don't propagate + if (mount(nullptr, "/", nullptr, MS_REC | MS_PRIVATE, nullptr) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "mount MS_PRIVATE failed", errno); + } + + // 4. Create mount points inside the sandbox and bind-mount system directories + + // Helper macro-like pattern for building paths inside sandbox + // We use stack buffers since we can't allocate heap memory safely + char MountPoint[4096]; + + auto BuildPath = [&](const char* Suffix) -> const char* { + snprintf(MountPoint, sizeof(MountPoint), "%s/%s", SandboxPath, Suffix); + return MountPoint; + }; + + // /usr (required) + if (MkdirIfNeeded(BuildPath("usr"), 0755) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "mkdir sandbox/usr failed", errno); + } + if (BindMountReadOnly("/usr", BuildPath("usr")) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "bind mount /usr failed", errno); + } + + // /lib (required) + if (MkdirIfNeeded(BuildPath("lib"), 0755) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "mkdir sandbox/lib failed", errno); + } + if (BindMountReadOnly("/lib", BuildPath("lib")) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "bind mount /lib failed", errno); + } + + // /lib64 (optional — not all distros have it) + { + struct stat St; + if (stat("/lib64", &St) == 0 && S_ISDIR(St.st_mode)) + { + if (MkdirIfNeeded(BuildPath("lib64"), 0755) == 0) + { + BindMountReadOnly("/lib64", BuildPath("lib64")); + // Failure is non-fatal for lib64 + } + } + } + + // /etc (required — for resolv.conf, ld.so.cache, etc.) + if (MkdirIfNeeded(BuildPath("etc"), 0755) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "mkdir sandbox/etc failed", errno); + } + if (BindMountReadOnly("/etc", BuildPath("etc")) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "bind mount /etc failed", errno); + } + + // /worker — bind-mount worker directory (contains the executable) + if (MkdirIfNeeded(BuildPath("worker"), 0755) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "mkdir sandbox/worker failed", errno); + } + if (BindMountReadOnly(WorkerPath, BuildPath("worker")) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "bind mount worker dir failed", errno); + } + + // 5. Mount /proc inside sandbox + if (MkdirIfNeeded(BuildPath("proc"), 0755) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "mkdir sandbox/proc failed", errno); + } + if (mount("proc", BuildPath("proc"), "proc", MS_NOSUID | MS_NOEXEC | MS_NODEV, nullptr) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "mount /proc failed", errno); + } + + // 6. Mount tmpfs /dev and bind-mount essential device nodes + if (MkdirIfNeeded(BuildPath("dev"), 0755) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "mkdir sandbox/dev failed", errno); + } + if (mount("tmpfs", BuildPath("dev"), "tmpfs", MS_NOSUID | MS_NOEXEC, "size=64k,mode=0755") != 0) + { + WriteErrorAndExit(ErrorPipeFd, "mount tmpfs /dev failed", errno); + } + + // Bind-mount /dev/null, /dev/zero, /dev/urandom + { + char DevSrc[64]; + char DevDst[4096]; + + auto BindDev = [&](const char* Name) { + snprintf(DevSrc, sizeof(DevSrc), "/dev/%s", Name); + snprintf(DevDst, sizeof(DevDst), "%s/dev/%s", SandboxPath, Name); + + // Create the file to mount over + int Fd = open(DevDst, O_WRONLY | O_CREAT, 0666); + if (Fd >= 0) + { + close(Fd); + } + mount(DevSrc, DevDst, nullptr, MS_BIND, nullptr); + // Non-fatal if individual devices fail + }; + + BindDev("null"); + BindDev("zero"); + BindDev("urandom"); + } + + // 7. pivot_root to sandbox + // pivot_root requires the new root and put_old to be mount points. + // Bind-mount sandbox onto itself to make it a mount point. + if (mount(SandboxPath, SandboxPath, nullptr, MS_BIND | MS_REC, nullptr) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "bind mount sandbox onto itself failed", errno); + } + + // Create .pivot_old inside sandbox + char PivotOld[4096]; + snprintf(PivotOld, sizeof(PivotOld), "%s/.pivot_old", SandboxPath); + if (MkdirIfNeeded(PivotOld, 0755) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "mkdir .pivot_old failed", errno); + } + + if (syscall(SYS_pivot_root, SandboxPath, PivotOld) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "pivot_root failed", errno); + } + + // 8. Now inside new root. Clean up old root. + if (chdir("/") != 0) + { + WriteErrorAndExit(ErrorPipeFd, "chdir / failed", errno); + } + + if (umount2("/.pivot_old", MNT_DETACH) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "umount2 .pivot_old failed", errno); + } + + rmdir("/.pivot_old"); + } + +} // anonymous namespace + +LinuxProcessRunner::LinuxProcessRunner(ChunkResolver& Resolver, + const std::filesystem::path& BaseDir, + DeferredDirectoryDeleter& Deleter, + WorkerThreadPool& WorkerPool, + bool Sandboxed, + int32_t MaxConcurrentActions) +: LocalProcessRunner(Resolver, BaseDir, Deleter, WorkerPool, MaxConcurrentActions) +, m_Sandboxed(Sandboxed) +{ + // Restore SIGCHLD to default behavior so waitpid() can properly collect + // child exit status. zenserver/main.cpp sets SIGCHLD to SIG_IGN which + // causes the kernel to auto-reap children, making waitpid() return + // -1/ECHILD instead of the exit status we need. + struct sigaction Action = {}; + sigemptyset(&Action.sa_mask); + Action.sa_handler = SIG_DFL; + sigaction(SIGCHLD, &Action, nullptr); + + if (m_Sandboxed) + { + ZEN_INFO("namespace sandboxing enabled for child processes"); + } +} + +SubmitResult +LinuxProcessRunner::SubmitAction(Ref<RunnerAction> Action) +{ + ZEN_TRACE_CPU("LinuxProcessRunner::SubmitAction"); + std::optional<PreparedAction> Prepared = PrepareActionSubmission(Action); + + if (!Prepared) + { + return SubmitResult{.IsAccepted = false}; + } + + // Build environment array from worker descriptor + + CbObject WorkerDescription = Prepared->WorkerPackage.GetObject(); + + std::vector<std::string> EnvStrings; + for (auto& It : WorkerDescription["environment"sv]) + { + EnvStrings.emplace_back(It.AsString()); + } + + std::vector<char*> Envp; + Envp.reserve(EnvStrings.size() + 1); + for (auto& Str : EnvStrings) + { + Envp.push_back(Str.data()); + } + Envp.push_back(nullptr); + + // Build argv: <worker_exe_path> -Build=build.action + // Pre-compute all path strings before fork() for async-signal-safety. + + std::string_view ExecPath = WorkerDescription["path"sv].AsString(); + std::string ExePathStr; + std::string SandboxedExePathStr; + + if (m_Sandboxed) + { + // After pivot_root, the worker dir is at /worker inside the new root + std::filesystem::path SandboxedExePath = std::filesystem::path("/worker") / std::filesystem::path(ExecPath); + SandboxedExePathStr = SandboxedExePath.string(); + // We still need the real path for logging + ExePathStr = (Prepared->WorkerPath / std::filesystem::path(ExecPath)).string(); + } + else + { + ExePathStr = (Prepared->WorkerPath / std::filesystem::path(ExecPath)).string(); + } + + std::string BuildArg = "-Build=build.action"; + + // argv[0] should be the path the child will see + const std::string& ChildExePath = m_Sandboxed ? SandboxedExePathStr : ExePathStr; + + std::vector<char*> ArgV; + ArgV.push_back(const_cast<char*>(ChildExePath.data())); + ArgV.push_back(BuildArg.data()); + ArgV.push_back(nullptr); + + ZEN_DEBUG("Executing: {} {} (sandboxed={})", ExePathStr, BuildArg, m_Sandboxed); + + std::string SandboxPathStr = Prepared->SandboxPath.string(); + std::string WorkerPathStr = Prepared->WorkerPath.string(); + + // Pre-fork: get uid/gid for namespace mapping, create error pipe + uid_t CurrentUid = 0; + gid_t CurrentGid = 0; + int ErrorPipe[2] = {-1, -1}; + + if (m_Sandboxed) + { + CurrentUid = getuid(); + CurrentGid = getgid(); + + if (pipe2(ErrorPipe, O_CLOEXEC) != 0) + { + throw zen::runtime_error("pipe2() for sandbox error pipe failed: {}", strerror(errno)); + } + } + + pid_t ChildPid = fork(); + + if (ChildPid < 0) + { + int SavedErrno = errno; + if (m_Sandboxed) + { + close(ErrorPipe[0]); + close(ErrorPipe[1]); + } + throw zen::runtime_error("fork() failed: {}", strerror(SavedErrno)); + } + + if (ChildPid == 0) + { + // Child process + + if (m_Sandboxed) + { + // Close read end of error pipe — child only writes + close(ErrorPipe[0]); + + SetupNamespaceSandbox(SandboxPathStr.c_str(), CurrentUid, CurrentGid, WorkerPathStr.c_str(), ErrorPipe[1]); + + // After pivot_root, CWD is "/" which is the sandbox root. + // execve with the sandboxed path. + execve(SandboxedExePathStr.c_str(), ArgV.data(), Envp.data()); + + WriteErrorAndExit(ErrorPipe[1], "execve failed", errno); + } + else + { + if (chdir(SandboxPathStr.c_str()) != 0) + { + _exit(127); + } + + execve(ExePathStr.c_str(), ArgV.data(), Envp.data()); + _exit(127); + } + } + + // Parent process + + if (m_Sandboxed) + { + // Close write end of error pipe — parent only reads + close(ErrorPipe[1]); + + // Read from error pipe. If execve succeeded, pipe was closed by O_CLOEXEC + // and read returns 0. If setup failed, child wrote an error message. + char ErrBuf[512]; + ssize_t BytesRead = read(ErrorPipe[0], ErrBuf, sizeof(ErrBuf) - 1); + close(ErrorPipe[0]); + + if (BytesRead > 0) + { + // Sandbox setup or execve failed + ErrBuf[BytesRead] = '\0'; + + // Reap the child (it called _exit(127)) + waitpid(ChildPid, nullptr, 0); + + // Clean up the sandbox in the background + m_DeferredDeleter.Enqueue(Action->ActionLsn, std::move(Prepared->SandboxPath)); + + ZEN_ERROR("Sandbox setup failed for action {}: {}", Action->ActionLsn, ErrBuf); + + Action->SetActionState(RunnerAction::State::Failed); + return SubmitResult{.IsAccepted = false}; + } + } + + // Store child pid as void* (same convention as zencore/process.cpp) + + Ref<RunningAction> NewAction{new RunningAction()}; + NewAction->Action = Action; + NewAction->ProcessHandle = reinterpret_cast<void*>(static_cast<intptr_t>(ChildPid)); + NewAction->SandboxPath = std::move(Prepared->SandboxPath); + + { + RwLock::ExclusiveLockScope _(m_RunningLock); + m_RunningMap[Prepared->ActionLsn] = std::move(NewAction); + } + + Action->SetActionState(RunnerAction::State::Running); + + return SubmitResult{.IsAccepted = true}; +} + +void +LinuxProcessRunner::SweepRunningActions() +{ + ZEN_TRACE_CPU("LinuxProcessRunner::SweepRunningActions"); + std::vector<Ref<RunningAction>> CompletedActions; + + m_RunningLock.WithExclusiveLock([&] { + for (auto It = begin(m_RunningMap), ItEnd = end(m_RunningMap); It != ItEnd;) + { + Ref<RunningAction> Running = It->second; + + pid_t Pid = static_cast<pid_t>(reinterpret_cast<intptr_t>(Running->ProcessHandle)); + int Status = 0; + + pid_t Result = waitpid(Pid, &Status, WNOHANG); + + if (Result == Pid) + { + if (WIFEXITED(Status)) + { + Running->ExitCode = WEXITSTATUS(Status); + } + else if (WIFSIGNALED(Status)) + { + Running->ExitCode = 128 + WTERMSIG(Status); + } + else + { + Running->ExitCode = 1; + } + + Running->ProcessHandle = nullptr; + + CompletedActions.push_back(std::move(Running)); + It = m_RunningMap.erase(It); + } + else + { + ++It; + } + } + }); + + ProcessCompletedActions(CompletedActions); +} + +void +LinuxProcessRunner::CancelRunningActions() +{ + ZEN_TRACE_CPU("LinuxProcessRunner::CancelRunningActions"); + Stopwatch Timer; + std::unordered_map<int, Ref<RunningAction>> RunningMap; + + m_RunningLock.WithExclusiveLock([&] { std::swap(RunningMap, m_RunningMap); }); + + if (RunningMap.empty()) + { + return; + } + + ZEN_INFO("cancelling all running actions"); + + // Send SIGTERM to all running processes first + + for (const auto& [Lsn, Running] : RunningMap) + { + pid_t Pid = static_cast<pid_t>(reinterpret_cast<intptr_t>(Running->ProcessHandle)); + + if (kill(Pid, SIGTERM) != 0) + { + ZEN_WARN("kill(SIGTERM) for LSN {} (pid {}) failed: {}", Running->Action->ActionLsn, Pid, strerror(errno)); + } + } + + // Wait for all processes, regardless of whether SIGTERM succeeded, then clean up. + + for (auto& [Lsn, Running] : RunningMap) + { + pid_t Pid = static_cast<pid_t>(reinterpret_cast<intptr_t>(Running->ProcessHandle)); + + // Poll for up to 2 seconds + bool Exited = false; + for (int i = 0; i < 20; ++i) + { + int Status = 0; + pid_t WaitResult = waitpid(Pid, &Status, WNOHANG); + if (WaitResult == Pid) + { + Exited = true; + ZEN_DEBUG("LSN {}: process exit OK", Running->Action->ActionLsn); + break; + } + usleep(100000); // 100ms + } + + if (!Exited) + { + ZEN_WARN("LSN {}: process did not exit after SIGTERM, sending SIGKILL", Running->Action->ActionLsn); + kill(Pid, SIGKILL); + waitpid(Pid, nullptr, 0); + } + + m_DeferredDeleter.Enqueue(Running->Action->ActionLsn, std::move(Running->SandboxPath)); + Running->Action->SetActionState(RunnerAction::State::Failed); + } + + ZEN_INFO("DONE - cancelled {} running processes (took {})", RunningMap.size(), NiceTimeSpanMs(Timer.GetElapsedTimeMs())); +} + +bool +LinuxProcessRunner::CancelAction(int ActionLsn) +{ + ZEN_TRACE_CPU("LinuxProcessRunner::CancelAction"); + + // Hold the shared lock while sending the signal to prevent the sweep thread + // from reaping the PID (via waitpid) between our lookup and kill(). Without + // the lock held, the PID could be recycled by the kernel and we'd signal an + // unrelated process. + bool Sent = false; + + m_RunningLock.WithSharedLock([&] { + auto It = m_RunningMap.find(ActionLsn); + if (It == m_RunningMap.end()) + { + return; + } + + Ref<RunningAction> Target = It->second; + if (!Target->ProcessHandle) + { + return; + } + + pid_t Pid = static_cast<pid_t>(reinterpret_cast<intptr_t>(Target->ProcessHandle)); + + if (kill(Pid, SIGTERM) != 0) + { + ZEN_WARN("CancelAction: kill(SIGTERM) for LSN {} (pid {}) failed: {}", ActionLsn, Pid, strerror(errno)); + return; + } + + ZEN_DEBUG("CancelAction: sent SIGTERM to LSN {} (pid {})", ActionLsn, Pid); + Sent = true; + }); + + // The monitor thread will pick up the process exit and mark the action as Failed. + return Sent; +} + +static uint64_t +ReadProcStatCpuTicks(pid_t Pid) +{ + char Path[64]; + snprintf(Path, sizeof(Path), "/proc/%d/stat", static_cast<int>(Pid)); + + char Buf[256]; + int Fd = open(Path, O_RDONLY); + if (Fd < 0) + { + return 0; + } + + ssize_t Len = read(Fd, Buf, sizeof(Buf) - 1); + close(Fd); + + if (Len <= 0) + { + return 0; + } + + Buf[Len] = '\0'; + + // Skip past "pid (name) " — find last ')' to handle names containing spaces or parens + const char* P = strrchr(Buf, ')'); + if (!P) + { + return 0; + } + + P += 2; // skip ') ' + + // Remaining fields (space-separated, 0-indexed from here): + // 0:state 1:ppid 2:pgrp 3:session 4:tty_nr 5:tty_pgrp 6:flags + // 7:minflt 8:cminflt 9:majflt 10:cmajflt 11:utime 12:stime + unsigned long UTime = 0; + unsigned long STime = 0; + sscanf(P, "%*c %*d %*d %*d %*d %*d %*u %*u %*u %*u %*u %lu %lu", &UTime, &STime); + return UTime + STime; +} + +void +LinuxProcessRunner::SampleProcessCpu(RunningAction& Running) +{ + static const long ClkTck = sysconf(_SC_CLK_TCK); + + const pid_t Pid = static_cast<pid_t>(reinterpret_cast<intptr_t>(Running.ProcessHandle)); + + const uint64_t NowTicks = GetHifreqTimerValue(); + const uint64_t CurrentOsTicks = ReadProcStatCpuTicks(Pid); + + if (CurrentOsTicks == 0) + { + // Process gone or /proc entry unreadable — record timestamp without updating usage + Running.LastCpuSampleTicks = NowTicks; + Running.LastCpuOsTicks = 0; + return; + } + + // Cumulative CPU seconds (absolute, available from first sample) + Running.Action->CpuSeconds.store(static_cast<float>(static_cast<double>(CurrentOsTicks) / ClkTck), std::memory_order_relaxed); + + if (Running.LastCpuSampleTicks != 0 && Running.LastCpuOsTicks != 0) + { + const uint64_t ElapsedMs = Stopwatch::GetElapsedTimeMs(NowTicks - Running.LastCpuSampleTicks); + if (ElapsedMs > 0) + { + const uint64_t DeltaOsTicks = CurrentOsTicks - Running.LastCpuOsTicks; + const float CpuPct = static_cast<float>(static_cast<double>(DeltaOsTicks) * 1000.0 / ClkTck / ElapsedMs * 100.0); + Running.Action->CpuUsagePercent.store(CpuPct, std::memory_order_relaxed); + } + } + + Running.LastCpuSampleTicks = NowTicks; + Running.LastCpuOsTicks = CurrentOsTicks; +} + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/runners/linuxrunner.h b/src/zencompute/runners/linuxrunner.h new file mode 100644 index 000000000..266de366b --- /dev/null +++ b/src/zencompute/runners/linuxrunner.h @@ -0,0 +1,44 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "localrunner.h" + +#if ZEN_WITH_COMPUTE_SERVICES && ZEN_PLATFORM_LINUX + +namespace zen::compute { + +/** Native Linux process runner for executing Linux worker executables directly. + + Subclasses LocalProcessRunner, reusing sandbox management, worker manifesting, + input/output handling, and monitor thread infrastructure. Overrides only the + platform-specific methods: process spawning, sweep, and cancellation. + + When Sandboxed is true, child processes are isolated using Linux namespaces: + user, mount, and network namespaces are unshared so the child has no network + access and can only see the sandbox directory (with system libraries bind-mounted + read-only). This requires no special privileges thanks to user namespaces. + */ +class LinuxProcessRunner : public LocalProcessRunner +{ +public: + LinuxProcessRunner(ChunkResolver& Resolver, + const std::filesystem::path& BaseDir, + DeferredDirectoryDeleter& Deleter, + WorkerThreadPool& WorkerPool, + bool Sandboxed = false, + int32_t MaxConcurrentActions = 0); + + [[nodiscard]] SubmitResult SubmitAction(Ref<RunnerAction> Action) override; + void SweepRunningActions() override; + void CancelRunningActions() override; + bool CancelAction(int ActionLsn) override; + void SampleProcessCpu(RunningAction& Running) override; + +private: + bool m_Sandboxed = false; +}; + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/localrunner.cpp b/src/zencompute/runners/localrunner.cpp index 9a27f3f3d..7aaefb06e 100644 --- a/src/zencompute/localrunner.cpp +++ b/src/zencompute/runners/localrunner.cpp @@ -8,7 +8,7 @@ # include <zencore/compactbinarybuilder.h> # include <zencore/compactbinarypackage.h> # include <zencore/compress.h> -# include <zencore/except.h> +# include <zencore/except_fmt.h> # include <zencore/filesystem.h> # include <zencore/fmtutils.h> # include <zencore/iobuffer.h> @@ -16,6 +16,7 @@ # include <zencore/system.h> # include <zencore/scopeguard.h> # include <zencore/timer.h> +# include <zencore/trace.h> # include <zenstore/cidstore.h> # include <span> @@ -24,17 +25,28 @@ namespace zen::compute { using namespace std::literals; -LocalProcessRunner::LocalProcessRunner(ChunkResolver& Resolver, const std::filesystem::path& BaseDir) +LocalProcessRunner::LocalProcessRunner(ChunkResolver& Resolver, + const std::filesystem::path& BaseDir, + DeferredDirectoryDeleter& Deleter, + WorkerThreadPool& WorkerPool, + int32_t MaxConcurrentActions) : FunctionRunner(BaseDir) , m_Log(logging::Get("local_exec")) , m_ChunkResolver(Resolver) , m_WorkerPath(std::filesystem::weakly_canonical(BaseDir / "workers")) , m_SandboxPath(std::filesystem::weakly_canonical(BaseDir / "scratch")) +, m_DeferredDeleter(Deleter) +, m_WorkerPool(WorkerPool) { SystemMetrics Sm = GetSystemMetricsForReporting(); m_MaxRunningActions = Sm.LogicalProcessorCount * 2; + if (MaxConcurrentActions > 0) + { + m_MaxRunningActions = MaxConcurrentActions; + } + ZEN_INFO("Max concurrent action count: {}", m_MaxRunningActions); bool DidCleanup = false; @@ -116,6 +128,7 @@ LocalProcessRunner::~LocalProcessRunner() void LocalProcessRunner::Shutdown() { + ZEN_TRACE_CPU("LocalProcessRunner::Shutdown"); m_AcceptNewActions = false; m_MonitorThreadEnabled = false; @@ -131,6 +144,7 @@ LocalProcessRunner::Shutdown() std::filesystem::path LocalProcessRunner::CreateNewSandbox() { + ZEN_TRACE_CPU("LocalProcessRunner::CreateNewSandbox"); std::string UniqueId = std::to_string(++m_SandboxCounter); std::filesystem::path Path = m_SandboxPath / UniqueId; zen::CreateDirectories(Path); @@ -141,6 +155,7 @@ LocalProcessRunner::CreateNewSandbox() void LocalProcessRunner::RegisterWorker(const CbPackage& WorkerPackage) { + ZEN_TRACE_CPU("LocalProcessRunner::RegisterWorker"); if (m_DumpActions) { CbObject WorkerDescriptor = WorkerPackage.GetObject(); @@ -172,32 +187,84 @@ LocalProcessRunner::QueryCapacity() return 0; } - size_t RunningCount = m_RunningMap.size(); + const size_t InFlightCount = m_RunningMap.size() + m_SubmittingCount.load(std::memory_order_relaxed); - if (RunningCount >= size_t(m_MaxRunningActions)) + if (const size_t MaxRunningActions = m_MaxRunningActions; InFlightCount >= MaxRunningActions) { return 0; } - - return m_MaxRunningActions - RunningCount; + else + { + return MaxRunningActions - InFlightCount; + } } std::vector<SubmitResult> LocalProcessRunner::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions) { - std::vector<SubmitResult> Results; + if (Actions.size() <= 1) + { + std::vector<SubmitResult> Results; - for (const Ref<RunnerAction>& Action : Actions) + for (const Ref<RunnerAction>& Action : Actions) + { + Results.push_back(SubmitAction(Action)); + } + + return Results; + } + + // For nontrivial batches, check capacity upfront and accept what fits. + // Accepted actions are transitioned to Submitting and dispatched to the + // worker pool as fire-and-forget, so SubmitActions returns immediately + // and the scheduler thread is free to handle completions and updates. + + size_t Available = QueryCapacity(); + + std::vector<SubmitResult> Results(Actions.size()); + + size_t AcceptCount = std::min(Available, Actions.size()); + + for (size_t i = 0; i < AcceptCount; ++i) { - Results.push_back(SubmitAction(Action)); + const Ref<RunnerAction>& Action = Actions[i]; + + Action->SetActionState(RunnerAction::State::Submitting); + m_SubmittingCount.fetch_add(1, std::memory_order_relaxed); + + Results[i] = SubmitResult{.IsAccepted = true}; + + m_WorkerPool.ScheduleWork( + [this, Action]() { + auto CountGuard = MakeGuard([this] { m_SubmittingCount.fetch_sub(1, std::memory_order_relaxed); }); + + SubmitResult Result = SubmitAction(Action); + + if (!Result.IsAccepted) + { + // This might require another state? We should + // distinguish between outright rejections (e.g. invalid action) + // and transient failures (e.g. failed to launch process) which might + // be retried by the scheduler, but for now just fail the action + Action->SetActionState(RunnerAction::State::Failed); + } + }, + WorkerThreadPool::EMode::EnableBacklog); + } + + for (size_t i = AcceptCount; i < Actions.size(); ++i) + { + Results[i] = SubmitResult{.IsAccepted = false}; } return Results; } -SubmitResult -LocalProcessRunner::SubmitAction(Ref<RunnerAction> Action) +std::optional<LocalProcessRunner::PreparedAction> +LocalProcessRunner::PrepareActionSubmission(Ref<RunnerAction> Action) { + ZEN_TRACE_CPU("LocalProcessRunner::PrepareActionSubmission"); + // Verify whether we can accept more work { @@ -205,29 +272,29 @@ LocalProcessRunner::SubmitAction(Ref<RunnerAction> Action) if (!m_AcceptNewActions) { - return SubmitResult{.IsAccepted = false}; + return std::nullopt; } if (m_RunningMap.size() >= size_t(m_MaxRunningActions)) { - return SubmitResult{.IsAccepted = false}; + return std::nullopt; } } - using namespace std::literals; - // Each enqueued action is assigned an integer index (logical sequence number), // which we use as a key for tracking data structures and as an opaque id which // may be used by clients to reference the scheduled action const int32_t ActionLsn = Action->ActionLsn; const CbObject& ActionObj = Action->ActionObj; - const IoHash ActionId = ActionObj.GetHash(); MaybeDumpAction(ActionLsn, ActionObj); std::filesystem::path SandboxPath = CreateNewSandbox(); + // Ensure the sandbox directory is cleaned up if any subsequent step throws + auto SandboxGuard = MakeGuard([&] { m_DeferredDeleter.Enqueue(Action->ActionLsn, std::move(SandboxPath)); }); + CbPackage WorkerPackage = Action->Worker.Descriptor; std::filesystem::path WorkerPath = ManifestWorker(Action->Worker); @@ -251,89 +318,24 @@ LocalProcessRunner::SubmitAction(Ref<RunnerAction> Action) zen::WriteFile(FilePath, DataBuffer); }); -# if ZEN_PLATFORM_WINDOWS - // Set up environment variables + Action->ExecutionLocation = "local"; - StringBuilder<1024> EnvironmentBlock; + SandboxGuard.Dismiss(); - CbObject WorkerDescription = WorkerPackage.GetObject(); - - for (auto& It : WorkerDescription["environment"sv]) - { - EnvironmentBlock.Append(It.AsString()); - EnvironmentBlock.Append('\0'); - } - EnvironmentBlock.Append('\0'); - EnvironmentBlock.Append('\0'); - - // Execute process - this spawns the child process immediately without waiting - // for completion - - std::string_view ExecPath = WorkerDescription["path"sv].AsString(); - std::filesystem::path ExePath = WorkerPath / std::filesystem::path(ExecPath).make_preferred(); - - ExtendableWideStringBuilder<512> CommandLine; - CommandLine.Append(L'"'); - CommandLine.Append(ExePath.c_str()); - CommandLine.Append(L'"'); - CommandLine.Append(L" -Build=build.action"); - - LPSECURITY_ATTRIBUTES lpProcessAttributes = nullptr; - LPSECURITY_ATTRIBUTES lpThreadAttributes = nullptr; - BOOL bInheritHandles = FALSE; - DWORD dwCreationFlags = 0; - - STARTUPINFO StartupInfo{}; - StartupInfo.cb = sizeof StartupInfo; - - PROCESS_INFORMATION ProcessInformation{}; - - ZEN_DEBUG("Executing: {}", WideToUtf8(CommandLine.c_str())); - - CommandLine.EnsureNulTerminated(); - - BOOL Success = CreateProcessW(nullptr, - CommandLine.Data(), - lpProcessAttributes, - lpThreadAttributes, - bInheritHandles, - dwCreationFlags, - (LPVOID)EnvironmentBlock.Data(), // Environment block - SandboxPath.c_str(), // Current directory - &StartupInfo, - /* out */ &ProcessInformation); - - if (!Success) - { - // TODO: this is probably not the best way to report failure. The return - // object should include a failure state and context - - zen::ThrowLastError("Unable to launch process" /* TODO: Add context */); - } - - CloseHandle(ProcessInformation.hThread); - - Ref<RunningAction> NewAction{new RunningAction()}; - NewAction->Action = Action; - NewAction->ProcessHandle = ProcessInformation.hProcess; - NewAction->SandboxPath = std::move(SandboxPath); - - { - RwLock::ExclusiveLockScope _(m_RunningLock); - - m_RunningMap[ActionLsn] = std::move(NewAction); - } - - Action->SetActionState(RunnerAction::State::Running); -# else - ZEN_UNUSED(ActionId); - - ZEN_NOT_IMPLEMENTED(); - - int ExitCode = 0; -# endif + return PreparedAction{ + .ActionLsn = ActionLsn, + .SandboxPath = std::move(SandboxPath), + .WorkerPath = std::move(WorkerPath), + .WorkerPackage = std::move(WorkerPackage), + }; +} - return SubmitResult{.IsAccepted = true}; +SubmitResult +LocalProcessRunner::SubmitAction(Ref<RunnerAction> Action) +{ + // Base class is not directly usable — platform subclasses override this + ZEN_UNUSED(Action); + return SubmitResult{.IsAccepted = false}; } size_t @@ -346,6 +348,7 @@ LocalProcessRunner::GetSubmittedActionCount() std::filesystem::path LocalProcessRunner::ManifestWorker(const WorkerDesc& Worker) { + ZEN_TRACE_CPU("LocalProcessRunner::ManifestWorker"); RwLock::SharedLockScope _(m_WorkerLock); std::filesystem::path WorkerDir = m_WorkerPath / fmt::format("runner_{}", Worker.WorkerId); @@ -405,6 +408,23 @@ LocalProcessRunner::DecompressAttachmentToFile(const CbPackage& FromP std::filesystem::path FilePath{SandboxRootPath / std::filesystem::path(Name).make_preferred()}; + // Validate the resolved path stays within the sandbox to prevent directory traversal + // via malicious names like "../../etc/evil" + // + // This might be worth revisiting to frontload the validation and eliminate some memory + // allocations in the future. + { + std::filesystem::path CanonicalRoot = std::filesystem::weakly_canonical(SandboxRootPath); + std::filesystem::path CanonicalFile = std::filesystem::weakly_canonical(FilePath); + std::string RootStr = CanonicalRoot.string(); + std::string FileStr = CanonicalFile.string(); + + if (FileStr.size() < RootStr.size() || FileStr.compare(0, RootStr.size(), RootStr) != 0) + { + throw zen::runtime_error("path traversal detected: '{}' escapes sandbox root '{}'", Name, SandboxRootPath); + } + } + SharedBuffer Decompressed = Compressed.Decompress(); zen::WriteFile(FilePath, Decompressed.AsIoBuffer()); } @@ -421,12 +441,34 @@ LocalProcessRunner::ManifestWorker(const CbPackage& WorkerPackage, for (auto& It : WorkerDescription["executables"sv]) { DecompressAttachmentToFile(WorkerPackage, It.AsObjectView(), SandboxPath, ChunkReferenceCallback); +# if !ZEN_PLATFORM_WINDOWS + std::string_view ExeName = It.AsObjectView()["name"sv].AsString(); + std::filesystem::path ExePath{SandboxPath / std::filesystem::path(ExeName).make_preferred()}; + std::filesystem::permissions( + ExePath, + std::filesystem::perms::owner_exec | std::filesystem::perms::group_exec | std::filesystem::perms::others_exec, + std::filesystem::perm_options::add); +# endif } for (auto& It : WorkerDescription["dirs"sv]) { std::string_view Name = It.AsString(); std::filesystem::path DirPath{SandboxPath / std::filesystem::path(Name).make_preferred()}; + + // Validate dir path stays within sandbox + { + std::filesystem::path CanonicalRoot = std::filesystem::weakly_canonical(SandboxPath); + std::filesystem::path CanonicalDir = std::filesystem::weakly_canonical(DirPath); + std::string RootStr = CanonicalRoot.string(); + std::string DirStr = CanonicalDir.string(); + + if (DirStr.size() < RootStr.size() || DirStr.compare(0, RootStr.size(), RootStr) != 0) + { + throw zen::runtime_error("path traversal detected: dir '{}' escapes sandbox root '{}'", Name, SandboxPath); + } + } + zen::CreateDirectories(DirPath); } @@ -441,6 +483,7 @@ LocalProcessRunner::ManifestWorker(const CbPackage& WorkerPackage, CbPackage LocalProcessRunner::GatherActionOutputs(std::filesystem::path SandboxPath) { + ZEN_TRACE_CPU("LocalProcessRunner::GatherActionOutputs"); std::filesystem::path OutputFile = SandboxPath / "build.output"; FileContents OutputData = zen::ReadFile(OutputFile); @@ -542,134 +585,53 @@ LocalProcessRunner::MonitorThreadFunction() } SweepRunningActions(); + SampleRunningProcessCpu(); } // Signal received SweepRunningActions(); + SampleRunningProcessCpu(); } while (m_MonitorThreadEnabled); } void LocalProcessRunner::CancelRunningActions() { - Stopwatch Timer; - std::unordered_map<int, Ref<RunningAction>> RunningMap; - - m_RunningLock.WithExclusiveLock([&] { std::swap(RunningMap, m_RunningMap); }); - - if (RunningMap.empty()) - { - return; - } - - ZEN_INFO("cancelling all running actions"); - - // For expedience we initiate the process termination for all known - // processes before attempting to wait for them to exit. - - std::vector<int> TerminatedLsnList; - - for (const auto& Kv : RunningMap) - { - Ref<RunningAction> Action = Kv.second; - - // Terminate running process + // Base class is not directly usable — platform subclasses override this +} -# if ZEN_PLATFORM_WINDOWS - BOOL Success = TerminateProcess(Action->ProcessHandle, 222); +void +LocalProcessRunner::SampleRunningProcessCpu() +{ + static constexpr uint64_t kSampleIntervalMs = 5'000; - if (Success) - { - TerminatedLsnList.push_back(Kv.first); - } - else + m_RunningLock.WithSharedLock([&] { + const uint64_t Now = GetHifreqTimerValue(); + for (auto& [Lsn, Running] : m_RunningMap) { - DWORD LastError = GetLastError(); - - if (LastError != ERROR_ACCESS_DENIED) + const bool NeverSampled = Running->LastCpuSampleTicks == 0; + const bool IntervalElapsed = Stopwatch::GetElapsedTimeMs(Now - Running->LastCpuSampleTicks) >= kSampleIntervalMs; + if (NeverSampled || IntervalElapsed) { - ZEN_WARN("TerminateProcess for LSN {} not successful: {}", Action->Action->ActionLsn, GetSystemErrorAsString(LastError)); + SampleProcessCpu(*Running); } } -# else - ZEN_NOT_IMPLEMENTED("need to implement process termination"); -# endif - } - - // We only post results for processes we have terminated, in order - // to avoid multiple results getting posted for the same action - - for (int Lsn : TerminatedLsnList) - { - if (auto It = RunningMap.find(Lsn); It != RunningMap.end()) - { - Ref<RunningAction> Running = It->second; - -# if ZEN_PLATFORM_WINDOWS - if (Running->ProcessHandle != INVALID_HANDLE_VALUE) - { - DWORD WaitResult = WaitForSingleObject(Running->ProcessHandle, 2000); - - if (WaitResult != WAIT_OBJECT_0) - { - ZEN_WARN("wait for LSN {}: process exit did not succeed, result = {}", Running->Action->ActionLsn, WaitResult); - } - else - { - ZEN_DEBUG("LSN {}: process exit OK", Running->Action->ActionLsn); - } - } -# endif - - // Clean up and post error result - - DeleteDirectories(Running->SandboxPath); - Running->Action->SetActionState(RunnerAction::State::Failed); - } - } - - ZEN_INFO("DONE - cancelled {} running processes (took {})", TerminatedLsnList.size(), NiceTimeSpanMs(Timer.GetElapsedTimeMs())); + }); } void LocalProcessRunner::SweepRunningActions() { - std::vector<Ref<RunningAction>> CompletedActions; - - m_RunningLock.WithExclusiveLock([&] { - // TODO: It would be good to not hold the exclusive lock while making - // system calls and other expensive operations. - - for (auto It = begin(m_RunningMap), ItEnd = end(m_RunningMap); It != ItEnd;) - { - Ref<RunningAction> Action = It->second; - -# if ZEN_PLATFORM_WINDOWS - DWORD ExitCode = 0; - BOOL IsSuccess = GetExitCodeProcess(Action->ProcessHandle, &ExitCode); - - if (IsSuccess && ExitCode != STILL_ACTIVE) - { - CloseHandle(Action->ProcessHandle); - Action->ProcessHandle = INVALID_HANDLE_VALUE; - - CompletedActions.push_back(std::move(Action)); - It = m_RunningMap.erase(It); - } - else - { - ++It; - } -# else - // TODO: implement properly for Mac/Linux - - ZEN_UNUSED(Action); -# endif - } - }); + ZEN_TRACE_CPU("LocalProcessRunner::SweepRunningActions"); +} - // Notify outer. Note that this has to be done without holding any local locks +void +LocalProcessRunner::ProcessCompletedActions(std::vector<Ref<RunningAction>>& CompletedActions) +{ + ZEN_TRACE_CPU("LocalProcessRunner::ProcessCompletedActions"); + // Shared post-processing: gather outputs, set state, clean sandbox. + // Note that this must be called without holding any local locks // otherwise we may end up with deadlocks. for (Ref<RunningAction> Running : CompletedActions) @@ -687,11 +649,9 @@ LocalProcessRunner::SweepRunningActions() Running->Action->SetResult(std::move(OutputPackage)); Running->Action->SetActionState(RunnerAction::State::Completed); - // We can delete the files at this point - if (!DeleteDirectories(Running->SandboxPath)) - { - ZEN_WARN("Unable to delete directory '{}', this will continue to exist until service restart", Running->SandboxPath); - } + // Enqueue sandbox for deferred background deletion, giving + // file handles time to close before we attempt removal. + m_DeferredDeleter.Enqueue(ActionLsn, std::move(Running->SandboxPath)); // Success -- continue with next iteration of the loop continue; @@ -702,17 +662,9 @@ LocalProcessRunner::SweepRunningActions() } } - // Failed - for now this is indicated with an empty package in - // the results map. We can clean out the sandbox directory immediately. - - std::error_code Ec; - DeleteDirectories(Running->SandboxPath, Ec); - - if (Ec) - { - ZEN_WARN("Unable to delete sandbox directory '{}': {}", Running->SandboxPath, Ec.message()); - } + // Failed - clean up the sandbox in the background. + m_DeferredDeleter.Enqueue(ActionLsn, std::move(Running->SandboxPath)); Running->Action->SetActionState(RunnerAction::State::Failed); } } diff --git a/src/zencompute/localrunner.h b/src/zencompute/runners/localrunner.h index 35f464805..7493e980b 100644 --- a/src/zencompute/localrunner.h +++ b/src/zencompute/runners/localrunner.h @@ -2,7 +2,7 @@ #pragma once -#include "zencompute/functionservice.h" +#include "zencompute/computeservice.h" #if ZEN_WITH_COMPUTE_SERVICES @@ -14,8 +14,13 @@ # include <zencore/compactbinarypackage.h> # include <zencore/logging.h> +# include "deferreddeleter.h" + +# include <zencore/workthreadpool.h> + # include <atomic> # include <filesystem> +# include <optional> # include <thread> namespace zen { @@ -38,7 +43,11 @@ class LocalProcessRunner : public FunctionRunner LocalProcessRunner& operator=(LocalProcessRunner&&) = delete; public: - LocalProcessRunner(ChunkResolver& Resolver, const std::filesystem::path& BaseDir); + LocalProcessRunner(ChunkResolver& Resolver, + const std::filesystem::path& BaseDir, + DeferredDirectoryDeleter& Deleter, + WorkerThreadPool& WorkerPool, + int32_t MaxConcurrentActions = 0); ~LocalProcessRunner(); virtual void Shutdown() override; @@ -60,6 +69,10 @@ protected: void* ProcessHandle = nullptr; int ExitCode = 0; std::filesystem::path SandboxPath; + + // State for periodic CPU usage sampling + uint64_t LastCpuSampleTicks = 0; // hifreq timer value at last sample + uint64_t LastCpuOsTicks = 0; // OS CPU ticks (platform-specific units) at last sample }; std::atomic_bool m_AcceptNewActions; @@ -75,12 +88,37 @@ protected: RwLock m_RunningLock; std::unordered_map<int, Ref<RunningAction>> m_RunningMap; + std::atomic<int32_t> m_SubmittingCount = 0; + DeferredDirectoryDeleter& m_DeferredDeleter; + WorkerThreadPool& m_WorkerPool; + std::thread m_MonitorThread; std::atomic<bool> m_MonitorThreadEnabled{true}; Event m_MonitorThreadEvent; void MonitorThreadFunction(); - void SweepRunningActions(); - void CancelRunningActions(); + virtual void SweepRunningActions(); + virtual void CancelRunningActions(); + + // Sample CPU usage for all currently running processes (throttled per-action). + void SampleRunningProcessCpu(); + + // Override in platform runners to sample one process. Called under a shared RunningLock. + virtual void SampleProcessCpu(RunningAction& /*Running*/) {} + + // Shared preamble for SubmitAction: capacity check, sandbox creation, + // worker manifesting, action writing, input manifesting. + struct PreparedAction + { + int32_t ActionLsn; + std::filesystem::path SandboxPath; + std::filesystem::path WorkerPath; + CbPackage WorkerPackage; + }; + std::optional<PreparedAction> PrepareActionSubmission(Ref<RunnerAction> Action); + + // Shared post-processing for SweepRunningActions: gather outputs, + // set state, clean sandbox. + void ProcessCompletedActions(std::vector<Ref<RunningAction>>& CompletedActions); std::filesystem::path CreateNewSandbox(); void ManifestWorker(const CbPackage& WorkerPackage, diff --git a/src/zencompute/runners/macrunner.cpp b/src/zencompute/runners/macrunner.cpp new file mode 100644 index 000000000..5cec90699 --- /dev/null +++ b/src/zencompute/runners/macrunner.cpp @@ -0,0 +1,491 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "macrunner.h" + +#if ZEN_WITH_COMPUTE_SERVICES && ZEN_PLATFORM_MAC + +# include <zencore/compactbinary.h> +# include <zencore/compactbinarypackage.h> +# include <zencore/except.h> +# include <zencore/except_fmt.h> +# include <zencore/filesystem.h> +# include <zencore/fmtutils.h> +# include <zencore/timer.h> +# include <zencore/trace.h> + +# include <fcntl.h> +# include <libproc.h> +# include <sandbox.h> +# include <signal.h> +# include <sys/wait.h> +# include <unistd.h> + +namespace zen::compute { + +using namespace std::literals; + +namespace { + + // All helper functions in this namespace are async-signal-safe (safe to call + // between fork() and execve()). They use only raw syscalls and avoid any + // heap allocation, stdio, or other non-AS-safe operations. + + void WriteToFd(int Fd, const char* Buf, size_t Len) + { + while (Len > 0) + { + ssize_t Written = write(Fd, Buf, Len); + if (Written <= 0) + { + break; + } + Buf += Written; + Len -= static_cast<size_t>(Written); + } + } + + [[noreturn]] void WriteErrorAndExit(int ErrorPipeFd, const char* Msg, int Errno) + { + // Write the message prefix + size_t MsgLen = 0; + for (const char* P = Msg; *P; ++P) + { + ++MsgLen; + } + WriteToFd(ErrorPipeFd, Msg, MsgLen); + + // Append ": " and the errno string if non-zero + if (Errno != 0) + { + WriteToFd(ErrorPipeFd, ": ", 2); + const char* ErrStr = strerror(Errno); + size_t ErrLen = 0; + for (const char* P = ErrStr; *P; ++P) + { + ++ErrLen; + } + WriteToFd(ErrorPipeFd, ErrStr, ErrLen); + } + + _exit(127); + } + + // Build a Seatbelt profile string that denies everything by default and + // allows only the minimum needed for the worker to execute: process ops, + // system library reads, worker directory (read-only), and sandbox directory + // (read-write). Network access is denied implicitly by the deny-default policy. + std::string BuildSandboxProfile(const std::string& SandboxPath, const std::string& WorkerPath) + { + std::string Profile; + Profile.reserve(1024); + + Profile += "(version 1)\n"; + Profile += "(deny default)\n"; + Profile += "(allow process*)\n"; + Profile += "(allow sysctl-read)\n"; + Profile += "(allow file-read-metadata)\n"; + + // System library paths needed for dynamic linker and runtime + Profile += "(allow file-read* (subpath \"/usr\"))\n"; + Profile += "(allow file-read* (subpath \"/System\"))\n"; + Profile += "(allow file-read* (subpath \"/Library\"))\n"; + Profile += "(allow file-read* (subpath \"/dev\"))\n"; + Profile += "(allow file-read* (subpath \"/private/var/db/dyld\"))\n"; + Profile += "(allow file-read* (subpath \"/etc\"))\n"; + + // Worker directory: read-only + Profile += "(allow file-read* (subpath \""; + Profile += WorkerPath; + Profile += "\"))\n"; + + // Sandbox directory: read+write + Profile += "(allow file-read* file-write* (subpath \""; + Profile += SandboxPath; + Profile += "\"))\n"; + + return Profile; + } + +} // anonymous namespace + +MacProcessRunner::MacProcessRunner(ChunkResolver& Resolver, + const std::filesystem::path& BaseDir, + DeferredDirectoryDeleter& Deleter, + WorkerThreadPool& WorkerPool, + bool Sandboxed, + int32_t MaxConcurrentActions) +: LocalProcessRunner(Resolver, BaseDir, Deleter, WorkerPool, MaxConcurrentActions) +, m_Sandboxed(Sandboxed) +{ + // Restore SIGCHLD to default behavior so waitpid() can properly collect + // child exit status. zenserver/main.cpp sets SIGCHLD to SIG_IGN which + // causes the kernel to auto-reap children, making waitpid() return + // -1/ECHILD instead of the exit status we need. + struct sigaction Action = {}; + sigemptyset(&Action.sa_mask); + Action.sa_handler = SIG_DFL; + sigaction(SIGCHLD, &Action, nullptr); + + if (m_Sandboxed) + { + ZEN_INFO("Seatbelt sandboxing enabled for child processes"); + } +} + +SubmitResult +MacProcessRunner::SubmitAction(Ref<RunnerAction> Action) +{ + ZEN_TRACE_CPU("MacProcessRunner::SubmitAction"); + std::optional<PreparedAction> Prepared = PrepareActionSubmission(Action); + + if (!Prepared) + { + return SubmitResult{.IsAccepted = false}; + } + + // Build environment array from worker descriptor + + CbObject WorkerDescription = Prepared->WorkerPackage.GetObject(); + + std::vector<std::string> EnvStrings; + for (auto& It : WorkerDescription["environment"sv]) + { + EnvStrings.emplace_back(It.AsString()); + } + + std::vector<char*> Envp; + Envp.reserve(EnvStrings.size() + 1); + for (auto& Str : EnvStrings) + { + Envp.push_back(Str.data()); + } + Envp.push_back(nullptr); + + // Build argv: <worker_exe_path> -Build=build.action + + std::string_view ExecPath = WorkerDescription["path"sv].AsString(); + std::filesystem::path ExePath = Prepared->WorkerPath / std::filesystem::path(ExecPath); + std::string ExePathStr = ExePath.string(); + std::string BuildArg = "-Build=build.action"; + + std::vector<char*> ArgV; + ArgV.push_back(ExePathStr.data()); + ArgV.push_back(BuildArg.data()); + ArgV.push_back(nullptr); + + ZEN_DEBUG("Executing: {} {} (sandboxed={})", ExePathStr, BuildArg, m_Sandboxed); + + std::string SandboxPathStr = Prepared->SandboxPath.string(); + std::string WorkerPathStr = Prepared->WorkerPath.string(); + + // Pre-fork: build sandbox profile and create error pipe + std::string SandboxProfile; + int ErrorPipe[2] = {-1, -1}; + + if (m_Sandboxed) + { + SandboxProfile = BuildSandboxProfile(SandboxPathStr, WorkerPathStr); + + if (pipe(ErrorPipe) != 0) + { + throw zen::runtime_error("pipe() for sandbox error pipe failed: {}", strerror(errno)); + } + fcntl(ErrorPipe[0], F_SETFD, FD_CLOEXEC); + fcntl(ErrorPipe[1], F_SETFD, FD_CLOEXEC); + } + + pid_t ChildPid = fork(); + + if (ChildPid < 0) + { + int SavedErrno = errno; + if (m_Sandboxed) + { + close(ErrorPipe[0]); + close(ErrorPipe[1]); + } + throw zen::runtime_error("fork() failed: {}", strerror(SavedErrno)); + } + + if (ChildPid == 0) + { + // Child process + + if (m_Sandboxed) + { + // Close read end of error pipe — child only writes + close(ErrorPipe[0]); + + // Apply Seatbelt sandbox profile + char* ErrorBuf = nullptr; + if (sandbox_init(SandboxProfile.c_str(), 0, &ErrorBuf) != 0) + { + // sandbox_init failed — write error to pipe and exit + if (ErrorBuf) + { + WriteErrorAndExit(ErrorPipe[1], ErrorBuf, 0); + // WriteErrorAndExit does not return, but sandbox_free_error + // is not needed since we _exit + } + WriteErrorAndExit(ErrorPipe[1], "sandbox_init failed", errno); + } + if (ErrorBuf) + { + sandbox_free_error(ErrorBuf); + } + + if (chdir(SandboxPathStr.c_str()) != 0) + { + WriteErrorAndExit(ErrorPipe[1], "chdir to sandbox failed", errno); + } + + execve(ExePathStr.c_str(), ArgV.data(), Envp.data()); + + WriteErrorAndExit(ErrorPipe[1], "execve failed", errno); + } + else + { + if (chdir(SandboxPathStr.c_str()) != 0) + { + _exit(127); + } + + execve(ExePathStr.c_str(), ArgV.data(), Envp.data()); + _exit(127); + } + } + + // Parent process + + if (m_Sandboxed) + { + // Close write end of error pipe — parent only reads + close(ErrorPipe[1]); + + // Read from error pipe. If execve succeeded, pipe was closed by O_CLOEXEC + // and read returns 0. If setup failed, child wrote an error message. + char ErrBuf[512]; + ssize_t BytesRead = read(ErrorPipe[0], ErrBuf, sizeof(ErrBuf) - 1); + close(ErrorPipe[0]); + + if (BytesRead > 0) + { + // Sandbox setup or execve failed + ErrBuf[BytesRead] = '\0'; + + // Reap the child (it called _exit(127)) + waitpid(ChildPid, nullptr, 0); + + // Clean up the sandbox in the background + m_DeferredDeleter.Enqueue(Action->ActionLsn, std::move(Prepared->SandboxPath)); + + ZEN_ERROR("Sandbox setup failed for action {}: {}", Action->ActionLsn, ErrBuf); + + Action->SetActionState(RunnerAction::State::Failed); + return SubmitResult{.IsAccepted = false}; + } + } + + // Store child pid as void* (same convention as zencore/process.cpp) + + Ref<RunningAction> NewAction{new RunningAction()}; + NewAction->Action = Action; + NewAction->ProcessHandle = reinterpret_cast<void*>(static_cast<intptr_t>(ChildPid)); + NewAction->SandboxPath = std::move(Prepared->SandboxPath); + + { + RwLock::ExclusiveLockScope _(m_RunningLock); + m_RunningMap[Prepared->ActionLsn] = std::move(NewAction); + } + + Action->SetActionState(RunnerAction::State::Running); + + return SubmitResult{.IsAccepted = true}; +} + +void +MacProcessRunner::SweepRunningActions() +{ + ZEN_TRACE_CPU("MacProcessRunner::SweepRunningActions"); + std::vector<Ref<RunningAction>> CompletedActions; + + m_RunningLock.WithExclusiveLock([&] { + for (auto It = begin(m_RunningMap), ItEnd = end(m_RunningMap); It != ItEnd;) + { + Ref<RunningAction> Running = It->second; + + pid_t Pid = static_cast<pid_t>(reinterpret_cast<intptr_t>(Running->ProcessHandle)); + int Status = 0; + + pid_t Result = waitpid(Pid, &Status, WNOHANG); + + if (Result == Pid) + { + if (WIFEXITED(Status)) + { + Running->ExitCode = WEXITSTATUS(Status); + } + else if (WIFSIGNALED(Status)) + { + Running->ExitCode = 128 + WTERMSIG(Status); + } + else + { + Running->ExitCode = 1; + } + + Running->ProcessHandle = nullptr; + + CompletedActions.push_back(std::move(Running)); + It = m_RunningMap.erase(It); + } + else + { + ++It; + } + } + }); + + ProcessCompletedActions(CompletedActions); +} + +void +MacProcessRunner::CancelRunningActions() +{ + ZEN_TRACE_CPU("MacProcessRunner::CancelRunningActions"); + Stopwatch Timer; + std::unordered_map<int, Ref<RunningAction>> RunningMap; + + m_RunningLock.WithExclusiveLock([&] { std::swap(RunningMap, m_RunningMap); }); + + if (RunningMap.empty()) + { + return; + } + + ZEN_INFO("cancelling all running actions"); + + // Send SIGTERM to all running processes first + + for (const auto& [Lsn, Running] : RunningMap) + { + pid_t Pid = static_cast<pid_t>(reinterpret_cast<intptr_t>(Running->ProcessHandle)); + + if (kill(Pid, SIGTERM) != 0) + { + ZEN_WARN("kill(SIGTERM) for LSN {} (pid {}) failed: {}", Running->Action->ActionLsn, Pid, strerror(errno)); + } + } + + // Wait for all processes, regardless of whether SIGTERM succeeded, then clean up. + + for (auto& [Lsn, Running] : RunningMap) + { + pid_t Pid = static_cast<pid_t>(reinterpret_cast<intptr_t>(Running->ProcessHandle)); + + // Poll for up to 2 seconds + bool Exited = false; + for (int i = 0; i < 20; ++i) + { + int Status = 0; + pid_t WaitResult = waitpid(Pid, &Status, WNOHANG); + if (WaitResult == Pid) + { + Exited = true; + ZEN_DEBUG("LSN {}: process exit OK", Running->Action->ActionLsn); + break; + } + usleep(100000); // 100ms + } + + if (!Exited) + { + ZEN_WARN("LSN {}: process did not exit after SIGTERM, sending SIGKILL", Running->Action->ActionLsn); + kill(Pid, SIGKILL); + waitpid(Pid, nullptr, 0); + } + + m_DeferredDeleter.Enqueue(Running->Action->ActionLsn, std::move(Running->SandboxPath)); + Running->Action->SetActionState(RunnerAction::State::Failed); + } + + ZEN_INFO("DONE - cancelled {} running processes (took {})", RunningMap.size(), NiceTimeSpanMs(Timer.GetElapsedTimeMs())); +} + +bool +MacProcessRunner::CancelAction(int ActionLsn) +{ + ZEN_TRACE_CPU("MacProcessRunner::CancelAction"); + + // Hold the shared lock while sending the signal to prevent the sweep thread + // from reaping the PID (via waitpid) between our lookup and kill(). Without + // the lock held, the PID could be recycled by the kernel and we'd signal an + // unrelated process. + bool Sent = false; + + m_RunningLock.WithSharedLock([&] { + auto It = m_RunningMap.find(ActionLsn); + if (It == m_RunningMap.end()) + { + return; + } + + Ref<RunningAction> Target = It->second; + if (!Target->ProcessHandle) + { + return; + } + + pid_t Pid = static_cast<pid_t>(reinterpret_cast<intptr_t>(Target->ProcessHandle)); + + if (kill(Pid, SIGTERM) != 0) + { + ZEN_WARN("CancelAction: kill(SIGTERM) for LSN {} (pid {}) failed: {}", ActionLsn, Pid, strerror(errno)); + return; + } + + ZEN_DEBUG("CancelAction: sent SIGTERM to LSN {} (pid {})", ActionLsn, Pid); + Sent = true; + }); + + // The monitor thread will pick up the process exit and mark the action as Failed. + return Sent; +} + +void +MacProcessRunner::SampleProcessCpu(RunningAction& Running) +{ + const pid_t Pid = static_cast<pid_t>(reinterpret_cast<intptr_t>(Running.ProcessHandle)); + + struct proc_taskinfo Info; + if (proc_pidinfo(Pid, PROC_PIDTASKINFO, 0, &Info, sizeof(Info)) <= 0) + { + return; + } + + // pti_total_user and pti_total_system are in nanoseconds + const uint64_t CurrentOsTicks = Info.pti_total_user + Info.pti_total_system; + const uint64_t NowTicks = GetHifreqTimerValue(); + + // Cumulative CPU seconds (absolute, available from first sample): ns → seconds + Running.Action->CpuSeconds.store(static_cast<float>(static_cast<double>(CurrentOsTicks) / 1'000'000'000.0), std::memory_order_relaxed); + + if (Running.LastCpuSampleTicks != 0 && Running.LastCpuOsTicks != 0) + { + const uint64_t ElapsedMs = Stopwatch::GetElapsedTimeMs(NowTicks - Running.LastCpuSampleTicks); + if (ElapsedMs > 0) + { + const uint64_t DeltaOsTicks = CurrentOsTicks - Running.LastCpuOsTicks; + // ns → ms: divide by 1,000,000; then as percent of elapsed ms + const float CpuPct = static_cast<float>(static_cast<double>(DeltaOsTicks) / 1'000'000.0 / ElapsedMs * 100.0); + Running.Action->CpuUsagePercent.store(CpuPct, std::memory_order_relaxed); + } + } + + Running.LastCpuSampleTicks = NowTicks; + Running.LastCpuOsTicks = CurrentOsTicks; +} + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/runners/macrunner.h b/src/zencompute/runners/macrunner.h new file mode 100644 index 000000000..d653b923a --- /dev/null +++ b/src/zencompute/runners/macrunner.h @@ -0,0 +1,43 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "localrunner.h" + +#if ZEN_WITH_COMPUTE_SERVICES && ZEN_PLATFORM_MAC + +namespace zen::compute { + +/** Native macOS process runner for executing Mac worker executables directly. + + Subclasses LocalProcessRunner, reusing sandbox management, worker manifesting, + input/output handling, and monitor thread infrastructure. Overrides only the + platform-specific methods: process spawning, sweep, and cancellation. + + When Sandboxed is true, child processes are isolated using macOS Seatbelt + (sandbox_init): no network access and no filesystem access outside the + explicitly allowed sandbox and worker directories. This requires no elevation. + */ +class MacProcessRunner : public LocalProcessRunner +{ +public: + MacProcessRunner(ChunkResolver& Resolver, + const std::filesystem::path& BaseDir, + DeferredDirectoryDeleter& Deleter, + WorkerThreadPool& WorkerPool, + bool Sandboxed = false, + int32_t MaxConcurrentActions = 0); + + [[nodiscard]] SubmitResult SubmitAction(Ref<RunnerAction> Action) override; + void SweepRunningActions() override; + void CancelRunningActions() override; + bool CancelAction(int ActionLsn) override; + void SampleProcessCpu(RunningAction& Running) override; + +private: + bool m_Sandboxed = false; +}; + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/remotehttprunner.cpp b/src/zencompute/runners/remotehttprunner.cpp index 98ced5fe8..672636d06 100644 --- a/src/zencompute/remotehttprunner.cpp +++ b/src/zencompute/runners/remotehttprunner.cpp @@ -14,6 +14,8 @@ # include <zencore/iobuffer.h> # include <zencore/iohash.h> # include <zencore/scopeguard.h> +# include <zencore/system.h> +# include <zencore/trace.h> # include <zenhttp/httpcommon.h> # include <zenstore/cidstore.h> @@ -27,12 +29,18 @@ using namespace std::literals; ////////////////////////////////////////////////////////////////////////// -RemoteHttpRunner::RemoteHttpRunner(ChunkResolver& InChunkResolver, const std::filesystem::path& BaseDir, std::string_view HostName) +RemoteHttpRunner::RemoteHttpRunner(ChunkResolver& InChunkResolver, + const std::filesystem::path& BaseDir, + std::string_view HostName, + WorkerThreadPool& InWorkerPool) : FunctionRunner(BaseDir) , m_Log(logging::Get("http_exec")) , m_ChunkResolver{InChunkResolver} -, m_BaseUrl{fmt::format("{}/apply", HostName)} +, m_WorkerPool{InWorkerPool} +, m_HostName{HostName} +, m_BaseUrl{fmt::format("{}/compute", HostName)} , m_Http(m_BaseUrl) +, m_InstanceId(Oid::NewOid()) { m_MonitorThread = std::thread{&RemoteHttpRunner::MonitorThreadFunction, this}; } @@ -58,6 +66,7 @@ RemoteHttpRunner::Shutdown() void RemoteHttpRunner::RegisterWorker(const CbPackage& WorkerPackage) { + ZEN_TRACE_CPU("RemoteHttpRunner::RegisterWorker"); const IoHash WorkerId = WorkerPackage.GetObjectHash(); CbPackage WorkerDesc = WorkerPackage; @@ -168,11 +177,38 @@ RemoteHttpRunner::QueryCapacity() std::vector<SubmitResult> RemoteHttpRunner::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions) { - std::vector<SubmitResult> Results; + ZEN_TRACE_CPU("RemoteHttpRunner::SubmitActions"); + + if (Actions.size() <= 1) + { + std::vector<SubmitResult> Results; + + for (const Ref<RunnerAction>& Action : Actions) + { + Results.push_back(SubmitAction(Action)); + } + + return Results; + } + + // For larger batches, submit HTTP requests in parallel via the shared worker pool + + std::vector<std::future<SubmitResult>> Futures; + Futures.reserve(Actions.size()); for (const Ref<RunnerAction>& Action : Actions) { - Results.push_back(SubmitAction(Action)); + std::packaged_task<SubmitResult()> Task([this, Action]() { return SubmitAction(Action); }); + + Futures.push_back(m_WorkerPool.EnqueueTask(std::move(Task), WorkerThreadPool::EMode::EnableBacklog)); + } + + std::vector<SubmitResult> Results; + Results.reserve(Futures.size()); + + for (auto& Future : Futures) + { + Results.push_back(Future.get()); } return Results; @@ -181,6 +217,8 @@ RemoteHttpRunner::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions) SubmitResult RemoteHttpRunner::SubmitAction(Ref<RunnerAction> Action) { + ZEN_TRACE_CPU("RemoteHttpRunner::SubmitAction"); + // Verify whether we can accept more work { @@ -197,18 +235,53 @@ RemoteHttpRunner::SubmitAction(Ref<RunnerAction> Action) // which we use as a key for tracking data structures and as an opaque id which // may be used by clients to reference the scheduled action + Action->ExecutionLocation = m_HostName; + const int32_t ActionLsn = Action->ActionLsn; const CbObject& ActionObj = Action->ActionObj; const IoHash ActionId = ActionObj.GetHash(); MaybeDumpAction(ActionLsn, ActionObj); - // Enqueue job + // Determine the submission URL. If the action belongs to a queue, ensure a + // corresponding remote queue exists on the target node and submit via it. + + std::string SubmitUrl = "/jobs"; + if (const int QueueId = Action->QueueId; QueueId != 0) + { + CbObject QueueMeta = Action->GetOwnerSession()->GetQueueMetadata(QueueId); + CbObject QueueConfig = Action->GetOwnerSession()->GetQueueConfig(QueueId); + if (Oid Token = EnsureRemoteQueue(QueueId, QueueMeta, QueueConfig); Token != Oid::Zero) + { + SubmitUrl = fmt::format("/queues/{}/jobs", Token); + } + } - CbObject Result; + // Enqueue job. If the remote returns FailedDependency (424), it means it + // cannot resolve the worker/function — re-register the worker and retry once. - HttpClient::Response WorkResponse = m_Http.Post("/jobs", ActionObj); - HttpResponseCode WorkResponseCode = WorkResponse.StatusCode; + CbObject Result; + HttpClient::Response WorkResponse; + HttpResponseCode WorkResponseCode{}; + + for (int Attempt = 0; Attempt < 2; ++Attempt) + { + WorkResponse = m_Http.Post(SubmitUrl, ActionObj); + WorkResponseCode = WorkResponse.StatusCode; + + if (WorkResponseCode == HttpResponseCode::FailedDependency && Attempt == 0) + { + ZEN_WARN("remote {} returned FailedDependency for action {} — re-registering worker and retrying", + m_Http.GetBaseUri(), + ActionId); + + RegisterWorker(Action->Worker.Descriptor); + } + else + { + break; + } + } if (WorkResponseCode == HttpResponseCode::OK) { @@ -250,11 +323,11 @@ RemoteHttpRunner::SubmitAction(Ref<RunnerAction> Action) // Post resulting package - HttpClient::Response PayloadResponse = m_Http.Post("/jobs", Pkg); + HttpClient::Response PayloadResponse = m_Http.Post(SubmitUrl, Pkg); if (!PayloadResponse) { - ZEN_WARN("unable to register payloads for action {} at {}/jobs", ActionId, m_Http.GetBaseUri()); + ZEN_WARN("unable to register payloads for action {} at {}{}", ActionId, m_Http.GetBaseUri(), SubmitUrl); // TODO: include more information about the failure in the response @@ -270,17 +343,19 @@ RemoteHttpRunner::SubmitAction(Ref<RunnerAction> Action) const int ResponseStatusCode = (int)PayloadResponse.StatusCode; - ZEN_WARN("unable to register payloads for action {} at {}/jobs (error: {} {})", + ZEN_WARN("unable to register payloads for action {} at {}{} (error: {} {})", ActionId, m_Http.GetBaseUri(), + SubmitUrl, ResponseStatusCode, ToString(ResponseStatusCode)); return {.IsAccepted = false, - .Reason = fmt::format("unexpected response code {} {} from {}/jobs", + .Reason = fmt::format("unexpected response code {} {} from {}{}", ResponseStatusCode, ToString(ResponseStatusCode), - m_Http.GetBaseUri())}; + m_Http.GetBaseUri(), + SubmitUrl)}; } } @@ -309,6 +384,82 @@ RemoteHttpRunner::SubmitAction(Ref<RunnerAction> Action) return {}; } +Oid +RemoteHttpRunner::EnsureRemoteQueue(int QueueId, const CbObject& Metadata, const CbObject& Config) +{ + { + RwLock::SharedLockScope _(m_QueueTokenLock); + if (auto It = m_RemoteQueueTokens.find(QueueId); It != m_RemoteQueueTokens.end()) + { + return It->second; + } + } + + // Build a stable idempotency key that uniquely identifies this (runner instance, local queue) + // pair. The server uses this to return the same remote queue token for concurrent or redundant + // requests, preventing orphaned remote queues when multiple threads race through here. + // Also send hostname so the server can associate the queue with its origin for diagnostics. + CbObjectWriter Body; + Body << "idempotency_key"sv << fmt::format("{}/{}", m_InstanceId, QueueId); + Body << "hostname"sv << GetMachineName(); + if (Metadata) + { + Body << "metadata"sv << Metadata; + } + if (Config) + { + Body << "config"sv << Config; + } + + HttpClient::Response Resp = m_Http.Post("/queues/remote", Body.Save()); + if (!Resp) + { + ZEN_WARN("failed to create remote queue for local queue {} on {}", QueueId, m_HostName); + return Oid::Zero; + } + + Oid Token = Oid::TryFromHexString(Resp.AsObject()["queue_token"sv].AsString()); + if (Token == Oid::Zero) + { + return Oid::Zero; + } + + ZEN_DEBUG("created remote queue '{}' for local queue {} on {}", Token, QueueId, m_HostName); + + RwLock::ExclusiveLockScope _(m_QueueTokenLock); + auto [It, Inserted] = m_RemoteQueueTokens.try_emplace(QueueId, Token); + return It->second; +} + +void +RemoteHttpRunner::CancelRemoteQueue(int QueueId) +{ + Oid Token; + { + RwLock::SharedLockScope _(m_QueueTokenLock); + if (auto It = m_RemoteQueueTokens.find(QueueId); It != m_RemoteQueueTokens.end()) + { + Token = It->second; + } + } + + if (Token == Oid::Zero) + { + return; + } + + HttpClient::Response Resp = m_Http.Delete(fmt::format("/queues/{}", Token)); + + if (Resp.StatusCode == HttpResponseCode::NoContent) + { + ZEN_DEBUG("cancelled remote queue '{}' (local queue {}) on {}", Token, QueueId, m_HostName); + } + else + { + ZEN_WARN("failed to cancel remote queue '{}' on {}: {}", Token, m_HostName, int(Resp.StatusCode)); + } +} + bool RemoteHttpRunner::IsHealthy() { @@ -337,7 +488,7 @@ RemoteHttpRunner::MonitorThreadFunction() do { - const int NormalWaitingTime = 1000; + const int NormalWaitingTime = 200; int WaitTimeMs = NormalWaitingTime; auto WaitOnce = [&] { return m_MonitorThreadEvent.Wait(WaitTimeMs); }; auto SweepOnce = [&] { @@ -376,6 +527,7 @@ RemoteHttpRunner::MonitorThreadFunction() size_t RemoteHttpRunner::SweepRunningActions() { + ZEN_TRACE_CPU("RemoteHttpRunner::SweepRunningActions"); std::vector<HttpRunningAction> CompletedActions; // Poll remote for list of completed actions @@ -386,29 +538,38 @@ RemoteHttpRunner::SweepRunningActions() { for (auto& FieldIt : Completed["completed"sv]) { - const int32_t CompleteLsn = FieldIt.AsInt32(); + CbObjectView EntryObj = FieldIt.AsObjectView(); + const int32_t CompleteLsn = EntryObj["lsn"sv].AsInt32(); + std::string_view StateName = EntryObj["state"sv].AsString(); - if (HttpClient::Response ResponseJob = m_Http.Get(fmt::format("/jobs/{}"sv, CompleteLsn))) - { - m_RunningLock.WithExclusiveLock([&] { - if (auto CompleteIt = m_RemoteRunningMap.find(CompleteLsn); CompleteIt != m_RemoteRunningMap.end()) - { - HttpRunningAction CompletedAction = std::move(CompleteIt->second); - CompletedAction.ActionResults = ResponseJob.AsPackage(); - CompletedAction.Success = true; + RunnerAction::State RemoteState = RunnerAction::FromString(StateName); - CompletedActions.push_back(std::move(CompletedAction)); - m_RemoteRunningMap.erase(CompleteIt); - } - else + // Always fetch to drain the result from the remote's results map, + // but only keep the result package for successfully completed actions. + HttpClient::Response ResponseJob = m_Http.Get(fmt::format("/jobs/{}"sv, CompleteLsn)); + + m_RunningLock.WithExclusiveLock([&] { + if (auto CompleteIt = m_RemoteRunningMap.find(CompleteLsn); CompleteIt != m_RemoteRunningMap.end()) + { + HttpRunningAction CompletedAction = std::move(CompleteIt->second); + CompletedAction.RemoteState = RemoteState; + + if (RemoteState == RunnerAction::State::Completed && ResponseJob) { - // we received a completion notice for an action we don't know about, - // this can happen if the runner is used by multiple upstream schedulers, - // or if this compute node was recently restarted and lost track of - // previously scheduled actions + CompletedAction.ActionResults = ResponseJob.AsPackage(); } - }); - } + + CompletedActions.push_back(std::move(CompletedAction)); + m_RemoteRunningMap.erase(CompleteIt); + } + else + { + // we received a completion notice for an action we don't know about, + // this can happen if the runner is used by multiple upstream schedulers, + // or if this compute node was recently restarted and lost track of + // previously scheduled actions + } + }); } if (CbObjectView Metrics = Completed["metrics"sv].AsObjectView()) @@ -435,18 +596,18 @@ RemoteHttpRunner::SweepRunningActions() { const int ActionLsn = HttpAction.Action->ActionLsn; - if (HttpAction.Success) - { - ZEN_DEBUG("completed: {} LSN {} (remote LSN {})", HttpAction.Action->ActionId, ActionLsn, HttpAction.RemoteActionLsn); - - HttpAction.Action->SetActionState(RunnerAction::State::Completed); + ZEN_DEBUG("action {} LSN {} (remote LSN {}) -> {}", + HttpAction.Action->ActionId, + ActionLsn, + HttpAction.RemoteActionLsn, + RunnerAction::ToString(HttpAction.RemoteState)); - HttpAction.Action->SetResult(std::move(HttpAction.ActionResults)); - } - else + if (HttpAction.RemoteState == RunnerAction::State::Completed) { - HttpAction.Action->SetActionState(RunnerAction::State::Failed); + HttpAction.Action->SetResult(std::move(HttpAction.ActionResults)); } + + HttpAction.Action->SetActionState(HttpAction.RemoteState); } return CompletedActions.size(); diff --git a/src/zencompute/remotehttprunner.h b/src/zencompute/runners/remotehttprunner.h index 1e885da3d..9119992a9 100644 --- a/src/zencompute/remotehttprunner.h +++ b/src/zencompute/runners/remotehttprunner.h @@ -2,7 +2,7 @@ #pragma once -#include "zencompute/functionservice.h" +#include "zencompute/computeservice.h" #if ZEN_WITH_COMPUTE_SERVICES @@ -10,12 +10,15 @@ # include <zencore/compactbinarypackage.h> # include <zencore/logging.h> +# include <zencore/uid.h> +# include <zencore/workthreadpool.h> # include <zencore/zencore.h> # include <zenhttp/httpclient.h> # include <atomic> # include <filesystem> # include <thread> +# include <unordered_map> namespace zen { class CidStore; @@ -35,7 +38,10 @@ class RemoteHttpRunner : public FunctionRunner RemoteHttpRunner& operator=(RemoteHttpRunner&&) = delete; public: - RemoteHttpRunner(ChunkResolver& InChunkResolver, const std::filesystem::path& BaseDir, std::string_view HostName); + RemoteHttpRunner(ChunkResolver& InChunkResolver, + const std::filesystem::path& BaseDir, + std::string_view HostName, + WorkerThreadPool& InWorkerPool); ~RemoteHttpRunner(); virtual void Shutdown() override; @@ -45,24 +51,29 @@ public: [[nodiscard]] virtual size_t GetSubmittedActionCount() override; [[nodiscard]] virtual size_t QueryCapacity() override; [[nodiscard]] virtual std::vector<SubmitResult> SubmitActions(const std::vector<Ref<RunnerAction>>& Actions) override; + virtual void CancelRemoteQueue(int QueueId) override; + + std::string_view GetHostName() const { return m_HostName; } protected: LoggerRef Log() { return m_Log; } private: - LoggerRef m_Log; - ChunkResolver& m_ChunkResolver; - std::string m_BaseUrl; - HttpClient m_Http; + LoggerRef m_Log; + ChunkResolver& m_ChunkResolver; + WorkerThreadPool& m_WorkerPool; + std::string m_HostName; + std::string m_BaseUrl; + HttpClient m_Http; int32_t m_MaxRunningActions = 256; // arbitrary limit for testing struct HttpRunningAction { - Ref<RunnerAction> Action; - int RemoteActionLsn = 0; // Remote LSN - bool Success = false; - CbPackage ActionResults; + Ref<RunnerAction> Action; + int RemoteActionLsn = 0; // Remote LSN + RunnerAction::State RemoteState = RunnerAction::State::Failed; + CbPackage ActionResults; }; RwLock m_RunningLock; @@ -73,6 +84,15 @@ private: Event m_MonitorThreadEvent; void MonitorThreadFunction(); size_t SweepRunningActions(); + + RwLock m_QueueTokenLock; + std::unordered_map<int, Oid> m_RemoteQueueTokens; // local QueueId → remote queue token + + // Stable identity for this runner instance, used as part of the idempotency key when + // creating remote queues. Generated once at construction and never changes. + Oid m_InstanceId; + + Oid EnsureRemoteQueue(int QueueId, const CbObject& Metadata, const CbObject& Config); }; } // namespace zen::compute diff --git a/src/zencompute/runners/windowsrunner.cpp b/src/zencompute/runners/windowsrunner.cpp new file mode 100644 index 000000000..e9a1ae8b6 --- /dev/null +++ b/src/zencompute/runners/windowsrunner.cpp @@ -0,0 +1,460 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "windowsrunner.h" + +#if ZEN_WITH_COMPUTE_SERVICES && ZEN_PLATFORM_WINDOWS + +# include <zencore/compactbinary.h> +# include <zencore/compactbinarypackage.h> +# include <zencore/except.h> +# include <zencore/except_fmt.h> +# include <zencore/filesystem.h> +# include <zencore/fmtutils.h> +# include <zencore/scopeguard.h> +# include <zencore/trace.h> +# include <zencore/system.h> +# include <zencore/timer.h> + +ZEN_THIRD_PARTY_INCLUDES_START +# include <userenv.h> +# include <aclapi.h> +# include <sddl.h> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen::compute { + +using namespace std::literals; + +WindowsProcessRunner::WindowsProcessRunner(ChunkResolver& Resolver, + const std::filesystem::path& BaseDir, + DeferredDirectoryDeleter& Deleter, + WorkerThreadPool& WorkerPool, + bool Sandboxed, + int32_t MaxConcurrentActions) +: LocalProcessRunner(Resolver, BaseDir, Deleter, WorkerPool, MaxConcurrentActions) +, m_Sandboxed(Sandboxed) +{ + if (!m_Sandboxed) + { + return; + } + + // Build a unique profile name per process to avoid collisions + m_AppContainerName = L"zenserver-sandbox-" + std::to_wstring(GetCurrentProcessId()); + + // Clean up any stale profile from a previous crash + DeleteAppContainerProfile(m_AppContainerName.c_str()); + + PSID Sid = nullptr; + + HRESULT Hr = CreateAppContainerProfile(m_AppContainerName.c_str(), + m_AppContainerName.c_str(), // display name + m_AppContainerName.c_str(), // description + nullptr, // no capabilities + 0, // capability count + &Sid); + + if (FAILED(Hr)) + { + throw zen::runtime_error("CreateAppContainerProfile failed: HRESULT 0x{:08X}", static_cast<uint32_t>(Hr)); + } + + m_AppContainerSid = Sid; + + ZEN_INFO("AppContainer sandboxing enabled for child processes (profile={})", WideToUtf8(m_AppContainerName)); +} + +WindowsProcessRunner::~WindowsProcessRunner() +{ + if (m_AppContainerSid) + { + FreeSid(m_AppContainerSid); + m_AppContainerSid = nullptr; + } + + if (!m_AppContainerName.empty()) + { + DeleteAppContainerProfile(m_AppContainerName.c_str()); + } +} + +void +WindowsProcessRunner::GrantAppContainerAccess(const std::filesystem::path& Path, DWORD AccessMask) +{ + PACL ExistingDacl = nullptr; + PSECURITY_DESCRIPTOR SecurityDescriptor = nullptr; + + DWORD Result = GetNamedSecurityInfoW(Path.c_str(), + SE_FILE_OBJECT, + DACL_SECURITY_INFORMATION, + nullptr, + nullptr, + &ExistingDacl, + nullptr, + &SecurityDescriptor); + + if (Result != ERROR_SUCCESS) + { + throw zen::runtime_error("GetNamedSecurityInfoW failed for '{}': {}", Path.string(), GetSystemErrorAsString(Result)); + } + + auto $0 = MakeGuard([&] { LocalFree(SecurityDescriptor); }); + + EXPLICIT_ACCESSW Access{}; + Access.grfAccessPermissions = AccessMask; + Access.grfAccessMode = SET_ACCESS; + Access.grfInheritance = OBJECT_INHERIT_ACE | CONTAINER_INHERIT_ACE; + Access.Trustee.TrusteeForm = TRUSTEE_IS_SID; + Access.Trustee.TrusteeType = TRUSTEE_IS_WELL_KNOWN_GROUP; + Access.Trustee.ptstrName = static_cast<LPWSTR>(m_AppContainerSid); + + PACL NewDacl = nullptr; + + Result = SetEntriesInAclW(1, &Access, ExistingDacl, &NewDacl); + if (Result != ERROR_SUCCESS) + { + throw zen::runtime_error("SetEntriesInAclW failed for '{}': {}", Path.string(), GetSystemErrorAsString(Result)); + } + + auto $1 = MakeGuard([&] { LocalFree(NewDacl); }); + + Result = SetNamedSecurityInfoW(const_cast<LPWSTR>(Path.c_str()), + SE_FILE_OBJECT, + DACL_SECURITY_INFORMATION, + nullptr, + nullptr, + NewDacl, + nullptr); + + if (Result != ERROR_SUCCESS) + { + throw zen::runtime_error("SetNamedSecurityInfoW failed for '{}': {}", Path.string(), GetSystemErrorAsString(Result)); + } +} + +SubmitResult +WindowsProcessRunner::SubmitAction(Ref<RunnerAction> Action) +{ + ZEN_TRACE_CPU("WindowsProcessRunner::SubmitAction"); + std::optional<PreparedAction> Prepared = PrepareActionSubmission(Action); + + if (!Prepared) + { + return SubmitResult{.IsAccepted = false}; + } + + // Set up environment variables + + CbObject WorkerDescription = Prepared->WorkerPackage.GetObject(); + + StringBuilder<1024> EnvironmentBlock; + + for (auto& It : WorkerDescription["environment"sv]) + { + EnvironmentBlock.Append(It.AsString()); + EnvironmentBlock.Append('\0'); + } + EnvironmentBlock.Append('\0'); + EnvironmentBlock.Append('\0'); + + // Execute process - this spawns the child process immediately without waiting + // for completion + + std::string_view ExecPath = WorkerDescription["path"sv].AsString(); + std::filesystem::path ExePath = Prepared->WorkerPath / std::filesystem::path(ExecPath).make_preferred(); + + ExtendableWideStringBuilder<512> CommandLine; + CommandLine.Append(L'"'); + CommandLine.Append(ExePath.c_str()); + CommandLine.Append(L'"'); + CommandLine.Append(L" -Build=build.action"); + + LPSECURITY_ATTRIBUTES lpProcessAttributes = nullptr; + LPSECURITY_ATTRIBUTES lpThreadAttributes = nullptr; + BOOL bInheritHandles = FALSE; + DWORD dwCreationFlags = 0; + + ZEN_DEBUG("Executing: {} (sandboxed={})", WideToUtf8(CommandLine.c_str()), m_Sandboxed); + + CommandLine.EnsureNulTerminated(); + + PROCESS_INFORMATION ProcessInformation{}; + + if (m_Sandboxed) + { + // Grant AppContainer access to sandbox and worker directories + GrantAppContainerAccess(Prepared->SandboxPath, FILE_ALL_ACCESS); + GrantAppContainerAccess(Prepared->WorkerPath, FILE_GENERIC_READ | FILE_GENERIC_EXECUTE); + + // Set up extended startup info with AppContainer security capabilities + SECURITY_CAPABILITIES SecurityCapabilities{}; + SecurityCapabilities.AppContainerSid = m_AppContainerSid; + SecurityCapabilities.Capabilities = nullptr; + SecurityCapabilities.CapabilityCount = 0; + + SIZE_T AttrListSize = 0; + InitializeProcThreadAttributeList(nullptr, 1, 0, &AttrListSize); + + auto AttrList = static_cast<PPROC_THREAD_ATTRIBUTE_LIST>(malloc(AttrListSize)); + auto $0 = MakeGuard([&] { free(AttrList); }); + + if (!InitializeProcThreadAttributeList(AttrList, 1, 0, &AttrListSize)) + { + zen::ThrowLastError("InitializeProcThreadAttributeList failed"); + } + + auto $1 = MakeGuard([&] { DeleteProcThreadAttributeList(AttrList); }); + + if (!UpdateProcThreadAttribute(AttrList, + 0, + PROC_THREAD_ATTRIBUTE_SECURITY_CAPABILITIES, + &SecurityCapabilities, + sizeof(SecurityCapabilities), + nullptr, + nullptr)) + { + zen::ThrowLastError("UpdateProcThreadAttribute (SECURITY_CAPABILITIES) failed"); + } + + STARTUPINFOEXW StartupInfoEx{}; + StartupInfoEx.StartupInfo.cb = sizeof(STARTUPINFOEXW); + StartupInfoEx.lpAttributeList = AttrList; + + dwCreationFlags |= EXTENDED_STARTUPINFO_PRESENT; + + BOOL Success = CreateProcessW(nullptr, + CommandLine.Data(), + lpProcessAttributes, + lpThreadAttributes, + bInheritHandles, + dwCreationFlags, + (LPVOID)EnvironmentBlock.Data(), + Prepared->SandboxPath.c_str(), + &StartupInfoEx.StartupInfo, + /* out */ &ProcessInformation); + + if (!Success) + { + zen::ThrowLastError("Unable to launch sandboxed process"); + } + } + else + { + STARTUPINFO StartupInfo{}; + StartupInfo.cb = sizeof StartupInfo; + + BOOL Success = CreateProcessW(nullptr, + CommandLine.Data(), + lpProcessAttributes, + lpThreadAttributes, + bInheritHandles, + dwCreationFlags, + (LPVOID)EnvironmentBlock.Data(), + Prepared->SandboxPath.c_str(), + &StartupInfo, + /* out */ &ProcessInformation); + + if (!Success) + { + zen::ThrowLastError("Unable to launch process"); + } + } + + CloseHandle(ProcessInformation.hThread); + + Ref<RunningAction> NewAction{new RunningAction()}; + NewAction->Action = Action; + NewAction->ProcessHandle = ProcessInformation.hProcess; + NewAction->SandboxPath = std::move(Prepared->SandboxPath); + + { + RwLock::ExclusiveLockScope _(m_RunningLock); + + m_RunningMap[Prepared->ActionLsn] = std::move(NewAction); + } + + Action->SetActionState(RunnerAction::State::Running); + + return SubmitResult{.IsAccepted = true}; +} + +void +WindowsProcessRunner::SweepRunningActions() +{ + ZEN_TRACE_CPU("WindowsProcessRunner::SweepRunningActions"); + std::vector<Ref<RunningAction>> CompletedActions; + + m_RunningLock.WithExclusiveLock([&] { + for (auto It = begin(m_RunningMap), ItEnd = end(m_RunningMap); It != ItEnd;) + { + Ref<RunningAction> Running = It->second; + + DWORD ExitCode = 0; + BOOL IsSuccess = GetExitCodeProcess(Running->ProcessHandle, &ExitCode); + + if (IsSuccess && ExitCode != STILL_ACTIVE) + { + CloseHandle(Running->ProcessHandle); + Running->ProcessHandle = INVALID_HANDLE_VALUE; + Running->ExitCode = ExitCode; + + CompletedActions.push_back(std::move(Running)); + It = m_RunningMap.erase(It); + } + else + { + ++It; + } + } + }); + + ProcessCompletedActions(CompletedActions); +} + +void +WindowsProcessRunner::CancelRunningActions() +{ + ZEN_TRACE_CPU("WindowsProcessRunner::CancelRunningActions"); + Stopwatch Timer; + std::unordered_map<int, Ref<RunningAction>> RunningMap; + + m_RunningLock.WithExclusiveLock([&] { std::swap(RunningMap, m_RunningMap); }); + + if (RunningMap.empty()) + { + return; + } + + ZEN_INFO("cancelling all running actions"); + + // For expedience we initiate the process termination for all known + // processes before attempting to wait for them to exit. + + // Initiate termination for all known processes before waiting for them to exit. + + for (const auto& Kv : RunningMap) + { + Ref<RunningAction> Running = Kv.second; + + BOOL TermSuccess = TerminateProcess(Running->ProcessHandle, 222); + + if (!TermSuccess) + { + DWORD LastError = GetLastError(); + + if (LastError != ERROR_ACCESS_DENIED) + { + ZEN_WARN("TerminateProcess for LSN {} not successful: {}", Running->Action->ActionLsn, GetSystemErrorAsString(LastError)); + } + } + } + + // Wait for all processes and clean up, regardless of whether TerminateProcess succeeded. + + for (auto& [Lsn, Running] : RunningMap) + { + if (Running->ProcessHandle != INVALID_HANDLE_VALUE) + { + DWORD WaitResult = WaitForSingleObject(Running->ProcessHandle, 2000); + + if (WaitResult != WAIT_OBJECT_0) + { + ZEN_WARN("wait for LSN {}: process exit did not succeed, result = {}", Running->Action->ActionLsn, WaitResult); + } + else + { + ZEN_DEBUG("LSN {}: process exit OK", Running->Action->ActionLsn); + } + + CloseHandle(Running->ProcessHandle); + Running->ProcessHandle = INVALID_HANDLE_VALUE; + } + + m_DeferredDeleter.Enqueue(Running->Action->ActionLsn, std::move(Running->SandboxPath)); + Running->Action->SetActionState(RunnerAction::State::Failed); + } + + ZEN_INFO("DONE - cancelled {} running processes (took {})", RunningMap.size(), NiceTimeSpanMs(Timer.GetElapsedTimeMs())); +} + +bool +WindowsProcessRunner::CancelAction(int ActionLsn) +{ + ZEN_TRACE_CPU("WindowsProcessRunner::CancelAction"); + + // Hold the shared lock while terminating to prevent the sweep thread from + // closing the handle between our lookup and TerminateProcess call. + bool Sent = false; + + m_RunningLock.WithSharedLock([&] { + auto It = m_RunningMap.find(ActionLsn); + if (It == m_RunningMap.end()) + { + return; + } + + Ref<RunningAction> Target = It->second; + if (Target->ProcessHandle == INVALID_HANDLE_VALUE) + { + return; + } + + BOOL TermSuccess = TerminateProcess(Target->ProcessHandle, 222); + + if (!TermSuccess) + { + DWORD LastError = GetLastError(); + + if (LastError != ERROR_ACCESS_DENIED) + { + ZEN_WARN("CancelAction: TerminateProcess for LSN {} not successful: {}", ActionLsn, GetSystemErrorAsString(LastError)); + } + + return; + } + + ZEN_DEBUG("CancelAction: initiated cancellation of LSN {}", ActionLsn); + Sent = true; + }); + + // The monitor thread will pick up the process exit and mark the action as Failed. + return Sent; +} + +void +WindowsProcessRunner::SampleProcessCpu(RunningAction& Running) +{ + FILETIME CreationTime, ExitTime, KernelTime, UserTime; + if (!GetProcessTimes(Running.ProcessHandle, &CreationTime, &ExitTime, &KernelTime, &UserTime)) + { + return; + } + + auto FtToU64 = [](FILETIME Ft) -> uint64_t { return (static_cast<uint64_t>(Ft.dwHighDateTime) << 32) | Ft.dwLowDateTime; }; + + // FILETIME values are in 100-nanosecond intervals + const uint64_t CurrentOsTicks = FtToU64(KernelTime) + FtToU64(UserTime); + const uint64_t NowTicks = GetHifreqTimerValue(); + + // Cumulative CPU seconds (absolute, available from first sample): 100ns → seconds + Running.Action->CpuSeconds.store(static_cast<float>(static_cast<double>(CurrentOsTicks) / 10'000'000.0), std::memory_order_relaxed); + + if (Running.LastCpuSampleTicks != 0 && Running.LastCpuOsTicks != 0) + { + const uint64_t ElapsedMs = Stopwatch::GetElapsedTimeMs(NowTicks - Running.LastCpuSampleTicks); + if (ElapsedMs > 0) + { + const uint64_t DeltaOsTicks = CurrentOsTicks - Running.LastCpuOsTicks; + // 100ns → ms: divide by 10000; then as percent of elapsed ms + const float CpuPct = static_cast<float>(static_cast<double>(DeltaOsTicks) / 10000.0 / ElapsedMs * 100.0); + Running.Action->CpuUsagePercent.store(CpuPct, std::memory_order_relaxed); + } + } + + Running.LastCpuSampleTicks = NowTicks; + Running.LastCpuOsTicks = CurrentOsTicks; +} + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/runners/windowsrunner.h b/src/zencompute/runners/windowsrunner.h new file mode 100644 index 000000000..9f2385cc4 --- /dev/null +++ b/src/zencompute/runners/windowsrunner.h @@ -0,0 +1,53 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "localrunner.h" + +#if ZEN_WITH_COMPUTE_SERVICES && ZEN_PLATFORM_WINDOWS + +# include <zencore/windows.h> + +# include <string> + +namespace zen::compute { + +/** Windows process runner using CreateProcessW for executing worker executables. + + Subclasses LocalProcessRunner, reusing sandbox management, worker manifesting, + input/output handling, and monitor thread infrastructure. Overrides only the + platform-specific methods: process spawning, sweep, and cancellation. + + When Sandboxed is true, child processes are isolated using a Windows AppContainer: + no network access (AppContainer blocks network by default when no capabilities are + granted) and no filesystem access outside explicitly granted sandbox and worker + directories. This requires no elevation. + */ +class WindowsProcessRunner : public LocalProcessRunner +{ +public: + WindowsProcessRunner(ChunkResolver& Resolver, + const std::filesystem::path& BaseDir, + DeferredDirectoryDeleter& Deleter, + WorkerThreadPool& WorkerPool, + bool Sandboxed = false, + int32_t MaxConcurrentActions = 0); + ~WindowsProcessRunner(); + + [[nodiscard]] SubmitResult SubmitAction(Ref<RunnerAction> Action) override; + void SweepRunningActions() override; + void CancelRunningActions() override; + bool CancelAction(int ActionLsn) override; + void SampleProcessCpu(RunningAction& Running) override; + +private: + void GrantAppContainerAccess(const std::filesystem::path& Path, DWORD AccessMask); + + bool m_Sandboxed = false; + PSID m_AppContainerSid = nullptr; + std::wstring m_AppContainerName; +}; + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/runners/winerunner.cpp b/src/zencompute/runners/winerunner.cpp new file mode 100644 index 000000000..506bec73b --- /dev/null +++ b/src/zencompute/runners/winerunner.cpp @@ -0,0 +1,237 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "winerunner.h" + +#if ZEN_WITH_COMPUTE_SERVICES && ZEN_PLATFORM_LINUX + +# include <zencore/compactbinary.h> +# include <zencore/compactbinarypackage.h> +# include <zencore/except.h> +# include <zencore/filesystem.h> +# include <zencore/fmtutils.h> +# include <zencore/iobuffer.h> +# include <zencore/iohash.h> +# include <zencore/timer.h> +# include <zencore/trace.h> + +# include <signal.h> +# include <sys/wait.h> +# include <unistd.h> + +namespace zen::compute { + +using namespace std::literals; + +WineProcessRunner::WineProcessRunner(ChunkResolver& Resolver, + const std::filesystem::path& BaseDir, + DeferredDirectoryDeleter& Deleter, + WorkerThreadPool& WorkerPool) +: LocalProcessRunner(Resolver, BaseDir, Deleter, WorkerPool) +{ + // Restore SIGCHLD to default behavior so waitpid() can properly collect + // child exit status. zenserver/main.cpp sets SIGCHLD to SIG_IGN which + // causes the kernel to auto-reap children, making waitpid() return + // -1/ECHILD instead of the exit status we need. + struct sigaction Action = {}; + sigemptyset(&Action.sa_mask); + Action.sa_handler = SIG_DFL; + sigaction(SIGCHLD, &Action, nullptr); +} + +SubmitResult +WineProcessRunner::SubmitAction(Ref<RunnerAction> Action) +{ + ZEN_TRACE_CPU("WineProcessRunner::SubmitAction"); + std::optional<PreparedAction> Prepared = PrepareActionSubmission(Action); + + if (!Prepared) + { + return SubmitResult{.IsAccepted = false}; + } + + // Build environment array from worker descriptor + + CbObject WorkerDescription = Prepared->WorkerPackage.GetObject(); + + std::vector<std::string> EnvStrings; + for (auto& It : WorkerDescription["environment"sv]) + { + EnvStrings.emplace_back(It.AsString()); + } + + std::vector<char*> Envp; + Envp.reserve(EnvStrings.size() + 1); + for (auto& Str : EnvStrings) + { + Envp.push_back(Str.data()); + } + Envp.push_back(nullptr); + + // Build argv: wine <worker_exe_path> -Build=build.action + + std::string_view ExecPath = WorkerDescription["path"sv].AsString(); + std::filesystem::path ExePath = Prepared->WorkerPath / std::filesystem::path(ExecPath); + std::string ExePathStr = ExePath.string(); + std::string WinePathStr = m_WinePath; + std::string BuildArg = "-Build=build.action"; + + std::vector<char*> ArgV; + ArgV.push_back(WinePathStr.data()); + ArgV.push_back(ExePathStr.data()); + ArgV.push_back(BuildArg.data()); + ArgV.push_back(nullptr); + + ZEN_DEBUG("Executing via Wine: {} {} {}", WinePathStr, ExePathStr, BuildArg); + + std::string SandboxPathStr = Prepared->SandboxPath.string(); + + pid_t ChildPid = fork(); + + if (ChildPid < 0) + { + throw std::runtime_error(fmt::format("fork() failed: {}", strerror(errno))); + } + + if (ChildPid == 0) + { + // Child process + if (chdir(SandboxPathStr.c_str()) != 0) + { + _exit(127); + } + + execve(WinePathStr.c_str(), ArgV.data(), Envp.data()); + + // execve only returns on failure + _exit(127); + } + + // Parent: store child pid as void* (same convention as zencore/process.cpp) + + Ref<RunningAction> NewAction{new RunningAction()}; + NewAction->Action = Action; + NewAction->ProcessHandle = reinterpret_cast<void*>(static_cast<intptr_t>(ChildPid)); + NewAction->SandboxPath = std::move(Prepared->SandboxPath); + + { + RwLock::ExclusiveLockScope _(m_RunningLock); + m_RunningMap[Prepared->ActionLsn] = std::move(NewAction); + } + + Action->SetActionState(RunnerAction::State::Running); + + return SubmitResult{.IsAccepted = true}; +} + +void +WineProcessRunner::SweepRunningActions() +{ + ZEN_TRACE_CPU("WineProcessRunner::SweepRunningActions"); + std::vector<Ref<RunningAction>> CompletedActions; + + m_RunningLock.WithExclusiveLock([&] { + for (auto It = begin(m_RunningMap), ItEnd = end(m_RunningMap); It != ItEnd;) + { + Ref<RunningAction> Running = It->second; + + pid_t Pid = static_cast<pid_t>(reinterpret_cast<intptr_t>(Running->ProcessHandle)); + int Status = 0; + + pid_t Result = waitpid(Pid, &Status, WNOHANG); + + if (Result == Pid) + { + if (WIFEXITED(Status)) + { + Running->ExitCode = WEXITSTATUS(Status); + } + else if (WIFSIGNALED(Status)) + { + Running->ExitCode = 128 + WTERMSIG(Status); + } + else + { + Running->ExitCode = 1; + } + + Running->ProcessHandle = nullptr; + + CompletedActions.push_back(std::move(Running)); + It = m_RunningMap.erase(It); + } + else + { + ++It; + } + } + }); + + ProcessCompletedActions(CompletedActions); +} + +void +WineProcessRunner::CancelRunningActions() +{ + ZEN_TRACE_CPU("WineProcessRunner::CancelRunningActions"); + Stopwatch Timer; + std::unordered_map<int, Ref<RunningAction>> RunningMap; + + m_RunningLock.WithExclusiveLock([&] { std::swap(RunningMap, m_RunningMap); }); + + if (RunningMap.empty()) + { + return; + } + + ZEN_INFO("cancelling all running actions"); + + // Send SIGTERM to all running processes first + + for (const auto& [Lsn, Running] : RunningMap) + { + pid_t Pid = static_cast<pid_t>(reinterpret_cast<intptr_t>(Running->ProcessHandle)); + + if (kill(Pid, SIGTERM) != 0) + { + ZEN_WARN("kill(SIGTERM) for LSN {} (pid {}) failed: {}", Running->Action->ActionLsn, Pid, strerror(errno)); + } + } + + // Wait for all processes, regardless of whether SIGTERM succeeded, then clean up. + + for (auto& [Lsn, Running] : RunningMap) + { + pid_t Pid = static_cast<pid_t>(reinterpret_cast<intptr_t>(Running->ProcessHandle)); + + // Poll for up to 2 seconds + bool Exited = false; + for (int i = 0; i < 20; ++i) + { + int Status = 0; + pid_t WaitResult = waitpid(Pid, &Status, WNOHANG); + if (WaitResult == Pid) + { + Exited = true; + ZEN_DEBUG("LSN {}: process exit OK", Running->Action->ActionLsn); + break; + } + usleep(100000); // 100ms + } + + if (!Exited) + { + ZEN_WARN("LSN {}: process did not exit after SIGTERM, sending SIGKILL", Running->Action->ActionLsn); + kill(Pid, SIGKILL); + waitpid(Pid, nullptr, 0); + } + + m_DeferredDeleter.Enqueue(Running->Action->ActionLsn, std::move(Running->SandboxPath)); + Running->Action->SetActionState(RunnerAction::State::Failed); + } + + ZEN_INFO("DONE - cancelled {} running processes (took {})", RunningMap.size(), NiceTimeSpanMs(Timer.GetElapsedTimeMs())); +} + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/runners/winerunner.h b/src/zencompute/runners/winerunner.h new file mode 100644 index 000000000..7df62e7c0 --- /dev/null +++ b/src/zencompute/runners/winerunner.h @@ -0,0 +1,37 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "localrunner.h" + +#if ZEN_WITH_COMPUTE_SERVICES && ZEN_PLATFORM_LINUX + +# include <string> + +namespace zen::compute { + +/** Wine-based process runner for executing Windows worker executables on Linux. + + Subclasses LocalProcessRunner, reusing sandbox management, worker manifesting, + input/output handling, and monitor thread infrastructure. Overrides only the + platform-specific methods: process spawning, sweep, and cancellation. + */ +class WineProcessRunner : public LocalProcessRunner +{ +public: + WineProcessRunner(ChunkResolver& Resolver, + const std::filesystem::path& BaseDir, + DeferredDirectoryDeleter& Deleter, + WorkerThreadPool& WorkerPool); + + [[nodiscard]] SubmitResult SubmitAction(Ref<RunnerAction> Action) override; + void SweepRunningActions() override; + void CancelRunningActions() override; + +private: + std::string m_WinePath = "wine"; +}; + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/testing/mockimds.cpp b/src/zencompute/testing/mockimds.cpp new file mode 100644 index 000000000..dd09312df --- /dev/null +++ b/src/zencompute/testing/mockimds.cpp @@ -0,0 +1,205 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencompute/mockimds.h> + +#include <zencore/fmtutils.h> + +#if ZEN_WITH_TESTS + +namespace zen::compute { + +const char* +MockImdsService::BaseUri() const +{ + return "/"; +} + +void +MockImdsService::HandleRequest(HttpServerRequest& Request) +{ + std::string_view Uri = Request.RelativeUri(); + + // AWS endpoints live under /latest/ + if (Uri.starts_with("latest/")) + { + if (ActiveProvider == CloudProvider::AWS) + { + HandleAwsRequest(Request); + return; + } + Request.WriteResponse(HttpResponseCode::NotFound); + return; + } + + // Azure endpoints live under /metadata/ + if (Uri.starts_with("metadata/")) + { + if (ActiveProvider == CloudProvider::Azure) + { + HandleAzureRequest(Request); + return; + } + Request.WriteResponse(HttpResponseCode::NotFound); + return; + } + + // GCP endpoints live under /computeMetadata/ + if (Uri.starts_with("computeMetadata/")) + { + if (ActiveProvider == CloudProvider::GCP) + { + HandleGcpRequest(Request); + return; + } + Request.WriteResponse(HttpResponseCode::NotFound); + return; + } + + Request.WriteResponse(HttpResponseCode::NotFound); +} + +// --------------------------------------------------------------------------- +// AWS +// --------------------------------------------------------------------------- + +void +MockImdsService::HandleAwsRequest(HttpServerRequest& Request) +{ + std::string_view Uri = Request.RelativeUri(); + + // IMDSv2 token acquisition (PUT only) + if (Uri == "latest/api/token" && Request.RequestVerb() == HttpVerb::kPut) + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Aws.Token); + return; + } + + // Instance identity + if (Uri == "latest/meta-data/instance-id") + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Aws.InstanceId); + return; + } + + if (Uri == "latest/meta-data/placement/availability-zone") + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Aws.AvailabilityZone); + return; + } + + if (Uri == "latest/meta-data/instance-life-cycle") + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Aws.LifeCycle); + return; + } + + // Autoscaling lifecycle state — 404 when not in an ASG + if (Uri == "latest/meta-data/autoscaling/target-lifecycle-state") + { + if (Aws.AutoscalingState.empty()) + { + Request.WriteResponse(HttpResponseCode::NotFound); + return; + } + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Aws.AutoscalingState); + return; + } + + // Spot interruption notice — 404 when no interruption pending + if (Uri == "latest/meta-data/spot/instance-action") + { + if (Aws.SpotAction.empty()) + { + Request.WriteResponse(HttpResponseCode::NotFound); + return; + } + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Aws.SpotAction); + return; + } + + Request.WriteResponse(HttpResponseCode::NotFound); +} + +// --------------------------------------------------------------------------- +// Azure +// --------------------------------------------------------------------------- + +void +MockImdsService::HandleAzureRequest(HttpServerRequest& Request) +{ + std::string_view Uri = Request.RelativeUri(); + + // Instance metadata (single JSON document) + if (Uri == "metadata/instance") + { + std::string Json = fmt::format(R"({{"compute":{{"vmId":"{}","location":"{}","priority":"{}","vmScaleSetName":"{}"}}}})", + Azure.VmId, + Azure.Location, + Azure.Priority, + Azure.VmScaleSetName); + + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Json); + return; + } + + // Scheduled events for termination monitoring + if (Uri == "metadata/scheduledevents") + { + std::string Json; + if (Azure.ScheduledEventType.empty()) + { + Json = R"({"Events":[]})"; + } + else + { + Json = fmt::format(R"({{"Events":[{{"EventType":"{}","EventStatus":"{}"}}]}})", + Azure.ScheduledEventType, + Azure.ScheduledEventStatus); + } + + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Json); + return; + } + + Request.WriteResponse(HttpResponseCode::NotFound); +} + +// --------------------------------------------------------------------------- +// GCP +// --------------------------------------------------------------------------- + +void +MockImdsService::HandleGcpRequest(HttpServerRequest& Request) +{ + std::string_view Uri = Request.RelativeUri(); + + if (Uri == "computeMetadata/v1/instance/id") + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Gcp.InstanceId); + return; + } + + if (Uri == "computeMetadata/v1/instance/zone") + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Gcp.Zone); + return; + } + + if (Uri == "computeMetadata/v1/instance/scheduling/preemptible") + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Gcp.Preemptible); + return; + } + + if (Uri == "computeMetadata/v1/instance/maintenance-event") + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Gcp.MaintenanceEvent); + return; + } + + Request.WriteResponse(HttpResponseCode::NotFound); +} + +} // namespace zen::compute + +#endif // ZEN_WITH_TESTS diff --git a/src/zencompute/timeline/workertimeline.cpp b/src/zencompute/timeline/workertimeline.cpp new file mode 100644 index 000000000..88ef5b62d --- /dev/null +++ b/src/zencompute/timeline/workertimeline.cpp @@ -0,0 +1,430 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "workertimeline.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include <zencore/basicfile.h> +# include <zencore/compactbinary.h> +# include <zencore/compactbinarybuilder.h> +# include <zencore/compactbinaryfile.h> + +# include <algorithm> + +namespace zen::compute { + +WorkerTimeline::WorkerTimeline(std::string_view WorkerId) : m_WorkerId(WorkerId) +{ +} + +WorkerTimeline::~WorkerTimeline() +{ +} + +void +WorkerTimeline::RecordProvisioned() +{ + AppendEvent({ + .Type = EventType::WorkerProvisioned, + .Timestamp = DateTime::Now(), + }); +} + +void +WorkerTimeline::RecordDeprovisioned() +{ + AppendEvent({ + .Type = EventType::WorkerDeprovisioned, + .Timestamp = DateTime::Now(), + }); +} + +void +WorkerTimeline::RecordActionAccepted(int ActionLsn, const IoHash& ActionId) +{ + AppendEvent({ + .Type = EventType::ActionAccepted, + .Timestamp = DateTime::Now(), + .ActionLsn = ActionLsn, + .ActionId = ActionId, + }); +} + +void +WorkerTimeline::RecordActionRejected(int ActionLsn, const IoHash& ActionId, std::string_view Reason) +{ + AppendEvent({ + .Type = EventType::ActionRejected, + .Timestamp = DateTime::Now(), + .ActionLsn = ActionLsn, + .ActionId = ActionId, + .Reason = std::string(Reason), + }); +} + +void +WorkerTimeline::RecordActionStateChanged(int ActionLsn, + const IoHash& ActionId, + RunnerAction::State PreviousState, + RunnerAction::State NewState) +{ + AppendEvent({ + .Type = EventType::ActionStateChanged, + .Timestamp = DateTime::Now(), + .ActionLsn = ActionLsn, + .ActionId = ActionId, + .ActionState = NewState, + .PreviousState = PreviousState, + }); +} + +std::vector<WorkerTimeline::Event> +WorkerTimeline::QueryTimeline(DateTime StartTime, DateTime EndTime) const +{ + std::vector<Event> Result; + + m_EventsLock.WithSharedLock([&] { + for (const auto& Evt : m_Events) + { + if (Evt.Timestamp >= StartTime && Evt.Timestamp <= EndTime) + { + Result.push_back(Evt); + } + } + }); + + return Result; +} + +std::vector<WorkerTimeline::Event> +WorkerTimeline::QueryRecent(int Limit) const +{ + std::vector<Event> Result; + + m_EventsLock.WithSharedLock([&] { + const int Count = std::min(Limit, gsl::narrow<int>(m_Events.size())); + auto It = m_Events.end() - Count; + Result.assign(It, m_Events.end()); + }); + + return Result; +} + +size_t +WorkerTimeline::GetEventCount() const +{ + size_t Count = 0; + m_EventsLock.WithSharedLock([&] { Count = m_Events.size(); }); + return Count; +} + +WorkerTimeline::TimeRange +WorkerTimeline::GetTimeRange() const +{ + TimeRange Range; + m_EventsLock.WithSharedLock([&] { + if (!m_Events.empty()) + { + Range.First = m_Events.front().Timestamp; + Range.Last = m_Events.back().Timestamp; + } + }); + return Range; +} + +void +WorkerTimeline::AppendEvent(Event&& Evt) +{ + m_EventsLock.WithExclusiveLock([&] { + while (m_Events.size() >= m_MaxEvents) + { + m_Events.pop_front(); + } + + m_Events.push_back(std::move(Evt)); + }); +} + +const char* +WorkerTimeline::ToString(EventType Type) +{ + switch (Type) + { + case EventType::WorkerProvisioned: + return "provisioned"; + case EventType::WorkerDeprovisioned: + return "deprovisioned"; + case EventType::ActionAccepted: + return "accepted"; + case EventType::ActionRejected: + return "rejected"; + case EventType::ActionStateChanged: + return "state_changed"; + default: + return "unknown"; + } +} + +static WorkerTimeline::EventType +EventTypeFromString(std::string_view Str) +{ + if (Str == "provisioned") + return WorkerTimeline::EventType::WorkerProvisioned; + if (Str == "deprovisioned") + return WorkerTimeline::EventType::WorkerDeprovisioned; + if (Str == "accepted") + return WorkerTimeline::EventType::ActionAccepted; + if (Str == "rejected") + return WorkerTimeline::EventType::ActionRejected; + if (Str == "state_changed") + return WorkerTimeline::EventType::ActionStateChanged; + return WorkerTimeline::EventType::WorkerProvisioned; +} + +void +WorkerTimeline::WriteTo(const std::filesystem::path& Path) const +{ + CbObjectWriter Cbo; + Cbo << "worker_id" << m_WorkerId; + + m_EventsLock.WithSharedLock([&] { + if (!m_Events.empty()) + { + Cbo.AddDateTime("time_first", m_Events.front().Timestamp); + Cbo.AddDateTime("time_last", m_Events.back().Timestamp); + } + + Cbo.BeginArray("events"); + for (const auto& Evt : m_Events) + { + Cbo.BeginObject(); + Cbo << "type" << ToString(Evt.Type); + Cbo.AddDateTime("ts", Evt.Timestamp); + + if (Evt.ActionLsn != 0) + { + Cbo << "lsn" << Evt.ActionLsn; + Cbo << "action_id" << Evt.ActionId; + } + + if (Evt.Type == EventType::ActionStateChanged) + { + Cbo << "prev_state" << static_cast<int32_t>(Evt.PreviousState); + Cbo << "state" << static_cast<int32_t>(Evt.ActionState); + } + + if (!Evt.Reason.empty()) + { + Cbo << "reason" << std::string_view(Evt.Reason); + } + + Cbo.EndObject(); + } + Cbo.EndArray(); + }); + + CbObject Obj = Cbo.Save(); + + BasicFile File(Path, BasicFile::Mode::kTruncate); + File.Write(Obj.GetBuffer().GetView(), 0); +} + +void +WorkerTimeline::ReadFrom(const std::filesystem::path& Path) +{ + CbObjectFromFile Loaded = LoadCompactBinaryObject(Path); + CbObject Root = std::move(Loaded.Object); + + if (!Root) + { + return; + } + + std::deque<Event> LoadedEvents; + + for (CbFieldView Field : Root["events"].AsArrayView()) + { + CbObjectView EventObj = Field.AsObjectView(); + + Event Evt; + Evt.Type = EventTypeFromString(EventObj["type"].AsString()); + Evt.Timestamp = EventObj["ts"].AsDateTime(); + + Evt.ActionLsn = EventObj["lsn"].AsInt32(); + Evt.ActionId = EventObj["action_id"].AsHash(); + + if (Evt.Type == EventType::ActionStateChanged) + { + Evt.PreviousState = static_cast<RunnerAction::State>(EventObj["prev_state"].AsInt32()); + Evt.ActionState = static_cast<RunnerAction::State>(EventObj["state"].AsInt32()); + } + + std::string_view Reason = EventObj["reason"].AsString(); + if (!Reason.empty()) + { + Evt.Reason = std::string(Reason); + } + + LoadedEvents.push_back(std::move(Evt)); + } + + m_EventsLock.WithExclusiveLock([&] { m_Events = std::move(LoadedEvents); }); +} + +WorkerTimeline::TimeRange +WorkerTimeline::ReadTimeRange(const std::filesystem::path& Path) +{ + CbObjectFromFile Loaded = LoadCompactBinaryObject(Path); + + if (!Loaded.Object) + { + return {}; + } + + return { + .First = Loaded.Object["time_first"].AsDateTime(), + .Last = Loaded.Object["time_last"].AsDateTime(), + }; +} + +// WorkerTimelineStore + +static constexpr std::string_view kTimelineExtension = ".ztimeline"; + +WorkerTimelineStore::WorkerTimelineStore(std::filesystem::path PersistenceDir) : m_PersistenceDir(std::move(PersistenceDir)) +{ + std::error_code Ec; + std::filesystem::create_directories(m_PersistenceDir, Ec); +} + +Ref<WorkerTimeline> +WorkerTimelineStore::GetOrCreate(std::string_view WorkerId) +{ + // Fast path: check if it already exists in memory + { + RwLock::SharedLockScope _(m_Lock); + auto It = m_Timelines.find(std::string(WorkerId)); + if (It != m_Timelines.end()) + { + return It->second; + } + } + + // Slow path: create under exclusive lock, loading from disk if available + RwLock::ExclusiveLockScope _(m_Lock); + + auto& Entry = m_Timelines[std::string(WorkerId)]; + if (!Entry) + { + Entry = Ref<WorkerTimeline>(new WorkerTimeline(WorkerId)); + + std::filesystem::path Path = TimelinePath(WorkerId); + std::error_code Ec; + if (std::filesystem::is_regular_file(Path, Ec)) + { + Entry->ReadFrom(Path); + } + } + return Entry; +} + +Ref<WorkerTimeline> +WorkerTimelineStore::Find(std::string_view WorkerId) +{ + RwLock::SharedLockScope _(m_Lock); + auto It = m_Timelines.find(std::string(WorkerId)); + if (It != m_Timelines.end()) + { + return It->second; + } + return {}; +} + +std::vector<std::string> +WorkerTimelineStore::GetActiveWorkerIds() const +{ + std::vector<std::string> Result; + + RwLock::SharedLockScope $(m_Lock); + Result.reserve(m_Timelines.size()); + for (const auto& [Id, _] : m_Timelines) + { + Result.push_back(Id); + } + + return Result; +} + +std::vector<WorkerTimelineStore::WorkerTimelineInfo> +WorkerTimelineStore::GetAllWorkerInfo() const +{ + std::unordered_map<std::string, WorkerTimeline::TimeRange> InfoMap; + + { + RwLock::SharedLockScope _(m_Lock); + for (const auto& [Id, Timeline] : m_Timelines) + { + InfoMap[Id] = Timeline->GetTimeRange(); + } + } + + std::error_code Ec; + for (const auto& Entry : std::filesystem::directory_iterator(m_PersistenceDir, Ec)) + { + if (!Entry.is_regular_file()) + { + continue; + } + + const auto& Path = Entry.path(); + if (Path.extension().string() != kTimelineExtension) + { + continue; + } + + std::string Id = Path.stem().string(); + if (InfoMap.find(Id) == InfoMap.end()) + { + InfoMap[Id] = WorkerTimeline::ReadTimeRange(Path); + } + } + + std::vector<WorkerTimelineInfo> Result; + Result.reserve(InfoMap.size()); + for (auto& [Id, Range] : InfoMap) + { + Result.push_back({.WorkerId = std::move(Id), .Range = Range}); + } + return Result; +} + +void +WorkerTimelineStore::Save(std::string_view WorkerId) +{ + RwLock::SharedLockScope _(m_Lock); + auto It = m_Timelines.find(std::string(WorkerId)); + if (It != m_Timelines.end()) + { + It->second->WriteTo(TimelinePath(WorkerId)); + } +} + +void +WorkerTimelineStore::SaveAll() +{ + RwLock::SharedLockScope _(m_Lock); + for (const auto& [Id, Timeline] : m_Timelines) + { + Timeline->WriteTo(TimelinePath(Id)); + } +} + +std::filesystem::path +WorkerTimelineStore::TimelinePath(std::string_view WorkerId) const +{ + return m_PersistenceDir / (std::string(WorkerId) + std::string(kTimelineExtension)); +} + +} // namespace zen::compute + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/timeline/workertimeline.h b/src/zencompute/timeline/workertimeline.h new file mode 100644 index 000000000..87e19bc28 --- /dev/null +++ b/src/zencompute/timeline/workertimeline.h @@ -0,0 +1,169 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "../runners/functionrunner.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include <zenbase/refcount.h> +# include <zencore/compactbinary.h> +# include <zencore/iohash.h> +# include <zencore/thread.h> +# include <zencore/timer.h> + +# include <deque> +# include <filesystem> +# include <string> +# include <string_view> +# include <unordered_map> +# include <vector> + +namespace zen::compute { + +struct RunnerAction; + +/** Worker activity timeline for tracking and visualizing worker activity over time. + * + * Records worker lifecycle events (provisioning/deprovisioning) and action lifecycle + * events (accept, reject, state changes) with timestamps, enabling time-range queries + * for dashboard visualization. + */ +class WorkerTimeline : public RefCounted +{ +public: + explicit WorkerTimeline(std::string_view WorkerId); + ~WorkerTimeline() override; + + struct TimeRange + { + DateTime First = DateTime(0); + DateTime Last = DateTime(0); + + explicit operator bool() const { return First.GetTicks() != 0; } + }; + + enum class EventType + { + WorkerProvisioned, + WorkerDeprovisioned, + ActionAccepted, + ActionRejected, + ActionStateChanged + }; + + static const char* ToString(EventType Type); + + struct Event + { + EventType Type; + DateTime Timestamp = DateTime(0); + + // Action context (only set for action events) + int ActionLsn = 0; + IoHash ActionId; + RunnerAction::State ActionState = RunnerAction::State::New; + RunnerAction::State PreviousState = RunnerAction::State::New; + + // Optional reason (e.g. rejection reason) + std::string Reason; + }; + + /** Record that this worker has been provisioned and is available for work. */ + void RecordProvisioned(); + + /** Record that this worker has been deprovisioned and is no longer available. */ + void RecordDeprovisioned(); + + /** Record that an action was accepted by this worker. */ + void RecordActionAccepted(int ActionLsn, const IoHash& ActionId); + + /** Record that an action was rejected by this worker. */ + void RecordActionRejected(int ActionLsn, const IoHash& ActionId, std::string_view Reason); + + /** Record an action state transition on this worker. */ + void RecordActionStateChanged(int ActionLsn, const IoHash& ActionId, RunnerAction::State PreviousState, RunnerAction::State NewState); + + /** Query events within a time range (inclusive). Returns events ordered by timestamp. */ + [[nodiscard]] std::vector<Event> QueryTimeline(DateTime StartTime, DateTime EndTime) const; + + /** Query the most recent N events. */ + [[nodiscard]] std::vector<Event> QueryRecent(int Limit = 100) const; + + /** Return the total number of recorded events. */ + [[nodiscard]] size_t GetEventCount() const; + + /** Return the time range covered by the events in this timeline. */ + [[nodiscard]] TimeRange GetTimeRange() const; + + [[nodiscard]] const std::string& GetWorkerId() const { return m_WorkerId; } + + /** Write the timeline to a file at the given path. */ + void WriteTo(const std::filesystem::path& Path) const; + + /** Read the timeline from a file at the given path. Replaces current in-memory events. */ + void ReadFrom(const std::filesystem::path& Path); + + /** Read only the time range from a persisted timeline file, without loading events. */ + [[nodiscard]] static TimeRange ReadTimeRange(const std::filesystem::path& Path); + +private: + void AppendEvent(Event&& Evt); + + std::string m_WorkerId; + mutable RwLock m_EventsLock; + std::deque<Event> m_Events; + size_t m_MaxEvents = 10'000; +}; + +/** Manages a set of WorkerTimeline instances, keyed by worker ID. + * + * Provides thread-safe lookup and on-demand creation of timelines, backed by + * a persistence directory. Each timeline is stored as a separate file named + * {WorkerId}.ztimeline within the directory. + */ +class WorkerTimelineStore +{ +public: + explicit WorkerTimelineStore(std::filesystem::path PersistenceDir); + ~WorkerTimelineStore() = default; + + WorkerTimelineStore(const WorkerTimelineStore&) = delete; + WorkerTimelineStore& operator=(const WorkerTimelineStore&) = delete; + + /** Get the timeline for a worker, creating one if it does not exist. + * If a persisted file exists on disk it will be loaded on first access. */ + Ref<WorkerTimeline> GetOrCreate(std::string_view WorkerId); + + /** Get the timeline for a worker, or null ref if it does not exist in memory. */ + [[nodiscard]] Ref<WorkerTimeline> Find(std::string_view WorkerId); + + /** Return the worker IDs of currently loaded (in-memory) timelines. */ + [[nodiscard]] std::vector<std::string> GetActiveWorkerIds() const; + + struct WorkerTimelineInfo + { + std::string WorkerId; + WorkerTimeline::TimeRange Range; + }; + + /** Return info for all known timelines (in-memory and on-disk), including time range. */ + [[nodiscard]] std::vector<WorkerTimelineInfo> GetAllWorkerInfo() const; + + /** Persist a single worker's timeline to disk. */ + void Save(std::string_view WorkerId); + + /** Persist all in-memory timelines to disk. */ + void SaveAll(); + +private: + [[nodiscard]] std::filesystem::path TimelinePath(std::string_view WorkerId) const; + + std::filesystem::path m_PersistenceDir; + mutable RwLock m_Lock; + std::unordered_map<std::string, Ref<WorkerTimeline>> m_Timelines; +}; + +} // namespace zen::compute + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/xmake.lua b/src/zencompute/xmake.lua index 50877508c..ed0af66a5 100644 --- a/src/zencompute/xmake.lua +++ b/src/zencompute/xmake.lua @@ -6,4 +6,14 @@ target('zencompute') add_headerfiles("**.h") add_files("**.cpp") add_includedirs("include", {public=true}) + add_includedirs(".", {private=true}) add_deps("zencore", "zenstore", "zenutil", "zennet", "zenhttp") + add_packages("json11") + + if is_os("macosx") then + add_cxxflags("-Wno-deprecated-declarations") + end + + if is_plat("windows") then + add_syslinks("Userenv") + end diff --git a/src/zencompute/zencompute.cpp b/src/zencompute/zencompute.cpp index 633250f4e..1f3f6d3f9 100644 --- a/src/zencompute/zencompute.cpp +++ b/src/zencompute/zencompute.cpp @@ -2,11 +2,20 @@ #include "zencompute/zencompute.h" +#if ZEN_WITH_TESTS +# include "runners/deferreddeleter.h" +# include <zencompute/cloudmetadata.h> +#endif + namespace zen { void zencompute_forcelinktests() { +#if ZEN_WITH_TESTS + compute::cloudmetadata_forcelink(); + compute::deferreddeleter_forcelink(); +#endif } } // namespace zen diff --git a/src/zencore/include/zencore/system.h b/src/zencore/include/zencore/system.h index bf3c15d3d..fecbe2dbe 100644 --- a/src/zencore/include/zencore/system.h +++ b/src/zencore/include/zencore/system.h @@ -4,6 +4,8 @@ #include <zencore/zencore.h> +#include <chrono> +#include <memory> #include <string> namespace zen { @@ -12,6 +14,7 @@ class CbWriter; std::string GetMachineName(); std::string_view GetOperatingSystemName(); +std::string_view GetRuntimePlatformName(); // "windows", "wine", "linux", or "macos" std::string_view GetCpuName(); struct SystemMetrics @@ -25,7 +28,13 @@ struct SystemMetrics uint64_t AvailVirtualMemoryMiB = 0; uint64_t PageFileMiB = 0; uint64_t AvailPageFileMiB = 0; - float CpuUsagePercent = 0.0f; +}; + +/// Extended metrics that include CPU usage percentage, which requires +/// stateful delta tracking via SystemMetricsTracker. +struct ExtendedSystemMetrics : SystemMetrics +{ + float CpuUsagePercent = 0.0f; }; SystemMetrics GetSystemMetrics(); @@ -33,6 +42,31 @@ SystemMetrics GetSystemMetrics(); void SetCpuCountForReporting(int FakeCpuCount); SystemMetrics GetSystemMetricsForReporting(); +ExtendedSystemMetrics ApplyReportingOverrides(ExtendedSystemMetrics Metrics); + void Describe(const SystemMetrics& Metrics, CbWriter& Writer); +void Describe(const ExtendedSystemMetrics& Metrics, CbWriter& Writer); + +/// Stateful tracker that computes CPU usage as a delta between consecutive +/// Query() calls. The first call returns CpuUsagePercent = 0 (no previous +/// sample). Thread-safe: concurrent calls are serialised internally. +/// CPU sampling is rate-limited to MinInterval (default 1 s); calls that +/// arrive sooner return the previously cached value. +class SystemMetricsTracker +{ +public: + explicit SystemMetricsTracker(std::chrono::milliseconds MinInterval = std::chrono::seconds(1)); + ~SystemMetricsTracker(); + + SystemMetricsTracker(const SystemMetricsTracker&) = delete; + SystemMetricsTracker& operator=(const SystemMetricsTracker&) = delete; + + /// Collect current metrics. CPU usage is computed as delta since last Query(). + ExtendedSystemMetrics Query(); + +private: + struct Impl; + std::unique_ptr<Impl> m_Impl; +}; } // namespace zen diff --git a/src/zencore/system.cpp b/src/zencore/system.cpp index 267c87e12..833d3c04b 100644 --- a/src/zencore/system.cpp +++ b/src/zencore/system.cpp @@ -7,6 +7,8 @@ #include <zencore/memory/memory.h> #include <zencore/string.h> +#include <mutex> + #if ZEN_PLATFORM_WINDOWS # include <zencore/windows.h> @@ -133,33 +135,6 @@ GetSystemMetrics() Metrics.AvailPageFileMiB = MemStatus.ullAvailPageFile / 1024 / 1024; } - // Query CPU usage using PDH - // - // TODO: This should be changed to not require a Sleep, perhaps by using some - // background metrics gathering mechanism. - - { - PDH_HQUERY QueryHandle = nullptr; - PDH_HCOUNTER CounterHandle = nullptr; - - if (PdhOpenQueryW(nullptr, 0, &QueryHandle) == ERROR_SUCCESS) - { - if (PdhAddEnglishCounterW(QueryHandle, L"\\Processor(_Total)\\% Processor Time", 0, &CounterHandle) == ERROR_SUCCESS) - { - PdhCollectQueryData(QueryHandle); - Sleep(100); - PdhCollectQueryData(QueryHandle); - - PDH_FMT_COUNTERVALUE CounterValue; - if (PdhGetFormattedCounterValue(CounterHandle, PDH_FMT_DOUBLE, nullptr, &CounterValue) == ERROR_SUCCESS) - { - Metrics.CpuUsagePercent = static_cast<float>(CounterValue.doubleValue); - } - } - PdhCloseQuery(QueryHandle); - } - } - return Metrics; } #elif ZEN_PLATFORM_LINUX @@ -235,39 +210,6 @@ GetSystemMetrics() } } - // Query CPU usage - Metrics.CpuUsagePercent = 0.0f; - if (FILE* Stat = fopen("/proc/stat", "r")) - { - char Line[256]; - unsigned long User, Nice, System, Idle, IoWait, Irq, SoftIrq; - static unsigned long PrevUser = 0, PrevNice = 0, PrevSystem = 0, PrevIdle = 0, PrevIoWait = 0, PrevIrq = 0, PrevSoftIrq = 0; - - if (fgets(Line, sizeof(Line), Stat)) - { - if (sscanf(Line, "cpu %lu %lu %lu %lu %lu %lu %lu", &User, &Nice, &System, &Idle, &IoWait, &Irq, &SoftIrq) == 7) - { - unsigned long TotalDelta = (User + Nice + System + Idle + IoWait + Irq + SoftIrq) - - (PrevUser + PrevNice + PrevSystem + PrevIdle + PrevIoWait + PrevIrq + PrevSoftIrq); - unsigned long IdleDelta = Idle - PrevIdle; - - if (TotalDelta > 0) - { - Metrics.CpuUsagePercent = 100.0f * (TotalDelta - IdleDelta) / TotalDelta; - } - - PrevUser = User; - PrevNice = Nice; - PrevSystem = System; - PrevIdle = Idle; - PrevIoWait = IoWait; - PrevIrq = Irq; - PrevSoftIrq = SoftIrq; - } - } - fclose(Stat); - } - // Get memory information long Pages = sysconf(_SC_PHYS_PAGES); long PageSize = sysconf(_SC_PAGE_SIZE); @@ -348,25 +290,6 @@ GetSystemMetrics() sysctlbyname("hw.packages", &Packages, &Size, nullptr, 0); Metrics.CpuCount = Packages > 0 ? Packages : 1; - // Query CPU usage using host_statistics64 - Metrics.CpuUsagePercent = 0.0f; - host_cpu_load_info_data_t CpuLoad; - mach_msg_type_number_t CpuCount = sizeof(CpuLoad) / sizeof(natural_t); - if (host_statistics(mach_host_self(), HOST_CPU_LOAD_INFO, (host_info_t)&CpuLoad, &CpuCount) == KERN_SUCCESS) - { - unsigned long TotalTicks = 0; - for (int i = 0; i < CPU_STATE_MAX; ++i) - { - TotalTicks += CpuLoad.cpu_ticks[i]; - } - - if (TotalTicks > 0) - { - unsigned long IdleTicks = CpuLoad.cpu_ticks[CPU_STATE_IDLE]; - Metrics.CpuUsagePercent = 100.0f * (TotalTicks - IdleTicks) / TotalTicks; - } - } - // Get memory information uint64_t MemSize = 0; Size = sizeof(MemSize); @@ -401,6 +324,17 @@ GetSystemMetrics() # error "Unknown platform" #endif +ExtendedSystemMetrics +ApplyReportingOverrides(ExtendedSystemMetrics Metrics) +{ + if (g_FakeCpuCount) + { + Metrics.CoreCount = g_FakeCpuCount; + Metrics.LogicalProcessorCount = g_FakeCpuCount; + } + return Metrics; +} + SystemMetrics GetSystemMetricsForReporting() { @@ -415,6 +349,225 @@ GetSystemMetricsForReporting() return Sm; } +/////////////////////////////////////////////////////////////////////////// +// SystemMetricsTracker +/////////////////////////////////////////////////////////////////////////// + +// Per-platform CPU sampling helper. Called with m_Mutex held. + +#if ZEN_PLATFORM_WINDOWS || ZEN_PLATFORM_LINUX + +// Samples CPU usage by reading /proc/stat. Used natively on Linux and as a +// Wine fallback on Windows (where /proc/stat is accessible via the Z: drive). +struct ProcStatCpuSampler +{ + const char* Path = "/proc/stat"; + unsigned long PrevUser = 0; + unsigned long PrevNice = 0; + unsigned long PrevSystem = 0; + unsigned long PrevIdle = 0; + unsigned long PrevIoWait = 0; + unsigned long PrevIrq = 0; + unsigned long PrevSoftIrq = 0; + + explicit ProcStatCpuSampler(const char* InPath = "/proc/stat") : Path(InPath) {} + + float Sample() + { + float CpuUsage = 0.0f; + + if (FILE* Stat = fopen(Path, "r")) + { + char Line[256]; + unsigned long User, Nice, System, Idle, IoWait, Irq, SoftIrq; + + if (fgets(Line, sizeof(Line), Stat)) + { + if (sscanf(Line, "cpu %lu %lu %lu %lu %lu %lu %lu", &User, &Nice, &System, &Idle, &IoWait, &Irq, &SoftIrq) == 7) + { + unsigned long TotalDelta = (User + Nice + System + Idle + IoWait + Irq + SoftIrq) - + (PrevUser + PrevNice + PrevSystem + PrevIdle + PrevIoWait + PrevIrq + PrevSoftIrq); + unsigned long IdleDelta = Idle - PrevIdle; + + if (TotalDelta > 0) + { + CpuUsage = 100.0f * (TotalDelta - IdleDelta) / TotalDelta; + } + + PrevUser = User; + PrevNice = Nice; + PrevSystem = System; + PrevIdle = Idle; + PrevIoWait = IoWait; + PrevIrq = Irq; + PrevSoftIrq = SoftIrq; + } + } + fclose(Stat); + } + + return CpuUsage; + } +}; + +#endif + +#if ZEN_PLATFORM_WINDOWS + +struct CpuSampler +{ + PDH_HQUERY QueryHandle = nullptr; + PDH_HCOUNTER CounterHandle = nullptr; + bool HasPreviousSample = false; + bool IsWine = false; + ProcStatCpuSampler ProcStat{"Z:\\proc\\stat"}; + + CpuSampler() + { + IsWine = zen::windows::IsRunningOnWine(); + + if (!IsWine) + { + if (PdhOpenQueryW(nullptr, 0, &QueryHandle) == ERROR_SUCCESS) + { + if (PdhAddEnglishCounterW(QueryHandle, L"\\Processor(_Total)\\% Processor Time", 0, &CounterHandle) != ERROR_SUCCESS) + { + CounterHandle = nullptr; + } + } + } + } + + ~CpuSampler() + { + if (QueryHandle) + { + PdhCloseQuery(QueryHandle); + } + } + + float Sample() + { + if (IsWine) + { + return ProcStat.Sample(); + } + + if (!QueryHandle || !CounterHandle) + { + return 0.0f; + } + + PdhCollectQueryData(QueryHandle); + + if (!HasPreviousSample) + { + HasPreviousSample = true; + return 0.0f; + } + + PDH_FMT_COUNTERVALUE CounterValue; + if (PdhGetFormattedCounterValue(CounterHandle, PDH_FMT_DOUBLE, nullptr, &CounterValue) == ERROR_SUCCESS) + { + return static_cast<float>(CounterValue.doubleValue); + } + + return 0.0f; + } +}; + +#elif ZEN_PLATFORM_LINUX + +struct CpuSampler +{ + ProcStatCpuSampler ProcStat; + + float Sample() { return ProcStat.Sample(); } +}; + +#elif ZEN_PLATFORM_MAC + +struct CpuSampler +{ + unsigned long PrevTotalTicks = 0; + unsigned long PrevIdleTicks = 0; + + float Sample() + { + float CpuUsage = 0.0f; + + host_cpu_load_info_data_t CpuLoad; + mach_msg_type_number_t Count = sizeof(CpuLoad) / sizeof(natural_t); + if (host_statistics(mach_host_self(), HOST_CPU_LOAD_INFO, (host_info_t)&CpuLoad, &Count) == KERN_SUCCESS) + { + unsigned long TotalTicks = 0; + for (int i = 0; i < CPU_STATE_MAX; ++i) + { + TotalTicks += CpuLoad.cpu_ticks[i]; + } + unsigned long IdleTicks = CpuLoad.cpu_ticks[CPU_STATE_IDLE]; + + unsigned long TotalDelta = TotalTicks - PrevTotalTicks; + unsigned long IdleDelta = IdleTicks - PrevIdleTicks; + + if (TotalDelta > 0 && PrevTotalTicks > 0) + { + CpuUsage = 100.0f * (TotalDelta - IdleDelta) / TotalDelta; + } + + PrevTotalTicks = TotalTicks; + PrevIdleTicks = IdleTicks; + } + + return CpuUsage; + } +}; + +#endif + +struct SystemMetricsTracker::Impl +{ + using Clock = std::chrono::steady_clock; + + std::mutex Mutex; + CpuSampler Sampler; + float CachedCpuPercent = 0.0f; + Clock::time_point NextSampleTime = Clock::now(); + std::chrono::milliseconds MinInterval; + + explicit Impl(std::chrono::milliseconds InMinInterval) : MinInterval(InMinInterval) {} + + float SampleCpu() + { + const auto Now = Clock::now(); + if (Now >= NextSampleTime) + { + CachedCpuPercent = Sampler.Sample(); + NextSampleTime = Now + MinInterval; + } + return CachedCpuPercent; + } +}; + +SystemMetricsTracker::SystemMetricsTracker(std::chrono::milliseconds MinInterval) : m_Impl(std::make_unique<Impl>(MinInterval)) +{ +} + +SystemMetricsTracker::~SystemMetricsTracker() = default; + +ExtendedSystemMetrics +SystemMetricsTracker::Query() +{ + ExtendedSystemMetrics Metrics; + static_cast<SystemMetrics&>(Metrics) = GetSystemMetrics(); + + std::lock_guard Lock(m_Impl->Mutex); + Metrics.CpuUsagePercent = m_Impl->SampleCpu(); + return Metrics; +} + +/////////////////////////////////////////////////////////////////////////// + std::string_view GetOperatingSystemName() { @@ -422,6 +575,24 @@ GetOperatingSystemName() } std::string_view +GetRuntimePlatformName() +{ +#if ZEN_PLATFORM_WINDOWS + if (zen::windows::IsRunningOnWine()) + { + return "wine"sv; + } + return "windows"sv; +#elif ZEN_PLATFORM_LINUX + return "linux"sv; +#elif ZEN_PLATFORM_MAC + return "macos"sv; +#else + return "unknown"sv; +#endif +} + +std::string_view GetCpuName() { #if ZEN_ARCH_X64 @@ -440,4 +611,11 @@ Describe(const SystemMetrics& Metrics, CbWriter& Writer) << "avail_pagefile_mb" << Metrics.AvailPageFileMiB; } +void +Describe(const ExtendedSystemMetrics& Metrics, CbWriter& Writer) +{ + Describe(static_cast<const SystemMetrics&>(Metrics), Writer); + Writer << "cpu_usage_percent" << Metrics.CpuUsagePercent; +} + } // namespace zen diff --git a/src/zenhorde/hordeagent.cpp b/src/zenhorde/hordeagent.cpp new file mode 100644 index 000000000..819b2d0cb --- /dev/null +++ b/src/zenhorde/hordeagent.cpp @@ -0,0 +1,297 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "hordeagent.h" +#include "hordetransportaes.h" + +#include <zencore/basicfile.h> +#include <zencore/fmtutils.h> +#include <zencore/logging.h> +#include <zencore/trace.h> + +#include <cstring> +#include <unordered_map> + +namespace zen::horde { + +HordeAgent::HordeAgent(const MachineInfo& Info) : m_Log(zen::logging::Get("horde.agent")), m_MachineInfo(Info) +{ + ZEN_TRACE_CPU("HordeAgent::Connect"); + + auto Transport = std::make_unique<TcpComputeTransport>(Info); + if (!Transport->IsValid()) + { + ZEN_WARN("failed to create TCP transport to '{}:{}'", Info.GetConnectionAddress(), Info.GetConnectionPort()); + return; + } + + // The 64-byte nonce is always sent unencrypted as the first thing on the wire. + // The Horde agent uses this to identify which lease this connection belongs to. + Transport->Send(Info.Nonce, sizeof(Info.Nonce)); + + std::unique_ptr<ComputeTransport> FinalTransport = std::move(Transport); + if (Info.EncryptionMode == Encryption::AES) + { + FinalTransport = std::make_unique<AesComputeTransport>(Info.Key, std::move(FinalTransport)); + if (!FinalTransport->IsValid()) + { + ZEN_WARN("failed to create AES transport"); + return; + } + } + + // Create multiplexed socket and channels + m_Socket = std::make_unique<ComputeSocket>(std::move(FinalTransport)); + + // Channel 0 is the agent control channel (handles Attach/Fork handshake). + // Channel 100 is the child I/O channel (handles file upload and remote execution). + Ref<ComputeChannel> AgentComputeChannel = m_Socket->CreateChannel(0); + Ref<ComputeChannel> ChildComputeChannel = m_Socket->CreateChannel(100); + + if (!AgentComputeChannel || !ChildComputeChannel) + { + ZEN_WARN("failed to create compute channels"); + return; + } + + m_AgentChannel = std::make_unique<AgentMessageChannel>(std::move(AgentComputeChannel)); + m_ChildChannel = std::make_unique<AgentMessageChannel>(std::move(ChildComputeChannel)); + + m_IsValid = true; +} + +HordeAgent::~HordeAgent() +{ + CloseConnection(); +} + +bool +HordeAgent::BeginCommunication() +{ + ZEN_TRACE_CPU("HordeAgent::BeginCommunication"); + + if (!m_IsValid) + { + return false; + } + + // Start the send/recv pump threads + m_Socket->StartCommunication(); + + // Wait for Attach on agent channel + AgentMessageType Type = m_AgentChannel->ReadResponse(5000); + if (Type == AgentMessageType::None) + { + ZEN_WARN("timed out waiting for Attach on agent channel"); + return false; + } + if (Type != AgentMessageType::Attach) + { + ZEN_WARN("expected Attach on agent channel, got 0x{:02x}", static_cast<int>(Type)); + return false; + } + + // Fork tells the remote agent to create child channel 100 with a 4MB buffer. + // After this, the agent will send an Attach on the child channel. + m_AgentChannel->Fork(100, 4 * 1024 * 1024); + + // Wait for Attach on child channel + Type = m_ChildChannel->ReadResponse(5000); + if (Type == AgentMessageType::None) + { + ZEN_WARN("timed out waiting for Attach on child channel"); + return false; + } + if (Type != AgentMessageType::Attach) + { + ZEN_WARN("expected Attach on child channel, got 0x{:02x}", static_cast<int>(Type)); + return false; + } + + return true; +} + +bool +HordeAgent::UploadBinaries(const std::filesystem::path& BundleDir, const std::string& BundleLocator) +{ + ZEN_TRACE_CPU("HordeAgent::UploadBinaries"); + + m_ChildChannel->UploadFiles("", BundleLocator.c_str()); + + std::unordered_map<std::string, std::unique_ptr<BasicFile>> BlobFiles; + + auto FindOrOpenBlob = [&](std::string_view Locator) -> BasicFile* { + std::string Key(Locator); + + if (auto It = BlobFiles.find(Key); It != BlobFiles.end()) + { + return It->second.get(); + } + + const std::filesystem::path Path = BundleDir / (Key + ".blob"); + std::error_code Ec; + auto File = std::make_unique<BasicFile>(); + File->Open(Path, BasicFile::Mode::kRead, Ec); + + if (Ec) + { + ZEN_ERROR("cannot read blob file: '{}'", Path); + return nullptr; + } + + BasicFile* Ptr = File.get(); + BlobFiles.emplace(std::move(Key), std::move(File)); + return Ptr; + }; + + // The upload protocol is request-driven: we send WriteFiles, then the remote agent + // sends ReadBlob requests for each blob it needs. We respond with Blob data until + // the agent sends WriteFilesResponse indicating the upload is complete. + constexpr int32_t ReadResponseTimeoutMs = 1000; + + for (;;) + { + bool TimedOut = false; + + if (AgentMessageType Type = m_ChildChannel->ReadResponse(ReadResponseTimeoutMs, &TimedOut); Type != AgentMessageType::ReadBlob) + { + if (TimedOut) + { + continue; + } + // End of stream - check if it was a successful upload + if (Type == AgentMessageType::WriteFilesResponse) + { + return true; + } + else if (Type == AgentMessageType::Exception) + { + ExceptionInfo Ex; + m_ChildChannel->ReadException(Ex); + ZEN_ERROR("upload exception: {} - {}", Ex.Message, Ex.Description); + } + else + { + ZEN_ERROR("unexpected message type 0x{:02x} during upload", static_cast<int>(Type)); + } + return false; + } + + BlobRequest Req; + m_ChildChannel->ReadBlobRequest(Req); + + BasicFile* File = FindOrOpenBlob(Req.Locator); + if (!File) + { + return false; + } + + // Read from offset to end of file + const uint64_t TotalSize = File->FileSize(); + const uint64_t Offset = static_cast<uint64_t>(Req.Offset); + if (Offset >= TotalSize) + { + ZEN_ERROR("upload got request for data beyond end of file: offset={}, length={}, total_size={}", Offset, Req.Length, TotalSize); + m_ChildChannel->Blob(nullptr, 0); + continue; + } + + const IoBuffer Data = File->ReadRange(Offset, Min(Req.Length, TotalSize - Offset)); + m_ChildChannel->Blob(static_cast<const uint8_t*>(Data.GetData()), Data.GetSize()); + } +} + +void +HordeAgent::Execute(const char* Exe, + const char* const* Args, + size_t NumArgs, + const char* WorkingDir, + const char* const* EnvVars, + size_t NumEnvVars, + bool UseWine) +{ + ZEN_TRACE_CPU("HordeAgent::Execute"); + m_ChildChannel + ->Execute(Exe, Args, NumArgs, WorkingDir, EnvVars, NumEnvVars, UseWine ? ExecuteProcessFlags::UseWine : ExecuteProcessFlags::None); +} + +bool +HordeAgent::Poll(bool LogOutput) +{ + constexpr int32_t ReadResponseTimeoutMs = 100; + AgentMessageType Type; + + while ((Type = m_ChildChannel->ReadResponse(ReadResponseTimeoutMs)) != AgentMessageType::None) + { + switch (Type) + { + case AgentMessageType::ExecuteOutput: + { + if (LogOutput && m_ChildChannel->GetResponseSize() > 0) + { + const char* ResponseData = static_cast<const char*>(m_ChildChannel->GetResponseData()); + size_t ResponseSize = m_ChildChannel->GetResponseSize(); + + // Trim trailing newlines + while (ResponseSize > 0 && (ResponseData[ResponseSize - 1] == '\n' || ResponseData[ResponseSize - 1] == '\r')) + { + --ResponseSize; + } + + if (ResponseSize > 0) + { + const std::string_view Output(ResponseData, ResponseSize); + ZEN_INFO("[remote] {}", Output); + } + } + break; + } + + case AgentMessageType::ExecuteResult: + { + if (m_ChildChannel->GetResponseSize() == sizeof(int32_t)) + { + int32_t ExitCode; + memcpy(&ExitCode, m_ChildChannel->GetResponseData(), sizeof(int32_t)); + ZEN_INFO("remote process exited with code {}", ExitCode); + } + m_IsValid = false; + return false; + } + + case AgentMessageType::Exception: + { + ExceptionInfo Ex; + m_ChildChannel->ReadException(Ex); + ZEN_ERROR("exception: {} - {}", Ex.Message, Ex.Description); + m_HasErrors = true; + break; + } + + default: + break; + } + } + + return m_IsValid && !m_HasErrors; +} + +void +HordeAgent::CloseConnection() +{ + if (m_ChildChannel) + { + m_ChildChannel->Close(); + } + if (m_AgentChannel) + { + m_AgentChannel->Close(); + } +} + +bool +HordeAgent::IsValid() const +{ + return m_IsValid && !m_HasErrors; +} + +} // namespace zen::horde diff --git a/src/zenhorde/hordeagent.h b/src/zenhorde/hordeagent.h new file mode 100644 index 000000000..e0ae89ead --- /dev/null +++ b/src/zenhorde/hordeagent.h @@ -0,0 +1,77 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "hordeagentmessage.h" +#include "hordecomputesocket.h" + +#include <zenhorde/hordeclient.h> + +#include <zencore/logbase.h> + +#include <filesystem> +#include <memory> +#include <string> + +namespace zen::horde { + +/** Manages the lifecycle of a single Horde compute agent. + * + * Handles the full connection sequence for one provisioned machine: + * 1. Connect via TCP transport (with optional AES encryption wrapping) + * 2. Create a multiplexed ComputeSocket with agent (channel 0) and child (channel 100) + * 3. Perform the Attach/Fork handshake to establish the child channel + * 4. Upload zenserver binary via the WriteFiles/ReadBlob protocol + * 5. Execute zenserver remotely via ExecuteV2 + * 6. Poll for ExecuteOutput (stdout) and ExecuteResult (exit code) + */ +class HordeAgent +{ +public: + explicit HordeAgent(const MachineInfo& Info); + ~HordeAgent(); + + HordeAgent(const HordeAgent&) = delete; + HordeAgent& operator=(const HordeAgent&) = delete; + + /** Perform the channel setup handshake (Attach on agent channel, Fork, Attach on child channel). + * Returns false if the handshake times out or receives an unexpected message. */ + bool BeginCommunication(); + + /** Upload binary files to the remote agent. + * @param BundleDir Directory containing .blob files. + * @param BundleLocator Locator string identifying the bundle (from CreateBundle). */ + bool UploadBinaries(const std::filesystem::path& BundleDir, const std::string& BundleLocator); + + /** Execute a command on the remote machine. */ + void Execute(const char* Exe, + const char* const* Args, + size_t NumArgs, + const char* WorkingDir = nullptr, + const char* const* EnvVars = nullptr, + size_t NumEnvVars = 0, + bool UseWine = false); + + /** Poll for output and results. Returns true if the agent is still running. + * When LogOutput is true, remote stdout is logged via ZEN_INFO. */ + bool Poll(bool LogOutput = true); + + void CloseConnection(); + bool IsValid() const; + + const MachineInfo& GetMachineInfo() const { return m_MachineInfo; } + +private: + LoggerRef Log() { return m_Log; } + + std::unique_ptr<ComputeSocket> m_Socket; + std::unique_ptr<AgentMessageChannel> m_AgentChannel; ///< Channel 0: agent control + std::unique_ptr<AgentMessageChannel> m_ChildChannel; ///< Channel 100: child I/O + + LoggerRef m_Log; + bool m_IsValid = false; + bool m_HasErrors = false; + MachineInfo m_MachineInfo; +}; + +} // namespace zen::horde diff --git a/src/zenhorde/hordeagentmessage.cpp b/src/zenhorde/hordeagentmessage.cpp new file mode 100644 index 000000000..998134a96 --- /dev/null +++ b/src/zenhorde/hordeagentmessage.cpp @@ -0,0 +1,340 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "hordeagentmessage.h" + +#include <zencore/intmath.h> + +#include <cassert> +#include <cstring> + +namespace zen::horde { + +AgentMessageChannel::AgentMessageChannel(Ref<ComputeChannel> Channel) : m_Channel(std::move(Channel)) +{ +} + +AgentMessageChannel::~AgentMessageChannel() = default; + +void +AgentMessageChannel::Close() +{ + CreateMessage(AgentMessageType::None, 0); + FlushMessage(); +} + +void +AgentMessageChannel::Ping() +{ + CreateMessage(AgentMessageType::Ping, 0); + FlushMessage(); +} + +void +AgentMessageChannel::Fork(int ChannelId, int BufferSize) +{ + CreateMessage(AgentMessageType::Fork, sizeof(int) + sizeof(int)); + WriteInt32(ChannelId); + WriteInt32(BufferSize); + FlushMessage(); +} + +void +AgentMessageChannel::Attach() +{ + CreateMessage(AgentMessageType::Attach, 0); + FlushMessage(); +} + +void +AgentMessageChannel::UploadFiles(const char* Path, const char* Locator) +{ + CreateMessage(AgentMessageType::WriteFiles, strlen(Path) + strlen(Locator) + 20); + WriteString(Path); + WriteString(Locator); + FlushMessage(); +} + +void +AgentMessageChannel::Execute(const char* Exe, + const char* const* Args, + size_t NumArgs, + const char* WorkingDir, + const char* const* EnvVars, + size_t NumEnvVars, + ExecuteProcessFlags Flags) +{ + size_t RequiredSize = 50 + strlen(Exe); + for (size_t i = 0; i < NumArgs; ++i) + { + RequiredSize += strlen(Args[i]) + 10; + } + if (WorkingDir) + { + RequiredSize += strlen(WorkingDir) + 10; + } + for (size_t i = 0; i < NumEnvVars; ++i) + { + RequiredSize += strlen(EnvVars[i]) + 20; + } + + CreateMessage(AgentMessageType::ExecuteV2, RequiredSize); + WriteString(Exe); + + WriteUnsignedVarInt(NumArgs); + for (size_t i = 0; i < NumArgs; ++i) + { + WriteString(Args[i]); + } + + WriteOptionalString(WorkingDir); + + // ExecuteV2 protocol requires env vars as separate key/value pairs. + // Callers pass "KEY=VALUE" strings; we split on the first '=' here. + WriteUnsignedVarInt(NumEnvVars); + for (size_t i = 0; i < NumEnvVars; ++i) + { + const char* Eq = strchr(EnvVars[i], '='); + assert(Eq != nullptr); + + WriteString(std::string_view(EnvVars[i], Eq - EnvVars[i])); + if (*(Eq + 1) == '\0') + { + WriteOptionalString(nullptr); + } + else + { + WriteOptionalString(Eq + 1); + } + } + + WriteInt32(static_cast<int>(Flags)); + FlushMessage(); +} + +void +AgentMessageChannel::Blob(const uint8_t* Data, size_t Length) +{ + // Blob responses are chunked to fit within the compute buffer's chunk size. + // The 128-byte margin accounts for the ReadBlobResponse header (offset + total length fields). + const size_t MaxChunkSize = m_Channel->Writer.GetChunkMaxLength() - 128 - MessageHeaderLength; + for (size_t ChunkOffset = 0; ChunkOffset < Length;) + { + const size_t ChunkLength = std::min(Length - ChunkOffset, MaxChunkSize); + + CreateMessage(AgentMessageType::ReadBlobResponse, ChunkLength + 128); + WriteInt32(static_cast<int>(ChunkOffset)); + WriteInt32(static_cast<int>(Length)); + WriteFixedLengthBytes(Data + ChunkOffset, ChunkLength); + FlushMessage(); + + ChunkOffset += ChunkLength; + } +} + +AgentMessageType +AgentMessageChannel::ReadResponse(int32_t TimeoutMs, bool* OutTimedOut) +{ + // Deferred advance: the previous response's buffer is only released when the next + // ReadResponse is called. This allows callers to read response data between calls + // without copying, since the pointer comes directly from the ring buffer. + if (m_ResponseData) + { + m_Channel->Reader.AdvanceReadPosition(m_ResponseLength + MessageHeaderLength); + m_ResponseData = nullptr; + m_ResponseLength = 0; + } + + const uint8_t* Header = m_Channel->Reader.WaitToRead(MessageHeaderLength, TimeoutMs, OutTimedOut); + if (!Header) + { + return AgentMessageType::None; + } + + uint32_t Length; + memcpy(&Length, Header + 1, sizeof(uint32_t)); + + Header = m_Channel->Reader.WaitToRead(MessageHeaderLength + Length, TimeoutMs, OutTimedOut); + if (!Header) + { + return AgentMessageType::None; + } + + m_ResponseType = static_cast<AgentMessageType>(Header[0]); + m_ResponseData = Header + MessageHeaderLength; + m_ResponseLength = Length; + + return m_ResponseType; +} + +void +AgentMessageChannel::ReadException(ExceptionInfo& Ex) +{ + assert(m_ResponseType == AgentMessageType::Exception); + const uint8_t* Pos = m_ResponseData; + Ex.Message = ReadString(&Pos); + Ex.Description = ReadString(&Pos); +} + +int +AgentMessageChannel::ReadExecuteResult() +{ + assert(m_ResponseType == AgentMessageType::ExecuteResult); + const uint8_t* Pos = m_ResponseData; + return ReadInt32(&Pos); +} + +void +AgentMessageChannel::ReadBlobRequest(BlobRequest& Req) +{ + assert(m_ResponseType == AgentMessageType::ReadBlob); + const uint8_t* Pos = m_ResponseData; + Req.Locator = ReadString(&Pos); + Req.Offset = ReadUnsignedVarInt(&Pos); + Req.Length = ReadUnsignedVarInt(&Pos); +} + +void +AgentMessageChannel::CreateMessage(AgentMessageType Type, size_t MaxLength) +{ + m_RequestData = m_Channel->Writer.WaitToWrite(MessageHeaderLength + MaxLength); + m_RequestData[0] = static_cast<uint8_t>(Type); + m_MaxRequestSize = MaxLength; + m_RequestSize = 0; +} + +void +AgentMessageChannel::FlushMessage() +{ + const uint32_t Size = static_cast<uint32_t>(m_RequestSize); + memcpy(&m_RequestData[1], &Size, sizeof(uint32_t)); + m_Channel->Writer.AdvanceWritePosition(MessageHeaderLength + m_RequestSize); + m_RequestSize = 0; + m_MaxRequestSize = 0; + m_RequestData = nullptr; +} + +void +AgentMessageChannel::WriteInt32(int Value) +{ + WriteFixedLengthBytes(reinterpret_cast<const uint8_t*>(&Value), sizeof(int)); +} + +int +AgentMessageChannel::ReadInt32(const uint8_t** Pos) +{ + int Value; + memcpy(&Value, *Pos, sizeof(int)); + *Pos += sizeof(int); + return Value; +} + +void +AgentMessageChannel::WriteFixedLengthBytes(const uint8_t* Data, size_t Length) +{ + assert(m_RequestSize + Length <= m_MaxRequestSize); + memcpy(&m_RequestData[MessageHeaderLength + m_RequestSize], Data, Length); + m_RequestSize += Length; +} + +const uint8_t* +AgentMessageChannel::ReadFixedLengthBytes(const uint8_t** Pos, size_t Length) +{ + const uint8_t* Data = *Pos; + *Pos += Length; + return Data; +} + +size_t +AgentMessageChannel::MeasureUnsignedVarInt(size_t Value) +{ + if (Value == 0) + { + return 1; + } + return (FloorLog2_64(static_cast<uint64_t>(Value)) / 7) + 1; +} + +void +AgentMessageChannel::WriteUnsignedVarInt(size_t Value) +{ + const size_t ByteCount = MeasureUnsignedVarInt(Value); + assert(m_RequestSize + ByteCount <= m_MaxRequestSize); + + uint8_t* Output = m_RequestData + MessageHeaderLength + m_RequestSize; + for (size_t i = 1; i < ByteCount; ++i) + { + Output[ByteCount - i] = static_cast<uint8_t>(Value); + Value >>= 8; + } + Output[0] = static_cast<uint8_t>((0xFF << (9 - static_cast<int>(ByteCount))) | static_cast<uint8_t>(Value)); + + m_RequestSize += ByteCount; +} + +size_t +AgentMessageChannel::ReadUnsignedVarInt(const uint8_t** Pos) +{ + const uint8_t* Data = *Pos; + const uint8_t FirstByte = Data[0]; + const size_t NumBytes = CountLeadingZeros(0xFF & (~static_cast<unsigned int>(FirstByte))) + 1 - 24; + + size_t Value = static_cast<size_t>(FirstByte & (0xFF >> NumBytes)); + for (size_t i = 1; i < NumBytes; ++i) + { + Value <<= 8; + Value |= Data[i]; + } + + *Pos += NumBytes; + return Value; +} + +size_t +AgentMessageChannel::MeasureString(const char* Text) const +{ + const size_t Length = strlen(Text); + return MeasureUnsignedVarInt(Length) + Length; +} + +void +AgentMessageChannel::WriteString(const char* Text) +{ + const size_t Length = strlen(Text); + WriteUnsignedVarInt(Length); + WriteFixedLengthBytes(reinterpret_cast<const uint8_t*>(Text), Length); +} + +void +AgentMessageChannel::WriteString(std::string_view Text) +{ + WriteUnsignedVarInt(Text.size()); + WriteFixedLengthBytes(reinterpret_cast<const uint8_t*>(Text.data()), Text.size()); +} + +std::string_view +AgentMessageChannel::ReadString(const uint8_t** Pos) +{ + const size_t Length = ReadUnsignedVarInt(Pos); + const char* Start = reinterpret_cast<const char*>(ReadFixedLengthBytes(Pos, Length)); + return std::string_view(Start, Length); +} + +void +AgentMessageChannel::WriteOptionalString(const char* Text) +{ + // Optional strings use length+1 encoding: 0 means null/absent, + // N>0 means a string of length N-1 follows. This matches the UE + // FAgentMessageChannel serialization convention. + if (!Text) + { + WriteUnsignedVarInt(0); + } + else + { + const size_t Length = strlen(Text); + WriteUnsignedVarInt(Length + 1); + WriteFixedLengthBytes(reinterpret_cast<const uint8_t*>(Text), Length); + } +} + +} // namespace zen::horde diff --git a/src/zenhorde/hordeagentmessage.h b/src/zenhorde/hordeagentmessage.h new file mode 100644 index 000000000..38c4375fd --- /dev/null +++ b/src/zenhorde/hordeagentmessage.h @@ -0,0 +1,161 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenbase/zenbase.h> + +#include "hordecomputechannel.h" + +#include <cstddef> +#include <cstdint> +#include <string> +#include <string_view> +#include <vector> + +namespace zen::horde { + +/** Agent message types matching the UE EAgentMessageType byte values. + * These are the message opcodes exchanged over the agent/child channels. */ +enum class AgentMessageType : uint8_t +{ + None = 0x00, + Ping = 0x01, + Exception = 0x02, + Fork = 0x03, + Attach = 0x04, + WriteFiles = 0x10, + WriteFilesResponse = 0x11, + DeleteFiles = 0x12, + ExecuteV2 = 0x22, + ExecuteOutput = 0x17, + ExecuteResult = 0x18, + ReadBlob = 0x20, + ReadBlobResponse = 0x21, +}; + +/** Flags for the ExecuteV2 message. */ +enum class ExecuteProcessFlags : uint8_t +{ + None = 0, + UseWine = 1, ///< Run the executable under Wine on Linux agents +}; + +/** Parsed exception information from an Exception message. */ +struct ExceptionInfo +{ + std::string_view Message; + std::string_view Description; +}; + +/** Parsed blob read request from a ReadBlob message. */ +struct BlobRequest +{ + std::string_view Locator; + size_t Offset = 0; + size_t Length = 0; +}; + +/** Channel for sending and receiving agent messages over a ComputeChannel. + * + * Implements the Horde agent message protocol, matching the UE + * FAgentMessageChannel serialization format exactly. Messages are framed as + * [type (1B)][payload length (4B)][payload]. Strings use length-prefixed UTF-8; + * integers use variable-length encoding. + * + * The protocol has two directions: + * - Requests (initiator -> remote): Close, Ping, Fork, Attach, UploadFiles, Execute, Blob + * - Responses (remote -> initiator): ReadResponse returns the type, then call the + * appropriate Read* method to parse the payload. + */ +class AgentMessageChannel +{ +public: + explicit AgentMessageChannel(Ref<ComputeChannel> Channel); + ~AgentMessageChannel(); + + AgentMessageChannel(const AgentMessageChannel&) = delete; + AgentMessageChannel& operator=(const AgentMessageChannel&) = delete; + + // --- Requests (Initiator -> Remote) --- + + /** Close the channel. */ + void Close(); + + /** Send a keepalive ping. */ + void Ping(); + + /** Fork communication to a new channel with the given ID and buffer size. */ + void Fork(int ChannelId, int BufferSize); + + /** Send an attach request (used during channel setup handshake). */ + void Attach(); + + /** Request the remote agent to write files from the given bundle locator. */ + void UploadFiles(const char* Path, const char* Locator); + + /** Execute a process on the remote machine. */ + void Execute(const char* Exe, + const char* const* Args, + size_t NumArgs, + const char* WorkingDir, + const char* const* EnvVars, + size_t NumEnvVars, + ExecuteProcessFlags Flags = ExecuteProcessFlags::None); + + /** Send blob data in response to a ReadBlob request. */ + void Blob(const uint8_t* Data, size_t Length); + + // --- Responses (Remote -> Initiator) --- + + /** Read the next response message. Returns the message type, or None on timeout. + * After this returns, use GetResponseData()/GetResponseSize() or the typed + * Read* methods to access the payload. */ + AgentMessageType ReadResponse(int32_t TimeoutMs = -1, bool* OutTimedOut = nullptr); + + const void* GetResponseData() const { return m_ResponseData; } + size_t GetResponseSize() const { return m_ResponseLength; } + + /** Parse an Exception response payload. */ + void ReadException(ExceptionInfo& Ex); + + /** Parse an ExecuteResult response payload. Returns the exit code. */ + int ReadExecuteResult(); + + /** Parse a ReadBlob response payload into a BlobRequest. */ + void ReadBlobRequest(BlobRequest& Req); + +private: + static constexpr size_t MessageHeaderLength = 5; ///< [type(1B)][length(4B)] + + Ref<ComputeChannel> m_Channel; + + uint8_t* m_RequestData = nullptr; + size_t m_RequestSize = 0; + size_t m_MaxRequestSize = 0; + + AgentMessageType m_ResponseType = AgentMessageType::None; + const uint8_t* m_ResponseData = nullptr; + size_t m_ResponseLength = 0; + + void CreateMessage(AgentMessageType Type, size_t MaxLength); + void FlushMessage(); + + void WriteInt32(int Value); + static int ReadInt32(const uint8_t** Pos); + + void WriteFixedLengthBytes(const uint8_t* Data, size_t Length); + static const uint8_t* ReadFixedLengthBytes(const uint8_t** Pos, size_t Length); + + static size_t MeasureUnsignedVarInt(size_t Value); + void WriteUnsignedVarInt(size_t Value); + static size_t ReadUnsignedVarInt(const uint8_t** Pos); + + size_t MeasureString(const char* Text) const; + void WriteString(const char* Text); + void WriteString(std::string_view Text); + static std::string_view ReadString(const uint8_t** Pos); + + void WriteOptionalString(const char* Text); +}; + +} // namespace zen::horde diff --git a/src/zenhorde/hordebundle.cpp b/src/zenhorde/hordebundle.cpp new file mode 100644 index 000000000..d3974bc28 --- /dev/null +++ b/src/zenhorde/hordebundle.cpp @@ -0,0 +1,619 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "hordebundle.h" + +#include <zencore/basicfile.h> +#include <zencore/filesystem.h> +#include <zencore/fmtutils.h> +#include <zencore/intmath.h> +#include <zencore/iohash.h> +#include <zencore/logging.h> +#include <zencore/process.h> +#include <zencore/trace.h> + +#include <algorithm> +#include <chrono> +#include <cstring> + +namespace zen::horde { + +static LoggerRef +Log() +{ + static auto s_Logger = zen::logging::Get("horde.bundle"); + return s_Logger; +} + +static constexpr uint8_t PacketSignature[3] = {'U', 'B', 'N'}; +static constexpr uint8_t PacketVersion = 5; +static constexpr int32_t CurrentPacketBaseIdx = -2; +static constexpr int ImportBias = 3; +static constexpr uint32_t ChunkSize = 64 * 1024; // 64KB fixed chunks +static constexpr uint32_t LargeFileThreshold = 128 * 1024; // 128KB + +// BlobType: 20 bytes each = FGuid (16 bytes, 4x uint32 LE) + Version (int32 LE) +// Values from UE SDK: GUIDs stored as 4 uint32 LE values. + +// ChunkLeaf v1: {0xB27AFB68, 0x4A4B9E20, 0x8A78D8A4, 0x39D49840} +static constexpr uint8_t BlobType_ChunkLeafV1[20] = {0x68, 0xFB, 0x7A, 0xB2, 0x20, 0x9E, 0x4B, 0x4A, 0xA4, 0xD8, + 0x78, 0x8A, 0x40, 0x98, 0xD4, 0x39, 0x01, 0x00, 0x00, 0x00}; // version 1 + +// ChunkInterior v2: {0xF4DEDDBC, 0x4C7A70CB, 0x11F04783, 0xB9CDCCAF} +static constexpr uint8_t BlobType_ChunkInteriorV2[20] = {0xBC, 0xDD, 0xDE, 0xF4, 0xCB, 0x70, 0x7A, 0x4C, 0x83, 0x47, + 0xF0, 0x11, 0xAF, 0xCC, 0xCD, 0xB9, 0x02, 0x00, 0x00, 0x00}; // version 2 + +// Directory v1: {0x0714EC11, 0x4D07291A, 0x8AE77F86, 0x799980D6} +static constexpr uint8_t BlobType_DirectoryV1[20] = {0x11, 0xEC, 0x14, 0x07, 0x1A, 0x29, 0x07, 0x4D, 0x86, 0x7F, + 0xE7, 0x8A, 0xD6, 0x80, 0x99, 0x79, 0x01, 0x00, 0x00, 0x00}; // version 1 + +static constexpr size_t BlobTypeSize = 20; + +// ─── VarInt helpers (UE format) ───────────────────────────────────────────── + +static size_t +MeasureVarInt(size_t Value) +{ + if (Value == 0) + { + return 1; + } + return (FloorLog2(static_cast<unsigned int>(Value)) / 7) + 1; +} + +static void +WriteVarInt(std::vector<uint8_t>& Buffer, size_t Value) +{ + const size_t ByteCount = MeasureVarInt(Value); + const size_t Offset = Buffer.size(); + Buffer.resize(Offset + ByteCount); + + uint8_t* Output = Buffer.data() + Offset; + for (size_t i = 1; i < ByteCount; ++i) + { + Output[ByteCount - i] = static_cast<uint8_t>(Value); + Value >>= 8; + } + Output[0] = static_cast<uint8_t>((0xFF << (9 - static_cast<int>(ByteCount))) | static_cast<uint8_t>(Value)); +} + +// ─── Binary helpers ───────────────────────────────────────────────────────── + +static void +WriteLE32(std::vector<uint8_t>& Buffer, int32_t Value) +{ + uint8_t Bytes[4]; + memcpy(Bytes, &Value, 4); + Buffer.insert(Buffer.end(), Bytes, Bytes + 4); +} + +static void +WriteByte(std::vector<uint8_t>& Buffer, uint8_t Value) +{ + Buffer.push_back(Value); +} + +static void +WriteBytes(std::vector<uint8_t>& Buffer, const void* Data, size_t Size) +{ + auto* Ptr = static_cast<const uint8_t*>(Data); + Buffer.insert(Buffer.end(), Ptr, Ptr + Size); +} + +static void +WriteString(std::vector<uint8_t>& Buffer, std::string_view Str) +{ + WriteVarInt(Buffer, Str.size()); + WriteBytes(Buffer, Str.data(), Str.size()); +} + +static void +AlignTo4(std::vector<uint8_t>& Buffer) +{ + while (Buffer.size() % 4 != 0) + { + Buffer.push_back(0); + } +} + +static void +PatchLE32(std::vector<uint8_t>& Buffer, size_t Offset, int32_t Value) +{ + memcpy(Buffer.data() + Offset, &Value, 4); +} + +// ─── Packet builder ───────────────────────────────────────────────────────── + +// Builds a single uncompressed Horde V2 packet. Layout: +// [Signature(3) + Version(1) + PacketLength(4)] 8 bytes (header) +// [TypeTableOffset(4) + ImportTableOffset(4) + ExportTableOffset(4)] 12 bytes +// [Export data...] +// [Type table: count(4) + count * 20 bytes] +// [Import table: count(4) + (count+1) offset entries(4 each) + import data] +// [Export table: count(4) + (count+1) offset entries(4 each)] +// +// ALL offsets are absolute from byte 0 of the full packet (including the 8-byte header). +// PacketLength in the header = total packet size including the 8-byte header. + +struct PacketBuilder +{ + std::vector<uint8_t> Data; + std::vector<int32_t> ExportOffsets; // Absolute byte offset of each export from byte 0 + + // Type table: unique 20-byte BlobType entries + std::vector<const uint8_t*> Types; + + // Import table entries: (baseIdx, fragment) + struct ImportEntry + { + int32_t BaseIdx; + std::string Fragment; + }; + std::vector<ImportEntry> Imports; + + // Current export's start offset (absolute from byte 0) + size_t CurrentExportStart = 0; + + PacketBuilder() + { + // Reserve packet header (8 bytes) + table offsets (12 bytes) = 20 bytes + Data.resize(20, 0); + + // Write signature + Data[0] = PacketSignature[0]; + Data[1] = PacketSignature[1]; + Data[2] = PacketSignature[2]; + Data[3] = PacketVersion; + // PacketLength, TypeTableOffset, ImportTableOffset, ExportTableOffset + // will be patched in Finish() + } + + int AddType(const uint8_t* BlobType) + { + for (size_t i = 0; i < Types.size(); ++i) + { + if (memcmp(Types[i], BlobType, BlobTypeSize) == 0) + { + return static_cast<int>(i); + } + } + Types.push_back(BlobType); + return static_cast<int>(Types.size() - 1); + } + + int AddImport(int32_t BaseIdx, std::string Fragment) + { + Imports.push_back({BaseIdx, std::move(Fragment)}); + return static_cast<int>(Imports.size() - 1); + } + + void BeginExport() + { + AlignTo4(Data); + CurrentExportStart = Data.size(); + // Reserve space for payload length + WriteLE32(Data, 0); + } + + // Write raw payload data into the current export + void WritePayload(const void* Payload, size_t Size) { WriteBytes(Data, Payload, Size); } + + // Complete the current export: patches payload length, writes type+imports metadata + int CompleteExport(const uint8_t* BlobType, const std::vector<int>& ImportIndices) + { + const int ExportIndex = static_cast<int>(ExportOffsets.size()); + + // Patch payload length (does not include the 4-byte length field itself) + const size_t PayloadStart = CurrentExportStart + 4; + const int32_t PayloadLen = static_cast<int32_t>(Data.size() - PayloadStart); + PatchLE32(Data, CurrentExportStart, PayloadLen); + + // Write type index (varint) + const int TypeIdx = AddType(BlobType); + WriteVarInt(Data, static_cast<size_t>(TypeIdx)); + + // Write import count + indices + WriteVarInt(Data, ImportIndices.size()); + for (int Idx : ImportIndices) + { + WriteVarInt(Data, static_cast<size_t>(Idx)); + } + + // Record export offset (absolute from byte 0) + ExportOffsets.push_back(static_cast<int32_t>(CurrentExportStart)); + + return ExportIndex; + } + + // Finalize the packet: write type/import/export tables, patch header. + std::vector<uint8_t> Finish() + { + AlignTo4(Data); + + // ── Type table: count(int32) + count * BlobTypeSize bytes ── + const int32_t TypeTableOffset = static_cast<int32_t>(Data.size()); + WriteLE32(Data, static_cast<int32_t>(Types.size())); + for (const uint8_t* TypeEntry : Types) + { + WriteBytes(Data, TypeEntry, BlobTypeSize); + } + + // ── Import table: count(int32) + (count+1) offsets(int32 each) + import data ── + const int32_t ImportTableOffset = static_cast<int32_t>(Data.size()); + const int32_t ImportCount = static_cast<int32_t>(Imports.size()); + WriteLE32(Data, ImportCount); + + // Reserve space for (count+1) offset entries — will be patched below + const size_t ImportOffsetsStart = Data.size(); + for (int32_t i = 0; i <= ImportCount; ++i) + { + WriteLE32(Data, 0); // placeholder + } + + // Write import data and record offsets + for (int32_t i = 0; i < ImportCount; ++i) + { + // Record absolute offset of this import's data + PatchLE32(Data, ImportOffsetsStart + static_cast<size_t>(i) * 4, static_cast<int32_t>(Data.size())); + + ImportEntry& Imp = Imports[static_cast<size_t>(i)]; + // BaseIdx encoded as unsigned VarInt with bias: VarInt(BaseIdx + ImportBias) + const size_t EncodedBaseIdx = static_cast<size_t>(static_cast<int64_t>(Imp.BaseIdx) + ImportBias); + WriteVarInt(Data, EncodedBaseIdx); + // Fragment: raw UTF-8 bytes, NO length prefix (length determined by offset table) + WriteBytes(Data, Imp.Fragment.data(), Imp.Fragment.size()); + } + + // Sentinel offset (points past the last import's data) + PatchLE32(Data, ImportOffsetsStart + static_cast<size_t>(ImportCount) * 4, static_cast<int32_t>(Data.size())); + + // ── Export table: count(int32) + (count+1) offsets(int32 each) ── + const int32_t ExportTableOffset = static_cast<int32_t>(Data.size()); + const int32_t ExportCount = static_cast<int32_t>(ExportOffsets.size()); + WriteLE32(Data, ExportCount); + + for (int32_t Off : ExportOffsets) + { + WriteLE32(Data, Off); + } + // Sentinel: points to the start of the type table (end of export data region) + WriteLE32(Data, TypeTableOffset); + + // ── Patch header ── + // PacketLength = total packet size including the 8-byte header + const int32_t PacketLength = static_cast<int32_t>(Data.size()); + PatchLE32(Data, 4, PacketLength); + PatchLE32(Data, 8, TypeTableOffset); + PatchLE32(Data, 12, ImportTableOffset); + PatchLE32(Data, 16, ExportTableOffset); + + return std::move(Data); + } +}; + +// ─── Encoded packet wrapper ───────────────────────────────────────────────── + +// Wraps an uncompressed packet with the encoded header: +// [Signature(3) + Version(1) + HeaderLength(4)] 8 bytes +// [DecompressedLength(4)] 4 bytes +// [CompressionFormat(1): 0=None] 1 byte +// [PacketData...] +// +// HeaderLength = total encoded packet size INCLUDING the 8-byte outer header. + +static std::vector<uint8_t> +EncodePacket(std::vector<uint8_t> UncompressedPacket) +{ + const int32_t DecompressedLen = static_cast<int32_t>(UncompressedPacket.size()); + // HeaderLength includes the 8-byte outer signature header itself + const int32_t HeaderLength = 8 + 4 + 1 + DecompressedLen; + + std::vector<uint8_t> Encoded; + Encoded.reserve(static_cast<size_t>(HeaderLength)); + + // Outer signature: 'U','B','N', version=5, HeaderLength (LE int32) + WriteByte(Encoded, PacketSignature[0]); // 'U' + WriteByte(Encoded, PacketSignature[1]); // 'B' + WriteByte(Encoded, PacketSignature[2]); // 'N' + WriteByte(Encoded, PacketVersion); // 5 + WriteLE32(Encoded, HeaderLength); + + // Decompressed length + compression format + WriteLE32(Encoded, DecompressedLen); + WriteByte(Encoded, 0); // CompressionFormat::None + + // Packet data + WriteBytes(Encoded, UncompressedPacket.data(), UncompressedPacket.size()); + + return Encoded; +} + +// ─── Bundle blob name generation ──────────────────────────────────────────── + +static std::string +GenerateBlobName() +{ + static std::atomic<uint32_t> s_Counter{0}; + + const int Pid = GetCurrentProcessId(); + + auto Now = std::chrono::steady_clock::now().time_since_epoch(); + auto Ms = std::chrono::duration_cast<std::chrono::milliseconds>(Now).count(); + + ExtendableStringBuilder<64> Name; + Name << Pid << "_" << Ms << "_" << s_Counter.fetch_add(1); + return std::string(Name.ToView()); +} + +// ─── File info for bundling ───────────────────────────────────────────────── + +struct FileInfo +{ + std::filesystem::path Path; + std::string Name; // Filename only (for directory entry) + uint64_t FileSize; + IoHash ContentHash; // IoHash of file content + BLAKE3 StreamHash; // Full BLAKE3 for stream hash + int DirectoryExportImportIndex; // Import index referencing this file's root export + IoHash RootExportHash; // IoHash of the root export for this file +}; + +// ─── CreateBundle implementation ──────────────────────────────────────────── + +bool +BundleCreator::CreateBundle(const std::vector<BundleFile>& Files, const std::filesystem::path& OutputDir, BundleResult& OutResult) +{ + ZEN_TRACE_CPU("BundleCreator::CreateBundle"); + + std::error_code Ec; + + // Collect files that exist + std::vector<FileInfo> ValidFiles; + for (const BundleFile& F : Files) + { + if (!std::filesystem::exists(F.Path, Ec)) + { + if (F.Optional) + { + continue; + } + ZEN_ERROR("required bundle file does not exist: {}", F.Path.string()); + return false; + } + FileInfo Info; + Info.Path = F.Path; + Info.Name = F.Path.filename().string(); + Info.FileSize = std::filesystem::file_size(F.Path, Ec); + if (Ec) + { + ZEN_ERROR("failed to get file size: {}", F.Path.string()); + return false; + } + ValidFiles.push_back(std::move(Info)); + } + + if (ValidFiles.empty()) + { + ZEN_ERROR("no valid files to bundle"); + return false; + } + + std::filesystem::create_directories(OutputDir, Ec); + if (Ec) + { + ZEN_ERROR("failed to create output directory: {}", OutputDir.string()); + return false; + } + + const std::string BlobName = GenerateBlobName(); + PacketBuilder Packet; + + // Process each file: create chunk exports + for (FileInfo& Info : ValidFiles) + { + BasicFile File; + File.Open(Info.Path, BasicFile::Mode::kRead, Ec); + if (Ec) + { + ZEN_ERROR("failed to open file: {}", Info.Path.string()); + return false; + } + + // Compute stream hash (full BLAKE3) and content hash (IoHash) while reading + BLAKE3Stream StreamHasher; + IoHashStream ContentHasher; + + if (Info.FileSize <= LargeFileThreshold) + { + // Small file: single chunk leaf export + IoBuffer Content = File.ReadAll(); + const auto* Data = static_cast<const uint8_t*>(Content.GetData()); + const size_t Size = Content.GetSize(); + + StreamHasher.Append(Data, Size); + ContentHasher.Append(Data, Size); + + Packet.BeginExport(); + Packet.WritePayload(Data, Size); + + const IoHash ChunkHash = IoHash::HashBuffer(Data, Size); + const int ExportIndex = Packet.CompleteExport(BlobType_ChunkLeafV1, {}); + Info.RootExportHash = ChunkHash; + Info.ContentHash = ContentHasher.GetHash(); + Info.StreamHash = StreamHasher.GetHash(); + + // Add import for this file's root export (references export within same packet) + ExtendableStringBuilder<32> Fragment; + Fragment << "exp=" << ExportIndex; + Info.DirectoryExportImportIndex = Packet.AddImport(CurrentPacketBaseIdx, std::string(Fragment.ToView())); + } + else + { + // Large file: split into fixed 64KB chunks, then create interior node + std::vector<int> ChunkExportIndices; + std::vector<IoHash> ChunkHashes; + + uint64_t Remaining = Info.FileSize; + uint64_t Offset = 0; + + while (Remaining > 0) + { + const uint64_t ReadSize = std::min(static_cast<uint64_t>(ChunkSize), Remaining); + IoBuffer Chunk = File.ReadRange(Offset, ReadSize); + const auto* Data = static_cast<const uint8_t*>(Chunk.GetData()); + const size_t Size = Chunk.GetSize(); + + StreamHasher.Append(Data, Size); + ContentHasher.Append(Data, Size); + + Packet.BeginExport(); + Packet.WritePayload(Data, Size); + + const IoHash ChunkHash = IoHash::HashBuffer(Data, Size); + const int ExpIdx = Packet.CompleteExport(BlobType_ChunkLeafV1, {}); + + ChunkExportIndices.push_back(ExpIdx); + ChunkHashes.push_back(ChunkHash); + + Offset += ReadSize; + Remaining -= ReadSize; + } + + Info.ContentHash = ContentHasher.GetHash(); + Info.StreamHash = StreamHasher.GetHash(); + + // Create interior node referencing all chunk leaves + // Interior payload: for each child: [IoHash(20)][node_type=1(1)] + imports + std::vector<int> InteriorImports; + for (size_t i = 0; i < ChunkExportIndices.size(); ++i) + { + ExtendableStringBuilder<32> Fragment; + Fragment << "exp=" << ChunkExportIndices[i]; + const int ImportIdx = Packet.AddImport(CurrentPacketBaseIdx, std::string(Fragment.ToView())); + InteriorImports.push_back(ImportIdx); + } + + Packet.BeginExport(); + + // Write interior payload: [hash(20)][type(1)] per child + for (size_t i = 0; i < ChunkHashes.size(); ++i) + { + Packet.WritePayload(ChunkHashes[i].Hash, sizeof(IoHash)); + const uint8_t NodeType = 1; // ChunkNode type + Packet.WritePayload(&NodeType, 1); + } + + // Hash the interior payload to get the interior node hash + const IoHash InteriorHash = IoHash::HashBuffer(Packet.Data.data() + (Packet.CurrentExportStart + 4), + Packet.Data.size() - (Packet.CurrentExportStart + 4)); + + const int InteriorExportIndex = Packet.CompleteExport(BlobType_ChunkInteriorV2, InteriorImports); + + Info.RootExportHash = InteriorHash; + + // Add import for directory to reference this interior node + ExtendableStringBuilder<32> Fragment; + Fragment << "exp=" << InteriorExportIndex; + Info.DirectoryExportImportIndex = Packet.AddImport(CurrentPacketBaseIdx, std::string(Fragment.ToView())); + } + } + + // Create directory node export + // Payload: [flags(varint=0)] [file_count(varint)] [file_entries...] [dir_count(varint=0)] + // FileEntry: [import(varint)] [IoHash(20)] [name(string)] [flags(varint)] [length(varint)] [IoHash_stream(20)] + + Packet.BeginExport(); + + // Build directory payload into a temporary buffer, then write it + std::vector<uint8_t> DirPayload; + WriteVarInt(DirPayload, 0); // flags + WriteVarInt(DirPayload, ValidFiles.size()); // file_count + + std::vector<int> DirImports; + for (size_t i = 0; i < ValidFiles.size(); ++i) + { + FileInfo& Info = ValidFiles[i]; + DirImports.push_back(Info.DirectoryExportImportIndex); + + // IoHash of target (20 bytes) — import is consumed sequentially from the + // export's import list by ReadBlobRef, not encoded in the payload + WriteBytes(DirPayload, Info.RootExportHash.Hash, sizeof(IoHash)); + // name (string) + WriteString(DirPayload, Info.Name); + // flags (varint): 1 = Executable + WriteVarInt(DirPayload, 1); + // length (varint) + WriteVarInt(DirPayload, static_cast<size_t>(Info.FileSize)); + // stream hash: IoHash from full BLAKE3, truncated to 20 bytes + const IoHash StreamIoHash = IoHash::FromBLAKE3(Info.StreamHash); + WriteBytes(DirPayload, StreamIoHash.Hash, sizeof(IoHash)); + } + + WriteVarInt(DirPayload, 0); // dir_count + + Packet.WritePayload(DirPayload.data(), DirPayload.size()); + const int DirExportIndex = Packet.CompleteExport(BlobType_DirectoryV1, DirImports); + + // Finalize packet and encode + std::vector<uint8_t> UncompressedPacket = Packet.Finish(); + std::vector<uint8_t> EncodedPacket = EncodePacket(std::move(UncompressedPacket)); + + // Write .blob file + const std::filesystem::path BlobFilePath = OutputDir / (BlobName + ".blob"); + { + BasicFile BlobFile(BlobFilePath, BasicFile::Mode::kTruncate, Ec); + if (Ec) + { + ZEN_ERROR("failed to create blob file: {}", BlobFilePath.string()); + return false; + } + BlobFile.Write(EncodedPacket.data(), EncodedPacket.size(), 0); + } + + // Build locator: <blob_name>#pkt=0,<encoded_len>&exp=<dir_export_index> + ExtendableStringBuilder<256> Locator; + Locator << BlobName << "#pkt=0," << uint64_t(EncodedPacket.size()) << "&exp=" << DirExportIndex; + const std::string LocatorStr(Locator.ToView()); + + // Write .ref file (use first file's name as the ref base) + const std::filesystem::path RefFilePath = OutputDir / (ValidFiles[0].Name + ".Bundle.ref"); + { + BasicFile RefFile(RefFilePath, BasicFile::Mode::kTruncate, Ec); + if (Ec) + { + ZEN_ERROR("failed to create ref file: {}", RefFilePath.string()); + return false; + } + RefFile.Write(LocatorStr.data(), LocatorStr.size(), 0); + } + + OutResult.Locator = LocatorStr; + OutResult.BundleDir = OutputDir; + + ZEN_INFO("created V2 bundle: blob={}.blob locator={} files={}", BlobName, LocatorStr, ValidFiles.size()); + return true; +} + +bool +BundleCreator::ReadLocator(const std::filesystem::path& RefFile, std::string& OutLocator) +{ + BasicFile File; + std::error_code Ec; + File.Open(RefFile, BasicFile::Mode::kRead, Ec); + if (Ec) + { + return false; + } + + IoBuffer Content = File.ReadAll(); + OutLocator.assign(static_cast<const char*>(Content.GetData()), Content.GetSize()); + + // Strip trailing whitespace/newlines + while (!OutLocator.empty() && (OutLocator.back() == '\n' || OutLocator.back() == '\r' || OutLocator.back() == '\0')) + { + OutLocator.pop_back(); + } + + return !OutLocator.empty(); +} + +} // namespace zen::horde diff --git a/src/zenhorde/hordebundle.h b/src/zenhorde/hordebundle.h new file mode 100644 index 000000000..052f60435 --- /dev/null +++ b/src/zenhorde/hordebundle.h @@ -0,0 +1,49 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <filesystem> +#include <string> +#include <vector> + +namespace zen::horde { + +/** Describes a file to include in a Horde bundle. */ +struct BundleFile +{ + std::filesystem::path Path; ///< Local file path + bool Optional; ///< If true, skip without error if missing +}; + +/** Result of a successful bundle creation. */ +struct BundleResult +{ + std::string Locator; ///< Root directory locator for WriteFiles + std::filesystem::path BundleDir; ///< Directory containing .blob files +}; + +/** Creates Horde V2 bundles from local files for upload to remote agents. + * + * Produces a proper Horde storage V2 bundle containing: + * - Chunk leaf exports for file data (split into 64KB chunks for large files) + * - Optional interior chunk nodes referencing leaf chunks + * - A directory node listing all bundled files with metadata + * + * The bundle is written as a single .blob file with a corresponding .ref file + * containing the locator string. The locator format is: + * <blob_name>#pkt=0,<encoded_len>&exp=<directory_export_index> + */ +struct BundleCreator +{ + /** Create a V2 bundle from one or more input files. + * @param Files Files to include in the bundle. + * @param OutputDir Directory where .blob and .ref files will be written. + * @param OutResult Receives the locator and output directory on success. + * @return True on success. */ + static bool CreateBundle(const std::vector<BundleFile>& Files, const std::filesystem::path& OutputDir, BundleResult& OutResult); + + /** Read a locator string from a .ref file. Strips trailing whitespace/newlines. */ + static bool ReadLocator(const std::filesystem::path& RefFile, std::string& OutLocator); +}; + +} // namespace zen::horde diff --git a/src/zenhorde/hordeclient.cpp b/src/zenhorde/hordeclient.cpp new file mode 100644 index 000000000..fb981f0ba --- /dev/null +++ b/src/zenhorde/hordeclient.cpp @@ -0,0 +1,382 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencore/fmtutils.h> +#include <zencore/iobuffer.h> +#include <zencore/logging.h> +#include <zencore/memoryview.h> +#include <zencore/trace.h> +#include <zenhorde/hordeclient.h> +#include <zenhttp/httpclient.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <json11.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen::horde { + +HordeClient::HordeClient(const HordeConfig& Config) : m_Config(Config), m_Log(zen::logging::Get("horde.client")) +{ +} + +HordeClient::~HordeClient() = default; + +bool +HordeClient::Initialize() +{ + ZEN_TRACE_CPU("HordeClient::Initialize"); + + HttpClientSettings Settings; + Settings.LogCategory = "horde.http"; + Settings.ConnectTimeout = std::chrono::milliseconds{10000}; + Settings.Timeout = std::chrono::milliseconds{60000}; + Settings.RetryCount = 1; + Settings.ExpectedErrorCodes = {HttpResponseCode::ServiceUnavailable, HttpResponseCode::TooManyRequests}; + + if (!m_Config.AuthToken.empty()) + { + Settings.AccessTokenProvider = [token = m_Config.AuthToken]() -> HttpClientAccessToken { + HttpClientAccessToken Token; + Token.Value = token; + Token.ExpireTime = HttpClientAccessToken::Clock::now() + std::chrono::hours{24}; + return Token; + }; + } + + m_Http = std::make_unique<zen::HttpClient>(m_Config.ServerUrl, Settings); + + if (!m_Config.AuthToken.empty()) + { + if (!m_Http->Authenticate()) + { + ZEN_WARN("failed to authenticate with Horde server"); + return false; + } + } + + return true; +} + +std::string +HordeClient::BuildRequestBody() const +{ + json11::Json::object Requirements; + + if (m_Config.Mode == ConnectionMode::Direct && !m_Config.Pool.empty()) + { + Requirements["pool"] = m_Config.Pool; + } + + std::string Condition; +#if ZEN_PLATFORM_WINDOWS + ExtendableStringBuilder<256> CondBuf; + CondBuf << "(OSFamily == 'Windows' || WineEnabled == '" << (m_Config.AllowWine ? "true" : "false") << "')"; + Condition = std::string(CondBuf); +#elif ZEN_PLATFORM_MAC + Condition = "OSFamily == 'MacOS'"; +#else + Condition = "OSFamily == 'Linux'"; +#endif + + if (!m_Config.Condition.empty()) + { + Condition += " "; + Condition += m_Config.Condition; + } + + Requirements["condition"] = Condition; + Requirements["exclusive"] = true; + + json11::Json::object Connection; + Connection["modePreference"] = ToString(m_Config.Mode); + + if (m_Config.EncryptionMode != Encryption::None) + { + Connection["encryption"] = ToString(m_Config.EncryptionMode); + } + + // Request configured zen service port to be forwarded. The Horde agent will map this + // to a local port on the provisioned machine and report it back in the response. + json11::Json::object PortsObj; + PortsObj["ZenPort"] = json11::Json(m_Config.ZenServicePort); + Connection["ports"] = PortsObj; + + json11::Json::object Root; + Root["requirements"] = Requirements; + Root["connection"] = Connection; + + return json11::Json(Root).dump(); +} + +bool +HordeClient::ResolveCluster(const std::string& RequestBody, ClusterInfo& OutCluster) +{ + ZEN_TRACE_CPU("HordeClient::ResolveCluster"); + + const IoBuffer Payload = IoBufferBuilder::MakeFromMemory(MemoryView{RequestBody.data(), RequestBody.size()}, ZenContentType::kJSON); + + const HttpClient::Response Response = m_Http->Post("api/v2/compute/_cluster", Payload); + + if (Response.Error) + { + ZEN_WARN("cluster resolution failed: {}", Response.Error->ErrorMessage); + return false; + } + + const int StatusCode = static_cast<int>(Response.StatusCode); + + if (StatusCode == 503 || StatusCode == 429) + { + ZEN_DEBUG("cluster resolution returned HTTP/{}: no resources", StatusCode); + return false; + } + + if (StatusCode == 401) + { + ZEN_WARN("cluster resolution returned HTTP/401: token expired"); + return false; + } + + if (!Response.IsSuccess()) + { + ZEN_WARN("cluster resolution failed with HTTP/{}", StatusCode); + return false; + } + + const std::string Body(Response.AsText()); + std::string Err; + const json11::Json Json = json11::Json::parse(Body, Err); + + if (!Err.empty()) + { + ZEN_WARN("invalid JSON response for cluster resolution: {}", Err); + return false; + } + + const json11::Json ClusterIdVal = Json["clusterId"]; + if (!ClusterIdVal.is_string() || ClusterIdVal.string_value().empty()) + { + ZEN_WARN("missing 'clusterId' in cluster resolution response"); + return false; + } + + OutCluster.ClusterId = ClusterIdVal.string_value(); + return true; +} + +bool +HordeClient::ParseHexBytes(std::string_view Hex, uint8_t* Out, size_t OutSize) +{ + if (Hex.size() != OutSize * 2) + { + return false; + } + + for (size_t i = 0; i < OutSize; ++i) + { + auto HexToByte = [](char c) -> int { + if (c >= '0' && c <= '9') + return c - '0'; + if (c >= 'a' && c <= 'f') + return c - 'a' + 10; + if (c >= 'A' && c <= 'F') + return c - 'A' + 10; + return -1; + }; + + const int Hi = HexToByte(Hex[i * 2]); + const int Lo = HexToByte(Hex[i * 2 + 1]); + if (Hi < 0 || Lo < 0) + { + return false; + } + Out[i] = static_cast<uint8_t>((Hi << 4) | Lo); + } + + return true; +} + +bool +HordeClient::RequestMachine(const std::string& RequestBody, const std::string& ClusterId, MachineInfo& OutMachine) +{ + ZEN_TRACE_CPU("HordeClient::RequestMachine"); + + ZEN_INFO("requesting machine from Horde with cluster '{}'", ClusterId.empty() ? "default" : ClusterId.c_str()); + + ExtendableStringBuilder<128> ResourcePath; + ResourcePath << "api/v2/compute/" << (ClusterId.empty() ? "default" : ClusterId.c_str()); + + const IoBuffer Payload = IoBufferBuilder::MakeFromMemory(MemoryView{RequestBody.data(), RequestBody.size()}, ZenContentType::kJSON); + const HttpClient::Response Response = m_Http->Post(ResourcePath.ToView(), Payload); + + // Reset output to invalid state + OutMachine = {}; + OutMachine.Port = 0xFFFF; + + if (Response.Error) + { + ZEN_WARN("machine request failed: {}", Response.Error->ErrorMessage); + return false; + } + + const int StatusCode = static_cast<int>(Response.StatusCode); + + if (StatusCode == 404 || StatusCode == 503 || StatusCode == 429) + { + ZEN_DEBUG("machine request returned HTTP/{}: no resources", StatusCode); + return false; + } + + if (StatusCode == 401) + { + ZEN_WARN("machine request returned HTTP/401: token expired"); + return false; + } + + if (!Response.IsSuccess()) + { + ZEN_WARN("machine request failed with HTTP/{}", StatusCode); + return false; + } + + const std::string Body(Response.AsText()); + std::string Err; + const json11::Json Json = json11::Json::parse(Body, Err); + + if (!Err.empty()) + { + ZEN_WARN("invalid JSON response for machine request: {}", Err); + return false; + } + + // Required fields + const json11::Json NonceVal = Json["nonce"]; + const json11::Json IpVal = Json["ip"]; + const json11::Json PortVal = Json["port"]; + + if (!NonceVal.is_string() || !IpVal.is_string() || !PortVal.is_number()) + { + ZEN_WARN("missing 'nonce', 'ip', or 'port' in machine response"); + return false; + } + + OutMachine.Ip = IpVal.string_value(); + OutMachine.Port = static_cast<uint16_t>(PortVal.int_value()); + + if (!ParseHexBytes(NonceVal.string_value(), OutMachine.Nonce, NonceSize)) + { + ZEN_WARN("invalid nonce hex string in machine response"); + return false; + } + + if (const json11::Json PortsVal = Json["ports"]; PortsVal.is_object()) + { + for (const auto& [Key, Val] : PortsVal.object_items()) + { + PortInfo Info; + if (Val["port"].is_number()) + { + Info.Port = static_cast<uint16_t>(Val["port"].int_value()); + } + if (Val["agentPort"].is_number()) + { + Info.AgentPort = static_cast<uint16_t>(Val["agentPort"].int_value()); + } + OutMachine.Ports[Key] = Info; + } + } + + if (const json11::Json ConnectionModeVal = Json["connectionMode"]; ConnectionModeVal.is_string()) + { + if (FromString(OutMachine.Mode, ConnectionModeVal.string_value())) + { + if (const json11::Json ConnectionAddressVal = Json["connectionAddress"]; ConnectionAddressVal.is_string()) + { + OutMachine.ConnectionAddress = ConnectionAddressVal.string_value(); + } + } + } + + // Properties are a flat string array of "Key=Value" pairs describing the machine. + // We extract OS family and core counts for sizing decisions. If neither core count + // is available, we fall back to 16 as a conservative default. + uint16_t LogicalCores = 0; + uint16_t PhysicalCores = 0; + + if (const json11::Json PropertiesVal = Json["properties"]; PropertiesVal.is_array()) + { + for (const json11::Json& PropVal : PropertiesVal.array_items()) + { + if (!PropVal.is_string()) + { + continue; + } + + const std::string Prop = PropVal.string_value(); + if (Prop.starts_with("OSFamily=")) + { + if (Prop.substr(9) == "Windows") + { + OutMachine.IsWindows = true; + } + } + else if (Prop.starts_with("LogicalCores=")) + { + LogicalCores = static_cast<uint16_t>(std::atoi(Prop.c_str() + 13)); + } + else if (Prop.starts_with("PhysicalCores=")) + { + PhysicalCores = static_cast<uint16_t>(std::atoi(Prop.c_str() + 14)); + } + } + } + + if (LogicalCores > 0) + { + OutMachine.LogicalCores = LogicalCores; + } + else if (PhysicalCores > 0) + { + OutMachine.LogicalCores = PhysicalCores * 2; + } + else + { + OutMachine.LogicalCores = 16; + } + + if (const json11::Json EncryptionVal = Json["encryption"]; EncryptionVal.is_string()) + { + if (FromString(OutMachine.EncryptionMode, EncryptionVal.string_value())) + { + if (OutMachine.EncryptionMode == Encryption::AES) + { + const json11::Json KeyVal = Json["key"]; + if (KeyVal.is_string() && !KeyVal.string_value().empty()) + { + if (!ParseHexBytes(KeyVal.string_value(), OutMachine.Key, KeySize)) + { + ZEN_WARN("invalid AES key in machine response"); + } + } + else + { + ZEN_WARN("AES encryption requested but no key provided"); + } + } + } + } + + if (const json11::Json LeaseIdVal = Json["leaseId"]; LeaseIdVal.is_string()) + { + OutMachine.LeaseId = LeaseIdVal.string_value(); + } + + ZEN_INFO("Horde machine assigned [{}:{}] cores={} lease={}", + OutMachine.GetConnectionAddress(), + OutMachine.GetConnectionPort(), + OutMachine.LogicalCores, + OutMachine.LeaseId); + + return true; +} + +} // namespace zen::horde diff --git a/src/zenhorde/hordecomputebuffer.cpp b/src/zenhorde/hordecomputebuffer.cpp new file mode 100644 index 000000000..0d032b5d5 --- /dev/null +++ b/src/zenhorde/hordecomputebuffer.cpp @@ -0,0 +1,454 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "hordecomputebuffer.h" + +#include <algorithm> +#include <cassert> +#include <chrono> +#include <condition_variable> +#include <cstring> + +namespace zen::horde { + +// Simplified ring buffer implementation for in-process use only. +// Uses a single contiguous buffer with write/read cursors and +// mutex+condvar for synchronization. This is simpler than the UE version +// which uses lock-free atomics and shared memory, but sufficient for our +// use case where we're the initiator side of the compute protocol. + +struct ComputeBuffer::Detail : TRefCounted<Detail> +{ + std::vector<uint8_t> Data; + size_t NumChunks = 0; + size_t ChunkLength = 0; + + // Current write state + size_t WriteChunkIdx = 0; + size_t WriteOffset = 0; + bool WriteComplete = false; + + // Current read state + size_t ReadChunkIdx = 0; + size_t ReadOffset = 0; + bool Detached = false; + + // Per-chunk written length + std::vector<size_t> ChunkWrittenLength; + std::vector<bool> ChunkFinished; // Writer moved to next chunk + + std::mutex Mutex; + std::condition_variable ReadCV; ///< Signaled when new data is written or stream completes + std::condition_variable WriteCV; ///< Signaled when reader advances past a chunk, freeing space + + bool HasWriter = false; + bool HasReader = false; + + uint8_t* ChunkPtr(size_t ChunkIdx) { return Data.data() + ChunkIdx * ChunkLength; } + const uint8_t* ChunkPtr(size_t ChunkIdx) const { return Data.data() + ChunkIdx * ChunkLength; } +}; + +// ComputeBuffer + +ComputeBuffer::ComputeBuffer() +{ +} +ComputeBuffer::~ComputeBuffer() +{ +} + +bool +ComputeBuffer::CreateNew(const Params& InParams) +{ + auto* NewDetail = new Detail(); + NewDetail->NumChunks = InParams.NumChunks; + NewDetail->ChunkLength = InParams.ChunkLength; + NewDetail->Data.resize(InParams.NumChunks * InParams.ChunkLength, 0); + NewDetail->ChunkWrittenLength.resize(InParams.NumChunks, 0); + NewDetail->ChunkFinished.resize(InParams.NumChunks, false); + + m_Detail = NewDetail; + return true; +} + +void +ComputeBuffer::Close() +{ + m_Detail = nullptr; +} + +bool +ComputeBuffer::IsValid() const +{ + return static_cast<bool>(m_Detail); +} + +ComputeBufferReader +ComputeBuffer::CreateReader() +{ + assert(m_Detail); + m_Detail->HasReader = true; + return ComputeBufferReader(m_Detail); +} + +ComputeBufferWriter +ComputeBuffer::CreateWriter() +{ + assert(m_Detail); + m_Detail->HasWriter = true; + return ComputeBufferWriter(m_Detail); +} + +// ComputeBufferReader + +ComputeBufferReader::ComputeBufferReader() +{ +} +ComputeBufferReader::~ComputeBufferReader() +{ +} + +ComputeBufferReader::ComputeBufferReader(const ComputeBufferReader& Other) = default; +ComputeBufferReader::ComputeBufferReader(ComputeBufferReader&& Other) noexcept = default; +ComputeBufferReader& ComputeBufferReader::operator=(const ComputeBufferReader& Other) = default; +ComputeBufferReader& ComputeBufferReader::operator=(ComputeBufferReader&& Other) noexcept = default; + +ComputeBufferReader::ComputeBufferReader(Ref<ComputeBuffer::Detail> InDetail) : m_Detail(std::move(InDetail)) +{ +} + +void +ComputeBufferReader::Close() +{ + m_Detail = nullptr; +} + +void +ComputeBufferReader::Detach() +{ + if (m_Detail) + { + std::lock_guard<std::mutex> Lock(m_Detail->Mutex); + m_Detail->Detached = true; + m_Detail->ReadCV.notify_all(); + } +} + +bool +ComputeBufferReader::IsValid() const +{ + return static_cast<bool>(m_Detail); +} + +bool +ComputeBufferReader::IsComplete() const +{ + if (!m_Detail) + { + return true; + } + std::lock_guard<std::mutex> Lock(m_Detail->Mutex); + if (m_Detail->Detached) + { + return true; + } + return m_Detail->WriteComplete && m_Detail->ReadChunkIdx == m_Detail->WriteChunkIdx && + m_Detail->ReadOffset >= m_Detail->ChunkWrittenLength[m_Detail->ReadChunkIdx]; +} + +void +ComputeBufferReader::AdvanceReadPosition(size_t Size) +{ + if (!m_Detail) + { + return; + } + + std::lock_guard<std::mutex> Lock(m_Detail->Mutex); + + m_Detail->ReadOffset += Size; + + // Check if we need to move to next chunk + const size_t ReadChunk = m_Detail->ReadChunkIdx; + if (m_Detail->ChunkFinished[ReadChunk] && m_Detail->ReadOffset >= m_Detail->ChunkWrittenLength[ReadChunk]) + { + const size_t NextChunk = (ReadChunk + 1) % m_Detail->NumChunks; + m_Detail->ReadChunkIdx = NextChunk; + m_Detail->ReadOffset = 0; + m_Detail->WriteCV.notify_all(); + } + + m_Detail->ReadCV.notify_all(); +} + +size_t +ComputeBufferReader::GetMaxReadSize() const +{ + if (!m_Detail) + { + return 0; + } + std::lock_guard<std::mutex> Lock(m_Detail->Mutex); + const size_t ReadChunk = m_Detail->ReadChunkIdx; + return m_Detail->ChunkWrittenLength[ReadChunk] - m_Detail->ReadOffset; +} + +const uint8_t* +ComputeBufferReader::WaitToRead(size_t MinSize, int TimeoutMs, bool* OutTimedOut) +{ + if (!m_Detail) + { + return nullptr; + } + + std::unique_lock<std::mutex> Lock(m_Detail->Mutex); + + auto Predicate = [&]() -> bool { + if (m_Detail->Detached) + { + return true; + } + + const size_t ReadChunk = m_Detail->ReadChunkIdx; + const size_t Available = m_Detail->ChunkWrittenLength[ReadChunk] - m_Detail->ReadOffset; + + if (Available >= MinSize) + { + return true; + } + + // If chunk is finished and we've read everything, try to move to next + if (m_Detail->ChunkFinished[ReadChunk] && m_Detail->ReadOffset >= m_Detail->ChunkWrittenLength[ReadChunk]) + { + if (m_Detail->WriteComplete) + { + return true; // End of stream + } + // Move to next chunk + const size_t NextChunk = (ReadChunk + 1) % m_Detail->NumChunks; + m_Detail->ReadChunkIdx = NextChunk; + m_Detail->ReadOffset = 0; + m_Detail->WriteCV.notify_all(); + return false; // Re-check with new chunk + } + + if (m_Detail->WriteComplete) + { + return true; // End of stream + } + + return false; + }; + + if (TimeoutMs < 0) + { + m_Detail->ReadCV.wait(Lock, Predicate); + } + else + { + if (!m_Detail->ReadCV.wait_for(Lock, std::chrono::milliseconds(TimeoutMs), Predicate)) + { + if (OutTimedOut) + { + *OutTimedOut = true; + } + return nullptr; + } + } + + if (m_Detail->Detached) + { + return nullptr; + } + + const size_t ReadChunk = m_Detail->ReadChunkIdx; + const size_t Available = m_Detail->ChunkWrittenLength[ReadChunk] - m_Detail->ReadOffset; + + if (Available < MinSize) + { + return nullptr; // End of stream + } + + return m_Detail->ChunkPtr(ReadChunk) + m_Detail->ReadOffset; +} + +size_t +ComputeBufferReader::Read(void* Buffer, size_t MaxSize, int TimeoutMs, bool* OutTimedOut) +{ + const uint8_t* Data = WaitToRead(1, TimeoutMs, OutTimedOut); + if (!Data) + { + return 0; + } + + const size_t Available = GetMaxReadSize(); + const size_t ToCopy = std::min(Available, MaxSize); + memcpy(Buffer, Data, ToCopy); + AdvanceReadPosition(ToCopy); + return ToCopy; +} + +// ComputeBufferWriter + +ComputeBufferWriter::ComputeBufferWriter() = default; +ComputeBufferWriter::ComputeBufferWriter(const ComputeBufferWriter& Other) = default; +ComputeBufferWriter::ComputeBufferWriter(ComputeBufferWriter&& Other) noexcept = default; +ComputeBufferWriter::~ComputeBufferWriter() = default; +ComputeBufferWriter& ComputeBufferWriter::operator=(const ComputeBufferWriter& Other) = default; +ComputeBufferWriter& ComputeBufferWriter::operator=(ComputeBufferWriter&& Other) noexcept = default; + +ComputeBufferWriter::ComputeBufferWriter(Ref<ComputeBuffer::Detail> InDetail) : m_Detail(std::move(InDetail)) +{ +} + +void +ComputeBufferWriter::Close() +{ + if (m_Detail) + { + { + std::lock_guard<std::mutex> Lock(m_Detail->Mutex); + if (!m_Detail->WriteComplete) + { + m_Detail->WriteComplete = true; + m_Detail->ReadCV.notify_all(); + } + } + m_Detail = nullptr; + } +} + +bool +ComputeBufferWriter::IsValid() const +{ + return static_cast<bool>(m_Detail); +} + +void +ComputeBufferWriter::MarkComplete() +{ + if (m_Detail) + { + std::lock_guard<std::mutex> Lock(m_Detail->Mutex); + m_Detail->WriteComplete = true; + m_Detail->ReadCV.notify_all(); + } +} + +void +ComputeBufferWriter::AdvanceWritePosition(size_t Size) +{ + if (!m_Detail || Size == 0) + { + return; + } + + std::lock_guard<std::mutex> Lock(m_Detail->Mutex); + const size_t WriteChunk = m_Detail->WriteChunkIdx; + m_Detail->ChunkWrittenLength[WriteChunk] += Size; + m_Detail->WriteOffset += Size; + m_Detail->ReadCV.notify_all(); +} + +size_t +ComputeBufferWriter::GetMaxWriteSize() const +{ + if (!m_Detail) + { + return 0; + } + std::lock_guard<std::mutex> Lock(m_Detail->Mutex); + const size_t WriteChunk = m_Detail->WriteChunkIdx; + return m_Detail->ChunkLength - m_Detail->ChunkWrittenLength[WriteChunk]; +} + +size_t +ComputeBufferWriter::GetChunkMaxLength() const +{ + if (!m_Detail) + { + return 0; + } + return m_Detail->ChunkLength; +} + +size_t +ComputeBufferWriter::Write(const void* Buffer, size_t MaxSize, int TimeoutMs) +{ + uint8_t* Dest = WaitToWrite(1, TimeoutMs); + if (!Dest) + { + return 0; + } + + const size_t Available = GetMaxWriteSize(); + const size_t ToCopy = std::min(Available, MaxSize); + memcpy(Dest, Buffer, ToCopy); + AdvanceWritePosition(ToCopy); + return ToCopy; +} + +uint8_t* +ComputeBufferWriter::WaitToWrite(size_t MinSize, int TimeoutMs) +{ + if (!m_Detail) + { + return nullptr; + } + + std::unique_lock<std::mutex> Lock(m_Detail->Mutex); + + if (m_Detail->WriteComplete) + { + return nullptr; + } + + const size_t WriteChunk = m_Detail->WriteChunkIdx; + const size_t Available = m_Detail->ChunkLength - m_Detail->ChunkWrittenLength[WriteChunk]; + + // If current chunk has enough space, return pointer + if (Available >= MinSize) + { + return m_Detail->ChunkPtr(WriteChunk) + m_Detail->ChunkWrittenLength[WriteChunk]; + } + + // Current chunk is full - mark it as finished and move to next. + // The writer cannot advance until the reader has fully consumed the next chunk, + // preventing the writer from overwriting data the reader hasn't processed yet. + m_Detail->ChunkFinished[WriteChunk] = true; + m_Detail->ReadCV.notify_all(); + + const size_t NextChunk = (WriteChunk + 1) % m_Detail->NumChunks; + + // Wait until reader has consumed the next chunk + auto Predicate = [&]() -> bool { + // Check if read has moved past this chunk + return m_Detail->ReadChunkIdx != NextChunk || m_Detail->Detached; + }; + + if (TimeoutMs < 0) + { + m_Detail->WriteCV.wait(Lock, Predicate); + } + else + { + if (!m_Detail->WriteCV.wait_for(Lock, std::chrono::milliseconds(TimeoutMs), Predicate)) + { + return nullptr; + } + } + + if (m_Detail->Detached) + { + return nullptr; + } + + // Reset next chunk + m_Detail->ChunkWrittenLength[NextChunk] = 0; + m_Detail->ChunkFinished[NextChunk] = false; + m_Detail->WriteChunkIdx = NextChunk; + m_Detail->WriteOffset = 0; + + return m_Detail->ChunkPtr(NextChunk); +} + +} // namespace zen::horde diff --git a/src/zenhorde/hordecomputebuffer.h b/src/zenhorde/hordecomputebuffer.h new file mode 100644 index 000000000..64ef91b7a --- /dev/null +++ b/src/zenhorde/hordecomputebuffer.h @@ -0,0 +1,136 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenbase/refcount.h> + +#include <cstddef> +#include <cstdint> +#include <mutex> +#include <vector> + +namespace zen::horde { + +class ComputeBufferReader; +class ComputeBufferWriter; + +/** Simplified in-process ring buffer for the Horde compute protocol. + * + * Unlike the UE FComputeBuffer which supports shared-memory and memory-mapped files, + * this implementation uses plain heap-allocated memory since we only need in-process + * communication between channel and transport threads. The buffer is divided into + * fixed-size chunks; readers and writers block when no space is available. + */ +class ComputeBuffer +{ +public: + struct Params + { + size_t NumChunks = 2; + size_t ChunkLength = 512 * 1024; + }; + + ComputeBuffer(); + ~ComputeBuffer(); + + ComputeBuffer(const ComputeBuffer&) = delete; + ComputeBuffer& operator=(const ComputeBuffer&) = delete; + + bool CreateNew(const Params& InParams); + void Close(); + + bool IsValid() const; + + ComputeBufferReader CreateReader(); + ComputeBufferWriter CreateWriter(); + +private: + struct Detail; + Ref<Detail> m_Detail; + + friend class ComputeBufferReader; + friend class ComputeBufferWriter; +}; + +/** Read endpoint for a ComputeBuffer. + * + * Provides blocking reads from the ring buffer. WaitToRead() returns a pointer + * directly into the buffer memory (zero-copy); the caller must call + * AdvanceReadPosition() after consuming the data. + */ +class ComputeBufferReader +{ +public: + ComputeBufferReader(); + ComputeBufferReader(const ComputeBufferReader&); + ComputeBufferReader(ComputeBufferReader&&) noexcept; + ~ComputeBufferReader(); + + ComputeBufferReader& operator=(const ComputeBufferReader&); + ComputeBufferReader& operator=(ComputeBufferReader&&) noexcept; + + void Close(); + void Detach(); + bool IsValid() const; + bool IsComplete() const; + + void AdvanceReadPosition(size_t Size); + size_t GetMaxReadSize() const; + + /** Copy up to MaxSize bytes from the buffer into Buffer. Blocks until data is available. */ + size_t Read(void* Buffer, size_t MaxSize, int TimeoutMs = -1, bool* OutTimedOut = nullptr); + + /** Wait until at least MinSize bytes are available and return a direct pointer. + * Returns nullptr on timeout or if the writer has completed. */ + const uint8_t* WaitToRead(size_t MinSize, int TimeoutMs = -1, bool* OutTimedOut = nullptr); + +private: + friend class ComputeBuffer; + explicit ComputeBufferReader(Ref<ComputeBuffer::Detail> InDetail); + + Ref<ComputeBuffer::Detail> m_Detail; +}; + +/** Write endpoint for a ComputeBuffer. + * + * Provides blocking writes into the ring buffer. WaitToWrite() returns a pointer + * directly into the buffer memory (zero-copy); the caller must call + * AdvanceWritePosition() after filling the data. Call MarkComplete() to signal + * that no more data will be written. + */ +class ComputeBufferWriter +{ +public: + ComputeBufferWriter(); + ComputeBufferWriter(const ComputeBufferWriter&); + ComputeBufferWriter(ComputeBufferWriter&&) noexcept; + ~ComputeBufferWriter(); + + ComputeBufferWriter& operator=(const ComputeBufferWriter&); + ComputeBufferWriter& operator=(ComputeBufferWriter&&) noexcept; + + void Close(); + bool IsValid() const; + + /** Signal that no more data will be written. Unblocks any waiting readers. */ + void MarkComplete(); + + void AdvanceWritePosition(size_t Size); + size_t GetMaxWriteSize() const; + size_t GetChunkMaxLength() const; + + /** Copy up to MaxSize bytes from Buffer into the ring buffer. Blocks until space is available. */ + size_t Write(const void* Buffer, size_t MaxSize, int TimeoutMs = -1); + + /** Wait until at least MinSize bytes of write space are available and return a direct pointer. + * Returns nullptr on timeout. */ + uint8_t* WaitToWrite(size_t MinSize, int TimeoutMs = -1); + +private: + friend class ComputeBuffer; + explicit ComputeBufferWriter(Ref<ComputeBuffer::Detail> InDetail); + + Ref<ComputeBuffer::Detail> m_Detail; +}; + +} // namespace zen::horde diff --git a/src/zenhorde/hordecomputechannel.cpp b/src/zenhorde/hordecomputechannel.cpp new file mode 100644 index 000000000..ee2a6f327 --- /dev/null +++ b/src/zenhorde/hordecomputechannel.cpp @@ -0,0 +1,37 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "hordecomputechannel.h" + +namespace zen::horde { + +ComputeChannel::ComputeChannel(ComputeBufferReader InReader, ComputeBufferWriter InWriter) +: Reader(std::move(InReader)) +, Writer(std::move(InWriter)) +{ +} + +bool +ComputeChannel::IsValid() const +{ + return Reader.IsValid() && Writer.IsValid(); +} + +size_t +ComputeChannel::Send(const void* Data, size_t Size, int TimeoutMs) +{ + return Writer.Write(Data, Size, TimeoutMs); +} + +size_t +ComputeChannel::Recv(void* Data, size_t Size, int TimeoutMs) +{ + return Reader.Read(Data, Size, TimeoutMs); +} + +void +ComputeChannel::MarkComplete() +{ + Writer.MarkComplete(); +} + +} // namespace zen::horde diff --git a/src/zenhorde/hordecomputechannel.h b/src/zenhorde/hordecomputechannel.h new file mode 100644 index 000000000..c1dff20e4 --- /dev/null +++ b/src/zenhorde/hordecomputechannel.h @@ -0,0 +1,32 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "hordecomputebuffer.h" + +namespace zen::horde { + +/** Bidirectional communication channel using a pair of compute buffers. + * + * Pairs a ComputeBufferReader (for receiving data) with a ComputeBufferWriter + * (for sending data). Used by ComputeSocket to represent one logical channel + * within a multiplexed connection. + */ +class ComputeChannel : public TRefCounted<ComputeChannel> +{ +public: + ComputeBufferReader Reader; + ComputeBufferWriter Writer; + + ComputeChannel(ComputeBufferReader InReader, ComputeBufferWriter InWriter); + + bool IsValid() const; + + size_t Send(const void* Data, size_t Size, int TimeoutMs = -1); + size_t Recv(void* Data, size_t Size, int TimeoutMs = -1); + + /** Signal that no more data will be sent on this channel. */ + void MarkComplete(); +}; + +} // namespace zen::horde diff --git a/src/zenhorde/hordecomputesocket.cpp b/src/zenhorde/hordecomputesocket.cpp new file mode 100644 index 000000000..6ef67760c --- /dev/null +++ b/src/zenhorde/hordecomputesocket.cpp @@ -0,0 +1,204 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "hordecomputesocket.h" + +#include <zencore/logging.h> + +namespace zen::horde { + +ComputeSocket::ComputeSocket(std::unique_ptr<ComputeTransport> Transport) +: m_Log(zen::logging::Get("horde.socket")) +, m_Transport(std::move(Transport)) +{ +} + +ComputeSocket::~ComputeSocket() +{ + // Shutdown order matters: first stop the ping thread, then unblock send threads + // by detaching readers, then join send threads, and finally close the transport + // to unblock the recv thread (which is blocked on RecvMessage). + { + std::lock_guard<std::mutex> Lock(m_PingMutex); + m_PingShouldStop = true; + m_PingCV.notify_all(); + } + + for (auto& Reader : m_Readers) + { + Reader.Detach(); + } + + for (auto& [Id, Thread] : m_SendThreads) + { + if (Thread.joinable()) + { + Thread.join(); + } + } + + m_Transport->Close(); + + if (m_RecvThread.joinable()) + { + m_RecvThread.join(); + } + if (m_PingThread.joinable()) + { + m_PingThread.join(); + } +} + +Ref<ComputeChannel> +ComputeSocket::CreateChannel(int ChannelId) +{ + ComputeBuffer::Params Params; + + ComputeBuffer RecvBuffer; + if (!RecvBuffer.CreateNew(Params)) + { + return {}; + } + + ComputeBuffer SendBuffer; + if (!SendBuffer.CreateNew(Params)) + { + return {}; + } + + Ref<ComputeChannel> Channel(new ComputeChannel(RecvBuffer.CreateReader(), SendBuffer.CreateWriter())); + + // Attach recv buffer writer (transport recv thread writes into this) + { + std::lock_guard<std::mutex> Lock(m_WritersMutex); + m_Writers.emplace(ChannelId, RecvBuffer.CreateWriter()); + } + + // Attach send buffer reader (send thread reads from this) + { + ComputeBufferReader Reader = SendBuffer.CreateReader(); + m_Readers.push_back(Reader); + m_SendThreads.emplace(ChannelId, std::thread(&ComputeSocket::SendThreadProc, this, ChannelId, std::move(Reader))); + } + + return Channel; +} + +void +ComputeSocket::StartCommunication() +{ + m_RecvThread = std::thread(&ComputeSocket::RecvThreadProc, this); + m_PingThread = std::thread(&ComputeSocket::PingThreadProc, this); +} + +void +ComputeSocket::PingThreadProc() +{ + while (true) + { + { + std::unique_lock<std::mutex> Lock(m_PingMutex); + if (m_PingCV.wait_for(Lock, std::chrono::milliseconds(2000), [this] { return m_PingShouldStop; })) + { + break; + } + } + + std::lock_guard<std::mutex> Lock(m_SendMutex); + FrameHeader Header; + Header.Channel = 0; + Header.Size = ControlPing; + m_Transport->SendMessage(&Header, sizeof(Header)); + } +} + +void +ComputeSocket::RecvThreadProc() +{ + // Writers are cached locally to avoid taking m_WritersMutex on every frame. + // The shared m_Writers map is only accessed when a channel is seen for the first time. + std::unordered_map<int, ComputeBufferWriter> CachedWriters; + + FrameHeader Header; + while (m_Transport->RecvMessage(&Header, sizeof(Header))) + { + if (Header.Size >= 0) + { + // Data frame + auto It = CachedWriters.find(Header.Channel); + if (It == CachedWriters.end()) + { + std::lock_guard<std::mutex> Lock(m_WritersMutex); + auto WIt = m_Writers.find(Header.Channel); + if (WIt == m_Writers.end()) + { + ZEN_WARN("recv frame for unknown channel {}", Header.Channel); + // Skip the data + std::vector<uint8_t> Discard(Header.Size); + m_Transport->RecvMessage(Discard.data(), Header.Size); + continue; + } + It = CachedWriters.emplace(Header.Channel, WIt->second).first; + } + + ComputeBufferWriter& Writer = It->second; + uint8_t* Dest = Writer.WaitToWrite(Header.Size); + if (!Dest || !m_Transport->RecvMessage(Dest, Header.Size)) + { + ZEN_WARN("failed to read frame data (channel={}, size={})", Header.Channel, Header.Size); + return; + } + Writer.AdvanceWritePosition(Header.Size); + } + else if (Header.Size == ControlDetach) + { + // Detach the recv buffer for this channel + CachedWriters.erase(Header.Channel); + + std::lock_guard<std::mutex> Lock(m_WritersMutex); + auto It = m_Writers.find(Header.Channel); + if (It != m_Writers.end()) + { + It->second.MarkComplete(); + m_Writers.erase(It); + } + } + else if (Header.Size == ControlPing) + { + // Ping response - ignore + } + else + { + ZEN_WARN("invalid frame header size: {}", Header.Size); + return; + } + } +} + +void +ComputeSocket::SendThreadProc(int Channel, ComputeBufferReader Reader) +{ + // Each channel has its own send thread. All send threads share m_SendMutex + // to serialize writes to the transport, since TCP requires atomic frame writes. + FrameHeader Header; + Header.Channel = Channel; + + const uint8_t* Data; + while ((Data = Reader.WaitToRead(1)) != nullptr) + { + std::lock_guard<std::mutex> Lock(m_SendMutex); + + Header.Size = static_cast<int32_t>(Reader.GetMaxReadSize()); + m_Transport->SendMessage(&Header, sizeof(Header)); + m_Transport->SendMessage(Data, Header.Size); + Reader.AdvanceReadPosition(Header.Size); + } + + if (Reader.IsComplete()) + { + std::lock_guard<std::mutex> Lock(m_SendMutex); + Header.Size = ControlDetach; + m_Transport->SendMessage(&Header, sizeof(Header)); + } +} + +} // namespace zen::horde diff --git a/src/zenhorde/hordecomputesocket.h b/src/zenhorde/hordecomputesocket.h new file mode 100644 index 000000000..0c3cb4195 --- /dev/null +++ b/src/zenhorde/hordecomputesocket.h @@ -0,0 +1,79 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "hordecomputebuffer.h" +#include "hordecomputechannel.h" +#include "hordetransport.h" + +#include <zencore/logbase.h> + +#include <condition_variable> +#include <memory> +#include <mutex> +#include <thread> +#include <unordered_map> +#include <vector> + +namespace zen::horde { + +/** Multiplexed socket that routes data between multiple channels over a single transport. + * + * Each channel is identified by an integer ID and backed by a pair of ComputeBuffers. + * A recv thread demultiplexes incoming frames to channel-specific buffers, while + * per-channel send threads multiplex outgoing data onto the shared transport. + * + * Wire format per frame: [channelId (4B)][size (4B)][data] + * Control messages use negative sizes: -2 = detach (channel closed), -3 = ping. + */ +class ComputeSocket +{ +public: + explicit ComputeSocket(std::unique_ptr<ComputeTransport> Transport); + ~ComputeSocket(); + + ComputeSocket(const ComputeSocket&) = delete; + ComputeSocket& operator=(const ComputeSocket&) = delete; + + /** Create a channel with the given ID. + * Allocates anonymous in-process buffers and spawns a send thread for the channel. */ + Ref<ComputeChannel> CreateChannel(int ChannelId); + + /** Start the recv pump and ping threads. Must be called after all channels are created. */ + void StartCommunication(); + +private: + struct FrameHeader + { + int32_t Channel = 0; + int32_t Size = 0; + }; + + static constexpr int32_t ControlDetach = -2; + static constexpr int32_t ControlPing = -3; + + LoggerRef Log() { return m_Log; } + + void RecvThreadProc(); + void SendThreadProc(int Channel, ComputeBufferReader Reader); + void PingThreadProc(); + + LoggerRef m_Log; + std::unique_ptr<ComputeTransport> m_Transport; + std::mutex m_SendMutex; ///< Serializes writes to the transport + + std::mutex m_WritersMutex; + std::unordered_map<int, ComputeBufferWriter> m_Writers; ///< Recv-side: writers keyed by channel ID + + std::vector<ComputeBufferReader> m_Readers; ///< Send-side: readers for join on destruction + std::unordered_map<int, std::thread> m_SendThreads; ///< One send thread per channel + + std::thread m_RecvThread; + std::thread m_PingThread; + + bool m_PingShouldStop = false; + std::mutex m_PingMutex; + std::condition_variable m_PingCV; +}; + +} // namespace zen::horde diff --git a/src/zenhorde/hordeconfig.cpp b/src/zenhorde/hordeconfig.cpp new file mode 100644 index 000000000..2dca228d9 --- /dev/null +++ b/src/zenhorde/hordeconfig.cpp @@ -0,0 +1,89 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenhorde/hordeconfig.h> + +namespace zen::horde { + +bool +HordeConfig::Validate() const +{ + if (ServerUrl.empty()) + { + return false; + } + + // Relay mode implies AES encryption + if (Mode == ConnectionMode::Relay && EncryptionMode != Encryption::AES) + { + return false; + } + + return true; +} + +const char* +ToString(ConnectionMode Mode) +{ + switch (Mode) + { + case ConnectionMode::Direct: + return "direct"; + case ConnectionMode::Tunnel: + return "tunnel"; + case ConnectionMode::Relay: + return "relay"; + } + return "direct"; +} + +const char* +ToString(Encryption Enc) +{ + switch (Enc) + { + case Encryption::None: + return "none"; + case Encryption::AES: + return "aes"; + } + return "none"; +} + +bool +FromString(ConnectionMode& OutMode, std::string_view Str) +{ + if (Str == "direct") + { + OutMode = ConnectionMode::Direct; + return true; + } + if (Str == "tunnel") + { + OutMode = ConnectionMode::Tunnel; + return true; + } + if (Str == "relay") + { + OutMode = ConnectionMode::Relay; + return true; + } + return false; +} + +bool +FromString(Encryption& OutEnc, std::string_view Str) +{ + if (Str == "none") + { + OutEnc = Encryption::None; + return true; + } + if (Str == "aes") + { + OutEnc = Encryption::AES; + return true; + } + return false; +} + +} // namespace zen::horde diff --git a/src/zenhorde/hordeprovisioner.cpp b/src/zenhorde/hordeprovisioner.cpp new file mode 100644 index 000000000..f88c95da2 --- /dev/null +++ b/src/zenhorde/hordeprovisioner.cpp @@ -0,0 +1,367 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenhorde/hordeclient.h> +#include <zenhorde/hordeprovisioner.h> + +#include "hordeagent.h" +#include "hordebundle.h" + +#include <zencore/fmtutils.h> +#include <zencore/logging.h> +#include <zencore/scopeguard.h> +#include <zencore/thread.h> +#include <zencore/trace.h> + +#include <chrono> +#include <thread> + +namespace zen::horde { + +struct HordeProvisioner::AgentWrapper +{ + std::thread Thread; + std::atomic<bool> ShouldExit{false}; +}; + +HordeProvisioner::HordeProvisioner(const HordeConfig& Config, + const std::filesystem::path& BinariesPath, + const std::filesystem::path& WorkingDir, + std::string_view OrchestratorEndpoint) +: m_Config(Config) +, m_BinariesPath(BinariesPath) +, m_WorkingDir(WorkingDir) +, m_OrchestratorEndpoint(OrchestratorEndpoint) +, m_Log(zen::logging::Get("horde.provisioner")) +{ +} + +HordeProvisioner::~HordeProvisioner() +{ + std::lock_guard<std::mutex> Lock(m_AgentsLock); + for (auto& Agent : m_Agents) + { + Agent->ShouldExit.store(true); + } + for (auto& Agent : m_Agents) + { + if (Agent->Thread.joinable()) + { + Agent->Thread.join(); + } + } +} + +void +HordeProvisioner::SetTargetCoreCount(uint32_t Count) +{ + ZEN_TRACE_CPU("HordeProvisioner::SetTargetCoreCount"); + + m_TargetCoreCount.store(std::min(Count, static_cast<uint32_t>(m_Config.MaxCores))); + + while (m_EstimatedCoreCount.load() < m_TargetCoreCount.load()) + { + if (!m_AskForAgents.load()) + { + return; + } + RequestAgent(); + } + + // Clean up finished agent threads + std::lock_guard<std::mutex> Lock(m_AgentsLock); + for (auto It = m_Agents.begin(); It != m_Agents.end();) + { + if ((*It)->ShouldExit.load()) + { + if ((*It)->Thread.joinable()) + { + (*It)->Thread.join(); + } + It = m_Agents.erase(It); + } + else + { + ++It; + } + } +} + +ProvisioningStats +HordeProvisioner::GetStats() const +{ + ProvisioningStats Stats; + Stats.TargetCoreCount = m_TargetCoreCount.load(); + Stats.EstimatedCoreCount = m_EstimatedCoreCount.load(); + Stats.ActiveCoreCount = m_ActiveCoreCount.load(); + Stats.AgentsActive = m_AgentsActive.load(); + Stats.AgentsRequesting = m_AgentsRequesting.load(); + return Stats; +} + +uint32_t +HordeProvisioner::GetAgentCount() const +{ + std::lock_guard<std::mutex> Lock(m_AgentsLock); + return static_cast<uint32_t>(m_Agents.size()); +} + +void +HordeProvisioner::RequestAgent() +{ + m_EstimatedCoreCount.fetch_add(EstimatedCoresPerAgent); + + std::lock_guard<std::mutex> Lock(m_AgentsLock); + + auto Wrapper = std::make_unique<AgentWrapper>(); + AgentWrapper& Ref = *Wrapper; + Wrapper->Thread = std::thread([this, &Ref] { ThreadAgent(Ref); }); + + m_Agents.push_back(std::move(Wrapper)); +} + +void +HordeProvisioner::ThreadAgent(AgentWrapper& Wrapper) +{ + ZEN_TRACE_CPU("HordeProvisioner::ThreadAgent"); + + static std::atomic<uint32_t> ThreadIndex{0}; + const uint32_t CurrentIndex = ThreadIndex.fetch_add(1); + + zen::SetCurrentThreadName(fmt::format("horde_agent_{}", CurrentIndex)); + + std::unique_ptr<HordeAgent> Agent; + uint32_t MachineCoreCount = 0; + + auto _ = MakeGuard([&] { + if (Agent) + { + Agent->CloseConnection(); + } + Wrapper.ShouldExit.store(true); + }); + + { + // EstimatedCoreCount is incremented speculatively when the agent is requested + // (in RequestAgent) so that SetTargetCoreCount doesn't over-provision. + auto $ = MakeGuard([&] { m_EstimatedCoreCount.fetch_sub(EstimatedCoresPerAgent); }); + + { + ZEN_TRACE_CPU("HordeProvisioner::CreateBundles"); + + std::lock_guard<std::mutex> BundleLock(m_BundleLock); + + if (!m_BundlesCreated) + { + const std::filesystem::path OutputDir = m_WorkingDir / "horde_bundles"; + + std::vector<BundleFile> Files; + +#if ZEN_PLATFORM_WINDOWS + Files.emplace_back(m_BinariesPath / "zenserver.exe", false); +#elif ZEN_PLATFORM_LINUX + Files.emplace_back(m_BinariesPath / "zenserver", false); + Files.emplace_back(m_BinariesPath / "zenserver.debug", true); +#elif ZEN_PLATFORM_MAC + Files.emplace_back(m_BinariesPath / "zenserver", false); +#endif + + BundleResult Result; + if (!BundleCreator::CreateBundle(Files, OutputDir, Result)) + { + ZEN_WARN("failed to create bundle, cannot provision any agents!"); + m_AskForAgents.store(false); + return; + } + + m_Bundles.emplace_back(Result.Locator, Result.BundleDir); + m_BundlesCreated = true; + } + + if (!m_HordeClient) + { + m_HordeClient = std::make_unique<HordeClient>(m_Config); + if (!m_HordeClient->Initialize()) + { + ZEN_WARN("failed to initialize Horde HTTP client, cannot provision any agents!"); + m_AskForAgents.store(false); + return; + } + } + } + + if (!m_AskForAgents.load()) + { + return; + } + + m_AgentsRequesting.fetch_add(1); + auto ReqGuard = MakeGuard([this] { m_AgentsRequesting.fetch_sub(1); }); + + // Simple backoff: if the last machine request failed, wait up to 5 seconds + // before trying again. + // + // Note however that it's possible that multiple threads enter this code at + // the same time if multiple agents are requested at once, and they will all + // see the same last failure time and back off accordingly. We might want to + // use a semaphore or similar to limit the number of concurrent requests. + + if (const uint64_t LastFail = m_LastRequestFailTime.load(); LastFail != 0) + { + auto Now = static_cast<uint64_t>(std::chrono::steady_clock::now().time_since_epoch().count()); + const uint64_t ElapsedNs = Now - LastFail; + const uint64_t ElapsedMs = ElapsedNs / 1'000'000; + if (ElapsedMs < 5000) + { + const uint64_t WaitMs = 5000 - ElapsedMs; + for (uint64_t Waited = 0; Waited < WaitMs && !Wrapper.ShouldExit.load(); Waited += 100) + { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + + if (Wrapper.ShouldExit.load()) + { + return; + } + } + } + + if (m_ActiveCoreCount.load() >= m_TargetCoreCount.load()) + { + return; + } + + std::string RequestBody = m_HordeClient->BuildRequestBody(); + + // Resolve cluster if needed + std::string ClusterId = m_Config.Cluster; + if (ClusterId == HordeConfig::ClusterAuto) + { + ClusterInfo Cluster; + if (!m_HordeClient->ResolveCluster(RequestBody, Cluster)) + { + ZEN_WARN("failed to resolve cluster"); + m_LastRequestFailTime.store(static_cast<uint64_t>(std::chrono::steady_clock::now().time_since_epoch().count())); + return; + } + ClusterId = Cluster.ClusterId; + } + + MachineInfo Machine; + if (!m_HordeClient->RequestMachine(RequestBody, ClusterId, /* out */ Machine) || !Machine.IsValid()) + { + m_LastRequestFailTime.store(static_cast<uint64_t>(std::chrono::steady_clock::now().time_since_epoch().count())); + return; + } + + m_LastRequestFailTime.store(0); + + if (Wrapper.ShouldExit.load()) + { + return; + } + + // Connect to agent and perform handshake + Agent = std::make_unique<HordeAgent>(Machine); + if (!Agent->IsValid()) + { + ZEN_WARN("agent creation failed for {}:{}", Machine.GetConnectionAddress(), Machine.GetConnectionPort()); + return; + } + + if (!Agent->BeginCommunication()) + { + ZEN_WARN("BeginCommunication failed"); + return; + } + + for (auto& [Locator, BundleDir] : m_Bundles) + { + if (Wrapper.ShouldExit.load()) + { + return; + } + + if (!Agent->UploadBinaries(BundleDir, Locator)) + { + ZEN_WARN("UploadBinaries failed"); + return; + } + } + + if (Wrapper.ShouldExit.load()) + { + return; + } + + // Build command line for remote zenserver + std::vector<std::string> ArgStrings; + ArgStrings.push_back("compute"); + ArgStrings.push_back("--http=asio"); + + // TEMP HACK - these should be made fully dynamic + // these are currently here to allow spawning the compute agent locally + // for debugging purposes (i.e with a local Horde Server+Agent setup) + ArgStrings.push_back(fmt::format("--port={}", m_Config.ZenServicePort)); + ArgStrings.push_back("--data-dir=c:\\temp\\123"); + + if (!m_OrchestratorEndpoint.empty()) + { + ExtendableStringBuilder<256> CoordArg; + CoordArg << "--coordinator-endpoint=" << m_OrchestratorEndpoint; + ArgStrings.emplace_back(CoordArg.ToView()); + } + + { + ExtendableStringBuilder<128> IdArg; + IdArg << "--instance-id=horde-" << Machine.LeaseId; + ArgStrings.emplace_back(IdArg.ToView()); + } + + std::vector<const char*> Args; + Args.reserve(ArgStrings.size()); + for (const std::string& Arg : ArgStrings) + { + Args.push_back(Arg.c_str()); + } + +#if ZEN_PLATFORM_WINDOWS + const bool UseWine = !Machine.IsWindows; + const char* AppName = "zenserver.exe"; +#else + const bool UseWine = false; + const char* AppName = "zenserver"; +#endif + + Agent->Execute(AppName, Args.data(), Args.size(), nullptr, nullptr, 0, UseWine); + + ZEN_INFO("remote execution started on [{}:{}] lease={}", + Machine.GetConnectionAddress(), + Machine.GetConnectionPort(), + Machine.LeaseId); + + MachineCoreCount = Machine.LogicalCores; + m_EstimatedCoreCount.fetch_add(MachineCoreCount); + m_ActiveCoreCount.fetch_add(MachineCoreCount); + m_AgentsActive.fetch_add(1); + } + + // Agent poll loop + + auto ActiveGuard = MakeGuard([&]() { + m_EstimatedCoreCount.fetch_sub(MachineCoreCount); + m_ActiveCoreCount.fetch_sub(MachineCoreCount); + m_AgentsActive.fetch_sub(1); + }); + + while (Agent->IsValid() && !Wrapper.ShouldExit.load()) + { + const bool LogOutput = false; + if (!Agent->Poll(LogOutput)) + { + break; + } + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } +} + +} // namespace zen::horde diff --git a/src/zenhorde/hordetransport.cpp b/src/zenhorde/hordetransport.cpp new file mode 100644 index 000000000..69766e73e --- /dev/null +++ b/src/zenhorde/hordetransport.cpp @@ -0,0 +1,169 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "hordetransport.h" + +#include <zencore/logging.h> +#include <zencore/trace.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <asio.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +#if ZEN_PLATFORM_WINDOWS +# undef SendMessage +#endif + +namespace zen::horde { + +// ComputeTransport base + +bool +ComputeTransport::SendMessage(const void* Data, size_t Size) +{ + const uint8_t* Ptr = static_cast<const uint8_t*>(Data); + size_t Remaining = Size; + + while (Remaining > 0) + { + const size_t Sent = Send(Ptr, Remaining); + if (Sent == 0) + { + return false; + } + Ptr += Sent; + Remaining -= Sent; + } + + return true; +} + +bool +ComputeTransport::RecvMessage(void* Data, size_t Size) +{ + uint8_t* Ptr = static_cast<uint8_t*>(Data); + size_t Remaining = Size; + + while (Remaining > 0) + { + const size_t Received = Recv(Ptr, Remaining); + if (Received == 0) + { + return false; + } + Ptr += Received; + Remaining -= Received; + } + + return true; +} + +// TcpComputeTransport - ASIO pimpl + +struct TcpComputeTransport::Impl +{ + asio::io_context IoContext; + asio::ip::tcp::socket Socket; + + Impl() : Socket(IoContext) {} +}; + +// Uses ASIO in synchronous mode only — no async operations or io_context::run(). +// The io_context is only needed because ASIO sockets require one to be constructed. +TcpComputeTransport::TcpComputeTransport(const MachineInfo& Info) +: m_Impl(std::make_unique<Impl>()) +, m_Log(zen::logging::Get("horde.transport")) +{ + ZEN_TRACE_CPU("TcpComputeTransport::Connect"); + + asio::error_code Ec; + + const asio::ip::address Address = asio::ip::make_address(Info.GetConnectionAddress(), Ec); + if (Ec) + { + ZEN_WARN("invalid address '{}': {}", Info.GetConnectionAddress(), Ec.message()); + m_HasErrors = true; + return; + } + + const asio::ip::tcp::endpoint Endpoint(Address, Info.GetConnectionPort()); + + m_Impl->Socket.connect(Endpoint, Ec); + if (Ec) + { + ZEN_WARN("failed to connect to Horde compute [{}:{}]: {}", Info.GetConnectionAddress(), Info.GetConnectionPort(), Ec.message()); + m_HasErrors = true; + return; + } + + // Disable Nagle's algorithm for lower latency + m_Impl->Socket.set_option(asio::ip::tcp::no_delay(true), Ec); +} + +TcpComputeTransport::~TcpComputeTransport() +{ + Close(); +} + +bool +TcpComputeTransport::IsValid() const +{ + return m_Impl && m_Impl->Socket.is_open() && !m_HasErrors && !m_IsClosed; +} + +size_t +TcpComputeTransport::Send(const void* Data, size_t Size) +{ + if (!IsValid()) + { + return 0; + } + + asio::error_code Ec; + const size_t Sent = m_Impl->Socket.send(asio::buffer(Data, Size), 0, Ec); + + if (Ec) + { + m_HasErrors = true; + return 0; + } + + return Sent; +} + +size_t +TcpComputeTransport::Recv(void* Data, size_t Size) +{ + if (!IsValid()) + { + return 0; + } + + asio::error_code Ec; + const size_t Received = m_Impl->Socket.receive(asio::buffer(Data, Size), 0, Ec); + + if (Ec) + { + return 0; + } + + return Received; +} + +void +TcpComputeTransport::MarkComplete() +{ +} + +void +TcpComputeTransport::Close() +{ + if (!m_IsClosed && m_Impl && m_Impl->Socket.is_open()) + { + asio::error_code Ec; + m_Impl->Socket.shutdown(asio::ip::tcp::socket::shutdown_both, Ec); + m_Impl->Socket.close(Ec); + } + m_IsClosed = true; +} + +} // namespace zen::horde diff --git a/src/zenhorde/hordetransport.h b/src/zenhorde/hordetransport.h new file mode 100644 index 000000000..1b178dc0f --- /dev/null +++ b/src/zenhorde/hordetransport.h @@ -0,0 +1,71 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenhorde/hordeclient.h> + +#include <zencore/logbase.h> + +#include <cstddef> +#include <cstdint> +#include <memory> + +#if ZEN_PLATFORM_WINDOWS +# undef SendMessage +#endif + +namespace zen::horde { + +/** Abstract base interface for compute transports. + * + * Matches the UE FComputeTransport pattern. Concrete implementations handle + * the underlying I/O (TCP, AES-wrapped, etc.) while this interface provides + * blocking message helpers on top. + */ +class ComputeTransport +{ +public: + virtual ~ComputeTransport() = default; + + virtual bool IsValid() const = 0; + virtual size_t Send(const void* Data, size_t Size) = 0; + virtual size_t Recv(void* Data, size_t Size) = 0; + virtual void MarkComplete() = 0; + virtual void Close() = 0; + + /** Blocking send that loops until all bytes are transferred. Returns false on error. */ + bool SendMessage(const void* Data, size_t Size); + + /** Blocking receive that loops until all bytes are transferred. Returns false on error. */ + bool RecvMessage(void* Data, size_t Size); +}; + +/** TCP socket transport using ASIO. + * + * Connects to the Horde compute endpoint specified by MachineInfo and provides + * raw TCP send/receive. ASIO internals are hidden behind a pimpl to keep the + * header clean. + */ +class TcpComputeTransport final : public ComputeTransport +{ +public: + explicit TcpComputeTransport(const MachineInfo& Info); + ~TcpComputeTransport() override; + + bool IsValid() const override; + size_t Send(const void* Data, size_t Size) override; + size_t Recv(void* Data, size_t Size) override; + void MarkComplete() override; + void Close() override; + +private: + LoggerRef Log() { return m_Log; } + + struct Impl; + std::unique_ptr<Impl> m_Impl; + LoggerRef m_Log; + bool m_IsClosed = false; + bool m_HasErrors = false; +}; + +} // namespace zen::horde diff --git a/src/zenhorde/hordetransportaes.cpp b/src/zenhorde/hordetransportaes.cpp new file mode 100644 index 000000000..986dd3705 --- /dev/null +++ b/src/zenhorde/hordetransportaes.cpp @@ -0,0 +1,425 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "hordetransportaes.h" + +#include <zencore/logging.h> +#include <zencore/trace.h> + +#include <algorithm> +#include <cstring> +#include <random> + +#if ZEN_PLATFORM_WINDOWS +# include <zencore/windows.h> +# include <bcrypt.h> +# pragma comment(lib, "Bcrypt.lib") +#else +ZEN_THIRD_PARTY_INCLUDES_START +# include <openssl/evp.h> +# include <openssl/err.h> +ZEN_THIRD_PARTY_INCLUDES_END +#endif + +namespace zen::horde { + +struct AesComputeTransport::CryptoContext +{ + uint8_t Key[KeySize] = {}; + uint8_t EncryptNonce[NonceBytes] = {}; + uint8_t DecryptNonce[NonceBytes] = {}; + bool HasErrors = false; + +#if !ZEN_PLATFORM_WINDOWS + EVP_CIPHER_CTX* EncCtx = nullptr; + EVP_CIPHER_CTX* DecCtx = nullptr; +#endif + + CryptoContext(const uint8_t (&InKey)[KeySize]) + { + memcpy(Key, InKey, KeySize); + + // The encrypt nonce is randomly initialized and then deterministically mutated + // per message via UpdateNonce(). The decrypt nonce is not used — it comes from + // the wire (each received message carries its own nonce in the header). + std::random_device Rd; + std::mt19937 Gen(Rd()); + std::uniform_int_distribution<int> Dist(0, 255); + for (auto& Byte : EncryptNonce) + { + Byte = static_cast<uint8_t>(Dist(Gen)); + } + +#if !ZEN_PLATFORM_WINDOWS + // Drain any stale OpenSSL errors + while (ERR_get_error() != 0) + { + } + + EncCtx = EVP_CIPHER_CTX_new(); + EVP_EncryptInit_ex(EncCtx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr); + + DecCtx = EVP_CIPHER_CTX_new(); + EVP_DecryptInit_ex(DecCtx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr); +#endif + } + + ~CryptoContext() + { +#if ZEN_PLATFORM_WINDOWS + SecureZeroMemory(Key, sizeof(Key)); + SecureZeroMemory(EncryptNonce, sizeof(EncryptNonce)); + SecureZeroMemory(DecryptNonce, sizeof(DecryptNonce)); +#else + OPENSSL_cleanse(Key, sizeof(Key)); + OPENSSL_cleanse(EncryptNonce, sizeof(EncryptNonce)); + OPENSSL_cleanse(DecryptNonce, sizeof(DecryptNonce)); + + if (EncCtx) + { + EVP_CIPHER_CTX_free(EncCtx); + } + if (DecCtx) + { + EVP_CIPHER_CTX_free(DecCtx); + } +#endif + } + + void UpdateNonce() + { + uint32_t* N32 = reinterpret_cast<uint32_t*>(EncryptNonce); + N32[0]++; + N32[1]--; + N32[2] = N32[0] ^ N32[1]; + } + + // Returns total encrypted message size, or 0 on failure + // Output format: [length(4B)][nonce(12B)][ciphertext][tag(16B)] + int32_t EncryptMessage(uint8_t* Out, const void* In, int32_t InLength) + { + UpdateNonce(); + + // On Windows, BCrypt algorithm/key handles are created per call. This is simpler than + // caching but has some overhead. For our use case (relatively large, infrequent messages) + // this is acceptable. +#if ZEN_PLATFORM_WINDOWS + BCRYPT_ALG_HANDLE hAlg = nullptr; + BCRYPT_KEY_HANDLE hKey = nullptr; + + BCryptOpenAlgorithmProvider(&hAlg, BCRYPT_AES_ALGORITHM, nullptr, 0); + BCryptSetProperty(hAlg, BCRYPT_CHAINING_MODE, (PUCHAR)BCRYPT_CHAIN_MODE_GCM, sizeof(BCRYPT_CHAIN_MODE_GCM), 0); + BCryptGenerateSymmetricKey(hAlg, &hKey, nullptr, 0, (PUCHAR)Key, KeySize, 0); + + BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO AuthInfo; + BCRYPT_INIT_AUTH_MODE_INFO(AuthInfo); + AuthInfo.pbNonce = EncryptNonce; + AuthInfo.cbNonce = NonceBytes; + uint8_t Tag[TagBytes] = {}; + AuthInfo.pbTag = Tag; + AuthInfo.cbTag = TagBytes; + + ULONG CipherLen = 0; + NTSTATUS Status = + BCryptEncrypt(hKey, (PUCHAR)In, (ULONG)InLength, &AuthInfo, nullptr, 0, Out + 4 + NonceBytes, (ULONG)InLength, &CipherLen, 0); + + if (!BCRYPT_SUCCESS(Status)) + { + HasErrors = true; + BCryptDestroyKey(hKey); + BCryptCloseAlgorithmProvider(hAlg, 0); + return 0; + } + + // Write header: length + nonce + memcpy(Out, &InLength, 4); + memcpy(Out + 4, EncryptNonce, NonceBytes); + // Write tag after ciphertext + memcpy(Out + 4 + NonceBytes + CipherLen, Tag, TagBytes); + + BCryptDestroyKey(hKey); + BCryptCloseAlgorithmProvider(hAlg, 0); + + return 4 + NonceBytes + static_cast<int32_t>(CipherLen) + TagBytes; +#else + if (EVP_EncryptInit_ex(EncCtx, nullptr, nullptr, Key, EncryptNonce) != 1) + { + HasErrors = true; + return 0; + } + + int32_t Offset = 0; + // Write length + memcpy(Out + Offset, &InLength, 4); + Offset += 4; + // Write nonce + memcpy(Out + Offset, EncryptNonce, NonceBytes); + Offset += NonceBytes; + + // Encrypt + int OutLen = 0; + if (EVP_EncryptUpdate(EncCtx, Out + Offset, &OutLen, static_cast<const uint8_t*>(In), InLength) != 1) + { + HasErrors = true; + return 0; + } + Offset += OutLen; + + // Finalize + int FinalLen = 0; + if (EVP_EncryptFinal_ex(EncCtx, Out + Offset, &FinalLen) != 1) + { + HasErrors = true; + return 0; + } + Offset += FinalLen; + + // Get tag + if (EVP_CIPHER_CTX_ctrl(EncCtx, EVP_CTRL_GCM_GET_TAG, TagBytes, Out + Offset) != 1) + { + HasErrors = true; + return 0; + } + Offset += TagBytes; + + return Offset; +#endif + } + + // Decrypt a message. Returns decrypted data length, or 0 on failure. + // Input must be [ciphertext][tag], with nonce provided separately. + int32_t DecryptMessage(void* Out, const uint8_t* Nonce, const uint8_t* CipherAndTag, int32_t DataLength) + { +#if ZEN_PLATFORM_WINDOWS + BCRYPT_ALG_HANDLE hAlg = nullptr; + BCRYPT_KEY_HANDLE hKey = nullptr; + + BCryptOpenAlgorithmProvider(&hAlg, BCRYPT_AES_ALGORITHM, nullptr, 0); + BCryptSetProperty(hAlg, BCRYPT_CHAINING_MODE, (PUCHAR)BCRYPT_CHAIN_MODE_GCM, sizeof(BCRYPT_CHAIN_MODE_GCM), 0); + BCryptGenerateSymmetricKey(hAlg, &hKey, nullptr, 0, (PUCHAR)Key, KeySize, 0); + + BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO AuthInfo; + BCRYPT_INIT_AUTH_MODE_INFO(AuthInfo); + AuthInfo.pbNonce = const_cast<uint8_t*>(Nonce); + AuthInfo.cbNonce = NonceBytes; + AuthInfo.pbTag = const_cast<uint8_t*>(CipherAndTag + DataLength); + AuthInfo.cbTag = TagBytes; + + ULONG PlainLen = 0; + NTSTATUS Status = BCryptDecrypt(hKey, + (PUCHAR)CipherAndTag, + (ULONG)DataLength, + &AuthInfo, + nullptr, + 0, + (PUCHAR)Out, + (ULONG)DataLength, + &PlainLen, + 0); + + BCryptDestroyKey(hKey); + BCryptCloseAlgorithmProvider(hAlg, 0); + + if (!BCRYPT_SUCCESS(Status)) + { + HasErrors = true; + return 0; + } + + return static_cast<int32_t>(PlainLen); +#else + if (EVP_DecryptInit_ex(DecCtx, nullptr, nullptr, Key, Nonce) != 1) + { + HasErrors = true; + return 0; + } + + int OutLen = 0; + if (EVP_DecryptUpdate(DecCtx, static_cast<uint8_t*>(Out), &OutLen, CipherAndTag, DataLength) != 1) + { + HasErrors = true; + return 0; + } + + // Set the tag for verification + if (EVP_CIPHER_CTX_ctrl(DecCtx, EVP_CTRL_GCM_SET_TAG, TagBytes, const_cast<uint8_t*>(CipherAndTag + DataLength)) != 1) + { + HasErrors = true; + return 0; + } + + int FinalLen = 0; + if (EVP_DecryptFinal_ex(DecCtx, static_cast<uint8_t*>(Out) + OutLen, &FinalLen) != 1) + { + HasErrors = true; + return 0; + } + + return OutLen + FinalLen; +#endif + } +}; + +AesComputeTransport::AesComputeTransport(const uint8_t (&Key)[KeySize], std::unique_ptr<ComputeTransport> InnerTransport) +: m_Crypto(std::make_unique<CryptoContext>(Key)) +, m_Inner(std::move(InnerTransport)) +{ +} + +AesComputeTransport::~AesComputeTransport() +{ + Close(); +} + +bool +AesComputeTransport::IsValid() const +{ + return m_Inner && m_Inner->IsValid() && m_Crypto && !m_Crypto->HasErrors && !m_IsClosed; +} + +size_t +AesComputeTransport::Send(const void* Data, size_t Size) +{ + ZEN_TRACE_CPU("AesComputeTransport::Send"); + + if (!IsValid()) + { + return 0; + } + + std::lock_guard<std::mutex> Lock(m_Lock); + + const int32_t DataLength = static_cast<int32_t>(Size); + const size_t MessageLength = 4 + NonceBytes + Size + TagBytes; + + if (m_EncryptBuffer.size() < MessageLength) + { + m_EncryptBuffer.resize(MessageLength); + } + + const int32_t EncryptedLen = m_Crypto->EncryptMessage(m_EncryptBuffer.data(), Data, DataLength); + if (EncryptedLen == 0) + { + return 0; + } + + if (!m_Inner->SendMessage(m_EncryptBuffer.data(), static_cast<size_t>(EncryptedLen))) + { + return 0; + } + + return Size; +} + +size_t +AesComputeTransport::Recv(void* Data, size_t Size) +{ + if (!IsValid()) + { + return 0; + } + + // AES-GCM decrypts entire messages at once, but the caller may request fewer bytes + // than the decrypted message contains. Excess bytes are buffered in m_RemainingData + // and returned on subsequent Recv calls without another decryption round-trip. + ZEN_TRACE_CPU("AesComputeTransport::Recv"); + + std::lock_guard<std::mutex> Lock(m_Lock); + + if (!m_RemainingData.empty()) + { + const size_t Available = m_RemainingData.size() - m_RemainingOffset; + const size_t ToCopy = std::min(Available, Size); + + memcpy(Data, m_RemainingData.data() + m_RemainingOffset, ToCopy); + m_RemainingOffset += ToCopy; + + if (m_RemainingOffset >= m_RemainingData.size()) + { + m_RemainingData.clear(); + m_RemainingOffset = 0; + } + + return ToCopy; + } + + // Receive packet header: [length(4B)][nonce(12B)] + struct PacketHeader + { + int32_t DataLength = 0; + uint8_t Nonce[NonceBytes] = {}; + } Header; + + if (!m_Inner->RecvMessage(&Header, sizeof(Header))) + { + return 0; + } + + // Validate DataLength to prevent OOM from malicious/corrupt peers + static constexpr int32_t MaxDataLength = 64 * 1024 * 1024; // 64 MiB + + if (Header.DataLength <= 0 || Header.DataLength > MaxDataLength) + { + ZEN_WARN("AES recv: invalid DataLength {} from peer", Header.DataLength); + return 0; + } + + // Receive ciphertext + tag + const size_t MessageLength = static_cast<size_t>(Header.DataLength) + TagBytes; + + if (m_EncryptBuffer.size() < MessageLength) + { + m_EncryptBuffer.resize(MessageLength); + } + + if (!m_Inner->RecvMessage(m_EncryptBuffer.data(), MessageLength)) + { + return 0; + } + + // Decrypt + const size_t BytesToReturn = std::min(static_cast<size_t>(Header.DataLength), Size); + + // We need a temporary buffer for decryption if we can't decrypt directly into output + std::vector<uint8_t> DecryptedBuf(static_cast<size_t>(Header.DataLength)); + + const int32_t Decrypted = m_Crypto->DecryptMessage(DecryptedBuf.data(), Header.Nonce, m_EncryptBuffer.data(), Header.DataLength); + if (Decrypted == 0) + { + return 0; + } + + memcpy(Data, DecryptedBuf.data(), BytesToReturn); + + // Store remaining data if we couldn't return everything + if (static_cast<size_t>(Header.DataLength) > BytesToReturn) + { + m_RemainingOffset = 0; + m_RemainingData.assign(DecryptedBuf.begin() + BytesToReturn, DecryptedBuf.begin() + Header.DataLength); + } + + return BytesToReturn; +} + +void +AesComputeTransport::MarkComplete() +{ + if (IsValid()) + { + m_Inner->MarkComplete(); + } +} + +void +AesComputeTransport::Close() +{ + if (!m_IsClosed) + { + if (m_Inner && m_Inner->IsValid()) + { + m_Inner->Close(); + } + m_IsClosed = true; + } +} + +} // namespace zen::horde diff --git a/src/zenhorde/hordetransportaes.h b/src/zenhorde/hordetransportaes.h new file mode 100644 index 000000000..efcad9835 --- /dev/null +++ b/src/zenhorde/hordetransportaes.h @@ -0,0 +1,52 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "hordetransport.h" + +#include <cstdint> +#include <memory> +#include <mutex> +#include <vector> + +namespace zen::horde { + +/** AES-256-GCM encrypted transport wrapper. + * + * Wraps an inner ComputeTransport, encrypting all outgoing data and decrypting + * all incoming data using AES-256-GCM. The nonce is mutated per message using + * the Horde nonce mangling scheme: n32[0]++; n32[1]--; n32[2] = n32[0] ^ n32[1]. + * + * Wire format per encrypted message: + * [plaintext length (4B little-endian)][nonce (12B)][ciphertext][GCM tag (16B)] + * + * Uses BCrypt on Windows and OpenSSL EVP on Linux/macOS (selected at compile time). + */ +class AesComputeTransport final : public ComputeTransport +{ +public: + AesComputeTransport(const uint8_t (&Key)[KeySize], std::unique_ptr<ComputeTransport> InnerTransport); + ~AesComputeTransport() override; + + bool IsValid() const override; + size_t Send(const void* Data, size_t Size) override; + size_t Recv(void* Data, size_t Size) override; + void MarkComplete() override; + void Close() override; + +private: + static constexpr size_t NonceBytes = 12; ///< AES-GCM nonce size + static constexpr size_t TagBytes = 16; ///< AES-GCM authentication tag size + + struct CryptoContext; + + std::unique_ptr<CryptoContext> m_Crypto; + std::unique_ptr<ComputeTransport> m_Inner; + std::vector<uint8_t> m_EncryptBuffer; + std::vector<uint8_t> m_RemainingData; ///< Buffered decrypted data from a partially consumed Recv + size_t m_RemainingOffset = 0; + std::mutex m_Lock; + bool m_IsClosed = false; +}; + +} // namespace zen::horde diff --git a/src/zenhorde/include/zenhorde/hordeclient.h b/src/zenhorde/include/zenhorde/hordeclient.h new file mode 100644 index 000000000..201d68b83 --- /dev/null +++ b/src/zenhorde/include/zenhorde/hordeclient.h @@ -0,0 +1,116 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenhorde/hordeconfig.h> + +#include <zencore/logbase.h> + +#include <cstdint> +#include <map> +#include <memory> +#include <string> +#include <vector> + +namespace zen { +class HttpClient; +} + +namespace zen::horde { + +static constexpr size_t NonceSize = 64; +static constexpr size_t KeySize = 32; + +/** Port mapping information returned by Horde for a provisioned machine. */ +struct PortInfo +{ + uint16_t Port = 0; + uint16_t AgentPort = 0; +}; + +/** Describes a provisioned compute machine returned by the Horde API. + * + * Contains the network address, encryption credentials, and capabilities + * needed to establish a compute transport connection to the machine. + */ +struct MachineInfo +{ + std::string Ip; + ConnectionMode Mode = ConnectionMode::Direct; + std::string ConnectionAddress; ///< Relay/tunnel address (used when Mode != Direct) + uint16_t Port = 0; + uint16_t LogicalCores = 0; + Encryption EncryptionMode = Encryption::None; + uint8_t Nonce[NonceSize] = {}; ///< 64-byte nonce sent during TCP handshake + uint8_t Key[KeySize] = {}; ///< 32-byte AES key (when EncryptionMode == AES) + bool IsWindows = false; + std::string LeaseId; + + std::map<std::string, PortInfo> Ports; + + /** Return the address to connect to, accounting for connection mode. */ + const std::string& GetConnectionAddress() const { return Mode == ConnectionMode::Relay ? ConnectionAddress : Ip; } + + /** Return the port to connect to, accounting for connection mode and port mapping. */ + uint16_t GetConnectionPort() const + { + if (Mode == ConnectionMode::Relay) + { + auto It = Ports.find("_horde_compute"); + if (It != Ports.end()) + { + return It->second.Port; + } + } + return Port; + } + + bool IsValid() const { return !Ip.empty() && Port != 0xFFFF; } +}; + +/** Result of cluster auto-resolution via the Horde API. */ +struct ClusterInfo +{ + std::string ClusterId = "default"; +}; + +/** HTTP client for the Horde compute REST API. + * + * Handles cluster resolution and machine provisioning requests. Each call + * is synchronous and returns success/failure. Thread safety: individual + * methods are not thread-safe; callers must synchronize access. + */ +class HordeClient +{ +public: + explicit HordeClient(const HordeConfig& Config); + ~HordeClient(); + + HordeClient(const HordeClient&) = delete; + HordeClient& operator=(const HordeClient&) = delete; + + /** Initialize the underlying HTTP client. Must be called before other methods. */ + bool Initialize(); + + /** Build the JSON request body for cluster resolution and machine requests. + * Encodes pool, condition, connection mode, encryption, and port requirements. */ + std::string BuildRequestBody() const; + + /** Resolve the best cluster for the given request via POST /api/v2/compute/_cluster. */ + bool ResolveCluster(const std::string& RequestBody, ClusterInfo& OutCluster); + + /** Request a compute machine from the given cluster via POST /api/v2/compute/{clusterId}. + * On success, populates OutMachine with connection details and credentials. */ + bool RequestMachine(const std::string& RequestBody, const std::string& ClusterId, MachineInfo& OutMachine); + + LoggerRef Log() { return m_Log; } + +private: + bool ParseHexBytes(std::string_view Hex, uint8_t* Out, size_t OutSize); + + HordeConfig m_Config; + std::unique_ptr<zen::HttpClient> m_Http; + LoggerRef m_Log; +}; + +} // namespace zen::horde diff --git a/src/zenhorde/include/zenhorde/hordeconfig.h b/src/zenhorde/include/zenhorde/hordeconfig.h new file mode 100644 index 000000000..dd70f9832 --- /dev/null +++ b/src/zenhorde/include/zenhorde/hordeconfig.h @@ -0,0 +1,62 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenhorde/zenhorde.h> + +#include <string> + +namespace zen::horde { + +/** Transport connection mode for Horde compute agents. */ +enum class ConnectionMode +{ + Direct, ///< Connect directly to the agent IP + Tunnel, ///< Connect through a Horde tunnel relay + Relay, ///< Connect through a Horde relay with port mapping +}; + +/** Transport encryption mode for Horde compute channels. */ +enum class Encryption +{ + None, ///< No encryption + AES, ///< AES-256-GCM encryption (required for Relay mode) +}; + +/** Configuration for connecting to an Epic Horde compute cluster. + * + * Specifies the Horde server URL, authentication token, pool selection, + * connection mode, and resource limits. Used by HordeClient and HordeProvisioner. + */ +struct HordeConfig +{ + static constexpr const char* ClusterDefault = "default"; + static constexpr const char* ClusterAuto = "_auto"; + + bool Enabled = false; ///< Whether Horde provisioning is active + std::string ServerUrl; ///< Horde server base URL (e.g. "https://horde.example.com") + std::string AuthToken; ///< Authentication token for the Horde API + std::string Pool; ///< Pool name to request machines from + std::string Cluster = ClusterDefault; ///< Cluster ID, or "_auto" to auto-resolve + std::string Condition; ///< Agent filter expression for machine selection + std::string HostAddress; ///< Address that provisioned agents use to connect back to us + std::string BinariesPath; ///< Path to directory containing zenserver binary for remote upload + uint16_t ZenServicePort = 8558; ///< Port number that provisioned agents should forward to us for Zen service communication + + int MaxCores = 2048; + bool AllowWine = true; ///< Allow running Windows binaries under Wine on Linux agents + ConnectionMode Mode = ConnectionMode::Direct; + Encryption EncryptionMode = Encryption::None; + + /** Validate the configuration. Returns false if the configuration is invalid + * (e.g. Relay mode without AES encryption). */ + bool Validate() const; +}; + +const char* ToString(ConnectionMode Mode); +const char* ToString(Encryption Enc); + +bool FromString(ConnectionMode& OutMode, std::string_view Str); +bool FromString(Encryption& OutEnc, std::string_view Str); + +} // namespace zen::horde diff --git a/src/zenhorde/include/zenhorde/hordeprovisioner.h b/src/zenhorde/include/zenhorde/hordeprovisioner.h new file mode 100644 index 000000000..4e2e63bbd --- /dev/null +++ b/src/zenhorde/include/zenhorde/hordeprovisioner.h @@ -0,0 +1,110 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenhorde/hordeconfig.h> + +#include <zencore/logbase.h> + +#include <atomic> +#include <cstdint> +#include <filesystem> +#include <memory> +#include <mutex> +#include <string> +#include <vector> + +namespace zen::horde { + +class HordeClient; + +/** Snapshot of the current provisioning state, returned by HordeProvisioner::GetStats(). */ +struct ProvisioningStats +{ + uint32_t TargetCoreCount = 0; ///< Requested number of cores (clamped to MaxCores) + uint32_t EstimatedCoreCount = 0; ///< Cores expected once pending requests complete + uint32_t ActiveCoreCount = 0; ///< Cores on machines that are currently running zenserver + uint32_t AgentsActive = 0; ///< Number of agents with a running remote process + uint32_t AgentsRequesting = 0; ///< Number of agents currently requesting a machine from Horde +}; + +/** Multi-agent lifecycle manager for Horde worker provisioning. + * + * Provisions remote compute workers by requesting machines from the Horde API, + * connecting via the Horde compute transport protocol, uploading the zenserver + * binary, and executing it remotely. Each provisioned machine runs zenserver + * in compute mode, which announces itself back to the orchestrator. + * + * Spawns one thread per agent. Each thread handles the full lifecycle: + * HTTP request -> TCP connect -> nonce handshake -> optional AES encryption -> + * channel setup -> binary upload -> remote execution -> poll until exit. + * + * Thread safety: SetTargetCoreCount and GetStats may be called from any thread. + */ +class HordeProvisioner +{ +public: + /** Construct a provisioner. + * @param Config Horde connection and pool configuration. + * @param BinariesPath Directory containing the zenserver binary to upload. + * @param WorkingDir Local directory for bundle staging and working files. + * @param OrchestratorEndpoint URL of the orchestrator that remote workers announce to. */ + HordeProvisioner(const HordeConfig& Config, + const std::filesystem::path& BinariesPath, + const std::filesystem::path& WorkingDir, + std::string_view OrchestratorEndpoint); + + /** Signals all agent threads to exit and joins them. */ + ~HordeProvisioner(); + + HordeProvisioner(const HordeProvisioner&) = delete; + HordeProvisioner& operator=(const HordeProvisioner&) = delete; + + /** Set the target number of cores to provision. + * Clamped to HordeConfig::MaxCores. Spawns new agent threads if the + * estimated core count is below the target. Also joins any finished + * agent threads. */ + void SetTargetCoreCount(uint32_t Count); + + /** Return a snapshot of the current provisioning counters. */ + ProvisioningStats GetStats() const; + + uint32_t GetActiveCoreCount() const { return m_ActiveCoreCount.load(); } + uint32_t GetAgentCount() const; + +private: + LoggerRef Log() { return m_Log; } + + struct AgentWrapper; + + void RequestAgent(); + void ThreadAgent(AgentWrapper& Wrapper); + + HordeConfig m_Config; + std::filesystem::path m_BinariesPath; + std::filesystem::path m_WorkingDir; + std::string m_OrchestratorEndpoint; + + std::unique_ptr<HordeClient> m_HordeClient; + + std::mutex m_BundleLock; + std::vector<std::pair<std::string, std::filesystem::path>> m_Bundles; ///< (locator, bundleDir) pairs + bool m_BundlesCreated = false; + + mutable std::mutex m_AgentsLock; + std::vector<std::unique_ptr<AgentWrapper>> m_Agents; + + std::atomic<uint64_t> m_LastRequestFailTime{0}; + std::atomic<uint32_t> m_TargetCoreCount{0}; + std::atomic<uint32_t> m_EstimatedCoreCount{0}; + std::atomic<uint32_t> m_ActiveCoreCount{0}; + std::atomic<uint32_t> m_AgentsActive{0}; + std::atomic<uint32_t> m_AgentsRequesting{0}; + std::atomic<bool> m_AskForAgents{true}; + + LoggerRef m_Log; + + static constexpr uint32_t EstimatedCoresPerAgent = 32; +}; + +} // namespace zen::horde diff --git a/src/zenhorde/include/zenhorde/zenhorde.h b/src/zenhorde/include/zenhorde/zenhorde.h new file mode 100644 index 000000000..35147ff75 --- /dev/null +++ b/src/zenhorde/include/zenhorde/zenhorde.h @@ -0,0 +1,9 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/zencore.h> + +#if !defined(ZEN_WITH_HORDE) +# define ZEN_WITH_HORDE 1 +#endif diff --git a/src/zenhorde/xmake.lua b/src/zenhorde/xmake.lua new file mode 100644 index 000000000..48d028e86 --- /dev/null +++ b/src/zenhorde/xmake.lua @@ -0,0 +1,22 @@ +-- Copyright Epic Games, Inc. All Rights Reserved. + +target('zenhorde') + set_kind("static") + set_group("libs") + add_headerfiles("**.h") + add_files("**.cpp") + add_includedirs("include", {public=true}) + add_deps("zencore", "zenhttp", "zencompute", "zenutil") + add_packages("asio", "json11") + + if is_plat("windows") then + add_syslinks("Ws2_32", "Bcrypt") + end + + if is_plat("linux") or is_plat("macosx") then + add_packages("openssl") + end + + if is_os("macosx") then + add_cxxflags("-Wno-deprecated-declarations") + end diff --git a/src/zenhttp/include/zenhttp/httpserver.h b/src/zenhttp/include/zenhttp/httpserver.h index 62c080a7b..02cccc540 100644 --- a/src/zenhttp/include/zenhttp/httpserver.h +++ b/src/zenhttp/include/zenhttp/httpserver.h @@ -103,6 +103,7 @@ public: virtual bool IsLocalMachineRequest() const = 0; virtual std::string_view GetAuthorizationHeader() const = 0; + virtual std::string_view GetRemoteAddress() const { return {}; } /** Respond with payload diff --git a/src/zenhttp/servers/httpasio.cpp b/src/zenhttp/servers/httpasio.cpp index c4d9ee777..33f182df9 100644 --- a/src/zenhttp/servers/httpasio.cpp +++ b/src/zenhttp/servers/httpasio.cpp @@ -544,7 +544,8 @@ public: HttpService& Service, IoBuffer PayloadBuffer, uint32_t RequestNumber, - bool IsLocalMachineRequest); + bool IsLocalMachineRequest, + std::string RemoteAddress); ~HttpAsioServerRequest(); virtual Oid ParseSessionId() const override; @@ -552,6 +553,7 @@ public: virtual bool IsLocalMachineRequest() const override; virtual std::string_view GetAuthorizationHeader() const override; + virtual std::string_view GetRemoteAddress() const override; virtual IoBuffer ReadPayload() override; virtual void WriteResponse(HttpResponseCode ResponseCode) override; @@ -569,6 +571,7 @@ public: uint32_t m_RequestNumber = 0; // Note: different to request ID which is derived from headers IoBuffer m_PayloadBuffer; bool m_IsLocalMachineRequest; + std::string m_RemoteAddress; std::unique_ptr<HttpResponse> m_Response; }; @@ -1238,9 +1241,15 @@ HttpServerConnection::HandleRequest() { ZEN_TRACE_CPU("asio::HandleRequest"); - bool IsLocalConnection = m_Socket->local_endpoint().address() == m_Socket->remote_endpoint().address(); + auto RemoteEndpoint = m_Socket->remote_endpoint(); + bool IsLocalConnection = m_Socket->local_endpoint().address() == RemoteEndpoint.address(); - HttpAsioServerRequest Request(m_RequestData, *Service, m_RequestData.Body(), RequestNumber, IsLocalConnection); + HttpAsioServerRequest Request(m_RequestData, + *Service, + m_RequestData.Body(), + RequestNumber, + IsLocalConnection, + RemoteEndpoint.address().to_string()); ZEN_TRACE_VERBOSE("handle request, connection: {}, request: {}'", m_ConnectionId, RequestNumber); @@ -1725,12 +1734,14 @@ HttpAsioServerRequest::HttpAsioServerRequest(HttpRequestParser& Request, HttpService& Service, IoBuffer PayloadBuffer, uint32_t RequestNumber, - bool IsLocalMachineRequest) + bool IsLocalMachineRequest, + std::string RemoteAddress) : HttpServerRequest(Service) , m_Request(Request) , m_RequestNumber(RequestNumber) , m_PayloadBuffer(std::move(PayloadBuffer)) , m_IsLocalMachineRequest(IsLocalMachineRequest) +, m_RemoteAddress(std::move(RemoteAddress)) { const int PrefixLength = Service.UriPrefixLength(); @@ -1809,6 +1820,12 @@ HttpAsioServerRequest::IsLocalMachineRequest() const } std::string_view +HttpAsioServerRequest::GetRemoteAddress() const +{ + return m_RemoteAddress; +} + +std::string_view HttpAsioServerRequest::GetAuthorizationHeader() const { return m_Request.AuthorizationHeader(); diff --git a/src/zenhttp/servers/httpsys.cpp b/src/zenhttp/servers/httpsys.cpp index a48f1d316..cf639c114 100644 --- a/src/zenhttp/servers/httpsys.cpp +++ b/src/zenhttp/servers/httpsys.cpp @@ -336,8 +336,9 @@ public: virtual Oid ParseSessionId() const override; virtual uint32_t ParseRequestId() const override; - virtual bool IsLocalMachineRequest() const; + virtual bool IsLocalMachineRequest() const override; virtual std::string_view GetAuthorizationHeader() const override; + virtual std::string_view GetRemoteAddress() const override; virtual IoBuffer ReadPayload() override; virtual void WriteResponse(HttpResponseCode ResponseCode) override; @@ -353,11 +354,12 @@ public: HttpSysServerRequest(const HttpSysServerRequest&) = delete; HttpSysServerRequest& operator=(const HttpSysServerRequest&) = delete; - HttpSysTransaction& m_HttpTx; - HttpSysRequestHandler* m_NextCompletionHandler = nullptr; - IoBuffer m_PayloadBuffer; - ExtendableStringBuilder<128> m_UriUtf8; - ExtendableStringBuilder<128> m_QueryStringUtf8; + HttpSysTransaction& m_HttpTx; + HttpSysRequestHandler* m_NextCompletionHandler = nullptr; + IoBuffer m_PayloadBuffer; + ExtendableStringBuilder<128> m_UriUtf8; + ExtendableStringBuilder<128> m_QueryStringUtf8; + mutable ExtendableStringBuilder<64> m_RemoteAddress; }; /** HTTP transaction @@ -1902,6 +1904,17 @@ HttpSysServerRequest::IsLocalMachineRequest() const } std::string_view +HttpSysServerRequest::GetRemoteAddress() const +{ + if (m_RemoteAddress.Size() == 0) + { + const SOCKADDR* SockAddr = m_HttpTx.HttpRequest()->Address.pRemoteAddress; + GetAddressString(m_RemoteAddress, SockAddr, /* IncludePort */ false); + } + return m_RemoteAddress.ToView(); +} + +std::string_view HttpSysServerRequest::GetAuthorizationHeader() const { const HTTP_REQUEST* HttpRequestPtr = m_HttpTx.HttpRequest(); diff --git a/src/zennomad/include/zennomad/nomadclient.h b/src/zennomad/include/zennomad/nomadclient.h new file mode 100644 index 000000000..0a3411ace --- /dev/null +++ b/src/zennomad/include/zennomad/nomadclient.h @@ -0,0 +1,77 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zennomad/nomadconfig.h> + +#include <zencore/logbase.h> + +#include <memory> +#include <string> +#include <vector> + +namespace zen { +class HttpClient; +} + +namespace zen::nomad { + +/** Summary of a Nomad job returned by the API. */ +struct NomadJobInfo +{ + std::string Id; + std::string Status; ///< "pending", "running", "dead" + std::string StatusDescription; +}; + +/** Summary of a Nomad allocation returned by the API. */ +struct NomadAllocInfo +{ + std::string Id; + std::string ClientStatus; ///< "pending", "running", "complete", "failed" + std::string TaskState; ///< State of the task within the allocation +}; + +/** HTTP client for the Nomad REST API (v1). + * + * Handles job submission, status polling, and job termination. + * All calls are synchronous. Thread safety: individual methods are + * not thread-safe; callers must synchronize access. + */ +class NomadClient +{ +public: + explicit NomadClient(const NomadConfig& Config); + ~NomadClient(); + + NomadClient(const NomadClient&) = delete; + NomadClient& operator=(const NomadClient&) = delete; + + /** Initialize the underlying HTTP client. Must be called before other methods. */ + bool Initialize(); + + /** Build the Nomad job registration JSON for the given job ID and orchestrator endpoint. + * The JSON structure varies based on the configured driver and distribution mode. */ + std::string BuildJobJson(const std::string& JobId, const std::string& OrchestratorEndpoint) const; + + /** Submit a job via PUT /v1/jobs. On success, populates OutJob with the job info. */ + bool SubmitJob(const std::string& JobJson, NomadJobInfo& OutJob); + + /** Get the status of a job via GET /v1/job/{jobId}. */ + bool GetJobStatus(const std::string& JobId, NomadJobInfo& OutJob); + + /** Get allocations for a job via GET /v1/job/{jobId}/allocations. */ + bool GetAllocations(const std::string& JobId, std::vector<NomadAllocInfo>& OutAllocs); + + /** Stop a job via DELETE /v1/job/{jobId}. */ + bool StopJob(const std::string& JobId); + + LoggerRef Log() { return m_Log; } + +private: + NomadConfig m_Config; + std::unique_ptr<zen::HttpClient> m_Http; + LoggerRef m_Log; +}; + +} // namespace zen::nomad diff --git a/src/zennomad/include/zennomad/nomadconfig.h b/src/zennomad/include/zennomad/nomadconfig.h new file mode 100644 index 000000000..92d2bbaca --- /dev/null +++ b/src/zennomad/include/zennomad/nomadconfig.h @@ -0,0 +1,65 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zennomad/zennomad.h> + +#include <string> + +namespace zen::nomad { + +/** Nomad task driver type. */ +enum class Driver +{ + RawExec, ///< Use Nomad raw_exec driver (direct process execution) + Docker, ///< Use Nomad Docker driver +}; + +/** How the zenserver binary is made available on Nomad clients. */ +enum class BinaryDistribution +{ + PreDeployed, ///< Binary is already present on Nomad client nodes + Artifact, ///< Download binary via Nomad artifact stanza +}; + +/** Configuration for Nomad worker provisioning. + * + * Specifies the Nomad server URL, authentication, resource limits, and + * job configuration. Used by NomadClient and NomadProvisioner. + */ +struct NomadConfig +{ + bool Enabled = false; ///< Whether Nomad provisioning is active + std::string ServerUrl; ///< Nomad HTTP API URL (e.g. "http://localhost:4646") + std::string AclToken; ///< Nomad ACL token (sent as X-Nomad-Token header) + std::string Datacenter = "dc1"; ///< Target datacenter + std::string Namespace = "default"; ///< Nomad namespace + std::string Region; ///< Nomad region (empty = server default) + + Driver TaskDriver = Driver::RawExec; ///< Task driver for job execution + BinaryDistribution BinDistribution = BinaryDistribution::PreDeployed; ///< How to distribute the zenserver binary + + std::string BinaryPath; ///< Path to zenserver on Nomad clients (PreDeployed mode) + std::string ArtifactSource; ///< URL to download zenserver binary (Artifact mode) + std::string DockerImage; ///< Docker image name (Docker driver mode) + + int MaxJobs = 64; ///< Maximum concurrent Nomad jobs + int CpuMhz = 1000; ///< CPU MHz allocated per task + int MemoryMb = 2048; ///< Memory MB allocated per task + int CoresPerJob = 32; ///< Estimated cores per job (for scaling calculations) + int MaxCores = 2048; ///< Maximum total cores to provision + + std::string JobPrefix = "zenserver-worker"; ///< Prefix for generated Nomad job IDs + + /** Validate the configuration. Returns false if required fields are missing + * or incompatible options are set. */ + bool Validate() const; +}; + +const char* ToString(Driver D); +const char* ToString(BinaryDistribution Dist); + +bool FromString(Driver& OutDriver, std::string_view Str); +bool FromString(BinaryDistribution& OutDist, std::string_view Str); + +} // namespace zen::nomad diff --git a/src/zennomad/include/zennomad/nomadprocess.h b/src/zennomad/include/zennomad/nomadprocess.h new file mode 100644 index 000000000..a66c2ce41 --- /dev/null +++ b/src/zennomad/include/zennomad/nomadprocess.h @@ -0,0 +1,78 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenhttp/httpclient.h> + +#include <memory> +#include <string> +#include <string_view> +#include <vector> + +namespace zen::nomad { + +struct NomadJobInfo; +struct NomadAllocInfo; + +/** Manages a Nomad agent process running in dev mode for testing. + * + * Spawns `nomad agent -dev` and polls the HTTP API until the agent + * is ready. On destruction or via StopNomadAgent(), the agent + * process is killed. + */ +class NomadProcess +{ +public: + NomadProcess(); + ~NomadProcess(); + + NomadProcess(const NomadProcess&) = delete; + NomadProcess& operator=(const NomadProcess&) = delete; + + /** Spawn a Nomad dev agent and block until the leader endpoint responds (10 s timeout). */ + void SpawnNomadAgent(); + + /** Kill the Nomad agent process. */ + void StopNomadAgent(); + +private: + struct Impl; + std::unique_ptr<Impl> m_Impl; +}; + +/** Lightweight HTTP wrapper around the Nomad v1 REST API for use in tests. + * + * Unlike the production NomadClient (which requires a NomadConfig and + * supports all driver/distribution modes), this client exposes a simpler + * interface geared towards test scenarios. + */ +class NomadTestClient +{ +public: + explicit NomadTestClient(std::string_view BaseUri); + ~NomadTestClient(); + + NomadTestClient(const NomadTestClient&) = delete; + NomadTestClient& operator=(const NomadTestClient&) = delete; + + /** Submit a raw_exec batch job. + * Returns the parsed job info on success; Id will be empty on failure. */ + NomadJobInfo SubmitJob(std::string_view JobId, std::string_view Command, const std::vector<std::string>& Args); + + /** Query the status of an existing job. */ + NomadJobInfo GetJobStatus(std::string_view JobId); + + /** Stop (deregister) a running job. */ + void StopJob(std::string_view JobId); + + /** Get allocations for a job. */ + std::vector<NomadAllocInfo> GetAllocations(std::string_view JobId); + + /** List all jobs, optionally filtered by prefix. */ + std::vector<NomadJobInfo> ListJobs(std::string_view Prefix = ""); + +private: + HttpClient m_HttpClient; +}; + +} // namespace zen::nomad diff --git a/src/zennomad/include/zennomad/nomadprovisioner.h b/src/zennomad/include/zennomad/nomadprovisioner.h new file mode 100644 index 000000000..750693b3f --- /dev/null +++ b/src/zennomad/include/zennomad/nomadprovisioner.h @@ -0,0 +1,107 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zennomad/nomadconfig.h> + +#include <zencore/logbase.h> + +#include <atomic> +#include <condition_variable> +#include <cstdint> +#include <memory> +#include <mutex> +#include <string> +#include <thread> +#include <vector> + +namespace zen::nomad { + +class NomadClient; + +/** Snapshot of the current Nomad provisioning state, returned by NomadProvisioner::GetStats(). */ +struct NomadProvisioningStats +{ + uint32_t TargetCoreCount = 0; ///< Requested number of cores (clamped to MaxCores) + uint32_t EstimatedCoreCount = 0; ///< Cores expected from submitted jobs + uint32_t ActiveJobCount = 0; ///< Number of currently tracked Nomad jobs + uint32_t RunningJobCount = 0; ///< Number of jobs in "running" status +}; + +/** Job lifecycle manager for Nomad worker provisioning. + * + * Provisions remote compute workers by submitting batch jobs to a Nomad + * cluster via the REST API. Each job runs zenserver in compute mode, which + * announces itself back to the orchestrator. + * + * Uses a single management thread that periodically: + * 1. Submits new jobs when estimated cores < target cores + * 2. Polls existing jobs for status changes + * 3. Cleans up dead/failed jobs and adjusts counters + * + * Thread safety: SetTargetCoreCount and GetStats may be called from any thread. + */ +class NomadProvisioner +{ +public: + /** Construct a provisioner. + * @param Config Nomad connection and job configuration. + * @param OrchestratorEndpoint URL of the orchestrator that remote workers announce to. */ + NomadProvisioner(const NomadConfig& Config, std::string_view OrchestratorEndpoint); + + /** Signals the management thread to exit and stops all tracked jobs. */ + ~NomadProvisioner(); + + NomadProvisioner(const NomadProvisioner&) = delete; + NomadProvisioner& operator=(const NomadProvisioner&) = delete; + + /** Set the target number of cores to provision. + * Clamped to NomadConfig::MaxCores. The management thread will + * submit new jobs to approach this target. */ + void SetTargetCoreCount(uint32_t Count); + + /** Return a snapshot of the current provisioning counters. */ + NomadProvisioningStats GetStats() const; + +private: + LoggerRef Log() { return m_Log; } + + struct TrackedJob + { + std::string JobId; + std::string Status; ///< "pending", "running", "dead" + int Cores = 0; + }; + + void ManagementThread(); + void SubmitNewJobs(); + void PollExistingJobs(); + void CleanupDeadJobs(); + void StopAllJobs(); + + std::string GenerateJobId(); + + NomadConfig m_Config; + std::string m_OrchestratorEndpoint; + + std::unique_ptr<NomadClient> m_Client; + + mutable std::mutex m_JobsLock; + std::vector<TrackedJob> m_Jobs; + std::atomic<uint32_t> m_JobIndex{0}; + + std::atomic<uint32_t> m_TargetCoreCount{0}; + std::atomic<uint32_t> m_EstimatedCoreCount{0}; + std::atomic<uint32_t> m_RunningJobCount{0}; + + std::thread m_Thread; + std::mutex m_WakeMutex; + std::condition_variable m_WakeCV; + std::atomic<bool> m_ShouldExit{false}; + + uint32_t m_ProcessId = 0; + + LoggerRef m_Log; +}; + +} // namespace zen::nomad diff --git a/src/zennomad/include/zennomad/zennomad.h b/src/zennomad/include/zennomad/zennomad.h new file mode 100644 index 000000000..09fb98dfe --- /dev/null +++ b/src/zennomad/include/zennomad/zennomad.h @@ -0,0 +1,9 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/zencore.h> + +#if !defined(ZEN_WITH_NOMAD) +# define ZEN_WITH_NOMAD 1 +#endif diff --git a/src/zennomad/nomadclient.cpp b/src/zennomad/nomadclient.cpp new file mode 100644 index 000000000..9edcde125 --- /dev/null +++ b/src/zennomad/nomadclient.cpp @@ -0,0 +1,366 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencore/fmtutils.h> +#include <zencore/iobuffer.h> +#include <zencore/logging.h> +#include <zencore/memoryview.h> +#include <zencore/trace.h> +#include <zenhttp/httpclient.h> +#include <zennomad/nomadclient.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <json11.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen::nomad { + +namespace { + + HttpClient::KeyValueMap MakeNomadHeaders(const NomadConfig& Config) + { + HttpClient::KeyValueMap Headers; + if (!Config.AclToken.empty()) + { + Headers->emplace("X-Nomad-Token", Config.AclToken); + } + return Headers; + } + +} // namespace + +NomadClient::NomadClient(const NomadConfig& Config) : m_Config(Config), m_Log(zen::logging::Get("nomad.client")) +{ +} + +NomadClient::~NomadClient() = default; + +bool +NomadClient::Initialize() +{ + ZEN_TRACE_CPU("NomadClient::Initialize"); + + HttpClientSettings Settings; + Settings.LogCategory = "nomad.http"; + Settings.ConnectTimeout = std::chrono::milliseconds{10000}; + Settings.Timeout = std::chrono::milliseconds{60000}; + Settings.RetryCount = 1; + + // Ensure the base URL ends with a slash so path concatenation works correctly + std::string BaseUrl = m_Config.ServerUrl; + if (!BaseUrl.empty() && BaseUrl.back() != '/') + { + BaseUrl += '/'; + } + + m_Http = std::make_unique<zen::HttpClient>(BaseUrl, Settings); + + return true; +} + +std::string +NomadClient::BuildJobJson(const std::string& JobId, const std::string& OrchestratorEndpoint) const +{ + ZEN_TRACE_CPU("NomadClient::BuildJobJson"); + + // Build the task config based on driver and distribution mode + json11::Json::object TaskConfig; + + if (m_Config.TaskDriver == Driver::RawExec) + { + std::string Command; + if (m_Config.BinDistribution == BinaryDistribution::PreDeployed) + { + Command = m_Config.BinaryPath; + } + else + { + // Artifact mode: binary is downloaded to local/zenserver + Command = "local/zenserver"; + } + + TaskConfig["command"] = Command; + + json11::Json::array Args; + Args.push_back("compute"); + Args.push_back("--http=asio"); + if (!OrchestratorEndpoint.empty()) + { + ExtendableStringBuilder<256> CoordArg; + CoordArg << "--coordinator-endpoint=" << OrchestratorEndpoint; + Args.push_back(std::string(CoordArg.ToView())); + } + { + ExtendableStringBuilder<128> IdArg; + IdArg << "--instance-id=nomad-" << JobId; + Args.push_back(std::string(IdArg.ToView())); + } + TaskConfig["args"] = Args; + } + else + { + // Docker driver + TaskConfig["image"] = m_Config.DockerImage; + + json11::Json::array Args; + Args.push_back("compute"); + Args.push_back("--http=asio"); + if (!OrchestratorEndpoint.empty()) + { + ExtendableStringBuilder<256> CoordArg; + CoordArg << "--coordinator-endpoint=" << OrchestratorEndpoint; + Args.push_back(std::string(CoordArg.ToView())); + } + { + ExtendableStringBuilder<128> IdArg; + IdArg << "--instance-id=nomad-" << JobId; + Args.push_back(std::string(IdArg.ToView())); + } + TaskConfig["args"] = Args; + } + + // Build resource stanza + json11::Json::object Resources; + Resources["CPU"] = m_Config.CpuMhz; + Resources["MemoryMB"] = m_Config.MemoryMb; + + // Build the task + json11::Json::object Task; + Task["Name"] = "zenserver"; + Task["Driver"] = (m_Config.TaskDriver == Driver::RawExec) ? "raw_exec" : "docker"; + Task["Config"] = TaskConfig; + Task["Resources"] = Resources; + + // Add artifact stanza if using artifact distribution + if (m_Config.BinDistribution == BinaryDistribution::Artifact && !m_Config.ArtifactSource.empty()) + { + json11::Json::object Artifact; + Artifact["GetterSource"] = m_Config.ArtifactSource; + + json11::Json::array Artifacts; + Artifacts.push_back(Artifact); + Task["Artifacts"] = Artifacts; + } + + json11::Json::array Tasks; + Tasks.push_back(Task); + + // Build the task group + json11::Json::object Group; + Group["Name"] = "zenserver-group"; + Group["Count"] = 1; + Group["Tasks"] = Tasks; + + json11::Json::array Groups; + Groups.push_back(Group); + + // Build datacenters array + json11::Json::array Datacenters; + Datacenters.push_back(m_Config.Datacenter); + + // Build the job + json11::Json::object Job; + Job["ID"] = JobId; + Job["Name"] = JobId; + Job["Type"] = "batch"; + Job["Datacenters"] = Datacenters; + Job["TaskGroups"] = Groups; + + if (!m_Config.Namespace.empty() && m_Config.Namespace != "default") + { + Job["Namespace"] = m_Config.Namespace; + } + + if (!m_Config.Region.empty()) + { + Job["Region"] = m_Config.Region; + } + + // Wrap in the registration envelope + json11::Json::object Root; + Root["Job"] = Job; + + return json11::Json(Root).dump(); +} + +bool +NomadClient::SubmitJob(const std::string& JobJson, NomadJobInfo& OutJob) +{ + ZEN_TRACE_CPU("NomadClient::SubmitJob"); + + const IoBuffer Payload = IoBufferBuilder::MakeFromMemory(MemoryView{JobJson.data(), JobJson.size()}, ZenContentType::kJSON); + + const HttpClient::Response Response = m_Http->Put("v1/jobs", Payload, MakeNomadHeaders(m_Config)); + + if (Response.Error) + { + ZEN_WARN("Nomad job submit failed: {}", Response.Error->ErrorMessage); + return false; + } + + const int StatusCode = static_cast<int>(Response.StatusCode); + + if (!Response.IsSuccess()) + { + ZEN_WARN("Nomad job submit failed with HTTP/{}", StatusCode); + return false; + } + + const std::string Body(Response.AsText()); + std::string Err; + const json11::Json Json = json11::Json::parse(Body, Err); + + if (!Err.empty()) + { + ZEN_WARN("invalid JSON response from Nomad job submit: {}", Err); + return false; + } + + // The response contains EvalID; the job ID is what we submitted + OutJob.Id = Json["JobModifyIndex"].is_number() ? OutJob.Id : ""; + OutJob.Status = "pending"; + + ZEN_INFO("Nomad job submitted: eval_id={}", Json["EvalID"].string_value()); + + return true; +} + +bool +NomadClient::GetJobStatus(const std::string& JobId, NomadJobInfo& OutJob) +{ + ZEN_TRACE_CPU("NomadClient::GetJobStatus"); + + ExtendableStringBuilder<128> Path; + Path << "v1/job/" << JobId; + + const HttpClient::Response Response = m_Http->Get(Path.ToView(), MakeNomadHeaders(m_Config)); + + if (Response.Error) + { + ZEN_WARN("Nomad job status query failed for '{}': {}", JobId, Response.Error->ErrorMessage); + return false; + } + + const int StatusCode = static_cast<int>(Response.StatusCode); + + if (StatusCode == 404) + { + ZEN_INFO("Nomad job '{}' not found", JobId); + OutJob.Status = "dead"; + return true; + } + + if (!Response.IsSuccess()) + { + ZEN_WARN("Nomad job status query failed with HTTP/{}", StatusCode); + return false; + } + + const std::string Body(Response.AsText()); + std::string Err; + const json11::Json Json = json11::Json::parse(Body, Err); + + if (!Err.empty()) + { + ZEN_WARN("invalid JSON in Nomad job status response: {}", Err); + return false; + } + + OutJob.Id = Json["ID"].string_value(); + OutJob.Status = Json["Status"].string_value(); + if (const json11::Json Desc = Json["StatusDescription"]; Desc.is_string()) + { + OutJob.StatusDescription = Desc.string_value(); + } + + return true; +} + +bool +NomadClient::GetAllocations(const std::string& JobId, std::vector<NomadAllocInfo>& OutAllocs) +{ + ZEN_TRACE_CPU("NomadClient::GetAllocations"); + + ExtendableStringBuilder<128> Path; + Path << "v1/job/" << JobId << "/allocations"; + + const HttpClient::Response Response = m_Http->Get(Path.ToView(), MakeNomadHeaders(m_Config)); + + if (Response.Error) + { + ZEN_WARN("Nomad allocation query failed for '{}': {}", JobId, Response.Error->ErrorMessage); + return false; + } + + if (!Response.IsSuccess()) + { + ZEN_WARN("Nomad allocation query failed with HTTP/{}", static_cast<int>(Response.StatusCode)); + return false; + } + + const std::string Body(Response.AsText()); + std::string Err; + const json11::Json Json = json11::Json::parse(Body, Err); + + if (!Err.empty()) + { + ZEN_WARN("invalid JSON in Nomad allocation response: {}", Err); + return false; + } + + OutAllocs.clear(); + if (!Json.is_array()) + { + return true; + } + + for (const json11::Json& AllocVal : Json.array_items()) + { + NomadAllocInfo Alloc; + Alloc.Id = AllocVal["ID"].string_value(); + Alloc.ClientStatus = AllocVal["ClientStatus"].string_value(); + + // Extract task state if available + if (const json11::Json TaskStates = AllocVal["TaskStates"]; TaskStates.is_object()) + { + for (const auto& [TaskName, TaskState] : TaskStates.object_items()) + { + if (TaskState["State"].is_string()) + { + Alloc.TaskState = TaskState["State"].string_value(); + } + } + } + + OutAllocs.push_back(std::move(Alloc)); + } + + return true; +} + +bool +NomadClient::StopJob(const std::string& JobId) +{ + ZEN_TRACE_CPU("NomadClient::StopJob"); + + ExtendableStringBuilder<128> Path; + Path << "v1/job/" << JobId; + + const HttpClient::Response Response = m_Http->Delete(Path.ToView(), MakeNomadHeaders(m_Config)); + + if (Response.Error) + { + ZEN_WARN("Nomad job stop failed for '{}': {}", JobId, Response.Error->ErrorMessage); + return false; + } + + if (!Response.IsSuccess()) + { + ZEN_WARN("Nomad job stop failed with HTTP/{}", static_cast<int>(Response.StatusCode)); + return false; + } + + ZEN_INFO("Nomad job '{}' stopped", JobId); + return true; +} + +} // namespace zen::nomad diff --git a/src/zennomad/nomadconfig.cpp b/src/zennomad/nomadconfig.cpp new file mode 100644 index 000000000..d55b3da9a --- /dev/null +++ b/src/zennomad/nomadconfig.cpp @@ -0,0 +1,91 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zennomad/nomadconfig.h> + +namespace zen::nomad { + +bool +NomadConfig::Validate() const +{ + if (ServerUrl.empty()) + { + return false; + } + + if (BinDistribution == BinaryDistribution::PreDeployed && BinaryPath.empty()) + { + return false; + } + + if (BinDistribution == BinaryDistribution::Artifact && ArtifactSource.empty()) + { + return false; + } + + if (TaskDriver == Driver::Docker && DockerImage.empty()) + { + return false; + } + + return true; +} + +const char* +ToString(Driver D) +{ + switch (D) + { + case Driver::RawExec: + return "raw_exec"; + case Driver::Docker: + return "docker"; + } + return "raw_exec"; +} + +const char* +ToString(BinaryDistribution Dist) +{ + switch (Dist) + { + case BinaryDistribution::PreDeployed: + return "predeployed"; + case BinaryDistribution::Artifact: + return "artifact"; + } + return "predeployed"; +} + +bool +FromString(Driver& OutDriver, std::string_view Str) +{ + if (Str == "raw_exec") + { + OutDriver = Driver::RawExec; + return true; + } + if (Str == "docker") + { + OutDriver = Driver::Docker; + return true; + } + return false; +} + +bool +FromString(BinaryDistribution& OutDist, std::string_view Str) +{ + if (Str == "predeployed") + { + OutDist = BinaryDistribution::PreDeployed; + return true; + } + if (Str == "artifact") + { + OutDist = BinaryDistribution::Artifact; + return true; + } + return false; +} + +} // namespace zen::nomad diff --git a/src/zennomad/nomadprocess.cpp b/src/zennomad/nomadprocess.cpp new file mode 100644 index 000000000..1ae968fb7 --- /dev/null +++ b/src/zennomad/nomadprocess.cpp @@ -0,0 +1,354 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zennomad/nomadclient.h> +#include <zennomad/nomadprocess.h> + +#include <zenbase/zenbase.h> +#include <zencore/fmtutils.h> +#include <zencore/iobuffer.h> +#include <zencore/logging.h> +#include <zencore/memoryview.h> +#include <zencore/process.h> +#include <zencore/timer.h> +#include <zencore/trace.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <json11.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +#include <fmt/format.h> + +namespace zen::nomad { + +////////////////////////////////////////////////////////////////////////// + +struct NomadProcess::Impl +{ + Impl(std::string_view BaseUri) : m_HttpClient(BaseUri) {} + ~Impl() = default; + + void SpawnNomadAgent() + { + ZEN_TRACE_CPU("SpawnNomadAgent"); + + if (m_ProcessHandle.IsValid()) + { + return; + } + + CreateProcOptions Options; + Options.Flags |= CreateProcOptions::Flag_Windows_NewProcessGroup; + + CreateProcResult Result = CreateProc("nomad" ZEN_EXE_SUFFIX_LITERAL, "nomad" ZEN_EXE_SUFFIX_LITERAL " agent -dev", Options); + + if (Result) + { + m_ProcessHandle.Initialize(Result); + + Stopwatch Timer; + + // Poll to check when the agent is ready + + do + { + Sleep(100); + HttpClient::Response Resp = m_HttpClient.Get("v1/status/leader"); + if (Resp) + { + ZEN_INFO("Nomad agent started successfully (waited {})", NiceTimeSpanMs(Timer.GetElapsedTimeMs())); + + return; + } + } while (Timer.GetElapsedTimeMs() < 30000); + } + + // Report failure! + + ZEN_WARN("Nomad agent failed to start within timeout period"); + } + + void StopNomadAgent() + { + if (!m_ProcessHandle.IsValid()) + { + return; + } + + // This waits for the process to exit and also resets the handle + m_ProcessHandle.Kill(); + } + +private: + ProcessHandle m_ProcessHandle; + HttpClient m_HttpClient; +}; + +NomadProcess::NomadProcess() : m_Impl(std::make_unique<Impl>("http://localhost:4646/")) +{ +} + +NomadProcess::~NomadProcess() +{ +} + +void +NomadProcess::SpawnNomadAgent() +{ + m_Impl->SpawnNomadAgent(); +} + +void +NomadProcess::StopNomadAgent() +{ + m_Impl->StopNomadAgent(); +} + +////////////////////////////////////////////////////////////////////////// + +NomadTestClient::NomadTestClient(std::string_view BaseUri) : m_HttpClient(BaseUri) +{ +} + +NomadTestClient::~NomadTestClient() +{ +} + +NomadJobInfo +NomadTestClient::SubmitJob(std::string_view JobId, std::string_view Command, const std::vector<std::string>& Args) +{ + ZEN_TRACE_CPU("SubmitNomadJob"); + + NomadJobInfo Result; + + // Build the job JSON for a raw_exec batch job + json11::Json::object TaskConfig; + TaskConfig["command"] = std::string(Command); + + json11::Json::array JsonArgs; + for (const auto& Arg : Args) + { + JsonArgs.push_back(Arg); + } + TaskConfig["args"] = JsonArgs; + + json11::Json::object Resources; + Resources["CPU"] = 100; + Resources["MemoryMB"] = 64; + + json11::Json::object Task; + Task["Name"] = "test-task"; + Task["Driver"] = "raw_exec"; + Task["Config"] = TaskConfig; + Task["Resources"] = Resources; + + json11::Json::array Tasks; + Tasks.push_back(Task); + + json11::Json::object Group; + Group["Name"] = "test-group"; + Group["Count"] = 1; + Group["Tasks"] = Tasks; + + json11::Json::array Groups; + Groups.push_back(Group); + + json11::Json::array Datacenters; + Datacenters.push_back("dc1"); + + json11::Json::object Job; + Job["ID"] = std::string(JobId); + Job["Name"] = std::string(JobId); + Job["Type"] = "batch"; + Job["Datacenters"] = Datacenters; + Job["TaskGroups"] = Groups; + + json11::Json::object Root; + Root["Job"] = Job; + + std::string Body = json11::Json(Root).dump(); + + IoBuffer Payload = IoBufferBuilder::MakeFromMemory(MemoryView{Body.data(), Body.size()}, ZenContentType::kJSON); + + HttpClient::Response Response = + m_HttpClient.Put("v1/jobs", Payload, {{"Content-Type", "application/json"}, {"Accept", "application/json"}}); + + if (!Response || !Response.IsSuccess()) + { + ZEN_WARN("NomadTestClient: SubmitJob failed for '{}'", JobId); + return Result; + } + + std::string ResponseBody(Response.AsText()); + std::string Err; + const json11::Json Json = json11::Json::parse(ResponseBody, Err); + + if (!Err.empty()) + { + ZEN_WARN("NomadTestClient: invalid JSON in SubmitJob response: {}", Err); + return Result; + } + + Result.Id = std::string(JobId); + Result.Status = "pending"; + + ZEN_INFO("NomadTestClient: job '{}' submitted (eval_id={})", JobId, Json["EvalID"].string_value()); + + return Result; +} + +NomadJobInfo +NomadTestClient::GetJobStatus(std::string_view JobId) +{ + ZEN_TRACE_CPU("GetNomadJobStatus"); + + NomadJobInfo Result; + + HttpClient::Response Response = m_HttpClient.Get(fmt::format("v1/job/{}", JobId)); + + if (Response.Error) + { + ZEN_WARN("NomadTestClient: GetJobStatus failed for '{}': {}", JobId, Response.Error->ErrorMessage); + return Result; + } + + if (static_cast<int>(Response.StatusCode) == 404) + { + Result.Status = "dead"; + return Result; + } + + if (!Response.IsSuccess()) + { + ZEN_WARN("NomadTestClient: GetJobStatus failed with HTTP/{}", static_cast<int>(Response.StatusCode)); + return Result; + } + + std::string Body(Response.AsText()); + std::string Err; + const json11::Json Json = json11::Json::parse(Body, Err); + + if (!Err.empty()) + { + ZEN_WARN("NomadTestClient: invalid JSON in GetJobStatus response: {}", Err); + return Result; + } + + Result.Id = Json["ID"].string_value(); + Result.Status = Json["Status"].string_value(); + if (const json11::Json Desc = Json["StatusDescription"]; Desc.is_string()) + { + Result.StatusDescription = Desc.string_value(); + } + + return Result; +} + +void +NomadTestClient::StopJob(std::string_view JobId) +{ + ZEN_TRACE_CPU("StopNomadJob"); + + HttpClient::Response Response = m_HttpClient.Delete(fmt::format("v1/job/{}", JobId)); + + if (!Response || !Response.IsSuccess()) + { + ZEN_WARN("NomadTestClient: StopJob failed for '{}'", JobId); + return; + } + + ZEN_INFO("NomadTestClient: job '{}' stopped", JobId); +} + +std::vector<NomadAllocInfo> +NomadTestClient::GetAllocations(std::string_view JobId) +{ + ZEN_TRACE_CPU("GetNomadAllocations"); + + std::vector<NomadAllocInfo> Allocs; + + HttpClient::Response Response = m_HttpClient.Get(fmt::format("v1/job/{}/allocations", JobId)); + + if (!Response || !Response.IsSuccess()) + { + ZEN_WARN("NomadTestClient: GetAllocations failed for '{}'", JobId); + return Allocs; + } + + std::string Body(Response.AsText()); + std::string Err; + const json11::Json Json = json11::Json::parse(Body, Err); + + if (!Err.empty() || !Json.is_array()) + { + return Allocs; + } + + for (const json11::Json& AllocVal : Json.array_items()) + { + NomadAllocInfo Alloc; + Alloc.Id = AllocVal["ID"].string_value(); + Alloc.ClientStatus = AllocVal["ClientStatus"].string_value(); + + if (const json11::Json TaskStates = AllocVal["TaskStates"]; TaskStates.is_object()) + { + for (const auto& [TaskName, TaskState] : TaskStates.object_items()) + { + if (TaskState["State"].is_string()) + { + Alloc.TaskState = TaskState["State"].string_value(); + } + } + } + + Allocs.push_back(std::move(Alloc)); + } + + return Allocs; +} + +std::vector<NomadJobInfo> +NomadTestClient::ListJobs(std::string_view Prefix) +{ + ZEN_TRACE_CPU("ListNomadJobs"); + + std::vector<NomadJobInfo> Jobs; + + std::string Url = "v1/jobs"; + if (!Prefix.empty()) + { + Url = fmt::format("v1/jobs?prefix={}", Prefix); + } + + HttpClient::Response Response = m_HttpClient.Get(Url); + + if (!Response || !Response.IsSuccess()) + { + ZEN_WARN("NomadTestClient: ListJobs failed"); + return Jobs; + } + + std::string Body(Response.AsText()); + std::string Err; + const json11::Json Json = json11::Json::parse(Body, Err); + + if (!Err.empty() || !Json.is_array()) + { + return Jobs; + } + + for (const json11::Json& JobVal : Json.array_items()) + { + NomadJobInfo Job; + Job.Id = JobVal["ID"].string_value(); + Job.Status = JobVal["Status"].string_value(); + if (const json11::Json Desc = JobVal["StatusDescription"]; Desc.is_string()) + { + Job.StatusDescription = Desc.string_value(); + } + Jobs.push_back(std::move(Job)); + } + + return Jobs; +} + +} // namespace zen::nomad diff --git a/src/zennomad/nomadprovisioner.cpp b/src/zennomad/nomadprovisioner.cpp new file mode 100644 index 000000000..3fe9c0ac3 --- /dev/null +++ b/src/zennomad/nomadprovisioner.cpp @@ -0,0 +1,264 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zennomad/nomadclient.h> +#include <zennomad/nomadprovisioner.h> + +#include <zencore/fmtutils.h> +#include <zencore/logging.h> +#include <zencore/process.h> +#include <zencore/scopeguard.h> +#include <zencore/thread.h> +#include <zencore/trace.h> + +#include <chrono> + +namespace zen::nomad { + +NomadProvisioner::NomadProvisioner(const NomadConfig& Config, std::string_view OrchestratorEndpoint) +: m_Config(Config) +, m_OrchestratorEndpoint(OrchestratorEndpoint) +, m_ProcessId(static_cast<uint32_t>(zen::GetCurrentProcessId())) +, m_Log(zen::logging::Get("nomad.provisioner")) +{ + ZEN_DEBUG("initializing provisioner (server: {}, driver: {}, max_cores: {}, cores_per_job: {}, max_jobs: {})", + m_Config.ServerUrl, + ToString(m_Config.TaskDriver), + m_Config.MaxCores, + m_Config.CoresPerJob, + m_Config.MaxJobs); + + m_Client = std::make_unique<NomadClient>(m_Config); + if (!m_Client->Initialize()) + { + ZEN_ERROR("failed to initialize Nomad HTTP client"); + return; + } + + ZEN_DEBUG("Nomad HTTP client initialized, starting management thread"); + + m_Thread = std::thread([this] { ManagementThread(); }); +} + +NomadProvisioner::~NomadProvisioner() +{ + ZEN_DEBUG("provisioner shutting down"); + + m_ShouldExit.store(true); + m_WakeCV.notify_all(); + + if (m_Thread.joinable()) + { + m_Thread.join(); + } + + StopAllJobs(); + + ZEN_DEBUG("provisioner shutdown complete"); +} + +void +NomadProvisioner::SetTargetCoreCount(uint32_t Count) +{ + const uint32_t Clamped = std::min(Count, static_cast<uint32_t>(m_Config.MaxCores)); + const uint32_t Previous = m_TargetCoreCount.exchange(Clamped); + + if (Clamped != Previous) + { + ZEN_DEBUG("target core count changed: {} -> {}", Previous, Clamped); + } + + m_WakeCV.notify_all(); +} + +NomadProvisioningStats +NomadProvisioner::GetStats() const +{ + NomadProvisioningStats Stats; + Stats.TargetCoreCount = m_TargetCoreCount.load(); + Stats.EstimatedCoreCount = m_EstimatedCoreCount.load(); + Stats.RunningJobCount = m_RunningJobCount.load(); + + { + std::lock_guard<std::mutex> Lock(m_JobsLock); + Stats.ActiveJobCount = static_cast<uint32_t>(m_Jobs.size()); + } + + return Stats; +} + +std::string +NomadProvisioner::GenerateJobId() +{ + const uint32_t Index = m_JobIndex.fetch_add(1); + + ExtendableStringBuilder<128> Builder; + Builder << m_Config.JobPrefix << "-" << m_ProcessId << "-" << Index; + return std::string(Builder.ToView()); +} + +void +NomadProvisioner::ManagementThread() +{ + ZEN_TRACE_CPU("Nomad_Mgmt"); + zen::SetCurrentThreadName("nomad_mgmt"); + + ZEN_INFO("Nomad management thread started"); + + while (!m_ShouldExit.load()) + { + ZEN_DEBUG("management cycle: target={} estimated={} running={} active={}", + m_TargetCoreCount.load(), + m_EstimatedCoreCount.load(), + m_RunningJobCount.load(), + [this] { + std::lock_guard<std::mutex> Lock(m_JobsLock); + return m_Jobs.size(); + }()); + + SubmitNewJobs(); + PollExistingJobs(); + CleanupDeadJobs(); + + // Wait up to 5 seconds or until woken + std::unique_lock<std::mutex> Lock(m_WakeMutex); + m_WakeCV.wait_for(Lock, std::chrono::seconds(5), [this] { return m_ShouldExit.load(); }); + } + + ZEN_INFO("Nomad management thread exiting"); +} + +void +NomadProvisioner::SubmitNewJobs() +{ + ZEN_TRACE_CPU("NomadProvisioner::SubmitNewJobs"); + + const uint32_t CoresPerJob = static_cast<uint32_t>(m_Config.CoresPerJob); + + while (m_EstimatedCoreCount.load() < m_TargetCoreCount.load()) + { + { + std::lock_guard<std::mutex> Lock(m_JobsLock); + if (static_cast<int>(m_Jobs.size()) >= m_Config.MaxJobs) + { + ZEN_INFO("Nomad max jobs limit reached ({})", m_Config.MaxJobs); + break; + } + } + + if (m_ShouldExit.load()) + { + break; + } + + const std::string JobId = GenerateJobId(); + + ZEN_DEBUG("submitting job '{}' (estimated: {}, target: {})", JobId, m_EstimatedCoreCount.load(), m_TargetCoreCount.load()); + + const std::string JobJson = m_Client->BuildJobJson(JobId, m_OrchestratorEndpoint); + + NomadJobInfo JobInfo; + JobInfo.Id = JobId; + + if (!m_Client->SubmitJob(JobJson, JobInfo)) + { + ZEN_WARN("failed to submit Nomad job '{}'", JobId); + break; + } + + TrackedJob Tracked; + Tracked.JobId = JobId; + Tracked.Status = "pending"; + Tracked.Cores = static_cast<int>(CoresPerJob); + + { + std::lock_guard<std::mutex> Lock(m_JobsLock); + m_Jobs.push_back(std::move(Tracked)); + } + + m_EstimatedCoreCount.fetch_add(CoresPerJob); + + ZEN_INFO("Nomad job '{}' submitted (estimated cores: {})", JobId, m_EstimatedCoreCount.load()); + } +} + +void +NomadProvisioner::PollExistingJobs() +{ + ZEN_TRACE_CPU("NomadProvisioner::PollExistingJobs"); + + std::lock_guard<std::mutex> Lock(m_JobsLock); + + for (auto& Job : m_Jobs) + { + if (m_ShouldExit.load()) + { + break; + } + + NomadJobInfo Info; + if (!m_Client->GetJobStatus(Job.JobId, Info)) + { + ZEN_DEBUG("failed to poll status for job '{}'", Job.JobId); + continue; + } + + const std::string PrevStatus = Job.Status; + Job.Status = Info.Status; + + if (PrevStatus != Job.Status) + { + ZEN_INFO("Nomad job '{}' status changed: {} -> {}", Job.JobId, PrevStatus, Job.Status); + + if (Job.Status == "running" && PrevStatus != "running") + { + m_RunningJobCount.fetch_add(1); + } + else if (Job.Status != "running" && PrevStatus == "running") + { + m_RunningJobCount.fetch_sub(1); + } + } + } +} + +void +NomadProvisioner::CleanupDeadJobs() +{ + ZEN_TRACE_CPU("NomadProvisioner::CleanupDeadJobs"); + + std::lock_guard<std::mutex> Lock(m_JobsLock); + + for (auto It = m_Jobs.begin(); It != m_Jobs.end();) + { + if (It->Status == "dead") + { + ZEN_INFO("Nomad job '{}' is dead, removing from tracked jobs", It->JobId); + m_EstimatedCoreCount.fetch_sub(static_cast<uint32_t>(It->Cores)); + It = m_Jobs.erase(It); + } + else + { + ++It; + } + } +} + +void +NomadProvisioner::StopAllJobs() +{ + ZEN_TRACE_CPU("NomadProvisioner::StopAllJobs"); + + std::lock_guard<std::mutex> Lock(m_JobsLock); + + for (const auto& Job : m_Jobs) + { + ZEN_INFO("stopping Nomad job '{}' during shutdown", Job.JobId); + m_Client->StopJob(Job.JobId); + } + + m_Jobs.clear(); + m_EstimatedCoreCount.store(0); + m_RunningJobCount.store(0); +} + +} // namespace zen::nomad diff --git a/src/zennomad/xmake.lua b/src/zennomad/xmake.lua new file mode 100644 index 000000000..ef1a8b201 --- /dev/null +++ b/src/zennomad/xmake.lua @@ -0,0 +1,10 @@ +-- Copyright Epic Games, Inc. All Rights Reserved. + +target('zennomad') + set_kind("static") + set_group("libs") + add_headerfiles("**.h") + add_files("**.cpp") + add_includedirs("include", {public=true}) + add_deps("zencore", "zenhttp", "zenutil") + add_packages("json11") diff --git a/src/zenremotestore/builds/buildstorageoperations.cpp b/src/zenremotestore/builds/buildstorageoperations.cpp index f4b4d592b..43a4937f0 100644 --- a/src/zenremotestore/builds/buildstorageoperations.cpp +++ b/src/zenremotestore/builds/buildstorageoperations.cpp @@ -8186,7 +8186,7 @@ TEST_CASE("buildstorageoperations.partial.block.download" * doctest::skip(true)) Headers); REQUIRE(GetBlobRangesResponse.IsSuccess()); - MemoryView RangesMemoryView = GetBlobRangesResponse.ResponsePayload.GetView(); + [[maybe_unused]] MemoryView RangesMemoryView = GetBlobRangesResponse.ResponsePayload.GetView(); std::vector<std::pair<uint64_t, uint64_t>> PayloadRanges = GetBlobRangesResponse.GetRanges(Ranges); if (PayloadRanges.empty()) diff --git a/src/zenremotestore/chunking/chunkingcache.cpp b/src/zenremotestore/chunking/chunkingcache.cpp index f4e1c7837..e9b783a00 100644 --- a/src/zenremotestore/chunking/chunkingcache.cpp +++ b/src/zenremotestore/chunking/chunkingcache.cpp @@ -75,13 +75,13 @@ public: { Lock.ReleaseNow(); RwLock::ExclusiveLockScope EditLock(m_Lock); - if (auto RemoveIt = m_PathHashToEntry.find(PathHash); It != m_PathHashToEntry.end()) + if (auto RemoveIt = m_PathHashToEntry.find(PathHash); RemoveIt != m_PathHashToEntry.end()) { - CachedEntry& DeleteEntry = m_Entries[It->second]; + CachedEntry& DeleteEntry = m_Entries[RemoveIt->second]; DeleteEntry.Chunked = {}; DeleteEntry.ModificationTick = 0; - m_FreeEntryIndexes.push_back(It->second); - m_PathHashToEntry.erase(It); + m_FreeEntryIndexes.push_back(RemoveIt->second); + m_PathHashToEntry.erase(RemoveIt); } } } diff --git a/src/zenserver-test/compute-tests.cpp b/src/zenserver-test/compute-tests.cpp new file mode 100644 index 000000000..c90ac5d8b --- /dev/null +++ b/src/zenserver-test/compute-tests.cpp @@ -0,0 +1,1700 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencore/zencore.h> + +#if ZEN_WITH_TESTS && ZEN_WITH_COMPUTE_SERVICES + +# include <zenbase/zenbase.h> +# include <zencore/compactbinary.h> +# include <zencore/compactbinarybuilder.h> +# include <zencore/compactbinarypackage.h> +# include <zencore/compress.h> +# include <zencore/filesystem.h> +# include <zencore/guid.h> +# include <zencore/iobuffer.h> +# include <zencore/iohash.h> +# include <zencore/testing.h> +# include <zencore/testutils.h> +# include <zencore/thread.h> +# include <zencore/timer.h> +# include <zenhttp/httpclient.h> +# include <zenhttp/httpserver.h> +# include <zencompute/computeservice.h> +# include <zenstore/zenstore.h> +# include <zenutil/zenserverprocess.h> + +# include "zenserver-test.h" + +# include <thread> + +namespace zen::tests::compute { + +using namespace std::literals; + +// BuildSystemVersion and function version GUIDs matching zentest-appstub +static constexpr std::string_view kBuildSystemVersion = "17fe280d-ccd8-4be8-a9d1-89c944a70969"; +static constexpr std::string_view kRot13Version = "13131313-1313-1313-1313-131313131313"; +static constexpr std::string_view kSleepVersion = "88888888-8888-8888-8888-888888888888"; + +// In-memory implementation of ChunkResolver for test use. +// Stores compressed data keyed by decompressed content hash. +class InMemoryChunkResolver : public ChunkResolver +{ +public: + IoBuffer FindChunkByCid(const IoHash& DecompressedId) override + { + auto It = m_Chunks.find(DecompressedId); + if (It != m_Chunks.end()) + { + return It->second; + } + return {}; + } + + void AddChunk(const IoHash& DecompressedId, IoBuffer Data) { m_Chunks[DecompressedId] = std::move(Data); } + +private: + std::unordered_map<IoHash, IoBuffer> m_Chunks; +}; + +// Read, compress, and register zentest-appstub as a worker. +// Returns the WorkerId (hash of the worker package object). +static IoHash +RegisterWorker(HttpClient& Client, ZenServerEnvironment& Env) +{ + std::filesystem::path AppStubPath = Env.ProgramBaseDir() / ("zentest-appstub" ZEN_EXE_SUFFIX_LITERAL); + + FileContents AppStubData = zen::ReadFile(AppStubPath); + REQUIRE_MESSAGE(!AppStubData.ErrorCode, fmt::format("Failed to read '{}': {}", AppStubPath.string(), AppStubData.ErrorCode.message())); + + IoBuffer AppStubBuffer = AppStubData.Flatten(); + + CompressedBuffer AppStubCompressed = CompressedBuffer::Compress(SharedBuffer::MakeView(AppStubBuffer.GetData(), AppStubBuffer.Size()), + OodleCompressor::Selkie, + OodleCompressionLevel::HyperFast4); + + const IoHash AppStubRawHash = AppStubCompressed.DecodeRawHash(); + const uint64_t AppStubRawSize = AppStubBuffer.Size(); + + CbAttachment AppStubAttachment(std::move(AppStubCompressed), AppStubRawHash); + + CbObjectWriter WorkerWriter; + WorkerWriter << "buildsystem_version"sv << Guid::FromString(kBuildSystemVersion); + WorkerWriter << "path"sv + << "zentest-appstub"sv; + + WorkerWriter.BeginArray("executables"sv); + WorkerWriter.BeginObject(); + WorkerWriter << "name"sv + << "zentest-appstub"sv; + WorkerWriter.AddAttachment("hash"sv, AppStubAttachment); + WorkerWriter << "size"sv << AppStubRawSize; + WorkerWriter.EndObject(); + WorkerWriter.EndArray(); + + WorkerWriter.BeginArray("functions"sv); + WorkerWriter.BeginObject(); + WorkerWriter << "name"sv + << "Rot13"sv; + WorkerWriter << "version"sv << Guid::FromString(kRot13Version); + WorkerWriter.EndObject(); + WorkerWriter.BeginObject(); + WorkerWriter << "name"sv + << "Sleep"sv; + WorkerWriter << "version"sv << Guid::FromString(kSleepVersion); + WorkerWriter.EndObject(); + WorkerWriter.EndArray(); + + CbPackage WorkerPackage; + WorkerPackage.SetObject(WorkerWriter.Save()); + WorkerPackage.AddAttachment(AppStubAttachment); + + const IoHash WorkerId = WorkerPackage.GetObjectHash(); + + const std::string WorkerUrl = fmt::format("/workers/{}", WorkerId.ToHexString()); + HttpClient::Response RegisterResp = Client.Post(WorkerUrl, std::move(WorkerPackage)); + REQUIRE_MESSAGE(RegisterResp, + fmt::format("Worker registration failed: status={}, body={}", int(RegisterResp.StatusCode), RegisterResp.ToText())); + + return WorkerId; +} + +// Build a Rot13 action CbPackage for the given input string. +static CbPackage +BuildRot13ActionPackage(std::string_view Input) +{ + CompressedBuffer InputCompressed = CompressedBuffer::Compress(SharedBuffer::MakeView(Input.data(), Input.size()), + OodleCompressor::Selkie, + OodleCompressionLevel::HyperFast4); + + const IoHash InputRawHash = InputCompressed.DecodeRawHash(); + const uint64_t InputRawSize = Input.size(); + + CbAttachment InputAttachment(std::move(InputCompressed), InputRawHash); + + CbObjectWriter ActionWriter; + ActionWriter << "Function"sv + << "Rot13"sv; + ActionWriter << "FunctionVersion"sv << Guid::FromString(kRot13Version); + ActionWriter << "BuildSystemVersion"sv << Guid::FromString(kBuildSystemVersion); + ActionWriter.BeginObject("Inputs"sv); + ActionWriter.BeginObject("Source"sv); + ActionWriter.AddAttachment("RawHash"sv, InputAttachment); + ActionWriter << "RawSize"sv << InputRawSize; + ActionWriter.EndObject(); + ActionWriter.EndObject(); + + CbPackage ActionPackage; + ActionPackage.SetObject(ActionWriter.Save()); + ActionPackage.AddAttachment(InputAttachment); + + return ActionPackage; +} + +// Build a Sleep action CbPackage. The worker sleeps for SleepTimeMs before returning its input. +static CbPackage +BuildSleepActionPackage(std::string_view Input, uint64_t SleepTimeMs) +{ + CompressedBuffer InputCompressed = CompressedBuffer::Compress(SharedBuffer::MakeView(Input.data(), Input.size()), + OodleCompressor::Selkie, + OodleCompressionLevel::HyperFast4); + + const IoHash InputRawHash = InputCompressed.DecodeRawHash(); + const uint64_t InputRawSize = Input.size(); + + CbAttachment InputAttachment(std::move(InputCompressed), InputRawHash); + + CbObjectWriter ActionWriter; + ActionWriter << "Function"sv + << "Sleep"sv; + ActionWriter << "FunctionVersion"sv << Guid::FromString(kSleepVersion); + ActionWriter << "BuildSystemVersion"sv << Guid::FromString(kBuildSystemVersion); + ActionWriter.BeginObject("Inputs"sv); + ActionWriter.BeginObject("Source"sv); + ActionWriter.AddAttachment("RawHash"sv, InputAttachment); + ActionWriter << "RawSize"sv << InputRawSize; + ActionWriter.EndObject(); + ActionWriter.EndObject(); + ActionWriter.BeginObject("Constants"sv); + ActionWriter << "SleepTimeMs"sv << SleepTimeMs; + ActionWriter.EndObject(); + + CbPackage ActionPackage; + ActionPackage.SetObject(ActionWriter.Save()); + ActionPackage.AddAttachment(InputAttachment); + + return ActionPackage; +} + +// Build a Sleep action CbObject and populate the chunk resolver with the input attachment. +static CbObject +BuildSleepActionForSession(std::string_view Input, uint64_t SleepTimeMs, InMemoryChunkResolver& Resolver) +{ + CompressedBuffer InputCompressed = CompressedBuffer::Compress(SharedBuffer::MakeView(Input.data(), Input.size()), + OodleCompressor::Selkie, + OodleCompressionLevel::HyperFast4); + + const IoHash InputRawHash = InputCompressed.DecodeRawHash(); + const uint64_t InputRawSize = Input.size(); + + Resolver.AddChunk(InputRawHash, InputCompressed.GetCompressed().Flatten().AsIoBuffer()); + + CbAttachment InputAttachment(std::move(InputCompressed), InputRawHash); + + CbObjectWriter ActionWriter; + ActionWriter << "Function"sv + << "Sleep"sv; + ActionWriter << "FunctionVersion"sv << Guid::FromString(kSleepVersion); + ActionWriter << "BuildSystemVersion"sv << Guid::FromString(kBuildSystemVersion); + ActionWriter.BeginObject("Inputs"sv); + ActionWriter.BeginObject("Source"sv); + ActionWriter.AddAttachment("RawHash"sv, InputAttachment); + ActionWriter << "RawSize"sv << InputRawSize; + ActionWriter.EndObject(); + ActionWriter.EndObject(); + ActionWriter.BeginObject("Constants"sv); + ActionWriter << "SleepTimeMs"sv << SleepTimeMs; + ActionWriter.EndObject(); + + return ActionWriter.Save(); +} + +static HttpClient::Response +PollForResult(HttpClient& Client, const std::string& ResultUrl, uint64_t TimeoutMs = 30'000) +{ + HttpClient::Response Resp; + Stopwatch Timer; + + while (Timer.GetElapsedTimeMs() < TimeoutMs) + { + Resp = Client.Get(ResultUrl); + + if (Resp.StatusCode == HttpResponseCode::OK) + { + break; + } + + Sleep(100); + } + + return Resp; +} + +static bool +PollForLsnInCompleted(HttpClient& Client, const std::string& CompletedUrl, int Lsn, uint64_t TimeoutMs = 30'000) +{ + Stopwatch Timer; + + while (Timer.GetElapsedTimeMs() < TimeoutMs) + { + HttpClient::Response Resp = Client.Get(CompletedUrl); + + if (Resp) + { + for (auto& Item : Resp.AsObject()["completed"sv]) + { + if (Item.AsInt32() == Lsn) + { + return true; + } + } + } + + Sleep(100); + } + + return false; +} + +static std::string +GetRot13Output(const CbPackage& ResultPackage) +{ + CbObject ResultObj = ResultPackage.GetObject(); + + IoHash OutputHash; + CbFieldView ValuesField = ResultObj["Values"sv]; + + if (CbFieldViewIterator It = begin(ValuesField); It.HasValue()) + { + OutputHash = (*It).AsObjectView()["RawHash"sv].AsHash(); + } + + REQUIRE_MESSAGE(OutputHash != IoHash::Zero, "Expected non-zero output hash in result Values array"); + + const CbAttachment* OutputAttachment = ResultPackage.FindAttachment(OutputHash); + REQUIRE_MESSAGE(OutputAttachment != nullptr, "Output attachment not found in result package"); + + CompressedBuffer OutputCompressed = OutputAttachment->AsCompressedBinary(); + SharedBuffer OutputData = OutputCompressed.Decompress(); + + return std::string(static_cast<const char*>(OutputData.GetData()), OutputData.GetSize()); +} + +// Mock orchestrator HTTP service that serves GET /orch/agents with a controllable response. +class MockOrchestratorService : public HttpService +{ +public: + MockOrchestratorService() + { + // Initialize with empty worker list + CbObjectWriter Cbo; + Cbo.BeginArray("workers"sv); + Cbo.EndArray(); + m_WorkerList = Cbo.Save(); + } + + const char* BaseUri() const override { return "/orch/"; } + + void HandleRequest(HttpServerRequest& Request) override + { + if (Request.RequestVerb() == HttpVerb::kGet && Request.RelativeUri() == "agents"sv) + { + RwLock::SharedLockScope Lock(m_Lock); + Request.WriteResponse(HttpResponseCode::OK, m_WorkerList); + return; + } + Request.WriteResponse(HttpResponseCode::NotFound); + } + + void SetWorkerList(CbObject WorkerList) + { + RwLock::ExclusiveLockScope Lock(m_Lock); + m_WorkerList = std::move(WorkerList); + } + +private: + RwLock m_Lock; + CbObject m_WorkerList; +}; + +// Manages in-process ASIO HTTP server lifecycle for mock orchestrator. +struct MockOrchestratorFixture +{ + MockOrchestratorService Service; + ScopedTemporaryDirectory TmpDir; + Ref<HttpServer> Server; + std::thread ServerThread; + uint16_t Port = 0; + + MockOrchestratorFixture() + { + HttpServerConfig Config; + Config.ServerClass = "asio"; + Config.ForceLoopback = true; + Server = CreateHttpServer(Config); + Server->RegisterService(Service); + Port = static_cast<uint16_t>(Server->Initialize(TestEnv.GetNewPortNumber(), TmpDir.Path())); + ZEN_ASSERT(Port != 0); + ServerThread = std::thread([this]() { Server->Run(false); }); + } + + ~MockOrchestratorFixture() + { + Server->RequestExit(); + if (ServerThread.joinable()) + { + ServerThread.join(); + } + Server->Close(); + } + + std::string GetEndpoint() const { return fmt::format("http://localhost:{}", Port); } +}; + +// Build the CbObject response for /orch/agents matching the format UpdateCoordinatorState expects. +static CbObject +BuildAgentListResponse(std::initializer_list<std::pair<std::string_view, std::string_view>> Workers) +{ + CbObjectWriter Cbo; + Cbo.BeginArray("workers"sv); + for (const auto& [Id, Uri] : Workers) + { + Cbo.BeginObject(); + Cbo << "id"sv << Id; + Cbo << "uri"sv << Uri; + Cbo << "hostname"sv + << "localhost"sv; + Cbo << "reachable"sv << true; + Cbo << "dt"sv << uint64_t(0); + Cbo.EndObject(); + } + Cbo.EndArray(); + return Cbo.Save(); +} + +// Build the worker CbPackage for zentest-appstub AND populate the chunk resolver. +// This is the same logic as RegisterWorker() but returns the package instead of POSTing it. +static CbPackage +BuildWorkerPackage(ZenServerEnvironment& Env, InMemoryChunkResolver& Resolver) +{ + std::filesystem::path AppStubPath = Env.ProgramBaseDir() / ("zentest-appstub" ZEN_EXE_SUFFIX_LITERAL); + + FileContents AppStubData = zen::ReadFile(AppStubPath); + REQUIRE_MESSAGE(!AppStubData.ErrorCode, fmt::format("Failed to read '{}': {}", AppStubPath.string(), AppStubData.ErrorCode.message())); + + IoBuffer AppStubBuffer = AppStubData.Flatten(); + + CompressedBuffer AppStubCompressed = CompressedBuffer::Compress(SharedBuffer::MakeView(AppStubBuffer.GetData(), AppStubBuffer.Size()), + OodleCompressor::Selkie, + OodleCompressionLevel::HyperFast4); + + const IoHash AppStubRawHash = AppStubCompressed.DecodeRawHash(); + const uint64_t AppStubRawSize = AppStubBuffer.Size(); + + // Store compressed data in chunk resolver for when the remote runner needs it + Resolver.AddChunk(AppStubRawHash, AppStubCompressed.GetCompressed().Flatten().AsIoBuffer()); + + CbAttachment AppStubAttachment(std::move(AppStubCompressed), AppStubRawHash); + + CbObjectWriter WorkerWriter; + WorkerWriter << "buildsystem_version"sv << Guid::FromString(kBuildSystemVersion); + WorkerWriter << "path"sv + << "zentest-appstub"sv; + + WorkerWriter.BeginArray("executables"sv); + WorkerWriter.BeginObject(); + WorkerWriter << "name"sv + << "zentest-appstub"sv; + WorkerWriter.AddAttachment("hash"sv, AppStubAttachment); + WorkerWriter << "size"sv << AppStubRawSize; + WorkerWriter.EndObject(); + WorkerWriter.EndArray(); + + WorkerWriter.BeginArray("functions"sv); + WorkerWriter.BeginObject(); + WorkerWriter << "name"sv + << "Rot13"sv; + WorkerWriter << "version"sv << Guid::FromString(kRot13Version); + WorkerWriter.EndObject(); + WorkerWriter.BeginObject(); + WorkerWriter << "name"sv + << "Sleep"sv; + WorkerWriter << "version"sv << Guid::FromString(kSleepVersion); + WorkerWriter.EndObject(); + WorkerWriter.EndArray(); + + CbPackage WorkerPackage; + WorkerPackage.SetObject(WorkerWriter.Save()); + WorkerPackage.AddAttachment(AppStubAttachment); + + return WorkerPackage; +} + +// Build a Rot13 action CbObject (not CbPackage) and populate the chunk resolver with the input attachment. +static CbObject +BuildRot13ActionForSession(std::string_view Input, InMemoryChunkResolver& Resolver) +{ + CompressedBuffer InputCompressed = CompressedBuffer::Compress(SharedBuffer::MakeView(Input.data(), Input.size()), + OodleCompressor::Selkie, + OodleCompressionLevel::HyperFast4); + + const IoHash InputRawHash = InputCompressed.DecodeRawHash(); + const uint64_t InputRawSize = Input.size(); + + // Store compressed data in chunk resolver + Resolver.AddChunk(InputRawHash, InputCompressed.GetCompressed().Flatten().AsIoBuffer()); + + CbAttachment InputAttachment(std::move(InputCompressed), InputRawHash); + + CbObjectWriter ActionWriter; + ActionWriter << "Function"sv + << "Rot13"sv; + ActionWriter << "FunctionVersion"sv << Guid::FromString(kRot13Version); + ActionWriter << "BuildSystemVersion"sv << Guid::FromString(kBuildSystemVersion); + ActionWriter.BeginObject("Inputs"sv); + ActionWriter.BeginObject("Source"sv); + ActionWriter.AddAttachment("RawHash"sv, InputAttachment); + ActionWriter << "RawSize"sv << InputRawSize; + ActionWriter.EndObject(); + ActionWriter.EndObject(); + + return ActionWriter.Save(); +} + +TEST_SUITE_BEGIN("server.function"); + +TEST_CASE("function.rot13") +{ + ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(); + REQUIRE_MESSAGE(Port != 0, Instance.GetLogOutput()); + + const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port); + HttpClient Client(ComputeBaseUri); + + const IoHash WorkerId = RegisterWorker(Client, TestEnv); + + // Submit action via legacy /jobs/{worker} endpoint + const std::string JobUrl = fmt::format("/jobs/{}", WorkerId.ToHexString()); + HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildRot13ActionPackage("Hello World"sv)); + REQUIRE_MESSAGE(SubmitResp, fmt::format("Job submission failed: status={}, body={}", int(SubmitResp.StatusCode), SubmitResp.ToText())); + + const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32(); + REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN from job submission"); + + // Poll for result via legacy /jobs/{lsn} endpoint + const std::string ResultUrl = fmt::format("/jobs/{}", Lsn); + HttpClient::Response ResultResp = PollForResult(Client, ResultUrl); + REQUIRE_MESSAGE( + ResultResp.StatusCode == HttpResponseCode::OK, + fmt::format("Job did not complete in time. Last status: {}\nServer log:\n{}", int(ResultResp.StatusCode), Instance.GetLogOutput())); + + // Verify result: Rot13("Hello World") == "Uryyb Jbeyq" + CbPackage ResultPackage = ResultResp.AsPackage(); + REQUIRE_MESSAGE(bool(ResultPackage), fmt::format("Action failed (empty result package)\nServer log:\n{}", Instance.GetLogOutput())); + + CHECK_EQ(GetRot13Output(ResultPackage), "Uryyb Jbeyq"sv); +} + +TEST_CASE("function.workers") +{ + ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(); + REQUIRE_MESSAGE(Port != 0, Instance.GetLogOutput()); + + const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port); + HttpClient Client(ComputeBaseUri); + + // Before registration, GET /workers should return an empty list + HttpClient::Response EmptyListResp = Client.Get("/workers"sv); + REQUIRE_MESSAGE(EmptyListResp, "Failed to list workers before registration"); + CHECK_EQ(EmptyListResp.AsObject()["workers"sv].AsArrayView().Num(), 0); + + const IoHash WorkerId = RegisterWorker(Client, TestEnv); + + // GET /workers — the registered worker should appear in the listing + HttpClient::Response ListResp = Client.Get("/workers"sv); + REQUIRE_MESSAGE(ListResp, "Failed to list workers after registration"); + + bool WorkerFound = false; + for (auto& Item : ListResp.AsObject()["workers"sv]) + { + if (Item.AsHash() == WorkerId) + { + WorkerFound = true; + break; + } + } + + REQUIRE_MESSAGE(WorkerFound, fmt::format("Worker {} not found in worker listing", WorkerId.ToHexString())); + + // GET /workers/{worker} — descriptor should match what was registered + const std::string WorkerUrl = fmt::format("/workers/{}", WorkerId.ToHexString()); + HttpClient::Response DescResp = Client.Get(WorkerUrl); + REQUIRE_MESSAGE(DescResp, fmt::format("Failed to get worker descriptor: status={}", int(DescResp.StatusCode))); + + CbObject Desc = DescResp.AsObject(); + CHECK_EQ(Desc["buildsystem_version"sv].AsUuid(), Guid::FromString(kBuildSystemVersion)); + CHECK_EQ(Desc["path"sv].AsString(), "zentest-appstub"sv); + + bool Rot13Found = false; + bool SleepFound = false; + for (auto& Item : Desc["functions"sv]) + { + std::string_view Name = Item.AsObjectView()["name"sv].AsString(); + if (Name == "Rot13"sv) + { + CHECK_EQ(Item.AsObjectView()["version"sv].AsUuid(), Guid::FromString(kRot13Version)); + Rot13Found = true; + } + else if (Name == "Sleep"sv) + { + CHECK_EQ(Item.AsObjectView()["version"sv].AsUuid(), Guid::FromString(kSleepVersion)); + SleepFound = true; + } + } + + CHECK_MESSAGE(Rot13Found, "Rot13 function not found in worker descriptor"); + CHECK_MESSAGE(SleepFound, "Sleep function not found in worker descriptor"); + + // GET /workers/{unknown} — should return 404 + const std::string UnknownUrl = fmt::format("/workers/{}", IoHash::Zero.ToHexString()); + HttpClient::Response NotFoundResp = Client.Get(UnknownUrl); + CHECK_EQ(NotFoundResp.StatusCode, HttpResponseCode::NotFound); +} + +TEST_CASE("function.queues.lifecycle") +{ + ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(); + REQUIRE_MESSAGE(Port != 0, Instance.GetLogOutput()); + + const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port); + HttpClient Client(ComputeBaseUri); + + const IoHash WorkerId = RegisterWorker(Client, TestEnv); + + // Create a queue + HttpClient::Response CreateResp = Client.Post("/queues"sv); + REQUIRE_MESSAGE(CreateResp, fmt::format("Queue creation failed: status={}, body={}", int(CreateResp.StatusCode), CreateResp.ToText())); + + const int QueueId = CreateResp.AsObject()["queue_id"sv].AsInt32(); + REQUIRE_MESSAGE(QueueId != 0, "Expected non-zero queue_id from queue creation"); + + // Verify the queue appears in the listing + HttpClient::Response ListResp = Client.Get("/queues"sv); + REQUIRE_MESSAGE(ListResp, "Failed to list queues"); + + bool QueueFound = false; + for (auto& Item : ListResp.AsObject()["queues"sv]) + { + if (Item.AsObjectView()["queue_id"sv].AsInt32() == QueueId) + { + QueueFound = true; + break; + } + } + + REQUIRE_MESSAGE(QueueFound, fmt::format("Queue {} not found in queue listing", QueueId)); + + // Submit action via queue-scoped endpoint + const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString()); + HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildRot13ActionPackage("Hello World"sv)); + REQUIRE_MESSAGE(SubmitResp, + fmt::format("Queue job submission failed: status={}, body={}", int(SubmitResp.StatusCode), SubmitResp.ToText())); + + const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32(); + REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN from queue job submission"); + + // Poll for completion via queue-scoped /completed endpoint + const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueId); + REQUIRE_MESSAGE(PollForLsnInCompleted(Client, CompletedUrl, Lsn), + fmt::format("LSN {} did not appear in queue {} completed list within timeout\nServer log:\n{}", + Lsn, + QueueId, + Instance.GetLogOutput())); + + // Retrieve result via queue-scoped /jobs/{lsn} endpoint + const std::string ResultUrl = fmt::format("/queues/{}/jobs/{}", QueueId, Lsn); + HttpClient::Response ResultResp = Client.Get(ResultUrl); + REQUIRE_MESSAGE( + ResultResp.StatusCode == HttpResponseCode::OK, + fmt::format("Failed to retrieve result: status={}\nServer log:\n{}", int(ResultResp.StatusCode), Instance.GetLogOutput())); + + // Verify result: Rot13("Hello World") == "Uryyb Jbeyq" + CbPackage ResultPackage = ResultResp.AsPackage(); + REQUIRE_MESSAGE(bool(ResultPackage), fmt::format("Empty result package\nServer log:\n{}", Instance.GetLogOutput())); + + CHECK_EQ(GetRot13Output(ResultPackage), "Uryyb Jbeyq"sv); + + // Verify queue status reflects completion + const std::string StatusUrl = fmt::format("/queues/{}", QueueId); + HttpClient::Response StatusResp = Client.Get(StatusUrl); + REQUIRE_MESSAGE(StatusResp, "Failed to get queue status"); + + CbObject QueueStatus = StatusResp.AsObject(); + CHECK_EQ(QueueStatus["completed_count"sv].AsInt32(), 1); + CHECK_EQ(QueueStatus["active_count"sv].AsInt32(), 0); + CHECK_EQ(QueueStatus["failed_count"sv].AsInt32(), 0); + CHECK_EQ(std::string(QueueStatus["state"sv].AsString()), "active"); +} + +TEST_CASE("function.queues.cancel") +{ + ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(); + REQUIRE_MESSAGE(Port != 0, Instance.GetLogOutput()); + + const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port); + HttpClient Client(ComputeBaseUri); + + const IoHash WorkerId = RegisterWorker(Client, TestEnv); + + // Create a queue + HttpClient::Response CreateResp = Client.Post("/queues"sv); + REQUIRE_MESSAGE(CreateResp, "Queue creation failed"); + + const int QueueId = CreateResp.AsObject()["queue_id"sv].AsInt32(); + REQUIRE_MESSAGE(QueueId != 0, "Expected non-zero queue_id from queue creation"); + + // Submit a job + const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString()); + HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildRot13ActionPackage("Hello World"sv)); + REQUIRE_MESSAGE(SubmitResp, fmt::format("Job submission failed: status={}, body={}", int(SubmitResp.StatusCode), SubmitResp.ToText())); + + // Cancel the queue + const std::string QueueUrl = fmt::format("/queues/{}", QueueId); + HttpClient::Response CancelResp = Client.Delete(QueueUrl); + REQUIRE_MESSAGE(CancelResp.StatusCode == HttpResponseCode::NoContent, + fmt::format("Queue cancellation failed: status={}, body={}", int(CancelResp.StatusCode), CancelResp.ToText())); + + // Verify queue status shows cancelled + HttpClient::Response StatusResp = Client.Get(QueueUrl); + REQUIRE_MESSAGE(StatusResp, "Failed to get queue status after cancel"); + + CbObject QueueStatus = StatusResp.AsObject(); + CHECK_EQ(std::string(QueueStatus["state"sv].AsString()), "cancelled"); +} + +TEST_CASE("function.queues.remote") +{ + ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(); + REQUIRE_MESSAGE(Port != 0, Instance.GetLogOutput()); + + const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port); + HttpClient Client(ComputeBaseUri); + + const IoHash WorkerId = RegisterWorker(Client, TestEnv); + + // Create a remote queue — response includes both an integer queue_id and an OID queue_token + HttpClient::Response CreateResp = Client.Post("/queues/remote"sv); + REQUIRE_MESSAGE(CreateResp, + fmt::format("Remote queue creation failed: status={}, body={}", int(CreateResp.StatusCode), CreateResp.ToText())); + + CbObject CreateObj = CreateResp.AsObject(); + const std::string QueueToken = std::string(CreateObj["queue_token"sv].AsString()); + REQUIRE_MESSAGE(!QueueToken.empty(), "Expected non-empty queue_token from remote queue creation"); + + // All subsequent requests use the opaque token in place of the integer queue id + const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueToken, WorkerId.ToHexString()); + HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildRot13ActionPackage("Hello World"sv)); + REQUIRE_MESSAGE(SubmitResp, + fmt::format("Remote queue job submission failed: status={}, body={}", int(SubmitResp.StatusCode), SubmitResp.ToText())); + + const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32(); + REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN from remote queue job submission"); + + // Poll for completion via the token-addressed /completed endpoint + const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueToken); + REQUIRE_MESSAGE( + PollForLsnInCompleted(Client, CompletedUrl, Lsn), + fmt::format("LSN {} did not appear in remote queue completed list within timeout\nServer log:\n{}", Lsn, Instance.GetLogOutput())); + + // Retrieve result via the token-addressed /jobs/{lsn} endpoint + const std::string ResultUrl = fmt::format("/queues/{}/jobs/{}", QueueToken, Lsn); + HttpClient::Response ResultResp = Client.Get(ResultUrl); + REQUIRE_MESSAGE(ResultResp.StatusCode == HttpResponseCode::OK, + fmt::format("Failed to retrieve result from remote queue: status={}\nServer log:\n{}", + int(ResultResp.StatusCode), + Instance.GetLogOutput())); + + // Verify result: Rot13("Hello World") == "Uryyb Jbeyq" + CbPackage ResultPackage = ResultResp.AsPackage(); + REQUIRE_MESSAGE(bool(ResultPackage), fmt::format("Empty result package\nServer log:\n{}", Instance.GetLogOutput())); + + CHECK_EQ(GetRot13Output(ResultPackage), "Uryyb Jbeyq"sv); +} + +TEST_CASE("function.queues.cancel_running") +{ + ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(); + REQUIRE_MESSAGE(Port != 0, Instance.GetLogOutput()); + + const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port); + HttpClient Client(ComputeBaseUri); + + const IoHash WorkerId = RegisterWorker(Client, TestEnv); + + // Create a queue + HttpClient::Response CreateResp = Client.Post("/queues"sv); + REQUIRE_MESSAGE(CreateResp, "Queue creation failed"); + + const int QueueId = CreateResp.AsObject()["queue_id"sv].AsInt32(); + REQUIRE_MESSAGE(QueueId != 0, "Expected non-zero queue_id from queue creation"); + + // Submit a Sleep job long enough that it will still be running when we cancel + const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString()); + HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildSleepActionPackage("data"sv, 30'000)); + REQUIRE_MESSAGE(SubmitResp, + fmt::format("Sleep job submission failed: status={}, body={}", int(SubmitResp.StatusCode), SubmitResp.ToText())); + + const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32(); + REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN from Sleep job submission"); + + // Wait for the worker process to start executing before cancelling + Sleep(1'000); + + // Cancel the queue, which should interrupt the running Sleep job + const std::string QueueUrl = fmt::format("/queues/{}", QueueId); + HttpClient::Response CancelResp = Client.Delete(QueueUrl); + REQUIRE_MESSAGE(CancelResp.StatusCode == HttpResponseCode::NoContent, + fmt::format("Queue cancellation failed: status={}, body={}", int(CancelResp.StatusCode), CancelResp.ToText())); + + // The cancelled job should appear in the /completed endpoint once the process exits + const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueId); + REQUIRE_MESSAGE(PollForLsnInCompleted(Client, CompletedUrl, Lsn), + fmt::format("LSN {} did not appear in queue {} completed list after cancel\nServer log:\n{}", + Lsn, + QueueId, + Instance.GetLogOutput())); + + // Verify the queue reflects one cancelled action + HttpClient::Response StatusResp = Client.Get(QueueUrl); + REQUIRE_MESSAGE(StatusResp, "Failed to get queue status after cancel"); + + CbObject QueueStatus = StatusResp.AsObject(); + CHECK_EQ(std::string(QueueStatus["state"sv].AsString()), "cancelled"); + CHECK_EQ(QueueStatus["cancelled_count"sv].AsInt32(), 1); + CHECK_EQ(QueueStatus["completed_count"sv].AsInt32(), 0); +} + +TEST_CASE("function.queues.remote_cancel") +{ + ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(); + REQUIRE_MESSAGE(Port != 0, Instance.GetLogOutput()); + + const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port); + HttpClient Client(ComputeBaseUri); + + const IoHash WorkerId = RegisterWorker(Client, TestEnv); + + // Create a remote queue to obtain an OID token for token-addressed cancellation + HttpClient::Response CreateResp = Client.Post("/queues/remote"sv); + REQUIRE_MESSAGE(CreateResp, + fmt::format("Remote queue creation failed: status={}, body={}", int(CreateResp.StatusCode), CreateResp.ToText())); + + const std::string QueueToken = std::string(CreateResp.AsObject()["queue_token"sv].AsString()); + REQUIRE_MESSAGE(!QueueToken.empty(), "Expected non-empty queue_token from remote queue creation"); + + // Submit a long-running Sleep job via the token-addressed endpoint + const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueToken, WorkerId.ToHexString()); + HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildSleepActionPackage("data"sv, 30'000)); + REQUIRE_MESSAGE(SubmitResp, + fmt::format("Sleep job submission failed: status={}, body={}", int(SubmitResp.StatusCode), SubmitResp.ToText())); + + const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32(); + REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN from Sleep job submission"); + + // Wait for the worker process to start executing before cancelling + Sleep(1'000); + + // Cancel the queue via its OID token + const std::string QueueUrl = fmt::format("/queues/{}", QueueToken); + HttpClient::Response CancelResp = Client.Delete(QueueUrl); + REQUIRE_MESSAGE(CancelResp.StatusCode == HttpResponseCode::NoContent, + fmt::format("Remote queue cancellation failed: status={}, body={}", int(CancelResp.StatusCode), CancelResp.ToText())); + + // The cancelled job should appear in the token-addressed /completed endpoint + const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueToken); + REQUIRE_MESSAGE( + PollForLsnInCompleted(Client, CompletedUrl, Lsn), + fmt::format("LSN {} did not appear in remote queue completed list after cancel\nServer log:\n{}", Lsn, Instance.GetLogOutput())); + + // Verify the queue status reflects the cancellation + HttpClient::Response StatusResp = Client.Get(QueueUrl); + REQUIRE_MESSAGE(StatusResp, "Failed to get remote queue status after cancel"); + + CbObject QueueStatus = StatusResp.AsObject(); + CHECK_EQ(std::string(QueueStatus["state"sv].AsString()), "cancelled"); + CHECK_EQ(QueueStatus["cancelled_count"sv].AsInt32(), 1); + CHECK_EQ(QueueStatus["completed_count"sv].AsInt32(), 0); +} + +TEST_CASE("function.queues.drain") +{ + ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(); + REQUIRE_MESSAGE(Port != 0, Instance.GetLogOutput()); + + const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port); + HttpClient Client(ComputeBaseUri); + + const IoHash WorkerId = RegisterWorker(Client, TestEnv); + + // Create a queue + HttpClient::Response CreateResp = Client.Post("/queues"sv); + REQUIRE_MESSAGE(CreateResp, "Queue creation failed"); + + const int QueueId = CreateResp.AsObject()["queue_id"sv].AsInt32(); + const std::string QueueUrl = fmt::format("/queues/{}", QueueId); + + // Submit a long-running job so we can verify it completes even after drain + const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString()); + HttpClient::Response Submit1 = Client.Post(JobUrl, BuildSleepActionPackage("data"sv, 2'000)); + REQUIRE_MESSAGE(Submit1, fmt::format("First job submission failed: status={}", int(Submit1.StatusCode))); + const int Lsn1 = Submit1.AsObject()["lsn"sv].AsInt32(); + + // Drain the queue + const std::string DrainUrl = fmt::format("/queues/{}/drain", QueueId); + HttpClient::Response DrainResp = Client.Post(DrainUrl); + REQUIRE_MESSAGE(DrainResp, fmt::format("Drain failed: status={}, body={}", int(DrainResp.StatusCode), DrainResp.ToText())); + CHECK_EQ(std::string(DrainResp.AsObject()["state"sv].AsString()), "draining"); + + // Second submission should be rejected with 424 + HttpClient::Response Submit2 = Client.Post(JobUrl, BuildRot13ActionPackage("Hello"sv)); + CHECK_EQ(Submit2.StatusCode, HttpResponseCode::FailedDependency); + CHECK_EQ(std::string(Submit2.AsObject()["error"sv].AsString()), "queue is draining"); + + // First job should still complete + const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueId); + REQUIRE_MESSAGE(PollForLsnInCompleted(Client, CompletedUrl, Lsn1), + fmt::format("LSN {} did not complete after drain\nServer log:\n{}", Lsn1, Instance.GetLogOutput())); + + // Queue status should show draining + complete + HttpClient::Response StatusResp = Client.Get(QueueUrl); + REQUIRE_MESSAGE(StatusResp, "Failed to get queue status"); + + CbObject QueueStatus = StatusResp.AsObject(); + CHECK_EQ(std::string(QueueStatus["state"sv].AsString()), "draining"); + CHECK(QueueStatus["is_complete"sv].AsBool()); +} + +TEST_CASE("function.priority") +{ + // Spawn server with max-actions=1 to guarantee serialized action execution, + // which lets us deterministically verify that higher-priority pending jobs + // are scheduled before lower-priority ones. + ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady("--max-actions=1"); + REQUIRE_MESSAGE(Port != 0, Instance.GetLogOutput()); + + const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port); + HttpClient Client(ComputeBaseUri); + + const IoHash WorkerId = RegisterWorker(Client, TestEnv); + + // Create a queue for all test jobs + HttpClient::Response CreateResp = Client.Post("/queues"sv); + REQUIRE_MESSAGE(CreateResp, "Queue creation failed"); + + const int QueueId = CreateResp.AsObject()["queue_id"sv].AsInt32(); + REQUIRE_MESSAGE(QueueId != 0, "Expected non-zero queue_id"); + + // Submit a blocker Sleep job to occupy the single execution slot. + // Once the blocker is running, the scheduler must choose among the pending + // jobs by priority when the slot becomes free. + const std::string BlockerJobUrl = fmt::format("/queues/{}/jobs/{}?priority=0", QueueId, WorkerId.ToHexString()); + HttpClient::Response BlockerResp = Client.Post(BlockerJobUrl, BuildSleepActionPackage("data"sv, 1'000)); + REQUIRE_MESSAGE(BlockerResp, fmt::format("Blocker job submission failed: status={}", int(BlockerResp.StatusCode))); + + // Submit 3 low-priority Rot13 jobs + const std::string LowJobUrl = fmt::format("/queues/{}/jobs/{}?priority=0", QueueId, WorkerId.ToHexString()); + + HttpClient::Response LowResp1 = Client.Post(LowJobUrl, BuildRot13ActionPackage("low1"sv)); + REQUIRE_MESSAGE(LowResp1, "Low-priority job 1 submission failed"); + const int LsnLow1 = LowResp1.AsObject()["lsn"sv].AsInt32(); + + HttpClient::Response LowResp2 = Client.Post(LowJobUrl, BuildRot13ActionPackage("low2"sv)); + REQUIRE_MESSAGE(LowResp2, "Low-priority job 2 submission failed"); + const int LsnLow2 = LowResp2.AsObject()["lsn"sv].AsInt32(); + + HttpClient::Response LowResp3 = Client.Post(LowJobUrl, BuildRot13ActionPackage("low3"sv)); + REQUIRE_MESSAGE(LowResp3, "Low-priority job 3 submission failed"); + const int LsnLow3 = LowResp3.AsObject()["lsn"sv].AsInt32(); + + // Submit 1 high-priority Rot13 job — should execute before the low-priority ones + const std::string HighJobUrl = fmt::format("/queues/{}/jobs/{}?priority=10", QueueId, WorkerId.ToHexString()); + HttpClient::Response HighResp = Client.Post(HighJobUrl, BuildRot13ActionPackage("high"sv)); + REQUIRE_MESSAGE(HighResp, "High-priority job submission failed"); + const int LsnHigh = HighResp.AsObject()["lsn"sv].AsInt32(); + + // Wait for all 4 priority-test jobs to appear in the queue's completed list. + // This avoids any snapshot-timing race: by the time we compare timestamps, all + // jobs have already finished and their history entries are stable. + const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueId); + + { + bool AllCompleted = false; + Stopwatch WaitTimer; + + while (!AllCompleted && WaitTimer.GetElapsedTimeMs() < 30'000) + { + HttpClient::Response Resp = Client.Get(CompletedUrl); + + if (Resp) + { + bool FoundHigh = false; + bool FoundLow1 = false; + bool FoundLow2 = false; + bool FoundLow3 = false; + + CbObject RespObj = Resp.AsObject(); + + for (auto& Item : RespObj["completed"sv]) + { + const int Lsn = Item.AsInt32(); + if (Lsn == LsnHigh) + { + FoundHigh = true; + } + else if (Lsn == LsnLow1) + { + FoundLow1 = true; + } + else if (Lsn == LsnLow2) + { + FoundLow2 = true; + } + else if (Lsn == LsnLow3) + { + FoundLow3 = true; + } + } + + AllCompleted = FoundHigh && FoundLow1 && FoundLow2 && FoundLow3; + } + + if (!AllCompleted) + { + Sleep(100); + } + } + + REQUIRE_MESSAGE( + AllCompleted, + fmt::format( + "Not all priority test jobs completed within timeout (lsnHigh={} lsnLow1={} lsnLow2={} lsnLow3={})\nServer log:\n{}", + LsnHigh, + LsnLow1, + LsnLow2, + LsnLow3, + Instance.GetLogOutput())); + } + + // Query the queue-scoped history to obtain the time_Completed timestamp for each + // job. The history endpoint records when each RunnerAction::State transition + // occurred, so time_Completed is the wall-clock tick at which the action finished. + // Using the queue-scoped endpoint avoids exposing history from other queues. + const std::string HistoryUrl = fmt::format("/queues/{}/history", QueueId); + HttpClient::Response HistoryResp = Client.Get(HistoryUrl); + REQUIRE_MESSAGE(HistoryResp, "Failed to query queue action history"); + + CbObject HistoryObj = HistoryResp.AsObject(); + + auto GetCompletedTimestamp = [&](int Lsn) -> uint64_t { + for (auto& Item : HistoryObj["history"sv]) + { + if (Item.AsObjectView()["lsn"sv].AsInt32() == Lsn) + { + return Item.AsObjectView()["time_Completed"sv].AsUInt64(); + } + } + return 0; + }; + + const uint64_t TimeHigh = GetCompletedTimestamp(LsnHigh); + const uint64_t TimeLow1 = GetCompletedTimestamp(LsnLow1); + const uint64_t TimeLow2 = GetCompletedTimestamp(LsnLow2); + const uint64_t TimeLow3 = GetCompletedTimestamp(LsnLow3); + + REQUIRE_MESSAGE(TimeHigh != 0, fmt::format("lsnHigh={} not found in action history", LsnHigh)); + REQUIRE_MESSAGE(TimeLow1 != 0, fmt::format("lsnLow1={} not found in action history", LsnLow1)); + REQUIRE_MESSAGE(TimeLow2 != 0, fmt::format("lsnLow2={} not found in action history", LsnLow2)); + REQUIRE_MESSAGE(TimeLow3 != 0, fmt::format("lsnLow3={} not found in action history", LsnLow3)); + + // The high-priority job must have completed strictly before every low-priority job + CHECK_MESSAGE(TimeHigh < TimeLow1, + fmt::format("Priority ordering violated: lsnHigh={} completed at t={} but lsnLow1={} completed at t={} (expected later)", + LsnHigh, + TimeHigh, + LsnLow1, + TimeLow1)); + CHECK_MESSAGE(TimeHigh < TimeLow2, + fmt::format("Priority ordering violated: lsnHigh={} completed at t={} but lsnLow2={} completed at t={} (expected later)", + LsnHigh, + TimeHigh, + LsnLow2, + TimeLow2)); + CHECK_MESSAGE(TimeHigh < TimeLow3, + fmt::format("Priority ordering violated: lsnHigh={} completed at t={} but lsnLow3={} completed at t={} (expected later)", + LsnHigh, + TimeHigh, + LsnLow3, + TimeLow3)); +} + +////////////////////////////////////////////////////////////////////////// +// Remote worker synchronization tests +// +// These tests exercise the orchestrator discovery path where new compute +// nodes appear over time and must receive previously registered workers +// via SyncWorkersToRunner(). + +TEST_CASE("function.remote.worker_sync_on_discovery") +{ + // Spawn real zenserver in compute mode + ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + const uint16_t ServerPort = Instance.SpawnServerAndWaitUntilReady(); + REQUIRE_MESSAGE(ServerPort != 0, Instance.GetLogOutput()); + + const std::string ServerUri = fmt::format("http://localhost:{}", ServerPort); + + // Start mock orchestrator with empty worker list + MockOrchestratorFixture MockOrch; + + // Create session infrastructure + InMemoryChunkResolver Resolver; + ScopedTemporaryDirectory SessionBaseDir; + zen::compute::ComputeServiceSession Session(Resolver); + Session.SetOrchestratorEndpoint(MockOrch.GetEndpoint()); + Session.SetOrchestratorBasePath(SessionBaseDir.Path()); + Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Ready); + + // Register worker on session (stored locally, no runners yet) + CbPackage WorkerPackage = BuildWorkerPackage(TestEnv, Resolver); + Session.RegisterWorker(WorkerPackage); + + // Update mock orchestrator to advertise the real server + MockOrch.Service.SetWorkerList(BuildAgentListResponse({{"worker-1", ServerUri}})); + + // Wait for scheduler to discover the runner (~5s throttle + margin) + Sleep(7'000); + + // Submit Rot13 action via session + CbObject ActionObj = BuildRot13ActionForSession("Hello World"sv, Resolver); + + zen::compute::ComputeServiceSession::EnqueueResult EnqueueRes = Session.EnqueueAction(ActionObj, 0); + REQUIRE_MESSAGE(EnqueueRes, "Action enqueue failed"); + + // Poll for result + CbPackage ResultPackage; + HttpResponseCode ResultCode = HttpResponseCode::Accepted; + Stopwatch Timer; + + while (Timer.GetElapsedTimeMs() < 30'000) + { + ResultCode = Session.GetActionResult(EnqueueRes.Lsn, ResultPackage); + if (ResultCode == HttpResponseCode::OK) + { + break; + } + Sleep(200); + } + + REQUIRE_MESSAGE( + ResultCode == HttpResponseCode::OK, + fmt::format("Action did not complete in time. Last status: {}\nServer log:\n{}", int(ResultCode), Instance.GetLogOutput())); + + REQUIRE_MESSAGE(bool(ResultPackage), fmt::format("Empty result package\nServer log:\n{}", Instance.GetLogOutput())); + + CHECK_EQ(GetRot13Output(ResultPackage), "Uryyb Jbeyq"sv); + + Session.Shutdown(); +} + +TEST_CASE("function.remote.late_runner_discovery") +{ + // Spawn first server + ZenServerInstance Instance1(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance1.SetDataDir(TestEnv.CreateNewTestDir()); + const uint16_t Port1 = Instance1.SpawnServerAndWaitUntilReady(); + REQUIRE_MESSAGE(Port1 != 0, Instance1.GetLogOutput()); + + const std::string ServerUri1 = fmt::format("http://localhost:{}", Port1); + + // Start mock orchestrator advertising W1 + MockOrchestratorFixture MockOrch; + MockOrch.Service.SetWorkerList(BuildAgentListResponse({{"worker-1", ServerUri1}})); + + // Create session and register worker + InMemoryChunkResolver Resolver; + ScopedTemporaryDirectory SessionBaseDir; + zen::compute::ComputeServiceSession Session(Resolver); + Session.SetOrchestratorEndpoint(MockOrch.GetEndpoint()); + Session.SetOrchestratorBasePath(SessionBaseDir.Path()); + Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Ready); + + CbPackage WorkerPackage = BuildWorkerPackage(TestEnv, Resolver); + Session.RegisterWorker(WorkerPackage); + + // Wait for W1 discovery + Sleep(7'000); + + // Baseline: submit Rot13 action and verify it completes on W1 + { + CbObject ActionObj = BuildRot13ActionForSession("Hello World"sv, Resolver); + + zen::compute::ComputeServiceSession::EnqueueResult EnqueueRes = Session.EnqueueAction(ActionObj, 0); + REQUIRE_MESSAGE(EnqueueRes, "Baseline action enqueue failed"); + + CbPackage ResultPackage; + HttpResponseCode ResultCode = HttpResponseCode::Accepted; + Stopwatch Timer; + + while (Timer.GetElapsedTimeMs() < 30'000) + { + ResultCode = Session.GetActionResult(EnqueueRes.Lsn, ResultPackage); + if (ResultCode == HttpResponseCode::OK) + { + break; + } + Sleep(200); + } + + REQUIRE_MESSAGE(ResultCode == HttpResponseCode::OK, + fmt::format("Baseline action did not complete in time\nServer log:\n{}", Instance1.GetLogOutput())); + + CHECK_EQ(GetRot13Output(ResultPackage), "Uryyb Jbeyq"sv); + } + + // Spawn second server + ZenServerInstance Instance2(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance2.SetDataDir(TestEnv.CreateNewTestDir()); + const uint16_t Port2 = Instance2.SpawnServerAndWaitUntilReady(); + REQUIRE_MESSAGE(Port2 != 0, Instance2.GetLogOutput()); + + const std::string ServerUri2 = fmt::format("http://localhost:{}", Port2); + + // Update mock orchestrator to include both W1 and W2 + MockOrch.Service.SetWorkerList(BuildAgentListResponse({{"worker-1", ServerUri1}, {"worker-2", ServerUri2}})); + + // Wait for W2 discovery + Sleep(7'000); + + // Verify W2 received the worker by querying its /compute/workers endpoint directly + { + const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port2); + HttpClient Client(ComputeBaseUri); + HttpClient::Response ListResp = Client.Get("/workers"sv); + REQUIRE_MESSAGE(ListResp, "Failed to list workers on W2"); + + bool WorkerFound = false; + for (auto& Item : ListResp.AsObject()["workers"sv]) + { + if (Item.AsHash() == WorkerPackage.GetObjectHash()) + { + WorkerFound = true; + break; + } + } + + REQUIRE_MESSAGE(WorkerFound, + fmt::format("Worker not found on W2 after discovery — SyncWorkersToRunner may have failed\nW2 log:\n{}", + Instance2.GetLogOutput())); + } + + // Submit another action and verify it completes (could run on either W1 or W2) + { + CbObject ActionObj = BuildRot13ActionForSession("Second Test"sv, Resolver); + + zen::compute::ComputeServiceSession::EnqueueResult EnqueueRes = Session.EnqueueAction(ActionObj, 0); + REQUIRE_MESSAGE(EnqueueRes, "Second action enqueue failed"); + + CbPackage ResultPackage; + HttpResponseCode ResultCode = HttpResponseCode::Accepted; + Stopwatch Timer; + + while (Timer.GetElapsedTimeMs() < 30'000) + { + ResultCode = Session.GetActionResult(EnqueueRes.Lsn, ResultPackage); + if (ResultCode == HttpResponseCode::OK) + { + break; + } + Sleep(200); + } + + REQUIRE_MESSAGE(ResultCode == HttpResponseCode::OK, + fmt::format("Second action did not complete in time\nW1 log:\n{}\nW2 log:\n{}", + Instance1.GetLogOutput(), + Instance2.GetLogOutput())); + + // Rot13("Second Test") = "Frpbaq Grfg" + CHECK_EQ(GetRot13Output(ResultPackage), "Frpbaq Grfg"sv); + } + + Session.Shutdown(); +} + +TEST_CASE("function.remote.queue_association") +{ + // Spawn real zenserver as a remote compute node + ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + REQUIRE_MESSAGE(Instance.SpawnServerAndWaitUntilReady() != 0, Instance.GetLogOutput()); + + // Start mock orchestrator advertising the server + MockOrchestratorFixture MockOrch; + MockOrch.Service.SetWorkerList(BuildAgentListResponse({{"worker-1", Instance.GetBaseUri()}})); + + // Create session infrastructure + InMemoryChunkResolver Resolver; + ScopedTemporaryDirectory SessionBaseDir; + zen::compute::ComputeServiceSession Session(Resolver); + Session.SetOrchestratorEndpoint(MockOrch.GetEndpoint()); + Session.SetOrchestratorBasePath(SessionBaseDir.Path()); + Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Ready); + + // Register worker on session + CbPackage WorkerPackage = BuildWorkerPackage(TestEnv, Resolver); + Session.RegisterWorker(WorkerPackage); + + // Wait for scheduler to discover the runner + Sleep(7'000); + + // Create a local queue and submit action to it + auto QueueResult = Session.CreateQueue(); + REQUIRE_MESSAGE(QueueResult.QueueId != 0, "Failed to create local queue"); + const int QueueId = QueueResult.QueueId; + + CbObject ActionObj = BuildRot13ActionForSession("Hello World"sv, Resolver); + + zen::compute::ComputeServiceSession::EnqueueResult EnqueueRes = Session.EnqueueActionToQueue(QueueId, ActionObj, 0); + REQUIRE_MESSAGE(EnqueueRes, "Action enqueue to queue failed"); + + // Poll for result + CbPackage ResultPackage; + HttpResponseCode ResultCode = HttpResponseCode::Accepted; + Stopwatch Timer; + + while (Timer.GetElapsedTimeMs() < 30'000) + { + ResultCode = Session.GetActionResult(EnqueueRes.Lsn, ResultPackage); + if (ResultCode == HttpResponseCode::OK) + { + break; + } + Sleep(200); + } + + REQUIRE_MESSAGE( + ResultCode == HttpResponseCode::OK, + fmt::format("Action did not complete in time. Last status: {}\nServer log:\n{}", int(ResultCode), Instance.GetLogOutput())); + + REQUIRE_MESSAGE(bool(ResultPackage), fmt::format("Empty result package\nServer log:\n{}", Instance.GetLogOutput())); + CHECK_EQ(GetRot13Output(ResultPackage), "Uryyb Jbeyq"sv); + + // Verify that a non-implicit remote queue was created on the compute node + HttpClient Client(Instance.GetBaseUri() + "/compute"); + + HttpClient::Response QueuesResp = Client.Get("/queues"sv); + REQUIRE_MESSAGE(QueuesResp, "Failed to list queues on remote server"); + + bool RemoteQueueFound = false; + for (auto& Item : QueuesResp.AsObject()["queues"sv]) + { + if (!Item.AsObjectView()["implicit"sv].AsBool()) + { + RemoteQueueFound = true; + break; + } + } + + CHECK_MESSAGE(RemoteQueueFound, "Expected a non-implicit remote queue on the compute node"); + + Session.Shutdown(); +} + +TEST_CASE("function.remote.queue_cancel_propagation") +{ + // Spawn real zenserver as a remote compute node + ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + REQUIRE_MESSAGE(Instance.SpawnServerAndWaitUntilReady() != 0, Instance.GetLogOutput()); + + // Start mock orchestrator advertising the server + MockOrchestratorFixture MockOrch; + MockOrch.Service.SetWorkerList(BuildAgentListResponse({{"worker-1", Instance.GetBaseUri()}})); + + // Create session infrastructure + InMemoryChunkResolver Resolver; + ScopedTemporaryDirectory SessionBaseDir; + zen::compute::ComputeServiceSession Session(Resolver); + Session.SetOrchestratorEndpoint(MockOrch.GetEndpoint()); + Session.SetOrchestratorBasePath(SessionBaseDir.Path()); + Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Ready); + + // Register worker on session + CbPackage WorkerPackage = BuildWorkerPackage(TestEnv, Resolver); + Session.RegisterWorker(WorkerPackage); + + // Wait for scheduler to discover the runner + Sleep(7'000); + + // Create a local queue and submit a long-running Sleep action + auto QueueResult = Session.CreateQueue(); + REQUIRE_MESSAGE(QueueResult.QueueId != 0, "Failed to create local queue"); + const int QueueId = QueueResult.QueueId; + + CbObject ActionObj = BuildSleepActionForSession("data"sv, 30'000, Resolver); + + zen::compute::ComputeServiceSession::EnqueueResult EnqueueRes = Session.EnqueueActionToQueue(QueueId, ActionObj, 0); + REQUIRE_MESSAGE(EnqueueRes, "Sleep action enqueue to queue failed"); + + // Wait for the action to start running on the remote + Sleep(2'000); + + // Cancel the local queue — this should propagate to the remote + Session.CancelQueue(QueueId); + + // Poll for the action to complete (as cancelled) + CbPackage ResultPackage; + HttpResponseCode ResultCode = HttpResponseCode::Accepted; + Stopwatch Timer; + + while (Timer.GetElapsedTimeMs() < 30'000) + { + ResultCode = Session.GetActionResult(EnqueueRes.Lsn, ResultPackage); + if (ResultCode == HttpResponseCode::OK) + { + break; + } + Sleep(200); + } + + // Verify the local queue shows cancelled + auto QueueStatus = Session.GetQueueStatus(QueueId); + CHECK(QueueStatus.State == zen::compute::ComputeServiceSession::QueueState::Cancelled); + + // Verify the remote queue on the compute node is also cancelled + HttpClient Client(Instance.GetBaseUri() + "/compute"); + + HttpClient::Response QueuesResp = Client.Get("/queues"sv); + REQUIRE_MESSAGE(QueuesResp, "Failed to list queues on remote server"); + + bool RemoteQueueCancelled = false; + for (auto& Item : QueuesResp.AsObject()["queues"sv]) + { + if (!Item.AsObjectView()["implicit"sv].AsBool()) + { + RemoteQueueCancelled = std::string(Item.AsObjectView()["state"sv].AsString()) == "cancelled"; + break; + } + } + + CHECK_MESSAGE(RemoteQueueCancelled, "Expected the remote queue to be cancelled"); + + Session.Shutdown(); +} + +TEST_CASE("function.abandon_running_http") +{ + // Spawn a real zenserver to execute a long-running action, then abandon via HTTP endpoint + ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(); + REQUIRE_MESSAGE(Port != 0, Instance.GetLogOutput()); + + const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port); + HttpClient Client(ComputeBaseUri); + + const IoHash WorkerId = RegisterWorker(Client, TestEnv); + + // Create a queue and submit a long-running Sleep job + HttpClient::Response CreateResp = Client.Post("/queues"sv); + REQUIRE_MESSAGE(CreateResp, "Queue creation failed"); + + const int QueueId = CreateResp.AsObject()["queue_id"sv].AsInt32(); + REQUIRE_MESSAGE(QueueId != 0, "Expected non-zero queue_id"); + + const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString()); + HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildSleepActionPackage("data"sv, 30'000)); + REQUIRE_MESSAGE(SubmitResp, fmt::format("Sleep job submission failed: status={}", int(SubmitResp.StatusCode))); + + const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32(); + REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN"); + + // Wait for the process to start running + Sleep(1'000); + + // Verify the ready endpoint returns OK before abandon + { + HttpClient::Response ReadyResp = Client.Get("/ready"sv); + CHECK(ReadyResp.StatusCode == HttpResponseCode::OK); + } + + // Trigger abandon via the HTTP endpoint + HttpClient::Response AbandonResp = Client.Post("/abandon"sv); + REQUIRE_MESSAGE(AbandonResp.StatusCode == HttpResponseCode::OK, + fmt::format("Abandon request failed: status={}, body={}", int(AbandonResp.StatusCode), AbandonResp.ToText())); + + // Ready endpoint should now return 503 + { + HttpClient::Response ReadyResp = Client.Get("/ready"sv); + CHECK(ReadyResp.StatusCode == HttpResponseCode::ServiceUnavailable); + } + + // The abandoned action should appear in the completed endpoint once the process exits + const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueId); + REQUIRE_MESSAGE(PollForLsnInCompleted(Client, CompletedUrl, Lsn), + fmt::format("LSN {} did not appear in queue {} completed list after abandon\nServer log:\n{}", + Lsn, + QueueId, + Instance.GetLogOutput())); + + // Verify the queue reflects one abandoned action + const std::string QueueUrl = fmt::format("/queues/{}", QueueId); + HttpClient::Response StatusResp = Client.Get(QueueUrl); + REQUIRE_MESSAGE(StatusResp, "Failed to get queue status after abandon"); + + CbObject QueueStatus = StatusResp.AsObject(); + CHECK_EQ(QueueStatus["abandoned_count"sv].AsInt32(), 1); + CHECK_EQ(QueueStatus["completed_count"sv].AsInt32(), 0); + CHECK_EQ(QueueStatus["active_count"sv].AsInt32(), 0); + + // Submitting new work should be rejected + HttpClient::Response RejectedResp = Client.Post(JobUrl, BuildRot13ActionPackage("rejected"sv)); + CHECK_MESSAGE(RejectedResp.StatusCode != HttpResponseCode::OK, "Expected action submission to be rejected in Abandoned state"); +} + +TEST_CASE("function.session.abandon_pending") +{ + // Create a session with no runners so actions stay pending + InMemoryChunkResolver Resolver; + ScopedTemporaryDirectory SessionBaseDir; + zen::compute::ComputeServiceSession Session(Resolver); + Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Ready); + + CbPackage WorkerPackage = BuildWorkerPackage(TestEnv, Resolver); + Session.RegisterWorker(WorkerPackage); + + // Enqueue several actions — they will stay pending because there are no runners + auto QueueResult = Session.CreateQueue(); + REQUIRE_MESSAGE(QueueResult.QueueId != 0, "Failed to create queue"); + + CbObject ActionObj = BuildRot13ActionForSession("abandon-test"sv, Resolver); + + auto Enqueue1 = Session.EnqueueActionToQueue(QueueResult.QueueId, ActionObj, 0); + auto Enqueue2 = Session.EnqueueActionToQueue(QueueResult.QueueId, ActionObj, 0); + auto Enqueue3 = Session.EnqueueActionToQueue(QueueResult.QueueId, ActionObj, 0); + REQUIRE_MESSAGE(Enqueue1, "Failed to enqueue action 1"); + REQUIRE_MESSAGE(Enqueue2, "Failed to enqueue action 2"); + REQUIRE_MESSAGE(Enqueue3, "Failed to enqueue action 3"); + + // Transition to Abandoned — should mark all pending actions as Abandoned + bool Transitioned = Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Abandoned); + CHECK_MESSAGE(Transitioned, "Failed to transition to Abandoned"); + CHECK(Session.GetSessionState() == zen::compute::ComputeServiceSession::SessionState::Abandoned); + CHECK(!Session.IsHealthy()); + + // Give the scheduler thread time to process the state changes + Sleep(2'000); + + // All three actions should now be in the results map as abandoned + for (int Lsn : {Enqueue1.Lsn, Enqueue2.Lsn, Enqueue3.Lsn}) + { + CbPackage Result; + HttpResponseCode Code = Session.GetActionResult(Lsn, Result); + CHECK_MESSAGE(Code == HttpResponseCode::OK, fmt::format("Expected action LSN {} to be in results (got {})", Lsn, int(Code))); + } + + // Queue should show 0 active, 3 abandoned + auto Status = Session.GetQueueStatus(QueueResult.QueueId); + CHECK_EQ(Status.ActiveCount, 0); + CHECK_EQ(Status.AbandonedCount, 3); + + // New actions should be rejected + auto Rejected = Session.EnqueueActionToQueue(QueueResult.QueueId, ActionObj, 0); + CHECK_MESSAGE(!Rejected, "Expected action submission to be rejected in Abandoned state"); + + // Abandoned → Sunset should be valid + CHECK(Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Sunset)); + + Session.Shutdown(); +} + +TEST_CASE("function.session.abandon_running") +{ + // Spawn a real zenserver as a remote compute node + ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + REQUIRE_MESSAGE(Instance.SpawnServerAndWaitUntilReady() != 0, Instance.GetLogOutput()); + + // Start mock orchestrator advertising the server + MockOrchestratorFixture MockOrch; + MockOrch.Service.SetWorkerList(BuildAgentListResponse({{"worker-1", Instance.GetBaseUri()}})); + + // Create session infrastructure + InMemoryChunkResolver Resolver; + ScopedTemporaryDirectory SessionBaseDir; + zen::compute::ComputeServiceSession Session(Resolver); + Session.SetOrchestratorEndpoint(MockOrch.GetEndpoint()); + Session.SetOrchestratorBasePath(SessionBaseDir.Path()); + Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Ready); + + CbPackage WorkerPackage = BuildWorkerPackage(TestEnv, Resolver); + Session.RegisterWorker(WorkerPackage); + + // Wait for scheduler to discover the runner + Sleep(7'000); + + // Create a queue and submit a long-running Sleep action + auto QueueResult = Session.CreateQueue(); + REQUIRE_MESSAGE(QueueResult.QueueId != 0, "Failed to create queue"); + const int QueueId = QueueResult.QueueId; + + CbObject ActionObj = BuildSleepActionForSession("data"sv, 30'000, Resolver); + + auto EnqueueRes = Session.EnqueueActionToQueue(QueueId, ActionObj, 0); + REQUIRE_MESSAGE(EnqueueRes, "Sleep action enqueue to queue failed"); + + // Wait for the action to start running on the remote + Sleep(2'000); + + // Transition to Abandoned — should abandon the running action + bool Transitioned = Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Abandoned); + CHECK_MESSAGE(Transitioned, "Failed to transition to Abandoned"); + CHECK(!Session.IsHealthy()); + + // Poll for the action to complete (as abandoned) + CbPackage ResultPackage; + HttpResponseCode ResultCode = HttpResponseCode::Accepted; + Stopwatch Timer; + + while (Timer.GetElapsedTimeMs() < 30'000) + { + ResultCode = Session.GetActionResult(EnqueueRes.Lsn, ResultPackage); + if (ResultCode == HttpResponseCode::OK) + { + break; + } + Sleep(200); + } + + REQUIRE_MESSAGE(ResultCode == HttpResponseCode::OK, + fmt::format("Action did not complete within timeout\nServer log:\n{}", Instance.GetLogOutput())); + + // Verify the queue shows abandoned, not completed + auto QueueStatus = Session.GetQueueStatus(QueueId); + CHECK_EQ(QueueStatus.ActiveCount, 0); + CHECK_EQ(QueueStatus.AbandonedCount, 1); + CHECK_EQ(QueueStatus.CompletedCount, 0); + + Session.Shutdown(); +} + +TEST_CASE("function.remote.abandon_propagation") +{ + // Spawn real zenserver as a remote compute node + ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + REQUIRE_MESSAGE(Instance.SpawnServerAndWaitUntilReady() != 0, Instance.GetLogOutput()); + + // Start mock orchestrator advertising the server + MockOrchestratorFixture MockOrch; + MockOrch.Service.SetWorkerList(BuildAgentListResponse({{"worker-1", Instance.GetBaseUri()}})); + + // Create session infrastructure + InMemoryChunkResolver Resolver; + ScopedTemporaryDirectory SessionBaseDir; + zen::compute::ComputeServiceSession Session(Resolver); + Session.SetOrchestratorEndpoint(MockOrch.GetEndpoint()); + Session.SetOrchestratorBasePath(SessionBaseDir.Path()); + Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Ready); + + // Register worker on session + CbPackage WorkerPackage = BuildWorkerPackage(TestEnv, Resolver); + Session.RegisterWorker(WorkerPackage); + + // Wait for scheduler to discover the runner + Sleep(7'000); + + // Create a local queue and submit a long-running Sleep action + auto QueueResult = Session.CreateQueue(); + REQUIRE_MESSAGE(QueueResult.QueueId != 0, "Failed to create local queue"); + const int QueueId = QueueResult.QueueId; + + CbObject ActionObj = BuildSleepActionForSession("data"sv, 30'000, Resolver); + + auto EnqueueRes = Session.EnqueueActionToQueue(QueueId, ActionObj, 0); + REQUIRE_MESSAGE(EnqueueRes, "Sleep action enqueue to queue failed"); + + // Wait for the action to start running on the remote + Sleep(2'000); + + // Transition to Abandoned — should abandon the running action and propagate + bool Transitioned = Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Abandoned); + CHECK_MESSAGE(Transitioned, "Failed to transition to Abandoned"); + + // Poll for the action to complete + CbPackage ResultPackage; + HttpResponseCode ResultCode = HttpResponseCode::Accepted; + Stopwatch Timer; + + while (Timer.GetElapsedTimeMs() < 30'000) + { + ResultCode = Session.GetActionResult(EnqueueRes.Lsn, ResultPackage); + if (ResultCode == HttpResponseCode::OK) + { + break; + } + Sleep(200); + } + + REQUIRE_MESSAGE(ResultCode == HttpResponseCode::OK, + fmt::format("Action did not complete within timeout\nServer log:\n{}", Instance.GetLogOutput())); + + // Verify the local queue shows abandoned + auto QueueStatus = Session.GetQueueStatus(QueueId); + CHECK_EQ(QueueStatus.ActiveCount, 0); + CHECK_EQ(QueueStatus.AbandonedCount, 1); + + // Session should not be healthy + CHECK(!Session.IsHealthy()); + + // The remote compute node should still be healthy (only the parent abandoned) + HttpClient RemoteClient(Instance.GetBaseUri() + "/compute"); + HttpClient::Response ReadyResp = RemoteClient.Get("/ready"sv); + CHECK_MESSAGE(ReadyResp.StatusCode == HttpResponseCode::OK, "Remote compute node should still be healthy"); + + Session.Shutdown(); +} + +TEST_SUITE_END(); + +} // namespace zen::tests::compute + +#endif diff --git a/src/zenserver-test/function-tests.cpp b/src/zenserver-test/function-tests.cpp deleted file mode 100644 index 82848c6ad..000000000 --- a/src/zenserver-test/function-tests.cpp +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#include <zencore/zencore.h> - -#if ZEN_WITH_TESTS - -# include <zencore/compactbinary.h> -# include <zencore/compactbinarybuilder.h> -# include <zencore/string.h> -# include <zencore/testing.h> -# include <zenutil/zenserverprocess.h> - -# include "zenserver-test.h" - -namespace zen::tests { - -using namespace std::literals; - -TEST_SUITE_BEGIN("server.function"); - -TEST_CASE("function.run") -{ - std::filesystem::path TestDir = TestEnv.CreateNewTestDir(); - - ZenServerInstance Instance(TestEnv); - Instance.SetDataDir(TestDir); - Instance.SpawnServer(13337); - - ZEN_INFO("Waiting..."); - - Instance.WaitUntilReady(); -} - -TEST_SUITE_END(); - -} // namespace zen::tests - -#endif diff --git a/src/zenserver-test/logging-tests.cpp b/src/zenserver-test/logging-tests.cpp new file mode 100644 index 000000000..fe39e14c0 --- /dev/null +++ b/src/zenserver-test/logging-tests.cpp @@ -0,0 +1,257 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencore/zencore.h> + +#if ZEN_WITH_TESTS + +# include "zenserver-test.h" + +# include <zencore/filesystem.h> +# include <zencore/logging.h> +# include <zencore/testing.h> +# include <zenutil/zenserverprocess.h> + +namespace zen::tests { + +using namespace std::literals; + +////////////////////////////////////////////////////////////////////////// + +static bool +LogContains(const std::string& Log, std::string_view Needle) +{ + return Log.find(Needle) != std::string::npos; +} + +static std::string +ReadFileToString(const std::filesystem::path& Path) +{ + FileContents Contents = ReadFile(Path); + if (Contents.ErrorCode) + { + return {}; + } + + IoBuffer Content = Contents.Flatten(); + if (!Content) + { + return {}; + } + + return std::string(static_cast<const char*>(Content.Data()), Content.Size()); +} + +////////////////////////////////////////////////////////////////////////// + +// Verify that a log file is created at the default location (DataDir/logs/zenserver.log) +// even without --abslog. The file must contain "server session id" (logged at INFO +// to all registered loggers during init) and "log starting at" (emitted once a file +// sink is first opened). +TEST_CASE("logging.file.default") +{ + const std::filesystem::path TestDir = TestEnv.CreateNewTestDir(); + + ZenServerInstance Instance(TestEnv); + Instance.SetDataDir(TestDir); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(); + CHECK_MESSAGE(Port != 0, Instance.GetLogOutput()); + + Instance.Shutdown(); + + const std::filesystem::path DefaultLogFile = TestDir / "logs" / "zenserver.log"; + CHECK_MESSAGE(std::filesystem::exists(DefaultLogFile), "Default log file was not created"); + const std::string FileLog = ReadFileToString(DefaultLogFile); + CHECK_MESSAGE(LogContains(FileLog, "server session id"), FileLog); + CHECK_MESSAGE(LogContains(FileLog, "log starting at"), FileLog); +} + +// --quiet sets the console sink level to WARN. The formatted "[info] ..." +// entry written by the default logger's console sink must therefore not appear +// in captured stdout. (The "console" named logger — used by ZEN_CONSOLE_* +// macros — may still emit plain-text messages without a level marker, so we +// check for the absence of the full_formatter "[info]" prefix rather than the +// message text itself.) +TEST_CASE("logging.console.quiet") +{ + ZenServerInstance Instance(TestEnv); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady("--quiet"); + CHECK_MESSAGE(Port != 0, Instance.GetLogOutput()); + + Instance.Shutdown(); + + const std::string Log = Instance.GetLogOutput(); + CHECK_MESSAGE(!LogContains(Log, "[info] server session id"), Log); +} + +// --noconsole removes the stdout sink entirely, so the captured console output +// must not contain any log entries from the logging system. +TEST_CASE("logging.console.disabled") +{ + ZenServerInstance Instance(TestEnv); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady("--noconsole"); + CHECK_MESSAGE(Port != 0, Instance.GetLogOutput()); + + Instance.Shutdown(); + + const std::string Log = Instance.GetLogOutput(); + CHECK_MESSAGE(!LogContains(Log, "server session id"), Log); +} + +// --abslog <path> creates a rotating log file at the specified path. +// The file must contain "server session id" (logged at INFO to all loggers +// during init) and "log starting at" (emitted once a file sink is active). +TEST_CASE("logging.file.basic") +{ + const std::filesystem::path TestDir = TestEnv.CreateNewTestDir(); + const std::filesystem::path LogFile = TestDir / "test.log"; + + ZenServerInstance Instance(TestEnv); + Instance.SetDataDir(TestDir); + + const std::string LogArg = fmt::format("--abslog {}", LogFile.string()); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(LogArg); + CHECK_MESSAGE(Port != 0, Instance.GetLogOutput()); + + Instance.Shutdown(); + + CHECK_MESSAGE(std::filesystem::exists(LogFile), "Log file was not created"); + const std::string FileLog = ReadFileToString(LogFile); + CHECK_MESSAGE(LogContains(FileLog, "server session id"), FileLog); + CHECK_MESSAGE(LogContains(FileLog, "log starting at"), FileLog); +} + +// --abslog with a .json extension selects the JSON formatter. +// Each log entry must be a JSON object containing at least the "message" +// and "source" fields. +TEST_CASE("logging.file.json") +{ + const std::filesystem::path TestDir = TestEnv.CreateNewTestDir(); + const std::filesystem::path LogFile = TestDir / "test.json"; + + ZenServerInstance Instance(TestEnv); + Instance.SetDataDir(TestDir); + + const std::string LogArg = fmt::format("--abslog {}", LogFile.string()); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(LogArg); + CHECK_MESSAGE(Port != 0, Instance.GetLogOutput()); + + Instance.Shutdown(); + + CHECK_MESSAGE(std::filesystem::exists(LogFile), "JSON log file was not created"); + const std::string FileLog = ReadFileToString(LogFile); + CHECK_MESSAGE(LogContains(FileLog, "\"message\""), FileLog); + CHECK_MESSAGE(LogContains(FileLog, "\"source\": \"zenserver\""), FileLog); + CHECK_MESSAGE(LogContains(FileLog, "server session id"), FileLog); +} + +// --log-id <id> is automatically set to the server instance name in test mode. +// The JSON formatter emits this value as the "id" field, so every entry in a +// .json log file must carry a non-empty "id". +TEST_CASE("logging.log_id") +{ + const std::filesystem::path TestDir = TestEnv.CreateNewTestDir(); + const std::filesystem::path LogFile = TestDir / "test.json"; + + ZenServerInstance Instance(TestEnv); + Instance.SetDataDir(TestDir); + + const std::string LogArg = fmt::format("--abslog {}", LogFile.string()); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(LogArg); + CHECK_MESSAGE(Port != 0, Instance.GetLogOutput()); + + Instance.Shutdown(); + + CHECK_MESSAGE(std::filesystem::exists(LogFile), "JSON log file was not created"); + const std::string FileLog = ReadFileToString(LogFile); + // The JSON formatter writes the log-id as: "id": "<value>", + CHECK_MESSAGE(LogContains(FileLog, "\"id\": \""), FileLog); +} + +// --log-warn <logger> raises the level threshold above INFO so that INFO messages +// are filtered. "server session id" is broadcast at INFO to all loggers: it must +// appear in the main file sink (default logger unaffected) but must NOT appear in +// http.log where the http_requests logger now has a WARN threshold. +TEST_CASE("logging.level.warn_suppresses_info") +{ + const std::filesystem::path TestDir = TestEnv.CreateNewTestDir(); + const std::filesystem::path LogFile = TestDir / "test.log"; + + ZenServerInstance Instance(TestEnv); + Instance.SetDataDir(TestDir); + + const std::string LogArg = fmt::format("--abslog {} --log-warn http_requests", LogFile.string()); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(LogArg); + CHECK_MESSAGE(Port != 0, Instance.GetLogOutput()); + + Instance.Shutdown(); + + CHECK_MESSAGE(std::filesystem::exists(LogFile), "Log file was not created"); + const std::string FileLog = ReadFileToString(LogFile); + CHECK_MESSAGE(LogContains(FileLog, "server session id"), FileLog); + + const std::filesystem::path HttpLogFile = TestDir / "logs" / "http.log"; + CHECK_MESSAGE(std::filesystem::exists(HttpLogFile), "http.log was not created"); + const std::string HttpLog = ReadFileToString(HttpLogFile); + CHECK_MESSAGE(!LogContains(HttpLog, "server session id"), HttpLog); +} + +// --log-info <logger> sets an explicit INFO threshold. The INFO "server session id" +// broadcast must still land in http.log, confirming that INFO messages are not +// filtered when the logger level is exactly INFO. +TEST_CASE("logging.level.info_allows_info") +{ + const std::filesystem::path TestDir = TestEnv.CreateNewTestDir(); + const std::filesystem::path LogFile = TestDir / "test.log"; + + ZenServerInstance Instance(TestEnv); + Instance.SetDataDir(TestDir); + + const std::string LogArg = fmt::format("--abslog {} --log-info http_requests", LogFile.string()); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(LogArg); + CHECK_MESSAGE(Port != 0, Instance.GetLogOutput()); + + Instance.Shutdown(); + + const std::filesystem::path HttpLogFile = TestDir / "logs" / "http.log"; + CHECK_MESSAGE(std::filesystem::exists(HttpLogFile), "http.log was not created"); + const std::string HttpLog = ReadFileToString(HttpLogFile); + CHECK_MESSAGE(LogContains(HttpLog, "server session id"), HttpLog); +} + +// --log-off <logger> silences a named logger entirely. +// "server session id" is broadcast at INFO to all registered loggers via +// spdlog::apply_all during init. When the "http_requests" logger is set to +// OFF its dedicated http.log file must not contain that message. +// The main file sink (via --abslog) must be unaffected. +TEST_CASE("logging.level.off_specific_logger") +{ + const std::filesystem::path TestDir = TestEnv.CreateNewTestDir(); + const std::filesystem::path LogFile = TestDir / "test.log"; + + ZenServerInstance Instance(TestEnv); + Instance.SetDataDir(TestDir); + + const std::string LogArg = fmt::format("--abslog {} --log-off http_requests", LogFile.string()); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(LogArg); + CHECK_MESSAGE(Port != 0, Instance.GetLogOutput()); + + Instance.Shutdown(); + + // Main log file must still have the startup message + CHECK_MESSAGE(std::filesystem::exists(LogFile), "Log file was not created"); + const std::string FileLog = ReadFileToString(LogFile); + CHECK_MESSAGE(LogContains(FileLog, "server session id"), FileLog); + + // http.log is created by the RotatingFileSink but the logger is OFF, so + // the broadcast "server session id" message must not have been written to it + const std::filesystem::path HttpLogFile = TestDir / "logs" / "http.log"; + CHECK_MESSAGE(std::filesystem::exists(HttpLogFile), "http.log was not created"); + const std::string HttpLog = ReadFileToString(HttpLogFile); + CHECK_MESSAGE(!LogContains(HttpLog, "server session id"), HttpLog); +} + +} // namespace zen::tests + +#endif diff --git a/src/zenserver-test/nomad-tests.cpp b/src/zenserver-test/nomad-tests.cpp new file mode 100644 index 000000000..6eb99bc3a --- /dev/null +++ b/src/zenserver-test/nomad-tests.cpp @@ -0,0 +1,126 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#if ZEN_WITH_TESTS && ZEN_WITH_NOMAD +# include "zenserver-test.h" +# include <zencore/filesystem.h> +# include <zencore/logging.h> +# include <zencore/testing.h> +# include <zencore/timer.h> +# include <zenhttp/httpclient.h> +# include <zennomad/nomadclient.h> +# include <zennomad/nomadprocess.h> +# include <zenutil/zenserverprocess.h> + +# include <fmt/format.h> + +namespace zen::tests::nomad_tests { + +using namespace std::literals; + +TEST_CASE("nomad.client.lifecycle" * doctest::skip()) +{ + zen::nomad::NomadProcess NomadProc; + NomadProc.SpawnNomadAgent(); + + zen::nomad::NomadTestClient Client("http://localhost:4646/"); + + // Submit a simple batch job that sleeps briefly +# if ZEN_PLATFORM_WINDOWS + auto Job = Client.SubmitJob("zen-test-job", "cmd.exe", {"/C", "timeout /t 10 /nobreak"}); +# else + auto Job = Client.SubmitJob("zen-test-job", "/bin/sleep", {"10"}); +# endif + REQUIRE(!Job.Id.empty()); + CHECK_EQ(Job.Status, "pending"); + + // Poll until the job is running (or dead) + { + Stopwatch Timer; + bool FoundRunning = false; + while (Timer.GetElapsedTimeMs() < 15000) + { + auto Status = Client.GetJobStatus("zen-test-job"); + if (Status.Status == "running") + { + FoundRunning = true; + break; + } + if (Status.Status == "dead") + { + break; + } + Sleep(500); + } + CHECK(FoundRunning); + } + + // Verify allocations exist + auto Allocs = Client.GetAllocations("zen-test-job"); + CHECK(!Allocs.empty()); + + // Stop the job + Client.StopJob("zen-test-job"); + + // Verify it reaches dead state + { + Stopwatch Timer; + bool FoundDead = false; + while (Timer.GetElapsedTimeMs() < 10000) + { + auto Status = Client.GetJobStatus("zen-test-job"); + if (Status.Status == "dead") + { + FoundDead = true; + break; + } + Sleep(500); + } + CHECK(FoundDead); + } + + NomadProc.StopNomadAgent(); +} + +TEST_CASE("nomad.provisioner.integration" * doctest::skip()) +{ + zen::nomad::NomadProcess NomadProc; + NomadProc.SpawnNomadAgent(); + + // Spawn zenserver in compute mode with Nomad provisioning enabled + ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + + std::filesystem::path ZenServerPath = TestEnv.ProgramBaseDir() / "zenserver" ZEN_EXE_SUFFIX_LITERAL; + + std::string NomadArgs = fmt::format( + "--nomad-enabled=true" + " --nomad-server=http://localhost:4646" + " --nomad-driver=raw_exec" + " --nomad-binary-path={}" + " --nomad-max-cores=32" + " --nomad-cores-per-job=32", + ZenServerPath.string()); + + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(NomadArgs); + REQUIRE(Port != 0); + + // Give the provisioner time to submit jobs. + // The management thread has a 5s wait between cycles, and the HTTP client has + // a 10s connect timeout, so we need to allow enough time for at least one full cycle. + Sleep(15000); + + // Verify jobs were submitted to Nomad + zen::nomad::NomadTestClient NomadClient("http://localhost:4646/"); + + auto Jobs = NomadClient.ListJobs("zenserver-worker"); + + ZEN_INFO("nomad.provisioner.integration: found {} jobs with prefix 'zenserver-worker'", Jobs.size()); + CHECK_MESSAGE(!Jobs.empty(), Instance.GetLogOutput()); + + Instance.Shutdown(); + NomadProc.StopNomadAgent(); +} + +} // namespace zen::tests::nomad_tests +#endif diff --git a/src/zenserver-test/xmake.lua b/src/zenserver-test/xmake.lua index 2a269cea1..7b208bbc7 100644 --- a/src/zenserver-test/xmake.lua +++ b/src/zenserver-test/xmake.lua @@ -6,10 +6,15 @@ target("zenserver-test") add_headerfiles("**.h") add_files("*.cpp") add_files("zenserver-test.cpp", {unity_ignored = true }) - add_deps("zencore", "zenremotestore", "zenhttp") + add_deps("zencore", "zenremotestore", "zenhttp", "zencompute", "zenstore") add_deps("zenserver", {inherit=false}) + add_deps("zentest-appstub", {inherit=false}) add_packages("http_parser") + if has_config("zennomad") then + add_deps("zennomad") + end + if is_plat("macosx") then add_ldflags("-framework CoreFoundation") add_ldflags("-framework Security") diff --git a/src/zenserver/compute/computeserver.cpp b/src/zenserver/compute/computeserver.cpp index 0f9ef0287..802d06caf 100644 --- a/src/zenserver/compute/computeserver.cpp +++ b/src/zenserver/compute/computeserver.cpp @@ -1,9 +1,9 @@ // Copyright Epic Games, Inc. All Rights Reserved. #include "computeserver.h" -#include <zencompute/httpfunctionservice.h> -#include "computeservice.h" - +#include <zencompute/cloudmetadata.h> +#include <zencompute/httpcomputeservice.h> +#include <zencompute/httporchestrator.h> #if ZEN_WITH_COMPUTE_SERVICES # include <zencore/fmtutils.h> @@ -13,10 +13,20 @@ # include <zencore/scopeguard.h> # include <zencore/sentryintegration.h> # include <zencore/system.h> +# include <zencore/compactbinarybuilder.h> # include <zencore/windows.h> +# include <zenhttp/httpclient.h> # include <zenhttp/httpapiservice.h> # include <zenstore/cidstore.h> # include <zenutil/service.h> +# if ZEN_WITH_HORDE +# include <zenhorde/hordeconfig.h> +# include <zenhorde/hordeprovisioner.h> +# endif +# if ZEN_WITH_NOMAD +# include <zennomad/nomadconfig.h> +# include <zennomad/nomadprovisioner.h> +# endif ZEN_THIRD_PARTY_INCLUDES_START # include <cxxopts.hpp> @@ -29,6 +39,13 @@ ZenComputeServerConfigurator::AddCliOptions(cxxopts::Options& Options) { Options.add_option("compute", "", + "max-actions", + "Maximum number of concurrent local actions (0 = auto)", + cxxopts::value<int32_t>(m_ServerOptions.MaxConcurrentActions)->default_value("0"), + ""); + + Options.add_option("compute", + "", "upstream-notification-endpoint", "Endpoint URL for upstream notifications", cxxopts::value<std::string>(m_ServerOptions.UpstreamNotificationEndpoint)->default_value(""), @@ -40,6 +57,236 @@ ZenComputeServerConfigurator::AddCliOptions(cxxopts::Options& Options) "Instance ID for use in notifications", cxxopts::value<std::string>(m_ServerOptions.InstanceId)->default_value(""), ""); + + Options.add_option("compute", + "", + "coordinator-endpoint", + "Endpoint URL for coordinator service", + cxxopts::value<std::string>(m_ServerOptions.CoordinatorEndpoint)->default_value(""), + ""); + + Options.add_option("compute", + "", + "idms", + "Enable IDMS cloud detection; optionally specify a custom probe endpoint", + cxxopts::value<std::string>(m_ServerOptions.IdmsEndpoint)->default_value("")->implicit_value("auto"), + ""); + + Options.add_option("compute", + "", + "worker-websocket", + "Use WebSocket for worker-orchestrator link (instant reachability detection)", + cxxopts::value<bool>(m_ServerOptions.EnableWorkerWebSocket)->default_value("false"), + ""); + +# if ZEN_WITH_HORDE + // Horde provisioning options + Options.add_option("horde", + "", + "horde-enabled", + "Enable Horde worker provisioning", + cxxopts::value<bool>(m_ServerOptions.HordeConfig.Enabled)->default_value("false"), + ""); + + Options.add_option("horde", + "", + "horde-server", + "Horde server URL", + cxxopts::value<std::string>(m_ServerOptions.HordeConfig.ServerUrl)->default_value(""), + ""); + + Options.add_option("horde", + "", + "horde-token", + "Horde authentication token", + cxxopts::value<std::string>(m_ServerOptions.HordeConfig.AuthToken)->default_value(""), + ""); + + Options.add_option("horde", + "", + "horde-pool", + "Horde pool name", + cxxopts::value<std::string>(m_ServerOptions.HordeConfig.Pool)->default_value(""), + ""); + + Options.add_option("horde", + "", + "horde-cluster", + "Horde cluster ID ('default' or '_auto' for auto-resolve)", + cxxopts::value<std::string>(m_ServerOptions.HordeConfig.Cluster)->default_value("default"), + ""); + + Options.add_option("horde", + "", + "horde-mode", + "Horde connection mode (direct, tunnel, relay)", + cxxopts::value<std::string>(m_HordeModeStr)->default_value("direct"), + ""); + + Options.add_option("horde", + "", + "horde-encryption", + "Horde transport encryption (none, aes)", + cxxopts::value<std::string>(m_HordeEncryptionStr)->default_value("none"), + ""); + + Options.add_option("horde", + "", + "horde-max-cores", + "Maximum number of Horde cores to provision", + cxxopts::value<int>(m_ServerOptions.HordeConfig.MaxCores)->default_value("2048"), + ""); + + Options.add_option("horde", + "", + "horde-host", + "Host address for Horde agents to connect back to", + cxxopts::value<std::string>(m_ServerOptions.HordeConfig.HostAddress)->default_value(""), + ""); + + Options.add_option("horde", + "", + "horde-condition", + "Additional Horde agent filter condition", + cxxopts::value<std::string>(m_ServerOptions.HordeConfig.Condition)->default_value(""), + ""); + + Options.add_option("horde", + "", + "horde-binaries", + "Path to directory containing zenserver binary for remote upload", + cxxopts::value<std::string>(m_ServerOptions.HordeConfig.BinariesPath)->default_value(""), + ""); + + Options.add_option("horde", + "", + "horde-zen-service-port", + "Port number for Zen service communication", + cxxopts::value<uint16_t>(m_ServerOptions.HordeConfig.ZenServicePort)->default_value("8558"), + ""); +# endif + +# if ZEN_WITH_NOMAD + // Nomad provisioning options + Options.add_option("nomad", + "", + "nomad-enabled", + "Enable Nomad worker provisioning", + cxxopts::value<bool>(m_ServerOptions.NomadConfig.Enabled)->default_value("false"), + ""); + + Options.add_option("nomad", + "", + "nomad-server", + "Nomad HTTP API URL", + cxxopts::value<std::string>(m_ServerOptions.NomadConfig.ServerUrl)->default_value(""), + ""); + + Options.add_option("nomad", + "", + "nomad-token", + "Nomad ACL token", + cxxopts::value<std::string>(m_ServerOptions.NomadConfig.AclToken)->default_value(""), + ""); + + Options.add_option("nomad", + "", + "nomad-datacenter", + "Nomad target datacenter", + cxxopts::value<std::string>(m_ServerOptions.NomadConfig.Datacenter)->default_value("dc1"), + ""); + + Options.add_option("nomad", + "", + "nomad-namespace", + "Nomad namespace", + cxxopts::value<std::string>(m_ServerOptions.NomadConfig.Namespace)->default_value("default"), + ""); + + Options.add_option("nomad", + "", + "nomad-region", + "Nomad region (empty for server default)", + cxxopts::value<std::string>(m_ServerOptions.NomadConfig.Region)->default_value(""), + ""); + + Options.add_option("nomad", + "", + "nomad-driver", + "Nomad task driver (raw_exec, docker)", + cxxopts::value<std::string>(m_NomadDriverStr)->default_value("raw_exec"), + ""); + + Options.add_option("nomad", + "", + "nomad-distribution", + "Binary distribution mode (predeployed, artifact)", + cxxopts::value<std::string>(m_NomadDistributionStr)->default_value("predeployed"), + ""); + + Options.add_option("nomad", + "", + "nomad-binary-path", + "Path to zenserver on Nomad clients (predeployed mode)", + cxxopts::value<std::string>(m_ServerOptions.NomadConfig.BinaryPath)->default_value(""), + ""); + + Options.add_option("nomad", + "", + "nomad-artifact-source", + "URL to download zenserver binary (artifact mode)", + cxxopts::value<std::string>(m_ServerOptions.NomadConfig.ArtifactSource)->default_value(""), + ""); + + Options.add_option("nomad", + "", + "nomad-docker-image", + "Docker image for zenserver (docker driver)", + cxxopts::value<std::string>(m_ServerOptions.NomadConfig.DockerImage)->default_value(""), + ""); + + Options.add_option("nomad", + "", + "nomad-max-jobs", + "Maximum concurrent Nomad jobs", + cxxopts::value<int>(m_ServerOptions.NomadConfig.MaxJobs)->default_value("64"), + ""); + + Options.add_option("nomad", + "", + "nomad-cpu-mhz", + "CPU MHz allocated per Nomad task", + cxxopts::value<int>(m_ServerOptions.NomadConfig.CpuMhz)->default_value("1000"), + ""); + + Options.add_option("nomad", + "", + "nomad-memory-mb", + "Memory MB allocated per Nomad task", + cxxopts::value<int>(m_ServerOptions.NomadConfig.MemoryMb)->default_value("2048"), + ""); + + Options.add_option("nomad", + "", + "nomad-cores-per-job", + "Estimated cores per Nomad job (for scaling)", + cxxopts::value<int>(m_ServerOptions.NomadConfig.CoresPerJob)->default_value("32"), + ""); + + Options.add_option("nomad", + "", + "nomad-max-cores", + "Maximum total cores to provision via Nomad", + cxxopts::value<int>(m_ServerOptions.NomadConfig.MaxCores)->default_value("2048"), + ""); + + Options.add_option("nomad", + "", + "nomad-job-prefix", + "Prefix for generated Nomad job IDs", + cxxopts::value<std::string>(m_ServerOptions.NomadConfig.JobPrefix)->default_value("zenserver-worker"), + ""); +# endif } void @@ -63,6 +310,15 @@ ZenComputeServerConfigurator::OnConfigFileParsed(LuaConfig::Options& LuaOptions) void ZenComputeServerConfigurator::ValidateOptions() { +# if ZEN_WITH_HORDE + horde::FromString(m_ServerOptions.HordeConfig.Mode, m_HordeModeStr); + horde::FromString(m_ServerOptions.HordeConfig.EncryptionMode, m_HordeEncryptionStr); +# endif + +# if ZEN_WITH_NOMAD + nomad::FromString(m_ServerOptions.NomadConfig.TaskDriver, m_NomadDriverStr); + nomad::FromString(m_ServerOptions.NomadConfig.BinDistribution, m_NomadDistributionStr); +# endif } /////////////////////////////////////////////////////////////////////////// @@ -90,10 +346,14 @@ ZenComputeServer::Initialize(const ZenComputeServerConfig& ServerConfig, ZenServ return EffectiveBasePort; } + m_CoordinatorEndpoint = ServerConfig.CoordinatorEndpoint; + m_InstanceId = ServerConfig.InstanceId; + m_EnableWorkerWebSocket = ServerConfig.EnableWorkerWebSocket; + // This is a workaround to make sure we can have automated tests. Without // this the ranges for different child zen compute processes could overlap with // the main test range. - ZenServerEnvironment::SetBaseChildId(1000); + ZenServerEnvironment::SetBaseChildId(2000); m_DebugOptionForcedCrash = ServerConfig.ShouldCrash; @@ -113,6 +373,46 @@ ZenComputeServer::Cleanup() ZEN_INFO(ZEN_APP_NAME " cleaning up"); try { + // Cancel the maintenance timer so it stops re-enqueuing before we + // tear down the provisioners it references. + m_ProvisionerMaintenanceTimer.cancel(); + m_AnnounceTimer.cancel(); + +# if ZEN_WITH_HORDE + // Shut down Horde provisioner first — this signals all agent threads + // to exit and joins them before we tear down HTTP services. + m_HordeProvisioner.reset(); +# endif + +# if ZEN_WITH_NOMAD + // Shut down Nomad provisioner — stops the management thread and + // sends stop requests for all tracked jobs. + m_NomadProvisioner.reset(); +# endif + + // Close the orchestrator WebSocket client before stopping the io_context + m_WsReconnectTimer.cancel(); + if (m_OrchestratorWsClient) + { + m_OrchestratorWsClient->Close(); + m_OrchestratorWsClient.reset(); + } + m_OrchestratorWsHandler.reset(); + + ResolveCloudMetadata(); + m_CloudMetadata.reset(); + + // Shut down services that own threads or use the io_context before we + // stop the io_context and close the HTTP server. + if (m_OrchestratorService) + { + m_OrchestratorService->Shutdown(); + } + if (m_ComputeService) + { + m_ComputeService->Shutdown(); + } + m_IoContext.stop(); if (m_IoRunner.joinable()) { @@ -139,7 +439,8 @@ ZenComputeServer::InitializeState(const ZenComputeServerConfig& ServerConfig) void ZenComputeServer::InitializeServices(const ZenComputeServerConfig& ServerConfig) { - ZEN_INFO("initializing storage"); + ZEN_TRACE_CPU("ZenComputeServer::InitializeServices"); + ZEN_INFO("initializing compute services"); CidStoreConfiguration Config; Config.RootDirectory = m_DataRoot / "cas"; @@ -147,46 +448,405 @@ ZenComputeServer::InitializeServices(const ZenComputeServerConfig& ServerConfig) m_CidStore = std::make_unique<CidStore>(m_GcManager); m_CidStore->Initialize(Config); + if (!ServerConfig.IdmsEndpoint.empty()) + { + ZEN_INFO("detecting cloud environment (async)"); + if (ServerConfig.IdmsEndpoint == "auto") + { + m_CloudMetadataFuture = std::async(std::launch::async, [DataDir = ServerConfig.DataDir] { + return std::make_unique<zen::compute::CloudMetadata>(DataDir / "cloud"); + }); + } + else + { + ZEN_INFO("using custom IDMS endpoint: {}", ServerConfig.IdmsEndpoint); + m_CloudMetadataFuture = std::async(std::launch::async, [DataDir = ServerConfig.DataDir, Endpoint = ServerConfig.IdmsEndpoint] { + return std::make_unique<zen::compute::CloudMetadata>(DataDir / "cloud", Endpoint); + }); + } + } + ZEN_INFO("instantiating API service"); m_ApiService = std::make_unique<zen::HttpApiService>(*m_Http); - ZEN_INFO("instantiating compute service"); - m_ComputeService = std::make_unique<HttpComputeService>(ServerConfig.DataDir / "compute"); + ZEN_INFO("instantiating orchestrator service"); + m_OrchestratorService = + std::make_unique<zen::compute::HttpOrchestratorService>(ServerConfig.DataDir / "orch", ServerConfig.EnableWorkerWebSocket); + + ZEN_INFO("instantiating function service"); + m_ComputeService = std::make_unique<zen::compute::HttpComputeService>(*m_CidStore, + m_StatsService, + ServerConfig.DataDir / "functions", + ServerConfig.MaxConcurrentActions); - // Ref<zen::compute::FunctionRunner> Runner; - // Runner = zen::compute::CreateLocalRunner(*m_CidStore, ServerConfig.DataDir / "runner"); + m_FrontendService = std::make_unique<HttpFrontendService>(m_ContentRoot, m_StatusService); - // TODO: (re)implement default configuration here +# if ZEN_WITH_NOMAD + // Nomad provisioner + if (ServerConfig.NomadConfig.Enabled && !ServerConfig.NomadConfig.ServerUrl.empty()) + { + ZEN_INFO("instantiating Nomad provisioner (server: {})", ServerConfig.NomadConfig.ServerUrl); - ZEN_INFO("instantiating function service"); - m_FunctionService = - std::make_unique<zen::compute::HttpFunctionService>(*m_CidStore, m_StatsService, ServerConfig.DataDir / "functions"); + const auto& NomadCfg = ServerConfig.NomadConfig; + + if (!NomadCfg.Validate()) + { + ZEN_ERROR("invalid Nomad configuration"); + } + else + { + ExtendableStringBuilder<256> OrchestratorEndpoint; + OrchestratorEndpoint << m_Http->GetServiceUri(m_OrchestratorService.get()); + if (auto View = OrchestratorEndpoint.ToView(); !View.empty() && View.back() != '/') + { + OrchestratorEndpoint << '/'; + } + + m_NomadProvisioner = std::make_unique<nomad::NomadProvisioner>(NomadCfg, OrchestratorEndpoint); + } + } +# endif + +# if ZEN_WITH_HORDE + // Horde provisioner + if (ServerConfig.HordeConfig.Enabled && !ServerConfig.HordeConfig.ServerUrl.empty()) + { + ZEN_INFO("instantiating Horde provisioner (server: {})", ServerConfig.HordeConfig.ServerUrl); + + const auto& HordeConfig = ServerConfig.HordeConfig; + + if (!HordeConfig.Validate()) + { + ZEN_ERROR("invalid Horde configuration"); + } + else + { + ExtendableStringBuilder<256> OrchestratorEndpoint; + OrchestratorEndpoint << m_Http->GetServiceUri(m_OrchestratorService.get()); + if (auto View = OrchestratorEndpoint.ToView(); !View.empty() && View.back() != '/') + { + OrchestratorEndpoint << '/'; + } + + // If no binaries path is specified, just use the running executable's directory + std::filesystem::path BinariesPath = HordeConfig.BinariesPath.empty() ? GetRunningExecutablePath().parent_path() + : std::filesystem::path(HordeConfig.BinariesPath); + std::filesystem::path WorkingDir = ServerConfig.DataDir / "horde"; + + m_HordeProvisioner = std::make_unique<horde::HordeProvisioner>(HordeConfig, BinariesPath, WorkingDir, OrchestratorEndpoint); + } + } +# endif +} + +void +ZenComputeServer::ResolveCloudMetadata() +{ + if (m_CloudMetadataFuture.valid()) + { + m_CloudMetadata = m_CloudMetadataFuture.get(); + } +} + +std::string +ZenComputeServer::GetInstanceId() const +{ + if (!m_InstanceId.empty()) + { + return m_InstanceId; + } + return fmt::format("{}-{}", GetMachineName(), GetCurrentProcessId()); +} + +std::string +ZenComputeServer::GetAnnounceUrl() const +{ + return m_Http->GetServiceUri(nullptr); } void ZenComputeServer::RegisterServices(const ZenComputeServerConfig& ServerConfig) { + ZEN_TRACE_CPU("ZenComputeServer::RegisterServices"); ZEN_UNUSED(ServerConfig); + m_Http->RegisterService(m_StatsService); + + if (m_ApiService) + { + m_Http->RegisterService(*m_ApiService); + } + + if (m_OrchestratorService) + { + m_Http->RegisterService(*m_OrchestratorService); + } + if (m_ComputeService) { m_Http->RegisterService(*m_ComputeService); } - if (m_ApiService) + if (m_FrontendService) { - m_Http->RegisterService(*m_ApiService); + m_Http->RegisterService(*m_FrontendService); + } +} + +CbObject +ZenComputeServer::BuildAnnounceBody() +{ + CbObjectWriter AnnounceBody; + AnnounceBody << "id" << GetInstanceId(); + AnnounceBody << "uri" << GetAnnounceUrl(); + AnnounceBody << "hostname" << GetMachineName(); + AnnounceBody << "platform" << GetRuntimePlatformName(); + + ExtendedSystemMetrics Sm = ApplyReportingOverrides(m_MetricsTracker.Query()); + + AnnounceBody.BeginObject("metrics"); + Describe(Sm, AnnounceBody); + AnnounceBody.EndObject(); + + AnnounceBody << "cpu_usage" << Sm.CpuUsagePercent; + AnnounceBody << "memory_total" << Sm.SystemMemoryMiB * 1024 * 1024; + AnnounceBody << "memory_used" << (Sm.SystemMemoryMiB - Sm.AvailSystemMemoryMiB) * 1024 * 1024; + + AnnounceBody << "bytes_received" << m_Http->GetTotalBytesReceived(); + AnnounceBody << "bytes_sent" << m_Http->GetTotalBytesSent(); + + auto Actions = m_ComputeService->GetActionCounts(); + AnnounceBody << "actions_pending" << Actions.Pending; + AnnounceBody << "actions_running" << Actions.Running; + AnnounceBody << "actions_completed" << Actions.Completed; + AnnounceBody << "active_queues" << Actions.ActiveQueues; + + // Derive provisioner from instance ID prefix (e.g. "horde-xxx" or "nomad-xxx") + if (m_InstanceId.starts_with("horde-")) + { + AnnounceBody << "provisioner" + << "horde"; + } + else if (m_InstanceId.starts_with("nomad-")) + { + AnnounceBody << "provisioner" + << "nomad"; + } + + ResolveCloudMetadata(); + if (m_CloudMetadata) + { + m_CloudMetadata->Describe(AnnounceBody); + } + + return AnnounceBody.Save(); +} + +void +ZenComputeServer::PostAnnounce() +{ + ZEN_TRACE_CPU("ZenComputeServer::PostAnnounce"); + + if (!m_ComputeService || m_CoordinatorEndpoint.empty()) + { + return; + } + + ZEN_INFO("notifying coordinator at '{}' of our availability at '{}'", m_CoordinatorEndpoint, GetAnnounceUrl()); + + try + { + CbObject Body = BuildAnnounceBody(); + + // If we have an active WebSocket connection, send via that instead of HTTP POST + if (m_OrchestratorWsClient && m_OrchestratorWsClient->IsOpen()) + { + MemoryView View = Body.GetView(); + m_OrchestratorWsClient->SendBinary(std::span<const uint8_t>(reinterpret_cast<const uint8_t*>(View.GetData()), View.GetSize())); + ZEN_INFO("announced to coordinator via WebSocket"); + return; + } + + HttpClient CoordinatorHttp(m_CoordinatorEndpoint); + HttpClient::Response Result = CoordinatorHttp.Post("announce", std::move(Body)); + + if (Result.Error) + { + ZEN_ERROR("failed to notify coordinator at '{}': HTTP error {} - {}", + m_CoordinatorEndpoint, + Result.Error->ErrorCode, + Result.Error->ErrorMessage); + } + else if (!IsHttpOk(Result.StatusCode)) + { + ZEN_ERROR("failed to notify coordinator at '{}': unexpected HTTP status code {}", + m_CoordinatorEndpoint, + static_cast<int>(Result.StatusCode)); + } + else + { + ZEN_INFO("successfully notified coordinator at '{}'", m_CoordinatorEndpoint); + } + } + catch (const std::exception& Ex) + { + ZEN_ERROR("failed to notify coordinator at '{}': {}", m_CoordinatorEndpoint, Ex.what()); + } +} + +void +ZenComputeServer::EnqueueAnnounceTimer() +{ + if (!m_ComputeService || m_CoordinatorEndpoint.empty()) + { + return; + } + + m_AnnounceTimer.expires_after(std::chrono::seconds(15)); + m_AnnounceTimer.async_wait([this](const asio::error_code& Ec) { + if (!Ec) + { + PostAnnounce(); + EnqueueAnnounceTimer(); + } + }); + EnsureIoRunner(); +} + +void +ZenComputeServer::InitializeOrchestratorWebSocket() +{ + if (!m_EnableWorkerWebSocket || m_CoordinatorEndpoint.empty()) + { + return; + } + + // Convert http://host:port → ws://host:port/orch/ws + std::string WsUrl = m_CoordinatorEndpoint; + if (WsUrl.starts_with("http://")) + { + WsUrl = "ws://" + WsUrl.substr(7); + } + else if (WsUrl.starts_with("https://")) + { + WsUrl = "wss://" + WsUrl.substr(8); + } + if (!WsUrl.empty() && WsUrl.back() != '/') + { + WsUrl += '/'; + } + WsUrl += "orch/ws"; + + ZEN_INFO("establishing WebSocket link to orchestrator at {}", WsUrl); + + m_OrchestratorWsHandler = std::make_unique<OrchestratorWsHandler>(*this); + m_OrchestratorWsClient = + std::make_unique<HttpWsClient>(WsUrl, *m_OrchestratorWsHandler, m_IoContext, HttpWsClientSettings{.LogCategory = "orch_ws"}); + + m_OrchestratorWsClient->Connect(); + EnsureIoRunner(); +} + +void +ZenComputeServer::EnqueueWsReconnect() +{ + m_WsReconnectTimer.expires_after(std::chrono::seconds(5)); + m_WsReconnectTimer.async_wait([this](const asio::error_code& Ec) { + if (!Ec && m_OrchestratorWsClient) + { + ZEN_INFO("attempting WebSocket reconnect to orchestrator"); + m_OrchestratorWsClient->Connect(); + } + }); + EnsureIoRunner(); +} + +void +ZenComputeServer::OrchestratorWsHandler::OnWsOpen() +{ + ZEN_INFO("WebSocket link to orchestrator established"); + + // Send initial announce immediately over the WebSocket + Server.PostAnnounce(); +} + +void +ZenComputeServer::OrchestratorWsHandler::OnWsMessage([[maybe_unused]] const WebSocketMessage& Msg) +{ + // Orchestrator does not push messages to workers; ignore +} + +void +ZenComputeServer::OrchestratorWsHandler::OnWsClose([[maybe_unused]] uint16_t Code, [[maybe_unused]] std::string_view Reason) +{ + ZEN_WARN("WebSocket link to orchestrator closed (code {}), falling back to HTTP announce", Code); + + // Trigger an immediate HTTP announce so the orchestrator has fresh state, + // then schedule a reconnect attempt. + Server.PostAnnounce(); + Server.EnqueueWsReconnect(); +} + +void +ZenComputeServer::ProvisionerMaintenanceTick() +{ +# if ZEN_WITH_HORDE + if (m_HordeProvisioner) + { + m_HordeProvisioner->SetTargetCoreCount(UINT32_MAX); + auto Stats = m_HordeProvisioner->GetStats(); + ZEN_DEBUG("Horde maintenance: target={}, estimated={}, active={}", + Stats.TargetCoreCount, + Stats.EstimatedCoreCount, + Stats.ActiveCoreCount); + } +# endif + +# if ZEN_WITH_NOMAD + if (m_NomadProvisioner) + { + m_NomadProvisioner->SetTargetCoreCount(UINT32_MAX); + auto Stats = m_NomadProvisioner->GetStats(); + ZEN_DEBUG("Nomad maintenance: target={}, estimated={}, running jobs={}", + Stats.TargetCoreCount, + Stats.EstimatedCoreCount, + Stats.RunningJobCount); } +# endif +} + +void +ZenComputeServer::EnqueueProvisionerMaintenanceTimer() +{ + bool HasProvisioner = false; +# if ZEN_WITH_HORDE + HasProvisioner = HasProvisioner || (m_HordeProvisioner != nullptr); +# endif +# if ZEN_WITH_NOMAD + HasProvisioner = HasProvisioner || (m_NomadProvisioner != nullptr); +# endif - if (m_FunctionService) + if (!HasProvisioner) { - m_Http->RegisterService(*m_FunctionService); + return; } + + m_ProvisionerMaintenanceTimer.expires_after(std::chrono::seconds(15)); + m_ProvisionerMaintenanceTimer.async_wait([this](const asio::error_code& Ec) { + if (!Ec) + { + ProvisionerMaintenanceTick(); + EnqueueProvisionerMaintenanceTimer(); + } + }); + EnsureIoRunner(); } void ZenComputeServer::Run() { + ZEN_TRACE_CPU("ZenComputeServer::Run"); + if (m_ProcessMonitor.IsActive()) { CheckOwnerPid(); @@ -236,6 +896,35 @@ ZenComputeServer::Run() OnReady(); + PostAnnounce(); + EnqueueAnnounceTimer(); + InitializeOrchestratorWebSocket(); + +# if ZEN_WITH_HORDE + // Start Horde provisioning if configured — request maximum allowed cores. + // SetTargetCoreCount clamps to HordeConfig::MaxCores internally. + if (m_HordeProvisioner) + { + ZEN_INFO("Horde provisioning starting"); + m_HordeProvisioner->SetTargetCoreCount(UINT32_MAX); + auto Stats = m_HordeProvisioner->GetStats(); + ZEN_INFO("Horde provisioning started (target cores: {})", Stats.TargetCoreCount); + } +# endif + +# if ZEN_WITH_NOMAD + // Start Nomad provisioning if configured — request maximum allowed cores. + // SetTargetCoreCount clamps to NomadConfig::MaxCores internally. + if (m_NomadProvisioner) + { + m_NomadProvisioner->SetTargetCoreCount(UINT32_MAX); + auto Stats = m_NomadProvisioner->GetStats(); + ZEN_INFO("Nomad provisioning started (target cores: {})", Stats.TargetCoreCount); + } +# endif + + EnqueueProvisionerMaintenanceTimer(); + m_Http->Run(IsInteractiveMode); SetNewState(kShuttingDown); @@ -254,6 +943,8 @@ ZenComputeServerMain::ZenComputeServerMain(ZenComputeServerConfig& ServerOptions void ZenComputeServerMain::DoRun(ZenServerState::ZenServerEntry* Entry) { + ZEN_TRACE_CPU("ZenComputeServerMain::DoRun"); + ZenComputeServer Server; Server.SetDataRoot(m_ServerOptions.DataDir); Server.SetContentRoot(m_ServerOptions.ContentDir); diff --git a/src/zenserver/compute/computeserver.h b/src/zenserver/compute/computeserver.h index 625140b23..e4a6b01d5 100644 --- a/src/zenserver/compute/computeserver.h +++ b/src/zenserver/compute/computeserver.h @@ -6,7 +6,11 @@ #if ZEN_WITH_COMPUTE_SERVICES +# include <future> +# include <zencore/system.h> +# include <zenhttp/httpwsclient.h> # include <zenstore/gc.h> +# include "frontend/frontend.h" namespace cxxopts { class Options; @@ -16,19 +20,46 @@ struct Options; } namespace zen::compute { -class HttpFunctionService; -} +class CloudMetadata; +class HttpComputeService; +class HttpOrchestratorService; +} // namespace zen::compute + +# if ZEN_WITH_HORDE +# include <zenhorde/hordeconfig.h> +namespace zen::horde { +class HordeProvisioner; +} // namespace zen::horde +# endif + +# if ZEN_WITH_NOMAD +# include <zennomad/nomadconfig.h> +namespace zen::nomad { +class NomadProvisioner; +} // namespace zen::nomad +# endif namespace zen { class CidStore; class HttpApiService; -class HttpComputeService; struct ZenComputeServerConfig : public ZenServerConfig { std::string UpstreamNotificationEndpoint; std::string InstanceId; // For use in notifications + std::string CoordinatorEndpoint; + std::string IdmsEndpoint; + int32_t MaxConcurrentActions = 0; // 0 = auto (LogicalProcessorCount * 2) + bool EnableWorkerWebSocket = false; // Use WebSocket for worker↔orchestrator link + +# if ZEN_WITH_HORDE + horde::HordeConfig HordeConfig; +# endif + +# if ZEN_WITH_NOMAD + nomad::NomadConfig NomadConfig; +# endif }; struct ZenComputeServerConfigurator : public ZenServerConfiguratorBase @@ -49,6 +80,16 @@ private: virtual void ValidateOptions() override; ZenComputeServerConfig& m_ServerOptions; + +# if ZEN_WITH_HORDE + std::string m_HordeModeStr = "direct"; + std::string m_HordeEncryptionStr = "none"; +# endif + +# if ZEN_WITH_NOMAD + std::string m_NomadDriverStr = "raw_exec"; + std::string m_NomadDistributionStr = "predeployed"; +# endif }; class ZenComputeServerMain : public ZenServerMain @@ -88,17 +129,59 @@ public: void Cleanup(); private: - HttpStatsService m_StatsService; - GcManager m_GcManager; - GcScheduler m_GcScheduler{m_GcManager}; - std::unique_ptr<CidStore> m_CidStore; - std::unique_ptr<HttpComputeService> m_ComputeService; - std::unique_ptr<HttpApiService> m_ApiService; - std::unique_ptr<zen::compute::HttpFunctionService> m_FunctionService; - - void InitializeState(const ZenComputeServerConfig& ServerConfig); - void InitializeServices(const ZenComputeServerConfig& ServerConfig); - void RegisterServices(const ZenComputeServerConfig& ServerConfig); + HttpStatsService m_StatsService; + GcManager m_GcManager; + GcScheduler m_GcScheduler{m_GcManager}; + std::unique_ptr<CidStore> m_CidStore; + std::unique_ptr<HttpApiService> m_ApiService; + std::unique_ptr<zen::compute::HttpComputeService> m_ComputeService; + std::unique_ptr<zen::compute::HttpOrchestratorService> m_OrchestratorService; + std::unique_ptr<zen::compute::CloudMetadata> m_CloudMetadata; + std::future<std::unique_ptr<zen::compute::CloudMetadata>> m_CloudMetadataFuture; + std::unique_ptr<HttpFrontendService> m_FrontendService; +# if ZEN_WITH_HORDE + std::unique_ptr<zen::horde::HordeProvisioner> m_HordeProvisioner; +# endif +# if ZEN_WITH_NOMAD + std::unique_ptr<zen::nomad::NomadProvisioner> m_NomadProvisioner; +# endif + SystemMetricsTracker m_MetricsTracker; + std::string m_CoordinatorEndpoint; + std::string m_InstanceId; + + asio::steady_timer m_AnnounceTimer{m_IoContext}; + asio::steady_timer m_ProvisionerMaintenanceTimer{m_IoContext}; + + void InitializeState(const ZenComputeServerConfig& ServerConfig); + void InitializeServices(const ZenComputeServerConfig& ServerConfig); + void RegisterServices(const ZenComputeServerConfig& ServerConfig); + void ResolveCloudMetadata(); + void PostAnnounce(); + void EnqueueAnnounceTimer(); + void EnqueueProvisionerMaintenanceTimer(); + void ProvisionerMaintenanceTick(); + std::string GetAnnounceUrl() const; + std::string GetInstanceId() const; + CbObject BuildAnnounceBody(); + + // Worker→orchestrator WebSocket client + struct OrchestratorWsHandler : public IWsClientHandler + { + ZenComputeServer& Server; + explicit OrchestratorWsHandler(ZenComputeServer& S) : Server(S) {} + + void OnWsOpen() override; + void OnWsMessage(const WebSocketMessage& Msg) override; + void OnWsClose(uint16_t Code, std::string_view Reason) override; + }; + + std::unique_ptr<OrchestratorWsHandler> m_OrchestratorWsHandler; + std::unique_ptr<HttpWsClient> m_OrchestratorWsClient; + asio::steady_timer m_WsReconnectTimer{m_IoContext}; + bool m_EnableWorkerWebSocket = false; + + void InitializeOrchestratorWebSocket(); + void EnqueueWsReconnect(); }; } // namespace zen diff --git a/src/zenserver/compute/computeservice.cpp b/src/zenserver/compute/computeservice.cpp deleted file mode 100644 index 2c0bc0ae9..000000000 --- a/src/zenserver/compute/computeservice.cpp +++ /dev/null @@ -1,100 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#include "computeservice.h" - -#if ZEN_WITH_COMPUTE_SERVICES - -# include <zencore/compactbinarybuilder.h> -# include <zencore/filesystem.h> -# include <zencore/fmtutils.h> -# include <zencore/logging.h> -# include <zencore/system.h> -# include <zenutil/zenserverprocess.h> - -ZEN_THIRD_PARTY_INCLUDES_START -# include <EASTL/fixed_vector.h> -# include <asio.hpp> -ZEN_THIRD_PARTY_INCLUDES_END - -# include <unordered_map> - -namespace zen { - -////////////////////////////////////////////////////////////////////////// - -struct ResourceMetrics -{ - uint64_t DiskUsageBytes = 0; - uint64_t MemoryUsageBytes = 0; -}; - -////////////////////////////////////////////////////////////////////////// - -struct HttpComputeService::Impl -{ - Impl(const Impl&) = delete; - Impl& operator=(const Impl&) = delete; - - Impl(); - ~Impl(); - - void Initialize(std::filesystem::path BaseDir) { ZEN_UNUSED(BaseDir); } - - void Cleanup() {} - -private: -}; - -HttpComputeService::Impl::Impl() -{ -} - -HttpComputeService::Impl::~Impl() -{ -} - -/////////////////////////////////////////////////////////////////////////// - -HttpComputeService::HttpComputeService(std::filesystem::path BaseDir) : m_Impl(std::make_unique<Impl>()) -{ - using namespace std::literals; - - m_Impl->Initialize(BaseDir); - - m_Router.RegisterRoute( - "status", - [this](HttpRouterRequest& Req) { - CbObjectWriter Obj; - Obj.BeginArray("modules"); - Obj.EndArray(); - Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Obj.Save()); - }, - HttpVerb::kGet); - - m_Router.RegisterRoute( - "stats", - [this](HttpRouterRequest& Req) { - CbObjectWriter Obj; - Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Obj.Save()); - }, - HttpVerb::kGet); -} - -HttpComputeService::~HttpComputeService() -{ -} - -const char* -HttpComputeService::BaseUri() const -{ - return "/compute/"; -} - -void -HttpComputeService::HandleRequest(zen::HttpServerRequest& Request) -{ - m_Router.HandleRequest(Request); -} - -} // namespace zen -#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zenserver/compute/computeservice.h b/src/zenserver/compute/computeservice.h deleted file mode 100644 index 339200dd8..000000000 --- a/src/zenserver/compute/computeservice.h +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#pragma once - -#include <zenhttp/httpserver.h> - -#if ZEN_WITH_COMPUTE_SERVICES -namespace zen { - -/** ZenServer Compute Service - * - * Manages a set of compute workers for use in UEFN content worker - * - */ -class HttpComputeService : public zen::HttpService -{ -public: - HttpComputeService(std::filesystem::path BaseDir); - ~HttpComputeService(); - - HttpComputeService(const HttpComputeService&) = delete; - HttpComputeService& operator=(const HttpComputeService&) = delete; - - virtual const char* BaseUri() const override; - virtual void HandleRequest(zen::HttpServerRequest& Request) override; - -private: - HttpRequestRouter m_Router; - - struct Impl; - - std::unique_ptr<Impl> m_Impl; -}; - -} // namespace zen -#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zenserver/frontend/html.zip b/src/zenserver/frontend/html.zip Binary files differindex 4767029c0..c167cc70e 100644 --- a/src/zenserver/frontend/html.zip +++ b/src/zenserver/frontend/html.zip diff --git a/src/zenserver/frontend/html/404.html b/src/zenserver/frontend/html/404.html new file mode 100644 index 000000000..829ef2097 --- /dev/null +++ b/src/zenserver/frontend/html/404.html @@ -0,0 +1,486 @@ +<!DOCTYPE html> +<html lang="en"> +<head> +<meta charset="UTF-8"> +<meta name="viewport" content="width=device-width, initial-scale=1.0"> +<title>Ooops</title> +<style> + * { margin: 0; padding: 0; box-sizing: border-box; } + + :root { + --deep-space: #00000f; + --nebula-blue: #0a0a2e; + --star-white: #ffffff; + --star-blue: #c8d8ff; + --star-yellow: #fff3c0; + --star-red: #ffd0c0; + --nebula-glow: rgba(60, 80, 180, 0.12); + } + + body { + background: var(--deep-space); + min-height: 100vh; + display: flex; + align-items: center; + justify-content: center; + font-family: 'Courier New', monospace; + overflow: hidden; + } + + starfield-bg { + display: block; + position: fixed; + inset: 0; + z-index: 0; + } + + canvas { + display: block; + width: 100%; + height: 100%; + } + + .page-content { + position: relative; + z-index: 1; + text-align: center; + color: rgba(200, 216, 255, 0.85); + letter-spacing: 0.25em; + text-transform: uppercase; + pointer-events: none; + user-select: none; + } + + .page-content h1 { + font-size: clamp(1.2rem, 4vw, 2.4rem); + font-weight: 300; + letter-spacing: 0.6em; + text-shadow: 0 0 40px rgba(120, 160, 255, 0.6), 0 0 80px rgba(80, 120, 255, 0.3); + animation: pulse 6s ease-in-out infinite; + } + + .page-content p { + margin-top: 1.2rem; + font-size: clamp(0.55rem, 1.5vw, 0.75rem); + letter-spacing: 0.4em; + opacity: 0.45; + } + + @keyframes pulse { + 0%, 100% { opacity: 0.7; } + 50% { opacity: 1; } + } + + .globe-link { + display: block; + margin: 0 auto 2rem; + width: 160px; + height: 160px; + pointer-events: auto; + cursor: pointer; + border-radius: 50%; + position: relative; + } + + .globe-link:hover .globe-glow { + opacity: 0.6; + } + + .globe-glow { + position: absolute; + inset: -18px; + border-radius: 50%; + background: radial-gradient(circle, rgba(80, 140, 255, 0.35) 0%, transparent 70%); + opacity: 0.35; + transition: opacity 0.4s; + pointer-events: none; + } + + .globe-link canvas { + display: block; + width: 160px; + height: 160px; + border-radius: 50%; + } +</style> +</head> +<body> + +<starfield-bg + star-count="380" + speed="0.6" + depth="true" + nebula="true" + shooting-stars="true" +></starfield-bg> + +<div class="page-content"> + <a class="globe-link" href="/dashboard/" title="Back to Dashboard"> + <div class="globe-glow"></div> + <canvas id="globe" width="320" height="320"></canvas> + </a> + <h1>404 NOT FOUND</h1> +</div> + +<script> +class StarfieldBg extends HTMLElement { + constructor() { + super(); + this.attachShadow({ mode: 'open' }); + } + + connectedCallback() { + this.shadowRoot.innerHTML = ` + <style> + :host { display: block; position: absolute; inset: 0; overflow: hidden; } + canvas { width: 100%; height: 100%; display: block; } + </style> + <canvas></canvas> + `; + + this.canvas = this.shadowRoot.querySelector('canvas'); + this.ctx = this.canvas.getContext('2d'); + + this.starCount = parseInt(this.getAttribute('star-count') || '350'); + this.speed = parseFloat(this.getAttribute('speed') || '0.6'); + this.useDepth = this.getAttribute('depth') !== 'false'; + this.useNebula = this.getAttribute('nebula') !== 'false'; + this.useShooting = this.getAttribute('shooting-stars') !== 'false'; + + this.stars = []; + this.shooters = []; + this.nebulaTime = 0; + this.frame = 0; + + this.resize(); + this.init(); + + this._ro = new ResizeObserver(() => { this.resize(); this.init(); }); + this._ro.observe(this); + + this.raf = requestAnimationFrame(this.tick.bind(this)); + } + + disconnectedCallback() { + cancelAnimationFrame(this.raf); + this._ro.disconnect(); + } + + resize() { + const dpr = window.devicePixelRatio || 1; + const rect = this.getBoundingClientRect(); + this.W = rect.width || window.innerWidth; + this.H = rect.height || window.innerHeight; + this.canvas.width = this.W * dpr; + this.canvas.height = this.H * dpr; + this.ctx.setTransform(dpr, 0, 0, dpr, 0, 0); + } + + init() { + const COLORS = ['#ffffff', '#c8d8ff', '#d0e8ff', '#fff3c0', '#ffd0c0', '#e0f0ff']; + this.stars = Array.from({ length: this.starCount }, () => ({ + x: Math.random() * this.W, + y: Math.random() * this.H, + z: this.useDepth ? Math.random() : 1, // depth: 0=far, 1=near + r: Math.random() * 1.4 + 0.2, + color: COLORS[Math.floor(Math.random() * COLORS.length)], + twinkleOffset: Math.random() * Math.PI * 2, + twinkleSpeed: 0.008 + Math.random() * 0.012, + })); + } + + spawnShooter() { + const edge = Math.random() < 0.7 ? 'top' : 'left'; + const angle = (Math.random() * 30 + 15) * (Math.PI / 180); + this.shooters.push({ + x: edge === 'top' ? Math.random() * this.W : -10, + y: edge === 'top' ? -10 : Math.random() * this.H * 0.5, + vx: Math.cos(angle) * (6 + Math.random() * 6), + vy: Math.sin(angle) * (6 + Math.random() * 6), + len: 80 + Math.random() * 120, + life: 1, + decay: 0.012 + Math.random() * 0.018, + }); + } + + tick() { + this.raf = requestAnimationFrame(this.tick.bind(this)); + this.frame++; + const ctx = this.ctx; + const W = this.W, H = this.H; + + // Background + ctx.fillStyle = '#00000f'; + ctx.fillRect(0, 0, W, H); + + // Nebula clouds (subtle) + if (this.useNebula) { + this.nebulaTime += 0.003; + this.drawNebula(ctx, W, H); + } + + // Stars + for (const s of this.stars) { + const twinkle = 0.55 + 0.45 * Math.sin(this.frame * s.twinkleSpeed + s.twinkleOffset); + const radius = s.r * (this.useDepth ? (0.3 + s.z * 0.7) : 1); + const alpha = (this.useDepth ? (0.25 + s.z * 0.75) : 1) * twinkle; + + // Tiny drift + s.x += (s.z * this.speed * 0.08) * (this.useDepth ? 1 : 0); + s.y += (s.z * this.speed * 0.04) * (this.useDepth ? 1 : 0); + if (s.x > W + 2) s.x = -2; + if (s.y > H + 2) s.y = -2; + + // Glow for bright stars + if (radius > 1.1 && alpha > 0.6) { + const grd = ctx.createRadialGradient(s.x, s.y, 0, s.x, s.y, radius * 3.5); + grd.addColorStop(0, s.color.replace(')', `, ${alpha * 0.5})`).replace('rgb', 'rgba')); + grd.addColorStop(1, 'transparent'); + ctx.beginPath(); + ctx.arc(s.x, s.y, radius * 3.5, 0, Math.PI * 2); + ctx.fillStyle = grd; + ctx.fill(); + } + + ctx.beginPath(); + ctx.arc(s.x, s.y, radius, 0, Math.PI * 2); + ctx.fillStyle = hexToRgba(s.color, alpha); + ctx.fill(); + } + + // Shooting stars + if (this.useShooting) { + if (this.frame % 140 === 0 && Math.random() < 0.65) this.spawnShooter(); + for (let i = this.shooters.length - 1; i >= 0; i--) { + const s = this.shooters[i]; + const tailX = s.x - s.vx * (s.len / Math.hypot(s.vx, s.vy)); + const tailY = s.y - s.vy * (s.len / Math.hypot(s.vx, s.vy)); + + const grd = ctx.createLinearGradient(tailX, tailY, s.x, s.y); + grd.addColorStop(0, `rgba(255,255,255,0)`); + grd.addColorStop(0.7, `rgba(200,220,255,${s.life * 0.5})`); + grd.addColorStop(1, `rgba(255,255,255,${s.life})`); + + ctx.beginPath(); + ctx.moveTo(tailX, tailY); + ctx.lineTo(s.x, s.y); + ctx.strokeStyle = grd; + ctx.lineWidth = 1.5 * s.life; + ctx.lineCap = 'round'; + ctx.stroke(); + + // Head dot + ctx.beginPath(); + ctx.arc(s.x, s.y, 1.5 * s.life, 0, Math.PI * 2); + ctx.fillStyle = `rgba(255,255,255,${s.life})`; + ctx.fill(); + + s.x += s.vx; + s.y += s.vy; + s.life -= s.decay; + + if (s.life <= 0 || s.x > W + 200 || s.y > H + 200) { + this.shooters.splice(i, 1); + } + } + } + } + + drawNebula(ctx, W, H) { + const t = this.nebulaTime; + const blobs = [ + { x: W * 0.25, y: H * 0.3, rx: W * 0.35, ry: H * 0.25, color: '40,60,180', a: 0.055 }, + { x: W * 0.75, y: H * 0.65, rx: W * 0.30, ry: H * 0.22, color: '100,40,160', a: 0.04 }, + { x: W * 0.5, y: H * 0.5, rx: W * 0.45, ry: H * 0.35, color: '20,50,120', a: 0.035 }, + ]; + ctx.save(); + for (const b of blobs) { + const ox = Math.sin(t * 0.7 + b.x) * 30; + const oy = Math.cos(t * 0.5 + b.y) * 20; + const grd = ctx.createRadialGradient(b.x + ox, b.y + oy, 0, b.x + ox, b.y + oy, Math.max(b.rx, b.ry)); + grd.addColorStop(0, `rgba(${b.color}, ${b.a})`); + grd.addColorStop(0.5, `rgba(${b.color}, ${b.a * 0.4})`); + grd.addColorStop(1, `rgba(${b.color}, 0)`); + ctx.save(); + ctx.scale(b.rx / Math.max(b.rx, b.ry), b.ry / Math.max(b.rx, b.ry)); + ctx.beginPath(); + const scale = Math.max(b.rx, b.ry); + ctx.arc((b.x + ox) / (b.rx / scale), (b.y + oy) / (b.ry / scale), scale, 0, Math.PI * 2); + ctx.fillStyle = grd; + ctx.fill(); + ctx.restore(); + } + ctx.restore(); + } +} + +function hexToRgba(hex, alpha) { + // Handle named-ish values or full hex + const c = hex.startsWith('#') ? hex : '#ffffff'; + const r = parseInt(c.slice(1,3), 16); + const g = parseInt(c.slice(3,5), 16); + const b = parseInt(c.slice(5,7), 16); + return `rgba(${r},${g},${b},${alpha.toFixed(3)})`; +} + +customElements.define('starfield-bg', StarfieldBg); +</script> + +<script> +(function() { + const canvas = document.getElementById('globe'); + const ctx = canvas.getContext('2d'); + const W = canvas.width, H = canvas.height; + const R = W * 0.44; + const cx = W / 2, cy = H / 2; + + // Simplified continent outlines as lon/lat polygon chains (degrees). + // Each continent is an array of [lon, lat] points. + const continents = [ + // North America + [[-130,50],[-125,55],[-120,60],[-115,65],[-100,68],[-85,70],[-75,65],[-60,52],[-65,45],[-70,42],[-75,35],[-80,30],[-85,28],[-90,28],[-95,25],[-100,20],[-105,20],[-110,25],[-115,30],[-120,35],[-125,42],[-130,50]], + // South America + [[-80,10],[-75,5],[-70,5],[-65,0],[-60,-5],[-55,-5],[-50,-10],[-45,-15],[-40,-20],[-40,-25],[-42,-30],[-48,-32],[-52,-34],[-55,-38],[-60,-42],[-65,-50],[-68,-55],[-70,-48],[-72,-40],[-75,-30],[-78,-15],[-80,-5],[-80,5],[-80,10]], + // Europe + [[-10,36],[-5,38],[0,40],[2,43],[5,44],[8,46],[10,48],[15,50],[18,54],[20,56],[25,58],[28,60],[30,62],[35,65],[40,68],[38,60],[35,55],[30,50],[28,48],[25,45],[22,40],[20,38],[15,36],[10,36],[5,36],[0,36],[-5,36],[-10,36]], + // Africa + [[-15,14],[-17,16],[-15,22],[-12,28],[-5,32],[0,35],[5,37],[10,35],[15,32],[20,30],[25,30],[30,28],[35,25],[38,18],[40,12],[42,5],[44,0],[42,-5],[40,-12],[38,-18],[35,-25],[32,-30],[30,-34],[25,-33],[20,-30],[15,-28],[12,-20],[10,-10],[8,-5],[5,0],[2,5],[0,5],[-5,5],[-10,6],[-15,10],[-15,14]], + // Asia (simplified) + [[30,35],[35,38],[40,40],[45,42],[50,45],[55,48],[60,50],[65,55],[70,60],[75,65],[80,68],[90,70],[100,68],[110,65],[120,60],[125,55],[130,50],[135,45],[140,40],[138,35],[130,30],[120,25],[110,20],[105,15],[100,10],[95,12],[90,20],[85,22],[80,25],[75,28],[70,30],[65,35],[55,35],[45,35],[40,35],[35,35],[30,35]], + // Australia + [[115,-12],[120,-14],[125,-15],[130,-14],[135,-13],[138,-16],[140,-18],[145,-20],[148,-22],[150,-25],[152,-28],[150,-33],[148,-35],[145,-37],[140,-38],[135,-36],[130,-33],[125,-30],[120,-25],[118,-22],[116,-20],[114,-18],[115,-15],[115,-12]], + ]; + + function project(lon, lat, rotation) { + // Convert to radians and apply rotation + var lonR = (lon + rotation) * Math.PI / 180; + var latR = lat * Math.PI / 180; + + var x3 = Math.cos(latR) * Math.sin(lonR); + var y3 = -Math.sin(latR); + var z3 = Math.cos(latR) * Math.cos(lonR); + + // Only visible if facing us + if (z3 < 0) return null; + + return { x: cx + x3 * R, y: cy + y3 * R, z: z3 }; + } + + var rotation = 0; + + function draw() { + requestAnimationFrame(draw); + rotation += 0.15; + ctx.clearRect(0, 0, W, H); + + // Atmosphere glow + var atm = ctx.createRadialGradient(cx, cy, R * 0.85, cx, cy, R * 1.15); + atm.addColorStop(0, 'rgba(60,130,255,0.12)'); + atm.addColorStop(0.5, 'rgba(60,130,255,0.06)'); + atm.addColorStop(1, 'rgba(60,130,255,0)'); + ctx.beginPath(); + ctx.arc(cx, cy, R * 1.15, 0, Math.PI * 2); + ctx.fillStyle = atm; + ctx.fill(); + + // Ocean sphere + var oceanGrad = ctx.createRadialGradient(cx - R * 0.3, cy - R * 0.3, R * 0.1, cx, cy, R); + oceanGrad.addColorStop(0, '#1a4a8a'); + oceanGrad.addColorStop(0.5, '#0e2d5e'); + oceanGrad.addColorStop(1, '#071838'); + ctx.beginPath(); + ctx.arc(cx, cy, R, 0, Math.PI * 2); + ctx.fillStyle = oceanGrad; + ctx.fill(); + + // Draw continents + for (var c = 0; c < continents.length; c++) { + var pts = continents[c]; + var projected = []; + var allVisible = true; + + for (var i = 0; i < pts.length; i++) { + var p = project(pts[i][0], pts[i][1], rotation); + if (!p) { allVisible = false; break; } + projected.push(p); + } + + if (!allVisible || projected.length < 3) continue; + + ctx.beginPath(); + ctx.moveTo(projected[0].x, projected[0].y); + for (var i = 1; i < projected.length; i++) { + ctx.lineTo(projected[i].x, projected[i].y); + } + ctx.closePath(); + + // Shade based on average depth + var avgZ = 0; + for (var i = 0; i < projected.length; i++) avgZ += projected[i].z; + avgZ /= projected.length; + var brightness = 0.3 + avgZ * 0.7; + + var r = Math.round(30 * brightness); + var g = Math.round(100 * brightness); + var b = Math.round(50 * brightness); + ctx.fillStyle = 'rgb(' + r + ',' + g + ',' + b + ')'; + ctx.fill(); + } + + // Grid lines (longitude) + ctx.strokeStyle = 'rgba(100,160,255,0.08)'; + ctx.lineWidth = 0.7; + for (var lon = -180; lon < 180; lon += 30) { + ctx.beginPath(); + var started = false; + for (var lat = -90; lat <= 90; lat += 3) { + var p = project(lon, lat, rotation); + if (p) { + if (!started) { ctx.moveTo(p.x, p.y); started = true; } + else ctx.lineTo(p.x, p.y); + } else { + started = false; + } + } + ctx.stroke(); + } + + // Grid lines (latitude) + for (var lat = -60; lat <= 60; lat += 30) { + ctx.beginPath(); + var started = false; + for (var lon = -180; lon <= 180; lon += 3) { + var p = project(lon, lat, rotation); + if (p) { + if (!started) { ctx.moveTo(p.x, p.y); started = true; } + else ctx.lineTo(p.x, p.y); + } else { + started = false; + } + } + ctx.stroke(); + } + + // Specular highlight + var spec = ctx.createRadialGradient(cx - R * 0.35, cy - R * 0.35, 0, cx - R * 0.35, cy - R * 0.35, R * 0.8); + spec.addColorStop(0, 'rgba(180,210,255,0.18)'); + spec.addColorStop(0.4, 'rgba(120,160,255,0.05)'); + spec.addColorStop(1, 'rgba(0,0,0,0)'); + ctx.beginPath(); + ctx.arc(cx, cy, R, 0, Math.PI * 2); + ctx.fillStyle = spec; + ctx.fill(); + + // Rim light + ctx.beginPath(); + ctx.arc(cx, cy, R, 0, Math.PI * 2); + ctx.strokeStyle = 'rgba(80,140,255,0.2)'; + ctx.lineWidth = 1.5; + ctx.stroke(); + } + + draw(); +})(); +</script> +</body> +</html> diff --git a/src/zenserver/frontend/html/compute/banner.js b/src/zenserver/frontend/html/compute/banner.js new file mode 100644 index 000000000..61c7ce21f --- /dev/null +++ b/src/zenserver/frontend/html/compute/banner.js @@ -0,0 +1,321 @@ +/** + * zen-banner.js — Zen Compute dashboard banner Web Component + * + * Usage: + * <script src="/components/zen-banner.js" defer></script> + * + * <zen-banner></zen-banner> + * <zen-banner variant="compact"></zen-banner> + * <zen-banner cluster-status="degraded" load="78"></zen-banner> + * + * Attributes: + * variant "full" (default) | "compact" + * cluster-status "nominal" (default) | "degraded" | "offline" + * load 0–100 integer, shown as a percentage (default: hidden) + * tagline custom tagline text (default: "Orchestrator Overview" / "Orchestrator") + * subtitle text after "ZEN" in the wordmark (default: "COMPUTE") + */ + +class ZenBanner extends HTMLElement { + + static get observedAttributes() { + return ['variant', 'cluster-status', 'load', 'tagline', 'subtitle']; + } + + attributeChangedCallback() { + if (this.shadowRoot) this._render(); + } + + connectedCallback() { + if (!this.shadowRoot) this.attachShadow({ mode: 'open' }); + this._render(); + } + + // ───────────────────────────────────────────── + // Derived values + // ───────────────────────────────────────────── + + get _variant() { return this.getAttribute('variant') || 'full'; } + get _status() { return (this.getAttribute('cluster-status') || 'nominal').toLowerCase(); } + get _load() { return this.getAttribute('load'); } // null → hidden + get _tagline() { return this.getAttribute('tagline'); } // null → default + get _subtitle() { return this.getAttribute('subtitle'); } // null → "COMPUTE" + + get _statusColor() { + return { nominal: '#7ecfb8', degraded: '#d4a84b', offline: '#c0504d' }[this._status] ?? '#7ecfb8'; + } + + get _statusLabel() { + return { nominal: 'NOMINAL', degraded: 'DEGRADED', offline: 'OFFLINE' }[this._status] ?? 'NOMINAL'; + } + + get _loadColor() { + const v = parseInt(this._load, 10); + if (isNaN(v)) return '#7ecfb8'; + if (v >= 85) return '#c0504d'; + if (v >= 60) return '#d4a84b'; + return '#7ecfb8'; + } + + // ───────────────────────────────────────────── + // Render + // ───────────────────────────────────────────── + + _render() { + const compact = this._variant === 'compact'; + this.shadowRoot.innerHTML = ` + <style>${this._css(compact)}</style> + ${this._html(compact)} + `; + } + + // ───────────────────────────────────────────── + // CSS + // ───────────────────────────────────────────── + + _css(compact) { + const height = compact ? '60px' : '100px'; + const padding = compact ? '0 24px' : '0 32px'; + const gap = compact ? '16px' : '24px'; + const markSize = compact ? '34px' : '52px'; + const divH = compact ? '32px' : '48px'; + const nameSize = compact ? '15px' : '22px'; + const tagSize = compact ? '9px' : '11px'; + const sc = this._statusColor; + const lc = this._loadColor; + + return ` + @import url('https://fonts.googleapis.com/css2?family=Noto+Serif+JP:wght@300;400&family=Space+Mono:wght@400;700&display=swap'); + + *, *::before, *::after { box-sizing: border-box; margin: 0; padding: 0; } + + :host { + display: block; + font-family: 'Space Mono', monospace; + } + + .banner { + width: 100%; + height: ${height}; + background: #0b0d10; + border: 1px solid #1e2330; + border-radius: 6px; + display: flex; + align-items: center; + padding: ${padding}; + gap: ${gap}; + position: relative; + overflow: hidden; + } + + /* scan-line texture */ + .banner::before { + content: ''; + position: absolute; + inset: 0; + background: repeating-linear-gradient( + 0deg, + transparent, transparent 3px, + rgba(255,255,255,0.012) 3px, rgba(255,255,255,0.012) 4px + ); + pointer-events: none; + } + + /* ambient glow */ + .banner::after { + content: ''; + position: absolute; + right: -60px; + top: 50%; + transform: translateY(-50%); + width: 280px; + height: 280px; + background: radial-gradient(circle, rgba(130,200,180,0.06) 0%, transparent 70%); + pointer-events: none; + } + + .logo-mark { + flex-shrink: 0; + width: ${markSize}; + height: ${markSize}; + } + + .logo-mark svg { width: 100%; height: 100%; } + + .divider { + width: 1px; + height: ${divH}; + background: linear-gradient(to bottom, transparent, #2a3040, transparent); + flex-shrink: 0; + } + + .text-block { + display: flex; + flex-direction: column; + gap: 4px; + } + + .wordmark { + font-weight: 700; + font-size: ${nameSize}; + letter-spacing: 0.12em; + color: #e8e4dc; + text-transform: uppercase; + line-height: 1; + } + + .wordmark span { color: #7ecfb8; } + + .tagline { + font-family: 'Noto Serif JP', serif; + font-weight: 300; + font-size: ${tagSize}; + letter-spacing: 0.3em; + color: #4a5a68; + text-transform: uppercase; + } + + .spacer { flex: 1; } + + /* ── right-side decorative circuit ── */ + .circuit { flex-shrink: 0; opacity: 0.22; } + + /* ── status cluster ── */ + .status-cluster { + display: flex; + flex-direction: column; + align-items: flex-end; + gap: 6px; + } + + .status-row { + display: flex; + align-items: center; + gap: 8px; + } + + .status-lbl { + font-size: 9px; + letter-spacing: 0.18em; + color: #3a4555; + text-transform: uppercase; + } + + .pill { + display: flex; + align-items: center; + gap: 5px; + border-radius: 20px; + padding: 2px 10px; + font-size: 10px; + letter-spacing: 0.1em; + } + + .pill.cluster { + color: ${sc}; + background: color-mix(in srgb, ${sc} 8%, transparent); + border: 1px solid color-mix(in srgb, ${sc} 28%, transparent); + } + + .pill.load-pill { + color: ${lc}; + background: color-mix(in srgb, ${lc} 8%, transparent); + border: 1px solid color-mix(in srgb, ${lc} 28%, transparent); + } + + .dot { + width: 5px; + height: 5px; + border-radius: 50%; + animation: pulse 2.4s ease-in-out infinite; + } + + .dot.cluster { background: ${sc}; } + .dot.load-dot { background: ${lc}; animation-delay: 0.5s; } + + @keyframes pulse { + 0%, 100% { opacity: 1; } + 50% { opacity: 0.25; } + } + `; + } + + // ───────────────────────────────────────────── + // HTML template + // ───────────────────────────────────────────── + + _html(compact) { + const loadAttr = this._load; + const showStatus = !compact; + + const rightSide = showStatus ? ` + <svg class="circuit" width="60" height="60" viewBox="0 0 60 60" fill="none"> + <path d="M5 30 H22 L28 18 H60" stroke="#7ecfb8" stroke-width="0.8"/> + <path d="M5 38 H18 L24 46 H60" stroke="#7ecfb8" stroke-width="0.8"/> + <circle cx="22" cy="30" r="2" fill="none" stroke="#7ecfb8" stroke-width="0.8"/> + <circle cx="18" cy="38" r="2" fill="none" stroke="#7ecfb8" stroke-width="0.8"/> + <circle cx="10" cy="30" r="1.2" fill="#7ecfb8"/> + <circle cx="10" cy="38" r="1.2" fill="#7ecfb8"/> + </svg> + + <div class="status-cluster"> + <div class="status-row"> + <span class="status-lbl">Cluster</span> + <div class="pill cluster"> + <div class="dot cluster"></div> + ${this._statusLabel} + </div> + </div> + ${loadAttr !== null ? ` + <div class="status-row"> + <span class="status-lbl">Load</span> + <div class="pill load-pill"> + <div class="dot load-dot"></div> + ${parseInt(loadAttr, 10)} % + </div> + </div>` : ''} + </div> + ` : ''; + + return ` + <div class="banner"> + <div class="logo-mark">${this._svgMark()}</div> + <div class="divider"></div> + <div class="text-block"> + <div class="wordmark">ZEN<span> ${this._subtitle ?? 'COMPUTE'}</span></div> + <div class="tagline">${this._tagline ?? (compact ? 'Orchestrator' : 'Orchestrator Overview')}</div> + </div> + <div class="spacer"></div> + ${rightSide} + </div> + `; + } + + // ───────────────────────────────────────────── + // SVG logo mark + // ───────────────────────────────────────────── + + _svgMark() { + return ` + <svg viewBox="0 0 52 52" fill="none" xmlns="http://www.w3.org/2000/svg"> + <circle cx="26" cy="26" r="22" stroke="#2a3a48" stroke-width="1.5"/> + <path d="M26 4 A22 22 0 1 1 12 43.1" stroke="#7ecfb8" stroke-width="2" stroke-linecap="round" fill="none"/> + <circle cx="17" cy="17" r="1.6" fill="#7ecfb8" /> + <circle cx="26" cy="17" r="1.6" fill="#7ecfb8" /> + <circle cx="35" cy="17" r="1.6" fill="#7ecfb8" /> + <circle cx="17" cy="26" r="1.6" fill="#7ecfb8" opacity="0.6"/> + <circle cx="26" cy="26" r="2.2" fill="#7ecfb8"/> + <circle cx="35" cy="26" r="1.6" fill="#7ecfb8" opacity="0.6"/> + <circle cx="17" cy="35" r="1.6" fill="#7ecfb8"/> + <circle cx="26" cy="35" r="1.6" fill="#7ecfb8"/> + <circle cx="35" cy="35" r="1.6" fill="#7ecfb8"/> + <line x1="17" y1="17" x2="35" y2="17" stroke="#7ecfb8" stroke-width="0.7" stroke-opacity="0.25"/> + <line x1="35" y1="17" x2="17" y2="35" stroke="#7ecfb8" stroke-width="0.7" stroke-opacity="0.25"/> + <line x1="17" y1="35" x2="35" y2="35" stroke="#7ecfb8" stroke-width="0.7" stroke-opacity="0.2"/> + <line x1="26" y1="17" x2="26" y2="35" stroke="#7ecfb8" stroke-width="0.7" stroke-opacity="0.2"/> + </svg> + `; + } +} + +customElements.define('zen-banner', ZenBanner); diff --git a/src/zenserver/frontend/html/compute.html b/src/zenserver/frontend/html/compute/compute.html index 668189fe5..1e101d839 100644 --- a/src/zenserver/frontend/html/compute.html +++ b/src/zenserver/frontend/html/compute/compute.html @@ -5,6 +5,8 @@ <meta name="viewport" content="width=device-width, initial-scale=1.0"> <title>Zen Compute Dashboard</title> <script src="https://cdn.jsdelivr.net/npm/[email protected]/dist/chart.umd.min.js"></script> + <script src="banner.js" defer></script> + <script src="nav.js" defer></script> <style> * { margin: 0; @@ -291,16 +293,12 @@ </head> <body> <div class="container"> - <div class="header"> - <div> - <h1>Zen Compute Dashboard</h1> - <div class="timestamp">Last updated: <span id="last-update">Never</span></div> - </div> - <div class="health-indicator" id="health-indicator"> - <div class="status-dot"></div> - <span id="health-text">Checking...</span> - </div> - </div> + <zen-banner cluster-status="nominal" load="0" tagline="Node Overview"></zen-banner> + <zen-nav> + <a href="compute.html">Node</a> + <a href="orchestrator.html">Orchestrator</a> + </zen-nav> + <div class="timestamp">Last updated: <span id="last-update">Never</span></div> <div id="error-container"></div> @@ -388,6 +386,30 @@ </div> </div> + <!-- Queues --> + <div class="section-title">Queues</div> + <div class="card" style="margin-bottom: 30px;"> + <div class="card-title">Queue Status</div> + <div id="queue-list-empty" style="color: #6e7681; font-size: 13px;">No queues.</div> + <div id="queue-list-container" style="display: none;"> + <table id="queue-list-table" style="width: 100%; border-collapse: collapse; font-size: 13px;"> + <thead> + <tr> + <th style="text-align: right; color: #8b949e; padding: 6px 8px; border-bottom: 1px solid #30363d; font-weight: 600; text-transform: uppercase; letter-spacing: 0.5px; font-size: 11px; width: 60px;">ID</th> + <th style="text-align: center; color: #8b949e; padding: 6px 8px; border-bottom: 1px solid #30363d; font-weight: 600; text-transform: uppercase; letter-spacing: 0.5px; font-size: 11px; width: 80px;">Status</th> + <th style="text-align: right; color: #8b949e; padding: 6px 8px; border-bottom: 1px solid #30363d; font-weight: 600; text-transform: uppercase; letter-spacing: 0.5px; font-size: 11px;">Active</th> + <th style="text-align: right; color: #8b949e; padding: 6px 8px; border-bottom: 1px solid #30363d; font-weight: 600; text-transform: uppercase; letter-spacing: 0.5px; font-size: 11px;">Completed</th> + <th style="text-align: right; color: #8b949e; padding: 6px 8px; border-bottom: 1px solid #30363d; font-weight: 600; text-transform: uppercase; letter-spacing: 0.5px; font-size: 11px;">Failed</th> + <th style="text-align: right; color: #8b949e; padding: 6px 8px; border-bottom: 1px solid #30363d; font-weight: 600; text-transform: uppercase; letter-spacing: 0.5px; font-size: 11px;">Abandoned</th> + <th style="text-align: right; color: #8b949e; padding: 6px 8px; border-bottom: 1px solid #30363d; font-weight: 600; text-transform: uppercase; letter-spacing: 0.5px; font-size: 11px;">Cancelled</th> + <th style="text-align: left; color: #8b949e; padding: 6px 8px; border-bottom: 1px solid #30363d; font-weight: 600; text-transform: uppercase; letter-spacing: 0.5px; font-size: 11px;">Token</th> + </tr> + </thead> + <tbody id="queue-list-body"></tbody> + </table> + </div> + </div> + <!-- Action History --> <div class="section-title">Recent Actions</div> <div class="card" style="margin-bottom: 30px;"> @@ -398,6 +420,7 @@ <thead> <tr> <th style="text-align: right; color: #8b949e; padding: 6px 8px; border-bottom: 1px solid #30363d; font-weight: 600; text-transform: uppercase; letter-spacing: 0.5px; font-size: 11px; width: 60px;">LSN</th> + <th style="text-align: right; color: #8b949e; padding: 6px 8px; border-bottom: 1px solid #30363d; font-weight: 600; text-transform: uppercase; letter-spacing: 0.5px; font-size: 11px; width: 60px;">Queue</th> <th style="text-align: center; color: #8b949e; padding: 6px 8px; border-bottom: 1px solid #30363d; font-weight: 600; text-transform: uppercase; letter-spacing: 0.5px; font-size: 11px; width: 70px;">Status</th> <th style="text-align: left; color: #8b949e; padding: 6px 8px; border-bottom: 1px solid #30363d; font-weight: 600; text-transform: uppercase; letter-spacing: 0.5px; font-size: 11px;">Function</th> <th style="text-align: right; color: #8b949e; padding: 6px 8px; border-bottom: 1px solid #30363d; font-weight: 600; text-transform: uppercase; letter-spacing: 0.5px; font-size: 11px; width: 80px;">Started</th> @@ -576,6 +599,12 @@ }); // Helper functions + function escapeHtml(text) { + var div = document.createElement('div'); + div.textContent = text; + return div.innerHTML; + } + function formatBytes(bytes) { if (bytes === 0) return '0 B'; const k = 1024; @@ -590,7 +619,7 @@ function showError(message) { const container = document.getElementById('error-container'); - container.innerHTML = `<div class="error">Error: ${message}</div>`; + container.innerHTML = `<div class="error">Error: ${escapeHtml(message)}</div>`; } function clearError() { @@ -617,35 +646,30 @@ async function fetchHealth() { try { - const response = await fetch(`${BASE_URL}/apply/ready`); + const response = await fetch(`${BASE_URL}/compute/ready`); const isHealthy = response.status === 200; - const indicator = document.getElementById('health-indicator'); - const text = document.getElementById('health-text'); + const banner = document.querySelector('zen-banner'); if (isHealthy) { - indicator.classList.add('healthy'); - indicator.classList.remove('unhealthy'); - text.textContent = 'Healthy'; + banner.setAttribute('cluster-status', 'nominal'); + banner.setAttribute('load', '0'); } else { - indicator.classList.add('unhealthy'); - indicator.classList.remove('healthy'); - text.textContent = 'Unhealthy'; + banner.setAttribute('cluster-status', 'degraded'); + banner.setAttribute('load', '0'); } return isHealthy; } catch (error) { - const indicator = document.getElementById('health-indicator'); - const text = document.getElementById('health-text'); - indicator.classList.add('unhealthy'); - indicator.classList.remove('healthy'); - text.textContent = 'Error'; + const banner = document.querySelector('zen-banner'); + banner.setAttribute('cluster-status', 'degraded'); + banner.setAttribute('load', '0'); throw error; } } async function fetchStats() { - const data = await fetchJSON('/stats/apply'); + const data = await fetchJSON('/stats/compute'); // Update action counts document.getElementById('actions-pending').textContent = data.actions_pending || 0; @@ -684,13 +708,16 @@ } async function fetchSysInfo() { - const data = await fetchJSON('/apply/sysinfo'); + const data = await fetchJSON('/compute/sysinfo'); // Update CPU const cpuUsage = data.cpu_usage || 0; document.getElementById('cpu-usage').textContent = cpuUsage.toFixed(1) + '%'; document.getElementById('cpu-progress').style.width = cpuUsage + '%'; + const banner = document.querySelector('zen-banner'); + banner.setAttribute('load', cpuUsage.toFixed(1)); + history.cpu.push(cpuUsage); if (history.cpu.length > MAX_HISTORY_POINTS) history.cpu.shift(); cpuChart.data.labels = history.cpu.map(() => ''); @@ -741,7 +768,7 @@ const functions = desc.functions || []; const functionsHtml = functions.length === 0 ? '<span style="color:#6e7681;font-size:12px;">none</span>' : `<table class="detail-table">${functions.map(f => - `<tr><td>${f.name || '-'}</td><td class="detail-mono">${f.version || '-'}</td></tr>` + `<tr><td>${escapeHtml(f.name || '-')}</td><td class="detail-mono">${escapeHtml(f.version || '-')}</td></tr>` ).join('')}</table>`; // Executables @@ -756,8 +783,8 @@ </tr> ${executables.map(e => `<tr> - <td>${e.name || '-'}</td> - <td class="detail-mono">${e.hash || '-'}</td> + <td>${escapeHtml(e.name || '-')}</td> + <td class="detail-mono">${escapeHtml(e.hash || '-')}</td> <td style="text-align:right;white-space:nowrap;">${e.size != null ? formatBytes(e.size) : '-'}</td> </tr>` ).join('')} @@ -772,26 +799,26 @@ const files = desc.files || []; const filesHtml = files.length === 0 ? '<span style="color:#6e7681;font-size:12px;">none</span>' : `<table class="detail-table">${files.map(f => - `<tr><td>${f.name || f}</td><td class="detail-mono">${f.hash || ''}</td></tr>` + `<tr><td>${escapeHtml(f.name || f)}</td><td class="detail-mono">${escapeHtml(f.hash || '')}</td></tr>` ).join('')}</table>`; // Dirs const dirs = desc.dirs || []; const dirsHtml = dirs.length === 0 ? '<span style="color:#6e7681;font-size:12px;">none</span>' : - dirs.map(d => `<span class="detail-tag">${d}</span>`).join(''); + dirs.map(d => `<span class="detail-tag">${escapeHtml(d)}</span>`).join(''); // Environment const env = desc.environment || []; const envHtml = env.length === 0 ? '<span style="color:#6e7681;font-size:12px;">none</span>' : - env.map(e => `<span class="detail-tag">${e}</span>`).join(''); + env.map(e => `<span class="detail-tag">${escapeHtml(e)}</span>`).join(''); panel.innerHTML = ` - <div class="worker-detail-title">${desc.name || id}</div> + <div class="worker-detail-title">${escapeHtml(desc.name || id)}</div> <div class="detail-section"> <table class="detail-table"> - ${field('Worker ID', `<span class="detail-mono">${id}</span>`)} - ${field('Path', desc.path)} - ${field('Platform', desc.host)} + ${field('Worker ID', `<span class="detail-mono">${escapeHtml(id)}</span>`)} + ${field('Path', escapeHtml(desc.path || '-'))} + ${field('Platform', escapeHtml(desc.host || '-'))} ${monoField('Build System', desc.buildsystem_version)} ${field('Cores', desc.cores)} ${field('Timeout', desc.timeout != null ? desc.timeout + 's' : null)} @@ -822,7 +849,7 @@ } async function fetchWorkers() { - const data = await fetchJSON('/apply/workers'); + const data = await fetchJSON('/compute/workers'); const workerIds = data.workers || []; document.getElementById('worker-count').textContent = workerIds.length; @@ -837,7 +864,7 @@ } const descriptors = await Promise.all( - workerIds.map(id => fetchJSON(`/apply/workers/${id}`).catch(() => null)) + workerIds.map(id => fetchJSON(`/compute/workers/${id}`).catch(() => null)) ); // Build a map for quick lookup by ID @@ -857,12 +884,12 @@ tr.className = 'worker-row' + (id === selectedWorkerId ? ' selected' : ''); tr.dataset.workerId = id; tr.innerHTML = ` - <td style="padding: 6px 8px; color: #f0f6fc; border-bottom: 1px solid #21262d;">${name}</td> - <td style="padding: 6px 8px; color: #c9d1d9; border-bottom: 1px solid #21262d;">${host}</td> - <td style="padding: 6px 8px; color: #c9d1d9; border-bottom: 1px solid #21262d; text-align: right;">${cores}</td> - <td style="padding: 6px 8px; color: #c9d1d9; border-bottom: 1px solid #21262d; text-align: right;">${timeout}</td> - <td style="padding: 6px 8px; color: #c9d1d9; border-bottom: 1px solid #21262d; text-align: right;">${functions}</td> - <td style="padding: 6px 8px; color: #8b949e; border-bottom: 1px solid #21262d; font-family: monospace; font-size: 11px;">${id}</td> + <td style="padding: 6px 8px; color: #f0f6fc; border-bottom: 1px solid #21262d;">${escapeHtml(name)}</td> + <td style="padding: 6px 8px; color: #c9d1d9; border-bottom: 1px solid #21262d;">${escapeHtml(host)}</td> + <td style="padding: 6px 8px; color: #c9d1d9; border-bottom: 1px solid #21262d; text-align: right;">${escapeHtml(String(cores))}</td> + <td style="padding: 6px 8px; color: #c9d1d9; border-bottom: 1px solid #21262d; text-align: right;">${escapeHtml(String(timeout))}</td> + <td style="padding: 6px 8px; color: #c9d1d9; border-bottom: 1px solid #21262d; text-align: right;">${escapeHtml(String(functions))}</td> + <td style="padding: 6px 8px; color: #8b949e; border-bottom: 1px solid #21262d; font-family: monospace; font-size: 11px;">${escapeHtml(id)}</td> `; tr.addEventListener('click', () => { document.querySelectorAll('.worker-row').forEach(r => r.classList.remove('selected')); @@ -914,8 +941,55 @@ return `${m}m ${s}s`; } + async function fetchQueues() { + const data = await fetchJSON('/compute/queues'); + const queues = data.queues || []; + + const empty = document.getElementById('queue-list-empty'); + const container = document.getElementById('queue-list-container'); + const tbody = document.getElementById('queue-list-body'); + + if (queues.length === 0) { + empty.style.display = ''; + container.style.display = 'none'; + return; + } + + empty.style.display = 'none'; + tbody.innerHTML = ''; + + for (const q of queues) { + const id = q.queue_id ?? '-'; + const badge = q.state === 'cancelled' + ? '<span class="status-badge failure">cancelled</span>' + : q.state === 'draining' + ? '<span class="status-badge" style="background:rgba(210,153,34,0.15);color:#d29922;">draining</span>' + : q.is_complete + ? '<span class="status-badge success">complete</span>' + : '<span class="status-badge" style="background:rgba(88,166,255,0.15);color:#58a6ff;">active</span>'; + const token = q.queue_token + ? `<span class="detail-mono">${escapeHtml(q.queue_token)}</span>` + : '<span style="color:#6e7681;">-</span>'; + + const tr = document.createElement('tr'); + tr.innerHTML = ` + <td style="padding: 6px 8px; color: #f0f6fc; border-bottom: 1px solid #21262d; text-align: right; font-family: monospace;">${escapeHtml(String(id))}</td> + <td style="padding: 6px 8px; border-bottom: 1px solid #21262d; text-align: center;">${badge}</td> + <td style="padding: 6px 8px; color: #c9d1d9; border-bottom: 1px solid #21262d; text-align: right;">${q.active_count ?? 0}</td> + <td style="padding: 6px 8px; color: #3fb950; border-bottom: 1px solid #21262d; text-align: right;">${q.completed_count ?? 0}</td> + <td style="padding: 6px 8px; color: #f85149; border-bottom: 1px solid #21262d; text-align: right;">${q.failed_count ?? 0}</td> + <td style="padding: 6px 8px; color: #d29922; border-bottom: 1px solid #21262d; text-align: right;">${q.abandoned_count ?? 0}</td> + <td style="padding: 6px 8px; color: #f0883e; border-bottom: 1px solid #21262d; text-align: right;">${q.cancelled_count ?? 0}</td> + <td style="padding: 6px 8px; border-bottom: 1px solid #21262d;">${token}</td> + `; + tbody.appendChild(tr); + } + + container.style.display = 'block'; + } + async function fetchActionHistory() { - const data = await fetchJSON('/apply/jobs/history?limit=50'); + const data = await fetchJSON('/compute/jobs/history?limit=50'); const entries = data.history || []; const empty = document.getElementById('action-history-empty'); @@ -948,16 +1022,22 @@ const startDate = filetimeToDate(entry.time_Running); const endDate = filetimeToDate(entry.time_Completed ?? entry.time_Failed); + const queueId = entry.queueId || 0; + const queueCell = queueId + ? `<a href="/compute/queues/${queueId}" style="color: #58a6ff; text-decoration: none; font-family: monospace;">${escapeHtml(String(queueId))}</a>` + : '<span style="color: #6e7681;">-</span>'; + const tr = document.createElement('tr'); tr.innerHTML = ` - <td style="padding: 6px 8px; color: #8b949e; border-bottom: 1px solid #21262d; text-align: right; font-family: monospace;">${lsn}</td> + <td style="padding: 6px 8px; color: #8b949e; border-bottom: 1px solid #21262d; text-align: right; font-family: monospace;">${escapeHtml(String(lsn))}</td> + <td style="padding: 6px 8px; border-bottom: 1px solid #21262d; text-align: right;">${queueCell}</td> <td style="padding: 6px 8px; border-bottom: 1px solid #21262d; text-align: center;">${badge}</td> - <td style="padding: 6px 8px; color: #f0f6fc; border-bottom: 1px solid #21262d;">${fn}</td> + <td style="padding: 6px 8px; color: #f0f6fc; border-bottom: 1px solid #21262d;">${escapeHtml(fn)}</td> <td style="padding: 6px 8px; color: #8b949e; border-bottom: 1px solid #21262d; text-align: right; font-size: 12px; white-space: nowrap;">${formatTime(startDate)}</td> <td style="padding: 6px 8px; color: #8b949e; border-bottom: 1px solid #21262d; text-align: right; font-size: 12px; white-space: nowrap;">${formatTime(endDate)}</td> <td style="padding: 6px 8px; color: #c9d1d9; border-bottom: 1px solid #21262d; text-align: right; font-size: 12px; white-space: nowrap;">${formatDuration(startDate, endDate)}</td> - <td style="padding: 6px 8px; color: #8b949e; border-bottom: 1px solid #21262d; font-family: monospace; font-size: 11px;">${workerId}</td> - <td style="padding: 6px 8px; color: #8b949e; border-bottom: 1px solid #21262d; font-family: monospace; font-size: 11px;">${actionId}</td> + <td style="padding: 6px 8px; color: #8b949e; border-bottom: 1px solid #21262d; font-family: monospace; font-size: 11px;">${escapeHtml(workerId)}</td> + <td style="padding: 6px 8px; color: #8b949e; border-bottom: 1px solid #21262d; font-family: monospace; font-size: 11px;">${escapeHtml(actionId)}</td> `; tbody.appendChild(tr); } @@ -972,6 +1052,7 @@ fetchStats(), fetchSysInfo(), fetchWorkers(), + fetchQueues(), fetchActionHistory() ]); diff --git a/src/zenserver/frontend/html/compute/hub.html b/src/zenserver/frontend/html/compute/hub.html new file mode 100644 index 000000000..f66ba94d5 --- /dev/null +++ b/src/zenserver/frontend/html/compute/hub.html @@ -0,0 +1,310 @@ +<!DOCTYPE html> +<html lang="en"> +<head> + <meta charset="UTF-8"> + <meta name="viewport" content="width=device-width, initial-scale=1.0"> + <script src="banner.js" defer></script> + <script src="nav.js" defer></script> + <title>Zen Hub Dashboard</title> + <style> + * { + margin: 0; + padding: 0; + box-sizing: border-box; + } + + body { + font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, sans-serif; + background: #0d1117; + color: #c9d1d9; + padding: 20px; + } + + .container { + max-width: 1400px; + margin: 0 auto; + } + + .timestamp { + font-size: 12px; + color: #6e7681; + } + + .grid { + display: grid; + grid-template-columns: repeat(auto-fit, minmax(220px, 1fr)); + gap: 20px; + margin-bottom: 30px; + } + + .card { + background: #161b22; + border: 1px solid #30363d; + border-radius: 6px; + padding: 20px; + } + + .card-title { + font-size: 14px; + font-weight: 600; + color: #8b949e; + margin-bottom: 12px; + text-transform: uppercase; + letter-spacing: 0.5px; + } + + .metric-value { + font-size: 36px; + font-weight: 600; + color: #f0f6fc; + line-height: 1; + } + + .metric-label { + font-size: 12px; + color: #8b949e; + margin-top: 4px; + } + + .progress-bar { + width: 100%; + height: 8px; + background: #21262d; + border-radius: 4px; + overflow: hidden; + margin-top: 8px; + } + + .progress-fill { + height: 100%; + background: #58a6ff; + transition: width 0.3s ease; + } + + .error { + color: #f85149; + padding: 12px; + background: #1c1c1c; + border-radius: 6px; + margin: 20px 0; + font-size: 13px; + } + + .section-title { + font-size: 20px; + font-weight: 600; + margin-bottom: 20px; + color: #f0f6fc; + } + + table { + width: 100%; + border-collapse: collapse; + font-size: 13px; + } + + th { + text-align: left; + color: #8b949e; + padding: 8px 12px; + border-bottom: 1px solid #30363d; + font-weight: 600; + text-transform: uppercase; + letter-spacing: 0.5px; + font-size: 11px; + } + + td { + padding: 8px 12px; + border-bottom: 1px solid #21262d; + color: #c9d1d9; + } + + tr:last-child td { + border-bottom: none; + } + + .status-badge { + display: inline-block; + padding: 2px 8px; + border-radius: 4px; + font-size: 11px; + font-weight: 600; + } + + .status-badge.active { + background: rgba(63, 185, 80, 0.15); + color: #3fb950; + } + + .status-badge.inactive { + background: rgba(139, 148, 158, 0.15); + color: #8b949e; + } + + .empty-state { + color: #6e7681; + font-size: 13px; + padding: 20px 0; + text-align: center; + } + </style> +</head> +<body> + <div class="container"> + <zen-banner cluster-status="nominal" subtitle="HUB" tagline="Overview"></zen-banner> + <zen-nav> + <a href="hub.html">Hub</a> + </zen-nav> + <div class="timestamp">Last updated: <span id="last-update">Never</span></div> + + <div id="error-container"></div> + + <div class="section-title">Capacity</div> + <div class="grid"> + <div class="card"> + <div class="card-title">Active Modules</div> + <div class="metric-value" id="instance-count">-</div> + <div class="metric-label">Currently provisioned</div> + </div> + <div class="card"> + <div class="card-title">Peak Modules</div> + <div class="metric-value" id="max-instance-count">-</div> + <div class="metric-label">High watermark</div> + </div> + <div class="card"> + <div class="card-title">Instance Limit</div> + <div class="metric-value" id="instance-limit">-</div> + <div class="metric-label">Maximum allowed</div> + <div class="progress-bar"> + <div class="progress-fill" id="capacity-progress" style="width: 0%"></div> + </div> + </div> + </div> + + <div class="section-title">Modules</div> + <div class="card"> + <div class="card-title">Storage Server Instances</div> + <div id="empty-state" class="empty-state">No modules provisioned.</div> + <table id="module-table" style="display: none;"> + <thead> + <tr> + <th>Module ID</th> + <th style="text-align: center;">Status</th> + </tr> + </thead> + <tbody id="module-table-body"></tbody> + </table> + </div> + </div> + + <script> + const BASE_URL = window.location.origin; + const REFRESH_INTERVAL = 2000; + + function escapeHtml(text) { + var div = document.createElement('div'); + div.textContent = text; + return div.innerHTML; + } + + function showError(message) { + document.getElementById('error-container').innerHTML = + '<div class="error">Error: ' + escapeHtml(message) + '</div>'; + } + + function clearError() { + document.getElementById('error-container').innerHTML = ''; + } + + async function fetchJSON(endpoint) { + var response = await fetch(BASE_URL + endpoint, { + headers: { 'Accept': 'application/json' } + }); + if (!response.ok) { + throw new Error('HTTP ' + response.status + ': ' + response.statusText); + } + return await response.json(); + } + + async function fetchStats() { + var data = await fetchJSON('/hub/stats'); + + var current = data.currentInstanceCount || 0; + var max = data.maxInstanceCount || 0; + var limit = data.instanceLimit || 0; + + document.getElementById('instance-count').textContent = current; + document.getElementById('max-instance-count').textContent = max; + document.getElementById('instance-limit').textContent = limit; + + var pct = limit > 0 ? (current / limit) * 100 : 0; + document.getElementById('capacity-progress').style.width = pct + '%'; + + var banner = document.querySelector('zen-banner'); + if (current === 0) { + banner.setAttribute('cluster-status', 'nominal'); + } else if (limit > 0 && current >= limit * 0.9) { + banner.setAttribute('cluster-status', 'degraded'); + } else { + banner.setAttribute('cluster-status', 'nominal'); + } + } + + async function fetchModules() { + var data = await fetchJSON('/hub/status'); + var modules = data.modules || []; + + var emptyState = document.getElementById('empty-state'); + var table = document.getElementById('module-table'); + var tbody = document.getElementById('module-table-body'); + + if (modules.length === 0) { + emptyState.style.display = ''; + table.style.display = 'none'; + return; + } + + emptyState.style.display = 'none'; + table.style.display = ''; + + tbody.innerHTML = ''; + for (var i = 0; i < modules.length; i++) { + var m = modules[i]; + var moduleId = m.moduleId || ''; + var provisioned = m.provisioned; + + var badge = provisioned + ? '<span class="status-badge active">Provisioned</span>' + : '<span class="status-badge inactive">Inactive</span>'; + + var tr = document.createElement('tr'); + tr.innerHTML = + '<td style="font-family: monospace; font-size: 12px;">' + escapeHtml(moduleId) + '</td>' + + '<td style="text-align: center;">' + badge + '</td>'; + tbody.appendChild(tr); + } + } + + async function updateDashboard() { + var banner = document.querySelector('zen-banner'); + try { + await Promise.all([ + fetchStats(), + fetchModules() + ]); + + clearError(); + document.getElementById('last-update').textContent = new Date().toLocaleTimeString(); + } catch (error) { + console.error('Error updating dashboard:', error); + showError(error.message); + banner.setAttribute('cluster-status', 'offline'); + } + } + + updateDashboard(); + setInterval(updateDashboard, REFRESH_INTERVAL); + </script> +</body> +</html> diff --git a/src/zenserver/frontend/html/compute/index.html b/src/zenserver/frontend/html/compute/index.html new file mode 100644 index 000000000..9597fd7f3 --- /dev/null +++ b/src/zenserver/frontend/html/compute/index.html @@ -0,0 +1 @@ +<meta http-equiv="refresh" content="0; url=compute.html" />
\ No newline at end of file diff --git a/src/zenserver/frontend/html/compute/nav.js b/src/zenserver/frontend/html/compute/nav.js new file mode 100644 index 000000000..8ec42abd0 --- /dev/null +++ b/src/zenserver/frontend/html/compute/nav.js @@ -0,0 +1,79 @@ +/** + * zen-nav.js — Zen dashboard navigation bar Web Component + * + * Usage: + * <script src="nav.js" defer></script> + * + * <zen-nav> + * <a href="compute.html">Node</a> + * <a href="orchestrator.html">Orchestrator</a> + * </zen-nav> + * + * Each child <a> becomes a nav link. The current page is + * highlighted automatically based on the href. + */ + +class ZenNav extends HTMLElement { + + connectedCallback() { + if (!this.shadowRoot) this.attachShadow({ mode: 'open' }); + this._render(); + } + + _render() { + const currentPath = window.location.pathname; + const items = Array.from(this.querySelectorAll(':scope > a')); + + const links = items.map(a => { + const href = a.getAttribute('href') || ''; + const label = a.textContent.trim(); + const active = currentPath.endsWith(href); + return `<a class="nav-link${active ? ' active' : ''}" href="${href}">${label}</a>`; + }).join(''); + + this.shadowRoot.innerHTML = ` + <style> + *, *::before, *::after { box-sizing: border-box; margin: 0; padding: 0; } + + :host { + display: block; + margin-bottom: 16px; + } + + .nav-bar { + display: flex; + align-items: center; + gap: 4px; + padding: 4px; + background: #161b22; + border: 1px solid #30363d; + border-radius: 6px; + } + + .nav-link { + font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, sans-serif; + font-size: 13px; + font-weight: 500; + color: #8b949e; + text-decoration: none; + padding: 6px 14px; + border-radius: 4px; + transition: color 0.15s, background 0.15s; + } + + .nav-link:hover { + color: #c9d1d9; + background: #21262d; + } + + .nav-link.active { + color: #f0f6fc; + background: #30363d; + } + </style> + <nav class="nav-bar">${links}</nav> + `; + } +} + +customElements.define('zen-nav', ZenNav); diff --git a/src/zenserver/frontend/html/compute/orchestrator.html b/src/zenserver/frontend/html/compute/orchestrator.html new file mode 100644 index 000000000..2ee57b6b3 --- /dev/null +++ b/src/zenserver/frontend/html/compute/orchestrator.html @@ -0,0 +1,831 @@ +<!DOCTYPE html> +<html lang="en"> +<head> + <meta charset="UTF-8"> + <meta name="viewport" content="width=device-width, initial-scale=1.0"> + <script src="banner.js" defer></script> + <script src="nav.js" defer></script> + <title>Zen Orchestrator Dashboard</title> + <style> + * { + margin: 0; + padding: 0; + box-sizing: border-box; + } + + body { + font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, sans-serif; + background: #0d1117; + color: #c9d1d9; + padding: 20px; + } + + .container { + max-width: 1400px; + margin: 0 auto; + } + + h1 { + font-size: 32px; + font-weight: 600; + margin-bottom: 10px; + color: #f0f6fc; + } + + .header { + display: flex; + justify-content: space-between; + align-items: center; + margin-bottom: 30px; + } + + .timestamp { + font-size: 12px; + color: #6e7681; + } + + .agent-count { + display: flex; + align-items: center; + gap: 8px; + font-size: 14px; + padding: 8px 16px; + border-radius: 6px; + background: #161b22; + border: 1px solid #30363d; + } + + .agent-count .count { + font-size: 20px; + font-weight: 600; + color: #f0f6fc; + } + + .card { + background: #161b22; + border: 1px solid #30363d; + border-radius: 6px; + padding: 20px; + } + + .card-title { + font-size: 14px; + font-weight: 600; + color: #8b949e; + margin-bottom: 12px; + text-transform: uppercase; + letter-spacing: 0.5px; + } + + .error { + color: #f85149; + padding: 12px; + background: #1c1c1c; + border-radius: 6px; + margin: 20px 0; + font-size: 13px; + } + + table { + width: 100%; + border-collapse: collapse; + font-size: 13px; + } + + th { + text-align: left; + color: #8b949e; + padding: 8px 12px; + border-bottom: 1px solid #30363d; + font-weight: 600; + text-transform: uppercase; + letter-spacing: 0.5px; + font-size: 11px; + } + + td { + padding: 8px 12px; + border-bottom: 1px solid #21262d; + color: #c9d1d9; + } + + tr:last-child td { + border-bottom: none; + } + + .total-row td { + border-top: 2px solid #30363d; + font-weight: 600; + color: #f0f6fc; + } + + .health-dot { + display: inline-block; + width: 10px; + height: 10px; + border-radius: 50%; + } + + .health-green { + background: #3fb950; + } + + .health-yellow { + background: #d29922; + } + + .health-red { + background: #f85149; + } + + a { + color: #58a6ff; + text-decoration: none; + } + + a:hover { + text-decoration: underline; + } + + .empty-state { + color: #6e7681; + font-size: 13px; + padding: 20px 0; + text-align: center; + } + + .history-tabs { + display: flex; + gap: 4px; + background: #0d1117; + border-radius: 6px; + padding: 2px; + } + + .history-tab { + background: transparent; + border: none; + color: #8b949e; + font-size: 12px; + font-weight: 600; + padding: 4px 12px; + border-radius: 4px; + cursor: pointer; + text-transform: uppercase; + letter-spacing: 0.5px; + } + + .history-tab:hover { + color: #c9d1d9; + } + + .history-tab.active { + background: #30363d; + color: #f0f6fc; + } + </style> +</head> +<body> + <div class="container"> + <zen-banner cluster-status="nominal" load="0"></zen-banner> + <zen-nav> + <a href="compute.html">Node</a> + <a href="orchestrator.html">Orchestrator</a> + </zen-nav> + <div class="header"> + <div> + <div class="timestamp">Last updated: <span id="last-update">Never</span></div> + </div> + <div class="agent-count"> + <span>Agents:</span> + <span class="count" id="agent-count">-</span> + </div> + </div> + + <div id="error-container"></div> + + <div class="card"> + <div class="card-title">Compute Agents</div> + <div id="empty-state" class="empty-state">No agents registered.</div> + <table id="agent-table" style="display: none;"> + <thead> + <tr> + <th style="width: 40px; text-align: center;">Health</th> + <th>Hostname</th> + <th style="text-align: right;">CPUs</th> + <th style="text-align: right;">CPU Usage</th> + <th style="text-align: right;">Memory</th> + <th style="text-align: right;">Queues</th> + <th style="text-align: right;">Pending</th> + <th style="text-align: right;">Running</th> + <th style="text-align: right;">Completed</th> + <th style="text-align: right;">Traffic</th> + <th style="text-align: right;">Last Seen</th> + </tr> + </thead> + <tbody id="agent-table-body"></tbody> + </table> + </div> + <div class="card" style="margin-top: 20px;"> + <div class="card-title">Connected Clients</div> + <div id="clients-empty" class="empty-state">No clients connected.</div> + <table id="clients-table" style="display: none;"> + <thead> + <tr> + <th style="width: 40px; text-align: center;">Health</th> + <th>Client ID</th> + <th>Hostname</th> + <th>Address</th> + <th style="text-align: right;">Last Seen</th> + </tr> + </thead> + <tbody id="clients-table-body"></tbody> + </table> + </div> + <div class="card" style="margin-top: 20px;"> + <div style="display: flex; align-items: center; gap: 12px; margin-bottom: 12px;"> + <div class="card-title" style="margin-bottom: 0;">Event History</div> + <div class="history-tabs"> + <button class="history-tab active" data-tab="workers" onclick="switchHistoryTab('workers')">Workers</button> + <button class="history-tab" data-tab="clients" onclick="switchHistoryTab('clients')">Clients</button> + </div> + </div> + <div id="history-panel-workers"> + <div id="history-empty" class="empty-state">No provisioning events recorded.</div> + <table id="history-table" style="display: none;"> + <thead> + <tr> + <th>Time</th> + <th>Event</th> + <th>Worker</th> + <th>Hostname</th> + </tr> + </thead> + <tbody id="history-table-body"></tbody> + </table> + </div> + <div id="history-panel-clients" style="display: none;"> + <div id="client-history-empty" class="empty-state">No client events recorded.</div> + <table id="client-history-table" style="display: none;"> + <thead> + <tr> + <th>Time</th> + <th>Event</th> + <th>Client</th> + <th>Hostname</th> + </tr> + </thead> + <tbody id="client-history-table-body"></tbody> + </table> + </div> + </div> + </div> + + <script> + const BASE_URL = window.location.origin; + const REFRESH_INTERVAL = 2000; + + function escapeHtml(text) { + var div = document.createElement('div'); + div.textContent = text; + return div.innerHTML; + } + + function showError(message) { + document.getElementById('error-container').innerHTML = + '<div class="error">Error: ' + escapeHtml(message) + '</div>'; + } + + function clearError() { + document.getElementById('error-container').innerHTML = ''; + } + + function formatLastSeen(dtMs) { + if (dtMs == null) return '-'; + var seconds = Math.floor(dtMs / 1000); + if (seconds < 60) return seconds + 's ago'; + var minutes = Math.floor(seconds / 60); + if (minutes < 60) return minutes + 'm ' + (seconds % 60) + 's ago'; + var hours = Math.floor(minutes / 60); + return hours + 'h ' + (minutes % 60) + 'm ago'; + } + + function healthClass(dtMs, reachable) { + if (reachable === false) return 'health-red'; + if (dtMs == null) return 'health-red'; + var seconds = dtMs / 1000; + if (seconds < 30 && reachable === true) return 'health-green'; + if (seconds < 120) return 'health-yellow'; + return 'health-red'; + } + + function healthTitle(dtMs, reachable) { + var seenStr = dtMs != null ? 'Last seen ' + formatLastSeen(dtMs) : 'Never seen'; + if (reachable === true) return seenStr + ' · Reachable'; + if (reachable === false) return seenStr + ' · Unreachable'; + return seenStr + ' · Reachability unknown'; + } + + function formatCpuUsage(percent) { + if (percent == null || percent === 0) return '-'; + return percent.toFixed(1) + '%'; + } + + function formatMemory(usedBytes, totalBytes) { + if (!totalBytes) return '-'; + var usedGiB = usedBytes / (1024 * 1024 * 1024); + var totalGiB = totalBytes / (1024 * 1024 * 1024); + return usedGiB.toFixed(1) + ' / ' + totalGiB.toFixed(1) + ' GiB'; + } + + function formatBytes(bytes) { + if (!bytes) return '-'; + if (bytes < 1024) return bytes + ' B'; + if (bytes < 1024 * 1024) return (bytes / 1024).toFixed(1) + ' KiB'; + if (bytes < 1024 * 1024 * 1024) return (bytes / (1024 * 1024)).toFixed(1) + ' MiB'; + if (bytes < 1024 * 1024 * 1024 * 1024) return (bytes / (1024 * 1024 * 1024)).toFixed(1) + ' GiB'; + return (bytes / (1024 * 1024 * 1024 * 1024)).toFixed(1) + ' TiB'; + } + + function formatTraffic(recv, sent) { + if (!recv && !sent) return '-'; + return formatBytes(recv) + ' / ' + formatBytes(sent); + } + + function parseIpFromUri(uri) { + try { + var url = new URL(uri); + var host = url.hostname; + // Strip IPv6 brackets + if (host.startsWith('[') && host.endsWith(']')) host = host.slice(1, -1); + // Only handle IPv4 + var parts = host.split('.'); + if (parts.length !== 4) return null; + var octets = parts.map(Number); + if (octets.some(function(o) { return isNaN(o) || o < 0 || o > 255; })) return null; + return octets; + } catch (e) { + return null; + } + } + + function computeCidr(ips) { + if (ips.length === 0) return null; + if (ips.length === 1) return ips[0].join('.') + '/32'; + + // Convert each IP to a 32-bit integer + var ints = ips.map(function(o) { + return ((o[0] << 24) | (o[1] << 16) | (o[2] << 8) | o[3]) >>> 0; + }); + + // Find common prefix length by ANDing all identical high bits + var common = ~0 >>> 0; + for (var i = 1; i < ints.length; i++) { + // XOR to find differing bits, then mask away everything from the first difference down + var diff = (ints[0] ^ ints[i]) >>> 0; + if (diff !== 0) { + var bit = 31 - Math.floor(Math.log2(diff)); + var mask = bit > 0 ? ((~0 << (32 - bit)) >>> 0) : 0; + common = (common & mask) >>> 0; + } + } + + // Count leading ones in the common mask + var prefix = 0; + for (var b = 31; b >= 0; b--) { + if ((common >>> b) & 1) prefix++; + else break; + } + + // Network address + var net = (ints[0] & common) >>> 0; + var a = (net >>> 24) & 0xff; + var bv = (net >>> 16) & 0xff; + var c = (net >>> 8) & 0xff; + var d = net & 0xff; + return a + '.' + bv + '.' + c + '.' + d + '/' + prefix; + } + + function renderDashboard(data) { + var banner = document.querySelector('zen-banner'); + if (data.hostname) { + banner.setAttribute('tagline', 'Orchestrator \u2014 ' + data.hostname); + } + var workers = data.workers || []; + + document.getElementById('agent-count').textContent = workers.length; + + if (workers.length === 0) { + banner.setAttribute('cluster-status', 'degraded'); + banner.setAttribute('load', '0'); + } else { + banner.setAttribute('cluster-status', 'nominal'); + } + + var emptyState = document.getElementById('empty-state'); + var table = document.getElementById('agent-table'); + var tbody = document.getElementById('agent-table-body'); + + if (workers.length === 0) { + emptyState.style.display = ''; + table.style.display = 'none'; + } else { + emptyState.style.display = 'none'; + table.style.display = ''; + + tbody.innerHTML = ''; + var totalCpus = 0; + var totalWeightedCpuUsage = 0; + var totalMemUsed = 0; + var totalMemTotal = 0; + var totalQueues = 0; + var totalPending = 0; + var totalRunning = 0; + var totalCompleted = 0; + var totalBytesRecv = 0; + var totalBytesSent = 0; + var allIps = []; + for (var i = 0; i < workers.length; i++) { + var w = workers[i]; + var uri = w.uri || ''; + var dt = w.dt; + var dashboardUrl = uri + '/dashboard/compute/'; + + var id = w.id || ''; + + var hostname = w.hostname || ''; + var cpus = w.cpus || 0; + totalCpus += cpus; + if (cpus > 0 && typeof w.cpu_usage === 'number') { + totalWeightedCpuUsage += w.cpu_usage * cpus; + } + + var memTotal = w.memory_total || 0; + var memUsed = w.memory_used || 0; + totalMemTotal += memTotal; + totalMemUsed += memUsed; + + var activeQueues = w.active_queues || 0; + totalQueues += activeQueues; + + var actionsPending = w.actions_pending || 0; + var actionsRunning = w.actions_running || 0; + var actionsCompleted = w.actions_completed || 0; + totalPending += actionsPending; + totalRunning += actionsRunning; + totalCompleted += actionsCompleted; + + var bytesRecv = w.bytes_received || 0; + var bytesSent = w.bytes_sent || 0; + totalBytesRecv += bytesRecv; + totalBytesSent += bytesSent; + + var ip = parseIpFromUri(uri); + if (ip) allIps.push(ip); + + var reachable = w.reachable; + var hClass = healthClass(dt, reachable); + var hTitle = healthTitle(dt, reachable); + + var platform = w.platform || ''; + var badges = ''; + if (platform) { + var platColors = { windows: '#0078d4', wine: '#722f37', linux: '#e95420', macos: '#a2aaad' }; + var platColor = platColors[platform] || '#8b949e'; + badges += ' <span style="display:inline-block;padding:1px 6px;border-radius:10px;font-size:10px;font-weight:600;color:#fff;background:' + platColor + ';vertical-align:middle;margin-left:4px;">' + escapeHtml(platform) + '</span>'; + } + var provisioner = w.provisioner || ''; + if (provisioner) { + var provColors = { horde: '#8957e5', nomad: '#3fb950' }; + var provColor = provColors[provisioner] || '#8b949e'; + badges += ' <span style="display:inline-block;padding:1px 6px;border-radius:10px;font-size:10px;font-weight:600;color:#fff;background:' + provColor + ';vertical-align:middle;margin-left:4px;">' + escapeHtml(provisioner) + '</span>'; + } + + var tr = document.createElement('tr'); + tr.title = id; + tr.innerHTML = + '<td style="text-align: center;"><span class="health-dot ' + hClass + '" title="' + escapeHtml(hTitle) + '"></span></td>' + + '<td><a href="' + escapeHtml(dashboardUrl) + '" target="_blank">' + escapeHtml(hostname) + '</a>' + badges + '</td>' + + '<td style="text-align: right;">' + (cpus > 0 ? cpus : '-') + '</td>' + + '<td style="text-align: right;">' + formatCpuUsage(w.cpu_usage) + '</td>' + + '<td style="text-align: right;">' + formatMemory(memUsed, memTotal) + '</td>' + + '<td style="text-align: right;">' + (activeQueues > 0 ? activeQueues : '-') + '</td>' + + '<td style="text-align: right;">' + actionsPending + '</td>' + + '<td style="text-align: right;">' + actionsRunning + '</td>' + + '<td style="text-align: right;">' + actionsCompleted + '</td>' + + '<td style="text-align: right; color: #8b949e; font-size: 11px;">' + formatTraffic(bytesRecv, bytesSent) + '</td>' + + '<td style="text-align: right; color: #8b949e;">' + formatLastSeen(dt) + '</td>'; + tbody.appendChild(tr); + } + + var clusterLoad = totalCpus > 0 ? (totalWeightedCpuUsage / totalCpus) : 0; + banner.setAttribute('load', clusterLoad.toFixed(1)); + + // Total row + var cidr = computeCidr(allIps); + var totalTr = document.createElement('tr'); + totalTr.className = 'total-row'; + totalTr.innerHTML = + '<td></td>' + + '<td style="text-align: right; color: #8b949e; text-transform: uppercase; font-size: 11px;">Total' + (cidr ? ' <span style="font-family: monospace; font-weight: normal;">' + escapeHtml(cidr) + '</span>' : '') + '</td>' + + '<td style="text-align: right;">' + totalCpus + '</td>' + + '<td></td>' + + '<td style="text-align: right;">' + formatMemory(totalMemUsed, totalMemTotal) + '</td>' + + '<td style="text-align: right;">' + totalQueues + '</td>' + + '<td style="text-align: right;">' + totalPending + '</td>' + + '<td style="text-align: right;">' + totalRunning + '</td>' + + '<td style="text-align: right;">' + totalCompleted + '</td>' + + '<td style="text-align: right; font-size: 11px;">' + formatTraffic(totalBytesRecv, totalBytesSent) + '</td>' + + '<td></td>'; + tbody.appendChild(totalTr); + } + + clearError(); + document.getElementById('last-update').textContent = new Date().toLocaleTimeString(); + + // Render provisioning history if present in WebSocket payload + if (data.events) { + renderProvisioningHistory(data.events); + } + + // Render connected clients if present + if (data.clients) { + renderClients(data.clients); + } + + // Render client history if present + if (data.client_events) { + renderClientHistory(data.client_events); + } + } + + function eventBadge(type) { + var colors = { joined: '#3fb950', left: '#f85149', returned: '#d29922' }; + var labels = { joined: 'Joined', left: 'Left', returned: 'Returned' }; + var color = colors[type] || '#8b949e'; + var label = labels[type] || type; + return '<span style="display:inline-block;padding:2px 8px;border-radius:4px;font-size:11px;font-weight:600;color:#0d1117;background:' + color + ';">' + escapeHtml(label) + '</span>'; + } + + function formatTimestamp(ts) { + if (!ts) return '-'; + // CbObject DateTime serialized as ticks (100ns since 0001-01-01) or ISO string + var date; + if (typeof ts === 'number') { + // .NET-style ticks: convert to Unix ms + var unixMs = (ts - 621355968000000000) / 10000; + date = new Date(unixMs); + } else { + date = new Date(ts); + } + if (isNaN(date.getTime())) return '-'; + return date.toLocaleTimeString(); + } + + var activeHistoryTab = 'workers'; + + function switchHistoryTab(tab) { + activeHistoryTab = tab; + var tabs = document.querySelectorAll('.history-tab'); + for (var i = 0; i < tabs.length; i++) { + tabs[i].classList.toggle('active', tabs[i].getAttribute('data-tab') === tab); + } + document.getElementById('history-panel-workers').style.display = tab === 'workers' ? '' : 'none'; + document.getElementById('history-panel-clients').style.display = tab === 'clients' ? '' : 'none'; + } + + function renderProvisioningHistory(events) { + var emptyState = document.getElementById('history-empty'); + var table = document.getElementById('history-table'); + var tbody = document.getElementById('history-table-body'); + + if (!events || events.length === 0) { + emptyState.style.display = ''; + table.style.display = 'none'; + return; + } + + emptyState.style.display = 'none'; + table.style.display = ''; + tbody.innerHTML = ''; + + for (var i = 0; i < events.length; i++) { + var evt = events[i]; + var tr = document.createElement('tr'); + tr.innerHTML = + '<td style="color: #8b949e;">' + formatTimestamp(evt.ts) + '</td>' + + '<td>' + eventBadge(evt.type) + '</td>' + + '<td>' + escapeHtml(evt.worker_id || '') + '</td>' + + '<td>' + escapeHtml(evt.hostname || '') + '</td>'; + tbody.appendChild(tr); + } + } + + function clientHealthClass(dtMs) { + if (dtMs == null) return 'health-red'; + var seconds = dtMs / 1000; + if (seconds < 30) return 'health-green'; + if (seconds < 120) return 'health-yellow'; + return 'health-red'; + } + + function renderClients(clients) { + var emptyState = document.getElementById('clients-empty'); + var table = document.getElementById('clients-table'); + var tbody = document.getElementById('clients-table-body'); + + if (!clients || clients.length === 0) { + emptyState.style.display = ''; + table.style.display = 'none'; + return; + } + + emptyState.style.display = 'none'; + table.style.display = ''; + tbody.innerHTML = ''; + + for (var i = 0; i < clients.length; i++) { + var c = clients[i]; + var dt = c.dt; + var hClass = clientHealthClass(dt); + var hTitle = dt != null ? 'Last seen ' + formatLastSeen(dt) : 'Never seen'; + + var sessionBadge = ''; + if (c.session_id) { + sessionBadge = ' <span style="font-family:monospace;font-size:10px;color:#6e7681;" title="Session ' + escapeHtml(c.session_id) + '">' + escapeHtml(c.session_id.substring(0, 8)) + '</span>'; + } + + var tr = document.createElement('tr'); + tr.innerHTML = + '<td style="text-align: center;"><span class="health-dot ' + hClass + '" title="' + escapeHtml(hTitle) + '"></span></td>' + + '<td>' + escapeHtml(c.id || '') + sessionBadge + '</td>' + + '<td>' + escapeHtml(c.hostname || '') + '</td>' + + '<td style="font-family: monospace; font-size: 12px; color: #8b949e;">' + escapeHtml(c.address || '') + '</td>' + + '<td style="text-align: right; color: #8b949e;">' + formatLastSeen(dt) + '</td>'; + tbody.appendChild(tr); + } + } + + function clientEventBadge(type) { + var colors = { connected: '#3fb950', disconnected: '#f85149', updated: '#d29922' }; + var labels = { connected: 'Connected', disconnected: 'Disconnected', updated: 'Updated' }; + var color = colors[type] || '#8b949e'; + var label = labels[type] || type; + return '<span style="display:inline-block;padding:2px 8px;border-radius:4px;font-size:11px;font-weight:600;color:#0d1117;background:' + color + ';">' + escapeHtml(label) + '</span>'; + } + + function renderClientHistory(events) { + var emptyState = document.getElementById('client-history-empty'); + var table = document.getElementById('client-history-table'); + var tbody = document.getElementById('client-history-table-body'); + + if (!events || events.length === 0) { + emptyState.style.display = ''; + table.style.display = 'none'; + return; + } + + emptyState.style.display = 'none'; + table.style.display = ''; + tbody.innerHTML = ''; + + for (var i = 0; i < events.length; i++) { + var evt = events[i]; + var tr = document.createElement('tr'); + tr.innerHTML = + '<td style="color: #8b949e;">' + formatTimestamp(evt.ts) + '</td>' + + '<td>' + clientEventBadge(evt.type) + '</td>' + + '<td>' + escapeHtml(evt.client_id || '') + '</td>' + + '<td>' + escapeHtml(evt.hostname || '') + '</td>'; + tbody.appendChild(tr); + } + } + + // Fetch-based polling fallback + var pollTimer = null; + + async function fetchProvisioningHistory() { + try { + var response = await fetch(BASE_URL + '/orch/history?limit=50', { + headers: { 'Accept': 'application/json' } + }); + if (response.ok) { + var data = await response.json(); + renderProvisioningHistory(data.events || []); + } + } catch (e) { + console.error('Error fetching provisioning history:', e); + } + } + + async function fetchClients() { + try { + var response = await fetch(BASE_URL + '/orch/clients', { + headers: { 'Accept': 'application/json' } + }); + if (response.ok) { + var data = await response.json(); + renderClients(data.clients || []); + } + } catch (e) { + console.error('Error fetching clients:', e); + } + } + + async function fetchClientHistory() { + try { + var response = await fetch(BASE_URL + '/orch/clients/history?limit=50', { + headers: { 'Accept': 'application/json' } + }); + if (response.ok) { + var data = await response.json(); + renderClientHistory(data.client_events || []); + } + } catch (e) { + console.error('Error fetching client history:', e); + } + } + + async function fetchDashboard() { + var banner = document.querySelector('zen-banner'); + try { + var response = await fetch(BASE_URL + '/orch/agents', { + headers: { 'Accept': 'application/json' } + }); + + if (!response.ok) { + banner.setAttribute('cluster-status', 'degraded'); + throw new Error('HTTP ' + response.status + ': ' + response.statusText); + } + + renderDashboard(await response.json()); + fetchProvisioningHistory(); + fetchClients(); + fetchClientHistory(); + } catch (error) { + console.error('Error updating dashboard:', error); + showError(error.message); + banner.setAttribute('cluster-status', 'offline'); + } + } + + function startPolling() { + if (pollTimer) return; + fetchDashboard(); + pollTimer = setInterval(fetchDashboard, REFRESH_INTERVAL); + } + + function stopPolling() { + if (pollTimer) { + clearInterval(pollTimer); + pollTimer = null; + } + } + + // WebSocket connection with automatic reconnect and polling fallback + var ws = null; + + function connectWebSocket() { + var proto = window.location.protocol === 'https:' ? 'wss:' : 'ws:'; + ws = new WebSocket(proto + '//' + window.location.host + '/orch/ws'); + + ws.onopen = function() { + stopPolling(); + clearError(); + }; + + ws.onmessage = function(event) { + try { + renderDashboard(JSON.parse(event.data)); + } catch (e) { + console.error('WebSocket message parse error:', e); + } + }; + + ws.onclose = function() { + ws = null; + startPolling(); + setTimeout(connectWebSocket, 3000); + }; + + ws.onerror = function() { + // onclose will fire after onerror + }; + } + + // Fetch orchestrator hostname for the banner + fetch(BASE_URL + '/orch/status', { headers: { 'Accept': 'application/json' } }) + .then(function(r) { return r.ok ? r.json() : null; }) + .then(function(d) { + if (d && d.hostname) { + document.querySelector('zen-banner').setAttribute('tagline', 'Orchestrator \u2014 ' + d.hostname); + } + }) + .catch(function() {}); + + // Initial load via fetch, then try WebSocket + fetchDashboard(); + connectWebSocket(); + </script> +</body> +</html> diff --git a/src/zenserver/frontend/html/pages/page.js b/src/zenserver/frontend/html/pages/page.js index 3c2d3619a..592b699dc 100644 --- a/src/zenserver/frontend/html/pages/page.js +++ b/src/zenserver/frontend/html/pages/page.js @@ -3,6 +3,7 @@ "use strict"; import { WidgetHost } from "../util/widgets.js" +import { Fetcher } from "../util/fetcher.js" //////////////////////////////////////////////////////////////////////////////// export class PageBase extends WidgetHost @@ -63,6 +64,7 @@ export class ZenPage extends PageBase super(parent, ...args); super.set_title("zen"); this.add_branding(parent); + this.add_service_nav(parent); this.generate_crumbs(); } @@ -78,6 +80,40 @@ export class ZenPage extends PageBase root.tag("img").attr("src", "epicgames.ico").id("epic_logo"); } + add_service_nav(parent) + { + const nav = parent.tag().id("service_nav"); + + // Map service base URIs to dashboard links, this table is also used to detemine + // which links to show based on the services that are currently registered. + + const service_dashboards = [ + { base_uri: "/compute/", label: "Compute", href: "/dashboard/compute/compute.html" }, + { base_uri: "/orch/", label: "Orchestrator", href: "/dashboard/compute/orchestrator.html" }, + { base_uri: "/hub/", label: "Hub", href: "/dashboard/compute/hub.html" }, + ]; + + new Fetcher().resource("/api/").json().then((data) => { + const services = data.services || []; + const uris = new Set(services.map(s => s.base_uri)); + + const links = service_dashboards.filter(d => uris.has(d.base_uri)); + + if (links.length === 0) + { + nav.inner().style.display = "none"; + return; + } + + for (const link of links) + { + nav.tag("a").text(link.label).attr("href", link.href); + } + }).catch(() => { + nav.inner().style.display = "none"; + }); + } + set_title(...args) { super.set_title(...args); diff --git a/src/zenserver/frontend/html/zen.css b/src/zenserver/frontend/html/zen.css index 702bf9aa6..a80a1a4f6 100644 --- a/src/zenserver/frontend/html/zen.css +++ b/src/zenserver/frontend/html/zen.css @@ -80,6 +80,33 @@ input { } } +/* service nav -------------------------------------------------------------- */ + +#service_nav { + display: flex; + justify-content: center; + gap: 0.3em; + margin-bottom: 1.5em; + padding: 0.3em; + background-color: var(--theme_g3); + border: 1px solid var(--theme_g2); + border-radius: 0.4em; + + a { + padding: 0.3em 0.9em; + border-radius: 0.3em; + font-size: 0.85em; + color: var(--theme_g1); + text-decoration: none; + } + + a:hover { + background-color: var(--theme_p4); + color: var(--theme_g0); + text-decoration: none; + } +} + /* links -------------------------------------------------------------------- */ a { diff --git a/src/zenserver/hub/hubservice.cpp b/src/zenserver/hub/hubservice.cpp index bf0e294c5..a757cd594 100644 --- a/src/zenserver/hub/hubservice.cpp +++ b/src/zenserver/hub/hubservice.cpp @@ -845,7 +845,7 @@ HttpHubService::HttpHubService(std::filesystem::path HubBaseDir, std::filesystem Obj << "currentInstanceCount" << m_Impl->GetInstanceCount(); Obj << "maxInstanceCount" << m_Impl->GetMaxInstanceCount(); Obj << "instanceLimit" << m_Impl->GetInstanceLimit(); - Req.ServerRequest().WriteResponse(HttpResponseCode::OK); + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Obj.Save()); }, HttpVerb::kGet); } diff --git a/src/zenserver/hub/zenhubserver.cpp b/src/zenserver/hub/zenhubserver.cpp index d0a0db417..c63c618df 100644 --- a/src/zenserver/hub/zenhubserver.cpp +++ b/src/zenserver/hub/zenhubserver.cpp @@ -143,6 +143,8 @@ ZenHubServer::InitializeServices(const ZenHubServerConfig& ServerConfig) ZEN_INFO("instantiating hub service"); m_HubService = std::make_unique<HttpHubService>(ServerConfig.DataDir / "hub", ServerConfig.DataDir / "servers"); m_HubService->SetNotificationEndpoint(ServerConfig.UpstreamNotificationEndpoint, ServerConfig.InstanceId); + + m_FrontendService = std::make_unique<HttpFrontendService>(m_ContentRoot, m_StatusService); } void @@ -159,6 +161,11 @@ ZenHubServer::RegisterServices(const ZenHubServerConfig& ServerConfig) { m_Http->RegisterService(*m_ApiService); } + + if (m_FrontendService) + { + m_Http->RegisterService(*m_FrontendService); + } } void diff --git a/src/zenserver/hub/zenhubserver.h b/src/zenserver/hub/zenhubserver.h index ac14362f0..4c56fdce5 100644 --- a/src/zenserver/hub/zenhubserver.h +++ b/src/zenserver/hub/zenhubserver.h @@ -2,6 +2,7 @@ #pragma once +#include "frontend/frontend.h" #include "zenserver.h" namespace cxxopts { @@ -81,8 +82,9 @@ private: std::filesystem::path m_ContentRoot; bool m_DebugOptionForcedCrash = false; - std::unique_ptr<HttpHubService> m_HubService; - std::unique_ptr<HttpApiService> m_ApiService; + std::unique_ptr<HttpHubService> m_HubService; + std::unique_ptr<HttpApiService> m_ApiService; + std::unique_ptr<HttpFrontendService> m_FrontendService; void InitializeState(const ZenHubServerConfig& ServerConfig); void InitializeServices(const ZenHubServerConfig& ServerConfig); diff --git a/src/zenserver/storage/zenstorageserver.cpp b/src/zenserver/storage/zenstorageserver.cpp index 3d81db656..bca26e87a 100644 --- a/src/zenserver/storage/zenstorageserver.cpp +++ b/src/zenserver/storage/zenstorageserver.cpp @@ -183,10 +183,15 @@ ZenStorageServer::RegisterServices() m_Http->RegisterService(*m_AdminService); + if (m_ApiService) + { + m_Http->RegisterService(*m_ApiService); + } + #if ZEN_WITH_COMPUTE_SERVICES - if (m_HttpFunctionService) + if (m_HttpComputeService) { - m_Http->RegisterService(*m_HttpFunctionService); + m_Http->RegisterService(*m_HttpComputeService); } #endif } @@ -279,8 +284,8 @@ ZenStorageServer::InitializeServices(const ZenStorageServerConfig& ServerOptions { ZEN_OTEL_SPAN("InitializeComputeService"); - m_HttpFunctionService = - std::make_unique<compute::HttpFunctionService>(*m_CidStore, m_StatsService, ServerOptions.DataDir / "functions"); + m_HttpComputeService = + std::make_unique<compute::HttpComputeService>(*m_CidStore, m_StatsService, ServerOptions.DataDir / "functions"); } #endif @@ -316,6 +321,8 @@ ZenStorageServer::InitializeServices(const ZenStorageServerConfig& ServerOptions .AttachmentPassCount = ServerOptions.GcConfig.AttachmentPassCount}; m_GcScheduler.Initialize(GcConfig); + m_ApiService = std::make_unique<HttpApiService>(*m_Http); + // Create and register admin interface last to make sure all is properly initialized m_AdminService = std::make_unique<HttpAdminService>( m_GcScheduler, @@ -832,7 +839,7 @@ ZenStorageServer::Cleanup() Flush(); #if ZEN_WITH_COMPUTE_SERVICES - m_HttpFunctionService.reset(); + m_HttpComputeService.reset(); #endif m_AdminService.reset(); diff --git a/src/zenserver/storage/zenstorageserver.h b/src/zenserver/storage/zenstorageserver.h index 456447a2a..5b163fc8e 100644 --- a/src/zenserver/storage/zenstorageserver.h +++ b/src/zenserver/storage/zenstorageserver.h @@ -25,7 +25,7 @@ #include "workspaces/httpworkspaces.h" #if ZEN_WITH_COMPUTE_SERVICES -# include <zencompute/httpfunctionservice.h> +# include <zencompute/httpcomputeservice.h> #endif namespace zen { @@ -93,7 +93,7 @@ private: std::unique_ptr<HttpApiService> m_ApiService; #if ZEN_WITH_COMPUTE_SERVICES - std::unique_ptr<compute::HttpFunctionService> m_HttpFunctionService; + std::unique_ptr<compute::HttpComputeService> m_HttpComputeService; #endif }; diff --git a/src/zenserver/trace/tracerecorder.cpp b/src/zenserver/trace/tracerecorder.cpp new file mode 100644 index 000000000..5dec20e18 --- /dev/null +++ b/src/zenserver/trace/tracerecorder.cpp @@ -0,0 +1,565 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "tracerecorder.h" + +#include <zencore/basicfile.h> +#include <zencore/filesystem.h> +#include <zencore/fmtutils.h> +#include <zencore/logging.h> +#include <zencore/uid.h> + +#include <asio.hpp> + +#include <atomic> +#include <cstring> +#include <memory> +#include <mutex> +#include <thread> + +namespace zen { + +//////////////////////////////////////////////////////////////////////////////// + +struct TraceSession : public std::enable_shared_from_this<TraceSession> +{ + TraceSession(asio::ip::tcp::socket&& Socket, const std::filesystem::path& OutputDir) + : m_Socket(std::move(Socket)) + , m_OutputDir(OutputDir) + , m_SessionId(Oid::NewOid()) + { + try + { + m_RemoteAddress = m_Socket.remote_endpoint().address().to_string(); + } + catch (...) + { + m_RemoteAddress = "unknown"; + } + + ZEN_INFO("Trace session {} started from {}", m_SessionId, m_RemoteAddress); + } + + ~TraceSession() + { + if (m_TraceFile.IsOpen()) + { + m_TraceFile.Close(); + } + + ZEN_INFO("Trace session {} ended, {} bytes recorded to '{}'", m_SessionId, m_TotalBytesRecorded, m_TraceFilePath); + } + + void Start() { ReadPreambleHeader(); } + + bool IsActive() const { return m_Socket.is_open(); } + + TraceSessionInfo GetInfo() const + { + TraceSessionInfo Info; + Info.SessionGuid = m_SessionGuid; + Info.TraceGuid = m_TraceGuid; + Info.ControlPort = m_ControlPort; + Info.TransportVersion = m_TransportVersion; + Info.ProtocolVersion = m_ProtocolVersion; + Info.RemoteAddress = m_RemoteAddress; + Info.BytesRecorded = m_TotalBytesRecorded; + Info.TraceFilePath = m_TraceFilePath; + return Info; + } + +private: + // Preamble format: + // [magic: 4 bytes][metadata_size: 2 bytes][metadata fields: variable][version: 2 bytes] + // + // Magic bytes: [0]=version_char ('2'-'9'), [1]='C', [2]='R', [3]='T' + // + // Metadata fields (repeated): + // [size: 1 byte][id: 1 byte][data: <size> bytes] + // Field 0: ControlPort (uint16) + // Field 1: SessionGuid (16 bytes) + // Field 2: TraceGuid (16 bytes) + // + // Version: [transport: 1 byte][protocol: 1 byte] + + static constexpr size_t kMagicSize = 4; + static constexpr size_t kMetadataSizeFieldSize = 2; + static constexpr size_t kPreambleHeaderSize = kMagicSize + kMetadataSizeFieldSize; + static constexpr size_t kVersionSize = 2; + static constexpr size_t kPreambleBufferSize = 256; + static constexpr size_t kReadBufferSize = 64 * 1024; + + void ReadPreambleHeader() + { + auto Self = shared_from_this(); + + // Read the first 6 bytes: 4 magic + 2 metadata size + asio::async_read(m_Socket, + asio::buffer(m_PreambleBuffer, kPreambleHeaderSize), + [this, Self](const asio::error_code& Ec, std::size_t /*BytesRead*/) { + if (Ec) + { + HandleReadError("preamble header", Ec); + return; + } + + if (!ValidateMagic()) + { + ZEN_WARN("Trace session {}: invalid trace magic header", m_SessionId); + CloseSocket(); + return; + } + + ReadPreambleMetadata(); + }); + } + + bool ValidateMagic() + { + const uint8_t* Cursor = m_PreambleBuffer; + + // Validate magic: bytes are version, 'C', 'R', 'T' + if (Cursor[3] != 'T' || Cursor[2] != 'R' || Cursor[1] != 'C') + { + return false; + } + + if (Cursor[0] < '2' || Cursor[0] > '9') + { + return false; + } + + // Extract the metadata fields size (does not include the trailing version bytes) + std::memcpy(&m_MetadataFieldsSize, Cursor + kMagicSize, sizeof(m_MetadataFieldsSize)); + + if (m_MetadataFieldsSize + kVersionSize > kPreambleBufferSize - kPreambleHeaderSize) + { + return false; + } + + return true; + } + + void ReadPreambleMetadata() + { + auto Self = shared_from_this(); + size_t ReadSize = m_MetadataFieldsSize + kVersionSize; + + // Read metadata fields + 2 version bytes + asio::async_read(m_Socket, + asio::buffer(m_PreambleBuffer + kPreambleHeaderSize, ReadSize), + [this, Self](const asio::error_code& Ec, std::size_t /*BytesRead*/) { + if (Ec) + { + HandleReadError("preamble metadata", Ec); + return; + } + + if (!ParseMetadata()) + { + ZEN_WARN("Trace session {}: malformed trace metadata", m_SessionId); + CloseSocket(); + return; + } + + if (!CreateTraceFile()) + { + CloseSocket(); + return; + } + + // Write the full preamble to the trace file so it remains a valid .utrace + size_t PreambleSize = kPreambleHeaderSize + m_MetadataFieldsSize + kVersionSize; + std::error_code WriteEc; + m_TraceFile.Write(m_PreambleBuffer, PreambleSize, 0, WriteEc); + + if (WriteEc) + { + ZEN_ERROR("Trace session {}: failed to write preamble: {}", m_SessionId, WriteEc.message()); + CloseSocket(); + return; + } + + m_TotalBytesRecorded = PreambleSize; + + ZEN_INFO("Trace session {}: metadata - TransportV{} ProtocolV{} ControlPort:{} SessionGuid:{} TraceGuid:{}", + m_SessionId, + m_TransportVersion, + m_ProtocolVersion, + m_ControlPort, + m_SessionGuid, + m_TraceGuid); + + // Begin streaming trace data to disk + ReadMore(); + }); + } + + bool ParseMetadata() + { + const uint8_t* Cursor = m_PreambleBuffer + kPreambleHeaderSize; + int32_t Remaining = static_cast<int32_t>(m_MetadataFieldsSize); + + while (Remaining >= 2) + { + uint8_t FieldSize = Cursor[0]; + uint8_t FieldId = Cursor[1]; + Cursor += 2; + Remaining -= 2; + + if (Remaining < FieldSize) + { + return false; + } + + switch (FieldId) + { + case 0: // ControlPort + if (FieldSize >= sizeof(uint16_t)) + { + std::memcpy(&m_ControlPort, Cursor, sizeof(uint16_t)); + } + break; + case 1: // SessionGuid + if (FieldSize >= sizeof(Guid)) + { + std::memcpy(&m_SessionGuid, Cursor, sizeof(Guid)); + } + break; + case 2: // TraceGuid + if (FieldSize >= sizeof(Guid)) + { + std::memcpy(&m_TraceGuid, Cursor, sizeof(Guid)); + } + break; + } + + Cursor += FieldSize; + Remaining -= FieldSize; + } + + // Metadata should be fully consumed + if (Remaining != 0) + { + return false; + } + + // Version bytes follow immediately after the metadata fields + const uint8_t* VersionPtr = m_PreambleBuffer + kPreambleHeaderSize + m_MetadataFieldsSize; + m_TransportVersion = VersionPtr[0]; + m_ProtocolVersion = VersionPtr[1]; + + return true; + } + + bool CreateTraceFile() + { + m_TraceFilePath = m_OutputDir / fmt::format("{}.utrace", m_SessionId); + + try + { + m_TraceFile.Open(m_TraceFilePath, BasicFile::Mode::kTruncate); + ZEN_INFO("Trace session {} writing to '{}'", m_SessionId, m_TraceFilePath); + return true; + } + catch (const std::exception& Ex) + { + ZEN_ERROR("Trace session {}: failed to create trace file '{}': {}", m_SessionId, m_TraceFilePath, Ex.what()); + return false; + } + } + + void ReadMore() + { + auto Self = shared_from_this(); + + m_Socket.async_read_some(asio::buffer(m_ReadBuffer, kReadBufferSize), + [this, Self](const asio::error_code& Ec, std::size_t BytesRead) { + if (!Ec) + { + if (BytesRead > 0 && m_TraceFile.IsOpen()) + { + std::error_code WriteEc; + const uint64_t FileOffset = m_TotalBytesRecorded; + m_TraceFile.Write(m_ReadBuffer, BytesRead, FileOffset, WriteEc); + + if (WriteEc) + { + ZEN_ERROR("Trace session {}: write error: {}", m_SessionId, WriteEc.message()); + CloseSocket(); + return; + } + + m_TotalBytesRecorded += BytesRead; + } + + ReadMore(); + } + else if (Ec == asio::error::eof) + { + ZEN_DEBUG("Trace session {} connection closed by peer", m_SessionId); + CloseSocket(); + } + else if (Ec == asio::error::operation_aborted) + { + ZEN_DEBUG("Trace session {} operation aborted", m_SessionId); + } + else + { + ZEN_WARN("Trace session {} read error: {}", m_SessionId, Ec.message()); + CloseSocket(); + } + }); + } + + void HandleReadError(const char* Phase, const asio::error_code& Ec) + { + if (Ec == asio::error::eof) + { + ZEN_DEBUG("Trace session {}: connection closed during {}", m_SessionId, Phase); + } + else if (Ec == asio::error::operation_aborted) + { + ZEN_DEBUG("Trace session {}: operation aborted during {}", m_SessionId, Phase); + } + else + { + ZEN_WARN("Trace session {}: error during {}: {}", m_SessionId, Phase, Ec.message()); + } + + CloseSocket(); + } + + void CloseSocket() + { + std::error_code Ec; + m_Socket.close(Ec); + + if (m_TraceFile.IsOpen()) + { + m_TraceFile.Close(); + } + } + + asio::ip::tcp::socket m_Socket; + std::filesystem::path m_OutputDir; + std::filesystem::path m_TraceFilePath; + BasicFile m_TraceFile; + Oid m_SessionId; + std::string m_RemoteAddress; + + // Preamble parsing + uint8_t m_PreambleBuffer[kPreambleBufferSize] = {}; + uint16_t m_MetadataFieldsSize = 0; + + // Extracted metadata + Guid m_SessionGuid{}; + Guid m_TraceGuid{}; + uint16_t m_ControlPort = 0; + uint8_t m_TransportVersion = 0; + uint8_t m_ProtocolVersion = 0; + + // Streaming + uint8_t m_ReadBuffer[kReadBufferSize]; + uint64_t m_TotalBytesRecorded = 0; +}; + +//////////////////////////////////////////////////////////////////////////////// + +struct TraceRecorder::Impl +{ + Impl() : m_IoContext(), m_Acceptor(m_IoContext) {} + + ~Impl() { Shutdown(); } + + void Initialize(uint16_t InPort, const std::filesystem::path& OutputDir) + { + std::lock_guard<std::mutex> Lock(m_Mutex); + + if (m_IsRunning) + { + ZEN_WARN("TraceRecorder already initialized"); + return; + } + + m_OutputDir = OutputDir; + + try + { + // Create output directory if it doesn't exist + CreateDirectories(m_OutputDir); + + // Configure acceptor + m_Acceptor.open(asio::ip::tcp::v4()); + m_Acceptor.set_option(asio::socket_base::reuse_address(true)); + m_Acceptor.bind(asio::ip::tcp::endpoint(asio::ip::tcp::v4(), InPort)); + m_Acceptor.listen(); + + m_Port = m_Acceptor.local_endpoint().port(); + + ZEN_INFO("TraceRecorder listening on port {}, output directory: '{}'", m_Port, m_OutputDir); + + m_IsRunning = true; + + // Start accepting connections + StartAccept(); + + // Start IO thread + m_IoThread = std::thread([this]() { + try + { + m_IoContext.run(); + } + catch (const std::exception& Ex) + { + ZEN_ERROR("TraceRecorder IO thread exception: {}", Ex.what()); + } + }); + } + catch (const std::exception& Ex) + { + ZEN_ERROR("Failed to initialize TraceRecorder: {}", Ex.what()); + m_IsRunning = false; + throw; + } + } + + void Shutdown() + { + std::lock_guard<std::mutex> Lock(m_Mutex); + + if (!m_IsRunning) + { + return; + } + + ZEN_INFO("TraceRecorder shutting down"); + + m_IsRunning = false; + + std::error_code Ec; + m_Acceptor.close(Ec); + + m_IoContext.stop(); + + if (m_IoThread.joinable()) + { + m_IoThread.join(); + } + + { + std::lock_guard<std::mutex> SessionLock(m_SessionsMutex); + m_Sessions.clear(); + } + + ZEN_INFO("TraceRecorder shutdown complete"); + } + + bool IsRunning() const { return m_IsRunning; } + + uint16_t GetPort() const { return m_Port; } + + std::vector<TraceSessionInfo> GetActiveSessions() const + { + std::lock_guard<std::mutex> Lock(m_SessionsMutex); + + std::vector<TraceSessionInfo> Result; + for (const auto& WeakSession : m_Sessions) + { + if (auto Session = WeakSession.lock()) + { + if (Session->IsActive()) + { + Result.push_back(Session->GetInfo()); + } + } + } + return Result; + } + +private: + void StartAccept() + { + auto Socket = std::make_shared<asio::ip::tcp::socket>(m_IoContext); + + m_Acceptor.async_accept(*Socket, [this, Socket](const asio::error_code& Ec) { + if (!Ec) + { + auto Session = std::make_shared<TraceSession>(std::move(*Socket), m_OutputDir); + + { + std::lock_guard<std::mutex> Lock(m_SessionsMutex); + + // Prune expired sessions while adding the new one + std::erase_if(m_Sessions, [](const std::weak_ptr<TraceSession>& Wp) { return Wp.expired(); }); + m_Sessions.push_back(Session); + } + + Session->Start(); + } + else if (Ec != asio::error::operation_aborted) + { + ZEN_WARN("Accept error: {}", Ec.message()); + } + + // Continue accepting if still running + if (m_IsRunning) + { + StartAccept(); + } + }); + } + + asio::io_context m_IoContext; + asio::ip::tcp::acceptor m_Acceptor; + std::thread m_IoThread; + std::filesystem::path m_OutputDir; + std::mutex m_Mutex; + std::atomic<bool> m_IsRunning{false}; + uint16_t m_Port = 0; + + mutable std::mutex m_SessionsMutex; + std::vector<std::weak_ptr<TraceSession>> m_Sessions; +}; + +//////////////////////////////////////////////////////////////////////////////// + +TraceRecorder::TraceRecorder() : m_Impl(std::make_unique<Impl>()) +{ +} + +TraceRecorder::~TraceRecorder() +{ + Shutdown(); +} + +void +TraceRecorder::Initialize(uint16_t InPort, const std::filesystem::path& OutputDir) +{ + m_Impl->Initialize(InPort, OutputDir); +} + +void +TraceRecorder::Shutdown() +{ + m_Impl->Shutdown(); +} + +bool +TraceRecorder::IsRunning() const +{ + return m_Impl->IsRunning(); +} + +uint16_t +TraceRecorder::GetPort() const +{ + return m_Impl->GetPort(); +} + +std::vector<TraceSessionInfo> +TraceRecorder::GetActiveSessions() const +{ + return m_Impl->GetActiveSessions(); +} + +} // namespace zen diff --git a/src/zenserver/trace/tracerecorder.h b/src/zenserver/trace/tracerecorder.h new file mode 100644 index 000000000..48857aec8 --- /dev/null +++ b/src/zenserver/trace/tracerecorder.h @@ -0,0 +1,46 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/guid.h> +#include <zencore/zencore.h> + +#include <filesystem> +#include <memory> +#include <string> +#include <vector> + +namespace zen { + +struct TraceSessionInfo +{ + Guid SessionGuid{}; + Guid TraceGuid{}; + uint16_t ControlPort = 0; + uint8_t TransportVersion = 0; + uint8_t ProtocolVersion = 0; + std::string RemoteAddress; + uint64_t BytesRecorded = 0; + std::filesystem::path TraceFilePath; +}; + +class TraceRecorder +{ +public: + TraceRecorder(); + ~TraceRecorder(); + + void Initialize(uint16_t InPort, const std::filesystem::path& OutputDir); + void Shutdown(); + + bool IsRunning() const; + uint16_t GetPort() const; + + std::vector<TraceSessionInfo> GetActiveSessions() const; + +private: + struct Impl; + std::unique_ptr<Impl> m_Impl; +}; + +} // namespace zen
\ No newline at end of file diff --git a/src/zenserver/xmake.lua b/src/zenserver/xmake.lua index 9ab51beb2..915b6a3b1 100644 --- a/src/zenserver/xmake.lua +++ b/src/zenserver/xmake.lua @@ -27,6 +27,7 @@ target("zenserver") add_packages("json11") add_packages("lua") add_packages("consul") + add_packages("nomad") if has_config("zenmimalloc") then add_packages("mimalloc") @@ -36,6 +37,14 @@ target("zenserver") add_packages("sentry-native") end + if has_config("zenhorde") then + add_deps("zenhorde") + end + + if has_config("zennomad") then + add_deps("zennomad") + end + if is_mode("release") then set_optimize("fastest") end @@ -145,4 +154,14 @@ target("zenserver") end copy_if_newer(path.join(installdir, "bin", consul_bin), path.join(target:targetdir(), consul_bin), consul_bin) end + + local nomad_pkg = target:pkg("nomad") + if nomad_pkg then + local installdir = nomad_pkg:installdir() + local nomad_bin = "nomad" + if is_plat("windows") then + nomad_bin = "nomad.exe" + end + copy_if_newer(path.join(installdir, "bin", nomad_bin), path.join(target:targetdir(), nomad_bin), nomad_bin) + end end) diff --git a/src/zentest-appstub/zentest-appstub.cpp b/src/zentest-appstub/zentest-appstub.cpp index 67fbef532..509629739 100644 --- a/src/zentest-appstub/zentest-appstub.cpp +++ b/src/zentest-appstub/zentest-appstub.cpp @@ -106,6 +106,11 @@ DescribeFunctions() << "Reverse"sv; Versions << "Version"sv << Guid::FromString("31313131-3131-3131-3131-313131313131"sv); Versions.EndObject(); + Versions.BeginObject(); + Versions << "Name"sv + << "Sleep"sv; + Versions << "Version"sv << Guid::FromString("88888888-8888-8888-8888-888888888888"sv); + Versions.EndObject(); Versions.EndArray(); return Versions.Save(); @@ -190,6 +195,12 @@ ExecuteFunction(CbObject Action, ContentResolver ChunkResolver) { return Apply(NullFunction); } + else if (Function == "Sleep"sv) + { + uint64_t SleepTimeMs = Action["Constants"sv].AsObjectView()["SleepTimeMs"sv].AsUInt64(); + zen::Sleep(static_cast<int>(SleepTimeMs)); + return Apply(IdentityFunction); + } else { return {}; diff --git a/src/zenutil/include/zenutil/consoletui.h b/src/zenutil/include/zenutil/consoletui.h index 7dc68c126..5f74fa82b 100644 --- a/src/zenutil/include/zenutil/consoletui.h +++ b/src/zenutil/include/zenutil/consoletui.h @@ -2,6 +2,7 @@ #pragma once +#include <cstdint> #include <span> #include <string> #include <string_view> diff --git a/src/zenutil/include/zenutil/zenserverprocess.h b/src/zenutil/include/zenutil/zenserverprocess.h index e81b154e8..2a8617162 100644 --- a/src/zenutil/include/zenutil/zenserverprocess.h +++ b/src/zenutil/include/zenutil/zenserverprocess.h @@ -84,6 +84,7 @@ struct ZenServerInstance { kStorageServer, // default kHubServer, + kComputeServer, }; ZenServerInstance(ZenServerEnvironment& TestEnvironment, ServerMode Mode = ServerMode::kStorageServer); diff --git a/src/zenutil/zenserverprocess.cpp b/src/zenutil/zenserverprocess.cpp index e127a92d7..b09c2d89a 100644 --- a/src/zenutil/zenserverprocess.cpp +++ b/src/zenutil/zenserverprocess.cpp @@ -787,6 +787,8 @@ ToString(ZenServerInstance::ServerMode Mode) return "storage"sv; case ZenServerInstance::ServerMode::kHubServer: return "hub"sv; + case ZenServerInstance::ServerMode::kComputeServer: + return "compute"sv; default: return "invalid"sv; } @@ -808,6 +810,10 @@ ZenServerInstance::SpawnServerInternal(int ChildId, std::string_view ServerArgs, { CommandLine << " hub"; } + else if (m_ServerMode == ServerMode::kComputeServer) + { + CommandLine << " compute"; + } CommandLine << " --child-id " << ChildEventName; @@ -74,6 +74,7 @@ add_defines("EASTL_STD_ITERATOR_CATEGORY_ENABLED", "EASTL_DEPRECATIONS_FOR_2024_ add_requires("eastl", {system = false}) add_requires("consul", {system = false}) -- for hub tests +add_requires("nomad", {system = false}) -- for nomad provisioner tests if has_config("zenmimalloc") and not use_asan then add_requires("mimalloc", {system = false}) @@ -244,13 +245,29 @@ else add_defines("ZEN_WITH_HTTPSYS=0") end +local compute_default = false + option("zencompute") - set_default(false) + set_default(compute_default) set_showmenu(true) set_description("Enable compute services endpoint") option_end() add_define_by_config("ZEN_WITH_COMPUTE_SERVICES", "zencompute") +option("zenhorde") + set_default(compute_default) + set_showmenu(true) + set_description("Enable Horde worker provisioning") +option_end() +add_define_by_config("ZEN_WITH_HORDE", "zenhorde") + +option("zennomad") + set_default(compute_default) + set_showmenu(true) + set_description("Enable Nomad worker provisioning") +option_end() +add_define_by_config("ZEN_WITH_NOMAD", "zennomad") + if is_os("windows") then add_defines("UE_MEMORY_TRACE_AVAILABLE=1") @@ -304,6 +321,12 @@ includes("src/zenhttp", "src/zenhttp-test") includes("src/zennet", "src/zennet-test") includes("src/zenremotestore", "src/zenremotestore-test") includes("src/zencompute", "src/zencompute-test") +if has_config("zenhorde") then + includes("src/zenhorde") +end +if has_config("zennomad") then + includes("src/zennomad") +end includes("src/zenstore", "src/zenstore-test") includes("src/zentelemetry", "src/zentelemetry-test") includes("src/zenutil", "src/zenutil-test") |