From 795345e5fd7974a1f5227d507a58bb3ed75eafd5 Mon Sep 17 00:00:00 2001 From: Stefan Boberg Date: Mon, 13 Apr 2026 16:38:16 +0200 Subject: Compute OIDC auth, async Horde agents, and orchestrator improvements (#913) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Rework of the Horde agent subsystem from synchronous per-thread I/O to an async ASIO-driven architecture, plus provisioner scale-down with graceful draining, OIDC authentication, scheduler improvements, and dashboard UI for provisioner control. ### Async Horde Agent Rewrite - Replace synchronous `HordeAgent` (one thread per agent, blocking I/O) with `AsyncHordeAgent` — an ASIO state machine running on a shared `io_context` thread pool - Replace `TcpComputeTransport`/`AesComputeTransport` with `AsyncTcpComputeTransport`/`AsyncAesComputeTransport` - Replace `AgentMessageChannel` with `AsyncAgentMessageChannel` using frame queuing and ASIO timers - Delete `ComputeBuffer` and `ComputeChannel` ring-buffer classes (no longer needed) ### Provisioner Drain / Scale-Down - `HordeProvisioner` can now drain agents when target core count is lowered: queries each agent's `/compute/session/status` for workload, selects candidates by largest-fit/lowest-workload, and sends `/compute/session/drain` - Configurable `--horde-drain-grace-period` (default 300s) before force-kill - Implement `IProvisionerStateProvider` interface to expose provisioner state to the orchestrator HTTP layer - Forward `--coordinator-session`, `--provision-clean`, and `--provision-tracehost` through both Horde and Nomad provisioners to spawned workers ### OIDC Authentication - `HordeClient` accepts an `AccessTokenProvider` (refreshable token function) as alternative to static `--horde-token` - Wire up `OidcToken.exe` auto-discovery via `httpclientauth::CreateFromOidcTokenExecutable` with `--HordeUrl` mode - New `--horde-oidctoken-exe-path` CLI option for explicit path override ### Orchestrator & Scheduler - Orchestrator generates a session ID at startup; workers include `coordinator_session` in announcements so the orchestrator can reject stale-session workers - New `Rejected` action state — when a remote runner declines at capacity, the action is rescheduled without retry count increment - Reduce scheduler lock contention: snapshot pending actions under shared lock, sort/trim outside the lock - Parallelize remote action submission across runners via `WorkerThreadPool` with slow-submit warnings - New action field `FailureReason` populated by all runner types (exit codes, sandbox failures, exceptions) - New endpoints: `session/drain`, `session/status`, `session/sunset`, `provisioner/status`, `provisioner/target` ### Remote Execution - Eager-attach mode for `RemoteHttpRunner` — bundles all attachments upfront in a `CbPackage` for single-roundtrip submits - Track in-flight submissions to prevent over-queuing - Show remote runner hostname in `GetDisplayName()` - `--announce-url` to override the endpoint announced to the coordinator (e.g. relay-visible address) ### Frontend Dashboard - Delete standalone `compute.html` (925 lines) and `orchestrator.html` (669 lines), consolidated into JS page modules - Add provisioner panel to orchestrator dashboard: target/active/estimated core counts, draining agent count - Editable target-cores input with debounced POST to `/orch/provisioner/target` - Per-agent provisioning status badges (active / draining / deallocated) in the agents table - Active vs total CPU counts in agents summary row ### CLI - New `zen compute record-start` / `record-stop` subcommands - `zen exec` progress bar with submit and completion phases, atomic work counters, `--progress` mode (Pretty/Plain/Quiet) ### Other - `DataDir` supports environment variable expansion - Worker manifest validation checks for `worker.zcb` marker to detect incomplete cached directories - Linux/Mac runners `nice(5)` child processes to avoid starving the main server - `ComputeService::SetShutdownCallback` wired to `RequestExit` via `session/sunset` - Curl HTTP client logs effective URL on failure - `MachineInfo` carries `Pool` and `Mode` from Horde response - Horde bundle creation includes `.pdb` on Windows --- src/zen/cmds/compute_cmd.cpp | 96 +++ src/zen/cmds/compute_cmd.h | 53 ++ src/zen/cmds/exec_cmd.cpp | 80 +- src/zen/zen.cpp | 5 +- src/zencompute/CLAUDE.md | 7 + src/zencompute/computeservice.cpp | 164 ++-- src/zencompute/httpcomputeservice.cpp | 95 ++- src/zencompute/httporchestrator.cpp | 135 ++- src/zencompute/include/zencompute/computeservice.h | 7 +- .../include/zencompute/httpcomputeservice.h | 4 + .../include/zencompute/httporchestrator.h | 17 + .../include/zencompute/orchestratorservice.h | 12 +- .../include/zencompute/provisionerstate.h | 38 + src/zencompute/orchestratorservice.cpp | 29 +- src/zencompute/runners/functionrunner.cpp | 120 ++- src/zencompute/runners/functionrunner.h | 27 + src/zencompute/runners/linuxrunner.cpp | 6 +- src/zencompute/runners/localrunner.cpp | 19 +- src/zencompute/runners/macrunner.cpp | 6 +- src/zencompute/runners/managedrunner.cpp | 2 +- src/zencompute/runners/remotehttprunner.cpp | 360 +++++--- src/zencompute/runners/remotehttprunner.h | 12 +- src/zencompute/runners/windowsrunner.cpp | 4 +- src/zencompute/runners/winerunner.cpp | 4 +- src/zenhorde/README.md | 17 + src/zenhorde/hordeagent.cpp | 551 +++++++----- src/zenhorde/hordeagent.h | 128 ++- src/zenhorde/hordeagentmessage.cpp | 502 ++++++----- src/zenhorde/hordeagentmessage.h | 123 ++- src/zenhorde/hordebundle.cpp | 2 +- src/zenhorde/hordeclient.cpp | 65 +- src/zenhorde/hordecomputebuffer.cpp | 454 ---------- src/zenhorde/hordecomputebuffer.h | 136 --- src/zenhorde/hordecomputechannel.cpp | 37 - src/zenhorde/hordecomputechannel.h | 32 - src/zenhorde/hordecomputesocket.cpp | 410 +++++---- src/zenhorde/hordecomputesocket.h | 104 ++- src/zenhorde/hordeconfig.cpp | 16 +- src/zenhorde/hordeprovisioner.cpp | 664 ++++++++++----- src/zenhorde/hordetransport.cpp | 153 ++-- src/zenhorde/hordetransport.h | 67 +- src/zenhorde/hordetransportaes.cpp | 609 +++++++------- src/zenhorde/hordetransportaes.h | 50 +- src/zenhorde/include/zenhorde/hordeclient.h | 32 +- src/zenhorde/include/zenhorde/hordeconfig.h | 37 +- src/zenhorde/include/zenhorde/hordeprovisioner.h | 80 +- src/zenhttp/clients/httpclientcurl.cpp | 10 +- src/zenhttp/clients/httpclientcurl.h | 1 + src/zenhttp/httpclientauth.cpp | 18 +- src/zenhttp/include/zenhttp/httpclientauth.h | 3 +- src/zennomad/include/zennomad/nomadclient.h | 6 +- src/zennomad/include/zennomad/nomadprovisioner.h | 9 +- src/zennomad/nomadclient.cpp | 38 +- src/zennomad/nomadprovisioner.cpp | 11 +- src/zenserver/compute/computeserver.cpp | 108 ++- src/zenserver/compute/computeserver.h | 7 + src/zenserver/config/config.cpp | 23 +- src/zenserver/frontend/html/compute/compute.html | 925 --------------------- src/zenserver/frontend/html/compute/index.html | 2 +- .../frontend/html/compute/orchestrator.html | 669 --------------- src/zenserver/frontend/html/pages/orchestrator.js | 210 ++++- src/zenserver/main.cpp | 8 +- 62 files changed, 3649 insertions(+), 3970 deletions(-) create mode 100644 src/zen/cmds/compute_cmd.cpp create mode 100644 src/zen/cmds/compute_cmd.h create mode 100644 src/zencompute/include/zencompute/provisionerstate.h create mode 100644 src/zenhorde/README.md delete mode 100644 src/zenhorde/hordecomputebuffer.cpp delete mode 100644 src/zenhorde/hordecomputebuffer.h delete mode 100644 src/zenhorde/hordecomputechannel.cpp delete mode 100644 src/zenhorde/hordecomputechannel.h delete mode 100644 src/zenserver/frontend/html/compute/compute.html delete mode 100644 src/zenserver/frontend/html/compute/orchestrator.html (limited to 'src') diff --git a/src/zen/cmds/compute_cmd.cpp b/src/zen/cmds/compute_cmd.cpp new file mode 100644 index 000000000..01166cb0e --- /dev/null +++ b/src/zen/cmds/compute_cmd.cpp @@ -0,0 +1,96 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "compute_cmd.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include +# include +# include + +using namespace std::literals; + +namespace zen { + +////////////////////////////////////////////////////////////////////////// +// ComputeRecordStartSubCmd + +ComputeRecordStartSubCmd::ComputeRecordStartSubCmd() : ZenSubCmdBase("record-start", "Start recording compute actions") +{ + SubOptions().add_option("", "u", "hosturl", ZenCmdBase::kHostUrlHelp, cxxopts::value(m_HostName)->default_value(""), ""); +} + +void +ComputeRecordStartSubCmd::Run(const ZenCliOptions& GlobalOptions) +{ + ZEN_UNUSED(GlobalOptions); + + m_HostName = ZenCmdBase::ResolveTargetHostSpec(m_HostName); + if (m_HostName.empty()) + { + throw OptionParseException("Unable to resolve server specification", SubOptions().help()); + } + + HttpClient Http = ZenCmdBase::CreateHttpClient(m_HostName); + if (HttpClient::Response Response = Http.Post("/compute/record/start"sv, HttpClient::KeyValueMap{}, HttpClient::KeyValueMap{})) + { + CbObject Obj = Response.AsObject(); + std::string_view Path = Obj["path"sv].AsString(); + ZEN_CONSOLE("recording started: " ZEN_BRIGHT_GREEN("{}"), Path); + } + else + { + Response.ThrowError("Failed to start recording"); + } +} + +////////////////////////////////////////////////////////////////////////// +// ComputeRecordStopSubCmd + +ComputeRecordStopSubCmd::ComputeRecordStopSubCmd() : ZenSubCmdBase("record-stop", "Stop recording compute actions") +{ + SubOptions().add_option("", "u", "hosturl", ZenCmdBase::kHostUrlHelp, cxxopts::value(m_HostName)->default_value(""), ""); +} + +void +ComputeRecordStopSubCmd::Run(const ZenCliOptions& GlobalOptions) +{ + ZEN_UNUSED(GlobalOptions); + + m_HostName = ZenCmdBase::ResolveTargetHostSpec(m_HostName); + if (m_HostName.empty()) + { + throw OptionParseException("Unable to resolve server specification", SubOptions().help()); + } + + HttpClient Http = ZenCmdBase::CreateHttpClient(m_HostName); + if (HttpClient::Response Response = Http.Post("/compute/record/stop"sv, HttpClient::KeyValueMap{}, HttpClient::KeyValueMap{})) + { + CbObject Obj = Response.AsObject(); + std::string_view Path = Obj["path"sv].AsString(); + ZEN_CONSOLE("recording stopped: " ZEN_BRIGHT_GREEN("{}"), Path); + } + else + { + Response.ThrowError("Failed to stop recording"); + } +} + +////////////////////////////////////////////////////////////////////////// +// ComputeCommand + +ComputeCommand::ComputeCommand() +{ + m_Options.add_options()("h,help", "Print help"); + m_Options.add_option("__hidden__", "", "subcommand", "", cxxopts::value(m_SubCommand)->default_value(""), ""); + m_Options.parse_positional({"subcommand"}); + + AddSubCommand(m_RecordStartSubCmd); + AddSubCommand(m_RecordStopSubCmd); +} + +ComputeCommand::~ComputeCommand() = default; + +} // namespace zen + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zen/cmds/compute_cmd.h b/src/zen/cmds/compute_cmd.h new file mode 100644 index 000000000..b26f639c4 --- /dev/null +++ b/src/zen/cmds/compute_cmd.h @@ -0,0 +1,53 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "../zen.h" + +#include + +#if ZEN_WITH_COMPUTE_SERVICES + +namespace zen { + +class ComputeRecordStartSubCmd : public ZenSubCmdBase +{ +public: + ComputeRecordStartSubCmd(); + void Run(const ZenCliOptions& GlobalOptions) override; + +private: + std::string m_HostName; +}; + +class ComputeRecordStopSubCmd : public ZenSubCmdBase +{ +public: + ComputeRecordStopSubCmd(); + void Run(const ZenCliOptions& GlobalOptions) override; + +private: + std::string m_HostName; +}; + +class ComputeCommand : public ZenCmdWithSubCommands +{ +public: + static constexpr char Name[] = "compute"; + static constexpr char Description[] = "Compute service operations"; + + ComputeCommand(); + ~ComputeCommand(); + + cxxopts::Options& Options() override { return m_Options; } + +private: + cxxopts::Options m_Options{Name, Description}; + std::string m_SubCommand; + ComputeRecordStartSubCmd m_RecordStartSubCmd; + ComputeRecordStopSubCmd m_RecordStopSubCmd; +}; + +} // namespace zen + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zen/cmds/exec_cmd.cpp b/src/zen/cmds/exec_cmd.cpp index 9719fce77..89bbf3638 100644 --- a/src/zen/cmds/exec_cmd.cpp +++ b/src/zen/cmds/exec_cmd.cpp @@ -23,6 +23,8 @@ #include #include +#include "../progressbar.h" + #include #include #include @@ -124,13 +126,14 @@ struct ExecSessionConfig std::vector& FunctionList; // mutable for EmitFunctionListOnce std::string_view OrchestratorUrl; const std::filesystem::path& OutputPath; - int Offset = 0; - int Stride = 1; - int Limit = 0; - bool Verbose = false; - bool Quiet = false; - bool DumpActions = false; - bool Binary = false; + int Offset = 0; + int Stride = 1; + int Limit = 0; + bool Verbose = false; + bool Quiet = false; + bool DumpActions = false; + bool Binary = false; + ProgressBar::Mode ProgressMode = ProgressBar::Mode::PrettyScroll; }; ////////////////////////////////////////////////////////////////////////// @@ -345,8 +348,6 @@ ExecSessionRunner::DrainCompletedJobs() } m_PendingJobs.Remove(CompleteLsn); - - ZEN_CONSOLE("completed: LSN {} ({} still pending)", CompleteLsn, m_PendingJobs.GetSize()); } } } @@ -897,17 +898,20 @@ ExecSessionRunner::Run() // Then submit work items - int FailedWorkCounter = 0; - size_t RemainingWorkItems = m_Config.RecordingReader.GetActionCount(); - int SubmittedWorkItems = 0; + std::atomic FailedWorkCounter{0}; + std::atomic RemainingWorkItems{m_Config.RecordingReader.GetActionCount()}; + std::atomic SubmittedWorkItems{0}; + size_t TotalWorkItems = RemainingWorkItems.load(); - ZEN_CONSOLE("submitting {} work items", RemainingWorkItems); + ProgressBar SubmitProgress(m_Config.ProgressMode, "Submit"); + SubmitProgress.UpdateState({.Task = "Submitting work items", .TotalCount = TotalWorkItems, .RemainingCount = RemainingWorkItems.load()}, + false); int OffsetCounter = m_Config.Offset; int StrideCounter = m_Config.Stride; auto ShouldSchedule = [&]() -> bool { - if (m_Config.Limit && SubmittedWorkItems >= m_Config.Limit) + if (m_Config.Limit && SubmittedWorkItems.load() >= m_Config.Limit) { // Limit reached, ignore @@ -1005,17 +1009,14 @@ ExecSessionRunner::Run() { const int32_t LsnField = EnqueueResult.Lsn; - --RemainingWorkItems; - ++SubmittedWorkItems; + size_t Remaining = --RemainingWorkItems; + int Submitted = ++SubmittedWorkItems; - if (!m_Config.Quiet) - { - ZEN_CONSOLE("submitted work item #{} - LSN {} - {}. {} remaining", - SubmittedWorkItems, - LsnField, - NiceTimeSpanMs(SubmitTimer.GetElapsedTimeMs()), - RemainingWorkItems); - } + SubmitProgress.UpdateState({.Task = "Submitting work items", + .Details = fmt::format("#{} LSN {}", Submitted, LsnField), + .TotalCount = TotalWorkItems, + .RemainingCount = Remaining}, + false); if (!m_Config.OutputPath.empty()) { @@ -1055,22 +1056,36 @@ ExecSessionRunner::Run() }, TargetParallelism); + SubmitProgress.Finish(); + // Wait until all pending work is complete + size_t TotalPendingJobs = m_PendingJobs.GetSize(); + + ProgressBar CompletionProgress(m_Config.ProgressMode, "Execute"); + while (!m_PendingJobs.IsEmpty()) { - // TODO: improve this logic - zen::Sleep(500); + size_t PendingCount = m_PendingJobs.GetSize(); + CompletionProgress.UpdateState({.Task = "Executing work items", + .Details = fmt::format("{} completed, {} remaining", TotalPendingJobs - PendingCount, PendingCount), + .TotalCount = TotalPendingJobs, + .RemainingCount = PendingCount}, + false); + + zen::Sleep(GetUpdateDelayMS(m_Config.ProgressMode)); DrainCompletedJobs(); SendOrchestratorHeartbeat(); } + CompletionProgress.Finish(); + // Write summary files WriteSummaryFiles(); - if (FailedWorkCounter) + if (FailedWorkCounter.load()) { return 1; } @@ -1423,6 +1438,16 @@ ExecCommand::OnParentOptionsParsed(const ZenCliOptions& GlobalOptions) int ExecCommand::RunSession(zen::compute::ComputeServiceSession& ComputeSession, std::string_view OrchestratorUrl) { + ProgressBar::Mode ProgressMode = ProgressBar::Mode::Pretty; + if (m_VerboseLogging) + { + ProgressMode = ProgressBar::Mode::Plain; + } + else if (m_QuietLogging) + { + ProgressMode = ProgressBar::Mode::Quiet; + } + ExecSessionConfig Config{ .Resolver = *m_ChunkResolver, .RecordingReader = *m_RecordingReader, @@ -1437,6 +1462,7 @@ ExecCommand::RunSession(zen::compute::ComputeServiceSession& ComputeSession, std .Quiet = m_QuietLogging, .DumpActions = m_DumpActions, .Binary = m_Binary, + .ProgressMode = ProgressMode, }; ExecSessionRunner Runner(ComputeSession, Config); diff --git a/src/zen/zen.cpp b/src/zen/zen.cpp index 3277eb856..bbf6b4f8a 100644 --- a/src/zen/zen.cpp +++ b/src/zen/zen.cpp @@ -9,6 +9,7 @@ #include "cmds/bench_cmd.h" #include "cmds/builds_cmd.h" #include "cmds/cache_cmd.h" +#include "cmds/compute_cmd.h" #include "cmds/copy_cmd.h" #include "cmds/dedup_cmd.h" #include "cmds/exec_cmd.h" @@ -588,7 +589,8 @@ main(int argc, char** argv) DropCommand DropCmd; DropProjectCommand ProjectDropCmd; #if ZEN_WITH_COMPUTE_SERVICES - ExecCommand ExecCmd; + ComputeCommand ComputeCmd; + ExecCommand ExecCmd; #endif // ZEN_WITH_COMPUTE_SERVICES ExportOplogCommand ExportOplogCmd; FlushCommand FlushCmd; @@ -649,6 +651,7 @@ main(int argc, char** argv) {DownCommand::Name, &DownCmd, DownCommand::Description}, {DropCommand::Name, &DropCmd, DropCommand::Description}, #if ZEN_WITH_COMPUTE_SERVICES + {ComputeCommand::Name, &ComputeCmd, ComputeCommand::Description}, {ExecCommand::Name, &ExecCmd, ExecCommand::Description}, #endif {GcStatusCommand::Name, &GcStatusCmd, GcStatusCommand::Description}, diff --git a/src/zencompute/CLAUDE.md b/src/zencompute/CLAUDE.md index 750879d5a..bb574edc2 100644 --- a/src/zencompute/CLAUDE.md +++ b/src/zencompute/CLAUDE.md @@ -218,6 +218,13 @@ Worker handler logic is extracted into private helpers (`HandleWorkersGet`, `Han **Locking discipline:** The three action maps (`m_PendingActions`, `m_RunningMap`, `m_ResultsMap`) are guarded by a single `m_ActionMapLock`. This eliminates lock-ordering concerns between maps and prevents actions from being temporarily absent from all maps during state transitions. Runner-level `m_RunningLock` in `LocalProcessRunner` / `RemoteHttpRunner` is a separate lock on a different class — unrelated to the session-level action map lock. +**Lock ordering:** When acquiring multiple session-level locks, always acquire in this order to avoid deadlocks: +1. `m_ActionMapLock` (session action maps) +2. `QueueEntry::m_Lock` (per-queue state) +3. `m_ActionHistoryLock` (action history ring) + +Never acquire an earlier lock while holding a later one (e.g. never acquire `m_ActionMapLock` while holding `QueueEntry::m_Lock`). + **Atomic fields** for counters and simple state: queue counts, `CpuUsagePercent`, `CpuSeconds`, `RetryCount`, `RunnerAction::m_ActionState`. **Update decoupling:** Runners call `PostUpdate(RunnerAction*)` rather than directly mutating service state. The scheduler thread batches and deduplicates updates. diff --git a/src/zencompute/computeservice.cpp b/src/zencompute/computeservice.cpp index aaf34cbe2..852e93fa0 100644 --- a/src/zencompute/computeservice.cpp +++ b/src/zencompute/computeservice.cpp @@ -121,6 +121,8 @@ struct ComputeServiceSession::Impl , m_LocalSubmitPool(GetLargeWorkerPool(EWorkloadType::Burst)) , m_RemoteSubmitPool(GetLargeWorkerPool(EWorkloadType::Burst)) { + m_RemoteRunnerGroup.SetWorkerPool(&m_RemoteSubmitPool); + // Create a non-expiring, non-deletable implicit queue for legacy endpoints auto Result = CreateQueue("implicit"sv, {}, {}); m_ImplicitQueueId = Result.QueueId; @@ -240,8 +242,9 @@ struct ComputeServiceSession::Impl // Recording - void StartRecording(ChunkResolver& InCidStore, const std::filesystem::path& RecordingPath); - void StopRecording(); + bool StartRecording(ChunkResolver& InCidStore, const std::filesystem::path& RecordingPath); + bool StopRecording(); + bool IsRecording() const; std::unique_ptr m_Recorder; @@ -615,6 +618,7 @@ ComputeServiceSession::Impl::UpdateCoordinatorState() m_KnownWorkerUris.insert(UriStr); auto* NewRunner = new RemoteHttpRunner(m_ChunkResolver, m_OrchestratorBasePath, UriStr, m_RemoteSubmitPool); + NewRunner->SetRemoteHostname(Hostname); SyncWorkersToRunner(*NewRunner); m_RemoteRunnerGroup.AddRunner(NewRunner); } @@ -716,24 +720,44 @@ ComputeServiceSession::Impl::ShutdownRunners() m_RemoteRunnerGroup.Shutdown(); } -void +bool ComputeServiceSession::Impl::StartRecording(ChunkResolver& InCidStore, const std::filesystem::path& RecordingPath) { + if (m_Recorder) + { + ZEN_WARN("recording is already active"); + return false; + } + ZEN_INFO("starting recording to '{}'", RecordingPath); m_Recorder = std::make_unique(InCidStore, RecordingPath); ZEN_INFO("started recording to '{}'", RecordingPath); + return true; } -void +bool ComputeServiceSession::Impl::StopRecording() { + if (!m_Recorder) + { + ZEN_WARN("no recording is active"); + return false; + } + ZEN_INFO("stopping recording"); m_Recorder = nullptr; ZEN_INFO("stopped recording"); + return true; +} + +bool +ComputeServiceSession::Impl::IsRecording() const +{ + return m_Recorder != nullptr; } std::vector @@ -1128,6 +1152,10 @@ ComputeServiceSession::Impl::GetCompleted(CbWriter& Cbo) Cbo.BeginObject(); Cbo << "lsn"sv << Lsn; Cbo << "state"sv << RunnerAction::ToString(Action->ActionState()); + if (!Action->FailureReason.empty()) + { + Cbo << "reason"sv << Action->FailureReason; + } Cbo.EndObject(); } }); @@ -1416,8 +1444,8 @@ ComputeServiceSession::Impl::GetQueueCompleted(int QueueId, CbWriter& Cbo) if (Queue) { - Queue->m_Lock.WithSharedLock([&] { - m_ActionMapLock.WithSharedLock([&] { + m_ActionMapLock.WithSharedLock([&] { + Queue->m_Lock.WithSharedLock([&] { for (int Lsn : Queue->FinishedLsns) { if (m_ResultsMap.contains(Lsn)) @@ -1530,12 +1558,12 @@ ComputeServiceSession::Impl::SchedulePendingActions() static Stopwatch DumpRunningTimer; auto _ = MakeGuard([&] { - ZEN_INFO("scheduled {} pending actions. {} running ({} retired), {} still pending, {} results", - ScheduledCount, - RunningCount, - m_RetiredCount.load(), - PendingCount, - ResultCount); + ZEN_DEBUG("scheduled {} pending actions. {} running ({} retired), {} still pending, {} results", + ScheduledCount, + RunningCount, + m_RetiredCount.load(), + PendingCount, + ResultCount); if (DumpRunningTimer.GetElapsedTimeMs() > 30000) { @@ -1584,13 +1612,13 @@ ComputeServiceSession::Impl::SchedulePendingActions() // Also note that the m_PendingActions list is not maintained // here, that's done periodically in SchedulePendingActions() - m_ActionMapLock.WithExclusiveLock([&] { - if (m_SessionState.load(std::memory_order_relaxed) >= SessionState::Paused) - { - return; - } + // Extract pending actions under a shared lock — we only need to read + // the map and take Ref copies. ActionState() is atomic so this is safe. + // Sorting and capacity trimming happen outside the lock to avoid + // blocking HTTP handlers on O(N log N) work with large pending queues. - if (m_PendingActions.empty()) + m_ActionMapLock.WithSharedLock([&] { + if (m_SessionState.load(std::memory_order_relaxed) >= SessionState::Paused) { return; } @@ -1610,6 +1638,7 @@ ComputeServiceSession::Impl::SchedulePendingActions() case RunnerAction::State::Completed: case RunnerAction::State::Failed: case RunnerAction::State::Abandoned: + case RunnerAction::State::Rejected: case RunnerAction::State::Cancelled: break; @@ -1620,30 +1649,30 @@ ComputeServiceSession::Impl::SchedulePendingActions() } } - // Sort by priority descending, then by LSN ascending (FIFO within same priority) - std::sort(ActionsToSchedule.begin(), ActionsToSchedule.end(), [](const Ref& A, const Ref& B) { - if (A->Priority != B->Priority) - { - return A->Priority > B->Priority; - } - return A->ActionLsn < B->ActionLsn; - }); + PendingCount = m_PendingActions.size(); + }); - if (ActionsToSchedule.size() > Capacity) + // Sort by priority descending, then by LSN ascending (FIFO within same priority) + std::sort(ActionsToSchedule.begin(), ActionsToSchedule.end(), [](const Ref& A, const Ref& B) { + if (A->Priority != B->Priority) { - ActionsToSchedule.resize(Capacity); + return A->Priority > B->Priority; } - - PendingCount = m_PendingActions.size(); + return A->ActionLsn < B->ActionLsn; }); + if (ActionsToSchedule.size() > Capacity) + { + ActionsToSchedule.resize(Capacity); + } + if (ActionsToSchedule.empty()) { _.Dismiss(); return; } - ZEN_INFO("attempting schedule of {} pending actions", ActionsToSchedule.size()); + ZEN_DEBUG("attempting schedule of {} pending actions", ActionsToSchedule.size()); Stopwatch SubmitTimer; std::vector SubmitResults = SubmitActions(ActionsToSchedule); @@ -1663,10 +1692,10 @@ ComputeServiceSession::Impl::SchedulePendingActions() } } - ZEN_INFO("scheduled {} pending actions in {} ({} rejected)", - ScheduledActionCount, - NiceTimeSpanMs(SubmitTimer.GetElapsedTimeMs()), - NotAcceptedCount); + ZEN_DEBUG("scheduled {} pending actions in {} ({} rejected)", + ScheduledActionCount, + NiceTimeSpanMs(SubmitTimer.GetElapsedTimeMs()), + NotAcceptedCount); ScheduledCount += ScheduledActionCount; PendingCount -= ScheduledActionCount; @@ -1975,6 +2004,14 @@ ComputeServiceSession::Impl::HandleActionUpdates() break; } + // Rejected — runner was at capacity, reschedule without retry cost + case RunnerAction::State::Rejected: + { + Action->ResetActionStateToPending(); + ZEN_DEBUG("action {} ({}) rescheduled after runner rejection", Action->ActionId, ActionLsn); + break; + } + // Terminal states — move to results, record history, notify queue case RunnerAction::State::Completed: case RunnerAction::State::Failed: @@ -2009,6 +2046,14 @@ ComputeServiceSession::Impl::HandleActionUpdates() MaxRetries); break; } + else + { + ZEN_WARN("action {} ({}) {} after {} retries, not rescheduling", + Action->ActionId, + ActionLsn, + RunnerAction::ToString(TerminalState), + Action->RetryCount.load(std::memory_order_relaxed)); + } } m_ActionMapLock.WithExclusiveLock([&] { @@ -2101,10 +2146,9 @@ ComputeServiceSession::Impl::SubmitActions(const std::vector>& ZEN_TRACE_CPU("ComputeServiceSession::SubmitActions"); std::vector Results(Actions.size()); - // First try submitting the batch to local runners in parallel + // First try submitting the batch to local runners std::vector LocalResults = m_LocalRunnerGroup.SubmitActions(Actions); - std::vector RemoteIndices; std::vector> RemoteActions; for (size_t i = 0; i < Actions.size(); ++i) @@ -2115,20 +2159,40 @@ ComputeServiceSession::Impl::SubmitActions(const std::vector>& } else { - RemoteIndices.push_back(i); RemoteActions.push_back(Actions[i]); + Results[i] = SubmitResult{.IsAccepted = true, .Reason = "dispatched to remote"}; } } - // Submit remaining actions to remote runners in parallel + // Dispatch remaining actions to remote runners asynchronously. + // Mark actions as Submitting so the scheduler won't re-pick them. + // The remote runner will transition them to Running on success, or + // we mark them Failed on rejection so HandleActionUpdates retries. if (!RemoteActions.empty()) { - std::vector RemoteResults = m_RemoteRunnerGroup.SubmitActions(RemoteActions); - - for (size_t j = 0; j < RemoteIndices.size(); ++j) + for (const Ref& Action : RemoteActions) { - Results[RemoteIndices[j]] = std::move(RemoteResults[j]); + Action->SetActionState(RunnerAction::State::Submitting); } + + m_RemoteSubmitPool.ScheduleWork( + [this, RemoteActions = std::move(RemoteActions)]() { + std::vector RemoteResults = m_RemoteRunnerGroup.SubmitActions(RemoteActions); + + for (size_t j = 0; j < RemoteResults.size(); ++j) + { + if (!RemoteResults[j].IsAccepted) + { + ZEN_DEBUG("remote submission rejected for action {} ({}): {}", + RemoteActions[j]->ActionId, + RemoteActions[j]->ActionLsn, + RemoteResults[j].Reason); + + RemoteActions[j]->SetActionState(RunnerAction::State::Rejected); + } + } + }, + WorkerThreadPool::EMode::EnableBacklog); } return Results; @@ -2194,16 +2258,22 @@ ComputeServiceSession::NotifyOrchestratorChanged() m_Impl->NotifyOrchestratorChanged(); } -void +bool ComputeServiceSession::StartRecording(ChunkResolver& InResolver, const std::filesystem::path& RecordingPath) { - m_Impl->StartRecording(InResolver, RecordingPath); + return m_Impl->StartRecording(InResolver, RecordingPath); } -void +bool ComputeServiceSession::StopRecording() { - m_Impl->StopRecording(); + return m_Impl->StopRecording(); +} + +bool +ComputeServiceSession::IsRecording() const +{ + return m_Impl->IsRecording(); } ComputeServiceSession::ActionCounts diff --git a/src/zencompute/httpcomputeservice.cpp b/src/zencompute/httpcomputeservice.cpp index 6cb975dd3..8cbb25afd 100644 --- a/src/zencompute/httpcomputeservice.cpp +++ b/src/zencompute/httpcomputeservice.cpp @@ -62,6 +62,8 @@ struct HttpComputeService::Impl RwLock m_WsConnectionsLock; std::vector> m_WsConnections; + std::function m_ShutdownCallback; + // Metrics metrics::OperationTiming m_HttpRequests; @@ -189,6 +191,65 @@ HttpComputeService::Impl::RegisterRoutes() }, HttpVerb::kPost); + m_Router.RegisterRoute( + "session/drain", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + if (m_ComputeService.RequestStateTransition(ComputeServiceSession::SessionState::Draining)) + { + CbObjectWriter Cbo; + Cbo << "state"sv << ToString(m_ComputeService.GetSessionState()); + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + } + + CbObjectWriter Cbo; + Cbo << "error"sv + << "Cannot transition to Draining from current state"sv; + HttpReq.WriteResponse(HttpResponseCode::Conflict, Cbo.Save()); + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "session/status", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + CbObjectWriter Cbo; + Cbo << "state"sv << ToString(m_ComputeService.GetSessionState()); + auto Counts = m_ComputeService.GetActionCounts(); + Cbo << "actions_pending"sv << Counts.Pending; + Cbo << "actions_running"sv << Counts.Running; + Cbo << "actions_completed"sv << Counts.Completed; + HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "session/sunset", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + if (m_ComputeService.RequestStateTransition(ComputeServiceSession::SessionState::Sunset)) + { + CbObjectWriter Cbo; + Cbo << "state"sv << ToString(m_ComputeService.GetSessionState()); + HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + + if (m_ShutdownCallback) + { + m_ShutdownCallback(); + } + return; + } + + CbObjectWriter Cbo; + Cbo << "error"sv + << "Cannot transition to Sunset from current state"sv; + HttpReq.WriteResponse(HttpResponseCode::Conflict, Cbo.Save()); + }, + HttpVerb::kPost); + m_Router.RegisterRoute( "workers", [this](HttpRouterRequest& Req) { HandleWorkersGet(Req.ServerRequest()); }, @@ -506,9 +567,19 @@ HttpComputeService::Impl::RegisterRoutes() return HttpReq.WriteResponse(HttpResponseCode::Forbidden); } - m_ComputeService.StartRecording(m_CombinedResolver, m_BaseDir / "recording"); + std::filesystem::path RecordingPath = m_BaseDir / "recording"; - return HttpReq.WriteResponse(HttpResponseCode::OK); + if (!m_ComputeService.StartRecording(m_CombinedResolver, RecordingPath)) + { + CbObjectWriter Cbo; + Cbo << "error" + << "recording is already active"; + return HttpReq.WriteResponse(HttpResponseCode::Conflict, Cbo.Save()); + } + + CbObjectWriter Cbo; + Cbo << "path" << RecordingPath.string(); + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); }, HttpVerb::kPost); @@ -522,9 +593,19 @@ HttpComputeService::Impl::RegisterRoutes() return HttpReq.WriteResponse(HttpResponseCode::Forbidden); } - m_ComputeService.StopRecording(); + std::filesystem::path RecordingPath = m_BaseDir / "recording"; + + if (!m_ComputeService.StopRecording()) + { + CbObjectWriter Cbo; + Cbo << "error" + << "no recording is active"; + return HttpReq.WriteResponse(HttpResponseCode::Conflict, Cbo.Save()); + } - return HttpReq.WriteResponse(HttpResponseCode::OK); + CbObjectWriter Cbo; + Cbo << "path" << RecordingPath.string(); + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); }, HttpVerb::kPost); @@ -1066,6 +1147,12 @@ HttpComputeService::GetActionCounts() return m_Impl->m_ComputeService.GetActionCounts(); } +void +HttpComputeService::SetShutdownCallback(std::function Callback) +{ + m_Impl->m_ShutdownCallback = std::move(Callback); +} + const char* HttpComputeService::BaseUri() const { diff --git a/src/zencompute/httporchestrator.cpp b/src/zencompute/httporchestrator.cpp index d92af8716..1f51e560e 100644 --- a/src/zencompute/httporchestrator.cpp +++ b/src/zencompute/httporchestrator.cpp @@ -7,6 +7,7 @@ # include # include # include +# include # include # include @@ -77,10 +78,47 @@ ParseWorkerAnnouncement(const CbObjectView& Data, OrchestratorService::WorkerAnn return Ann.Id; } +static OrchestratorService::WorkerAnnotator +MakeWorkerAnnotator(IProvisionerStateProvider* Prov) +{ + if (!Prov) + { + return {}; + } + return [Prov](std::string_view WorkerId, CbObjectWriter& Cbo) { + AgentProvisioningStatus Status = Prov->GetAgentStatus(WorkerId); + if (Status != AgentProvisioningStatus::Unknown) + { + const char* StatusStr = (Status == AgentProvisioningStatus::Draining) ? "draining" : "active"; + Cbo << "provisioner_status" << std::string_view(StatusStr); + } + }; +} + +bool +HttpOrchestratorService::ValidateCoordinatorSession(const CbObjectView& Data, std::string_view WorkerId) +{ + std::string_view SessionStr = Data["coordinator_session"].AsString(""); + if (SessionStr.empty()) + { + return true; // backwards compatibility: accept announcements without a session + } + Oid Session = Oid::TryFromHexString(SessionStr); + if (Session == m_SessionId) + { + return true; + } + ZEN_WARN("rejecting stale announcement from '{}' (session {} != {})", WorkerId, SessionStr, m_SessionId.ToString()); + return false; +} + HttpOrchestratorService::HttpOrchestratorService(std::filesystem::path DataDir, bool EnableWorkerWebSocket) : m_Service(std::make_unique(std::move(DataDir), EnableWorkerWebSocket)) , m_Hostname(GetMachineName()) { + m_SessionId = zen::GetSessionId(); + ZEN_INFO("orchestrator session id: {}", m_SessionId.ToString()); + m_Router.AddMatcher("workerid", [](std::string_view Segment) { return IsValidWorkerId(Segment); }); m_Router.AddMatcher("clientid", [](std::string_view Segment) { return IsValidWorkerId(Segment); }); @@ -95,13 +133,17 @@ HttpOrchestratorService::HttpOrchestratorService(std::filesystem::path DataDir, [this](HttpRouterRequest& Req) { CbObjectWriter Cbo; Cbo << "hostname" << std::string_view(m_Hostname); + Cbo << "session_id" << m_SessionId.ToString(); Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Cbo.Save()); }, HttpVerb::kGet); m_Router.RegisterRoute( "provision", - [this](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK, m_Service->GetWorkerList()); }, + [this](HttpRouterRequest& Req) { + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, + m_Service->GetWorkerList(MakeWorkerAnnotator(m_Provisioner.load(std::memory_order_acquire)))); + }, HttpVerb::kPost); m_Router.RegisterRoute( @@ -122,6 +164,11 @@ HttpOrchestratorService::HttpOrchestratorService(std::filesystem::path DataDir, "characters and uri must start with http:// or https://"); } + if (!ValidateCoordinatorSession(Data, WorkerId)) + { + return HttpReq.WriteResponse(HttpResponseCode::Conflict, HttpContentType::kText, "Stale coordinator session"); + } + m_Service->AnnounceWorker(Ann); HttpReq.WriteResponse(HttpResponseCode::OK); @@ -135,7 +182,10 @@ HttpOrchestratorService::HttpOrchestratorService(std::filesystem::path DataDir, m_Router.RegisterRoute( "agents", - [this](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK, m_Service->GetWorkerList()); }, + [this](HttpRouterRequest& Req) { + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, + m_Service->GetWorkerList(MakeWorkerAnnotator(m_Provisioner.load(std::memory_order_acquire)))); + }, HttpVerb::kGet); m_Router.RegisterRoute( @@ -241,6 +291,59 @@ HttpOrchestratorService::HttpOrchestratorService(std::filesystem::path DataDir, }, HttpVerb::kGet); + // Provisioner endpoints + + m_Router.RegisterRoute( + "provisioner/status", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + CbObjectWriter Cbo; + if (IProvisionerStateProvider* Prov = m_Provisioner.load(std::memory_order_acquire)) + { + Cbo << "name" << Prov->GetName(); + Cbo << "target_cores" << Prov->GetTargetCoreCount(); + Cbo << "estimated_cores" << Prov->GetEstimatedCoreCount(); + Cbo << "active_cores" << Prov->GetActiveCoreCount(); + Cbo << "agents" << Prov->GetAgentCount(); + Cbo << "agents_draining" << Prov->GetDrainingAgentCount(); + } + HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "provisioner/target", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + CbObject Data = HttpReq.ReadPayloadObject(); + int32_t Cores = Data["target_cores"].AsInt32(-1); + + ZEN_INFO("provisioner/target: received request (target_cores={}, payload_valid={})", Cores, Data ? true : false); + + if (Cores < 0) + { + ZEN_WARN("provisioner/target: bad request (target_cores={})", Cores); + return HttpReq.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "Missing or invalid target_cores field"); + } + + IProvisionerStateProvider* Prov = m_Provisioner.load(std::memory_order_acquire); + if (!Prov) + { + ZEN_WARN("provisioner/target: no provisioner configured"); + return HttpReq.WriteResponse(HttpResponseCode::NotFound, HttpContentType::kText, "No provisioner configured"); + } + + ZEN_INFO("provisioner/target: setting target to {} cores", Cores); + Prov->SetTargetCoreCount(static_cast(Cores)); + + CbObjectWriter Cbo; + Cbo << "target_cores" << Prov->GetTargetCoreCount(); + HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + }, + HttpVerb::kPost); + // Client tracking endpoints m_Router.RegisterRoute( @@ -411,6 +514,13 @@ HttpOrchestratorService::HandleRequest(HttpServerRequest& Request) } } +void +HttpOrchestratorService::SetProvisionerStateProvider(IProvisionerStateProvider* Provider) +{ + m_Provisioner.store(Provider, std::memory_order_release); + m_Service->SetProvisionerStateProvider(Provider); +} + ////////////////////////////////////////////////////////////////////////// // // IWebSocketHandler @@ -488,6 +598,11 @@ HttpOrchestratorService::HandleWorkerWebSocketMessage(const WebSocketMessage& Ms return {}; } + if (!ValidateCoordinatorSession(Data, WorkerId)) + { + return {}; + } + m_Service->AnnounceWorker(Ann); return std::string(WorkerId); } @@ -563,7 +678,7 @@ HttpOrchestratorService::PushThreadFunction() } // Build combined JSON with worker list, provisioning history, clients, and client history - CbObject WorkerList = m_Service->GetWorkerList(); + CbObject WorkerList = m_Service->GetWorkerList(MakeWorkerAnnotator(m_Provisioner.load(std::memory_order_acquire))); CbObject History = m_Service->GetProvisioningHistory(50); CbObject ClientList = m_Service->GetClientList(); CbObject ClientHistory = m_Service->GetClientHistory(50); @@ -615,6 +730,20 @@ HttpOrchestratorService::PushThreadFunction() JsonBuilder.Append(ClientHistoryJsonView.substr(1, ClientHistoryJsonView.size() - 2)); } + // Emit provisioner stats if available + if (IProvisionerStateProvider* Prov = m_Provisioner.load(std::memory_order_acquire)) + { + JsonBuilder.Append( + fmt::format(",\"provisioner\":{{\"name\":\"{}\",\"target_cores\":{},\"estimated_cores\":{}" + ",\"active_cores\":{},\"agents\":{},\"agents_draining\":{}}}", + Prov->GetName(), + Prov->GetTargetCoreCount(), + Prov->GetEstimatedCoreCount(), + Prov->GetActiveCoreCount(), + Prov->GetAgentCount(), + Prov->GetDrainingAgentCount())); + } + JsonBuilder.Append("}"); std::string_view Json = JsonBuilder.ToView(); diff --git a/src/zencompute/include/zencompute/computeservice.h b/src/zencompute/include/zencompute/computeservice.h index ad556f546..97de4321a 100644 --- a/src/zencompute/include/zencompute/computeservice.h +++ b/src/zencompute/include/zencompute/computeservice.h @@ -279,7 +279,7 @@ public: // sized to match RunnerAction::State::_Count but we can't use the enum here // for dependency reasons, so just use a fixed size array and static assert in // the implementation file - uint64_t Timestamps[9] = {}; + uint64_t Timestamps[10] = {}; }; [[nodiscard]] std::vector GetActionHistory(int Limit = 100); @@ -305,8 +305,9 @@ public: // Recording - void StartRecording(ChunkResolver& InResolver, const std::filesystem::path& RecordingPath); - void StopRecording(); + bool StartRecording(ChunkResolver& InResolver, const std::filesystem::path& RecordingPath); + bool StopRecording(); + bool IsRecording() const; private: void PostUpdate(RunnerAction* Action); diff --git a/src/zencompute/include/zencompute/httpcomputeservice.h b/src/zencompute/include/zencompute/httpcomputeservice.h index db3fce3c2..32f54f293 100644 --- a/src/zencompute/include/zencompute/httpcomputeservice.h +++ b/src/zencompute/include/zencompute/httpcomputeservice.h @@ -35,6 +35,10 @@ public: void Shutdown(); + /** Set a callback to be invoked when the session/sunset endpoint is hit. + * Typically wired to HttpServer::RequestExit() to shut down the process. */ + void SetShutdownCallback(std::function Callback); + [[nodiscard]] ComputeServiceSession::ActionCounts GetActionCounts(); const char* BaseUri() const override; diff --git a/src/zencompute/include/zencompute/httporchestrator.h b/src/zencompute/include/zencompute/httporchestrator.h index 58b2c9152..ef0a1269a 100644 --- a/src/zencompute/include/zencompute/httporchestrator.h +++ b/src/zencompute/include/zencompute/httporchestrator.h @@ -2,10 +2,12 @@ #pragma once +#include #include #include #include +#include #include #include @@ -65,6 +67,16 @@ public: */ void Shutdown(); + /** Return the session ID generated at construction time. Provisioners + * pass this to spawned workers so the orchestrator can reject stale + * announcements from previous sessions. */ + Oid GetSessionId() const { return m_SessionId; } + + /** Register a provisioner whose target core count can be read and changed + * via the orchestrator HTTP API and dashboard. Caller retains ownership; + * the provider must outlive this service. */ + void SetProvisionerStateProvider(IProvisionerStateProvider* Provider); + virtual const char* BaseUri() const override; virtual void HandleRequest(HttpServerRequest& Request) override; @@ -81,6 +93,11 @@ private: std::unique_ptr m_Service; std::string m_Hostname; + Oid m_SessionId; + bool ValidateCoordinatorSession(const CbObjectView& Data, std::string_view WorkerId); + + std::atomic m_Provisioner{nullptr}; + // WebSocket push #if ZEN_WITH_WEBSOCKETS diff --git a/src/zencompute/include/zencompute/orchestratorservice.h b/src/zencompute/include/zencompute/orchestratorservice.h index 549ff8e3c..2c49e22df 100644 --- a/src/zencompute/include/zencompute/orchestratorservice.h +++ b/src/zencompute/include/zencompute/orchestratorservice.h @@ -6,6 +6,7 @@ #if ZEN_WITH_COMPUTE_SERVICES +# include # include # include # include @@ -90,9 +91,16 @@ public: std::string Hostname; }; - CbObject GetWorkerList(); + /** Per-worker callback invoked during GetWorkerList serialization. + * The callback receives the worker ID and a CbObjectWriter positioned + * inside the worker's object, allowing the caller to append extra fields. */ + using WorkerAnnotator = std::function; + + CbObject GetWorkerList(const WorkerAnnotator& Annotate = {}); void AnnounceWorker(const WorkerAnnouncement& Announcement); + void SetProvisionerStateProvider(IProvisionerStateProvider* Provider); + bool IsWorkerWebSocketEnabled() const; void SetWorkerWebSocketConnected(std::string_view WorkerId, bool Connected); @@ -171,6 +179,8 @@ private: LoggerRef m_Log{"compute.orchestrator"}; bool m_EnableWorkerWebSocket = false; + std::atomic m_Provisioner{nullptr}; + std::thread m_ProbeThread; std::atomic m_ProbeThreadEnabled{true}; Event m_ProbeThreadEvent; diff --git a/src/zencompute/include/zencompute/provisionerstate.h b/src/zencompute/include/zencompute/provisionerstate.h new file mode 100644 index 000000000..e9af8a635 --- /dev/null +++ b/src/zencompute/include/zencompute/provisionerstate.h @@ -0,0 +1,38 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include +#include + +namespace zen::compute { + +/** Per-agent provisioning status as seen by the provisioner. */ +enum class AgentProvisioningStatus +{ + Unknown, ///< Not known to the provisioner + Active, ///< Running and allocated + Draining, ///< Being gracefully deprovisioned +}; + +/** Abstract interface for querying and controlling a provisioner from the HTTP layer. + * This decouples the orchestrator service from specific provisioner implementations. */ +class IProvisionerStateProvider +{ +public: + virtual ~IProvisionerStateProvider() = default; + + virtual std::string_view GetName() const = 0; ///< e.g. "horde", "nomad" + virtual uint32_t GetTargetCoreCount() const = 0; + virtual uint32_t GetEstimatedCoreCount() const = 0; + virtual uint32_t GetActiveCoreCount() const = 0; + virtual uint32_t GetAgentCount() const = 0; + virtual uint32_t GetDrainingAgentCount() const { return 0; } + virtual void SetTargetCoreCount(uint32_t Count) = 0; + + /** Return the provisioning status for a worker by its orchestrator ID + * (e.g. "horde-{LeaseId}"). Returns Unknown if the ID is not recognized. */ + virtual AgentProvisioningStatus GetAgentStatus(std::string_view /*WorkerId*/) const { return AgentProvisioningStatus::Unknown; } +}; + +} // namespace zen::compute diff --git a/src/zencompute/orchestratorservice.cpp b/src/zencompute/orchestratorservice.cpp index 9ea695305..aee8fa63a 100644 --- a/src/zencompute/orchestratorservice.cpp +++ b/src/zencompute/orchestratorservice.cpp @@ -31,7 +31,7 @@ OrchestratorService::~OrchestratorService() } CbObject -OrchestratorService::GetWorkerList() +OrchestratorService::GetWorkerList(const WorkerAnnotator& Annotate) { ZEN_TRACE_CPU("OrchestratorService::GetWorkerList"); CbObjectWriter Cbo; @@ -71,6 +71,10 @@ OrchestratorService::GetWorkerList() Cbo << "ws_connected" << true; } Cbo << "dt" << Worker.LastSeen.GetElapsedTimeMs(); + if (Annotate) + { + Annotate(WorkerId, Cbo); + } Cbo.EndObject(); } }); @@ -144,6 +148,12 @@ OrchestratorService::AnnounceWorker(const WorkerAnnouncement& Ann) } } +void +OrchestratorService::SetProvisionerStateProvider(IProvisionerStateProvider* Provider) +{ + m_Provisioner.store(Provider, std::memory_order_release); +} + bool OrchestratorService::IsWorkerWebSocketEnabled() const { @@ -607,6 +617,14 @@ OrchestratorService::ProbeThreadFunction() continue; } + // Check if the provisioner knows this worker is draining — if so, + // unreachability is expected and should not be logged as a warning. + bool IsDraining = false; + if (IProvisionerStateProvider* Prov = m_Provisioner.load(std::memory_order_acquire)) + { + IsDraining = Prov->GetAgentStatus(Snap.Id) == AgentProvisioningStatus::Draining; + } + ReachableState NewState = ReachableState::Unreachable; try @@ -621,7 +639,10 @@ OrchestratorService::ProbeThreadFunction() } catch (const std::exception& Ex) { - ZEN_WARN("probe failed for worker {} ({}): {}", Snap.Id, Snap.Uri, Ex.what()); + if (!IsDraining) + { + ZEN_WARN("probe failed for worker {} ({}): {}", Snap.Id, Snap.Uri, Ex.what()); + } } ReachableState PrevState = ReachableState::Unknown; @@ -646,6 +667,10 @@ OrchestratorService::ProbeThreadFunction() { ZEN_INFO("worker {} ({}) is now reachable", Snap.Id, Snap.Uri); } + else if (IsDraining) + { + ZEN_INFO("worker {} ({}) shut down (draining)", Snap.Id, Snap.Uri); + } else if (PrevState == ReachableState::Reachable) { ZEN_WARN("worker {} ({}) is no longer reachable", Snap.Id, Snap.Uri); diff --git a/src/zencompute/runners/functionrunner.cpp b/src/zencompute/runners/functionrunner.cpp index 67e12b84e..ab22c6363 100644 --- a/src/zencompute/runners/functionrunner.cpp +++ b/src/zencompute/runners/functionrunner.cpp @@ -6,9 +6,15 @@ # include # include +# include +# include +# include +# include # include +# include # include +# include # include namespace zen::compute { @@ -118,23 +124,34 @@ std::vector BaseRunnerGroup::SubmitActions(const std::vector>& Actions) { ZEN_TRACE_CPU("BaseRunnerGroup::SubmitActions"); - RwLock::SharedLockScope _(m_RunnersLock); - const int RunnerCount = gsl::narrow(m_Runners.size()); + // Snapshot runners and query capacity under the lock, then release + // before submitting — HTTP submissions to remote runners can take + // hundreds of milliseconds and we must not hold m_RunnersLock during I/O. - if (RunnerCount == 0) - { - return std::vector(Actions.size(), SubmitResult{.IsAccepted = false, .Reason = "No runners available"}); - } + std::vector> Runners; + std::vector Capacities; + std::vector>> PerRunnerActions; + size_t TotalCapacity = 0; - // Query capacity per runner and compute total - std::vector Capacities(RunnerCount); - size_t TotalCapacity = 0; + m_RunnersLock.WithSharedLock([&] { + const int RunnerCount = gsl::narrow(m_Runners.size()); + Runners.assign(m_Runners.begin(), m_Runners.end()); + Capacities.resize(RunnerCount); + PerRunnerActions.resize(RunnerCount); - for (int i = 0; i < RunnerCount; ++i) + for (int i = 0; i < RunnerCount; ++i) + { + Capacities[i] = Runners[i]->QueryCapacity(); + TotalCapacity += Capacities[i]; + } + }); + + const int RunnerCount = gsl::narrow(Runners.size()); + + if (RunnerCount == 0) { - Capacities[i] = m_Runners[i]->QueryCapacity(); - TotalCapacity += Capacities[i]; + return std::vector(Actions.size(), SubmitResult{.IsAccepted = false, .Reason = "No runners available"}); } if (TotalCapacity == 0) @@ -143,9 +160,8 @@ BaseRunnerGroup::SubmitActions(const std::vector>& Actions) } // Distribute actions across runners proportionally to their available capacity - std::vector>> PerRunnerActions(RunnerCount); - std::vector ActionRunnerIndex(Actions.size()); - size_t ActionIdx = 0; + std::vector ActionRunnerIndex(Actions.size()); + size_t ActionIdx = 0; for (int i = 0; i < RunnerCount; ++i) { @@ -176,14 +192,74 @@ BaseRunnerGroup::SubmitActions(const std::vector>& Actions) } } - // Submit batches per runner + // Submit batches per runner — in parallel when a worker pool is available + std::vector> PerRunnerResults(RunnerCount); + int ActiveRunnerCount = 0; for (int i = 0; i < RunnerCount; ++i) { if (!PerRunnerActions[i].empty()) { - PerRunnerResults[i] = m_Runners[i]->SubmitActions(PerRunnerActions[i]); + ++ActiveRunnerCount; + } + } + + static constexpr uint64_t SubmitWarnThresholdMs = 500; + + auto SubmitToRunner = [&](int RunnerIndex) { + auto& Runner = Runners[RunnerIndex]; + Runner->m_LastSubmitStats.Reset(); + + Stopwatch Timer; + + PerRunnerResults[RunnerIndex] = Runner->SubmitActions(PerRunnerActions[RunnerIndex]); + + uint64_t ElapsedMs = Timer.GetElapsedTimeMs(); + if (ElapsedMs >= SubmitWarnThresholdMs) + { + size_t Attachments = Runner->m_LastSubmitStats.TotalAttachments.load(std::memory_order_relaxed); + uint64_t AttachmentBytes = Runner->m_LastSubmitStats.TotalAttachmentBytes.load(std::memory_order_relaxed); + + ZEN_WARN("submit of {} actions ({} attachments, {}) to '{}' took {}ms", + PerRunnerActions[RunnerIndex].size(), + Attachments, + NiceBytes(AttachmentBytes), + Runner->GetDisplayName(), + ElapsedMs); + } + }; + + if (m_WorkerPool && ActiveRunnerCount > 1) + { + std::vector> Futures(RunnerCount); + + for (int i = 0; i < RunnerCount; ++i) + { + if (!PerRunnerActions[i].empty()) + { + std::packaged_task Task([&SubmitToRunner, i]() { SubmitToRunner(i); }); + + Futures[i] = m_WorkerPool->EnqueueTask(std::move(Task), WorkerThreadPool::EMode::EnableBacklog); + } + } + + for (int i = 0; i < RunnerCount; ++i) + { + if (Futures[i].valid()) + { + Futures[i].get(); + } + } + } + else + { + for (int i = 0; i < RunnerCount; ++i) + { + if (!PerRunnerActions[i].empty()) + { + SubmitToRunner(i); + } } } @@ -309,10 +385,11 @@ RunnerAction::RetractAction() bool RunnerAction::ResetActionStateToPending() { - // Only allow reset from Failed, Abandoned, or Retracted states + // Only allow reset from Failed, Abandoned, Rejected, or Retracted states State CurrentState = m_ActionState.load(); - if (CurrentState != State::Failed && CurrentState != State::Abandoned && CurrentState != State::Retracted) + if (CurrentState != State::Failed && CurrentState != State::Abandoned && CurrentState != State::Rejected && + CurrentState != State::Retracted) { return false; } @@ -333,11 +410,12 @@ RunnerAction::ResetActionStateToPending() // Clear execution fields ExecutionLocation.clear(); + FailureReason.clear(); CpuUsagePercent.store(-1.0f, std::memory_order_relaxed); CpuSeconds.store(0.0f, std::memory_order_relaxed); - // Increment retry count (skip for Retracted — nothing failed) - if (CurrentState != State::Retracted) + // Increment retry count (skip for Retracted/Rejected — nothing failed) + if (CurrentState != State::Retracted && CurrentState != State::Rejected) { RetryCount.fetch_add(1, std::memory_order_relaxed); } diff --git a/src/zencompute/runners/functionrunner.h b/src/zencompute/runners/functionrunner.h index 56c3f3af0..449f0e228 100644 --- a/src/zencompute/runners/functionrunner.h +++ b/src/zencompute/runners/functionrunner.h @@ -10,6 +10,10 @@ # include # include +namespace zen { +class WorkerThreadPool; +} + namespace zen::compute { struct SubmitResult @@ -37,6 +41,22 @@ public: [[nodiscard]] virtual bool IsHealthy() = 0; [[nodiscard]] virtual size_t QueryCapacity(); [[nodiscard]] virtual std::vector SubmitActions(const std::vector>& Actions); + [[nodiscard]] virtual std::string_view GetDisplayName() const { return "local"; } + + // Accumulated stats from the most recent SubmitActions call. + // Reset before each call, populated by the runner implementation. + struct SubmitStats + { + std::atomic TotalAttachments{0}; + std::atomic TotalAttachmentBytes{0}; + + void Reset() + { + TotalAttachments.store(0, std::memory_order_relaxed); + TotalAttachmentBytes.store(0, std::memory_order_relaxed); + } + }; + SubmitStats m_LastSubmitStats; // Best-effort cancellation of a specific in-flight action. Returns true if the // cancellation signal was successfully sent. The action will transition to Cancelled @@ -68,6 +88,8 @@ public: bool CancelAction(int ActionLsn); void CancelRemoteQueue(int QueueId); + void SetWorkerPool(WorkerThreadPool* Pool) { m_WorkerPool = Pool; } + size_t GetRunnerCount() { return m_RunnersLock.WithSharedLock([this] { return m_Runners.size(); }); @@ -79,6 +101,7 @@ protected: RwLock m_RunnersLock; std::vector> m_Runners; std::atomic m_NextSubmitIndex{0}; + WorkerThreadPool* m_WorkerPool = nullptr; }; /** Typed RunnerGroup that adds type-safe runner addition and predicate-based removal. @@ -151,6 +174,7 @@ struct RunnerAction : public RefCounted CbObject ActionObj; int Priority = 0; std::string ExecutionLocation; // "local" or remote hostname + std::string FailureReason; // human-readable reason when action fails (empty on success) // CPU usage and total CPU time of the running process, sampled periodically by the local runner. // CpuUsagePercent: -1.0 means not yet sampled; >=0.0 is the most recent reading as a percentage. @@ -168,6 +192,7 @@ struct RunnerAction : public RefCounted Completed, // Finished successfully with results available Failed, // Execution failed (transient error, eligible for retry) Abandoned, // Infrastructure termination (e.g. spot eviction, session abandon) + Rejected, // Runner declined (e.g. at capacity) — rescheduled without retry cost Cancelled, // Intentional user cancellation (never retried) Retracted, // Pulled back for rescheduling on a different runner (no retry cost) _Count @@ -194,6 +219,8 @@ struct RunnerAction : public RefCounted return "Failed"; case State::Abandoned: return "Abandoned"; + case State::Rejected: + return "Rejected"; case State::Cancelled: return "Cancelled"; case State::Retracted: diff --git a/src/zencompute/runners/linuxrunner.cpp b/src/zencompute/runners/linuxrunner.cpp index 9055005d9..ce5bbdcc8 100644 --- a/src/zencompute/runners/linuxrunner.cpp +++ b/src/zencompute/runners/linuxrunner.cpp @@ -430,7 +430,8 @@ LinuxProcessRunner::SubmitAction(Ref Action) if (ChildPid == 0) { - // Child process + // Child process — lower priority so workers don't starve the main server + nice(5); if (m_Sandboxed) { @@ -481,7 +482,8 @@ LinuxProcessRunner::SubmitAction(Ref Action) // Clean up the sandbox in the background m_DeferredDeleter.Enqueue(Action->ActionLsn, std::move(Prepared->SandboxPath)); - ZEN_ERROR("Sandbox setup failed for action {}: {}", Action->ActionLsn, ErrBuf); + Action->FailureReason = fmt::format("sandbox setup failed: {}", ErrBuf); + ZEN_ERROR("action {} ({}): {}", Action->ActionId, Action->ActionLsn, Action->FailureReason); Action->SetActionState(RunnerAction::State::Failed); return SubmitResult{.IsAccepted = false}; diff --git a/src/zencompute/runners/localrunner.cpp b/src/zencompute/runners/localrunner.cpp index 1b748c0e5..96cbdc134 100644 --- a/src/zencompute/runners/localrunner.cpp +++ b/src/zencompute/runners/localrunner.cpp @@ -357,14 +357,21 @@ LocalProcessRunner::ManifestWorker(const WorkerDesc& Worker) std::filesystem::path WorkerDir = m_WorkerPath / fmt::format("runner_{}", Worker.WorkerId); - if (!std::filesystem::exists(WorkerDir)) + // worker.zcb is written as the last step of ManifestWorker, so its presence + // indicates a complete manifest. If the directory exists but the marker is + // missing, a previous manifest was interrupted and we need to start over. + bool NeedsManifest = !std::filesystem::exists(WorkerDir / "worker.zcb"); + + if (NeedsManifest) { _.ReleaseNow(); RwLock::ExclusiveLockScope $(m_WorkerLock); - if (!std::filesystem::exists(WorkerDir)) + if (!std::filesystem::exists(WorkerDir / "worker.zcb")) { + std::error_code Ec; + std::filesystem::remove_all(WorkerDir, Ec); ManifestWorker(Worker.Descriptor, WorkerDir, [](const IoHash&, CompressedBuffer&) {}); } } @@ -673,9 +680,15 @@ LocalProcessRunner::ProcessCompletedActions(std::vector>& Com } catch (std::exception& Ex) { - ZEN_ERROR("Encountered failure while gathering outputs for action lsn {}, '{}'", ActionLsn, Ex.what()); + Running->Action->FailureReason = fmt::format("exception gathering outputs: {}", Ex.what()); + ZEN_ERROR("action {} ({}) failed: {}", Running->Action->ActionId, ActionLsn, Running->Action->FailureReason); } } + else + { + Running->Action->FailureReason = fmt::format("process exited with code {}", Running->ExitCode); + ZEN_WARN("action {} ({}) failed: {}", Running->Action->ActionId, ActionLsn, Running->Action->FailureReason); + } // Failed - clean up the sandbox in the background. diff --git a/src/zencompute/runners/macrunner.cpp b/src/zencompute/runners/macrunner.cpp index c2ccca9a6..13c01d988 100644 --- a/src/zencompute/runners/macrunner.cpp +++ b/src/zencompute/runners/macrunner.cpp @@ -211,7 +211,8 @@ MacProcessRunner::SubmitAction(Ref Action) if (ChildPid == 0) { - // Child process + // Child process — lower priority so workers don't starve the main server + nice(5); if (m_Sandboxed) { @@ -281,7 +282,8 @@ MacProcessRunner::SubmitAction(Ref Action) // Clean up the sandbox in the background m_DeferredDeleter.Enqueue(Action->ActionLsn, std::move(Prepared->SandboxPath)); - ZEN_ERROR("Sandbox setup failed for action {}: {}", Action->ActionLsn, ErrBuf); + Action->FailureReason = fmt::format("sandbox setup failed: {}", ErrBuf); + ZEN_ERROR("action {} ({}): {}", Action->ActionId, Action->ActionLsn, Action->FailureReason); Action->SetActionState(RunnerAction::State::Failed); return SubmitResult{.IsAccepted = false}; diff --git a/src/zencompute/runners/managedrunner.cpp b/src/zencompute/runners/managedrunner.cpp index e4a7ba388..a4f586852 100644 --- a/src/zencompute/runners/managedrunner.cpp +++ b/src/zencompute/runners/managedrunner.cpp @@ -128,7 +128,7 @@ ManagedProcessRunner::SubmitAction(Ref Action) CreateProcOptions Options; Options.WorkingDirectory = &Prepared->SandboxPath; - Options.Flags = CreateProcOptions::Flag_NoConsole; + Options.Flags = CreateProcOptions::Flag_NoConsole | CreateProcOptions::Flag_BelowNormalPriority; Options.Environment = std::move(EnvPairs); const int32_t ActionLsn = Prepared->ActionLsn; diff --git a/src/zencompute/runners/remotehttprunner.cpp b/src/zencompute/runners/remotehttprunner.cpp index ce6a81173..55f78fdd6 100644 --- a/src/zencompute/runners/remotehttprunner.cpp +++ b/src/zencompute/runners/remotehttprunner.cpp @@ -20,6 +20,7 @@ # include # include +# include ////////////////////////////////////////////////////////////////////////// @@ -38,6 +39,7 @@ RemoteHttpRunner::RemoteHttpRunner(ChunkResolver& InChunkResolver, , m_ChunkResolver{InChunkResolver} , m_WorkerPool{InWorkerPool} , m_HostName{HostName} +, m_DisplayName{HostName} , m_BaseUrl{fmt::format("{}/compute", HostName)} , m_Http(m_BaseUrl) , m_InstanceId(Oid::NewOid()) @@ -59,6 +61,15 @@ RemoteHttpRunner::RemoteHttpRunner(ChunkResolver& InChunkResolver, m_MonitorThread = std::thread{&RemoteHttpRunner::MonitorThreadFunction, this}; } +void +RemoteHttpRunner::SetRemoteHostname(std::string_view Hostname) +{ + if (!Hostname.empty()) + { + m_DisplayName = fmt::format("{} ({})", m_HostName, Hostname); + } +} + RemoteHttpRunner::~RemoteHttpRunner() { Shutdown(); @@ -108,6 +119,7 @@ RemoteHttpRunner::Shutdown() for (auto& [RemoteLsn, HttpAction] : Remaining) { ZEN_DEBUG("shutdown: marking remote action LSN {} (local LSN {}) as Failed", RemoteLsn, HttpAction.Action->ActionLsn); + HttpAction.Action->FailureReason = "remote runner shutdown"; HttpAction.Action->SetActionState(RunnerAction::State::Failed); } } @@ -213,11 +225,13 @@ RemoteHttpRunner::QueryCapacity() return 0; } - // Estimate how much more work we're ready to accept + // Estimate how much more work we're ready to accept. + // Include actions currently being submitted over HTTP so we don't + // keep queueing new submissions while previous ones are still in flight. RwLock::SharedLockScope _{m_RunningLock}; - size_t RunningCount = m_RemoteRunningMap.size(); + size_t RunningCount = m_RemoteRunningMap.size() + m_InFlightSubmissions.load(std::memory_order_relaxed); if (RunningCount >= size_t(m_MaxRunningActions)) { @@ -232,6 +246,9 @@ RemoteHttpRunner::SubmitActions(const std::vector>& Actions) { ZEN_TRACE_CPU("RemoteHttpRunner::SubmitActions"); + m_InFlightSubmissions.fetch_add(Actions.size(), std::memory_order_relaxed); + auto InFlightGuard = MakeGuard([&] { m_InFlightSubmissions.fetch_sub(Actions.size(), std::memory_order_relaxed); }); + if (Actions.size() <= 1) { std::vector Results; @@ -359,108 +376,141 @@ RemoteHttpRunner::SubmitAction(Ref Action) } } - // Enqueue job. If the remote returns FailedDependency (424), it means it - // cannot resolve the worker/function — re-register the worker and retry once. + // Submit the action to the remote. In eager-attach mode we build a + // CbPackage with all referenced attachments upfront to avoid the 404 + // round-trip. In the default mode we POST the bare object first and + // only upload missing attachments if the remote requests them. + // + // In both modes, FailedDependency (424) triggers a worker re-register + // and a single retry. CbObject Result; HttpClient::Response WorkResponse; HttpResponseCode WorkResponseCode{}; - for (int Attempt = 0; Attempt < 2; ++Attempt) - { - WorkResponse = m_Http.Post(SubmitUrl, ActionObj); - WorkResponseCode = WorkResponse.StatusCode; - - if (WorkResponseCode == HttpResponseCode::FailedDependency && Attempt == 0) - { - ZEN_WARN("remote {} returned FailedDependency for action {} — re-registering worker and retrying", - m_Http.GetBaseUri(), - ActionId); - - (void)RegisterWorker(Action->Worker.Descriptor); - } - else - { - break; - } - } - - if (WorkResponseCode == HttpResponseCode::OK) - { - Result = WorkResponse.AsObject(); - } - else if (WorkResponseCode == HttpResponseCode::NotFound) + if (m_EagerAttach) { - // Not all attachments are present - - // Build response package including all required attachments - CbPackage Pkg; Pkg.SetObject(ActionObj); - CbObject Response = WorkResponse.AsObject(); + ActionObj.IterateAttachments([&](CbFieldView Field) { + const IoHash AttachHash = Field.AsHash(); - for (auto& Item : Response["need"sv]) - { - const IoHash NeedHash = Item.AsHash(); - - if (IoBuffer Chunk = m_ChunkResolver.FindChunkByCid(NeedHash)) + if (IoBuffer Chunk = m_ChunkResolver.FindChunkByCid(AttachHash)) { uint64_t DataRawSize = 0; IoHash DataRawHash; CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer{Chunk}, /* out */ DataRawHash, /* out */ DataRawSize); - ZEN_ASSERT(DataRawHash == NeedHash); + Pkg.AddAttachment(CbAttachment(Compressed, AttachHash)); + m_LastSubmitStats.TotalAttachments.fetch_add(1, std::memory_order_relaxed); + m_LastSubmitStats.TotalAttachmentBytes.fetch_add(Chunk.GetSize(), std::memory_order_relaxed); + } + }); + + for (int Attempt = 0; Attempt < 2; ++Attempt) + { + WorkResponse = m_Http.Post(SubmitUrl, Pkg); + WorkResponseCode = WorkResponse.StatusCode; + + if (WorkResponseCode == HttpResponseCode::FailedDependency && Attempt == 0) + { + ZEN_WARN("remote {} returned FailedDependency for action {} — re-registering worker and retrying", + m_Http.GetBaseUri(), + ActionId); - Pkg.AddAttachment(CbAttachment(Compressed, NeedHash)); + (void)RegisterWorker(Action->Worker.Descriptor); } else { - // No such attachment - - return {.IsAccepted = false, .Reason = fmt::format("missing attachment {}", NeedHash)}; + break; } } + } + else + { + for (int Attempt = 0; Attempt < 2; ++Attempt) + { + WorkResponse = m_Http.Post(SubmitUrl, ActionObj); + WorkResponseCode = WorkResponse.StatusCode; - // Post resulting package + if (WorkResponseCode == HttpResponseCode::FailedDependency && Attempt == 0) + { + ZEN_WARN("remote {} returned FailedDependency for action {} — re-registering worker and retrying", + m_Http.GetBaseUri(), + ActionId); - HttpClient::Response PayloadResponse = m_Http.Post(SubmitUrl, Pkg); + (void)RegisterWorker(Action->Worker.Descriptor); + } + else + { + break; + } + } - if (!PayloadResponse) + if (WorkResponseCode == HttpResponseCode::NotFound) { - ZEN_WARN("unable to register payloads for action {} at {}{}", ActionId, m_Http.GetBaseUri(), SubmitUrl); + // Remote needs attachments — resolve them and retry with a CbPackage - // TODO: include more information about the failure in the response + CbPackage Pkg; + Pkg.SetObject(ActionObj); - return {.IsAccepted = false, .Reason = "HTTP request failed"}; - } - else if (PayloadResponse.StatusCode == HttpResponseCode::OK) - { - Result = PayloadResponse.AsObject(); - } - else - { - // Unexpected response - - const int ResponseStatusCode = (int)PayloadResponse.StatusCode; - - ZEN_WARN("unable to register payloads for action {} at {}{} (error: {} {})", - ActionId, - m_Http.GetBaseUri(), - SubmitUrl, - ResponseStatusCode, - ToString(ResponseStatusCode)); - - return {.IsAccepted = false, - .Reason = fmt::format("unexpected response code {} {} from {}{}", - ResponseStatusCode, - ToString(ResponseStatusCode), - m_Http.GetBaseUri(), - SubmitUrl)}; + CbObject Response = WorkResponse.AsObject(); + + for (auto& Item : Response["need"sv]) + { + const IoHash NeedHash = Item.AsHash(); + + if (IoBuffer Chunk = m_ChunkResolver.FindChunkByCid(NeedHash)) + { + uint64_t DataRawSize = 0; + IoHash DataRawHash; + CompressedBuffer Compressed = + CompressedBuffer::FromCompressed(SharedBuffer{Chunk}, /* out */ DataRawHash, /* out */ DataRawSize); + + ZEN_ASSERT(DataRawHash == NeedHash); + + Pkg.AddAttachment(CbAttachment(Compressed, NeedHash)); + m_LastSubmitStats.TotalAttachments.fetch_add(1, std::memory_order_relaxed); + m_LastSubmitStats.TotalAttachmentBytes.fetch_add(Chunk.GetSize(), std::memory_order_relaxed); + } + else + { + return {.IsAccepted = false, .Reason = fmt::format("missing attachment {}", NeedHash)}; + } + } + + HttpClient::Response PayloadResponse = m_Http.Post(SubmitUrl, Pkg); + + if (!PayloadResponse) + { + ZEN_WARN("unable to register payloads for action {} at {}{}", ActionId, m_Http.GetBaseUri(), SubmitUrl); + return {.IsAccepted = false, .Reason = "HTTP request failed"}; + } + + WorkResponse = std::move(PayloadResponse); + WorkResponseCode = WorkResponse.StatusCode; } } + if (WorkResponseCode == HttpResponseCode::OK) + { + Result = WorkResponse.AsObject(); + } + else if (!WorkResponse) + { + ZEN_WARN("submit of action {} to {}{} failed", ActionId, m_Http.GetBaseUri(), SubmitUrl); + return {.IsAccepted = false, .Reason = "HTTP request failed"}; + } + else if (!IsHttpSuccessCode(WorkResponseCode)) + { + const int Code = static_cast(WorkResponseCode); + ZEN_WARN("submit of action {} to {}{} returned {} {}", ActionId, m_Http.GetBaseUri(), SubmitUrl, Code, ToString(Code)); + return {.IsAccepted = false, + .Reason = fmt::format("unexpected response code {} {} from {}{}", Code, ToString(Code), m_Http.GetBaseUri(), SubmitUrl)}; + } + if (Result) { if (const int32_t LsnField = Result["lsn"].AsInt32(0)) @@ -512,82 +562,110 @@ RemoteHttpRunner::SubmitActionBatch(const std::string& SubmitUrl, const std::vec CbObjectWriter Body; Body.BeginArray("actions"sv); + std::unordered_set AttachmentsSeen; + for (const Ref& Action : Actions) { Action->ExecutionLocation = m_HostName; MaybeDumpAction(Action->ActionLsn, Action->ActionObj); Body.AddObject(Action->ActionObj); + + if (m_EagerAttach) + { + Action->ActionObj.IterateAttachments([&](CbFieldView Field) { AttachmentsSeen.insert(Field.AsHash()); }); + } } Body.EndArray(); - // POST the batch - - HttpClient::Response Response = m_Http.Post(SubmitUrl, Body.Save()); - - if (Response.StatusCode == HttpResponseCode::OK) - { - return ParseBatchResponse(Response, Actions); - } + // In eager-attach mode, build a CbPackage with all referenced attachments + // so the remote can accept in a single round-trip. Otherwise POST a bare + // CbObject and handle the 404 need-list flow. - if (Response.StatusCode == HttpResponseCode::NotFound) + if (m_EagerAttach) { - // Server needs attachments — resolve them and retry with a CbPackage - - CbObject NeedObj = Response.AsObject(); - CbPackage Pkg; Pkg.SetObject(Body.Save()); - for (auto& Item : NeedObj["need"sv]) + for (const IoHash& AttachHash : AttachmentsSeen) { - const IoHash NeedHash = Item.AsHash(); - - if (IoBuffer Chunk = m_ChunkResolver.FindChunkByCid(NeedHash)) + if (IoBuffer Chunk = m_ChunkResolver.FindChunkByCid(AttachHash)) { uint64_t DataRawSize = 0; IoHash DataRawHash; CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer{Chunk}, /* out */ DataRawHash, /* out */ DataRawSize); - ZEN_ASSERT(DataRawHash == NeedHash); - - Pkg.AddAttachment(CbAttachment(Compressed, NeedHash)); - } - else - { - ZEN_WARN("batch submit: missing attachment {} — falling back to individual submit", NeedHash); - return FallbackToIndividualSubmit(Actions); + Pkg.AddAttachment(CbAttachment(Compressed, AttachHash)); + m_LastSubmitStats.TotalAttachments.fetch_add(1, std::memory_order_relaxed); + m_LastSubmitStats.TotalAttachmentBytes.fetch_add(Chunk.GetSize(), std::memory_order_relaxed); } } - HttpClient::Response RetryResponse = m_Http.Post(SubmitUrl, Pkg); + HttpClient::Response Response = m_Http.Post(SubmitUrl, Pkg); + + if (Response.StatusCode == HttpResponseCode::OK) + { + return ParseBatchResponse(Response, Actions); + } + } + else + { + HttpClient::Response Response = m_Http.Post(SubmitUrl, Body.Save()); - if (RetryResponse.StatusCode == HttpResponseCode::OK) + if (Response.StatusCode == HttpResponseCode::OK) { - return ParseBatchResponse(RetryResponse, Actions); + return ParseBatchResponse(Response, Actions); } - ZEN_WARN("batch submit retry failed with {} {} — falling back to individual submit", - (int)RetryResponse.StatusCode, - ToString(RetryResponse.StatusCode)); - return FallbackToIndividualSubmit(Actions); + if (Response.StatusCode == HttpResponseCode::NotFound) + { + CbObject NeedObj = Response.AsObject(); + + CbPackage Pkg; + Pkg.SetObject(Body.Save()); + + for (auto& Item : NeedObj["need"sv]) + { + const IoHash NeedHash = Item.AsHash(); + + if (IoBuffer Chunk = m_ChunkResolver.FindChunkByCid(NeedHash)) + { + uint64_t DataRawSize = 0; + IoHash DataRawHash; + CompressedBuffer Compressed = + CompressedBuffer::FromCompressed(SharedBuffer{Chunk}, /* out */ DataRawHash, /* out */ DataRawSize); + + ZEN_ASSERT(DataRawHash == NeedHash); + + Pkg.AddAttachment(CbAttachment(Compressed, NeedHash)); + m_LastSubmitStats.TotalAttachments.fetch_add(1, std::memory_order_relaxed); + m_LastSubmitStats.TotalAttachmentBytes.fetch_add(Chunk.GetSize(), std::memory_order_relaxed); + } + else + { + ZEN_WARN("batch submit: missing attachment {} — falling back to individual submit", NeedHash); + return FallbackToIndividualSubmit(Actions); + } + } + + HttpClient::Response RetryResponse = m_Http.Post(SubmitUrl, Pkg); + + if (RetryResponse.StatusCode == HttpResponseCode::OK) + { + return ParseBatchResponse(RetryResponse, Actions); + } + + ZEN_WARN("batch submit retry failed with {} {} — falling back to individual submit", + (int)RetryResponse.StatusCode, + ToString(RetryResponse.StatusCode)); + return FallbackToIndividualSubmit(Actions); + } } // Unexpected status or connection error — fall back to individual submission - if (Response) - { - ZEN_WARN("batch submit to {}{} returned {} {} — falling back to individual submit", - m_Http.GetBaseUri(), - SubmitUrl, - (int)Response.StatusCode, - ToString(Response.StatusCode)); - } - else - { - ZEN_WARN("batch submit to {}{} failed — falling back to individual submit", m_Http.GetBaseUri(), SubmitUrl); - } + ZEN_WARN("batch submit to {}{} failed — falling back to individual submit", m_Http.GetBaseUri(), SubmitUrl); return FallbackToIndividualSubmit(Actions); } @@ -869,9 +947,10 @@ RemoteHttpRunner::SweepRunningActions() { for (auto& FieldIt : Completed["completed"sv]) { - CbObjectView EntryObj = FieldIt.AsObjectView(); - const int32_t CompleteLsn = EntryObj["lsn"sv].AsInt32(); - std::string_view StateName = EntryObj["state"sv].AsString(); + CbObjectView EntryObj = FieldIt.AsObjectView(); + const int32_t CompleteLsn = EntryObj["lsn"sv].AsInt32(); + std::string_view StateName = EntryObj["state"sv].AsString(); + std::string_view FailureReason = EntryObj["reason"sv].AsString(); RunnerAction::State RemoteState = RunnerAction::FromString(StateName); @@ -884,6 +963,7 @@ RemoteHttpRunner::SweepRunningActions() { HttpRunningAction CompletedAction = std::move(CompleteIt->second); CompletedAction.RemoteState = RemoteState; + CompletedAction.FailureReason = std::string(FailureReason); if (RemoteState == RunnerAction::State::Completed && ResponseJob) { @@ -927,16 +1007,44 @@ RemoteHttpRunner::SweepRunningActions() { const int ActionLsn = HttpAction.Action->ActionLsn; - ZEN_DEBUG("action {} LSN {} (remote LSN {}) -> {}", - HttpAction.Action->ActionId, - ActionLsn, - HttpAction.RemoteActionLsn, - RunnerAction::ToString(HttpAction.RemoteState)); - if (HttpAction.RemoteState == RunnerAction::State::Completed) { + ZEN_DEBUG("action {} LSN {} (remote LSN {}) completed on {}", + HttpAction.Action->ActionId, + ActionLsn, + HttpAction.RemoteActionLsn, + m_HostName); HttpAction.Action->SetResult(std::move(HttpAction.ActionResults)); } + else if (HttpAction.RemoteState == RunnerAction::State::Failed || HttpAction.RemoteState == RunnerAction::State::Abandoned) + { + HttpAction.Action->FailureReason = HttpAction.FailureReason; + if (HttpAction.FailureReason.empty()) + { + ZEN_WARN("action {} ({}) {} on remote {}", + HttpAction.Action->ActionId, + ActionLsn, + RunnerAction::ToString(HttpAction.RemoteState), + m_HostName); + } + else + { + ZEN_WARN("action {} ({}) {} on remote {}: {}", + HttpAction.Action->ActionId, + ActionLsn, + RunnerAction::ToString(HttpAction.RemoteState), + m_HostName, + HttpAction.FailureReason); + } + } + else + { + ZEN_DEBUG("action {} LSN {} (remote LSN {}) -> {}", + HttpAction.Action->ActionId, + ActionLsn, + HttpAction.RemoteActionLsn, + RunnerAction::ToString(HttpAction.RemoteState)); + } HttpAction.Action->SetActionState(HttpAction.RemoteState); } diff --git a/src/zencompute/runners/remotehttprunner.h b/src/zencompute/runners/remotehttprunner.h index c17d0cf2a..fdf113c77 100644 --- a/src/zencompute/runners/remotehttprunner.h +++ b/src/zencompute/runners/remotehttprunner.h @@ -54,8 +54,10 @@ public: [[nodiscard]] virtual size_t QueryCapacity() override; [[nodiscard]] virtual std::vector SubmitActions(const std::vector>& Actions) override; virtual void CancelRemoteQueue(int QueueId) override; + [[nodiscard]] virtual std::string_view GetDisplayName() const override { return m_DisplayName; } std::string_view GetHostName() const { return m_HostName; } + void SetRemoteHostname(std::string_view Hostname); protected: LoggerRef Log() { return m_Log; } @@ -65,12 +67,15 @@ private: ChunkResolver& m_ChunkResolver; WorkerThreadPool& m_WorkerPool; std::string m_HostName; + std::string m_DisplayName; std::string m_BaseUrl; HttpClient m_Http; - std::atomic m_AcceptNewActions{true}; - int32_t m_MaxRunningActions = 256; // arbitrary limit for testing - int32_t m_MaxBatchSize = 50; + std::atomic m_AcceptNewActions{true}; + int32_t m_MaxRunningActions = 256; // arbitrary limit for testing + int32_t m_MaxBatchSize = 50; + bool m_EagerAttach = true; ///< Send attachments with every submit instead of the two-step 404 retry + std::atomic m_InFlightSubmissions{0}; // actions currently being submitted over HTTP struct HttpRunningAction { @@ -78,6 +83,7 @@ private: int RemoteActionLsn = 0; // Remote LSN RunnerAction::State RemoteState = RunnerAction::State::Failed; CbPackage ActionResults; + std::string FailureReason; }; RwLock m_RunningLock; diff --git a/src/zencompute/runners/windowsrunner.cpp b/src/zencompute/runners/windowsrunner.cpp index 92ee65c2d..e643c9ce8 100644 --- a/src/zencompute/runners/windowsrunner.cpp +++ b/src/zencompute/runners/windowsrunner.cpp @@ -48,7 +48,9 @@ WindowsProcessRunner::WindowsProcessRunner(ChunkResolver& Resolver, if (m_JobObject) { JOBOBJECT_EXTENDED_LIMIT_INFORMATION ExtLimits{}; - ExtLimits.BasicLimitInformation.LimitFlags = JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE | JOB_OBJECT_LIMIT_DIE_ON_UNHANDLED_EXCEPTION; + ExtLimits.BasicLimitInformation.LimitFlags = + JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE | JOB_OBJECT_LIMIT_DIE_ON_UNHANDLED_EXCEPTION | JOB_OBJECT_LIMIT_PRIORITY_CLASS; + ExtLimits.BasicLimitInformation.PriorityClass = BELOW_NORMAL_PRIORITY_CLASS; SetInformationJobObject(m_JobObject, JobObjectExtendedLimitInformation, &ExtLimits, sizeof(ExtLimits)); JOBOBJECT_BASIC_UI_RESTRICTIONS UiRestrictions{}; diff --git a/src/zencompute/runners/winerunner.cpp b/src/zencompute/runners/winerunner.cpp index b4fafb467..593b19e55 100644 --- a/src/zencompute/runners/winerunner.cpp +++ b/src/zencompute/runners/winerunner.cpp @@ -96,7 +96,9 @@ WineProcessRunner::SubmitAction(Ref Action) if (ChildPid == 0) { - // Child process + // Child process — lower priority so workers don't starve the main server + nice(5); + if (chdir(SandboxPathStr.c_str()) != 0) { _exit(127); diff --git a/src/zenhorde/README.md b/src/zenhorde/README.md new file mode 100644 index 000000000..13beaa968 --- /dev/null +++ b/src/zenhorde/README.md @@ -0,0 +1,17 @@ +# Horde Compute integration + +Zen compute can use Horde to provision runner nodes. + +## Launch a coordinator instance + +Coordinator instances provision compute resources (runners) from a compute provider such as Horde, and surface an interface which allows zenserver instances to discover endpoints which they can submit actions to. + +```bash +zenserver compute --horde-enabled --horde-server=https://horde.dev.net:13340/ --horde-max-cores=512 --horde-zen-service-port=25000 --http=asio +``` + +## Use a coordinator + +```bash +zen exec beacon --path=e:\lyra-recording --orch=http://localhost:8558 +``` diff --git a/src/zenhorde/hordeagent.cpp b/src/zenhorde/hordeagent.cpp index 819b2d0cb..275f5bd4c 100644 --- a/src/zenhorde/hordeagent.cpp +++ b/src/zenhorde/hordeagent.cpp @@ -8,290 +8,457 @@ #include #include +ZEN_THIRD_PARTY_INCLUDES_START +#include +ZEN_THIRD_PARTY_INCLUDES_END + #include -#include namespace zen::horde { -HordeAgent::HordeAgent(const MachineInfo& Info) : m_Log(zen::logging::Get("horde.agent")), m_MachineInfo(Info) -{ - ZEN_TRACE_CPU("HordeAgent::Connect"); +// --- AsyncHordeAgent --- - auto Transport = std::make_unique(Info); - if (!Transport->IsValid()) +static const char* +GetStateName(AsyncHordeAgent::State S) +{ + switch (S) { - ZEN_WARN("failed to create TCP transport to '{}:{}'", Info.GetConnectionAddress(), Info.GetConnectionPort()); - return; + case AsyncHordeAgent::State::Idle: + return "idle"; + case AsyncHordeAgent::State::Connecting: + return "connect"; + case AsyncHordeAgent::State::WaitAgentAttach: + return "agent-attach"; + case AsyncHordeAgent::State::SentFork: + return "fork"; + case AsyncHordeAgent::State::WaitChildAttach: + return "child-attach"; + case AsyncHordeAgent::State::Uploading: + return "upload"; + case AsyncHordeAgent::State::Executing: + return "execute"; + case AsyncHordeAgent::State::Polling: + return "poll"; + case AsyncHordeAgent::State::Done: + return "done"; + default: + return "unknown"; } +} - // The 64-byte nonce is always sent unencrypted as the first thing on the wire. - // The Horde agent uses this to identify which lease this connection belongs to. - Transport->Send(Info.Nonce, sizeof(Info.Nonce)); +AsyncHordeAgent::AsyncHordeAgent(asio::io_context& IoContext) : m_IoContext(IoContext), m_Log(zen::logging::Get("horde.agent.async")) +{ +} - std::unique_ptr FinalTransport = std::move(Transport); - if (Info.EncryptionMode == Encryption::AES) +AsyncHordeAgent::~AsyncHordeAgent() +{ + Cancel(); +} + +void +AsyncHordeAgent::Start(AsyncAgentConfig Config, AsyncAgentCompletionHandler OnDone) +{ + m_Config = std::move(Config); + m_OnDone = std::move(OnDone); + m_State = State::Connecting; + DoConnect(); +} + +void +AsyncHordeAgent::Cancel() +{ + m_Cancelled = true; + if (m_Socket) { - FinalTransport = std::make_unique(Info.Key, std::move(FinalTransport)); - if (!FinalTransport->IsValid()) - { - ZEN_WARN("failed to create AES transport"); - return; - } + m_Socket->Close(); + } + else if (m_Transport) + { + m_Transport->Close(); } +} - // Create multiplexed socket and channels - m_Socket = std::make_unique(std::move(FinalTransport)); +void +AsyncHordeAgent::DoConnect() +{ + ZEN_TRACE_CPU("AsyncHordeAgent::DoConnect"); - // Channel 0 is the agent control channel (handles Attach/Fork handshake). - // Channel 100 is the child I/O channel (handles file upload and remote execution). - Ref AgentComputeChannel = m_Socket->CreateChannel(0); - Ref ChildComputeChannel = m_Socket->CreateChannel(100); + m_TcpTransport = std::make_unique(m_IoContext); + + auto Self = shared_from_this(); + m_TcpTransport->AsyncConnect(m_Config.Machine, [this, Self](const std::error_code& Ec) { OnConnected(Ec); }); +} - if (!AgentComputeChannel || !ChildComputeChannel) +void +AsyncHordeAgent::OnConnected(const std::error_code& Ec) +{ + if (Ec || m_Cancelled) { - ZEN_WARN("failed to create compute channels"); + if (Ec) + { + ZEN_WARN("connect failed: {}", Ec.message()); + } + Finish(false); return; } - m_AgentChannel = std::make_unique(std::move(AgentComputeChannel)); - m_ChildChannel = std::make_unique(std::move(ChildComputeChannel)); + // Optionally wrap with AES encryption + std::unique_ptr FinalTransport = std::move(m_TcpTransport); + if (m_Config.Machine.EncryptionMode == Encryption::AES) + { + FinalTransport = std::make_unique(m_Config.Machine.Key, std::move(FinalTransport), m_IoContext); + } + m_Transport = std::move(FinalTransport); + + // Create the multiplexed socket and register channels + m_Socket = std::make_shared(std::move(m_Transport), m_IoContext); + + m_AgentChannel = std::make_unique(m_Socket, 0, m_IoContext); + m_ChildChannel = std::make_unique(m_Socket, 100, m_IoContext); + + m_Socket->RegisterChannel( + 0, + [this](std::vector Data) { m_AgentChannel->OnFrame(std::move(Data)); }, + [this]() { m_AgentChannel->OnDetach(); }); - m_IsValid = true; + m_Socket->RegisterChannel( + 100, + [this](std::vector Data) { m_ChildChannel->OnFrame(std::move(Data)); }, + [this]() { m_ChildChannel->OnDetach(); }); + + m_Socket->StartRecvPump(); + + m_State = State::WaitAgentAttach; + DoWaitAgentAttach(); } -HordeAgent::~HordeAgent() +void +AsyncHordeAgent::DoWaitAgentAttach() { - CloseConnection(); + auto Self = shared_from_this(); + m_AgentChannel->AsyncReadResponse(5000, [this, Self](AgentMessageType Type, const uint8_t* Data, size_t Size) { + OnAgentResponse(Type, Data, Size); + }); } -bool -HordeAgent::BeginCommunication() +void +AsyncHordeAgent::OnAgentResponse(AgentMessageType Type, const uint8_t* /*Data*/, size_t /*Size*/) { - ZEN_TRACE_CPU("HordeAgent::BeginCommunication"); - - if (!m_IsValid) + if (m_Cancelled) { - return false; + Finish(false); + return; } - // Start the send/recv pump threads - m_Socket->StartCommunication(); - - // Wait for Attach on agent channel - AgentMessageType Type = m_AgentChannel->ReadResponse(5000); if (Type == AgentMessageType::None) { ZEN_WARN("timed out waiting for Attach on agent channel"); - return false; + Finish(false); + return; } + if (Type != AgentMessageType::Attach) { ZEN_WARN("expected Attach on agent channel, got 0x{:02x}", static_cast(Type)); - return false; + Finish(false); + return; } - // Fork tells the remote agent to create child channel 100 with a 4MB buffer. - // After this, the agent will send an Attach on the child channel. + m_State = State::SentFork; + DoSendFork(); +} + +void +AsyncHordeAgent::DoSendFork() +{ m_AgentChannel->Fork(100, 4 * 1024 * 1024); - // Wait for Attach on child channel - Type = m_ChildChannel->ReadResponse(5000); + m_State = State::WaitChildAttach; + DoWaitChildAttach(); +} + +void +AsyncHordeAgent::DoWaitChildAttach() +{ + auto Self = shared_from_this(); + m_ChildChannel->AsyncReadResponse(5000, [this, Self](AgentMessageType Type, const uint8_t* Data, size_t Size) { + OnChildAttachResponse(Type, Data, Size); + }); +} + +void +AsyncHordeAgent::OnChildAttachResponse(AgentMessageType Type, const uint8_t* /*Data*/, size_t /*Size*/) +{ + if (m_Cancelled) + { + Finish(false); + return; + } + if (Type == AgentMessageType::None) { ZEN_WARN("timed out waiting for Attach on child channel"); - return false; + Finish(false); + return; } + if (Type != AgentMessageType::Attach) { ZEN_WARN("expected Attach on child channel, got 0x{:02x}", static_cast(Type)); - return false; + Finish(false); + return; } - return true; + m_State = State::Uploading; + m_CurrentBundleIndex = 0; + DoUploadNext(); } -bool -HordeAgent::UploadBinaries(const std::filesystem::path& BundleDir, const std::string& BundleLocator) +void +AsyncHordeAgent::DoUploadNext() { - ZEN_TRACE_CPU("HordeAgent::UploadBinaries"); - - m_ChildChannel->UploadFiles("", BundleLocator.c_str()); + if (m_Cancelled) + { + Finish(false); + return; + } - std::unordered_map> BlobFiles; + if (m_CurrentBundleIndex >= m_Config.Bundles.size()) + { + // All bundles uploaded — proceed to execute + m_State = State::Executing; + DoExecute(); + return; + } - auto FindOrOpenBlob = [&](std::string_view Locator) -> BasicFile* { - std::string Key(Locator); + const auto& [Locator, BundleDir] = m_Config.Bundles[m_CurrentBundleIndex]; + m_ChildChannel->UploadFiles("", Locator.c_str()); - if (auto It = BlobFiles.find(Key); It != BlobFiles.end()) - { - return It->second.get(); - } + // Enter the ReadBlob/Blob upload loop + auto Self = shared_from_this(); + m_ChildChannel->AsyncReadResponse(1000, [this, Self](AgentMessageType Type, const uint8_t* Data, size_t Size) { + OnUploadResponse(Type, Data, Size); + }); +} - const std::filesystem::path Path = BundleDir / (Key + ".blob"); - std::error_code Ec; - auto File = std::make_unique(); - File->Open(Path, BasicFile::Mode::kRead, Ec); +void +AsyncHordeAgent::OnUploadResponse(AgentMessageType Type, const uint8_t* Data, size_t Size) +{ + if (m_Cancelled) + { + Finish(false); + return; + } - if (Ec) + if (Type == AgentMessageType::None) + { + if (m_ChildChannel->IsDetached()) { - ZEN_ERROR("cannot read blob file: '{}'", Path); - return nullptr; + ZEN_WARN("connection lost during upload"); + Finish(false); + return; } + // Timeout — retry read + auto Self = shared_from_this(); + m_ChildChannel->AsyncReadResponse(1000, [this, Self](AgentMessageType Type, const uint8_t* Data, size_t Size) { + OnUploadResponse(Type, Data, Size); + }); + return; + } - BasicFile* Ptr = File.get(); - BlobFiles.emplace(std::move(Key), std::move(File)); - return Ptr; - }; + if (Type == AgentMessageType::WriteFilesResponse) + { + // This bundle upload is done — move to next + ++m_CurrentBundleIndex; + DoUploadNext(); + return; + } - // The upload protocol is request-driven: we send WriteFiles, then the remote agent - // sends ReadBlob requests for each blob it needs. We respond with Blob data until - // the agent sends WriteFilesResponse indicating the upload is complete. - constexpr int32_t ReadResponseTimeoutMs = 1000; + if (Type == AgentMessageType::Exception) + { + ExceptionInfo Ex; + AsyncAgentMessageChannel::ReadException(Data, Size, Ex); + ZEN_ERROR("upload exception: {} - {}", Ex.Message, Ex.Description); + Finish(false); + return; + } - for (;;) + if (Type != AgentMessageType::ReadBlob) { - bool TimedOut = false; + ZEN_ERROR("unexpected message type 0x{:02x} during upload", static_cast(Type)); + Finish(false); + return; + } - if (AgentMessageType Type = m_ChildChannel->ReadResponse(ReadResponseTimeoutMs, &TimedOut); Type != AgentMessageType::ReadBlob) - { - if (TimedOut) - { - continue; - } - // End of stream - check if it was a successful upload - if (Type == AgentMessageType::WriteFilesResponse) - { - return true; - } - else if (Type == AgentMessageType::Exception) - { - ExceptionInfo Ex; - m_ChildChannel->ReadException(Ex); - ZEN_ERROR("upload exception: {} - {}", Ex.Message, Ex.Description); - } - else - { - ZEN_ERROR("unexpected message type 0x{:02x} during upload", static_cast(Type)); - } - return false; - } + // Handle ReadBlob request + BlobRequest Req; + AsyncAgentMessageChannel::ReadBlobRequest(Data, Size, Req); - BlobRequest Req; - m_ChildChannel->ReadBlobRequest(Req); + const auto& [Locator, BundleDir] = m_Config.Bundles[m_CurrentBundleIndex]; + const std::filesystem::path BlobPath = BundleDir / (std::string(Req.Locator) + ".blob"); - BasicFile* File = FindOrOpenBlob(Req.Locator); - if (!File) - { - return false; - } + std::error_code FsEc; + BasicFile File; + File.Open(BlobPath, BasicFile::Mode::kRead, FsEc); - // Read from offset to end of file - const uint64_t TotalSize = File->FileSize(); - const uint64_t Offset = static_cast(Req.Offset); - if (Offset >= TotalSize) - { - ZEN_ERROR("upload got request for data beyond end of file: offset={}, length={}, total_size={}", Offset, Req.Length, TotalSize); - m_ChildChannel->Blob(nullptr, 0); - continue; - } + if (FsEc) + { + ZEN_ERROR("cannot read blob file: '{}'", BlobPath); + Finish(false); + return; + } + + const uint64_t TotalSize = File.FileSize(); + const uint64_t Offset = static_cast(Req.Offset); + if (Offset >= TotalSize) + { + ZEN_ERROR("blob request beyond end of file: offset={}, length={}, total_size={}", Offset, Req.Length, TotalSize); + m_ChildChannel->Blob(nullptr, 0); + } + else + { + const IoBuffer FileData = File.ReadRange(Offset, Min(Req.Length, TotalSize - Offset)); + m_ChildChannel->Blob(static_cast(FileData.GetData()), FileData.GetSize()); + } + + // Continue the upload loop + auto Self = shared_from_this(); + m_ChildChannel->AsyncReadResponse(1000, [this, Self](AgentMessageType Type, const uint8_t* Data, size_t Size) { + OnUploadResponse(Type, Data, Size); + }); +} - const IoBuffer Data = File->ReadRange(Offset, Min(Req.Length, TotalSize - Offset)); - m_ChildChannel->Blob(static_cast(Data.GetData()), Data.GetSize()); +void +AsyncHordeAgent::DoExecute() +{ + ZEN_TRACE_CPU("AsyncHordeAgent::DoExecute"); + + std::vector ArgPtrs; + ArgPtrs.reserve(m_Config.Args.size()); + for (const std::string& Arg : m_Config.Args) + { + ArgPtrs.push_back(Arg.c_str()); } + + m_ChildChannel->Execute(m_Config.Executable.c_str(), + ArgPtrs.data(), + ArgPtrs.size(), + nullptr, + nullptr, + 0, + m_Config.UseWine ? ExecuteProcessFlags::UseWine : ExecuteProcessFlags::None); + + ZEN_INFO("remote execution started on [{}:{}] lease={}", + m_Config.Machine.GetConnectionAddress(), + m_Config.Machine.GetConnectionPort(), + m_Config.Machine.LeaseId); + + m_State = State::Polling; + DoPoll(); } void -HordeAgent::Execute(const char* Exe, - const char* const* Args, - size_t NumArgs, - const char* WorkingDir, - const char* const* EnvVars, - size_t NumEnvVars, - bool UseWine) +AsyncHordeAgent::DoPoll() { - ZEN_TRACE_CPU("HordeAgent::Execute"); - m_ChildChannel - ->Execute(Exe, Args, NumArgs, WorkingDir, EnvVars, NumEnvVars, UseWine ? ExecuteProcessFlags::UseWine : ExecuteProcessFlags::None); + if (m_Cancelled) + { + Finish(false); + return; + } + + auto Self = shared_from_this(); + m_ChildChannel->AsyncReadResponse(100, [this, Self](AgentMessageType Type, const uint8_t* Data, size_t Size) { + OnPollResponse(Type, Data, Size); + }); } -bool -HordeAgent::Poll(bool LogOutput) +void +AsyncHordeAgent::OnPollResponse(AgentMessageType Type, const uint8_t* Data, size_t Size) { - constexpr int32_t ReadResponseTimeoutMs = 100; - AgentMessageType Type; + if (m_Cancelled) + { + Finish(false); + return; + } - while ((Type = m_ChildChannel->ReadResponse(ReadResponseTimeoutMs)) != AgentMessageType::None) + switch (Type) { - switch (Type) - { - case AgentMessageType::ExecuteOutput: - { - if (LogOutput && m_ChildChannel->GetResponseSize() > 0) - { - const char* ResponseData = static_cast(m_ChildChannel->GetResponseData()); - size_t ResponseSize = m_ChildChannel->GetResponseSize(); - - // Trim trailing newlines - while (ResponseSize > 0 && (ResponseData[ResponseSize - 1] == '\n' || ResponseData[ResponseSize - 1] == '\r')) - { - --ResponseSize; - } - - if (ResponseSize > 0) - { - const std::string_view Output(ResponseData, ResponseSize); - ZEN_INFO("[remote] {}", Output); - } - } - break; - } + case AgentMessageType::None: + if (m_ChildChannel->IsDetached()) + { + ZEN_WARN("connection lost during execution"); + Finish(false); + } + else + { + // Timeout — poll again + DoPoll(); + } + break; - case AgentMessageType::ExecuteResult: - { - if (m_ChildChannel->GetResponseSize() == sizeof(int32_t)) - { - int32_t ExitCode; - memcpy(&ExitCode, m_ChildChannel->GetResponseData(), sizeof(int32_t)); - ZEN_INFO("remote process exited with code {}", ExitCode); - } - m_IsValid = false; - return false; - } + case AgentMessageType::ExecuteOutput: + // Silently consume remote stdout (matching LogOutput=false in provisioner) + DoPoll(); + break; - case AgentMessageType::Exception: + case AgentMessageType::ExecuteResult: + { + int32_t ExitCode = -1; + if (Size == sizeof(int32_t)) { - ExceptionInfo Ex; - m_ChildChannel->ReadException(Ex); - ZEN_ERROR("exception: {} - {}", Ex.Message, Ex.Description); - m_HasErrors = true; - break; + memcpy(&ExitCode, Data, sizeof(int32_t)); } + ZEN_INFO("remote process exited with code {} (lease={})", ExitCode, m_Config.Machine.LeaseId); + Finish(ExitCode == 0, ExitCode); + } + break; - default: - break; - } - } + case AgentMessageType::Exception: + { + ExceptionInfo Ex; + AsyncAgentMessageChannel::ReadException(Data, Size, Ex); + ZEN_ERROR("exception: {} - {}", Ex.Message, Ex.Description); + Finish(false); + } + break; - return m_IsValid && !m_HasErrors; + default: + DoPoll(); + break; + } } void -HordeAgent::CloseConnection() +AsyncHordeAgent::Finish(bool Success, int32_t ExitCode) { - if (m_ChildChannel) + if (m_State == State::Done) { - m_ChildChannel->Close(); + return; // Already finished } - if (m_AgentChannel) + + if (!Success) { - m_AgentChannel->Close(); + ZEN_WARN("agent failed during {} (lease={})", GetStateName(m_State), m_Config.Machine.LeaseId); } -} -bool -HordeAgent::IsValid() const -{ - return m_IsValid && !m_HasErrors; + m_State = State::Done; + + if (m_Socket) + { + m_Socket->Close(); + } + + if (m_OnDone) + { + AsyncAgentResult Result; + Result.Success = Success; + Result.ExitCode = ExitCode; + Result.CoreCount = m_Config.Machine.LogicalCores; + + auto Handler = std::move(m_OnDone); + m_OnDone = nullptr; + Handler(Result); + } } } // namespace zen::horde diff --git a/src/zenhorde/hordeagent.h b/src/zenhorde/hordeagent.h index e0ae89ead..b581a8da1 100644 --- a/src/zenhorde/hordeagent.h +++ b/src/zenhorde/hordeagent.h @@ -10,68 +10,108 @@ #include #include +#include #include #include +#include + +namespace asio { +class io_context; +} namespace zen::horde { -/** Manages the lifecycle of a single Horde compute agent. +class AsyncComputeTransport; + +/** Result passed to the completion handler when an async agent finishes. */ +struct AsyncAgentResult +{ + bool Success = false; + int32_t ExitCode = -1; + uint16_t CoreCount = 0; ///< Logical cores on the provisioned machine +}; + +/** Completion handler for async agent lifecycle. */ +using AsyncAgentCompletionHandler = std::function; + +/** Configuration for launching a remote zenserver instance via an async agent. */ +struct AsyncAgentConfig +{ + MachineInfo Machine; + std::vector> Bundles; ///< (locator, bundleDir) pairs + std::string Executable; + std::vector Args; + bool UseWine = false; +}; + +/** Async agent that manages the full lifecycle of a single Horde compute connection. * - * Handles the full connection sequence for one provisioned machine: - * 1. Connect via TCP transport (with optional AES encryption wrapping) - * 2. Create a multiplexed ComputeSocket with agent (channel 0) and child (channel 100) - * 3. Perform the Attach/Fork handshake to establish the child channel - * 4. Upload zenserver binary via the WriteFiles/ReadBlob protocol - * 5. Execute zenserver remotely via ExecuteV2 - * 6. Poll for ExecuteOutput (stdout) and ExecuteResult (exit code) + * Driven by a state machine using callbacks on a shared io_context — no dedicated + * threads. Call Start() to begin the connection/handshake/upload/execute/poll + * sequence. The completion handler is invoked when the remote process exits or + * an error occurs. */ -class HordeAgent +class AsyncHordeAgent : public std::enable_shared_from_this { public: - explicit HordeAgent(const MachineInfo& Info); - ~HordeAgent(); + AsyncHordeAgent(asio::io_context& IoContext); + ~AsyncHordeAgent(); - HordeAgent(const HordeAgent&) = delete; - HordeAgent& operator=(const HordeAgent&) = delete; + AsyncHordeAgent(const AsyncHordeAgent&) = delete; + AsyncHordeAgent& operator=(const AsyncHordeAgent&) = delete; - /** Perform the channel setup handshake (Attach on agent channel, Fork, Attach on child channel). - * Returns false if the handshake times out or receives an unexpected message. */ - bool BeginCommunication(); + /** Start the full agent lifecycle. The completion handler is called exactly once. */ + void Start(AsyncAgentConfig Config, AsyncAgentCompletionHandler OnDone); - /** Upload binary files to the remote agent. - * @param BundleDir Directory containing .blob files. - * @param BundleLocator Locator string identifying the bundle (from CreateBundle). */ - bool UploadBinaries(const std::filesystem::path& BundleDir, const std::string& BundleLocator); + /** Cancel in-flight operations. The completion handler is still called (with Success=false). */ + void Cancel(); - /** Execute a command on the remote machine. */ - void Execute(const char* Exe, - const char* const* Args, - size_t NumArgs, - const char* WorkingDir = nullptr, - const char* const* EnvVars = nullptr, - size_t NumEnvVars = 0, - bool UseWine = false); + const MachineInfo& GetMachineInfo() const { return m_Config.Machine; } - /** Poll for output and results. Returns true if the agent is still running. - * When LogOutput is true, remote stdout is logged via ZEN_INFO. */ - bool Poll(bool LogOutput = true); - - void CloseConnection(); - bool IsValid() const; - - const MachineInfo& GetMachineInfo() const { return m_MachineInfo; } + enum class State + { + Idle, + Connecting, + WaitAgentAttach, + SentFork, + WaitChildAttach, + Uploading, + Executing, + Polling, + Done + }; private: LoggerRef Log() { return m_Log; } - std::unique_ptr m_Socket; - std::unique_ptr m_AgentChannel; ///< Channel 0: agent control - std::unique_ptr m_ChildChannel; ///< Channel 100: child I/O - - LoggerRef m_Log; - bool m_IsValid = false; - bool m_HasErrors = false; - MachineInfo m_MachineInfo; + void DoConnect(); + void OnConnected(const std::error_code& Ec); + void DoWaitAgentAttach(); + void OnAgentResponse(AgentMessageType Type, const uint8_t* Data, size_t Size); + void DoSendFork(); + void DoWaitChildAttach(); + void OnChildAttachResponse(AgentMessageType Type, const uint8_t* Data, size_t Size); + void DoUploadNext(); + void OnUploadResponse(AgentMessageType Type, const uint8_t* Data, size_t Size); + void DoExecute(); + void DoPoll(); + void OnPollResponse(AgentMessageType Type, const uint8_t* Data, size_t Size); + void Finish(bool Success, int32_t ExitCode = -1); + + asio::io_context& m_IoContext; + LoggerRef m_Log; + State m_State = State::Idle; + bool m_Cancelled = false; + + AsyncAgentConfig m_Config; + AsyncAgentCompletionHandler m_OnDone; + size_t m_CurrentBundleIndex = 0; + + std::unique_ptr m_TcpTransport; + std::unique_ptr m_Transport; + std::shared_ptr m_Socket; + std::unique_ptr m_AgentChannel; + std::unique_ptr m_ChildChannel; }; } // namespace zen::horde diff --git a/src/zenhorde/hordeagentmessage.cpp b/src/zenhorde/hordeagentmessage.cpp index 998134a96..31498972f 100644 --- a/src/zenhorde/hordeagentmessage.cpp +++ b/src/zenhorde/hordeagentmessage.cpp @@ -4,337 +4,403 @@ #include +ZEN_THIRD_PARTY_INCLUDES_START +#include +ZEN_THIRD_PARTY_INCLUDES_END + #include #include namespace zen::horde { -AgentMessageChannel::AgentMessageChannel(Ref Channel) : m_Channel(std::move(Channel)) +// --- AsyncAgentMessageChannel --- + +AsyncAgentMessageChannel::AsyncAgentMessageChannel(std::shared_ptr Socket, int ChannelId, asio::io_context& IoContext) +: m_Socket(std::move(Socket)) +, m_ChannelId(ChannelId) +, m_IoContext(IoContext) +, m_TimeoutTimer(std::make_unique(IoContext)) { } -AgentMessageChannel::~AgentMessageChannel() = default; - -void -AgentMessageChannel::Close() +AsyncAgentMessageChannel::~AsyncAgentMessageChannel() { - CreateMessage(AgentMessageType::None, 0); - FlushMessage(); + if (m_TimeoutTimer) + { + m_TimeoutTimer->cancel(); + } } -void -AgentMessageChannel::Ping() +// --- Message building helpers --- + +std::vector +AsyncAgentMessageChannel::BeginMessage(AgentMessageType Type, size_t ReservePayload) { - CreateMessage(AgentMessageType::Ping, 0); - FlushMessage(); + std::vector Buf; + Buf.reserve(MessageHeaderLength + ReservePayload); + Buf.push_back(static_cast(Type)); + Buf.resize(MessageHeaderLength); // 1 byte type + 4 bytes length placeholder + return Buf; } void -AgentMessageChannel::Fork(int ChannelId, int BufferSize) +AsyncAgentMessageChannel::FinalizeAndSend(std::vector Msg) { - CreateMessage(AgentMessageType::Fork, sizeof(int) + sizeof(int)); - WriteInt32(ChannelId); - WriteInt32(BufferSize); - FlushMessage(); + const uint32_t PayloadSize = static_cast(Msg.size() - MessageHeaderLength); + memcpy(&Msg[1], &PayloadSize, sizeof(uint32_t)); + m_Socket->AsyncSendFrame(m_ChannelId, std::move(Msg)); } void -AgentMessageChannel::Attach() +AsyncAgentMessageChannel::WriteInt32(std::vector& Buf, int Value) { - CreateMessage(AgentMessageType::Attach, 0); - FlushMessage(); + const uint8_t* Ptr = reinterpret_cast(&Value); + Buf.insert(Buf.end(), Ptr, Ptr + sizeof(int)); } -void -AgentMessageChannel::UploadFiles(const char* Path, const char* Locator) +int +AsyncAgentMessageChannel::ReadInt32(const uint8_t** Pos) { - CreateMessage(AgentMessageType::WriteFiles, strlen(Path) + strlen(Locator) + 20); - WriteString(Path); - WriteString(Locator); - FlushMessage(); + int Value; + memcpy(&Value, *Pos, sizeof(int)); + *Pos += sizeof(int); + return Value; } void -AgentMessageChannel::Execute(const char* Exe, - const char* const* Args, - size_t NumArgs, - const char* WorkingDir, - const char* const* EnvVars, - size_t NumEnvVars, - ExecuteProcessFlags Flags) +AsyncAgentMessageChannel::WriteFixedLengthBytes(std::vector& Buf, const uint8_t* Data, size_t Length) { - size_t RequiredSize = 50 + strlen(Exe); - for (size_t i = 0; i < NumArgs; ++i) - { - RequiredSize += strlen(Args[i]) + 10; - } - if (WorkingDir) - { - RequiredSize += strlen(WorkingDir) + 10; - } - for (size_t i = 0; i < NumEnvVars; ++i) - { - RequiredSize += strlen(EnvVars[i]) + 20; - } - - CreateMessage(AgentMessageType::ExecuteV2, RequiredSize); - WriteString(Exe); - - WriteUnsignedVarInt(NumArgs); - for (size_t i = 0; i < NumArgs; ++i) - { - WriteString(Args[i]); - } - - WriteOptionalString(WorkingDir); - - // ExecuteV2 protocol requires env vars as separate key/value pairs. - // Callers pass "KEY=VALUE" strings; we split on the first '=' here. - WriteUnsignedVarInt(NumEnvVars); - for (size_t i = 0; i < NumEnvVars; ++i) - { - const char* Eq = strchr(EnvVars[i], '='); - assert(Eq != nullptr); - - WriteString(std::string_view(EnvVars[i], Eq - EnvVars[i])); - if (*(Eq + 1) == '\0') - { - WriteOptionalString(nullptr); - } - else - { - WriteOptionalString(Eq + 1); - } - } + Buf.insert(Buf.end(), Data, Data + Length); +} - WriteInt32(static_cast(Flags)); - FlushMessage(); +const uint8_t* +AsyncAgentMessageChannel::ReadFixedLengthBytes(const uint8_t** Pos, size_t Length) +{ + const uint8_t* Data = *Pos; + *Pos += Length; + return Data; } -void -AgentMessageChannel::Blob(const uint8_t* Data, size_t Length) +size_t +AsyncAgentMessageChannel::MeasureUnsignedVarInt(size_t Value) { - // Blob responses are chunked to fit within the compute buffer's chunk size. - // The 128-byte margin accounts for the ReadBlobResponse header (offset + total length fields). - const size_t MaxChunkSize = m_Channel->Writer.GetChunkMaxLength() - 128 - MessageHeaderLength; - for (size_t ChunkOffset = 0; ChunkOffset < Length;) + if (Value == 0) { - const size_t ChunkLength = std::min(Length - ChunkOffset, MaxChunkSize); - - CreateMessage(AgentMessageType::ReadBlobResponse, ChunkLength + 128); - WriteInt32(static_cast(ChunkOffset)); - WriteInt32(static_cast(Length)); - WriteFixedLengthBytes(Data + ChunkOffset, ChunkLength); - FlushMessage(); - - ChunkOffset += ChunkLength; + return 1; } + return (FloorLog2_64(static_cast(Value)) / 7) + 1; } -AgentMessageType -AgentMessageChannel::ReadResponse(int32_t TimeoutMs, bool* OutTimedOut) +void +AsyncAgentMessageChannel::WriteUnsignedVarInt(std::vector& Buf, size_t Value) { - // Deferred advance: the previous response's buffer is only released when the next - // ReadResponse is called. This allows callers to read response data between calls - // without copying, since the pointer comes directly from the ring buffer. - if (m_ResponseData) - { - m_Channel->Reader.AdvanceReadPosition(m_ResponseLength + MessageHeaderLength); - m_ResponseData = nullptr; - m_ResponseLength = 0; - } + const size_t ByteCount = MeasureUnsignedVarInt(Value); + const size_t StartPos = Buf.size(); + Buf.resize(StartPos + ByteCount); - const uint8_t* Header = m_Channel->Reader.WaitToRead(MessageHeaderLength, TimeoutMs, OutTimedOut); - if (!Header) + uint8_t* Output = Buf.data() + StartPos; + for (size_t i = 1; i < ByteCount; ++i) { - return AgentMessageType::None; + Output[ByteCount - i] = static_cast(Value); + Value >>= 8; } + Output[0] = static_cast((0xFF << (9 - static_cast(ByteCount))) | static_cast(Value)); +} - uint32_t Length; - memcpy(&Length, Header + 1, sizeof(uint32_t)); +size_t +AsyncAgentMessageChannel::ReadUnsignedVarInt(const uint8_t** Pos) +{ + const uint8_t* Data = *Pos; + const uint8_t FirstByte = Data[0]; + const size_t NumBytes = CountLeadingZeros(0xFF & (~static_cast(FirstByte))) + 1 - 24; - Header = m_Channel->Reader.WaitToRead(MessageHeaderLength + Length, TimeoutMs, OutTimedOut); - if (!Header) + size_t Value = static_cast(FirstByte & (0xFF >> NumBytes)); + for (size_t i = 1; i < NumBytes; ++i) { - return AgentMessageType::None; + Value <<= 8; + Value |= Data[i]; } - m_ResponseType = static_cast(Header[0]); - m_ResponseData = Header + MessageHeaderLength; - m_ResponseLength = Length; - - return m_ResponseType; + *Pos += NumBytes; + return Value; } void -AgentMessageChannel::ReadException(ExceptionInfo& Ex) +AsyncAgentMessageChannel::WriteString(std::vector& Buf, const char* Text) { - assert(m_ResponseType == AgentMessageType::Exception); - const uint8_t* Pos = m_ResponseData; - Ex.Message = ReadString(&Pos); - Ex.Description = ReadString(&Pos); + const size_t Length = strlen(Text); + WriteUnsignedVarInt(Buf, Length); + WriteFixedLengthBytes(Buf, reinterpret_cast(Text), Length); } -int -AgentMessageChannel::ReadExecuteResult() +void +AsyncAgentMessageChannel::WriteString(std::vector& Buf, std::string_view Text) { - assert(m_ResponseType == AgentMessageType::ExecuteResult); - const uint8_t* Pos = m_ResponseData; - return ReadInt32(&Pos); + WriteUnsignedVarInt(Buf, Text.size()); + WriteFixedLengthBytes(Buf, reinterpret_cast(Text.data()), Text.size()); } -void -AgentMessageChannel::ReadBlobRequest(BlobRequest& Req) +std::string_view +AsyncAgentMessageChannel::ReadString(const uint8_t** Pos) { - assert(m_ResponseType == AgentMessageType::ReadBlob); - const uint8_t* Pos = m_ResponseData; - Req.Locator = ReadString(&Pos); - Req.Offset = ReadUnsignedVarInt(&Pos); - Req.Length = ReadUnsignedVarInt(&Pos); + const size_t Length = ReadUnsignedVarInt(Pos); + const char* Start = reinterpret_cast(ReadFixedLengthBytes(Pos, Length)); + return std::string_view(Start, Length); } void -AgentMessageChannel::CreateMessage(AgentMessageType Type, size_t MaxLength) +AsyncAgentMessageChannel::WriteOptionalString(std::vector& Buf, const char* Text) { - m_RequestData = m_Channel->Writer.WaitToWrite(MessageHeaderLength + MaxLength); - m_RequestData[0] = static_cast(Type); - m_MaxRequestSize = MaxLength; - m_RequestSize = 0; + if (!Text) + { + WriteUnsignedVarInt(Buf, 0); + } + else + { + const size_t Length = strlen(Text); + WriteUnsignedVarInt(Buf, Length + 1); + WriteFixedLengthBytes(Buf, reinterpret_cast(Text), Length); + } } +// --- Send methods --- + void -AgentMessageChannel::FlushMessage() +AsyncAgentMessageChannel::Close() { - const uint32_t Size = static_cast(m_RequestSize); - memcpy(&m_RequestData[1], &Size, sizeof(uint32_t)); - m_Channel->Writer.AdvanceWritePosition(MessageHeaderLength + m_RequestSize); - m_RequestSize = 0; - m_MaxRequestSize = 0; - m_RequestData = nullptr; + auto Msg = BeginMessage(AgentMessageType::None, 0); + FinalizeAndSend(std::move(Msg)); } void -AgentMessageChannel::WriteInt32(int Value) +AsyncAgentMessageChannel::Ping() { - WriteFixedLengthBytes(reinterpret_cast(&Value), sizeof(int)); + auto Msg = BeginMessage(AgentMessageType::Ping, 0); + FinalizeAndSend(std::move(Msg)); } -int -AgentMessageChannel::ReadInt32(const uint8_t** Pos) +void +AsyncAgentMessageChannel::Fork(int ChannelId, int BufferSize) { - int Value; - memcpy(&Value, *Pos, sizeof(int)); - *Pos += sizeof(int); - return Value; + auto Msg = BeginMessage(AgentMessageType::Fork, sizeof(int) + sizeof(int)); + WriteInt32(Msg, ChannelId); + WriteInt32(Msg, BufferSize); + FinalizeAndSend(std::move(Msg)); } void -AgentMessageChannel::WriteFixedLengthBytes(const uint8_t* Data, size_t Length) +AsyncAgentMessageChannel::Attach() { - assert(m_RequestSize + Length <= m_MaxRequestSize); - memcpy(&m_RequestData[MessageHeaderLength + m_RequestSize], Data, Length); - m_RequestSize += Length; + auto Msg = BeginMessage(AgentMessageType::Attach, 0); + FinalizeAndSend(std::move(Msg)); } -const uint8_t* -AgentMessageChannel::ReadFixedLengthBytes(const uint8_t** Pos, size_t Length) +void +AsyncAgentMessageChannel::UploadFiles(const char* Path, const char* Locator) { - const uint8_t* Data = *Pos; - *Pos += Length; - return Data; + auto Msg = BeginMessage(AgentMessageType::WriteFiles, strlen(Path) + strlen(Locator) + 20); + WriteString(Msg, Path); + WriteString(Msg, Locator); + FinalizeAndSend(std::move(Msg)); } -size_t -AgentMessageChannel::MeasureUnsignedVarInt(size_t Value) +void +AsyncAgentMessageChannel::Execute(const char* Exe, + const char* const* Args, + size_t NumArgs, + const char* WorkingDir, + const char* const* EnvVars, + size_t NumEnvVars, + ExecuteProcessFlags Flags) { - if (Value == 0) + size_t ReserveSize = 50 + strlen(Exe); + for (size_t i = 0; i < NumArgs; ++i) { - return 1; + ReserveSize += strlen(Args[i]) + 10; } - return (FloorLog2_64(static_cast(Value)) / 7) + 1; + if (WorkingDir) + { + ReserveSize += strlen(WorkingDir) + 10; + } + for (size_t i = 0; i < NumEnvVars; ++i) + { + ReserveSize += strlen(EnvVars[i]) + 20; + } + + auto Msg = BeginMessage(AgentMessageType::ExecuteV2, ReserveSize); + WriteString(Msg, Exe); + + WriteUnsignedVarInt(Msg, NumArgs); + for (size_t i = 0; i < NumArgs; ++i) + { + WriteString(Msg, Args[i]); + } + + WriteOptionalString(Msg, WorkingDir); + + WriteUnsignedVarInt(Msg, NumEnvVars); + for (size_t i = 0; i < NumEnvVars; ++i) + { + const char* Eq = strchr(EnvVars[i], '='); + assert(Eq != nullptr); + + WriteString(Msg, std::string_view(EnvVars[i], Eq - EnvVars[i])); + if (*(Eq + 1) == '\0') + { + WriteOptionalString(Msg, nullptr); + } + else + { + WriteOptionalString(Msg, Eq + 1); + } + } + + WriteInt32(Msg, static_cast(Flags)); + FinalizeAndSend(std::move(Msg)); } void -AgentMessageChannel::WriteUnsignedVarInt(size_t Value) +AsyncAgentMessageChannel::Blob(const uint8_t* Data, size_t Length) { - const size_t ByteCount = MeasureUnsignedVarInt(Value); - assert(m_RequestSize + ByteCount <= m_MaxRequestSize); + static constexpr size_t MaxBlobChunkSize = 512 * 1024; - uint8_t* Output = m_RequestData + MessageHeaderLength + m_RequestSize; - for (size_t i = 1; i < ByteCount; ++i) + for (size_t ChunkOffset = 0; ChunkOffset < Length;) { - Output[ByteCount - i] = static_cast(Value); - Value >>= 8; - } - Output[0] = static_cast((0xFF << (9 - static_cast(ByteCount))) | static_cast(Value)); + const size_t ChunkLength = std::min(Length - ChunkOffset, MaxBlobChunkSize); - m_RequestSize += ByteCount; + auto Msg = BeginMessage(AgentMessageType::ReadBlobResponse, ChunkLength + 128); + WriteInt32(Msg, static_cast(ChunkOffset)); + WriteInt32(Msg, static_cast(Length)); + WriteFixedLengthBytes(Msg, Data + ChunkOffset, ChunkLength); + FinalizeAndSend(std::move(Msg)); + + ChunkOffset += ChunkLength; + } } -size_t -AgentMessageChannel::ReadUnsignedVarInt(const uint8_t** Pos) +// --- Async response reading --- + +void +AsyncAgentMessageChannel::AsyncReadResponse(int32_t TimeoutMs, AsyncResponseHandler Handler) { - const uint8_t* Data = *Pos; - const uint8_t FirstByte = Data[0]; - const size_t NumBytes = CountLeadingZeros(0xFF & (~static_cast(FirstByte))) + 1 - 24; + // If frames are already queued, dispatch immediately + if (!m_IncomingFrames.empty()) + { + std::vector Frame = std::move(m_IncomingFrames.front()); + m_IncomingFrames.pop_front(); - size_t Value = static_cast(FirstByte & (0xFF >> NumBytes)); - for (size_t i = 1; i < NumBytes; ++i) + if (Frame.size() >= MessageHeaderLength) + { + AgentMessageType Type = static_cast(Frame[0]); + const uint8_t* Data = Frame.data() + MessageHeaderLength; + size_t Size = Frame.size() - MessageHeaderLength; + asio::post(m_IoContext, [Handler = std::move(Handler), Type, Frame = std::move(Frame), Data, Size]() mutable { + // The Frame is captured to keep Data pointer valid + Handler(Type, Data, Size); + }); + } + else + { + asio::post(m_IoContext, [Handler = std::move(Handler)] { Handler(AgentMessageType::None, nullptr, 0); }); + } + return; + } + + if (m_Detached) { - Value <<= 8; - Value |= Data[i]; + asio::post(m_IoContext, [Handler = std::move(Handler)] { Handler(AgentMessageType::None, nullptr, 0); }); + return; } - *Pos += NumBytes; - return Value; + // No frames queued — store pending handler and arm timeout + m_PendingHandler = std::move(Handler); + + if (TimeoutMs >= 0) + { + m_TimeoutTimer->expires_after(std::chrono::milliseconds(TimeoutMs)); + m_TimeoutTimer->async_wait([this](const asio::error_code& Ec) { + if (Ec) + { + return; // Cancelled — frame arrived before timeout + } + + if (m_PendingHandler) + { + AsyncResponseHandler Handler = std::move(m_PendingHandler); + m_PendingHandler = nullptr; + Handler(AgentMessageType::None, nullptr, 0); + } + }); + } } -size_t -AgentMessageChannel::MeasureString(const char* Text) const +void +AsyncAgentMessageChannel::OnFrame(std::vector Data) { - const size_t Length = strlen(Text); - return MeasureUnsignedVarInt(Length) + Length; + if (m_PendingHandler) + { + // Cancel the timeout timer + m_TimeoutTimer->cancel(); + + AsyncResponseHandler Handler = std::move(m_PendingHandler); + m_PendingHandler = nullptr; + + if (Data.size() >= MessageHeaderLength) + { + AgentMessageType Type = static_cast(Data[0]); + const uint8_t* Payload = Data.data() + MessageHeaderLength; + size_t PayloadSize = Data.size() - MessageHeaderLength; + Handler(Type, Payload, PayloadSize); + } + else + { + Handler(AgentMessageType::None, nullptr, 0); + } + } + else + { + m_IncomingFrames.push_back(std::move(Data)); + } } void -AgentMessageChannel::WriteString(const char* Text) +AsyncAgentMessageChannel::OnDetach() { - const size_t Length = strlen(Text); - WriteUnsignedVarInt(Length); - WriteFixedLengthBytes(reinterpret_cast(Text), Length); + m_Detached = true; + + if (m_PendingHandler) + { + m_TimeoutTimer->cancel(); + AsyncResponseHandler Handler = std::move(m_PendingHandler); + m_PendingHandler = nullptr; + Handler(AgentMessageType::None, nullptr, 0); + } } +// --- Response parsing helpers --- + void -AgentMessageChannel::WriteString(std::string_view Text) +AsyncAgentMessageChannel::ReadException(const uint8_t* Data, size_t /*Size*/, ExceptionInfo& Ex) { - WriteUnsignedVarInt(Text.size()); - WriteFixedLengthBytes(reinterpret_cast(Text.data()), Text.size()); + const uint8_t* Pos = Data; + Ex.Message = ReadString(&Pos); + Ex.Description = ReadString(&Pos); } -std::string_view -AgentMessageChannel::ReadString(const uint8_t** Pos) +int +AsyncAgentMessageChannel::ReadExecuteResult(const uint8_t* Data, size_t /*Size*/) { - const size_t Length = ReadUnsignedVarInt(Pos); - const char* Start = reinterpret_cast(ReadFixedLengthBytes(Pos, Length)); - return std::string_view(Start, Length); + const uint8_t* Pos = Data; + return ReadInt32(&Pos); } void -AgentMessageChannel::WriteOptionalString(const char* Text) +AsyncAgentMessageChannel::ReadBlobRequest(const uint8_t* Data, size_t /*Size*/, BlobRequest& Req) { - // Optional strings use length+1 encoding: 0 means null/absent, - // N>0 means a string of length N-1 follows. This matches the UE - // FAgentMessageChannel serialization convention. - if (!Text) - { - WriteUnsignedVarInt(0); - } - else - { - const size_t Length = strlen(Text); - WriteUnsignedVarInt(Length + 1); - WriteFixedLengthBytes(reinterpret_cast(Text), Length); - } + const uint8_t* Pos = Data; + Req.Locator = ReadString(&Pos); + Req.Offset = ReadUnsignedVarInt(&Pos); + Req.Length = ReadUnsignedVarInt(&Pos); } } // namespace zen::horde diff --git a/src/zenhorde/hordeagentmessage.h b/src/zenhorde/hordeagentmessage.h index 38c4375fd..0068fb468 100644 --- a/src/zenhorde/hordeagentmessage.h +++ b/src/zenhorde/hordeagentmessage.h @@ -4,14 +4,22 @@ #include -#include "hordecomputechannel.h" +#include "hordecomputesocket.h" #include #include +#include +#include +#include #include #include +#include #include +namespace asio { +class io_context; +} // namespace asio + namespace zen::horde { /** Agent message types matching the UE EAgentMessageType byte values. @@ -55,45 +63,34 @@ struct BlobRequest size_t Length = 0; }; -/** Channel for sending and receiving agent messages over a ComputeChannel. +/** Handler for async response reads. Receives the message type and a view of the payload data. + * The payload vector is valid until the next AsyncReadResponse call. */ +using AsyncResponseHandler = std::function; + +/** Async channel for sending and receiving agent messages over an AsyncComputeSocket. * - * Implements the Horde agent message protocol, matching the UE - * FAgentMessageChannel serialization format exactly. Messages are framed as - * [type (1B)][payload length (4B)][payload]. Strings use length-prefixed UTF-8; - * integers use variable-length encoding. + * Send methods build messages into vectors and submit them via AsyncComputeSocket. + * Receives are delivered via the socket's FrameHandler callback and queued internally. + * AsyncReadResponse checks the queue and invokes the handler, with optional timeout. * - * The protocol has two directions: - * - Requests (initiator -> remote): Close, Ping, Fork, Attach, UploadFiles, Execute, Blob - * - Responses (remote -> initiator): ReadResponse returns the type, then call the - * appropriate Read* method to parse the payload. + * All operations must be externally serialized (e.g. via the socket's strand). */ -class AgentMessageChannel +class AsyncAgentMessageChannel { public: - explicit AgentMessageChannel(Ref Channel); - ~AgentMessageChannel(); + AsyncAgentMessageChannel(std::shared_ptr Socket, int ChannelId, asio::io_context& IoContext); + ~AsyncAgentMessageChannel(); - AgentMessageChannel(const AgentMessageChannel&) = delete; - AgentMessageChannel& operator=(const AgentMessageChannel&) = delete; + AsyncAgentMessageChannel(const AsyncAgentMessageChannel&) = delete; + AsyncAgentMessageChannel& operator=(const AsyncAgentMessageChannel&) = delete; - // --- Requests (Initiator -> Remote) --- + // --- Requests (fire-and-forget sends) --- - /** Close the channel. */ void Close(); - - /** Send a keepalive ping. */ void Ping(); - - /** Fork communication to a new channel with the given ID and buffer size. */ void Fork(int ChannelId, int BufferSize); - - /** Send an attach request (used during channel setup handshake). */ void Attach(); - - /** Request the remote agent to write files from the given bundle locator. */ void UploadFiles(const char* Path, const char* Locator); - - /** Execute a process on the remote machine. */ void Execute(const char* Exe, const char* const* Args, size_t NumArgs, @@ -101,61 +98,61 @@ public: const char* const* EnvVars, size_t NumEnvVars, ExecuteProcessFlags Flags = ExecuteProcessFlags::None); - - /** Send blob data in response to a ReadBlob request. */ void Blob(const uint8_t* Data, size_t Length); - // --- Responses (Remote -> Initiator) --- - - /** Read the next response message. Returns the message type, or None on timeout. - * After this returns, use GetResponseData()/GetResponseSize() or the typed - * Read* methods to access the payload. */ - AgentMessageType ReadResponse(int32_t TimeoutMs = -1, bool* OutTimedOut = nullptr); + // --- Async response reading --- - const void* GetResponseData() const { return m_ResponseData; } - size_t GetResponseSize() const { return m_ResponseLength; } + /** Read the next response. If a frame is already queued, the handler is posted immediately. + * Otherwise waits up to TimeoutMs for a frame to arrive. On timeout, invokes the handler + * with AgentMessageType::None. */ + void AsyncReadResponse(int32_t TimeoutMs, AsyncResponseHandler Handler); - /** Parse an Exception response payload. */ - void ReadException(ExceptionInfo& Ex); + /** Called by the socket's FrameHandler when a frame arrives for this channel. */ + void OnFrame(std::vector Data); - /** Parse an ExecuteResult response payload. Returns the exit code. */ - int ReadExecuteResult(); + /** Called by the socket's DetachHandler. */ + void OnDetach(); - /** Parse a ReadBlob response payload into a BlobRequest. */ - void ReadBlobRequest(BlobRequest& Req); - -private: - static constexpr size_t MessageHeaderLength = 5; ///< [type(1B)][length(4B)] + /** Returns true if the channel has been detached (connection lost). */ + bool IsDetached() const { return m_Detached; } - Ref m_Channel; + // --- Response parsing helpers --- - uint8_t* m_RequestData = nullptr; - size_t m_RequestSize = 0; - size_t m_MaxRequestSize = 0; + static void ReadException(const uint8_t* Data, size_t Size, ExceptionInfo& Ex); + static int ReadExecuteResult(const uint8_t* Data, size_t Size); + static void ReadBlobRequest(const uint8_t* Data, size_t Size, BlobRequest& Req); - AgentMessageType m_ResponseType = AgentMessageType::None; - const uint8_t* m_ResponseData = nullptr; - size_t m_ResponseLength = 0; +private: + static constexpr size_t MessageHeaderLength = 5; - void CreateMessage(AgentMessageType Type, size_t MaxLength); - void FlushMessage(); + // Message building helpers + std::vector BeginMessage(AgentMessageType Type, size_t ReservePayload); + void FinalizeAndSend(std::vector Msg); - void WriteInt32(int Value); - static int ReadInt32(const uint8_t** Pos); + static void WriteInt32(std::vector& Buf, int Value); + static int ReadInt32(const uint8_t** Pos); - void WriteFixedLengthBytes(const uint8_t* Data, size_t Length); + static void WriteFixedLengthBytes(std::vector& Buf, const uint8_t* Data, size_t Length); static const uint8_t* ReadFixedLengthBytes(const uint8_t** Pos, size_t Length); static size_t MeasureUnsignedVarInt(size_t Value); - void WriteUnsignedVarInt(size_t Value); + static void WriteUnsignedVarInt(std::vector& Buf, size_t Value); static size_t ReadUnsignedVarInt(const uint8_t** Pos); - size_t MeasureString(const char* Text) const; - void WriteString(const char* Text); - void WriteString(std::string_view Text); + static void WriteString(std::vector& Buf, const char* Text); + static void WriteString(std::vector& Buf, std::string_view Text); static std::string_view ReadString(const uint8_t** Pos); - void WriteOptionalString(const char* Text); + static void WriteOptionalString(std::vector& Buf, const char* Text); + + std::shared_ptr m_Socket; + int m_ChannelId; + asio::io_context& m_IoContext; + + std::deque> m_IncomingFrames; + AsyncResponseHandler m_PendingHandler; + std::unique_ptr m_TimeoutTimer; + bool m_Detached = false; }; } // namespace zen::horde diff --git a/src/zenhorde/hordebundle.cpp b/src/zenhorde/hordebundle.cpp index d3974bc28..af6b97e59 100644 --- a/src/zenhorde/hordebundle.cpp +++ b/src/zenhorde/hordebundle.cpp @@ -57,7 +57,7 @@ MeasureVarInt(size_t Value) { return 1; } - return (FloorLog2(static_cast(Value)) / 7) + 1; + return (FloorLog2_64(static_cast(Value)) / 7) + 1; } static void diff --git a/src/zenhorde/hordeclient.cpp b/src/zenhorde/hordeclient.cpp index 0eefc57c6..618a85e0e 100644 --- a/src/zenhorde/hordeclient.cpp +++ b/src/zenhorde/hordeclient.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -14,7 +15,7 @@ ZEN_THIRD_PARTY_INCLUDES_END namespace zen::horde { -HordeClient::HordeClient(const HordeConfig& Config) : m_Config(Config), m_Log(zen::logging::Get("horde.client")) +HordeClient::HordeClient(HordeConfig Config) : m_Config(std::move(Config)), m_Log("horde.client") { } @@ -32,7 +33,11 @@ HordeClient::Initialize() Settings.RetryCount = 1; Settings.ExpectedErrorCodes = {HttpResponseCode::ServiceUnavailable, HttpResponseCode::TooManyRequests}; - if (!m_Config.AuthToken.empty()) + if (m_Config.AccessTokenProvider) + { + Settings.AccessTokenProvider = m_Config.AccessTokenProvider; + } + else if (!m_Config.AuthToken.empty()) { Settings.AccessTokenProvider = [token = m_Config.AuthToken]() -> HttpClientAccessToken { return HttpClientAccessToken(token, HttpClientAccessToken::Clock::now() + std::chrono::hours{24}); @@ -41,7 +46,7 @@ HordeClient::Initialize() m_Http = std::make_unique(m_Config.ServerUrl, Settings); - if (!m_Config.AuthToken.empty()) + if (Settings.AccessTokenProvider) { if (!m_Http->Authenticate()) { @@ -63,24 +68,21 @@ HordeClient::BuildRequestBody() const Requirements["pool"] = m_Config.Pool; } - std::string Condition; -#if ZEN_PLATFORM_WINDOWS ExtendableStringBuilder<256> CondBuf; +#if ZEN_PLATFORM_WINDOWS CondBuf << "(OSFamily == 'Windows' || WineEnabled == '" << (m_Config.AllowWine ? "true" : "false") << "')"; - Condition = std::string(CondBuf); #elif ZEN_PLATFORM_MAC - Condition = "OSFamily == 'MacOS'"; + CondBuf << "OSFamily == 'MacOS'"; #else - Condition = "OSFamily == 'Linux'"; + CondBuf << "OSFamily == 'Linux'"; #endif if (!m_Config.Condition.empty()) { - Condition += " "; - Condition += m_Config.Condition; + CondBuf << " " << m_Config.Condition; } - Requirements["condition"] = Condition; + Requirements["condition"] = std::string(CondBuf); Requirements["exclusive"] = true; json11::Json::object Connection; @@ -157,37 +159,8 @@ HordeClient::ResolveCluster(const std::string& RequestBody, ClusterInfo& OutClus } OutCluster.ClusterId = ClusterIdVal.string_value(); - return true; -} - -bool -HordeClient::ParseHexBytes(std::string_view Hex, uint8_t* Out, size_t OutSize) -{ - if (Hex.size() != OutSize * 2) - { - return false; - } - for (size_t i = 0; i < OutSize; ++i) - { - auto HexToByte = [](char c) -> int { - if (c >= '0' && c <= '9') - return c - '0'; - if (c >= 'a' && c <= 'f') - return c - 'a' + 10; - if (c >= 'A' && c <= 'F') - return c - 'A' + 10; - return -1; - }; - - const int Hi = HexToByte(Hex[i * 2]); - const int Lo = HexToByte(Hex[i * 2 + 1]); - if (Hi < 0 || Lo < 0) - { - return false; - } - Out[i] = static_cast((Hi << 4) | Lo); - } + ZEN_DEBUG("cluster resolution succeeded: clusterId='{}'", OutCluster.ClusterId); return true; } @@ -197,8 +170,6 @@ HordeClient::RequestMachine(const std::string& RequestBody, const std::string& C { ZEN_TRACE_CPU("HordeClient::RequestMachine"); - ZEN_INFO("requesting machine from Horde with cluster '{}'", ClusterId.empty() ? "default" : ClusterId.c_str()); - ExtendableStringBuilder<128> ResourcePath; ResourcePath << "api/v2/compute/" << (ClusterId.empty() ? "default" : ClusterId.c_str()); @@ -324,6 +295,10 @@ HordeClient::RequestMachine(const std::string& RequestBody, const std::string& C { PhysicalCores = static_cast(std::atoi(Prop.c_str() + 14)); } + else if (Prop.starts_with("Pool=")) + { + OutMachine.Pool = Prop.substr(5); + } } } @@ -367,10 +342,12 @@ HordeClient::RequestMachine(const std::string& RequestBody, const std::string& C OutMachine.LeaseId = LeaseIdVal.string_value(); } - ZEN_INFO("Horde machine assigned [{}:{}] cores={} lease={}", + ZEN_INFO("Horde machine assigned [{}:{}] mode={} cores={} pool={} lease={}", OutMachine.GetConnectionAddress(), OutMachine.GetConnectionPort(), + ToString(OutMachine.Mode), OutMachine.LogicalCores, + OutMachine.Pool, OutMachine.LeaseId); return true; diff --git a/src/zenhorde/hordecomputebuffer.cpp b/src/zenhorde/hordecomputebuffer.cpp deleted file mode 100644 index 0d032b5d5..000000000 --- a/src/zenhorde/hordecomputebuffer.cpp +++ /dev/null @@ -1,454 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#include "hordecomputebuffer.h" - -#include -#include -#include -#include -#include - -namespace zen::horde { - -// Simplified ring buffer implementation for in-process use only. -// Uses a single contiguous buffer with write/read cursors and -// mutex+condvar for synchronization. This is simpler than the UE version -// which uses lock-free atomics and shared memory, but sufficient for our -// use case where we're the initiator side of the compute protocol. - -struct ComputeBuffer::Detail : TRefCounted -{ - std::vector Data; - size_t NumChunks = 0; - size_t ChunkLength = 0; - - // Current write state - size_t WriteChunkIdx = 0; - size_t WriteOffset = 0; - bool WriteComplete = false; - - // Current read state - size_t ReadChunkIdx = 0; - size_t ReadOffset = 0; - bool Detached = false; - - // Per-chunk written length - std::vector ChunkWrittenLength; - std::vector ChunkFinished; // Writer moved to next chunk - - std::mutex Mutex; - std::condition_variable ReadCV; ///< Signaled when new data is written or stream completes - std::condition_variable WriteCV; ///< Signaled when reader advances past a chunk, freeing space - - bool HasWriter = false; - bool HasReader = false; - - uint8_t* ChunkPtr(size_t ChunkIdx) { return Data.data() + ChunkIdx * ChunkLength; } - const uint8_t* ChunkPtr(size_t ChunkIdx) const { return Data.data() + ChunkIdx * ChunkLength; } -}; - -// ComputeBuffer - -ComputeBuffer::ComputeBuffer() -{ -} -ComputeBuffer::~ComputeBuffer() -{ -} - -bool -ComputeBuffer::CreateNew(const Params& InParams) -{ - auto* NewDetail = new Detail(); - NewDetail->NumChunks = InParams.NumChunks; - NewDetail->ChunkLength = InParams.ChunkLength; - NewDetail->Data.resize(InParams.NumChunks * InParams.ChunkLength, 0); - NewDetail->ChunkWrittenLength.resize(InParams.NumChunks, 0); - NewDetail->ChunkFinished.resize(InParams.NumChunks, false); - - m_Detail = NewDetail; - return true; -} - -void -ComputeBuffer::Close() -{ - m_Detail = nullptr; -} - -bool -ComputeBuffer::IsValid() const -{ - return static_cast(m_Detail); -} - -ComputeBufferReader -ComputeBuffer::CreateReader() -{ - assert(m_Detail); - m_Detail->HasReader = true; - return ComputeBufferReader(m_Detail); -} - -ComputeBufferWriter -ComputeBuffer::CreateWriter() -{ - assert(m_Detail); - m_Detail->HasWriter = true; - return ComputeBufferWriter(m_Detail); -} - -// ComputeBufferReader - -ComputeBufferReader::ComputeBufferReader() -{ -} -ComputeBufferReader::~ComputeBufferReader() -{ -} - -ComputeBufferReader::ComputeBufferReader(const ComputeBufferReader& Other) = default; -ComputeBufferReader::ComputeBufferReader(ComputeBufferReader&& Other) noexcept = default; -ComputeBufferReader& ComputeBufferReader::operator=(const ComputeBufferReader& Other) = default; -ComputeBufferReader& ComputeBufferReader::operator=(ComputeBufferReader&& Other) noexcept = default; - -ComputeBufferReader::ComputeBufferReader(Ref InDetail) : m_Detail(std::move(InDetail)) -{ -} - -void -ComputeBufferReader::Close() -{ - m_Detail = nullptr; -} - -void -ComputeBufferReader::Detach() -{ - if (m_Detail) - { - std::lock_guard Lock(m_Detail->Mutex); - m_Detail->Detached = true; - m_Detail->ReadCV.notify_all(); - } -} - -bool -ComputeBufferReader::IsValid() const -{ - return static_cast(m_Detail); -} - -bool -ComputeBufferReader::IsComplete() const -{ - if (!m_Detail) - { - return true; - } - std::lock_guard Lock(m_Detail->Mutex); - if (m_Detail->Detached) - { - return true; - } - return m_Detail->WriteComplete && m_Detail->ReadChunkIdx == m_Detail->WriteChunkIdx && - m_Detail->ReadOffset >= m_Detail->ChunkWrittenLength[m_Detail->ReadChunkIdx]; -} - -void -ComputeBufferReader::AdvanceReadPosition(size_t Size) -{ - if (!m_Detail) - { - return; - } - - std::lock_guard Lock(m_Detail->Mutex); - - m_Detail->ReadOffset += Size; - - // Check if we need to move to next chunk - const size_t ReadChunk = m_Detail->ReadChunkIdx; - if (m_Detail->ChunkFinished[ReadChunk] && m_Detail->ReadOffset >= m_Detail->ChunkWrittenLength[ReadChunk]) - { - const size_t NextChunk = (ReadChunk + 1) % m_Detail->NumChunks; - m_Detail->ReadChunkIdx = NextChunk; - m_Detail->ReadOffset = 0; - m_Detail->WriteCV.notify_all(); - } - - m_Detail->ReadCV.notify_all(); -} - -size_t -ComputeBufferReader::GetMaxReadSize() const -{ - if (!m_Detail) - { - return 0; - } - std::lock_guard Lock(m_Detail->Mutex); - const size_t ReadChunk = m_Detail->ReadChunkIdx; - return m_Detail->ChunkWrittenLength[ReadChunk] - m_Detail->ReadOffset; -} - -const uint8_t* -ComputeBufferReader::WaitToRead(size_t MinSize, int TimeoutMs, bool* OutTimedOut) -{ - if (!m_Detail) - { - return nullptr; - } - - std::unique_lock Lock(m_Detail->Mutex); - - auto Predicate = [&]() -> bool { - if (m_Detail->Detached) - { - return true; - } - - const size_t ReadChunk = m_Detail->ReadChunkIdx; - const size_t Available = m_Detail->ChunkWrittenLength[ReadChunk] - m_Detail->ReadOffset; - - if (Available >= MinSize) - { - return true; - } - - // If chunk is finished and we've read everything, try to move to next - if (m_Detail->ChunkFinished[ReadChunk] && m_Detail->ReadOffset >= m_Detail->ChunkWrittenLength[ReadChunk]) - { - if (m_Detail->WriteComplete) - { - return true; // End of stream - } - // Move to next chunk - const size_t NextChunk = (ReadChunk + 1) % m_Detail->NumChunks; - m_Detail->ReadChunkIdx = NextChunk; - m_Detail->ReadOffset = 0; - m_Detail->WriteCV.notify_all(); - return false; // Re-check with new chunk - } - - if (m_Detail->WriteComplete) - { - return true; // End of stream - } - - return false; - }; - - if (TimeoutMs < 0) - { - m_Detail->ReadCV.wait(Lock, Predicate); - } - else - { - if (!m_Detail->ReadCV.wait_for(Lock, std::chrono::milliseconds(TimeoutMs), Predicate)) - { - if (OutTimedOut) - { - *OutTimedOut = true; - } - return nullptr; - } - } - - if (m_Detail->Detached) - { - return nullptr; - } - - const size_t ReadChunk = m_Detail->ReadChunkIdx; - const size_t Available = m_Detail->ChunkWrittenLength[ReadChunk] - m_Detail->ReadOffset; - - if (Available < MinSize) - { - return nullptr; // End of stream - } - - return m_Detail->ChunkPtr(ReadChunk) + m_Detail->ReadOffset; -} - -size_t -ComputeBufferReader::Read(void* Buffer, size_t MaxSize, int TimeoutMs, bool* OutTimedOut) -{ - const uint8_t* Data = WaitToRead(1, TimeoutMs, OutTimedOut); - if (!Data) - { - return 0; - } - - const size_t Available = GetMaxReadSize(); - const size_t ToCopy = std::min(Available, MaxSize); - memcpy(Buffer, Data, ToCopy); - AdvanceReadPosition(ToCopy); - return ToCopy; -} - -// ComputeBufferWriter - -ComputeBufferWriter::ComputeBufferWriter() = default; -ComputeBufferWriter::ComputeBufferWriter(const ComputeBufferWriter& Other) = default; -ComputeBufferWriter::ComputeBufferWriter(ComputeBufferWriter&& Other) noexcept = default; -ComputeBufferWriter::~ComputeBufferWriter() = default; -ComputeBufferWriter& ComputeBufferWriter::operator=(const ComputeBufferWriter& Other) = default; -ComputeBufferWriter& ComputeBufferWriter::operator=(ComputeBufferWriter&& Other) noexcept = default; - -ComputeBufferWriter::ComputeBufferWriter(Ref InDetail) : m_Detail(std::move(InDetail)) -{ -} - -void -ComputeBufferWriter::Close() -{ - if (m_Detail) - { - { - std::lock_guard Lock(m_Detail->Mutex); - if (!m_Detail->WriteComplete) - { - m_Detail->WriteComplete = true; - m_Detail->ReadCV.notify_all(); - } - } - m_Detail = nullptr; - } -} - -bool -ComputeBufferWriter::IsValid() const -{ - return static_cast(m_Detail); -} - -void -ComputeBufferWriter::MarkComplete() -{ - if (m_Detail) - { - std::lock_guard Lock(m_Detail->Mutex); - m_Detail->WriteComplete = true; - m_Detail->ReadCV.notify_all(); - } -} - -void -ComputeBufferWriter::AdvanceWritePosition(size_t Size) -{ - if (!m_Detail || Size == 0) - { - return; - } - - std::lock_guard Lock(m_Detail->Mutex); - const size_t WriteChunk = m_Detail->WriteChunkIdx; - m_Detail->ChunkWrittenLength[WriteChunk] += Size; - m_Detail->WriteOffset += Size; - m_Detail->ReadCV.notify_all(); -} - -size_t -ComputeBufferWriter::GetMaxWriteSize() const -{ - if (!m_Detail) - { - return 0; - } - std::lock_guard Lock(m_Detail->Mutex); - const size_t WriteChunk = m_Detail->WriteChunkIdx; - return m_Detail->ChunkLength - m_Detail->ChunkWrittenLength[WriteChunk]; -} - -size_t -ComputeBufferWriter::GetChunkMaxLength() const -{ - if (!m_Detail) - { - return 0; - } - return m_Detail->ChunkLength; -} - -size_t -ComputeBufferWriter::Write(const void* Buffer, size_t MaxSize, int TimeoutMs) -{ - uint8_t* Dest = WaitToWrite(1, TimeoutMs); - if (!Dest) - { - return 0; - } - - const size_t Available = GetMaxWriteSize(); - const size_t ToCopy = std::min(Available, MaxSize); - memcpy(Dest, Buffer, ToCopy); - AdvanceWritePosition(ToCopy); - return ToCopy; -} - -uint8_t* -ComputeBufferWriter::WaitToWrite(size_t MinSize, int TimeoutMs) -{ - if (!m_Detail) - { - return nullptr; - } - - std::unique_lock Lock(m_Detail->Mutex); - - if (m_Detail->WriteComplete) - { - return nullptr; - } - - const size_t WriteChunk = m_Detail->WriteChunkIdx; - const size_t Available = m_Detail->ChunkLength - m_Detail->ChunkWrittenLength[WriteChunk]; - - // If current chunk has enough space, return pointer - if (Available >= MinSize) - { - return m_Detail->ChunkPtr(WriteChunk) + m_Detail->ChunkWrittenLength[WriteChunk]; - } - - // Current chunk is full - mark it as finished and move to next. - // The writer cannot advance until the reader has fully consumed the next chunk, - // preventing the writer from overwriting data the reader hasn't processed yet. - m_Detail->ChunkFinished[WriteChunk] = true; - m_Detail->ReadCV.notify_all(); - - const size_t NextChunk = (WriteChunk + 1) % m_Detail->NumChunks; - - // Wait until reader has consumed the next chunk - auto Predicate = [&]() -> bool { - // Check if read has moved past this chunk - return m_Detail->ReadChunkIdx != NextChunk || m_Detail->Detached; - }; - - if (TimeoutMs < 0) - { - m_Detail->WriteCV.wait(Lock, Predicate); - } - else - { - if (!m_Detail->WriteCV.wait_for(Lock, std::chrono::milliseconds(TimeoutMs), Predicate)) - { - return nullptr; - } - } - - if (m_Detail->Detached) - { - return nullptr; - } - - // Reset next chunk - m_Detail->ChunkWrittenLength[NextChunk] = 0; - m_Detail->ChunkFinished[NextChunk] = false; - m_Detail->WriteChunkIdx = NextChunk; - m_Detail->WriteOffset = 0; - - return m_Detail->ChunkPtr(NextChunk); -} - -} // namespace zen::horde diff --git a/src/zenhorde/hordecomputebuffer.h b/src/zenhorde/hordecomputebuffer.h deleted file mode 100644 index 64ef91b7a..000000000 --- a/src/zenhorde/hordecomputebuffer.h +++ /dev/null @@ -1,136 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#pragma once - -#include - -#include -#include -#include -#include - -namespace zen::horde { - -class ComputeBufferReader; -class ComputeBufferWriter; - -/** Simplified in-process ring buffer for the Horde compute protocol. - * - * Unlike the UE FComputeBuffer which supports shared-memory and memory-mapped files, - * this implementation uses plain heap-allocated memory since we only need in-process - * communication between channel and transport threads. The buffer is divided into - * fixed-size chunks; readers and writers block when no space is available. - */ -class ComputeBuffer -{ -public: - struct Params - { - size_t NumChunks = 2; - size_t ChunkLength = 512 * 1024; - }; - - ComputeBuffer(); - ~ComputeBuffer(); - - ComputeBuffer(const ComputeBuffer&) = delete; - ComputeBuffer& operator=(const ComputeBuffer&) = delete; - - bool CreateNew(const Params& InParams); - void Close(); - - bool IsValid() const; - - ComputeBufferReader CreateReader(); - ComputeBufferWriter CreateWriter(); - -private: - struct Detail; - Ref m_Detail; - - friend class ComputeBufferReader; - friend class ComputeBufferWriter; -}; - -/** Read endpoint for a ComputeBuffer. - * - * Provides blocking reads from the ring buffer. WaitToRead() returns a pointer - * directly into the buffer memory (zero-copy); the caller must call - * AdvanceReadPosition() after consuming the data. - */ -class ComputeBufferReader -{ -public: - ComputeBufferReader(); - ComputeBufferReader(const ComputeBufferReader&); - ComputeBufferReader(ComputeBufferReader&&) noexcept; - ~ComputeBufferReader(); - - ComputeBufferReader& operator=(const ComputeBufferReader&); - ComputeBufferReader& operator=(ComputeBufferReader&&) noexcept; - - void Close(); - void Detach(); - bool IsValid() const; - bool IsComplete() const; - - void AdvanceReadPosition(size_t Size); - size_t GetMaxReadSize() const; - - /** Copy up to MaxSize bytes from the buffer into Buffer. Blocks until data is available. */ - size_t Read(void* Buffer, size_t MaxSize, int TimeoutMs = -1, bool* OutTimedOut = nullptr); - - /** Wait until at least MinSize bytes are available and return a direct pointer. - * Returns nullptr on timeout or if the writer has completed. */ - const uint8_t* WaitToRead(size_t MinSize, int TimeoutMs = -1, bool* OutTimedOut = nullptr); - -private: - friend class ComputeBuffer; - explicit ComputeBufferReader(Ref InDetail); - - Ref m_Detail; -}; - -/** Write endpoint for a ComputeBuffer. - * - * Provides blocking writes into the ring buffer. WaitToWrite() returns a pointer - * directly into the buffer memory (zero-copy); the caller must call - * AdvanceWritePosition() after filling the data. Call MarkComplete() to signal - * that no more data will be written. - */ -class ComputeBufferWriter -{ -public: - ComputeBufferWriter(); - ComputeBufferWriter(const ComputeBufferWriter&); - ComputeBufferWriter(ComputeBufferWriter&&) noexcept; - ~ComputeBufferWriter(); - - ComputeBufferWriter& operator=(const ComputeBufferWriter&); - ComputeBufferWriter& operator=(ComputeBufferWriter&&) noexcept; - - void Close(); - bool IsValid() const; - - /** Signal that no more data will be written. Unblocks any waiting readers. */ - void MarkComplete(); - - void AdvanceWritePosition(size_t Size); - size_t GetMaxWriteSize() const; - size_t GetChunkMaxLength() const; - - /** Copy up to MaxSize bytes from Buffer into the ring buffer. Blocks until space is available. */ - size_t Write(const void* Buffer, size_t MaxSize, int TimeoutMs = -1); - - /** Wait until at least MinSize bytes of write space are available and return a direct pointer. - * Returns nullptr on timeout. */ - uint8_t* WaitToWrite(size_t MinSize, int TimeoutMs = -1); - -private: - friend class ComputeBuffer; - explicit ComputeBufferWriter(Ref InDetail); - - Ref m_Detail; -}; - -} // namespace zen::horde diff --git a/src/zenhorde/hordecomputechannel.cpp b/src/zenhorde/hordecomputechannel.cpp deleted file mode 100644 index ee2a6f327..000000000 --- a/src/zenhorde/hordecomputechannel.cpp +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#include "hordecomputechannel.h" - -namespace zen::horde { - -ComputeChannel::ComputeChannel(ComputeBufferReader InReader, ComputeBufferWriter InWriter) -: Reader(std::move(InReader)) -, Writer(std::move(InWriter)) -{ -} - -bool -ComputeChannel::IsValid() const -{ - return Reader.IsValid() && Writer.IsValid(); -} - -size_t -ComputeChannel::Send(const void* Data, size_t Size, int TimeoutMs) -{ - return Writer.Write(Data, Size, TimeoutMs); -} - -size_t -ComputeChannel::Recv(void* Data, size_t Size, int TimeoutMs) -{ - return Reader.Read(Data, Size, TimeoutMs); -} - -void -ComputeChannel::MarkComplete() -{ - Writer.MarkComplete(); -} - -} // namespace zen::horde diff --git a/src/zenhorde/hordecomputechannel.h b/src/zenhorde/hordecomputechannel.h deleted file mode 100644 index c1dff20e4..000000000 --- a/src/zenhorde/hordecomputechannel.h +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#pragma once - -#include "hordecomputebuffer.h" - -namespace zen::horde { - -/** Bidirectional communication channel using a pair of compute buffers. - * - * Pairs a ComputeBufferReader (for receiving data) with a ComputeBufferWriter - * (for sending data). Used by ComputeSocket to represent one logical channel - * within a multiplexed connection. - */ -class ComputeChannel : public TRefCounted -{ -public: - ComputeBufferReader Reader; - ComputeBufferWriter Writer; - - ComputeChannel(ComputeBufferReader InReader, ComputeBufferWriter InWriter); - - bool IsValid() const; - - size_t Send(const void* Data, size_t Size, int TimeoutMs = -1); - size_t Recv(void* Data, size_t Size, int TimeoutMs = -1); - - /** Signal that no more data will be sent on this channel. */ - void MarkComplete(); -}; - -} // namespace zen::horde diff --git a/src/zenhorde/hordecomputesocket.cpp b/src/zenhorde/hordecomputesocket.cpp index 6ef67760c..8a6fc40a9 100644 --- a/src/zenhorde/hordecomputesocket.cpp +++ b/src/zenhorde/hordecomputesocket.cpp @@ -6,198 +6,326 @@ namespace zen::horde { -ComputeSocket::ComputeSocket(std::unique_ptr Transport) -: m_Log(zen::logging::Get("horde.socket")) +AsyncComputeSocket::AsyncComputeSocket(std::unique_ptr Transport, asio::io_context& IoContext) +: m_Log(zen::logging::Get("horde.socket.async")) , m_Transport(std::move(Transport)) +, m_Strand(asio::make_strand(IoContext)) +, m_PingTimer(m_Strand) { } -ComputeSocket::~ComputeSocket() +AsyncComputeSocket::~AsyncComputeSocket() { - // Shutdown order matters: first stop the ping thread, then unblock send threads - // by detaching readers, then join send threads, and finally close the transport - // to unblock the recv thread (which is blocked on RecvMessage). - { - std::lock_guard Lock(m_PingMutex); - m_PingShouldStop = true; - m_PingCV.notify_all(); - } - - for (auto& Reader : m_Readers) - { - Reader.Detach(); - } - - for (auto& [Id, Thread] : m_SendThreads) - { - if (Thread.joinable()) - { - Thread.join(); - } - } + Close(); +} - m_Transport->Close(); +void +AsyncComputeSocket::RegisterChannel(int ChannelId, FrameHandler OnFrame, DetachHandler OnDetach) +{ + m_FrameHandlers[ChannelId] = std::move(OnFrame); + m_DetachHandlers[ChannelId] = std::move(OnDetach); +} - if (m_RecvThread.joinable()) - { - m_RecvThread.join(); - } - if (m_PingThread.joinable()) - { - m_PingThread.join(); - } +void +AsyncComputeSocket::StartRecvPump() +{ + StartPingTimer(); + DoRecvHeader(); } -Ref -ComputeSocket::CreateChannel(int ChannelId) +void +AsyncComputeSocket::DoRecvHeader() { - ComputeBuffer::Params Params; + auto Self = shared_from_this(); + m_Transport->AsyncRead(&m_RecvHeader, + sizeof(FrameHeader), + asio::bind_executor(m_Strand, [this, Self](const std::error_code& Ec, size_t /*Bytes*/) { + if (Ec) + { + if (Ec != asio::error::operation_aborted && !m_Closed) + { + ZEN_WARN("recv header error: {}", Ec.message()); + HandleError(); + } + return; + } - ComputeBuffer RecvBuffer; - if (!RecvBuffer.CreateNew(Params)) - { - return {}; - } + if (m_Closed) + { + return; + } - ComputeBuffer SendBuffer; - if (!SendBuffer.CreateNew(Params)) - { - return {}; - } + if (m_RecvHeader.Size >= 0) + { + DoRecvPayload(m_RecvHeader); + } + else if (m_RecvHeader.Size == ControlDetach) + { + if (auto It = m_DetachHandlers.find(m_RecvHeader.Channel); It != m_DetachHandlers.end() && It->second) + { + It->second(); + } + DoRecvHeader(); + } + else if (m_RecvHeader.Size == ControlPing) + { + DoRecvHeader(); + } + else + { + ZEN_WARN("invalid frame header size: {}", m_RecvHeader.Size); + } + })); +} - Ref Channel(new ComputeChannel(RecvBuffer.CreateReader(), SendBuffer.CreateWriter())); +void +AsyncComputeSocket::DoRecvPayload(FrameHeader Header) +{ + auto PayloadBuf = std::make_shared>(static_cast(Header.Size)); + auto Self = shared_from_this(); - // Attach recv buffer writer (transport recv thread writes into this) - { - std::lock_guard Lock(m_WritersMutex); - m_Writers.emplace(ChannelId, RecvBuffer.CreateWriter()); - } + m_Transport->AsyncRead(PayloadBuf->data(), + PayloadBuf->size(), + asio::bind_executor(m_Strand, [this, Self, Header, PayloadBuf](const std::error_code& Ec, size_t /*Bytes*/) { + if (Ec) + { + if (Ec != asio::error::operation_aborted && !m_Closed) + { + ZEN_WARN("recv payload error (channel={}, size={}): {}", Header.Channel, Header.Size, Ec.message()); + HandleError(); + } + return; + } - // Attach send buffer reader (send thread reads from this) - { - ComputeBufferReader Reader = SendBuffer.CreateReader(); - m_Readers.push_back(Reader); - m_SendThreads.emplace(ChannelId, std::thread(&ComputeSocket::SendThreadProc, this, ChannelId, std::move(Reader))); - } + if (m_Closed) + { + return; + } + + if (auto It = m_FrameHandlers.find(Header.Channel); It != m_FrameHandlers.end() && It->second) + { + It->second(std::move(*PayloadBuf)); + } + else + { + ZEN_WARN("recv frame for unknown channel {}", Header.Channel); + } - return Channel; + DoRecvHeader(); + })); } void -ComputeSocket::StartCommunication() +AsyncComputeSocket::AsyncSendFrame(int ChannelId, std::vector Data, SendHandler Handler) { - m_RecvThread = std::thread(&ComputeSocket::RecvThreadProc, this); - m_PingThread = std::thread(&ComputeSocket::PingThreadProc, this); + auto Self = shared_from_this(); + asio::dispatch(m_Strand, [this, Self, ChannelId, Data = std::move(Data), Handler = std::move(Handler)]() mutable { + if (m_Closed) + { + if (Handler) + { + Handler(asio::error::make_error_code(asio::error::operation_aborted)); + } + return; + } + + PendingWrite Write; + Write.Header.Channel = ChannelId; + Write.Header.Size = static_cast(Data.size()); + Write.Data = std::move(Data); + Write.Handler = std::move(Handler); + + m_SendQueue.push_back(std::move(Write)); + if (m_SendQueue.size() == 1) + { + FlushNextSend(); + } + }); } void -ComputeSocket::PingThreadProc() +AsyncComputeSocket::AsyncSendDetach(int ChannelId, SendHandler Handler) { - while (true) - { + auto Self = shared_from_this(); + asio::dispatch(m_Strand, [this, Self, ChannelId, Handler = std::move(Handler)]() mutable { + if (m_Closed) { - std::unique_lock Lock(m_PingMutex); - if (m_PingCV.wait_for(Lock, std::chrono::milliseconds(2000), [this] { return m_PingShouldStop; })) + if (Handler) { - break; + Handler(asio::error::make_error_code(asio::error::operation_aborted)); } + return; } - std::lock_guard Lock(m_SendMutex); - FrameHeader Header; - Header.Channel = 0; - Header.Size = ControlPing; - m_Transport->SendMessage(&Header, sizeof(Header)); - } + PendingWrite Write; + Write.Header.Channel = ChannelId; + Write.Header.Size = ControlDetach; + Write.Handler = std::move(Handler); + + m_SendQueue.push_back(std::move(Write)); + if (m_SendQueue.size() == 1) + { + FlushNextSend(); + } + }); } void -ComputeSocket::RecvThreadProc() +AsyncComputeSocket::FlushNextSend() { - // Writers are cached locally to avoid taking m_WritersMutex on every frame. - // The shared m_Writers map is only accessed when a channel is seen for the first time. - std::unordered_map CachedWriters; + if (m_SendQueue.empty() || m_Closed) + { + return; + } - FrameHeader Header; - while (m_Transport->RecvMessage(&Header, sizeof(Header))) + PendingWrite& Front = m_SendQueue.front(); + + if (Front.Data.empty()) { - if (Header.Size >= 0) - { - // Data frame - auto It = CachedWriters.find(Header.Channel); - if (It == CachedWriters.end()) - { - std::lock_guard Lock(m_WritersMutex); - auto WIt = m_Writers.find(Header.Channel); - if (WIt == m_Writers.end()) - { - ZEN_WARN("recv frame for unknown channel {}", Header.Channel); - // Skip the data - std::vector Discard(Header.Size); - m_Transport->RecvMessage(Discard.data(), Header.Size); - continue; - } - It = CachedWriters.emplace(Header.Channel, WIt->second).first; - } + // Control frame — header only + auto Self = shared_from_this(); + m_Transport->AsyncWrite(&Front.Header, + sizeof(FrameHeader), + asio::bind_executor(m_Strand, [this, Self](const std::error_code& Ec, size_t /*Bytes*/) { + SendHandler Handler = std::move(m_SendQueue.front().Handler); + m_SendQueue.pop_front(); - ComputeBufferWriter& Writer = It->second; - uint8_t* Dest = Writer.WaitToWrite(Header.Size); - if (!Dest || !m_Transport->RecvMessage(Dest, Header.Size)) - { - ZEN_WARN("failed to read frame data (channel={}, size={})", Header.Channel, Header.Size); - return; - } - Writer.AdvanceWritePosition(Header.Size); - } - else if (Header.Size == ControlDetach) - { - // Detach the recv buffer for this channel - CachedWriters.erase(Header.Channel); + if (Handler) + { + Handler(Ec); + } - std::lock_guard Lock(m_WritersMutex); - auto It = m_Writers.find(Header.Channel); - if (It != m_Writers.end()) - { - It->second.MarkComplete(); - m_Writers.erase(It); - } - } - else if (Header.Size == ControlPing) + if (Ec) + { + if (Ec != asio::error::operation_aborted) + { + ZEN_WARN("send error: {}", Ec.message()); + HandleError(); + } + return; + } + + FlushNextSend(); + })); + } + else + { + // Data frame — write header first, then payload + auto Self = shared_from_this(); + m_Transport->AsyncWrite(&Front.Header, + sizeof(FrameHeader), + asio::bind_executor(m_Strand, [this, Self](const std::error_code& Ec, size_t /*Bytes*/) { + if (Ec) + { + SendHandler Handler = std::move(m_SendQueue.front().Handler); + m_SendQueue.pop_front(); + if (Handler) + { + Handler(Ec); + } + if (Ec != asio::error::operation_aborted) + { + ZEN_WARN("send header error: {}", Ec.message()); + HandleError(); + } + return; + } + + PendingWrite& Payload = m_SendQueue.front(); + m_Transport->AsyncWrite( + Payload.Data.data(), + Payload.Data.size(), + asio::bind_executor(m_Strand, [this, Self](const std::error_code& Ec, size_t /*Bytes*/) { + SendHandler Handler = std::move(m_SendQueue.front().Handler); + m_SendQueue.pop_front(); + + if (Handler) + { + Handler(Ec); + } + + if (Ec) + { + if (Ec != asio::error::operation_aborted) + { + ZEN_WARN("send payload error: {}", Ec.message()); + HandleError(); + } + return; + } + + FlushNextSend(); + })); + })); + } +} + +void +AsyncComputeSocket::StartPingTimer() +{ + if (m_Closed) + { + return; + } + + m_PingTimer.expires_after(std::chrono::seconds(2)); + + auto Self = shared_from_this(); + m_PingTimer.async_wait(asio::bind_executor(m_Strand, [this, Self](const asio::error_code& Ec) { + if (Ec || m_Closed) { - // Ping response - ignore + return; } - else + + // Enqueue a ping control frame + PendingWrite Write; + Write.Header.Channel = 0; + Write.Header.Size = ControlPing; + + m_SendQueue.push_back(std::move(Write)); + if (m_SendQueue.size() == 1) { - ZEN_WARN("invalid frame header size: {}", Header.Size); - return; + FlushNextSend(); } - } + + StartPingTimer(); + })); } void -ComputeSocket::SendThreadProc(int Channel, ComputeBufferReader Reader) +AsyncComputeSocket::HandleError() { - // Each channel has its own send thread. All send threads share m_SendMutex - // to serialize writes to the transport, since TCP requires atomic frame writes. - FrameHeader Header; - Header.Channel = Channel; + if (m_Closed) + { + return; + } + + Close(); - const uint8_t* Data; - while ((Data = Reader.WaitToRead(1)) != nullptr) + // Notify all channels that the connection is gone so agents can clean up + for (auto& [ChannelId, Handler] : m_DetachHandlers) { - std::lock_guard Lock(m_SendMutex); + if (Handler) + { + Handler(); + } + } +} - Header.Size = static_cast(Reader.GetMaxReadSize()); - m_Transport->SendMessage(&Header, sizeof(Header)); - m_Transport->SendMessage(Data, Header.Size); - Reader.AdvanceReadPosition(Header.Size); +void +AsyncComputeSocket::Close() +{ + if (m_Closed) + { + return; } - if (Reader.IsComplete()) + m_Closed = true; + m_PingTimer.cancel(); + + if (m_Transport) { - std::lock_guard Lock(m_SendMutex); - Header.Size = ControlDetach; - m_Transport->SendMessage(&Header, sizeof(Header)); + m_Transport->Close(); } } diff --git a/src/zenhorde/hordecomputesocket.h b/src/zenhorde/hordecomputesocket.h index 0c3cb4195..45b3418b7 100644 --- a/src/zenhorde/hordecomputesocket.h +++ b/src/zenhorde/hordecomputesocket.h @@ -2,45 +2,69 @@ #pragma once -#include "hordecomputebuffer.h" -#include "hordecomputechannel.h" #include "hordetransport.h" #include -#include +ZEN_THIRD_PARTY_INCLUDES_START +#include +ZEN_THIRD_PARTY_INCLUDES_END + +#if ZEN_PLATFORM_WINDOWS +# undef SendMessage +#endif + +#include +#include #include -#include -#include +#include #include #include namespace zen::horde { -/** Multiplexed socket that routes data between multiple channels over a single transport. +class AsyncComputeTransport; + +/** Handler called when a data frame arrives for a channel. */ +using FrameHandler = std::function Data)>; + +/** Handler called when a channel is detached by the remote peer. */ +using DetachHandler = std::function; + +/** Handler for async send completion. */ +using SendHandler = std::function; + +/** Async multiplexed socket that routes data between channels over a single transport. * - * Each channel is identified by an integer ID and backed by a pair of ComputeBuffers. - * A recv thread demultiplexes incoming frames to channel-specific buffers, while - * per-channel send threads multiplex outgoing data onto the shared transport. + * Uses an async recv pump, a serialized send queue, and a periodic ping timer — + * all running on a shared io_context. * - * Wire format per frame: [channelId (4B)][size (4B)][data] - * Control messages use negative sizes: -2 = detach (channel closed), -3 = ping. + * Wire format per frame: [channelId(4B)][size(4B)][data]. + * Control messages use negative sizes: -2 = detach, -3 = ping. */ -class ComputeSocket +class AsyncComputeSocket : public std::enable_shared_from_this { public: - explicit ComputeSocket(std::unique_ptr Transport); - ~ComputeSocket(); + AsyncComputeSocket(std::unique_ptr Transport, asio::io_context& IoContext); + ~AsyncComputeSocket(); + + AsyncComputeSocket(const AsyncComputeSocket&) = delete; + AsyncComputeSocket& operator=(const AsyncComputeSocket&) = delete; - ComputeSocket(const ComputeSocket&) = delete; - ComputeSocket& operator=(const ComputeSocket&) = delete; + /** Register callbacks for a channel. Must be called before StartRecvPump(). */ + void RegisterChannel(int ChannelId, FrameHandler OnFrame, DetachHandler OnDetach); - /** Create a channel with the given ID. - * Allocates anonymous in-process buffers and spawns a send thread for the channel. */ - Ref CreateChannel(int ChannelId); + /** Begin the async recv pump and ping timer. */ + void StartRecvPump(); - /** Start the recv pump and ping threads. Must be called after all channels are created. */ - void StartCommunication(); + /** Enqueue a data frame for async transmission. */ + void AsyncSendFrame(int ChannelId, std::vector Data, SendHandler Handler = {}); + + /** Send a control frame (detach) for a channel. */ + void AsyncSendDetach(int ChannelId, SendHandler Handler = {}); + + /** Close the transport and cancel all pending operations. */ + void Close(); private: struct FrameHeader @@ -49,31 +73,35 @@ private: int32_t Size = 0; }; + struct PendingWrite + { + FrameHeader Header; + std::vector Data; + SendHandler Handler; + }; + static constexpr int32_t ControlDetach = -2; static constexpr int32_t ControlPing = -3; LoggerRef Log() { return m_Log; } - void RecvThreadProc(); - void SendThreadProc(int Channel, ComputeBufferReader Reader); - void PingThreadProc(); - - LoggerRef m_Log; - std::unique_ptr m_Transport; - std::mutex m_SendMutex; ///< Serializes writes to the transport - - std::mutex m_WritersMutex; - std::unordered_map m_Writers; ///< Recv-side: writers keyed by channel ID + void DoRecvHeader(); + void DoRecvPayload(FrameHeader Header); + void FlushNextSend(); + void StartPingTimer(); + void HandleError(); - std::vector m_Readers; ///< Send-side: readers for join on destruction - std::unordered_map m_SendThreads; ///< One send thread per channel + LoggerRef m_Log; + std::unique_ptr m_Transport; + asio::strand m_Strand; + asio::steady_timer m_PingTimer; - std::thread m_RecvThread; - std::thread m_PingThread; + std::unordered_map m_FrameHandlers; + std::unordered_map m_DetachHandlers; - bool m_PingShouldStop = false; - std::mutex m_PingMutex; - std::condition_variable m_PingCV; + FrameHeader m_RecvHeader; + std::deque m_SendQueue; + bool m_Closed = false; }; } // namespace zen::horde diff --git a/src/zenhorde/hordeconfig.cpp b/src/zenhorde/hordeconfig.cpp index 2dca228d9..9f6125c64 100644 --- a/src/zenhorde/hordeconfig.cpp +++ b/src/zenhorde/hordeconfig.cpp @@ -1,5 +1,7 @@ // Copyright Epic Games, Inc. All Rights Reserved. +#include +#include #include namespace zen::horde { @@ -9,12 +11,14 @@ HordeConfig::Validate() const { if (ServerUrl.empty()) { + ZEN_WARN("Horde server URL is not configured"); return false; } // Relay mode implies AES encryption if (Mode == ConnectionMode::Relay && EncryptionMode != Encryption::AES) { + ZEN_WARN("Horde relay mode requires AES encryption, but encryption is set to '{}'", ToString(EncryptionMode)); return false; } @@ -52,37 +56,39 @@ ToString(Encryption Enc) bool FromString(ConnectionMode& OutMode, std::string_view Str) { - if (Str == "direct") + if (StrCaseCompare(Str, "direct") == 0) { OutMode = ConnectionMode::Direct; return true; } - if (Str == "tunnel") + if (StrCaseCompare(Str, "tunnel") == 0) { OutMode = ConnectionMode::Tunnel; return true; } - if (Str == "relay") + if (StrCaseCompare(Str, "relay") == 0) { OutMode = ConnectionMode::Relay; return true; } + ZEN_WARN("unrecognized Horde connection mode: '{}'", Str); return false; } bool FromString(Encryption& OutEnc, std::string_view Str) { - if (Str == "none") + if (StrCaseCompare(Str, "none") == 0) { OutEnc = Encryption::None; return true; } - if (Str == "aes") + if (StrCaseCompare(Str, "aes") == 0) { OutEnc = Encryption::AES; return true; } + ZEN_WARN("unrecognized Horde encryption mode: '{}'", Str); return false; } diff --git a/src/zenhorde/hordeprovisioner.cpp b/src/zenhorde/hordeprovisioner.cpp index f88c95da2..b08544d1a 100644 --- a/src/zenhorde/hordeprovisioner.cpp +++ b/src/zenhorde/hordeprovisioner.cpp @@ -6,49 +6,82 @@ #include "hordeagent.h" #include "hordebundle.h" +#include #include #include #include #include #include +#include +#include +ZEN_THIRD_PARTY_INCLUDES_START +#include +ZEN_THIRD_PARTY_INCLUDES_END + +#include #include #include namespace zen::horde { -struct HordeProvisioner::AgentWrapper -{ - std::thread Thread; - std::atomic ShouldExit{false}; -}; - HordeProvisioner::HordeProvisioner(const HordeConfig& Config, const std::filesystem::path& BinariesPath, const std::filesystem::path& WorkingDir, - std::string_view OrchestratorEndpoint) + std::string_view OrchestratorEndpoint, + std::string_view CoordinatorSession, + bool CleanStart, + std::string_view TraceHost) : m_Config(Config) , m_BinariesPath(BinariesPath) , m_WorkingDir(WorkingDir) , m_OrchestratorEndpoint(OrchestratorEndpoint) +, m_CoordinatorSession(CoordinatorSession) +, m_CleanStart(CleanStart) +, m_TraceHost(TraceHost) , m_Log(zen::logging::Get("horde.provisioner")) { + m_IoContext = std::make_unique(); + + auto Work = asio::make_work_guard(*m_IoContext); + for (int i = 0; i < IoThreadCount; ++i) + { + m_IoThreads.emplace_back([this, i, Work] { + zen::SetCurrentThreadName(fmt::format("horde_io_{}", i)); + m_IoContext->run(); + }); + } } HordeProvisioner::~HordeProvisioner() { - std::lock_guard Lock(m_AgentsLock); - for (auto& Agent : m_Agents) + m_AskForAgents.store(false); + + // Shut down async agents and io_context { - Agent->ShouldExit.store(true); + std::lock_guard Lock(m_AsyncAgentsLock); + for (auto& Entry : m_AsyncAgents) + { + Entry.Agent->Cancel(); + } + m_AsyncAgents.clear(); } - for (auto& Agent : m_Agents) + + m_IoContext->stop(); + + for (auto& Thread : m_IoThreads) { - if (Agent->Thread.joinable()) + if (Thread.joinable()) { - Agent->Thread.join(); + Thread.join(); } } + + // Wait for all pool work items to finish before destroying members they reference + if (m_PendingWorkItems.load() > 0) + { + m_AllWorkDone.Wait(); + } } void @@ -56,9 +89,23 @@ HordeProvisioner::SetTargetCoreCount(uint32_t Count) { ZEN_TRACE_CPU("HordeProvisioner::SetTargetCoreCount"); - m_TargetCoreCount.store(std::min(Count, static_cast(m_Config.MaxCores))); + const uint32_t ClampedCount = std::min(Count, static_cast(m_Config.MaxCores)); + const uint32_t PreviousTarget = m_TargetCoreCount.exchange(ClampedCount); + + if (ClampedCount != PreviousTarget) + { + ZEN_INFO("target core count changed: {} -> {} (active={}, estimated={})", + PreviousTarget, + ClampedCount, + m_ActiveCoreCount.load(), + m_EstimatedCoreCount.load()); + } - while (m_EstimatedCoreCount.load() < m_TargetCoreCount.load()) + // Only provision if the gap is at least one agent-sized chunk. Without + // this, draining a 32-core agent to cover a 28-core excess would leave a + // 4-core gap that triggers a 32-core provision, which triggers another + // drain, ad infinitum. + while (m_EstimatedCoreCount.load() + EstimatedCoresPerAgent <= m_TargetCoreCount.load()) { if (!m_AskForAgents.load()) { @@ -67,21 +114,108 @@ HordeProvisioner::SetTargetCoreCount(uint32_t Count) RequestAgent(); } - // Clean up finished agent threads - std::lock_guard Lock(m_AgentsLock); - for (auto It = m_Agents.begin(); It != m_Agents.end();) + // Scale down async agents { - if ((*It)->ShouldExit.load()) + std::lock_guard AsyncLock(m_AsyncAgentsLock); + + uint32_t AsyncActive = m_ActiveCoreCount.load(); + uint32_t AsyncTarget = m_TargetCoreCount.load(); + + uint32_t AlreadyDrainingCores = 0; + for (const auto& Entry : m_AsyncAgents) { - if ((*It)->Thread.joinable()) + if (Entry.Draining) { - (*It)->Thread.join(); + AlreadyDrainingCores += Entry.CoreCount; } - It = m_Agents.erase(It); } - else + + uint32_t EffectiveAsync = (AsyncActive > AlreadyDrainingCores) ? AsyncActive - AlreadyDrainingCores : 0; + + if (EffectiveAsync > AsyncTarget) { - ++It; + struct Candidate + { + AsyncAgentEntry* Entry; + int Workload; + }; + std::vector Candidates; + + for (auto& Entry : m_AsyncAgents) + { + if (Entry.Draining || Entry.RemoteEndpoint.empty()) + { + continue; + } + + int Workload = 0; + bool Reachable = false; + HttpClientSettings Settings; + Settings.LogCategory = "horde.drain"; + Settings.ConnectTimeout = std::chrono::milliseconds{2000}; + Settings.Timeout = std::chrono::milliseconds{3000}; + try + { + HttpClient Client(Entry.RemoteEndpoint, Settings); + HttpClient::Response Resp = Client.Get("/compute/session/status"); + if (Resp.IsSuccess()) + { + CbObject Status = Resp.AsObject(); + Workload = Status["actions_pending"].AsInt32(0) + Status["actions_running"].AsInt32(0); + Reachable = true; + } + } + catch (const std::exception& Ex) + { + ZEN_DEBUG("agent lease={} not yet reachable for drain: {}", Entry.LeaseId, Ex.what()); + } + + if (Reachable) + { + Candidates.push_back({&Entry, Workload}); + } + } + + const uint32_t ExcessCores = EffectiveAsync - AsyncTarget; + uint32_t CoresDrained = 0; + + while (CoresDrained < ExcessCores && !Candidates.empty()) + { + const uint32_t Remaining = ExcessCores - CoresDrained; + + Candidates.erase(std::remove_if(Candidates.begin(), + Candidates.end(), + [Remaining](const Candidate& C) { return C.Entry->CoreCount > Remaining; }), + Candidates.end()); + + if (Candidates.empty()) + { + break; + } + + Candidate* Best = &Candidates[0]; + for (auto& C : Candidates) + { + if (C.Entry->CoreCount > Best->Entry->CoreCount || + (C.Entry->CoreCount == Best->Entry->CoreCount && C.Workload < Best->Workload)) + { + Best = &C; + } + } + + ZEN_INFO("draining async agent lease={} ({} cores, workload={})", + Best->Entry->LeaseId, + Best->Entry->CoreCount, + Best->Workload); + + DrainAsyncAgent(*Best->Entry); + CoresDrained += Best->Entry->CoreCount; + + AsyncAgentEntry* Drained = Best->Entry; + Candidates.erase( + std::remove_if(Candidates.begin(), Candidates.end(), [Drained](const Candidate& C) { return C.Entry == Drained; }), + Candidates.end()); + } } } } @@ -101,266 +235,380 @@ HordeProvisioner::GetStats() const uint32_t HordeProvisioner::GetAgentCount() const { - std::lock_guard Lock(m_AgentsLock); - return static_cast(m_Agents.size()); + std::lock_guard Lock(m_AsyncAgentsLock); + return static_cast(m_AsyncAgents.size()); } -void -HordeProvisioner::RequestAgent() +compute::AgentProvisioningStatus +HordeProvisioner::GetAgentStatus(std::string_view WorkerId) const { - m_EstimatedCoreCount.fetch_add(EstimatedCoresPerAgent); + // Worker IDs are "horde-{LeaseId}" — strip the prefix to match lease ID + constexpr std::string_view Prefix = "horde-"; + if (!WorkerId.starts_with(Prefix)) + { + return compute::AgentProvisioningStatus::Unknown; + } + std::string_view LeaseId = WorkerId.substr(Prefix.size()); - std::lock_guard Lock(m_AgentsLock); + std::lock_guard AsyncLock(m_AsyncAgentsLock); + for (const auto& Entry : m_AsyncAgents) + { + if (Entry.LeaseId == LeaseId) + { + if (Entry.Draining) + { + return compute::AgentProvisioningStatus::Draining; + } + return compute::AgentProvisioningStatus::Active; + } + } - auto Wrapper = std::make_unique(); - AgentWrapper& Ref = *Wrapper; - Wrapper->Thread = std::thread([this, &Ref] { ThreadAgent(Ref); }); + // Check recently-drained agents that have already been cleaned up + std::string WorkerIdStr(WorkerId); + if (m_RecentlyDrainedWorkerIds.erase(WorkerIdStr) > 0) + { + return compute::AgentProvisioningStatus::Draining; + } - m_Agents.push_back(std::move(Wrapper)); + return compute::AgentProvisioningStatus::Unknown; } -void -HordeProvisioner::ThreadAgent(AgentWrapper& Wrapper) +std::vector +HordeProvisioner::BuildAgentArgs(const MachineInfo& Machine) const { - ZEN_TRACE_CPU("HordeProvisioner::ThreadAgent"); + std::vector Args; + Args.emplace_back("compute"); + Args.emplace_back("--http=asio"); + Args.push_back(fmt::format("--port={}", m_Config.ZenServicePort)); + Args.emplace_back("--data-dir=%UE_HORDE_SHARED_DIR%\\zen"); - static std::atomic ThreadIndex{0}; - const uint32_t CurrentIndex = ThreadIndex.fetch_add(1); + if (m_CleanStart) + { + Args.emplace_back("--clean"); + } - zen::SetCurrentThreadName(fmt::format("horde_agent_{}", CurrentIndex)); + if (!m_OrchestratorEndpoint.empty()) + { + ExtendableStringBuilder<256> CoordArg; + CoordArg << "--coordinator-endpoint=" << m_OrchestratorEndpoint; + Args.emplace_back(CoordArg.ToView()); + } - std::unique_ptr Agent; - uint32_t MachineCoreCount = 0; + { + ExtendableStringBuilder<128> IdArg; + IdArg << "--instance-id=horde-" << Machine.LeaseId; + Args.emplace_back(IdArg.ToView()); + } - auto _ = MakeGuard([&] { - if (Agent) - { - Agent->CloseConnection(); - } - Wrapper.ShouldExit.store(true); - }); + if (!m_CoordinatorSession.empty()) + { + ExtendableStringBuilder<128> SessionArg; + SessionArg << "--coordinator-session=" << m_CoordinatorSession; + Args.emplace_back(SessionArg.ToView()); + } + if (!m_TraceHost.empty()) { - // EstimatedCoreCount is incremented speculatively when the agent is requested - // (in RequestAgent) so that SetTargetCoreCount doesn't over-provision. - auto $ = MakeGuard([&] { m_EstimatedCoreCount.fetch_sub(EstimatedCoresPerAgent); }); + ExtendableStringBuilder<128> TraceArg; + TraceArg << "--tracehost=" << m_TraceHost; + Args.emplace_back(TraceArg.ToView()); + } + // In relay mode, the remote zenserver's local address is not reachable from the + // orchestrator. Pass the relay-visible endpoint so it announces the correct URL. + if (Machine.Mode == ConnectionMode::Relay) + { + const auto [Addr, Port] = Machine.GetZenServiceEndpoint(m_Config.ZenServicePort); + if (Addr.find(':') != std::string::npos) + { + Args.push_back(fmt::format("--announce-url=http://[{}]:{}", Addr, Port)); + } + else { - ZEN_TRACE_CPU("HordeProvisioner::CreateBundles"); + Args.push_back(fmt::format("--announce-url=http://{}:{}", Addr, Port)); + } + } - std::lock_guard BundleLock(m_BundleLock); + return Args; +} - if (!m_BundlesCreated) - { - const std::filesystem::path OutputDir = m_WorkingDir / "horde_bundles"; +bool +HordeProvisioner::InitializeHordeClient() +{ + ZEN_TRACE_CPU("HordeProvisioner::InitializeHordeClient"); + + std::lock_guard BundleLock(m_BundleLock); - std::vector Files; + if (!m_BundlesCreated) + { + const std::filesystem::path OutputDir = m_WorkingDir / "horde_bundles"; + + std::vector Files; #if ZEN_PLATFORM_WINDOWS - Files.emplace_back(m_BinariesPath / "zenserver.exe", false); + Files.emplace_back(m_BinariesPath / "zenserver.exe", false); + Files.emplace_back(m_BinariesPath / "zenserver.pdb", true); #elif ZEN_PLATFORM_LINUX - Files.emplace_back(m_BinariesPath / "zenserver", false); - Files.emplace_back(m_BinariesPath / "zenserver.debug", true); + Files.emplace_back(m_BinariesPath / "zenserver", false); + Files.emplace_back(m_BinariesPath / "zenserver.debug", true); #elif ZEN_PLATFORM_MAC - Files.emplace_back(m_BinariesPath / "zenserver", false); + Files.emplace_back(m_BinariesPath / "zenserver", false); #endif - BundleResult Result; - if (!BundleCreator::CreateBundle(Files, OutputDir, Result)) - { - ZEN_WARN("failed to create bundle, cannot provision any agents!"); - m_AskForAgents.store(false); - return; - } - - m_Bundles.emplace_back(Result.Locator, Result.BundleDir); - m_BundlesCreated = true; - } - - if (!m_HordeClient) - { - m_HordeClient = std::make_unique(m_Config); - if (!m_HordeClient->Initialize()) - { - ZEN_WARN("failed to initialize Horde HTTP client, cannot provision any agents!"); - m_AskForAgents.store(false); - return; - } - } + BundleResult Result; + if (!BundleCreator::CreateBundle(Files, OutputDir, Result)) + { + ZEN_WARN("failed to create bundle, cannot provision any agents!"); + m_AskForAgents.store(false); + return false; } - if (!m_AskForAgents.load()) + m_Bundles.emplace_back(Result.Locator, Result.BundleDir); + m_BundlesCreated = true; + } + + if (!m_HordeClient) + { + m_HordeClient = std::make_unique(m_Config); + if (!m_HordeClient->Initialize()) { - return; + ZEN_WARN("failed to initialize Horde HTTP client, cannot provision any agents!"); + m_AskForAgents.store(false); + return false; } + } - m_AgentsRequesting.fetch_add(1); - auto ReqGuard = MakeGuard([this] { m_AgentsRequesting.fetch_sub(1); }); + return true; +} - // Simple backoff: if the last machine request failed, wait up to 5 seconds - // before trying again. - // - // Note however that it's possible that multiple threads enter this code at - // the same time if multiple agents are requested at once, and they will all - // see the same last failure time and back off accordingly. We might want to - // use a semaphore or similar to limit the number of concurrent requests. +void +HordeProvisioner::RequestAgent() +{ + m_EstimatedCoreCount.fetch_add(EstimatedCoresPerAgent); - if (const uint64_t LastFail = m_LastRequestFailTime.load(); LastFail != 0) - { - auto Now = static_cast(std::chrono::steady_clock::now().time_since_epoch().count()); - const uint64_t ElapsedNs = Now - LastFail; - const uint64_t ElapsedMs = ElapsedNs / 1'000'000; - if (ElapsedMs < 5000) - { - const uint64_t WaitMs = 5000 - ElapsedMs; - for (uint64_t Waited = 0; Waited < WaitMs && !Wrapper.ShouldExit.load(); Waited += 100) - { - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - } + if (m_PendingWorkItems.fetch_add(1) == 0) + { + m_AllWorkDone.Reset(); + } - if (Wrapper.ShouldExit.load()) + GetSmallWorkerPool(EWorkloadType::Background) + .ScheduleWork( + [this] { + ProvisionAgent(); + if (m_PendingWorkItems.fetch_sub(1) == 1) { - return; + m_AllWorkDone.Set(); } - } - } - - if (m_ActiveCoreCount.load() >= m_TargetCoreCount.load()) - { - return; - } - - std::string RequestBody = m_HordeClient->BuildRequestBody(); + }, + WorkerThreadPool::EMode::EnableBacklog); +} - // Resolve cluster if needed - std::string ClusterId = m_Config.Cluster; - if (ClusterId == HordeConfig::ClusterAuto) - { - ClusterInfo Cluster; - if (!m_HordeClient->ResolveCluster(RequestBody, Cluster)) - { - ZEN_WARN("failed to resolve cluster"); - m_LastRequestFailTime.store(static_cast(std::chrono::steady_clock::now().time_since_epoch().count())); - return; - } - ClusterId = Cluster.ClusterId; - } +void +HordeProvisioner::ProvisionAgent() +{ + ZEN_TRACE_CPU("HordeProvisioner::ProvisionAgent"); - MachineInfo Machine; - if (!m_HordeClient->RequestMachine(RequestBody, ClusterId, /* out */ Machine) || !Machine.IsValid()) - { - m_LastRequestFailTime.store(static_cast(std::chrono::steady_clock::now().time_since_epoch().count())); - return; - } + // EstimatedCoreCount is incremented speculatively when the agent is requested + // (in RequestAgent) so that SetTargetCoreCount doesn't over-provision. + auto $ = MakeGuard([&] { m_EstimatedCoreCount.fetch_sub(EstimatedCoresPerAgent); }); - m_LastRequestFailTime.store(0); + if (!InitializeHordeClient()) + { + return; + } - if (Wrapper.ShouldExit.load()) - { - return; - } + if (!m_AskForAgents.load()) + { + return; + } - // Connect to agent and perform handshake - Agent = std::make_unique(Machine); - if (!Agent->IsValid()) - { - ZEN_WARN("agent creation failed for {}:{}", Machine.GetConnectionAddress(), Machine.GetConnectionPort()); - return; - } + m_AgentsRequesting.fetch_add(1); + auto ReqGuard = MakeGuard([this] { m_AgentsRequesting.fetch_sub(1); }); - if (!Agent->BeginCommunication()) - { - ZEN_WARN("BeginCommunication failed"); - return; - } + // Simple backoff: if the last machine request failed, wait up to 5 seconds + // before trying again. + // + // Note however that it's possible that multiple threads enter this code at + // the same time if multiple agents are requested at once, and they will all + // see the same last failure time and back off accordingly. We might want to + // use a semaphore or similar to limit the number of concurrent requests. - for (auto& [Locator, BundleDir] : m_Bundles) + if (const uint64_t LastFail = m_LastRequestFailTime.load(); LastFail != 0) + { + auto Now = static_cast(std::chrono::steady_clock::now().time_since_epoch().count()); + const uint64_t ElapsedNs = Now - LastFail; + const uint64_t ElapsedMs = ElapsedNs / 1'000'000; + if (ElapsedMs < 5000) { - if (Wrapper.ShouldExit.load()) + const uint64_t WaitMs = 5000 - ElapsedMs; + for (uint64_t Waited = 0; Waited < WaitMs && !!m_AskForAgents.load(); Waited += 100) { - return; + std::this_thread::sleep_for(std::chrono::milliseconds(100)); } - if (!Agent->UploadBinaries(BundleDir, Locator)) + if (!m_AskForAgents.load()) { - ZEN_WARN("UploadBinaries failed"); return; } } + } - if (Wrapper.ShouldExit.load()) + if (m_ActiveCoreCount.load() >= m_TargetCoreCount.load()) + { + return; + } + + std::string RequestBody = m_HordeClient->BuildRequestBody(); + + // Resolve cluster if needed + std::string ClusterId = m_Config.Cluster; + if (ClusterId == HordeConfig::ClusterAuto) + { + ClusterInfo Cluster; + if (!m_HordeClient->ResolveCluster(RequestBody, Cluster)) { + ZEN_WARN("failed to resolve cluster"); + m_LastRequestFailTime.store(static_cast(std::chrono::steady_clock::now().time_since_epoch().count())); return; } + ClusterId = Cluster.ClusterId; + } - // Build command line for remote zenserver - std::vector ArgStrings; - ArgStrings.push_back("compute"); - ArgStrings.push_back("--http=asio"); + ZEN_INFO("requesting machine from Horde (cluster='{}', cores={}/{})", + ClusterId.empty() ? "default" : ClusterId.c_str(), + m_ActiveCoreCount.load(), + m_TargetCoreCount.load()); - // TEMP HACK - these should be made fully dynamic - // these are currently here to allow spawning the compute agent locally - // for debugging purposes (i.e with a local Horde Server+Agent setup) - ArgStrings.push_back(fmt::format("--port={}", m_Config.ZenServicePort)); - ArgStrings.push_back("--data-dir=c:\\temp\\123"); + MachineInfo Machine; + if (!m_HordeClient->RequestMachine(RequestBody, ClusterId, /* out */ Machine) || !Machine.IsValid()) + { + m_LastRequestFailTime.store(static_cast(std::chrono::steady_clock::now().time_since_epoch().count())); + return; + } - if (!m_OrchestratorEndpoint.empty()) - { - ExtendableStringBuilder<256> CoordArg; - CoordArg << "--coordinator-endpoint=" << m_OrchestratorEndpoint; - ArgStrings.emplace_back(CoordArg.ToView()); - } + m_LastRequestFailTime.store(0); - { - ExtendableStringBuilder<128> IdArg; - IdArg << "--instance-id=horde-" << Machine.LeaseId; - ArgStrings.emplace_back(IdArg.ToView()); - } + if (!m_AskForAgents.load()) + { + return; + } - std::vector Args; - Args.reserve(ArgStrings.size()); - for (const std::string& Arg : ArgStrings) - { - Args.push_back(Arg.c_str()); - } + AsyncAgentConfig AgentConfig; + AgentConfig.Machine = Machine; + AgentConfig.Bundles = m_Bundles; + AgentConfig.Args = BuildAgentArgs(Machine); #if ZEN_PLATFORM_WINDOWS - const bool UseWine = !Machine.IsWindows; - const char* AppName = "zenserver.exe"; + AgentConfig.UseWine = !Machine.IsWindows; + AgentConfig.Executable = "zenserver.exe"; #else - const bool UseWine = false; - const char* AppName = "zenserver"; + AgentConfig.UseWine = false; + AgentConfig.Executable = "zenserver"; #endif - Agent->Execute(AppName, Args.data(), Args.size(), nullptr, nullptr, 0, UseWine); + auto AsyncAgent = std::make_shared(*m_IoContext); + + AsyncAgentEntry Entry; + Entry.Agent = AsyncAgent; + Entry.LeaseId = Machine.LeaseId; + Entry.CoreCount = Machine.LogicalCores; - ZEN_INFO("remote execution started on [{}:{}] lease={}", - Machine.GetConnectionAddress(), - Machine.GetConnectionPort(), - Machine.LeaseId); + const auto [EndpointAddr, EndpointPort] = Machine.GetZenServiceEndpoint(m_Config.ZenServicePort); + if (EndpointAddr.find(':') != std::string::npos) + { + Entry.RemoteEndpoint = fmt::format("http://[{}]:{}", EndpointAddr, EndpointPort); + } + else + { + Entry.RemoteEndpoint = fmt::format("http://{}:{}", EndpointAddr, EndpointPort); + } - MachineCoreCount = Machine.LogicalCores; - m_EstimatedCoreCount.fetch_add(MachineCoreCount); - m_ActiveCoreCount.fetch_add(MachineCoreCount); - m_AgentsActive.fetch_add(1); + { + std::lock_guard Lock(m_AsyncAgentsLock); + m_AsyncAgents.push_back(std::move(Entry)); } - // Agent poll loop + AsyncAgent->Start(std::move(AgentConfig), [this, AsyncAgent](const AsyncAgentResult& Result) { + if (Result.CoreCount > 0) + { + // Only subtract estimated cores if not already subtracted by DrainAsyncAgent + bool WasDraining = false; + { + std::lock_guard Lock(m_AsyncAgentsLock); + for (const auto& Entry : m_AsyncAgents) + { + if (Entry.Agent == AsyncAgent) + { + WasDraining = Entry.Draining; + break; + } + } + } - auto ActiveGuard = MakeGuard([&]() { - m_EstimatedCoreCount.fetch_sub(MachineCoreCount); - m_ActiveCoreCount.fetch_sub(MachineCoreCount); - m_AgentsActive.fetch_sub(1); + if (!WasDraining) + { + m_EstimatedCoreCount.fetch_sub(Result.CoreCount); + } + m_ActiveCoreCount.fetch_sub(Result.CoreCount); + m_AgentsActive.fetch_sub(1); + } + OnAsyncAgentDone(AsyncAgent); }); - while (Agent->IsValid() && !Wrapper.ShouldExit.load()) + // Track active cores (estimated was already added by RequestAgent) + m_EstimatedCoreCount.fetch_add(Machine.LogicalCores); + m_ActiveCoreCount.fetch_add(Machine.LogicalCores); + m_AgentsActive.fetch_add(1); +} + +void +HordeProvisioner::DrainAsyncAgent(AsyncAgentEntry& Entry) +{ + Entry.Draining = true; + m_EstimatedCoreCount.fetch_sub(Entry.CoreCount); + m_AgentsDraining.fetch_add(1); + + HttpClientSettings Settings; + Settings.LogCategory = "horde.drain"; + Settings.ConnectTimeout = std::chrono::milliseconds{5000}; + Settings.Timeout = std::chrono::milliseconds{10000}; + + try { - const bool LogOutput = false; - if (!Agent->Poll(LogOutput)) + HttpClient Client(Entry.RemoteEndpoint, Settings); + + HttpClient::Response Response = Client.Post("/compute/session/drain"); + if (!Response.IsSuccess()) { + ZEN_WARN("drain[{}]: POST session/drain failed: HTTP {}", Entry.LeaseId, static_cast(Response.StatusCode)); + return; + } + + ZEN_INFO("drain[{}]: session/drain accepted, sending sunset", Entry.LeaseId); + (void)Client.Post("/compute/session/sunset"); + } + catch (const std::exception& Ex) + { + ZEN_WARN("drain[{}]: exception: {}", Entry.LeaseId, Ex.what()); + } +} + +void +HordeProvisioner::OnAsyncAgentDone(std::shared_ptr Agent) +{ + std::lock_guard Lock(m_AsyncAgentsLock); + for (auto It = m_AsyncAgents.begin(); It != m_AsyncAgents.end(); ++It) + { + if (It->Agent == Agent) + { + if (It->Draining) + { + m_AgentsDraining.fetch_sub(1); + m_RecentlyDrainedWorkerIds.insert("horde-" + It->LeaseId); + } + m_AsyncAgents.erase(It); break; } - std::this_thread::sleep_for(std::chrono::milliseconds(100)); } } diff --git a/src/zenhorde/hordetransport.cpp b/src/zenhorde/hordetransport.cpp index 69766e73e..65eaea477 100644 --- a/src/zenhorde/hordetransport.cpp +++ b/src/zenhorde/hordetransport.cpp @@ -9,71 +9,33 @@ ZEN_THIRD_PARTY_INCLUDES_START #include ZEN_THIRD_PARTY_INCLUDES_END -#if ZEN_PLATFORM_WINDOWS -# undef SendMessage -#endif - namespace zen::horde { -// ComputeTransport base +// --- AsyncTcpComputeTransport --- -bool -ComputeTransport::SendMessage(const void* Data, size_t Size) +struct AsyncTcpComputeTransport::Impl { - const uint8_t* Ptr = static_cast(Data); - size_t Remaining = Size; - - while (Remaining > 0) - { - const size_t Sent = Send(Ptr, Remaining); - if (Sent == 0) - { - return false; - } - Ptr += Sent; - Remaining -= Sent; - } + asio::io_context& IoContext; + asio::ip::tcp::socket Socket; - return true; -} + explicit Impl(asio::io_context& Ctx) : IoContext(Ctx), Socket(Ctx) {} +}; -bool -ComputeTransport::RecvMessage(void* Data, size_t Size) +AsyncTcpComputeTransport::AsyncTcpComputeTransport(asio::io_context& IoContext) +: m_Impl(std::make_unique(IoContext)) +, m_Log(zen::logging::Get("horde.transport.async")) { - uint8_t* Ptr = static_cast(Data); - size_t Remaining = Size; - - while (Remaining > 0) - { - const size_t Received = Recv(Ptr, Remaining); - if (Received == 0) - { - return false; - } - Ptr += Received; - Remaining -= Received; - } - - return true; } -// TcpComputeTransport - ASIO pimpl - -struct TcpComputeTransport::Impl +AsyncTcpComputeTransport::~AsyncTcpComputeTransport() { - asio::io_context IoContext; - asio::ip::tcp::socket Socket; - - Impl() : Socket(IoContext) {} -}; + Close(); +} -// Uses ASIO in synchronous mode only — no async operations or io_context::run(). -// The io_context is only needed because ASIO sockets require one to be constructed. -TcpComputeTransport::TcpComputeTransport(const MachineInfo& Info) -: m_Impl(std::make_unique()) -, m_Log(zen::logging::Get("horde.transport")) +void +AsyncTcpComputeTransport::AsyncConnect(const MachineInfo& Info, AsyncConnectHandler Handler) { - ZEN_TRACE_CPU("TcpComputeTransport::Connect"); + ZEN_TRACE_CPU("AsyncTcpComputeTransport::AsyncConnect"); asio::error_code Ec; @@ -82,80 +44,75 @@ TcpComputeTransport::TcpComputeTransport(const MachineInfo& Info) { ZEN_WARN("invalid address '{}': {}", Info.GetConnectionAddress(), Ec.message()); m_HasErrors = true; + asio::post(m_Impl->IoContext, [Handler = std::move(Handler), Ec] { Handler(Ec); }); return; } const asio::ip::tcp::endpoint Endpoint(Address, Info.GetConnectionPort()); - m_Impl->Socket.connect(Endpoint, Ec); - if (Ec) - { - ZEN_WARN("failed to connect to Horde compute [{}:{}]: {}", Info.GetConnectionAddress(), Info.GetConnectionPort(), Ec.message()); - m_HasErrors = true; - return; - } + // Copy the nonce so it survives past this scope into the async callback + auto NonceBuf = std::make_shared>(Info.Nonce, Info.Nonce + NonceSize); - // Disable Nagle's algorithm for lower latency - m_Impl->Socket.set_option(asio::ip::tcp::no_delay(true), Ec); -} + m_Impl->Socket.async_connect(Endpoint, [this, Handler = std::move(Handler), NonceBuf](const asio::error_code& Ec) mutable { + if (Ec) + { + ZEN_WARN("async connect failed: {}", Ec.message()); + m_HasErrors = true; + Handler(Ec); + return; + } -TcpComputeTransport::~TcpComputeTransport() -{ - Close(); + asio::error_code SetOptEc; + m_Impl->Socket.set_option(asio::ip::tcp::no_delay(true), SetOptEc); + + // Send the 64-byte nonce as the first thing on the wire + asio::async_write(m_Impl->Socket, + asio::buffer(*NonceBuf), + [this, Handler = std::move(Handler), NonceBuf](const asio::error_code& Ec, size_t /*BytesWritten*/) { + if (Ec) + { + ZEN_WARN("nonce write failed: {}", Ec.message()); + m_HasErrors = true; + } + Handler(Ec); + }); + }); } bool -TcpComputeTransport::IsValid() const +AsyncTcpComputeTransport::IsValid() const { return m_Impl && m_Impl->Socket.is_open() && !m_HasErrors && !m_IsClosed; } -size_t -TcpComputeTransport::Send(const void* Data, size_t Size) +void +AsyncTcpComputeTransport::AsyncWrite(const void* Data, size_t Size, AsyncIoHandler Handler) { if (!IsValid()) { - return 0; - } - - asio::error_code Ec; - const size_t Sent = m_Impl->Socket.send(asio::buffer(Data, Size), 0, Ec); - - if (Ec) - { - m_HasErrors = true; - return 0; + asio::post(m_Impl->IoContext, + [Handler = std::move(Handler)] { Handler(asio::error::make_error_code(asio::error::not_connected), 0); }); + return; } - return Sent; + asio::async_write(m_Impl->Socket, asio::buffer(Data, Size), std::move(Handler)); } -size_t -TcpComputeTransport::Recv(void* Data, size_t Size) +void +AsyncTcpComputeTransport::AsyncRead(void* Data, size_t Size, AsyncIoHandler Handler) { if (!IsValid()) { - return 0; - } - - asio::error_code Ec; - const size_t Received = m_Impl->Socket.receive(asio::buffer(Data, Size), 0, Ec); - - if (Ec) - { - return 0; + asio::post(m_Impl->IoContext, + [Handler = std::move(Handler)] { Handler(asio::error::make_error_code(asio::error::not_connected), 0); }); + return; } - return Received; -} - -void -TcpComputeTransport::MarkComplete() -{ + asio::async_read(m_Impl->Socket, asio::buffer(Data, Size), std::move(Handler)); } void -TcpComputeTransport::Close() +AsyncTcpComputeTransport::Close() { if (!m_IsClosed && m_Impl && m_Impl->Socket.is_open()) { diff --git a/src/zenhorde/hordetransport.h b/src/zenhorde/hordetransport.h index 1b178dc0f..b5e841d7a 100644 --- a/src/zenhorde/hordetransport.h +++ b/src/zenhorde/hordetransport.h @@ -8,55 +8,60 @@ #include #include +#include #include +#include -#if ZEN_PLATFORM_WINDOWS -# undef SendMessage -#endif +namespace asio { +class io_context; +} namespace zen::horde { -/** Abstract base interface for compute transports. +/** Handler types for async transport operations. */ +using AsyncConnectHandler = std::function; +using AsyncIoHandler = std::function; + +/** Abstract base for asynchronous compute transports. * - * Matches the UE FComputeTransport pattern. Concrete implementations handle - * the underlying I/O (TCP, AES-wrapped, etc.) while this interface provides - * blocking message helpers on top. + * All callbacks are invoked on the io_context that was provided at construction. + * Callers are responsible for strand serialization if needed. */ -class ComputeTransport +class AsyncComputeTransport { public: - virtual ~ComputeTransport() = default; + virtual ~AsyncComputeTransport() = default; + + virtual bool IsValid() const = 0; - virtual bool IsValid() const = 0; - virtual size_t Send(const void* Data, size_t Size) = 0; - virtual size_t Recv(void* Data, size_t Size) = 0; - virtual void MarkComplete() = 0; - virtual void Close() = 0; + /** Asynchronous write of exactly Size bytes. Handler called on completion or error. */ + virtual void AsyncWrite(const void* Data, size_t Size, AsyncIoHandler Handler) = 0; - /** Blocking send that loops until all bytes are transferred. Returns false on error. */ - bool SendMessage(const void* Data, size_t Size); + /** Asynchronous read of exactly Size bytes into Data. Handler called on completion or error. */ + virtual void AsyncRead(void* Data, size_t Size, AsyncIoHandler Handler) = 0; - /** Blocking receive that loops until all bytes are transferred. Returns false on error. */ - bool RecvMessage(void* Data, size_t Size); + virtual void Close() = 0; }; -/** TCP socket transport using ASIO. +/** Async TCP transport using ASIO. * - * Connects to the Horde compute endpoint specified by MachineInfo and provides - * raw TCP send/receive. ASIO internals are hidden behind a pimpl to keep the - * header clean. + * Connects to the Horde compute endpoint and provides async send/receive. + * The socket is created on a caller-provided io_context (shared across agents). */ -class TcpComputeTransport final : public ComputeTransport +class AsyncTcpComputeTransport final : public AsyncComputeTransport { public: - explicit TcpComputeTransport(const MachineInfo& Info); - ~TcpComputeTransport() override; - - bool IsValid() const override; - size_t Send(const void* Data, size_t Size) override; - size_t Recv(void* Data, size_t Size) override; - void MarkComplete() override; - void Close() override; + /** Construct a transport on the given io_context. Does not connect yet. */ + explicit AsyncTcpComputeTransport(asio::io_context& IoContext); + ~AsyncTcpComputeTransport() override; + + /** Asynchronously connect to the endpoint and send the nonce. */ + void AsyncConnect(const MachineInfo& Info, AsyncConnectHandler Handler); + + bool IsValid() const override; + void AsyncWrite(const void* Data, size_t Size, AsyncIoHandler Handler) override; + void AsyncRead(void* Data, size_t Size, AsyncIoHandler Handler) override; + void Close() override; private: LoggerRef Log() { return m_Log; } diff --git a/src/zenhorde/hordetransportaes.cpp b/src/zenhorde/hordetransportaes.cpp index 505b6bde7..c71866e8c 100644 --- a/src/zenhorde/hordetransportaes.cpp +++ b/src/zenhorde/hordetransportaes.cpp @@ -5,6 +5,10 @@ #include #include +ZEN_THIRD_PARTY_INCLUDES_START +#include +ZEN_THIRD_PARTY_INCLUDES_END + #include #include #include @@ -22,274 +26,281 @@ ZEN_THIRD_PARTY_INCLUDES_END namespace zen::horde { -struct AesComputeTransport::CryptoContext -{ - uint8_t Key[KeySize] = {}; - uint8_t EncryptNonce[NonceBytes] = {}; - uint8_t DecryptNonce[NonceBytes] = {}; - bool HasErrors = false; +namespace { + + static constexpr size_t AesNonceBytes = 12; + static constexpr size_t AesTagBytes = 16; + + /** AES-256-GCM crypto context. Not exposed outside this translation unit. */ + struct AesCryptoContext + { + static constexpr size_t NonceBytes = AesNonceBytes; + static constexpr size_t TagBytes = AesTagBytes; + + uint8_t Key[KeySize] = {}; + uint8_t EncryptNonce[NonceBytes] = {}; + uint8_t DecryptNonce[NonceBytes] = {}; + bool HasErrors = false; #if !ZEN_PLATFORM_WINDOWS - EVP_CIPHER_CTX* EncCtx = nullptr; - EVP_CIPHER_CTX* DecCtx = nullptr; + EVP_CIPHER_CTX* EncCtx = nullptr; + EVP_CIPHER_CTX* DecCtx = nullptr; #endif - CryptoContext(const uint8_t (&InKey)[KeySize]) - { - memcpy(Key, InKey, KeySize); - - // The encrypt nonce is randomly initialized and then deterministically mutated - // per message via UpdateNonce(). The decrypt nonce is not used — it comes from - // the wire (each received message carries its own nonce in the header). - std::random_device Rd; - std::mt19937 Gen(Rd()); - std::uniform_int_distribution Dist(0, 255); - for (auto& Byte : EncryptNonce) + AesCryptoContext(const uint8_t (&InKey)[KeySize]) { - Byte = static_cast(Dist(Gen)); - } + memcpy(Key, InKey, KeySize); + + std::random_device Rd; + std::mt19937 Gen(Rd()); + std::uniform_int_distribution Dist(0, 255); + for (auto& Byte : EncryptNonce) + { + Byte = static_cast(Dist(Gen)); + } #if !ZEN_PLATFORM_WINDOWS - // Drain any stale OpenSSL errors - while (ERR_get_error() != 0) - { - } + while (ERR_get_error() != 0) + { + } - EncCtx = EVP_CIPHER_CTX_new(); - EVP_EncryptInit_ex(EncCtx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr); + EncCtx = EVP_CIPHER_CTX_new(); + EVP_EncryptInit_ex(EncCtx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr); - DecCtx = EVP_CIPHER_CTX_new(); - EVP_DecryptInit_ex(DecCtx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr); + DecCtx = EVP_CIPHER_CTX_new(); + EVP_DecryptInit_ex(DecCtx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr); #endif - } + } - ~CryptoContext() - { + ~AesCryptoContext() + { #if ZEN_PLATFORM_WINDOWS - SecureZeroMemory(Key, sizeof(Key)); - SecureZeroMemory(EncryptNonce, sizeof(EncryptNonce)); - SecureZeroMemory(DecryptNonce, sizeof(DecryptNonce)); + SecureZeroMemory(Key, sizeof(Key)); + SecureZeroMemory(EncryptNonce, sizeof(EncryptNonce)); + SecureZeroMemory(DecryptNonce, sizeof(DecryptNonce)); #else - OPENSSL_cleanse(Key, sizeof(Key)); - OPENSSL_cleanse(EncryptNonce, sizeof(EncryptNonce)); - OPENSSL_cleanse(DecryptNonce, sizeof(DecryptNonce)); - - if (EncCtx) - { - EVP_CIPHER_CTX_free(EncCtx); + OPENSSL_cleanse(Key, sizeof(Key)); + OPENSSL_cleanse(EncryptNonce, sizeof(EncryptNonce)); + OPENSSL_cleanse(DecryptNonce, sizeof(DecryptNonce)); + + if (EncCtx) + { + EVP_CIPHER_CTX_free(EncCtx); + } + if (DecCtx) + { + EVP_CIPHER_CTX_free(DecCtx); + } +#endif } - if (DecCtx) + + void UpdateNonce() { - EVP_CIPHER_CTX_free(DecCtx); + uint32_t* N32 = reinterpret_cast(EncryptNonce); + N32[0]++; + N32[1]--; + N32[2] = N32[0] ^ N32[1]; } -#endif - } - - void UpdateNonce() - { - uint32_t* N32 = reinterpret_cast(EncryptNonce); - N32[0]++; - N32[1]--; - N32[2] = N32[0] ^ N32[1]; - } - // Returns total encrypted message size, or 0 on failure - // Output format: [length(4B)][nonce(12B)][ciphertext][tag(16B)] - int32_t EncryptMessage(uint8_t* Out, const void* In, int32_t InLength) - { - UpdateNonce(); + int32_t EncryptMessage(uint8_t* Out, const void* In, int32_t InLength) + { + UpdateNonce(); - // On Windows, BCrypt algorithm/key handles are created per call. This is simpler than - // caching but has some overhead. For our use case (relatively large, infrequent messages) - // this is acceptable. #if ZEN_PLATFORM_WINDOWS - BCRYPT_ALG_HANDLE hAlg = nullptr; - BCRYPT_KEY_HANDLE hKey = nullptr; - - BCryptOpenAlgorithmProvider(&hAlg, BCRYPT_AES_ALGORITHM, nullptr, 0); - BCryptSetProperty(hAlg, BCRYPT_CHAINING_MODE, (PUCHAR)BCRYPT_CHAIN_MODE_GCM, sizeof(BCRYPT_CHAIN_MODE_GCM), 0); - BCryptGenerateSymmetricKey(hAlg, &hKey, nullptr, 0, (PUCHAR)Key, KeySize, 0); - - BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO AuthInfo; - BCRYPT_INIT_AUTH_MODE_INFO(AuthInfo); - AuthInfo.pbNonce = EncryptNonce; - AuthInfo.cbNonce = NonceBytes; - uint8_t Tag[TagBytes] = {}; - AuthInfo.pbTag = Tag; - AuthInfo.cbTag = TagBytes; - - ULONG CipherLen = 0; - NTSTATUS Status = - BCryptEncrypt(hKey, (PUCHAR)In, (ULONG)InLength, &AuthInfo, nullptr, 0, Out + 4 + NonceBytes, (ULONG)InLength, &CipherLen, 0); - - if (!BCRYPT_SUCCESS(Status)) - { - HasErrors = true; + BCRYPT_ALG_HANDLE hAlg = nullptr; + BCRYPT_KEY_HANDLE hKey = nullptr; + + BCryptOpenAlgorithmProvider(&hAlg, BCRYPT_AES_ALGORITHM, nullptr, 0); + BCryptSetProperty(hAlg, BCRYPT_CHAINING_MODE, (PUCHAR)BCRYPT_CHAIN_MODE_GCM, sizeof(BCRYPT_CHAIN_MODE_GCM), 0); + BCryptGenerateSymmetricKey(hAlg, &hKey, nullptr, 0, (PUCHAR)Key, KeySize, 0); + + BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO AuthInfo; + BCRYPT_INIT_AUTH_MODE_INFO(AuthInfo); + AuthInfo.pbNonce = EncryptNonce; + AuthInfo.cbNonce = NonceBytes; + uint8_t Tag[TagBytes] = {}; + AuthInfo.pbTag = Tag; + AuthInfo.cbTag = TagBytes; + + ULONG CipherLen = 0; + NTSTATUS Status = BCryptEncrypt(hKey, + (PUCHAR)In, + (ULONG)InLength, + &AuthInfo, + nullptr, + 0, + Out + 4 + NonceBytes, + (ULONG)InLength, + &CipherLen, + 0); + + if (!BCRYPT_SUCCESS(Status)) + { + HasErrors = true; + BCryptDestroyKey(hKey); + BCryptCloseAlgorithmProvider(hAlg, 0); + return 0; + } + + memcpy(Out, &InLength, 4); + memcpy(Out + 4, EncryptNonce, NonceBytes); + memcpy(Out + 4 + NonceBytes + CipherLen, Tag, TagBytes); + BCryptDestroyKey(hKey); BCryptCloseAlgorithmProvider(hAlg, 0); - return 0; - } - - // Write header: length + nonce - memcpy(Out, &InLength, 4); - memcpy(Out + 4, EncryptNonce, NonceBytes); - // Write tag after ciphertext - memcpy(Out + 4 + NonceBytes + CipherLen, Tag, TagBytes); - - BCryptDestroyKey(hKey); - BCryptCloseAlgorithmProvider(hAlg, 0); - return 4 + NonceBytes + static_cast(CipherLen) + TagBytes; + return 4 + NonceBytes + static_cast(CipherLen) + TagBytes; #else - if (EVP_EncryptInit_ex(EncCtx, nullptr, nullptr, Key, EncryptNonce) != 1) - { - HasErrors = true; - return 0; - } - - int32_t Offset = 0; - // Write length - memcpy(Out + Offset, &InLength, 4); - Offset += 4; - // Write nonce - memcpy(Out + Offset, EncryptNonce, NonceBytes); - Offset += NonceBytes; - - // Encrypt - int OutLen = 0; - if (EVP_EncryptUpdate(EncCtx, Out + Offset, &OutLen, static_cast(In), InLength) != 1) - { - HasErrors = true; - return 0; - } - Offset += OutLen; - - // Finalize - int FinalLen = 0; - if (EVP_EncryptFinal_ex(EncCtx, Out + Offset, &FinalLen) != 1) - { - HasErrors = true; - return 0; + if (EVP_EncryptInit_ex(EncCtx, nullptr, nullptr, Key, EncryptNonce) != 1) + { + HasErrors = true; + return 0; + } + + int32_t Offset = 0; + memcpy(Out + Offset, &InLength, 4); + Offset += 4; + memcpy(Out + Offset, EncryptNonce, NonceBytes); + Offset += NonceBytes; + + int OutLen = 0; + if (EVP_EncryptUpdate(EncCtx, Out + Offset, &OutLen, static_cast(In), InLength) != 1) + { + HasErrors = true; + return 0; + } + Offset += OutLen; + + int FinalLen = 0; + if (EVP_EncryptFinal_ex(EncCtx, Out + Offset, &FinalLen) != 1) + { + HasErrors = true; + return 0; + } + Offset += FinalLen; + + if (EVP_CIPHER_CTX_ctrl(EncCtx, EVP_CTRL_GCM_GET_TAG, TagBytes, Out + Offset) != 1) + { + HasErrors = true; + return 0; + } + Offset += TagBytes; + + return Offset; +#endif } - Offset += FinalLen; - // Get tag - if (EVP_CIPHER_CTX_ctrl(EncCtx, EVP_CTRL_GCM_GET_TAG, TagBytes, Out + Offset) != 1) + int32_t DecryptMessage(void* Out, const uint8_t* Nonce, const uint8_t* CipherAndTag, int32_t DataLength) { - HasErrors = true; - return 0; - } - Offset += TagBytes; - - return Offset; -#endif - } - - // Decrypt a message. Returns decrypted data length, or 0 on failure. - // Input must be [ciphertext][tag], with nonce provided separately. - int32_t DecryptMessage(void* Out, const uint8_t* Nonce, const uint8_t* CipherAndTag, int32_t DataLength) - { #if ZEN_PLATFORM_WINDOWS - BCRYPT_ALG_HANDLE hAlg = nullptr; - BCRYPT_KEY_HANDLE hKey = nullptr; - - BCryptOpenAlgorithmProvider(&hAlg, BCRYPT_AES_ALGORITHM, nullptr, 0); - BCryptSetProperty(hAlg, BCRYPT_CHAINING_MODE, (PUCHAR)BCRYPT_CHAIN_MODE_GCM, sizeof(BCRYPT_CHAIN_MODE_GCM), 0); - BCryptGenerateSymmetricKey(hAlg, &hKey, nullptr, 0, (PUCHAR)Key, KeySize, 0); - - BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO AuthInfo; - BCRYPT_INIT_AUTH_MODE_INFO(AuthInfo); - AuthInfo.pbNonce = const_cast(Nonce); - AuthInfo.cbNonce = NonceBytes; - AuthInfo.pbTag = const_cast(CipherAndTag + DataLength); - AuthInfo.cbTag = TagBytes; - - ULONG PlainLen = 0; - NTSTATUS Status = BCryptDecrypt(hKey, - (PUCHAR)CipherAndTag, - (ULONG)DataLength, - &AuthInfo, - nullptr, - 0, - (PUCHAR)Out, - (ULONG)DataLength, - &PlainLen, - 0); - - BCryptDestroyKey(hKey); - BCryptCloseAlgorithmProvider(hAlg, 0); - - if (!BCRYPT_SUCCESS(Status)) - { - HasErrors = true; - return 0; - } + BCRYPT_ALG_HANDLE hAlg = nullptr; + BCRYPT_KEY_HANDLE hKey = nullptr; + + BCryptOpenAlgorithmProvider(&hAlg, BCRYPT_AES_ALGORITHM, nullptr, 0); + BCryptSetProperty(hAlg, BCRYPT_CHAINING_MODE, (PUCHAR)BCRYPT_CHAIN_MODE_GCM, sizeof(BCRYPT_CHAIN_MODE_GCM), 0); + BCryptGenerateSymmetricKey(hAlg, &hKey, nullptr, 0, (PUCHAR)Key, KeySize, 0); + + BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO AuthInfo; + BCRYPT_INIT_AUTH_MODE_INFO(AuthInfo); + AuthInfo.pbNonce = const_cast(Nonce); + AuthInfo.cbNonce = NonceBytes; + AuthInfo.pbTag = const_cast(CipherAndTag + DataLength); + AuthInfo.cbTag = TagBytes; + + ULONG PlainLen = 0; + NTSTATUS Status = BCryptDecrypt(hKey, + (PUCHAR)CipherAndTag, + (ULONG)DataLength, + &AuthInfo, + nullptr, + 0, + (PUCHAR)Out, + (ULONG)DataLength, + &PlainLen, + 0); - return static_cast(PlainLen); -#else - if (EVP_DecryptInit_ex(DecCtx, nullptr, nullptr, Key, Nonce) != 1) - { - HasErrors = true; - return 0; - } + BCryptDestroyKey(hKey); + BCryptCloseAlgorithmProvider(hAlg, 0); - int OutLen = 0; - if (EVP_DecryptUpdate(DecCtx, static_cast(Out), &OutLen, CipherAndTag, DataLength) != 1) - { - HasErrors = true; - return 0; - } + if (!BCRYPT_SUCCESS(Status)) + { + HasErrors = true; + return 0; + } - // Set the tag for verification - if (EVP_CIPHER_CTX_ctrl(DecCtx, EVP_CTRL_GCM_SET_TAG, TagBytes, const_cast(CipherAndTag + DataLength)) != 1) - { - HasErrors = true; - return 0; + return static_cast(PlainLen); +#else + if (EVP_DecryptInit_ex(DecCtx, nullptr, nullptr, Key, Nonce) != 1) + { + HasErrors = true; + return 0; + } + + int OutLen = 0; + if (EVP_DecryptUpdate(DecCtx, static_cast(Out), &OutLen, CipherAndTag, DataLength) != 1) + { + HasErrors = true; + return 0; + } + + if (EVP_CIPHER_CTX_ctrl(DecCtx, EVP_CTRL_GCM_SET_TAG, TagBytes, const_cast(CipherAndTag + DataLength)) != 1) + { + HasErrors = true; + return 0; + } + + int FinalLen = 0; + if (EVP_DecryptFinal_ex(DecCtx, static_cast(Out) + OutLen, &FinalLen) != 1) + { + HasErrors = true; + return 0; + } + + return OutLen + FinalLen; +#endif } + }; - int FinalLen = 0; - if (EVP_DecryptFinal_ex(DecCtx, static_cast(Out) + OutLen, &FinalLen) != 1) - { - HasErrors = true; - return 0; - } +} // anonymous namespace - return OutLen + FinalLen; -#endif - } +struct AsyncAesComputeTransport::CryptoContext : AesCryptoContext +{ + using AesCryptoContext::AesCryptoContext; }; -AesComputeTransport::AesComputeTransport(const uint8_t (&Key)[KeySize], std::unique_ptr InnerTransport) +// --- AsyncAesComputeTransport --- + +AsyncAesComputeTransport::AsyncAesComputeTransport(const uint8_t (&Key)[KeySize], + std::unique_ptr InnerTransport, + asio::io_context& IoContext) : m_Crypto(std::make_unique(Key)) , m_Inner(std::move(InnerTransport)) +, m_IoContext(IoContext) { } -AesComputeTransport::~AesComputeTransport() +AsyncAesComputeTransport::~AsyncAesComputeTransport() { Close(); } bool -AesComputeTransport::IsValid() const +AsyncAesComputeTransport::IsValid() const { return m_Inner && m_Inner->IsValid() && m_Crypto && !m_Crypto->HasErrors && !m_IsClosed; } -size_t -AesComputeTransport::Send(const void* Data, size_t Size) +void +AsyncAesComputeTransport::AsyncWrite(const void* Data, size_t Size, AsyncIoHandler Handler) { - ZEN_TRACE_CPU("AesComputeTransport::Send"); - if (!IsValid()) { - return 0; + asio::post(m_IoContext, [Handler = std::move(Handler)] { Handler(asio::error::make_error_code(asio::error::not_connected), 0); }); + return; } - std::lock_guard Lock(m_Lock); - const int32_t DataLength = static_cast(Size); - const size_t MessageLength = 4 + NonceBytes + Size + TagBytes; + const size_t MessageLength = 4 + CryptoContext::NonceBytes + Size + CryptoContext::TagBytes; if (m_EncryptBuffer.size() < MessageLength) { @@ -299,38 +310,36 @@ AesComputeTransport::Send(const void* Data, size_t Size) const int32_t EncryptedLen = m_Crypto->EncryptMessage(m_EncryptBuffer.data(), Data, DataLength); if (EncryptedLen == 0) { - return 0; + asio::post(m_IoContext, + [Handler = std::move(Handler)] { Handler(asio::error::make_error_code(asio::error::connection_aborted), 0); }); + return; } - if (!m_Inner->SendMessage(m_EncryptBuffer.data(), static_cast(EncryptedLen))) - { - return 0; - } + auto EncBuf = std::make_shared>(m_EncryptBuffer.begin(), m_EncryptBuffer.begin() + EncryptedLen); - return Size; + m_Inner->AsyncWrite( + EncBuf->data(), + EncBuf->size(), + [Handler = std::move(Handler), EncBuf, Size](const std::error_code& Ec, size_t /*BytesWritten*/) { Handler(Ec, Ec ? 0 : Size); }); } -size_t -AesComputeTransport::Recv(void* Data, size_t Size) +void +AsyncAesComputeTransport::AsyncRead(void* Data, size_t Size, AsyncIoHandler Handler) { if (!IsValid()) { - return 0; + asio::post(m_IoContext, [Handler = std::move(Handler)] { Handler(asio::error::make_error_code(asio::error::not_connected), 0); }); + return; } - // AES-GCM decrypts entire messages at once, but the caller may request fewer bytes - // than the decrypted message contains. Excess bytes are buffered in m_RemainingData - // and returned on subsequent Recv calls without another decryption round-trip. - ZEN_TRACE_CPU("AesComputeTransport::Recv"); - - std::lock_guard Lock(m_Lock); + uint8_t* Dest = static_cast(Data); if (!m_RemainingData.empty()) { const size_t Available = m_RemainingData.size() - m_RemainingOffset; const size_t ToCopy = std::min(Available, Size); - memcpy(Data, m_RemainingData.data() + m_RemainingOffset, ToCopy); + memcpy(Dest, m_RemainingData.data() + m_RemainingOffset, ToCopy); m_RemainingOffset += ToCopy; if (m_RemainingOffset >= m_RemainingData.size()) @@ -339,78 +348,96 @@ AesComputeTransport::Recv(void* Data, size_t Size) m_RemainingOffset = 0; } - return ToCopy; - } - - // Receive packet header: [length(4B)][nonce(12B)] - struct PacketHeader - { - int32_t DataLength = 0; - uint8_t Nonce[NonceBytes] = {}; - } Header; - - if (!m_Inner->RecvMessage(&Header, sizeof(Header))) - { - return 0; - } - - // Validate DataLength to prevent OOM from malicious/corrupt peers - static constexpr int32_t MaxDataLength = 64 * 1024 * 1024; // 64 MiB - - if (Header.DataLength <= 0 || Header.DataLength > MaxDataLength) - { - ZEN_WARN("AES recv: invalid DataLength {} from peer", Header.DataLength); - return 0; - } - - // Receive ciphertext + tag - const size_t MessageLength = static_cast(Header.DataLength) + TagBytes; - - if (m_EncryptBuffer.size() < MessageLength) - { - m_EncryptBuffer.resize(MessageLength); - } - - if (!m_Inner->RecvMessage(m_EncryptBuffer.data(), MessageLength)) - { - return 0; - } - - // Decrypt - const size_t BytesToReturn = std::min(static_cast(Header.DataLength), Size); - - // We need a temporary buffer for decryption if we can't decrypt directly into output - std::vector DecryptedBuf(static_cast(Header.DataLength)); - - const int32_t Decrypted = m_Crypto->DecryptMessage(DecryptedBuf.data(), Header.Nonce, m_EncryptBuffer.data(), Header.DataLength); - if (Decrypted == 0) - { - return 0; - } - - memcpy(Data, DecryptedBuf.data(), BytesToReturn); + if (ToCopy == Size) + { + asio::post(m_IoContext, [Handler = std::move(Handler), Size] { Handler(std::error_code{}, Size); }); + return; + } - // Store remaining data if we couldn't return everything - if (static_cast(Header.DataLength) > BytesToReturn) - { - m_RemainingOffset = 0; - m_RemainingData.assign(DecryptedBuf.begin() + BytesToReturn, DecryptedBuf.begin() + Header.DataLength); + DoRecvMessage(Dest + ToCopy, Size - ToCopy, std::move(Handler)); + return; } - return BytesToReturn; + DoRecvMessage(Dest, Size, std::move(Handler)); } void -AesComputeTransport::MarkComplete() +AsyncAesComputeTransport::DoRecvMessage(uint8_t* Dest, size_t Size, AsyncIoHandler Handler) { - if (IsValid()) - { - m_Inner->MarkComplete(); - } + static constexpr size_t HeaderSize = 4 + CryptoContext::NonceBytes; + auto HeaderBuf = std::make_shared>(); + + m_Inner->AsyncRead(HeaderBuf->data(), + HeaderSize, + [this, Dest, Size, Handler = std::move(Handler), HeaderBuf](const std::error_code& Ec, size_t /*Bytes*/) mutable { + if (Ec) + { + Handler(Ec, 0); + return; + } + + int32_t DataLength = 0; + memcpy(&DataLength, HeaderBuf->data(), 4); + + static constexpr int32_t MaxDataLength = 64 * 1024 * 1024; + if (DataLength <= 0 || DataLength > MaxDataLength) + { + Handler(asio::error::make_error_code(asio::error::invalid_argument), 0); + return; + } + + const size_t MessageLength = static_cast(DataLength) + CryptoContext::TagBytes; + if (m_DecryptBuffer.size() < MessageLength) + { + m_DecryptBuffer.resize(MessageLength); + } + + auto NonceBuf = std::make_shared>(); + memcpy(NonceBuf->data(), HeaderBuf->data() + 4, CryptoContext::NonceBytes); + + m_Inner->AsyncRead( + m_DecryptBuffer.data(), + MessageLength, + [this, Dest, Size, Handler = std::move(Handler), DataLength, NonceBuf](const std::error_code& Ec, + size_t /*Bytes*/) mutable { + if (Ec) + { + Handler(Ec, 0); + return; + } + + std::vector PlaintextBuf(static_cast(DataLength)); + const int32_t Decrypted = + m_Crypto->DecryptMessage(PlaintextBuf.data(), NonceBuf->data(), m_DecryptBuffer.data(), DataLength); + if (Decrypted == 0) + { + Handler(asio::error::make_error_code(asio::error::connection_aborted), 0); + return; + } + + const size_t BytesToReturn = std::min(static_cast(Decrypted), Size); + memcpy(Dest, PlaintextBuf.data(), BytesToReturn); + + if (static_cast(Decrypted) > BytesToReturn) + { + m_RemainingOffset = 0; + m_RemainingData.assign(PlaintextBuf.begin() + BytesToReturn, PlaintextBuf.begin() + Decrypted); + } + + if (BytesToReturn < Size) + { + DoRecvMessage(Dest + BytesToReturn, Size - BytesToReturn, std::move(Handler)); + } + else + { + Handler(std::error_code{}, Size); + } + }); + }); } void -AesComputeTransport::Close() +AsyncAesComputeTransport::Close() { if (!m_IsClosed) { diff --git a/src/zenhorde/hordetransportaes.h b/src/zenhorde/hordetransportaes.h index efcad9835..a1800c684 100644 --- a/src/zenhorde/hordetransportaes.h +++ b/src/zenhorde/hordetransportaes.h @@ -6,47 +6,53 @@ #include #include -#include #include +namespace asio { +class io_context; +} + namespace zen::horde { -/** AES-256-GCM encrypted transport wrapper. +/** Async AES-256-GCM encrypted transport wrapper. * - * Wraps an inner ComputeTransport, encrypting all outgoing data and decrypting - * all incoming data using AES-256-GCM. The nonce is mutated per message using - * the Horde nonce mangling scheme: n32[0]++; n32[1]--; n32[2] = n32[0] ^ n32[1]. + * Wraps an AsyncComputeTransport, encrypting outgoing and decrypting incoming + * data using AES-256-GCM. The nonce is mutated per message using the Horde + * nonce mangling scheme: n32[0]++; n32[1]--; n32[2] = n32[0] ^ n32[1]. * * Wire format per encrypted message: * [plaintext length (4B little-endian)][nonce (12B)][ciphertext][GCM tag (16B)] * * Uses BCrypt on Windows and OpenSSL EVP on Linux/macOS (selected at compile time). + * + * Thread safety: all operations must be serialized by the caller (e.g. via a strand). */ -class AesComputeTransport final : public ComputeTransport +class AsyncAesComputeTransport final : public AsyncComputeTransport { public: - AesComputeTransport(const uint8_t (&Key)[KeySize], std::unique_ptr InnerTransport); - ~AesComputeTransport() override; + AsyncAesComputeTransport(const uint8_t (&Key)[KeySize], + std::unique_ptr InnerTransport, + asio::io_context& IoContext); + ~AsyncAesComputeTransport() override; - bool IsValid() const override; - size_t Send(const void* Data, size_t Size) override; - size_t Recv(void* Data, size_t Size) override; - void MarkComplete() override; - void Close() override; + bool IsValid() const override; + void AsyncWrite(const void* Data, size_t Size, AsyncIoHandler Handler) override; + void AsyncRead(void* Data, size_t Size, AsyncIoHandler Handler) override; + void Close() override; private: - static constexpr size_t NonceBytes = 12; ///< AES-GCM nonce size - static constexpr size_t TagBytes = 16; ///< AES-GCM authentication tag size + void DoRecvMessage(uint8_t* Dest, size_t Size, AsyncIoHandler Handler); struct CryptoContext; - std::unique_ptr m_Crypto; - std::unique_ptr m_Inner; - std::vector m_EncryptBuffer; - std::vector m_RemainingData; ///< Buffered decrypted data from a partially consumed Recv - size_t m_RemainingOffset = 0; - std::mutex m_Lock; - bool m_IsClosed = false; + std::unique_ptr m_Crypto; + std::unique_ptr m_Inner; + asio::io_context& m_IoContext; + std::vector m_EncryptBuffer; + std::vector m_DecryptBuffer; + std::vector m_RemainingData; + size_t m_RemainingOffset = 0; + bool m_IsClosed = false; }; } // namespace zen::horde diff --git a/src/zenhorde/include/zenhorde/hordeclient.h b/src/zenhorde/include/zenhorde/hordeclient.h index 201d68b83..87caec019 100644 --- a/src/zenhorde/include/zenhorde/hordeclient.h +++ b/src/zenhorde/include/zenhorde/hordeclient.h @@ -45,14 +45,15 @@ struct MachineInfo uint8_t Key[KeySize] = {}; ///< 32-byte AES key (when EncryptionMode == AES) bool IsWindows = false; std::string LeaseId; + std::string Pool; std::map Ports; /** Return the address to connect to, accounting for connection mode. */ - const std::string& GetConnectionAddress() const { return Mode == ConnectionMode::Relay ? ConnectionAddress : Ip; } + [[nodiscard]] const std::string& GetConnectionAddress() const { return Mode == ConnectionMode::Relay ? ConnectionAddress : Ip; } /** Return the port to connect to, accounting for connection mode and port mapping. */ - uint16_t GetConnectionPort() const + [[nodiscard]] uint16_t GetConnectionPort() const { if (Mode == ConnectionMode::Relay) { @@ -65,7 +66,20 @@ struct MachineInfo return Port; } - bool IsValid() const { return !Ip.empty() && Port != 0xFFFF; } + /** Return the address and port for the Zen service endpoint, accounting for relay port mapping. */ + [[nodiscard]] std::pair GetZenServiceEndpoint(uint16_t DefaultPort) const + { + if (Mode == ConnectionMode::Relay) + { + if (auto It = Ports.find("ZenPort"); It != Ports.end()) + { + return {ConnectionAddress, It->second.Port}; + } + } + return {Ip, DefaultPort}; + } + + [[nodiscard]] bool IsValid() const { return !Ip.empty() && Port != 0xFFFF; } }; /** Result of cluster auto-resolution via the Horde API. */ @@ -83,31 +97,29 @@ struct ClusterInfo class HordeClient { public: - explicit HordeClient(const HordeConfig& Config); + explicit HordeClient(HordeConfig Config); ~HordeClient(); HordeClient(const HordeClient&) = delete; HordeClient& operator=(const HordeClient&) = delete; /** Initialize the underlying HTTP client. Must be called before other methods. */ - bool Initialize(); + [[nodiscard]] bool Initialize(); /** Build the JSON request body for cluster resolution and machine requests. * Encodes pool, condition, connection mode, encryption, and port requirements. */ - std::string BuildRequestBody() const; + [[nodiscard]] std::string BuildRequestBody() const; /** Resolve the best cluster for the given request via POST /api/v2/compute/_cluster. */ - bool ResolveCluster(const std::string& RequestBody, ClusterInfo& OutCluster); + [[nodiscard]] bool ResolveCluster(const std::string& RequestBody, ClusterInfo& OutCluster); /** Request a compute machine from the given cluster via POST /api/v2/compute/{clusterId}. * On success, populates OutMachine with connection details and credentials. */ - bool RequestMachine(const std::string& RequestBody, const std::string& ClusterId, MachineInfo& OutMachine); + [[nodiscard]] bool RequestMachine(const std::string& RequestBody, const std::string& ClusterId, MachineInfo& OutMachine); LoggerRef Log() { return m_Log; } private: - bool ParseHexBytes(std::string_view Hex, uint8_t* Out, size_t OutSize); - HordeConfig m_Config; std::unique_ptr m_Http; LoggerRef m_Log; diff --git a/src/zenhorde/include/zenhorde/hordeconfig.h b/src/zenhorde/include/zenhorde/hordeconfig.h index dd70f9832..3a4dfb386 100644 --- a/src/zenhorde/include/zenhorde/hordeconfig.h +++ b/src/zenhorde/include/zenhorde/hordeconfig.h @@ -4,6 +4,10 @@ #include +#include + +#include +#include #include namespace zen::horde { @@ -33,20 +37,25 @@ struct HordeConfig static constexpr const char* ClusterDefault = "default"; static constexpr const char* ClusterAuto = "_auto"; - bool Enabled = false; ///< Whether Horde provisioning is active - std::string ServerUrl; ///< Horde server base URL (e.g. "https://horde.example.com") - std::string AuthToken; ///< Authentication token for the Horde API - std::string Pool; ///< Pool name to request machines from - std::string Cluster = ClusterDefault; ///< Cluster ID, or "_auto" to auto-resolve - std::string Condition; ///< Agent filter expression for machine selection - std::string HostAddress; ///< Address that provisioned agents use to connect back to us - std::string BinariesPath; ///< Path to directory containing zenserver binary for remote upload - uint16_t ZenServicePort = 8558; ///< Port number that provisioned agents should forward to us for Zen service communication - - int MaxCores = 2048; - bool AllowWine = true; ///< Allow running Windows binaries under Wine on Linux agents - ConnectionMode Mode = ConnectionMode::Direct; - Encryption EncryptionMode = Encryption::None; + bool Enabled = false; ///< Whether Horde provisioning is active + std::string ServerUrl; ///< Horde server base URL (e.g. "https://horde.example.com") + std::string AuthToken; ///< Authentication token for the Horde API (static fallback) + + /// Optional token provider with automatic refresh (e.g. from OidcToken executable). + /// When set, takes priority over the static AuthToken string. + std::optional> AccessTokenProvider; + std::string Pool; ///< Pool name to request machines from + std::string Cluster = ClusterDefault; ///< Cluster ID, or "_auto" to auto-resolve + std::string Condition; ///< Agent filter expression for machine selection + std::string HostAddress; ///< Address that provisioned agents use to connect back to us + std::string BinariesPath; ///< Path to directory containing zenserver binary for remote upload + uint16_t ZenServicePort = 8558; ///< Port number that provisioned agents should forward to us for Zen service communication + + int MaxCores = 2048; + int DrainGracePeriodSeconds = 300; ///< Grace period for draining agents before force-kill (default 5 min) + bool AllowWine = true; ///< Allow running Windows binaries under Wine on Linux agents + ConnectionMode Mode = ConnectionMode::Direct; + Encryption EncryptionMode = Encryption::None; /** Validate the configuration. Returns false if the configuration is invalid * (e.g. Relay mode without AES encryption). */ diff --git a/src/zenhorde/include/zenhorde/hordeprovisioner.h b/src/zenhorde/include/zenhorde/hordeprovisioner.h index 4e2e63bbd..1dd12936b 100644 --- a/src/zenhorde/include/zenhorde/hordeprovisioner.h +++ b/src/zenhorde/include/zenhorde/hordeprovisioner.h @@ -2,9 +2,12 @@ #pragma once +#include #include +#include #include +#include #include #include @@ -12,11 +15,18 @@ #include #include #include +#include +#include #include +namespace asio { +class io_context; +} + namespace zen::horde { class HordeClient; +class AsyncHordeAgent; /** Snapshot of the current provisioning state, returned by HordeProvisioner::GetStats(). */ struct ProvisioningStats @@ -35,13 +45,12 @@ struct ProvisioningStats * binary, and executing it remotely. Each provisioned machine runs zenserver * in compute mode, which announces itself back to the orchestrator. * - * Spawns one thread per agent. Each thread handles the full lifecycle: - * HTTP request -> TCP connect -> nonce handshake -> optional AES encryption -> - * channel setup -> binary upload -> remote execution -> poll until exit. + * Agent work (HTTP request, connect, upload, poll) is dispatched to a thread + * pool rather than spawning a dedicated thread per agent. * * Thread safety: SetTargetCoreCount and GetStats may be called from any thread. */ -class HordeProvisioner +class HordeProvisioner : public compute::IProvisionerStateProvider { public: /** Construct a provisioner. @@ -52,38 +61,48 @@ public: HordeProvisioner(const HordeConfig& Config, const std::filesystem::path& BinariesPath, const std::filesystem::path& WorkingDir, - std::string_view OrchestratorEndpoint); + std::string_view OrchestratorEndpoint, + std::string_view CoordinatorSession = {}, + bool CleanStart = false, + std::string_view TraceHost = {}); - /** Signals all agent threads to exit and joins them. */ - ~HordeProvisioner(); + /** Signals all agents to exit and waits for completion. */ + ~HordeProvisioner() override; HordeProvisioner(const HordeProvisioner&) = delete; HordeProvisioner& operator=(const HordeProvisioner&) = delete; /** Set the target number of cores to provision. - * Clamped to HordeConfig::MaxCores. Spawns new agent threads if the - * estimated core count is below the target. Also joins any finished - * agent threads. */ - void SetTargetCoreCount(uint32_t Count); + * Clamped to HordeConfig::MaxCores. Dispatches new agent work if the + * estimated core count is below the target. Also removes finished agents. */ + void SetTargetCoreCount(uint32_t Count) override; /** Return a snapshot of the current provisioning counters. */ ProvisioningStats GetStats() const; - uint32_t GetActiveCoreCount() const { return m_ActiveCoreCount.load(); } - uint32_t GetAgentCount() const; + // IProvisionerStateProvider + std::string_view GetName() const override { return "horde"; } + uint32_t GetTargetCoreCount() const override { return m_TargetCoreCount.load(); } + uint32_t GetEstimatedCoreCount() const override { return m_EstimatedCoreCount.load(); } + uint32_t GetActiveCoreCount() const override { return m_ActiveCoreCount.load(); } + uint32_t GetAgentCount() const override; + uint32_t GetDrainingAgentCount() const override { return m_AgentsDraining.load(); } + compute::AgentProvisioningStatus GetAgentStatus(std::string_view WorkerId) const override; private: LoggerRef Log() { return m_Log; } - struct AgentWrapper; - void RequestAgent(); - void ThreadAgent(AgentWrapper& Wrapper); + void ProvisionAgent(); + bool InitializeHordeClient(); HordeConfig m_Config; std::filesystem::path m_BinariesPath; std::filesystem::path m_WorkingDir; std::string m_OrchestratorEndpoint; + std::string m_CoordinatorSession; + bool m_CleanStart = false; + std::string m_TraceHost; std::unique_ptr m_HordeClient; @@ -91,20 +110,43 @@ private: std::vector> m_Bundles; ///< (locator, bundleDir) pairs bool m_BundlesCreated = false; - mutable std::mutex m_AgentsLock; - std::vector> m_Agents; - std::atomic m_LastRequestFailTime{0}; std::atomic m_TargetCoreCount{0}; std::atomic m_EstimatedCoreCount{0}; std::atomic m_ActiveCoreCount{0}; std::atomic m_AgentsActive{0}; + std::atomic m_AgentsDraining{0}; std::atomic m_AgentsRequesting{0}; std::atomic m_AskForAgents{true}; + std::atomic m_PendingWorkItems{0}; + Event m_AllWorkDone; LoggerRef m_Log; + // Async I/O + std::unique_ptr m_IoContext; + std::vector m_IoThreads; + + struct AsyncAgentEntry + { + std::shared_ptr Agent; + std::string RemoteEndpoint; + std::string LeaseId; + uint16_t CoreCount = 0; + bool Draining = false; + }; + + mutable std::mutex m_AsyncAgentsLock; + std::vector m_AsyncAgents; + mutable std::unordered_set m_RecentlyDrainedWorkerIds; ///< Worker IDs of agents that completed after draining + + void OnAsyncAgentDone(std::shared_ptr Agent); + void DrainAsyncAgent(AsyncAgentEntry& Entry); + + std::vector BuildAgentArgs(const MachineInfo& Machine) const; + static constexpr uint32_t EstimatedCoresPerAgent = 32; + static constexpr int IoThreadCount = 3; }; } // namespace zen::horde diff --git a/src/zenhttp/clients/httpclientcurl.cpp b/src/zenhttp/clients/httpclientcurl.cpp index 446dd80be..56b9c39c5 100644 --- a/src/zenhttp/clients/httpclientcurl.cpp +++ b/src/zenhttp/clients/httpclientcurl.cpp @@ -228,6 +228,13 @@ CurlHttpClient::Session::Perform() curl_easy_getinfo(Handle, CURLINFO_SIZE_DOWNLOAD_T, &DownBytes); Result.DownloadedBytes = static_cast(DownBytes); + char* EffectiveUrl = nullptr; + curl_easy_getinfo(Handle, CURLINFO_EFFECTIVE_URL, &EffectiveUrl); + if (EffectiveUrl) + { + Result.Url = EffectiveUrl; + } + return Result; } @@ -294,8 +301,9 @@ CurlHttpClient::CommonResponse(std::string_view SessionId, if (Result.ErrorCode != CURLE_OPERATION_TIMEDOUT && Result.ErrorCode != CURLE_COULDNT_CONNECT && Result.ErrorCode != CURLE_ABORTED_BY_CALLBACK) { - ZEN_WARN("HttpClient client failure (session: {}): ({}) '{}'", + ZEN_WARN("HttpClient client failure (session: {}, url: {}): ({}) '{}'", SessionId, + Result.Url, static_cast(Result.ErrorCode), Result.ErrorMessage); } diff --git a/src/zenhttp/clients/httpclientcurl.h b/src/zenhttp/clients/httpclientcurl.h index bdeb46633..ea9193e65 100644 --- a/src/zenhttp/clients/httpclientcurl.h +++ b/src/zenhttp/clients/httpclientcurl.h @@ -73,6 +73,7 @@ private: int64_t DownloadedBytes = 0; CURLcode ErrorCode = CURLE_OK; std::string ErrorMessage; + std::string Url; }; struct Session diff --git a/src/zenhttp/httpclientauth.cpp b/src/zenhttp/httpclientauth.cpp index c42841922..0432e50ef 100644 --- a/src/zenhttp/httpclientauth.cpp +++ b/src/zenhttp/httpclientauth.cpp @@ -94,7 +94,8 @@ namespace zen { namespace httpclientauth { std::string_view CloudHost, bool Unattended, bool Quiet, - bool Hidden) + bool Hidden, + bool IsHordeUrl) { Stopwatch Timer; @@ -117,8 +118,9 @@ namespace zen { namespace httpclientauth { } }); - const std::string ProcArgs = fmt::format("{} --AuthConfigUrl {} --OutFile {} --Unattended={}", + const std::string ProcArgs = fmt::format("{} {} {} --OutFile {} --Unattended={}", OidcExecutablePath, + IsHordeUrl ? "--HordeUrl" : "--AuthConfigUrl", CloudHost, AuthTokenPath, Unattended ? "true"sv : "false"sv); @@ -193,7 +195,7 @@ namespace zen { namespace httpclientauth { } else { - ZEN_WARN("Failed running {} to get auth token, error code {}", OidcExecutablePath, ExitCode); + ZEN_WARN("Failed running '{}' to get auth token, error code {}", ProcArgs, ExitCode); } return HttpClientAccessToken{}; } @@ -202,9 +204,10 @@ namespace zen { namespace httpclientauth { std::string_view CloudHost, bool Quiet, bool Unattended, - bool Hidden) + bool Hidden, + bool IsHordeUrl) { - HttpClientAccessToken InitialToken = GetOidcTokenFromExe(OidcExecutablePath, CloudHost, Unattended, Quiet, Hidden); + HttpClientAccessToken InitialToken = GetOidcTokenFromExe(OidcExecutablePath, CloudHost, Unattended, Quiet, Hidden, IsHordeUrl); if (InitialToken.IsValid()) { return [OidcExecutablePath = std::filesystem::path(OidcExecutablePath), @@ -212,12 +215,13 @@ namespace zen { namespace httpclientauth { Token = InitialToken, Quiet, Unattended, - Hidden]() mutable { + Hidden, + IsHordeUrl]() mutable { if (!Token.NeedsRefresh()) { return std::move(Token); } - return GetOidcTokenFromExe(OidcExecutablePath, CloudHost, Unattended, Quiet, Hidden); + return GetOidcTokenFromExe(OidcExecutablePath, CloudHost, Unattended, Quiet, Hidden, IsHordeUrl); }; } return {}; diff --git a/src/zenhttp/include/zenhttp/httpclientauth.h b/src/zenhttp/include/zenhttp/httpclientauth.h index ce646ebd7..9220a50b6 100644 --- a/src/zenhttp/include/zenhttp/httpclientauth.h +++ b/src/zenhttp/include/zenhttp/httpclientauth.h @@ -33,7 +33,8 @@ namespace httpclientauth { std::string_view CloudHost, bool Quiet, bool Unattended, - bool Hidden); + bool Hidden, + bool IsHordeUrl = false); } // namespace httpclientauth } // namespace zen diff --git a/src/zennomad/include/zennomad/nomadclient.h b/src/zennomad/include/zennomad/nomadclient.h index 0a3411ace..cebf217e1 100644 --- a/src/zennomad/include/zennomad/nomadclient.h +++ b/src/zennomad/include/zennomad/nomadclient.h @@ -52,7 +52,11 @@ public: /** Build the Nomad job registration JSON for the given job ID and orchestrator endpoint. * The JSON structure varies based on the configured driver and distribution mode. */ - std::string BuildJobJson(const std::string& JobId, const std::string& OrchestratorEndpoint) const; + std::string BuildJobJson(const std::string& JobId, + const std::string& OrchestratorEndpoint, + const std::string& CoordinatorSession = {}, + bool CleanStart = false, + const std::string& TraceHost = {}) const; /** Submit a job via PUT /v1/jobs. On success, populates OutJob with the job info. */ bool SubmitJob(const std::string& JobJson, NomadJobInfo& OutJob); diff --git a/src/zennomad/include/zennomad/nomadprovisioner.h b/src/zennomad/include/zennomad/nomadprovisioner.h index 750693b3f..a8368e3dc 100644 --- a/src/zennomad/include/zennomad/nomadprovisioner.h +++ b/src/zennomad/include/zennomad/nomadprovisioner.h @@ -47,7 +47,11 @@ public: /** Construct a provisioner. * @param Config Nomad connection and job configuration. * @param OrchestratorEndpoint URL of the orchestrator that remote workers announce to. */ - NomadProvisioner(const NomadConfig& Config, std::string_view OrchestratorEndpoint); + NomadProvisioner(const NomadConfig& Config, + std::string_view OrchestratorEndpoint, + std::string_view CoordinatorSession = {}, + bool CleanStart = false, + std::string_view TraceHost = {}); /** Signals the management thread to exit and stops all tracked jobs. */ ~NomadProvisioner(); @@ -83,6 +87,9 @@ private: NomadConfig m_Config; std::string m_OrchestratorEndpoint; + std::string m_CoordinatorSession; + bool m_CleanStart = false; + std::string m_TraceHost; std::unique_ptr m_Client; diff --git a/src/zennomad/nomadclient.cpp b/src/zennomad/nomadclient.cpp index 9edcde125..4bb09a930 100644 --- a/src/zennomad/nomadclient.cpp +++ b/src/zennomad/nomadclient.cpp @@ -58,7 +58,11 @@ NomadClient::Initialize() } std::string -NomadClient::BuildJobJson(const std::string& JobId, const std::string& OrchestratorEndpoint) const +NomadClient::BuildJobJson(const std::string& JobId, + const std::string& OrchestratorEndpoint, + const std::string& CoordinatorSession, + bool CleanStart, + const std::string& TraceHost) const { ZEN_TRACE_CPU("NomadClient::BuildJobJson"); @@ -94,6 +98,22 @@ NomadClient::BuildJobJson(const std::string& JobId, const std::string& Orchestra IdArg << "--instance-id=nomad-" << JobId; Args.push_back(std::string(IdArg.ToView())); } + if (!CoordinatorSession.empty()) + { + ExtendableStringBuilder<128> SessionArg; + SessionArg << "--coordinator-session=" << CoordinatorSession; + Args.push_back(std::string(SessionArg.ToView())); + } + if (CleanStart) + { + Args.push_back("--clean"); + } + if (!TraceHost.empty()) + { + ExtendableStringBuilder<128> TraceArg; + TraceArg << "--tracehost=" << TraceHost; + Args.push_back(std::string(TraceArg.ToView())); + } TaskConfig["args"] = Args; } else @@ -115,6 +135,22 @@ NomadClient::BuildJobJson(const std::string& JobId, const std::string& Orchestra IdArg << "--instance-id=nomad-" << JobId; Args.push_back(std::string(IdArg.ToView())); } + if (!CoordinatorSession.empty()) + { + ExtendableStringBuilder<128> SessionArg; + SessionArg << "--coordinator-session=" << CoordinatorSession; + Args.push_back(std::string(SessionArg.ToView())); + } + if (CleanStart) + { + Args.push_back("--clean"); + } + if (!TraceHost.empty()) + { + ExtendableStringBuilder<128> TraceArg; + TraceArg << "--tracehost=" << TraceHost; + Args.push_back(std::string(TraceArg.ToView())); + } TaskConfig["args"] = Args; } diff --git a/src/zennomad/nomadprovisioner.cpp b/src/zennomad/nomadprovisioner.cpp index 3fe9c0ac3..e07ce155e 100644 --- a/src/zennomad/nomadprovisioner.cpp +++ b/src/zennomad/nomadprovisioner.cpp @@ -14,9 +14,16 @@ namespace zen::nomad { -NomadProvisioner::NomadProvisioner(const NomadConfig& Config, std::string_view OrchestratorEndpoint) +NomadProvisioner::NomadProvisioner(const NomadConfig& Config, + std::string_view OrchestratorEndpoint, + std::string_view CoordinatorSession, + bool CleanStart, + std::string_view TraceHost) : m_Config(Config) , m_OrchestratorEndpoint(OrchestratorEndpoint) +, m_CoordinatorSession(CoordinatorSession) +, m_CleanStart(CleanStart) +, m_TraceHost(TraceHost) , m_ProcessId(static_cast(zen::GetCurrentProcessId())) , m_Log(zen::logging::Get("nomad.provisioner")) { @@ -154,7 +161,7 @@ NomadProvisioner::SubmitNewJobs() ZEN_DEBUG("submitting job '{}' (estimated: {}, target: {})", JobId, m_EstimatedCoreCount.load(), m_TargetCoreCount.load()); - const std::string JobJson = m_Client->BuildJobJson(JobId, m_OrchestratorEndpoint); + const std::string JobJson = m_Client->BuildJobJson(JobId, m_OrchestratorEndpoint, m_CoordinatorSession, m_CleanStart, m_TraceHost); NomadJobInfo JobInfo; JobInfo.Id = JobId; diff --git a/src/zenserver/compute/computeserver.cpp b/src/zenserver/compute/computeserver.cpp index 7296098e0..8cd8b4cfe 100644 --- a/src/zenserver/compute/computeserver.cpp +++ b/src/zenserver/compute/computeserver.cpp @@ -22,6 +22,8 @@ # if ZEN_WITH_HORDE # include # include +# include +# include # endif # if ZEN_WITH_NOMAD # include @@ -65,6 +67,20 @@ ZenComputeServerConfigurator::AddCliOptions(cxxopts::Options& Options) cxxopts::value(m_ServerOptions.CoordinatorEndpoint)->default_value(""), ""); + Options.add_option("compute", + "", + "coordinator-session", + "Session ID of the orchestrator (for stale-instance rejection)", + cxxopts::value(m_ServerOptions.CoordinatorSession)->default_value(""), + ""); + + Options.add_option("compute", + "", + "announce-url", + "Override URL announced to the coordinator (e.g. relay-visible endpoint)", + cxxopts::value(m_ServerOptions.AnnounceUrl)->default_value(""), + ""); + Options.add_option("compute", "", "idms", @@ -79,6 +95,20 @@ ZenComputeServerConfigurator::AddCliOptions(cxxopts::Options& Options) cxxopts::value(m_ServerOptions.EnableWorkerWebSocket)->default_value("false"), ""); + Options.add_option("compute", + "", + "provision-clean", + "Pass --clean to provisioned worker instances so they wipe state on startup", + cxxopts::value(m_ServerOptions.ProvisionClean)->default_value("false"), + ""); + + Options.add_option("compute", + "", + "provision-tracehost", + "Pass --tracehost to provisioned worker instances for remote trace collection", + cxxopts::value(m_ServerOptions.ProvisionTraceHost)->default_value(""), + ""); + # if ZEN_WITH_HORDE // Horde provisioning options Options.add_option("horde", @@ -137,6 +167,13 @@ ZenComputeServerConfigurator::AddCliOptions(cxxopts::Options& Options) cxxopts::value(m_ServerOptions.HordeConfig.MaxCores)->default_value("2048"), ""); + Options.add_option("horde", + "", + "horde-drain-grace-period", + "Grace period in seconds for draining agents before force-kill", + cxxopts::value(m_ServerOptions.HordeConfig.DrainGracePeriodSeconds)->default_value("300"), + ""); + Options.add_option("horde", "", "horde-host", @@ -164,6 +201,13 @@ ZenComputeServerConfigurator::AddCliOptions(cxxopts::Options& Options) "Port number for Zen service communication", cxxopts::value(m_ServerOptions.HordeConfig.ZenServicePort)->default_value("8558"), ""); + + Options.add_option("horde", + "", + "horde-oidctoken-exe-path", + "Path to OidcToken executable for automatic Horde authentication", + cxxopts::value(m_HordeOidcTokenExePath)->default_value(""), + ""); # endif # if ZEN_WITH_NOMAD @@ -313,6 +357,30 @@ ZenComputeServerConfigurator::ValidateOptions() # if ZEN_WITH_HORDE horde::FromString(m_ServerOptions.HordeConfig.Mode, m_HordeModeStr); horde::FromString(m_ServerOptions.HordeConfig.EncryptionMode, m_HordeEncryptionStr); + + // Set up OidcToken-based authentication if no static token was provided + if (m_ServerOptions.HordeConfig.AuthToken.empty() && !m_ServerOptions.HordeConfig.ServerUrl.empty()) + { + std::filesystem::path OidcExePath = FindOidcTokenExePath(m_HordeOidcTokenExePath); + if (!OidcExePath.empty()) + { + ZEN_INFO("using OidcToken executable for Horde authentication: {}", OidcExePath); + auto Provider = httpclientauth::CreateFromOidcTokenExecutable(OidcExePath, + m_ServerOptions.HordeConfig.ServerUrl, + /*Quiet=*/true, + /*Unattended=*/false, + /*Hidden=*/true, + /*IsHordeUrl=*/true); + if (Provider) + { + m_ServerOptions.HordeConfig.AccessTokenProvider = std::move(*Provider); + } + else + { + ZEN_WARN("OidcToken authentication failed; Horde requests will be unauthenticated"); + } + } + } # endif # if ZEN_WITH_NOMAD @@ -347,6 +415,8 @@ ZenComputeServer::Initialize(const ZenComputeServerConfig& ServerConfig, ZenServ } m_CoordinatorEndpoint = ServerConfig.CoordinatorEndpoint; + m_CoordinatorSession = ServerConfig.CoordinatorSession; + m_AnnounceUrl = ServerConfig.AnnounceUrl; m_InstanceId = ServerConfig.InstanceId; m_EnableWorkerWebSocket = ServerConfig.EnableWorkerWebSocket; @@ -379,7 +449,14 @@ ZenComputeServer::Cleanup() m_AnnounceTimer.cancel(); # if ZEN_WITH_HORDE - // Shut down Horde provisioner first — this signals all agent threads + // Disconnect the provisioner state provider before destroying the + // provisioner so the orchestrator HTTP layer cannot call into it. + if (m_OrchestratorService) + { + m_OrchestratorService->SetProvisionerStateProvider(nullptr); + } + + // Shut down Horde provisioner — this signals all agent threads // to exit and joins them before we tear down HTTP services. m_HordeProvisioner.reset(); # endif @@ -482,6 +559,7 @@ ZenComputeServer::InitializeServices(const ZenComputeServerConfig& ServerConfig) m_StatsService, ServerConfig.DataDir / "functions", ServerConfig.MaxConcurrentActions); + m_ComputeService->SetShutdownCallback([this] { RequestExit(0); }); m_FrontendService = std::make_unique(m_ContentRoot, m_StatsService, m_StatusService); @@ -506,7 +584,11 @@ ZenComputeServer::InitializeServices(const ZenComputeServerConfig& ServerConfig) OrchestratorEndpoint << '/'; } - m_NomadProvisioner = std::make_unique(NomadCfg, OrchestratorEndpoint); + m_NomadProvisioner = std::make_unique(NomadCfg, + OrchestratorEndpoint, + m_OrchestratorService->GetSessionId().ToString(), + ServerConfig.ProvisionClean, + ServerConfig.ProvisionTraceHost); } } # endif @@ -537,7 +619,14 @@ ZenComputeServer::InitializeServices(const ZenComputeServerConfig& ServerConfig) : std::filesystem::path(HordeConfig.BinariesPath); std::filesystem::path WorkingDir = ServerConfig.DataDir / "horde"; - m_HordeProvisioner = std::make_unique(HordeConfig, BinariesPath, WorkingDir, OrchestratorEndpoint); + m_HordeProvisioner = std::make_unique(HordeConfig, + BinariesPath, + WorkingDir, + OrchestratorEndpoint, + m_OrchestratorService->GetSessionId().ToString(), + ServerConfig.ProvisionClean, + ServerConfig.ProvisionTraceHost); + m_OrchestratorService->SetProvisionerStateProvider(m_HordeProvisioner.get()); } } # endif @@ -565,6 +654,10 @@ ZenComputeServer::GetInstanceId() const std::string ZenComputeServer::GetAnnounceUrl() const { + if (!m_AnnounceUrl.empty()) + { + return m_AnnounceUrl; + } return m_Http->GetServiceUri(nullptr); } @@ -635,6 +728,11 @@ ZenComputeServer::BuildAnnounceBody() << "nomad"; } + if (!m_CoordinatorSession.empty()) + { + AnnounceBody << "coordinator_session" << m_CoordinatorSession; + } + ResolveCloudMetadata(); if (m_CloudMetadata) { @@ -781,8 +879,10 @@ ZenComputeServer::ProvisionerMaintenanceTick() # if ZEN_WITH_HORDE if (m_HordeProvisioner) { - m_HordeProvisioner->SetTargetCoreCount(UINT32_MAX); + // Re-apply current target to spawn agent threads for any that have + // exited since the last tick, without overwriting a user-set target. auto Stats = m_HordeProvisioner->GetStats(); + m_HordeProvisioner->SetTargetCoreCount(Stats.TargetCoreCount); ZEN_DEBUG("Horde maintenance: target={}, estimated={}, active={}", Stats.TargetCoreCount, Stats.EstimatedCoreCount, diff --git a/src/zenserver/compute/computeserver.h b/src/zenserver/compute/computeserver.h index 38f93bc36..63db7e9b3 100644 --- a/src/zenserver/compute/computeserver.h +++ b/src/zenserver/compute/computeserver.h @@ -49,9 +49,13 @@ struct ZenComputeServerConfig : public ZenServerConfig std::string UpstreamNotificationEndpoint; std::string InstanceId; // For use in notifications std::string CoordinatorEndpoint; + std::string CoordinatorSession; ///< Session ID for stale-instance rejection + std::string AnnounceUrl; ///< Override for self-announced URL (e.g. relay-visible endpoint) std::string IdmsEndpoint; int32_t MaxConcurrentActions = 0; // 0 = auto (LogicalProcessorCount * 2) bool EnableWorkerWebSocket = false; // Use WebSocket for worker↔orchestrator link + bool ProvisionClean = false; // Pass --clean to provisioned workers + std::string ProvisionTraceHost; // Pass --tracehost to provisioned workers # if ZEN_WITH_HORDE horde::HordeConfig HordeConfig; @@ -84,6 +88,7 @@ private: # if ZEN_WITH_HORDE std::string m_HordeModeStr = "direct"; std::string m_HordeEncryptionStr = "none"; + std::string m_HordeOidcTokenExePath; # endif # if ZEN_WITH_NOMAD @@ -147,6 +152,8 @@ private: # endif SystemMetricsTracker m_MetricsTracker; std::string m_CoordinatorEndpoint; + std::string m_CoordinatorSession; + std::string m_AnnounceUrl; std::string m_InstanceId; asio::steady_timer m_AnnounceTimer{m_IoContext}; diff --git a/src/zenserver/config/config.cpp b/src/zenserver/config/config.cpp index daad154bc..6449159fd 100644 --- a/src/zenserver/config/config.cpp +++ b/src/zenserver/config/config.cpp @@ -12,6 +12,7 @@ #include #include #include +#include #include #include #include @@ -478,15 +479,27 @@ ZenServerCmdLineOptions::ApplyOptions(cxxopts::Options& options, ZenServerConfig throw std::runtime_error(fmt::format("'--snapshot-dir' ('{}') must be a directory", ServerOptions.BaseSnapshotDir)); } - ServerOptions.SystemRootDir = MakeSafeAbsolutePath(SystemRootDir); - ServerOptions.DataDir = MakeSafeAbsolutePath(DataDir); - ServerOptions.ContentDir = MakeSafeAbsolutePath(ContentDir); - ServerOptions.ConfigFile = MakeSafeAbsolutePath(ConfigFile); - ServerOptions.BaseSnapshotDir = MakeSafeAbsolutePath(BaseSnapshotDir); + SystemRootDir = ExpandEnvironmentVariables(SystemRootDir); + ServerOptions.SystemRootDir = MakeSafeAbsolutePath(SystemRootDir); + + DataDir = ExpandEnvironmentVariables(DataDir); + ServerOptions.DataDir = MakeSafeAbsolutePath(DataDir); + + ContentDir = ExpandEnvironmentVariables(ContentDir); + ServerOptions.ContentDir = MakeSafeAbsolutePath(ContentDir); + + ConfigFile = ExpandEnvironmentVariables(ConfigFile); + ServerOptions.ConfigFile = MakeSafeAbsolutePath(ConfigFile); + + BaseSnapshotDir = ExpandEnvironmentVariables(BaseSnapshotDir); + ServerOptions.BaseSnapshotDir = MakeSafeAbsolutePath(BaseSnapshotDir); + + ExpandEnvironmentVariables(SecurityConfigPath); ServerOptions.SecurityConfigPath = MakeSafeAbsolutePath(SecurityConfigPath); if (!UnixSocketPath.empty()) { + UnixSocketPath = ExpandEnvironmentVariables(UnixSocketPath); ServerOptions.HttpConfig.UnixSocketPath = MakeSafeAbsolutePath(UnixSocketPath); } diff --git a/src/zenserver/frontend/html/compute/compute.html b/src/zenserver/frontend/html/compute/compute.html deleted file mode 100644 index c07bbb692..000000000 --- a/src/zenserver/frontend/html/compute/compute.html +++ /dev/null @@ -1,925 +0,0 @@ - - - - - - Zen Compute Dashboard - - - - - - - - - -
- - - Home - Node - Orchestrator - -
Last updated: Never
- -
- - -
Action Queue
-
-
-
Pending Actions
-
-
-
Waiting to be scheduled
-
-
-
Running Actions
-
-
-
Currently executing
-
-
-
Completed Actions
-
-
-
Results available
-
-
- - -
-
Action Queue History
-
- -
-
- - -
Performance Metrics
-
-
Completion Rate
-
-
-
-
-
1 min rate
-
-
-
-
-
5 min rate
-
-
-
-
-
15 min rate
-
-
-
-
- Total Retired - - -
-
- Mean Rate - - -
-
-
- - -
Workers
-
-
Worker Status
-
- Registered Workers - - -
- -
- - -
Queues
-
-
Queue Status
-
No queues.
- -
- - -
Recent Actions
-
-
Action History
-
No actions recorded yet.
- -
- - -
System Resources
-
-
-
CPU Usage
-
-
-
Percent
-
-
-
-
- -
-
-
- Packages - - -
-
- Physical Cores - - -
-
- Logical Processors - - -
-
-
-
-
Memory
-
- Used - - -
-
- Total - - -
-
-
-
-
-
-
Disk
-
- Used - - -
-
- Total - - -
-
-
-
-
-
-
- - - - diff --git a/src/zenserver/frontend/html/compute/index.html b/src/zenserver/frontend/html/compute/index.html index 9597fd7f3..aaa09aec0 100644 --- a/src/zenserver/frontend/html/compute/index.html +++ b/src/zenserver/frontend/html/compute/index.html @@ -1 +1 @@ - \ No newline at end of file + \ No newline at end of file diff --git a/src/zenserver/frontend/html/compute/orchestrator.html b/src/zenserver/frontend/html/compute/orchestrator.html deleted file mode 100644 index d1a2bb015..000000000 --- a/src/zenserver/frontend/html/compute/orchestrator.html +++ /dev/null @@ -1,669 +0,0 @@ - - - - - - - - - - - Zen Orchestrator Dashboard - - - -
- - - Home - Node - Orchestrator - -
-
-
Last updated: Never
-
-
- Agents: - - -
-
- -
- -
-
Compute Agents
-
No agents registered.
- - - - - - - - - - - - - - - - - - -
-
-
Connected Clients
-
No clients connected.
- - - - - - - - - - - - -
-
-
-
Event History
-
- - -
-
-
-
No provisioning events recorded.
- - - - - - - - - - - -
- -
-
- - - - diff --git a/src/zenserver/frontend/html/pages/orchestrator.js b/src/zenserver/frontend/html/pages/orchestrator.js index a280fabdb..30f6a8122 100644 --- a/src/zenserver/frontend/html/pages/orchestrator.js +++ b/src/zenserver/frontend/html/pages/orchestrator.js @@ -14,6 +14,14 @@ export class Page extends ZenPage { this.set_title("orchestrator"); + // Provisioner section (hidden until data arrives) + this._prov_section = this._collapsible_section("Provisioner"); + this._prov_section._parent.inner().style.display = "none"; + this._prov_grid = null; + this._prov_target_dirty = false; + this._prov_commit_timer = null; + this._prov_last_target = null; + // Agents section const agents_section = this._collapsible_section("Compute Agents"); this._agents_host = agents_section; @@ -50,11 +58,12 @@ export class Page extends ZenPage { try { - const [agents, history, clients, client_history] = await Promise.all([ + const [agents, history, clients, client_history, prov] = await Promise.all([ new Fetcher().resource("/orch/agents").json(), new Fetcher().resource("/orch/history").param("limit", "50").json().catch(() => null), new Fetcher().resource("/orch/clients").json().catch(() => null), new Fetcher().resource("/orch/clients/history").param("limit", "50").json().catch(() => null), + new Fetcher().resource("/orch/provisioner/status").json().catch(() => null), ]); this._render_agents(agents); @@ -70,6 +79,7 @@ export class Page extends ZenPage { this._render_client_history(client_history.client_events || []); } + this._render_provisioner(prov); } catch (e) { /* service unavailable */ } } @@ -109,6 +119,7 @@ export class Page extends ZenPage { this._render_client_history(data.client_events); } + this._render_provisioner(data.provisioner); } catch (e) { /* ignore parse errors */ } }; @@ -156,7 +167,7 @@ export class Page extends ZenPage return; } - let totalCpus = 0, totalWeightedCpu = 0; + let totalCpus = 0, activeCpus = 0, totalWeightedCpu = 0; let totalMemUsed = 0, totalMemTotal = 0; let totalQueues = 0, totalPending = 0, totalRunning = 0, totalCompleted = 0; let totalRecv = 0, totalSent = 0; @@ -173,8 +184,14 @@ export class Page extends ZenPage const completed = w.actions_completed || 0; const recv = w.bytes_received || 0; const sent = w.bytes_sent || 0; + const provisioner = w.provisioner || ""; + const isProvisioned = provisioner !== ""; totalCpus += cpus; + if (w.provisioner_status === "active") + { + activeCpus += cpus; + } if (cpus > 0 && typeof cpuUsage === "number") { totalWeightedCpu += cpuUsage * cpus; @@ -209,12 +226,49 @@ export class Page extends ZenPage cell.inner().textContent = ""; cell.tag("a").text(hostname).attr("href", w.uri + "/dashboard/compute/").attr("target", "_blank"); } + + // Visual treatment based on provisioner status + const provStatus = w.provisioner_status || ""; + if (!isProvisioned) + { + row.inner().style.opacity = "0.45"; + } + else + { + const hostCell = row.get_cell(0); + const el = hostCell.inner(); + const badge = document.createElement("span"); + const badgeBase = "display:inline-block;margin-left:6px;padding:1px 5px;border-radius:8px;" + + "font-size:9px;font-weight:600;color:#fff;vertical-align:middle;"; + + if (provStatus === "draining") + { + badge.textContent = "draining"; + badge.style.cssText = badgeBase + "background:var(--theme_warn);"; + row.inner().style.opacity = "0.6"; + } + else if (provStatus === "active") + { + badge.textContent = provisioner; + badge.style.cssText = badgeBase + "background:#8957e5;"; + } + else + { + badge.textContent = "deallocated"; + badge.style.cssText = badgeBase + "background:var(--theme_fail);"; + row.inner().style.opacity = "0.45"; + } + el.appendChild(badge); + } } - // Total row + // Total row — show active / total in CPUs column + const cpuLabel = activeCpus < totalCpus + ? Friendly.sep(activeCpus) + " / " + Friendly.sep(totalCpus) + : Friendly.sep(totalCpus); const total = this._agents_table.add_row( "TOTAL", - Friendly.sep(totalCpus), + cpuLabel, "", totalMemTotal > 0 ? Friendly.bytes(totalMemUsed) + " / " + Friendly.bytes(totalMemTotal) : "-", Friendly.sep(totalQueues), @@ -305,6 +359,154 @@ export class Page extends ZenPage } } + _render_provisioner(prov) + { + const container = this._prov_section._parent.inner(); + + if (!prov || !prov.name) + { + container.style.display = "none"; + return; + } + container.style.display = ""; + + if (!this._prov_grid) + { + this._prov_grid = this._prov_section.tag().classify("grid").classify("stats-tiles"); + this._prov_tiles = {}; + + // Target cores tile with editable input + const target_tile = this._prov_grid.tag().classify("card").classify("stats-tile"); + target_tile.tag().classify("card-title").text("Target Cores"); + const target_body = target_tile.tag().classify("tile-metrics"); + const target_m = target_body.tag().classify("tile-metric").classify("tile-metric-hero"); + const input = document.createElement("input"); + input.type = "number"; + input.min = "0"; + input.style.cssText = "width:100px;padding:4px 8px;border:1px solid var(--theme_g2);border-radius:4px;" + + "background:var(--theme_g4);color:var(--theme_bright);font-size:20px;font-weight:600;text-align:right;"; + target_m.inner().appendChild(input); + target_m.tag().classify("metric-label").text("target"); + this._prov_tiles.target_input = input; + + input.addEventListener("focus", () => { this._prov_target_dirty = true; }); + input.addEventListener("input", () => { + this._prov_target_dirty = true; + if (this._prov_commit_timer) + { + clearTimeout(this._prov_commit_timer); + } + this._prov_commit_timer = setTimeout(() => this._commit_provisioner_target(), 800); + }); + input.addEventListener("keydown", (e) => { + if (e.key === "Enter") + { + if (this._prov_commit_timer) + { + clearTimeout(this._prov_commit_timer); + } + this._commit_provisioner_target(); + input.blur(); + } + }); + input.addEventListener("blur", () => { + if (this._prov_commit_timer) + { + clearTimeout(this._prov_commit_timer); + } + this._commit_provisioner_target(); + }); + + // Active cores + const active_tile = this._prov_grid.tag().classify("card").classify("stats-tile"); + active_tile.tag().classify("card-title").text("Active Cores"); + const active_body = active_tile.tag().classify("tile-metrics"); + this._prov_tiles.active = active_body; + + // Estimated cores + const est_tile = this._prov_grid.tag().classify("card").classify("stats-tile"); + est_tile.tag().classify("card-title").text("Estimated Cores"); + const est_body = est_tile.tag().classify("tile-metrics"); + this._prov_tiles.estimated = est_body; + + // Agents + const agents_tile = this._prov_grid.tag().classify("card").classify("stats-tile"); + agents_tile.tag().classify("card-title").text("Agents"); + const agents_body = agents_tile.tag().classify("tile-metrics"); + this._prov_tiles.agents = agents_body; + + // Draining + const drain_tile = this._prov_grid.tag().classify("card").classify("stats-tile"); + drain_tile.tag().classify("card-title").text("Draining"); + const drain_body = drain_tile.tag().classify("tile-metrics"); + this._prov_tiles.draining = drain_body; + } + + // Update values + const input = this._prov_tiles.target_input; + if (!this._prov_target_dirty && document.activeElement !== input) + { + input.value = prov.target_cores; + } + this._prov_last_target = prov.target_cores; + + // Re-render metric tiles (clear and recreate content) + for (const key of ["active", "estimated", "agents", "draining"]) + { + this._prov_tiles[key].inner().innerHTML = ""; + } + this._metric(this._prov_tiles.active, Friendly.sep(prov.active_cores), "cores", true); + this._metric(this._prov_tiles.estimated, Friendly.sep(prov.estimated_cores), "cores", true); + this._metric(this._prov_tiles.agents, Friendly.sep(prov.agents), "active", true); + this._metric(this._prov_tiles.draining, Friendly.sep(prov.agents_draining || 0), "agents", true); + } + + async _commit_provisioner_target() + { + const input = this._prov_tiles?.target_input; + if (!input || this._prov_committing) + { + return; + } + const value = parseInt(input.value, 10); + if (isNaN(value) || value < 0) + { + return; + } + if (value === this._prov_last_target) + { + this._prov_target_dirty = false; + return; + } + this._prov_committing = true; + try + { + const resp = await fetch("/orch/provisioner/target", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ target_cores: value }), + }); + if (resp.ok) + { + this._prov_target_dirty = false; + console.log("Target cores set to", value); + } + else + { + const text = await resp.text(); + console.error("Failed to set target cores: HTTP", resp.status, text); + } + } + catch (e) + { + console.error("Failed to set target cores:", e); + } + finally + { + this._prov_committing = false; + } + } + _metric(parent, value, label, hero = false) { const m = parent.tag().classify("tile-metric"); diff --git a/src/zenserver/main.cpp b/src/zenserver/main.cpp index c5f8724ca..108685eb9 100644 --- a/src/zenserver/main.cpp +++ b/src/zenserver/main.cpp @@ -14,7 +14,6 @@ #include #include #include -#include #include #include #include @@ -169,7 +168,12 @@ AppMain(int argc, char* argv[]) if (IsDir(ServerOptions.DataDir)) { ZEN_CONSOLE_INFO("Deleting files from '{}' ({})", ServerOptions.DataDir, DeleteReason); - DeleteDirectories(ServerOptions.DataDir); + std::error_code Ec; + DeleteDirectories(ServerOptions.DataDir, Ec); + if (Ec) + { + ZEN_WARN("could not fully clean '{}': {} (continuing anyway)", ServerOptions.DataDir, Ec.message()); + } } } -- cgit v1.2.3