diff options
| author | Stefan Boberg <[email protected]> | 2026-04-13 16:38:58 +0200 |
|---|---|---|
| committer | GitHub Enterprise <[email protected]> | 2026-04-13 16:38:58 +0200 |
| commit | f387a069967e960305cc189827093111eb5b82e7 (patch) | |
| tree | 1bb9b5c79d87ba64a8a10c23958dfa98769950ba /src | |
| parent | Merge branch 'main' into sb/tourist (diff) | |
| parent | Compute OIDC auth, async Horde agents, and orchestrator improvements (#913) (diff) | |
| download | zen-sb/tourist.tar.xz zen-sb/tourist.zip | |
Merge branch 'main' into sb/touristsb/tourist
Diffstat (limited to 'src')
63 files changed, 3985 insertions, 4159 deletions
diff --git a/src/zen/cmds/compute_cmd.cpp b/src/zen/cmds/compute_cmd.cpp new file mode 100644 index 000000000..01166cb0e --- /dev/null +++ b/src/zen/cmds/compute_cmd.cpp @@ -0,0 +1,96 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "compute_cmd.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include <zencore/compactbinary.h> +# include <zencore/logging.h> +# include <zenhttp/httpclient.h> + +using namespace std::literals; + +namespace zen { + +////////////////////////////////////////////////////////////////////////// +// ComputeRecordStartSubCmd + +ComputeRecordStartSubCmd::ComputeRecordStartSubCmd() : ZenSubCmdBase("record-start", "Start recording compute actions") +{ + SubOptions().add_option("", "u", "hosturl", ZenCmdBase::kHostUrlHelp, cxxopts::value(m_HostName)->default_value(""), "<hosturl>"); +} + +void +ComputeRecordStartSubCmd::Run(const ZenCliOptions& GlobalOptions) +{ + ZEN_UNUSED(GlobalOptions); + + m_HostName = ZenCmdBase::ResolveTargetHostSpec(m_HostName); + if (m_HostName.empty()) + { + throw OptionParseException("Unable to resolve server specification", SubOptions().help()); + } + + HttpClient Http = ZenCmdBase::CreateHttpClient(m_HostName); + if (HttpClient::Response Response = Http.Post("/compute/record/start"sv, HttpClient::KeyValueMap{}, HttpClient::KeyValueMap{})) + { + CbObject Obj = Response.AsObject(); + std::string_view Path = Obj["path"sv].AsString(); + ZEN_CONSOLE("recording started: " ZEN_BRIGHT_GREEN("{}"), Path); + } + else + { + Response.ThrowError("Failed to start recording"); + } +} + +////////////////////////////////////////////////////////////////////////// +// ComputeRecordStopSubCmd + +ComputeRecordStopSubCmd::ComputeRecordStopSubCmd() : ZenSubCmdBase("record-stop", "Stop recording compute actions") +{ + SubOptions().add_option("", "u", "hosturl", ZenCmdBase::kHostUrlHelp, cxxopts::value(m_HostName)->default_value(""), "<hosturl>"); +} + +void +ComputeRecordStopSubCmd::Run(const ZenCliOptions& GlobalOptions) +{ + ZEN_UNUSED(GlobalOptions); + + m_HostName = ZenCmdBase::ResolveTargetHostSpec(m_HostName); + if (m_HostName.empty()) + { + throw OptionParseException("Unable to resolve server specification", SubOptions().help()); + } + + HttpClient Http = ZenCmdBase::CreateHttpClient(m_HostName); + if (HttpClient::Response Response = Http.Post("/compute/record/stop"sv, HttpClient::KeyValueMap{}, HttpClient::KeyValueMap{})) + { + CbObject Obj = Response.AsObject(); + std::string_view Path = Obj["path"sv].AsString(); + ZEN_CONSOLE("recording stopped: " ZEN_BRIGHT_GREEN("{}"), Path); + } + else + { + Response.ThrowError("Failed to stop recording"); + } +} + +////////////////////////////////////////////////////////////////////////// +// ComputeCommand + +ComputeCommand::ComputeCommand() +{ + m_Options.add_options()("h,help", "Print help"); + m_Options.add_option("__hidden__", "", "subcommand", "", cxxopts::value<std::string>(m_SubCommand)->default_value(""), ""); + m_Options.parse_positional({"subcommand"}); + + AddSubCommand(m_RecordStartSubCmd); + AddSubCommand(m_RecordStopSubCmd); +} + +ComputeCommand::~ComputeCommand() = default; + +} // namespace zen + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zen/cmds/compute_cmd.h b/src/zen/cmds/compute_cmd.h new file mode 100644 index 000000000..b26f639c4 --- /dev/null +++ b/src/zen/cmds/compute_cmd.h @@ -0,0 +1,53 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "../zen.h" + +#include <string> + +#if ZEN_WITH_COMPUTE_SERVICES + +namespace zen { + +class ComputeRecordStartSubCmd : public ZenSubCmdBase +{ +public: + ComputeRecordStartSubCmd(); + void Run(const ZenCliOptions& GlobalOptions) override; + +private: + std::string m_HostName; +}; + +class ComputeRecordStopSubCmd : public ZenSubCmdBase +{ +public: + ComputeRecordStopSubCmd(); + void Run(const ZenCliOptions& GlobalOptions) override; + +private: + std::string m_HostName; +}; + +class ComputeCommand : public ZenCmdWithSubCommands +{ +public: + static constexpr char Name[] = "compute"; + static constexpr char Description[] = "Compute service operations"; + + ComputeCommand(); + ~ComputeCommand(); + + cxxopts::Options& Options() override { return m_Options; } + +private: + cxxopts::Options m_Options{Name, Description}; + std::string m_SubCommand; + ComputeRecordStartSubCmd m_RecordStartSubCmd; + ComputeRecordStopSubCmd m_RecordStopSubCmd; +}; + +} // namespace zen + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zen/cmds/exec_cmd.cpp b/src/zen/cmds/exec_cmd.cpp index 9719fce77..89bbf3638 100644 --- a/src/zen/cmds/exec_cmd.cpp +++ b/src/zen/cmds/exec_cmd.cpp @@ -23,6 +23,8 @@ #include <zenhttp/httpclient.h> #include <zenhttp/packageformat.h> +#include "../progressbar.h" + #include <EASTL/hash_map.h> #include <EASTL/hash_set.h> #include <EASTL/map.h> @@ -124,13 +126,14 @@ struct ExecSessionConfig std::vector<ExecFunctionDefinition>& FunctionList; // mutable for EmitFunctionListOnce std::string_view OrchestratorUrl; const std::filesystem::path& OutputPath; - int Offset = 0; - int Stride = 1; - int Limit = 0; - bool Verbose = false; - bool Quiet = false; - bool DumpActions = false; - bool Binary = false; + int Offset = 0; + int Stride = 1; + int Limit = 0; + bool Verbose = false; + bool Quiet = false; + bool DumpActions = false; + bool Binary = false; + ProgressBar::Mode ProgressMode = ProgressBar::Mode::PrettyScroll; }; ////////////////////////////////////////////////////////////////////////// @@ -345,8 +348,6 @@ ExecSessionRunner::DrainCompletedJobs() } m_PendingJobs.Remove(CompleteLsn); - - ZEN_CONSOLE("completed: LSN {} ({} still pending)", CompleteLsn, m_PendingJobs.GetSize()); } } } @@ -897,17 +898,20 @@ ExecSessionRunner::Run() // Then submit work items - int FailedWorkCounter = 0; - size_t RemainingWorkItems = m_Config.RecordingReader.GetActionCount(); - int SubmittedWorkItems = 0; + std::atomic<int> FailedWorkCounter{0}; + std::atomic<size_t> RemainingWorkItems{m_Config.RecordingReader.GetActionCount()}; + std::atomic<int> SubmittedWorkItems{0}; + size_t TotalWorkItems = RemainingWorkItems.load(); - ZEN_CONSOLE("submitting {} work items", RemainingWorkItems); + ProgressBar SubmitProgress(m_Config.ProgressMode, "Submit"); + SubmitProgress.UpdateState({.Task = "Submitting work items", .TotalCount = TotalWorkItems, .RemainingCount = RemainingWorkItems.load()}, + false); int OffsetCounter = m_Config.Offset; int StrideCounter = m_Config.Stride; auto ShouldSchedule = [&]() -> bool { - if (m_Config.Limit && SubmittedWorkItems >= m_Config.Limit) + if (m_Config.Limit && SubmittedWorkItems.load() >= m_Config.Limit) { // Limit reached, ignore @@ -1005,17 +1009,14 @@ ExecSessionRunner::Run() { const int32_t LsnField = EnqueueResult.Lsn; - --RemainingWorkItems; - ++SubmittedWorkItems; + size_t Remaining = --RemainingWorkItems; + int Submitted = ++SubmittedWorkItems; - if (!m_Config.Quiet) - { - ZEN_CONSOLE("submitted work item #{} - LSN {} - {}. {} remaining", - SubmittedWorkItems, - LsnField, - NiceTimeSpanMs(SubmitTimer.GetElapsedTimeMs()), - RemainingWorkItems); - } + SubmitProgress.UpdateState({.Task = "Submitting work items", + .Details = fmt::format("#{} LSN {}", Submitted, LsnField), + .TotalCount = TotalWorkItems, + .RemainingCount = Remaining}, + false); if (!m_Config.OutputPath.empty()) { @@ -1055,22 +1056,36 @@ ExecSessionRunner::Run() }, TargetParallelism); + SubmitProgress.Finish(); + // Wait until all pending work is complete + size_t TotalPendingJobs = m_PendingJobs.GetSize(); + + ProgressBar CompletionProgress(m_Config.ProgressMode, "Execute"); + while (!m_PendingJobs.IsEmpty()) { - // TODO: improve this logic - zen::Sleep(500); + size_t PendingCount = m_PendingJobs.GetSize(); + CompletionProgress.UpdateState({.Task = "Executing work items", + .Details = fmt::format("{} completed, {} remaining", TotalPendingJobs - PendingCount, PendingCount), + .TotalCount = TotalPendingJobs, + .RemainingCount = PendingCount}, + false); + + zen::Sleep(GetUpdateDelayMS(m_Config.ProgressMode)); DrainCompletedJobs(); SendOrchestratorHeartbeat(); } + CompletionProgress.Finish(); + // Write summary files WriteSummaryFiles(); - if (FailedWorkCounter) + if (FailedWorkCounter.load()) { return 1; } @@ -1423,6 +1438,16 @@ ExecCommand::OnParentOptionsParsed(const ZenCliOptions& GlobalOptions) int ExecCommand::RunSession(zen::compute::ComputeServiceSession& ComputeSession, std::string_view OrchestratorUrl) { + ProgressBar::Mode ProgressMode = ProgressBar::Mode::Pretty; + if (m_VerboseLogging) + { + ProgressMode = ProgressBar::Mode::Plain; + } + else if (m_QuietLogging) + { + ProgressMode = ProgressBar::Mode::Quiet; + } + ExecSessionConfig Config{ .Resolver = *m_ChunkResolver, .RecordingReader = *m_RecordingReader, @@ -1437,6 +1462,7 @@ ExecCommand::RunSession(zen::compute::ComputeServiceSession& ComputeSession, std .Quiet = m_QuietLogging, .DumpActions = m_DumpActions, .Binary = m_Binary, + .ProgressMode = ProgressMode, }; ExecSessionRunner Runner(ComputeSession, Config); diff --git a/src/zen/zen.cpp b/src/zen/zen.cpp index 3277eb856..bbf6b4f8a 100644 --- a/src/zen/zen.cpp +++ b/src/zen/zen.cpp @@ -9,6 +9,7 @@ #include "cmds/bench_cmd.h" #include "cmds/builds_cmd.h" #include "cmds/cache_cmd.h" +#include "cmds/compute_cmd.h" #include "cmds/copy_cmd.h" #include "cmds/dedup_cmd.h" #include "cmds/exec_cmd.h" @@ -588,7 +589,8 @@ main(int argc, char** argv) DropCommand DropCmd; DropProjectCommand ProjectDropCmd; #if ZEN_WITH_COMPUTE_SERVICES - ExecCommand ExecCmd; + ComputeCommand ComputeCmd; + ExecCommand ExecCmd; #endif // ZEN_WITH_COMPUTE_SERVICES ExportOplogCommand ExportOplogCmd; FlushCommand FlushCmd; @@ -649,6 +651,7 @@ main(int argc, char** argv) {DownCommand::Name, &DownCmd, DownCommand::Description}, {DropCommand::Name, &DropCmd, DropCommand::Description}, #if ZEN_WITH_COMPUTE_SERVICES + {ComputeCommand::Name, &ComputeCmd, ComputeCommand::Description}, {ExecCommand::Name, &ExecCmd, ExecCommand::Description}, #endif {GcStatusCommand::Name, &GcStatusCmd, GcStatusCommand::Description}, diff --git a/src/zencompute/CLAUDE.md b/src/zencompute/CLAUDE.md index 750879d5a..bb574edc2 100644 --- a/src/zencompute/CLAUDE.md +++ b/src/zencompute/CLAUDE.md @@ -218,6 +218,13 @@ Worker handler logic is extracted into private helpers (`HandleWorkersGet`, `Han **Locking discipline:** The three action maps (`m_PendingActions`, `m_RunningMap`, `m_ResultsMap`) are guarded by a single `m_ActionMapLock`. This eliminates lock-ordering concerns between maps and prevents actions from being temporarily absent from all maps during state transitions. Runner-level `m_RunningLock` in `LocalProcessRunner` / `RemoteHttpRunner` is a separate lock on a different class — unrelated to the session-level action map lock. +**Lock ordering:** When acquiring multiple session-level locks, always acquire in this order to avoid deadlocks: +1. `m_ActionMapLock` (session action maps) +2. `QueueEntry::m_Lock` (per-queue state) +3. `m_ActionHistoryLock` (action history ring) + +Never acquire an earlier lock while holding a later one (e.g. never acquire `m_ActionMapLock` while holding `QueueEntry::m_Lock`). + **Atomic fields** for counters and simple state: queue counts, `CpuUsagePercent`, `CpuSeconds`, `RetryCount`, `RunnerAction::m_ActionState`. **Update decoupling:** Runners call `PostUpdate(RunnerAction*)` rather than directly mutating service state. The scheduler thread batches and deduplicates updates. diff --git a/src/zencompute/computeservice.cpp b/src/zencompute/computeservice.cpp index aaf34cbe2..852e93fa0 100644 --- a/src/zencompute/computeservice.cpp +++ b/src/zencompute/computeservice.cpp @@ -121,6 +121,8 @@ struct ComputeServiceSession::Impl , m_LocalSubmitPool(GetLargeWorkerPool(EWorkloadType::Burst)) , m_RemoteSubmitPool(GetLargeWorkerPool(EWorkloadType::Burst)) { + m_RemoteRunnerGroup.SetWorkerPool(&m_RemoteSubmitPool); + // Create a non-expiring, non-deletable implicit queue for legacy endpoints auto Result = CreateQueue("implicit"sv, {}, {}); m_ImplicitQueueId = Result.QueueId; @@ -240,8 +242,9 @@ struct ComputeServiceSession::Impl // Recording - void StartRecording(ChunkResolver& InCidStore, const std::filesystem::path& RecordingPath); - void StopRecording(); + bool StartRecording(ChunkResolver& InCidStore, const std::filesystem::path& RecordingPath); + bool StopRecording(); + bool IsRecording() const; std::unique_ptr<ActionRecorder> m_Recorder; @@ -615,6 +618,7 @@ ComputeServiceSession::Impl::UpdateCoordinatorState() m_KnownWorkerUris.insert(UriStr); auto* NewRunner = new RemoteHttpRunner(m_ChunkResolver, m_OrchestratorBasePath, UriStr, m_RemoteSubmitPool); + NewRunner->SetRemoteHostname(Hostname); SyncWorkersToRunner(*NewRunner); m_RemoteRunnerGroup.AddRunner(NewRunner); } @@ -716,24 +720,44 @@ ComputeServiceSession::Impl::ShutdownRunners() m_RemoteRunnerGroup.Shutdown(); } -void +bool ComputeServiceSession::Impl::StartRecording(ChunkResolver& InCidStore, const std::filesystem::path& RecordingPath) { + if (m_Recorder) + { + ZEN_WARN("recording is already active"); + return false; + } + ZEN_INFO("starting recording to '{}'", RecordingPath); m_Recorder = std::make_unique<ActionRecorder>(InCidStore, RecordingPath); ZEN_INFO("started recording to '{}'", RecordingPath); + return true; } -void +bool ComputeServiceSession::Impl::StopRecording() { + if (!m_Recorder) + { + ZEN_WARN("no recording is active"); + return false; + } + ZEN_INFO("stopping recording"); m_Recorder = nullptr; ZEN_INFO("stopped recording"); + return true; +} + +bool +ComputeServiceSession::Impl::IsRecording() const +{ + return m_Recorder != nullptr; } std::vector<ComputeServiceSession::RunningActionInfo> @@ -1128,6 +1152,10 @@ ComputeServiceSession::Impl::GetCompleted(CbWriter& Cbo) Cbo.BeginObject(); Cbo << "lsn"sv << Lsn; Cbo << "state"sv << RunnerAction::ToString(Action->ActionState()); + if (!Action->FailureReason.empty()) + { + Cbo << "reason"sv << Action->FailureReason; + } Cbo.EndObject(); } }); @@ -1416,8 +1444,8 @@ ComputeServiceSession::Impl::GetQueueCompleted(int QueueId, CbWriter& Cbo) if (Queue) { - Queue->m_Lock.WithSharedLock([&] { - m_ActionMapLock.WithSharedLock([&] { + m_ActionMapLock.WithSharedLock([&] { + Queue->m_Lock.WithSharedLock([&] { for (int Lsn : Queue->FinishedLsns) { if (m_ResultsMap.contains(Lsn)) @@ -1530,12 +1558,12 @@ ComputeServiceSession::Impl::SchedulePendingActions() static Stopwatch DumpRunningTimer; auto _ = MakeGuard([&] { - ZEN_INFO("scheduled {} pending actions. {} running ({} retired), {} still pending, {} results", - ScheduledCount, - RunningCount, - m_RetiredCount.load(), - PendingCount, - ResultCount); + ZEN_DEBUG("scheduled {} pending actions. {} running ({} retired), {} still pending, {} results", + ScheduledCount, + RunningCount, + m_RetiredCount.load(), + PendingCount, + ResultCount); if (DumpRunningTimer.GetElapsedTimeMs() > 30000) { @@ -1584,13 +1612,13 @@ ComputeServiceSession::Impl::SchedulePendingActions() // Also note that the m_PendingActions list is not maintained // here, that's done periodically in SchedulePendingActions() - m_ActionMapLock.WithExclusiveLock([&] { - if (m_SessionState.load(std::memory_order_relaxed) >= SessionState::Paused) - { - return; - } + // Extract pending actions under a shared lock — we only need to read + // the map and take Ref copies. ActionState() is atomic so this is safe. + // Sorting and capacity trimming happen outside the lock to avoid + // blocking HTTP handlers on O(N log N) work with large pending queues. - if (m_PendingActions.empty()) + m_ActionMapLock.WithSharedLock([&] { + if (m_SessionState.load(std::memory_order_relaxed) >= SessionState::Paused) { return; } @@ -1610,6 +1638,7 @@ ComputeServiceSession::Impl::SchedulePendingActions() case RunnerAction::State::Completed: case RunnerAction::State::Failed: case RunnerAction::State::Abandoned: + case RunnerAction::State::Rejected: case RunnerAction::State::Cancelled: break; @@ -1620,30 +1649,30 @@ ComputeServiceSession::Impl::SchedulePendingActions() } } - // Sort by priority descending, then by LSN ascending (FIFO within same priority) - std::sort(ActionsToSchedule.begin(), ActionsToSchedule.end(), [](const Ref<RunnerAction>& A, const Ref<RunnerAction>& B) { - if (A->Priority != B->Priority) - { - return A->Priority > B->Priority; - } - return A->ActionLsn < B->ActionLsn; - }); + PendingCount = m_PendingActions.size(); + }); - if (ActionsToSchedule.size() > Capacity) + // Sort by priority descending, then by LSN ascending (FIFO within same priority) + std::sort(ActionsToSchedule.begin(), ActionsToSchedule.end(), [](const Ref<RunnerAction>& A, const Ref<RunnerAction>& B) { + if (A->Priority != B->Priority) { - ActionsToSchedule.resize(Capacity); + return A->Priority > B->Priority; } - - PendingCount = m_PendingActions.size(); + return A->ActionLsn < B->ActionLsn; }); + if (ActionsToSchedule.size() > Capacity) + { + ActionsToSchedule.resize(Capacity); + } + if (ActionsToSchedule.empty()) { _.Dismiss(); return; } - ZEN_INFO("attempting schedule of {} pending actions", ActionsToSchedule.size()); + ZEN_DEBUG("attempting schedule of {} pending actions", ActionsToSchedule.size()); Stopwatch SubmitTimer; std::vector<SubmitResult> SubmitResults = SubmitActions(ActionsToSchedule); @@ -1663,10 +1692,10 @@ ComputeServiceSession::Impl::SchedulePendingActions() } } - ZEN_INFO("scheduled {} pending actions in {} ({} rejected)", - ScheduledActionCount, - NiceTimeSpanMs(SubmitTimer.GetElapsedTimeMs()), - NotAcceptedCount); + ZEN_DEBUG("scheduled {} pending actions in {} ({} rejected)", + ScheduledActionCount, + NiceTimeSpanMs(SubmitTimer.GetElapsedTimeMs()), + NotAcceptedCount); ScheduledCount += ScheduledActionCount; PendingCount -= ScheduledActionCount; @@ -1975,6 +2004,14 @@ ComputeServiceSession::Impl::HandleActionUpdates() break; } + // Rejected — runner was at capacity, reschedule without retry cost + case RunnerAction::State::Rejected: + { + Action->ResetActionStateToPending(); + ZEN_DEBUG("action {} ({}) rescheduled after runner rejection", Action->ActionId, ActionLsn); + break; + } + // Terminal states — move to results, record history, notify queue case RunnerAction::State::Completed: case RunnerAction::State::Failed: @@ -2009,6 +2046,14 @@ ComputeServiceSession::Impl::HandleActionUpdates() MaxRetries); break; } + else + { + ZEN_WARN("action {} ({}) {} after {} retries, not rescheduling", + Action->ActionId, + ActionLsn, + RunnerAction::ToString(TerminalState), + Action->RetryCount.load(std::memory_order_relaxed)); + } } m_ActionMapLock.WithExclusiveLock([&] { @@ -2101,10 +2146,9 @@ ComputeServiceSession::Impl::SubmitActions(const std::vector<Ref<RunnerAction>>& ZEN_TRACE_CPU("ComputeServiceSession::SubmitActions"); std::vector<SubmitResult> Results(Actions.size()); - // First try submitting the batch to local runners in parallel + // First try submitting the batch to local runners std::vector<SubmitResult> LocalResults = m_LocalRunnerGroup.SubmitActions(Actions); - std::vector<size_t> RemoteIndices; std::vector<Ref<RunnerAction>> RemoteActions; for (size_t i = 0; i < Actions.size(); ++i) @@ -2115,20 +2159,40 @@ ComputeServiceSession::Impl::SubmitActions(const std::vector<Ref<RunnerAction>>& } else { - RemoteIndices.push_back(i); RemoteActions.push_back(Actions[i]); + Results[i] = SubmitResult{.IsAccepted = true, .Reason = "dispatched to remote"}; } } - // Submit remaining actions to remote runners in parallel + // Dispatch remaining actions to remote runners asynchronously. + // Mark actions as Submitting so the scheduler won't re-pick them. + // The remote runner will transition them to Running on success, or + // we mark them Failed on rejection so HandleActionUpdates retries. if (!RemoteActions.empty()) { - std::vector<SubmitResult> RemoteResults = m_RemoteRunnerGroup.SubmitActions(RemoteActions); - - for (size_t j = 0; j < RemoteIndices.size(); ++j) + for (const Ref<RunnerAction>& Action : RemoteActions) { - Results[RemoteIndices[j]] = std::move(RemoteResults[j]); + Action->SetActionState(RunnerAction::State::Submitting); } + + m_RemoteSubmitPool.ScheduleWork( + [this, RemoteActions = std::move(RemoteActions)]() { + std::vector<SubmitResult> RemoteResults = m_RemoteRunnerGroup.SubmitActions(RemoteActions); + + for (size_t j = 0; j < RemoteResults.size(); ++j) + { + if (!RemoteResults[j].IsAccepted) + { + ZEN_DEBUG("remote submission rejected for action {} ({}): {}", + RemoteActions[j]->ActionId, + RemoteActions[j]->ActionLsn, + RemoteResults[j].Reason); + + RemoteActions[j]->SetActionState(RunnerAction::State::Rejected); + } + } + }, + WorkerThreadPool::EMode::EnableBacklog); } return Results; @@ -2194,16 +2258,22 @@ ComputeServiceSession::NotifyOrchestratorChanged() m_Impl->NotifyOrchestratorChanged(); } -void +bool ComputeServiceSession::StartRecording(ChunkResolver& InResolver, const std::filesystem::path& RecordingPath) { - m_Impl->StartRecording(InResolver, RecordingPath); + return m_Impl->StartRecording(InResolver, RecordingPath); } -void +bool ComputeServiceSession::StopRecording() { - m_Impl->StopRecording(); + return m_Impl->StopRecording(); +} + +bool +ComputeServiceSession::IsRecording() const +{ + return m_Impl->IsRecording(); } ComputeServiceSession::ActionCounts diff --git a/src/zencompute/httpcomputeservice.cpp b/src/zencompute/httpcomputeservice.cpp index 6cb975dd3..8cbb25afd 100644 --- a/src/zencompute/httpcomputeservice.cpp +++ b/src/zencompute/httpcomputeservice.cpp @@ -62,6 +62,8 @@ struct HttpComputeService::Impl RwLock m_WsConnectionsLock; std::vector<Ref<WebSocketConnection>> m_WsConnections; + std::function<void()> m_ShutdownCallback; + // Metrics metrics::OperationTiming m_HttpRequests; @@ -190,6 +192,65 @@ HttpComputeService::Impl::RegisterRoutes() HttpVerb::kPost); m_Router.RegisterRoute( + "session/drain", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + if (m_ComputeService.RequestStateTransition(ComputeServiceSession::SessionState::Draining)) + { + CbObjectWriter Cbo; + Cbo << "state"sv << ToString(m_ComputeService.GetSessionState()); + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + } + + CbObjectWriter Cbo; + Cbo << "error"sv + << "Cannot transition to Draining from current state"sv; + HttpReq.WriteResponse(HttpResponseCode::Conflict, Cbo.Save()); + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "session/status", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + CbObjectWriter Cbo; + Cbo << "state"sv << ToString(m_ComputeService.GetSessionState()); + auto Counts = m_ComputeService.GetActionCounts(); + Cbo << "actions_pending"sv << Counts.Pending; + Cbo << "actions_running"sv << Counts.Running; + Cbo << "actions_completed"sv << Counts.Completed; + HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "session/sunset", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + if (m_ComputeService.RequestStateTransition(ComputeServiceSession::SessionState::Sunset)) + { + CbObjectWriter Cbo; + Cbo << "state"sv << ToString(m_ComputeService.GetSessionState()); + HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + + if (m_ShutdownCallback) + { + m_ShutdownCallback(); + } + return; + } + + CbObjectWriter Cbo; + Cbo << "error"sv + << "Cannot transition to Sunset from current state"sv; + HttpReq.WriteResponse(HttpResponseCode::Conflict, Cbo.Save()); + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( "workers", [this](HttpRouterRequest& Req) { HandleWorkersGet(Req.ServerRequest()); }, HttpVerb::kGet); @@ -506,9 +567,19 @@ HttpComputeService::Impl::RegisterRoutes() return HttpReq.WriteResponse(HttpResponseCode::Forbidden); } - m_ComputeService.StartRecording(m_CombinedResolver, m_BaseDir / "recording"); + std::filesystem::path RecordingPath = m_BaseDir / "recording"; - return HttpReq.WriteResponse(HttpResponseCode::OK); + if (!m_ComputeService.StartRecording(m_CombinedResolver, RecordingPath)) + { + CbObjectWriter Cbo; + Cbo << "error" + << "recording is already active"; + return HttpReq.WriteResponse(HttpResponseCode::Conflict, Cbo.Save()); + } + + CbObjectWriter Cbo; + Cbo << "path" << RecordingPath.string(); + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); }, HttpVerb::kPost); @@ -522,9 +593,19 @@ HttpComputeService::Impl::RegisterRoutes() return HttpReq.WriteResponse(HttpResponseCode::Forbidden); } - m_ComputeService.StopRecording(); + std::filesystem::path RecordingPath = m_BaseDir / "recording"; + + if (!m_ComputeService.StopRecording()) + { + CbObjectWriter Cbo; + Cbo << "error" + << "no recording is active"; + return HttpReq.WriteResponse(HttpResponseCode::Conflict, Cbo.Save()); + } - return HttpReq.WriteResponse(HttpResponseCode::OK); + CbObjectWriter Cbo; + Cbo << "path" << RecordingPath.string(); + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); }, HttpVerb::kPost); @@ -1066,6 +1147,12 @@ HttpComputeService::GetActionCounts() return m_Impl->m_ComputeService.GetActionCounts(); } +void +HttpComputeService::SetShutdownCallback(std::function<void()> Callback) +{ + m_Impl->m_ShutdownCallback = std::move(Callback); +} + const char* HttpComputeService::BaseUri() const { diff --git a/src/zencompute/httporchestrator.cpp b/src/zencompute/httporchestrator.cpp index d92af8716..1f51e560e 100644 --- a/src/zencompute/httporchestrator.cpp +++ b/src/zencompute/httporchestrator.cpp @@ -7,6 +7,7 @@ # include <zencompute/orchestratorservice.h> # include <zencore/compactbinarybuilder.h> # include <zencore/logging.h> +# include <zencore/session.h> # include <zencore/string.h> # include <zencore/system.h> @@ -77,10 +78,47 @@ ParseWorkerAnnouncement(const CbObjectView& Data, OrchestratorService::WorkerAnn return Ann.Id; } +static OrchestratorService::WorkerAnnotator +MakeWorkerAnnotator(IProvisionerStateProvider* Prov) +{ + if (!Prov) + { + return {}; + } + return [Prov](std::string_view WorkerId, CbObjectWriter& Cbo) { + AgentProvisioningStatus Status = Prov->GetAgentStatus(WorkerId); + if (Status != AgentProvisioningStatus::Unknown) + { + const char* StatusStr = (Status == AgentProvisioningStatus::Draining) ? "draining" : "active"; + Cbo << "provisioner_status" << std::string_view(StatusStr); + } + }; +} + +bool +HttpOrchestratorService::ValidateCoordinatorSession(const CbObjectView& Data, std::string_view WorkerId) +{ + std::string_view SessionStr = Data["coordinator_session"].AsString(""); + if (SessionStr.empty()) + { + return true; // backwards compatibility: accept announcements without a session + } + Oid Session = Oid::TryFromHexString(SessionStr); + if (Session == m_SessionId) + { + return true; + } + ZEN_WARN("rejecting stale announcement from '{}' (session {} != {})", WorkerId, SessionStr, m_SessionId.ToString()); + return false; +} + HttpOrchestratorService::HttpOrchestratorService(std::filesystem::path DataDir, bool EnableWorkerWebSocket) : m_Service(std::make_unique<OrchestratorService>(std::move(DataDir), EnableWorkerWebSocket)) , m_Hostname(GetMachineName()) { + m_SessionId = zen::GetSessionId(); + ZEN_INFO("orchestrator session id: {}", m_SessionId.ToString()); + m_Router.AddMatcher("workerid", [](std::string_view Segment) { return IsValidWorkerId(Segment); }); m_Router.AddMatcher("clientid", [](std::string_view Segment) { return IsValidWorkerId(Segment); }); @@ -95,13 +133,17 @@ HttpOrchestratorService::HttpOrchestratorService(std::filesystem::path DataDir, [this](HttpRouterRequest& Req) { CbObjectWriter Cbo; Cbo << "hostname" << std::string_view(m_Hostname); + Cbo << "session_id" << m_SessionId.ToString(); Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Cbo.Save()); }, HttpVerb::kGet); m_Router.RegisterRoute( "provision", - [this](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK, m_Service->GetWorkerList()); }, + [this](HttpRouterRequest& Req) { + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, + m_Service->GetWorkerList(MakeWorkerAnnotator(m_Provisioner.load(std::memory_order_acquire)))); + }, HttpVerb::kPost); m_Router.RegisterRoute( @@ -122,6 +164,11 @@ HttpOrchestratorService::HttpOrchestratorService(std::filesystem::path DataDir, "characters and uri must start with http:// or https://"); } + if (!ValidateCoordinatorSession(Data, WorkerId)) + { + return HttpReq.WriteResponse(HttpResponseCode::Conflict, HttpContentType::kText, "Stale coordinator session"); + } + m_Service->AnnounceWorker(Ann); HttpReq.WriteResponse(HttpResponseCode::OK); @@ -135,7 +182,10 @@ HttpOrchestratorService::HttpOrchestratorService(std::filesystem::path DataDir, m_Router.RegisterRoute( "agents", - [this](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK, m_Service->GetWorkerList()); }, + [this](HttpRouterRequest& Req) { + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, + m_Service->GetWorkerList(MakeWorkerAnnotator(m_Provisioner.load(std::memory_order_acquire)))); + }, HttpVerb::kGet); m_Router.RegisterRoute( @@ -241,6 +291,59 @@ HttpOrchestratorService::HttpOrchestratorService(std::filesystem::path DataDir, }, HttpVerb::kGet); + // Provisioner endpoints + + m_Router.RegisterRoute( + "provisioner/status", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + CbObjectWriter Cbo; + if (IProvisionerStateProvider* Prov = m_Provisioner.load(std::memory_order_acquire)) + { + Cbo << "name" << Prov->GetName(); + Cbo << "target_cores" << Prov->GetTargetCoreCount(); + Cbo << "estimated_cores" << Prov->GetEstimatedCoreCount(); + Cbo << "active_cores" << Prov->GetActiveCoreCount(); + Cbo << "agents" << Prov->GetAgentCount(); + Cbo << "agents_draining" << Prov->GetDrainingAgentCount(); + } + HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "provisioner/target", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + CbObject Data = HttpReq.ReadPayloadObject(); + int32_t Cores = Data["target_cores"].AsInt32(-1); + + ZEN_INFO("provisioner/target: received request (target_cores={}, payload_valid={})", Cores, Data ? true : false); + + if (Cores < 0) + { + ZEN_WARN("provisioner/target: bad request (target_cores={})", Cores); + return HttpReq.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "Missing or invalid target_cores field"); + } + + IProvisionerStateProvider* Prov = m_Provisioner.load(std::memory_order_acquire); + if (!Prov) + { + ZEN_WARN("provisioner/target: no provisioner configured"); + return HttpReq.WriteResponse(HttpResponseCode::NotFound, HttpContentType::kText, "No provisioner configured"); + } + + ZEN_INFO("provisioner/target: setting target to {} cores", Cores); + Prov->SetTargetCoreCount(static_cast<uint32_t>(Cores)); + + CbObjectWriter Cbo; + Cbo << "target_cores" << Prov->GetTargetCoreCount(); + HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + }, + HttpVerb::kPost); + // Client tracking endpoints m_Router.RegisterRoute( @@ -411,6 +514,13 @@ HttpOrchestratorService::HandleRequest(HttpServerRequest& Request) } } +void +HttpOrchestratorService::SetProvisionerStateProvider(IProvisionerStateProvider* Provider) +{ + m_Provisioner.store(Provider, std::memory_order_release); + m_Service->SetProvisionerStateProvider(Provider); +} + ////////////////////////////////////////////////////////////////////////// // // IWebSocketHandler @@ -488,6 +598,11 @@ HttpOrchestratorService::HandleWorkerWebSocketMessage(const WebSocketMessage& Ms return {}; } + if (!ValidateCoordinatorSession(Data, WorkerId)) + { + return {}; + } + m_Service->AnnounceWorker(Ann); return std::string(WorkerId); } @@ -563,7 +678,7 @@ HttpOrchestratorService::PushThreadFunction() } // Build combined JSON with worker list, provisioning history, clients, and client history - CbObject WorkerList = m_Service->GetWorkerList(); + CbObject WorkerList = m_Service->GetWorkerList(MakeWorkerAnnotator(m_Provisioner.load(std::memory_order_acquire))); CbObject History = m_Service->GetProvisioningHistory(50); CbObject ClientList = m_Service->GetClientList(); CbObject ClientHistory = m_Service->GetClientHistory(50); @@ -615,6 +730,20 @@ HttpOrchestratorService::PushThreadFunction() JsonBuilder.Append(ClientHistoryJsonView.substr(1, ClientHistoryJsonView.size() - 2)); } + // Emit provisioner stats if available + if (IProvisionerStateProvider* Prov = m_Provisioner.load(std::memory_order_acquire)) + { + JsonBuilder.Append( + fmt::format(",\"provisioner\":{{\"name\":\"{}\",\"target_cores\":{},\"estimated_cores\":{}" + ",\"active_cores\":{},\"agents\":{},\"agents_draining\":{}}}", + Prov->GetName(), + Prov->GetTargetCoreCount(), + Prov->GetEstimatedCoreCount(), + Prov->GetActiveCoreCount(), + Prov->GetAgentCount(), + Prov->GetDrainingAgentCount())); + } + JsonBuilder.Append("}"); std::string_view Json = JsonBuilder.ToView(); diff --git a/src/zencompute/include/zencompute/computeservice.h b/src/zencompute/include/zencompute/computeservice.h index ad556f546..97de4321a 100644 --- a/src/zencompute/include/zencompute/computeservice.h +++ b/src/zencompute/include/zencompute/computeservice.h @@ -279,7 +279,7 @@ public: // sized to match RunnerAction::State::_Count but we can't use the enum here // for dependency reasons, so just use a fixed size array and static assert in // the implementation file - uint64_t Timestamps[9] = {}; + uint64_t Timestamps[10] = {}; }; [[nodiscard]] std::vector<ActionHistoryEntry> GetActionHistory(int Limit = 100); @@ -305,8 +305,9 @@ public: // Recording - void StartRecording(ChunkResolver& InResolver, const std::filesystem::path& RecordingPath); - void StopRecording(); + bool StartRecording(ChunkResolver& InResolver, const std::filesystem::path& RecordingPath); + bool StopRecording(); + bool IsRecording() const; private: void PostUpdate(RunnerAction* Action); diff --git a/src/zencompute/include/zencompute/httpcomputeservice.h b/src/zencompute/include/zencompute/httpcomputeservice.h index db3fce3c2..32f54f293 100644 --- a/src/zencompute/include/zencompute/httpcomputeservice.h +++ b/src/zencompute/include/zencompute/httpcomputeservice.h @@ -35,6 +35,10 @@ public: void Shutdown(); + /** Set a callback to be invoked when the session/sunset endpoint is hit. + * Typically wired to HttpServer::RequestExit() to shut down the process. */ + void SetShutdownCallback(std::function<void()> Callback); + [[nodiscard]] ComputeServiceSession::ActionCounts GetActionCounts(); const char* BaseUri() const override; diff --git a/src/zencompute/include/zencompute/httporchestrator.h b/src/zencompute/include/zencompute/httporchestrator.h index 58b2c9152..ef0a1269a 100644 --- a/src/zencompute/include/zencompute/httporchestrator.h +++ b/src/zencompute/include/zencompute/httporchestrator.h @@ -2,10 +2,12 @@ #pragma once +#include <zencompute/provisionerstate.h> #include <zencompute/zencompute.h> #include <zencore/logging.h> #include <zencore/thread.h> +#include <zencore/uid.h> #include <zenhttp/httpserver.h> #include <zenhttp/websocket.h> @@ -65,6 +67,16 @@ public: */ void Shutdown(); + /** Return the session ID generated at construction time. Provisioners + * pass this to spawned workers so the orchestrator can reject stale + * announcements from previous sessions. */ + Oid GetSessionId() const { return m_SessionId; } + + /** Register a provisioner whose target core count can be read and changed + * via the orchestrator HTTP API and dashboard. Caller retains ownership; + * the provider must outlive this service. */ + void SetProvisionerStateProvider(IProvisionerStateProvider* Provider); + virtual const char* BaseUri() const override; virtual void HandleRequest(HttpServerRequest& Request) override; @@ -81,6 +93,11 @@ private: std::unique_ptr<OrchestratorService> m_Service; std::string m_Hostname; + Oid m_SessionId; + bool ValidateCoordinatorSession(const CbObjectView& Data, std::string_view WorkerId); + + std::atomic<IProvisionerStateProvider*> m_Provisioner{nullptr}; + // WebSocket push #if ZEN_WITH_WEBSOCKETS diff --git a/src/zencompute/include/zencompute/orchestratorservice.h b/src/zencompute/include/zencompute/orchestratorservice.h index 549ff8e3c..2c49e22df 100644 --- a/src/zencompute/include/zencompute/orchestratorservice.h +++ b/src/zencompute/include/zencompute/orchestratorservice.h @@ -6,6 +6,7 @@ #if ZEN_WITH_COMPUTE_SERVICES +# include <zencompute/provisionerstate.h> # include <zencore/compactbinary.h> # include <zencore/compactbinarybuilder.h> # include <zencore/logbase.h> @@ -90,9 +91,16 @@ public: std::string Hostname; }; - CbObject GetWorkerList(); + /** Per-worker callback invoked during GetWorkerList serialization. + * The callback receives the worker ID and a CbObjectWriter positioned + * inside the worker's object, allowing the caller to append extra fields. */ + using WorkerAnnotator = std::function<void(std::string_view WorkerId, CbObjectWriter& Cbo)>; + + CbObject GetWorkerList(const WorkerAnnotator& Annotate = {}); void AnnounceWorker(const WorkerAnnouncement& Announcement); + void SetProvisionerStateProvider(IProvisionerStateProvider* Provider); + bool IsWorkerWebSocketEnabled() const; void SetWorkerWebSocketConnected(std::string_view WorkerId, bool Connected); @@ -171,6 +179,8 @@ private: LoggerRef m_Log{"compute.orchestrator"}; bool m_EnableWorkerWebSocket = false; + std::atomic<IProvisionerStateProvider*> m_Provisioner{nullptr}; + std::thread m_ProbeThread; std::atomic<bool> m_ProbeThreadEnabled{true}; Event m_ProbeThreadEvent; diff --git a/src/zencompute/include/zencompute/provisionerstate.h b/src/zencompute/include/zencompute/provisionerstate.h new file mode 100644 index 000000000..e9af8a635 --- /dev/null +++ b/src/zencompute/include/zencompute/provisionerstate.h @@ -0,0 +1,38 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <cstdint> +#include <string_view> + +namespace zen::compute { + +/** Per-agent provisioning status as seen by the provisioner. */ +enum class AgentProvisioningStatus +{ + Unknown, ///< Not known to the provisioner + Active, ///< Running and allocated + Draining, ///< Being gracefully deprovisioned +}; + +/** Abstract interface for querying and controlling a provisioner from the HTTP layer. + * This decouples the orchestrator service from specific provisioner implementations. */ +class IProvisionerStateProvider +{ +public: + virtual ~IProvisionerStateProvider() = default; + + virtual std::string_view GetName() const = 0; ///< e.g. "horde", "nomad" + virtual uint32_t GetTargetCoreCount() const = 0; + virtual uint32_t GetEstimatedCoreCount() const = 0; + virtual uint32_t GetActiveCoreCount() const = 0; + virtual uint32_t GetAgentCount() const = 0; + virtual uint32_t GetDrainingAgentCount() const { return 0; } + virtual void SetTargetCoreCount(uint32_t Count) = 0; + + /** Return the provisioning status for a worker by its orchestrator ID + * (e.g. "horde-{LeaseId}"). Returns Unknown if the ID is not recognized. */ + virtual AgentProvisioningStatus GetAgentStatus(std::string_view /*WorkerId*/) const { return AgentProvisioningStatus::Unknown; } +}; + +} // namespace zen::compute diff --git a/src/zencompute/orchestratorservice.cpp b/src/zencompute/orchestratorservice.cpp index 9ea695305..aee8fa63a 100644 --- a/src/zencompute/orchestratorservice.cpp +++ b/src/zencompute/orchestratorservice.cpp @@ -31,7 +31,7 @@ OrchestratorService::~OrchestratorService() } CbObject -OrchestratorService::GetWorkerList() +OrchestratorService::GetWorkerList(const WorkerAnnotator& Annotate) { ZEN_TRACE_CPU("OrchestratorService::GetWorkerList"); CbObjectWriter Cbo; @@ -71,6 +71,10 @@ OrchestratorService::GetWorkerList() Cbo << "ws_connected" << true; } Cbo << "dt" << Worker.LastSeen.GetElapsedTimeMs(); + if (Annotate) + { + Annotate(WorkerId, Cbo); + } Cbo.EndObject(); } }); @@ -144,6 +148,12 @@ OrchestratorService::AnnounceWorker(const WorkerAnnouncement& Ann) } } +void +OrchestratorService::SetProvisionerStateProvider(IProvisionerStateProvider* Provider) +{ + m_Provisioner.store(Provider, std::memory_order_release); +} + bool OrchestratorService::IsWorkerWebSocketEnabled() const { @@ -607,6 +617,14 @@ OrchestratorService::ProbeThreadFunction() continue; } + // Check if the provisioner knows this worker is draining — if so, + // unreachability is expected and should not be logged as a warning. + bool IsDraining = false; + if (IProvisionerStateProvider* Prov = m_Provisioner.load(std::memory_order_acquire)) + { + IsDraining = Prov->GetAgentStatus(Snap.Id) == AgentProvisioningStatus::Draining; + } + ReachableState NewState = ReachableState::Unreachable; try @@ -621,7 +639,10 @@ OrchestratorService::ProbeThreadFunction() } catch (const std::exception& Ex) { - ZEN_WARN("probe failed for worker {} ({}): {}", Snap.Id, Snap.Uri, Ex.what()); + if (!IsDraining) + { + ZEN_WARN("probe failed for worker {} ({}): {}", Snap.Id, Snap.Uri, Ex.what()); + } } ReachableState PrevState = ReachableState::Unknown; @@ -646,6 +667,10 @@ OrchestratorService::ProbeThreadFunction() { ZEN_INFO("worker {} ({}) is now reachable", Snap.Id, Snap.Uri); } + else if (IsDraining) + { + ZEN_INFO("worker {} ({}) shut down (draining)", Snap.Id, Snap.Uri); + } else if (PrevState == ReachableState::Reachable) { ZEN_WARN("worker {} ({}) is no longer reachable", Snap.Id, Snap.Uri); diff --git a/src/zencompute/runners/functionrunner.cpp b/src/zencompute/runners/functionrunner.cpp index 67e12b84e..ab22c6363 100644 --- a/src/zencompute/runners/functionrunner.cpp +++ b/src/zencompute/runners/functionrunner.cpp @@ -6,9 +6,15 @@ # include <zencore/compactbinary.h> # include <zencore/filesystem.h> +# include <zencore/fmtutils.h> +# include <zencore/logging.h> +# include <zencore/string.h> +# include <zencore/timer.h> # include <zencore/trace.h> +# include <zencore/workthreadpool.h> # include <fmt/format.h> +# include <future> # include <vector> namespace zen::compute { @@ -118,23 +124,34 @@ std::vector<SubmitResult> BaseRunnerGroup::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions) { ZEN_TRACE_CPU("BaseRunnerGroup::SubmitActions"); - RwLock::SharedLockScope _(m_RunnersLock); - const int RunnerCount = gsl::narrow<int>(m_Runners.size()); + // Snapshot runners and query capacity under the lock, then release + // before submitting — HTTP submissions to remote runners can take + // hundreds of milliseconds and we must not hold m_RunnersLock during I/O. - if (RunnerCount == 0) - { - return std::vector<SubmitResult>(Actions.size(), SubmitResult{.IsAccepted = false, .Reason = "No runners available"}); - } + std::vector<Ref<FunctionRunner>> Runners; + std::vector<size_t> Capacities; + std::vector<std::vector<Ref<RunnerAction>>> PerRunnerActions; + size_t TotalCapacity = 0; - // Query capacity per runner and compute total - std::vector<size_t> Capacities(RunnerCount); - size_t TotalCapacity = 0; + m_RunnersLock.WithSharedLock([&] { + const int RunnerCount = gsl::narrow<int>(m_Runners.size()); + Runners.assign(m_Runners.begin(), m_Runners.end()); + Capacities.resize(RunnerCount); + PerRunnerActions.resize(RunnerCount); - for (int i = 0; i < RunnerCount; ++i) + for (int i = 0; i < RunnerCount; ++i) + { + Capacities[i] = Runners[i]->QueryCapacity(); + TotalCapacity += Capacities[i]; + } + }); + + const int RunnerCount = gsl::narrow<int>(Runners.size()); + + if (RunnerCount == 0) { - Capacities[i] = m_Runners[i]->QueryCapacity(); - TotalCapacity += Capacities[i]; + return std::vector<SubmitResult>(Actions.size(), SubmitResult{.IsAccepted = false, .Reason = "No runners available"}); } if (TotalCapacity == 0) @@ -143,9 +160,8 @@ BaseRunnerGroup::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions) } // Distribute actions across runners proportionally to their available capacity - std::vector<std::vector<Ref<RunnerAction>>> PerRunnerActions(RunnerCount); - std::vector<size_t> ActionRunnerIndex(Actions.size()); - size_t ActionIdx = 0; + std::vector<size_t> ActionRunnerIndex(Actions.size()); + size_t ActionIdx = 0; for (int i = 0; i < RunnerCount; ++i) { @@ -176,14 +192,74 @@ BaseRunnerGroup::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions) } } - // Submit batches per runner + // Submit batches per runner — in parallel when a worker pool is available + std::vector<std::vector<SubmitResult>> PerRunnerResults(RunnerCount); + int ActiveRunnerCount = 0; for (int i = 0; i < RunnerCount; ++i) { if (!PerRunnerActions[i].empty()) { - PerRunnerResults[i] = m_Runners[i]->SubmitActions(PerRunnerActions[i]); + ++ActiveRunnerCount; + } + } + + static constexpr uint64_t SubmitWarnThresholdMs = 500; + + auto SubmitToRunner = [&](int RunnerIndex) { + auto& Runner = Runners[RunnerIndex]; + Runner->m_LastSubmitStats.Reset(); + + Stopwatch Timer; + + PerRunnerResults[RunnerIndex] = Runner->SubmitActions(PerRunnerActions[RunnerIndex]); + + uint64_t ElapsedMs = Timer.GetElapsedTimeMs(); + if (ElapsedMs >= SubmitWarnThresholdMs) + { + size_t Attachments = Runner->m_LastSubmitStats.TotalAttachments.load(std::memory_order_relaxed); + uint64_t AttachmentBytes = Runner->m_LastSubmitStats.TotalAttachmentBytes.load(std::memory_order_relaxed); + + ZEN_WARN("submit of {} actions ({} attachments, {}) to '{}' took {}ms", + PerRunnerActions[RunnerIndex].size(), + Attachments, + NiceBytes(AttachmentBytes), + Runner->GetDisplayName(), + ElapsedMs); + } + }; + + if (m_WorkerPool && ActiveRunnerCount > 1) + { + std::vector<std::future<void>> Futures(RunnerCount); + + for (int i = 0; i < RunnerCount; ++i) + { + if (!PerRunnerActions[i].empty()) + { + std::packaged_task<void()> Task([&SubmitToRunner, i]() { SubmitToRunner(i); }); + + Futures[i] = m_WorkerPool->EnqueueTask(std::move(Task), WorkerThreadPool::EMode::EnableBacklog); + } + } + + for (int i = 0; i < RunnerCount; ++i) + { + if (Futures[i].valid()) + { + Futures[i].get(); + } + } + } + else + { + for (int i = 0; i < RunnerCount; ++i) + { + if (!PerRunnerActions[i].empty()) + { + SubmitToRunner(i); + } } } @@ -309,10 +385,11 @@ RunnerAction::RetractAction() bool RunnerAction::ResetActionStateToPending() { - // Only allow reset from Failed, Abandoned, or Retracted states + // Only allow reset from Failed, Abandoned, Rejected, or Retracted states State CurrentState = m_ActionState.load(); - if (CurrentState != State::Failed && CurrentState != State::Abandoned && CurrentState != State::Retracted) + if (CurrentState != State::Failed && CurrentState != State::Abandoned && CurrentState != State::Rejected && + CurrentState != State::Retracted) { return false; } @@ -333,11 +410,12 @@ RunnerAction::ResetActionStateToPending() // Clear execution fields ExecutionLocation.clear(); + FailureReason.clear(); CpuUsagePercent.store(-1.0f, std::memory_order_relaxed); CpuSeconds.store(0.0f, std::memory_order_relaxed); - // Increment retry count (skip for Retracted — nothing failed) - if (CurrentState != State::Retracted) + // Increment retry count (skip for Retracted/Rejected — nothing failed) + if (CurrentState != State::Retracted && CurrentState != State::Rejected) { RetryCount.fetch_add(1, std::memory_order_relaxed); } diff --git a/src/zencompute/runners/functionrunner.h b/src/zencompute/runners/functionrunner.h index 56c3f3af0..449f0e228 100644 --- a/src/zencompute/runners/functionrunner.h +++ b/src/zencompute/runners/functionrunner.h @@ -10,6 +10,10 @@ # include <filesystem> # include <vector> +namespace zen { +class WorkerThreadPool; +} + namespace zen::compute { struct SubmitResult @@ -37,6 +41,22 @@ public: [[nodiscard]] virtual bool IsHealthy() = 0; [[nodiscard]] virtual size_t QueryCapacity(); [[nodiscard]] virtual std::vector<SubmitResult> SubmitActions(const std::vector<Ref<RunnerAction>>& Actions); + [[nodiscard]] virtual std::string_view GetDisplayName() const { return "local"; } + + // Accumulated stats from the most recent SubmitActions call. + // Reset before each call, populated by the runner implementation. + struct SubmitStats + { + std::atomic<size_t> TotalAttachments{0}; + std::atomic<uint64_t> TotalAttachmentBytes{0}; + + void Reset() + { + TotalAttachments.store(0, std::memory_order_relaxed); + TotalAttachmentBytes.store(0, std::memory_order_relaxed); + } + }; + SubmitStats m_LastSubmitStats; // Best-effort cancellation of a specific in-flight action. Returns true if the // cancellation signal was successfully sent. The action will transition to Cancelled @@ -68,6 +88,8 @@ public: bool CancelAction(int ActionLsn); void CancelRemoteQueue(int QueueId); + void SetWorkerPool(WorkerThreadPool* Pool) { m_WorkerPool = Pool; } + size_t GetRunnerCount() { return m_RunnersLock.WithSharedLock([this] { return m_Runners.size(); }); @@ -79,6 +101,7 @@ protected: RwLock m_RunnersLock; std::vector<Ref<FunctionRunner>> m_Runners; std::atomic<int> m_NextSubmitIndex{0}; + WorkerThreadPool* m_WorkerPool = nullptr; }; /** Typed RunnerGroup that adds type-safe runner addition and predicate-based removal. @@ -151,6 +174,7 @@ struct RunnerAction : public RefCounted CbObject ActionObj; int Priority = 0; std::string ExecutionLocation; // "local" or remote hostname + std::string FailureReason; // human-readable reason when action fails (empty on success) // CPU usage and total CPU time of the running process, sampled periodically by the local runner. // CpuUsagePercent: -1.0 means not yet sampled; >=0.0 is the most recent reading as a percentage. @@ -168,6 +192,7 @@ struct RunnerAction : public RefCounted Completed, // Finished successfully with results available Failed, // Execution failed (transient error, eligible for retry) Abandoned, // Infrastructure termination (e.g. spot eviction, session abandon) + Rejected, // Runner declined (e.g. at capacity) — rescheduled without retry cost Cancelled, // Intentional user cancellation (never retried) Retracted, // Pulled back for rescheduling on a different runner (no retry cost) _Count @@ -194,6 +219,8 @@ struct RunnerAction : public RefCounted return "Failed"; case State::Abandoned: return "Abandoned"; + case State::Rejected: + return "Rejected"; case State::Cancelled: return "Cancelled"; case State::Retracted: diff --git a/src/zencompute/runners/linuxrunner.cpp b/src/zencompute/runners/linuxrunner.cpp index 9055005d9..ce5bbdcc8 100644 --- a/src/zencompute/runners/linuxrunner.cpp +++ b/src/zencompute/runners/linuxrunner.cpp @@ -430,7 +430,8 @@ LinuxProcessRunner::SubmitAction(Ref<RunnerAction> Action) if (ChildPid == 0) { - // Child process + // Child process — lower priority so workers don't starve the main server + nice(5); if (m_Sandboxed) { @@ -481,7 +482,8 @@ LinuxProcessRunner::SubmitAction(Ref<RunnerAction> Action) // Clean up the sandbox in the background m_DeferredDeleter.Enqueue(Action->ActionLsn, std::move(Prepared->SandboxPath)); - ZEN_ERROR("Sandbox setup failed for action {}: {}", Action->ActionLsn, ErrBuf); + Action->FailureReason = fmt::format("sandbox setup failed: {}", ErrBuf); + ZEN_ERROR("action {} ({}): {}", Action->ActionId, Action->ActionLsn, Action->FailureReason); Action->SetActionState(RunnerAction::State::Failed); return SubmitResult{.IsAccepted = false}; diff --git a/src/zencompute/runners/localrunner.cpp b/src/zencompute/runners/localrunner.cpp index 1b748c0e5..96cbdc134 100644 --- a/src/zencompute/runners/localrunner.cpp +++ b/src/zencompute/runners/localrunner.cpp @@ -357,14 +357,21 @@ LocalProcessRunner::ManifestWorker(const WorkerDesc& Worker) std::filesystem::path WorkerDir = m_WorkerPath / fmt::format("runner_{}", Worker.WorkerId); - if (!std::filesystem::exists(WorkerDir)) + // worker.zcb is written as the last step of ManifestWorker, so its presence + // indicates a complete manifest. If the directory exists but the marker is + // missing, a previous manifest was interrupted and we need to start over. + bool NeedsManifest = !std::filesystem::exists(WorkerDir / "worker.zcb"); + + if (NeedsManifest) { _.ReleaseNow(); RwLock::ExclusiveLockScope $(m_WorkerLock); - if (!std::filesystem::exists(WorkerDir)) + if (!std::filesystem::exists(WorkerDir / "worker.zcb")) { + std::error_code Ec; + std::filesystem::remove_all(WorkerDir, Ec); ManifestWorker(Worker.Descriptor, WorkerDir, [](const IoHash&, CompressedBuffer&) {}); } } @@ -673,9 +680,15 @@ LocalProcessRunner::ProcessCompletedActions(std::vector<Ref<RunningAction>>& Com } catch (std::exception& Ex) { - ZEN_ERROR("Encountered failure while gathering outputs for action lsn {}, '{}'", ActionLsn, Ex.what()); + Running->Action->FailureReason = fmt::format("exception gathering outputs: {}", Ex.what()); + ZEN_ERROR("action {} ({}) failed: {}", Running->Action->ActionId, ActionLsn, Running->Action->FailureReason); } } + else + { + Running->Action->FailureReason = fmt::format("process exited with code {}", Running->ExitCode); + ZEN_WARN("action {} ({}) failed: {}", Running->Action->ActionId, ActionLsn, Running->Action->FailureReason); + } // Failed - clean up the sandbox in the background. diff --git a/src/zencompute/runners/macrunner.cpp b/src/zencompute/runners/macrunner.cpp index c2ccca9a6..13c01d988 100644 --- a/src/zencompute/runners/macrunner.cpp +++ b/src/zencompute/runners/macrunner.cpp @@ -211,7 +211,8 @@ MacProcessRunner::SubmitAction(Ref<RunnerAction> Action) if (ChildPid == 0) { - // Child process + // Child process — lower priority so workers don't starve the main server + nice(5); if (m_Sandboxed) { @@ -281,7 +282,8 @@ MacProcessRunner::SubmitAction(Ref<RunnerAction> Action) // Clean up the sandbox in the background m_DeferredDeleter.Enqueue(Action->ActionLsn, std::move(Prepared->SandboxPath)); - ZEN_ERROR("Sandbox setup failed for action {}: {}", Action->ActionLsn, ErrBuf); + Action->FailureReason = fmt::format("sandbox setup failed: {}", ErrBuf); + ZEN_ERROR("action {} ({}): {}", Action->ActionId, Action->ActionLsn, Action->FailureReason); Action->SetActionState(RunnerAction::State::Failed); return SubmitResult{.IsAccepted = false}; diff --git a/src/zencompute/runners/managedrunner.cpp b/src/zencompute/runners/managedrunner.cpp index e4a7ba388..a4f586852 100644 --- a/src/zencompute/runners/managedrunner.cpp +++ b/src/zencompute/runners/managedrunner.cpp @@ -128,7 +128,7 @@ ManagedProcessRunner::SubmitAction(Ref<RunnerAction> Action) CreateProcOptions Options; Options.WorkingDirectory = &Prepared->SandboxPath; - Options.Flags = CreateProcOptions::Flag_NoConsole; + Options.Flags = CreateProcOptions::Flag_NoConsole | CreateProcOptions::Flag_BelowNormalPriority; Options.Environment = std::move(EnvPairs); const int32_t ActionLsn = Prepared->ActionLsn; diff --git a/src/zencompute/runners/remotehttprunner.cpp b/src/zencompute/runners/remotehttprunner.cpp index ce6a81173..55f78fdd6 100644 --- a/src/zencompute/runners/remotehttprunner.cpp +++ b/src/zencompute/runners/remotehttprunner.cpp @@ -20,6 +20,7 @@ # include <zenstore/cidstore.h> # include <span> +# include <unordered_set> ////////////////////////////////////////////////////////////////////////// @@ -38,6 +39,7 @@ RemoteHttpRunner::RemoteHttpRunner(ChunkResolver& InChunkResolver, , m_ChunkResolver{InChunkResolver} , m_WorkerPool{InWorkerPool} , m_HostName{HostName} +, m_DisplayName{HostName} , m_BaseUrl{fmt::format("{}/compute", HostName)} , m_Http(m_BaseUrl) , m_InstanceId(Oid::NewOid()) @@ -59,6 +61,15 @@ RemoteHttpRunner::RemoteHttpRunner(ChunkResolver& InChunkResolver, m_MonitorThread = std::thread{&RemoteHttpRunner::MonitorThreadFunction, this}; } +void +RemoteHttpRunner::SetRemoteHostname(std::string_view Hostname) +{ + if (!Hostname.empty()) + { + m_DisplayName = fmt::format("{} ({})", m_HostName, Hostname); + } +} + RemoteHttpRunner::~RemoteHttpRunner() { Shutdown(); @@ -108,6 +119,7 @@ RemoteHttpRunner::Shutdown() for (auto& [RemoteLsn, HttpAction] : Remaining) { ZEN_DEBUG("shutdown: marking remote action LSN {} (local LSN {}) as Failed", RemoteLsn, HttpAction.Action->ActionLsn); + HttpAction.Action->FailureReason = "remote runner shutdown"; HttpAction.Action->SetActionState(RunnerAction::State::Failed); } } @@ -213,11 +225,13 @@ RemoteHttpRunner::QueryCapacity() return 0; } - // Estimate how much more work we're ready to accept + // Estimate how much more work we're ready to accept. + // Include actions currently being submitted over HTTP so we don't + // keep queueing new submissions while previous ones are still in flight. RwLock::SharedLockScope _{m_RunningLock}; - size_t RunningCount = m_RemoteRunningMap.size(); + size_t RunningCount = m_RemoteRunningMap.size() + m_InFlightSubmissions.load(std::memory_order_relaxed); if (RunningCount >= size_t(m_MaxRunningActions)) { @@ -232,6 +246,9 @@ RemoteHttpRunner::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions) { ZEN_TRACE_CPU("RemoteHttpRunner::SubmitActions"); + m_InFlightSubmissions.fetch_add(Actions.size(), std::memory_order_relaxed); + auto InFlightGuard = MakeGuard([&] { m_InFlightSubmissions.fetch_sub(Actions.size(), std::memory_order_relaxed); }); + if (Actions.size() <= 1) { std::vector<SubmitResult> Results; @@ -359,108 +376,141 @@ RemoteHttpRunner::SubmitAction(Ref<RunnerAction> Action) } } - // Enqueue job. If the remote returns FailedDependency (424), it means it - // cannot resolve the worker/function — re-register the worker and retry once. + // Submit the action to the remote. In eager-attach mode we build a + // CbPackage with all referenced attachments upfront to avoid the 404 + // round-trip. In the default mode we POST the bare object first and + // only upload missing attachments if the remote requests them. + // + // In both modes, FailedDependency (424) triggers a worker re-register + // and a single retry. CbObject Result; HttpClient::Response WorkResponse; HttpResponseCode WorkResponseCode{}; - for (int Attempt = 0; Attempt < 2; ++Attempt) - { - WorkResponse = m_Http.Post(SubmitUrl, ActionObj); - WorkResponseCode = WorkResponse.StatusCode; - - if (WorkResponseCode == HttpResponseCode::FailedDependency && Attempt == 0) - { - ZEN_WARN("remote {} returned FailedDependency for action {} — re-registering worker and retrying", - m_Http.GetBaseUri(), - ActionId); - - (void)RegisterWorker(Action->Worker.Descriptor); - } - else - { - break; - } - } - - if (WorkResponseCode == HttpResponseCode::OK) - { - Result = WorkResponse.AsObject(); - } - else if (WorkResponseCode == HttpResponseCode::NotFound) + if (m_EagerAttach) { - // Not all attachments are present - - // Build response package including all required attachments - CbPackage Pkg; Pkg.SetObject(ActionObj); - CbObject Response = WorkResponse.AsObject(); + ActionObj.IterateAttachments([&](CbFieldView Field) { + const IoHash AttachHash = Field.AsHash(); - for (auto& Item : Response["need"sv]) - { - const IoHash NeedHash = Item.AsHash(); - - if (IoBuffer Chunk = m_ChunkResolver.FindChunkByCid(NeedHash)) + if (IoBuffer Chunk = m_ChunkResolver.FindChunkByCid(AttachHash)) { uint64_t DataRawSize = 0; IoHash DataRawHash; CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer{Chunk}, /* out */ DataRawHash, /* out */ DataRawSize); - ZEN_ASSERT(DataRawHash == NeedHash); + Pkg.AddAttachment(CbAttachment(Compressed, AttachHash)); + m_LastSubmitStats.TotalAttachments.fetch_add(1, std::memory_order_relaxed); + m_LastSubmitStats.TotalAttachmentBytes.fetch_add(Chunk.GetSize(), std::memory_order_relaxed); + } + }); + + for (int Attempt = 0; Attempt < 2; ++Attempt) + { + WorkResponse = m_Http.Post(SubmitUrl, Pkg); + WorkResponseCode = WorkResponse.StatusCode; + + if (WorkResponseCode == HttpResponseCode::FailedDependency && Attempt == 0) + { + ZEN_WARN("remote {} returned FailedDependency for action {} — re-registering worker and retrying", + m_Http.GetBaseUri(), + ActionId); - Pkg.AddAttachment(CbAttachment(Compressed, NeedHash)); + (void)RegisterWorker(Action->Worker.Descriptor); } else { - // No such attachment - - return {.IsAccepted = false, .Reason = fmt::format("missing attachment {}", NeedHash)}; + break; } } + } + else + { + for (int Attempt = 0; Attempt < 2; ++Attempt) + { + WorkResponse = m_Http.Post(SubmitUrl, ActionObj); + WorkResponseCode = WorkResponse.StatusCode; - // Post resulting package + if (WorkResponseCode == HttpResponseCode::FailedDependency && Attempt == 0) + { + ZEN_WARN("remote {} returned FailedDependency for action {} — re-registering worker and retrying", + m_Http.GetBaseUri(), + ActionId); - HttpClient::Response PayloadResponse = m_Http.Post(SubmitUrl, Pkg); + (void)RegisterWorker(Action->Worker.Descriptor); + } + else + { + break; + } + } - if (!PayloadResponse) + if (WorkResponseCode == HttpResponseCode::NotFound) { - ZEN_WARN("unable to register payloads for action {} at {}{}", ActionId, m_Http.GetBaseUri(), SubmitUrl); + // Remote needs attachments — resolve them and retry with a CbPackage - // TODO: include more information about the failure in the response + CbPackage Pkg; + Pkg.SetObject(ActionObj); - return {.IsAccepted = false, .Reason = "HTTP request failed"}; - } - else if (PayloadResponse.StatusCode == HttpResponseCode::OK) - { - Result = PayloadResponse.AsObject(); - } - else - { - // Unexpected response - - const int ResponseStatusCode = (int)PayloadResponse.StatusCode; - - ZEN_WARN("unable to register payloads for action {} at {}{} (error: {} {})", - ActionId, - m_Http.GetBaseUri(), - SubmitUrl, - ResponseStatusCode, - ToString(ResponseStatusCode)); - - return {.IsAccepted = false, - .Reason = fmt::format("unexpected response code {} {} from {}{}", - ResponseStatusCode, - ToString(ResponseStatusCode), - m_Http.GetBaseUri(), - SubmitUrl)}; + CbObject Response = WorkResponse.AsObject(); + + for (auto& Item : Response["need"sv]) + { + const IoHash NeedHash = Item.AsHash(); + + if (IoBuffer Chunk = m_ChunkResolver.FindChunkByCid(NeedHash)) + { + uint64_t DataRawSize = 0; + IoHash DataRawHash; + CompressedBuffer Compressed = + CompressedBuffer::FromCompressed(SharedBuffer{Chunk}, /* out */ DataRawHash, /* out */ DataRawSize); + + ZEN_ASSERT(DataRawHash == NeedHash); + + Pkg.AddAttachment(CbAttachment(Compressed, NeedHash)); + m_LastSubmitStats.TotalAttachments.fetch_add(1, std::memory_order_relaxed); + m_LastSubmitStats.TotalAttachmentBytes.fetch_add(Chunk.GetSize(), std::memory_order_relaxed); + } + else + { + return {.IsAccepted = false, .Reason = fmt::format("missing attachment {}", NeedHash)}; + } + } + + HttpClient::Response PayloadResponse = m_Http.Post(SubmitUrl, Pkg); + + if (!PayloadResponse) + { + ZEN_WARN("unable to register payloads for action {} at {}{}", ActionId, m_Http.GetBaseUri(), SubmitUrl); + return {.IsAccepted = false, .Reason = "HTTP request failed"}; + } + + WorkResponse = std::move(PayloadResponse); + WorkResponseCode = WorkResponse.StatusCode; } } + if (WorkResponseCode == HttpResponseCode::OK) + { + Result = WorkResponse.AsObject(); + } + else if (!WorkResponse) + { + ZEN_WARN("submit of action {} to {}{} failed", ActionId, m_Http.GetBaseUri(), SubmitUrl); + return {.IsAccepted = false, .Reason = "HTTP request failed"}; + } + else if (!IsHttpSuccessCode(WorkResponseCode)) + { + const int Code = static_cast<int>(WorkResponseCode); + ZEN_WARN("submit of action {} to {}{} returned {} {}", ActionId, m_Http.GetBaseUri(), SubmitUrl, Code, ToString(Code)); + return {.IsAccepted = false, + .Reason = fmt::format("unexpected response code {} {} from {}{}", Code, ToString(Code), m_Http.GetBaseUri(), SubmitUrl)}; + } + if (Result) { if (const int32_t LsnField = Result["lsn"].AsInt32(0)) @@ -512,82 +562,110 @@ RemoteHttpRunner::SubmitActionBatch(const std::string& SubmitUrl, const std::vec CbObjectWriter Body; Body.BeginArray("actions"sv); + std::unordered_set<IoHash, IoHash::Hasher> AttachmentsSeen; + for (const Ref<RunnerAction>& Action : Actions) { Action->ExecutionLocation = m_HostName; MaybeDumpAction(Action->ActionLsn, Action->ActionObj); Body.AddObject(Action->ActionObj); + + if (m_EagerAttach) + { + Action->ActionObj.IterateAttachments([&](CbFieldView Field) { AttachmentsSeen.insert(Field.AsHash()); }); + } } Body.EndArray(); - // POST the batch - - HttpClient::Response Response = m_Http.Post(SubmitUrl, Body.Save()); - - if (Response.StatusCode == HttpResponseCode::OK) - { - return ParseBatchResponse(Response, Actions); - } + // In eager-attach mode, build a CbPackage with all referenced attachments + // so the remote can accept in a single round-trip. Otherwise POST a bare + // CbObject and handle the 404 need-list flow. - if (Response.StatusCode == HttpResponseCode::NotFound) + if (m_EagerAttach) { - // Server needs attachments — resolve them and retry with a CbPackage - - CbObject NeedObj = Response.AsObject(); - CbPackage Pkg; Pkg.SetObject(Body.Save()); - for (auto& Item : NeedObj["need"sv]) + for (const IoHash& AttachHash : AttachmentsSeen) { - const IoHash NeedHash = Item.AsHash(); - - if (IoBuffer Chunk = m_ChunkResolver.FindChunkByCid(NeedHash)) + if (IoBuffer Chunk = m_ChunkResolver.FindChunkByCid(AttachHash)) { uint64_t DataRawSize = 0; IoHash DataRawHash; CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer{Chunk}, /* out */ DataRawHash, /* out */ DataRawSize); - ZEN_ASSERT(DataRawHash == NeedHash); - - Pkg.AddAttachment(CbAttachment(Compressed, NeedHash)); - } - else - { - ZEN_WARN("batch submit: missing attachment {} — falling back to individual submit", NeedHash); - return FallbackToIndividualSubmit(Actions); + Pkg.AddAttachment(CbAttachment(Compressed, AttachHash)); + m_LastSubmitStats.TotalAttachments.fetch_add(1, std::memory_order_relaxed); + m_LastSubmitStats.TotalAttachmentBytes.fetch_add(Chunk.GetSize(), std::memory_order_relaxed); } } - HttpClient::Response RetryResponse = m_Http.Post(SubmitUrl, Pkg); + HttpClient::Response Response = m_Http.Post(SubmitUrl, Pkg); + + if (Response.StatusCode == HttpResponseCode::OK) + { + return ParseBatchResponse(Response, Actions); + } + } + else + { + HttpClient::Response Response = m_Http.Post(SubmitUrl, Body.Save()); - if (RetryResponse.StatusCode == HttpResponseCode::OK) + if (Response.StatusCode == HttpResponseCode::OK) { - return ParseBatchResponse(RetryResponse, Actions); + return ParseBatchResponse(Response, Actions); } - ZEN_WARN("batch submit retry failed with {} {} — falling back to individual submit", - (int)RetryResponse.StatusCode, - ToString(RetryResponse.StatusCode)); - return FallbackToIndividualSubmit(Actions); + if (Response.StatusCode == HttpResponseCode::NotFound) + { + CbObject NeedObj = Response.AsObject(); + + CbPackage Pkg; + Pkg.SetObject(Body.Save()); + + for (auto& Item : NeedObj["need"sv]) + { + const IoHash NeedHash = Item.AsHash(); + + if (IoBuffer Chunk = m_ChunkResolver.FindChunkByCid(NeedHash)) + { + uint64_t DataRawSize = 0; + IoHash DataRawHash; + CompressedBuffer Compressed = + CompressedBuffer::FromCompressed(SharedBuffer{Chunk}, /* out */ DataRawHash, /* out */ DataRawSize); + + ZEN_ASSERT(DataRawHash == NeedHash); + + Pkg.AddAttachment(CbAttachment(Compressed, NeedHash)); + m_LastSubmitStats.TotalAttachments.fetch_add(1, std::memory_order_relaxed); + m_LastSubmitStats.TotalAttachmentBytes.fetch_add(Chunk.GetSize(), std::memory_order_relaxed); + } + else + { + ZEN_WARN("batch submit: missing attachment {} — falling back to individual submit", NeedHash); + return FallbackToIndividualSubmit(Actions); + } + } + + HttpClient::Response RetryResponse = m_Http.Post(SubmitUrl, Pkg); + + if (RetryResponse.StatusCode == HttpResponseCode::OK) + { + return ParseBatchResponse(RetryResponse, Actions); + } + + ZEN_WARN("batch submit retry failed with {} {} — falling back to individual submit", + (int)RetryResponse.StatusCode, + ToString(RetryResponse.StatusCode)); + return FallbackToIndividualSubmit(Actions); + } } // Unexpected status or connection error — fall back to individual submission - if (Response) - { - ZEN_WARN("batch submit to {}{} returned {} {} — falling back to individual submit", - m_Http.GetBaseUri(), - SubmitUrl, - (int)Response.StatusCode, - ToString(Response.StatusCode)); - } - else - { - ZEN_WARN("batch submit to {}{} failed — falling back to individual submit", m_Http.GetBaseUri(), SubmitUrl); - } + ZEN_WARN("batch submit to {}{} failed — falling back to individual submit", m_Http.GetBaseUri(), SubmitUrl); return FallbackToIndividualSubmit(Actions); } @@ -869,9 +947,10 @@ RemoteHttpRunner::SweepRunningActions() { for (auto& FieldIt : Completed["completed"sv]) { - CbObjectView EntryObj = FieldIt.AsObjectView(); - const int32_t CompleteLsn = EntryObj["lsn"sv].AsInt32(); - std::string_view StateName = EntryObj["state"sv].AsString(); + CbObjectView EntryObj = FieldIt.AsObjectView(); + const int32_t CompleteLsn = EntryObj["lsn"sv].AsInt32(); + std::string_view StateName = EntryObj["state"sv].AsString(); + std::string_view FailureReason = EntryObj["reason"sv].AsString(); RunnerAction::State RemoteState = RunnerAction::FromString(StateName); @@ -884,6 +963,7 @@ RemoteHttpRunner::SweepRunningActions() { HttpRunningAction CompletedAction = std::move(CompleteIt->second); CompletedAction.RemoteState = RemoteState; + CompletedAction.FailureReason = std::string(FailureReason); if (RemoteState == RunnerAction::State::Completed && ResponseJob) { @@ -927,16 +1007,44 @@ RemoteHttpRunner::SweepRunningActions() { const int ActionLsn = HttpAction.Action->ActionLsn; - ZEN_DEBUG("action {} LSN {} (remote LSN {}) -> {}", - HttpAction.Action->ActionId, - ActionLsn, - HttpAction.RemoteActionLsn, - RunnerAction::ToString(HttpAction.RemoteState)); - if (HttpAction.RemoteState == RunnerAction::State::Completed) { + ZEN_DEBUG("action {} LSN {} (remote LSN {}) completed on {}", + HttpAction.Action->ActionId, + ActionLsn, + HttpAction.RemoteActionLsn, + m_HostName); HttpAction.Action->SetResult(std::move(HttpAction.ActionResults)); } + else if (HttpAction.RemoteState == RunnerAction::State::Failed || HttpAction.RemoteState == RunnerAction::State::Abandoned) + { + HttpAction.Action->FailureReason = HttpAction.FailureReason; + if (HttpAction.FailureReason.empty()) + { + ZEN_WARN("action {} ({}) {} on remote {}", + HttpAction.Action->ActionId, + ActionLsn, + RunnerAction::ToString(HttpAction.RemoteState), + m_HostName); + } + else + { + ZEN_WARN("action {} ({}) {} on remote {}: {}", + HttpAction.Action->ActionId, + ActionLsn, + RunnerAction::ToString(HttpAction.RemoteState), + m_HostName, + HttpAction.FailureReason); + } + } + else + { + ZEN_DEBUG("action {} LSN {} (remote LSN {}) -> {}", + HttpAction.Action->ActionId, + ActionLsn, + HttpAction.RemoteActionLsn, + RunnerAction::ToString(HttpAction.RemoteState)); + } HttpAction.Action->SetActionState(HttpAction.RemoteState); } diff --git a/src/zencompute/runners/remotehttprunner.h b/src/zencompute/runners/remotehttprunner.h index c17d0cf2a..fdf113c77 100644 --- a/src/zencompute/runners/remotehttprunner.h +++ b/src/zencompute/runners/remotehttprunner.h @@ -54,8 +54,10 @@ public: [[nodiscard]] virtual size_t QueryCapacity() override; [[nodiscard]] virtual std::vector<SubmitResult> SubmitActions(const std::vector<Ref<RunnerAction>>& Actions) override; virtual void CancelRemoteQueue(int QueueId) override; + [[nodiscard]] virtual std::string_view GetDisplayName() const override { return m_DisplayName; } std::string_view GetHostName() const { return m_HostName; } + void SetRemoteHostname(std::string_view Hostname); protected: LoggerRef Log() { return m_Log; } @@ -65,12 +67,15 @@ private: ChunkResolver& m_ChunkResolver; WorkerThreadPool& m_WorkerPool; std::string m_HostName; + std::string m_DisplayName; std::string m_BaseUrl; HttpClient m_Http; - std::atomic<bool> m_AcceptNewActions{true}; - int32_t m_MaxRunningActions = 256; // arbitrary limit for testing - int32_t m_MaxBatchSize = 50; + std::atomic<bool> m_AcceptNewActions{true}; + int32_t m_MaxRunningActions = 256; // arbitrary limit for testing + int32_t m_MaxBatchSize = 50; + bool m_EagerAttach = true; ///< Send attachments with every submit instead of the two-step 404 retry + std::atomic<size_t> m_InFlightSubmissions{0}; // actions currently being submitted over HTTP struct HttpRunningAction { @@ -78,6 +83,7 @@ private: int RemoteActionLsn = 0; // Remote LSN RunnerAction::State RemoteState = RunnerAction::State::Failed; CbPackage ActionResults; + std::string FailureReason; }; RwLock m_RunningLock; diff --git a/src/zencompute/runners/windowsrunner.cpp b/src/zencompute/runners/windowsrunner.cpp index 92ee65c2d..e643c9ce8 100644 --- a/src/zencompute/runners/windowsrunner.cpp +++ b/src/zencompute/runners/windowsrunner.cpp @@ -48,7 +48,9 @@ WindowsProcessRunner::WindowsProcessRunner(ChunkResolver& Resolver, if (m_JobObject) { JOBOBJECT_EXTENDED_LIMIT_INFORMATION ExtLimits{}; - ExtLimits.BasicLimitInformation.LimitFlags = JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE | JOB_OBJECT_LIMIT_DIE_ON_UNHANDLED_EXCEPTION; + ExtLimits.BasicLimitInformation.LimitFlags = + JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE | JOB_OBJECT_LIMIT_DIE_ON_UNHANDLED_EXCEPTION | JOB_OBJECT_LIMIT_PRIORITY_CLASS; + ExtLimits.BasicLimitInformation.PriorityClass = BELOW_NORMAL_PRIORITY_CLASS; SetInformationJobObject(m_JobObject, JobObjectExtendedLimitInformation, &ExtLimits, sizeof(ExtLimits)); JOBOBJECT_BASIC_UI_RESTRICTIONS UiRestrictions{}; diff --git a/src/zencompute/runners/winerunner.cpp b/src/zencompute/runners/winerunner.cpp index b4fafb467..593b19e55 100644 --- a/src/zencompute/runners/winerunner.cpp +++ b/src/zencompute/runners/winerunner.cpp @@ -96,7 +96,9 @@ WineProcessRunner::SubmitAction(Ref<RunnerAction> Action) if (ChildPid == 0) { - // Child process + // Child process — lower priority so workers don't starve the main server + nice(5); + if (chdir(SandboxPathStr.c_str()) != 0) { _exit(127); diff --git a/src/zenhorde/README.md b/src/zenhorde/README.md new file mode 100644 index 000000000..13beaa968 --- /dev/null +++ b/src/zenhorde/README.md @@ -0,0 +1,17 @@ +# Horde Compute integration + +Zen compute can use Horde to provision runner nodes. + +## Launch a coordinator instance + +Coordinator instances provision compute resources (runners) from a compute provider such as Horde, and surface an interface which allows zenserver instances to discover endpoints which they can submit actions to. + +```bash +zenserver compute --horde-enabled --horde-server=https://horde.dev.net:13340/ --horde-max-cores=512 --horde-zen-service-port=25000 --http=asio +``` + +## Use a coordinator + +```bash +zen exec beacon --path=e:\lyra-recording --orch=http://localhost:8558 +``` diff --git a/src/zenhorde/hordeagent.cpp b/src/zenhorde/hordeagent.cpp index 819b2d0cb..275f5bd4c 100644 --- a/src/zenhorde/hordeagent.cpp +++ b/src/zenhorde/hordeagent.cpp @@ -8,290 +8,457 @@ #include <zencore/logging.h> #include <zencore/trace.h> +ZEN_THIRD_PARTY_INCLUDES_START +#include <asio.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + #include <cstring> -#include <unordered_map> namespace zen::horde { -HordeAgent::HordeAgent(const MachineInfo& Info) : m_Log(zen::logging::Get("horde.agent")), m_MachineInfo(Info) -{ - ZEN_TRACE_CPU("HordeAgent::Connect"); +// --- AsyncHordeAgent --- - auto Transport = std::make_unique<TcpComputeTransport>(Info); - if (!Transport->IsValid()) +static const char* +GetStateName(AsyncHordeAgent::State S) +{ + switch (S) { - ZEN_WARN("failed to create TCP transport to '{}:{}'", Info.GetConnectionAddress(), Info.GetConnectionPort()); - return; + case AsyncHordeAgent::State::Idle: + return "idle"; + case AsyncHordeAgent::State::Connecting: + return "connect"; + case AsyncHordeAgent::State::WaitAgentAttach: + return "agent-attach"; + case AsyncHordeAgent::State::SentFork: + return "fork"; + case AsyncHordeAgent::State::WaitChildAttach: + return "child-attach"; + case AsyncHordeAgent::State::Uploading: + return "upload"; + case AsyncHordeAgent::State::Executing: + return "execute"; + case AsyncHordeAgent::State::Polling: + return "poll"; + case AsyncHordeAgent::State::Done: + return "done"; + default: + return "unknown"; } +} - // The 64-byte nonce is always sent unencrypted as the first thing on the wire. - // The Horde agent uses this to identify which lease this connection belongs to. - Transport->Send(Info.Nonce, sizeof(Info.Nonce)); +AsyncHordeAgent::AsyncHordeAgent(asio::io_context& IoContext) : m_IoContext(IoContext), m_Log(zen::logging::Get("horde.agent.async")) +{ +} - std::unique_ptr<ComputeTransport> FinalTransport = std::move(Transport); - if (Info.EncryptionMode == Encryption::AES) +AsyncHordeAgent::~AsyncHordeAgent() +{ + Cancel(); +} + +void +AsyncHordeAgent::Start(AsyncAgentConfig Config, AsyncAgentCompletionHandler OnDone) +{ + m_Config = std::move(Config); + m_OnDone = std::move(OnDone); + m_State = State::Connecting; + DoConnect(); +} + +void +AsyncHordeAgent::Cancel() +{ + m_Cancelled = true; + if (m_Socket) { - FinalTransport = std::make_unique<AesComputeTransport>(Info.Key, std::move(FinalTransport)); - if (!FinalTransport->IsValid()) - { - ZEN_WARN("failed to create AES transport"); - return; - } + m_Socket->Close(); + } + else if (m_Transport) + { + m_Transport->Close(); } +} - // Create multiplexed socket and channels - m_Socket = std::make_unique<ComputeSocket>(std::move(FinalTransport)); +void +AsyncHordeAgent::DoConnect() +{ + ZEN_TRACE_CPU("AsyncHordeAgent::DoConnect"); - // Channel 0 is the agent control channel (handles Attach/Fork handshake). - // Channel 100 is the child I/O channel (handles file upload and remote execution). - Ref<ComputeChannel> AgentComputeChannel = m_Socket->CreateChannel(0); - Ref<ComputeChannel> ChildComputeChannel = m_Socket->CreateChannel(100); + m_TcpTransport = std::make_unique<AsyncTcpComputeTransport>(m_IoContext); + + auto Self = shared_from_this(); + m_TcpTransport->AsyncConnect(m_Config.Machine, [this, Self](const std::error_code& Ec) { OnConnected(Ec); }); +} - if (!AgentComputeChannel || !ChildComputeChannel) +void +AsyncHordeAgent::OnConnected(const std::error_code& Ec) +{ + if (Ec || m_Cancelled) { - ZEN_WARN("failed to create compute channels"); + if (Ec) + { + ZEN_WARN("connect failed: {}", Ec.message()); + } + Finish(false); return; } - m_AgentChannel = std::make_unique<AgentMessageChannel>(std::move(AgentComputeChannel)); - m_ChildChannel = std::make_unique<AgentMessageChannel>(std::move(ChildComputeChannel)); + // Optionally wrap with AES encryption + std::unique_ptr<AsyncComputeTransport> FinalTransport = std::move(m_TcpTransport); + if (m_Config.Machine.EncryptionMode == Encryption::AES) + { + FinalTransport = std::make_unique<AsyncAesComputeTransport>(m_Config.Machine.Key, std::move(FinalTransport), m_IoContext); + } + m_Transport = std::move(FinalTransport); + + // Create the multiplexed socket and register channels + m_Socket = std::make_shared<AsyncComputeSocket>(std::move(m_Transport), m_IoContext); + + m_AgentChannel = std::make_unique<AsyncAgentMessageChannel>(m_Socket, 0, m_IoContext); + m_ChildChannel = std::make_unique<AsyncAgentMessageChannel>(m_Socket, 100, m_IoContext); + + m_Socket->RegisterChannel( + 0, + [this](std::vector<uint8_t> Data) { m_AgentChannel->OnFrame(std::move(Data)); }, + [this]() { m_AgentChannel->OnDetach(); }); - m_IsValid = true; + m_Socket->RegisterChannel( + 100, + [this](std::vector<uint8_t> Data) { m_ChildChannel->OnFrame(std::move(Data)); }, + [this]() { m_ChildChannel->OnDetach(); }); + + m_Socket->StartRecvPump(); + + m_State = State::WaitAgentAttach; + DoWaitAgentAttach(); } -HordeAgent::~HordeAgent() +void +AsyncHordeAgent::DoWaitAgentAttach() { - CloseConnection(); + auto Self = shared_from_this(); + m_AgentChannel->AsyncReadResponse(5000, [this, Self](AgentMessageType Type, const uint8_t* Data, size_t Size) { + OnAgentResponse(Type, Data, Size); + }); } -bool -HordeAgent::BeginCommunication() +void +AsyncHordeAgent::OnAgentResponse(AgentMessageType Type, const uint8_t* /*Data*/, size_t /*Size*/) { - ZEN_TRACE_CPU("HordeAgent::BeginCommunication"); - - if (!m_IsValid) + if (m_Cancelled) { - return false; + Finish(false); + return; } - // Start the send/recv pump threads - m_Socket->StartCommunication(); - - // Wait for Attach on agent channel - AgentMessageType Type = m_AgentChannel->ReadResponse(5000); if (Type == AgentMessageType::None) { ZEN_WARN("timed out waiting for Attach on agent channel"); - return false; + Finish(false); + return; } + if (Type != AgentMessageType::Attach) { ZEN_WARN("expected Attach on agent channel, got 0x{:02x}", static_cast<int>(Type)); - return false; + Finish(false); + return; } - // Fork tells the remote agent to create child channel 100 with a 4MB buffer. - // After this, the agent will send an Attach on the child channel. + m_State = State::SentFork; + DoSendFork(); +} + +void +AsyncHordeAgent::DoSendFork() +{ m_AgentChannel->Fork(100, 4 * 1024 * 1024); - // Wait for Attach on child channel - Type = m_ChildChannel->ReadResponse(5000); + m_State = State::WaitChildAttach; + DoWaitChildAttach(); +} + +void +AsyncHordeAgent::DoWaitChildAttach() +{ + auto Self = shared_from_this(); + m_ChildChannel->AsyncReadResponse(5000, [this, Self](AgentMessageType Type, const uint8_t* Data, size_t Size) { + OnChildAttachResponse(Type, Data, Size); + }); +} + +void +AsyncHordeAgent::OnChildAttachResponse(AgentMessageType Type, const uint8_t* /*Data*/, size_t /*Size*/) +{ + if (m_Cancelled) + { + Finish(false); + return; + } + if (Type == AgentMessageType::None) { ZEN_WARN("timed out waiting for Attach on child channel"); - return false; + Finish(false); + return; } + if (Type != AgentMessageType::Attach) { ZEN_WARN("expected Attach on child channel, got 0x{:02x}", static_cast<int>(Type)); - return false; + Finish(false); + return; } - return true; + m_State = State::Uploading; + m_CurrentBundleIndex = 0; + DoUploadNext(); } -bool -HordeAgent::UploadBinaries(const std::filesystem::path& BundleDir, const std::string& BundleLocator) +void +AsyncHordeAgent::DoUploadNext() { - ZEN_TRACE_CPU("HordeAgent::UploadBinaries"); - - m_ChildChannel->UploadFiles("", BundleLocator.c_str()); + if (m_Cancelled) + { + Finish(false); + return; + } - std::unordered_map<std::string, std::unique_ptr<BasicFile>> BlobFiles; + if (m_CurrentBundleIndex >= m_Config.Bundles.size()) + { + // All bundles uploaded — proceed to execute + m_State = State::Executing; + DoExecute(); + return; + } - auto FindOrOpenBlob = [&](std::string_view Locator) -> BasicFile* { - std::string Key(Locator); + const auto& [Locator, BundleDir] = m_Config.Bundles[m_CurrentBundleIndex]; + m_ChildChannel->UploadFiles("", Locator.c_str()); - if (auto It = BlobFiles.find(Key); It != BlobFiles.end()) - { - return It->second.get(); - } + // Enter the ReadBlob/Blob upload loop + auto Self = shared_from_this(); + m_ChildChannel->AsyncReadResponse(1000, [this, Self](AgentMessageType Type, const uint8_t* Data, size_t Size) { + OnUploadResponse(Type, Data, Size); + }); +} - const std::filesystem::path Path = BundleDir / (Key + ".blob"); - std::error_code Ec; - auto File = std::make_unique<BasicFile>(); - File->Open(Path, BasicFile::Mode::kRead, Ec); +void +AsyncHordeAgent::OnUploadResponse(AgentMessageType Type, const uint8_t* Data, size_t Size) +{ + if (m_Cancelled) + { + Finish(false); + return; + } - if (Ec) + if (Type == AgentMessageType::None) + { + if (m_ChildChannel->IsDetached()) { - ZEN_ERROR("cannot read blob file: '{}'", Path); - return nullptr; + ZEN_WARN("connection lost during upload"); + Finish(false); + return; } + // Timeout — retry read + auto Self = shared_from_this(); + m_ChildChannel->AsyncReadResponse(1000, [this, Self](AgentMessageType Type, const uint8_t* Data, size_t Size) { + OnUploadResponse(Type, Data, Size); + }); + return; + } - BasicFile* Ptr = File.get(); - BlobFiles.emplace(std::move(Key), std::move(File)); - return Ptr; - }; + if (Type == AgentMessageType::WriteFilesResponse) + { + // This bundle upload is done — move to next + ++m_CurrentBundleIndex; + DoUploadNext(); + return; + } - // The upload protocol is request-driven: we send WriteFiles, then the remote agent - // sends ReadBlob requests for each blob it needs. We respond with Blob data until - // the agent sends WriteFilesResponse indicating the upload is complete. - constexpr int32_t ReadResponseTimeoutMs = 1000; + if (Type == AgentMessageType::Exception) + { + ExceptionInfo Ex; + AsyncAgentMessageChannel::ReadException(Data, Size, Ex); + ZEN_ERROR("upload exception: {} - {}", Ex.Message, Ex.Description); + Finish(false); + return; + } - for (;;) + if (Type != AgentMessageType::ReadBlob) { - bool TimedOut = false; + ZEN_ERROR("unexpected message type 0x{:02x} during upload", static_cast<int>(Type)); + Finish(false); + return; + } - if (AgentMessageType Type = m_ChildChannel->ReadResponse(ReadResponseTimeoutMs, &TimedOut); Type != AgentMessageType::ReadBlob) - { - if (TimedOut) - { - continue; - } - // End of stream - check if it was a successful upload - if (Type == AgentMessageType::WriteFilesResponse) - { - return true; - } - else if (Type == AgentMessageType::Exception) - { - ExceptionInfo Ex; - m_ChildChannel->ReadException(Ex); - ZEN_ERROR("upload exception: {} - {}", Ex.Message, Ex.Description); - } - else - { - ZEN_ERROR("unexpected message type 0x{:02x} during upload", static_cast<int>(Type)); - } - return false; - } + // Handle ReadBlob request + BlobRequest Req; + AsyncAgentMessageChannel::ReadBlobRequest(Data, Size, Req); - BlobRequest Req; - m_ChildChannel->ReadBlobRequest(Req); + const auto& [Locator, BundleDir] = m_Config.Bundles[m_CurrentBundleIndex]; + const std::filesystem::path BlobPath = BundleDir / (std::string(Req.Locator) + ".blob"); - BasicFile* File = FindOrOpenBlob(Req.Locator); - if (!File) - { - return false; - } + std::error_code FsEc; + BasicFile File; + File.Open(BlobPath, BasicFile::Mode::kRead, FsEc); - // Read from offset to end of file - const uint64_t TotalSize = File->FileSize(); - const uint64_t Offset = static_cast<uint64_t>(Req.Offset); - if (Offset >= TotalSize) - { - ZEN_ERROR("upload got request for data beyond end of file: offset={}, length={}, total_size={}", Offset, Req.Length, TotalSize); - m_ChildChannel->Blob(nullptr, 0); - continue; - } + if (FsEc) + { + ZEN_ERROR("cannot read blob file: '{}'", BlobPath); + Finish(false); + return; + } + + const uint64_t TotalSize = File.FileSize(); + const uint64_t Offset = static_cast<uint64_t>(Req.Offset); + if (Offset >= TotalSize) + { + ZEN_ERROR("blob request beyond end of file: offset={}, length={}, total_size={}", Offset, Req.Length, TotalSize); + m_ChildChannel->Blob(nullptr, 0); + } + else + { + const IoBuffer FileData = File.ReadRange(Offset, Min(Req.Length, TotalSize - Offset)); + m_ChildChannel->Blob(static_cast<const uint8_t*>(FileData.GetData()), FileData.GetSize()); + } + + // Continue the upload loop + auto Self = shared_from_this(); + m_ChildChannel->AsyncReadResponse(1000, [this, Self](AgentMessageType Type, const uint8_t* Data, size_t Size) { + OnUploadResponse(Type, Data, Size); + }); +} - const IoBuffer Data = File->ReadRange(Offset, Min(Req.Length, TotalSize - Offset)); - m_ChildChannel->Blob(static_cast<const uint8_t*>(Data.GetData()), Data.GetSize()); +void +AsyncHordeAgent::DoExecute() +{ + ZEN_TRACE_CPU("AsyncHordeAgent::DoExecute"); + + std::vector<const char*> ArgPtrs; + ArgPtrs.reserve(m_Config.Args.size()); + for (const std::string& Arg : m_Config.Args) + { + ArgPtrs.push_back(Arg.c_str()); } + + m_ChildChannel->Execute(m_Config.Executable.c_str(), + ArgPtrs.data(), + ArgPtrs.size(), + nullptr, + nullptr, + 0, + m_Config.UseWine ? ExecuteProcessFlags::UseWine : ExecuteProcessFlags::None); + + ZEN_INFO("remote execution started on [{}:{}] lease={}", + m_Config.Machine.GetConnectionAddress(), + m_Config.Machine.GetConnectionPort(), + m_Config.Machine.LeaseId); + + m_State = State::Polling; + DoPoll(); } void -HordeAgent::Execute(const char* Exe, - const char* const* Args, - size_t NumArgs, - const char* WorkingDir, - const char* const* EnvVars, - size_t NumEnvVars, - bool UseWine) +AsyncHordeAgent::DoPoll() { - ZEN_TRACE_CPU("HordeAgent::Execute"); - m_ChildChannel - ->Execute(Exe, Args, NumArgs, WorkingDir, EnvVars, NumEnvVars, UseWine ? ExecuteProcessFlags::UseWine : ExecuteProcessFlags::None); + if (m_Cancelled) + { + Finish(false); + return; + } + + auto Self = shared_from_this(); + m_ChildChannel->AsyncReadResponse(100, [this, Self](AgentMessageType Type, const uint8_t* Data, size_t Size) { + OnPollResponse(Type, Data, Size); + }); } -bool -HordeAgent::Poll(bool LogOutput) +void +AsyncHordeAgent::OnPollResponse(AgentMessageType Type, const uint8_t* Data, size_t Size) { - constexpr int32_t ReadResponseTimeoutMs = 100; - AgentMessageType Type; + if (m_Cancelled) + { + Finish(false); + return; + } - while ((Type = m_ChildChannel->ReadResponse(ReadResponseTimeoutMs)) != AgentMessageType::None) + switch (Type) { - switch (Type) - { - case AgentMessageType::ExecuteOutput: - { - if (LogOutput && m_ChildChannel->GetResponseSize() > 0) - { - const char* ResponseData = static_cast<const char*>(m_ChildChannel->GetResponseData()); - size_t ResponseSize = m_ChildChannel->GetResponseSize(); - - // Trim trailing newlines - while (ResponseSize > 0 && (ResponseData[ResponseSize - 1] == '\n' || ResponseData[ResponseSize - 1] == '\r')) - { - --ResponseSize; - } - - if (ResponseSize > 0) - { - const std::string_view Output(ResponseData, ResponseSize); - ZEN_INFO("[remote] {}", Output); - } - } - break; - } + case AgentMessageType::None: + if (m_ChildChannel->IsDetached()) + { + ZEN_WARN("connection lost during execution"); + Finish(false); + } + else + { + // Timeout — poll again + DoPoll(); + } + break; - case AgentMessageType::ExecuteResult: - { - if (m_ChildChannel->GetResponseSize() == sizeof(int32_t)) - { - int32_t ExitCode; - memcpy(&ExitCode, m_ChildChannel->GetResponseData(), sizeof(int32_t)); - ZEN_INFO("remote process exited with code {}", ExitCode); - } - m_IsValid = false; - return false; - } + case AgentMessageType::ExecuteOutput: + // Silently consume remote stdout (matching LogOutput=false in provisioner) + DoPoll(); + break; - case AgentMessageType::Exception: + case AgentMessageType::ExecuteResult: + { + int32_t ExitCode = -1; + if (Size == sizeof(int32_t)) { - ExceptionInfo Ex; - m_ChildChannel->ReadException(Ex); - ZEN_ERROR("exception: {} - {}", Ex.Message, Ex.Description); - m_HasErrors = true; - break; + memcpy(&ExitCode, Data, sizeof(int32_t)); } + ZEN_INFO("remote process exited with code {} (lease={})", ExitCode, m_Config.Machine.LeaseId); + Finish(ExitCode == 0, ExitCode); + } + break; - default: - break; - } - } + case AgentMessageType::Exception: + { + ExceptionInfo Ex; + AsyncAgentMessageChannel::ReadException(Data, Size, Ex); + ZEN_ERROR("exception: {} - {}", Ex.Message, Ex.Description); + Finish(false); + } + break; - return m_IsValid && !m_HasErrors; + default: + DoPoll(); + break; + } } void -HordeAgent::CloseConnection() +AsyncHordeAgent::Finish(bool Success, int32_t ExitCode) { - if (m_ChildChannel) + if (m_State == State::Done) { - m_ChildChannel->Close(); + return; // Already finished } - if (m_AgentChannel) + + if (!Success) { - m_AgentChannel->Close(); + ZEN_WARN("agent failed during {} (lease={})", GetStateName(m_State), m_Config.Machine.LeaseId); } -} -bool -HordeAgent::IsValid() const -{ - return m_IsValid && !m_HasErrors; + m_State = State::Done; + + if (m_Socket) + { + m_Socket->Close(); + } + + if (m_OnDone) + { + AsyncAgentResult Result; + Result.Success = Success; + Result.ExitCode = ExitCode; + Result.CoreCount = m_Config.Machine.LogicalCores; + + auto Handler = std::move(m_OnDone); + m_OnDone = nullptr; + Handler(Result); + } } } // namespace zen::horde diff --git a/src/zenhorde/hordeagent.h b/src/zenhorde/hordeagent.h index e0ae89ead..b581a8da1 100644 --- a/src/zenhorde/hordeagent.h +++ b/src/zenhorde/hordeagent.h @@ -10,68 +10,108 @@ #include <zencore/logbase.h> #include <filesystem> +#include <functional> #include <memory> #include <string> +#include <vector> + +namespace asio { +class io_context; +} namespace zen::horde { -/** Manages the lifecycle of a single Horde compute agent. +class AsyncComputeTransport; + +/** Result passed to the completion handler when an async agent finishes. */ +struct AsyncAgentResult +{ + bool Success = false; + int32_t ExitCode = -1; + uint16_t CoreCount = 0; ///< Logical cores on the provisioned machine +}; + +/** Completion handler for async agent lifecycle. */ +using AsyncAgentCompletionHandler = std::function<void(const AsyncAgentResult&)>; + +/** Configuration for launching a remote zenserver instance via an async agent. */ +struct AsyncAgentConfig +{ + MachineInfo Machine; + std::vector<std::pair<std::string, std::filesystem::path>> Bundles; ///< (locator, bundleDir) pairs + std::string Executable; + std::vector<std::string> Args; + bool UseWine = false; +}; + +/** Async agent that manages the full lifecycle of a single Horde compute connection. * - * Handles the full connection sequence for one provisioned machine: - * 1. Connect via TCP transport (with optional AES encryption wrapping) - * 2. Create a multiplexed ComputeSocket with agent (channel 0) and child (channel 100) - * 3. Perform the Attach/Fork handshake to establish the child channel - * 4. Upload zenserver binary via the WriteFiles/ReadBlob protocol - * 5. Execute zenserver remotely via ExecuteV2 - * 6. Poll for ExecuteOutput (stdout) and ExecuteResult (exit code) + * Driven by a state machine using callbacks on a shared io_context — no dedicated + * threads. Call Start() to begin the connection/handshake/upload/execute/poll + * sequence. The completion handler is invoked when the remote process exits or + * an error occurs. */ -class HordeAgent +class AsyncHordeAgent : public std::enable_shared_from_this<AsyncHordeAgent> { public: - explicit HordeAgent(const MachineInfo& Info); - ~HordeAgent(); + AsyncHordeAgent(asio::io_context& IoContext); + ~AsyncHordeAgent(); - HordeAgent(const HordeAgent&) = delete; - HordeAgent& operator=(const HordeAgent&) = delete; + AsyncHordeAgent(const AsyncHordeAgent&) = delete; + AsyncHordeAgent& operator=(const AsyncHordeAgent&) = delete; - /** Perform the channel setup handshake (Attach on agent channel, Fork, Attach on child channel). - * Returns false if the handshake times out or receives an unexpected message. */ - bool BeginCommunication(); + /** Start the full agent lifecycle. The completion handler is called exactly once. */ + void Start(AsyncAgentConfig Config, AsyncAgentCompletionHandler OnDone); - /** Upload binary files to the remote agent. - * @param BundleDir Directory containing .blob files. - * @param BundleLocator Locator string identifying the bundle (from CreateBundle). */ - bool UploadBinaries(const std::filesystem::path& BundleDir, const std::string& BundleLocator); + /** Cancel in-flight operations. The completion handler is still called (with Success=false). */ + void Cancel(); - /** Execute a command on the remote machine. */ - void Execute(const char* Exe, - const char* const* Args, - size_t NumArgs, - const char* WorkingDir = nullptr, - const char* const* EnvVars = nullptr, - size_t NumEnvVars = 0, - bool UseWine = false); + const MachineInfo& GetMachineInfo() const { return m_Config.Machine; } - /** Poll for output and results. Returns true if the agent is still running. - * When LogOutput is true, remote stdout is logged via ZEN_INFO. */ - bool Poll(bool LogOutput = true); - - void CloseConnection(); - bool IsValid() const; - - const MachineInfo& GetMachineInfo() const { return m_MachineInfo; } + enum class State + { + Idle, + Connecting, + WaitAgentAttach, + SentFork, + WaitChildAttach, + Uploading, + Executing, + Polling, + Done + }; private: LoggerRef Log() { return m_Log; } - std::unique_ptr<ComputeSocket> m_Socket; - std::unique_ptr<AgentMessageChannel> m_AgentChannel; ///< Channel 0: agent control - std::unique_ptr<AgentMessageChannel> m_ChildChannel; ///< Channel 100: child I/O - - LoggerRef m_Log; - bool m_IsValid = false; - bool m_HasErrors = false; - MachineInfo m_MachineInfo; + void DoConnect(); + void OnConnected(const std::error_code& Ec); + void DoWaitAgentAttach(); + void OnAgentResponse(AgentMessageType Type, const uint8_t* Data, size_t Size); + void DoSendFork(); + void DoWaitChildAttach(); + void OnChildAttachResponse(AgentMessageType Type, const uint8_t* Data, size_t Size); + void DoUploadNext(); + void OnUploadResponse(AgentMessageType Type, const uint8_t* Data, size_t Size); + void DoExecute(); + void DoPoll(); + void OnPollResponse(AgentMessageType Type, const uint8_t* Data, size_t Size); + void Finish(bool Success, int32_t ExitCode = -1); + + asio::io_context& m_IoContext; + LoggerRef m_Log; + State m_State = State::Idle; + bool m_Cancelled = false; + + AsyncAgentConfig m_Config; + AsyncAgentCompletionHandler m_OnDone; + size_t m_CurrentBundleIndex = 0; + + std::unique_ptr<AsyncTcpComputeTransport> m_TcpTransport; + std::unique_ptr<AsyncComputeTransport> m_Transport; + std::shared_ptr<AsyncComputeSocket> m_Socket; + std::unique_ptr<AsyncAgentMessageChannel> m_AgentChannel; + std::unique_ptr<AsyncAgentMessageChannel> m_ChildChannel; }; } // namespace zen::horde diff --git a/src/zenhorde/hordeagentmessage.cpp b/src/zenhorde/hordeagentmessage.cpp index 998134a96..31498972f 100644 --- a/src/zenhorde/hordeagentmessage.cpp +++ b/src/zenhorde/hordeagentmessage.cpp @@ -4,337 +4,403 @@ #include <zencore/intmath.h> +ZEN_THIRD_PARTY_INCLUDES_START +#include <asio.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + #include <cassert> #include <cstring> namespace zen::horde { -AgentMessageChannel::AgentMessageChannel(Ref<ComputeChannel> Channel) : m_Channel(std::move(Channel)) +// --- AsyncAgentMessageChannel --- + +AsyncAgentMessageChannel::AsyncAgentMessageChannel(std::shared_ptr<AsyncComputeSocket> Socket, int ChannelId, asio::io_context& IoContext) +: m_Socket(std::move(Socket)) +, m_ChannelId(ChannelId) +, m_IoContext(IoContext) +, m_TimeoutTimer(std::make_unique<asio::steady_timer>(IoContext)) { } -AgentMessageChannel::~AgentMessageChannel() = default; - -void -AgentMessageChannel::Close() +AsyncAgentMessageChannel::~AsyncAgentMessageChannel() { - CreateMessage(AgentMessageType::None, 0); - FlushMessage(); + if (m_TimeoutTimer) + { + m_TimeoutTimer->cancel(); + } } -void -AgentMessageChannel::Ping() +// --- Message building helpers --- + +std::vector<uint8_t> +AsyncAgentMessageChannel::BeginMessage(AgentMessageType Type, size_t ReservePayload) { - CreateMessage(AgentMessageType::Ping, 0); - FlushMessage(); + std::vector<uint8_t> Buf; + Buf.reserve(MessageHeaderLength + ReservePayload); + Buf.push_back(static_cast<uint8_t>(Type)); + Buf.resize(MessageHeaderLength); // 1 byte type + 4 bytes length placeholder + return Buf; } void -AgentMessageChannel::Fork(int ChannelId, int BufferSize) +AsyncAgentMessageChannel::FinalizeAndSend(std::vector<uint8_t> Msg) { - CreateMessage(AgentMessageType::Fork, sizeof(int) + sizeof(int)); - WriteInt32(ChannelId); - WriteInt32(BufferSize); - FlushMessage(); + const uint32_t PayloadSize = static_cast<uint32_t>(Msg.size() - MessageHeaderLength); + memcpy(&Msg[1], &PayloadSize, sizeof(uint32_t)); + m_Socket->AsyncSendFrame(m_ChannelId, std::move(Msg)); } void -AgentMessageChannel::Attach() +AsyncAgentMessageChannel::WriteInt32(std::vector<uint8_t>& Buf, int Value) { - CreateMessage(AgentMessageType::Attach, 0); - FlushMessage(); + const uint8_t* Ptr = reinterpret_cast<const uint8_t*>(&Value); + Buf.insert(Buf.end(), Ptr, Ptr + sizeof(int)); } -void -AgentMessageChannel::UploadFiles(const char* Path, const char* Locator) +int +AsyncAgentMessageChannel::ReadInt32(const uint8_t** Pos) { - CreateMessage(AgentMessageType::WriteFiles, strlen(Path) + strlen(Locator) + 20); - WriteString(Path); - WriteString(Locator); - FlushMessage(); + int Value; + memcpy(&Value, *Pos, sizeof(int)); + *Pos += sizeof(int); + return Value; } void -AgentMessageChannel::Execute(const char* Exe, - const char* const* Args, - size_t NumArgs, - const char* WorkingDir, - const char* const* EnvVars, - size_t NumEnvVars, - ExecuteProcessFlags Flags) +AsyncAgentMessageChannel::WriteFixedLengthBytes(std::vector<uint8_t>& Buf, const uint8_t* Data, size_t Length) { - size_t RequiredSize = 50 + strlen(Exe); - for (size_t i = 0; i < NumArgs; ++i) - { - RequiredSize += strlen(Args[i]) + 10; - } - if (WorkingDir) - { - RequiredSize += strlen(WorkingDir) + 10; - } - for (size_t i = 0; i < NumEnvVars; ++i) - { - RequiredSize += strlen(EnvVars[i]) + 20; - } - - CreateMessage(AgentMessageType::ExecuteV2, RequiredSize); - WriteString(Exe); - - WriteUnsignedVarInt(NumArgs); - for (size_t i = 0; i < NumArgs; ++i) - { - WriteString(Args[i]); - } - - WriteOptionalString(WorkingDir); - - // ExecuteV2 protocol requires env vars as separate key/value pairs. - // Callers pass "KEY=VALUE" strings; we split on the first '=' here. - WriteUnsignedVarInt(NumEnvVars); - for (size_t i = 0; i < NumEnvVars; ++i) - { - const char* Eq = strchr(EnvVars[i], '='); - assert(Eq != nullptr); - - WriteString(std::string_view(EnvVars[i], Eq - EnvVars[i])); - if (*(Eq + 1) == '\0') - { - WriteOptionalString(nullptr); - } - else - { - WriteOptionalString(Eq + 1); - } - } + Buf.insert(Buf.end(), Data, Data + Length); +} - WriteInt32(static_cast<int>(Flags)); - FlushMessage(); +const uint8_t* +AsyncAgentMessageChannel::ReadFixedLengthBytes(const uint8_t** Pos, size_t Length) +{ + const uint8_t* Data = *Pos; + *Pos += Length; + return Data; } -void -AgentMessageChannel::Blob(const uint8_t* Data, size_t Length) +size_t +AsyncAgentMessageChannel::MeasureUnsignedVarInt(size_t Value) { - // Blob responses are chunked to fit within the compute buffer's chunk size. - // The 128-byte margin accounts for the ReadBlobResponse header (offset + total length fields). - const size_t MaxChunkSize = m_Channel->Writer.GetChunkMaxLength() - 128 - MessageHeaderLength; - for (size_t ChunkOffset = 0; ChunkOffset < Length;) + if (Value == 0) { - const size_t ChunkLength = std::min(Length - ChunkOffset, MaxChunkSize); - - CreateMessage(AgentMessageType::ReadBlobResponse, ChunkLength + 128); - WriteInt32(static_cast<int>(ChunkOffset)); - WriteInt32(static_cast<int>(Length)); - WriteFixedLengthBytes(Data + ChunkOffset, ChunkLength); - FlushMessage(); - - ChunkOffset += ChunkLength; + return 1; } + return (FloorLog2_64(static_cast<uint64_t>(Value)) / 7) + 1; } -AgentMessageType -AgentMessageChannel::ReadResponse(int32_t TimeoutMs, bool* OutTimedOut) +void +AsyncAgentMessageChannel::WriteUnsignedVarInt(std::vector<uint8_t>& Buf, size_t Value) { - // Deferred advance: the previous response's buffer is only released when the next - // ReadResponse is called. This allows callers to read response data between calls - // without copying, since the pointer comes directly from the ring buffer. - if (m_ResponseData) - { - m_Channel->Reader.AdvanceReadPosition(m_ResponseLength + MessageHeaderLength); - m_ResponseData = nullptr; - m_ResponseLength = 0; - } + const size_t ByteCount = MeasureUnsignedVarInt(Value); + const size_t StartPos = Buf.size(); + Buf.resize(StartPos + ByteCount); - const uint8_t* Header = m_Channel->Reader.WaitToRead(MessageHeaderLength, TimeoutMs, OutTimedOut); - if (!Header) + uint8_t* Output = Buf.data() + StartPos; + for (size_t i = 1; i < ByteCount; ++i) { - return AgentMessageType::None; + Output[ByteCount - i] = static_cast<uint8_t>(Value); + Value >>= 8; } + Output[0] = static_cast<uint8_t>((0xFF << (9 - static_cast<int>(ByteCount))) | static_cast<uint8_t>(Value)); +} - uint32_t Length; - memcpy(&Length, Header + 1, sizeof(uint32_t)); +size_t +AsyncAgentMessageChannel::ReadUnsignedVarInt(const uint8_t** Pos) +{ + const uint8_t* Data = *Pos; + const uint8_t FirstByte = Data[0]; + const size_t NumBytes = CountLeadingZeros(0xFF & (~static_cast<unsigned int>(FirstByte))) + 1 - 24; - Header = m_Channel->Reader.WaitToRead(MessageHeaderLength + Length, TimeoutMs, OutTimedOut); - if (!Header) + size_t Value = static_cast<size_t>(FirstByte & (0xFF >> NumBytes)); + for (size_t i = 1; i < NumBytes; ++i) { - return AgentMessageType::None; + Value <<= 8; + Value |= Data[i]; } - m_ResponseType = static_cast<AgentMessageType>(Header[0]); - m_ResponseData = Header + MessageHeaderLength; - m_ResponseLength = Length; - - return m_ResponseType; + *Pos += NumBytes; + return Value; } void -AgentMessageChannel::ReadException(ExceptionInfo& Ex) +AsyncAgentMessageChannel::WriteString(std::vector<uint8_t>& Buf, const char* Text) { - assert(m_ResponseType == AgentMessageType::Exception); - const uint8_t* Pos = m_ResponseData; - Ex.Message = ReadString(&Pos); - Ex.Description = ReadString(&Pos); + const size_t Length = strlen(Text); + WriteUnsignedVarInt(Buf, Length); + WriteFixedLengthBytes(Buf, reinterpret_cast<const uint8_t*>(Text), Length); } -int -AgentMessageChannel::ReadExecuteResult() +void +AsyncAgentMessageChannel::WriteString(std::vector<uint8_t>& Buf, std::string_view Text) { - assert(m_ResponseType == AgentMessageType::ExecuteResult); - const uint8_t* Pos = m_ResponseData; - return ReadInt32(&Pos); + WriteUnsignedVarInt(Buf, Text.size()); + WriteFixedLengthBytes(Buf, reinterpret_cast<const uint8_t*>(Text.data()), Text.size()); } -void -AgentMessageChannel::ReadBlobRequest(BlobRequest& Req) +std::string_view +AsyncAgentMessageChannel::ReadString(const uint8_t** Pos) { - assert(m_ResponseType == AgentMessageType::ReadBlob); - const uint8_t* Pos = m_ResponseData; - Req.Locator = ReadString(&Pos); - Req.Offset = ReadUnsignedVarInt(&Pos); - Req.Length = ReadUnsignedVarInt(&Pos); + const size_t Length = ReadUnsignedVarInt(Pos); + const char* Start = reinterpret_cast<const char*>(ReadFixedLengthBytes(Pos, Length)); + return std::string_view(Start, Length); } void -AgentMessageChannel::CreateMessage(AgentMessageType Type, size_t MaxLength) +AsyncAgentMessageChannel::WriteOptionalString(std::vector<uint8_t>& Buf, const char* Text) { - m_RequestData = m_Channel->Writer.WaitToWrite(MessageHeaderLength + MaxLength); - m_RequestData[0] = static_cast<uint8_t>(Type); - m_MaxRequestSize = MaxLength; - m_RequestSize = 0; + if (!Text) + { + WriteUnsignedVarInt(Buf, 0); + } + else + { + const size_t Length = strlen(Text); + WriteUnsignedVarInt(Buf, Length + 1); + WriteFixedLengthBytes(Buf, reinterpret_cast<const uint8_t*>(Text), Length); + } } +// --- Send methods --- + void -AgentMessageChannel::FlushMessage() +AsyncAgentMessageChannel::Close() { - const uint32_t Size = static_cast<uint32_t>(m_RequestSize); - memcpy(&m_RequestData[1], &Size, sizeof(uint32_t)); - m_Channel->Writer.AdvanceWritePosition(MessageHeaderLength + m_RequestSize); - m_RequestSize = 0; - m_MaxRequestSize = 0; - m_RequestData = nullptr; + auto Msg = BeginMessage(AgentMessageType::None, 0); + FinalizeAndSend(std::move(Msg)); } void -AgentMessageChannel::WriteInt32(int Value) +AsyncAgentMessageChannel::Ping() { - WriteFixedLengthBytes(reinterpret_cast<const uint8_t*>(&Value), sizeof(int)); + auto Msg = BeginMessage(AgentMessageType::Ping, 0); + FinalizeAndSend(std::move(Msg)); } -int -AgentMessageChannel::ReadInt32(const uint8_t** Pos) +void +AsyncAgentMessageChannel::Fork(int ChannelId, int BufferSize) { - int Value; - memcpy(&Value, *Pos, sizeof(int)); - *Pos += sizeof(int); - return Value; + auto Msg = BeginMessage(AgentMessageType::Fork, sizeof(int) + sizeof(int)); + WriteInt32(Msg, ChannelId); + WriteInt32(Msg, BufferSize); + FinalizeAndSend(std::move(Msg)); } void -AgentMessageChannel::WriteFixedLengthBytes(const uint8_t* Data, size_t Length) +AsyncAgentMessageChannel::Attach() { - assert(m_RequestSize + Length <= m_MaxRequestSize); - memcpy(&m_RequestData[MessageHeaderLength + m_RequestSize], Data, Length); - m_RequestSize += Length; + auto Msg = BeginMessage(AgentMessageType::Attach, 0); + FinalizeAndSend(std::move(Msg)); } -const uint8_t* -AgentMessageChannel::ReadFixedLengthBytes(const uint8_t** Pos, size_t Length) +void +AsyncAgentMessageChannel::UploadFiles(const char* Path, const char* Locator) { - const uint8_t* Data = *Pos; - *Pos += Length; - return Data; + auto Msg = BeginMessage(AgentMessageType::WriteFiles, strlen(Path) + strlen(Locator) + 20); + WriteString(Msg, Path); + WriteString(Msg, Locator); + FinalizeAndSend(std::move(Msg)); } -size_t -AgentMessageChannel::MeasureUnsignedVarInt(size_t Value) +void +AsyncAgentMessageChannel::Execute(const char* Exe, + const char* const* Args, + size_t NumArgs, + const char* WorkingDir, + const char* const* EnvVars, + size_t NumEnvVars, + ExecuteProcessFlags Flags) { - if (Value == 0) + size_t ReserveSize = 50 + strlen(Exe); + for (size_t i = 0; i < NumArgs; ++i) { - return 1; + ReserveSize += strlen(Args[i]) + 10; } - return (FloorLog2_64(static_cast<uint64_t>(Value)) / 7) + 1; + if (WorkingDir) + { + ReserveSize += strlen(WorkingDir) + 10; + } + for (size_t i = 0; i < NumEnvVars; ++i) + { + ReserveSize += strlen(EnvVars[i]) + 20; + } + + auto Msg = BeginMessage(AgentMessageType::ExecuteV2, ReserveSize); + WriteString(Msg, Exe); + + WriteUnsignedVarInt(Msg, NumArgs); + for (size_t i = 0; i < NumArgs; ++i) + { + WriteString(Msg, Args[i]); + } + + WriteOptionalString(Msg, WorkingDir); + + WriteUnsignedVarInt(Msg, NumEnvVars); + for (size_t i = 0; i < NumEnvVars; ++i) + { + const char* Eq = strchr(EnvVars[i], '='); + assert(Eq != nullptr); + + WriteString(Msg, std::string_view(EnvVars[i], Eq - EnvVars[i])); + if (*(Eq + 1) == '\0') + { + WriteOptionalString(Msg, nullptr); + } + else + { + WriteOptionalString(Msg, Eq + 1); + } + } + + WriteInt32(Msg, static_cast<int>(Flags)); + FinalizeAndSend(std::move(Msg)); } void -AgentMessageChannel::WriteUnsignedVarInt(size_t Value) +AsyncAgentMessageChannel::Blob(const uint8_t* Data, size_t Length) { - const size_t ByteCount = MeasureUnsignedVarInt(Value); - assert(m_RequestSize + ByteCount <= m_MaxRequestSize); + static constexpr size_t MaxBlobChunkSize = 512 * 1024; - uint8_t* Output = m_RequestData + MessageHeaderLength + m_RequestSize; - for (size_t i = 1; i < ByteCount; ++i) + for (size_t ChunkOffset = 0; ChunkOffset < Length;) { - Output[ByteCount - i] = static_cast<uint8_t>(Value); - Value >>= 8; - } - Output[0] = static_cast<uint8_t>((0xFF << (9 - static_cast<int>(ByteCount))) | static_cast<uint8_t>(Value)); + const size_t ChunkLength = std::min(Length - ChunkOffset, MaxBlobChunkSize); - m_RequestSize += ByteCount; + auto Msg = BeginMessage(AgentMessageType::ReadBlobResponse, ChunkLength + 128); + WriteInt32(Msg, static_cast<int>(ChunkOffset)); + WriteInt32(Msg, static_cast<int>(Length)); + WriteFixedLengthBytes(Msg, Data + ChunkOffset, ChunkLength); + FinalizeAndSend(std::move(Msg)); + + ChunkOffset += ChunkLength; + } } -size_t -AgentMessageChannel::ReadUnsignedVarInt(const uint8_t** Pos) +// --- Async response reading --- + +void +AsyncAgentMessageChannel::AsyncReadResponse(int32_t TimeoutMs, AsyncResponseHandler Handler) { - const uint8_t* Data = *Pos; - const uint8_t FirstByte = Data[0]; - const size_t NumBytes = CountLeadingZeros(0xFF & (~static_cast<unsigned int>(FirstByte))) + 1 - 24; + // If frames are already queued, dispatch immediately + if (!m_IncomingFrames.empty()) + { + std::vector<uint8_t> Frame = std::move(m_IncomingFrames.front()); + m_IncomingFrames.pop_front(); - size_t Value = static_cast<size_t>(FirstByte & (0xFF >> NumBytes)); - for (size_t i = 1; i < NumBytes; ++i) + if (Frame.size() >= MessageHeaderLength) + { + AgentMessageType Type = static_cast<AgentMessageType>(Frame[0]); + const uint8_t* Data = Frame.data() + MessageHeaderLength; + size_t Size = Frame.size() - MessageHeaderLength; + asio::post(m_IoContext, [Handler = std::move(Handler), Type, Frame = std::move(Frame), Data, Size]() mutable { + // The Frame is captured to keep Data pointer valid + Handler(Type, Data, Size); + }); + } + else + { + asio::post(m_IoContext, [Handler = std::move(Handler)] { Handler(AgentMessageType::None, nullptr, 0); }); + } + return; + } + + if (m_Detached) { - Value <<= 8; - Value |= Data[i]; + asio::post(m_IoContext, [Handler = std::move(Handler)] { Handler(AgentMessageType::None, nullptr, 0); }); + return; } - *Pos += NumBytes; - return Value; + // No frames queued — store pending handler and arm timeout + m_PendingHandler = std::move(Handler); + + if (TimeoutMs >= 0) + { + m_TimeoutTimer->expires_after(std::chrono::milliseconds(TimeoutMs)); + m_TimeoutTimer->async_wait([this](const asio::error_code& Ec) { + if (Ec) + { + return; // Cancelled — frame arrived before timeout + } + + if (m_PendingHandler) + { + AsyncResponseHandler Handler = std::move(m_PendingHandler); + m_PendingHandler = nullptr; + Handler(AgentMessageType::None, nullptr, 0); + } + }); + } } -size_t -AgentMessageChannel::MeasureString(const char* Text) const +void +AsyncAgentMessageChannel::OnFrame(std::vector<uint8_t> Data) { - const size_t Length = strlen(Text); - return MeasureUnsignedVarInt(Length) + Length; + if (m_PendingHandler) + { + // Cancel the timeout timer + m_TimeoutTimer->cancel(); + + AsyncResponseHandler Handler = std::move(m_PendingHandler); + m_PendingHandler = nullptr; + + if (Data.size() >= MessageHeaderLength) + { + AgentMessageType Type = static_cast<AgentMessageType>(Data[0]); + const uint8_t* Payload = Data.data() + MessageHeaderLength; + size_t PayloadSize = Data.size() - MessageHeaderLength; + Handler(Type, Payload, PayloadSize); + } + else + { + Handler(AgentMessageType::None, nullptr, 0); + } + } + else + { + m_IncomingFrames.push_back(std::move(Data)); + } } void -AgentMessageChannel::WriteString(const char* Text) +AsyncAgentMessageChannel::OnDetach() { - const size_t Length = strlen(Text); - WriteUnsignedVarInt(Length); - WriteFixedLengthBytes(reinterpret_cast<const uint8_t*>(Text), Length); + m_Detached = true; + + if (m_PendingHandler) + { + m_TimeoutTimer->cancel(); + AsyncResponseHandler Handler = std::move(m_PendingHandler); + m_PendingHandler = nullptr; + Handler(AgentMessageType::None, nullptr, 0); + } } +// --- Response parsing helpers --- + void -AgentMessageChannel::WriteString(std::string_view Text) +AsyncAgentMessageChannel::ReadException(const uint8_t* Data, size_t /*Size*/, ExceptionInfo& Ex) { - WriteUnsignedVarInt(Text.size()); - WriteFixedLengthBytes(reinterpret_cast<const uint8_t*>(Text.data()), Text.size()); + const uint8_t* Pos = Data; + Ex.Message = ReadString(&Pos); + Ex.Description = ReadString(&Pos); } -std::string_view -AgentMessageChannel::ReadString(const uint8_t** Pos) +int +AsyncAgentMessageChannel::ReadExecuteResult(const uint8_t* Data, size_t /*Size*/) { - const size_t Length = ReadUnsignedVarInt(Pos); - const char* Start = reinterpret_cast<const char*>(ReadFixedLengthBytes(Pos, Length)); - return std::string_view(Start, Length); + const uint8_t* Pos = Data; + return ReadInt32(&Pos); } void -AgentMessageChannel::WriteOptionalString(const char* Text) +AsyncAgentMessageChannel::ReadBlobRequest(const uint8_t* Data, size_t /*Size*/, BlobRequest& Req) { - // Optional strings use length+1 encoding: 0 means null/absent, - // N>0 means a string of length N-1 follows. This matches the UE - // FAgentMessageChannel serialization convention. - if (!Text) - { - WriteUnsignedVarInt(0); - } - else - { - const size_t Length = strlen(Text); - WriteUnsignedVarInt(Length + 1); - WriteFixedLengthBytes(reinterpret_cast<const uint8_t*>(Text), Length); - } + const uint8_t* Pos = Data; + Req.Locator = ReadString(&Pos); + Req.Offset = ReadUnsignedVarInt(&Pos); + Req.Length = ReadUnsignedVarInt(&Pos); } } // namespace zen::horde diff --git a/src/zenhorde/hordeagentmessage.h b/src/zenhorde/hordeagentmessage.h index 38c4375fd..0068fb468 100644 --- a/src/zenhorde/hordeagentmessage.h +++ b/src/zenhorde/hordeagentmessage.h @@ -4,14 +4,22 @@ #include <zenbase/zenbase.h> -#include "hordecomputechannel.h" +#include "hordecomputesocket.h" #include <cstddef> #include <cstdint> +#include <deque> +#include <functional> +#include <memory> #include <string> #include <string_view> +#include <system_error> #include <vector> +namespace asio { +class io_context; +} // namespace asio + namespace zen::horde { /** Agent message types matching the UE EAgentMessageType byte values. @@ -55,45 +63,34 @@ struct BlobRequest size_t Length = 0; }; -/** Channel for sending and receiving agent messages over a ComputeChannel. +/** Handler for async response reads. Receives the message type and a view of the payload data. + * The payload vector is valid until the next AsyncReadResponse call. */ +using AsyncResponseHandler = std::function<void(AgentMessageType Type, const uint8_t* Data, size_t Size)>; + +/** Async channel for sending and receiving agent messages over an AsyncComputeSocket. * - * Implements the Horde agent message protocol, matching the UE - * FAgentMessageChannel serialization format exactly. Messages are framed as - * [type (1B)][payload length (4B)][payload]. Strings use length-prefixed UTF-8; - * integers use variable-length encoding. + * Send methods build messages into vectors and submit them via AsyncComputeSocket. + * Receives are delivered via the socket's FrameHandler callback and queued internally. + * AsyncReadResponse checks the queue and invokes the handler, with optional timeout. * - * The protocol has two directions: - * - Requests (initiator -> remote): Close, Ping, Fork, Attach, UploadFiles, Execute, Blob - * - Responses (remote -> initiator): ReadResponse returns the type, then call the - * appropriate Read* method to parse the payload. + * All operations must be externally serialized (e.g. via the socket's strand). */ -class AgentMessageChannel +class AsyncAgentMessageChannel { public: - explicit AgentMessageChannel(Ref<ComputeChannel> Channel); - ~AgentMessageChannel(); + AsyncAgentMessageChannel(std::shared_ptr<AsyncComputeSocket> Socket, int ChannelId, asio::io_context& IoContext); + ~AsyncAgentMessageChannel(); - AgentMessageChannel(const AgentMessageChannel&) = delete; - AgentMessageChannel& operator=(const AgentMessageChannel&) = delete; + AsyncAgentMessageChannel(const AsyncAgentMessageChannel&) = delete; + AsyncAgentMessageChannel& operator=(const AsyncAgentMessageChannel&) = delete; - // --- Requests (Initiator -> Remote) --- + // --- Requests (fire-and-forget sends) --- - /** Close the channel. */ void Close(); - - /** Send a keepalive ping. */ void Ping(); - - /** Fork communication to a new channel with the given ID and buffer size. */ void Fork(int ChannelId, int BufferSize); - - /** Send an attach request (used during channel setup handshake). */ void Attach(); - - /** Request the remote agent to write files from the given bundle locator. */ void UploadFiles(const char* Path, const char* Locator); - - /** Execute a process on the remote machine. */ void Execute(const char* Exe, const char* const* Args, size_t NumArgs, @@ -101,61 +98,61 @@ public: const char* const* EnvVars, size_t NumEnvVars, ExecuteProcessFlags Flags = ExecuteProcessFlags::None); - - /** Send blob data in response to a ReadBlob request. */ void Blob(const uint8_t* Data, size_t Length); - // --- Responses (Remote -> Initiator) --- - - /** Read the next response message. Returns the message type, or None on timeout. - * After this returns, use GetResponseData()/GetResponseSize() or the typed - * Read* methods to access the payload. */ - AgentMessageType ReadResponse(int32_t TimeoutMs = -1, bool* OutTimedOut = nullptr); + // --- Async response reading --- - const void* GetResponseData() const { return m_ResponseData; } - size_t GetResponseSize() const { return m_ResponseLength; } + /** Read the next response. If a frame is already queued, the handler is posted immediately. + * Otherwise waits up to TimeoutMs for a frame to arrive. On timeout, invokes the handler + * with AgentMessageType::None. */ + void AsyncReadResponse(int32_t TimeoutMs, AsyncResponseHandler Handler); - /** Parse an Exception response payload. */ - void ReadException(ExceptionInfo& Ex); + /** Called by the socket's FrameHandler when a frame arrives for this channel. */ + void OnFrame(std::vector<uint8_t> Data); - /** Parse an ExecuteResult response payload. Returns the exit code. */ - int ReadExecuteResult(); + /** Called by the socket's DetachHandler. */ + void OnDetach(); - /** Parse a ReadBlob response payload into a BlobRequest. */ - void ReadBlobRequest(BlobRequest& Req); - -private: - static constexpr size_t MessageHeaderLength = 5; ///< [type(1B)][length(4B)] + /** Returns true if the channel has been detached (connection lost). */ + bool IsDetached() const { return m_Detached; } - Ref<ComputeChannel> m_Channel; + // --- Response parsing helpers --- - uint8_t* m_RequestData = nullptr; - size_t m_RequestSize = 0; - size_t m_MaxRequestSize = 0; + static void ReadException(const uint8_t* Data, size_t Size, ExceptionInfo& Ex); + static int ReadExecuteResult(const uint8_t* Data, size_t Size); + static void ReadBlobRequest(const uint8_t* Data, size_t Size, BlobRequest& Req); - AgentMessageType m_ResponseType = AgentMessageType::None; - const uint8_t* m_ResponseData = nullptr; - size_t m_ResponseLength = 0; +private: + static constexpr size_t MessageHeaderLength = 5; - void CreateMessage(AgentMessageType Type, size_t MaxLength); - void FlushMessage(); + // Message building helpers + std::vector<uint8_t> BeginMessage(AgentMessageType Type, size_t ReservePayload); + void FinalizeAndSend(std::vector<uint8_t> Msg); - void WriteInt32(int Value); - static int ReadInt32(const uint8_t** Pos); + static void WriteInt32(std::vector<uint8_t>& Buf, int Value); + static int ReadInt32(const uint8_t** Pos); - void WriteFixedLengthBytes(const uint8_t* Data, size_t Length); + static void WriteFixedLengthBytes(std::vector<uint8_t>& Buf, const uint8_t* Data, size_t Length); static const uint8_t* ReadFixedLengthBytes(const uint8_t** Pos, size_t Length); static size_t MeasureUnsignedVarInt(size_t Value); - void WriteUnsignedVarInt(size_t Value); + static void WriteUnsignedVarInt(std::vector<uint8_t>& Buf, size_t Value); static size_t ReadUnsignedVarInt(const uint8_t** Pos); - size_t MeasureString(const char* Text) const; - void WriteString(const char* Text); - void WriteString(std::string_view Text); + static void WriteString(std::vector<uint8_t>& Buf, const char* Text); + static void WriteString(std::vector<uint8_t>& Buf, std::string_view Text); static std::string_view ReadString(const uint8_t** Pos); - void WriteOptionalString(const char* Text); + static void WriteOptionalString(std::vector<uint8_t>& Buf, const char* Text); + + std::shared_ptr<AsyncComputeSocket> m_Socket; + int m_ChannelId; + asio::io_context& m_IoContext; + + std::deque<std::vector<uint8_t>> m_IncomingFrames; + AsyncResponseHandler m_PendingHandler; + std::unique_ptr<asio::steady_timer> m_TimeoutTimer; + bool m_Detached = false; }; } // namespace zen::horde diff --git a/src/zenhorde/hordebundle.cpp b/src/zenhorde/hordebundle.cpp index d3974bc28..af6b97e59 100644 --- a/src/zenhorde/hordebundle.cpp +++ b/src/zenhorde/hordebundle.cpp @@ -57,7 +57,7 @@ MeasureVarInt(size_t Value) { return 1; } - return (FloorLog2(static_cast<unsigned int>(Value)) / 7) + 1; + return (FloorLog2_64(static_cast<uint64_t>(Value)) / 7) + 1; } static void diff --git a/src/zenhorde/hordeclient.cpp b/src/zenhorde/hordeclient.cpp index 0eefc57c6..618a85e0e 100644 --- a/src/zenhorde/hordeclient.cpp +++ b/src/zenhorde/hordeclient.cpp @@ -4,6 +4,7 @@ #include <zencore/iobuffer.h> #include <zencore/logging.h> #include <zencore/memoryview.h> +#include <zencore/string.h> #include <zencore/trace.h> #include <zenhorde/hordeclient.h> #include <zenhttp/httpclient.h> @@ -14,7 +15,7 @@ ZEN_THIRD_PARTY_INCLUDES_END namespace zen::horde { -HordeClient::HordeClient(const HordeConfig& Config) : m_Config(Config), m_Log(zen::logging::Get("horde.client")) +HordeClient::HordeClient(HordeConfig Config) : m_Config(std::move(Config)), m_Log("horde.client") { } @@ -32,7 +33,11 @@ HordeClient::Initialize() Settings.RetryCount = 1; Settings.ExpectedErrorCodes = {HttpResponseCode::ServiceUnavailable, HttpResponseCode::TooManyRequests}; - if (!m_Config.AuthToken.empty()) + if (m_Config.AccessTokenProvider) + { + Settings.AccessTokenProvider = m_Config.AccessTokenProvider; + } + else if (!m_Config.AuthToken.empty()) { Settings.AccessTokenProvider = [token = m_Config.AuthToken]() -> HttpClientAccessToken { return HttpClientAccessToken(token, HttpClientAccessToken::Clock::now() + std::chrono::hours{24}); @@ -41,7 +46,7 @@ HordeClient::Initialize() m_Http = std::make_unique<zen::HttpClient>(m_Config.ServerUrl, Settings); - if (!m_Config.AuthToken.empty()) + if (Settings.AccessTokenProvider) { if (!m_Http->Authenticate()) { @@ -63,24 +68,21 @@ HordeClient::BuildRequestBody() const Requirements["pool"] = m_Config.Pool; } - std::string Condition; -#if ZEN_PLATFORM_WINDOWS ExtendableStringBuilder<256> CondBuf; +#if ZEN_PLATFORM_WINDOWS CondBuf << "(OSFamily == 'Windows' || WineEnabled == '" << (m_Config.AllowWine ? "true" : "false") << "')"; - Condition = std::string(CondBuf); #elif ZEN_PLATFORM_MAC - Condition = "OSFamily == 'MacOS'"; + CondBuf << "OSFamily == 'MacOS'"; #else - Condition = "OSFamily == 'Linux'"; + CondBuf << "OSFamily == 'Linux'"; #endif if (!m_Config.Condition.empty()) { - Condition += " "; - Condition += m_Config.Condition; + CondBuf << " " << m_Config.Condition; } - Requirements["condition"] = Condition; + Requirements["condition"] = std::string(CondBuf); Requirements["exclusive"] = true; json11::Json::object Connection; @@ -157,37 +159,8 @@ HordeClient::ResolveCluster(const std::string& RequestBody, ClusterInfo& OutClus } OutCluster.ClusterId = ClusterIdVal.string_value(); - return true; -} - -bool -HordeClient::ParseHexBytes(std::string_view Hex, uint8_t* Out, size_t OutSize) -{ - if (Hex.size() != OutSize * 2) - { - return false; - } - for (size_t i = 0; i < OutSize; ++i) - { - auto HexToByte = [](char c) -> int { - if (c >= '0' && c <= '9') - return c - '0'; - if (c >= 'a' && c <= 'f') - return c - 'a' + 10; - if (c >= 'A' && c <= 'F') - return c - 'A' + 10; - return -1; - }; - - const int Hi = HexToByte(Hex[i * 2]); - const int Lo = HexToByte(Hex[i * 2 + 1]); - if (Hi < 0 || Lo < 0) - { - return false; - } - Out[i] = static_cast<uint8_t>((Hi << 4) | Lo); - } + ZEN_DEBUG("cluster resolution succeeded: clusterId='{}'", OutCluster.ClusterId); return true; } @@ -197,8 +170,6 @@ HordeClient::RequestMachine(const std::string& RequestBody, const std::string& C { ZEN_TRACE_CPU("HordeClient::RequestMachine"); - ZEN_INFO("requesting machine from Horde with cluster '{}'", ClusterId.empty() ? "default" : ClusterId.c_str()); - ExtendableStringBuilder<128> ResourcePath; ResourcePath << "api/v2/compute/" << (ClusterId.empty() ? "default" : ClusterId.c_str()); @@ -324,6 +295,10 @@ HordeClient::RequestMachine(const std::string& RequestBody, const std::string& C { PhysicalCores = static_cast<uint16_t>(std::atoi(Prop.c_str() + 14)); } + else if (Prop.starts_with("Pool=")) + { + OutMachine.Pool = Prop.substr(5); + } } } @@ -367,10 +342,12 @@ HordeClient::RequestMachine(const std::string& RequestBody, const std::string& C OutMachine.LeaseId = LeaseIdVal.string_value(); } - ZEN_INFO("Horde machine assigned [{}:{}] cores={} lease={}", + ZEN_INFO("Horde machine assigned [{}:{}] mode={} cores={} pool={} lease={}", OutMachine.GetConnectionAddress(), OutMachine.GetConnectionPort(), + ToString(OutMachine.Mode), OutMachine.LogicalCores, + OutMachine.Pool, OutMachine.LeaseId); return true; diff --git a/src/zenhorde/hordecomputebuffer.cpp b/src/zenhorde/hordecomputebuffer.cpp deleted file mode 100644 index 0d032b5d5..000000000 --- a/src/zenhorde/hordecomputebuffer.cpp +++ /dev/null @@ -1,454 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#include "hordecomputebuffer.h" - -#include <algorithm> -#include <cassert> -#include <chrono> -#include <condition_variable> -#include <cstring> - -namespace zen::horde { - -// Simplified ring buffer implementation for in-process use only. -// Uses a single contiguous buffer with write/read cursors and -// mutex+condvar for synchronization. This is simpler than the UE version -// which uses lock-free atomics and shared memory, but sufficient for our -// use case where we're the initiator side of the compute protocol. - -struct ComputeBuffer::Detail : TRefCounted<Detail> -{ - std::vector<uint8_t> Data; - size_t NumChunks = 0; - size_t ChunkLength = 0; - - // Current write state - size_t WriteChunkIdx = 0; - size_t WriteOffset = 0; - bool WriteComplete = false; - - // Current read state - size_t ReadChunkIdx = 0; - size_t ReadOffset = 0; - bool Detached = false; - - // Per-chunk written length - std::vector<size_t> ChunkWrittenLength; - std::vector<bool> ChunkFinished; // Writer moved to next chunk - - std::mutex Mutex; - std::condition_variable ReadCV; ///< Signaled when new data is written or stream completes - std::condition_variable WriteCV; ///< Signaled when reader advances past a chunk, freeing space - - bool HasWriter = false; - bool HasReader = false; - - uint8_t* ChunkPtr(size_t ChunkIdx) { return Data.data() + ChunkIdx * ChunkLength; } - const uint8_t* ChunkPtr(size_t ChunkIdx) const { return Data.data() + ChunkIdx * ChunkLength; } -}; - -// ComputeBuffer - -ComputeBuffer::ComputeBuffer() -{ -} -ComputeBuffer::~ComputeBuffer() -{ -} - -bool -ComputeBuffer::CreateNew(const Params& InParams) -{ - auto* NewDetail = new Detail(); - NewDetail->NumChunks = InParams.NumChunks; - NewDetail->ChunkLength = InParams.ChunkLength; - NewDetail->Data.resize(InParams.NumChunks * InParams.ChunkLength, 0); - NewDetail->ChunkWrittenLength.resize(InParams.NumChunks, 0); - NewDetail->ChunkFinished.resize(InParams.NumChunks, false); - - m_Detail = NewDetail; - return true; -} - -void -ComputeBuffer::Close() -{ - m_Detail = nullptr; -} - -bool -ComputeBuffer::IsValid() const -{ - return static_cast<bool>(m_Detail); -} - -ComputeBufferReader -ComputeBuffer::CreateReader() -{ - assert(m_Detail); - m_Detail->HasReader = true; - return ComputeBufferReader(m_Detail); -} - -ComputeBufferWriter -ComputeBuffer::CreateWriter() -{ - assert(m_Detail); - m_Detail->HasWriter = true; - return ComputeBufferWriter(m_Detail); -} - -// ComputeBufferReader - -ComputeBufferReader::ComputeBufferReader() -{ -} -ComputeBufferReader::~ComputeBufferReader() -{ -} - -ComputeBufferReader::ComputeBufferReader(const ComputeBufferReader& Other) = default; -ComputeBufferReader::ComputeBufferReader(ComputeBufferReader&& Other) noexcept = default; -ComputeBufferReader& ComputeBufferReader::operator=(const ComputeBufferReader& Other) = default; -ComputeBufferReader& ComputeBufferReader::operator=(ComputeBufferReader&& Other) noexcept = default; - -ComputeBufferReader::ComputeBufferReader(Ref<ComputeBuffer::Detail> InDetail) : m_Detail(std::move(InDetail)) -{ -} - -void -ComputeBufferReader::Close() -{ - m_Detail = nullptr; -} - -void -ComputeBufferReader::Detach() -{ - if (m_Detail) - { - std::lock_guard<std::mutex> Lock(m_Detail->Mutex); - m_Detail->Detached = true; - m_Detail->ReadCV.notify_all(); - } -} - -bool -ComputeBufferReader::IsValid() const -{ - return static_cast<bool>(m_Detail); -} - -bool -ComputeBufferReader::IsComplete() const -{ - if (!m_Detail) - { - return true; - } - std::lock_guard<std::mutex> Lock(m_Detail->Mutex); - if (m_Detail->Detached) - { - return true; - } - return m_Detail->WriteComplete && m_Detail->ReadChunkIdx == m_Detail->WriteChunkIdx && - m_Detail->ReadOffset >= m_Detail->ChunkWrittenLength[m_Detail->ReadChunkIdx]; -} - -void -ComputeBufferReader::AdvanceReadPosition(size_t Size) -{ - if (!m_Detail) - { - return; - } - - std::lock_guard<std::mutex> Lock(m_Detail->Mutex); - - m_Detail->ReadOffset += Size; - - // Check if we need to move to next chunk - const size_t ReadChunk = m_Detail->ReadChunkIdx; - if (m_Detail->ChunkFinished[ReadChunk] && m_Detail->ReadOffset >= m_Detail->ChunkWrittenLength[ReadChunk]) - { - const size_t NextChunk = (ReadChunk + 1) % m_Detail->NumChunks; - m_Detail->ReadChunkIdx = NextChunk; - m_Detail->ReadOffset = 0; - m_Detail->WriteCV.notify_all(); - } - - m_Detail->ReadCV.notify_all(); -} - -size_t -ComputeBufferReader::GetMaxReadSize() const -{ - if (!m_Detail) - { - return 0; - } - std::lock_guard<std::mutex> Lock(m_Detail->Mutex); - const size_t ReadChunk = m_Detail->ReadChunkIdx; - return m_Detail->ChunkWrittenLength[ReadChunk] - m_Detail->ReadOffset; -} - -const uint8_t* -ComputeBufferReader::WaitToRead(size_t MinSize, int TimeoutMs, bool* OutTimedOut) -{ - if (!m_Detail) - { - return nullptr; - } - - std::unique_lock<std::mutex> Lock(m_Detail->Mutex); - - auto Predicate = [&]() -> bool { - if (m_Detail->Detached) - { - return true; - } - - const size_t ReadChunk = m_Detail->ReadChunkIdx; - const size_t Available = m_Detail->ChunkWrittenLength[ReadChunk] - m_Detail->ReadOffset; - - if (Available >= MinSize) - { - return true; - } - - // If chunk is finished and we've read everything, try to move to next - if (m_Detail->ChunkFinished[ReadChunk] && m_Detail->ReadOffset >= m_Detail->ChunkWrittenLength[ReadChunk]) - { - if (m_Detail->WriteComplete) - { - return true; // End of stream - } - // Move to next chunk - const size_t NextChunk = (ReadChunk + 1) % m_Detail->NumChunks; - m_Detail->ReadChunkIdx = NextChunk; - m_Detail->ReadOffset = 0; - m_Detail->WriteCV.notify_all(); - return false; // Re-check with new chunk - } - - if (m_Detail->WriteComplete) - { - return true; // End of stream - } - - return false; - }; - - if (TimeoutMs < 0) - { - m_Detail->ReadCV.wait(Lock, Predicate); - } - else - { - if (!m_Detail->ReadCV.wait_for(Lock, std::chrono::milliseconds(TimeoutMs), Predicate)) - { - if (OutTimedOut) - { - *OutTimedOut = true; - } - return nullptr; - } - } - - if (m_Detail->Detached) - { - return nullptr; - } - - const size_t ReadChunk = m_Detail->ReadChunkIdx; - const size_t Available = m_Detail->ChunkWrittenLength[ReadChunk] - m_Detail->ReadOffset; - - if (Available < MinSize) - { - return nullptr; // End of stream - } - - return m_Detail->ChunkPtr(ReadChunk) + m_Detail->ReadOffset; -} - -size_t -ComputeBufferReader::Read(void* Buffer, size_t MaxSize, int TimeoutMs, bool* OutTimedOut) -{ - const uint8_t* Data = WaitToRead(1, TimeoutMs, OutTimedOut); - if (!Data) - { - return 0; - } - - const size_t Available = GetMaxReadSize(); - const size_t ToCopy = std::min(Available, MaxSize); - memcpy(Buffer, Data, ToCopy); - AdvanceReadPosition(ToCopy); - return ToCopy; -} - -// ComputeBufferWriter - -ComputeBufferWriter::ComputeBufferWriter() = default; -ComputeBufferWriter::ComputeBufferWriter(const ComputeBufferWriter& Other) = default; -ComputeBufferWriter::ComputeBufferWriter(ComputeBufferWriter&& Other) noexcept = default; -ComputeBufferWriter::~ComputeBufferWriter() = default; -ComputeBufferWriter& ComputeBufferWriter::operator=(const ComputeBufferWriter& Other) = default; -ComputeBufferWriter& ComputeBufferWriter::operator=(ComputeBufferWriter&& Other) noexcept = default; - -ComputeBufferWriter::ComputeBufferWriter(Ref<ComputeBuffer::Detail> InDetail) : m_Detail(std::move(InDetail)) -{ -} - -void -ComputeBufferWriter::Close() -{ - if (m_Detail) - { - { - std::lock_guard<std::mutex> Lock(m_Detail->Mutex); - if (!m_Detail->WriteComplete) - { - m_Detail->WriteComplete = true; - m_Detail->ReadCV.notify_all(); - } - } - m_Detail = nullptr; - } -} - -bool -ComputeBufferWriter::IsValid() const -{ - return static_cast<bool>(m_Detail); -} - -void -ComputeBufferWriter::MarkComplete() -{ - if (m_Detail) - { - std::lock_guard<std::mutex> Lock(m_Detail->Mutex); - m_Detail->WriteComplete = true; - m_Detail->ReadCV.notify_all(); - } -} - -void -ComputeBufferWriter::AdvanceWritePosition(size_t Size) -{ - if (!m_Detail || Size == 0) - { - return; - } - - std::lock_guard<std::mutex> Lock(m_Detail->Mutex); - const size_t WriteChunk = m_Detail->WriteChunkIdx; - m_Detail->ChunkWrittenLength[WriteChunk] += Size; - m_Detail->WriteOffset += Size; - m_Detail->ReadCV.notify_all(); -} - -size_t -ComputeBufferWriter::GetMaxWriteSize() const -{ - if (!m_Detail) - { - return 0; - } - std::lock_guard<std::mutex> Lock(m_Detail->Mutex); - const size_t WriteChunk = m_Detail->WriteChunkIdx; - return m_Detail->ChunkLength - m_Detail->ChunkWrittenLength[WriteChunk]; -} - -size_t -ComputeBufferWriter::GetChunkMaxLength() const -{ - if (!m_Detail) - { - return 0; - } - return m_Detail->ChunkLength; -} - -size_t -ComputeBufferWriter::Write(const void* Buffer, size_t MaxSize, int TimeoutMs) -{ - uint8_t* Dest = WaitToWrite(1, TimeoutMs); - if (!Dest) - { - return 0; - } - - const size_t Available = GetMaxWriteSize(); - const size_t ToCopy = std::min(Available, MaxSize); - memcpy(Dest, Buffer, ToCopy); - AdvanceWritePosition(ToCopy); - return ToCopy; -} - -uint8_t* -ComputeBufferWriter::WaitToWrite(size_t MinSize, int TimeoutMs) -{ - if (!m_Detail) - { - return nullptr; - } - - std::unique_lock<std::mutex> Lock(m_Detail->Mutex); - - if (m_Detail->WriteComplete) - { - return nullptr; - } - - const size_t WriteChunk = m_Detail->WriteChunkIdx; - const size_t Available = m_Detail->ChunkLength - m_Detail->ChunkWrittenLength[WriteChunk]; - - // If current chunk has enough space, return pointer - if (Available >= MinSize) - { - return m_Detail->ChunkPtr(WriteChunk) + m_Detail->ChunkWrittenLength[WriteChunk]; - } - - // Current chunk is full - mark it as finished and move to next. - // The writer cannot advance until the reader has fully consumed the next chunk, - // preventing the writer from overwriting data the reader hasn't processed yet. - m_Detail->ChunkFinished[WriteChunk] = true; - m_Detail->ReadCV.notify_all(); - - const size_t NextChunk = (WriteChunk + 1) % m_Detail->NumChunks; - - // Wait until reader has consumed the next chunk - auto Predicate = [&]() -> bool { - // Check if read has moved past this chunk - return m_Detail->ReadChunkIdx != NextChunk || m_Detail->Detached; - }; - - if (TimeoutMs < 0) - { - m_Detail->WriteCV.wait(Lock, Predicate); - } - else - { - if (!m_Detail->WriteCV.wait_for(Lock, std::chrono::milliseconds(TimeoutMs), Predicate)) - { - return nullptr; - } - } - - if (m_Detail->Detached) - { - return nullptr; - } - - // Reset next chunk - m_Detail->ChunkWrittenLength[NextChunk] = 0; - m_Detail->ChunkFinished[NextChunk] = false; - m_Detail->WriteChunkIdx = NextChunk; - m_Detail->WriteOffset = 0; - - return m_Detail->ChunkPtr(NextChunk); -} - -} // namespace zen::horde diff --git a/src/zenhorde/hordecomputebuffer.h b/src/zenhorde/hordecomputebuffer.h deleted file mode 100644 index 64ef91b7a..000000000 --- a/src/zenhorde/hordecomputebuffer.h +++ /dev/null @@ -1,136 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#pragma once - -#include <zenbase/refcount.h> - -#include <cstddef> -#include <cstdint> -#include <mutex> -#include <vector> - -namespace zen::horde { - -class ComputeBufferReader; -class ComputeBufferWriter; - -/** Simplified in-process ring buffer for the Horde compute protocol. - * - * Unlike the UE FComputeBuffer which supports shared-memory and memory-mapped files, - * this implementation uses plain heap-allocated memory since we only need in-process - * communication between channel and transport threads. The buffer is divided into - * fixed-size chunks; readers and writers block when no space is available. - */ -class ComputeBuffer -{ -public: - struct Params - { - size_t NumChunks = 2; - size_t ChunkLength = 512 * 1024; - }; - - ComputeBuffer(); - ~ComputeBuffer(); - - ComputeBuffer(const ComputeBuffer&) = delete; - ComputeBuffer& operator=(const ComputeBuffer&) = delete; - - bool CreateNew(const Params& InParams); - void Close(); - - bool IsValid() const; - - ComputeBufferReader CreateReader(); - ComputeBufferWriter CreateWriter(); - -private: - struct Detail; - Ref<Detail> m_Detail; - - friend class ComputeBufferReader; - friend class ComputeBufferWriter; -}; - -/** Read endpoint for a ComputeBuffer. - * - * Provides blocking reads from the ring buffer. WaitToRead() returns a pointer - * directly into the buffer memory (zero-copy); the caller must call - * AdvanceReadPosition() after consuming the data. - */ -class ComputeBufferReader -{ -public: - ComputeBufferReader(); - ComputeBufferReader(const ComputeBufferReader&); - ComputeBufferReader(ComputeBufferReader&&) noexcept; - ~ComputeBufferReader(); - - ComputeBufferReader& operator=(const ComputeBufferReader&); - ComputeBufferReader& operator=(ComputeBufferReader&&) noexcept; - - void Close(); - void Detach(); - bool IsValid() const; - bool IsComplete() const; - - void AdvanceReadPosition(size_t Size); - size_t GetMaxReadSize() const; - - /** Copy up to MaxSize bytes from the buffer into Buffer. Blocks until data is available. */ - size_t Read(void* Buffer, size_t MaxSize, int TimeoutMs = -1, bool* OutTimedOut = nullptr); - - /** Wait until at least MinSize bytes are available and return a direct pointer. - * Returns nullptr on timeout or if the writer has completed. */ - const uint8_t* WaitToRead(size_t MinSize, int TimeoutMs = -1, bool* OutTimedOut = nullptr); - -private: - friend class ComputeBuffer; - explicit ComputeBufferReader(Ref<ComputeBuffer::Detail> InDetail); - - Ref<ComputeBuffer::Detail> m_Detail; -}; - -/** Write endpoint for a ComputeBuffer. - * - * Provides blocking writes into the ring buffer. WaitToWrite() returns a pointer - * directly into the buffer memory (zero-copy); the caller must call - * AdvanceWritePosition() after filling the data. Call MarkComplete() to signal - * that no more data will be written. - */ -class ComputeBufferWriter -{ -public: - ComputeBufferWriter(); - ComputeBufferWriter(const ComputeBufferWriter&); - ComputeBufferWriter(ComputeBufferWriter&&) noexcept; - ~ComputeBufferWriter(); - - ComputeBufferWriter& operator=(const ComputeBufferWriter&); - ComputeBufferWriter& operator=(ComputeBufferWriter&&) noexcept; - - void Close(); - bool IsValid() const; - - /** Signal that no more data will be written. Unblocks any waiting readers. */ - void MarkComplete(); - - void AdvanceWritePosition(size_t Size); - size_t GetMaxWriteSize() const; - size_t GetChunkMaxLength() const; - - /** Copy up to MaxSize bytes from Buffer into the ring buffer. Blocks until space is available. */ - size_t Write(const void* Buffer, size_t MaxSize, int TimeoutMs = -1); - - /** Wait until at least MinSize bytes of write space are available and return a direct pointer. - * Returns nullptr on timeout. */ - uint8_t* WaitToWrite(size_t MinSize, int TimeoutMs = -1); - -private: - friend class ComputeBuffer; - explicit ComputeBufferWriter(Ref<ComputeBuffer::Detail> InDetail); - - Ref<ComputeBuffer::Detail> m_Detail; -}; - -} // namespace zen::horde diff --git a/src/zenhorde/hordecomputechannel.cpp b/src/zenhorde/hordecomputechannel.cpp deleted file mode 100644 index ee2a6f327..000000000 --- a/src/zenhorde/hordecomputechannel.cpp +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#include "hordecomputechannel.h" - -namespace zen::horde { - -ComputeChannel::ComputeChannel(ComputeBufferReader InReader, ComputeBufferWriter InWriter) -: Reader(std::move(InReader)) -, Writer(std::move(InWriter)) -{ -} - -bool -ComputeChannel::IsValid() const -{ - return Reader.IsValid() && Writer.IsValid(); -} - -size_t -ComputeChannel::Send(const void* Data, size_t Size, int TimeoutMs) -{ - return Writer.Write(Data, Size, TimeoutMs); -} - -size_t -ComputeChannel::Recv(void* Data, size_t Size, int TimeoutMs) -{ - return Reader.Read(Data, Size, TimeoutMs); -} - -void -ComputeChannel::MarkComplete() -{ - Writer.MarkComplete(); -} - -} // namespace zen::horde diff --git a/src/zenhorde/hordecomputechannel.h b/src/zenhorde/hordecomputechannel.h deleted file mode 100644 index c1dff20e4..000000000 --- a/src/zenhorde/hordecomputechannel.h +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#pragma once - -#include "hordecomputebuffer.h" - -namespace zen::horde { - -/** Bidirectional communication channel using a pair of compute buffers. - * - * Pairs a ComputeBufferReader (for receiving data) with a ComputeBufferWriter - * (for sending data). Used by ComputeSocket to represent one logical channel - * within a multiplexed connection. - */ -class ComputeChannel : public TRefCounted<ComputeChannel> -{ -public: - ComputeBufferReader Reader; - ComputeBufferWriter Writer; - - ComputeChannel(ComputeBufferReader InReader, ComputeBufferWriter InWriter); - - bool IsValid() const; - - size_t Send(const void* Data, size_t Size, int TimeoutMs = -1); - size_t Recv(void* Data, size_t Size, int TimeoutMs = -1); - - /** Signal that no more data will be sent on this channel. */ - void MarkComplete(); -}; - -} // namespace zen::horde diff --git a/src/zenhorde/hordecomputesocket.cpp b/src/zenhorde/hordecomputesocket.cpp index 6ef67760c..8a6fc40a9 100644 --- a/src/zenhorde/hordecomputesocket.cpp +++ b/src/zenhorde/hordecomputesocket.cpp @@ -6,198 +6,326 @@ namespace zen::horde { -ComputeSocket::ComputeSocket(std::unique_ptr<ComputeTransport> Transport) -: m_Log(zen::logging::Get("horde.socket")) +AsyncComputeSocket::AsyncComputeSocket(std::unique_ptr<AsyncComputeTransport> Transport, asio::io_context& IoContext) +: m_Log(zen::logging::Get("horde.socket.async")) , m_Transport(std::move(Transport)) +, m_Strand(asio::make_strand(IoContext)) +, m_PingTimer(m_Strand) { } -ComputeSocket::~ComputeSocket() +AsyncComputeSocket::~AsyncComputeSocket() { - // Shutdown order matters: first stop the ping thread, then unblock send threads - // by detaching readers, then join send threads, and finally close the transport - // to unblock the recv thread (which is blocked on RecvMessage). - { - std::lock_guard<std::mutex> Lock(m_PingMutex); - m_PingShouldStop = true; - m_PingCV.notify_all(); - } - - for (auto& Reader : m_Readers) - { - Reader.Detach(); - } - - for (auto& [Id, Thread] : m_SendThreads) - { - if (Thread.joinable()) - { - Thread.join(); - } - } + Close(); +} - m_Transport->Close(); +void +AsyncComputeSocket::RegisterChannel(int ChannelId, FrameHandler OnFrame, DetachHandler OnDetach) +{ + m_FrameHandlers[ChannelId] = std::move(OnFrame); + m_DetachHandlers[ChannelId] = std::move(OnDetach); +} - if (m_RecvThread.joinable()) - { - m_RecvThread.join(); - } - if (m_PingThread.joinable()) - { - m_PingThread.join(); - } +void +AsyncComputeSocket::StartRecvPump() +{ + StartPingTimer(); + DoRecvHeader(); } -Ref<ComputeChannel> -ComputeSocket::CreateChannel(int ChannelId) +void +AsyncComputeSocket::DoRecvHeader() { - ComputeBuffer::Params Params; + auto Self = shared_from_this(); + m_Transport->AsyncRead(&m_RecvHeader, + sizeof(FrameHeader), + asio::bind_executor(m_Strand, [this, Self](const std::error_code& Ec, size_t /*Bytes*/) { + if (Ec) + { + if (Ec != asio::error::operation_aborted && !m_Closed) + { + ZEN_WARN("recv header error: {}", Ec.message()); + HandleError(); + } + return; + } - ComputeBuffer RecvBuffer; - if (!RecvBuffer.CreateNew(Params)) - { - return {}; - } + if (m_Closed) + { + return; + } - ComputeBuffer SendBuffer; - if (!SendBuffer.CreateNew(Params)) - { - return {}; - } + if (m_RecvHeader.Size >= 0) + { + DoRecvPayload(m_RecvHeader); + } + else if (m_RecvHeader.Size == ControlDetach) + { + if (auto It = m_DetachHandlers.find(m_RecvHeader.Channel); It != m_DetachHandlers.end() && It->second) + { + It->second(); + } + DoRecvHeader(); + } + else if (m_RecvHeader.Size == ControlPing) + { + DoRecvHeader(); + } + else + { + ZEN_WARN("invalid frame header size: {}", m_RecvHeader.Size); + } + })); +} - Ref<ComputeChannel> Channel(new ComputeChannel(RecvBuffer.CreateReader(), SendBuffer.CreateWriter())); +void +AsyncComputeSocket::DoRecvPayload(FrameHeader Header) +{ + auto PayloadBuf = std::make_shared<std::vector<uint8_t>>(static_cast<size_t>(Header.Size)); + auto Self = shared_from_this(); - // Attach recv buffer writer (transport recv thread writes into this) - { - std::lock_guard<std::mutex> Lock(m_WritersMutex); - m_Writers.emplace(ChannelId, RecvBuffer.CreateWriter()); - } + m_Transport->AsyncRead(PayloadBuf->data(), + PayloadBuf->size(), + asio::bind_executor(m_Strand, [this, Self, Header, PayloadBuf](const std::error_code& Ec, size_t /*Bytes*/) { + if (Ec) + { + if (Ec != asio::error::operation_aborted && !m_Closed) + { + ZEN_WARN("recv payload error (channel={}, size={}): {}", Header.Channel, Header.Size, Ec.message()); + HandleError(); + } + return; + } - // Attach send buffer reader (send thread reads from this) - { - ComputeBufferReader Reader = SendBuffer.CreateReader(); - m_Readers.push_back(Reader); - m_SendThreads.emplace(ChannelId, std::thread(&ComputeSocket::SendThreadProc, this, ChannelId, std::move(Reader))); - } + if (m_Closed) + { + return; + } + + if (auto It = m_FrameHandlers.find(Header.Channel); It != m_FrameHandlers.end() && It->second) + { + It->second(std::move(*PayloadBuf)); + } + else + { + ZEN_WARN("recv frame for unknown channel {}", Header.Channel); + } - return Channel; + DoRecvHeader(); + })); } void -ComputeSocket::StartCommunication() +AsyncComputeSocket::AsyncSendFrame(int ChannelId, std::vector<uint8_t> Data, SendHandler Handler) { - m_RecvThread = std::thread(&ComputeSocket::RecvThreadProc, this); - m_PingThread = std::thread(&ComputeSocket::PingThreadProc, this); + auto Self = shared_from_this(); + asio::dispatch(m_Strand, [this, Self, ChannelId, Data = std::move(Data), Handler = std::move(Handler)]() mutable { + if (m_Closed) + { + if (Handler) + { + Handler(asio::error::make_error_code(asio::error::operation_aborted)); + } + return; + } + + PendingWrite Write; + Write.Header.Channel = ChannelId; + Write.Header.Size = static_cast<int32_t>(Data.size()); + Write.Data = std::move(Data); + Write.Handler = std::move(Handler); + + m_SendQueue.push_back(std::move(Write)); + if (m_SendQueue.size() == 1) + { + FlushNextSend(); + } + }); } void -ComputeSocket::PingThreadProc() +AsyncComputeSocket::AsyncSendDetach(int ChannelId, SendHandler Handler) { - while (true) - { + auto Self = shared_from_this(); + asio::dispatch(m_Strand, [this, Self, ChannelId, Handler = std::move(Handler)]() mutable { + if (m_Closed) { - std::unique_lock<std::mutex> Lock(m_PingMutex); - if (m_PingCV.wait_for(Lock, std::chrono::milliseconds(2000), [this] { return m_PingShouldStop; })) + if (Handler) { - break; + Handler(asio::error::make_error_code(asio::error::operation_aborted)); } + return; } - std::lock_guard<std::mutex> Lock(m_SendMutex); - FrameHeader Header; - Header.Channel = 0; - Header.Size = ControlPing; - m_Transport->SendMessage(&Header, sizeof(Header)); - } + PendingWrite Write; + Write.Header.Channel = ChannelId; + Write.Header.Size = ControlDetach; + Write.Handler = std::move(Handler); + + m_SendQueue.push_back(std::move(Write)); + if (m_SendQueue.size() == 1) + { + FlushNextSend(); + } + }); } void -ComputeSocket::RecvThreadProc() +AsyncComputeSocket::FlushNextSend() { - // Writers are cached locally to avoid taking m_WritersMutex on every frame. - // The shared m_Writers map is only accessed when a channel is seen for the first time. - std::unordered_map<int, ComputeBufferWriter> CachedWriters; + if (m_SendQueue.empty() || m_Closed) + { + return; + } - FrameHeader Header; - while (m_Transport->RecvMessage(&Header, sizeof(Header))) + PendingWrite& Front = m_SendQueue.front(); + + if (Front.Data.empty()) { - if (Header.Size >= 0) - { - // Data frame - auto It = CachedWriters.find(Header.Channel); - if (It == CachedWriters.end()) - { - std::lock_guard<std::mutex> Lock(m_WritersMutex); - auto WIt = m_Writers.find(Header.Channel); - if (WIt == m_Writers.end()) - { - ZEN_WARN("recv frame for unknown channel {}", Header.Channel); - // Skip the data - std::vector<uint8_t> Discard(Header.Size); - m_Transport->RecvMessage(Discard.data(), Header.Size); - continue; - } - It = CachedWriters.emplace(Header.Channel, WIt->second).first; - } + // Control frame — header only + auto Self = shared_from_this(); + m_Transport->AsyncWrite(&Front.Header, + sizeof(FrameHeader), + asio::bind_executor(m_Strand, [this, Self](const std::error_code& Ec, size_t /*Bytes*/) { + SendHandler Handler = std::move(m_SendQueue.front().Handler); + m_SendQueue.pop_front(); - ComputeBufferWriter& Writer = It->second; - uint8_t* Dest = Writer.WaitToWrite(Header.Size); - if (!Dest || !m_Transport->RecvMessage(Dest, Header.Size)) - { - ZEN_WARN("failed to read frame data (channel={}, size={})", Header.Channel, Header.Size); - return; - } - Writer.AdvanceWritePosition(Header.Size); - } - else if (Header.Size == ControlDetach) - { - // Detach the recv buffer for this channel - CachedWriters.erase(Header.Channel); + if (Handler) + { + Handler(Ec); + } - std::lock_guard<std::mutex> Lock(m_WritersMutex); - auto It = m_Writers.find(Header.Channel); - if (It != m_Writers.end()) - { - It->second.MarkComplete(); - m_Writers.erase(It); - } - } - else if (Header.Size == ControlPing) + if (Ec) + { + if (Ec != asio::error::operation_aborted) + { + ZEN_WARN("send error: {}", Ec.message()); + HandleError(); + } + return; + } + + FlushNextSend(); + })); + } + else + { + // Data frame — write header first, then payload + auto Self = shared_from_this(); + m_Transport->AsyncWrite(&Front.Header, + sizeof(FrameHeader), + asio::bind_executor(m_Strand, [this, Self](const std::error_code& Ec, size_t /*Bytes*/) { + if (Ec) + { + SendHandler Handler = std::move(m_SendQueue.front().Handler); + m_SendQueue.pop_front(); + if (Handler) + { + Handler(Ec); + } + if (Ec != asio::error::operation_aborted) + { + ZEN_WARN("send header error: {}", Ec.message()); + HandleError(); + } + return; + } + + PendingWrite& Payload = m_SendQueue.front(); + m_Transport->AsyncWrite( + Payload.Data.data(), + Payload.Data.size(), + asio::bind_executor(m_Strand, [this, Self](const std::error_code& Ec, size_t /*Bytes*/) { + SendHandler Handler = std::move(m_SendQueue.front().Handler); + m_SendQueue.pop_front(); + + if (Handler) + { + Handler(Ec); + } + + if (Ec) + { + if (Ec != asio::error::operation_aborted) + { + ZEN_WARN("send payload error: {}", Ec.message()); + HandleError(); + } + return; + } + + FlushNextSend(); + })); + })); + } +} + +void +AsyncComputeSocket::StartPingTimer() +{ + if (m_Closed) + { + return; + } + + m_PingTimer.expires_after(std::chrono::seconds(2)); + + auto Self = shared_from_this(); + m_PingTimer.async_wait(asio::bind_executor(m_Strand, [this, Self](const asio::error_code& Ec) { + if (Ec || m_Closed) { - // Ping response - ignore + return; } - else + + // Enqueue a ping control frame + PendingWrite Write; + Write.Header.Channel = 0; + Write.Header.Size = ControlPing; + + m_SendQueue.push_back(std::move(Write)); + if (m_SendQueue.size() == 1) { - ZEN_WARN("invalid frame header size: {}", Header.Size); - return; + FlushNextSend(); } - } + + StartPingTimer(); + })); } void -ComputeSocket::SendThreadProc(int Channel, ComputeBufferReader Reader) +AsyncComputeSocket::HandleError() { - // Each channel has its own send thread. All send threads share m_SendMutex - // to serialize writes to the transport, since TCP requires atomic frame writes. - FrameHeader Header; - Header.Channel = Channel; + if (m_Closed) + { + return; + } + + Close(); - const uint8_t* Data; - while ((Data = Reader.WaitToRead(1)) != nullptr) + // Notify all channels that the connection is gone so agents can clean up + for (auto& [ChannelId, Handler] : m_DetachHandlers) { - std::lock_guard<std::mutex> Lock(m_SendMutex); + if (Handler) + { + Handler(); + } + } +} - Header.Size = static_cast<int32_t>(Reader.GetMaxReadSize()); - m_Transport->SendMessage(&Header, sizeof(Header)); - m_Transport->SendMessage(Data, Header.Size); - Reader.AdvanceReadPosition(Header.Size); +void +AsyncComputeSocket::Close() +{ + if (m_Closed) + { + return; } - if (Reader.IsComplete()) + m_Closed = true; + m_PingTimer.cancel(); + + if (m_Transport) { - std::lock_guard<std::mutex> Lock(m_SendMutex); - Header.Size = ControlDetach; - m_Transport->SendMessage(&Header, sizeof(Header)); + m_Transport->Close(); } } diff --git a/src/zenhorde/hordecomputesocket.h b/src/zenhorde/hordecomputesocket.h index 0c3cb4195..45b3418b7 100644 --- a/src/zenhorde/hordecomputesocket.h +++ b/src/zenhorde/hordecomputesocket.h @@ -2,45 +2,69 @@ #pragma once -#include "hordecomputebuffer.h" -#include "hordecomputechannel.h" #include "hordetransport.h" #include <zencore/logbase.h> -#include <condition_variable> +ZEN_THIRD_PARTY_INCLUDES_START +#include <asio.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +#if ZEN_PLATFORM_WINDOWS +# undef SendMessage +#endif + +#include <deque> +#include <functional> #include <memory> -#include <mutex> -#include <thread> +#include <system_error> #include <unordered_map> #include <vector> namespace zen::horde { -/** Multiplexed socket that routes data between multiple channels over a single transport. +class AsyncComputeTransport; + +/** Handler called when a data frame arrives for a channel. */ +using FrameHandler = std::function<void(std::vector<uint8_t> Data)>; + +/** Handler called when a channel is detached by the remote peer. */ +using DetachHandler = std::function<void()>; + +/** Handler for async send completion. */ +using SendHandler = std::function<void(const std::error_code&)>; + +/** Async multiplexed socket that routes data between channels over a single transport. * - * Each channel is identified by an integer ID and backed by a pair of ComputeBuffers. - * A recv thread demultiplexes incoming frames to channel-specific buffers, while - * per-channel send threads multiplex outgoing data onto the shared transport. + * Uses an async recv pump, a serialized send queue, and a periodic ping timer — + * all running on a shared io_context. * - * Wire format per frame: [channelId (4B)][size (4B)][data] - * Control messages use negative sizes: -2 = detach (channel closed), -3 = ping. + * Wire format per frame: [channelId(4B)][size(4B)][data]. + * Control messages use negative sizes: -2 = detach, -3 = ping. */ -class ComputeSocket +class AsyncComputeSocket : public std::enable_shared_from_this<AsyncComputeSocket> { public: - explicit ComputeSocket(std::unique_ptr<ComputeTransport> Transport); - ~ComputeSocket(); + AsyncComputeSocket(std::unique_ptr<AsyncComputeTransport> Transport, asio::io_context& IoContext); + ~AsyncComputeSocket(); + + AsyncComputeSocket(const AsyncComputeSocket&) = delete; + AsyncComputeSocket& operator=(const AsyncComputeSocket&) = delete; - ComputeSocket(const ComputeSocket&) = delete; - ComputeSocket& operator=(const ComputeSocket&) = delete; + /** Register callbacks for a channel. Must be called before StartRecvPump(). */ + void RegisterChannel(int ChannelId, FrameHandler OnFrame, DetachHandler OnDetach); - /** Create a channel with the given ID. - * Allocates anonymous in-process buffers and spawns a send thread for the channel. */ - Ref<ComputeChannel> CreateChannel(int ChannelId); + /** Begin the async recv pump and ping timer. */ + void StartRecvPump(); - /** Start the recv pump and ping threads. Must be called after all channels are created. */ - void StartCommunication(); + /** Enqueue a data frame for async transmission. */ + void AsyncSendFrame(int ChannelId, std::vector<uint8_t> Data, SendHandler Handler = {}); + + /** Send a control frame (detach) for a channel. */ + void AsyncSendDetach(int ChannelId, SendHandler Handler = {}); + + /** Close the transport and cancel all pending operations. */ + void Close(); private: struct FrameHeader @@ -49,31 +73,35 @@ private: int32_t Size = 0; }; + struct PendingWrite + { + FrameHeader Header; + std::vector<uint8_t> Data; + SendHandler Handler; + }; + static constexpr int32_t ControlDetach = -2; static constexpr int32_t ControlPing = -3; LoggerRef Log() { return m_Log; } - void RecvThreadProc(); - void SendThreadProc(int Channel, ComputeBufferReader Reader); - void PingThreadProc(); - - LoggerRef m_Log; - std::unique_ptr<ComputeTransport> m_Transport; - std::mutex m_SendMutex; ///< Serializes writes to the transport - - std::mutex m_WritersMutex; - std::unordered_map<int, ComputeBufferWriter> m_Writers; ///< Recv-side: writers keyed by channel ID + void DoRecvHeader(); + void DoRecvPayload(FrameHeader Header); + void FlushNextSend(); + void StartPingTimer(); + void HandleError(); - std::vector<ComputeBufferReader> m_Readers; ///< Send-side: readers for join on destruction - std::unordered_map<int, std::thread> m_SendThreads; ///< One send thread per channel + LoggerRef m_Log; + std::unique_ptr<AsyncComputeTransport> m_Transport; + asio::strand<asio::any_io_executor> m_Strand; + asio::steady_timer m_PingTimer; - std::thread m_RecvThread; - std::thread m_PingThread; + std::unordered_map<int, FrameHandler> m_FrameHandlers; + std::unordered_map<int, DetachHandler> m_DetachHandlers; - bool m_PingShouldStop = false; - std::mutex m_PingMutex; - std::condition_variable m_PingCV; + FrameHeader m_RecvHeader; + std::deque<PendingWrite> m_SendQueue; + bool m_Closed = false; }; } // namespace zen::horde diff --git a/src/zenhorde/hordeconfig.cpp b/src/zenhorde/hordeconfig.cpp index 2dca228d9..9f6125c64 100644 --- a/src/zenhorde/hordeconfig.cpp +++ b/src/zenhorde/hordeconfig.cpp @@ -1,5 +1,7 @@ // Copyright Epic Games, Inc. All Rights Reserved. +#include <zencore/logging.h> +#include <zencore/string.h> #include <zenhorde/hordeconfig.h> namespace zen::horde { @@ -9,12 +11,14 @@ HordeConfig::Validate() const { if (ServerUrl.empty()) { + ZEN_WARN("Horde server URL is not configured"); return false; } // Relay mode implies AES encryption if (Mode == ConnectionMode::Relay && EncryptionMode != Encryption::AES) { + ZEN_WARN("Horde relay mode requires AES encryption, but encryption is set to '{}'", ToString(EncryptionMode)); return false; } @@ -52,37 +56,39 @@ ToString(Encryption Enc) bool FromString(ConnectionMode& OutMode, std::string_view Str) { - if (Str == "direct") + if (StrCaseCompare(Str, "direct") == 0) { OutMode = ConnectionMode::Direct; return true; } - if (Str == "tunnel") + if (StrCaseCompare(Str, "tunnel") == 0) { OutMode = ConnectionMode::Tunnel; return true; } - if (Str == "relay") + if (StrCaseCompare(Str, "relay") == 0) { OutMode = ConnectionMode::Relay; return true; } + ZEN_WARN("unrecognized Horde connection mode: '{}'", Str); return false; } bool FromString(Encryption& OutEnc, std::string_view Str) { - if (Str == "none") + if (StrCaseCompare(Str, "none") == 0) { OutEnc = Encryption::None; return true; } - if (Str == "aes") + if (StrCaseCompare(Str, "aes") == 0) { OutEnc = Encryption::AES; return true; } + ZEN_WARN("unrecognized Horde encryption mode: '{}'", Str); return false; } diff --git a/src/zenhorde/hordeprovisioner.cpp b/src/zenhorde/hordeprovisioner.cpp index f88c95da2..b08544d1a 100644 --- a/src/zenhorde/hordeprovisioner.cpp +++ b/src/zenhorde/hordeprovisioner.cpp @@ -6,49 +6,82 @@ #include "hordeagent.h" #include "hordebundle.h" +#include <zencore/compactbinary.h> #include <zencore/fmtutils.h> #include <zencore/logging.h> #include <zencore/scopeguard.h> #include <zencore/thread.h> #include <zencore/trace.h> +#include <zenhttp/httpclient.h> +#include <zenutil/workerpools.h> +ZEN_THIRD_PARTY_INCLUDES_START +#include <asio.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +#include <algorithm> #include <chrono> #include <thread> namespace zen::horde { -struct HordeProvisioner::AgentWrapper -{ - std::thread Thread; - std::atomic<bool> ShouldExit{false}; -}; - HordeProvisioner::HordeProvisioner(const HordeConfig& Config, const std::filesystem::path& BinariesPath, const std::filesystem::path& WorkingDir, - std::string_view OrchestratorEndpoint) + std::string_view OrchestratorEndpoint, + std::string_view CoordinatorSession, + bool CleanStart, + std::string_view TraceHost) : m_Config(Config) , m_BinariesPath(BinariesPath) , m_WorkingDir(WorkingDir) , m_OrchestratorEndpoint(OrchestratorEndpoint) +, m_CoordinatorSession(CoordinatorSession) +, m_CleanStart(CleanStart) +, m_TraceHost(TraceHost) , m_Log(zen::logging::Get("horde.provisioner")) { + m_IoContext = std::make_unique<asio::io_context>(); + + auto Work = asio::make_work_guard(*m_IoContext); + for (int i = 0; i < IoThreadCount; ++i) + { + m_IoThreads.emplace_back([this, i, Work] { + zen::SetCurrentThreadName(fmt::format("horde_io_{}", i)); + m_IoContext->run(); + }); + } } HordeProvisioner::~HordeProvisioner() { - std::lock_guard<std::mutex> Lock(m_AgentsLock); - for (auto& Agent : m_Agents) + m_AskForAgents.store(false); + + // Shut down async agents and io_context { - Agent->ShouldExit.store(true); + std::lock_guard<std::mutex> Lock(m_AsyncAgentsLock); + for (auto& Entry : m_AsyncAgents) + { + Entry.Agent->Cancel(); + } + m_AsyncAgents.clear(); } - for (auto& Agent : m_Agents) + + m_IoContext->stop(); + + for (auto& Thread : m_IoThreads) { - if (Agent->Thread.joinable()) + if (Thread.joinable()) { - Agent->Thread.join(); + Thread.join(); } } + + // Wait for all pool work items to finish before destroying members they reference + if (m_PendingWorkItems.load() > 0) + { + m_AllWorkDone.Wait(); + } } void @@ -56,9 +89,23 @@ HordeProvisioner::SetTargetCoreCount(uint32_t Count) { ZEN_TRACE_CPU("HordeProvisioner::SetTargetCoreCount"); - m_TargetCoreCount.store(std::min(Count, static_cast<uint32_t>(m_Config.MaxCores))); + const uint32_t ClampedCount = std::min(Count, static_cast<uint32_t>(m_Config.MaxCores)); + const uint32_t PreviousTarget = m_TargetCoreCount.exchange(ClampedCount); + + if (ClampedCount != PreviousTarget) + { + ZEN_INFO("target core count changed: {} -> {} (active={}, estimated={})", + PreviousTarget, + ClampedCount, + m_ActiveCoreCount.load(), + m_EstimatedCoreCount.load()); + } - while (m_EstimatedCoreCount.load() < m_TargetCoreCount.load()) + // Only provision if the gap is at least one agent-sized chunk. Without + // this, draining a 32-core agent to cover a 28-core excess would leave a + // 4-core gap that triggers a 32-core provision, which triggers another + // drain, ad infinitum. + while (m_EstimatedCoreCount.load() + EstimatedCoresPerAgent <= m_TargetCoreCount.load()) { if (!m_AskForAgents.load()) { @@ -67,21 +114,108 @@ HordeProvisioner::SetTargetCoreCount(uint32_t Count) RequestAgent(); } - // Clean up finished agent threads - std::lock_guard<std::mutex> Lock(m_AgentsLock); - for (auto It = m_Agents.begin(); It != m_Agents.end();) + // Scale down async agents { - if ((*It)->ShouldExit.load()) + std::lock_guard<std::mutex> AsyncLock(m_AsyncAgentsLock); + + uint32_t AsyncActive = m_ActiveCoreCount.load(); + uint32_t AsyncTarget = m_TargetCoreCount.load(); + + uint32_t AlreadyDrainingCores = 0; + for (const auto& Entry : m_AsyncAgents) { - if ((*It)->Thread.joinable()) + if (Entry.Draining) { - (*It)->Thread.join(); + AlreadyDrainingCores += Entry.CoreCount; } - It = m_Agents.erase(It); } - else + + uint32_t EffectiveAsync = (AsyncActive > AlreadyDrainingCores) ? AsyncActive - AlreadyDrainingCores : 0; + + if (EffectiveAsync > AsyncTarget) { - ++It; + struct Candidate + { + AsyncAgentEntry* Entry; + int Workload; + }; + std::vector<Candidate> Candidates; + + for (auto& Entry : m_AsyncAgents) + { + if (Entry.Draining || Entry.RemoteEndpoint.empty()) + { + continue; + } + + int Workload = 0; + bool Reachable = false; + HttpClientSettings Settings; + Settings.LogCategory = "horde.drain"; + Settings.ConnectTimeout = std::chrono::milliseconds{2000}; + Settings.Timeout = std::chrono::milliseconds{3000}; + try + { + HttpClient Client(Entry.RemoteEndpoint, Settings); + HttpClient::Response Resp = Client.Get("/compute/session/status"); + if (Resp.IsSuccess()) + { + CbObject Status = Resp.AsObject(); + Workload = Status["actions_pending"].AsInt32(0) + Status["actions_running"].AsInt32(0); + Reachable = true; + } + } + catch (const std::exception& Ex) + { + ZEN_DEBUG("agent lease={} not yet reachable for drain: {}", Entry.LeaseId, Ex.what()); + } + + if (Reachable) + { + Candidates.push_back({&Entry, Workload}); + } + } + + const uint32_t ExcessCores = EffectiveAsync - AsyncTarget; + uint32_t CoresDrained = 0; + + while (CoresDrained < ExcessCores && !Candidates.empty()) + { + const uint32_t Remaining = ExcessCores - CoresDrained; + + Candidates.erase(std::remove_if(Candidates.begin(), + Candidates.end(), + [Remaining](const Candidate& C) { return C.Entry->CoreCount > Remaining; }), + Candidates.end()); + + if (Candidates.empty()) + { + break; + } + + Candidate* Best = &Candidates[0]; + for (auto& C : Candidates) + { + if (C.Entry->CoreCount > Best->Entry->CoreCount || + (C.Entry->CoreCount == Best->Entry->CoreCount && C.Workload < Best->Workload)) + { + Best = &C; + } + } + + ZEN_INFO("draining async agent lease={} ({} cores, workload={})", + Best->Entry->LeaseId, + Best->Entry->CoreCount, + Best->Workload); + + DrainAsyncAgent(*Best->Entry); + CoresDrained += Best->Entry->CoreCount; + + AsyncAgentEntry* Drained = Best->Entry; + Candidates.erase( + std::remove_if(Candidates.begin(), Candidates.end(), [Drained](const Candidate& C) { return C.Entry == Drained; }), + Candidates.end()); + } } } } @@ -101,266 +235,380 @@ HordeProvisioner::GetStats() const uint32_t HordeProvisioner::GetAgentCount() const { - std::lock_guard<std::mutex> Lock(m_AgentsLock); - return static_cast<uint32_t>(m_Agents.size()); + std::lock_guard<std::mutex> Lock(m_AsyncAgentsLock); + return static_cast<uint32_t>(m_AsyncAgents.size()); } -void -HordeProvisioner::RequestAgent() +compute::AgentProvisioningStatus +HordeProvisioner::GetAgentStatus(std::string_view WorkerId) const { - m_EstimatedCoreCount.fetch_add(EstimatedCoresPerAgent); + // Worker IDs are "horde-{LeaseId}" — strip the prefix to match lease ID + constexpr std::string_view Prefix = "horde-"; + if (!WorkerId.starts_with(Prefix)) + { + return compute::AgentProvisioningStatus::Unknown; + } + std::string_view LeaseId = WorkerId.substr(Prefix.size()); - std::lock_guard<std::mutex> Lock(m_AgentsLock); + std::lock_guard<std::mutex> AsyncLock(m_AsyncAgentsLock); + for (const auto& Entry : m_AsyncAgents) + { + if (Entry.LeaseId == LeaseId) + { + if (Entry.Draining) + { + return compute::AgentProvisioningStatus::Draining; + } + return compute::AgentProvisioningStatus::Active; + } + } - auto Wrapper = std::make_unique<AgentWrapper>(); - AgentWrapper& Ref = *Wrapper; - Wrapper->Thread = std::thread([this, &Ref] { ThreadAgent(Ref); }); + // Check recently-drained agents that have already been cleaned up + std::string WorkerIdStr(WorkerId); + if (m_RecentlyDrainedWorkerIds.erase(WorkerIdStr) > 0) + { + return compute::AgentProvisioningStatus::Draining; + } - m_Agents.push_back(std::move(Wrapper)); + return compute::AgentProvisioningStatus::Unknown; } -void -HordeProvisioner::ThreadAgent(AgentWrapper& Wrapper) +std::vector<std::string> +HordeProvisioner::BuildAgentArgs(const MachineInfo& Machine) const { - ZEN_TRACE_CPU("HordeProvisioner::ThreadAgent"); + std::vector<std::string> Args; + Args.emplace_back("compute"); + Args.emplace_back("--http=asio"); + Args.push_back(fmt::format("--port={}", m_Config.ZenServicePort)); + Args.emplace_back("--data-dir=%UE_HORDE_SHARED_DIR%\\zen"); - static std::atomic<uint32_t> ThreadIndex{0}; - const uint32_t CurrentIndex = ThreadIndex.fetch_add(1); + if (m_CleanStart) + { + Args.emplace_back("--clean"); + } - zen::SetCurrentThreadName(fmt::format("horde_agent_{}", CurrentIndex)); + if (!m_OrchestratorEndpoint.empty()) + { + ExtendableStringBuilder<256> CoordArg; + CoordArg << "--coordinator-endpoint=" << m_OrchestratorEndpoint; + Args.emplace_back(CoordArg.ToView()); + } - std::unique_ptr<HordeAgent> Agent; - uint32_t MachineCoreCount = 0; + { + ExtendableStringBuilder<128> IdArg; + IdArg << "--instance-id=horde-" << Machine.LeaseId; + Args.emplace_back(IdArg.ToView()); + } - auto _ = MakeGuard([&] { - if (Agent) - { - Agent->CloseConnection(); - } - Wrapper.ShouldExit.store(true); - }); + if (!m_CoordinatorSession.empty()) + { + ExtendableStringBuilder<128> SessionArg; + SessionArg << "--coordinator-session=" << m_CoordinatorSession; + Args.emplace_back(SessionArg.ToView()); + } + if (!m_TraceHost.empty()) { - // EstimatedCoreCount is incremented speculatively when the agent is requested - // (in RequestAgent) so that SetTargetCoreCount doesn't over-provision. - auto $ = MakeGuard([&] { m_EstimatedCoreCount.fetch_sub(EstimatedCoresPerAgent); }); + ExtendableStringBuilder<128> TraceArg; + TraceArg << "--tracehost=" << m_TraceHost; + Args.emplace_back(TraceArg.ToView()); + } + // In relay mode, the remote zenserver's local address is not reachable from the + // orchestrator. Pass the relay-visible endpoint so it announces the correct URL. + if (Machine.Mode == ConnectionMode::Relay) + { + const auto [Addr, Port] = Machine.GetZenServiceEndpoint(m_Config.ZenServicePort); + if (Addr.find(':') != std::string::npos) + { + Args.push_back(fmt::format("--announce-url=http://[{}]:{}", Addr, Port)); + } + else { - ZEN_TRACE_CPU("HordeProvisioner::CreateBundles"); + Args.push_back(fmt::format("--announce-url=http://{}:{}", Addr, Port)); + } + } - std::lock_guard<std::mutex> BundleLock(m_BundleLock); + return Args; +} - if (!m_BundlesCreated) - { - const std::filesystem::path OutputDir = m_WorkingDir / "horde_bundles"; +bool +HordeProvisioner::InitializeHordeClient() +{ + ZEN_TRACE_CPU("HordeProvisioner::InitializeHordeClient"); + + std::lock_guard<std::mutex> BundleLock(m_BundleLock); - std::vector<BundleFile> Files; + if (!m_BundlesCreated) + { + const std::filesystem::path OutputDir = m_WorkingDir / "horde_bundles"; + + std::vector<BundleFile> Files; #if ZEN_PLATFORM_WINDOWS - Files.emplace_back(m_BinariesPath / "zenserver.exe", false); + Files.emplace_back(m_BinariesPath / "zenserver.exe", false); + Files.emplace_back(m_BinariesPath / "zenserver.pdb", true); #elif ZEN_PLATFORM_LINUX - Files.emplace_back(m_BinariesPath / "zenserver", false); - Files.emplace_back(m_BinariesPath / "zenserver.debug", true); + Files.emplace_back(m_BinariesPath / "zenserver", false); + Files.emplace_back(m_BinariesPath / "zenserver.debug", true); #elif ZEN_PLATFORM_MAC - Files.emplace_back(m_BinariesPath / "zenserver", false); + Files.emplace_back(m_BinariesPath / "zenserver", false); #endif - BundleResult Result; - if (!BundleCreator::CreateBundle(Files, OutputDir, Result)) - { - ZEN_WARN("failed to create bundle, cannot provision any agents!"); - m_AskForAgents.store(false); - return; - } - - m_Bundles.emplace_back(Result.Locator, Result.BundleDir); - m_BundlesCreated = true; - } - - if (!m_HordeClient) - { - m_HordeClient = std::make_unique<HordeClient>(m_Config); - if (!m_HordeClient->Initialize()) - { - ZEN_WARN("failed to initialize Horde HTTP client, cannot provision any agents!"); - m_AskForAgents.store(false); - return; - } - } + BundleResult Result; + if (!BundleCreator::CreateBundle(Files, OutputDir, Result)) + { + ZEN_WARN("failed to create bundle, cannot provision any agents!"); + m_AskForAgents.store(false); + return false; } - if (!m_AskForAgents.load()) + m_Bundles.emplace_back(Result.Locator, Result.BundleDir); + m_BundlesCreated = true; + } + + if (!m_HordeClient) + { + m_HordeClient = std::make_unique<HordeClient>(m_Config); + if (!m_HordeClient->Initialize()) { - return; + ZEN_WARN("failed to initialize Horde HTTP client, cannot provision any agents!"); + m_AskForAgents.store(false); + return false; } + } - m_AgentsRequesting.fetch_add(1); - auto ReqGuard = MakeGuard([this] { m_AgentsRequesting.fetch_sub(1); }); + return true; +} - // Simple backoff: if the last machine request failed, wait up to 5 seconds - // before trying again. - // - // Note however that it's possible that multiple threads enter this code at - // the same time if multiple agents are requested at once, and they will all - // see the same last failure time and back off accordingly. We might want to - // use a semaphore or similar to limit the number of concurrent requests. +void +HordeProvisioner::RequestAgent() +{ + m_EstimatedCoreCount.fetch_add(EstimatedCoresPerAgent); - if (const uint64_t LastFail = m_LastRequestFailTime.load(); LastFail != 0) - { - auto Now = static_cast<uint64_t>(std::chrono::steady_clock::now().time_since_epoch().count()); - const uint64_t ElapsedNs = Now - LastFail; - const uint64_t ElapsedMs = ElapsedNs / 1'000'000; - if (ElapsedMs < 5000) - { - const uint64_t WaitMs = 5000 - ElapsedMs; - for (uint64_t Waited = 0; Waited < WaitMs && !Wrapper.ShouldExit.load(); Waited += 100) - { - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - } + if (m_PendingWorkItems.fetch_add(1) == 0) + { + m_AllWorkDone.Reset(); + } - if (Wrapper.ShouldExit.load()) + GetSmallWorkerPool(EWorkloadType::Background) + .ScheduleWork( + [this] { + ProvisionAgent(); + if (m_PendingWorkItems.fetch_sub(1) == 1) { - return; + m_AllWorkDone.Set(); } - } - } - - if (m_ActiveCoreCount.load() >= m_TargetCoreCount.load()) - { - return; - } - - std::string RequestBody = m_HordeClient->BuildRequestBody(); + }, + WorkerThreadPool::EMode::EnableBacklog); +} - // Resolve cluster if needed - std::string ClusterId = m_Config.Cluster; - if (ClusterId == HordeConfig::ClusterAuto) - { - ClusterInfo Cluster; - if (!m_HordeClient->ResolveCluster(RequestBody, Cluster)) - { - ZEN_WARN("failed to resolve cluster"); - m_LastRequestFailTime.store(static_cast<uint64_t>(std::chrono::steady_clock::now().time_since_epoch().count())); - return; - } - ClusterId = Cluster.ClusterId; - } +void +HordeProvisioner::ProvisionAgent() +{ + ZEN_TRACE_CPU("HordeProvisioner::ProvisionAgent"); - MachineInfo Machine; - if (!m_HordeClient->RequestMachine(RequestBody, ClusterId, /* out */ Machine) || !Machine.IsValid()) - { - m_LastRequestFailTime.store(static_cast<uint64_t>(std::chrono::steady_clock::now().time_since_epoch().count())); - return; - } + // EstimatedCoreCount is incremented speculatively when the agent is requested + // (in RequestAgent) so that SetTargetCoreCount doesn't over-provision. + auto $ = MakeGuard([&] { m_EstimatedCoreCount.fetch_sub(EstimatedCoresPerAgent); }); - m_LastRequestFailTime.store(0); + if (!InitializeHordeClient()) + { + return; + } - if (Wrapper.ShouldExit.load()) - { - return; - } + if (!m_AskForAgents.load()) + { + return; + } - // Connect to agent and perform handshake - Agent = std::make_unique<HordeAgent>(Machine); - if (!Agent->IsValid()) - { - ZEN_WARN("agent creation failed for {}:{}", Machine.GetConnectionAddress(), Machine.GetConnectionPort()); - return; - } + m_AgentsRequesting.fetch_add(1); + auto ReqGuard = MakeGuard([this] { m_AgentsRequesting.fetch_sub(1); }); - if (!Agent->BeginCommunication()) - { - ZEN_WARN("BeginCommunication failed"); - return; - } + // Simple backoff: if the last machine request failed, wait up to 5 seconds + // before trying again. + // + // Note however that it's possible that multiple threads enter this code at + // the same time if multiple agents are requested at once, and they will all + // see the same last failure time and back off accordingly. We might want to + // use a semaphore or similar to limit the number of concurrent requests. - for (auto& [Locator, BundleDir] : m_Bundles) + if (const uint64_t LastFail = m_LastRequestFailTime.load(); LastFail != 0) + { + auto Now = static_cast<uint64_t>(std::chrono::steady_clock::now().time_since_epoch().count()); + const uint64_t ElapsedNs = Now - LastFail; + const uint64_t ElapsedMs = ElapsedNs / 1'000'000; + if (ElapsedMs < 5000) { - if (Wrapper.ShouldExit.load()) + const uint64_t WaitMs = 5000 - ElapsedMs; + for (uint64_t Waited = 0; Waited < WaitMs && !!m_AskForAgents.load(); Waited += 100) { - return; + std::this_thread::sleep_for(std::chrono::milliseconds(100)); } - if (!Agent->UploadBinaries(BundleDir, Locator)) + if (!m_AskForAgents.load()) { - ZEN_WARN("UploadBinaries failed"); return; } } + } - if (Wrapper.ShouldExit.load()) + if (m_ActiveCoreCount.load() >= m_TargetCoreCount.load()) + { + return; + } + + std::string RequestBody = m_HordeClient->BuildRequestBody(); + + // Resolve cluster if needed + std::string ClusterId = m_Config.Cluster; + if (ClusterId == HordeConfig::ClusterAuto) + { + ClusterInfo Cluster; + if (!m_HordeClient->ResolveCluster(RequestBody, Cluster)) { + ZEN_WARN("failed to resolve cluster"); + m_LastRequestFailTime.store(static_cast<uint64_t>(std::chrono::steady_clock::now().time_since_epoch().count())); return; } + ClusterId = Cluster.ClusterId; + } - // Build command line for remote zenserver - std::vector<std::string> ArgStrings; - ArgStrings.push_back("compute"); - ArgStrings.push_back("--http=asio"); + ZEN_INFO("requesting machine from Horde (cluster='{}', cores={}/{})", + ClusterId.empty() ? "default" : ClusterId.c_str(), + m_ActiveCoreCount.load(), + m_TargetCoreCount.load()); - // TEMP HACK - these should be made fully dynamic - // these are currently here to allow spawning the compute agent locally - // for debugging purposes (i.e with a local Horde Server+Agent setup) - ArgStrings.push_back(fmt::format("--port={}", m_Config.ZenServicePort)); - ArgStrings.push_back("--data-dir=c:\\temp\\123"); + MachineInfo Machine; + if (!m_HordeClient->RequestMachine(RequestBody, ClusterId, /* out */ Machine) || !Machine.IsValid()) + { + m_LastRequestFailTime.store(static_cast<uint64_t>(std::chrono::steady_clock::now().time_since_epoch().count())); + return; + } - if (!m_OrchestratorEndpoint.empty()) - { - ExtendableStringBuilder<256> CoordArg; - CoordArg << "--coordinator-endpoint=" << m_OrchestratorEndpoint; - ArgStrings.emplace_back(CoordArg.ToView()); - } + m_LastRequestFailTime.store(0); - { - ExtendableStringBuilder<128> IdArg; - IdArg << "--instance-id=horde-" << Machine.LeaseId; - ArgStrings.emplace_back(IdArg.ToView()); - } + if (!m_AskForAgents.load()) + { + return; + } - std::vector<const char*> Args; - Args.reserve(ArgStrings.size()); - for (const std::string& Arg : ArgStrings) - { - Args.push_back(Arg.c_str()); - } + AsyncAgentConfig AgentConfig; + AgentConfig.Machine = Machine; + AgentConfig.Bundles = m_Bundles; + AgentConfig.Args = BuildAgentArgs(Machine); #if ZEN_PLATFORM_WINDOWS - const bool UseWine = !Machine.IsWindows; - const char* AppName = "zenserver.exe"; + AgentConfig.UseWine = !Machine.IsWindows; + AgentConfig.Executable = "zenserver.exe"; #else - const bool UseWine = false; - const char* AppName = "zenserver"; + AgentConfig.UseWine = false; + AgentConfig.Executable = "zenserver"; #endif - Agent->Execute(AppName, Args.data(), Args.size(), nullptr, nullptr, 0, UseWine); + auto AsyncAgent = std::make_shared<AsyncHordeAgent>(*m_IoContext); + + AsyncAgentEntry Entry; + Entry.Agent = AsyncAgent; + Entry.LeaseId = Machine.LeaseId; + Entry.CoreCount = Machine.LogicalCores; - ZEN_INFO("remote execution started on [{}:{}] lease={}", - Machine.GetConnectionAddress(), - Machine.GetConnectionPort(), - Machine.LeaseId); + const auto [EndpointAddr, EndpointPort] = Machine.GetZenServiceEndpoint(m_Config.ZenServicePort); + if (EndpointAddr.find(':') != std::string::npos) + { + Entry.RemoteEndpoint = fmt::format("http://[{}]:{}", EndpointAddr, EndpointPort); + } + else + { + Entry.RemoteEndpoint = fmt::format("http://{}:{}", EndpointAddr, EndpointPort); + } - MachineCoreCount = Machine.LogicalCores; - m_EstimatedCoreCount.fetch_add(MachineCoreCount); - m_ActiveCoreCount.fetch_add(MachineCoreCount); - m_AgentsActive.fetch_add(1); + { + std::lock_guard<std::mutex> Lock(m_AsyncAgentsLock); + m_AsyncAgents.push_back(std::move(Entry)); } - // Agent poll loop + AsyncAgent->Start(std::move(AgentConfig), [this, AsyncAgent](const AsyncAgentResult& Result) { + if (Result.CoreCount > 0) + { + // Only subtract estimated cores if not already subtracted by DrainAsyncAgent + bool WasDraining = false; + { + std::lock_guard<std::mutex> Lock(m_AsyncAgentsLock); + for (const auto& Entry : m_AsyncAgents) + { + if (Entry.Agent == AsyncAgent) + { + WasDraining = Entry.Draining; + break; + } + } + } - auto ActiveGuard = MakeGuard([&]() { - m_EstimatedCoreCount.fetch_sub(MachineCoreCount); - m_ActiveCoreCount.fetch_sub(MachineCoreCount); - m_AgentsActive.fetch_sub(1); + if (!WasDraining) + { + m_EstimatedCoreCount.fetch_sub(Result.CoreCount); + } + m_ActiveCoreCount.fetch_sub(Result.CoreCount); + m_AgentsActive.fetch_sub(1); + } + OnAsyncAgentDone(AsyncAgent); }); - while (Agent->IsValid() && !Wrapper.ShouldExit.load()) + // Track active cores (estimated was already added by RequestAgent) + m_EstimatedCoreCount.fetch_add(Machine.LogicalCores); + m_ActiveCoreCount.fetch_add(Machine.LogicalCores); + m_AgentsActive.fetch_add(1); +} + +void +HordeProvisioner::DrainAsyncAgent(AsyncAgentEntry& Entry) +{ + Entry.Draining = true; + m_EstimatedCoreCount.fetch_sub(Entry.CoreCount); + m_AgentsDraining.fetch_add(1); + + HttpClientSettings Settings; + Settings.LogCategory = "horde.drain"; + Settings.ConnectTimeout = std::chrono::milliseconds{5000}; + Settings.Timeout = std::chrono::milliseconds{10000}; + + try { - const bool LogOutput = false; - if (!Agent->Poll(LogOutput)) + HttpClient Client(Entry.RemoteEndpoint, Settings); + + HttpClient::Response Response = Client.Post("/compute/session/drain"); + if (!Response.IsSuccess()) { + ZEN_WARN("drain[{}]: POST session/drain failed: HTTP {}", Entry.LeaseId, static_cast<int>(Response.StatusCode)); + return; + } + + ZEN_INFO("drain[{}]: session/drain accepted, sending sunset", Entry.LeaseId); + (void)Client.Post("/compute/session/sunset"); + } + catch (const std::exception& Ex) + { + ZEN_WARN("drain[{}]: exception: {}", Entry.LeaseId, Ex.what()); + } +} + +void +HordeProvisioner::OnAsyncAgentDone(std::shared_ptr<AsyncHordeAgent> Agent) +{ + std::lock_guard<std::mutex> Lock(m_AsyncAgentsLock); + for (auto It = m_AsyncAgents.begin(); It != m_AsyncAgents.end(); ++It) + { + if (It->Agent == Agent) + { + if (It->Draining) + { + m_AgentsDraining.fetch_sub(1); + m_RecentlyDrainedWorkerIds.insert("horde-" + It->LeaseId); + } + m_AsyncAgents.erase(It); break; } - std::this_thread::sleep_for(std::chrono::milliseconds(100)); } } diff --git a/src/zenhorde/hordetransport.cpp b/src/zenhorde/hordetransport.cpp index 69766e73e..65eaea477 100644 --- a/src/zenhorde/hordetransport.cpp +++ b/src/zenhorde/hordetransport.cpp @@ -9,71 +9,33 @@ ZEN_THIRD_PARTY_INCLUDES_START #include <asio.hpp> ZEN_THIRD_PARTY_INCLUDES_END -#if ZEN_PLATFORM_WINDOWS -# undef SendMessage -#endif - namespace zen::horde { -// ComputeTransport base +// --- AsyncTcpComputeTransport --- -bool -ComputeTransport::SendMessage(const void* Data, size_t Size) +struct AsyncTcpComputeTransport::Impl { - const uint8_t* Ptr = static_cast<const uint8_t*>(Data); - size_t Remaining = Size; - - while (Remaining > 0) - { - const size_t Sent = Send(Ptr, Remaining); - if (Sent == 0) - { - return false; - } - Ptr += Sent; - Remaining -= Sent; - } + asio::io_context& IoContext; + asio::ip::tcp::socket Socket; - return true; -} + explicit Impl(asio::io_context& Ctx) : IoContext(Ctx), Socket(Ctx) {} +}; -bool -ComputeTransport::RecvMessage(void* Data, size_t Size) +AsyncTcpComputeTransport::AsyncTcpComputeTransport(asio::io_context& IoContext) +: m_Impl(std::make_unique<Impl>(IoContext)) +, m_Log(zen::logging::Get("horde.transport.async")) { - uint8_t* Ptr = static_cast<uint8_t*>(Data); - size_t Remaining = Size; - - while (Remaining > 0) - { - const size_t Received = Recv(Ptr, Remaining); - if (Received == 0) - { - return false; - } - Ptr += Received; - Remaining -= Received; - } - - return true; } -// TcpComputeTransport - ASIO pimpl - -struct TcpComputeTransport::Impl +AsyncTcpComputeTransport::~AsyncTcpComputeTransport() { - asio::io_context IoContext; - asio::ip::tcp::socket Socket; - - Impl() : Socket(IoContext) {} -}; + Close(); +} -// Uses ASIO in synchronous mode only — no async operations or io_context::run(). -// The io_context is only needed because ASIO sockets require one to be constructed. -TcpComputeTransport::TcpComputeTransport(const MachineInfo& Info) -: m_Impl(std::make_unique<Impl>()) -, m_Log(zen::logging::Get("horde.transport")) +void +AsyncTcpComputeTransport::AsyncConnect(const MachineInfo& Info, AsyncConnectHandler Handler) { - ZEN_TRACE_CPU("TcpComputeTransport::Connect"); + ZEN_TRACE_CPU("AsyncTcpComputeTransport::AsyncConnect"); asio::error_code Ec; @@ -82,80 +44,75 @@ TcpComputeTransport::TcpComputeTransport(const MachineInfo& Info) { ZEN_WARN("invalid address '{}': {}", Info.GetConnectionAddress(), Ec.message()); m_HasErrors = true; + asio::post(m_Impl->IoContext, [Handler = std::move(Handler), Ec] { Handler(Ec); }); return; } const asio::ip::tcp::endpoint Endpoint(Address, Info.GetConnectionPort()); - m_Impl->Socket.connect(Endpoint, Ec); - if (Ec) - { - ZEN_WARN("failed to connect to Horde compute [{}:{}]: {}", Info.GetConnectionAddress(), Info.GetConnectionPort(), Ec.message()); - m_HasErrors = true; - return; - } + // Copy the nonce so it survives past this scope into the async callback + auto NonceBuf = std::make_shared<std::vector<uint8_t>>(Info.Nonce, Info.Nonce + NonceSize); - // Disable Nagle's algorithm for lower latency - m_Impl->Socket.set_option(asio::ip::tcp::no_delay(true), Ec); -} + m_Impl->Socket.async_connect(Endpoint, [this, Handler = std::move(Handler), NonceBuf](const asio::error_code& Ec) mutable { + if (Ec) + { + ZEN_WARN("async connect failed: {}", Ec.message()); + m_HasErrors = true; + Handler(Ec); + return; + } -TcpComputeTransport::~TcpComputeTransport() -{ - Close(); + asio::error_code SetOptEc; + m_Impl->Socket.set_option(asio::ip::tcp::no_delay(true), SetOptEc); + + // Send the 64-byte nonce as the first thing on the wire + asio::async_write(m_Impl->Socket, + asio::buffer(*NonceBuf), + [this, Handler = std::move(Handler), NonceBuf](const asio::error_code& Ec, size_t /*BytesWritten*/) { + if (Ec) + { + ZEN_WARN("nonce write failed: {}", Ec.message()); + m_HasErrors = true; + } + Handler(Ec); + }); + }); } bool -TcpComputeTransport::IsValid() const +AsyncTcpComputeTransport::IsValid() const { return m_Impl && m_Impl->Socket.is_open() && !m_HasErrors && !m_IsClosed; } -size_t -TcpComputeTransport::Send(const void* Data, size_t Size) +void +AsyncTcpComputeTransport::AsyncWrite(const void* Data, size_t Size, AsyncIoHandler Handler) { if (!IsValid()) { - return 0; - } - - asio::error_code Ec; - const size_t Sent = m_Impl->Socket.send(asio::buffer(Data, Size), 0, Ec); - - if (Ec) - { - m_HasErrors = true; - return 0; + asio::post(m_Impl->IoContext, + [Handler = std::move(Handler)] { Handler(asio::error::make_error_code(asio::error::not_connected), 0); }); + return; } - return Sent; + asio::async_write(m_Impl->Socket, asio::buffer(Data, Size), std::move(Handler)); } -size_t -TcpComputeTransport::Recv(void* Data, size_t Size) +void +AsyncTcpComputeTransport::AsyncRead(void* Data, size_t Size, AsyncIoHandler Handler) { if (!IsValid()) { - return 0; - } - - asio::error_code Ec; - const size_t Received = m_Impl->Socket.receive(asio::buffer(Data, Size), 0, Ec); - - if (Ec) - { - return 0; + asio::post(m_Impl->IoContext, + [Handler = std::move(Handler)] { Handler(asio::error::make_error_code(asio::error::not_connected), 0); }); + return; } - return Received; -} - -void -TcpComputeTransport::MarkComplete() -{ + asio::async_read(m_Impl->Socket, asio::buffer(Data, Size), std::move(Handler)); } void -TcpComputeTransport::Close() +AsyncTcpComputeTransport::Close() { if (!m_IsClosed && m_Impl && m_Impl->Socket.is_open()) { diff --git a/src/zenhorde/hordetransport.h b/src/zenhorde/hordetransport.h index 1b178dc0f..b5e841d7a 100644 --- a/src/zenhorde/hordetransport.h +++ b/src/zenhorde/hordetransport.h @@ -8,55 +8,60 @@ #include <cstddef> #include <cstdint> +#include <functional> #include <memory> +#include <system_error> -#if ZEN_PLATFORM_WINDOWS -# undef SendMessage -#endif +namespace asio { +class io_context; +} namespace zen::horde { -/** Abstract base interface for compute transports. +/** Handler types for async transport operations. */ +using AsyncConnectHandler = std::function<void(const std::error_code&)>; +using AsyncIoHandler = std::function<void(const std::error_code&, size_t)>; + +/** Abstract base for asynchronous compute transports. * - * Matches the UE FComputeTransport pattern. Concrete implementations handle - * the underlying I/O (TCP, AES-wrapped, etc.) while this interface provides - * blocking message helpers on top. + * All callbacks are invoked on the io_context that was provided at construction. + * Callers are responsible for strand serialization if needed. */ -class ComputeTransport +class AsyncComputeTransport { public: - virtual ~ComputeTransport() = default; + virtual ~AsyncComputeTransport() = default; + + virtual bool IsValid() const = 0; - virtual bool IsValid() const = 0; - virtual size_t Send(const void* Data, size_t Size) = 0; - virtual size_t Recv(void* Data, size_t Size) = 0; - virtual void MarkComplete() = 0; - virtual void Close() = 0; + /** Asynchronous write of exactly Size bytes. Handler called on completion or error. */ + virtual void AsyncWrite(const void* Data, size_t Size, AsyncIoHandler Handler) = 0; - /** Blocking send that loops until all bytes are transferred. Returns false on error. */ - bool SendMessage(const void* Data, size_t Size); + /** Asynchronous read of exactly Size bytes into Data. Handler called on completion or error. */ + virtual void AsyncRead(void* Data, size_t Size, AsyncIoHandler Handler) = 0; - /** Blocking receive that loops until all bytes are transferred. Returns false on error. */ - bool RecvMessage(void* Data, size_t Size); + virtual void Close() = 0; }; -/** TCP socket transport using ASIO. +/** Async TCP transport using ASIO. * - * Connects to the Horde compute endpoint specified by MachineInfo and provides - * raw TCP send/receive. ASIO internals are hidden behind a pimpl to keep the - * header clean. + * Connects to the Horde compute endpoint and provides async send/receive. + * The socket is created on a caller-provided io_context (shared across agents). */ -class TcpComputeTransport final : public ComputeTransport +class AsyncTcpComputeTransport final : public AsyncComputeTransport { public: - explicit TcpComputeTransport(const MachineInfo& Info); - ~TcpComputeTransport() override; - - bool IsValid() const override; - size_t Send(const void* Data, size_t Size) override; - size_t Recv(void* Data, size_t Size) override; - void MarkComplete() override; - void Close() override; + /** Construct a transport on the given io_context. Does not connect yet. */ + explicit AsyncTcpComputeTransport(asio::io_context& IoContext); + ~AsyncTcpComputeTransport() override; + + /** Asynchronously connect to the endpoint and send the nonce. */ + void AsyncConnect(const MachineInfo& Info, AsyncConnectHandler Handler); + + bool IsValid() const override; + void AsyncWrite(const void* Data, size_t Size, AsyncIoHandler Handler) override; + void AsyncRead(void* Data, size_t Size, AsyncIoHandler Handler) override; + void Close() override; private: LoggerRef Log() { return m_Log; } diff --git a/src/zenhorde/hordetransportaes.cpp b/src/zenhorde/hordetransportaes.cpp index 505b6bde7..c71866e8c 100644 --- a/src/zenhorde/hordetransportaes.cpp +++ b/src/zenhorde/hordetransportaes.cpp @@ -5,6 +5,10 @@ #include <zencore/logging.h> #include <zencore/trace.h> +ZEN_THIRD_PARTY_INCLUDES_START +#include <asio.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + #include <algorithm> #include <cstring> #include <random> @@ -22,274 +26,281 @@ ZEN_THIRD_PARTY_INCLUDES_END namespace zen::horde { -struct AesComputeTransport::CryptoContext -{ - uint8_t Key[KeySize] = {}; - uint8_t EncryptNonce[NonceBytes] = {}; - uint8_t DecryptNonce[NonceBytes] = {}; - bool HasErrors = false; +namespace { + + static constexpr size_t AesNonceBytes = 12; + static constexpr size_t AesTagBytes = 16; + + /** AES-256-GCM crypto context. Not exposed outside this translation unit. */ + struct AesCryptoContext + { + static constexpr size_t NonceBytes = AesNonceBytes; + static constexpr size_t TagBytes = AesTagBytes; + + uint8_t Key[KeySize] = {}; + uint8_t EncryptNonce[NonceBytes] = {}; + uint8_t DecryptNonce[NonceBytes] = {}; + bool HasErrors = false; #if !ZEN_PLATFORM_WINDOWS - EVP_CIPHER_CTX* EncCtx = nullptr; - EVP_CIPHER_CTX* DecCtx = nullptr; + EVP_CIPHER_CTX* EncCtx = nullptr; + EVP_CIPHER_CTX* DecCtx = nullptr; #endif - CryptoContext(const uint8_t (&InKey)[KeySize]) - { - memcpy(Key, InKey, KeySize); - - // The encrypt nonce is randomly initialized and then deterministically mutated - // per message via UpdateNonce(). The decrypt nonce is not used — it comes from - // the wire (each received message carries its own nonce in the header). - std::random_device Rd; - std::mt19937 Gen(Rd()); - std::uniform_int_distribution<int> Dist(0, 255); - for (auto& Byte : EncryptNonce) + AesCryptoContext(const uint8_t (&InKey)[KeySize]) { - Byte = static_cast<uint8_t>(Dist(Gen)); - } + memcpy(Key, InKey, KeySize); + + std::random_device Rd; + std::mt19937 Gen(Rd()); + std::uniform_int_distribution<int> Dist(0, 255); + for (auto& Byte : EncryptNonce) + { + Byte = static_cast<uint8_t>(Dist(Gen)); + } #if !ZEN_PLATFORM_WINDOWS - // Drain any stale OpenSSL errors - while (ERR_get_error() != 0) - { - } + while (ERR_get_error() != 0) + { + } - EncCtx = EVP_CIPHER_CTX_new(); - EVP_EncryptInit_ex(EncCtx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr); + EncCtx = EVP_CIPHER_CTX_new(); + EVP_EncryptInit_ex(EncCtx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr); - DecCtx = EVP_CIPHER_CTX_new(); - EVP_DecryptInit_ex(DecCtx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr); + DecCtx = EVP_CIPHER_CTX_new(); + EVP_DecryptInit_ex(DecCtx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr); #endif - } + } - ~CryptoContext() - { + ~AesCryptoContext() + { #if ZEN_PLATFORM_WINDOWS - SecureZeroMemory(Key, sizeof(Key)); - SecureZeroMemory(EncryptNonce, sizeof(EncryptNonce)); - SecureZeroMemory(DecryptNonce, sizeof(DecryptNonce)); + SecureZeroMemory(Key, sizeof(Key)); + SecureZeroMemory(EncryptNonce, sizeof(EncryptNonce)); + SecureZeroMemory(DecryptNonce, sizeof(DecryptNonce)); #else - OPENSSL_cleanse(Key, sizeof(Key)); - OPENSSL_cleanse(EncryptNonce, sizeof(EncryptNonce)); - OPENSSL_cleanse(DecryptNonce, sizeof(DecryptNonce)); - - if (EncCtx) - { - EVP_CIPHER_CTX_free(EncCtx); + OPENSSL_cleanse(Key, sizeof(Key)); + OPENSSL_cleanse(EncryptNonce, sizeof(EncryptNonce)); + OPENSSL_cleanse(DecryptNonce, sizeof(DecryptNonce)); + + if (EncCtx) + { + EVP_CIPHER_CTX_free(EncCtx); + } + if (DecCtx) + { + EVP_CIPHER_CTX_free(DecCtx); + } +#endif } - if (DecCtx) + + void UpdateNonce() { - EVP_CIPHER_CTX_free(DecCtx); + uint32_t* N32 = reinterpret_cast<uint32_t*>(EncryptNonce); + N32[0]++; + N32[1]--; + N32[2] = N32[0] ^ N32[1]; } -#endif - } - - void UpdateNonce() - { - uint32_t* N32 = reinterpret_cast<uint32_t*>(EncryptNonce); - N32[0]++; - N32[1]--; - N32[2] = N32[0] ^ N32[1]; - } - // Returns total encrypted message size, or 0 on failure - // Output format: [length(4B)][nonce(12B)][ciphertext][tag(16B)] - int32_t EncryptMessage(uint8_t* Out, const void* In, int32_t InLength) - { - UpdateNonce(); + int32_t EncryptMessage(uint8_t* Out, const void* In, int32_t InLength) + { + UpdateNonce(); - // On Windows, BCrypt algorithm/key handles are created per call. This is simpler than - // caching but has some overhead. For our use case (relatively large, infrequent messages) - // this is acceptable. #if ZEN_PLATFORM_WINDOWS - BCRYPT_ALG_HANDLE hAlg = nullptr; - BCRYPT_KEY_HANDLE hKey = nullptr; - - BCryptOpenAlgorithmProvider(&hAlg, BCRYPT_AES_ALGORITHM, nullptr, 0); - BCryptSetProperty(hAlg, BCRYPT_CHAINING_MODE, (PUCHAR)BCRYPT_CHAIN_MODE_GCM, sizeof(BCRYPT_CHAIN_MODE_GCM), 0); - BCryptGenerateSymmetricKey(hAlg, &hKey, nullptr, 0, (PUCHAR)Key, KeySize, 0); - - BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO AuthInfo; - BCRYPT_INIT_AUTH_MODE_INFO(AuthInfo); - AuthInfo.pbNonce = EncryptNonce; - AuthInfo.cbNonce = NonceBytes; - uint8_t Tag[TagBytes] = {}; - AuthInfo.pbTag = Tag; - AuthInfo.cbTag = TagBytes; - - ULONG CipherLen = 0; - NTSTATUS Status = - BCryptEncrypt(hKey, (PUCHAR)In, (ULONG)InLength, &AuthInfo, nullptr, 0, Out + 4 + NonceBytes, (ULONG)InLength, &CipherLen, 0); - - if (!BCRYPT_SUCCESS(Status)) - { - HasErrors = true; + BCRYPT_ALG_HANDLE hAlg = nullptr; + BCRYPT_KEY_HANDLE hKey = nullptr; + + BCryptOpenAlgorithmProvider(&hAlg, BCRYPT_AES_ALGORITHM, nullptr, 0); + BCryptSetProperty(hAlg, BCRYPT_CHAINING_MODE, (PUCHAR)BCRYPT_CHAIN_MODE_GCM, sizeof(BCRYPT_CHAIN_MODE_GCM), 0); + BCryptGenerateSymmetricKey(hAlg, &hKey, nullptr, 0, (PUCHAR)Key, KeySize, 0); + + BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO AuthInfo; + BCRYPT_INIT_AUTH_MODE_INFO(AuthInfo); + AuthInfo.pbNonce = EncryptNonce; + AuthInfo.cbNonce = NonceBytes; + uint8_t Tag[TagBytes] = {}; + AuthInfo.pbTag = Tag; + AuthInfo.cbTag = TagBytes; + + ULONG CipherLen = 0; + NTSTATUS Status = BCryptEncrypt(hKey, + (PUCHAR)In, + (ULONG)InLength, + &AuthInfo, + nullptr, + 0, + Out + 4 + NonceBytes, + (ULONG)InLength, + &CipherLen, + 0); + + if (!BCRYPT_SUCCESS(Status)) + { + HasErrors = true; + BCryptDestroyKey(hKey); + BCryptCloseAlgorithmProvider(hAlg, 0); + return 0; + } + + memcpy(Out, &InLength, 4); + memcpy(Out + 4, EncryptNonce, NonceBytes); + memcpy(Out + 4 + NonceBytes + CipherLen, Tag, TagBytes); + BCryptDestroyKey(hKey); BCryptCloseAlgorithmProvider(hAlg, 0); - return 0; - } - - // Write header: length + nonce - memcpy(Out, &InLength, 4); - memcpy(Out + 4, EncryptNonce, NonceBytes); - // Write tag after ciphertext - memcpy(Out + 4 + NonceBytes + CipherLen, Tag, TagBytes); - - BCryptDestroyKey(hKey); - BCryptCloseAlgorithmProvider(hAlg, 0); - return 4 + NonceBytes + static_cast<int32_t>(CipherLen) + TagBytes; + return 4 + NonceBytes + static_cast<int32_t>(CipherLen) + TagBytes; #else - if (EVP_EncryptInit_ex(EncCtx, nullptr, nullptr, Key, EncryptNonce) != 1) - { - HasErrors = true; - return 0; - } - - int32_t Offset = 0; - // Write length - memcpy(Out + Offset, &InLength, 4); - Offset += 4; - // Write nonce - memcpy(Out + Offset, EncryptNonce, NonceBytes); - Offset += NonceBytes; - - // Encrypt - int OutLen = 0; - if (EVP_EncryptUpdate(EncCtx, Out + Offset, &OutLen, static_cast<const uint8_t*>(In), InLength) != 1) - { - HasErrors = true; - return 0; - } - Offset += OutLen; - - // Finalize - int FinalLen = 0; - if (EVP_EncryptFinal_ex(EncCtx, Out + Offset, &FinalLen) != 1) - { - HasErrors = true; - return 0; + if (EVP_EncryptInit_ex(EncCtx, nullptr, nullptr, Key, EncryptNonce) != 1) + { + HasErrors = true; + return 0; + } + + int32_t Offset = 0; + memcpy(Out + Offset, &InLength, 4); + Offset += 4; + memcpy(Out + Offset, EncryptNonce, NonceBytes); + Offset += NonceBytes; + + int OutLen = 0; + if (EVP_EncryptUpdate(EncCtx, Out + Offset, &OutLen, static_cast<const uint8_t*>(In), InLength) != 1) + { + HasErrors = true; + return 0; + } + Offset += OutLen; + + int FinalLen = 0; + if (EVP_EncryptFinal_ex(EncCtx, Out + Offset, &FinalLen) != 1) + { + HasErrors = true; + return 0; + } + Offset += FinalLen; + + if (EVP_CIPHER_CTX_ctrl(EncCtx, EVP_CTRL_GCM_GET_TAG, TagBytes, Out + Offset) != 1) + { + HasErrors = true; + return 0; + } + Offset += TagBytes; + + return Offset; +#endif } - Offset += FinalLen; - // Get tag - if (EVP_CIPHER_CTX_ctrl(EncCtx, EVP_CTRL_GCM_GET_TAG, TagBytes, Out + Offset) != 1) + int32_t DecryptMessage(void* Out, const uint8_t* Nonce, const uint8_t* CipherAndTag, int32_t DataLength) { - HasErrors = true; - return 0; - } - Offset += TagBytes; - - return Offset; -#endif - } - - // Decrypt a message. Returns decrypted data length, or 0 on failure. - // Input must be [ciphertext][tag], with nonce provided separately. - int32_t DecryptMessage(void* Out, const uint8_t* Nonce, const uint8_t* CipherAndTag, int32_t DataLength) - { #if ZEN_PLATFORM_WINDOWS - BCRYPT_ALG_HANDLE hAlg = nullptr; - BCRYPT_KEY_HANDLE hKey = nullptr; - - BCryptOpenAlgorithmProvider(&hAlg, BCRYPT_AES_ALGORITHM, nullptr, 0); - BCryptSetProperty(hAlg, BCRYPT_CHAINING_MODE, (PUCHAR)BCRYPT_CHAIN_MODE_GCM, sizeof(BCRYPT_CHAIN_MODE_GCM), 0); - BCryptGenerateSymmetricKey(hAlg, &hKey, nullptr, 0, (PUCHAR)Key, KeySize, 0); - - BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO AuthInfo; - BCRYPT_INIT_AUTH_MODE_INFO(AuthInfo); - AuthInfo.pbNonce = const_cast<uint8_t*>(Nonce); - AuthInfo.cbNonce = NonceBytes; - AuthInfo.pbTag = const_cast<uint8_t*>(CipherAndTag + DataLength); - AuthInfo.cbTag = TagBytes; - - ULONG PlainLen = 0; - NTSTATUS Status = BCryptDecrypt(hKey, - (PUCHAR)CipherAndTag, - (ULONG)DataLength, - &AuthInfo, - nullptr, - 0, - (PUCHAR)Out, - (ULONG)DataLength, - &PlainLen, - 0); - - BCryptDestroyKey(hKey); - BCryptCloseAlgorithmProvider(hAlg, 0); - - if (!BCRYPT_SUCCESS(Status)) - { - HasErrors = true; - return 0; - } + BCRYPT_ALG_HANDLE hAlg = nullptr; + BCRYPT_KEY_HANDLE hKey = nullptr; + + BCryptOpenAlgorithmProvider(&hAlg, BCRYPT_AES_ALGORITHM, nullptr, 0); + BCryptSetProperty(hAlg, BCRYPT_CHAINING_MODE, (PUCHAR)BCRYPT_CHAIN_MODE_GCM, sizeof(BCRYPT_CHAIN_MODE_GCM), 0); + BCryptGenerateSymmetricKey(hAlg, &hKey, nullptr, 0, (PUCHAR)Key, KeySize, 0); + + BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO AuthInfo; + BCRYPT_INIT_AUTH_MODE_INFO(AuthInfo); + AuthInfo.pbNonce = const_cast<uint8_t*>(Nonce); + AuthInfo.cbNonce = NonceBytes; + AuthInfo.pbTag = const_cast<uint8_t*>(CipherAndTag + DataLength); + AuthInfo.cbTag = TagBytes; + + ULONG PlainLen = 0; + NTSTATUS Status = BCryptDecrypt(hKey, + (PUCHAR)CipherAndTag, + (ULONG)DataLength, + &AuthInfo, + nullptr, + 0, + (PUCHAR)Out, + (ULONG)DataLength, + &PlainLen, + 0); - return static_cast<int32_t>(PlainLen); -#else - if (EVP_DecryptInit_ex(DecCtx, nullptr, nullptr, Key, Nonce) != 1) - { - HasErrors = true; - return 0; - } + BCryptDestroyKey(hKey); + BCryptCloseAlgorithmProvider(hAlg, 0); - int OutLen = 0; - if (EVP_DecryptUpdate(DecCtx, static_cast<uint8_t*>(Out), &OutLen, CipherAndTag, DataLength) != 1) - { - HasErrors = true; - return 0; - } + if (!BCRYPT_SUCCESS(Status)) + { + HasErrors = true; + return 0; + } - // Set the tag for verification - if (EVP_CIPHER_CTX_ctrl(DecCtx, EVP_CTRL_GCM_SET_TAG, TagBytes, const_cast<uint8_t*>(CipherAndTag + DataLength)) != 1) - { - HasErrors = true; - return 0; + return static_cast<int32_t>(PlainLen); +#else + if (EVP_DecryptInit_ex(DecCtx, nullptr, nullptr, Key, Nonce) != 1) + { + HasErrors = true; + return 0; + } + + int OutLen = 0; + if (EVP_DecryptUpdate(DecCtx, static_cast<uint8_t*>(Out), &OutLen, CipherAndTag, DataLength) != 1) + { + HasErrors = true; + return 0; + } + + if (EVP_CIPHER_CTX_ctrl(DecCtx, EVP_CTRL_GCM_SET_TAG, TagBytes, const_cast<uint8_t*>(CipherAndTag + DataLength)) != 1) + { + HasErrors = true; + return 0; + } + + int FinalLen = 0; + if (EVP_DecryptFinal_ex(DecCtx, static_cast<uint8_t*>(Out) + OutLen, &FinalLen) != 1) + { + HasErrors = true; + return 0; + } + + return OutLen + FinalLen; +#endif } + }; - int FinalLen = 0; - if (EVP_DecryptFinal_ex(DecCtx, static_cast<uint8_t*>(Out) + OutLen, &FinalLen) != 1) - { - HasErrors = true; - return 0; - } +} // anonymous namespace - return OutLen + FinalLen; -#endif - } +struct AsyncAesComputeTransport::CryptoContext : AesCryptoContext +{ + using AesCryptoContext::AesCryptoContext; }; -AesComputeTransport::AesComputeTransport(const uint8_t (&Key)[KeySize], std::unique_ptr<ComputeTransport> InnerTransport) +// --- AsyncAesComputeTransport --- + +AsyncAesComputeTransport::AsyncAesComputeTransport(const uint8_t (&Key)[KeySize], + std::unique_ptr<AsyncComputeTransport> InnerTransport, + asio::io_context& IoContext) : m_Crypto(std::make_unique<CryptoContext>(Key)) , m_Inner(std::move(InnerTransport)) +, m_IoContext(IoContext) { } -AesComputeTransport::~AesComputeTransport() +AsyncAesComputeTransport::~AsyncAesComputeTransport() { Close(); } bool -AesComputeTransport::IsValid() const +AsyncAesComputeTransport::IsValid() const { return m_Inner && m_Inner->IsValid() && m_Crypto && !m_Crypto->HasErrors && !m_IsClosed; } -size_t -AesComputeTransport::Send(const void* Data, size_t Size) +void +AsyncAesComputeTransport::AsyncWrite(const void* Data, size_t Size, AsyncIoHandler Handler) { - ZEN_TRACE_CPU("AesComputeTransport::Send"); - if (!IsValid()) { - return 0; + asio::post(m_IoContext, [Handler = std::move(Handler)] { Handler(asio::error::make_error_code(asio::error::not_connected), 0); }); + return; } - std::lock_guard<std::mutex> Lock(m_Lock); - const int32_t DataLength = static_cast<int32_t>(Size); - const size_t MessageLength = 4 + NonceBytes + Size + TagBytes; + const size_t MessageLength = 4 + CryptoContext::NonceBytes + Size + CryptoContext::TagBytes; if (m_EncryptBuffer.size() < MessageLength) { @@ -299,38 +310,36 @@ AesComputeTransport::Send(const void* Data, size_t Size) const int32_t EncryptedLen = m_Crypto->EncryptMessage(m_EncryptBuffer.data(), Data, DataLength); if (EncryptedLen == 0) { - return 0; + asio::post(m_IoContext, + [Handler = std::move(Handler)] { Handler(asio::error::make_error_code(asio::error::connection_aborted), 0); }); + return; } - if (!m_Inner->SendMessage(m_EncryptBuffer.data(), static_cast<size_t>(EncryptedLen))) - { - return 0; - } + auto EncBuf = std::make_shared<std::vector<uint8_t>>(m_EncryptBuffer.begin(), m_EncryptBuffer.begin() + EncryptedLen); - return Size; + m_Inner->AsyncWrite( + EncBuf->data(), + EncBuf->size(), + [Handler = std::move(Handler), EncBuf, Size](const std::error_code& Ec, size_t /*BytesWritten*/) { Handler(Ec, Ec ? 0 : Size); }); } -size_t -AesComputeTransport::Recv(void* Data, size_t Size) +void +AsyncAesComputeTransport::AsyncRead(void* Data, size_t Size, AsyncIoHandler Handler) { if (!IsValid()) { - return 0; + asio::post(m_IoContext, [Handler = std::move(Handler)] { Handler(asio::error::make_error_code(asio::error::not_connected), 0); }); + return; } - // AES-GCM decrypts entire messages at once, but the caller may request fewer bytes - // than the decrypted message contains. Excess bytes are buffered in m_RemainingData - // and returned on subsequent Recv calls without another decryption round-trip. - ZEN_TRACE_CPU("AesComputeTransport::Recv"); - - std::lock_guard<std::mutex> Lock(m_Lock); + uint8_t* Dest = static_cast<uint8_t*>(Data); if (!m_RemainingData.empty()) { const size_t Available = m_RemainingData.size() - m_RemainingOffset; const size_t ToCopy = std::min(Available, Size); - memcpy(Data, m_RemainingData.data() + m_RemainingOffset, ToCopy); + memcpy(Dest, m_RemainingData.data() + m_RemainingOffset, ToCopy); m_RemainingOffset += ToCopy; if (m_RemainingOffset >= m_RemainingData.size()) @@ -339,78 +348,96 @@ AesComputeTransport::Recv(void* Data, size_t Size) m_RemainingOffset = 0; } - return ToCopy; - } - - // Receive packet header: [length(4B)][nonce(12B)] - struct PacketHeader - { - int32_t DataLength = 0; - uint8_t Nonce[NonceBytes] = {}; - } Header; - - if (!m_Inner->RecvMessage(&Header, sizeof(Header))) - { - return 0; - } - - // Validate DataLength to prevent OOM from malicious/corrupt peers - static constexpr int32_t MaxDataLength = 64 * 1024 * 1024; // 64 MiB - - if (Header.DataLength <= 0 || Header.DataLength > MaxDataLength) - { - ZEN_WARN("AES recv: invalid DataLength {} from peer", Header.DataLength); - return 0; - } - - // Receive ciphertext + tag - const size_t MessageLength = static_cast<size_t>(Header.DataLength) + TagBytes; - - if (m_EncryptBuffer.size() < MessageLength) - { - m_EncryptBuffer.resize(MessageLength); - } - - if (!m_Inner->RecvMessage(m_EncryptBuffer.data(), MessageLength)) - { - return 0; - } - - // Decrypt - const size_t BytesToReturn = std::min(static_cast<size_t>(Header.DataLength), Size); - - // We need a temporary buffer for decryption if we can't decrypt directly into output - std::vector<uint8_t> DecryptedBuf(static_cast<size_t>(Header.DataLength)); - - const int32_t Decrypted = m_Crypto->DecryptMessage(DecryptedBuf.data(), Header.Nonce, m_EncryptBuffer.data(), Header.DataLength); - if (Decrypted == 0) - { - return 0; - } - - memcpy(Data, DecryptedBuf.data(), BytesToReturn); + if (ToCopy == Size) + { + asio::post(m_IoContext, [Handler = std::move(Handler), Size] { Handler(std::error_code{}, Size); }); + return; + } - // Store remaining data if we couldn't return everything - if (static_cast<size_t>(Header.DataLength) > BytesToReturn) - { - m_RemainingOffset = 0; - m_RemainingData.assign(DecryptedBuf.begin() + BytesToReturn, DecryptedBuf.begin() + Header.DataLength); + DoRecvMessage(Dest + ToCopy, Size - ToCopy, std::move(Handler)); + return; } - return BytesToReturn; + DoRecvMessage(Dest, Size, std::move(Handler)); } void -AesComputeTransport::MarkComplete() +AsyncAesComputeTransport::DoRecvMessage(uint8_t* Dest, size_t Size, AsyncIoHandler Handler) { - if (IsValid()) - { - m_Inner->MarkComplete(); - } + static constexpr size_t HeaderSize = 4 + CryptoContext::NonceBytes; + auto HeaderBuf = std::make_shared<std::array<uint8_t, 4 + 12>>(); + + m_Inner->AsyncRead(HeaderBuf->data(), + HeaderSize, + [this, Dest, Size, Handler = std::move(Handler), HeaderBuf](const std::error_code& Ec, size_t /*Bytes*/) mutable { + if (Ec) + { + Handler(Ec, 0); + return; + } + + int32_t DataLength = 0; + memcpy(&DataLength, HeaderBuf->data(), 4); + + static constexpr int32_t MaxDataLength = 64 * 1024 * 1024; + if (DataLength <= 0 || DataLength > MaxDataLength) + { + Handler(asio::error::make_error_code(asio::error::invalid_argument), 0); + return; + } + + const size_t MessageLength = static_cast<size_t>(DataLength) + CryptoContext::TagBytes; + if (m_DecryptBuffer.size() < MessageLength) + { + m_DecryptBuffer.resize(MessageLength); + } + + auto NonceBuf = std::make_shared<std::array<uint8_t, CryptoContext::NonceBytes>>(); + memcpy(NonceBuf->data(), HeaderBuf->data() + 4, CryptoContext::NonceBytes); + + m_Inner->AsyncRead( + m_DecryptBuffer.data(), + MessageLength, + [this, Dest, Size, Handler = std::move(Handler), DataLength, NonceBuf](const std::error_code& Ec, + size_t /*Bytes*/) mutable { + if (Ec) + { + Handler(Ec, 0); + return; + } + + std::vector<uint8_t> PlaintextBuf(static_cast<size_t>(DataLength)); + const int32_t Decrypted = + m_Crypto->DecryptMessage(PlaintextBuf.data(), NonceBuf->data(), m_DecryptBuffer.data(), DataLength); + if (Decrypted == 0) + { + Handler(asio::error::make_error_code(asio::error::connection_aborted), 0); + return; + } + + const size_t BytesToReturn = std::min(static_cast<size_t>(Decrypted), Size); + memcpy(Dest, PlaintextBuf.data(), BytesToReturn); + + if (static_cast<size_t>(Decrypted) > BytesToReturn) + { + m_RemainingOffset = 0; + m_RemainingData.assign(PlaintextBuf.begin() + BytesToReturn, PlaintextBuf.begin() + Decrypted); + } + + if (BytesToReturn < Size) + { + DoRecvMessage(Dest + BytesToReturn, Size - BytesToReturn, std::move(Handler)); + } + else + { + Handler(std::error_code{}, Size); + } + }); + }); } void -AesComputeTransport::Close() +AsyncAesComputeTransport::Close() { if (!m_IsClosed) { diff --git a/src/zenhorde/hordetransportaes.h b/src/zenhorde/hordetransportaes.h index efcad9835..a1800c684 100644 --- a/src/zenhorde/hordetransportaes.h +++ b/src/zenhorde/hordetransportaes.h @@ -6,47 +6,53 @@ #include <cstdint> #include <memory> -#include <mutex> #include <vector> +namespace asio { +class io_context; +} + namespace zen::horde { -/** AES-256-GCM encrypted transport wrapper. +/** Async AES-256-GCM encrypted transport wrapper. * - * Wraps an inner ComputeTransport, encrypting all outgoing data and decrypting - * all incoming data using AES-256-GCM. The nonce is mutated per message using - * the Horde nonce mangling scheme: n32[0]++; n32[1]--; n32[2] = n32[0] ^ n32[1]. + * Wraps an AsyncComputeTransport, encrypting outgoing and decrypting incoming + * data using AES-256-GCM. The nonce is mutated per message using the Horde + * nonce mangling scheme: n32[0]++; n32[1]--; n32[2] = n32[0] ^ n32[1]. * * Wire format per encrypted message: * [plaintext length (4B little-endian)][nonce (12B)][ciphertext][GCM tag (16B)] * * Uses BCrypt on Windows and OpenSSL EVP on Linux/macOS (selected at compile time). + * + * Thread safety: all operations must be serialized by the caller (e.g. via a strand). */ -class AesComputeTransport final : public ComputeTransport +class AsyncAesComputeTransport final : public AsyncComputeTransport { public: - AesComputeTransport(const uint8_t (&Key)[KeySize], std::unique_ptr<ComputeTransport> InnerTransport); - ~AesComputeTransport() override; + AsyncAesComputeTransport(const uint8_t (&Key)[KeySize], + std::unique_ptr<AsyncComputeTransport> InnerTransport, + asio::io_context& IoContext); + ~AsyncAesComputeTransport() override; - bool IsValid() const override; - size_t Send(const void* Data, size_t Size) override; - size_t Recv(void* Data, size_t Size) override; - void MarkComplete() override; - void Close() override; + bool IsValid() const override; + void AsyncWrite(const void* Data, size_t Size, AsyncIoHandler Handler) override; + void AsyncRead(void* Data, size_t Size, AsyncIoHandler Handler) override; + void Close() override; private: - static constexpr size_t NonceBytes = 12; ///< AES-GCM nonce size - static constexpr size_t TagBytes = 16; ///< AES-GCM authentication tag size + void DoRecvMessage(uint8_t* Dest, size_t Size, AsyncIoHandler Handler); struct CryptoContext; - std::unique_ptr<CryptoContext> m_Crypto; - std::unique_ptr<ComputeTransport> m_Inner; - std::vector<uint8_t> m_EncryptBuffer; - std::vector<uint8_t> m_RemainingData; ///< Buffered decrypted data from a partially consumed Recv - size_t m_RemainingOffset = 0; - std::mutex m_Lock; - bool m_IsClosed = false; + std::unique_ptr<CryptoContext> m_Crypto; + std::unique_ptr<AsyncComputeTransport> m_Inner; + asio::io_context& m_IoContext; + std::vector<uint8_t> m_EncryptBuffer; + std::vector<uint8_t> m_DecryptBuffer; + std::vector<uint8_t> m_RemainingData; + size_t m_RemainingOffset = 0; + bool m_IsClosed = false; }; } // namespace zen::horde diff --git a/src/zenhorde/include/zenhorde/hordeclient.h b/src/zenhorde/include/zenhorde/hordeclient.h index 201d68b83..87caec019 100644 --- a/src/zenhorde/include/zenhorde/hordeclient.h +++ b/src/zenhorde/include/zenhorde/hordeclient.h @@ -45,14 +45,15 @@ struct MachineInfo uint8_t Key[KeySize] = {}; ///< 32-byte AES key (when EncryptionMode == AES) bool IsWindows = false; std::string LeaseId; + std::string Pool; std::map<std::string, PortInfo> Ports; /** Return the address to connect to, accounting for connection mode. */ - const std::string& GetConnectionAddress() const { return Mode == ConnectionMode::Relay ? ConnectionAddress : Ip; } + [[nodiscard]] const std::string& GetConnectionAddress() const { return Mode == ConnectionMode::Relay ? ConnectionAddress : Ip; } /** Return the port to connect to, accounting for connection mode and port mapping. */ - uint16_t GetConnectionPort() const + [[nodiscard]] uint16_t GetConnectionPort() const { if (Mode == ConnectionMode::Relay) { @@ -65,7 +66,20 @@ struct MachineInfo return Port; } - bool IsValid() const { return !Ip.empty() && Port != 0xFFFF; } + /** Return the address and port for the Zen service endpoint, accounting for relay port mapping. */ + [[nodiscard]] std::pair<const std::string&, uint16_t> GetZenServiceEndpoint(uint16_t DefaultPort) const + { + if (Mode == ConnectionMode::Relay) + { + if (auto It = Ports.find("ZenPort"); It != Ports.end()) + { + return {ConnectionAddress, It->second.Port}; + } + } + return {Ip, DefaultPort}; + } + + [[nodiscard]] bool IsValid() const { return !Ip.empty() && Port != 0xFFFF; } }; /** Result of cluster auto-resolution via the Horde API. */ @@ -83,31 +97,29 @@ struct ClusterInfo class HordeClient { public: - explicit HordeClient(const HordeConfig& Config); + explicit HordeClient(HordeConfig Config); ~HordeClient(); HordeClient(const HordeClient&) = delete; HordeClient& operator=(const HordeClient&) = delete; /** Initialize the underlying HTTP client. Must be called before other methods. */ - bool Initialize(); + [[nodiscard]] bool Initialize(); /** Build the JSON request body for cluster resolution and machine requests. * Encodes pool, condition, connection mode, encryption, and port requirements. */ - std::string BuildRequestBody() const; + [[nodiscard]] std::string BuildRequestBody() const; /** Resolve the best cluster for the given request via POST /api/v2/compute/_cluster. */ - bool ResolveCluster(const std::string& RequestBody, ClusterInfo& OutCluster); + [[nodiscard]] bool ResolveCluster(const std::string& RequestBody, ClusterInfo& OutCluster); /** Request a compute machine from the given cluster via POST /api/v2/compute/{clusterId}. * On success, populates OutMachine with connection details and credentials. */ - bool RequestMachine(const std::string& RequestBody, const std::string& ClusterId, MachineInfo& OutMachine); + [[nodiscard]] bool RequestMachine(const std::string& RequestBody, const std::string& ClusterId, MachineInfo& OutMachine); LoggerRef Log() { return m_Log; } private: - bool ParseHexBytes(std::string_view Hex, uint8_t* Out, size_t OutSize); - HordeConfig m_Config; std::unique_ptr<zen::HttpClient> m_Http; LoggerRef m_Log; diff --git a/src/zenhorde/include/zenhorde/hordeconfig.h b/src/zenhorde/include/zenhorde/hordeconfig.h index dd70f9832..3a4dfb386 100644 --- a/src/zenhorde/include/zenhorde/hordeconfig.h +++ b/src/zenhorde/include/zenhorde/hordeconfig.h @@ -4,6 +4,10 @@ #include <zenhorde/zenhorde.h> +#include <zenhttp/httpclient.h> + +#include <functional> +#include <optional> #include <string> namespace zen::horde { @@ -33,20 +37,25 @@ struct HordeConfig static constexpr const char* ClusterDefault = "default"; static constexpr const char* ClusterAuto = "_auto"; - bool Enabled = false; ///< Whether Horde provisioning is active - std::string ServerUrl; ///< Horde server base URL (e.g. "https://horde.example.com") - std::string AuthToken; ///< Authentication token for the Horde API - std::string Pool; ///< Pool name to request machines from - std::string Cluster = ClusterDefault; ///< Cluster ID, or "_auto" to auto-resolve - std::string Condition; ///< Agent filter expression for machine selection - std::string HostAddress; ///< Address that provisioned agents use to connect back to us - std::string BinariesPath; ///< Path to directory containing zenserver binary for remote upload - uint16_t ZenServicePort = 8558; ///< Port number that provisioned agents should forward to us for Zen service communication - - int MaxCores = 2048; - bool AllowWine = true; ///< Allow running Windows binaries under Wine on Linux agents - ConnectionMode Mode = ConnectionMode::Direct; - Encryption EncryptionMode = Encryption::None; + bool Enabled = false; ///< Whether Horde provisioning is active + std::string ServerUrl; ///< Horde server base URL (e.g. "https://horde.example.com") + std::string AuthToken; ///< Authentication token for the Horde API (static fallback) + + /// Optional token provider with automatic refresh (e.g. from OidcToken executable). + /// When set, takes priority over the static AuthToken string. + std::optional<std::function<HttpClientAccessToken()>> AccessTokenProvider; + std::string Pool; ///< Pool name to request machines from + std::string Cluster = ClusterDefault; ///< Cluster ID, or "_auto" to auto-resolve + std::string Condition; ///< Agent filter expression for machine selection + std::string HostAddress; ///< Address that provisioned agents use to connect back to us + std::string BinariesPath; ///< Path to directory containing zenserver binary for remote upload + uint16_t ZenServicePort = 8558; ///< Port number that provisioned agents should forward to us for Zen service communication + + int MaxCores = 2048; + int DrainGracePeriodSeconds = 300; ///< Grace period for draining agents before force-kill (default 5 min) + bool AllowWine = true; ///< Allow running Windows binaries under Wine on Linux agents + ConnectionMode Mode = ConnectionMode::Direct; + Encryption EncryptionMode = Encryption::None; /** Validate the configuration. Returns false if the configuration is invalid * (e.g. Relay mode without AES encryption). */ diff --git a/src/zenhorde/include/zenhorde/hordeprovisioner.h b/src/zenhorde/include/zenhorde/hordeprovisioner.h index 4e2e63bbd..1dd12936b 100644 --- a/src/zenhorde/include/zenhorde/hordeprovisioner.h +++ b/src/zenhorde/include/zenhorde/hordeprovisioner.h @@ -2,9 +2,12 @@ #pragma once +#include <zenhorde/hordeclient.h> #include <zenhorde/hordeconfig.h> +#include <zencompute/provisionerstate.h> #include <zencore/logbase.h> +#include <zencore/thread.h> #include <atomic> #include <cstdint> @@ -12,11 +15,18 @@ #include <memory> #include <mutex> #include <string> +#include <thread> +#include <unordered_set> #include <vector> +namespace asio { +class io_context; +} + namespace zen::horde { class HordeClient; +class AsyncHordeAgent; /** Snapshot of the current provisioning state, returned by HordeProvisioner::GetStats(). */ struct ProvisioningStats @@ -35,13 +45,12 @@ struct ProvisioningStats * binary, and executing it remotely. Each provisioned machine runs zenserver * in compute mode, which announces itself back to the orchestrator. * - * Spawns one thread per agent. Each thread handles the full lifecycle: - * HTTP request -> TCP connect -> nonce handshake -> optional AES encryption -> - * channel setup -> binary upload -> remote execution -> poll until exit. + * Agent work (HTTP request, connect, upload, poll) is dispatched to a thread + * pool rather than spawning a dedicated thread per agent. * * Thread safety: SetTargetCoreCount and GetStats may be called from any thread. */ -class HordeProvisioner +class HordeProvisioner : public compute::IProvisionerStateProvider { public: /** Construct a provisioner. @@ -52,38 +61,48 @@ public: HordeProvisioner(const HordeConfig& Config, const std::filesystem::path& BinariesPath, const std::filesystem::path& WorkingDir, - std::string_view OrchestratorEndpoint); + std::string_view OrchestratorEndpoint, + std::string_view CoordinatorSession = {}, + bool CleanStart = false, + std::string_view TraceHost = {}); - /** Signals all agent threads to exit and joins them. */ - ~HordeProvisioner(); + /** Signals all agents to exit and waits for completion. */ + ~HordeProvisioner() override; HordeProvisioner(const HordeProvisioner&) = delete; HordeProvisioner& operator=(const HordeProvisioner&) = delete; /** Set the target number of cores to provision. - * Clamped to HordeConfig::MaxCores. Spawns new agent threads if the - * estimated core count is below the target. Also joins any finished - * agent threads. */ - void SetTargetCoreCount(uint32_t Count); + * Clamped to HordeConfig::MaxCores. Dispatches new agent work if the + * estimated core count is below the target. Also removes finished agents. */ + void SetTargetCoreCount(uint32_t Count) override; /** Return a snapshot of the current provisioning counters. */ ProvisioningStats GetStats() const; - uint32_t GetActiveCoreCount() const { return m_ActiveCoreCount.load(); } - uint32_t GetAgentCount() const; + // IProvisionerStateProvider + std::string_view GetName() const override { return "horde"; } + uint32_t GetTargetCoreCount() const override { return m_TargetCoreCount.load(); } + uint32_t GetEstimatedCoreCount() const override { return m_EstimatedCoreCount.load(); } + uint32_t GetActiveCoreCount() const override { return m_ActiveCoreCount.load(); } + uint32_t GetAgentCount() const override; + uint32_t GetDrainingAgentCount() const override { return m_AgentsDraining.load(); } + compute::AgentProvisioningStatus GetAgentStatus(std::string_view WorkerId) const override; private: LoggerRef Log() { return m_Log; } - struct AgentWrapper; - void RequestAgent(); - void ThreadAgent(AgentWrapper& Wrapper); + void ProvisionAgent(); + bool InitializeHordeClient(); HordeConfig m_Config; std::filesystem::path m_BinariesPath; std::filesystem::path m_WorkingDir; std::string m_OrchestratorEndpoint; + std::string m_CoordinatorSession; + bool m_CleanStart = false; + std::string m_TraceHost; std::unique_ptr<HordeClient> m_HordeClient; @@ -91,20 +110,43 @@ private: std::vector<std::pair<std::string, std::filesystem::path>> m_Bundles; ///< (locator, bundleDir) pairs bool m_BundlesCreated = false; - mutable std::mutex m_AgentsLock; - std::vector<std::unique_ptr<AgentWrapper>> m_Agents; - std::atomic<uint64_t> m_LastRequestFailTime{0}; std::atomic<uint32_t> m_TargetCoreCount{0}; std::atomic<uint32_t> m_EstimatedCoreCount{0}; std::atomic<uint32_t> m_ActiveCoreCount{0}; std::atomic<uint32_t> m_AgentsActive{0}; + std::atomic<uint32_t> m_AgentsDraining{0}; std::atomic<uint32_t> m_AgentsRequesting{0}; std::atomic<bool> m_AskForAgents{true}; + std::atomic<uint32_t> m_PendingWorkItems{0}; + Event m_AllWorkDone; LoggerRef m_Log; + // Async I/O + std::unique_ptr<asio::io_context> m_IoContext; + std::vector<std::thread> m_IoThreads; + + struct AsyncAgentEntry + { + std::shared_ptr<AsyncHordeAgent> Agent; + std::string RemoteEndpoint; + std::string LeaseId; + uint16_t CoreCount = 0; + bool Draining = false; + }; + + mutable std::mutex m_AsyncAgentsLock; + std::vector<AsyncAgentEntry> m_AsyncAgents; + mutable std::unordered_set<std::string> m_RecentlyDrainedWorkerIds; ///< Worker IDs of agents that completed after draining + + void OnAsyncAgentDone(std::shared_ptr<AsyncHordeAgent> Agent); + void DrainAsyncAgent(AsyncAgentEntry& Entry); + + std::vector<std::string> BuildAgentArgs(const MachineInfo& Machine) const; + static constexpr uint32_t EstimatedCoresPerAgent = 32; + static constexpr int IoThreadCount = 3; }; } // namespace zen::horde diff --git a/src/zenhttp/clients/httpclientcurl.cpp b/src/zenhttp/clients/httpclientcurl.cpp index b9af9bd52..56b9c39c5 100644 --- a/src/zenhttp/clients/httpclientcurl.cpp +++ b/src/zenhttp/clients/httpclientcurl.cpp @@ -228,6 +228,13 @@ CurlHttpClient::Session::Perform() curl_easy_getinfo(Handle, CURLINFO_SIZE_DOWNLOAD_T, &DownBytes); Result.DownloadedBytes = static_cast<int64_t>(DownBytes); + char* EffectiveUrl = nullptr; + curl_easy_getinfo(Handle, CURLINFO_EFFECTIVE_URL, &EffectiveUrl); + if (EffectiveUrl) + { + Result.Url = EffectiveUrl; + } + return Result; } @@ -294,8 +301,9 @@ CurlHttpClient::CommonResponse(std::string_view SessionId, if (Result.ErrorCode != CURLE_OPERATION_TIMEDOUT && Result.ErrorCode != CURLE_COULDNT_CONNECT && Result.ErrorCode != CURLE_ABORTED_BY_CALLBACK) { - ZEN_WARN("HttpClient client failure (session: {}): ({}) '{}'", + ZEN_WARN("HttpClient client failure (session: {}, url: {}): ({}) '{}'", SessionId, + Result.Url, static_cast<int>(Result.ErrorCode), Result.ErrorMessage); } @@ -443,6 +451,7 @@ CurlHttpClient::ShouldRetry(const CurlResult& Result) case CURLE_RECV_ERROR: case CURLE_SEND_ERROR: case CURLE_OPERATION_TIMEDOUT: + case CURLE_PARTIAL_FILE: return true; default: return false; @@ -489,10 +498,11 @@ CurlHttpClient::DoWithRetry(std::string_view SessionId, std::function<CurlResult { if (Result.ErrorCode != CURLE_OK) { - ZEN_INFO("Retry (session: {}): HTTP error ({}) '{}' Attempt {}/{}", + ZEN_INFO("Retry (session: {}): HTTP error ({}) '{}' (Curl error: {}) Attempt {}/{}", SessionId, static_cast<int>(MapCurlError(Result.ErrorCode)), Result.ErrorMessage, + static_cast<int>(Result.ErrorCode), Attempt, m_ConnectionSettings.RetryCount + 1); } diff --git a/src/zenhttp/clients/httpclientcurl.h b/src/zenhttp/clients/httpclientcurl.h index bdeb46633..ea9193e65 100644 --- a/src/zenhttp/clients/httpclientcurl.h +++ b/src/zenhttp/clients/httpclientcurl.h @@ -73,6 +73,7 @@ private: int64_t DownloadedBytes = 0; CURLcode ErrorCode = CURLE_OK; std::string ErrorMessage; + std::string Url; }; struct Session diff --git a/src/zenhttp/httpclientauth.cpp b/src/zenhttp/httpclientauth.cpp index c42841922..0432e50ef 100644 --- a/src/zenhttp/httpclientauth.cpp +++ b/src/zenhttp/httpclientauth.cpp @@ -94,7 +94,8 @@ namespace zen { namespace httpclientauth { std::string_view CloudHost, bool Unattended, bool Quiet, - bool Hidden) + bool Hidden, + bool IsHordeUrl) { Stopwatch Timer; @@ -117,8 +118,9 @@ namespace zen { namespace httpclientauth { } }); - const std::string ProcArgs = fmt::format("{} --AuthConfigUrl {} --OutFile {} --Unattended={}", + const std::string ProcArgs = fmt::format("{} {} {} --OutFile {} --Unattended={}", OidcExecutablePath, + IsHordeUrl ? "--HordeUrl" : "--AuthConfigUrl", CloudHost, AuthTokenPath, Unattended ? "true"sv : "false"sv); @@ -193,7 +195,7 @@ namespace zen { namespace httpclientauth { } else { - ZEN_WARN("Failed running {} to get auth token, error code {}", OidcExecutablePath, ExitCode); + ZEN_WARN("Failed running '{}' to get auth token, error code {}", ProcArgs, ExitCode); } return HttpClientAccessToken{}; } @@ -202,9 +204,10 @@ namespace zen { namespace httpclientauth { std::string_view CloudHost, bool Quiet, bool Unattended, - bool Hidden) + bool Hidden, + bool IsHordeUrl) { - HttpClientAccessToken InitialToken = GetOidcTokenFromExe(OidcExecutablePath, CloudHost, Unattended, Quiet, Hidden); + HttpClientAccessToken InitialToken = GetOidcTokenFromExe(OidcExecutablePath, CloudHost, Unattended, Quiet, Hidden, IsHordeUrl); if (InitialToken.IsValid()) { return [OidcExecutablePath = std::filesystem::path(OidcExecutablePath), @@ -212,12 +215,13 @@ namespace zen { namespace httpclientauth { Token = InitialToken, Quiet, Unattended, - Hidden]() mutable { + Hidden, + IsHordeUrl]() mutable { if (!Token.NeedsRefresh()) { return std::move(Token); } - return GetOidcTokenFromExe(OidcExecutablePath, CloudHost, Unattended, Quiet, Hidden); + return GetOidcTokenFromExe(OidcExecutablePath, CloudHost, Unattended, Quiet, Hidden, IsHordeUrl); }; } return {}; diff --git a/src/zenhttp/include/zenhttp/httpclientauth.h b/src/zenhttp/include/zenhttp/httpclientauth.h index ce646ebd7..9220a50b6 100644 --- a/src/zenhttp/include/zenhttp/httpclientauth.h +++ b/src/zenhttp/include/zenhttp/httpclientauth.h @@ -33,7 +33,8 @@ namespace httpclientauth { std::string_view CloudHost, bool Quiet, bool Unattended, - bool Hidden); + bool Hidden, + bool IsHordeUrl = false); } // namespace httpclientauth } // namespace zen diff --git a/src/zennomad/include/zennomad/nomadclient.h b/src/zennomad/include/zennomad/nomadclient.h index 0a3411ace..cebf217e1 100644 --- a/src/zennomad/include/zennomad/nomadclient.h +++ b/src/zennomad/include/zennomad/nomadclient.h @@ -52,7 +52,11 @@ public: /** Build the Nomad job registration JSON for the given job ID and orchestrator endpoint. * The JSON structure varies based on the configured driver and distribution mode. */ - std::string BuildJobJson(const std::string& JobId, const std::string& OrchestratorEndpoint) const; + std::string BuildJobJson(const std::string& JobId, + const std::string& OrchestratorEndpoint, + const std::string& CoordinatorSession = {}, + bool CleanStart = false, + const std::string& TraceHost = {}) const; /** Submit a job via PUT /v1/jobs. On success, populates OutJob with the job info. */ bool SubmitJob(const std::string& JobJson, NomadJobInfo& OutJob); diff --git a/src/zennomad/include/zennomad/nomadprovisioner.h b/src/zennomad/include/zennomad/nomadprovisioner.h index 750693b3f..a8368e3dc 100644 --- a/src/zennomad/include/zennomad/nomadprovisioner.h +++ b/src/zennomad/include/zennomad/nomadprovisioner.h @@ -47,7 +47,11 @@ public: /** Construct a provisioner. * @param Config Nomad connection and job configuration. * @param OrchestratorEndpoint URL of the orchestrator that remote workers announce to. */ - NomadProvisioner(const NomadConfig& Config, std::string_view OrchestratorEndpoint); + NomadProvisioner(const NomadConfig& Config, + std::string_view OrchestratorEndpoint, + std::string_view CoordinatorSession = {}, + bool CleanStart = false, + std::string_view TraceHost = {}); /** Signals the management thread to exit and stops all tracked jobs. */ ~NomadProvisioner(); @@ -83,6 +87,9 @@ private: NomadConfig m_Config; std::string m_OrchestratorEndpoint; + std::string m_CoordinatorSession; + bool m_CleanStart = false; + std::string m_TraceHost; std::unique_ptr<NomadClient> m_Client; diff --git a/src/zennomad/nomadclient.cpp b/src/zennomad/nomadclient.cpp index 9edcde125..4bb09a930 100644 --- a/src/zennomad/nomadclient.cpp +++ b/src/zennomad/nomadclient.cpp @@ -58,7 +58,11 @@ NomadClient::Initialize() } std::string -NomadClient::BuildJobJson(const std::string& JobId, const std::string& OrchestratorEndpoint) const +NomadClient::BuildJobJson(const std::string& JobId, + const std::string& OrchestratorEndpoint, + const std::string& CoordinatorSession, + bool CleanStart, + const std::string& TraceHost) const { ZEN_TRACE_CPU("NomadClient::BuildJobJson"); @@ -94,6 +98,22 @@ NomadClient::BuildJobJson(const std::string& JobId, const std::string& Orchestra IdArg << "--instance-id=nomad-" << JobId; Args.push_back(std::string(IdArg.ToView())); } + if (!CoordinatorSession.empty()) + { + ExtendableStringBuilder<128> SessionArg; + SessionArg << "--coordinator-session=" << CoordinatorSession; + Args.push_back(std::string(SessionArg.ToView())); + } + if (CleanStart) + { + Args.push_back("--clean"); + } + if (!TraceHost.empty()) + { + ExtendableStringBuilder<128> TraceArg; + TraceArg << "--tracehost=" << TraceHost; + Args.push_back(std::string(TraceArg.ToView())); + } TaskConfig["args"] = Args; } else @@ -115,6 +135,22 @@ NomadClient::BuildJobJson(const std::string& JobId, const std::string& Orchestra IdArg << "--instance-id=nomad-" << JobId; Args.push_back(std::string(IdArg.ToView())); } + if (!CoordinatorSession.empty()) + { + ExtendableStringBuilder<128> SessionArg; + SessionArg << "--coordinator-session=" << CoordinatorSession; + Args.push_back(std::string(SessionArg.ToView())); + } + if (CleanStart) + { + Args.push_back("--clean"); + } + if (!TraceHost.empty()) + { + ExtendableStringBuilder<128> TraceArg; + TraceArg << "--tracehost=" << TraceHost; + Args.push_back(std::string(TraceArg.ToView())); + } TaskConfig["args"] = Args; } diff --git a/src/zennomad/nomadprovisioner.cpp b/src/zennomad/nomadprovisioner.cpp index 3fe9c0ac3..e07ce155e 100644 --- a/src/zennomad/nomadprovisioner.cpp +++ b/src/zennomad/nomadprovisioner.cpp @@ -14,9 +14,16 @@ namespace zen::nomad { -NomadProvisioner::NomadProvisioner(const NomadConfig& Config, std::string_view OrchestratorEndpoint) +NomadProvisioner::NomadProvisioner(const NomadConfig& Config, + std::string_view OrchestratorEndpoint, + std::string_view CoordinatorSession, + bool CleanStart, + std::string_view TraceHost) : m_Config(Config) , m_OrchestratorEndpoint(OrchestratorEndpoint) +, m_CoordinatorSession(CoordinatorSession) +, m_CleanStart(CleanStart) +, m_TraceHost(TraceHost) , m_ProcessId(static_cast<uint32_t>(zen::GetCurrentProcessId())) , m_Log(zen::logging::Get("nomad.provisioner")) { @@ -154,7 +161,7 @@ NomadProvisioner::SubmitNewJobs() ZEN_DEBUG("submitting job '{}' (estimated: {}, target: {})", JobId, m_EstimatedCoreCount.load(), m_TargetCoreCount.load()); - const std::string JobJson = m_Client->BuildJobJson(JobId, m_OrchestratorEndpoint); + const std::string JobJson = m_Client->BuildJobJson(JobId, m_OrchestratorEndpoint, m_CoordinatorSession, m_CleanStart, m_TraceHost); NomadJobInfo JobInfo; JobInfo.Id = JobId; diff --git a/src/zenremotestore/builds/buildstorageoperations.cpp b/src/zenremotestore/builds/buildstorageoperations.cpp index 1f8b96cc4..3a41cd7eb 100644 --- a/src/zenremotestore/builds/buildstorageoperations.cpp +++ b/src/zenremotestore/builds/buildstorageoperations.cpp @@ -1664,21 +1664,34 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState) } if (!BlockBuffer) { - BlockBuffer = m_Storage.BuildStorage->GetBuildBlob(m_BuildId, BlockDescription.BlockHash); - if (BlockBuffer && m_Storage.CacheStorage && m_Options.PopulateCache) + try + { + BlockBuffer = m_Storage.BuildStorage->GetBuildBlob(m_BuildId, BlockDescription.BlockHash); + } + catch (const std::exception&) + { + // Silence http errors due to abort + if (!m_AbortFlag) + { + throw; + } + } + } + if (!m_AbortFlag) + { + if (!BlockBuffer) + { + throw std::runtime_error(fmt::format("Block {} is missing", BlockDescription.BlockHash)); + } + + if (m_Storage.CacheStorage && m_Options.PopulateCache) { m_Storage.CacheStorage->PutBuildBlob(m_BuildId, BlockDescription.BlockHash, ZenContentType::kCompressedBinary, CompositeBuffer(SharedBuffer(BlockBuffer))); } - } - if (!BlockBuffer) - { - throw std::runtime_error(fmt::format("Block {} is missing", BlockDescription.BlockHash)); - } - if (!m_AbortFlag) - { + uint64_t BlockSize = BlockBuffer.GetSize(); m_DownloadStats.DownloadedBlockCount++; m_DownloadStats.DownloadedBlockByteCount += BlockSize; @@ -3293,31 +3306,45 @@ BuildsOperationUpdateFolder::DownloadBuildBlob(uint32_t RemoteChunkInde } else { - BuildBlob = m_Storage.BuildStorage->GetBuildBlob(m_BuildId, ChunkHash); - if (BuildBlob && m_Storage.CacheStorage && m_Options.PopulateCache) + try { - m_Storage.CacheStorage->PutBuildBlob(m_BuildId, - ChunkHash, - ZenContentType::kCompressedBinary, - CompositeBuffer(SharedBuffer(BuildBlob))); + BuildBlob = m_Storage.BuildStorage->GetBuildBlob(m_BuildId, ChunkHash); } - if (!BuildBlob) + catch (const std::exception&) { - throw std::runtime_error(fmt::format("Chunk {} is missing", ChunkHash)); + // Silence http errors due to abort + if (!m_AbortFlag) + { + throw; + } } - if (!m_Options.PrimeCacheOnly) + if (!m_AbortFlag) { - if (!m_AbortFlag) + if (BuildBlob && m_Storage.CacheStorage && m_Options.PopulateCache) { - uint64_t BlobSize = BuildBlob.GetSize(); - m_DownloadStats.DownloadedChunkCount++; - m_DownloadStats.DownloadedChunkByteCount += BlobSize; - if (m_DownloadStats.RequestsCompleteCount.fetch_add(1) + 1 == TotalRequestCount) + m_Storage.CacheStorage->PutBuildBlob(m_BuildId, + ChunkHash, + ZenContentType::kCompressedBinary, + CompositeBuffer(SharedBuffer(BuildBlob))); + } + if (!BuildBlob) + { + throw std::runtime_error(fmt::format("Chunk {} is missing", ChunkHash)); + } + if (!m_Options.PrimeCacheOnly) + { + if (!m_AbortFlag) { - FilteredDownloadedBytesPerSecond.Stop(); - } + uint64_t BlobSize = BuildBlob.GetSize(); + m_DownloadStats.DownloadedChunkCount++; + m_DownloadStats.DownloadedChunkByteCount += BlobSize; + if (m_DownloadStats.RequestsCompleteCount.fetch_add(1) + 1 == TotalRequestCount) + { + FilteredDownloadedBytesPerSecond.Stop(); + } - OnDownloaded(std::move(BuildBlob)); + OnDownloaded(std::move(BuildBlob)); + } } } } @@ -3519,64 +3546,77 @@ BuildsOperationUpdateFolder::DownloadPartialBlock( auto SubRanges = RangesSpan.subspan(SubRangeCountComplete, SubRangeCount); - BuildStorageBase::BuildBlobRanges RangeBuffers = - m_Storage.BuildStorage->GetBuildBlobRanges(m_BuildId, BlockDescription.BlockHash, SubRanges); - if (m_AbortFlag) + BuildStorageBase::BuildBlobRanges RangeBuffers; + + try { - break; + RangeBuffers = m_Storage.BuildStorage->GetBuildBlobRanges(m_BuildId, BlockDescription.BlockHash, SubRanges); } - if (RangeBuffers.PayloadBuffer) + catch (const std::exception&) { - if (RangeBuffers.Ranges.empty()) + // Silence http errors due to abort + if (!m_AbortFlag) { - // Jupiter will ignore the ranges and send the whole payload if it fetches the payload from S3 - // Upload to cache (if enabled) and use the whole payload for the remaining ranges + throw; + } + } - if (m_Storage.CacheStorage && m_Options.PopulateCache) + if (!m_AbortFlag) + { + if (RangeBuffers.PayloadBuffer) + { + if (RangeBuffers.Ranges.empty()) { - m_Storage.CacheStorage->PutBuildBlob(m_BuildId, - BlockDescription.BlockHash, - ZenContentType::kCompressedBinary, - CompositeBuffer(std::vector<IoBuffer>{RangeBuffers.PayloadBuffer})); - if (m_AbortFlag) + // Jupiter will ignore the ranges and send the whole payload if it fetches the payload from S3 + // Upload to cache (if enabled) and use the whole payload for the remaining ranges + + if (m_Storage.CacheStorage && m_Options.PopulateCache) { - break; + m_Storage.CacheStorage->PutBuildBlob(m_BuildId, + BlockDescription.BlockHash, + ZenContentType::kCompressedBinary, + CompositeBuffer(std::vector<IoBuffer>{RangeBuffers.PayloadBuffer})); + if (m_AbortFlag) + { + break; + } } - } - SubRangeCount = Ranges.size() - SubRangeCountComplete; - ProcessDownload(BlockDescription, - std::move(RangeBuffers.PayloadBuffer), - SubRangeStartIndex, - RangesSpan.subspan(SubRangeCountComplete, SubRangeCount), - TotalRequestCount, - FilteredDownloadedBytesPerSecond, - OnDownloaded); + SubRangeCount = Ranges.size() - SubRangeCountComplete; + ProcessDownload(BlockDescription, + std::move(RangeBuffers.PayloadBuffer), + SubRangeStartIndex, + RangesSpan.subspan(SubRangeCountComplete, SubRangeCount), + TotalRequestCount, + FilteredDownloadedBytesPerSecond, + OnDownloaded); + } + else + { + if (RangeBuffers.Ranges.size() != SubRanges.size()) + { + throw std::runtime_error(fmt::format("Fetching {} ranges from {} resulted in {} ranges", + SubRanges.size(), + BlockDescription.BlockHash, + RangeBuffers.Ranges.size())); + } + ProcessDownload(BlockDescription, + std::move(RangeBuffers.PayloadBuffer), + SubRangeStartIndex, + RangeBuffers.Ranges, + TotalRequestCount, + FilteredDownloadedBytesPerSecond, + OnDownloaded); + } } else { - if (RangeBuffers.Ranges.size() != SubRanges.size()) - { - throw std::runtime_error(fmt::format("Fetching {} ranges from {} resulted in {} ranges", - SubRanges.size(), - BlockDescription.BlockHash, - RangeBuffers.Ranges.size())); - } - ProcessDownload(BlockDescription, - std::move(RangeBuffers.PayloadBuffer), - SubRangeStartIndex, - RangeBuffers.Ranges, - TotalRequestCount, - FilteredDownloadedBytesPerSecond, - OnDownloaded); + throw std::runtime_error( + fmt::format("Block {} is missing when fetching {} ranges", BlockDescription.BlockHash, SubRangeCount)); } - } - else - { - throw std::runtime_error(fmt::format("Block {} is missing when fetching {} ranges", BlockDescription.BlockHash, SubRangeCount)); - } - SubRangeCountComplete += SubRangeCount; + SubRangeCountComplete += SubRangeCount; + } } } @@ -5150,48 +5190,80 @@ BuildsOperationUploadFolder::GenerateBuildBlocks(const ChunkedFolderContent& Payload.GetCompressed()); } - m_Storage.BuildStorage->PutBuildBlob(m_BuildId, - BlockHash, - ZenContentType::kCompressedBinary, - std::move(Payload).GetCompressed()); - UploadStats.BlocksBytes += CompressedBlockSize; - - if (m_Options.IsVerbose) + try { - ZEN_OPERATION_LOG_INFO(m_LogOutput, - "Uploaded block {} ({}) containing {} chunks", - BlockHash, - NiceBytes(CompressedBlockSize), - OutBlocks.BlockDescriptions[BlockIndex].ChunkRawHashes.size()); + m_Storage.BuildStorage->PutBuildBlob(m_BuildId, + BlockHash, + ZenContentType::kCompressedBinary, + std::move(Payload).GetCompressed()); } - - if (m_Storage.CacheStorage && m_Options.PopulateCache) + catch (const std::exception&) { - m_Storage.CacheStorage->PutBlobMetadatas(m_BuildId, - std::vector<IoHash>({BlockHash}), - std::vector<CbObject>({BlockMetaData})); + // Silence http errors due to abort + if (!m_AbortFlag) + { + throw; + } } - bool MetadataSucceeded = - m_Storage.BuildStorage->PutBlockMetadata(m_BuildId, BlockHash, BlockMetaData); - if (MetadataSucceeded) + if (!m_AbortFlag) { + UploadStats.BlocksBytes += CompressedBlockSize; + if (m_Options.IsVerbose) { - ZEN_OPERATION_LOG_INFO(m_LogOutput, - "Uploaded block {} metadata ({})", - BlockHash, - NiceBytes(BlockMetaData.GetSize())); + ZEN_OPERATION_LOG_INFO( + m_LogOutput, + "Uploaded block {} ({}) containing {} chunks", + BlockHash, + NiceBytes(CompressedBlockSize), + OutBlocks.BlockDescriptions[BlockIndex].ChunkRawHashes.size()); } - OutBlocks.MetaDataHasBeenUploaded[BlockIndex] = true; - UploadStats.BlocksBytes += BlockMetaData.GetSize(); - } + if (m_Storage.CacheStorage && m_Options.PopulateCache) + { + m_Storage.CacheStorage->PutBlobMetadatas(m_BuildId, + std::vector<IoHash>({BlockHash}), + std::vector<CbObject>({BlockMetaData})); + } - UploadStats.BlockCount++; - if (UploadStats.BlockCount == NewBlockCount) - { - FilteredUploadedBytesPerSecond.Stop(); + bool MetadataSucceeded = false; + try + { + MetadataSucceeded = + m_Storage.BuildStorage->PutBlockMetadata(m_BuildId, BlockHash, BlockMetaData); + } + catch (const std::exception&) + { + // Silence http errors due to abort + if (!m_AbortFlag) + { + throw; + } + } + + if (!m_AbortFlag) + { + if (MetadataSucceeded) + { + if (m_Options.IsVerbose) + { + ZEN_OPERATION_LOG_INFO(m_LogOutput, + "Uploaded block {} metadata ({})", + BlockHash, + NiceBytes(BlockMetaData.GetSize())); + } + + OutBlocks.MetaDataHasBeenUploaded[BlockIndex] = true; + UploadStats.BlocksBytes += BlockMetaData.GetSize(); + } + + UploadStats.BlockCount++; + if (UploadStats.BlockCount == NewBlockCount) + { + FilteredUploadedBytesPerSecond.Stop(); + } + } } } } @@ -6215,44 +6287,76 @@ BuildsOperationUploadFolder::UploadPartBlobs(const ChunkedFolderContent& Co { m_Storage.CacheStorage->PutBuildBlob(m_BuildId, BlockHash, ZenContentType::kCompressedBinary, Payload); } - m_Storage.BuildStorage->PutBuildBlob(m_BuildId, BlockHash, ZenContentType::kCompressedBinary, Payload); - if (m_Options.IsVerbose) + + try { - ZEN_OPERATION_LOG_INFO(m_LogOutput, - "Uploaded block {} ({}) containing {} chunks", - BlockHash, - NiceBytes(PayloadSize), - NewBlocks.BlockDescriptions[BlockIndex].ChunkRawHashes.size()); + m_Storage.BuildStorage->PutBuildBlob(m_BuildId, BlockHash, ZenContentType::kCompressedBinary, Payload); } - UploadedBlockSize += PayloadSize; - TempUploadStats.BlocksBytes += PayloadSize; - - if (m_Storage.CacheStorage && m_Options.PopulateCache) + catch (const std::exception&) { - m_Storage.CacheStorage->PutBlobMetadatas(m_BuildId, - std::vector<IoHash>({BlockHash}), - std::vector<CbObject>({BlockMetaData})); + // Silence http errors due to abort + if (!m_AbortFlag) + { + throw; + } } - bool MetadataSucceeded = m_Storage.BuildStorage->PutBlockMetadata(m_BuildId, BlockHash, BlockMetaData); - if (MetadataSucceeded) + + if (!m_AbortFlag) { if (m_Options.IsVerbose) { ZEN_OPERATION_LOG_INFO(m_LogOutput, - "Uploaded block {} metadata ({})", + "Uploaded block {} ({}) containing {} chunks", BlockHash, - NiceBytes(BlockMetaData.GetSize())); + NiceBytes(PayloadSize), + NewBlocks.BlockDescriptions[BlockIndex].ChunkRawHashes.size()); } + UploadedBlockSize += PayloadSize; + TempUploadStats.BlocksBytes += PayloadSize; - NewBlocks.MetaDataHasBeenUploaded[BlockIndex] = true; - TempUploadStats.BlocksBytes += BlockMetaData.GetSize(); - } + if (m_Storage.CacheStorage && m_Options.PopulateCache) + { + m_Storage.CacheStorage->PutBlobMetadatas(m_BuildId, + std::vector<IoHash>({BlockHash}), + std::vector<CbObject>({BlockMetaData})); + } + + bool MetadataSucceeded = false; + try + { + MetadataSucceeded = m_Storage.BuildStorage->PutBlockMetadata(m_BuildId, BlockHash, BlockMetaData); + } + catch (const std::exception&) + { + // Silence http errors due to abort + if (!m_AbortFlag) + { + throw; + } + } + if (!m_AbortFlag) + { + if (MetadataSucceeded) + { + if (m_Options.IsVerbose) + { + ZEN_OPERATION_LOG_INFO(m_LogOutput, + "Uploaded block {} metadata ({})", + BlockHash, + NiceBytes(BlockMetaData.GetSize())); + } + + NewBlocks.MetaDataHasBeenUploaded[BlockIndex] = true; + TempUploadStats.BlocksBytes += BlockMetaData.GetSize(); + } - TempUploadStats.BlockCount++; + TempUploadStats.BlockCount++; - if (UploadedBlockCount.fetch_add(1) + 1 == UploadBlockCount && UploadedChunkCount == UploadChunkCount) - { - FilteredUploadedBytesPerSecond.Stop(); + if (UploadedBlockCount.fetch_add(1) + 1 == UploadBlockCount && UploadedChunkCount == UploadChunkCount) + { + FilteredUploadedBytesPerSecond.Stop(); + } + } } } }); @@ -6302,72 +6406,100 @@ BuildsOperationUploadFolder::UploadPartBlobs(const ChunkedFolderContent& Co { ZEN_TRACE_CPU("AsyncUploadLooseChunk_Multipart"); TempUploadStats.MultipartAttachmentCount++; - std::vector<std::function<void()>> MultipartWork = m_Storage.BuildStorage->PutLargeBuildBlob( - m_BuildId, - RawHash, - ZenContentType::kCompressedBinary, - PayloadSize, - [Payload = std::move(Payload), &FilteredUploadedBytesPerSecond](uint64_t Offset, - uint64_t Size) mutable -> IoBuffer { - FilteredUploadedBytesPerSecond.Start(); - - IoBuffer PartPayload = Payload.Mid(Offset, Size).Flatten().AsIoBuffer(); - PartPayload.SetContentType(ZenContentType::kBinary); - return PartPayload; - }, - [RawSize, - &TempUploadStats, - &UploadedCompressedChunkSize, - &UploadChunkPool, - &UploadedBlockCount, - UploadBlockCount, - &UploadedChunkCount, - UploadChunkCount, - &FilteredUploadedBytesPerSecond, - &UploadedRawChunkSize](uint64_t SentBytes, bool IsComplete) { - TempUploadStats.ChunksBytes += SentBytes; - UploadedCompressedChunkSize += SentBytes; - if (IsComplete) - { - TempUploadStats.ChunkCount++; - if (UploadedChunkCount.fetch_add(1) + 1 == UploadChunkCount && - UploadedBlockCount == UploadBlockCount) + try + { + std::vector<std::function<void()>> MultipartWork = m_Storage.BuildStorage->PutLargeBuildBlob( + m_BuildId, + RawHash, + ZenContentType::kCompressedBinary, + PayloadSize, + [Payload = std::move(Payload), &FilteredUploadedBytesPerSecond](uint64_t Offset, + uint64_t Size) mutable -> IoBuffer { + FilteredUploadedBytesPerSecond.Start(); + + IoBuffer PartPayload = Payload.Mid(Offset, Size).Flatten().AsIoBuffer(); + PartPayload.SetContentType(ZenContentType::kBinary); + return PartPayload; + }, + [RawSize, + &TempUploadStats, + &UploadedCompressedChunkSize, + &UploadChunkPool, + &UploadedBlockCount, + UploadBlockCount, + &UploadedChunkCount, + UploadChunkCount, + &FilteredUploadedBytesPerSecond, + &UploadedRawChunkSize](uint64_t SentBytes, bool IsComplete) { + TempUploadStats.ChunksBytes += SentBytes; + UploadedCompressedChunkSize += SentBytes; + if (IsComplete) { - FilteredUploadedBytesPerSecond.Stop(); + TempUploadStats.ChunkCount++; + if (UploadedChunkCount.fetch_add(1) + 1 == UploadChunkCount && + UploadedBlockCount == UploadBlockCount) + { + FilteredUploadedBytesPerSecond.Stop(); + } + UploadedRawChunkSize += RawSize; } - UploadedRawChunkSize += RawSize; - } - }); - for (auto& WorkPart : MultipartWork) - { - Work.ScheduleWork(UploadChunkPool, [Work = std::move(WorkPart)](std::atomic<bool>& AbortFlag) { - ZEN_TRACE_CPU("AsyncUploadLooseChunk_Multipart_Work"); - if (!AbortFlag) - { - Work(); - } - }); + }); + for (auto& WorkPart : MultipartWork) + { + Work.ScheduleWork(UploadChunkPool, [Work = std::move(WorkPart)](std::atomic<bool>& AbortFlag) { + ZEN_TRACE_CPU("AsyncUploadLooseChunk_Multipart_Work"); + if (!AbortFlag) + { + Work(); + } + }); + } + if (m_Options.IsVerbose) + { + ZEN_OPERATION_LOG_INFO(m_LogOutput, + "Uploaded multipart chunk {} ({})", + RawHash, + NiceBytes(PayloadSize)); + } } - if (m_Options.IsVerbose) + catch (const std::exception&) { - ZEN_OPERATION_LOG_INFO(m_LogOutput, "Uploaded multipart chunk {} ({})", RawHash, NiceBytes(PayloadSize)); + // Silence http errors due to abort + if (!m_AbortFlag) + { + throw; + } } } else { ZEN_TRACE_CPU("AsyncUploadLooseChunk_Singlepart"); - m_Storage.BuildStorage->PutBuildBlob(m_BuildId, RawHash, ZenContentType::kCompressedBinary, Payload); - if (m_Options.IsVerbose) + try { - ZEN_OPERATION_LOG_INFO(m_LogOutput, "Uploaded chunk {} ({})", RawHash, NiceBytes(PayloadSize)); + m_Storage.BuildStorage->PutBuildBlob(m_BuildId, RawHash, ZenContentType::kCompressedBinary, Payload); + } + catch (const std::exception&) + { + // Silence http errors due to abort + if (!m_AbortFlag) + { + throw; + } } - TempUploadStats.ChunksBytes += Payload.GetSize(); - TempUploadStats.ChunkCount++; - UploadedCompressedChunkSize += Payload.GetSize(); - UploadedRawChunkSize += RawSize; - if (UploadedChunkCount.fetch_add(1) + 1 == UploadChunkCount && UploadedBlockCount == UploadBlockCount) + if (!m_AbortFlag) { - FilteredUploadedBytesPerSecond.Stop(); + if (m_Options.IsVerbose) + { + ZEN_OPERATION_LOG_INFO(m_LogOutput, "Uploaded chunk {} ({})", RawHash, NiceBytes(PayloadSize)); + } + TempUploadStats.ChunksBytes += Payload.GetSize(); + TempUploadStats.ChunkCount++; + UploadedCompressedChunkSize += Payload.GetSize(); + UploadedRawChunkSize += RawSize; + if (UploadedChunkCount.fetch_add(1) + 1 == UploadChunkCount && UploadedBlockCount == UploadBlockCount) + { + FilteredUploadedBytesPerSecond.Stop(); + } } } } @@ -7147,10 +7279,23 @@ BuildsOperationPrimeCache::Execute() } else { - IoBuffer Payload = m_Storage.BuildStorage->GetBuildBlob(m_BuildId, BlobHash); - m_DownloadStats.DownloadedBlockCount++; - m_DownloadStats.DownloadedBlockByteCount += Payload.GetSize(); - m_DownloadStats.RequestsCompleteCount++; + IoBuffer Payload; + try + { + Payload = m_Storage.BuildStorage->GetBuildBlob(m_BuildId, BlobHash); + + m_DownloadStats.DownloadedBlockCount++; + m_DownloadStats.DownloadedBlockByteCount += Payload.GetSize(); + m_DownloadStats.RequestsCompleteCount++; + } + catch (const std::exception&) + { + // Silence http errors due to abort + if (!m_AbortFlag) + { + throw; + } + } if (!m_AbortFlag) { @@ -7161,10 +7306,10 @@ BuildsOperationPrimeCache::Execute() ZenContentType::kCompressedBinary, CompositeBuffer(SharedBuffer(std::move(Payload)))); } - } - if (CompletedDownloadCount.fetch_add(1) + 1 == BlobCount) - { - FilteredDownloadedBytesPerSecond.Stop(); + if (CompletedDownloadCount.fetch_add(1) + 1 == BlobCount) + { + FilteredDownloadedBytesPerSecond.Stop(); + } } } } diff --git a/src/zenserver/compute/computeserver.cpp b/src/zenserver/compute/computeserver.cpp index 7296098e0..8cd8b4cfe 100644 --- a/src/zenserver/compute/computeserver.cpp +++ b/src/zenserver/compute/computeserver.cpp @@ -22,6 +22,8 @@ # if ZEN_WITH_HORDE # include <zenhorde/hordeconfig.h> # include <zenhorde/hordeprovisioner.h> +# include <zenhttp/httpclientauth.h> +# include <zenutil/authutils.h> # endif # if ZEN_WITH_NOMAD # include <zennomad/nomadconfig.h> @@ -67,6 +69,20 @@ ZenComputeServerConfigurator::AddCliOptions(cxxopts::Options& Options) Options.add_option("compute", "", + "coordinator-session", + "Session ID of the orchestrator (for stale-instance rejection)", + cxxopts::value<std::string>(m_ServerOptions.CoordinatorSession)->default_value(""), + ""); + + Options.add_option("compute", + "", + "announce-url", + "Override URL announced to the coordinator (e.g. relay-visible endpoint)", + cxxopts::value<std::string>(m_ServerOptions.AnnounceUrl)->default_value(""), + ""); + + Options.add_option("compute", + "", "idms", "Enable IDMS cloud detection; optionally specify a custom probe endpoint", cxxopts::value<std::string>(m_ServerOptions.IdmsEndpoint)->default_value("")->implicit_value("auto"), @@ -79,6 +95,20 @@ ZenComputeServerConfigurator::AddCliOptions(cxxopts::Options& Options) cxxopts::value<bool>(m_ServerOptions.EnableWorkerWebSocket)->default_value("false"), ""); + Options.add_option("compute", + "", + "provision-clean", + "Pass --clean to provisioned worker instances so they wipe state on startup", + cxxopts::value<bool>(m_ServerOptions.ProvisionClean)->default_value("false"), + ""); + + Options.add_option("compute", + "", + "provision-tracehost", + "Pass --tracehost to provisioned worker instances for remote trace collection", + cxxopts::value<std::string>(m_ServerOptions.ProvisionTraceHost)->default_value(""), + ""); + # if ZEN_WITH_HORDE // Horde provisioning options Options.add_option("horde", @@ -139,6 +169,13 @@ ZenComputeServerConfigurator::AddCliOptions(cxxopts::Options& Options) Options.add_option("horde", "", + "horde-drain-grace-period", + "Grace period in seconds for draining agents before force-kill", + cxxopts::value<int>(m_ServerOptions.HordeConfig.DrainGracePeriodSeconds)->default_value("300"), + ""); + + Options.add_option("horde", + "", "horde-host", "Host address for Horde agents to connect back to", cxxopts::value<std::string>(m_ServerOptions.HordeConfig.HostAddress)->default_value(""), @@ -164,6 +201,13 @@ ZenComputeServerConfigurator::AddCliOptions(cxxopts::Options& Options) "Port number for Zen service communication", cxxopts::value<uint16_t>(m_ServerOptions.HordeConfig.ZenServicePort)->default_value("8558"), ""); + + Options.add_option("horde", + "", + "horde-oidctoken-exe-path", + "Path to OidcToken executable for automatic Horde authentication", + cxxopts::value<std::string>(m_HordeOidcTokenExePath)->default_value(""), + ""); # endif # if ZEN_WITH_NOMAD @@ -313,6 +357,30 @@ ZenComputeServerConfigurator::ValidateOptions() # if ZEN_WITH_HORDE horde::FromString(m_ServerOptions.HordeConfig.Mode, m_HordeModeStr); horde::FromString(m_ServerOptions.HordeConfig.EncryptionMode, m_HordeEncryptionStr); + + // Set up OidcToken-based authentication if no static token was provided + if (m_ServerOptions.HordeConfig.AuthToken.empty() && !m_ServerOptions.HordeConfig.ServerUrl.empty()) + { + std::filesystem::path OidcExePath = FindOidcTokenExePath(m_HordeOidcTokenExePath); + if (!OidcExePath.empty()) + { + ZEN_INFO("using OidcToken executable for Horde authentication: {}", OidcExePath); + auto Provider = httpclientauth::CreateFromOidcTokenExecutable(OidcExePath, + m_ServerOptions.HordeConfig.ServerUrl, + /*Quiet=*/true, + /*Unattended=*/false, + /*Hidden=*/true, + /*IsHordeUrl=*/true); + if (Provider) + { + m_ServerOptions.HordeConfig.AccessTokenProvider = std::move(*Provider); + } + else + { + ZEN_WARN("OidcToken authentication failed; Horde requests will be unauthenticated"); + } + } + } # endif # if ZEN_WITH_NOMAD @@ -347,6 +415,8 @@ ZenComputeServer::Initialize(const ZenComputeServerConfig& ServerConfig, ZenServ } m_CoordinatorEndpoint = ServerConfig.CoordinatorEndpoint; + m_CoordinatorSession = ServerConfig.CoordinatorSession; + m_AnnounceUrl = ServerConfig.AnnounceUrl; m_InstanceId = ServerConfig.InstanceId; m_EnableWorkerWebSocket = ServerConfig.EnableWorkerWebSocket; @@ -379,7 +449,14 @@ ZenComputeServer::Cleanup() m_AnnounceTimer.cancel(); # if ZEN_WITH_HORDE - // Shut down Horde provisioner first — this signals all agent threads + // Disconnect the provisioner state provider before destroying the + // provisioner so the orchestrator HTTP layer cannot call into it. + if (m_OrchestratorService) + { + m_OrchestratorService->SetProvisionerStateProvider(nullptr); + } + + // Shut down Horde provisioner — this signals all agent threads // to exit and joins them before we tear down HTTP services. m_HordeProvisioner.reset(); # endif @@ -482,6 +559,7 @@ ZenComputeServer::InitializeServices(const ZenComputeServerConfig& ServerConfig) m_StatsService, ServerConfig.DataDir / "functions", ServerConfig.MaxConcurrentActions); + m_ComputeService->SetShutdownCallback([this] { RequestExit(0); }); m_FrontendService = std::make_unique<HttpFrontendService>(m_ContentRoot, m_StatsService, m_StatusService); @@ -506,7 +584,11 @@ ZenComputeServer::InitializeServices(const ZenComputeServerConfig& ServerConfig) OrchestratorEndpoint << '/'; } - m_NomadProvisioner = std::make_unique<nomad::NomadProvisioner>(NomadCfg, OrchestratorEndpoint); + m_NomadProvisioner = std::make_unique<nomad::NomadProvisioner>(NomadCfg, + OrchestratorEndpoint, + m_OrchestratorService->GetSessionId().ToString(), + ServerConfig.ProvisionClean, + ServerConfig.ProvisionTraceHost); } } # endif @@ -537,7 +619,14 @@ ZenComputeServer::InitializeServices(const ZenComputeServerConfig& ServerConfig) : std::filesystem::path(HordeConfig.BinariesPath); std::filesystem::path WorkingDir = ServerConfig.DataDir / "horde"; - m_HordeProvisioner = std::make_unique<horde::HordeProvisioner>(HordeConfig, BinariesPath, WorkingDir, OrchestratorEndpoint); + m_HordeProvisioner = std::make_unique<horde::HordeProvisioner>(HordeConfig, + BinariesPath, + WorkingDir, + OrchestratorEndpoint, + m_OrchestratorService->GetSessionId().ToString(), + ServerConfig.ProvisionClean, + ServerConfig.ProvisionTraceHost); + m_OrchestratorService->SetProvisionerStateProvider(m_HordeProvisioner.get()); } } # endif @@ -565,6 +654,10 @@ ZenComputeServer::GetInstanceId() const std::string ZenComputeServer::GetAnnounceUrl() const { + if (!m_AnnounceUrl.empty()) + { + return m_AnnounceUrl; + } return m_Http->GetServiceUri(nullptr); } @@ -635,6 +728,11 @@ ZenComputeServer::BuildAnnounceBody() << "nomad"; } + if (!m_CoordinatorSession.empty()) + { + AnnounceBody << "coordinator_session" << m_CoordinatorSession; + } + ResolveCloudMetadata(); if (m_CloudMetadata) { @@ -781,8 +879,10 @@ ZenComputeServer::ProvisionerMaintenanceTick() # if ZEN_WITH_HORDE if (m_HordeProvisioner) { - m_HordeProvisioner->SetTargetCoreCount(UINT32_MAX); + // Re-apply current target to spawn agent threads for any that have + // exited since the last tick, without overwriting a user-set target. auto Stats = m_HordeProvisioner->GetStats(); + m_HordeProvisioner->SetTargetCoreCount(Stats.TargetCoreCount); ZEN_DEBUG("Horde maintenance: target={}, estimated={}, active={}", Stats.TargetCoreCount, Stats.EstimatedCoreCount, diff --git a/src/zenserver/compute/computeserver.h b/src/zenserver/compute/computeserver.h index 38f93bc36..63db7e9b3 100644 --- a/src/zenserver/compute/computeserver.h +++ b/src/zenserver/compute/computeserver.h @@ -49,9 +49,13 @@ struct ZenComputeServerConfig : public ZenServerConfig std::string UpstreamNotificationEndpoint; std::string InstanceId; // For use in notifications std::string CoordinatorEndpoint; + std::string CoordinatorSession; ///< Session ID for stale-instance rejection + std::string AnnounceUrl; ///< Override for self-announced URL (e.g. relay-visible endpoint) std::string IdmsEndpoint; int32_t MaxConcurrentActions = 0; // 0 = auto (LogicalProcessorCount * 2) bool EnableWorkerWebSocket = false; // Use WebSocket for worker↔orchestrator link + bool ProvisionClean = false; // Pass --clean to provisioned workers + std::string ProvisionTraceHost; // Pass --tracehost to provisioned workers # if ZEN_WITH_HORDE horde::HordeConfig HordeConfig; @@ -84,6 +88,7 @@ private: # if ZEN_WITH_HORDE std::string m_HordeModeStr = "direct"; std::string m_HordeEncryptionStr = "none"; + std::string m_HordeOidcTokenExePath; # endif # if ZEN_WITH_NOMAD @@ -147,6 +152,8 @@ private: # endif SystemMetricsTracker m_MetricsTracker; std::string m_CoordinatorEndpoint; + std::string m_CoordinatorSession; + std::string m_AnnounceUrl; std::string m_InstanceId; asio::steady_timer m_AnnounceTimer{m_IoContext}; diff --git a/src/zenserver/config/config.cpp b/src/zenserver/config/config.cpp index daad154bc..6449159fd 100644 --- a/src/zenserver/config/config.cpp +++ b/src/zenserver/config/config.cpp @@ -12,6 +12,7 @@ #include <zencore/compactbinaryutil.h> #include <zencore/compactbinaryvalidation.h> #include <zencore/except.h> +#include <zencore/filesystem.h> #include <zencore/fmtutils.h> #include <zencore/iobuffer.h> #include <zencore/logging.h> @@ -478,15 +479,27 @@ ZenServerCmdLineOptions::ApplyOptions(cxxopts::Options& options, ZenServerConfig throw std::runtime_error(fmt::format("'--snapshot-dir' ('{}') must be a directory", ServerOptions.BaseSnapshotDir)); } - ServerOptions.SystemRootDir = MakeSafeAbsolutePath(SystemRootDir); - ServerOptions.DataDir = MakeSafeAbsolutePath(DataDir); - ServerOptions.ContentDir = MakeSafeAbsolutePath(ContentDir); - ServerOptions.ConfigFile = MakeSafeAbsolutePath(ConfigFile); - ServerOptions.BaseSnapshotDir = MakeSafeAbsolutePath(BaseSnapshotDir); + SystemRootDir = ExpandEnvironmentVariables(SystemRootDir); + ServerOptions.SystemRootDir = MakeSafeAbsolutePath(SystemRootDir); + + DataDir = ExpandEnvironmentVariables(DataDir); + ServerOptions.DataDir = MakeSafeAbsolutePath(DataDir); + + ContentDir = ExpandEnvironmentVariables(ContentDir); + ServerOptions.ContentDir = MakeSafeAbsolutePath(ContentDir); + + ConfigFile = ExpandEnvironmentVariables(ConfigFile); + ServerOptions.ConfigFile = MakeSafeAbsolutePath(ConfigFile); + + BaseSnapshotDir = ExpandEnvironmentVariables(BaseSnapshotDir); + ServerOptions.BaseSnapshotDir = MakeSafeAbsolutePath(BaseSnapshotDir); + + ExpandEnvironmentVariables(SecurityConfigPath); ServerOptions.SecurityConfigPath = MakeSafeAbsolutePath(SecurityConfigPath); if (!UnixSocketPath.empty()) { + UnixSocketPath = ExpandEnvironmentVariables(UnixSocketPath); ServerOptions.HttpConfig.UnixSocketPath = MakeSafeAbsolutePath(UnixSocketPath); } diff --git a/src/zenserver/frontend/html/compute/compute.html b/src/zenserver/frontend/html/compute/compute.html deleted file mode 100644 index c07bbb692..000000000 --- a/src/zenserver/frontend/html/compute/compute.html +++ /dev/null @@ -1,925 +0,0 @@ -<!DOCTYPE html> -<html lang="en"> -<head> - <meta charset="UTF-8"> - <meta name="viewport" content="width=device-width, initial-scale=1.0"> - <title>Zen Compute Dashboard</title> - <script src="https://cdn.jsdelivr.net/npm/[email protected]/dist/chart.umd.min.js"></script> - <link rel="stylesheet" type="text/css" href="../zen.css" /> - <script src="../util/sanitize.js"></script> - <script src="../theme.js"></script> - <script src="../banner.js" defer></script> - <script src="../nav.js" defer></script> - <style> - .grid { - grid-template-columns: repeat(auto-fit, minmax(280px, 1fr)); - } - - .chart-container { - position: relative; - height: 300px; - margin-top: 20px; - } - - .stats-row { - display: flex; - justify-content: space-between; - margin-bottom: 12px; - padding: 8px 0; - border-bottom: 1px solid var(--theme_border_subtle); - } - - .stats-row:last-child { - border-bottom: none; - margin-bottom: 0; - } - - .stats-label { - color: var(--theme_g1); - font-size: 13px; - } - - .stats-value { - color: var(--theme_bright); - font-weight: 600; - font-size: 13px; - } - - .rate-stats { - display: grid; - grid-template-columns: repeat(3, 1fr); - gap: 16px; - margin-top: 16px; - } - - .rate-item { - text-align: center; - } - - .rate-value { - font-size: 20px; - font-weight: 600; - color: var(--theme_p0); - } - - .rate-label { - font-size: 11px; - color: var(--theme_g1); - margin-top: 4px; - text-transform: uppercase; - } - - .worker-row { - cursor: pointer; - transition: background 0.15s; - } - - .worker-row:hover { - background: var(--theme_p4); - } - - .worker-row.selected { - background: var(--theme_p3); - } - - .worker-detail { - margin-top: 20px; - border-top: 1px solid var(--theme_g2); - padding-top: 16px; - } - - .worker-detail-title { - font-size: 15px; - font-weight: 600; - color: var(--theme_bright); - margin-bottom: 12px; - } - - .detail-section { - margin-bottom: 16px; - } - - .detail-section-label { - font-size: 11px; - font-weight: 600; - color: var(--theme_g1); - text-transform: uppercase; - letter-spacing: 0.5px; - margin-bottom: 6px; - } - - .detail-table { - width: 100%; - border-collapse: collapse; - font-size: 12px; - } - - .detail-table td { - padding: 4px 8px; - color: var(--theme_g0); - border-bottom: 1px solid var(--theme_border_subtle); - vertical-align: top; - } - - .detail-table td:first-child { - color: var(--theme_g1); - width: 40%; - font-family: monospace; - } - - .detail-table tr:last-child td { - border-bottom: none; - } - - .detail-mono { - font-family: monospace; - font-size: 11px; - color: var(--theme_g1); - } - - .detail-tag { - display: inline-block; - padding: 2px 8px; - border-radius: 4px; - background: var(--theme_border_subtle); - color: var(--theme_g0); - font-size: 11px; - margin: 2px 4px 2px 0; - } - </style> -</head> -<body> - <div class="container" style="max-width: 1400px; margin: 0 auto;"> - <zen-banner cluster-status="nominal" load="0" tagline="Node Overview" logo-src="../favicon.ico"></zen-banner> - <zen-nav> - <a href="/dashboard/">Home</a> - <a href="compute.html">Node</a> - <a href="orchestrator.html">Orchestrator</a> - </zen-nav> - <div class="timestamp">Last updated: <span id="last-update">Never</span></div> - - <div id="error-container"></div> - - <!-- Action Queue Stats --> - <div class="section-title">Action Queue</div> - <div class="grid"> - <div class="card"> - <div class="card-title">Pending Actions</div> - <div class="metric-value" id="actions-pending">-</div> - <div class="metric-label">Waiting to be scheduled</div> - </div> - <div class="card"> - <div class="card-title">Running Actions</div> - <div class="metric-value" id="actions-running">-</div> - <div class="metric-label">Currently executing</div> - </div> - <div class="card"> - <div class="card-title">Completed Actions</div> - <div class="metric-value" id="actions-complete">-</div> - <div class="metric-label">Results available</div> - </div> - </div> - - <!-- Action Queue Chart --> - <div class="card" style="margin-bottom: 30px;"> - <div class="card-title">Action Queue History</div> - <div class="chart-container"> - <canvas id="queue-chart"></canvas> - </div> - </div> - - <!-- Performance Metrics --> - <div class="section-title">Performance Metrics</div> - <div class="card" style="margin-bottom: 30px;"> - <div class="card-title">Completion Rate</div> - <div class="rate-stats"> - <div class="rate-item"> - <div class="rate-value" id="rate-1">-</div> - <div class="rate-label">1 min rate</div> - </div> - <div class="rate-item"> - <div class="rate-value" id="rate-5">-</div> - <div class="rate-label">5 min rate</div> - </div> - <div class="rate-item"> - <div class="rate-value" id="rate-15">-</div> - <div class="rate-label">15 min rate</div> - </div> - </div> - <div style="margin-top: 20px;"> - <div class="stats-row"> - <span class="stats-label">Total Retired</span> - <span class="stats-value" id="retired-count">-</span> - </div> - <div class="stats-row"> - <span class="stats-label">Mean Rate</span> - <span class="stats-value" id="rate-mean">-</span> - </div> - </div> - </div> - - <!-- Workers --> - <div class="section-title">Workers</div> - <div class="card" style="margin-bottom: 30px;"> - <div class="card-title">Worker Status</div> - <div class="stats-row"> - <span class="stats-label">Registered Workers</span> - <span class="stats-value" id="worker-count">-</span> - </div> - <div id="worker-table-container" style="margin-top: 16px; display: none;"> - <table id="worker-table"> - <thead> - <tr> - <th>Name</th> - <th>Platform</th> - <th style="text-align: right;">Cores</th> - <th style="text-align: right;">Timeout</th> - <th style="text-align: right;">Functions</th> - <th>Worker ID</th> - </tr> - </thead> - <tbody id="worker-table-body"></tbody> - </table> - <div id="worker-detail" class="worker-detail" style="display: none;"></div> - </div> - </div> - - <!-- Queues --> - <div class="section-title">Queues</div> - <div class="card" style="margin-bottom: 30px;"> - <div class="card-title">Queue Status</div> - <div id="queue-list-empty" class="empty-state" style="text-align: left;">No queues.</div> - <div id="queue-list-container" style="display: none;"> - <table id="queue-list-table"> - <thead> - <tr> - <th style="text-align: right; width: 60px;">ID</th> - <th style="text-align: center; width: 80px;">Status</th> - <th style="text-align: right;">Active</th> - <th style="text-align: right;">Completed</th> - <th style="text-align: right;">Failed</th> - <th style="text-align: right;">Abandoned</th> - <th style="text-align: right;">Cancelled</th> - <th>Token</th> - </tr> - </thead> - <tbody id="queue-list-body"></tbody> - </table> - </div> - </div> - - <!-- Action History --> - <div class="section-title">Recent Actions</div> - <div class="card" style="margin-bottom: 30px;"> - <div class="card-title">Action History</div> - <div id="action-history-empty" class="empty-state" style="text-align: left;">No actions recorded yet.</div> - <div id="action-history-container" style="display: none;"> - <table id="action-history-table"> - <thead> - <tr> - <th style="text-align: right; width: 60px;">LSN</th> - <th style="text-align: right; width: 60px;">Queue</th> - <th style="text-align: center; width: 70px;">Status</th> - <th>Function</th> - <th style="text-align: right; width: 80px;">Started</th> - <th style="text-align: right; width: 80px;">Finished</th> - <th style="text-align: right; width: 80px;">Duration</th> - <th>Worker ID</th> - <th>Action ID</th> - </tr> - </thead> - <tbody id="action-history-body"></tbody> - </table> - </div> - </div> - - <!-- System Resources --> - <div class="section-title">System Resources</div> - <div class="grid"> - <div class="card"> - <div class="card-title">CPU Usage</div> - <div class="metric-value" id="cpu-usage">-</div> - <div class="metric-label">Percent</div> - <div class="progress-bar"> - <div class="progress-fill" id="cpu-progress" style="width: 0%"></div> - </div> - <div style="position: relative; height: 60px; margin-top: 12px;"> - <canvas id="cpu-chart"></canvas> - </div> - <div style="margin-top: 12px;"> - <div class="stats-row"> - <span class="stats-label">Packages</span> - <span class="stats-value" id="cpu-packages">-</span> - </div> - <div class="stats-row"> - <span class="stats-label">Physical Cores</span> - <span class="stats-value" id="cpu-cores">-</span> - </div> - <div class="stats-row"> - <span class="stats-label">Logical Processors</span> - <span class="stats-value" id="cpu-lp">-</span> - </div> - </div> - </div> - <div class="card"> - <div class="card-title">Memory</div> - <div class="stats-row"> - <span class="stats-label">Used</span> - <span class="stats-value" id="memory-used">-</span> - </div> - <div class="stats-row"> - <span class="stats-label">Total</span> - <span class="stats-value" id="memory-total">-</span> - </div> - <div class="progress-bar"> - <div class="progress-fill" id="memory-progress" style="width: 0%"></div> - </div> - </div> - <div class="card"> - <div class="card-title">Disk</div> - <div class="stats-row"> - <span class="stats-label">Used</span> - <span class="stats-value" id="disk-used">-</span> - </div> - <div class="stats-row"> - <span class="stats-label">Total</span> - <span class="stats-value" id="disk-total">-</span> - </div> - <div class="progress-bar"> - <div class="progress-fill" id="disk-progress" style="width: 0%"></div> - </div> - </div> - </div> - </div> - - <script> - // Configuration - const BASE_URL = window.location.origin; - const REFRESH_INTERVAL = 2000; // 2 seconds - const MAX_HISTORY_POINTS = 60; // Show last 2 minutes - - // Data storage - const history = { - timestamps: [], - pending: [], - running: [], - completed: [], - cpu: [] - }; - - // CPU sparkline chart - const cpuCtx = document.getElementById('cpu-chart').getContext('2d'); - const cpuChart = new Chart(cpuCtx, { - type: 'line', - data: { - labels: [], - datasets: [{ - data: [], - borderColor: '#58a6ff', - backgroundColor: 'rgba(88, 166, 255, 0.15)', - borderWidth: 1.5, - tension: 0.4, - fill: true, - pointRadius: 0 - }] - }, - options: { - responsive: true, - maintainAspectRatio: false, - animation: false, - plugins: { legend: { display: false }, tooltip: { enabled: false } }, - scales: { - x: { display: false }, - y: { display: false, min: 0, max: 100 } - } - } - }); - - // Queue chart setup - const ctx = document.getElementById('queue-chart').getContext('2d'); - const chart = new Chart(ctx, { - type: 'line', - data: { - labels: [], - datasets: [ - { - label: 'Pending', - data: [], - borderColor: '#f0883e', - backgroundColor: 'rgba(240, 136, 62, 0.1)', - tension: 0.4, - fill: true - }, - { - label: 'Running', - data: [], - borderColor: '#58a6ff', - backgroundColor: 'rgba(88, 166, 255, 0.1)', - tension: 0.4, - fill: true - }, - { - label: 'Completed', - data: [], - borderColor: '#3fb950', - backgroundColor: 'rgba(63, 185, 80, 0.1)', - tension: 0.4, - fill: true - } - ] - }, - options: { - responsive: true, - maintainAspectRatio: false, - plugins: { - legend: { - display: true, - labels: { - color: '#8b949e' - } - } - }, - scales: { - x: { - display: false - }, - y: { - beginAtZero: true, - ticks: { - color: '#8b949e' - }, - grid: { - color: '#21262d' - } - } - } - } - }); - - // Helper functions - - function formatBytes(bytes) { - if (bytes === 0) return '0 B'; - const k = 1024; - const sizes = ['B', 'KB', 'MB', 'GB', 'TB']; - const i = Math.floor(Math.log(bytes) / Math.log(k)); - return parseFloat((bytes / Math.pow(k, i)).toFixed(2)) + ' ' + sizes[i]; - } - - function formatRate(rate) { - return rate.toFixed(2) + '/s'; - } - - function showError(message) { - const container = document.getElementById('error-container'); - container.innerHTML = `<div class="error">Error: ${escapeHtml(message)}</div>`; - } - - function clearError() { - document.getElementById('error-container').innerHTML = ''; - } - - function updateTimestamp() { - const now = new Date(); - document.getElementById('last-update').textContent = now.toLocaleTimeString(); - } - - // Fetch functions - async function fetchJSON(endpoint) { - const response = await fetch(`${BASE_URL}${endpoint}`, { - headers: { - 'Accept': 'application/json' - } - }); - if (!response.ok) { - throw new Error(`HTTP ${response.status}: ${response.statusText}`); - } - return await response.json(); - } - - async function fetchHealth() { - try { - const response = await fetch(`${BASE_URL}/compute/ready`); - const isHealthy = response.status === 200; - - const banner = document.querySelector('zen-banner'); - - if (isHealthy) { - banner.setAttribute('cluster-status', 'nominal'); - banner.setAttribute('load', '0'); - } else { - banner.setAttribute('cluster-status', 'degraded'); - banner.setAttribute('load', '0'); - } - - return isHealthy; - } catch (error) { - const banner = document.querySelector('zen-banner'); - banner.setAttribute('cluster-status', 'degraded'); - banner.setAttribute('load', '0'); - throw error; - } - } - - async function fetchStats() { - const data = await fetchJSON('/stats/compute'); - - // Update action counts - document.getElementById('actions-pending').textContent = data.actions_pending || 0; - document.getElementById('actions-running').textContent = data.actions_submitted || 0; - document.getElementById('actions-complete').textContent = data.actions_complete || 0; - - // Update completion rates - if (data.actions_retired) { - document.getElementById('rate-1').textContent = formatRate(data.actions_retired.rate_1 || 0); - document.getElementById('rate-5').textContent = formatRate(data.actions_retired.rate_5 || 0); - document.getElementById('rate-15').textContent = formatRate(data.actions_retired.rate_15 || 0); - document.getElementById('retired-count').textContent = data.actions_retired.count || 0; - document.getElementById('rate-mean').textContent = formatRate(data.actions_retired.rate_mean || 0); - } - - // Update chart - const now = new Date().toLocaleTimeString(); - history.timestamps.push(now); - history.pending.push(data.actions_pending || 0); - history.running.push(data.actions_submitted || 0); - history.completed.push(data.actions_complete || 0); - - // Keep only last N points - if (history.timestamps.length > MAX_HISTORY_POINTS) { - history.timestamps.shift(); - history.pending.shift(); - history.running.shift(); - history.completed.shift(); - } - - chart.data.labels = history.timestamps; - chart.data.datasets[0].data = history.pending; - chart.data.datasets[1].data = history.running; - chart.data.datasets[2].data = history.completed; - chart.update('none'); - } - - async function fetchSysInfo() { - const data = await fetchJSON('/compute/sysinfo'); - - // Update CPU - const cpuUsage = data.cpu_usage || 0; - document.getElementById('cpu-usage').textContent = cpuUsage.toFixed(1) + '%'; - document.getElementById('cpu-progress').style.width = cpuUsage + '%'; - - const banner = document.querySelector('zen-banner'); - banner.setAttribute('load', cpuUsage.toFixed(1)); - - history.cpu.push(cpuUsage); - if (history.cpu.length > MAX_HISTORY_POINTS) history.cpu.shift(); - cpuChart.data.labels = history.cpu.map(() => ''); - cpuChart.data.datasets[0].data = history.cpu; - cpuChart.update('none'); - - document.getElementById('cpu-packages').textContent = data.cpu_count ?? '-'; - document.getElementById('cpu-cores').textContent = data.core_count ?? '-'; - document.getElementById('cpu-lp').textContent = data.lp_count ?? '-'; - - // Update Memory - const memUsed = data.memory_used || 0; - const memTotal = data.memory_total || 1; - const memPercent = (memUsed / memTotal) * 100; - document.getElementById('memory-used').textContent = formatBytes(memUsed); - document.getElementById('memory-total').textContent = formatBytes(memTotal); - document.getElementById('memory-progress').style.width = memPercent + '%'; - - // Update Disk - const diskUsed = data.disk_used || 0; - const diskTotal = data.disk_total || 1; - const diskPercent = (diskUsed / diskTotal) * 100; - document.getElementById('disk-used').textContent = formatBytes(diskUsed); - document.getElementById('disk-total').textContent = formatBytes(diskTotal); - document.getElementById('disk-progress').style.width = diskPercent + '%'; - } - - // Persists the selected worker ID across refreshes - let selectedWorkerId = null; - - function renderWorkerDetail(id, desc) { - const panel = document.getElementById('worker-detail'); - - if (!desc) { - panel.style.display = 'none'; - return; - } - - function field(label, value) { - return `<tr><td>${label}</td><td>${value ?? '-'}</td></tr>`; - } - - function monoField(label, value) { - return `<tr><td>${label}</td><td class="detail-mono">${value ?? '-'}</td></tr>`; - } - - // Functions - const functions = desc.functions || []; - const functionsHtml = functions.length === 0 ? '<span style="color:var(--theme_faint);font-size:12px;">none</span>' : - `<table class="detail-table">${functions.map(f => - `<tr><td>${escapeHtml(f.name || '-')}</td><td class="detail-mono">${escapeHtml(f.version || '-')}</td></tr>` - ).join('')}</table>`; - - // Executables - const executables = desc.executables || []; - const totalExecSize = executables.reduce((sum, e) => sum + (e.size || 0), 0); - const execHtml = executables.length === 0 ? '<span style="color:var(--theme_faint);font-size:12px;">none</span>' : - `<table class="detail-table"> - <tr style="font-size:11px;"> - <td style="color:var(--theme_faint);padding-bottom:4px;">Path</td> - <td style="color:var(--theme_faint);padding-bottom:4px;">Hash</td> - <td style="color:var(--theme_faint);padding-bottom:4px;text-align:right;">Size</td> - </tr> - ${executables.map(e => - `<tr> - <td>${escapeHtml(e.name || '-')}</td> - <td class="detail-mono">${escapeHtml(e.hash || '-')}</td> - <td style="text-align:right;white-space:nowrap;">${e.size != null ? formatBytes(e.size) : '-'}</td> - </tr>` - ).join('')} - <tr style="border-top:1px solid var(--theme_g2);"> - <td style="color:var(--theme_g1);padding-top:6px;">Total</td> - <td></td> - <td style="text-align:right;white-space:nowrap;padding-top:6px;color:var(--theme_bright);font-weight:600;">${formatBytes(totalExecSize)}</td> - </tr> - </table>`; - - // Files - const files = desc.files || []; - const filesHtml = files.length === 0 ? '<span style="color:var(--theme_faint);font-size:12px;">none</span>' : - `<table class="detail-table">${files.map(f => - `<tr><td>${escapeHtml(f.name || f)}</td><td class="detail-mono">${escapeHtml(f.hash || '')}</td></tr>` - ).join('')}</table>`; - - // Dirs - const dirs = desc.dirs || []; - const dirsHtml = dirs.length === 0 ? '<span style="color:var(--theme_faint);font-size:12px;">none</span>' : - dirs.map(d => `<span class="detail-tag">${escapeHtml(d)}</span>`).join(''); - - // Environment - const env = desc.environment || []; - const envHtml = env.length === 0 ? '<span style="color:var(--theme_faint);font-size:12px;">none</span>' : - env.map(e => `<span class="detail-tag">${escapeHtml(e)}</span>`).join(''); - - panel.innerHTML = ` - <div class="worker-detail-title">${escapeHtml(desc.name || id)}</div> - <div class="detail-section"> - <table class="detail-table"> - ${field('Worker ID', `<span class="detail-mono">${escapeHtml(id)}</span>`)} - ${field('Path', escapeHtml(desc.path || '-'))} - ${field('Platform', escapeHtml(desc.host || '-'))} - ${monoField('Build System', desc.buildsystem_version)} - ${field('Cores', desc.cores)} - ${field('Timeout', desc.timeout != null ? desc.timeout + 's' : null)} - </table> - </div> - <div class="detail-section"> - <div class="detail-section-label">Functions</div> - ${functionsHtml} - </div> - <div class="detail-section"> - <div class="detail-section-label">Executables</div> - ${execHtml} - </div> - <div class="detail-section"> - <div class="detail-section-label">Files</div> - ${filesHtml} - </div> - <div class="detail-section"> - <div class="detail-section-label">Directories</div> - ${dirsHtml} - </div> - <div class="detail-section"> - <div class="detail-section-label">Environment</div> - ${envHtml} - </div> - `; - panel.style.display = 'block'; - } - - async function fetchWorkers() { - const data = await fetchJSON('/compute/workers'); - const workerIds = data.workers || []; - - document.getElementById('worker-count').textContent = workerIds.length; - - const container = document.getElementById('worker-table-container'); - const tbody = document.getElementById('worker-table-body'); - - if (workerIds.length === 0) { - container.style.display = 'none'; - selectedWorkerId = null; - return; - } - - const descriptors = await Promise.all( - workerIds.map(id => fetchJSON(`/compute/workers/${id}`).catch(() => null)) - ); - - // Build a map for quick lookup by ID - const descriptorMap = {}; - workerIds.forEach((id, i) => { descriptorMap[id] = descriptors[i]; }); - - tbody.innerHTML = ''; - descriptors.forEach((desc, i) => { - const id = workerIds[i]; - const name = desc ? (desc.name || '-') : '-'; - const host = desc ? (desc.host || '-') : '-'; - const cores = desc ? (desc.cores != null ? desc.cores : '-') : '-'; - const timeout = desc ? (desc.timeout != null ? desc.timeout + 's' : '-') : '-'; - const functions = desc ? (desc.functions ? desc.functions.length : 0) : '-'; - - const tr = document.createElement('tr'); - tr.className = 'worker-row' + (id === selectedWorkerId ? ' selected' : ''); - tr.dataset.workerId = id; - tr.innerHTML = ` - <td style="color: var(--theme_bright);">${escapeHtml(name)}</td> - <td>${escapeHtml(host)}</td> - <td style="text-align: right;">${escapeHtml(String(cores))}</td> - <td style="text-align: right;">${escapeHtml(String(timeout))}</td> - <td style="text-align: right;">${escapeHtml(String(functions))}</td> - <td style="color: var(--theme_g1); font-family: monospace; font-size: 11px;">${escapeHtml(id)}</td> - `; - tr.addEventListener('click', () => { - document.querySelectorAll('.worker-row').forEach(r => r.classList.remove('selected')); - if (selectedWorkerId === id) { - // Toggle off - selectedWorkerId = null; - document.getElementById('worker-detail').style.display = 'none'; - } else { - selectedWorkerId = id; - tr.classList.add('selected'); - renderWorkerDetail(id, descriptorMap[id]); - } - }); - tbody.appendChild(tr); - }); - - // Re-render detail if selected worker is still present - if (selectedWorkerId && descriptorMap[selectedWorkerId]) { - renderWorkerDetail(selectedWorkerId, descriptorMap[selectedWorkerId]); - } else if (selectedWorkerId && !descriptorMap[selectedWorkerId]) { - selectedWorkerId = null; - document.getElementById('worker-detail').style.display = 'none'; - } - - container.style.display = 'block'; - } - - // Windows FILETIME: 100ns ticks since 1601-01-01. Convert to JS Date. - const FILETIME_EPOCH_OFFSET_MS = 11644473600000n; - function filetimeToDate(ticks) { - if (!ticks) return null; - const ms = BigInt(ticks) / 10000n - FILETIME_EPOCH_OFFSET_MS; - return new Date(Number(ms)); - } - - function formatTime(date) { - if (!date) return '-'; - return date.toLocaleTimeString([], { hour: '2-digit', minute: '2-digit', second: '2-digit' }); - } - - function formatDuration(startDate, endDate) { - if (!startDate || !endDate) return '-'; - const ms = endDate - startDate; - if (ms < 0) return '-'; - if (ms < 1000) return ms + ' ms'; - if (ms < 60000) return (ms / 1000).toFixed(2) + ' s'; - const m = Math.floor(ms / 60000); - const s = ((ms % 60000) / 1000).toFixed(0).padStart(2, '0'); - return `${m}m ${s}s`; - } - - async function fetchQueues() { - const data = await fetchJSON('/compute/queues'); - const queues = data.queues || []; - - const empty = document.getElementById('queue-list-empty'); - const container = document.getElementById('queue-list-container'); - const tbody = document.getElementById('queue-list-body'); - - if (queues.length === 0) { - empty.style.display = ''; - container.style.display = 'none'; - return; - } - - empty.style.display = 'none'; - tbody.innerHTML = ''; - - for (const q of queues) { - const id = q.queue_id ?? '-'; - const badge = q.state === 'cancelled' - ? '<span class="status-badge failure">cancelled</span>' - : q.state === 'draining' - ? '<span class="status-badge" style="background:color-mix(in srgb, var(--theme_warn) 15%, transparent);color:var(--theme_warn);">draining</span>' - : q.is_complete - ? '<span class="status-badge success">complete</span>' - : '<span class="status-badge" style="background:color-mix(in srgb, var(--theme_p0) 15%, transparent);color:var(--theme_p0);">active</span>'; - const token = q.queue_token - ? `<span class="detail-mono">${escapeHtml(q.queue_token)}</span>` - : '<span style="color:var(--theme_faint);">-</span>'; - - const tr = document.createElement('tr'); - tr.innerHTML = ` - <td style="text-align: right; font-family: monospace; color: var(--theme_bright);">${escapeHtml(String(id))}</td> - <td style="text-align: center;">${badge}</td> - <td style="text-align: right;">${q.active_count ?? 0}</td> - <td style="text-align: right; color: var(--theme_ok);">${q.completed_count ?? 0}</td> - <td style="text-align: right; color: var(--theme_fail);">${q.failed_count ?? 0}</td> - <td style="text-align: right; color: var(--theme_warn);">${q.abandoned_count ?? 0}</td> - <td style="text-align: right; color: var(--theme_warn);">${q.cancelled_count ?? 0}</td> - <td>${token}</td> - `; - tbody.appendChild(tr); - } - - container.style.display = 'block'; - } - - async function fetchActionHistory() { - const data = await fetchJSON('/compute/jobs/history?limit=50'); - const entries = data.history || []; - - const empty = document.getElementById('action-history-empty'); - const container = document.getElementById('action-history-container'); - const tbody = document.getElementById('action-history-body'); - - if (entries.length === 0) { - empty.style.display = ''; - container.style.display = 'none'; - return; - } - - empty.style.display = 'none'; - tbody.innerHTML = ''; - - // Entries arrive oldest-first; reverse to show newest at top - for (const entry of [...entries].reverse()) { - const lsn = entry.lsn ?? '-'; - const succeeded = entry.succeeded; - const badge = succeeded == null - ? '<span class="status-badge" style="background:var(--theme_border_subtle);color:var(--theme_g1);">unknown</span>' - : succeeded - ? '<span class="status-badge success">ok</span>' - : '<span class="status-badge failure">failed</span>'; - const desc = entry.actionDescriptor || {}; - const fn = desc.Function || '-'; - const workerId = entry.workerId || '-'; - const actionId = entry.actionId || '-'; - - const startDate = filetimeToDate(entry.time_Running); - const endDate = filetimeToDate(entry.time_Completed ?? entry.time_Failed); - - const queueId = entry.queueId || 0; - const queueCell = queueId - ? `<a href="/compute/queues/${queueId}" style="color: var(--theme_ln); text-decoration: none; font-family: monospace;">${escapeHtml(String(queueId))}</a>` - : '<span style="color: var(--theme_faint);">-</span>'; - - const tr = document.createElement('tr'); - tr.innerHTML = ` - <td style="text-align: right; font-family: monospace; color: var(--theme_g1);">${escapeHtml(String(lsn))}</td> - <td style="text-align: right;">${queueCell}</td> - <td style="text-align: center;">${badge}</td> - <td style="color: var(--theme_bright);">${escapeHtml(fn)}</td> - <td style="text-align: right; font-size: 12px; white-space: nowrap; color: var(--theme_g1);">${formatTime(startDate)}</td> - <td style="text-align: right; font-size: 12px; white-space: nowrap; color: var(--theme_g1);">${formatTime(endDate)}</td> - <td style="text-align: right; font-size: 12px; white-space: nowrap;">${formatDuration(startDate, endDate)}</td> - <td style="font-family: monospace; font-size: 11px; color: var(--theme_g1);">${escapeHtml(workerId)}</td> - <td style="font-family: monospace; font-size: 11px; color: var(--theme_g1);">${escapeHtml(actionId)}</td> - `; - tbody.appendChild(tr); - } - - container.style.display = 'block'; - } - - async function updateDashboard() { - try { - await Promise.all([ - fetchHealth(), - fetchStats(), - fetchSysInfo(), - fetchWorkers(), - fetchQueues(), - fetchActionHistory() - ]); - - clearError(); - updateTimestamp(); - } catch (error) { - console.error('Error updating dashboard:', error); - showError(error.message); - } - } - - // Start updating - updateDashboard(); - setInterval(updateDashboard, REFRESH_INTERVAL); - </script> -</body> -</html> diff --git a/src/zenserver/frontend/html/compute/index.html b/src/zenserver/frontend/html/compute/index.html index 9597fd7f3..aaa09aec0 100644 --- a/src/zenserver/frontend/html/compute/index.html +++ b/src/zenserver/frontend/html/compute/index.html @@ -1 +1 @@ -<meta http-equiv="refresh" content="0; url=compute.html" />
\ No newline at end of file +<meta http-equiv="refresh" content="0; url=/dashboard/?page=compute" />
\ No newline at end of file diff --git a/src/zenserver/frontend/html/compute/orchestrator.html b/src/zenserver/frontend/html/compute/orchestrator.html deleted file mode 100644 index d1a2bb015..000000000 --- a/src/zenserver/frontend/html/compute/orchestrator.html +++ /dev/null @@ -1,669 +0,0 @@ -<!DOCTYPE html> -<html lang="en"> -<head> - <meta charset="UTF-8"> - <meta name="viewport" content="width=device-width, initial-scale=1.0"> - <link rel="stylesheet" type="text/css" href="../zen.css" /> - <script src="../util/sanitize.js"></script> - <script src="../theme.js"></script> - <script src="../banner.js" defer></script> - <script src="../nav.js" defer></script> - <title>Zen Orchestrator Dashboard</title> - <style> - .agent-count { - display: flex; - align-items: center; - gap: 8px; - font-size: 14px; - padding: 8px 16px; - border-radius: 6px; - background: var(--theme_g3); - border: 1px solid var(--theme_g2); - } - - .agent-count .count { - font-size: 20px; - font-weight: 600; - color: var(--theme_bright); - } - </style> -</head> -<body> - <div class="container" style="max-width: 1400px; margin: 0 auto;"> - <zen-banner cluster-status="nominal" load="0" logo-src="../favicon.ico"></zen-banner> - <zen-nav> - <a href="/dashboard/">Home</a> - <a href="compute.html">Node</a> - <a href="orchestrator.html">Orchestrator</a> - </zen-nav> - <div class="header"> - <div> - <div class="timestamp">Last updated: <span id="last-update">Never</span></div> - </div> - <div class="agent-count"> - <span>Agents:</span> - <span class="count" id="agent-count">-</span> - </div> - </div> - - <div id="error-container"></div> - - <div class="card"> - <div class="card-title">Compute Agents</div> - <div id="empty-state" class="empty-state">No agents registered.</div> - <table id="agent-table" style="display: none;"> - <thead> - <tr> - <th style="width: 40px; text-align: center;">Health</th> - <th>Hostname</th> - <th style="text-align: right;">CPUs</th> - <th style="text-align: right;">CPU Usage</th> - <th style="text-align: right;">Memory</th> - <th style="text-align: right;">Queues</th> - <th style="text-align: right;">Pending</th> - <th style="text-align: right;">Running</th> - <th style="text-align: right;">Completed</th> - <th style="text-align: right;">Traffic</th> - <th style="text-align: right;">Last Seen</th> - </tr> - </thead> - <tbody id="agent-table-body"></tbody> - </table> - </div> - <div class="card" style="margin-top: 20px;"> - <div class="card-title">Connected Clients</div> - <div id="clients-empty" class="empty-state">No clients connected.</div> - <table id="clients-table" style="display: none;"> - <thead> - <tr> - <th style="width: 40px; text-align: center;">Health</th> - <th>Client ID</th> - <th>Hostname</th> - <th>Address</th> - <th style="text-align: right;">Last Seen</th> - </tr> - </thead> - <tbody id="clients-table-body"></tbody> - </table> - </div> - <div class="card" style="margin-top: 20px;"> - <div style="display: flex; align-items: center; gap: 12px; margin-bottom: 12px;"> - <div class="card-title" style="margin-bottom: 0;">Event History</div> - <div class="history-tabs"> - <button class="history-tab active" data-tab="workers" onclick="switchHistoryTab('workers')">Workers</button> - <button class="history-tab" data-tab="clients" onclick="switchHistoryTab('clients')">Clients</button> - </div> - </div> - <div id="history-panel-workers"> - <div id="history-empty" class="empty-state">No provisioning events recorded.</div> - <table id="history-table" style="display: none;"> - <thead> - <tr> - <th>Time</th> - <th>Event</th> - <th>Worker</th> - <th>Hostname</th> - </tr> - </thead> - <tbody id="history-table-body"></tbody> - </table> - </div> - <div id="history-panel-clients" style="display: none;"> - <div id="client-history-empty" class="empty-state">No client events recorded.</div> - <table id="client-history-table" style="display: none;"> - <thead> - <tr> - <th>Time</th> - <th>Event</th> - <th>Client</th> - <th>Hostname</th> - </tr> - </thead> - <tbody id="client-history-table-body"></tbody> - </table> - </div> - </div> - </div> - - <script> - const BASE_URL = window.location.origin; - const REFRESH_INTERVAL = 2000; - - function showError(message) { - document.getElementById('error-container').innerHTML = - '<div class="error">Error: ' + escapeHtml(message) + '</div>'; - } - - function clearError() { - document.getElementById('error-container').innerHTML = ''; - } - - function formatLastSeen(dtMs) { - if (dtMs == null) return '-'; - var seconds = Math.floor(dtMs / 1000); - if (seconds < 60) return seconds + 's ago'; - var minutes = Math.floor(seconds / 60); - if (minutes < 60) return minutes + 'm ' + (seconds % 60) + 's ago'; - var hours = Math.floor(minutes / 60); - return hours + 'h ' + (minutes % 60) + 'm ago'; - } - - function healthClass(dtMs, reachable) { - if (reachable === false) return 'health-red'; - if (dtMs == null) return 'health-red'; - var seconds = dtMs / 1000; - if (seconds < 30 && reachable === true) return 'health-green'; - if (seconds < 120) return 'health-yellow'; - return 'health-red'; - } - - function healthTitle(dtMs, reachable) { - var seenStr = dtMs != null ? 'Last seen ' + formatLastSeen(dtMs) : 'Never seen'; - if (reachable === true) return seenStr + ' · Reachable'; - if (reachable === false) return seenStr + ' · Unreachable'; - return seenStr + ' · Reachability unknown'; - } - - function formatCpuUsage(percent) { - if (percent == null || percent === 0) return '-'; - return percent.toFixed(1) + '%'; - } - - function formatMemory(usedBytes, totalBytes) { - if (!totalBytes) return '-'; - var usedGiB = usedBytes / (1024 * 1024 * 1024); - var totalGiB = totalBytes / (1024 * 1024 * 1024); - return usedGiB.toFixed(1) + ' / ' + totalGiB.toFixed(1) + ' GiB'; - } - - function formatBytes(bytes) { - if (!bytes) return '-'; - if (bytes < 1024) return bytes + ' B'; - if (bytes < 1024 * 1024) return (bytes / 1024).toFixed(1) + ' KiB'; - if (bytes < 1024 * 1024 * 1024) return (bytes / (1024 * 1024)).toFixed(1) + ' MiB'; - if (bytes < 1024 * 1024 * 1024 * 1024) return (bytes / (1024 * 1024 * 1024)).toFixed(1) + ' GiB'; - return (bytes / (1024 * 1024 * 1024 * 1024)).toFixed(1) + ' TiB'; - } - - function formatTraffic(recv, sent) { - if (!recv && !sent) return '-'; - return formatBytes(recv) + ' / ' + formatBytes(sent); - } - - function parseIpFromUri(uri) { - try { - var url = new URL(uri); - var host = url.hostname; - // Strip IPv6 brackets - if (host.startsWith('[') && host.endsWith(']')) host = host.slice(1, -1); - // Only handle IPv4 - var parts = host.split('.'); - if (parts.length !== 4) return null; - var octets = parts.map(Number); - if (octets.some(function(o) { return isNaN(o) || o < 0 || o > 255; })) return null; - return octets; - } catch (e) { - return null; - } - } - - function computeCidr(ips) { - if (ips.length === 0) return null; - if (ips.length === 1) return ips[0].join('.') + '/32'; - - // Convert each IP to a 32-bit integer - var ints = ips.map(function(o) { - return ((o[0] << 24) | (o[1] << 16) | (o[2] << 8) | o[3]) >>> 0; - }); - - // Find common prefix length by ANDing all identical high bits - var common = ~0 >>> 0; - for (var i = 1; i < ints.length; i++) { - // XOR to find differing bits, then mask away everything from the first difference down - var diff = (ints[0] ^ ints[i]) >>> 0; - if (diff !== 0) { - var bit = 31 - Math.floor(Math.log2(diff)); - var mask = bit > 0 ? ((~0 << (32 - bit)) >>> 0) : 0; - common = (common & mask) >>> 0; - } - } - - // Count leading ones in the common mask - var prefix = 0; - for (var b = 31; b >= 0; b--) { - if ((common >>> b) & 1) prefix++; - else break; - } - - // Network address - var net = (ints[0] & common) >>> 0; - var a = (net >>> 24) & 0xff; - var bv = (net >>> 16) & 0xff; - var c = (net >>> 8) & 0xff; - var d = net & 0xff; - return a + '.' + bv + '.' + c + '.' + d + '/' + prefix; - } - - function renderDashboard(data) { - var banner = document.querySelector('zen-banner'); - if (data.hostname) { - banner.setAttribute('tagline', 'Orchestrator \u2014 ' + data.hostname); - } - var workers = data.workers || []; - - document.getElementById('agent-count').textContent = workers.length; - - if (workers.length === 0) { - banner.setAttribute('cluster-status', 'degraded'); - banner.setAttribute('load', '0'); - } else { - banner.setAttribute('cluster-status', 'nominal'); - } - - var emptyState = document.getElementById('empty-state'); - var table = document.getElementById('agent-table'); - var tbody = document.getElementById('agent-table-body'); - - if (workers.length === 0) { - emptyState.style.display = ''; - table.style.display = 'none'; - } else { - emptyState.style.display = 'none'; - table.style.display = ''; - - tbody.innerHTML = ''; - var totalCpus = 0; - var totalWeightedCpuUsage = 0; - var totalMemUsed = 0; - var totalMemTotal = 0; - var totalQueues = 0; - var totalPending = 0; - var totalRunning = 0; - var totalCompleted = 0; - var totalBytesRecv = 0; - var totalBytesSent = 0; - var allIps = []; - for (var i = 0; i < workers.length; i++) { - var w = workers[i]; - var uri = w.uri || ''; - var dt = w.dt; - var dashboardUrl = uri + '/dashboard/compute/'; - - var id = w.id || ''; - - var hostname = w.hostname || ''; - var cpus = w.cpus || 0; - totalCpus += cpus; - if (cpus > 0 && typeof w.cpu_usage === 'number') { - totalWeightedCpuUsage += w.cpu_usage * cpus; - } - - var memTotal = w.memory_total || 0; - var memUsed = w.memory_used || 0; - totalMemTotal += memTotal; - totalMemUsed += memUsed; - - var activeQueues = w.active_queues || 0; - totalQueues += activeQueues; - - var actionsPending = w.actions_pending || 0; - var actionsRunning = w.actions_running || 0; - var actionsCompleted = w.actions_completed || 0; - totalPending += actionsPending; - totalRunning += actionsRunning; - totalCompleted += actionsCompleted; - - var bytesRecv = w.bytes_received || 0; - var bytesSent = w.bytes_sent || 0; - totalBytesRecv += bytesRecv; - totalBytesSent += bytesSent; - - var ip = parseIpFromUri(uri); - if (ip) allIps.push(ip); - - var reachable = w.reachable; - var hClass = healthClass(dt, reachable); - var hTitle = healthTitle(dt, reachable); - - var platform = w.platform || ''; - var badges = ''; - if (platform) { - var platColors = { windows: '#0078d4', wine: '#722f37', linux: '#e95420', macos: '#a2aaad' }; - var platColor = platColors[platform] || '#8b949e'; - badges += ' <span style="display:inline-block;padding:1px 6px;border-radius:10px;font-size:10px;font-weight:600;color:#fff;background:' + platColor + ';vertical-align:middle;margin-left:4px;">' + escapeHtml(platform) + '</span>'; - } - var provisioner = w.provisioner || ''; - if (provisioner) { - var provColors = { horde: '#8957e5', nomad: '#3fb950' }; - var provColor = provColors[provisioner] || '#8b949e'; - badges += ' <span style="display:inline-block;padding:1px 6px;border-radius:10px;font-size:10px;font-weight:600;color:#fff;background:' + provColor + ';vertical-align:middle;margin-left:4px;">' + escapeHtml(provisioner) + '</span>'; - } - - var tr = document.createElement('tr'); - tr.title = id; - tr.innerHTML = - '<td style="text-align: center;"><span class="health-dot ' + hClass + '" title="' + escapeHtml(hTitle) + '"></span></td>' + - '<td><a href="' + escapeHtml(dashboardUrl) + '" target="_blank">' + escapeHtml(hostname) + '</a>' + badges + '</td>' + - '<td style="text-align: right;">' + (cpus > 0 ? cpus : '-') + '</td>' + - '<td style="text-align: right;">' + formatCpuUsage(w.cpu_usage) + '</td>' + - '<td style="text-align: right;">' + formatMemory(memUsed, memTotal) + '</td>' + - '<td style="text-align: right;">' + (activeQueues > 0 ? activeQueues : '-') + '</td>' + - '<td style="text-align: right;">' + actionsPending + '</td>' + - '<td style="text-align: right;">' + actionsRunning + '</td>' + - '<td style="text-align: right;">' + actionsCompleted + '</td>' + - '<td style="text-align: right; font-size: 11px; color: var(--theme_g1);">' + formatTraffic(bytesRecv, bytesSent) + '</td>' + - '<td style="text-align: right; color: var(--theme_g1);">' + formatLastSeen(dt) + '</td>'; - tbody.appendChild(tr); - } - - var clusterLoad = totalCpus > 0 ? (totalWeightedCpuUsage / totalCpus) : 0; - banner.setAttribute('load', clusterLoad.toFixed(1)); - - // Total row - var cidr = computeCidr(allIps); - var totalTr = document.createElement('tr'); - totalTr.className = 'total-row'; - totalTr.innerHTML = - '<td></td>' + - '<td style="text-align: right; color: var(--theme_g1); text-transform: uppercase; font-size: 11px;">Total' + (cidr ? ' <span style="font-family: monospace; font-weight: normal;">' + escapeHtml(cidr) + '</span>' : '') + '</td>' + - '<td style="text-align: right;">' + totalCpus + '</td>' + - '<td></td>' + - '<td style="text-align: right;">' + formatMemory(totalMemUsed, totalMemTotal) + '</td>' + - '<td style="text-align: right;">' + totalQueues + '</td>' + - '<td style="text-align: right;">' + totalPending + '</td>' + - '<td style="text-align: right;">' + totalRunning + '</td>' + - '<td style="text-align: right;">' + totalCompleted + '</td>' + - '<td style="text-align: right; font-size: 11px;">' + formatTraffic(totalBytesRecv, totalBytesSent) + '</td>' + - '<td></td>'; - tbody.appendChild(totalTr); - } - - clearError(); - document.getElementById('last-update').textContent = new Date().toLocaleTimeString(); - - // Render provisioning history if present in WebSocket payload - if (data.events) { - renderProvisioningHistory(data.events); - } - - // Render connected clients if present - if (data.clients) { - renderClients(data.clients); - } - - // Render client history if present - if (data.client_events) { - renderClientHistory(data.client_events); - } - } - - function eventBadge(type) { - var colors = { joined: 'var(--theme_ok)', left: 'var(--theme_fail)', returned: 'var(--theme_warn)' }; - var labels = { joined: 'Joined', left: 'Left', returned: 'Returned' }; - var color = colors[type] || 'var(--theme_g1)'; - var label = labels[type] || type; - return '<span style="display:inline-block;padding:2px 8px;border-radius:4px;font-size:11px;font-weight:600;color:var(--theme_g4);background:' + color + ';">' + escapeHtml(label) + '</span>'; - } - - function formatTimestamp(ts) { - if (!ts) return '-'; - // CbObject DateTime serialized as ticks (100ns since 0001-01-01) or ISO string - var date; - if (typeof ts === 'number') { - // .NET-style ticks: convert to Unix ms - var unixMs = (ts - 621355968000000000) / 10000; - date = new Date(unixMs); - } else { - date = new Date(ts); - } - if (isNaN(date.getTime())) return '-'; - return date.toLocaleTimeString(); - } - - var activeHistoryTab = 'workers'; - - function switchHistoryTab(tab) { - activeHistoryTab = tab; - var tabs = document.querySelectorAll('.history-tab'); - for (var i = 0; i < tabs.length; i++) { - tabs[i].classList.toggle('active', tabs[i].getAttribute('data-tab') === tab); - } - document.getElementById('history-panel-workers').style.display = tab === 'workers' ? '' : 'none'; - document.getElementById('history-panel-clients').style.display = tab === 'clients' ? '' : 'none'; - } - - function renderProvisioningHistory(events) { - var emptyState = document.getElementById('history-empty'); - var table = document.getElementById('history-table'); - var tbody = document.getElementById('history-table-body'); - - if (!events || events.length === 0) { - emptyState.style.display = ''; - table.style.display = 'none'; - return; - } - - emptyState.style.display = 'none'; - table.style.display = ''; - tbody.innerHTML = ''; - - for (var i = 0; i < events.length; i++) { - var evt = events[i]; - var tr = document.createElement('tr'); - tr.innerHTML = - '<td style="color: var(--theme_g1);">' + formatTimestamp(evt.ts) + '</td>' + - '<td>' + eventBadge(evt.type) + '</td>' + - '<td>' + escapeHtml(evt.worker_id || '') + '</td>' + - '<td>' + escapeHtml(evt.hostname || '') + '</td>'; - tbody.appendChild(tr); - } - } - - function clientHealthClass(dtMs) { - if (dtMs == null) return 'health-red'; - var seconds = dtMs / 1000; - if (seconds < 30) return 'health-green'; - if (seconds < 120) return 'health-yellow'; - return 'health-red'; - } - - function renderClients(clients) { - var emptyState = document.getElementById('clients-empty'); - var table = document.getElementById('clients-table'); - var tbody = document.getElementById('clients-table-body'); - - if (!clients || clients.length === 0) { - emptyState.style.display = ''; - table.style.display = 'none'; - return; - } - - emptyState.style.display = 'none'; - table.style.display = ''; - tbody.innerHTML = ''; - - for (var i = 0; i < clients.length; i++) { - var c = clients[i]; - var dt = c.dt; - var hClass = clientHealthClass(dt); - var hTitle = dt != null ? 'Last seen ' + formatLastSeen(dt) : 'Never seen'; - - var sessionBadge = ''; - if (c.session_id) { - sessionBadge = ' <span style="font-family:monospace;font-size:10px;color:var(--theme_faint);" title="Session ' + escapeHtml(c.session_id) + '">' + escapeHtml(c.session_id.substring(0, 8)) + '</span>'; - } - - var tr = document.createElement('tr'); - tr.innerHTML = - '<td style="text-align: center;"><span class="health-dot ' + hClass + '" title="' + escapeHtml(hTitle) + '"></span></td>' + - '<td>' + escapeHtml(c.id || '') + sessionBadge + '</td>' + - '<td>' + escapeHtml(c.hostname || '') + '</td>' + - '<td style="font-family: monospace; font-size: 12px; color: var(--theme_g1);">' + escapeHtml(c.address || '') + '</td>' + - '<td style="text-align: right; color: var(--theme_g1);">' + formatLastSeen(dt) + '</td>'; - tbody.appendChild(tr); - } - } - - function clientEventBadge(type) { - var colors = { connected: 'var(--theme_ok)', disconnected: 'var(--theme_fail)', updated: 'var(--theme_warn)' }; - var labels = { connected: 'Connected', disconnected: 'Disconnected', updated: 'Updated' }; - var color = colors[type] || 'var(--theme_g1)'; - var label = labels[type] || type; - return '<span style="display:inline-block;padding:2px 8px;border-radius:4px;font-size:11px;font-weight:600;color:var(--theme_g4);background:' + color + ';">' + escapeHtml(label) + '</span>'; - } - - function renderClientHistory(events) { - var emptyState = document.getElementById('client-history-empty'); - var table = document.getElementById('client-history-table'); - var tbody = document.getElementById('client-history-table-body'); - - if (!events || events.length === 0) { - emptyState.style.display = ''; - table.style.display = 'none'; - return; - } - - emptyState.style.display = 'none'; - table.style.display = ''; - tbody.innerHTML = ''; - - for (var i = 0; i < events.length; i++) { - var evt = events[i]; - var tr = document.createElement('tr'); - tr.innerHTML = - '<td style="color: var(--theme_g1);">' + formatTimestamp(evt.ts) + '</td>' + - '<td>' + clientEventBadge(evt.type) + '</td>' + - '<td>' + escapeHtml(evt.client_id || '') + '</td>' + - '<td>' + escapeHtml(evt.hostname || '') + '</td>'; - tbody.appendChild(tr); - } - } - - // Fetch-based polling fallback - var pollTimer = null; - - async function fetchProvisioningHistory() { - try { - var response = await fetch(BASE_URL + '/orch/history?limit=50', { - headers: { 'Accept': 'application/json' } - }); - if (response.ok) { - var data = await response.json(); - renderProvisioningHistory(data.events || []); - } - } catch (e) { - console.error('Error fetching provisioning history:', e); - } - } - - async function fetchClients() { - try { - var response = await fetch(BASE_URL + '/orch/clients', { - headers: { 'Accept': 'application/json' } - }); - if (response.ok) { - var data = await response.json(); - renderClients(data.clients || []); - } - } catch (e) { - console.error('Error fetching clients:', e); - } - } - - async function fetchClientHistory() { - try { - var response = await fetch(BASE_URL + '/orch/clients/history?limit=50', { - headers: { 'Accept': 'application/json' } - }); - if (response.ok) { - var data = await response.json(); - renderClientHistory(data.client_events || []); - } - } catch (e) { - console.error('Error fetching client history:', e); - } - } - - async function fetchDashboard() { - var banner = document.querySelector('zen-banner'); - try { - var response = await fetch(BASE_URL + '/orch/agents', { - headers: { 'Accept': 'application/json' } - }); - - if (!response.ok) { - banner.setAttribute('cluster-status', 'degraded'); - throw new Error('HTTP ' + response.status + ': ' + response.statusText); - } - - renderDashboard(await response.json()); - fetchProvisioningHistory(); - fetchClients(); - fetchClientHistory(); - } catch (error) { - console.error('Error updating dashboard:', error); - showError(error.message); - banner.setAttribute('cluster-status', 'offline'); - } - } - - function startPolling() { - if (pollTimer) return; - fetchDashboard(); - pollTimer = setInterval(fetchDashboard, REFRESH_INTERVAL); - } - - function stopPolling() { - if (pollTimer) { - clearInterval(pollTimer); - pollTimer = null; - } - } - - // WebSocket connection with automatic reconnect and polling fallback - var ws = null; - - function connectWebSocket() { - var proto = window.location.protocol === 'https:' ? 'wss:' : 'ws:'; - ws = new WebSocket(proto + '//' + window.location.host + '/orch/ws'); - - ws.onopen = function() { - stopPolling(); - clearError(); - }; - - ws.onmessage = function(event) { - try { - renderDashboard(JSON.parse(event.data)); - } catch (e) { - console.error('WebSocket message parse error:', e); - } - }; - - ws.onclose = function() { - ws = null; - startPolling(); - setTimeout(connectWebSocket, 3000); - }; - - ws.onerror = function() { - // onclose will fire after onerror - }; - } - - // Fetch orchestrator hostname for the banner - fetch(BASE_URL + '/orch/status', { headers: { 'Accept': 'application/json' } }) - .then(function(r) { return r.ok ? r.json() : null; }) - .then(function(d) { - if (d && d.hostname) { - document.querySelector('zen-banner').setAttribute('tagline', 'Orchestrator \u2014 ' + d.hostname); - } - }) - .catch(function() {}); - - // Initial load via fetch, then try WebSocket - fetchDashboard(); - connectWebSocket(); - </script> -</body> -</html> diff --git a/src/zenserver/frontend/html/pages/orchestrator.js b/src/zenserver/frontend/html/pages/orchestrator.js index a280fabdb..30f6a8122 100644 --- a/src/zenserver/frontend/html/pages/orchestrator.js +++ b/src/zenserver/frontend/html/pages/orchestrator.js @@ -14,6 +14,14 @@ export class Page extends ZenPage { this.set_title("orchestrator"); + // Provisioner section (hidden until data arrives) + this._prov_section = this._collapsible_section("Provisioner"); + this._prov_section._parent.inner().style.display = "none"; + this._prov_grid = null; + this._prov_target_dirty = false; + this._prov_commit_timer = null; + this._prov_last_target = null; + // Agents section const agents_section = this._collapsible_section("Compute Agents"); this._agents_host = agents_section; @@ -50,11 +58,12 @@ export class Page extends ZenPage { try { - const [agents, history, clients, client_history] = await Promise.all([ + const [agents, history, clients, client_history, prov] = await Promise.all([ new Fetcher().resource("/orch/agents").json(), new Fetcher().resource("/orch/history").param("limit", "50").json().catch(() => null), new Fetcher().resource("/orch/clients").json().catch(() => null), new Fetcher().resource("/orch/clients/history").param("limit", "50").json().catch(() => null), + new Fetcher().resource("/orch/provisioner/status").json().catch(() => null), ]); this._render_agents(agents); @@ -70,6 +79,7 @@ export class Page extends ZenPage { this._render_client_history(client_history.client_events || []); } + this._render_provisioner(prov); } catch (e) { /* service unavailable */ } } @@ -109,6 +119,7 @@ export class Page extends ZenPage { this._render_client_history(data.client_events); } + this._render_provisioner(data.provisioner); } catch (e) { /* ignore parse errors */ } }; @@ -156,7 +167,7 @@ export class Page extends ZenPage return; } - let totalCpus = 0, totalWeightedCpu = 0; + let totalCpus = 0, activeCpus = 0, totalWeightedCpu = 0; let totalMemUsed = 0, totalMemTotal = 0; let totalQueues = 0, totalPending = 0, totalRunning = 0, totalCompleted = 0; let totalRecv = 0, totalSent = 0; @@ -173,8 +184,14 @@ export class Page extends ZenPage const completed = w.actions_completed || 0; const recv = w.bytes_received || 0; const sent = w.bytes_sent || 0; + const provisioner = w.provisioner || ""; + const isProvisioned = provisioner !== ""; totalCpus += cpus; + if (w.provisioner_status === "active") + { + activeCpus += cpus; + } if (cpus > 0 && typeof cpuUsage === "number") { totalWeightedCpu += cpuUsage * cpus; @@ -209,12 +226,49 @@ export class Page extends ZenPage cell.inner().textContent = ""; cell.tag("a").text(hostname).attr("href", w.uri + "/dashboard/compute/").attr("target", "_blank"); } + + // Visual treatment based on provisioner status + const provStatus = w.provisioner_status || ""; + if (!isProvisioned) + { + row.inner().style.opacity = "0.45"; + } + else + { + const hostCell = row.get_cell(0); + const el = hostCell.inner(); + const badge = document.createElement("span"); + const badgeBase = "display:inline-block;margin-left:6px;padding:1px 5px;border-radius:8px;" + + "font-size:9px;font-weight:600;color:#fff;vertical-align:middle;"; + + if (provStatus === "draining") + { + badge.textContent = "draining"; + badge.style.cssText = badgeBase + "background:var(--theme_warn);"; + row.inner().style.opacity = "0.6"; + } + else if (provStatus === "active") + { + badge.textContent = provisioner; + badge.style.cssText = badgeBase + "background:#8957e5;"; + } + else + { + badge.textContent = "deallocated"; + badge.style.cssText = badgeBase + "background:var(--theme_fail);"; + row.inner().style.opacity = "0.45"; + } + el.appendChild(badge); + } } - // Total row + // Total row — show active / total in CPUs column + const cpuLabel = activeCpus < totalCpus + ? Friendly.sep(activeCpus) + " / " + Friendly.sep(totalCpus) + : Friendly.sep(totalCpus); const total = this._agents_table.add_row( "TOTAL", - Friendly.sep(totalCpus), + cpuLabel, "", totalMemTotal > 0 ? Friendly.bytes(totalMemUsed) + " / " + Friendly.bytes(totalMemTotal) : "-", Friendly.sep(totalQueues), @@ -305,6 +359,154 @@ export class Page extends ZenPage } } + _render_provisioner(prov) + { + const container = this._prov_section._parent.inner(); + + if (!prov || !prov.name) + { + container.style.display = "none"; + return; + } + container.style.display = ""; + + if (!this._prov_grid) + { + this._prov_grid = this._prov_section.tag().classify("grid").classify("stats-tiles"); + this._prov_tiles = {}; + + // Target cores tile with editable input + const target_tile = this._prov_grid.tag().classify("card").classify("stats-tile"); + target_tile.tag().classify("card-title").text("Target Cores"); + const target_body = target_tile.tag().classify("tile-metrics"); + const target_m = target_body.tag().classify("tile-metric").classify("tile-metric-hero"); + const input = document.createElement("input"); + input.type = "number"; + input.min = "0"; + input.style.cssText = "width:100px;padding:4px 8px;border:1px solid var(--theme_g2);border-radius:4px;" + + "background:var(--theme_g4);color:var(--theme_bright);font-size:20px;font-weight:600;text-align:right;"; + target_m.inner().appendChild(input); + target_m.tag().classify("metric-label").text("target"); + this._prov_tiles.target_input = input; + + input.addEventListener("focus", () => { this._prov_target_dirty = true; }); + input.addEventListener("input", () => { + this._prov_target_dirty = true; + if (this._prov_commit_timer) + { + clearTimeout(this._prov_commit_timer); + } + this._prov_commit_timer = setTimeout(() => this._commit_provisioner_target(), 800); + }); + input.addEventListener("keydown", (e) => { + if (e.key === "Enter") + { + if (this._prov_commit_timer) + { + clearTimeout(this._prov_commit_timer); + } + this._commit_provisioner_target(); + input.blur(); + } + }); + input.addEventListener("blur", () => { + if (this._prov_commit_timer) + { + clearTimeout(this._prov_commit_timer); + } + this._commit_provisioner_target(); + }); + + // Active cores + const active_tile = this._prov_grid.tag().classify("card").classify("stats-tile"); + active_tile.tag().classify("card-title").text("Active Cores"); + const active_body = active_tile.tag().classify("tile-metrics"); + this._prov_tiles.active = active_body; + + // Estimated cores + const est_tile = this._prov_grid.tag().classify("card").classify("stats-tile"); + est_tile.tag().classify("card-title").text("Estimated Cores"); + const est_body = est_tile.tag().classify("tile-metrics"); + this._prov_tiles.estimated = est_body; + + // Agents + const agents_tile = this._prov_grid.tag().classify("card").classify("stats-tile"); + agents_tile.tag().classify("card-title").text("Agents"); + const agents_body = agents_tile.tag().classify("tile-metrics"); + this._prov_tiles.agents = agents_body; + + // Draining + const drain_tile = this._prov_grid.tag().classify("card").classify("stats-tile"); + drain_tile.tag().classify("card-title").text("Draining"); + const drain_body = drain_tile.tag().classify("tile-metrics"); + this._prov_tiles.draining = drain_body; + } + + // Update values + const input = this._prov_tiles.target_input; + if (!this._prov_target_dirty && document.activeElement !== input) + { + input.value = prov.target_cores; + } + this._prov_last_target = prov.target_cores; + + // Re-render metric tiles (clear and recreate content) + for (const key of ["active", "estimated", "agents", "draining"]) + { + this._prov_tiles[key].inner().innerHTML = ""; + } + this._metric(this._prov_tiles.active, Friendly.sep(prov.active_cores), "cores", true); + this._metric(this._prov_tiles.estimated, Friendly.sep(prov.estimated_cores), "cores", true); + this._metric(this._prov_tiles.agents, Friendly.sep(prov.agents), "active", true); + this._metric(this._prov_tiles.draining, Friendly.sep(prov.agents_draining || 0), "agents", true); + } + + async _commit_provisioner_target() + { + const input = this._prov_tiles?.target_input; + if (!input || this._prov_committing) + { + return; + } + const value = parseInt(input.value, 10); + if (isNaN(value) || value < 0) + { + return; + } + if (value === this._prov_last_target) + { + this._prov_target_dirty = false; + return; + } + this._prov_committing = true; + try + { + const resp = await fetch("/orch/provisioner/target", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ target_cores: value }), + }); + if (resp.ok) + { + this._prov_target_dirty = false; + console.log("Target cores set to", value); + } + else + { + const text = await resp.text(); + console.error("Failed to set target cores: HTTP", resp.status, text); + } + } + catch (e) + { + console.error("Failed to set target cores:", e); + } + finally + { + this._prov_committing = false; + } + } + _metric(parent, value, label, hero = false) { const m = parent.tag().classify("tile-metric"); diff --git a/src/zenserver/main.cpp b/src/zenserver/main.cpp index c5f8724ca..108685eb9 100644 --- a/src/zenserver/main.cpp +++ b/src/zenserver/main.cpp @@ -14,7 +14,6 @@ #include <zencore/memory/memorytrace.h> #include <zencore/memory/newdelete.h> #include <zencore/scopeguard.h> -#include <zencore/sentryintegration.h> #include <zencore/session.h> #include <zencore/string.h> #include <zencore/thread.h> @@ -169,7 +168,12 @@ AppMain(int argc, char* argv[]) if (IsDir(ServerOptions.DataDir)) { ZEN_CONSOLE_INFO("Deleting files from '{}' ({})", ServerOptions.DataDir, DeleteReason); - DeleteDirectories(ServerOptions.DataDir); + std::error_code Ec; + DeleteDirectories(ServerOptions.DataDir, Ec); + if (Ec) + { + ZEN_WARN("could not fully clean '{}': {} (continuing anyway)", ServerOptions.DataDir, Ec.message()); + } } } |