aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorStefan Boberg <[email protected]>2026-04-13 16:38:58 +0200
committerGitHub Enterprise <[email protected]>2026-04-13 16:38:58 +0200
commitf387a069967e960305cc189827093111eb5b82e7 (patch)
tree1bb9b5c79d87ba64a8a10c23958dfa98769950ba /src
parentMerge branch 'main' into sb/tourist (diff)
parentCompute OIDC auth, async Horde agents, and orchestrator improvements (#913) (diff)
downloadzen-sb/tourist.tar.xz
zen-sb/tourist.zip
Merge branch 'main' into sb/touristsb/tourist
Diffstat (limited to 'src')
-rw-r--r--src/zen/cmds/compute_cmd.cpp96
-rw-r--r--src/zen/cmds/compute_cmd.h53
-rw-r--r--src/zen/cmds/exec_cmd.cpp80
-rw-r--r--src/zen/zen.cpp5
-rw-r--r--src/zencompute/CLAUDE.md7
-rw-r--r--src/zencompute/computeservice.cpp164
-rw-r--r--src/zencompute/httpcomputeservice.cpp95
-rw-r--r--src/zencompute/httporchestrator.cpp135
-rw-r--r--src/zencompute/include/zencompute/computeservice.h7
-rw-r--r--src/zencompute/include/zencompute/httpcomputeservice.h4
-rw-r--r--src/zencompute/include/zencompute/httporchestrator.h17
-rw-r--r--src/zencompute/include/zencompute/orchestratorservice.h12
-rw-r--r--src/zencompute/include/zencompute/provisionerstate.h38
-rw-r--r--src/zencompute/orchestratorservice.cpp29
-rw-r--r--src/zencompute/runners/functionrunner.cpp120
-rw-r--r--src/zencompute/runners/functionrunner.h27
-rw-r--r--src/zencompute/runners/linuxrunner.cpp6
-rw-r--r--src/zencompute/runners/localrunner.cpp19
-rw-r--r--src/zencompute/runners/macrunner.cpp6
-rw-r--r--src/zencompute/runners/managedrunner.cpp2
-rw-r--r--src/zencompute/runners/remotehttprunner.cpp360
-rw-r--r--src/zencompute/runners/remotehttprunner.h12
-rw-r--r--src/zencompute/runners/windowsrunner.cpp4
-rw-r--r--src/zencompute/runners/winerunner.cpp4
-rw-r--r--src/zenhorde/README.md17
-rw-r--r--src/zenhorde/hordeagent.cpp551
-rw-r--r--src/zenhorde/hordeagent.h128
-rw-r--r--src/zenhorde/hordeagentmessage.cpp502
-rw-r--r--src/zenhorde/hordeagentmessage.h123
-rw-r--r--src/zenhorde/hordebundle.cpp2
-rw-r--r--src/zenhorde/hordeclient.cpp65
-rw-r--r--src/zenhorde/hordecomputebuffer.cpp454
-rw-r--r--src/zenhorde/hordecomputebuffer.h136
-rw-r--r--src/zenhorde/hordecomputechannel.cpp37
-rw-r--r--src/zenhorde/hordecomputechannel.h32
-rw-r--r--src/zenhorde/hordecomputesocket.cpp410
-rw-r--r--src/zenhorde/hordecomputesocket.h104
-rw-r--r--src/zenhorde/hordeconfig.cpp16
-rw-r--r--src/zenhorde/hordeprovisioner.cpp664
-rw-r--r--src/zenhorde/hordetransport.cpp153
-rw-r--r--src/zenhorde/hordetransport.h67
-rw-r--r--src/zenhorde/hordetransportaes.cpp609
-rw-r--r--src/zenhorde/hordetransportaes.h50
-rw-r--r--src/zenhorde/include/zenhorde/hordeclient.h32
-rw-r--r--src/zenhorde/include/zenhorde/hordeconfig.h37
-rw-r--r--src/zenhorde/include/zenhorde/hordeprovisioner.h80
-rw-r--r--src/zenhttp/clients/httpclientcurl.cpp14
-rw-r--r--src/zenhttp/clients/httpclientcurl.h1
-rw-r--r--src/zenhttp/httpclientauth.cpp18
-rw-r--r--src/zenhttp/include/zenhttp/httpclientauth.h3
-rw-r--r--src/zennomad/include/zennomad/nomadclient.h6
-rw-r--r--src/zennomad/include/zennomad/nomadprovisioner.h9
-rw-r--r--src/zennomad/nomadclient.cpp38
-rw-r--r--src/zennomad/nomadprovisioner.cpp11
-rw-r--r--src/zenremotestore/builds/buildstorageoperations.cpp521
-rw-r--r--src/zenserver/compute/computeserver.cpp108
-rw-r--r--src/zenserver/compute/computeserver.h7
-rw-r--r--src/zenserver/config/config.cpp23
-rw-r--r--src/zenserver/frontend/html/compute/compute.html925
-rw-r--r--src/zenserver/frontend/html/compute/index.html2
-rw-r--r--src/zenserver/frontend/html/compute/orchestrator.html669
-rw-r--r--src/zenserver/frontend/html/pages/orchestrator.js210
-rw-r--r--src/zenserver/main.cpp8
63 files changed, 3985 insertions, 4159 deletions
diff --git a/src/zen/cmds/compute_cmd.cpp b/src/zen/cmds/compute_cmd.cpp
new file mode 100644
index 000000000..01166cb0e
--- /dev/null
+++ b/src/zen/cmds/compute_cmd.cpp
@@ -0,0 +1,96 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "compute_cmd.h"
+
+#if ZEN_WITH_COMPUTE_SERVICES
+
+# include <zencore/compactbinary.h>
+# include <zencore/logging.h>
+# include <zenhttp/httpclient.h>
+
+using namespace std::literals;
+
+namespace zen {
+
+//////////////////////////////////////////////////////////////////////////
+// ComputeRecordStartSubCmd
+
+ComputeRecordStartSubCmd::ComputeRecordStartSubCmd() : ZenSubCmdBase("record-start", "Start recording compute actions")
+{
+ SubOptions().add_option("", "u", "hosturl", ZenCmdBase::kHostUrlHelp, cxxopts::value(m_HostName)->default_value(""), "<hosturl>");
+}
+
+void
+ComputeRecordStartSubCmd::Run(const ZenCliOptions& GlobalOptions)
+{
+ ZEN_UNUSED(GlobalOptions);
+
+ m_HostName = ZenCmdBase::ResolveTargetHostSpec(m_HostName);
+ if (m_HostName.empty())
+ {
+ throw OptionParseException("Unable to resolve server specification", SubOptions().help());
+ }
+
+ HttpClient Http = ZenCmdBase::CreateHttpClient(m_HostName);
+ if (HttpClient::Response Response = Http.Post("/compute/record/start"sv, HttpClient::KeyValueMap{}, HttpClient::KeyValueMap{}))
+ {
+ CbObject Obj = Response.AsObject();
+ std::string_view Path = Obj["path"sv].AsString();
+ ZEN_CONSOLE("recording started: " ZEN_BRIGHT_GREEN("{}"), Path);
+ }
+ else
+ {
+ Response.ThrowError("Failed to start recording");
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+// ComputeRecordStopSubCmd
+
+ComputeRecordStopSubCmd::ComputeRecordStopSubCmd() : ZenSubCmdBase("record-stop", "Stop recording compute actions")
+{
+ SubOptions().add_option("", "u", "hosturl", ZenCmdBase::kHostUrlHelp, cxxopts::value(m_HostName)->default_value(""), "<hosturl>");
+}
+
+void
+ComputeRecordStopSubCmd::Run(const ZenCliOptions& GlobalOptions)
+{
+ ZEN_UNUSED(GlobalOptions);
+
+ m_HostName = ZenCmdBase::ResolveTargetHostSpec(m_HostName);
+ if (m_HostName.empty())
+ {
+ throw OptionParseException("Unable to resolve server specification", SubOptions().help());
+ }
+
+ HttpClient Http = ZenCmdBase::CreateHttpClient(m_HostName);
+ if (HttpClient::Response Response = Http.Post("/compute/record/stop"sv, HttpClient::KeyValueMap{}, HttpClient::KeyValueMap{}))
+ {
+ CbObject Obj = Response.AsObject();
+ std::string_view Path = Obj["path"sv].AsString();
+ ZEN_CONSOLE("recording stopped: " ZEN_BRIGHT_GREEN("{}"), Path);
+ }
+ else
+ {
+ Response.ThrowError("Failed to stop recording");
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+// ComputeCommand
+
+ComputeCommand::ComputeCommand()
+{
+ m_Options.add_options()("h,help", "Print help");
+ m_Options.add_option("__hidden__", "", "subcommand", "", cxxopts::value<std::string>(m_SubCommand)->default_value(""), "");
+ m_Options.parse_positional({"subcommand"});
+
+ AddSubCommand(m_RecordStartSubCmd);
+ AddSubCommand(m_RecordStopSubCmd);
+}
+
+ComputeCommand::~ComputeCommand() = default;
+
+} // namespace zen
+
+#endif // ZEN_WITH_COMPUTE_SERVICES
diff --git a/src/zen/cmds/compute_cmd.h b/src/zen/cmds/compute_cmd.h
new file mode 100644
index 000000000..b26f639c4
--- /dev/null
+++ b/src/zen/cmds/compute_cmd.h
@@ -0,0 +1,53 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "../zen.h"
+
+#include <string>
+
+#if ZEN_WITH_COMPUTE_SERVICES
+
+namespace zen {
+
+class ComputeRecordStartSubCmd : public ZenSubCmdBase
+{
+public:
+ ComputeRecordStartSubCmd();
+ void Run(const ZenCliOptions& GlobalOptions) override;
+
+private:
+ std::string m_HostName;
+};
+
+class ComputeRecordStopSubCmd : public ZenSubCmdBase
+{
+public:
+ ComputeRecordStopSubCmd();
+ void Run(const ZenCliOptions& GlobalOptions) override;
+
+private:
+ std::string m_HostName;
+};
+
+class ComputeCommand : public ZenCmdWithSubCommands
+{
+public:
+ static constexpr char Name[] = "compute";
+ static constexpr char Description[] = "Compute service operations";
+
+ ComputeCommand();
+ ~ComputeCommand();
+
+ cxxopts::Options& Options() override { return m_Options; }
+
+private:
+ cxxopts::Options m_Options{Name, Description};
+ std::string m_SubCommand;
+ ComputeRecordStartSubCmd m_RecordStartSubCmd;
+ ComputeRecordStopSubCmd m_RecordStopSubCmd;
+};
+
+} // namespace zen
+
+#endif // ZEN_WITH_COMPUTE_SERVICES
diff --git a/src/zen/cmds/exec_cmd.cpp b/src/zen/cmds/exec_cmd.cpp
index 9719fce77..89bbf3638 100644
--- a/src/zen/cmds/exec_cmd.cpp
+++ b/src/zen/cmds/exec_cmd.cpp
@@ -23,6 +23,8 @@
#include <zenhttp/httpclient.h>
#include <zenhttp/packageformat.h>
+#include "../progressbar.h"
+
#include <EASTL/hash_map.h>
#include <EASTL/hash_set.h>
#include <EASTL/map.h>
@@ -124,13 +126,14 @@ struct ExecSessionConfig
std::vector<ExecFunctionDefinition>& FunctionList; // mutable for EmitFunctionListOnce
std::string_view OrchestratorUrl;
const std::filesystem::path& OutputPath;
- int Offset = 0;
- int Stride = 1;
- int Limit = 0;
- bool Verbose = false;
- bool Quiet = false;
- bool DumpActions = false;
- bool Binary = false;
+ int Offset = 0;
+ int Stride = 1;
+ int Limit = 0;
+ bool Verbose = false;
+ bool Quiet = false;
+ bool DumpActions = false;
+ bool Binary = false;
+ ProgressBar::Mode ProgressMode = ProgressBar::Mode::PrettyScroll;
};
//////////////////////////////////////////////////////////////////////////
@@ -345,8 +348,6 @@ ExecSessionRunner::DrainCompletedJobs()
}
m_PendingJobs.Remove(CompleteLsn);
-
- ZEN_CONSOLE("completed: LSN {} ({} still pending)", CompleteLsn, m_PendingJobs.GetSize());
}
}
}
@@ -897,17 +898,20 @@ ExecSessionRunner::Run()
// Then submit work items
- int FailedWorkCounter = 0;
- size_t RemainingWorkItems = m_Config.RecordingReader.GetActionCount();
- int SubmittedWorkItems = 0;
+ std::atomic<int> FailedWorkCounter{0};
+ std::atomic<size_t> RemainingWorkItems{m_Config.RecordingReader.GetActionCount()};
+ std::atomic<int> SubmittedWorkItems{0};
+ size_t TotalWorkItems = RemainingWorkItems.load();
- ZEN_CONSOLE("submitting {} work items", RemainingWorkItems);
+ ProgressBar SubmitProgress(m_Config.ProgressMode, "Submit");
+ SubmitProgress.UpdateState({.Task = "Submitting work items", .TotalCount = TotalWorkItems, .RemainingCount = RemainingWorkItems.load()},
+ false);
int OffsetCounter = m_Config.Offset;
int StrideCounter = m_Config.Stride;
auto ShouldSchedule = [&]() -> bool {
- if (m_Config.Limit && SubmittedWorkItems >= m_Config.Limit)
+ if (m_Config.Limit && SubmittedWorkItems.load() >= m_Config.Limit)
{
// Limit reached, ignore
@@ -1005,17 +1009,14 @@ ExecSessionRunner::Run()
{
const int32_t LsnField = EnqueueResult.Lsn;
- --RemainingWorkItems;
- ++SubmittedWorkItems;
+ size_t Remaining = --RemainingWorkItems;
+ int Submitted = ++SubmittedWorkItems;
- if (!m_Config.Quiet)
- {
- ZEN_CONSOLE("submitted work item #{} - LSN {} - {}. {} remaining",
- SubmittedWorkItems,
- LsnField,
- NiceTimeSpanMs(SubmitTimer.GetElapsedTimeMs()),
- RemainingWorkItems);
- }
+ SubmitProgress.UpdateState({.Task = "Submitting work items",
+ .Details = fmt::format("#{} LSN {}", Submitted, LsnField),
+ .TotalCount = TotalWorkItems,
+ .RemainingCount = Remaining},
+ false);
if (!m_Config.OutputPath.empty())
{
@@ -1055,22 +1056,36 @@ ExecSessionRunner::Run()
},
TargetParallelism);
+ SubmitProgress.Finish();
+
// Wait until all pending work is complete
+ size_t TotalPendingJobs = m_PendingJobs.GetSize();
+
+ ProgressBar CompletionProgress(m_Config.ProgressMode, "Execute");
+
while (!m_PendingJobs.IsEmpty())
{
- // TODO: improve this logic
- zen::Sleep(500);
+ size_t PendingCount = m_PendingJobs.GetSize();
+ CompletionProgress.UpdateState({.Task = "Executing work items",
+ .Details = fmt::format("{} completed, {} remaining", TotalPendingJobs - PendingCount, PendingCount),
+ .TotalCount = TotalPendingJobs,
+ .RemainingCount = PendingCount},
+ false);
+
+ zen::Sleep(GetUpdateDelayMS(m_Config.ProgressMode));
DrainCompletedJobs();
SendOrchestratorHeartbeat();
}
+ CompletionProgress.Finish();
+
// Write summary files
WriteSummaryFiles();
- if (FailedWorkCounter)
+ if (FailedWorkCounter.load())
{
return 1;
}
@@ -1423,6 +1438,16 @@ ExecCommand::OnParentOptionsParsed(const ZenCliOptions& GlobalOptions)
int
ExecCommand::RunSession(zen::compute::ComputeServiceSession& ComputeSession, std::string_view OrchestratorUrl)
{
+ ProgressBar::Mode ProgressMode = ProgressBar::Mode::Pretty;
+ if (m_VerboseLogging)
+ {
+ ProgressMode = ProgressBar::Mode::Plain;
+ }
+ else if (m_QuietLogging)
+ {
+ ProgressMode = ProgressBar::Mode::Quiet;
+ }
+
ExecSessionConfig Config{
.Resolver = *m_ChunkResolver,
.RecordingReader = *m_RecordingReader,
@@ -1437,6 +1462,7 @@ ExecCommand::RunSession(zen::compute::ComputeServiceSession& ComputeSession, std
.Quiet = m_QuietLogging,
.DumpActions = m_DumpActions,
.Binary = m_Binary,
+ .ProgressMode = ProgressMode,
};
ExecSessionRunner Runner(ComputeSession, Config);
diff --git a/src/zen/zen.cpp b/src/zen/zen.cpp
index 3277eb856..bbf6b4f8a 100644
--- a/src/zen/zen.cpp
+++ b/src/zen/zen.cpp
@@ -9,6 +9,7 @@
#include "cmds/bench_cmd.h"
#include "cmds/builds_cmd.h"
#include "cmds/cache_cmd.h"
+#include "cmds/compute_cmd.h"
#include "cmds/copy_cmd.h"
#include "cmds/dedup_cmd.h"
#include "cmds/exec_cmd.h"
@@ -588,7 +589,8 @@ main(int argc, char** argv)
DropCommand DropCmd;
DropProjectCommand ProjectDropCmd;
#if ZEN_WITH_COMPUTE_SERVICES
- ExecCommand ExecCmd;
+ ComputeCommand ComputeCmd;
+ ExecCommand ExecCmd;
#endif // ZEN_WITH_COMPUTE_SERVICES
ExportOplogCommand ExportOplogCmd;
FlushCommand FlushCmd;
@@ -649,6 +651,7 @@ main(int argc, char** argv)
{DownCommand::Name, &DownCmd, DownCommand::Description},
{DropCommand::Name, &DropCmd, DropCommand::Description},
#if ZEN_WITH_COMPUTE_SERVICES
+ {ComputeCommand::Name, &ComputeCmd, ComputeCommand::Description},
{ExecCommand::Name, &ExecCmd, ExecCommand::Description},
#endif
{GcStatusCommand::Name, &GcStatusCmd, GcStatusCommand::Description},
diff --git a/src/zencompute/CLAUDE.md b/src/zencompute/CLAUDE.md
index 750879d5a..bb574edc2 100644
--- a/src/zencompute/CLAUDE.md
+++ b/src/zencompute/CLAUDE.md
@@ -218,6 +218,13 @@ Worker handler logic is extracted into private helpers (`HandleWorkersGet`, `Han
**Locking discipline:** The three action maps (`m_PendingActions`, `m_RunningMap`, `m_ResultsMap`) are guarded by a single `m_ActionMapLock`. This eliminates lock-ordering concerns between maps and prevents actions from being temporarily absent from all maps during state transitions. Runner-level `m_RunningLock` in `LocalProcessRunner` / `RemoteHttpRunner` is a separate lock on a different class — unrelated to the session-level action map lock.
+**Lock ordering:** When acquiring multiple session-level locks, always acquire in this order to avoid deadlocks:
+1. `m_ActionMapLock` (session action maps)
+2. `QueueEntry::m_Lock` (per-queue state)
+3. `m_ActionHistoryLock` (action history ring)
+
+Never acquire an earlier lock while holding a later one (e.g. never acquire `m_ActionMapLock` while holding `QueueEntry::m_Lock`).
+
**Atomic fields** for counters and simple state: queue counts, `CpuUsagePercent`, `CpuSeconds`, `RetryCount`, `RunnerAction::m_ActionState`.
**Update decoupling:** Runners call `PostUpdate(RunnerAction*)` rather than directly mutating service state. The scheduler thread batches and deduplicates updates.
diff --git a/src/zencompute/computeservice.cpp b/src/zencompute/computeservice.cpp
index aaf34cbe2..852e93fa0 100644
--- a/src/zencompute/computeservice.cpp
+++ b/src/zencompute/computeservice.cpp
@@ -121,6 +121,8 @@ struct ComputeServiceSession::Impl
, m_LocalSubmitPool(GetLargeWorkerPool(EWorkloadType::Burst))
, m_RemoteSubmitPool(GetLargeWorkerPool(EWorkloadType::Burst))
{
+ m_RemoteRunnerGroup.SetWorkerPool(&m_RemoteSubmitPool);
+
// Create a non-expiring, non-deletable implicit queue for legacy endpoints
auto Result = CreateQueue("implicit"sv, {}, {});
m_ImplicitQueueId = Result.QueueId;
@@ -240,8 +242,9 @@ struct ComputeServiceSession::Impl
// Recording
- void StartRecording(ChunkResolver& InCidStore, const std::filesystem::path& RecordingPath);
- void StopRecording();
+ bool StartRecording(ChunkResolver& InCidStore, const std::filesystem::path& RecordingPath);
+ bool StopRecording();
+ bool IsRecording() const;
std::unique_ptr<ActionRecorder> m_Recorder;
@@ -615,6 +618,7 @@ ComputeServiceSession::Impl::UpdateCoordinatorState()
m_KnownWorkerUris.insert(UriStr);
auto* NewRunner = new RemoteHttpRunner(m_ChunkResolver, m_OrchestratorBasePath, UriStr, m_RemoteSubmitPool);
+ NewRunner->SetRemoteHostname(Hostname);
SyncWorkersToRunner(*NewRunner);
m_RemoteRunnerGroup.AddRunner(NewRunner);
}
@@ -716,24 +720,44 @@ ComputeServiceSession::Impl::ShutdownRunners()
m_RemoteRunnerGroup.Shutdown();
}
-void
+bool
ComputeServiceSession::Impl::StartRecording(ChunkResolver& InCidStore, const std::filesystem::path& RecordingPath)
{
+ if (m_Recorder)
+ {
+ ZEN_WARN("recording is already active");
+ return false;
+ }
+
ZEN_INFO("starting recording to '{}'", RecordingPath);
m_Recorder = std::make_unique<ActionRecorder>(InCidStore, RecordingPath);
ZEN_INFO("started recording to '{}'", RecordingPath);
+ return true;
}
-void
+bool
ComputeServiceSession::Impl::StopRecording()
{
+ if (!m_Recorder)
+ {
+ ZEN_WARN("no recording is active");
+ return false;
+ }
+
ZEN_INFO("stopping recording");
m_Recorder = nullptr;
ZEN_INFO("stopped recording");
+ return true;
+}
+
+bool
+ComputeServiceSession::Impl::IsRecording() const
+{
+ return m_Recorder != nullptr;
}
std::vector<ComputeServiceSession::RunningActionInfo>
@@ -1128,6 +1152,10 @@ ComputeServiceSession::Impl::GetCompleted(CbWriter& Cbo)
Cbo.BeginObject();
Cbo << "lsn"sv << Lsn;
Cbo << "state"sv << RunnerAction::ToString(Action->ActionState());
+ if (!Action->FailureReason.empty())
+ {
+ Cbo << "reason"sv << Action->FailureReason;
+ }
Cbo.EndObject();
}
});
@@ -1416,8 +1444,8 @@ ComputeServiceSession::Impl::GetQueueCompleted(int QueueId, CbWriter& Cbo)
if (Queue)
{
- Queue->m_Lock.WithSharedLock([&] {
- m_ActionMapLock.WithSharedLock([&] {
+ m_ActionMapLock.WithSharedLock([&] {
+ Queue->m_Lock.WithSharedLock([&] {
for (int Lsn : Queue->FinishedLsns)
{
if (m_ResultsMap.contains(Lsn))
@@ -1530,12 +1558,12 @@ ComputeServiceSession::Impl::SchedulePendingActions()
static Stopwatch DumpRunningTimer;
auto _ = MakeGuard([&] {
- ZEN_INFO("scheduled {} pending actions. {} running ({} retired), {} still pending, {} results",
- ScheduledCount,
- RunningCount,
- m_RetiredCount.load(),
- PendingCount,
- ResultCount);
+ ZEN_DEBUG("scheduled {} pending actions. {} running ({} retired), {} still pending, {} results",
+ ScheduledCount,
+ RunningCount,
+ m_RetiredCount.load(),
+ PendingCount,
+ ResultCount);
if (DumpRunningTimer.GetElapsedTimeMs() > 30000)
{
@@ -1584,13 +1612,13 @@ ComputeServiceSession::Impl::SchedulePendingActions()
// Also note that the m_PendingActions list is not maintained
// here, that's done periodically in SchedulePendingActions()
- m_ActionMapLock.WithExclusiveLock([&] {
- if (m_SessionState.load(std::memory_order_relaxed) >= SessionState::Paused)
- {
- return;
- }
+ // Extract pending actions under a shared lock — we only need to read
+ // the map and take Ref copies. ActionState() is atomic so this is safe.
+ // Sorting and capacity trimming happen outside the lock to avoid
+ // blocking HTTP handlers on O(N log N) work with large pending queues.
- if (m_PendingActions.empty())
+ m_ActionMapLock.WithSharedLock([&] {
+ if (m_SessionState.load(std::memory_order_relaxed) >= SessionState::Paused)
{
return;
}
@@ -1610,6 +1638,7 @@ ComputeServiceSession::Impl::SchedulePendingActions()
case RunnerAction::State::Completed:
case RunnerAction::State::Failed:
case RunnerAction::State::Abandoned:
+ case RunnerAction::State::Rejected:
case RunnerAction::State::Cancelled:
break;
@@ -1620,30 +1649,30 @@ ComputeServiceSession::Impl::SchedulePendingActions()
}
}
- // Sort by priority descending, then by LSN ascending (FIFO within same priority)
- std::sort(ActionsToSchedule.begin(), ActionsToSchedule.end(), [](const Ref<RunnerAction>& A, const Ref<RunnerAction>& B) {
- if (A->Priority != B->Priority)
- {
- return A->Priority > B->Priority;
- }
- return A->ActionLsn < B->ActionLsn;
- });
+ PendingCount = m_PendingActions.size();
+ });
- if (ActionsToSchedule.size() > Capacity)
+ // Sort by priority descending, then by LSN ascending (FIFO within same priority)
+ std::sort(ActionsToSchedule.begin(), ActionsToSchedule.end(), [](const Ref<RunnerAction>& A, const Ref<RunnerAction>& B) {
+ if (A->Priority != B->Priority)
{
- ActionsToSchedule.resize(Capacity);
+ return A->Priority > B->Priority;
}
-
- PendingCount = m_PendingActions.size();
+ return A->ActionLsn < B->ActionLsn;
});
+ if (ActionsToSchedule.size() > Capacity)
+ {
+ ActionsToSchedule.resize(Capacity);
+ }
+
if (ActionsToSchedule.empty())
{
_.Dismiss();
return;
}
- ZEN_INFO("attempting schedule of {} pending actions", ActionsToSchedule.size());
+ ZEN_DEBUG("attempting schedule of {} pending actions", ActionsToSchedule.size());
Stopwatch SubmitTimer;
std::vector<SubmitResult> SubmitResults = SubmitActions(ActionsToSchedule);
@@ -1663,10 +1692,10 @@ ComputeServiceSession::Impl::SchedulePendingActions()
}
}
- ZEN_INFO("scheduled {} pending actions in {} ({} rejected)",
- ScheduledActionCount,
- NiceTimeSpanMs(SubmitTimer.GetElapsedTimeMs()),
- NotAcceptedCount);
+ ZEN_DEBUG("scheduled {} pending actions in {} ({} rejected)",
+ ScheduledActionCount,
+ NiceTimeSpanMs(SubmitTimer.GetElapsedTimeMs()),
+ NotAcceptedCount);
ScheduledCount += ScheduledActionCount;
PendingCount -= ScheduledActionCount;
@@ -1975,6 +2004,14 @@ ComputeServiceSession::Impl::HandleActionUpdates()
break;
}
+ // Rejected — runner was at capacity, reschedule without retry cost
+ case RunnerAction::State::Rejected:
+ {
+ Action->ResetActionStateToPending();
+ ZEN_DEBUG("action {} ({}) rescheduled after runner rejection", Action->ActionId, ActionLsn);
+ break;
+ }
+
// Terminal states — move to results, record history, notify queue
case RunnerAction::State::Completed:
case RunnerAction::State::Failed:
@@ -2009,6 +2046,14 @@ ComputeServiceSession::Impl::HandleActionUpdates()
MaxRetries);
break;
}
+ else
+ {
+ ZEN_WARN("action {} ({}) {} after {} retries, not rescheduling",
+ Action->ActionId,
+ ActionLsn,
+ RunnerAction::ToString(TerminalState),
+ Action->RetryCount.load(std::memory_order_relaxed));
+ }
}
m_ActionMapLock.WithExclusiveLock([&] {
@@ -2101,10 +2146,9 @@ ComputeServiceSession::Impl::SubmitActions(const std::vector<Ref<RunnerAction>>&
ZEN_TRACE_CPU("ComputeServiceSession::SubmitActions");
std::vector<SubmitResult> Results(Actions.size());
- // First try submitting the batch to local runners in parallel
+ // First try submitting the batch to local runners
std::vector<SubmitResult> LocalResults = m_LocalRunnerGroup.SubmitActions(Actions);
- std::vector<size_t> RemoteIndices;
std::vector<Ref<RunnerAction>> RemoteActions;
for (size_t i = 0; i < Actions.size(); ++i)
@@ -2115,20 +2159,40 @@ ComputeServiceSession::Impl::SubmitActions(const std::vector<Ref<RunnerAction>>&
}
else
{
- RemoteIndices.push_back(i);
RemoteActions.push_back(Actions[i]);
+ Results[i] = SubmitResult{.IsAccepted = true, .Reason = "dispatched to remote"};
}
}
- // Submit remaining actions to remote runners in parallel
+ // Dispatch remaining actions to remote runners asynchronously.
+ // Mark actions as Submitting so the scheduler won't re-pick them.
+ // The remote runner will transition them to Running on success, or
+ // we mark them Failed on rejection so HandleActionUpdates retries.
if (!RemoteActions.empty())
{
- std::vector<SubmitResult> RemoteResults = m_RemoteRunnerGroup.SubmitActions(RemoteActions);
-
- for (size_t j = 0; j < RemoteIndices.size(); ++j)
+ for (const Ref<RunnerAction>& Action : RemoteActions)
{
- Results[RemoteIndices[j]] = std::move(RemoteResults[j]);
+ Action->SetActionState(RunnerAction::State::Submitting);
}
+
+ m_RemoteSubmitPool.ScheduleWork(
+ [this, RemoteActions = std::move(RemoteActions)]() {
+ std::vector<SubmitResult> RemoteResults = m_RemoteRunnerGroup.SubmitActions(RemoteActions);
+
+ for (size_t j = 0; j < RemoteResults.size(); ++j)
+ {
+ if (!RemoteResults[j].IsAccepted)
+ {
+ ZEN_DEBUG("remote submission rejected for action {} ({}): {}",
+ RemoteActions[j]->ActionId,
+ RemoteActions[j]->ActionLsn,
+ RemoteResults[j].Reason);
+
+ RemoteActions[j]->SetActionState(RunnerAction::State::Rejected);
+ }
+ }
+ },
+ WorkerThreadPool::EMode::EnableBacklog);
}
return Results;
@@ -2194,16 +2258,22 @@ ComputeServiceSession::NotifyOrchestratorChanged()
m_Impl->NotifyOrchestratorChanged();
}
-void
+bool
ComputeServiceSession::StartRecording(ChunkResolver& InResolver, const std::filesystem::path& RecordingPath)
{
- m_Impl->StartRecording(InResolver, RecordingPath);
+ return m_Impl->StartRecording(InResolver, RecordingPath);
}
-void
+bool
ComputeServiceSession::StopRecording()
{
- m_Impl->StopRecording();
+ return m_Impl->StopRecording();
+}
+
+bool
+ComputeServiceSession::IsRecording() const
+{
+ return m_Impl->IsRecording();
}
ComputeServiceSession::ActionCounts
diff --git a/src/zencompute/httpcomputeservice.cpp b/src/zencompute/httpcomputeservice.cpp
index 6cb975dd3..8cbb25afd 100644
--- a/src/zencompute/httpcomputeservice.cpp
+++ b/src/zencompute/httpcomputeservice.cpp
@@ -62,6 +62,8 @@ struct HttpComputeService::Impl
RwLock m_WsConnectionsLock;
std::vector<Ref<WebSocketConnection>> m_WsConnections;
+ std::function<void()> m_ShutdownCallback;
+
// Metrics
metrics::OperationTiming m_HttpRequests;
@@ -190,6 +192,65 @@ HttpComputeService::Impl::RegisterRoutes()
HttpVerb::kPost);
m_Router.RegisterRoute(
+ "session/drain",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+
+ if (m_ComputeService.RequestStateTransition(ComputeServiceSession::SessionState::Draining))
+ {
+ CbObjectWriter Cbo;
+ Cbo << "state"sv << ToString(m_ComputeService.GetSessionState());
+ return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save());
+ }
+
+ CbObjectWriter Cbo;
+ Cbo << "error"sv
+ << "Cannot transition to Draining from current state"sv;
+ HttpReq.WriteResponse(HttpResponseCode::Conflict, Cbo.Save());
+ },
+ HttpVerb::kPost);
+
+ m_Router.RegisterRoute(
+ "session/status",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+
+ CbObjectWriter Cbo;
+ Cbo << "state"sv << ToString(m_ComputeService.GetSessionState());
+ auto Counts = m_ComputeService.GetActionCounts();
+ Cbo << "actions_pending"sv << Counts.Pending;
+ Cbo << "actions_running"sv << Counts.Running;
+ Cbo << "actions_completed"sv << Counts.Completed;
+ HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save());
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "session/sunset",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+
+ if (m_ComputeService.RequestStateTransition(ComputeServiceSession::SessionState::Sunset))
+ {
+ CbObjectWriter Cbo;
+ Cbo << "state"sv << ToString(m_ComputeService.GetSessionState());
+ HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save());
+
+ if (m_ShutdownCallback)
+ {
+ m_ShutdownCallback();
+ }
+ return;
+ }
+
+ CbObjectWriter Cbo;
+ Cbo << "error"sv
+ << "Cannot transition to Sunset from current state"sv;
+ HttpReq.WriteResponse(HttpResponseCode::Conflict, Cbo.Save());
+ },
+ HttpVerb::kPost);
+
+ m_Router.RegisterRoute(
"workers",
[this](HttpRouterRequest& Req) { HandleWorkersGet(Req.ServerRequest()); },
HttpVerb::kGet);
@@ -506,9 +567,19 @@ HttpComputeService::Impl::RegisterRoutes()
return HttpReq.WriteResponse(HttpResponseCode::Forbidden);
}
- m_ComputeService.StartRecording(m_CombinedResolver, m_BaseDir / "recording");
+ std::filesystem::path RecordingPath = m_BaseDir / "recording";
- return HttpReq.WriteResponse(HttpResponseCode::OK);
+ if (!m_ComputeService.StartRecording(m_CombinedResolver, RecordingPath))
+ {
+ CbObjectWriter Cbo;
+ Cbo << "error"
+ << "recording is already active";
+ return HttpReq.WriteResponse(HttpResponseCode::Conflict, Cbo.Save());
+ }
+
+ CbObjectWriter Cbo;
+ Cbo << "path" << RecordingPath.string();
+ return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save());
},
HttpVerb::kPost);
@@ -522,9 +593,19 @@ HttpComputeService::Impl::RegisterRoutes()
return HttpReq.WriteResponse(HttpResponseCode::Forbidden);
}
- m_ComputeService.StopRecording();
+ std::filesystem::path RecordingPath = m_BaseDir / "recording";
+
+ if (!m_ComputeService.StopRecording())
+ {
+ CbObjectWriter Cbo;
+ Cbo << "error"
+ << "no recording is active";
+ return HttpReq.WriteResponse(HttpResponseCode::Conflict, Cbo.Save());
+ }
- return HttpReq.WriteResponse(HttpResponseCode::OK);
+ CbObjectWriter Cbo;
+ Cbo << "path" << RecordingPath.string();
+ return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save());
},
HttpVerb::kPost);
@@ -1066,6 +1147,12 @@ HttpComputeService::GetActionCounts()
return m_Impl->m_ComputeService.GetActionCounts();
}
+void
+HttpComputeService::SetShutdownCallback(std::function<void()> Callback)
+{
+ m_Impl->m_ShutdownCallback = std::move(Callback);
+}
+
const char*
HttpComputeService::BaseUri() const
{
diff --git a/src/zencompute/httporchestrator.cpp b/src/zencompute/httporchestrator.cpp
index d92af8716..1f51e560e 100644
--- a/src/zencompute/httporchestrator.cpp
+++ b/src/zencompute/httporchestrator.cpp
@@ -7,6 +7,7 @@
# include <zencompute/orchestratorservice.h>
# include <zencore/compactbinarybuilder.h>
# include <zencore/logging.h>
+# include <zencore/session.h>
# include <zencore/string.h>
# include <zencore/system.h>
@@ -77,10 +78,47 @@ ParseWorkerAnnouncement(const CbObjectView& Data, OrchestratorService::WorkerAnn
return Ann.Id;
}
+static OrchestratorService::WorkerAnnotator
+MakeWorkerAnnotator(IProvisionerStateProvider* Prov)
+{
+ if (!Prov)
+ {
+ return {};
+ }
+ return [Prov](std::string_view WorkerId, CbObjectWriter& Cbo) {
+ AgentProvisioningStatus Status = Prov->GetAgentStatus(WorkerId);
+ if (Status != AgentProvisioningStatus::Unknown)
+ {
+ const char* StatusStr = (Status == AgentProvisioningStatus::Draining) ? "draining" : "active";
+ Cbo << "provisioner_status" << std::string_view(StatusStr);
+ }
+ };
+}
+
+bool
+HttpOrchestratorService::ValidateCoordinatorSession(const CbObjectView& Data, std::string_view WorkerId)
+{
+ std::string_view SessionStr = Data["coordinator_session"].AsString("");
+ if (SessionStr.empty())
+ {
+ return true; // backwards compatibility: accept announcements without a session
+ }
+ Oid Session = Oid::TryFromHexString(SessionStr);
+ if (Session == m_SessionId)
+ {
+ return true;
+ }
+ ZEN_WARN("rejecting stale announcement from '{}' (session {} != {})", WorkerId, SessionStr, m_SessionId.ToString());
+ return false;
+}
+
HttpOrchestratorService::HttpOrchestratorService(std::filesystem::path DataDir, bool EnableWorkerWebSocket)
: m_Service(std::make_unique<OrchestratorService>(std::move(DataDir), EnableWorkerWebSocket))
, m_Hostname(GetMachineName())
{
+ m_SessionId = zen::GetSessionId();
+ ZEN_INFO("orchestrator session id: {}", m_SessionId.ToString());
+
m_Router.AddMatcher("workerid", [](std::string_view Segment) { return IsValidWorkerId(Segment); });
m_Router.AddMatcher("clientid", [](std::string_view Segment) { return IsValidWorkerId(Segment); });
@@ -95,13 +133,17 @@ HttpOrchestratorService::HttpOrchestratorService(std::filesystem::path DataDir,
[this](HttpRouterRequest& Req) {
CbObjectWriter Cbo;
Cbo << "hostname" << std::string_view(m_Hostname);
+ Cbo << "session_id" << m_SessionId.ToString();
Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Cbo.Save());
},
HttpVerb::kGet);
m_Router.RegisterRoute(
"provision",
- [this](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK, m_Service->GetWorkerList()); },
+ [this](HttpRouterRequest& Req) {
+ Req.ServerRequest().WriteResponse(HttpResponseCode::OK,
+ m_Service->GetWorkerList(MakeWorkerAnnotator(m_Provisioner.load(std::memory_order_acquire))));
+ },
HttpVerb::kPost);
m_Router.RegisterRoute(
@@ -122,6 +164,11 @@ HttpOrchestratorService::HttpOrchestratorService(std::filesystem::path DataDir,
"characters and uri must start with http:// or https://");
}
+ if (!ValidateCoordinatorSession(Data, WorkerId))
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::Conflict, HttpContentType::kText, "Stale coordinator session");
+ }
+
m_Service->AnnounceWorker(Ann);
HttpReq.WriteResponse(HttpResponseCode::OK);
@@ -135,7 +182,10 @@ HttpOrchestratorService::HttpOrchestratorService(std::filesystem::path DataDir,
m_Router.RegisterRoute(
"agents",
- [this](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK, m_Service->GetWorkerList()); },
+ [this](HttpRouterRequest& Req) {
+ Req.ServerRequest().WriteResponse(HttpResponseCode::OK,
+ m_Service->GetWorkerList(MakeWorkerAnnotator(m_Provisioner.load(std::memory_order_acquire))));
+ },
HttpVerb::kGet);
m_Router.RegisterRoute(
@@ -241,6 +291,59 @@ HttpOrchestratorService::HttpOrchestratorService(std::filesystem::path DataDir,
},
HttpVerb::kGet);
+ // Provisioner endpoints
+
+ m_Router.RegisterRoute(
+ "provisioner/status",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+
+ CbObjectWriter Cbo;
+ if (IProvisionerStateProvider* Prov = m_Provisioner.load(std::memory_order_acquire))
+ {
+ Cbo << "name" << Prov->GetName();
+ Cbo << "target_cores" << Prov->GetTargetCoreCount();
+ Cbo << "estimated_cores" << Prov->GetEstimatedCoreCount();
+ Cbo << "active_cores" << Prov->GetActiveCoreCount();
+ Cbo << "agents" << Prov->GetAgentCount();
+ Cbo << "agents_draining" << Prov->GetDrainingAgentCount();
+ }
+ HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save());
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "provisioner/target",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+
+ CbObject Data = HttpReq.ReadPayloadObject();
+ int32_t Cores = Data["target_cores"].AsInt32(-1);
+
+ ZEN_INFO("provisioner/target: received request (target_cores={}, payload_valid={})", Cores, Data ? true : false);
+
+ if (Cores < 0)
+ {
+ ZEN_WARN("provisioner/target: bad request (target_cores={})", Cores);
+ return HttpReq.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "Missing or invalid target_cores field");
+ }
+
+ IProvisionerStateProvider* Prov = m_Provisioner.load(std::memory_order_acquire);
+ if (!Prov)
+ {
+ ZEN_WARN("provisioner/target: no provisioner configured");
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound, HttpContentType::kText, "No provisioner configured");
+ }
+
+ ZEN_INFO("provisioner/target: setting target to {} cores", Cores);
+ Prov->SetTargetCoreCount(static_cast<uint32_t>(Cores));
+
+ CbObjectWriter Cbo;
+ Cbo << "target_cores" << Prov->GetTargetCoreCount();
+ HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save());
+ },
+ HttpVerb::kPost);
+
// Client tracking endpoints
m_Router.RegisterRoute(
@@ -411,6 +514,13 @@ HttpOrchestratorService::HandleRequest(HttpServerRequest& Request)
}
}
+void
+HttpOrchestratorService::SetProvisionerStateProvider(IProvisionerStateProvider* Provider)
+{
+ m_Provisioner.store(Provider, std::memory_order_release);
+ m_Service->SetProvisionerStateProvider(Provider);
+}
+
//////////////////////////////////////////////////////////////////////////
//
// IWebSocketHandler
@@ -488,6 +598,11 @@ HttpOrchestratorService::HandleWorkerWebSocketMessage(const WebSocketMessage& Ms
return {};
}
+ if (!ValidateCoordinatorSession(Data, WorkerId))
+ {
+ return {};
+ }
+
m_Service->AnnounceWorker(Ann);
return std::string(WorkerId);
}
@@ -563,7 +678,7 @@ HttpOrchestratorService::PushThreadFunction()
}
// Build combined JSON with worker list, provisioning history, clients, and client history
- CbObject WorkerList = m_Service->GetWorkerList();
+ CbObject WorkerList = m_Service->GetWorkerList(MakeWorkerAnnotator(m_Provisioner.load(std::memory_order_acquire)));
CbObject History = m_Service->GetProvisioningHistory(50);
CbObject ClientList = m_Service->GetClientList();
CbObject ClientHistory = m_Service->GetClientHistory(50);
@@ -615,6 +730,20 @@ HttpOrchestratorService::PushThreadFunction()
JsonBuilder.Append(ClientHistoryJsonView.substr(1, ClientHistoryJsonView.size() - 2));
}
+ // Emit provisioner stats if available
+ if (IProvisionerStateProvider* Prov = m_Provisioner.load(std::memory_order_acquire))
+ {
+ JsonBuilder.Append(
+ fmt::format(",\"provisioner\":{{\"name\":\"{}\",\"target_cores\":{},\"estimated_cores\":{}"
+ ",\"active_cores\":{},\"agents\":{},\"agents_draining\":{}}}",
+ Prov->GetName(),
+ Prov->GetTargetCoreCount(),
+ Prov->GetEstimatedCoreCount(),
+ Prov->GetActiveCoreCount(),
+ Prov->GetAgentCount(),
+ Prov->GetDrainingAgentCount()));
+ }
+
JsonBuilder.Append("}");
std::string_view Json = JsonBuilder.ToView();
diff --git a/src/zencompute/include/zencompute/computeservice.h b/src/zencompute/include/zencompute/computeservice.h
index ad556f546..97de4321a 100644
--- a/src/zencompute/include/zencompute/computeservice.h
+++ b/src/zencompute/include/zencompute/computeservice.h
@@ -279,7 +279,7 @@ public:
// sized to match RunnerAction::State::_Count but we can't use the enum here
// for dependency reasons, so just use a fixed size array and static assert in
// the implementation file
- uint64_t Timestamps[9] = {};
+ uint64_t Timestamps[10] = {};
};
[[nodiscard]] std::vector<ActionHistoryEntry> GetActionHistory(int Limit = 100);
@@ -305,8 +305,9 @@ public:
// Recording
- void StartRecording(ChunkResolver& InResolver, const std::filesystem::path& RecordingPath);
- void StopRecording();
+ bool StartRecording(ChunkResolver& InResolver, const std::filesystem::path& RecordingPath);
+ bool StopRecording();
+ bool IsRecording() const;
private:
void PostUpdate(RunnerAction* Action);
diff --git a/src/zencompute/include/zencompute/httpcomputeservice.h b/src/zencompute/include/zencompute/httpcomputeservice.h
index db3fce3c2..32f54f293 100644
--- a/src/zencompute/include/zencompute/httpcomputeservice.h
+++ b/src/zencompute/include/zencompute/httpcomputeservice.h
@@ -35,6 +35,10 @@ public:
void Shutdown();
+ /** Set a callback to be invoked when the session/sunset endpoint is hit.
+ * Typically wired to HttpServer::RequestExit() to shut down the process. */
+ void SetShutdownCallback(std::function<void()> Callback);
+
[[nodiscard]] ComputeServiceSession::ActionCounts GetActionCounts();
const char* BaseUri() const override;
diff --git a/src/zencompute/include/zencompute/httporchestrator.h b/src/zencompute/include/zencompute/httporchestrator.h
index 58b2c9152..ef0a1269a 100644
--- a/src/zencompute/include/zencompute/httporchestrator.h
+++ b/src/zencompute/include/zencompute/httporchestrator.h
@@ -2,10 +2,12 @@
#pragma once
+#include <zencompute/provisionerstate.h>
#include <zencompute/zencompute.h>
#include <zencore/logging.h>
#include <zencore/thread.h>
+#include <zencore/uid.h>
#include <zenhttp/httpserver.h>
#include <zenhttp/websocket.h>
@@ -65,6 +67,16 @@ public:
*/
void Shutdown();
+ /** Return the session ID generated at construction time. Provisioners
+ * pass this to spawned workers so the orchestrator can reject stale
+ * announcements from previous sessions. */
+ Oid GetSessionId() const { return m_SessionId; }
+
+ /** Register a provisioner whose target core count can be read and changed
+ * via the orchestrator HTTP API and dashboard. Caller retains ownership;
+ * the provider must outlive this service. */
+ void SetProvisionerStateProvider(IProvisionerStateProvider* Provider);
+
virtual const char* BaseUri() const override;
virtual void HandleRequest(HttpServerRequest& Request) override;
@@ -81,6 +93,11 @@ private:
std::unique_ptr<OrchestratorService> m_Service;
std::string m_Hostname;
+ Oid m_SessionId;
+ bool ValidateCoordinatorSession(const CbObjectView& Data, std::string_view WorkerId);
+
+ std::atomic<IProvisionerStateProvider*> m_Provisioner{nullptr};
+
// WebSocket push
#if ZEN_WITH_WEBSOCKETS
diff --git a/src/zencompute/include/zencompute/orchestratorservice.h b/src/zencompute/include/zencompute/orchestratorservice.h
index 549ff8e3c..2c49e22df 100644
--- a/src/zencompute/include/zencompute/orchestratorservice.h
+++ b/src/zencompute/include/zencompute/orchestratorservice.h
@@ -6,6 +6,7 @@
#if ZEN_WITH_COMPUTE_SERVICES
+# include <zencompute/provisionerstate.h>
# include <zencore/compactbinary.h>
# include <zencore/compactbinarybuilder.h>
# include <zencore/logbase.h>
@@ -90,9 +91,16 @@ public:
std::string Hostname;
};
- CbObject GetWorkerList();
+ /** Per-worker callback invoked during GetWorkerList serialization.
+ * The callback receives the worker ID and a CbObjectWriter positioned
+ * inside the worker's object, allowing the caller to append extra fields. */
+ using WorkerAnnotator = std::function<void(std::string_view WorkerId, CbObjectWriter& Cbo)>;
+
+ CbObject GetWorkerList(const WorkerAnnotator& Annotate = {});
void AnnounceWorker(const WorkerAnnouncement& Announcement);
+ void SetProvisionerStateProvider(IProvisionerStateProvider* Provider);
+
bool IsWorkerWebSocketEnabled() const;
void SetWorkerWebSocketConnected(std::string_view WorkerId, bool Connected);
@@ -171,6 +179,8 @@ private:
LoggerRef m_Log{"compute.orchestrator"};
bool m_EnableWorkerWebSocket = false;
+ std::atomic<IProvisionerStateProvider*> m_Provisioner{nullptr};
+
std::thread m_ProbeThread;
std::atomic<bool> m_ProbeThreadEnabled{true};
Event m_ProbeThreadEvent;
diff --git a/src/zencompute/include/zencompute/provisionerstate.h b/src/zencompute/include/zencompute/provisionerstate.h
new file mode 100644
index 000000000..e9af8a635
--- /dev/null
+++ b/src/zencompute/include/zencompute/provisionerstate.h
@@ -0,0 +1,38 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <cstdint>
+#include <string_view>
+
+namespace zen::compute {
+
+/** Per-agent provisioning status as seen by the provisioner. */
+enum class AgentProvisioningStatus
+{
+ Unknown, ///< Not known to the provisioner
+ Active, ///< Running and allocated
+ Draining, ///< Being gracefully deprovisioned
+};
+
+/** Abstract interface for querying and controlling a provisioner from the HTTP layer.
+ * This decouples the orchestrator service from specific provisioner implementations. */
+class IProvisionerStateProvider
+{
+public:
+ virtual ~IProvisionerStateProvider() = default;
+
+ virtual std::string_view GetName() const = 0; ///< e.g. "horde", "nomad"
+ virtual uint32_t GetTargetCoreCount() const = 0;
+ virtual uint32_t GetEstimatedCoreCount() const = 0;
+ virtual uint32_t GetActiveCoreCount() const = 0;
+ virtual uint32_t GetAgentCount() const = 0;
+ virtual uint32_t GetDrainingAgentCount() const { return 0; }
+ virtual void SetTargetCoreCount(uint32_t Count) = 0;
+
+ /** Return the provisioning status for a worker by its orchestrator ID
+ * (e.g. "horde-{LeaseId}"). Returns Unknown if the ID is not recognized. */
+ virtual AgentProvisioningStatus GetAgentStatus(std::string_view /*WorkerId*/) const { return AgentProvisioningStatus::Unknown; }
+};
+
+} // namespace zen::compute
diff --git a/src/zencompute/orchestratorservice.cpp b/src/zencompute/orchestratorservice.cpp
index 9ea695305..aee8fa63a 100644
--- a/src/zencompute/orchestratorservice.cpp
+++ b/src/zencompute/orchestratorservice.cpp
@@ -31,7 +31,7 @@ OrchestratorService::~OrchestratorService()
}
CbObject
-OrchestratorService::GetWorkerList()
+OrchestratorService::GetWorkerList(const WorkerAnnotator& Annotate)
{
ZEN_TRACE_CPU("OrchestratorService::GetWorkerList");
CbObjectWriter Cbo;
@@ -71,6 +71,10 @@ OrchestratorService::GetWorkerList()
Cbo << "ws_connected" << true;
}
Cbo << "dt" << Worker.LastSeen.GetElapsedTimeMs();
+ if (Annotate)
+ {
+ Annotate(WorkerId, Cbo);
+ }
Cbo.EndObject();
}
});
@@ -144,6 +148,12 @@ OrchestratorService::AnnounceWorker(const WorkerAnnouncement& Ann)
}
}
+void
+OrchestratorService::SetProvisionerStateProvider(IProvisionerStateProvider* Provider)
+{
+ m_Provisioner.store(Provider, std::memory_order_release);
+}
+
bool
OrchestratorService::IsWorkerWebSocketEnabled() const
{
@@ -607,6 +617,14 @@ OrchestratorService::ProbeThreadFunction()
continue;
}
+ // Check if the provisioner knows this worker is draining — if so,
+ // unreachability is expected and should not be logged as a warning.
+ bool IsDraining = false;
+ if (IProvisionerStateProvider* Prov = m_Provisioner.load(std::memory_order_acquire))
+ {
+ IsDraining = Prov->GetAgentStatus(Snap.Id) == AgentProvisioningStatus::Draining;
+ }
+
ReachableState NewState = ReachableState::Unreachable;
try
@@ -621,7 +639,10 @@ OrchestratorService::ProbeThreadFunction()
}
catch (const std::exception& Ex)
{
- ZEN_WARN("probe failed for worker {} ({}): {}", Snap.Id, Snap.Uri, Ex.what());
+ if (!IsDraining)
+ {
+ ZEN_WARN("probe failed for worker {} ({}): {}", Snap.Id, Snap.Uri, Ex.what());
+ }
}
ReachableState PrevState = ReachableState::Unknown;
@@ -646,6 +667,10 @@ OrchestratorService::ProbeThreadFunction()
{
ZEN_INFO("worker {} ({}) is now reachable", Snap.Id, Snap.Uri);
}
+ else if (IsDraining)
+ {
+ ZEN_INFO("worker {} ({}) shut down (draining)", Snap.Id, Snap.Uri);
+ }
else if (PrevState == ReachableState::Reachable)
{
ZEN_WARN("worker {} ({}) is no longer reachable", Snap.Id, Snap.Uri);
diff --git a/src/zencompute/runners/functionrunner.cpp b/src/zencompute/runners/functionrunner.cpp
index 67e12b84e..ab22c6363 100644
--- a/src/zencompute/runners/functionrunner.cpp
+++ b/src/zencompute/runners/functionrunner.cpp
@@ -6,9 +6,15 @@
# include <zencore/compactbinary.h>
# include <zencore/filesystem.h>
+# include <zencore/fmtutils.h>
+# include <zencore/logging.h>
+# include <zencore/string.h>
+# include <zencore/timer.h>
# include <zencore/trace.h>
+# include <zencore/workthreadpool.h>
# include <fmt/format.h>
+# include <future>
# include <vector>
namespace zen::compute {
@@ -118,23 +124,34 @@ std::vector<SubmitResult>
BaseRunnerGroup::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions)
{
ZEN_TRACE_CPU("BaseRunnerGroup::SubmitActions");
- RwLock::SharedLockScope _(m_RunnersLock);
- const int RunnerCount = gsl::narrow<int>(m_Runners.size());
+ // Snapshot runners and query capacity under the lock, then release
+ // before submitting — HTTP submissions to remote runners can take
+ // hundreds of milliseconds and we must not hold m_RunnersLock during I/O.
- if (RunnerCount == 0)
- {
- return std::vector<SubmitResult>(Actions.size(), SubmitResult{.IsAccepted = false, .Reason = "No runners available"});
- }
+ std::vector<Ref<FunctionRunner>> Runners;
+ std::vector<size_t> Capacities;
+ std::vector<std::vector<Ref<RunnerAction>>> PerRunnerActions;
+ size_t TotalCapacity = 0;
- // Query capacity per runner and compute total
- std::vector<size_t> Capacities(RunnerCount);
- size_t TotalCapacity = 0;
+ m_RunnersLock.WithSharedLock([&] {
+ const int RunnerCount = gsl::narrow<int>(m_Runners.size());
+ Runners.assign(m_Runners.begin(), m_Runners.end());
+ Capacities.resize(RunnerCount);
+ PerRunnerActions.resize(RunnerCount);
- for (int i = 0; i < RunnerCount; ++i)
+ for (int i = 0; i < RunnerCount; ++i)
+ {
+ Capacities[i] = Runners[i]->QueryCapacity();
+ TotalCapacity += Capacities[i];
+ }
+ });
+
+ const int RunnerCount = gsl::narrow<int>(Runners.size());
+
+ if (RunnerCount == 0)
{
- Capacities[i] = m_Runners[i]->QueryCapacity();
- TotalCapacity += Capacities[i];
+ return std::vector<SubmitResult>(Actions.size(), SubmitResult{.IsAccepted = false, .Reason = "No runners available"});
}
if (TotalCapacity == 0)
@@ -143,9 +160,8 @@ BaseRunnerGroup::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions)
}
// Distribute actions across runners proportionally to their available capacity
- std::vector<std::vector<Ref<RunnerAction>>> PerRunnerActions(RunnerCount);
- std::vector<size_t> ActionRunnerIndex(Actions.size());
- size_t ActionIdx = 0;
+ std::vector<size_t> ActionRunnerIndex(Actions.size());
+ size_t ActionIdx = 0;
for (int i = 0; i < RunnerCount; ++i)
{
@@ -176,14 +192,74 @@ BaseRunnerGroup::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions)
}
}
- // Submit batches per runner
+ // Submit batches per runner — in parallel when a worker pool is available
+
std::vector<std::vector<SubmitResult>> PerRunnerResults(RunnerCount);
+ int ActiveRunnerCount = 0;
for (int i = 0; i < RunnerCount; ++i)
{
if (!PerRunnerActions[i].empty())
{
- PerRunnerResults[i] = m_Runners[i]->SubmitActions(PerRunnerActions[i]);
+ ++ActiveRunnerCount;
+ }
+ }
+
+ static constexpr uint64_t SubmitWarnThresholdMs = 500;
+
+ auto SubmitToRunner = [&](int RunnerIndex) {
+ auto& Runner = Runners[RunnerIndex];
+ Runner->m_LastSubmitStats.Reset();
+
+ Stopwatch Timer;
+
+ PerRunnerResults[RunnerIndex] = Runner->SubmitActions(PerRunnerActions[RunnerIndex]);
+
+ uint64_t ElapsedMs = Timer.GetElapsedTimeMs();
+ if (ElapsedMs >= SubmitWarnThresholdMs)
+ {
+ size_t Attachments = Runner->m_LastSubmitStats.TotalAttachments.load(std::memory_order_relaxed);
+ uint64_t AttachmentBytes = Runner->m_LastSubmitStats.TotalAttachmentBytes.load(std::memory_order_relaxed);
+
+ ZEN_WARN("submit of {} actions ({} attachments, {}) to '{}' took {}ms",
+ PerRunnerActions[RunnerIndex].size(),
+ Attachments,
+ NiceBytes(AttachmentBytes),
+ Runner->GetDisplayName(),
+ ElapsedMs);
+ }
+ };
+
+ if (m_WorkerPool && ActiveRunnerCount > 1)
+ {
+ std::vector<std::future<void>> Futures(RunnerCount);
+
+ for (int i = 0; i < RunnerCount; ++i)
+ {
+ if (!PerRunnerActions[i].empty())
+ {
+ std::packaged_task<void()> Task([&SubmitToRunner, i]() { SubmitToRunner(i); });
+
+ Futures[i] = m_WorkerPool->EnqueueTask(std::move(Task), WorkerThreadPool::EMode::EnableBacklog);
+ }
+ }
+
+ for (int i = 0; i < RunnerCount; ++i)
+ {
+ if (Futures[i].valid())
+ {
+ Futures[i].get();
+ }
+ }
+ }
+ else
+ {
+ for (int i = 0; i < RunnerCount; ++i)
+ {
+ if (!PerRunnerActions[i].empty())
+ {
+ SubmitToRunner(i);
+ }
}
}
@@ -309,10 +385,11 @@ RunnerAction::RetractAction()
bool
RunnerAction::ResetActionStateToPending()
{
- // Only allow reset from Failed, Abandoned, or Retracted states
+ // Only allow reset from Failed, Abandoned, Rejected, or Retracted states
State CurrentState = m_ActionState.load();
- if (CurrentState != State::Failed && CurrentState != State::Abandoned && CurrentState != State::Retracted)
+ if (CurrentState != State::Failed && CurrentState != State::Abandoned && CurrentState != State::Rejected &&
+ CurrentState != State::Retracted)
{
return false;
}
@@ -333,11 +410,12 @@ RunnerAction::ResetActionStateToPending()
// Clear execution fields
ExecutionLocation.clear();
+ FailureReason.clear();
CpuUsagePercent.store(-1.0f, std::memory_order_relaxed);
CpuSeconds.store(0.0f, std::memory_order_relaxed);
- // Increment retry count (skip for Retracted — nothing failed)
- if (CurrentState != State::Retracted)
+ // Increment retry count (skip for Retracted/Rejected — nothing failed)
+ if (CurrentState != State::Retracted && CurrentState != State::Rejected)
{
RetryCount.fetch_add(1, std::memory_order_relaxed);
}
diff --git a/src/zencompute/runners/functionrunner.h b/src/zencompute/runners/functionrunner.h
index 56c3f3af0..449f0e228 100644
--- a/src/zencompute/runners/functionrunner.h
+++ b/src/zencompute/runners/functionrunner.h
@@ -10,6 +10,10 @@
# include <filesystem>
# include <vector>
+namespace zen {
+class WorkerThreadPool;
+}
+
namespace zen::compute {
struct SubmitResult
@@ -37,6 +41,22 @@ public:
[[nodiscard]] virtual bool IsHealthy() = 0;
[[nodiscard]] virtual size_t QueryCapacity();
[[nodiscard]] virtual std::vector<SubmitResult> SubmitActions(const std::vector<Ref<RunnerAction>>& Actions);
+ [[nodiscard]] virtual std::string_view GetDisplayName() const { return "local"; }
+
+ // Accumulated stats from the most recent SubmitActions call.
+ // Reset before each call, populated by the runner implementation.
+ struct SubmitStats
+ {
+ std::atomic<size_t> TotalAttachments{0};
+ std::atomic<uint64_t> TotalAttachmentBytes{0};
+
+ void Reset()
+ {
+ TotalAttachments.store(0, std::memory_order_relaxed);
+ TotalAttachmentBytes.store(0, std::memory_order_relaxed);
+ }
+ };
+ SubmitStats m_LastSubmitStats;
// Best-effort cancellation of a specific in-flight action. Returns true if the
// cancellation signal was successfully sent. The action will transition to Cancelled
@@ -68,6 +88,8 @@ public:
bool CancelAction(int ActionLsn);
void CancelRemoteQueue(int QueueId);
+ void SetWorkerPool(WorkerThreadPool* Pool) { m_WorkerPool = Pool; }
+
size_t GetRunnerCount()
{
return m_RunnersLock.WithSharedLock([this] { return m_Runners.size(); });
@@ -79,6 +101,7 @@ protected:
RwLock m_RunnersLock;
std::vector<Ref<FunctionRunner>> m_Runners;
std::atomic<int> m_NextSubmitIndex{0};
+ WorkerThreadPool* m_WorkerPool = nullptr;
};
/** Typed RunnerGroup that adds type-safe runner addition and predicate-based removal.
@@ -151,6 +174,7 @@ struct RunnerAction : public RefCounted
CbObject ActionObj;
int Priority = 0;
std::string ExecutionLocation; // "local" or remote hostname
+ std::string FailureReason; // human-readable reason when action fails (empty on success)
// CPU usage and total CPU time of the running process, sampled periodically by the local runner.
// CpuUsagePercent: -1.0 means not yet sampled; >=0.0 is the most recent reading as a percentage.
@@ -168,6 +192,7 @@ struct RunnerAction : public RefCounted
Completed, // Finished successfully with results available
Failed, // Execution failed (transient error, eligible for retry)
Abandoned, // Infrastructure termination (e.g. spot eviction, session abandon)
+ Rejected, // Runner declined (e.g. at capacity) — rescheduled without retry cost
Cancelled, // Intentional user cancellation (never retried)
Retracted, // Pulled back for rescheduling on a different runner (no retry cost)
_Count
@@ -194,6 +219,8 @@ struct RunnerAction : public RefCounted
return "Failed";
case State::Abandoned:
return "Abandoned";
+ case State::Rejected:
+ return "Rejected";
case State::Cancelled:
return "Cancelled";
case State::Retracted:
diff --git a/src/zencompute/runners/linuxrunner.cpp b/src/zencompute/runners/linuxrunner.cpp
index 9055005d9..ce5bbdcc8 100644
--- a/src/zencompute/runners/linuxrunner.cpp
+++ b/src/zencompute/runners/linuxrunner.cpp
@@ -430,7 +430,8 @@ LinuxProcessRunner::SubmitAction(Ref<RunnerAction> Action)
if (ChildPid == 0)
{
- // Child process
+ // Child process — lower priority so workers don't starve the main server
+ nice(5);
if (m_Sandboxed)
{
@@ -481,7 +482,8 @@ LinuxProcessRunner::SubmitAction(Ref<RunnerAction> Action)
// Clean up the sandbox in the background
m_DeferredDeleter.Enqueue(Action->ActionLsn, std::move(Prepared->SandboxPath));
- ZEN_ERROR("Sandbox setup failed for action {}: {}", Action->ActionLsn, ErrBuf);
+ Action->FailureReason = fmt::format("sandbox setup failed: {}", ErrBuf);
+ ZEN_ERROR("action {} ({}): {}", Action->ActionId, Action->ActionLsn, Action->FailureReason);
Action->SetActionState(RunnerAction::State::Failed);
return SubmitResult{.IsAccepted = false};
diff --git a/src/zencompute/runners/localrunner.cpp b/src/zencompute/runners/localrunner.cpp
index 1b748c0e5..96cbdc134 100644
--- a/src/zencompute/runners/localrunner.cpp
+++ b/src/zencompute/runners/localrunner.cpp
@@ -357,14 +357,21 @@ LocalProcessRunner::ManifestWorker(const WorkerDesc& Worker)
std::filesystem::path WorkerDir = m_WorkerPath / fmt::format("runner_{}", Worker.WorkerId);
- if (!std::filesystem::exists(WorkerDir))
+ // worker.zcb is written as the last step of ManifestWorker, so its presence
+ // indicates a complete manifest. If the directory exists but the marker is
+ // missing, a previous manifest was interrupted and we need to start over.
+ bool NeedsManifest = !std::filesystem::exists(WorkerDir / "worker.zcb");
+
+ if (NeedsManifest)
{
_.ReleaseNow();
RwLock::ExclusiveLockScope $(m_WorkerLock);
- if (!std::filesystem::exists(WorkerDir))
+ if (!std::filesystem::exists(WorkerDir / "worker.zcb"))
{
+ std::error_code Ec;
+ std::filesystem::remove_all(WorkerDir, Ec);
ManifestWorker(Worker.Descriptor, WorkerDir, [](const IoHash&, CompressedBuffer&) {});
}
}
@@ -673,9 +680,15 @@ LocalProcessRunner::ProcessCompletedActions(std::vector<Ref<RunningAction>>& Com
}
catch (std::exception& Ex)
{
- ZEN_ERROR("Encountered failure while gathering outputs for action lsn {}, '{}'", ActionLsn, Ex.what());
+ Running->Action->FailureReason = fmt::format("exception gathering outputs: {}", Ex.what());
+ ZEN_ERROR("action {} ({}) failed: {}", Running->Action->ActionId, ActionLsn, Running->Action->FailureReason);
}
}
+ else
+ {
+ Running->Action->FailureReason = fmt::format("process exited with code {}", Running->ExitCode);
+ ZEN_WARN("action {} ({}) failed: {}", Running->Action->ActionId, ActionLsn, Running->Action->FailureReason);
+ }
// Failed - clean up the sandbox in the background.
diff --git a/src/zencompute/runners/macrunner.cpp b/src/zencompute/runners/macrunner.cpp
index c2ccca9a6..13c01d988 100644
--- a/src/zencompute/runners/macrunner.cpp
+++ b/src/zencompute/runners/macrunner.cpp
@@ -211,7 +211,8 @@ MacProcessRunner::SubmitAction(Ref<RunnerAction> Action)
if (ChildPid == 0)
{
- // Child process
+ // Child process — lower priority so workers don't starve the main server
+ nice(5);
if (m_Sandboxed)
{
@@ -281,7 +282,8 @@ MacProcessRunner::SubmitAction(Ref<RunnerAction> Action)
// Clean up the sandbox in the background
m_DeferredDeleter.Enqueue(Action->ActionLsn, std::move(Prepared->SandboxPath));
- ZEN_ERROR("Sandbox setup failed for action {}: {}", Action->ActionLsn, ErrBuf);
+ Action->FailureReason = fmt::format("sandbox setup failed: {}", ErrBuf);
+ ZEN_ERROR("action {} ({}): {}", Action->ActionId, Action->ActionLsn, Action->FailureReason);
Action->SetActionState(RunnerAction::State::Failed);
return SubmitResult{.IsAccepted = false};
diff --git a/src/zencompute/runners/managedrunner.cpp b/src/zencompute/runners/managedrunner.cpp
index e4a7ba388..a4f586852 100644
--- a/src/zencompute/runners/managedrunner.cpp
+++ b/src/zencompute/runners/managedrunner.cpp
@@ -128,7 +128,7 @@ ManagedProcessRunner::SubmitAction(Ref<RunnerAction> Action)
CreateProcOptions Options;
Options.WorkingDirectory = &Prepared->SandboxPath;
- Options.Flags = CreateProcOptions::Flag_NoConsole;
+ Options.Flags = CreateProcOptions::Flag_NoConsole | CreateProcOptions::Flag_BelowNormalPriority;
Options.Environment = std::move(EnvPairs);
const int32_t ActionLsn = Prepared->ActionLsn;
diff --git a/src/zencompute/runners/remotehttprunner.cpp b/src/zencompute/runners/remotehttprunner.cpp
index ce6a81173..55f78fdd6 100644
--- a/src/zencompute/runners/remotehttprunner.cpp
+++ b/src/zencompute/runners/remotehttprunner.cpp
@@ -20,6 +20,7 @@
# include <zenstore/cidstore.h>
# include <span>
+# include <unordered_set>
//////////////////////////////////////////////////////////////////////////
@@ -38,6 +39,7 @@ RemoteHttpRunner::RemoteHttpRunner(ChunkResolver& InChunkResolver,
, m_ChunkResolver{InChunkResolver}
, m_WorkerPool{InWorkerPool}
, m_HostName{HostName}
+, m_DisplayName{HostName}
, m_BaseUrl{fmt::format("{}/compute", HostName)}
, m_Http(m_BaseUrl)
, m_InstanceId(Oid::NewOid())
@@ -59,6 +61,15 @@ RemoteHttpRunner::RemoteHttpRunner(ChunkResolver& InChunkResolver,
m_MonitorThread = std::thread{&RemoteHttpRunner::MonitorThreadFunction, this};
}
+void
+RemoteHttpRunner::SetRemoteHostname(std::string_view Hostname)
+{
+ if (!Hostname.empty())
+ {
+ m_DisplayName = fmt::format("{} ({})", m_HostName, Hostname);
+ }
+}
+
RemoteHttpRunner::~RemoteHttpRunner()
{
Shutdown();
@@ -108,6 +119,7 @@ RemoteHttpRunner::Shutdown()
for (auto& [RemoteLsn, HttpAction] : Remaining)
{
ZEN_DEBUG("shutdown: marking remote action LSN {} (local LSN {}) as Failed", RemoteLsn, HttpAction.Action->ActionLsn);
+ HttpAction.Action->FailureReason = "remote runner shutdown";
HttpAction.Action->SetActionState(RunnerAction::State::Failed);
}
}
@@ -213,11 +225,13 @@ RemoteHttpRunner::QueryCapacity()
return 0;
}
- // Estimate how much more work we're ready to accept
+ // Estimate how much more work we're ready to accept.
+ // Include actions currently being submitted over HTTP so we don't
+ // keep queueing new submissions while previous ones are still in flight.
RwLock::SharedLockScope _{m_RunningLock};
- size_t RunningCount = m_RemoteRunningMap.size();
+ size_t RunningCount = m_RemoteRunningMap.size() + m_InFlightSubmissions.load(std::memory_order_relaxed);
if (RunningCount >= size_t(m_MaxRunningActions))
{
@@ -232,6 +246,9 @@ RemoteHttpRunner::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions)
{
ZEN_TRACE_CPU("RemoteHttpRunner::SubmitActions");
+ m_InFlightSubmissions.fetch_add(Actions.size(), std::memory_order_relaxed);
+ auto InFlightGuard = MakeGuard([&] { m_InFlightSubmissions.fetch_sub(Actions.size(), std::memory_order_relaxed); });
+
if (Actions.size() <= 1)
{
std::vector<SubmitResult> Results;
@@ -359,108 +376,141 @@ RemoteHttpRunner::SubmitAction(Ref<RunnerAction> Action)
}
}
- // Enqueue job. If the remote returns FailedDependency (424), it means it
- // cannot resolve the worker/function — re-register the worker and retry once.
+ // Submit the action to the remote. In eager-attach mode we build a
+ // CbPackage with all referenced attachments upfront to avoid the 404
+ // round-trip. In the default mode we POST the bare object first and
+ // only upload missing attachments if the remote requests them.
+ //
+ // In both modes, FailedDependency (424) triggers a worker re-register
+ // and a single retry.
CbObject Result;
HttpClient::Response WorkResponse;
HttpResponseCode WorkResponseCode{};
- for (int Attempt = 0; Attempt < 2; ++Attempt)
- {
- WorkResponse = m_Http.Post(SubmitUrl, ActionObj);
- WorkResponseCode = WorkResponse.StatusCode;
-
- if (WorkResponseCode == HttpResponseCode::FailedDependency && Attempt == 0)
- {
- ZEN_WARN("remote {} returned FailedDependency for action {} — re-registering worker and retrying",
- m_Http.GetBaseUri(),
- ActionId);
-
- (void)RegisterWorker(Action->Worker.Descriptor);
- }
- else
- {
- break;
- }
- }
-
- if (WorkResponseCode == HttpResponseCode::OK)
- {
- Result = WorkResponse.AsObject();
- }
- else if (WorkResponseCode == HttpResponseCode::NotFound)
+ if (m_EagerAttach)
{
- // Not all attachments are present
-
- // Build response package including all required attachments
-
CbPackage Pkg;
Pkg.SetObject(ActionObj);
- CbObject Response = WorkResponse.AsObject();
+ ActionObj.IterateAttachments([&](CbFieldView Field) {
+ const IoHash AttachHash = Field.AsHash();
- for (auto& Item : Response["need"sv])
- {
- const IoHash NeedHash = Item.AsHash();
-
- if (IoBuffer Chunk = m_ChunkResolver.FindChunkByCid(NeedHash))
+ if (IoBuffer Chunk = m_ChunkResolver.FindChunkByCid(AttachHash))
{
uint64_t DataRawSize = 0;
IoHash DataRawHash;
CompressedBuffer Compressed =
CompressedBuffer::FromCompressed(SharedBuffer{Chunk}, /* out */ DataRawHash, /* out */ DataRawSize);
- ZEN_ASSERT(DataRawHash == NeedHash);
+ Pkg.AddAttachment(CbAttachment(Compressed, AttachHash));
+ m_LastSubmitStats.TotalAttachments.fetch_add(1, std::memory_order_relaxed);
+ m_LastSubmitStats.TotalAttachmentBytes.fetch_add(Chunk.GetSize(), std::memory_order_relaxed);
+ }
+ });
+
+ for (int Attempt = 0; Attempt < 2; ++Attempt)
+ {
+ WorkResponse = m_Http.Post(SubmitUrl, Pkg);
+ WorkResponseCode = WorkResponse.StatusCode;
+
+ if (WorkResponseCode == HttpResponseCode::FailedDependency && Attempt == 0)
+ {
+ ZEN_WARN("remote {} returned FailedDependency for action {} — re-registering worker and retrying",
+ m_Http.GetBaseUri(),
+ ActionId);
- Pkg.AddAttachment(CbAttachment(Compressed, NeedHash));
+ (void)RegisterWorker(Action->Worker.Descriptor);
}
else
{
- // No such attachment
-
- return {.IsAccepted = false, .Reason = fmt::format("missing attachment {}", NeedHash)};
+ break;
}
}
+ }
+ else
+ {
+ for (int Attempt = 0; Attempt < 2; ++Attempt)
+ {
+ WorkResponse = m_Http.Post(SubmitUrl, ActionObj);
+ WorkResponseCode = WorkResponse.StatusCode;
- // Post resulting package
+ if (WorkResponseCode == HttpResponseCode::FailedDependency && Attempt == 0)
+ {
+ ZEN_WARN("remote {} returned FailedDependency for action {} — re-registering worker and retrying",
+ m_Http.GetBaseUri(),
+ ActionId);
- HttpClient::Response PayloadResponse = m_Http.Post(SubmitUrl, Pkg);
+ (void)RegisterWorker(Action->Worker.Descriptor);
+ }
+ else
+ {
+ break;
+ }
+ }
- if (!PayloadResponse)
+ if (WorkResponseCode == HttpResponseCode::NotFound)
{
- ZEN_WARN("unable to register payloads for action {} at {}{}", ActionId, m_Http.GetBaseUri(), SubmitUrl);
+ // Remote needs attachments — resolve them and retry with a CbPackage
- // TODO: include more information about the failure in the response
+ CbPackage Pkg;
+ Pkg.SetObject(ActionObj);
- return {.IsAccepted = false, .Reason = "HTTP request failed"};
- }
- else if (PayloadResponse.StatusCode == HttpResponseCode::OK)
- {
- Result = PayloadResponse.AsObject();
- }
- else
- {
- // Unexpected response
-
- const int ResponseStatusCode = (int)PayloadResponse.StatusCode;
-
- ZEN_WARN("unable to register payloads for action {} at {}{} (error: {} {})",
- ActionId,
- m_Http.GetBaseUri(),
- SubmitUrl,
- ResponseStatusCode,
- ToString(ResponseStatusCode));
-
- return {.IsAccepted = false,
- .Reason = fmt::format("unexpected response code {} {} from {}{}",
- ResponseStatusCode,
- ToString(ResponseStatusCode),
- m_Http.GetBaseUri(),
- SubmitUrl)};
+ CbObject Response = WorkResponse.AsObject();
+
+ for (auto& Item : Response["need"sv])
+ {
+ const IoHash NeedHash = Item.AsHash();
+
+ if (IoBuffer Chunk = m_ChunkResolver.FindChunkByCid(NeedHash))
+ {
+ uint64_t DataRawSize = 0;
+ IoHash DataRawHash;
+ CompressedBuffer Compressed =
+ CompressedBuffer::FromCompressed(SharedBuffer{Chunk}, /* out */ DataRawHash, /* out */ DataRawSize);
+
+ ZEN_ASSERT(DataRawHash == NeedHash);
+
+ Pkg.AddAttachment(CbAttachment(Compressed, NeedHash));
+ m_LastSubmitStats.TotalAttachments.fetch_add(1, std::memory_order_relaxed);
+ m_LastSubmitStats.TotalAttachmentBytes.fetch_add(Chunk.GetSize(), std::memory_order_relaxed);
+ }
+ else
+ {
+ return {.IsAccepted = false, .Reason = fmt::format("missing attachment {}", NeedHash)};
+ }
+ }
+
+ HttpClient::Response PayloadResponse = m_Http.Post(SubmitUrl, Pkg);
+
+ if (!PayloadResponse)
+ {
+ ZEN_WARN("unable to register payloads for action {} at {}{}", ActionId, m_Http.GetBaseUri(), SubmitUrl);
+ return {.IsAccepted = false, .Reason = "HTTP request failed"};
+ }
+
+ WorkResponse = std::move(PayloadResponse);
+ WorkResponseCode = WorkResponse.StatusCode;
}
}
+ if (WorkResponseCode == HttpResponseCode::OK)
+ {
+ Result = WorkResponse.AsObject();
+ }
+ else if (!WorkResponse)
+ {
+ ZEN_WARN("submit of action {} to {}{} failed", ActionId, m_Http.GetBaseUri(), SubmitUrl);
+ return {.IsAccepted = false, .Reason = "HTTP request failed"};
+ }
+ else if (!IsHttpSuccessCode(WorkResponseCode))
+ {
+ const int Code = static_cast<int>(WorkResponseCode);
+ ZEN_WARN("submit of action {} to {}{} returned {} {}", ActionId, m_Http.GetBaseUri(), SubmitUrl, Code, ToString(Code));
+ return {.IsAccepted = false,
+ .Reason = fmt::format("unexpected response code {} {} from {}{}", Code, ToString(Code), m_Http.GetBaseUri(), SubmitUrl)};
+ }
+
if (Result)
{
if (const int32_t LsnField = Result["lsn"].AsInt32(0))
@@ -512,82 +562,110 @@ RemoteHttpRunner::SubmitActionBatch(const std::string& SubmitUrl, const std::vec
CbObjectWriter Body;
Body.BeginArray("actions"sv);
+ std::unordered_set<IoHash, IoHash::Hasher> AttachmentsSeen;
+
for (const Ref<RunnerAction>& Action : Actions)
{
Action->ExecutionLocation = m_HostName;
MaybeDumpAction(Action->ActionLsn, Action->ActionObj);
Body.AddObject(Action->ActionObj);
+
+ if (m_EagerAttach)
+ {
+ Action->ActionObj.IterateAttachments([&](CbFieldView Field) { AttachmentsSeen.insert(Field.AsHash()); });
+ }
}
Body.EndArray();
- // POST the batch
-
- HttpClient::Response Response = m_Http.Post(SubmitUrl, Body.Save());
-
- if (Response.StatusCode == HttpResponseCode::OK)
- {
- return ParseBatchResponse(Response, Actions);
- }
+ // In eager-attach mode, build a CbPackage with all referenced attachments
+ // so the remote can accept in a single round-trip. Otherwise POST a bare
+ // CbObject and handle the 404 need-list flow.
- if (Response.StatusCode == HttpResponseCode::NotFound)
+ if (m_EagerAttach)
{
- // Server needs attachments — resolve them and retry with a CbPackage
-
- CbObject NeedObj = Response.AsObject();
-
CbPackage Pkg;
Pkg.SetObject(Body.Save());
- for (auto& Item : NeedObj["need"sv])
+ for (const IoHash& AttachHash : AttachmentsSeen)
{
- const IoHash NeedHash = Item.AsHash();
-
- if (IoBuffer Chunk = m_ChunkResolver.FindChunkByCid(NeedHash))
+ if (IoBuffer Chunk = m_ChunkResolver.FindChunkByCid(AttachHash))
{
uint64_t DataRawSize = 0;
IoHash DataRawHash;
CompressedBuffer Compressed =
CompressedBuffer::FromCompressed(SharedBuffer{Chunk}, /* out */ DataRawHash, /* out */ DataRawSize);
- ZEN_ASSERT(DataRawHash == NeedHash);
-
- Pkg.AddAttachment(CbAttachment(Compressed, NeedHash));
- }
- else
- {
- ZEN_WARN("batch submit: missing attachment {} — falling back to individual submit", NeedHash);
- return FallbackToIndividualSubmit(Actions);
+ Pkg.AddAttachment(CbAttachment(Compressed, AttachHash));
+ m_LastSubmitStats.TotalAttachments.fetch_add(1, std::memory_order_relaxed);
+ m_LastSubmitStats.TotalAttachmentBytes.fetch_add(Chunk.GetSize(), std::memory_order_relaxed);
}
}
- HttpClient::Response RetryResponse = m_Http.Post(SubmitUrl, Pkg);
+ HttpClient::Response Response = m_Http.Post(SubmitUrl, Pkg);
+
+ if (Response.StatusCode == HttpResponseCode::OK)
+ {
+ return ParseBatchResponse(Response, Actions);
+ }
+ }
+ else
+ {
+ HttpClient::Response Response = m_Http.Post(SubmitUrl, Body.Save());
- if (RetryResponse.StatusCode == HttpResponseCode::OK)
+ if (Response.StatusCode == HttpResponseCode::OK)
{
- return ParseBatchResponse(RetryResponse, Actions);
+ return ParseBatchResponse(Response, Actions);
}
- ZEN_WARN("batch submit retry failed with {} {} — falling back to individual submit",
- (int)RetryResponse.StatusCode,
- ToString(RetryResponse.StatusCode));
- return FallbackToIndividualSubmit(Actions);
+ if (Response.StatusCode == HttpResponseCode::NotFound)
+ {
+ CbObject NeedObj = Response.AsObject();
+
+ CbPackage Pkg;
+ Pkg.SetObject(Body.Save());
+
+ for (auto& Item : NeedObj["need"sv])
+ {
+ const IoHash NeedHash = Item.AsHash();
+
+ if (IoBuffer Chunk = m_ChunkResolver.FindChunkByCid(NeedHash))
+ {
+ uint64_t DataRawSize = 0;
+ IoHash DataRawHash;
+ CompressedBuffer Compressed =
+ CompressedBuffer::FromCompressed(SharedBuffer{Chunk}, /* out */ DataRawHash, /* out */ DataRawSize);
+
+ ZEN_ASSERT(DataRawHash == NeedHash);
+
+ Pkg.AddAttachment(CbAttachment(Compressed, NeedHash));
+ m_LastSubmitStats.TotalAttachments.fetch_add(1, std::memory_order_relaxed);
+ m_LastSubmitStats.TotalAttachmentBytes.fetch_add(Chunk.GetSize(), std::memory_order_relaxed);
+ }
+ else
+ {
+ ZEN_WARN("batch submit: missing attachment {} — falling back to individual submit", NeedHash);
+ return FallbackToIndividualSubmit(Actions);
+ }
+ }
+
+ HttpClient::Response RetryResponse = m_Http.Post(SubmitUrl, Pkg);
+
+ if (RetryResponse.StatusCode == HttpResponseCode::OK)
+ {
+ return ParseBatchResponse(RetryResponse, Actions);
+ }
+
+ ZEN_WARN("batch submit retry failed with {} {} — falling back to individual submit",
+ (int)RetryResponse.StatusCode,
+ ToString(RetryResponse.StatusCode));
+ return FallbackToIndividualSubmit(Actions);
+ }
}
// Unexpected status or connection error — fall back to individual submission
- if (Response)
- {
- ZEN_WARN("batch submit to {}{} returned {} {} — falling back to individual submit",
- m_Http.GetBaseUri(),
- SubmitUrl,
- (int)Response.StatusCode,
- ToString(Response.StatusCode));
- }
- else
- {
- ZEN_WARN("batch submit to {}{} failed — falling back to individual submit", m_Http.GetBaseUri(), SubmitUrl);
- }
+ ZEN_WARN("batch submit to {}{} failed — falling back to individual submit", m_Http.GetBaseUri(), SubmitUrl);
return FallbackToIndividualSubmit(Actions);
}
@@ -869,9 +947,10 @@ RemoteHttpRunner::SweepRunningActions()
{
for (auto& FieldIt : Completed["completed"sv])
{
- CbObjectView EntryObj = FieldIt.AsObjectView();
- const int32_t CompleteLsn = EntryObj["lsn"sv].AsInt32();
- std::string_view StateName = EntryObj["state"sv].AsString();
+ CbObjectView EntryObj = FieldIt.AsObjectView();
+ const int32_t CompleteLsn = EntryObj["lsn"sv].AsInt32();
+ std::string_view StateName = EntryObj["state"sv].AsString();
+ std::string_view FailureReason = EntryObj["reason"sv].AsString();
RunnerAction::State RemoteState = RunnerAction::FromString(StateName);
@@ -884,6 +963,7 @@ RemoteHttpRunner::SweepRunningActions()
{
HttpRunningAction CompletedAction = std::move(CompleteIt->second);
CompletedAction.RemoteState = RemoteState;
+ CompletedAction.FailureReason = std::string(FailureReason);
if (RemoteState == RunnerAction::State::Completed && ResponseJob)
{
@@ -927,16 +1007,44 @@ RemoteHttpRunner::SweepRunningActions()
{
const int ActionLsn = HttpAction.Action->ActionLsn;
- ZEN_DEBUG("action {} LSN {} (remote LSN {}) -> {}",
- HttpAction.Action->ActionId,
- ActionLsn,
- HttpAction.RemoteActionLsn,
- RunnerAction::ToString(HttpAction.RemoteState));
-
if (HttpAction.RemoteState == RunnerAction::State::Completed)
{
+ ZEN_DEBUG("action {} LSN {} (remote LSN {}) completed on {}",
+ HttpAction.Action->ActionId,
+ ActionLsn,
+ HttpAction.RemoteActionLsn,
+ m_HostName);
HttpAction.Action->SetResult(std::move(HttpAction.ActionResults));
}
+ else if (HttpAction.RemoteState == RunnerAction::State::Failed || HttpAction.RemoteState == RunnerAction::State::Abandoned)
+ {
+ HttpAction.Action->FailureReason = HttpAction.FailureReason;
+ if (HttpAction.FailureReason.empty())
+ {
+ ZEN_WARN("action {} ({}) {} on remote {}",
+ HttpAction.Action->ActionId,
+ ActionLsn,
+ RunnerAction::ToString(HttpAction.RemoteState),
+ m_HostName);
+ }
+ else
+ {
+ ZEN_WARN("action {} ({}) {} on remote {}: {}",
+ HttpAction.Action->ActionId,
+ ActionLsn,
+ RunnerAction::ToString(HttpAction.RemoteState),
+ m_HostName,
+ HttpAction.FailureReason);
+ }
+ }
+ else
+ {
+ ZEN_DEBUG("action {} LSN {} (remote LSN {}) -> {}",
+ HttpAction.Action->ActionId,
+ ActionLsn,
+ HttpAction.RemoteActionLsn,
+ RunnerAction::ToString(HttpAction.RemoteState));
+ }
HttpAction.Action->SetActionState(HttpAction.RemoteState);
}
diff --git a/src/zencompute/runners/remotehttprunner.h b/src/zencompute/runners/remotehttprunner.h
index c17d0cf2a..fdf113c77 100644
--- a/src/zencompute/runners/remotehttprunner.h
+++ b/src/zencompute/runners/remotehttprunner.h
@@ -54,8 +54,10 @@ public:
[[nodiscard]] virtual size_t QueryCapacity() override;
[[nodiscard]] virtual std::vector<SubmitResult> SubmitActions(const std::vector<Ref<RunnerAction>>& Actions) override;
virtual void CancelRemoteQueue(int QueueId) override;
+ [[nodiscard]] virtual std::string_view GetDisplayName() const override { return m_DisplayName; }
std::string_view GetHostName() const { return m_HostName; }
+ void SetRemoteHostname(std::string_view Hostname);
protected:
LoggerRef Log() { return m_Log; }
@@ -65,12 +67,15 @@ private:
ChunkResolver& m_ChunkResolver;
WorkerThreadPool& m_WorkerPool;
std::string m_HostName;
+ std::string m_DisplayName;
std::string m_BaseUrl;
HttpClient m_Http;
- std::atomic<bool> m_AcceptNewActions{true};
- int32_t m_MaxRunningActions = 256; // arbitrary limit for testing
- int32_t m_MaxBatchSize = 50;
+ std::atomic<bool> m_AcceptNewActions{true};
+ int32_t m_MaxRunningActions = 256; // arbitrary limit for testing
+ int32_t m_MaxBatchSize = 50;
+ bool m_EagerAttach = true; ///< Send attachments with every submit instead of the two-step 404 retry
+ std::atomic<size_t> m_InFlightSubmissions{0}; // actions currently being submitted over HTTP
struct HttpRunningAction
{
@@ -78,6 +83,7 @@ private:
int RemoteActionLsn = 0; // Remote LSN
RunnerAction::State RemoteState = RunnerAction::State::Failed;
CbPackage ActionResults;
+ std::string FailureReason;
};
RwLock m_RunningLock;
diff --git a/src/zencompute/runners/windowsrunner.cpp b/src/zencompute/runners/windowsrunner.cpp
index 92ee65c2d..e643c9ce8 100644
--- a/src/zencompute/runners/windowsrunner.cpp
+++ b/src/zencompute/runners/windowsrunner.cpp
@@ -48,7 +48,9 @@ WindowsProcessRunner::WindowsProcessRunner(ChunkResolver& Resolver,
if (m_JobObject)
{
JOBOBJECT_EXTENDED_LIMIT_INFORMATION ExtLimits{};
- ExtLimits.BasicLimitInformation.LimitFlags = JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE | JOB_OBJECT_LIMIT_DIE_ON_UNHANDLED_EXCEPTION;
+ ExtLimits.BasicLimitInformation.LimitFlags =
+ JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE | JOB_OBJECT_LIMIT_DIE_ON_UNHANDLED_EXCEPTION | JOB_OBJECT_LIMIT_PRIORITY_CLASS;
+ ExtLimits.BasicLimitInformation.PriorityClass = BELOW_NORMAL_PRIORITY_CLASS;
SetInformationJobObject(m_JobObject, JobObjectExtendedLimitInformation, &ExtLimits, sizeof(ExtLimits));
JOBOBJECT_BASIC_UI_RESTRICTIONS UiRestrictions{};
diff --git a/src/zencompute/runners/winerunner.cpp b/src/zencompute/runners/winerunner.cpp
index b4fafb467..593b19e55 100644
--- a/src/zencompute/runners/winerunner.cpp
+++ b/src/zencompute/runners/winerunner.cpp
@@ -96,7 +96,9 @@ WineProcessRunner::SubmitAction(Ref<RunnerAction> Action)
if (ChildPid == 0)
{
- // Child process
+ // Child process — lower priority so workers don't starve the main server
+ nice(5);
+
if (chdir(SandboxPathStr.c_str()) != 0)
{
_exit(127);
diff --git a/src/zenhorde/README.md b/src/zenhorde/README.md
new file mode 100644
index 000000000..13beaa968
--- /dev/null
+++ b/src/zenhorde/README.md
@@ -0,0 +1,17 @@
+# Horde Compute integration
+
+Zen compute can use Horde to provision runner nodes.
+
+## Launch a coordinator instance
+
+Coordinator instances provision compute resources (runners) from a compute provider such as Horde, and surface an interface which allows zenserver instances to discover endpoints which they can submit actions to.
+
+```bash
+zenserver compute --horde-enabled --horde-server=https://horde.dev.net:13340/ --horde-max-cores=512 --horde-zen-service-port=25000 --http=asio
+```
+
+## Use a coordinator
+
+```bash
+zen exec beacon --path=e:\lyra-recording --orch=http://localhost:8558
+```
diff --git a/src/zenhorde/hordeagent.cpp b/src/zenhorde/hordeagent.cpp
index 819b2d0cb..275f5bd4c 100644
--- a/src/zenhorde/hordeagent.cpp
+++ b/src/zenhorde/hordeagent.cpp
@@ -8,290 +8,457 @@
#include <zencore/logging.h>
#include <zencore/trace.h>
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <asio.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
#include <cstring>
-#include <unordered_map>
namespace zen::horde {
-HordeAgent::HordeAgent(const MachineInfo& Info) : m_Log(zen::logging::Get("horde.agent")), m_MachineInfo(Info)
-{
- ZEN_TRACE_CPU("HordeAgent::Connect");
+// --- AsyncHordeAgent ---
- auto Transport = std::make_unique<TcpComputeTransport>(Info);
- if (!Transport->IsValid())
+static const char*
+GetStateName(AsyncHordeAgent::State S)
+{
+ switch (S)
{
- ZEN_WARN("failed to create TCP transport to '{}:{}'", Info.GetConnectionAddress(), Info.GetConnectionPort());
- return;
+ case AsyncHordeAgent::State::Idle:
+ return "idle";
+ case AsyncHordeAgent::State::Connecting:
+ return "connect";
+ case AsyncHordeAgent::State::WaitAgentAttach:
+ return "agent-attach";
+ case AsyncHordeAgent::State::SentFork:
+ return "fork";
+ case AsyncHordeAgent::State::WaitChildAttach:
+ return "child-attach";
+ case AsyncHordeAgent::State::Uploading:
+ return "upload";
+ case AsyncHordeAgent::State::Executing:
+ return "execute";
+ case AsyncHordeAgent::State::Polling:
+ return "poll";
+ case AsyncHordeAgent::State::Done:
+ return "done";
+ default:
+ return "unknown";
}
+}
- // The 64-byte nonce is always sent unencrypted as the first thing on the wire.
- // The Horde agent uses this to identify which lease this connection belongs to.
- Transport->Send(Info.Nonce, sizeof(Info.Nonce));
+AsyncHordeAgent::AsyncHordeAgent(asio::io_context& IoContext) : m_IoContext(IoContext), m_Log(zen::logging::Get("horde.agent.async"))
+{
+}
- std::unique_ptr<ComputeTransport> FinalTransport = std::move(Transport);
- if (Info.EncryptionMode == Encryption::AES)
+AsyncHordeAgent::~AsyncHordeAgent()
+{
+ Cancel();
+}
+
+void
+AsyncHordeAgent::Start(AsyncAgentConfig Config, AsyncAgentCompletionHandler OnDone)
+{
+ m_Config = std::move(Config);
+ m_OnDone = std::move(OnDone);
+ m_State = State::Connecting;
+ DoConnect();
+}
+
+void
+AsyncHordeAgent::Cancel()
+{
+ m_Cancelled = true;
+ if (m_Socket)
{
- FinalTransport = std::make_unique<AesComputeTransport>(Info.Key, std::move(FinalTransport));
- if (!FinalTransport->IsValid())
- {
- ZEN_WARN("failed to create AES transport");
- return;
- }
+ m_Socket->Close();
+ }
+ else if (m_Transport)
+ {
+ m_Transport->Close();
}
+}
- // Create multiplexed socket and channels
- m_Socket = std::make_unique<ComputeSocket>(std::move(FinalTransport));
+void
+AsyncHordeAgent::DoConnect()
+{
+ ZEN_TRACE_CPU("AsyncHordeAgent::DoConnect");
- // Channel 0 is the agent control channel (handles Attach/Fork handshake).
- // Channel 100 is the child I/O channel (handles file upload and remote execution).
- Ref<ComputeChannel> AgentComputeChannel = m_Socket->CreateChannel(0);
- Ref<ComputeChannel> ChildComputeChannel = m_Socket->CreateChannel(100);
+ m_TcpTransport = std::make_unique<AsyncTcpComputeTransport>(m_IoContext);
+
+ auto Self = shared_from_this();
+ m_TcpTransport->AsyncConnect(m_Config.Machine, [this, Self](const std::error_code& Ec) { OnConnected(Ec); });
+}
- if (!AgentComputeChannel || !ChildComputeChannel)
+void
+AsyncHordeAgent::OnConnected(const std::error_code& Ec)
+{
+ if (Ec || m_Cancelled)
{
- ZEN_WARN("failed to create compute channels");
+ if (Ec)
+ {
+ ZEN_WARN("connect failed: {}", Ec.message());
+ }
+ Finish(false);
return;
}
- m_AgentChannel = std::make_unique<AgentMessageChannel>(std::move(AgentComputeChannel));
- m_ChildChannel = std::make_unique<AgentMessageChannel>(std::move(ChildComputeChannel));
+ // Optionally wrap with AES encryption
+ std::unique_ptr<AsyncComputeTransport> FinalTransport = std::move(m_TcpTransport);
+ if (m_Config.Machine.EncryptionMode == Encryption::AES)
+ {
+ FinalTransport = std::make_unique<AsyncAesComputeTransport>(m_Config.Machine.Key, std::move(FinalTransport), m_IoContext);
+ }
+ m_Transport = std::move(FinalTransport);
+
+ // Create the multiplexed socket and register channels
+ m_Socket = std::make_shared<AsyncComputeSocket>(std::move(m_Transport), m_IoContext);
+
+ m_AgentChannel = std::make_unique<AsyncAgentMessageChannel>(m_Socket, 0, m_IoContext);
+ m_ChildChannel = std::make_unique<AsyncAgentMessageChannel>(m_Socket, 100, m_IoContext);
+
+ m_Socket->RegisterChannel(
+ 0,
+ [this](std::vector<uint8_t> Data) { m_AgentChannel->OnFrame(std::move(Data)); },
+ [this]() { m_AgentChannel->OnDetach(); });
- m_IsValid = true;
+ m_Socket->RegisterChannel(
+ 100,
+ [this](std::vector<uint8_t> Data) { m_ChildChannel->OnFrame(std::move(Data)); },
+ [this]() { m_ChildChannel->OnDetach(); });
+
+ m_Socket->StartRecvPump();
+
+ m_State = State::WaitAgentAttach;
+ DoWaitAgentAttach();
}
-HordeAgent::~HordeAgent()
+void
+AsyncHordeAgent::DoWaitAgentAttach()
{
- CloseConnection();
+ auto Self = shared_from_this();
+ m_AgentChannel->AsyncReadResponse(5000, [this, Self](AgentMessageType Type, const uint8_t* Data, size_t Size) {
+ OnAgentResponse(Type, Data, Size);
+ });
}
-bool
-HordeAgent::BeginCommunication()
+void
+AsyncHordeAgent::OnAgentResponse(AgentMessageType Type, const uint8_t* /*Data*/, size_t /*Size*/)
{
- ZEN_TRACE_CPU("HordeAgent::BeginCommunication");
-
- if (!m_IsValid)
+ if (m_Cancelled)
{
- return false;
+ Finish(false);
+ return;
}
- // Start the send/recv pump threads
- m_Socket->StartCommunication();
-
- // Wait for Attach on agent channel
- AgentMessageType Type = m_AgentChannel->ReadResponse(5000);
if (Type == AgentMessageType::None)
{
ZEN_WARN("timed out waiting for Attach on agent channel");
- return false;
+ Finish(false);
+ return;
}
+
if (Type != AgentMessageType::Attach)
{
ZEN_WARN("expected Attach on agent channel, got 0x{:02x}", static_cast<int>(Type));
- return false;
+ Finish(false);
+ return;
}
- // Fork tells the remote agent to create child channel 100 with a 4MB buffer.
- // After this, the agent will send an Attach on the child channel.
+ m_State = State::SentFork;
+ DoSendFork();
+}
+
+void
+AsyncHordeAgent::DoSendFork()
+{
m_AgentChannel->Fork(100, 4 * 1024 * 1024);
- // Wait for Attach on child channel
- Type = m_ChildChannel->ReadResponse(5000);
+ m_State = State::WaitChildAttach;
+ DoWaitChildAttach();
+}
+
+void
+AsyncHordeAgent::DoWaitChildAttach()
+{
+ auto Self = shared_from_this();
+ m_ChildChannel->AsyncReadResponse(5000, [this, Self](AgentMessageType Type, const uint8_t* Data, size_t Size) {
+ OnChildAttachResponse(Type, Data, Size);
+ });
+}
+
+void
+AsyncHordeAgent::OnChildAttachResponse(AgentMessageType Type, const uint8_t* /*Data*/, size_t /*Size*/)
+{
+ if (m_Cancelled)
+ {
+ Finish(false);
+ return;
+ }
+
if (Type == AgentMessageType::None)
{
ZEN_WARN("timed out waiting for Attach on child channel");
- return false;
+ Finish(false);
+ return;
}
+
if (Type != AgentMessageType::Attach)
{
ZEN_WARN("expected Attach on child channel, got 0x{:02x}", static_cast<int>(Type));
- return false;
+ Finish(false);
+ return;
}
- return true;
+ m_State = State::Uploading;
+ m_CurrentBundleIndex = 0;
+ DoUploadNext();
}
-bool
-HordeAgent::UploadBinaries(const std::filesystem::path& BundleDir, const std::string& BundleLocator)
+void
+AsyncHordeAgent::DoUploadNext()
{
- ZEN_TRACE_CPU("HordeAgent::UploadBinaries");
-
- m_ChildChannel->UploadFiles("", BundleLocator.c_str());
+ if (m_Cancelled)
+ {
+ Finish(false);
+ return;
+ }
- std::unordered_map<std::string, std::unique_ptr<BasicFile>> BlobFiles;
+ if (m_CurrentBundleIndex >= m_Config.Bundles.size())
+ {
+ // All bundles uploaded — proceed to execute
+ m_State = State::Executing;
+ DoExecute();
+ return;
+ }
- auto FindOrOpenBlob = [&](std::string_view Locator) -> BasicFile* {
- std::string Key(Locator);
+ const auto& [Locator, BundleDir] = m_Config.Bundles[m_CurrentBundleIndex];
+ m_ChildChannel->UploadFiles("", Locator.c_str());
- if (auto It = BlobFiles.find(Key); It != BlobFiles.end())
- {
- return It->second.get();
- }
+ // Enter the ReadBlob/Blob upload loop
+ auto Self = shared_from_this();
+ m_ChildChannel->AsyncReadResponse(1000, [this, Self](AgentMessageType Type, const uint8_t* Data, size_t Size) {
+ OnUploadResponse(Type, Data, Size);
+ });
+}
- const std::filesystem::path Path = BundleDir / (Key + ".blob");
- std::error_code Ec;
- auto File = std::make_unique<BasicFile>();
- File->Open(Path, BasicFile::Mode::kRead, Ec);
+void
+AsyncHordeAgent::OnUploadResponse(AgentMessageType Type, const uint8_t* Data, size_t Size)
+{
+ if (m_Cancelled)
+ {
+ Finish(false);
+ return;
+ }
- if (Ec)
+ if (Type == AgentMessageType::None)
+ {
+ if (m_ChildChannel->IsDetached())
{
- ZEN_ERROR("cannot read blob file: '{}'", Path);
- return nullptr;
+ ZEN_WARN("connection lost during upload");
+ Finish(false);
+ return;
}
+ // Timeout — retry read
+ auto Self = shared_from_this();
+ m_ChildChannel->AsyncReadResponse(1000, [this, Self](AgentMessageType Type, const uint8_t* Data, size_t Size) {
+ OnUploadResponse(Type, Data, Size);
+ });
+ return;
+ }
- BasicFile* Ptr = File.get();
- BlobFiles.emplace(std::move(Key), std::move(File));
- return Ptr;
- };
+ if (Type == AgentMessageType::WriteFilesResponse)
+ {
+ // This bundle upload is done — move to next
+ ++m_CurrentBundleIndex;
+ DoUploadNext();
+ return;
+ }
- // The upload protocol is request-driven: we send WriteFiles, then the remote agent
- // sends ReadBlob requests for each blob it needs. We respond with Blob data until
- // the agent sends WriteFilesResponse indicating the upload is complete.
- constexpr int32_t ReadResponseTimeoutMs = 1000;
+ if (Type == AgentMessageType::Exception)
+ {
+ ExceptionInfo Ex;
+ AsyncAgentMessageChannel::ReadException(Data, Size, Ex);
+ ZEN_ERROR("upload exception: {} - {}", Ex.Message, Ex.Description);
+ Finish(false);
+ return;
+ }
- for (;;)
+ if (Type != AgentMessageType::ReadBlob)
{
- bool TimedOut = false;
+ ZEN_ERROR("unexpected message type 0x{:02x} during upload", static_cast<int>(Type));
+ Finish(false);
+ return;
+ }
- if (AgentMessageType Type = m_ChildChannel->ReadResponse(ReadResponseTimeoutMs, &TimedOut); Type != AgentMessageType::ReadBlob)
- {
- if (TimedOut)
- {
- continue;
- }
- // End of stream - check if it was a successful upload
- if (Type == AgentMessageType::WriteFilesResponse)
- {
- return true;
- }
- else if (Type == AgentMessageType::Exception)
- {
- ExceptionInfo Ex;
- m_ChildChannel->ReadException(Ex);
- ZEN_ERROR("upload exception: {} - {}", Ex.Message, Ex.Description);
- }
- else
- {
- ZEN_ERROR("unexpected message type 0x{:02x} during upload", static_cast<int>(Type));
- }
- return false;
- }
+ // Handle ReadBlob request
+ BlobRequest Req;
+ AsyncAgentMessageChannel::ReadBlobRequest(Data, Size, Req);
- BlobRequest Req;
- m_ChildChannel->ReadBlobRequest(Req);
+ const auto& [Locator, BundleDir] = m_Config.Bundles[m_CurrentBundleIndex];
+ const std::filesystem::path BlobPath = BundleDir / (std::string(Req.Locator) + ".blob");
- BasicFile* File = FindOrOpenBlob(Req.Locator);
- if (!File)
- {
- return false;
- }
+ std::error_code FsEc;
+ BasicFile File;
+ File.Open(BlobPath, BasicFile::Mode::kRead, FsEc);
- // Read from offset to end of file
- const uint64_t TotalSize = File->FileSize();
- const uint64_t Offset = static_cast<uint64_t>(Req.Offset);
- if (Offset >= TotalSize)
- {
- ZEN_ERROR("upload got request for data beyond end of file: offset={}, length={}, total_size={}", Offset, Req.Length, TotalSize);
- m_ChildChannel->Blob(nullptr, 0);
- continue;
- }
+ if (FsEc)
+ {
+ ZEN_ERROR("cannot read blob file: '{}'", BlobPath);
+ Finish(false);
+ return;
+ }
+
+ const uint64_t TotalSize = File.FileSize();
+ const uint64_t Offset = static_cast<uint64_t>(Req.Offset);
+ if (Offset >= TotalSize)
+ {
+ ZEN_ERROR("blob request beyond end of file: offset={}, length={}, total_size={}", Offset, Req.Length, TotalSize);
+ m_ChildChannel->Blob(nullptr, 0);
+ }
+ else
+ {
+ const IoBuffer FileData = File.ReadRange(Offset, Min(Req.Length, TotalSize - Offset));
+ m_ChildChannel->Blob(static_cast<const uint8_t*>(FileData.GetData()), FileData.GetSize());
+ }
+
+ // Continue the upload loop
+ auto Self = shared_from_this();
+ m_ChildChannel->AsyncReadResponse(1000, [this, Self](AgentMessageType Type, const uint8_t* Data, size_t Size) {
+ OnUploadResponse(Type, Data, Size);
+ });
+}
- const IoBuffer Data = File->ReadRange(Offset, Min(Req.Length, TotalSize - Offset));
- m_ChildChannel->Blob(static_cast<const uint8_t*>(Data.GetData()), Data.GetSize());
+void
+AsyncHordeAgent::DoExecute()
+{
+ ZEN_TRACE_CPU("AsyncHordeAgent::DoExecute");
+
+ std::vector<const char*> ArgPtrs;
+ ArgPtrs.reserve(m_Config.Args.size());
+ for (const std::string& Arg : m_Config.Args)
+ {
+ ArgPtrs.push_back(Arg.c_str());
}
+
+ m_ChildChannel->Execute(m_Config.Executable.c_str(),
+ ArgPtrs.data(),
+ ArgPtrs.size(),
+ nullptr,
+ nullptr,
+ 0,
+ m_Config.UseWine ? ExecuteProcessFlags::UseWine : ExecuteProcessFlags::None);
+
+ ZEN_INFO("remote execution started on [{}:{}] lease={}",
+ m_Config.Machine.GetConnectionAddress(),
+ m_Config.Machine.GetConnectionPort(),
+ m_Config.Machine.LeaseId);
+
+ m_State = State::Polling;
+ DoPoll();
}
void
-HordeAgent::Execute(const char* Exe,
- const char* const* Args,
- size_t NumArgs,
- const char* WorkingDir,
- const char* const* EnvVars,
- size_t NumEnvVars,
- bool UseWine)
+AsyncHordeAgent::DoPoll()
{
- ZEN_TRACE_CPU("HordeAgent::Execute");
- m_ChildChannel
- ->Execute(Exe, Args, NumArgs, WorkingDir, EnvVars, NumEnvVars, UseWine ? ExecuteProcessFlags::UseWine : ExecuteProcessFlags::None);
+ if (m_Cancelled)
+ {
+ Finish(false);
+ return;
+ }
+
+ auto Self = shared_from_this();
+ m_ChildChannel->AsyncReadResponse(100, [this, Self](AgentMessageType Type, const uint8_t* Data, size_t Size) {
+ OnPollResponse(Type, Data, Size);
+ });
}
-bool
-HordeAgent::Poll(bool LogOutput)
+void
+AsyncHordeAgent::OnPollResponse(AgentMessageType Type, const uint8_t* Data, size_t Size)
{
- constexpr int32_t ReadResponseTimeoutMs = 100;
- AgentMessageType Type;
+ if (m_Cancelled)
+ {
+ Finish(false);
+ return;
+ }
- while ((Type = m_ChildChannel->ReadResponse(ReadResponseTimeoutMs)) != AgentMessageType::None)
+ switch (Type)
{
- switch (Type)
- {
- case AgentMessageType::ExecuteOutput:
- {
- if (LogOutput && m_ChildChannel->GetResponseSize() > 0)
- {
- const char* ResponseData = static_cast<const char*>(m_ChildChannel->GetResponseData());
- size_t ResponseSize = m_ChildChannel->GetResponseSize();
-
- // Trim trailing newlines
- while (ResponseSize > 0 && (ResponseData[ResponseSize - 1] == '\n' || ResponseData[ResponseSize - 1] == '\r'))
- {
- --ResponseSize;
- }
-
- if (ResponseSize > 0)
- {
- const std::string_view Output(ResponseData, ResponseSize);
- ZEN_INFO("[remote] {}", Output);
- }
- }
- break;
- }
+ case AgentMessageType::None:
+ if (m_ChildChannel->IsDetached())
+ {
+ ZEN_WARN("connection lost during execution");
+ Finish(false);
+ }
+ else
+ {
+ // Timeout — poll again
+ DoPoll();
+ }
+ break;
- case AgentMessageType::ExecuteResult:
- {
- if (m_ChildChannel->GetResponseSize() == sizeof(int32_t))
- {
- int32_t ExitCode;
- memcpy(&ExitCode, m_ChildChannel->GetResponseData(), sizeof(int32_t));
- ZEN_INFO("remote process exited with code {}", ExitCode);
- }
- m_IsValid = false;
- return false;
- }
+ case AgentMessageType::ExecuteOutput:
+ // Silently consume remote stdout (matching LogOutput=false in provisioner)
+ DoPoll();
+ break;
- case AgentMessageType::Exception:
+ case AgentMessageType::ExecuteResult:
+ {
+ int32_t ExitCode = -1;
+ if (Size == sizeof(int32_t))
{
- ExceptionInfo Ex;
- m_ChildChannel->ReadException(Ex);
- ZEN_ERROR("exception: {} - {}", Ex.Message, Ex.Description);
- m_HasErrors = true;
- break;
+ memcpy(&ExitCode, Data, sizeof(int32_t));
}
+ ZEN_INFO("remote process exited with code {} (lease={})", ExitCode, m_Config.Machine.LeaseId);
+ Finish(ExitCode == 0, ExitCode);
+ }
+ break;
- default:
- break;
- }
- }
+ case AgentMessageType::Exception:
+ {
+ ExceptionInfo Ex;
+ AsyncAgentMessageChannel::ReadException(Data, Size, Ex);
+ ZEN_ERROR("exception: {} - {}", Ex.Message, Ex.Description);
+ Finish(false);
+ }
+ break;
- return m_IsValid && !m_HasErrors;
+ default:
+ DoPoll();
+ break;
+ }
}
void
-HordeAgent::CloseConnection()
+AsyncHordeAgent::Finish(bool Success, int32_t ExitCode)
{
- if (m_ChildChannel)
+ if (m_State == State::Done)
{
- m_ChildChannel->Close();
+ return; // Already finished
}
- if (m_AgentChannel)
+
+ if (!Success)
{
- m_AgentChannel->Close();
+ ZEN_WARN("agent failed during {} (lease={})", GetStateName(m_State), m_Config.Machine.LeaseId);
}
-}
-bool
-HordeAgent::IsValid() const
-{
- return m_IsValid && !m_HasErrors;
+ m_State = State::Done;
+
+ if (m_Socket)
+ {
+ m_Socket->Close();
+ }
+
+ if (m_OnDone)
+ {
+ AsyncAgentResult Result;
+ Result.Success = Success;
+ Result.ExitCode = ExitCode;
+ Result.CoreCount = m_Config.Machine.LogicalCores;
+
+ auto Handler = std::move(m_OnDone);
+ m_OnDone = nullptr;
+ Handler(Result);
+ }
}
} // namespace zen::horde
diff --git a/src/zenhorde/hordeagent.h b/src/zenhorde/hordeagent.h
index e0ae89ead..b581a8da1 100644
--- a/src/zenhorde/hordeagent.h
+++ b/src/zenhorde/hordeagent.h
@@ -10,68 +10,108 @@
#include <zencore/logbase.h>
#include <filesystem>
+#include <functional>
#include <memory>
#include <string>
+#include <vector>
+
+namespace asio {
+class io_context;
+}
namespace zen::horde {
-/** Manages the lifecycle of a single Horde compute agent.
+class AsyncComputeTransport;
+
+/** Result passed to the completion handler when an async agent finishes. */
+struct AsyncAgentResult
+{
+ bool Success = false;
+ int32_t ExitCode = -1;
+ uint16_t CoreCount = 0; ///< Logical cores on the provisioned machine
+};
+
+/** Completion handler for async agent lifecycle. */
+using AsyncAgentCompletionHandler = std::function<void(const AsyncAgentResult&)>;
+
+/** Configuration for launching a remote zenserver instance via an async agent. */
+struct AsyncAgentConfig
+{
+ MachineInfo Machine;
+ std::vector<std::pair<std::string, std::filesystem::path>> Bundles; ///< (locator, bundleDir) pairs
+ std::string Executable;
+ std::vector<std::string> Args;
+ bool UseWine = false;
+};
+
+/** Async agent that manages the full lifecycle of a single Horde compute connection.
*
- * Handles the full connection sequence for one provisioned machine:
- * 1. Connect via TCP transport (with optional AES encryption wrapping)
- * 2. Create a multiplexed ComputeSocket with agent (channel 0) and child (channel 100)
- * 3. Perform the Attach/Fork handshake to establish the child channel
- * 4. Upload zenserver binary via the WriteFiles/ReadBlob protocol
- * 5. Execute zenserver remotely via ExecuteV2
- * 6. Poll for ExecuteOutput (stdout) and ExecuteResult (exit code)
+ * Driven by a state machine using callbacks on a shared io_context — no dedicated
+ * threads. Call Start() to begin the connection/handshake/upload/execute/poll
+ * sequence. The completion handler is invoked when the remote process exits or
+ * an error occurs.
*/
-class HordeAgent
+class AsyncHordeAgent : public std::enable_shared_from_this<AsyncHordeAgent>
{
public:
- explicit HordeAgent(const MachineInfo& Info);
- ~HordeAgent();
+ AsyncHordeAgent(asio::io_context& IoContext);
+ ~AsyncHordeAgent();
- HordeAgent(const HordeAgent&) = delete;
- HordeAgent& operator=(const HordeAgent&) = delete;
+ AsyncHordeAgent(const AsyncHordeAgent&) = delete;
+ AsyncHordeAgent& operator=(const AsyncHordeAgent&) = delete;
- /** Perform the channel setup handshake (Attach on agent channel, Fork, Attach on child channel).
- * Returns false if the handshake times out or receives an unexpected message. */
- bool BeginCommunication();
+ /** Start the full agent lifecycle. The completion handler is called exactly once. */
+ void Start(AsyncAgentConfig Config, AsyncAgentCompletionHandler OnDone);
- /** Upload binary files to the remote agent.
- * @param BundleDir Directory containing .blob files.
- * @param BundleLocator Locator string identifying the bundle (from CreateBundle). */
- bool UploadBinaries(const std::filesystem::path& BundleDir, const std::string& BundleLocator);
+ /** Cancel in-flight operations. The completion handler is still called (with Success=false). */
+ void Cancel();
- /** Execute a command on the remote machine. */
- void Execute(const char* Exe,
- const char* const* Args,
- size_t NumArgs,
- const char* WorkingDir = nullptr,
- const char* const* EnvVars = nullptr,
- size_t NumEnvVars = 0,
- bool UseWine = false);
+ const MachineInfo& GetMachineInfo() const { return m_Config.Machine; }
- /** Poll for output and results. Returns true if the agent is still running.
- * When LogOutput is true, remote stdout is logged via ZEN_INFO. */
- bool Poll(bool LogOutput = true);
-
- void CloseConnection();
- bool IsValid() const;
-
- const MachineInfo& GetMachineInfo() const { return m_MachineInfo; }
+ enum class State
+ {
+ Idle,
+ Connecting,
+ WaitAgentAttach,
+ SentFork,
+ WaitChildAttach,
+ Uploading,
+ Executing,
+ Polling,
+ Done
+ };
private:
LoggerRef Log() { return m_Log; }
- std::unique_ptr<ComputeSocket> m_Socket;
- std::unique_ptr<AgentMessageChannel> m_AgentChannel; ///< Channel 0: agent control
- std::unique_ptr<AgentMessageChannel> m_ChildChannel; ///< Channel 100: child I/O
-
- LoggerRef m_Log;
- bool m_IsValid = false;
- bool m_HasErrors = false;
- MachineInfo m_MachineInfo;
+ void DoConnect();
+ void OnConnected(const std::error_code& Ec);
+ void DoWaitAgentAttach();
+ void OnAgentResponse(AgentMessageType Type, const uint8_t* Data, size_t Size);
+ void DoSendFork();
+ void DoWaitChildAttach();
+ void OnChildAttachResponse(AgentMessageType Type, const uint8_t* Data, size_t Size);
+ void DoUploadNext();
+ void OnUploadResponse(AgentMessageType Type, const uint8_t* Data, size_t Size);
+ void DoExecute();
+ void DoPoll();
+ void OnPollResponse(AgentMessageType Type, const uint8_t* Data, size_t Size);
+ void Finish(bool Success, int32_t ExitCode = -1);
+
+ asio::io_context& m_IoContext;
+ LoggerRef m_Log;
+ State m_State = State::Idle;
+ bool m_Cancelled = false;
+
+ AsyncAgentConfig m_Config;
+ AsyncAgentCompletionHandler m_OnDone;
+ size_t m_CurrentBundleIndex = 0;
+
+ std::unique_ptr<AsyncTcpComputeTransport> m_TcpTransport;
+ std::unique_ptr<AsyncComputeTransport> m_Transport;
+ std::shared_ptr<AsyncComputeSocket> m_Socket;
+ std::unique_ptr<AsyncAgentMessageChannel> m_AgentChannel;
+ std::unique_ptr<AsyncAgentMessageChannel> m_ChildChannel;
};
} // namespace zen::horde
diff --git a/src/zenhorde/hordeagentmessage.cpp b/src/zenhorde/hordeagentmessage.cpp
index 998134a96..31498972f 100644
--- a/src/zenhorde/hordeagentmessage.cpp
+++ b/src/zenhorde/hordeagentmessage.cpp
@@ -4,337 +4,403 @@
#include <zencore/intmath.h>
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <asio.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
#include <cassert>
#include <cstring>
namespace zen::horde {
-AgentMessageChannel::AgentMessageChannel(Ref<ComputeChannel> Channel) : m_Channel(std::move(Channel))
+// --- AsyncAgentMessageChannel ---
+
+AsyncAgentMessageChannel::AsyncAgentMessageChannel(std::shared_ptr<AsyncComputeSocket> Socket, int ChannelId, asio::io_context& IoContext)
+: m_Socket(std::move(Socket))
+, m_ChannelId(ChannelId)
+, m_IoContext(IoContext)
+, m_TimeoutTimer(std::make_unique<asio::steady_timer>(IoContext))
{
}
-AgentMessageChannel::~AgentMessageChannel() = default;
-
-void
-AgentMessageChannel::Close()
+AsyncAgentMessageChannel::~AsyncAgentMessageChannel()
{
- CreateMessage(AgentMessageType::None, 0);
- FlushMessage();
+ if (m_TimeoutTimer)
+ {
+ m_TimeoutTimer->cancel();
+ }
}
-void
-AgentMessageChannel::Ping()
+// --- Message building helpers ---
+
+std::vector<uint8_t>
+AsyncAgentMessageChannel::BeginMessage(AgentMessageType Type, size_t ReservePayload)
{
- CreateMessage(AgentMessageType::Ping, 0);
- FlushMessage();
+ std::vector<uint8_t> Buf;
+ Buf.reserve(MessageHeaderLength + ReservePayload);
+ Buf.push_back(static_cast<uint8_t>(Type));
+ Buf.resize(MessageHeaderLength); // 1 byte type + 4 bytes length placeholder
+ return Buf;
}
void
-AgentMessageChannel::Fork(int ChannelId, int BufferSize)
+AsyncAgentMessageChannel::FinalizeAndSend(std::vector<uint8_t> Msg)
{
- CreateMessage(AgentMessageType::Fork, sizeof(int) + sizeof(int));
- WriteInt32(ChannelId);
- WriteInt32(BufferSize);
- FlushMessage();
+ const uint32_t PayloadSize = static_cast<uint32_t>(Msg.size() - MessageHeaderLength);
+ memcpy(&Msg[1], &PayloadSize, sizeof(uint32_t));
+ m_Socket->AsyncSendFrame(m_ChannelId, std::move(Msg));
}
void
-AgentMessageChannel::Attach()
+AsyncAgentMessageChannel::WriteInt32(std::vector<uint8_t>& Buf, int Value)
{
- CreateMessage(AgentMessageType::Attach, 0);
- FlushMessage();
+ const uint8_t* Ptr = reinterpret_cast<const uint8_t*>(&Value);
+ Buf.insert(Buf.end(), Ptr, Ptr + sizeof(int));
}
-void
-AgentMessageChannel::UploadFiles(const char* Path, const char* Locator)
+int
+AsyncAgentMessageChannel::ReadInt32(const uint8_t** Pos)
{
- CreateMessage(AgentMessageType::WriteFiles, strlen(Path) + strlen(Locator) + 20);
- WriteString(Path);
- WriteString(Locator);
- FlushMessage();
+ int Value;
+ memcpy(&Value, *Pos, sizeof(int));
+ *Pos += sizeof(int);
+ return Value;
}
void
-AgentMessageChannel::Execute(const char* Exe,
- const char* const* Args,
- size_t NumArgs,
- const char* WorkingDir,
- const char* const* EnvVars,
- size_t NumEnvVars,
- ExecuteProcessFlags Flags)
+AsyncAgentMessageChannel::WriteFixedLengthBytes(std::vector<uint8_t>& Buf, const uint8_t* Data, size_t Length)
{
- size_t RequiredSize = 50 + strlen(Exe);
- for (size_t i = 0; i < NumArgs; ++i)
- {
- RequiredSize += strlen(Args[i]) + 10;
- }
- if (WorkingDir)
- {
- RequiredSize += strlen(WorkingDir) + 10;
- }
- for (size_t i = 0; i < NumEnvVars; ++i)
- {
- RequiredSize += strlen(EnvVars[i]) + 20;
- }
-
- CreateMessage(AgentMessageType::ExecuteV2, RequiredSize);
- WriteString(Exe);
-
- WriteUnsignedVarInt(NumArgs);
- for (size_t i = 0; i < NumArgs; ++i)
- {
- WriteString(Args[i]);
- }
-
- WriteOptionalString(WorkingDir);
-
- // ExecuteV2 protocol requires env vars as separate key/value pairs.
- // Callers pass "KEY=VALUE" strings; we split on the first '=' here.
- WriteUnsignedVarInt(NumEnvVars);
- for (size_t i = 0; i < NumEnvVars; ++i)
- {
- const char* Eq = strchr(EnvVars[i], '=');
- assert(Eq != nullptr);
-
- WriteString(std::string_view(EnvVars[i], Eq - EnvVars[i]));
- if (*(Eq + 1) == '\0')
- {
- WriteOptionalString(nullptr);
- }
- else
- {
- WriteOptionalString(Eq + 1);
- }
- }
+ Buf.insert(Buf.end(), Data, Data + Length);
+}
- WriteInt32(static_cast<int>(Flags));
- FlushMessage();
+const uint8_t*
+AsyncAgentMessageChannel::ReadFixedLengthBytes(const uint8_t** Pos, size_t Length)
+{
+ const uint8_t* Data = *Pos;
+ *Pos += Length;
+ return Data;
}
-void
-AgentMessageChannel::Blob(const uint8_t* Data, size_t Length)
+size_t
+AsyncAgentMessageChannel::MeasureUnsignedVarInt(size_t Value)
{
- // Blob responses are chunked to fit within the compute buffer's chunk size.
- // The 128-byte margin accounts for the ReadBlobResponse header (offset + total length fields).
- const size_t MaxChunkSize = m_Channel->Writer.GetChunkMaxLength() - 128 - MessageHeaderLength;
- for (size_t ChunkOffset = 0; ChunkOffset < Length;)
+ if (Value == 0)
{
- const size_t ChunkLength = std::min(Length - ChunkOffset, MaxChunkSize);
-
- CreateMessage(AgentMessageType::ReadBlobResponse, ChunkLength + 128);
- WriteInt32(static_cast<int>(ChunkOffset));
- WriteInt32(static_cast<int>(Length));
- WriteFixedLengthBytes(Data + ChunkOffset, ChunkLength);
- FlushMessage();
-
- ChunkOffset += ChunkLength;
+ return 1;
}
+ return (FloorLog2_64(static_cast<uint64_t>(Value)) / 7) + 1;
}
-AgentMessageType
-AgentMessageChannel::ReadResponse(int32_t TimeoutMs, bool* OutTimedOut)
+void
+AsyncAgentMessageChannel::WriteUnsignedVarInt(std::vector<uint8_t>& Buf, size_t Value)
{
- // Deferred advance: the previous response's buffer is only released when the next
- // ReadResponse is called. This allows callers to read response data between calls
- // without copying, since the pointer comes directly from the ring buffer.
- if (m_ResponseData)
- {
- m_Channel->Reader.AdvanceReadPosition(m_ResponseLength + MessageHeaderLength);
- m_ResponseData = nullptr;
- m_ResponseLength = 0;
- }
+ const size_t ByteCount = MeasureUnsignedVarInt(Value);
+ const size_t StartPos = Buf.size();
+ Buf.resize(StartPos + ByteCount);
- const uint8_t* Header = m_Channel->Reader.WaitToRead(MessageHeaderLength, TimeoutMs, OutTimedOut);
- if (!Header)
+ uint8_t* Output = Buf.data() + StartPos;
+ for (size_t i = 1; i < ByteCount; ++i)
{
- return AgentMessageType::None;
+ Output[ByteCount - i] = static_cast<uint8_t>(Value);
+ Value >>= 8;
}
+ Output[0] = static_cast<uint8_t>((0xFF << (9 - static_cast<int>(ByteCount))) | static_cast<uint8_t>(Value));
+}
- uint32_t Length;
- memcpy(&Length, Header + 1, sizeof(uint32_t));
+size_t
+AsyncAgentMessageChannel::ReadUnsignedVarInt(const uint8_t** Pos)
+{
+ const uint8_t* Data = *Pos;
+ const uint8_t FirstByte = Data[0];
+ const size_t NumBytes = CountLeadingZeros(0xFF & (~static_cast<unsigned int>(FirstByte))) + 1 - 24;
- Header = m_Channel->Reader.WaitToRead(MessageHeaderLength + Length, TimeoutMs, OutTimedOut);
- if (!Header)
+ size_t Value = static_cast<size_t>(FirstByte & (0xFF >> NumBytes));
+ for (size_t i = 1; i < NumBytes; ++i)
{
- return AgentMessageType::None;
+ Value <<= 8;
+ Value |= Data[i];
}
- m_ResponseType = static_cast<AgentMessageType>(Header[0]);
- m_ResponseData = Header + MessageHeaderLength;
- m_ResponseLength = Length;
-
- return m_ResponseType;
+ *Pos += NumBytes;
+ return Value;
}
void
-AgentMessageChannel::ReadException(ExceptionInfo& Ex)
+AsyncAgentMessageChannel::WriteString(std::vector<uint8_t>& Buf, const char* Text)
{
- assert(m_ResponseType == AgentMessageType::Exception);
- const uint8_t* Pos = m_ResponseData;
- Ex.Message = ReadString(&Pos);
- Ex.Description = ReadString(&Pos);
+ const size_t Length = strlen(Text);
+ WriteUnsignedVarInt(Buf, Length);
+ WriteFixedLengthBytes(Buf, reinterpret_cast<const uint8_t*>(Text), Length);
}
-int
-AgentMessageChannel::ReadExecuteResult()
+void
+AsyncAgentMessageChannel::WriteString(std::vector<uint8_t>& Buf, std::string_view Text)
{
- assert(m_ResponseType == AgentMessageType::ExecuteResult);
- const uint8_t* Pos = m_ResponseData;
- return ReadInt32(&Pos);
+ WriteUnsignedVarInt(Buf, Text.size());
+ WriteFixedLengthBytes(Buf, reinterpret_cast<const uint8_t*>(Text.data()), Text.size());
}
-void
-AgentMessageChannel::ReadBlobRequest(BlobRequest& Req)
+std::string_view
+AsyncAgentMessageChannel::ReadString(const uint8_t** Pos)
{
- assert(m_ResponseType == AgentMessageType::ReadBlob);
- const uint8_t* Pos = m_ResponseData;
- Req.Locator = ReadString(&Pos);
- Req.Offset = ReadUnsignedVarInt(&Pos);
- Req.Length = ReadUnsignedVarInt(&Pos);
+ const size_t Length = ReadUnsignedVarInt(Pos);
+ const char* Start = reinterpret_cast<const char*>(ReadFixedLengthBytes(Pos, Length));
+ return std::string_view(Start, Length);
}
void
-AgentMessageChannel::CreateMessage(AgentMessageType Type, size_t MaxLength)
+AsyncAgentMessageChannel::WriteOptionalString(std::vector<uint8_t>& Buf, const char* Text)
{
- m_RequestData = m_Channel->Writer.WaitToWrite(MessageHeaderLength + MaxLength);
- m_RequestData[0] = static_cast<uint8_t>(Type);
- m_MaxRequestSize = MaxLength;
- m_RequestSize = 0;
+ if (!Text)
+ {
+ WriteUnsignedVarInt(Buf, 0);
+ }
+ else
+ {
+ const size_t Length = strlen(Text);
+ WriteUnsignedVarInt(Buf, Length + 1);
+ WriteFixedLengthBytes(Buf, reinterpret_cast<const uint8_t*>(Text), Length);
+ }
}
+// --- Send methods ---
+
void
-AgentMessageChannel::FlushMessage()
+AsyncAgentMessageChannel::Close()
{
- const uint32_t Size = static_cast<uint32_t>(m_RequestSize);
- memcpy(&m_RequestData[1], &Size, sizeof(uint32_t));
- m_Channel->Writer.AdvanceWritePosition(MessageHeaderLength + m_RequestSize);
- m_RequestSize = 0;
- m_MaxRequestSize = 0;
- m_RequestData = nullptr;
+ auto Msg = BeginMessage(AgentMessageType::None, 0);
+ FinalizeAndSend(std::move(Msg));
}
void
-AgentMessageChannel::WriteInt32(int Value)
+AsyncAgentMessageChannel::Ping()
{
- WriteFixedLengthBytes(reinterpret_cast<const uint8_t*>(&Value), sizeof(int));
+ auto Msg = BeginMessage(AgentMessageType::Ping, 0);
+ FinalizeAndSend(std::move(Msg));
}
-int
-AgentMessageChannel::ReadInt32(const uint8_t** Pos)
+void
+AsyncAgentMessageChannel::Fork(int ChannelId, int BufferSize)
{
- int Value;
- memcpy(&Value, *Pos, sizeof(int));
- *Pos += sizeof(int);
- return Value;
+ auto Msg = BeginMessage(AgentMessageType::Fork, sizeof(int) + sizeof(int));
+ WriteInt32(Msg, ChannelId);
+ WriteInt32(Msg, BufferSize);
+ FinalizeAndSend(std::move(Msg));
}
void
-AgentMessageChannel::WriteFixedLengthBytes(const uint8_t* Data, size_t Length)
+AsyncAgentMessageChannel::Attach()
{
- assert(m_RequestSize + Length <= m_MaxRequestSize);
- memcpy(&m_RequestData[MessageHeaderLength + m_RequestSize], Data, Length);
- m_RequestSize += Length;
+ auto Msg = BeginMessage(AgentMessageType::Attach, 0);
+ FinalizeAndSend(std::move(Msg));
}
-const uint8_t*
-AgentMessageChannel::ReadFixedLengthBytes(const uint8_t** Pos, size_t Length)
+void
+AsyncAgentMessageChannel::UploadFiles(const char* Path, const char* Locator)
{
- const uint8_t* Data = *Pos;
- *Pos += Length;
- return Data;
+ auto Msg = BeginMessage(AgentMessageType::WriteFiles, strlen(Path) + strlen(Locator) + 20);
+ WriteString(Msg, Path);
+ WriteString(Msg, Locator);
+ FinalizeAndSend(std::move(Msg));
}
-size_t
-AgentMessageChannel::MeasureUnsignedVarInt(size_t Value)
+void
+AsyncAgentMessageChannel::Execute(const char* Exe,
+ const char* const* Args,
+ size_t NumArgs,
+ const char* WorkingDir,
+ const char* const* EnvVars,
+ size_t NumEnvVars,
+ ExecuteProcessFlags Flags)
{
- if (Value == 0)
+ size_t ReserveSize = 50 + strlen(Exe);
+ for (size_t i = 0; i < NumArgs; ++i)
{
- return 1;
+ ReserveSize += strlen(Args[i]) + 10;
}
- return (FloorLog2_64(static_cast<uint64_t>(Value)) / 7) + 1;
+ if (WorkingDir)
+ {
+ ReserveSize += strlen(WorkingDir) + 10;
+ }
+ for (size_t i = 0; i < NumEnvVars; ++i)
+ {
+ ReserveSize += strlen(EnvVars[i]) + 20;
+ }
+
+ auto Msg = BeginMessage(AgentMessageType::ExecuteV2, ReserveSize);
+ WriteString(Msg, Exe);
+
+ WriteUnsignedVarInt(Msg, NumArgs);
+ for (size_t i = 0; i < NumArgs; ++i)
+ {
+ WriteString(Msg, Args[i]);
+ }
+
+ WriteOptionalString(Msg, WorkingDir);
+
+ WriteUnsignedVarInt(Msg, NumEnvVars);
+ for (size_t i = 0; i < NumEnvVars; ++i)
+ {
+ const char* Eq = strchr(EnvVars[i], '=');
+ assert(Eq != nullptr);
+
+ WriteString(Msg, std::string_view(EnvVars[i], Eq - EnvVars[i]));
+ if (*(Eq + 1) == '\0')
+ {
+ WriteOptionalString(Msg, nullptr);
+ }
+ else
+ {
+ WriteOptionalString(Msg, Eq + 1);
+ }
+ }
+
+ WriteInt32(Msg, static_cast<int>(Flags));
+ FinalizeAndSend(std::move(Msg));
}
void
-AgentMessageChannel::WriteUnsignedVarInt(size_t Value)
+AsyncAgentMessageChannel::Blob(const uint8_t* Data, size_t Length)
{
- const size_t ByteCount = MeasureUnsignedVarInt(Value);
- assert(m_RequestSize + ByteCount <= m_MaxRequestSize);
+ static constexpr size_t MaxBlobChunkSize = 512 * 1024;
- uint8_t* Output = m_RequestData + MessageHeaderLength + m_RequestSize;
- for (size_t i = 1; i < ByteCount; ++i)
+ for (size_t ChunkOffset = 0; ChunkOffset < Length;)
{
- Output[ByteCount - i] = static_cast<uint8_t>(Value);
- Value >>= 8;
- }
- Output[0] = static_cast<uint8_t>((0xFF << (9 - static_cast<int>(ByteCount))) | static_cast<uint8_t>(Value));
+ const size_t ChunkLength = std::min(Length - ChunkOffset, MaxBlobChunkSize);
- m_RequestSize += ByteCount;
+ auto Msg = BeginMessage(AgentMessageType::ReadBlobResponse, ChunkLength + 128);
+ WriteInt32(Msg, static_cast<int>(ChunkOffset));
+ WriteInt32(Msg, static_cast<int>(Length));
+ WriteFixedLengthBytes(Msg, Data + ChunkOffset, ChunkLength);
+ FinalizeAndSend(std::move(Msg));
+
+ ChunkOffset += ChunkLength;
+ }
}
-size_t
-AgentMessageChannel::ReadUnsignedVarInt(const uint8_t** Pos)
+// --- Async response reading ---
+
+void
+AsyncAgentMessageChannel::AsyncReadResponse(int32_t TimeoutMs, AsyncResponseHandler Handler)
{
- const uint8_t* Data = *Pos;
- const uint8_t FirstByte = Data[0];
- const size_t NumBytes = CountLeadingZeros(0xFF & (~static_cast<unsigned int>(FirstByte))) + 1 - 24;
+ // If frames are already queued, dispatch immediately
+ if (!m_IncomingFrames.empty())
+ {
+ std::vector<uint8_t> Frame = std::move(m_IncomingFrames.front());
+ m_IncomingFrames.pop_front();
- size_t Value = static_cast<size_t>(FirstByte & (0xFF >> NumBytes));
- for (size_t i = 1; i < NumBytes; ++i)
+ if (Frame.size() >= MessageHeaderLength)
+ {
+ AgentMessageType Type = static_cast<AgentMessageType>(Frame[0]);
+ const uint8_t* Data = Frame.data() + MessageHeaderLength;
+ size_t Size = Frame.size() - MessageHeaderLength;
+ asio::post(m_IoContext, [Handler = std::move(Handler), Type, Frame = std::move(Frame), Data, Size]() mutable {
+ // The Frame is captured to keep Data pointer valid
+ Handler(Type, Data, Size);
+ });
+ }
+ else
+ {
+ asio::post(m_IoContext, [Handler = std::move(Handler)] { Handler(AgentMessageType::None, nullptr, 0); });
+ }
+ return;
+ }
+
+ if (m_Detached)
{
- Value <<= 8;
- Value |= Data[i];
+ asio::post(m_IoContext, [Handler = std::move(Handler)] { Handler(AgentMessageType::None, nullptr, 0); });
+ return;
}
- *Pos += NumBytes;
- return Value;
+ // No frames queued — store pending handler and arm timeout
+ m_PendingHandler = std::move(Handler);
+
+ if (TimeoutMs >= 0)
+ {
+ m_TimeoutTimer->expires_after(std::chrono::milliseconds(TimeoutMs));
+ m_TimeoutTimer->async_wait([this](const asio::error_code& Ec) {
+ if (Ec)
+ {
+ return; // Cancelled — frame arrived before timeout
+ }
+
+ if (m_PendingHandler)
+ {
+ AsyncResponseHandler Handler = std::move(m_PendingHandler);
+ m_PendingHandler = nullptr;
+ Handler(AgentMessageType::None, nullptr, 0);
+ }
+ });
+ }
}
-size_t
-AgentMessageChannel::MeasureString(const char* Text) const
+void
+AsyncAgentMessageChannel::OnFrame(std::vector<uint8_t> Data)
{
- const size_t Length = strlen(Text);
- return MeasureUnsignedVarInt(Length) + Length;
+ if (m_PendingHandler)
+ {
+ // Cancel the timeout timer
+ m_TimeoutTimer->cancel();
+
+ AsyncResponseHandler Handler = std::move(m_PendingHandler);
+ m_PendingHandler = nullptr;
+
+ if (Data.size() >= MessageHeaderLength)
+ {
+ AgentMessageType Type = static_cast<AgentMessageType>(Data[0]);
+ const uint8_t* Payload = Data.data() + MessageHeaderLength;
+ size_t PayloadSize = Data.size() - MessageHeaderLength;
+ Handler(Type, Payload, PayloadSize);
+ }
+ else
+ {
+ Handler(AgentMessageType::None, nullptr, 0);
+ }
+ }
+ else
+ {
+ m_IncomingFrames.push_back(std::move(Data));
+ }
}
void
-AgentMessageChannel::WriteString(const char* Text)
+AsyncAgentMessageChannel::OnDetach()
{
- const size_t Length = strlen(Text);
- WriteUnsignedVarInt(Length);
- WriteFixedLengthBytes(reinterpret_cast<const uint8_t*>(Text), Length);
+ m_Detached = true;
+
+ if (m_PendingHandler)
+ {
+ m_TimeoutTimer->cancel();
+ AsyncResponseHandler Handler = std::move(m_PendingHandler);
+ m_PendingHandler = nullptr;
+ Handler(AgentMessageType::None, nullptr, 0);
+ }
}
+// --- Response parsing helpers ---
+
void
-AgentMessageChannel::WriteString(std::string_view Text)
+AsyncAgentMessageChannel::ReadException(const uint8_t* Data, size_t /*Size*/, ExceptionInfo& Ex)
{
- WriteUnsignedVarInt(Text.size());
- WriteFixedLengthBytes(reinterpret_cast<const uint8_t*>(Text.data()), Text.size());
+ const uint8_t* Pos = Data;
+ Ex.Message = ReadString(&Pos);
+ Ex.Description = ReadString(&Pos);
}
-std::string_view
-AgentMessageChannel::ReadString(const uint8_t** Pos)
+int
+AsyncAgentMessageChannel::ReadExecuteResult(const uint8_t* Data, size_t /*Size*/)
{
- const size_t Length = ReadUnsignedVarInt(Pos);
- const char* Start = reinterpret_cast<const char*>(ReadFixedLengthBytes(Pos, Length));
- return std::string_view(Start, Length);
+ const uint8_t* Pos = Data;
+ return ReadInt32(&Pos);
}
void
-AgentMessageChannel::WriteOptionalString(const char* Text)
+AsyncAgentMessageChannel::ReadBlobRequest(const uint8_t* Data, size_t /*Size*/, BlobRequest& Req)
{
- // Optional strings use length+1 encoding: 0 means null/absent,
- // N>0 means a string of length N-1 follows. This matches the UE
- // FAgentMessageChannel serialization convention.
- if (!Text)
- {
- WriteUnsignedVarInt(0);
- }
- else
- {
- const size_t Length = strlen(Text);
- WriteUnsignedVarInt(Length + 1);
- WriteFixedLengthBytes(reinterpret_cast<const uint8_t*>(Text), Length);
- }
+ const uint8_t* Pos = Data;
+ Req.Locator = ReadString(&Pos);
+ Req.Offset = ReadUnsignedVarInt(&Pos);
+ Req.Length = ReadUnsignedVarInt(&Pos);
}
} // namespace zen::horde
diff --git a/src/zenhorde/hordeagentmessage.h b/src/zenhorde/hordeagentmessage.h
index 38c4375fd..0068fb468 100644
--- a/src/zenhorde/hordeagentmessage.h
+++ b/src/zenhorde/hordeagentmessage.h
@@ -4,14 +4,22 @@
#include <zenbase/zenbase.h>
-#include "hordecomputechannel.h"
+#include "hordecomputesocket.h"
#include <cstddef>
#include <cstdint>
+#include <deque>
+#include <functional>
+#include <memory>
#include <string>
#include <string_view>
+#include <system_error>
#include <vector>
+namespace asio {
+class io_context;
+} // namespace asio
+
namespace zen::horde {
/** Agent message types matching the UE EAgentMessageType byte values.
@@ -55,45 +63,34 @@ struct BlobRequest
size_t Length = 0;
};
-/** Channel for sending and receiving agent messages over a ComputeChannel.
+/** Handler for async response reads. Receives the message type and a view of the payload data.
+ * The payload vector is valid until the next AsyncReadResponse call. */
+using AsyncResponseHandler = std::function<void(AgentMessageType Type, const uint8_t* Data, size_t Size)>;
+
+/** Async channel for sending and receiving agent messages over an AsyncComputeSocket.
*
- * Implements the Horde agent message protocol, matching the UE
- * FAgentMessageChannel serialization format exactly. Messages are framed as
- * [type (1B)][payload length (4B)][payload]. Strings use length-prefixed UTF-8;
- * integers use variable-length encoding.
+ * Send methods build messages into vectors and submit them via AsyncComputeSocket.
+ * Receives are delivered via the socket's FrameHandler callback and queued internally.
+ * AsyncReadResponse checks the queue and invokes the handler, with optional timeout.
*
- * The protocol has two directions:
- * - Requests (initiator -> remote): Close, Ping, Fork, Attach, UploadFiles, Execute, Blob
- * - Responses (remote -> initiator): ReadResponse returns the type, then call the
- * appropriate Read* method to parse the payload.
+ * All operations must be externally serialized (e.g. via the socket's strand).
*/
-class AgentMessageChannel
+class AsyncAgentMessageChannel
{
public:
- explicit AgentMessageChannel(Ref<ComputeChannel> Channel);
- ~AgentMessageChannel();
+ AsyncAgentMessageChannel(std::shared_ptr<AsyncComputeSocket> Socket, int ChannelId, asio::io_context& IoContext);
+ ~AsyncAgentMessageChannel();
- AgentMessageChannel(const AgentMessageChannel&) = delete;
- AgentMessageChannel& operator=(const AgentMessageChannel&) = delete;
+ AsyncAgentMessageChannel(const AsyncAgentMessageChannel&) = delete;
+ AsyncAgentMessageChannel& operator=(const AsyncAgentMessageChannel&) = delete;
- // --- Requests (Initiator -> Remote) ---
+ // --- Requests (fire-and-forget sends) ---
- /** Close the channel. */
void Close();
-
- /** Send a keepalive ping. */
void Ping();
-
- /** Fork communication to a new channel with the given ID and buffer size. */
void Fork(int ChannelId, int BufferSize);
-
- /** Send an attach request (used during channel setup handshake). */
void Attach();
-
- /** Request the remote agent to write files from the given bundle locator. */
void UploadFiles(const char* Path, const char* Locator);
-
- /** Execute a process on the remote machine. */
void Execute(const char* Exe,
const char* const* Args,
size_t NumArgs,
@@ -101,61 +98,61 @@ public:
const char* const* EnvVars,
size_t NumEnvVars,
ExecuteProcessFlags Flags = ExecuteProcessFlags::None);
-
- /** Send blob data in response to a ReadBlob request. */
void Blob(const uint8_t* Data, size_t Length);
- // --- Responses (Remote -> Initiator) ---
-
- /** Read the next response message. Returns the message type, or None on timeout.
- * After this returns, use GetResponseData()/GetResponseSize() or the typed
- * Read* methods to access the payload. */
- AgentMessageType ReadResponse(int32_t TimeoutMs = -1, bool* OutTimedOut = nullptr);
+ // --- Async response reading ---
- const void* GetResponseData() const { return m_ResponseData; }
- size_t GetResponseSize() const { return m_ResponseLength; }
+ /** Read the next response. If a frame is already queued, the handler is posted immediately.
+ * Otherwise waits up to TimeoutMs for a frame to arrive. On timeout, invokes the handler
+ * with AgentMessageType::None. */
+ void AsyncReadResponse(int32_t TimeoutMs, AsyncResponseHandler Handler);
- /** Parse an Exception response payload. */
- void ReadException(ExceptionInfo& Ex);
+ /** Called by the socket's FrameHandler when a frame arrives for this channel. */
+ void OnFrame(std::vector<uint8_t> Data);
- /** Parse an ExecuteResult response payload. Returns the exit code. */
- int ReadExecuteResult();
+ /** Called by the socket's DetachHandler. */
+ void OnDetach();
- /** Parse a ReadBlob response payload into a BlobRequest. */
- void ReadBlobRequest(BlobRequest& Req);
-
-private:
- static constexpr size_t MessageHeaderLength = 5; ///< [type(1B)][length(4B)]
+ /** Returns true if the channel has been detached (connection lost). */
+ bool IsDetached() const { return m_Detached; }
- Ref<ComputeChannel> m_Channel;
+ // --- Response parsing helpers ---
- uint8_t* m_RequestData = nullptr;
- size_t m_RequestSize = 0;
- size_t m_MaxRequestSize = 0;
+ static void ReadException(const uint8_t* Data, size_t Size, ExceptionInfo& Ex);
+ static int ReadExecuteResult(const uint8_t* Data, size_t Size);
+ static void ReadBlobRequest(const uint8_t* Data, size_t Size, BlobRequest& Req);
- AgentMessageType m_ResponseType = AgentMessageType::None;
- const uint8_t* m_ResponseData = nullptr;
- size_t m_ResponseLength = 0;
+private:
+ static constexpr size_t MessageHeaderLength = 5;
- void CreateMessage(AgentMessageType Type, size_t MaxLength);
- void FlushMessage();
+ // Message building helpers
+ std::vector<uint8_t> BeginMessage(AgentMessageType Type, size_t ReservePayload);
+ void FinalizeAndSend(std::vector<uint8_t> Msg);
- void WriteInt32(int Value);
- static int ReadInt32(const uint8_t** Pos);
+ static void WriteInt32(std::vector<uint8_t>& Buf, int Value);
+ static int ReadInt32(const uint8_t** Pos);
- void WriteFixedLengthBytes(const uint8_t* Data, size_t Length);
+ static void WriteFixedLengthBytes(std::vector<uint8_t>& Buf, const uint8_t* Data, size_t Length);
static const uint8_t* ReadFixedLengthBytes(const uint8_t** Pos, size_t Length);
static size_t MeasureUnsignedVarInt(size_t Value);
- void WriteUnsignedVarInt(size_t Value);
+ static void WriteUnsignedVarInt(std::vector<uint8_t>& Buf, size_t Value);
static size_t ReadUnsignedVarInt(const uint8_t** Pos);
- size_t MeasureString(const char* Text) const;
- void WriteString(const char* Text);
- void WriteString(std::string_view Text);
+ static void WriteString(std::vector<uint8_t>& Buf, const char* Text);
+ static void WriteString(std::vector<uint8_t>& Buf, std::string_view Text);
static std::string_view ReadString(const uint8_t** Pos);
- void WriteOptionalString(const char* Text);
+ static void WriteOptionalString(std::vector<uint8_t>& Buf, const char* Text);
+
+ std::shared_ptr<AsyncComputeSocket> m_Socket;
+ int m_ChannelId;
+ asio::io_context& m_IoContext;
+
+ std::deque<std::vector<uint8_t>> m_IncomingFrames;
+ AsyncResponseHandler m_PendingHandler;
+ std::unique_ptr<asio::steady_timer> m_TimeoutTimer;
+ bool m_Detached = false;
};
} // namespace zen::horde
diff --git a/src/zenhorde/hordebundle.cpp b/src/zenhorde/hordebundle.cpp
index d3974bc28..af6b97e59 100644
--- a/src/zenhorde/hordebundle.cpp
+++ b/src/zenhorde/hordebundle.cpp
@@ -57,7 +57,7 @@ MeasureVarInt(size_t Value)
{
return 1;
}
- return (FloorLog2(static_cast<unsigned int>(Value)) / 7) + 1;
+ return (FloorLog2_64(static_cast<uint64_t>(Value)) / 7) + 1;
}
static void
diff --git a/src/zenhorde/hordeclient.cpp b/src/zenhorde/hordeclient.cpp
index 0eefc57c6..618a85e0e 100644
--- a/src/zenhorde/hordeclient.cpp
+++ b/src/zenhorde/hordeclient.cpp
@@ -4,6 +4,7 @@
#include <zencore/iobuffer.h>
#include <zencore/logging.h>
#include <zencore/memoryview.h>
+#include <zencore/string.h>
#include <zencore/trace.h>
#include <zenhorde/hordeclient.h>
#include <zenhttp/httpclient.h>
@@ -14,7 +15,7 @@ ZEN_THIRD_PARTY_INCLUDES_END
namespace zen::horde {
-HordeClient::HordeClient(const HordeConfig& Config) : m_Config(Config), m_Log(zen::logging::Get("horde.client"))
+HordeClient::HordeClient(HordeConfig Config) : m_Config(std::move(Config)), m_Log("horde.client")
{
}
@@ -32,7 +33,11 @@ HordeClient::Initialize()
Settings.RetryCount = 1;
Settings.ExpectedErrorCodes = {HttpResponseCode::ServiceUnavailable, HttpResponseCode::TooManyRequests};
- if (!m_Config.AuthToken.empty())
+ if (m_Config.AccessTokenProvider)
+ {
+ Settings.AccessTokenProvider = m_Config.AccessTokenProvider;
+ }
+ else if (!m_Config.AuthToken.empty())
{
Settings.AccessTokenProvider = [token = m_Config.AuthToken]() -> HttpClientAccessToken {
return HttpClientAccessToken(token, HttpClientAccessToken::Clock::now() + std::chrono::hours{24});
@@ -41,7 +46,7 @@ HordeClient::Initialize()
m_Http = std::make_unique<zen::HttpClient>(m_Config.ServerUrl, Settings);
- if (!m_Config.AuthToken.empty())
+ if (Settings.AccessTokenProvider)
{
if (!m_Http->Authenticate())
{
@@ -63,24 +68,21 @@ HordeClient::BuildRequestBody() const
Requirements["pool"] = m_Config.Pool;
}
- std::string Condition;
-#if ZEN_PLATFORM_WINDOWS
ExtendableStringBuilder<256> CondBuf;
+#if ZEN_PLATFORM_WINDOWS
CondBuf << "(OSFamily == 'Windows' || WineEnabled == '" << (m_Config.AllowWine ? "true" : "false") << "')";
- Condition = std::string(CondBuf);
#elif ZEN_PLATFORM_MAC
- Condition = "OSFamily == 'MacOS'";
+ CondBuf << "OSFamily == 'MacOS'";
#else
- Condition = "OSFamily == 'Linux'";
+ CondBuf << "OSFamily == 'Linux'";
#endif
if (!m_Config.Condition.empty())
{
- Condition += " ";
- Condition += m_Config.Condition;
+ CondBuf << " " << m_Config.Condition;
}
- Requirements["condition"] = Condition;
+ Requirements["condition"] = std::string(CondBuf);
Requirements["exclusive"] = true;
json11::Json::object Connection;
@@ -157,37 +159,8 @@ HordeClient::ResolveCluster(const std::string& RequestBody, ClusterInfo& OutClus
}
OutCluster.ClusterId = ClusterIdVal.string_value();
- return true;
-}
-
-bool
-HordeClient::ParseHexBytes(std::string_view Hex, uint8_t* Out, size_t OutSize)
-{
- if (Hex.size() != OutSize * 2)
- {
- return false;
- }
- for (size_t i = 0; i < OutSize; ++i)
- {
- auto HexToByte = [](char c) -> int {
- if (c >= '0' && c <= '9')
- return c - '0';
- if (c >= 'a' && c <= 'f')
- return c - 'a' + 10;
- if (c >= 'A' && c <= 'F')
- return c - 'A' + 10;
- return -1;
- };
-
- const int Hi = HexToByte(Hex[i * 2]);
- const int Lo = HexToByte(Hex[i * 2 + 1]);
- if (Hi < 0 || Lo < 0)
- {
- return false;
- }
- Out[i] = static_cast<uint8_t>((Hi << 4) | Lo);
- }
+ ZEN_DEBUG("cluster resolution succeeded: clusterId='{}'", OutCluster.ClusterId);
return true;
}
@@ -197,8 +170,6 @@ HordeClient::RequestMachine(const std::string& RequestBody, const std::string& C
{
ZEN_TRACE_CPU("HordeClient::RequestMachine");
- ZEN_INFO("requesting machine from Horde with cluster '{}'", ClusterId.empty() ? "default" : ClusterId.c_str());
-
ExtendableStringBuilder<128> ResourcePath;
ResourcePath << "api/v2/compute/" << (ClusterId.empty() ? "default" : ClusterId.c_str());
@@ -324,6 +295,10 @@ HordeClient::RequestMachine(const std::string& RequestBody, const std::string& C
{
PhysicalCores = static_cast<uint16_t>(std::atoi(Prop.c_str() + 14));
}
+ else if (Prop.starts_with("Pool="))
+ {
+ OutMachine.Pool = Prop.substr(5);
+ }
}
}
@@ -367,10 +342,12 @@ HordeClient::RequestMachine(const std::string& RequestBody, const std::string& C
OutMachine.LeaseId = LeaseIdVal.string_value();
}
- ZEN_INFO("Horde machine assigned [{}:{}] cores={} lease={}",
+ ZEN_INFO("Horde machine assigned [{}:{}] mode={} cores={} pool={} lease={}",
OutMachine.GetConnectionAddress(),
OutMachine.GetConnectionPort(),
+ ToString(OutMachine.Mode),
OutMachine.LogicalCores,
+ OutMachine.Pool,
OutMachine.LeaseId);
return true;
diff --git a/src/zenhorde/hordecomputebuffer.cpp b/src/zenhorde/hordecomputebuffer.cpp
deleted file mode 100644
index 0d032b5d5..000000000
--- a/src/zenhorde/hordecomputebuffer.cpp
+++ /dev/null
@@ -1,454 +0,0 @@
-// Copyright Epic Games, Inc. All Rights Reserved.
-
-#include "hordecomputebuffer.h"
-
-#include <algorithm>
-#include <cassert>
-#include <chrono>
-#include <condition_variable>
-#include <cstring>
-
-namespace zen::horde {
-
-// Simplified ring buffer implementation for in-process use only.
-// Uses a single contiguous buffer with write/read cursors and
-// mutex+condvar for synchronization. This is simpler than the UE version
-// which uses lock-free atomics and shared memory, but sufficient for our
-// use case where we're the initiator side of the compute protocol.
-
-struct ComputeBuffer::Detail : TRefCounted<Detail>
-{
- std::vector<uint8_t> Data;
- size_t NumChunks = 0;
- size_t ChunkLength = 0;
-
- // Current write state
- size_t WriteChunkIdx = 0;
- size_t WriteOffset = 0;
- bool WriteComplete = false;
-
- // Current read state
- size_t ReadChunkIdx = 0;
- size_t ReadOffset = 0;
- bool Detached = false;
-
- // Per-chunk written length
- std::vector<size_t> ChunkWrittenLength;
- std::vector<bool> ChunkFinished; // Writer moved to next chunk
-
- std::mutex Mutex;
- std::condition_variable ReadCV; ///< Signaled when new data is written or stream completes
- std::condition_variable WriteCV; ///< Signaled when reader advances past a chunk, freeing space
-
- bool HasWriter = false;
- bool HasReader = false;
-
- uint8_t* ChunkPtr(size_t ChunkIdx) { return Data.data() + ChunkIdx * ChunkLength; }
- const uint8_t* ChunkPtr(size_t ChunkIdx) const { return Data.data() + ChunkIdx * ChunkLength; }
-};
-
-// ComputeBuffer
-
-ComputeBuffer::ComputeBuffer()
-{
-}
-ComputeBuffer::~ComputeBuffer()
-{
-}
-
-bool
-ComputeBuffer::CreateNew(const Params& InParams)
-{
- auto* NewDetail = new Detail();
- NewDetail->NumChunks = InParams.NumChunks;
- NewDetail->ChunkLength = InParams.ChunkLength;
- NewDetail->Data.resize(InParams.NumChunks * InParams.ChunkLength, 0);
- NewDetail->ChunkWrittenLength.resize(InParams.NumChunks, 0);
- NewDetail->ChunkFinished.resize(InParams.NumChunks, false);
-
- m_Detail = NewDetail;
- return true;
-}
-
-void
-ComputeBuffer::Close()
-{
- m_Detail = nullptr;
-}
-
-bool
-ComputeBuffer::IsValid() const
-{
- return static_cast<bool>(m_Detail);
-}
-
-ComputeBufferReader
-ComputeBuffer::CreateReader()
-{
- assert(m_Detail);
- m_Detail->HasReader = true;
- return ComputeBufferReader(m_Detail);
-}
-
-ComputeBufferWriter
-ComputeBuffer::CreateWriter()
-{
- assert(m_Detail);
- m_Detail->HasWriter = true;
- return ComputeBufferWriter(m_Detail);
-}
-
-// ComputeBufferReader
-
-ComputeBufferReader::ComputeBufferReader()
-{
-}
-ComputeBufferReader::~ComputeBufferReader()
-{
-}
-
-ComputeBufferReader::ComputeBufferReader(const ComputeBufferReader& Other) = default;
-ComputeBufferReader::ComputeBufferReader(ComputeBufferReader&& Other) noexcept = default;
-ComputeBufferReader& ComputeBufferReader::operator=(const ComputeBufferReader& Other) = default;
-ComputeBufferReader& ComputeBufferReader::operator=(ComputeBufferReader&& Other) noexcept = default;
-
-ComputeBufferReader::ComputeBufferReader(Ref<ComputeBuffer::Detail> InDetail) : m_Detail(std::move(InDetail))
-{
-}
-
-void
-ComputeBufferReader::Close()
-{
- m_Detail = nullptr;
-}
-
-void
-ComputeBufferReader::Detach()
-{
- if (m_Detail)
- {
- std::lock_guard<std::mutex> Lock(m_Detail->Mutex);
- m_Detail->Detached = true;
- m_Detail->ReadCV.notify_all();
- }
-}
-
-bool
-ComputeBufferReader::IsValid() const
-{
- return static_cast<bool>(m_Detail);
-}
-
-bool
-ComputeBufferReader::IsComplete() const
-{
- if (!m_Detail)
- {
- return true;
- }
- std::lock_guard<std::mutex> Lock(m_Detail->Mutex);
- if (m_Detail->Detached)
- {
- return true;
- }
- return m_Detail->WriteComplete && m_Detail->ReadChunkIdx == m_Detail->WriteChunkIdx &&
- m_Detail->ReadOffset >= m_Detail->ChunkWrittenLength[m_Detail->ReadChunkIdx];
-}
-
-void
-ComputeBufferReader::AdvanceReadPosition(size_t Size)
-{
- if (!m_Detail)
- {
- return;
- }
-
- std::lock_guard<std::mutex> Lock(m_Detail->Mutex);
-
- m_Detail->ReadOffset += Size;
-
- // Check if we need to move to next chunk
- const size_t ReadChunk = m_Detail->ReadChunkIdx;
- if (m_Detail->ChunkFinished[ReadChunk] && m_Detail->ReadOffset >= m_Detail->ChunkWrittenLength[ReadChunk])
- {
- const size_t NextChunk = (ReadChunk + 1) % m_Detail->NumChunks;
- m_Detail->ReadChunkIdx = NextChunk;
- m_Detail->ReadOffset = 0;
- m_Detail->WriteCV.notify_all();
- }
-
- m_Detail->ReadCV.notify_all();
-}
-
-size_t
-ComputeBufferReader::GetMaxReadSize() const
-{
- if (!m_Detail)
- {
- return 0;
- }
- std::lock_guard<std::mutex> Lock(m_Detail->Mutex);
- const size_t ReadChunk = m_Detail->ReadChunkIdx;
- return m_Detail->ChunkWrittenLength[ReadChunk] - m_Detail->ReadOffset;
-}
-
-const uint8_t*
-ComputeBufferReader::WaitToRead(size_t MinSize, int TimeoutMs, bool* OutTimedOut)
-{
- if (!m_Detail)
- {
- return nullptr;
- }
-
- std::unique_lock<std::mutex> Lock(m_Detail->Mutex);
-
- auto Predicate = [&]() -> bool {
- if (m_Detail->Detached)
- {
- return true;
- }
-
- const size_t ReadChunk = m_Detail->ReadChunkIdx;
- const size_t Available = m_Detail->ChunkWrittenLength[ReadChunk] - m_Detail->ReadOffset;
-
- if (Available >= MinSize)
- {
- return true;
- }
-
- // If chunk is finished and we've read everything, try to move to next
- if (m_Detail->ChunkFinished[ReadChunk] && m_Detail->ReadOffset >= m_Detail->ChunkWrittenLength[ReadChunk])
- {
- if (m_Detail->WriteComplete)
- {
- return true; // End of stream
- }
- // Move to next chunk
- const size_t NextChunk = (ReadChunk + 1) % m_Detail->NumChunks;
- m_Detail->ReadChunkIdx = NextChunk;
- m_Detail->ReadOffset = 0;
- m_Detail->WriteCV.notify_all();
- return false; // Re-check with new chunk
- }
-
- if (m_Detail->WriteComplete)
- {
- return true; // End of stream
- }
-
- return false;
- };
-
- if (TimeoutMs < 0)
- {
- m_Detail->ReadCV.wait(Lock, Predicate);
- }
- else
- {
- if (!m_Detail->ReadCV.wait_for(Lock, std::chrono::milliseconds(TimeoutMs), Predicate))
- {
- if (OutTimedOut)
- {
- *OutTimedOut = true;
- }
- return nullptr;
- }
- }
-
- if (m_Detail->Detached)
- {
- return nullptr;
- }
-
- const size_t ReadChunk = m_Detail->ReadChunkIdx;
- const size_t Available = m_Detail->ChunkWrittenLength[ReadChunk] - m_Detail->ReadOffset;
-
- if (Available < MinSize)
- {
- return nullptr; // End of stream
- }
-
- return m_Detail->ChunkPtr(ReadChunk) + m_Detail->ReadOffset;
-}
-
-size_t
-ComputeBufferReader::Read(void* Buffer, size_t MaxSize, int TimeoutMs, bool* OutTimedOut)
-{
- const uint8_t* Data = WaitToRead(1, TimeoutMs, OutTimedOut);
- if (!Data)
- {
- return 0;
- }
-
- const size_t Available = GetMaxReadSize();
- const size_t ToCopy = std::min(Available, MaxSize);
- memcpy(Buffer, Data, ToCopy);
- AdvanceReadPosition(ToCopy);
- return ToCopy;
-}
-
-// ComputeBufferWriter
-
-ComputeBufferWriter::ComputeBufferWriter() = default;
-ComputeBufferWriter::ComputeBufferWriter(const ComputeBufferWriter& Other) = default;
-ComputeBufferWriter::ComputeBufferWriter(ComputeBufferWriter&& Other) noexcept = default;
-ComputeBufferWriter::~ComputeBufferWriter() = default;
-ComputeBufferWriter& ComputeBufferWriter::operator=(const ComputeBufferWriter& Other) = default;
-ComputeBufferWriter& ComputeBufferWriter::operator=(ComputeBufferWriter&& Other) noexcept = default;
-
-ComputeBufferWriter::ComputeBufferWriter(Ref<ComputeBuffer::Detail> InDetail) : m_Detail(std::move(InDetail))
-{
-}
-
-void
-ComputeBufferWriter::Close()
-{
- if (m_Detail)
- {
- {
- std::lock_guard<std::mutex> Lock(m_Detail->Mutex);
- if (!m_Detail->WriteComplete)
- {
- m_Detail->WriteComplete = true;
- m_Detail->ReadCV.notify_all();
- }
- }
- m_Detail = nullptr;
- }
-}
-
-bool
-ComputeBufferWriter::IsValid() const
-{
- return static_cast<bool>(m_Detail);
-}
-
-void
-ComputeBufferWriter::MarkComplete()
-{
- if (m_Detail)
- {
- std::lock_guard<std::mutex> Lock(m_Detail->Mutex);
- m_Detail->WriteComplete = true;
- m_Detail->ReadCV.notify_all();
- }
-}
-
-void
-ComputeBufferWriter::AdvanceWritePosition(size_t Size)
-{
- if (!m_Detail || Size == 0)
- {
- return;
- }
-
- std::lock_guard<std::mutex> Lock(m_Detail->Mutex);
- const size_t WriteChunk = m_Detail->WriteChunkIdx;
- m_Detail->ChunkWrittenLength[WriteChunk] += Size;
- m_Detail->WriteOffset += Size;
- m_Detail->ReadCV.notify_all();
-}
-
-size_t
-ComputeBufferWriter::GetMaxWriteSize() const
-{
- if (!m_Detail)
- {
- return 0;
- }
- std::lock_guard<std::mutex> Lock(m_Detail->Mutex);
- const size_t WriteChunk = m_Detail->WriteChunkIdx;
- return m_Detail->ChunkLength - m_Detail->ChunkWrittenLength[WriteChunk];
-}
-
-size_t
-ComputeBufferWriter::GetChunkMaxLength() const
-{
- if (!m_Detail)
- {
- return 0;
- }
- return m_Detail->ChunkLength;
-}
-
-size_t
-ComputeBufferWriter::Write(const void* Buffer, size_t MaxSize, int TimeoutMs)
-{
- uint8_t* Dest = WaitToWrite(1, TimeoutMs);
- if (!Dest)
- {
- return 0;
- }
-
- const size_t Available = GetMaxWriteSize();
- const size_t ToCopy = std::min(Available, MaxSize);
- memcpy(Dest, Buffer, ToCopy);
- AdvanceWritePosition(ToCopy);
- return ToCopy;
-}
-
-uint8_t*
-ComputeBufferWriter::WaitToWrite(size_t MinSize, int TimeoutMs)
-{
- if (!m_Detail)
- {
- return nullptr;
- }
-
- std::unique_lock<std::mutex> Lock(m_Detail->Mutex);
-
- if (m_Detail->WriteComplete)
- {
- return nullptr;
- }
-
- const size_t WriteChunk = m_Detail->WriteChunkIdx;
- const size_t Available = m_Detail->ChunkLength - m_Detail->ChunkWrittenLength[WriteChunk];
-
- // If current chunk has enough space, return pointer
- if (Available >= MinSize)
- {
- return m_Detail->ChunkPtr(WriteChunk) + m_Detail->ChunkWrittenLength[WriteChunk];
- }
-
- // Current chunk is full - mark it as finished and move to next.
- // The writer cannot advance until the reader has fully consumed the next chunk,
- // preventing the writer from overwriting data the reader hasn't processed yet.
- m_Detail->ChunkFinished[WriteChunk] = true;
- m_Detail->ReadCV.notify_all();
-
- const size_t NextChunk = (WriteChunk + 1) % m_Detail->NumChunks;
-
- // Wait until reader has consumed the next chunk
- auto Predicate = [&]() -> bool {
- // Check if read has moved past this chunk
- return m_Detail->ReadChunkIdx != NextChunk || m_Detail->Detached;
- };
-
- if (TimeoutMs < 0)
- {
- m_Detail->WriteCV.wait(Lock, Predicate);
- }
- else
- {
- if (!m_Detail->WriteCV.wait_for(Lock, std::chrono::milliseconds(TimeoutMs), Predicate))
- {
- return nullptr;
- }
- }
-
- if (m_Detail->Detached)
- {
- return nullptr;
- }
-
- // Reset next chunk
- m_Detail->ChunkWrittenLength[NextChunk] = 0;
- m_Detail->ChunkFinished[NextChunk] = false;
- m_Detail->WriteChunkIdx = NextChunk;
- m_Detail->WriteOffset = 0;
-
- return m_Detail->ChunkPtr(NextChunk);
-}
-
-} // namespace zen::horde
diff --git a/src/zenhorde/hordecomputebuffer.h b/src/zenhorde/hordecomputebuffer.h
deleted file mode 100644
index 64ef91b7a..000000000
--- a/src/zenhorde/hordecomputebuffer.h
+++ /dev/null
@@ -1,136 +0,0 @@
-// Copyright Epic Games, Inc. All Rights Reserved.
-
-#pragma once
-
-#include <zenbase/refcount.h>
-
-#include <cstddef>
-#include <cstdint>
-#include <mutex>
-#include <vector>
-
-namespace zen::horde {
-
-class ComputeBufferReader;
-class ComputeBufferWriter;
-
-/** Simplified in-process ring buffer for the Horde compute protocol.
- *
- * Unlike the UE FComputeBuffer which supports shared-memory and memory-mapped files,
- * this implementation uses plain heap-allocated memory since we only need in-process
- * communication between channel and transport threads. The buffer is divided into
- * fixed-size chunks; readers and writers block when no space is available.
- */
-class ComputeBuffer
-{
-public:
- struct Params
- {
- size_t NumChunks = 2;
- size_t ChunkLength = 512 * 1024;
- };
-
- ComputeBuffer();
- ~ComputeBuffer();
-
- ComputeBuffer(const ComputeBuffer&) = delete;
- ComputeBuffer& operator=(const ComputeBuffer&) = delete;
-
- bool CreateNew(const Params& InParams);
- void Close();
-
- bool IsValid() const;
-
- ComputeBufferReader CreateReader();
- ComputeBufferWriter CreateWriter();
-
-private:
- struct Detail;
- Ref<Detail> m_Detail;
-
- friend class ComputeBufferReader;
- friend class ComputeBufferWriter;
-};
-
-/** Read endpoint for a ComputeBuffer.
- *
- * Provides blocking reads from the ring buffer. WaitToRead() returns a pointer
- * directly into the buffer memory (zero-copy); the caller must call
- * AdvanceReadPosition() after consuming the data.
- */
-class ComputeBufferReader
-{
-public:
- ComputeBufferReader();
- ComputeBufferReader(const ComputeBufferReader&);
- ComputeBufferReader(ComputeBufferReader&&) noexcept;
- ~ComputeBufferReader();
-
- ComputeBufferReader& operator=(const ComputeBufferReader&);
- ComputeBufferReader& operator=(ComputeBufferReader&&) noexcept;
-
- void Close();
- void Detach();
- bool IsValid() const;
- bool IsComplete() const;
-
- void AdvanceReadPosition(size_t Size);
- size_t GetMaxReadSize() const;
-
- /** Copy up to MaxSize bytes from the buffer into Buffer. Blocks until data is available. */
- size_t Read(void* Buffer, size_t MaxSize, int TimeoutMs = -1, bool* OutTimedOut = nullptr);
-
- /** Wait until at least MinSize bytes are available and return a direct pointer.
- * Returns nullptr on timeout or if the writer has completed. */
- const uint8_t* WaitToRead(size_t MinSize, int TimeoutMs = -1, bool* OutTimedOut = nullptr);
-
-private:
- friend class ComputeBuffer;
- explicit ComputeBufferReader(Ref<ComputeBuffer::Detail> InDetail);
-
- Ref<ComputeBuffer::Detail> m_Detail;
-};
-
-/** Write endpoint for a ComputeBuffer.
- *
- * Provides blocking writes into the ring buffer. WaitToWrite() returns a pointer
- * directly into the buffer memory (zero-copy); the caller must call
- * AdvanceWritePosition() after filling the data. Call MarkComplete() to signal
- * that no more data will be written.
- */
-class ComputeBufferWriter
-{
-public:
- ComputeBufferWriter();
- ComputeBufferWriter(const ComputeBufferWriter&);
- ComputeBufferWriter(ComputeBufferWriter&&) noexcept;
- ~ComputeBufferWriter();
-
- ComputeBufferWriter& operator=(const ComputeBufferWriter&);
- ComputeBufferWriter& operator=(ComputeBufferWriter&&) noexcept;
-
- void Close();
- bool IsValid() const;
-
- /** Signal that no more data will be written. Unblocks any waiting readers. */
- void MarkComplete();
-
- void AdvanceWritePosition(size_t Size);
- size_t GetMaxWriteSize() const;
- size_t GetChunkMaxLength() const;
-
- /** Copy up to MaxSize bytes from Buffer into the ring buffer. Blocks until space is available. */
- size_t Write(const void* Buffer, size_t MaxSize, int TimeoutMs = -1);
-
- /** Wait until at least MinSize bytes of write space are available and return a direct pointer.
- * Returns nullptr on timeout. */
- uint8_t* WaitToWrite(size_t MinSize, int TimeoutMs = -1);
-
-private:
- friend class ComputeBuffer;
- explicit ComputeBufferWriter(Ref<ComputeBuffer::Detail> InDetail);
-
- Ref<ComputeBuffer::Detail> m_Detail;
-};
-
-} // namespace zen::horde
diff --git a/src/zenhorde/hordecomputechannel.cpp b/src/zenhorde/hordecomputechannel.cpp
deleted file mode 100644
index ee2a6f327..000000000
--- a/src/zenhorde/hordecomputechannel.cpp
+++ /dev/null
@@ -1,37 +0,0 @@
-// Copyright Epic Games, Inc. All Rights Reserved.
-
-#include "hordecomputechannel.h"
-
-namespace zen::horde {
-
-ComputeChannel::ComputeChannel(ComputeBufferReader InReader, ComputeBufferWriter InWriter)
-: Reader(std::move(InReader))
-, Writer(std::move(InWriter))
-{
-}
-
-bool
-ComputeChannel::IsValid() const
-{
- return Reader.IsValid() && Writer.IsValid();
-}
-
-size_t
-ComputeChannel::Send(const void* Data, size_t Size, int TimeoutMs)
-{
- return Writer.Write(Data, Size, TimeoutMs);
-}
-
-size_t
-ComputeChannel::Recv(void* Data, size_t Size, int TimeoutMs)
-{
- return Reader.Read(Data, Size, TimeoutMs);
-}
-
-void
-ComputeChannel::MarkComplete()
-{
- Writer.MarkComplete();
-}
-
-} // namespace zen::horde
diff --git a/src/zenhorde/hordecomputechannel.h b/src/zenhorde/hordecomputechannel.h
deleted file mode 100644
index c1dff20e4..000000000
--- a/src/zenhorde/hordecomputechannel.h
+++ /dev/null
@@ -1,32 +0,0 @@
-// Copyright Epic Games, Inc. All Rights Reserved.
-
-#pragma once
-
-#include "hordecomputebuffer.h"
-
-namespace zen::horde {
-
-/** Bidirectional communication channel using a pair of compute buffers.
- *
- * Pairs a ComputeBufferReader (for receiving data) with a ComputeBufferWriter
- * (for sending data). Used by ComputeSocket to represent one logical channel
- * within a multiplexed connection.
- */
-class ComputeChannel : public TRefCounted<ComputeChannel>
-{
-public:
- ComputeBufferReader Reader;
- ComputeBufferWriter Writer;
-
- ComputeChannel(ComputeBufferReader InReader, ComputeBufferWriter InWriter);
-
- bool IsValid() const;
-
- size_t Send(const void* Data, size_t Size, int TimeoutMs = -1);
- size_t Recv(void* Data, size_t Size, int TimeoutMs = -1);
-
- /** Signal that no more data will be sent on this channel. */
- void MarkComplete();
-};
-
-} // namespace zen::horde
diff --git a/src/zenhorde/hordecomputesocket.cpp b/src/zenhorde/hordecomputesocket.cpp
index 6ef67760c..8a6fc40a9 100644
--- a/src/zenhorde/hordecomputesocket.cpp
+++ b/src/zenhorde/hordecomputesocket.cpp
@@ -6,198 +6,326 @@
namespace zen::horde {
-ComputeSocket::ComputeSocket(std::unique_ptr<ComputeTransport> Transport)
-: m_Log(zen::logging::Get("horde.socket"))
+AsyncComputeSocket::AsyncComputeSocket(std::unique_ptr<AsyncComputeTransport> Transport, asio::io_context& IoContext)
+: m_Log(zen::logging::Get("horde.socket.async"))
, m_Transport(std::move(Transport))
+, m_Strand(asio::make_strand(IoContext))
+, m_PingTimer(m_Strand)
{
}
-ComputeSocket::~ComputeSocket()
+AsyncComputeSocket::~AsyncComputeSocket()
{
- // Shutdown order matters: first stop the ping thread, then unblock send threads
- // by detaching readers, then join send threads, and finally close the transport
- // to unblock the recv thread (which is blocked on RecvMessage).
- {
- std::lock_guard<std::mutex> Lock(m_PingMutex);
- m_PingShouldStop = true;
- m_PingCV.notify_all();
- }
-
- for (auto& Reader : m_Readers)
- {
- Reader.Detach();
- }
-
- for (auto& [Id, Thread] : m_SendThreads)
- {
- if (Thread.joinable())
- {
- Thread.join();
- }
- }
+ Close();
+}
- m_Transport->Close();
+void
+AsyncComputeSocket::RegisterChannel(int ChannelId, FrameHandler OnFrame, DetachHandler OnDetach)
+{
+ m_FrameHandlers[ChannelId] = std::move(OnFrame);
+ m_DetachHandlers[ChannelId] = std::move(OnDetach);
+}
- if (m_RecvThread.joinable())
- {
- m_RecvThread.join();
- }
- if (m_PingThread.joinable())
- {
- m_PingThread.join();
- }
+void
+AsyncComputeSocket::StartRecvPump()
+{
+ StartPingTimer();
+ DoRecvHeader();
}
-Ref<ComputeChannel>
-ComputeSocket::CreateChannel(int ChannelId)
+void
+AsyncComputeSocket::DoRecvHeader()
{
- ComputeBuffer::Params Params;
+ auto Self = shared_from_this();
+ m_Transport->AsyncRead(&m_RecvHeader,
+ sizeof(FrameHeader),
+ asio::bind_executor(m_Strand, [this, Self](const std::error_code& Ec, size_t /*Bytes*/) {
+ if (Ec)
+ {
+ if (Ec != asio::error::operation_aborted && !m_Closed)
+ {
+ ZEN_WARN("recv header error: {}", Ec.message());
+ HandleError();
+ }
+ return;
+ }
- ComputeBuffer RecvBuffer;
- if (!RecvBuffer.CreateNew(Params))
- {
- return {};
- }
+ if (m_Closed)
+ {
+ return;
+ }
- ComputeBuffer SendBuffer;
- if (!SendBuffer.CreateNew(Params))
- {
- return {};
- }
+ if (m_RecvHeader.Size >= 0)
+ {
+ DoRecvPayload(m_RecvHeader);
+ }
+ else if (m_RecvHeader.Size == ControlDetach)
+ {
+ if (auto It = m_DetachHandlers.find(m_RecvHeader.Channel); It != m_DetachHandlers.end() && It->second)
+ {
+ It->second();
+ }
+ DoRecvHeader();
+ }
+ else if (m_RecvHeader.Size == ControlPing)
+ {
+ DoRecvHeader();
+ }
+ else
+ {
+ ZEN_WARN("invalid frame header size: {}", m_RecvHeader.Size);
+ }
+ }));
+}
- Ref<ComputeChannel> Channel(new ComputeChannel(RecvBuffer.CreateReader(), SendBuffer.CreateWriter()));
+void
+AsyncComputeSocket::DoRecvPayload(FrameHeader Header)
+{
+ auto PayloadBuf = std::make_shared<std::vector<uint8_t>>(static_cast<size_t>(Header.Size));
+ auto Self = shared_from_this();
- // Attach recv buffer writer (transport recv thread writes into this)
- {
- std::lock_guard<std::mutex> Lock(m_WritersMutex);
- m_Writers.emplace(ChannelId, RecvBuffer.CreateWriter());
- }
+ m_Transport->AsyncRead(PayloadBuf->data(),
+ PayloadBuf->size(),
+ asio::bind_executor(m_Strand, [this, Self, Header, PayloadBuf](const std::error_code& Ec, size_t /*Bytes*/) {
+ if (Ec)
+ {
+ if (Ec != asio::error::operation_aborted && !m_Closed)
+ {
+ ZEN_WARN("recv payload error (channel={}, size={}): {}", Header.Channel, Header.Size, Ec.message());
+ HandleError();
+ }
+ return;
+ }
- // Attach send buffer reader (send thread reads from this)
- {
- ComputeBufferReader Reader = SendBuffer.CreateReader();
- m_Readers.push_back(Reader);
- m_SendThreads.emplace(ChannelId, std::thread(&ComputeSocket::SendThreadProc, this, ChannelId, std::move(Reader)));
- }
+ if (m_Closed)
+ {
+ return;
+ }
+
+ if (auto It = m_FrameHandlers.find(Header.Channel); It != m_FrameHandlers.end() && It->second)
+ {
+ It->second(std::move(*PayloadBuf));
+ }
+ else
+ {
+ ZEN_WARN("recv frame for unknown channel {}", Header.Channel);
+ }
- return Channel;
+ DoRecvHeader();
+ }));
}
void
-ComputeSocket::StartCommunication()
+AsyncComputeSocket::AsyncSendFrame(int ChannelId, std::vector<uint8_t> Data, SendHandler Handler)
{
- m_RecvThread = std::thread(&ComputeSocket::RecvThreadProc, this);
- m_PingThread = std::thread(&ComputeSocket::PingThreadProc, this);
+ auto Self = shared_from_this();
+ asio::dispatch(m_Strand, [this, Self, ChannelId, Data = std::move(Data), Handler = std::move(Handler)]() mutable {
+ if (m_Closed)
+ {
+ if (Handler)
+ {
+ Handler(asio::error::make_error_code(asio::error::operation_aborted));
+ }
+ return;
+ }
+
+ PendingWrite Write;
+ Write.Header.Channel = ChannelId;
+ Write.Header.Size = static_cast<int32_t>(Data.size());
+ Write.Data = std::move(Data);
+ Write.Handler = std::move(Handler);
+
+ m_SendQueue.push_back(std::move(Write));
+ if (m_SendQueue.size() == 1)
+ {
+ FlushNextSend();
+ }
+ });
}
void
-ComputeSocket::PingThreadProc()
+AsyncComputeSocket::AsyncSendDetach(int ChannelId, SendHandler Handler)
{
- while (true)
- {
+ auto Self = shared_from_this();
+ asio::dispatch(m_Strand, [this, Self, ChannelId, Handler = std::move(Handler)]() mutable {
+ if (m_Closed)
{
- std::unique_lock<std::mutex> Lock(m_PingMutex);
- if (m_PingCV.wait_for(Lock, std::chrono::milliseconds(2000), [this] { return m_PingShouldStop; }))
+ if (Handler)
{
- break;
+ Handler(asio::error::make_error_code(asio::error::operation_aborted));
}
+ return;
}
- std::lock_guard<std::mutex> Lock(m_SendMutex);
- FrameHeader Header;
- Header.Channel = 0;
- Header.Size = ControlPing;
- m_Transport->SendMessage(&Header, sizeof(Header));
- }
+ PendingWrite Write;
+ Write.Header.Channel = ChannelId;
+ Write.Header.Size = ControlDetach;
+ Write.Handler = std::move(Handler);
+
+ m_SendQueue.push_back(std::move(Write));
+ if (m_SendQueue.size() == 1)
+ {
+ FlushNextSend();
+ }
+ });
}
void
-ComputeSocket::RecvThreadProc()
+AsyncComputeSocket::FlushNextSend()
{
- // Writers are cached locally to avoid taking m_WritersMutex on every frame.
- // The shared m_Writers map is only accessed when a channel is seen for the first time.
- std::unordered_map<int, ComputeBufferWriter> CachedWriters;
+ if (m_SendQueue.empty() || m_Closed)
+ {
+ return;
+ }
- FrameHeader Header;
- while (m_Transport->RecvMessage(&Header, sizeof(Header)))
+ PendingWrite& Front = m_SendQueue.front();
+
+ if (Front.Data.empty())
{
- if (Header.Size >= 0)
- {
- // Data frame
- auto It = CachedWriters.find(Header.Channel);
- if (It == CachedWriters.end())
- {
- std::lock_guard<std::mutex> Lock(m_WritersMutex);
- auto WIt = m_Writers.find(Header.Channel);
- if (WIt == m_Writers.end())
- {
- ZEN_WARN("recv frame for unknown channel {}", Header.Channel);
- // Skip the data
- std::vector<uint8_t> Discard(Header.Size);
- m_Transport->RecvMessage(Discard.data(), Header.Size);
- continue;
- }
- It = CachedWriters.emplace(Header.Channel, WIt->second).first;
- }
+ // Control frame — header only
+ auto Self = shared_from_this();
+ m_Transport->AsyncWrite(&Front.Header,
+ sizeof(FrameHeader),
+ asio::bind_executor(m_Strand, [this, Self](const std::error_code& Ec, size_t /*Bytes*/) {
+ SendHandler Handler = std::move(m_SendQueue.front().Handler);
+ m_SendQueue.pop_front();
- ComputeBufferWriter& Writer = It->second;
- uint8_t* Dest = Writer.WaitToWrite(Header.Size);
- if (!Dest || !m_Transport->RecvMessage(Dest, Header.Size))
- {
- ZEN_WARN("failed to read frame data (channel={}, size={})", Header.Channel, Header.Size);
- return;
- }
- Writer.AdvanceWritePosition(Header.Size);
- }
- else if (Header.Size == ControlDetach)
- {
- // Detach the recv buffer for this channel
- CachedWriters.erase(Header.Channel);
+ if (Handler)
+ {
+ Handler(Ec);
+ }
- std::lock_guard<std::mutex> Lock(m_WritersMutex);
- auto It = m_Writers.find(Header.Channel);
- if (It != m_Writers.end())
- {
- It->second.MarkComplete();
- m_Writers.erase(It);
- }
- }
- else if (Header.Size == ControlPing)
+ if (Ec)
+ {
+ if (Ec != asio::error::operation_aborted)
+ {
+ ZEN_WARN("send error: {}", Ec.message());
+ HandleError();
+ }
+ return;
+ }
+
+ FlushNextSend();
+ }));
+ }
+ else
+ {
+ // Data frame — write header first, then payload
+ auto Self = shared_from_this();
+ m_Transport->AsyncWrite(&Front.Header,
+ sizeof(FrameHeader),
+ asio::bind_executor(m_Strand, [this, Self](const std::error_code& Ec, size_t /*Bytes*/) {
+ if (Ec)
+ {
+ SendHandler Handler = std::move(m_SendQueue.front().Handler);
+ m_SendQueue.pop_front();
+ if (Handler)
+ {
+ Handler(Ec);
+ }
+ if (Ec != asio::error::operation_aborted)
+ {
+ ZEN_WARN("send header error: {}", Ec.message());
+ HandleError();
+ }
+ return;
+ }
+
+ PendingWrite& Payload = m_SendQueue.front();
+ m_Transport->AsyncWrite(
+ Payload.Data.data(),
+ Payload.Data.size(),
+ asio::bind_executor(m_Strand, [this, Self](const std::error_code& Ec, size_t /*Bytes*/) {
+ SendHandler Handler = std::move(m_SendQueue.front().Handler);
+ m_SendQueue.pop_front();
+
+ if (Handler)
+ {
+ Handler(Ec);
+ }
+
+ if (Ec)
+ {
+ if (Ec != asio::error::operation_aborted)
+ {
+ ZEN_WARN("send payload error: {}", Ec.message());
+ HandleError();
+ }
+ return;
+ }
+
+ FlushNextSend();
+ }));
+ }));
+ }
+}
+
+void
+AsyncComputeSocket::StartPingTimer()
+{
+ if (m_Closed)
+ {
+ return;
+ }
+
+ m_PingTimer.expires_after(std::chrono::seconds(2));
+
+ auto Self = shared_from_this();
+ m_PingTimer.async_wait(asio::bind_executor(m_Strand, [this, Self](const asio::error_code& Ec) {
+ if (Ec || m_Closed)
{
- // Ping response - ignore
+ return;
}
- else
+
+ // Enqueue a ping control frame
+ PendingWrite Write;
+ Write.Header.Channel = 0;
+ Write.Header.Size = ControlPing;
+
+ m_SendQueue.push_back(std::move(Write));
+ if (m_SendQueue.size() == 1)
{
- ZEN_WARN("invalid frame header size: {}", Header.Size);
- return;
+ FlushNextSend();
}
- }
+
+ StartPingTimer();
+ }));
}
void
-ComputeSocket::SendThreadProc(int Channel, ComputeBufferReader Reader)
+AsyncComputeSocket::HandleError()
{
- // Each channel has its own send thread. All send threads share m_SendMutex
- // to serialize writes to the transport, since TCP requires atomic frame writes.
- FrameHeader Header;
- Header.Channel = Channel;
+ if (m_Closed)
+ {
+ return;
+ }
+
+ Close();
- const uint8_t* Data;
- while ((Data = Reader.WaitToRead(1)) != nullptr)
+ // Notify all channels that the connection is gone so agents can clean up
+ for (auto& [ChannelId, Handler] : m_DetachHandlers)
{
- std::lock_guard<std::mutex> Lock(m_SendMutex);
+ if (Handler)
+ {
+ Handler();
+ }
+ }
+}
- Header.Size = static_cast<int32_t>(Reader.GetMaxReadSize());
- m_Transport->SendMessage(&Header, sizeof(Header));
- m_Transport->SendMessage(Data, Header.Size);
- Reader.AdvanceReadPosition(Header.Size);
+void
+AsyncComputeSocket::Close()
+{
+ if (m_Closed)
+ {
+ return;
}
- if (Reader.IsComplete())
+ m_Closed = true;
+ m_PingTimer.cancel();
+
+ if (m_Transport)
{
- std::lock_guard<std::mutex> Lock(m_SendMutex);
- Header.Size = ControlDetach;
- m_Transport->SendMessage(&Header, sizeof(Header));
+ m_Transport->Close();
}
}
diff --git a/src/zenhorde/hordecomputesocket.h b/src/zenhorde/hordecomputesocket.h
index 0c3cb4195..45b3418b7 100644
--- a/src/zenhorde/hordecomputesocket.h
+++ b/src/zenhorde/hordecomputesocket.h
@@ -2,45 +2,69 @@
#pragma once
-#include "hordecomputebuffer.h"
-#include "hordecomputechannel.h"
#include "hordetransport.h"
#include <zencore/logbase.h>
-#include <condition_variable>
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <asio.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+#if ZEN_PLATFORM_WINDOWS
+# undef SendMessage
+#endif
+
+#include <deque>
+#include <functional>
#include <memory>
-#include <mutex>
-#include <thread>
+#include <system_error>
#include <unordered_map>
#include <vector>
namespace zen::horde {
-/** Multiplexed socket that routes data between multiple channels over a single transport.
+class AsyncComputeTransport;
+
+/** Handler called when a data frame arrives for a channel. */
+using FrameHandler = std::function<void(std::vector<uint8_t> Data)>;
+
+/** Handler called when a channel is detached by the remote peer. */
+using DetachHandler = std::function<void()>;
+
+/** Handler for async send completion. */
+using SendHandler = std::function<void(const std::error_code&)>;
+
+/** Async multiplexed socket that routes data between channels over a single transport.
*
- * Each channel is identified by an integer ID and backed by a pair of ComputeBuffers.
- * A recv thread demultiplexes incoming frames to channel-specific buffers, while
- * per-channel send threads multiplex outgoing data onto the shared transport.
+ * Uses an async recv pump, a serialized send queue, and a periodic ping timer —
+ * all running on a shared io_context.
*
- * Wire format per frame: [channelId (4B)][size (4B)][data]
- * Control messages use negative sizes: -2 = detach (channel closed), -3 = ping.
+ * Wire format per frame: [channelId(4B)][size(4B)][data].
+ * Control messages use negative sizes: -2 = detach, -3 = ping.
*/
-class ComputeSocket
+class AsyncComputeSocket : public std::enable_shared_from_this<AsyncComputeSocket>
{
public:
- explicit ComputeSocket(std::unique_ptr<ComputeTransport> Transport);
- ~ComputeSocket();
+ AsyncComputeSocket(std::unique_ptr<AsyncComputeTransport> Transport, asio::io_context& IoContext);
+ ~AsyncComputeSocket();
+
+ AsyncComputeSocket(const AsyncComputeSocket&) = delete;
+ AsyncComputeSocket& operator=(const AsyncComputeSocket&) = delete;
- ComputeSocket(const ComputeSocket&) = delete;
- ComputeSocket& operator=(const ComputeSocket&) = delete;
+ /** Register callbacks for a channel. Must be called before StartRecvPump(). */
+ void RegisterChannel(int ChannelId, FrameHandler OnFrame, DetachHandler OnDetach);
- /** Create a channel with the given ID.
- * Allocates anonymous in-process buffers and spawns a send thread for the channel. */
- Ref<ComputeChannel> CreateChannel(int ChannelId);
+ /** Begin the async recv pump and ping timer. */
+ void StartRecvPump();
- /** Start the recv pump and ping threads. Must be called after all channels are created. */
- void StartCommunication();
+ /** Enqueue a data frame for async transmission. */
+ void AsyncSendFrame(int ChannelId, std::vector<uint8_t> Data, SendHandler Handler = {});
+
+ /** Send a control frame (detach) for a channel. */
+ void AsyncSendDetach(int ChannelId, SendHandler Handler = {});
+
+ /** Close the transport and cancel all pending operations. */
+ void Close();
private:
struct FrameHeader
@@ -49,31 +73,35 @@ private:
int32_t Size = 0;
};
+ struct PendingWrite
+ {
+ FrameHeader Header;
+ std::vector<uint8_t> Data;
+ SendHandler Handler;
+ };
+
static constexpr int32_t ControlDetach = -2;
static constexpr int32_t ControlPing = -3;
LoggerRef Log() { return m_Log; }
- void RecvThreadProc();
- void SendThreadProc(int Channel, ComputeBufferReader Reader);
- void PingThreadProc();
-
- LoggerRef m_Log;
- std::unique_ptr<ComputeTransport> m_Transport;
- std::mutex m_SendMutex; ///< Serializes writes to the transport
-
- std::mutex m_WritersMutex;
- std::unordered_map<int, ComputeBufferWriter> m_Writers; ///< Recv-side: writers keyed by channel ID
+ void DoRecvHeader();
+ void DoRecvPayload(FrameHeader Header);
+ void FlushNextSend();
+ void StartPingTimer();
+ void HandleError();
- std::vector<ComputeBufferReader> m_Readers; ///< Send-side: readers for join on destruction
- std::unordered_map<int, std::thread> m_SendThreads; ///< One send thread per channel
+ LoggerRef m_Log;
+ std::unique_ptr<AsyncComputeTransport> m_Transport;
+ asio::strand<asio::any_io_executor> m_Strand;
+ asio::steady_timer m_PingTimer;
- std::thread m_RecvThread;
- std::thread m_PingThread;
+ std::unordered_map<int, FrameHandler> m_FrameHandlers;
+ std::unordered_map<int, DetachHandler> m_DetachHandlers;
- bool m_PingShouldStop = false;
- std::mutex m_PingMutex;
- std::condition_variable m_PingCV;
+ FrameHeader m_RecvHeader;
+ std::deque<PendingWrite> m_SendQueue;
+ bool m_Closed = false;
};
} // namespace zen::horde
diff --git a/src/zenhorde/hordeconfig.cpp b/src/zenhorde/hordeconfig.cpp
index 2dca228d9..9f6125c64 100644
--- a/src/zenhorde/hordeconfig.cpp
+++ b/src/zenhorde/hordeconfig.cpp
@@ -1,5 +1,7 @@
// Copyright Epic Games, Inc. All Rights Reserved.
+#include <zencore/logging.h>
+#include <zencore/string.h>
#include <zenhorde/hordeconfig.h>
namespace zen::horde {
@@ -9,12 +11,14 @@ HordeConfig::Validate() const
{
if (ServerUrl.empty())
{
+ ZEN_WARN("Horde server URL is not configured");
return false;
}
// Relay mode implies AES encryption
if (Mode == ConnectionMode::Relay && EncryptionMode != Encryption::AES)
{
+ ZEN_WARN("Horde relay mode requires AES encryption, but encryption is set to '{}'", ToString(EncryptionMode));
return false;
}
@@ -52,37 +56,39 @@ ToString(Encryption Enc)
bool
FromString(ConnectionMode& OutMode, std::string_view Str)
{
- if (Str == "direct")
+ if (StrCaseCompare(Str, "direct") == 0)
{
OutMode = ConnectionMode::Direct;
return true;
}
- if (Str == "tunnel")
+ if (StrCaseCompare(Str, "tunnel") == 0)
{
OutMode = ConnectionMode::Tunnel;
return true;
}
- if (Str == "relay")
+ if (StrCaseCompare(Str, "relay") == 0)
{
OutMode = ConnectionMode::Relay;
return true;
}
+ ZEN_WARN("unrecognized Horde connection mode: '{}'", Str);
return false;
}
bool
FromString(Encryption& OutEnc, std::string_view Str)
{
- if (Str == "none")
+ if (StrCaseCompare(Str, "none") == 0)
{
OutEnc = Encryption::None;
return true;
}
- if (Str == "aes")
+ if (StrCaseCompare(Str, "aes") == 0)
{
OutEnc = Encryption::AES;
return true;
}
+ ZEN_WARN("unrecognized Horde encryption mode: '{}'", Str);
return false;
}
diff --git a/src/zenhorde/hordeprovisioner.cpp b/src/zenhorde/hordeprovisioner.cpp
index f88c95da2..b08544d1a 100644
--- a/src/zenhorde/hordeprovisioner.cpp
+++ b/src/zenhorde/hordeprovisioner.cpp
@@ -6,49 +6,82 @@
#include "hordeagent.h"
#include "hordebundle.h"
+#include <zencore/compactbinary.h>
#include <zencore/fmtutils.h>
#include <zencore/logging.h>
#include <zencore/scopeguard.h>
#include <zencore/thread.h>
#include <zencore/trace.h>
+#include <zenhttp/httpclient.h>
+#include <zenutil/workerpools.h>
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <asio.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+#include <algorithm>
#include <chrono>
#include <thread>
namespace zen::horde {
-struct HordeProvisioner::AgentWrapper
-{
- std::thread Thread;
- std::atomic<bool> ShouldExit{false};
-};
-
HordeProvisioner::HordeProvisioner(const HordeConfig& Config,
const std::filesystem::path& BinariesPath,
const std::filesystem::path& WorkingDir,
- std::string_view OrchestratorEndpoint)
+ std::string_view OrchestratorEndpoint,
+ std::string_view CoordinatorSession,
+ bool CleanStart,
+ std::string_view TraceHost)
: m_Config(Config)
, m_BinariesPath(BinariesPath)
, m_WorkingDir(WorkingDir)
, m_OrchestratorEndpoint(OrchestratorEndpoint)
+, m_CoordinatorSession(CoordinatorSession)
+, m_CleanStart(CleanStart)
+, m_TraceHost(TraceHost)
, m_Log(zen::logging::Get("horde.provisioner"))
{
+ m_IoContext = std::make_unique<asio::io_context>();
+
+ auto Work = asio::make_work_guard(*m_IoContext);
+ for (int i = 0; i < IoThreadCount; ++i)
+ {
+ m_IoThreads.emplace_back([this, i, Work] {
+ zen::SetCurrentThreadName(fmt::format("horde_io_{}", i));
+ m_IoContext->run();
+ });
+ }
}
HordeProvisioner::~HordeProvisioner()
{
- std::lock_guard<std::mutex> Lock(m_AgentsLock);
- for (auto& Agent : m_Agents)
+ m_AskForAgents.store(false);
+
+ // Shut down async agents and io_context
{
- Agent->ShouldExit.store(true);
+ std::lock_guard<std::mutex> Lock(m_AsyncAgentsLock);
+ for (auto& Entry : m_AsyncAgents)
+ {
+ Entry.Agent->Cancel();
+ }
+ m_AsyncAgents.clear();
}
- for (auto& Agent : m_Agents)
+
+ m_IoContext->stop();
+
+ for (auto& Thread : m_IoThreads)
{
- if (Agent->Thread.joinable())
+ if (Thread.joinable())
{
- Agent->Thread.join();
+ Thread.join();
}
}
+
+ // Wait for all pool work items to finish before destroying members they reference
+ if (m_PendingWorkItems.load() > 0)
+ {
+ m_AllWorkDone.Wait();
+ }
}
void
@@ -56,9 +89,23 @@ HordeProvisioner::SetTargetCoreCount(uint32_t Count)
{
ZEN_TRACE_CPU("HordeProvisioner::SetTargetCoreCount");
- m_TargetCoreCount.store(std::min(Count, static_cast<uint32_t>(m_Config.MaxCores)));
+ const uint32_t ClampedCount = std::min(Count, static_cast<uint32_t>(m_Config.MaxCores));
+ const uint32_t PreviousTarget = m_TargetCoreCount.exchange(ClampedCount);
+
+ if (ClampedCount != PreviousTarget)
+ {
+ ZEN_INFO("target core count changed: {} -> {} (active={}, estimated={})",
+ PreviousTarget,
+ ClampedCount,
+ m_ActiveCoreCount.load(),
+ m_EstimatedCoreCount.load());
+ }
- while (m_EstimatedCoreCount.load() < m_TargetCoreCount.load())
+ // Only provision if the gap is at least one agent-sized chunk. Without
+ // this, draining a 32-core agent to cover a 28-core excess would leave a
+ // 4-core gap that triggers a 32-core provision, which triggers another
+ // drain, ad infinitum.
+ while (m_EstimatedCoreCount.load() + EstimatedCoresPerAgent <= m_TargetCoreCount.load())
{
if (!m_AskForAgents.load())
{
@@ -67,21 +114,108 @@ HordeProvisioner::SetTargetCoreCount(uint32_t Count)
RequestAgent();
}
- // Clean up finished agent threads
- std::lock_guard<std::mutex> Lock(m_AgentsLock);
- for (auto It = m_Agents.begin(); It != m_Agents.end();)
+ // Scale down async agents
{
- if ((*It)->ShouldExit.load())
+ std::lock_guard<std::mutex> AsyncLock(m_AsyncAgentsLock);
+
+ uint32_t AsyncActive = m_ActiveCoreCount.load();
+ uint32_t AsyncTarget = m_TargetCoreCount.load();
+
+ uint32_t AlreadyDrainingCores = 0;
+ for (const auto& Entry : m_AsyncAgents)
{
- if ((*It)->Thread.joinable())
+ if (Entry.Draining)
{
- (*It)->Thread.join();
+ AlreadyDrainingCores += Entry.CoreCount;
}
- It = m_Agents.erase(It);
}
- else
+
+ uint32_t EffectiveAsync = (AsyncActive > AlreadyDrainingCores) ? AsyncActive - AlreadyDrainingCores : 0;
+
+ if (EffectiveAsync > AsyncTarget)
{
- ++It;
+ struct Candidate
+ {
+ AsyncAgentEntry* Entry;
+ int Workload;
+ };
+ std::vector<Candidate> Candidates;
+
+ for (auto& Entry : m_AsyncAgents)
+ {
+ if (Entry.Draining || Entry.RemoteEndpoint.empty())
+ {
+ continue;
+ }
+
+ int Workload = 0;
+ bool Reachable = false;
+ HttpClientSettings Settings;
+ Settings.LogCategory = "horde.drain";
+ Settings.ConnectTimeout = std::chrono::milliseconds{2000};
+ Settings.Timeout = std::chrono::milliseconds{3000};
+ try
+ {
+ HttpClient Client(Entry.RemoteEndpoint, Settings);
+ HttpClient::Response Resp = Client.Get("/compute/session/status");
+ if (Resp.IsSuccess())
+ {
+ CbObject Status = Resp.AsObject();
+ Workload = Status["actions_pending"].AsInt32(0) + Status["actions_running"].AsInt32(0);
+ Reachable = true;
+ }
+ }
+ catch (const std::exception& Ex)
+ {
+ ZEN_DEBUG("agent lease={} not yet reachable for drain: {}", Entry.LeaseId, Ex.what());
+ }
+
+ if (Reachable)
+ {
+ Candidates.push_back({&Entry, Workload});
+ }
+ }
+
+ const uint32_t ExcessCores = EffectiveAsync - AsyncTarget;
+ uint32_t CoresDrained = 0;
+
+ while (CoresDrained < ExcessCores && !Candidates.empty())
+ {
+ const uint32_t Remaining = ExcessCores - CoresDrained;
+
+ Candidates.erase(std::remove_if(Candidates.begin(),
+ Candidates.end(),
+ [Remaining](const Candidate& C) { return C.Entry->CoreCount > Remaining; }),
+ Candidates.end());
+
+ if (Candidates.empty())
+ {
+ break;
+ }
+
+ Candidate* Best = &Candidates[0];
+ for (auto& C : Candidates)
+ {
+ if (C.Entry->CoreCount > Best->Entry->CoreCount ||
+ (C.Entry->CoreCount == Best->Entry->CoreCount && C.Workload < Best->Workload))
+ {
+ Best = &C;
+ }
+ }
+
+ ZEN_INFO("draining async agent lease={} ({} cores, workload={})",
+ Best->Entry->LeaseId,
+ Best->Entry->CoreCount,
+ Best->Workload);
+
+ DrainAsyncAgent(*Best->Entry);
+ CoresDrained += Best->Entry->CoreCount;
+
+ AsyncAgentEntry* Drained = Best->Entry;
+ Candidates.erase(
+ std::remove_if(Candidates.begin(), Candidates.end(), [Drained](const Candidate& C) { return C.Entry == Drained; }),
+ Candidates.end());
+ }
}
}
}
@@ -101,266 +235,380 @@ HordeProvisioner::GetStats() const
uint32_t
HordeProvisioner::GetAgentCount() const
{
- std::lock_guard<std::mutex> Lock(m_AgentsLock);
- return static_cast<uint32_t>(m_Agents.size());
+ std::lock_guard<std::mutex> Lock(m_AsyncAgentsLock);
+ return static_cast<uint32_t>(m_AsyncAgents.size());
}
-void
-HordeProvisioner::RequestAgent()
+compute::AgentProvisioningStatus
+HordeProvisioner::GetAgentStatus(std::string_view WorkerId) const
{
- m_EstimatedCoreCount.fetch_add(EstimatedCoresPerAgent);
+ // Worker IDs are "horde-{LeaseId}" — strip the prefix to match lease ID
+ constexpr std::string_view Prefix = "horde-";
+ if (!WorkerId.starts_with(Prefix))
+ {
+ return compute::AgentProvisioningStatus::Unknown;
+ }
+ std::string_view LeaseId = WorkerId.substr(Prefix.size());
- std::lock_guard<std::mutex> Lock(m_AgentsLock);
+ std::lock_guard<std::mutex> AsyncLock(m_AsyncAgentsLock);
+ for (const auto& Entry : m_AsyncAgents)
+ {
+ if (Entry.LeaseId == LeaseId)
+ {
+ if (Entry.Draining)
+ {
+ return compute::AgentProvisioningStatus::Draining;
+ }
+ return compute::AgentProvisioningStatus::Active;
+ }
+ }
- auto Wrapper = std::make_unique<AgentWrapper>();
- AgentWrapper& Ref = *Wrapper;
- Wrapper->Thread = std::thread([this, &Ref] { ThreadAgent(Ref); });
+ // Check recently-drained agents that have already been cleaned up
+ std::string WorkerIdStr(WorkerId);
+ if (m_RecentlyDrainedWorkerIds.erase(WorkerIdStr) > 0)
+ {
+ return compute::AgentProvisioningStatus::Draining;
+ }
- m_Agents.push_back(std::move(Wrapper));
+ return compute::AgentProvisioningStatus::Unknown;
}
-void
-HordeProvisioner::ThreadAgent(AgentWrapper& Wrapper)
+std::vector<std::string>
+HordeProvisioner::BuildAgentArgs(const MachineInfo& Machine) const
{
- ZEN_TRACE_CPU("HordeProvisioner::ThreadAgent");
+ std::vector<std::string> Args;
+ Args.emplace_back("compute");
+ Args.emplace_back("--http=asio");
+ Args.push_back(fmt::format("--port={}", m_Config.ZenServicePort));
+ Args.emplace_back("--data-dir=%UE_HORDE_SHARED_DIR%\\zen");
- static std::atomic<uint32_t> ThreadIndex{0};
- const uint32_t CurrentIndex = ThreadIndex.fetch_add(1);
+ if (m_CleanStart)
+ {
+ Args.emplace_back("--clean");
+ }
- zen::SetCurrentThreadName(fmt::format("horde_agent_{}", CurrentIndex));
+ if (!m_OrchestratorEndpoint.empty())
+ {
+ ExtendableStringBuilder<256> CoordArg;
+ CoordArg << "--coordinator-endpoint=" << m_OrchestratorEndpoint;
+ Args.emplace_back(CoordArg.ToView());
+ }
- std::unique_ptr<HordeAgent> Agent;
- uint32_t MachineCoreCount = 0;
+ {
+ ExtendableStringBuilder<128> IdArg;
+ IdArg << "--instance-id=horde-" << Machine.LeaseId;
+ Args.emplace_back(IdArg.ToView());
+ }
- auto _ = MakeGuard([&] {
- if (Agent)
- {
- Agent->CloseConnection();
- }
- Wrapper.ShouldExit.store(true);
- });
+ if (!m_CoordinatorSession.empty())
+ {
+ ExtendableStringBuilder<128> SessionArg;
+ SessionArg << "--coordinator-session=" << m_CoordinatorSession;
+ Args.emplace_back(SessionArg.ToView());
+ }
+ if (!m_TraceHost.empty())
{
- // EstimatedCoreCount is incremented speculatively when the agent is requested
- // (in RequestAgent) so that SetTargetCoreCount doesn't over-provision.
- auto $ = MakeGuard([&] { m_EstimatedCoreCount.fetch_sub(EstimatedCoresPerAgent); });
+ ExtendableStringBuilder<128> TraceArg;
+ TraceArg << "--tracehost=" << m_TraceHost;
+ Args.emplace_back(TraceArg.ToView());
+ }
+ // In relay mode, the remote zenserver's local address is not reachable from the
+ // orchestrator. Pass the relay-visible endpoint so it announces the correct URL.
+ if (Machine.Mode == ConnectionMode::Relay)
+ {
+ const auto [Addr, Port] = Machine.GetZenServiceEndpoint(m_Config.ZenServicePort);
+ if (Addr.find(':') != std::string::npos)
+ {
+ Args.push_back(fmt::format("--announce-url=http://[{}]:{}", Addr, Port));
+ }
+ else
{
- ZEN_TRACE_CPU("HordeProvisioner::CreateBundles");
+ Args.push_back(fmt::format("--announce-url=http://{}:{}", Addr, Port));
+ }
+ }
- std::lock_guard<std::mutex> BundleLock(m_BundleLock);
+ return Args;
+}
- if (!m_BundlesCreated)
- {
- const std::filesystem::path OutputDir = m_WorkingDir / "horde_bundles";
+bool
+HordeProvisioner::InitializeHordeClient()
+{
+ ZEN_TRACE_CPU("HordeProvisioner::InitializeHordeClient");
+
+ std::lock_guard<std::mutex> BundleLock(m_BundleLock);
- std::vector<BundleFile> Files;
+ if (!m_BundlesCreated)
+ {
+ const std::filesystem::path OutputDir = m_WorkingDir / "horde_bundles";
+
+ std::vector<BundleFile> Files;
#if ZEN_PLATFORM_WINDOWS
- Files.emplace_back(m_BinariesPath / "zenserver.exe", false);
+ Files.emplace_back(m_BinariesPath / "zenserver.exe", false);
+ Files.emplace_back(m_BinariesPath / "zenserver.pdb", true);
#elif ZEN_PLATFORM_LINUX
- Files.emplace_back(m_BinariesPath / "zenserver", false);
- Files.emplace_back(m_BinariesPath / "zenserver.debug", true);
+ Files.emplace_back(m_BinariesPath / "zenserver", false);
+ Files.emplace_back(m_BinariesPath / "zenserver.debug", true);
#elif ZEN_PLATFORM_MAC
- Files.emplace_back(m_BinariesPath / "zenserver", false);
+ Files.emplace_back(m_BinariesPath / "zenserver", false);
#endif
- BundleResult Result;
- if (!BundleCreator::CreateBundle(Files, OutputDir, Result))
- {
- ZEN_WARN("failed to create bundle, cannot provision any agents!");
- m_AskForAgents.store(false);
- return;
- }
-
- m_Bundles.emplace_back(Result.Locator, Result.BundleDir);
- m_BundlesCreated = true;
- }
-
- if (!m_HordeClient)
- {
- m_HordeClient = std::make_unique<HordeClient>(m_Config);
- if (!m_HordeClient->Initialize())
- {
- ZEN_WARN("failed to initialize Horde HTTP client, cannot provision any agents!");
- m_AskForAgents.store(false);
- return;
- }
- }
+ BundleResult Result;
+ if (!BundleCreator::CreateBundle(Files, OutputDir, Result))
+ {
+ ZEN_WARN("failed to create bundle, cannot provision any agents!");
+ m_AskForAgents.store(false);
+ return false;
}
- if (!m_AskForAgents.load())
+ m_Bundles.emplace_back(Result.Locator, Result.BundleDir);
+ m_BundlesCreated = true;
+ }
+
+ if (!m_HordeClient)
+ {
+ m_HordeClient = std::make_unique<HordeClient>(m_Config);
+ if (!m_HordeClient->Initialize())
{
- return;
+ ZEN_WARN("failed to initialize Horde HTTP client, cannot provision any agents!");
+ m_AskForAgents.store(false);
+ return false;
}
+ }
- m_AgentsRequesting.fetch_add(1);
- auto ReqGuard = MakeGuard([this] { m_AgentsRequesting.fetch_sub(1); });
+ return true;
+}
- // Simple backoff: if the last machine request failed, wait up to 5 seconds
- // before trying again.
- //
- // Note however that it's possible that multiple threads enter this code at
- // the same time if multiple agents are requested at once, and they will all
- // see the same last failure time and back off accordingly. We might want to
- // use a semaphore or similar to limit the number of concurrent requests.
+void
+HordeProvisioner::RequestAgent()
+{
+ m_EstimatedCoreCount.fetch_add(EstimatedCoresPerAgent);
- if (const uint64_t LastFail = m_LastRequestFailTime.load(); LastFail != 0)
- {
- auto Now = static_cast<uint64_t>(std::chrono::steady_clock::now().time_since_epoch().count());
- const uint64_t ElapsedNs = Now - LastFail;
- const uint64_t ElapsedMs = ElapsedNs / 1'000'000;
- if (ElapsedMs < 5000)
- {
- const uint64_t WaitMs = 5000 - ElapsedMs;
- for (uint64_t Waited = 0; Waited < WaitMs && !Wrapper.ShouldExit.load(); Waited += 100)
- {
- std::this_thread::sleep_for(std::chrono::milliseconds(100));
- }
+ if (m_PendingWorkItems.fetch_add(1) == 0)
+ {
+ m_AllWorkDone.Reset();
+ }
- if (Wrapper.ShouldExit.load())
+ GetSmallWorkerPool(EWorkloadType::Background)
+ .ScheduleWork(
+ [this] {
+ ProvisionAgent();
+ if (m_PendingWorkItems.fetch_sub(1) == 1)
{
- return;
+ m_AllWorkDone.Set();
}
- }
- }
-
- if (m_ActiveCoreCount.load() >= m_TargetCoreCount.load())
- {
- return;
- }
-
- std::string RequestBody = m_HordeClient->BuildRequestBody();
+ },
+ WorkerThreadPool::EMode::EnableBacklog);
+}
- // Resolve cluster if needed
- std::string ClusterId = m_Config.Cluster;
- if (ClusterId == HordeConfig::ClusterAuto)
- {
- ClusterInfo Cluster;
- if (!m_HordeClient->ResolveCluster(RequestBody, Cluster))
- {
- ZEN_WARN("failed to resolve cluster");
- m_LastRequestFailTime.store(static_cast<uint64_t>(std::chrono::steady_clock::now().time_since_epoch().count()));
- return;
- }
- ClusterId = Cluster.ClusterId;
- }
+void
+HordeProvisioner::ProvisionAgent()
+{
+ ZEN_TRACE_CPU("HordeProvisioner::ProvisionAgent");
- MachineInfo Machine;
- if (!m_HordeClient->RequestMachine(RequestBody, ClusterId, /* out */ Machine) || !Machine.IsValid())
- {
- m_LastRequestFailTime.store(static_cast<uint64_t>(std::chrono::steady_clock::now().time_since_epoch().count()));
- return;
- }
+ // EstimatedCoreCount is incremented speculatively when the agent is requested
+ // (in RequestAgent) so that SetTargetCoreCount doesn't over-provision.
+ auto $ = MakeGuard([&] { m_EstimatedCoreCount.fetch_sub(EstimatedCoresPerAgent); });
- m_LastRequestFailTime.store(0);
+ if (!InitializeHordeClient())
+ {
+ return;
+ }
- if (Wrapper.ShouldExit.load())
- {
- return;
- }
+ if (!m_AskForAgents.load())
+ {
+ return;
+ }
- // Connect to agent and perform handshake
- Agent = std::make_unique<HordeAgent>(Machine);
- if (!Agent->IsValid())
- {
- ZEN_WARN("agent creation failed for {}:{}", Machine.GetConnectionAddress(), Machine.GetConnectionPort());
- return;
- }
+ m_AgentsRequesting.fetch_add(1);
+ auto ReqGuard = MakeGuard([this] { m_AgentsRequesting.fetch_sub(1); });
- if (!Agent->BeginCommunication())
- {
- ZEN_WARN("BeginCommunication failed");
- return;
- }
+ // Simple backoff: if the last machine request failed, wait up to 5 seconds
+ // before trying again.
+ //
+ // Note however that it's possible that multiple threads enter this code at
+ // the same time if multiple agents are requested at once, and they will all
+ // see the same last failure time and back off accordingly. We might want to
+ // use a semaphore or similar to limit the number of concurrent requests.
- for (auto& [Locator, BundleDir] : m_Bundles)
+ if (const uint64_t LastFail = m_LastRequestFailTime.load(); LastFail != 0)
+ {
+ auto Now = static_cast<uint64_t>(std::chrono::steady_clock::now().time_since_epoch().count());
+ const uint64_t ElapsedNs = Now - LastFail;
+ const uint64_t ElapsedMs = ElapsedNs / 1'000'000;
+ if (ElapsedMs < 5000)
{
- if (Wrapper.ShouldExit.load())
+ const uint64_t WaitMs = 5000 - ElapsedMs;
+ for (uint64_t Waited = 0; Waited < WaitMs && !!m_AskForAgents.load(); Waited += 100)
{
- return;
+ std::this_thread::sleep_for(std::chrono::milliseconds(100));
}
- if (!Agent->UploadBinaries(BundleDir, Locator))
+ if (!m_AskForAgents.load())
{
- ZEN_WARN("UploadBinaries failed");
return;
}
}
+ }
- if (Wrapper.ShouldExit.load())
+ if (m_ActiveCoreCount.load() >= m_TargetCoreCount.load())
+ {
+ return;
+ }
+
+ std::string RequestBody = m_HordeClient->BuildRequestBody();
+
+ // Resolve cluster if needed
+ std::string ClusterId = m_Config.Cluster;
+ if (ClusterId == HordeConfig::ClusterAuto)
+ {
+ ClusterInfo Cluster;
+ if (!m_HordeClient->ResolveCluster(RequestBody, Cluster))
{
+ ZEN_WARN("failed to resolve cluster");
+ m_LastRequestFailTime.store(static_cast<uint64_t>(std::chrono::steady_clock::now().time_since_epoch().count()));
return;
}
+ ClusterId = Cluster.ClusterId;
+ }
- // Build command line for remote zenserver
- std::vector<std::string> ArgStrings;
- ArgStrings.push_back("compute");
- ArgStrings.push_back("--http=asio");
+ ZEN_INFO("requesting machine from Horde (cluster='{}', cores={}/{})",
+ ClusterId.empty() ? "default" : ClusterId.c_str(),
+ m_ActiveCoreCount.load(),
+ m_TargetCoreCount.load());
- // TEMP HACK - these should be made fully dynamic
- // these are currently here to allow spawning the compute agent locally
- // for debugging purposes (i.e with a local Horde Server+Agent setup)
- ArgStrings.push_back(fmt::format("--port={}", m_Config.ZenServicePort));
- ArgStrings.push_back("--data-dir=c:\\temp\\123");
+ MachineInfo Machine;
+ if (!m_HordeClient->RequestMachine(RequestBody, ClusterId, /* out */ Machine) || !Machine.IsValid())
+ {
+ m_LastRequestFailTime.store(static_cast<uint64_t>(std::chrono::steady_clock::now().time_since_epoch().count()));
+ return;
+ }
- if (!m_OrchestratorEndpoint.empty())
- {
- ExtendableStringBuilder<256> CoordArg;
- CoordArg << "--coordinator-endpoint=" << m_OrchestratorEndpoint;
- ArgStrings.emplace_back(CoordArg.ToView());
- }
+ m_LastRequestFailTime.store(0);
- {
- ExtendableStringBuilder<128> IdArg;
- IdArg << "--instance-id=horde-" << Machine.LeaseId;
- ArgStrings.emplace_back(IdArg.ToView());
- }
+ if (!m_AskForAgents.load())
+ {
+ return;
+ }
- std::vector<const char*> Args;
- Args.reserve(ArgStrings.size());
- for (const std::string& Arg : ArgStrings)
- {
- Args.push_back(Arg.c_str());
- }
+ AsyncAgentConfig AgentConfig;
+ AgentConfig.Machine = Machine;
+ AgentConfig.Bundles = m_Bundles;
+ AgentConfig.Args = BuildAgentArgs(Machine);
#if ZEN_PLATFORM_WINDOWS
- const bool UseWine = !Machine.IsWindows;
- const char* AppName = "zenserver.exe";
+ AgentConfig.UseWine = !Machine.IsWindows;
+ AgentConfig.Executable = "zenserver.exe";
#else
- const bool UseWine = false;
- const char* AppName = "zenserver";
+ AgentConfig.UseWine = false;
+ AgentConfig.Executable = "zenserver";
#endif
- Agent->Execute(AppName, Args.data(), Args.size(), nullptr, nullptr, 0, UseWine);
+ auto AsyncAgent = std::make_shared<AsyncHordeAgent>(*m_IoContext);
+
+ AsyncAgentEntry Entry;
+ Entry.Agent = AsyncAgent;
+ Entry.LeaseId = Machine.LeaseId;
+ Entry.CoreCount = Machine.LogicalCores;
- ZEN_INFO("remote execution started on [{}:{}] lease={}",
- Machine.GetConnectionAddress(),
- Machine.GetConnectionPort(),
- Machine.LeaseId);
+ const auto [EndpointAddr, EndpointPort] = Machine.GetZenServiceEndpoint(m_Config.ZenServicePort);
+ if (EndpointAddr.find(':') != std::string::npos)
+ {
+ Entry.RemoteEndpoint = fmt::format("http://[{}]:{}", EndpointAddr, EndpointPort);
+ }
+ else
+ {
+ Entry.RemoteEndpoint = fmt::format("http://{}:{}", EndpointAddr, EndpointPort);
+ }
- MachineCoreCount = Machine.LogicalCores;
- m_EstimatedCoreCount.fetch_add(MachineCoreCount);
- m_ActiveCoreCount.fetch_add(MachineCoreCount);
- m_AgentsActive.fetch_add(1);
+ {
+ std::lock_guard<std::mutex> Lock(m_AsyncAgentsLock);
+ m_AsyncAgents.push_back(std::move(Entry));
}
- // Agent poll loop
+ AsyncAgent->Start(std::move(AgentConfig), [this, AsyncAgent](const AsyncAgentResult& Result) {
+ if (Result.CoreCount > 0)
+ {
+ // Only subtract estimated cores if not already subtracted by DrainAsyncAgent
+ bool WasDraining = false;
+ {
+ std::lock_guard<std::mutex> Lock(m_AsyncAgentsLock);
+ for (const auto& Entry : m_AsyncAgents)
+ {
+ if (Entry.Agent == AsyncAgent)
+ {
+ WasDraining = Entry.Draining;
+ break;
+ }
+ }
+ }
- auto ActiveGuard = MakeGuard([&]() {
- m_EstimatedCoreCount.fetch_sub(MachineCoreCount);
- m_ActiveCoreCount.fetch_sub(MachineCoreCount);
- m_AgentsActive.fetch_sub(1);
+ if (!WasDraining)
+ {
+ m_EstimatedCoreCount.fetch_sub(Result.CoreCount);
+ }
+ m_ActiveCoreCount.fetch_sub(Result.CoreCount);
+ m_AgentsActive.fetch_sub(1);
+ }
+ OnAsyncAgentDone(AsyncAgent);
});
- while (Agent->IsValid() && !Wrapper.ShouldExit.load())
+ // Track active cores (estimated was already added by RequestAgent)
+ m_EstimatedCoreCount.fetch_add(Machine.LogicalCores);
+ m_ActiveCoreCount.fetch_add(Machine.LogicalCores);
+ m_AgentsActive.fetch_add(1);
+}
+
+void
+HordeProvisioner::DrainAsyncAgent(AsyncAgentEntry& Entry)
+{
+ Entry.Draining = true;
+ m_EstimatedCoreCount.fetch_sub(Entry.CoreCount);
+ m_AgentsDraining.fetch_add(1);
+
+ HttpClientSettings Settings;
+ Settings.LogCategory = "horde.drain";
+ Settings.ConnectTimeout = std::chrono::milliseconds{5000};
+ Settings.Timeout = std::chrono::milliseconds{10000};
+
+ try
{
- const bool LogOutput = false;
- if (!Agent->Poll(LogOutput))
+ HttpClient Client(Entry.RemoteEndpoint, Settings);
+
+ HttpClient::Response Response = Client.Post("/compute/session/drain");
+ if (!Response.IsSuccess())
{
+ ZEN_WARN("drain[{}]: POST session/drain failed: HTTP {}", Entry.LeaseId, static_cast<int>(Response.StatusCode));
+ return;
+ }
+
+ ZEN_INFO("drain[{}]: session/drain accepted, sending sunset", Entry.LeaseId);
+ (void)Client.Post("/compute/session/sunset");
+ }
+ catch (const std::exception& Ex)
+ {
+ ZEN_WARN("drain[{}]: exception: {}", Entry.LeaseId, Ex.what());
+ }
+}
+
+void
+HordeProvisioner::OnAsyncAgentDone(std::shared_ptr<AsyncHordeAgent> Agent)
+{
+ std::lock_guard<std::mutex> Lock(m_AsyncAgentsLock);
+ for (auto It = m_AsyncAgents.begin(); It != m_AsyncAgents.end(); ++It)
+ {
+ if (It->Agent == Agent)
+ {
+ if (It->Draining)
+ {
+ m_AgentsDraining.fetch_sub(1);
+ m_RecentlyDrainedWorkerIds.insert("horde-" + It->LeaseId);
+ }
+ m_AsyncAgents.erase(It);
break;
}
- std::this_thread::sleep_for(std::chrono::milliseconds(100));
}
}
diff --git a/src/zenhorde/hordetransport.cpp b/src/zenhorde/hordetransport.cpp
index 69766e73e..65eaea477 100644
--- a/src/zenhorde/hordetransport.cpp
+++ b/src/zenhorde/hordetransport.cpp
@@ -9,71 +9,33 @@ ZEN_THIRD_PARTY_INCLUDES_START
#include <asio.hpp>
ZEN_THIRD_PARTY_INCLUDES_END
-#if ZEN_PLATFORM_WINDOWS
-# undef SendMessage
-#endif
-
namespace zen::horde {
-// ComputeTransport base
+// --- AsyncTcpComputeTransport ---
-bool
-ComputeTransport::SendMessage(const void* Data, size_t Size)
+struct AsyncTcpComputeTransport::Impl
{
- const uint8_t* Ptr = static_cast<const uint8_t*>(Data);
- size_t Remaining = Size;
-
- while (Remaining > 0)
- {
- const size_t Sent = Send(Ptr, Remaining);
- if (Sent == 0)
- {
- return false;
- }
- Ptr += Sent;
- Remaining -= Sent;
- }
+ asio::io_context& IoContext;
+ asio::ip::tcp::socket Socket;
- return true;
-}
+ explicit Impl(asio::io_context& Ctx) : IoContext(Ctx), Socket(Ctx) {}
+};
-bool
-ComputeTransport::RecvMessage(void* Data, size_t Size)
+AsyncTcpComputeTransport::AsyncTcpComputeTransport(asio::io_context& IoContext)
+: m_Impl(std::make_unique<Impl>(IoContext))
+, m_Log(zen::logging::Get("horde.transport.async"))
{
- uint8_t* Ptr = static_cast<uint8_t*>(Data);
- size_t Remaining = Size;
-
- while (Remaining > 0)
- {
- const size_t Received = Recv(Ptr, Remaining);
- if (Received == 0)
- {
- return false;
- }
- Ptr += Received;
- Remaining -= Received;
- }
-
- return true;
}
-// TcpComputeTransport - ASIO pimpl
-
-struct TcpComputeTransport::Impl
+AsyncTcpComputeTransport::~AsyncTcpComputeTransport()
{
- asio::io_context IoContext;
- asio::ip::tcp::socket Socket;
-
- Impl() : Socket(IoContext) {}
-};
+ Close();
+}
-// Uses ASIO in synchronous mode only — no async operations or io_context::run().
-// The io_context is only needed because ASIO sockets require one to be constructed.
-TcpComputeTransport::TcpComputeTransport(const MachineInfo& Info)
-: m_Impl(std::make_unique<Impl>())
-, m_Log(zen::logging::Get("horde.transport"))
+void
+AsyncTcpComputeTransport::AsyncConnect(const MachineInfo& Info, AsyncConnectHandler Handler)
{
- ZEN_TRACE_CPU("TcpComputeTransport::Connect");
+ ZEN_TRACE_CPU("AsyncTcpComputeTransport::AsyncConnect");
asio::error_code Ec;
@@ -82,80 +44,75 @@ TcpComputeTransport::TcpComputeTransport(const MachineInfo& Info)
{
ZEN_WARN("invalid address '{}': {}", Info.GetConnectionAddress(), Ec.message());
m_HasErrors = true;
+ asio::post(m_Impl->IoContext, [Handler = std::move(Handler), Ec] { Handler(Ec); });
return;
}
const asio::ip::tcp::endpoint Endpoint(Address, Info.GetConnectionPort());
- m_Impl->Socket.connect(Endpoint, Ec);
- if (Ec)
- {
- ZEN_WARN("failed to connect to Horde compute [{}:{}]: {}", Info.GetConnectionAddress(), Info.GetConnectionPort(), Ec.message());
- m_HasErrors = true;
- return;
- }
+ // Copy the nonce so it survives past this scope into the async callback
+ auto NonceBuf = std::make_shared<std::vector<uint8_t>>(Info.Nonce, Info.Nonce + NonceSize);
- // Disable Nagle's algorithm for lower latency
- m_Impl->Socket.set_option(asio::ip::tcp::no_delay(true), Ec);
-}
+ m_Impl->Socket.async_connect(Endpoint, [this, Handler = std::move(Handler), NonceBuf](const asio::error_code& Ec) mutable {
+ if (Ec)
+ {
+ ZEN_WARN("async connect failed: {}", Ec.message());
+ m_HasErrors = true;
+ Handler(Ec);
+ return;
+ }
-TcpComputeTransport::~TcpComputeTransport()
-{
- Close();
+ asio::error_code SetOptEc;
+ m_Impl->Socket.set_option(asio::ip::tcp::no_delay(true), SetOptEc);
+
+ // Send the 64-byte nonce as the first thing on the wire
+ asio::async_write(m_Impl->Socket,
+ asio::buffer(*NonceBuf),
+ [this, Handler = std::move(Handler), NonceBuf](const asio::error_code& Ec, size_t /*BytesWritten*/) {
+ if (Ec)
+ {
+ ZEN_WARN("nonce write failed: {}", Ec.message());
+ m_HasErrors = true;
+ }
+ Handler(Ec);
+ });
+ });
}
bool
-TcpComputeTransport::IsValid() const
+AsyncTcpComputeTransport::IsValid() const
{
return m_Impl && m_Impl->Socket.is_open() && !m_HasErrors && !m_IsClosed;
}
-size_t
-TcpComputeTransport::Send(const void* Data, size_t Size)
+void
+AsyncTcpComputeTransport::AsyncWrite(const void* Data, size_t Size, AsyncIoHandler Handler)
{
if (!IsValid())
{
- return 0;
- }
-
- asio::error_code Ec;
- const size_t Sent = m_Impl->Socket.send(asio::buffer(Data, Size), 0, Ec);
-
- if (Ec)
- {
- m_HasErrors = true;
- return 0;
+ asio::post(m_Impl->IoContext,
+ [Handler = std::move(Handler)] { Handler(asio::error::make_error_code(asio::error::not_connected), 0); });
+ return;
}
- return Sent;
+ asio::async_write(m_Impl->Socket, asio::buffer(Data, Size), std::move(Handler));
}
-size_t
-TcpComputeTransport::Recv(void* Data, size_t Size)
+void
+AsyncTcpComputeTransport::AsyncRead(void* Data, size_t Size, AsyncIoHandler Handler)
{
if (!IsValid())
{
- return 0;
- }
-
- asio::error_code Ec;
- const size_t Received = m_Impl->Socket.receive(asio::buffer(Data, Size), 0, Ec);
-
- if (Ec)
- {
- return 0;
+ asio::post(m_Impl->IoContext,
+ [Handler = std::move(Handler)] { Handler(asio::error::make_error_code(asio::error::not_connected), 0); });
+ return;
}
- return Received;
-}
-
-void
-TcpComputeTransport::MarkComplete()
-{
+ asio::async_read(m_Impl->Socket, asio::buffer(Data, Size), std::move(Handler));
}
void
-TcpComputeTransport::Close()
+AsyncTcpComputeTransport::Close()
{
if (!m_IsClosed && m_Impl && m_Impl->Socket.is_open())
{
diff --git a/src/zenhorde/hordetransport.h b/src/zenhorde/hordetransport.h
index 1b178dc0f..b5e841d7a 100644
--- a/src/zenhorde/hordetransport.h
+++ b/src/zenhorde/hordetransport.h
@@ -8,55 +8,60 @@
#include <cstddef>
#include <cstdint>
+#include <functional>
#include <memory>
+#include <system_error>
-#if ZEN_PLATFORM_WINDOWS
-# undef SendMessage
-#endif
+namespace asio {
+class io_context;
+}
namespace zen::horde {
-/** Abstract base interface for compute transports.
+/** Handler types for async transport operations. */
+using AsyncConnectHandler = std::function<void(const std::error_code&)>;
+using AsyncIoHandler = std::function<void(const std::error_code&, size_t)>;
+
+/** Abstract base for asynchronous compute transports.
*
- * Matches the UE FComputeTransport pattern. Concrete implementations handle
- * the underlying I/O (TCP, AES-wrapped, etc.) while this interface provides
- * blocking message helpers on top.
+ * All callbacks are invoked on the io_context that was provided at construction.
+ * Callers are responsible for strand serialization if needed.
*/
-class ComputeTransport
+class AsyncComputeTransport
{
public:
- virtual ~ComputeTransport() = default;
+ virtual ~AsyncComputeTransport() = default;
+
+ virtual bool IsValid() const = 0;
- virtual bool IsValid() const = 0;
- virtual size_t Send(const void* Data, size_t Size) = 0;
- virtual size_t Recv(void* Data, size_t Size) = 0;
- virtual void MarkComplete() = 0;
- virtual void Close() = 0;
+ /** Asynchronous write of exactly Size bytes. Handler called on completion or error. */
+ virtual void AsyncWrite(const void* Data, size_t Size, AsyncIoHandler Handler) = 0;
- /** Blocking send that loops until all bytes are transferred. Returns false on error. */
- bool SendMessage(const void* Data, size_t Size);
+ /** Asynchronous read of exactly Size bytes into Data. Handler called on completion or error. */
+ virtual void AsyncRead(void* Data, size_t Size, AsyncIoHandler Handler) = 0;
- /** Blocking receive that loops until all bytes are transferred. Returns false on error. */
- bool RecvMessage(void* Data, size_t Size);
+ virtual void Close() = 0;
};
-/** TCP socket transport using ASIO.
+/** Async TCP transport using ASIO.
*
- * Connects to the Horde compute endpoint specified by MachineInfo and provides
- * raw TCP send/receive. ASIO internals are hidden behind a pimpl to keep the
- * header clean.
+ * Connects to the Horde compute endpoint and provides async send/receive.
+ * The socket is created on a caller-provided io_context (shared across agents).
*/
-class TcpComputeTransport final : public ComputeTransport
+class AsyncTcpComputeTransport final : public AsyncComputeTransport
{
public:
- explicit TcpComputeTransport(const MachineInfo& Info);
- ~TcpComputeTransport() override;
-
- bool IsValid() const override;
- size_t Send(const void* Data, size_t Size) override;
- size_t Recv(void* Data, size_t Size) override;
- void MarkComplete() override;
- void Close() override;
+ /** Construct a transport on the given io_context. Does not connect yet. */
+ explicit AsyncTcpComputeTransport(asio::io_context& IoContext);
+ ~AsyncTcpComputeTransport() override;
+
+ /** Asynchronously connect to the endpoint and send the nonce. */
+ void AsyncConnect(const MachineInfo& Info, AsyncConnectHandler Handler);
+
+ bool IsValid() const override;
+ void AsyncWrite(const void* Data, size_t Size, AsyncIoHandler Handler) override;
+ void AsyncRead(void* Data, size_t Size, AsyncIoHandler Handler) override;
+ void Close() override;
private:
LoggerRef Log() { return m_Log; }
diff --git a/src/zenhorde/hordetransportaes.cpp b/src/zenhorde/hordetransportaes.cpp
index 505b6bde7..c71866e8c 100644
--- a/src/zenhorde/hordetransportaes.cpp
+++ b/src/zenhorde/hordetransportaes.cpp
@@ -5,6 +5,10 @@
#include <zencore/logging.h>
#include <zencore/trace.h>
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <asio.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
#include <algorithm>
#include <cstring>
#include <random>
@@ -22,274 +26,281 @@ ZEN_THIRD_PARTY_INCLUDES_END
namespace zen::horde {
-struct AesComputeTransport::CryptoContext
-{
- uint8_t Key[KeySize] = {};
- uint8_t EncryptNonce[NonceBytes] = {};
- uint8_t DecryptNonce[NonceBytes] = {};
- bool HasErrors = false;
+namespace {
+
+ static constexpr size_t AesNonceBytes = 12;
+ static constexpr size_t AesTagBytes = 16;
+
+ /** AES-256-GCM crypto context. Not exposed outside this translation unit. */
+ struct AesCryptoContext
+ {
+ static constexpr size_t NonceBytes = AesNonceBytes;
+ static constexpr size_t TagBytes = AesTagBytes;
+
+ uint8_t Key[KeySize] = {};
+ uint8_t EncryptNonce[NonceBytes] = {};
+ uint8_t DecryptNonce[NonceBytes] = {};
+ bool HasErrors = false;
#if !ZEN_PLATFORM_WINDOWS
- EVP_CIPHER_CTX* EncCtx = nullptr;
- EVP_CIPHER_CTX* DecCtx = nullptr;
+ EVP_CIPHER_CTX* EncCtx = nullptr;
+ EVP_CIPHER_CTX* DecCtx = nullptr;
#endif
- CryptoContext(const uint8_t (&InKey)[KeySize])
- {
- memcpy(Key, InKey, KeySize);
-
- // The encrypt nonce is randomly initialized and then deterministically mutated
- // per message via UpdateNonce(). The decrypt nonce is not used — it comes from
- // the wire (each received message carries its own nonce in the header).
- std::random_device Rd;
- std::mt19937 Gen(Rd());
- std::uniform_int_distribution<int> Dist(0, 255);
- for (auto& Byte : EncryptNonce)
+ AesCryptoContext(const uint8_t (&InKey)[KeySize])
{
- Byte = static_cast<uint8_t>(Dist(Gen));
- }
+ memcpy(Key, InKey, KeySize);
+
+ std::random_device Rd;
+ std::mt19937 Gen(Rd());
+ std::uniform_int_distribution<int> Dist(0, 255);
+ for (auto& Byte : EncryptNonce)
+ {
+ Byte = static_cast<uint8_t>(Dist(Gen));
+ }
#if !ZEN_PLATFORM_WINDOWS
- // Drain any stale OpenSSL errors
- while (ERR_get_error() != 0)
- {
- }
+ while (ERR_get_error() != 0)
+ {
+ }
- EncCtx = EVP_CIPHER_CTX_new();
- EVP_EncryptInit_ex(EncCtx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr);
+ EncCtx = EVP_CIPHER_CTX_new();
+ EVP_EncryptInit_ex(EncCtx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr);
- DecCtx = EVP_CIPHER_CTX_new();
- EVP_DecryptInit_ex(DecCtx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr);
+ DecCtx = EVP_CIPHER_CTX_new();
+ EVP_DecryptInit_ex(DecCtx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr);
#endif
- }
+ }
- ~CryptoContext()
- {
+ ~AesCryptoContext()
+ {
#if ZEN_PLATFORM_WINDOWS
- SecureZeroMemory(Key, sizeof(Key));
- SecureZeroMemory(EncryptNonce, sizeof(EncryptNonce));
- SecureZeroMemory(DecryptNonce, sizeof(DecryptNonce));
+ SecureZeroMemory(Key, sizeof(Key));
+ SecureZeroMemory(EncryptNonce, sizeof(EncryptNonce));
+ SecureZeroMemory(DecryptNonce, sizeof(DecryptNonce));
#else
- OPENSSL_cleanse(Key, sizeof(Key));
- OPENSSL_cleanse(EncryptNonce, sizeof(EncryptNonce));
- OPENSSL_cleanse(DecryptNonce, sizeof(DecryptNonce));
-
- if (EncCtx)
- {
- EVP_CIPHER_CTX_free(EncCtx);
+ OPENSSL_cleanse(Key, sizeof(Key));
+ OPENSSL_cleanse(EncryptNonce, sizeof(EncryptNonce));
+ OPENSSL_cleanse(DecryptNonce, sizeof(DecryptNonce));
+
+ if (EncCtx)
+ {
+ EVP_CIPHER_CTX_free(EncCtx);
+ }
+ if (DecCtx)
+ {
+ EVP_CIPHER_CTX_free(DecCtx);
+ }
+#endif
}
- if (DecCtx)
+
+ void UpdateNonce()
{
- EVP_CIPHER_CTX_free(DecCtx);
+ uint32_t* N32 = reinterpret_cast<uint32_t*>(EncryptNonce);
+ N32[0]++;
+ N32[1]--;
+ N32[2] = N32[0] ^ N32[1];
}
-#endif
- }
-
- void UpdateNonce()
- {
- uint32_t* N32 = reinterpret_cast<uint32_t*>(EncryptNonce);
- N32[0]++;
- N32[1]--;
- N32[2] = N32[0] ^ N32[1];
- }
- // Returns total encrypted message size, or 0 on failure
- // Output format: [length(4B)][nonce(12B)][ciphertext][tag(16B)]
- int32_t EncryptMessage(uint8_t* Out, const void* In, int32_t InLength)
- {
- UpdateNonce();
+ int32_t EncryptMessage(uint8_t* Out, const void* In, int32_t InLength)
+ {
+ UpdateNonce();
- // On Windows, BCrypt algorithm/key handles are created per call. This is simpler than
- // caching but has some overhead. For our use case (relatively large, infrequent messages)
- // this is acceptable.
#if ZEN_PLATFORM_WINDOWS
- BCRYPT_ALG_HANDLE hAlg = nullptr;
- BCRYPT_KEY_HANDLE hKey = nullptr;
-
- BCryptOpenAlgorithmProvider(&hAlg, BCRYPT_AES_ALGORITHM, nullptr, 0);
- BCryptSetProperty(hAlg, BCRYPT_CHAINING_MODE, (PUCHAR)BCRYPT_CHAIN_MODE_GCM, sizeof(BCRYPT_CHAIN_MODE_GCM), 0);
- BCryptGenerateSymmetricKey(hAlg, &hKey, nullptr, 0, (PUCHAR)Key, KeySize, 0);
-
- BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO AuthInfo;
- BCRYPT_INIT_AUTH_MODE_INFO(AuthInfo);
- AuthInfo.pbNonce = EncryptNonce;
- AuthInfo.cbNonce = NonceBytes;
- uint8_t Tag[TagBytes] = {};
- AuthInfo.pbTag = Tag;
- AuthInfo.cbTag = TagBytes;
-
- ULONG CipherLen = 0;
- NTSTATUS Status =
- BCryptEncrypt(hKey, (PUCHAR)In, (ULONG)InLength, &AuthInfo, nullptr, 0, Out + 4 + NonceBytes, (ULONG)InLength, &CipherLen, 0);
-
- if (!BCRYPT_SUCCESS(Status))
- {
- HasErrors = true;
+ BCRYPT_ALG_HANDLE hAlg = nullptr;
+ BCRYPT_KEY_HANDLE hKey = nullptr;
+
+ BCryptOpenAlgorithmProvider(&hAlg, BCRYPT_AES_ALGORITHM, nullptr, 0);
+ BCryptSetProperty(hAlg, BCRYPT_CHAINING_MODE, (PUCHAR)BCRYPT_CHAIN_MODE_GCM, sizeof(BCRYPT_CHAIN_MODE_GCM), 0);
+ BCryptGenerateSymmetricKey(hAlg, &hKey, nullptr, 0, (PUCHAR)Key, KeySize, 0);
+
+ BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO AuthInfo;
+ BCRYPT_INIT_AUTH_MODE_INFO(AuthInfo);
+ AuthInfo.pbNonce = EncryptNonce;
+ AuthInfo.cbNonce = NonceBytes;
+ uint8_t Tag[TagBytes] = {};
+ AuthInfo.pbTag = Tag;
+ AuthInfo.cbTag = TagBytes;
+
+ ULONG CipherLen = 0;
+ NTSTATUS Status = BCryptEncrypt(hKey,
+ (PUCHAR)In,
+ (ULONG)InLength,
+ &AuthInfo,
+ nullptr,
+ 0,
+ Out + 4 + NonceBytes,
+ (ULONG)InLength,
+ &CipherLen,
+ 0);
+
+ if (!BCRYPT_SUCCESS(Status))
+ {
+ HasErrors = true;
+ BCryptDestroyKey(hKey);
+ BCryptCloseAlgorithmProvider(hAlg, 0);
+ return 0;
+ }
+
+ memcpy(Out, &InLength, 4);
+ memcpy(Out + 4, EncryptNonce, NonceBytes);
+ memcpy(Out + 4 + NonceBytes + CipherLen, Tag, TagBytes);
+
BCryptDestroyKey(hKey);
BCryptCloseAlgorithmProvider(hAlg, 0);
- return 0;
- }
-
- // Write header: length + nonce
- memcpy(Out, &InLength, 4);
- memcpy(Out + 4, EncryptNonce, NonceBytes);
- // Write tag after ciphertext
- memcpy(Out + 4 + NonceBytes + CipherLen, Tag, TagBytes);
-
- BCryptDestroyKey(hKey);
- BCryptCloseAlgorithmProvider(hAlg, 0);
- return 4 + NonceBytes + static_cast<int32_t>(CipherLen) + TagBytes;
+ return 4 + NonceBytes + static_cast<int32_t>(CipherLen) + TagBytes;
#else
- if (EVP_EncryptInit_ex(EncCtx, nullptr, nullptr, Key, EncryptNonce) != 1)
- {
- HasErrors = true;
- return 0;
- }
-
- int32_t Offset = 0;
- // Write length
- memcpy(Out + Offset, &InLength, 4);
- Offset += 4;
- // Write nonce
- memcpy(Out + Offset, EncryptNonce, NonceBytes);
- Offset += NonceBytes;
-
- // Encrypt
- int OutLen = 0;
- if (EVP_EncryptUpdate(EncCtx, Out + Offset, &OutLen, static_cast<const uint8_t*>(In), InLength) != 1)
- {
- HasErrors = true;
- return 0;
- }
- Offset += OutLen;
-
- // Finalize
- int FinalLen = 0;
- if (EVP_EncryptFinal_ex(EncCtx, Out + Offset, &FinalLen) != 1)
- {
- HasErrors = true;
- return 0;
+ if (EVP_EncryptInit_ex(EncCtx, nullptr, nullptr, Key, EncryptNonce) != 1)
+ {
+ HasErrors = true;
+ return 0;
+ }
+
+ int32_t Offset = 0;
+ memcpy(Out + Offset, &InLength, 4);
+ Offset += 4;
+ memcpy(Out + Offset, EncryptNonce, NonceBytes);
+ Offset += NonceBytes;
+
+ int OutLen = 0;
+ if (EVP_EncryptUpdate(EncCtx, Out + Offset, &OutLen, static_cast<const uint8_t*>(In), InLength) != 1)
+ {
+ HasErrors = true;
+ return 0;
+ }
+ Offset += OutLen;
+
+ int FinalLen = 0;
+ if (EVP_EncryptFinal_ex(EncCtx, Out + Offset, &FinalLen) != 1)
+ {
+ HasErrors = true;
+ return 0;
+ }
+ Offset += FinalLen;
+
+ if (EVP_CIPHER_CTX_ctrl(EncCtx, EVP_CTRL_GCM_GET_TAG, TagBytes, Out + Offset) != 1)
+ {
+ HasErrors = true;
+ return 0;
+ }
+ Offset += TagBytes;
+
+ return Offset;
+#endif
}
- Offset += FinalLen;
- // Get tag
- if (EVP_CIPHER_CTX_ctrl(EncCtx, EVP_CTRL_GCM_GET_TAG, TagBytes, Out + Offset) != 1)
+ int32_t DecryptMessage(void* Out, const uint8_t* Nonce, const uint8_t* CipherAndTag, int32_t DataLength)
{
- HasErrors = true;
- return 0;
- }
- Offset += TagBytes;
-
- return Offset;
-#endif
- }
-
- // Decrypt a message. Returns decrypted data length, or 0 on failure.
- // Input must be [ciphertext][tag], with nonce provided separately.
- int32_t DecryptMessage(void* Out, const uint8_t* Nonce, const uint8_t* CipherAndTag, int32_t DataLength)
- {
#if ZEN_PLATFORM_WINDOWS
- BCRYPT_ALG_HANDLE hAlg = nullptr;
- BCRYPT_KEY_HANDLE hKey = nullptr;
-
- BCryptOpenAlgorithmProvider(&hAlg, BCRYPT_AES_ALGORITHM, nullptr, 0);
- BCryptSetProperty(hAlg, BCRYPT_CHAINING_MODE, (PUCHAR)BCRYPT_CHAIN_MODE_GCM, sizeof(BCRYPT_CHAIN_MODE_GCM), 0);
- BCryptGenerateSymmetricKey(hAlg, &hKey, nullptr, 0, (PUCHAR)Key, KeySize, 0);
-
- BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO AuthInfo;
- BCRYPT_INIT_AUTH_MODE_INFO(AuthInfo);
- AuthInfo.pbNonce = const_cast<uint8_t*>(Nonce);
- AuthInfo.cbNonce = NonceBytes;
- AuthInfo.pbTag = const_cast<uint8_t*>(CipherAndTag + DataLength);
- AuthInfo.cbTag = TagBytes;
-
- ULONG PlainLen = 0;
- NTSTATUS Status = BCryptDecrypt(hKey,
- (PUCHAR)CipherAndTag,
- (ULONG)DataLength,
- &AuthInfo,
- nullptr,
- 0,
- (PUCHAR)Out,
- (ULONG)DataLength,
- &PlainLen,
- 0);
-
- BCryptDestroyKey(hKey);
- BCryptCloseAlgorithmProvider(hAlg, 0);
-
- if (!BCRYPT_SUCCESS(Status))
- {
- HasErrors = true;
- return 0;
- }
+ BCRYPT_ALG_HANDLE hAlg = nullptr;
+ BCRYPT_KEY_HANDLE hKey = nullptr;
+
+ BCryptOpenAlgorithmProvider(&hAlg, BCRYPT_AES_ALGORITHM, nullptr, 0);
+ BCryptSetProperty(hAlg, BCRYPT_CHAINING_MODE, (PUCHAR)BCRYPT_CHAIN_MODE_GCM, sizeof(BCRYPT_CHAIN_MODE_GCM), 0);
+ BCryptGenerateSymmetricKey(hAlg, &hKey, nullptr, 0, (PUCHAR)Key, KeySize, 0);
+
+ BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO AuthInfo;
+ BCRYPT_INIT_AUTH_MODE_INFO(AuthInfo);
+ AuthInfo.pbNonce = const_cast<uint8_t*>(Nonce);
+ AuthInfo.cbNonce = NonceBytes;
+ AuthInfo.pbTag = const_cast<uint8_t*>(CipherAndTag + DataLength);
+ AuthInfo.cbTag = TagBytes;
+
+ ULONG PlainLen = 0;
+ NTSTATUS Status = BCryptDecrypt(hKey,
+ (PUCHAR)CipherAndTag,
+ (ULONG)DataLength,
+ &AuthInfo,
+ nullptr,
+ 0,
+ (PUCHAR)Out,
+ (ULONG)DataLength,
+ &PlainLen,
+ 0);
- return static_cast<int32_t>(PlainLen);
-#else
- if (EVP_DecryptInit_ex(DecCtx, nullptr, nullptr, Key, Nonce) != 1)
- {
- HasErrors = true;
- return 0;
- }
+ BCryptDestroyKey(hKey);
+ BCryptCloseAlgorithmProvider(hAlg, 0);
- int OutLen = 0;
- if (EVP_DecryptUpdate(DecCtx, static_cast<uint8_t*>(Out), &OutLen, CipherAndTag, DataLength) != 1)
- {
- HasErrors = true;
- return 0;
- }
+ if (!BCRYPT_SUCCESS(Status))
+ {
+ HasErrors = true;
+ return 0;
+ }
- // Set the tag for verification
- if (EVP_CIPHER_CTX_ctrl(DecCtx, EVP_CTRL_GCM_SET_TAG, TagBytes, const_cast<uint8_t*>(CipherAndTag + DataLength)) != 1)
- {
- HasErrors = true;
- return 0;
+ return static_cast<int32_t>(PlainLen);
+#else
+ if (EVP_DecryptInit_ex(DecCtx, nullptr, nullptr, Key, Nonce) != 1)
+ {
+ HasErrors = true;
+ return 0;
+ }
+
+ int OutLen = 0;
+ if (EVP_DecryptUpdate(DecCtx, static_cast<uint8_t*>(Out), &OutLen, CipherAndTag, DataLength) != 1)
+ {
+ HasErrors = true;
+ return 0;
+ }
+
+ if (EVP_CIPHER_CTX_ctrl(DecCtx, EVP_CTRL_GCM_SET_TAG, TagBytes, const_cast<uint8_t*>(CipherAndTag + DataLength)) != 1)
+ {
+ HasErrors = true;
+ return 0;
+ }
+
+ int FinalLen = 0;
+ if (EVP_DecryptFinal_ex(DecCtx, static_cast<uint8_t*>(Out) + OutLen, &FinalLen) != 1)
+ {
+ HasErrors = true;
+ return 0;
+ }
+
+ return OutLen + FinalLen;
+#endif
}
+ };
- int FinalLen = 0;
- if (EVP_DecryptFinal_ex(DecCtx, static_cast<uint8_t*>(Out) + OutLen, &FinalLen) != 1)
- {
- HasErrors = true;
- return 0;
- }
+} // anonymous namespace
- return OutLen + FinalLen;
-#endif
- }
+struct AsyncAesComputeTransport::CryptoContext : AesCryptoContext
+{
+ using AesCryptoContext::AesCryptoContext;
};
-AesComputeTransport::AesComputeTransport(const uint8_t (&Key)[KeySize], std::unique_ptr<ComputeTransport> InnerTransport)
+// --- AsyncAesComputeTransport ---
+
+AsyncAesComputeTransport::AsyncAesComputeTransport(const uint8_t (&Key)[KeySize],
+ std::unique_ptr<AsyncComputeTransport> InnerTransport,
+ asio::io_context& IoContext)
: m_Crypto(std::make_unique<CryptoContext>(Key))
, m_Inner(std::move(InnerTransport))
+, m_IoContext(IoContext)
{
}
-AesComputeTransport::~AesComputeTransport()
+AsyncAesComputeTransport::~AsyncAesComputeTransport()
{
Close();
}
bool
-AesComputeTransport::IsValid() const
+AsyncAesComputeTransport::IsValid() const
{
return m_Inner && m_Inner->IsValid() && m_Crypto && !m_Crypto->HasErrors && !m_IsClosed;
}
-size_t
-AesComputeTransport::Send(const void* Data, size_t Size)
+void
+AsyncAesComputeTransport::AsyncWrite(const void* Data, size_t Size, AsyncIoHandler Handler)
{
- ZEN_TRACE_CPU("AesComputeTransport::Send");
-
if (!IsValid())
{
- return 0;
+ asio::post(m_IoContext, [Handler = std::move(Handler)] { Handler(asio::error::make_error_code(asio::error::not_connected), 0); });
+ return;
}
- std::lock_guard<std::mutex> Lock(m_Lock);
-
const int32_t DataLength = static_cast<int32_t>(Size);
- const size_t MessageLength = 4 + NonceBytes + Size + TagBytes;
+ const size_t MessageLength = 4 + CryptoContext::NonceBytes + Size + CryptoContext::TagBytes;
if (m_EncryptBuffer.size() < MessageLength)
{
@@ -299,38 +310,36 @@ AesComputeTransport::Send(const void* Data, size_t Size)
const int32_t EncryptedLen = m_Crypto->EncryptMessage(m_EncryptBuffer.data(), Data, DataLength);
if (EncryptedLen == 0)
{
- return 0;
+ asio::post(m_IoContext,
+ [Handler = std::move(Handler)] { Handler(asio::error::make_error_code(asio::error::connection_aborted), 0); });
+ return;
}
- if (!m_Inner->SendMessage(m_EncryptBuffer.data(), static_cast<size_t>(EncryptedLen)))
- {
- return 0;
- }
+ auto EncBuf = std::make_shared<std::vector<uint8_t>>(m_EncryptBuffer.begin(), m_EncryptBuffer.begin() + EncryptedLen);
- return Size;
+ m_Inner->AsyncWrite(
+ EncBuf->data(),
+ EncBuf->size(),
+ [Handler = std::move(Handler), EncBuf, Size](const std::error_code& Ec, size_t /*BytesWritten*/) { Handler(Ec, Ec ? 0 : Size); });
}
-size_t
-AesComputeTransport::Recv(void* Data, size_t Size)
+void
+AsyncAesComputeTransport::AsyncRead(void* Data, size_t Size, AsyncIoHandler Handler)
{
if (!IsValid())
{
- return 0;
+ asio::post(m_IoContext, [Handler = std::move(Handler)] { Handler(asio::error::make_error_code(asio::error::not_connected), 0); });
+ return;
}
- // AES-GCM decrypts entire messages at once, but the caller may request fewer bytes
- // than the decrypted message contains. Excess bytes are buffered in m_RemainingData
- // and returned on subsequent Recv calls without another decryption round-trip.
- ZEN_TRACE_CPU("AesComputeTransport::Recv");
-
- std::lock_guard<std::mutex> Lock(m_Lock);
+ uint8_t* Dest = static_cast<uint8_t*>(Data);
if (!m_RemainingData.empty())
{
const size_t Available = m_RemainingData.size() - m_RemainingOffset;
const size_t ToCopy = std::min(Available, Size);
- memcpy(Data, m_RemainingData.data() + m_RemainingOffset, ToCopy);
+ memcpy(Dest, m_RemainingData.data() + m_RemainingOffset, ToCopy);
m_RemainingOffset += ToCopy;
if (m_RemainingOffset >= m_RemainingData.size())
@@ -339,78 +348,96 @@ AesComputeTransport::Recv(void* Data, size_t Size)
m_RemainingOffset = 0;
}
- return ToCopy;
- }
-
- // Receive packet header: [length(4B)][nonce(12B)]
- struct PacketHeader
- {
- int32_t DataLength = 0;
- uint8_t Nonce[NonceBytes] = {};
- } Header;
-
- if (!m_Inner->RecvMessage(&Header, sizeof(Header)))
- {
- return 0;
- }
-
- // Validate DataLength to prevent OOM from malicious/corrupt peers
- static constexpr int32_t MaxDataLength = 64 * 1024 * 1024; // 64 MiB
-
- if (Header.DataLength <= 0 || Header.DataLength > MaxDataLength)
- {
- ZEN_WARN("AES recv: invalid DataLength {} from peer", Header.DataLength);
- return 0;
- }
-
- // Receive ciphertext + tag
- const size_t MessageLength = static_cast<size_t>(Header.DataLength) + TagBytes;
-
- if (m_EncryptBuffer.size() < MessageLength)
- {
- m_EncryptBuffer.resize(MessageLength);
- }
-
- if (!m_Inner->RecvMessage(m_EncryptBuffer.data(), MessageLength))
- {
- return 0;
- }
-
- // Decrypt
- const size_t BytesToReturn = std::min(static_cast<size_t>(Header.DataLength), Size);
-
- // We need a temporary buffer for decryption if we can't decrypt directly into output
- std::vector<uint8_t> DecryptedBuf(static_cast<size_t>(Header.DataLength));
-
- const int32_t Decrypted = m_Crypto->DecryptMessage(DecryptedBuf.data(), Header.Nonce, m_EncryptBuffer.data(), Header.DataLength);
- if (Decrypted == 0)
- {
- return 0;
- }
-
- memcpy(Data, DecryptedBuf.data(), BytesToReturn);
+ if (ToCopy == Size)
+ {
+ asio::post(m_IoContext, [Handler = std::move(Handler), Size] { Handler(std::error_code{}, Size); });
+ return;
+ }
- // Store remaining data if we couldn't return everything
- if (static_cast<size_t>(Header.DataLength) > BytesToReturn)
- {
- m_RemainingOffset = 0;
- m_RemainingData.assign(DecryptedBuf.begin() + BytesToReturn, DecryptedBuf.begin() + Header.DataLength);
+ DoRecvMessage(Dest + ToCopy, Size - ToCopy, std::move(Handler));
+ return;
}
- return BytesToReturn;
+ DoRecvMessage(Dest, Size, std::move(Handler));
}
void
-AesComputeTransport::MarkComplete()
+AsyncAesComputeTransport::DoRecvMessage(uint8_t* Dest, size_t Size, AsyncIoHandler Handler)
{
- if (IsValid())
- {
- m_Inner->MarkComplete();
- }
+ static constexpr size_t HeaderSize = 4 + CryptoContext::NonceBytes;
+ auto HeaderBuf = std::make_shared<std::array<uint8_t, 4 + 12>>();
+
+ m_Inner->AsyncRead(HeaderBuf->data(),
+ HeaderSize,
+ [this, Dest, Size, Handler = std::move(Handler), HeaderBuf](const std::error_code& Ec, size_t /*Bytes*/) mutable {
+ if (Ec)
+ {
+ Handler(Ec, 0);
+ return;
+ }
+
+ int32_t DataLength = 0;
+ memcpy(&DataLength, HeaderBuf->data(), 4);
+
+ static constexpr int32_t MaxDataLength = 64 * 1024 * 1024;
+ if (DataLength <= 0 || DataLength > MaxDataLength)
+ {
+ Handler(asio::error::make_error_code(asio::error::invalid_argument), 0);
+ return;
+ }
+
+ const size_t MessageLength = static_cast<size_t>(DataLength) + CryptoContext::TagBytes;
+ if (m_DecryptBuffer.size() < MessageLength)
+ {
+ m_DecryptBuffer.resize(MessageLength);
+ }
+
+ auto NonceBuf = std::make_shared<std::array<uint8_t, CryptoContext::NonceBytes>>();
+ memcpy(NonceBuf->data(), HeaderBuf->data() + 4, CryptoContext::NonceBytes);
+
+ m_Inner->AsyncRead(
+ m_DecryptBuffer.data(),
+ MessageLength,
+ [this, Dest, Size, Handler = std::move(Handler), DataLength, NonceBuf](const std::error_code& Ec,
+ size_t /*Bytes*/) mutable {
+ if (Ec)
+ {
+ Handler(Ec, 0);
+ return;
+ }
+
+ std::vector<uint8_t> PlaintextBuf(static_cast<size_t>(DataLength));
+ const int32_t Decrypted =
+ m_Crypto->DecryptMessage(PlaintextBuf.data(), NonceBuf->data(), m_DecryptBuffer.data(), DataLength);
+ if (Decrypted == 0)
+ {
+ Handler(asio::error::make_error_code(asio::error::connection_aborted), 0);
+ return;
+ }
+
+ const size_t BytesToReturn = std::min(static_cast<size_t>(Decrypted), Size);
+ memcpy(Dest, PlaintextBuf.data(), BytesToReturn);
+
+ if (static_cast<size_t>(Decrypted) > BytesToReturn)
+ {
+ m_RemainingOffset = 0;
+ m_RemainingData.assign(PlaintextBuf.begin() + BytesToReturn, PlaintextBuf.begin() + Decrypted);
+ }
+
+ if (BytesToReturn < Size)
+ {
+ DoRecvMessage(Dest + BytesToReturn, Size - BytesToReturn, std::move(Handler));
+ }
+ else
+ {
+ Handler(std::error_code{}, Size);
+ }
+ });
+ });
}
void
-AesComputeTransport::Close()
+AsyncAesComputeTransport::Close()
{
if (!m_IsClosed)
{
diff --git a/src/zenhorde/hordetransportaes.h b/src/zenhorde/hordetransportaes.h
index efcad9835..a1800c684 100644
--- a/src/zenhorde/hordetransportaes.h
+++ b/src/zenhorde/hordetransportaes.h
@@ -6,47 +6,53 @@
#include <cstdint>
#include <memory>
-#include <mutex>
#include <vector>
+namespace asio {
+class io_context;
+}
+
namespace zen::horde {
-/** AES-256-GCM encrypted transport wrapper.
+/** Async AES-256-GCM encrypted transport wrapper.
*
- * Wraps an inner ComputeTransport, encrypting all outgoing data and decrypting
- * all incoming data using AES-256-GCM. The nonce is mutated per message using
- * the Horde nonce mangling scheme: n32[0]++; n32[1]--; n32[2] = n32[0] ^ n32[1].
+ * Wraps an AsyncComputeTransport, encrypting outgoing and decrypting incoming
+ * data using AES-256-GCM. The nonce is mutated per message using the Horde
+ * nonce mangling scheme: n32[0]++; n32[1]--; n32[2] = n32[0] ^ n32[1].
*
* Wire format per encrypted message:
* [plaintext length (4B little-endian)][nonce (12B)][ciphertext][GCM tag (16B)]
*
* Uses BCrypt on Windows and OpenSSL EVP on Linux/macOS (selected at compile time).
+ *
+ * Thread safety: all operations must be serialized by the caller (e.g. via a strand).
*/
-class AesComputeTransport final : public ComputeTransport
+class AsyncAesComputeTransport final : public AsyncComputeTransport
{
public:
- AesComputeTransport(const uint8_t (&Key)[KeySize], std::unique_ptr<ComputeTransport> InnerTransport);
- ~AesComputeTransport() override;
+ AsyncAesComputeTransport(const uint8_t (&Key)[KeySize],
+ std::unique_ptr<AsyncComputeTransport> InnerTransport,
+ asio::io_context& IoContext);
+ ~AsyncAesComputeTransport() override;
- bool IsValid() const override;
- size_t Send(const void* Data, size_t Size) override;
- size_t Recv(void* Data, size_t Size) override;
- void MarkComplete() override;
- void Close() override;
+ bool IsValid() const override;
+ void AsyncWrite(const void* Data, size_t Size, AsyncIoHandler Handler) override;
+ void AsyncRead(void* Data, size_t Size, AsyncIoHandler Handler) override;
+ void Close() override;
private:
- static constexpr size_t NonceBytes = 12; ///< AES-GCM nonce size
- static constexpr size_t TagBytes = 16; ///< AES-GCM authentication tag size
+ void DoRecvMessage(uint8_t* Dest, size_t Size, AsyncIoHandler Handler);
struct CryptoContext;
- std::unique_ptr<CryptoContext> m_Crypto;
- std::unique_ptr<ComputeTransport> m_Inner;
- std::vector<uint8_t> m_EncryptBuffer;
- std::vector<uint8_t> m_RemainingData; ///< Buffered decrypted data from a partially consumed Recv
- size_t m_RemainingOffset = 0;
- std::mutex m_Lock;
- bool m_IsClosed = false;
+ std::unique_ptr<CryptoContext> m_Crypto;
+ std::unique_ptr<AsyncComputeTransport> m_Inner;
+ asio::io_context& m_IoContext;
+ std::vector<uint8_t> m_EncryptBuffer;
+ std::vector<uint8_t> m_DecryptBuffer;
+ std::vector<uint8_t> m_RemainingData;
+ size_t m_RemainingOffset = 0;
+ bool m_IsClosed = false;
};
} // namespace zen::horde
diff --git a/src/zenhorde/include/zenhorde/hordeclient.h b/src/zenhorde/include/zenhorde/hordeclient.h
index 201d68b83..87caec019 100644
--- a/src/zenhorde/include/zenhorde/hordeclient.h
+++ b/src/zenhorde/include/zenhorde/hordeclient.h
@@ -45,14 +45,15 @@ struct MachineInfo
uint8_t Key[KeySize] = {}; ///< 32-byte AES key (when EncryptionMode == AES)
bool IsWindows = false;
std::string LeaseId;
+ std::string Pool;
std::map<std::string, PortInfo> Ports;
/** Return the address to connect to, accounting for connection mode. */
- const std::string& GetConnectionAddress() const { return Mode == ConnectionMode::Relay ? ConnectionAddress : Ip; }
+ [[nodiscard]] const std::string& GetConnectionAddress() const { return Mode == ConnectionMode::Relay ? ConnectionAddress : Ip; }
/** Return the port to connect to, accounting for connection mode and port mapping. */
- uint16_t GetConnectionPort() const
+ [[nodiscard]] uint16_t GetConnectionPort() const
{
if (Mode == ConnectionMode::Relay)
{
@@ -65,7 +66,20 @@ struct MachineInfo
return Port;
}
- bool IsValid() const { return !Ip.empty() && Port != 0xFFFF; }
+ /** Return the address and port for the Zen service endpoint, accounting for relay port mapping. */
+ [[nodiscard]] std::pair<const std::string&, uint16_t> GetZenServiceEndpoint(uint16_t DefaultPort) const
+ {
+ if (Mode == ConnectionMode::Relay)
+ {
+ if (auto It = Ports.find("ZenPort"); It != Ports.end())
+ {
+ return {ConnectionAddress, It->second.Port};
+ }
+ }
+ return {Ip, DefaultPort};
+ }
+
+ [[nodiscard]] bool IsValid() const { return !Ip.empty() && Port != 0xFFFF; }
};
/** Result of cluster auto-resolution via the Horde API. */
@@ -83,31 +97,29 @@ struct ClusterInfo
class HordeClient
{
public:
- explicit HordeClient(const HordeConfig& Config);
+ explicit HordeClient(HordeConfig Config);
~HordeClient();
HordeClient(const HordeClient&) = delete;
HordeClient& operator=(const HordeClient&) = delete;
/** Initialize the underlying HTTP client. Must be called before other methods. */
- bool Initialize();
+ [[nodiscard]] bool Initialize();
/** Build the JSON request body for cluster resolution and machine requests.
* Encodes pool, condition, connection mode, encryption, and port requirements. */
- std::string BuildRequestBody() const;
+ [[nodiscard]] std::string BuildRequestBody() const;
/** Resolve the best cluster for the given request via POST /api/v2/compute/_cluster. */
- bool ResolveCluster(const std::string& RequestBody, ClusterInfo& OutCluster);
+ [[nodiscard]] bool ResolveCluster(const std::string& RequestBody, ClusterInfo& OutCluster);
/** Request a compute machine from the given cluster via POST /api/v2/compute/{clusterId}.
* On success, populates OutMachine with connection details and credentials. */
- bool RequestMachine(const std::string& RequestBody, const std::string& ClusterId, MachineInfo& OutMachine);
+ [[nodiscard]] bool RequestMachine(const std::string& RequestBody, const std::string& ClusterId, MachineInfo& OutMachine);
LoggerRef Log() { return m_Log; }
private:
- bool ParseHexBytes(std::string_view Hex, uint8_t* Out, size_t OutSize);
-
HordeConfig m_Config;
std::unique_ptr<zen::HttpClient> m_Http;
LoggerRef m_Log;
diff --git a/src/zenhorde/include/zenhorde/hordeconfig.h b/src/zenhorde/include/zenhorde/hordeconfig.h
index dd70f9832..3a4dfb386 100644
--- a/src/zenhorde/include/zenhorde/hordeconfig.h
+++ b/src/zenhorde/include/zenhorde/hordeconfig.h
@@ -4,6 +4,10 @@
#include <zenhorde/zenhorde.h>
+#include <zenhttp/httpclient.h>
+
+#include <functional>
+#include <optional>
#include <string>
namespace zen::horde {
@@ -33,20 +37,25 @@ struct HordeConfig
static constexpr const char* ClusterDefault = "default";
static constexpr const char* ClusterAuto = "_auto";
- bool Enabled = false; ///< Whether Horde provisioning is active
- std::string ServerUrl; ///< Horde server base URL (e.g. "https://horde.example.com")
- std::string AuthToken; ///< Authentication token for the Horde API
- std::string Pool; ///< Pool name to request machines from
- std::string Cluster = ClusterDefault; ///< Cluster ID, or "_auto" to auto-resolve
- std::string Condition; ///< Agent filter expression for machine selection
- std::string HostAddress; ///< Address that provisioned agents use to connect back to us
- std::string BinariesPath; ///< Path to directory containing zenserver binary for remote upload
- uint16_t ZenServicePort = 8558; ///< Port number that provisioned agents should forward to us for Zen service communication
-
- int MaxCores = 2048;
- bool AllowWine = true; ///< Allow running Windows binaries under Wine on Linux agents
- ConnectionMode Mode = ConnectionMode::Direct;
- Encryption EncryptionMode = Encryption::None;
+ bool Enabled = false; ///< Whether Horde provisioning is active
+ std::string ServerUrl; ///< Horde server base URL (e.g. "https://horde.example.com")
+ std::string AuthToken; ///< Authentication token for the Horde API (static fallback)
+
+ /// Optional token provider with automatic refresh (e.g. from OidcToken executable).
+ /// When set, takes priority over the static AuthToken string.
+ std::optional<std::function<HttpClientAccessToken()>> AccessTokenProvider;
+ std::string Pool; ///< Pool name to request machines from
+ std::string Cluster = ClusterDefault; ///< Cluster ID, or "_auto" to auto-resolve
+ std::string Condition; ///< Agent filter expression for machine selection
+ std::string HostAddress; ///< Address that provisioned agents use to connect back to us
+ std::string BinariesPath; ///< Path to directory containing zenserver binary for remote upload
+ uint16_t ZenServicePort = 8558; ///< Port number that provisioned agents should forward to us for Zen service communication
+
+ int MaxCores = 2048;
+ int DrainGracePeriodSeconds = 300; ///< Grace period for draining agents before force-kill (default 5 min)
+ bool AllowWine = true; ///< Allow running Windows binaries under Wine on Linux agents
+ ConnectionMode Mode = ConnectionMode::Direct;
+ Encryption EncryptionMode = Encryption::None;
/** Validate the configuration. Returns false if the configuration is invalid
* (e.g. Relay mode without AES encryption). */
diff --git a/src/zenhorde/include/zenhorde/hordeprovisioner.h b/src/zenhorde/include/zenhorde/hordeprovisioner.h
index 4e2e63bbd..1dd12936b 100644
--- a/src/zenhorde/include/zenhorde/hordeprovisioner.h
+++ b/src/zenhorde/include/zenhorde/hordeprovisioner.h
@@ -2,9 +2,12 @@
#pragma once
+#include <zenhorde/hordeclient.h>
#include <zenhorde/hordeconfig.h>
+#include <zencompute/provisionerstate.h>
#include <zencore/logbase.h>
+#include <zencore/thread.h>
#include <atomic>
#include <cstdint>
@@ -12,11 +15,18 @@
#include <memory>
#include <mutex>
#include <string>
+#include <thread>
+#include <unordered_set>
#include <vector>
+namespace asio {
+class io_context;
+}
+
namespace zen::horde {
class HordeClient;
+class AsyncHordeAgent;
/** Snapshot of the current provisioning state, returned by HordeProvisioner::GetStats(). */
struct ProvisioningStats
@@ -35,13 +45,12 @@ struct ProvisioningStats
* binary, and executing it remotely. Each provisioned machine runs zenserver
* in compute mode, which announces itself back to the orchestrator.
*
- * Spawns one thread per agent. Each thread handles the full lifecycle:
- * HTTP request -> TCP connect -> nonce handshake -> optional AES encryption ->
- * channel setup -> binary upload -> remote execution -> poll until exit.
+ * Agent work (HTTP request, connect, upload, poll) is dispatched to a thread
+ * pool rather than spawning a dedicated thread per agent.
*
* Thread safety: SetTargetCoreCount and GetStats may be called from any thread.
*/
-class HordeProvisioner
+class HordeProvisioner : public compute::IProvisionerStateProvider
{
public:
/** Construct a provisioner.
@@ -52,38 +61,48 @@ public:
HordeProvisioner(const HordeConfig& Config,
const std::filesystem::path& BinariesPath,
const std::filesystem::path& WorkingDir,
- std::string_view OrchestratorEndpoint);
+ std::string_view OrchestratorEndpoint,
+ std::string_view CoordinatorSession = {},
+ bool CleanStart = false,
+ std::string_view TraceHost = {});
- /** Signals all agent threads to exit and joins them. */
- ~HordeProvisioner();
+ /** Signals all agents to exit and waits for completion. */
+ ~HordeProvisioner() override;
HordeProvisioner(const HordeProvisioner&) = delete;
HordeProvisioner& operator=(const HordeProvisioner&) = delete;
/** Set the target number of cores to provision.
- * Clamped to HordeConfig::MaxCores. Spawns new agent threads if the
- * estimated core count is below the target. Also joins any finished
- * agent threads. */
- void SetTargetCoreCount(uint32_t Count);
+ * Clamped to HordeConfig::MaxCores. Dispatches new agent work if the
+ * estimated core count is below the target. Also removes finished agents. */
+ void SetTargetCoreCount(uint32_t Count) override;
/** Return a snapshot of the current provisioning counters. */
ProvisioningStats GetStats() const;
- uint32_t GetActiveCoreCount() const { return m_ActiveCoreCount.load(); }
- uint32_t GetAgentCount() const;
+ // IProvisionerStateProvider
+ std::string_view GetName() const override { return "horde"; }
+ uint32_t GetTargetCoreCount() const override { return m_TargetCoreCount.load(); }
+ uint32_t GetEstimatedCoreCount() const override { return m_EstimatedCoreCount.load(); }
+ uint32_t GetActiveCoreCount() const override { return m_ActiveCoreCount.load(); }
+ uint32_t GetAgentCount() const override;
+ uint32_t GetDrainingAgentCount() const override { return m_AgentsDraining.load(); }
+ compute::AgentProvisioningStatus GetAgentStatus(std::string_view WorkerId) const override;
private:
LoggerRef Log() { return m_Log; }
- struct AgentWrapper;
-
void RequestAgent();
- void ThreadAgent(AgentWrapper& Wrapper);
+ void ProvisionAgent();
+ bool InitializeHordeClient();
HordeConfig m_Config;
std::filesystem::path m_BinariesPath;
std::filesystem::path m_WorkingDir;
std::string m_OrchestratorEndpoint;
+ std::string m_CoordinatorSession;
+ bool m_CleanStart = false;
+ std::string m_TraceHost;
std::unique_ptr<HordeClient> m_HordeClient;
@@ -91,20 +110,43 @@ private:
std::vector<std::pair<std::string, std::filesystem::path>> m_Bundles; ///< (locator, bundleDir) pairs
bool m_BundlesCreated = false;
- mutable std::mutex m_AgentsLock;
- std::vector<std::unique_ptr<AgentWrapper>> m_Agents;
-
std::atomic<uint64_t> m_LastRequestFailTime{0};
std::atomic<uint32_t> m_TargetCoreCount{0};
std::atomic<uint32_t> m_EstimatedCoreCount{0};
std::atomic<uint32_t> m_ActiveCoreCount{0};
std::atomic<uint32_t> m_AgentsActive{0};
+ std::atomic<uint32_t> m_AgentsDraining{0};
std::atomic<uint32_t> m_AgentsRequesting{0};
std::atomic<bool> m_AskForAgents{true};
+ std::atomic<uint32_t> m_PendingWorkItems{0};
+ Event m_AllWorkDone;
LoggerRef m_Log;
+ // Async I/O
+ std::unique_ptr<asio::io_context> m_IoContext;
+ std::vector<std::thread> m_IoThreads;
+
+ struct AsyncAgentEntry
+ {
+ std::shared_ptr<AsyncHordeAgent> Agent;
+ std::string RemoteEndpoint;
+ std::string LeaseId;
+ uint16_t CoreCount = 0;
+ bool Draining = false;
+ };
+
+ mutable std::mutex m_AsyncAgentsLock;
+ std::vector<AsyncAgentEntry> m_AsyncAgents;
+ mutable std::unordered_set<std::string> m_RecentlyDrainedWorkerIds; ///< Worker IDs of agents that completed after draining
+
+ void OnAsyncAgentDone(std::shared_ptr<AsyncHordeAgent> Agent);
+ void DrainAsyncAgent(AsyncAgentEntry& Entry);
+
+ std::vector<std::string> BuildAgentArgs(const MachineInfo& Machine) const;
+
static constexpr uint32_t EstimatedCoresPerAgent = 32;
+ static constexpr int IoThreadCount = 3;
};
} // namespace zen::horde
diff --git a/src/zenhttp/clients/httpclientcurl.cpp b/src/zenhttp/clients/httpclientcurl.cpp
index b9af9bd52..56b9c39c5 100644
--- a/src/zenhttp/clients/httpclientcurl.cpp
+++ b/src/zenhttp/clients/httpclientcurl.cpp
@@ -228,6 +228,13 @@ CurlHttpClient::Session::Perform()
curl_easy_getinfo(Handle, CURLINFO_SIZE_DOWNLOAD_T, &DownBytes);
Result.DownloadedBytes = static_cast<int64_t>(DownBytes);
+ char* EffectiveUrl = nullptr;
+ curl_easy_getinfo(Handle, CURLINFO_EFFECTIVE_URL, &EffectiveUrl);
+ if (EffectiveUrl)
+ {
+ Result.Url = EffectiveUrl;
+ }
+
return Result;
}
@@ -294,8 +301,9 @@ CurlHttpClient::CommonResponse(std::string_view SessionId,
if (Result.ErrorCode != CURLE_OPERATION_TIMEDOUT && Result.ErrorCode != CURLE_COULDNT_CONNECT &&
Result.ErrorCode != CURLE_ABORTED_BY_CALLBACK)
{
- ZEN_WARN("HttpClient client failure (session: {}): ({}) '{}'",
+ ZEN_WARN("HttpClient client failure (session: {}, url: {}): ({}) '{}'",
SessionId,
+ Result.Url,
static_cast<int>(Result.ErrorCode),
Result.ErrorMessage);
}
@@ -443,6 +451,7 @@ CurlHttpClient::ShouldRetry(const CurlResult& Result)
case CURLE_RECV_ERROR:
case CURLE_SEND_ERROR:
case CURLE_OPERATION_TIMEDOUT:
+ case CURLE_PARTIAL_FILE:
return true;
default:
return false;
@@ -489,10 +498,11 @@ CurlHttpClient::DoWithRetry(std::string_view SessionId, std::function<CurlResult
{
if (Result.ErrorCode != CURLE_OK)
{
- ZEN_INFO("Retry (session: {}): HTTP error ({}) '{}' Attempt {}/{}",
+ ZEN_INFO("Retry (session: {}): HTTP error ({}) '{}' (Curl error: {}) Attempt {}/{}",
SessionId,
static_cast<int>(MapCurlError(Result.ErrorCode)),
Result.ErrorMessage,
+ static_cast<int>(Result.ErrorCode),
Attempt,
m_ConnectionSettings.RetryCount + 1);
}
diff --git a/src/zenhttp/clients/httpclientcurl.h b/src/zenhttp/clients/httpclientcurl.h
index bdeb46633..ea9193e65 100644
--- a/src/zenhttp/clients/httpclientcurl.h
+++ b/src/zenhttp/clients/httpclientcurl.h
@@ -73,6 +73,7 @@ private:
int64_t DownloadedBytes = 0;
CURLcode ErrorCode = CURLE_OK;
std::string ErrorMessage;
+ std::string Url;
};
struct Session
diff --git a/src/zenhttp/httpclientauth.cpp b/src/zenhttp/httpclientauth.cpp
index c42841922..0432e50ef 100644
--- a/src/zenhttp/httpclientauth.cpp
+++ b/src/zenhttp/httpclientauth.cpp
@@ -94,7 +94,8 @@ namespace zen { namespace httpclientauth {
std::string_view CloudHost,
bool Unattended,
bool Quiet,
- bool Hidden)
+ bool Hidden,
+ bool IsHordeUrl)
{
Stopwatch Timer;
@@ -117,8 +118,9 @@ namespace zen { namespace httpclientauth {
}
});
- const std::string ProcArgs = fmt::format("{} --AuthConfigUrl {} --OutFile {} --Unattended={}",
+ const std::string ProcArgs = fmt::format("{} {} {} --OutFile {} --Unattended={}",
OidcExecutablePath,
+ IsHordeUrl ? "--HordeUrl" : "--AuthConfigUrl",
CloudHost,
AuthTokenPath,
Unattended ? "true"sv : "false"sv);
@@ -193,7 +195,7 @@ namespace zen { namespace httpclientauth {
}
else
{
- ZEN_WARN("Failed running {} to get auth token, error code {}", OidcExecutablePath, ExitCode);
+ ZEN_WARN("Failed running '{}' to get auth token, error code {}", ProcArgs, ExitCode);
}
return HttpClientAccessToken{};
}
@@ -202,9 +204,10 @@ namespace zen { namespace httpclientauth {
std::string_view CloudHost,
bool Quiet,
bool Unattended,
- bool Hidden)
+ bool Hidden,
+ bool IsHordeUrl)
{
- HttpClientAccessToken InitialToken = GetOidcTokenFromExe(OidcExecutablePath, CloudHost, Unattended, Quiet, Hidden);
+ HttpClientAccessToken InitialToken = GetOidcTokenFromExe(OidcExecutablePath, CloudHost, Unattended, Quiet, Hidden, IsHordeUrl);
if (InitialToken.IsValid())
{
return [OidcExecutablePath = std::filesystem::path(OidcExecutablePath),
@@ -212,12 +215,13 @@ namespace zen { namespace httpclientauth {
Token = InitialToken,
Quiet,
Unattended,
- Hidden]() mutable {
+ Hidden,
+ IsHordeUrl]() mutable {
if (!Token.NeedsRefresh())
{
return std::move(Token);
}
- return GetOidcTokenFromExe(OidcExecutablePath, CloudHost, Unattended, Quiet, Hidden);
+ return GetOidcTokenFromExe(OidcExecutablePath, CloudHost, Unattended, Quiet, Hidden, IsHordeUrl);
};
}
return {};
diff --git a/src/zenhttp/include/zenhttp/httpclientauth.h b/src/zenhttp/include/zenhttp/httpclientauth.h
index ce646ebd7..9220a50b6 100644
--- a/src/zenhttp/include/zenhttp/httpclientauth.h
+++ b/src/zenhttp/include/zenhttp/httpclientauth.h
@@ -33,7 +33,8 @@ namespace httpclientauth {
std::string_view CloudHost,
bool Quiet,
bool Unattended,
- bool Hidden);
+ bool Hidden,
+ bool IsHordeUrl = false);
} // namespace httpclientauth
} // namespace zen
diff --git a/src/zennomad/include/zennomad/nomadclient.h b/src/zennomad/include/zennomad/nomadclient.h
index 0a3411ace..cebf217e1 100644
--- a/src/zennomad/include/zennomad/nomadclient.h
+++ b/src/zennomad/include/zennomad/nomadclient.h
@@ -52,7 +52,11 @@ public:
/** Build the Nomad job registration JSON for the given job ID and orchestrator endpoint.
* The JSON structure varies based on the configured driver and distribution mode. */
- std::string BuildJobJson(const std::string& JobId, const std::string& OrchestratorEndpoint) const;
+ std::string BuildJobJson(const std::string& JobId,
+ const std::string& OrchestratorEndpoint,
+ const std::string& CoordinatorSession = {},
+ bool CleanStart = false,
+ const std::string& TraceHost = {}) const;
/** Submit a job via PUT /v1/jobs. On success, populates OutJob with the job info. */
bool SubmitJob(const std::string& JobJson, NomadJobInfo& OutJob);
diff --git a/src/zennomad/include/zennomad/nomadprovisioner.h b/src/zennomad/include/zennomad/nomadprovisioner.h
index 750693b3f..a8368e3dc 100644
--- a/src/zennomad/include/zennomad/nomadprovisioner.h
+++ b/src/zennomad/include/zennomad/nomadprovisioner.h
@@ -47,7 +47,11 @@ public:
/** Construct a provisioner.
* @param Config Nomad connection and job configuration.
* @param OrchestratorEndpoint URL of the orchestrator that remote workers announce to. */
- NomadProvisioner(const NomadConfig& Config, std::string_view OrchestratorEndpoint);
+ NomadProvisioner(const NomadConfig& Config,
+ std::string_view OrchestratorEndpoint,
+ std::string_view CoordinatorSession = {},
+ bool CleanStart = false,
+ std::string_view TraceHost = {});
/** Signals the management thread to exit and stops all tracked jobs. */
~NomadProvisioner();
@@ -83,6 +87,9 @@ private:
NomadConfig m_Config;
std::string m_OrchestratorEndpoint;
+ std::string m_CoordinatorSession;
+ bool m_CleanStart = false;
+ std::string m_TraceHost;
std::unique_ptr<NomadClient> m_Client;
diff --git a/src/zennomad/nomadclient.cpp b/src/zennomad/nomadclient.cpp
index 9edcde125..4bb09a930 100644
--- a/src/zennomad/nomadclient.cpp
+++ b/src/zennomad/nomadclient.cpp
@@ -58,7 +58,11 @@ NomadClient::Initialize()
}
std::string
-NomadClient::BuildJobJson(const std::string& JobId, const std::string& OrchestratorEndpoint) const
+NomadClient::BuildJobJson(const std::string& JobId,
+ const std::string& OrchestratorEndpoint,
+ const std::string& CoordinatorSession,
+ bool CleanStart,
+ const std::string& TraceHost) const
{
ZEN_TRACE_CPU("NomadClient::BuildJobJson");
@@ -94,6 +98,22 @@ NomadClient::BuildJobJson(const std::string& JobId, const std::string& Orchestra
IdArg << "--instance-id=nomad-" << JobId;
Args.push_back(std::string(IdArg.ToView()));
}
+ if (!CoordinatorSession.empty())
+ {
+ ExtendableStringBuilder<128> SessionArg;
+ SessionArg << "--coordinator-session=" << CoordinatorSession;
+ Args.push_back(std::string(SessionArg.ToView()));
+ }
+ if (CleanStart)
+ {
+ Args.push_back("--clean");
+ }
+ if (!TraceHost.empty())
+ {
+ ExtendableStringBuilder<128> TraceArg;
+ TraceArg << "--tracehost=" << TraceHost;
+ Args.push_back(std::string(TraceArg.ToView()));
+ }
TaskConfig["args"] = Args;
}
else
@@ -115,6 +135,22 @@ NomadClient::BuildJobJson(const std::string& JobId, const std::string& Orchestra
IdArg << "--instance-id=nomad-" << JobId;
Args.push_back(std::string(IdArg.ToView()));
}
+ if (!CoordinatorSession.empty())
+ {
+ ExtendableStringBuilder<128> SessionArg;
+ SessionArg << "--coordinator-session=" << CoordinatorSession;
+ Args.push_back(std::string(SessionArg.ToView()));
+ }
+ if (CleanStart)
+ {
+ Args.push_back("--clean");
+ }
+ if (!TraceHost.empty())
+ {
+ ExtendableStringBuilder<128> TraceArg;
+ TraceArg << "--tracehost=" << TraceHost;
+ Args.push_back(std::string(TraceArg.ToView()));
+ }
TaskConfig["args"] = Args;
}
diff --git a/src/zennomad/nomadprovisioner.cpp b/src/zennomad/nomadprovisioner.cpp
index 3fe9c0ac3..e07ce155e 100644
--- a/src/zennomad/nomadprovisioner.cpp
+++ b/src/zennomad/nomadprovisioner.cpp
@@ -14,9 +14,16 @@
namespace zen::nomad {
-NomadProvisioner::NomadProvisioner(const NomadConfig& Config, std::string_view OrchestratorEndpoint)
+NomadProvisioner::NomadProvisioner(const NomadConfig& Config,
+ std::string_view OrchestratorEndpoint,
+ std::string_view CoordinatorSession,
+ bool CleanStart,
+ std::string_view TraceHost)
: m_Config(Config)
, m_OrchestratorEndpoint(OrchestratorEndpoint)
+, m_CoordinatorSession(CoordinatorSession)
+, m_CleanStart(CleanStart)
+, m_TraceHost(TraceHost)
, m_ProcessId(static_cast<uint32_t>(zen::GetCurrentProcessId()))
, m_Log(zen::logging::Get("nomad.provisioner"))
{
@@ -154,7 +161,7 @@ NomadProvisioner::SubmitNewJobs()
ZEN_DEBUG("submitting job '{}' (estimated: {}, target: {})", JobId, m_EstimatedCoreCount.load(), m_TargetCoreCount.load());
- const std::string JobJson = m_Client->BuildJobJson(JobId, m_OrchestratorEndpoint);
+ const std::string JobJson = m_Client->BuildJobJson(JobId, m_OrchestratorEndpoint, m_CoordinatorSession, m_CleanStart, m_TraceHost);
NomadJobInfo JobInfo;
JobInfo.Id = JobId;
diff --git a/src/zenremotestore/builds/buildstorageoperations.cpp b/src/zenremotestore/builds/buildstorageoperations.cpp
index 1f8b96cc4..3a41cd7eb 100644
--- a/src/zenremotestore/builds/buildstorageoperations.cpp
+++ b/src/zenremotestore/builds/buildstorageoperations.cpp
@@ -1664,21 +1664,34 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState)
}
if (!BlockBuffer)
{
- BlockBuffer = m_Storage.BuildStorage->GetBuildBlob(m_BuildId, BlockDescription.BlockHash);
- if (BlockBuffer && m_Storage.CacheStorage && m_Options.PopulateCache)
+ try
+ {
+ BlockBuffer = m_Storage.BuildStorage->GetBuildBlob(m_BuildId, BlockDescription.BlockHash);
+ }
+ catch (const std::exception&)
+ {
+ // Silence http errors due to abort
+ if (!m_AbortFlag)
+ {
+ throw;
+ }
+ }
+ }
+ if (!m_AbortFlag)
+ {
+ if (!BlockBuffer)
+ {
+ throw std::runtime_error(fmt::format("Block {} is missing", BlockDescription.BlockHash));
+ }
+
+ if (m_Storage.CacheStorage && m_Options.PopulateCache)
{
m_Storage.CacheStorage->PutBuildBlob(m_BuildId,
BlockDescription.BlockHash,
ZenContentType::kCompressedBinary,
CompositeBuffer(SharedBuffer(BlockBuffer)));
}
- }
- if (!BlockBuffer)
- {
- throw std::runtime_error(fmt::format("Block {} is missing", BlockDescription.BlockHash));
- }
- if (!m_AbortFlag)
- {
+
uint64_t BlockSize = BlockBuffer.GetSize();
m_DownloadStats.DownloadedBlockCount++;
m_DownloadStats.DownloadedBlockByteCount += BlockSize;
@@ -3293,31 +3306,45 @@ BuildsOperationUpdateFolder::DownloadBuildBlob(uint32_t RemoteChunkInde
}
else
{
- BuildBlob = m_Storage.BuildStorage->GetBuildBlob(m_BuildId, ChunkHash);
- if (BuildBlob && m_Storage.CacheStorage && m_Options.PopulateCache)
+ try
{
- m_Storage.CacheStorage->PutBuildBlob(m_BuildId,
- ChunkHash,
- ZenContentType::kCompressedBinary,
- CompositeBuffer(SharedBuffer(BuildBlob)));
+ BuildBlob = m_Storage.BuildStorage->GetBuildBlob(m_BuildId, ChunkHash);
}
- if (!BuildBlob)
+ catch (const std::exception&)
{
- throw std::runtime_error(fmt::format("Chunk {} is missing", ChunkHash));
+ // Silence http errors due to abort
+ if (!m_AbortFlag)
+ {
+ throw;
+ }
}
- if (!m_Options.PrimeCacheOnly)
+ if (!m_AbortFlag)
{
- if (!m_AbortFlag)
+ if (BuildBlob && m_Storage.CacheStorage && m_Options.PopulateCache)
{
- uint64_t BlobSize = BuildBlob.GetSize();
- m_DownloadStats.DownloadedChunkCount++;
- m_DownloadStats.DownloadedChunkByteCount += BlobSize;
- if (m_DownloadStats.RequestsCompleteCount.fetch_add(1) + 1 == TotalRequestCount)
+ m_Storage.CacheStorage->PutBuildBlob(m_BuildId,
+ ChunkHash,
+ ZenContentType::kCompressedBinary,
+ CompositeBuffer(SharedBuffer(BuildBlob)));
+ }
+ if (!BuildBlob)
+ {
+ throw std::runtime_error(fmt::format("Chunk {} is missing", ChunkHash));
+ }
+ if (!m_Options.PrimeCacheOnly)
+ {
+ if (!m_AbortFlag)
{
- FilteredDownloadedBytesPerSecond.Stop();
- }
+ uint64_t BlobSize = BuildBlob.GetSize();
+ m_DownloadStats.DownloadedChunkCount++;
+ m_DownloadStats.DownloadedChunkByteCount += BlobSize;
+ if (m_DownloadStats.RequestsCompleteCount.fetch_add(1) + 1 == TotalRequestCount)
+ {
+ FilteredDownloadedBytesPerSecond.Stop();
+ }
- OnDownloaded(std::move(BuildBlob));
+ OnDownloaded(std::move(BuildBlob));
+ }
}
}
}
@@ -3519,64 +3546,77 @@ BuildsOperationUpdateFolder::DownloadPartialBlock(
auto SubRanges = RangesSpan.subspan(SubRangeCountComplete, SubRangeCount);
- BuildStorageBase::BuildBlobRanges RangeBuffers =
- m_Storage.BuildStorage->GetBuildBlobRanges(m_BuildId, BlockDescription.BlockHash, SubRanges);
- if (m_AbortFlag)
+ BuildStorageBase::BuildBlobRanges RangeBuffers;
+
+ try
{
- break;
+ RangeBuffers = m_Storage.BuildStorage->GetBuildBlobRanges(m_BuildId, BlockDescription.BlockHash, SubRanges);
}
- if (RangeBuffers.PayloadBuffer)
+ catch (const std::exception&)
{
- if (RangeBuffers.Ranges.empty())
+ // Silence http errors due to abort
+ if (!m_AbortFlag)
{
- // Jupiter will ignore the ranges and send the whole payload if it fetches the payload from S3
- // Upload to cache (if enabled) and use the whole payload for the remaining ranges
+ throw;
+ }
+ }
- if (m_Storage.CacheStorage && m_Options.PopulateCache)
+ if (!m_AbortFlag)
+ {
+ if (RangeBuffers.PayloadBuffer)
+ {
+ if (RangeBuffers.Ranges.empty())
{
- m_Storage.CacheStorage->PutBuildBlob(m_BuildId,
- BlockDescription.BlockHash,
- ZenContentType::kCompressedBinary,
- CompositeBuffer(std::vector<IoBuffer>{RangeBuffers.PayloadBuffer}));
- if (m_AbortFlag)
+ // Jupiter will ignore the ranges and send the whole payload if it fetches the payload from S3
+ // Upload to cache (if enabled) and use the whole payload for the remaining ranges
+
+ if (m_Storage.CacheStorage && m_Options.PopulateCache)
{
- break;
+ m_Storage.CacheStorage->PutBuildBlob(m_BuildId,
+ BlockDescription.BlockHash,
+ ZenContentType::kCompressedBinary,
+ CompositeBuffer(std::vector<IoBuffer>{RangeBuffers.PayloadBuffer}));
+ if (m_AbortFlag)
+ {
+ break;
+ }
}
- }
- SubRangeCount = Ranges.size() - SubRangeCountComplete;
- ProcessDownload(BlockDescription,
- std::move(RangeBuffers.PayloadBuffer),
- SubRangeStartIndex,
- RangesSpan.subspan(SubRangeCountComplete, SubRangeCount),
- TotalRequestCount,
- FilteredDownloadedBytesPerSecond,
- OnDownloaded);
+ SubRangeCount = Ranges.size() - SubRangeCountComplete;
+ ProcessDownload(BlockDescription,
+ std::move(RangeBuffers.PayloadBuffer),
+ SubRangeStartIndex,
+ RangesSpan.subspan(SubRangeCountComplete, SubRangeCount),
+ TotalRequestCount,
+ FilteredDownloadedBytesPerSecond,
+ OnDownloaded);
+ }
+ else
+ {
+ if (RangeBuffers.Ranges.size() != SubRanges.size())
+ {
+ throw std::runtime_error(fmt::format("Fetching {} ranges from {} resulted in {} ranges",
+ SubRanges.size(),
+ BlockDescription.BlockHash,
+ RangeBuffers.Ranges.size()));
+ }
+ ProcessDownload(BlockDescription,
+ std::move(RangeBuffers.PayloadBuffer),
+ SubRangeStartIndex,
+ RangeBuffers.Ranges,
+ TotalRequestCount,
+ FilteredDownloadedBytesPerSecond,
+ OnDownloaded);
+ }
}
else
{
- if (RangeBuffers.Ranges.size() != SubRanges.size())
- {
- throw std::runtime_error(fmt::format("Fetching {} ranges from {} resulted in {} ranges",
- SubRanges.size(),
- BlockDescription.BlockHash,
- RangeBuffers.Ranges.size()));
- }
- ProcessDownload(BlockDescription,
- std::move(RangeBuffers.PayloadBuffer),
- SubRangeStartIndex,
- RangeBuffers.Ranges,
- TotalRequestCount,
- FilteredDownloadedBytesPerSecond,
- OnDownloaded);
+ throw std::runtime_error(
+ fmt::format("Block {} is missing when fetching {} ranges", BlockDescription.BlockHash, SubRangeCount));
}
- }
- else
- {
- throw std::runtime_error(fmt::format("Block {} is missing when fetching {} ranges", BlockDescription.BlockHash, SubRangeCount));
- }
- SubRangeCountComplete += SubRangeCount;
+ SubRangeCountComplete += SubRangeCount;
+ }
}
}
@@ -5150,48 +5190,80 @@ BuildsOperationUploadFolder::GenerateBuildBlocks(const ChunkedFolderContent&
Payload.GetCompressed());
}
- m_Storage.BuildStorage->PutBuildBlob(m_BuildId,
- BlockHash,
- ZenContentType::kCompressedBinary,
- std::move(Payload).GetCompressed());
- UploadStats.BlocksBytes += CompressedBlockSize;
-
- if (m_Options.IsVerbose)
+ try
{
- ZEN_OPERATION_LOG_INFO(m_LogOutput,
- "Uploaded block {} ({}) containing {} chunks",
- BlockHash,
- NiceBytes(CompressedBlockSize),
- OutBlocks.BlockDescriptions[BlockIndex].ChunkRawHashes.size());
+ m_Storage.BuildStorage->PutBuildBlob(m_BuildId,
+ BlockHash,
+ ZenContentType::kCompressedBinary,
+ std::move(Payload).GetCompressed());
}
-
- if (m_Storage.CacheStorage && m_Options.PopulateCache)
+ catch (const std::exception&)
{
- m_Storage.CacheStorage->PutBlobMetadatas(m_BuildId,
- std::vector<IoHash>({BlockHash}),
- std::vector<CbObject>({BlockMetaData}));
+ // Silence http errors due to abort
+ if (!m_AbortFlag)
+ {
+ throw;
+ }
}
- bool MetadataSucceeded =
- m_Storage.BuildStorage->PutBlockMetadata(m_BuildId, BlockHash, BlockMetaData);
- if (MetadataSucceeded)
+ if (!m_AbortFlag)
{
+ UploadStats.BlocksBytes += CompressedBlockSize;
+
if (m_Options.IsVerbose)
{
- ZEN_OPERATION_LOG_INFO(m_LogOutput,
- "Uploaded block {} metadata ({})",
- BlockHash,
- NiceBytes(BlockMetaData.GetSize()));
+ ZEN_OPERATION_LOG_INFO(
+ m_LogOutput,
+ "Uploaded block {} ({}) containing {} chunks",
+ BlockHash,
+ NiceBytes(CompressedBlockSize),
+ OutBlocks.BlockDescriptions[BlockIndex].ChunkRawHashes.size());
}
- OutBlocks.MetaDataHasBeenUploaded[BlockIndex] = true;
- UploadStats.BlocksBytes += BlockMetaData.GetSize();
- }
+ if (m_Storage.CacheStorage && m_Options.PopulateCache)
+ {
+ m_Storage.CacheStorage->PutBlobMetadatas(m_BuildId,
+ std::vector<IoHash>({BlockHash}),
+ std::vector<CbObject>({BlockMetaData}));
+ }
- UploadStats.BlockCount++;
- if (UploadStats.BlockCount == NewBlockCount)
- {
- FilteredUploadedBytesPerSecond.Stop();
+ bool MetadataSucceeded = false;
+ try
+ {
+ MetadataSucceeded =
+ m_Storage.BuildStorage->PutBlockMetadata(m_BuildId, BlockHash, BlockMetaData);
+ }
+ catch (const std::exception&)
+ {
+ // Silence http errors due to abort
+ if (!m_AbortFlag)
+ {
+ throw;
+ }
+ }
+
+ if (!m_AbortFlag)
+ {
+ if (MetadataSucceeded)
+ {
+ if (m_Options.IsVerbose)
+ {
+ ZEN_OPERATION_LOG_INFO(m_LogOutput,
+ "Uploaded block {} metadata ({})",
+ BlockHash,
+ NiceBytes(BlockMetaData.GetSize()));
+ }
+
+ OutBlocks.MetaDataHasBeenUploaded[BlockIndex] = true;
+ UploadStats.BlocksBytes += BlockMetaData.GetSize();
+ }
+
+ UploadStats.BlockCount++;
+ if (UploadStats.BlockCount == NewBlockCount)
+ {
+ FilteredUploadedBytesPerSecond.Stop();
+ }
+ }
}
}
}
@@ -6215,44 +6287,76 @@ BuildsOperationUploadFolder::UploadPartBlobs(const ChunkedFolderContent& Co
{
m_Storage.CacheStorage->PutBuildBlob(m_BuildId, BlockHash, ZenContentType::kCompressedBinary, Payload);
}
- m_Storage.BuildStorage->PutBuildBlob(m_BuildId, BlockHash, ZenContentType::kCompressedBinary, Payload);
- if (m_Options.IsVerbose)
+
+ try
{
- ZEN_OPERATION_LOG_INFO(m_LogOutput,
- "Uploaded block {} ({}) containing {} chunks",
- BlockHash,
- NiceBytes(PayloadSize),
- NewBlocks.BlockDescriptions[BlockIndex].ChunkRawHashes.size());
+ m_Storage.BuildStorage->PutBuildBlob(m_BuildId, BlockHash, ZenContentType::kCompressedBinary, Payload);
}
- UploadedBlockSize += PayloadSize;
- TempUploadStats.BlocksBytes += PayloadSize;
-
- if (m_Storage.CacheStorage && m_Options.PopulateCache)
+ catch (const std::exception&)
{
- m_Storage.CacheStorage->PutBlobMetadatas(m_BuildId,
- std::vector<IoHash>({BlockHash}),
- std::vector<CbObject>({BlockMetaData}));
+ // Silence http errors due to abort
+ if (!m_AbortFlag)
+ {
+ throw;
+ }
}
- bool MetadataSucceeded = m_Storage.BuildStorage->PutBlockMetadata(m_BuildId, BlockHash, BlockMetaData);
- if (MetadataSucceeded)
+
+ if (!m_AbortFlag)
{
if (m_Options.IsVerbose)
{
ZEN_OPERATION_LOG_INFO(m_LogOutput,
- "Uploaded block {} metadata ({})",
+ "Uploaded block {} ({}) containing {} chunks",
BlockHash,
- NiceBytes(BlockMetaData.GetSize()));
+ NiceBytes(PayloadSize),
+ NewBlocks.BlockDescriptions[BlockIndex].ChunkRawHashes.size());
}
+ UploadedBlockSize += PayloadSize;
+ TempUploadStats.BlocksBytes += PayloadSize;
- NewBlocks.MetaDataHasBeenUploaded[BlockIndex] = true;
- TempUploadStats.BlocksBytes += BlockMetaData.GetSize();
- }
+ if (m_Storage.CacheStorage && m_Options.PopulateCache)
+ {
+ m_Storage.CacheStorage->PutBlobMetadatas(m_BuildId,
+ std::vector<IoHash>({BlockHash}),
+ std::vector<CbObject>({BlockMetaData}));
+ }
+
+ bool MetadataSucceeded = false;
+ try
+ {
+ MetadataSucceeded = m_Storage.BuildStorage->PutBlockMetadata(m_BuildId, BlockHash, BlockMetaData);
+ }
+ catch (const std::exception&)
+ {
+ // Silence http errors due to abort
+ if (!m_AbortFlag)
+ {
+ throw;
+ }
+ }
+ if (!m_AbortFlag)
+ {
+ if (MetadataSucceeded)
+ {
+ if (m_Options.IsVerbose)
+ {
+ ZEN_OPERATION_LOG_INFO(m_LogOutput,
+ "Uploaded block {} metadata ({})",
+ BlockHash,
+ NiceBytes(BlockMetaData.GetSize()));
+ }
+
+ NewBlocks.MetaDataHasBeenUploaded[BlockIndex] = true;
+ TempUploadStats.BlocksBytes += BlockMetaData.GetSize();
+ }
- TempUploadStats.BlockCount++;
+ TempUploadStats.BlockCount++;
- if (UploadedBlockCount.fetch_add(1) + 1 == UploadBlockCount && UploadedChunkCount == UploadChunkCount)
- {
- FilteredUploadedBytesPerSecond.Stop();
+ if (UploadedBlockCount.fetch_add(1) + 1 == UploadBlockCount && UploadedChunkCount == UploadChunkCount)
+ {
+ FilteredUploadedBytesPerSecond.Stop();
+ }
+ }
}
}
});
@@ -6302,72 +6406,100 @@ BuildsOperationUploadFolder::UploadPartBlobs(const ChunkedFolderContent& Co
{
ZEN_TRACE_CPU("AsyncUploadLooseChunk_Multipart");
TempUploadStats.MultipartAttachmentCount++;
- std::vector<std::function<void()>> MultipartWork = m_Storage.BuildStorage->PutLargeBuildBlob(
- m_BuildId,
- RawHash,
- ZenContentType::kCompressedBinary,
- PayloadSize,
- [Payload = std::move(Payload), &FilteredUploadedBytesPerSecond](uint64_t Offset,
- uint64_t Size) mutable -> IoBuffer {
- FilteredUploadedBytesPerSecond.Start();
-
- IoBuffer PartPayload = Payload.Mid(Offset, Size).Flatten().AsIoBuffer();
- PartPayload.SetContentType(ZenContentType::kBinary);
- return PartPayload;
- },
- [RawSize,
- &TempUploadStats,
- &UploadedCompressedChunkSize,
- &UploadChunkPool,
- &UploadedBlockCount,
- UploadBlockCount,
- &UploadedChunkCount,
- UploadChunkCount,
- &FilteredUploadedBytesPerSecond,
- &UploadedRawChunkSize](uint64_t SentBytes, bool IsComplete) {
- TempUploadStats.ChunksBytes += SentBytes;
- UploadedCompressedChunkSize += SentBytes;
- if (IsComplete)
- {
- TempUploadStats.ChunkCount++;
- if (UploadedChunkCount.fetch_add(1) + 1 == UploadChunkCount &&
- UploadedBlockCount == UploadBlockCount)
+ try
+ {
+ std::vector<std::function<void()>> MultipartWork = m_Storage.BuildStorage->PutLargeBuildBlob(
+ m_BuildId,
+ RawHash,
+ ZenContentType::kCompressedBinary,
+ PayloadSize,
+ [Payload = std::move(Payload), &FilteredUploadedBytesPerSecond](uint64_t Offset,
+ uint64_t Size) mutable -> IoBuffer {
+ FilteredUploadedBytesPerSecond.Start();
+
+ IoBuffer PartPayload = Payload.Mid(Offset, Size).Flatten().AsIoBuffer();
+ PartPayload.SetContentType(ZenContentType::kBinary);
+ return PartPayload;
+ },
+ [RawSize,
+ &TempUploadStats,
+ &UploadedCompressedChunkSize,
+ &UploadChunkPool,
+ &UploadedBlockCount,
+ UploadBlockCount,
+ &UploadedChunkCount,
+ UploadChunkCount,
+ &FilteredUploadedBytesPerSecond,
+ &UploadedRawChunkSize](uint64_t SentBytes, bool IsComplete) {
+ TempUploadStats.ChunksBytes += SentBytes;
+ UploadedCompressedChunkSize += SentBytes;
+ if (IsComplete)
{
- FilteredUploadedBytesPerSecond.Stop();
+ TempUploadStats.ChunkCount++;
+ if (UploadedChunkCount.fetch_add(1) + 1 == UploadChunkCount &&
+ UploadedBlockCount == UploadBlockCount)
+ {
+ FilteredUploadedBytesPerSecond.Stop();
+ }
+ UploadedRawChunkSize += RawSize;
}
- UploadedRawChunkSize += RawSize;
- }
- });
- for (auto& WorkPart : MultipartWork)
- {
- Work.ScheduleWork(UploadChunkPool, [Work = std::move(WorkPart)](std::atomic<bool>& AbortFlag) {
- ZEN_TRACE_CPU("AsyncUploadLooseChunk_Multipart_Work");
- if (!AbortFlag)
- {
- Work();
- }
- });
+ });
+ for (auto& WorkPart : MultipartWork)
+ {
+ Work.ScheduleWork(UploadChunkPool, [Work = std::move(WorkPart)](std::atomic<bool>& AbortFlag) {
+ ZEN_TRACE_CPU("AsyncUploadLooseChunk_Multipart_Work");
+ if (!AbortFlag)
+ {
+ Work();
+ }
+ });
+ }
+ if (m_Options.IsVerbose)
+ {
+ ZEN_OPERATION_LOG_INFO(m_LogOutput,
+ "Uploaded multipart chunk {} ({})",
+ RawHash,
+ NiceBytes(PayloadSize));
+ }
}
- if (m_Options.IsVerbose)
+ catch (const std::exception&)
{
- ZEN_OPERATION_LOG_INFO(m_LogOutput, "Uploaded multipart chunk {} ({})", RawHash, NiceBytes(PayloadSize));
+ // Silence http errors due to abort
+ if (!m_AbortFlag)
+ {
+ throw;
+ }
}
}
else
{
ZEN_TRACE_CPU("AsyncUploadLooseChunk_Singlepart");
- m_Storage.BuildStorage->PutBuildBlob(m_BuildId, RawHash, ZenContentType::kCompressedBinary, Payload);
- if (m_Options.IsVerbose)
+ try
{
- ZEN_OPERATION_LOG_INFO(m_LogOutput, "Uploaded chunk {} ({})", RawHash, NiceBytes(PayloadSize));
+ m_Storage.BuildStorage->PutBuildBlob(m_BuildId, RawHash, ZenContentType::kCompressedBinary, Payload);
+ }
+ catch (const std::exception&)
+ {
+ // Silence http errors due to abort
+ if (!m_AbortFlag)
+ {
+ throw;
+ }
}
- TempUploadStats.ChunksBytes += Payload.GetSize();
- TempUploadStats.ChunkCount++;
- UploadedCompressedChunkSize += Payload.GetSize();
- UploadedRawChunkSize += RawSize;
- if (UploadedChunkCount.fetch_add(1) + 1 == UploadChunkCount && UploadedBlockCount == UploadBlockCount)
+ if (!m_AbortFlag)
{
- FilteredUploadedBytesPerSecond.Stop();
+ if (m_Options.IsVerbose)
+ {
+ ZEN_OPERATION_LOG_INFO(m_LogOutput, "Uploaded chunk {} ({})", RawHash, NiceBytes(PayloadSize));
+ }
+ TempUploadStats.ChunksBytes += Payload.GetSize();
+ TempUploadStats.ChunkCount++;
+ UploadedCompressedChunkSize += Payload.GetSize();
+ UploadedRawChunkSize += RawSize;
+ if (UploadedChunkCount.fetch_add(1) + 1 == UploadChunkCount && UploadedBlockCount == UploadBlockCount)
+ {
+ FilteredUploadedBytesPerSecond.Stop();
+ }
}
}
}
@@ -7147,10 +7279,23 @@ BuildsOperationPrimeCache::Execute()
}
else
{
- IoBuffer Payload = m_Storage.BuildStorage->GetBuildBlob(m_BuildId, BlobHash);
- m_DownloadStats.DownloadedBlockCount++;
- m_DownloadStats.DownloadedBlockByteCount += Payload.GetSize();
- m_DownloadStats.RequestsCompleteCount++;
+ IoBuffer Payload;
+ try
+ {
+ Payload = m_Storage.BuildStorage->GetBuildBlob(m_BuildId, BlobHash);
+
+ m_DownloadStats.DownloadedBlockCount++;
+ m_DownloadStats.DownloadedBlockByteCount += Payload.GetSize();
+ m_DownloadStats.RequestsCompleteCount++;
+ }
+ catch (const std::exception&)
+ {
+ // Silence http errors due to abort
+ if (!m_AbortFlag)
+ {
+ throw;
+ }
+ }
if (!m_AbortFlag)
{
@@ -7161,10 +7306,10 @@ BuildsOperationPrimeCache::Execute()
ZenContentType::kCompressedBinary,
CompositeBuffer(SharedBuffer(std::move(Payload))));
}
- }
- if (CompletedDownloadCount.fetch_add(1) + 1 == BlobCount)
- {
- FilteredDownloadedBytesPerSecond.Stop();
+ if (CompletedDownloadCount.fetch_add(1) + 1 == BlobCount)
+ {
+ FilteredDownloadedBytesPerSecond.Stop();
+ }
}
}
}
diff --git a/src/zenserver/compute/computeserver.cpp b/src/zenserver/compute/computeserver.cpp
index 7296098e0..8cd8b4cfe 100644
--- a/src/zenserver/compute/computeserver.cpp
+++ b/src/zenserver/compute/computeserver.cpp
@@ -22,6 +22,8 @@
# if ZEN_WITH_HORDE
# include <zenhorde/hordeconfig.h>
# include <zenhorde/hordeprovisioner.h>
+# include <zenhttp/httpclientauth.h>
+# include <zenutil/authutils.h>
# endif
# if ZEN_WITH_NOMAD
# include <zennomad/nomadconfig.h>
@@ -67,6 +69,20 @@ ZenComputeServerConfigurator::AddCliOptions(cxxopts::Options& Options)
Options.add_option("compute",
"",
+ "coordinator-session",
+ "Session ID of the orchestrator (for stale-instance rejection)",
+ cxxopts::value<std::string>(m_ServerOptions.CoordinatorSession)->default_value(""),
+ "");
+
+ Options.add_option("compute",
+ "",
+ "announce-url",
+ "Override URL announced to the coordinator (e.g. relay-visible endpoint)",
+ cxxopts::value<std::string>(m_ServerOptions.AnnounceUrl)->default_value(""),
+ "");
+
+ Options.add_option("compute",
+ "",
"idms",
"Enable IDMS cloud detection; optionally specify a custom probe endpoint",
cxxopts::value<std::string>(m_ServerOptions.IdmsEndpoint)->default_value("")->implicit_value("auto"),
@@ -79,6 +95,20 @@ ZenComputeServerConfigurator::AddCliOptions(cxxopts::Options& Options)
cxxopts::value<bool>(m_ServerOptions.EnableWorkerWebSocket)->default_value("false"),
"");
+ Options.add_option("compute",
+ "",
+ "provision-clean",
+ "Pass --clean to provisioned worker instances so they wipe state on startup",
+ cxxopts::value<bool>(m_ServerOptions.ProvisionClean)->default_value("false"),
+ "");
+
+ Options.add_option("compute",
+ "",
+ "provision-tracehost",
+ "Pass --tracehost to provisioned worker instances for remote trace collection",
+ cxxopts::value<std::string>(m_ServerOptions.ProvisionTraceHost)->default_value(""),
+ "");
+
# if ZEN_WITH_HORDE
// Horde provisioning options
Options.add_option("horde",
@@ -139,6 +169,13 @@ ZenComputeServerConfigurator::AddCliOptions(cxxopts::Options& Options)
Options.add_option("horde",
"",
+ "horde-drain-grace-period",
+ "Grace period in seconds for draining agents before force-kill",
+ cxxopts::value<int>(m_ServerOptions.HordeConfig.DrainGracePeriodSeconds)->default_value("300"),
+ "");
+
+ Options.add_option("horde",
+ "",
"horde-host",
"Host address for Horde agents to connect back to",
cxxopts::value<std::string>(m_ServerOptions.HordeConfig.HostAddress)->default_value(""),
@@ -164,6 +201,13 @@ ZenComputeServerConfigurator::AddCliOptions(cxxopts::Options& Options)
"Port number for Zen service communication",
cxxopts::value<uint16_t>(m_ServerOptions.HordeConfig.ZenServicePort)->default_value("8558"),
"");
+
+ Options.add_option("horde",
+ "",
+ "horde-oidctoken-exe-path",
+ "Path to OidcToken executable for automatic Horde authentication",
+ cxxopts::value<std::string>(m_HordeOidcTokenExePath)->default_value(""),
+ "");
# endif
# if ZEN_WITH_NOMAD
@@ -313,6 +357,30 @@ ZenComputeServerConfigurator::ValidateOptions()
# if ZEN_WITH_HORDE
horde::FromString(m_ServerOptions.HordeConfig.Mode, m_HordeModeStr);
horde::FromString(m_ServerOptions.HordeConfig.EncryptionMode, m_HordeEncryptionStr);
+
+ // Set up OidcToken-based authentication if no static token was provided
+ if (m_ServerOptions.HordeConfig.AuthToken.empty() && !m_ServerOptions.HordeConfig.ServerUrl.empty())
+ {
+ std::filesystem::path OidcExePath = FindOidcTokenExePath(m_HordeOidcTokenExePath);
+ if (!OidcExePath.empty())
+ {
+ ZEN_INFO("using OidcToken executable for Horde authentication: {}", OidcExePath);
+ auto Provider = httpclientauth::CreateFromOidcTokenExecutable(OidcExePath,
+ m_ServerOptions.HordeConfig.ServerUrl,
+ /*Quiet=*/true,
+ /*Unattended=*/false,
+ /*Hidden=*/true,
+ /*IsHordeUrl=*/true);
+ if (Provider)
+ {
+ m_ServerOptions.HordeConfig.AccessTokenProvider = std::move(*Provider);
+ }
+ else
+ {
+ ZEN_WARN("OidcToken authentication failed; Horde requests will be unauthenticated");
+ }
+ }
+ }
# endif
# if ZEN_WITH_NOMAD
@@ -347,6 +415,8 @@ ZenComputeServer::Initialize(const ZenComputeServerConfig& ServerConfig, ZenServ
}
m_CoordinatorEndpoint = ServerConfig.CoordinatorEndpoint;
+ m_CoordinatorSession = ServerConfig.CoordinatorSession;
+ m_AnnounceUrl = ServerConfig.AnnounceUrl;
m_InstanceId = ServerConfig.InstanceId;
m_EnableWorkerWebSocket = ServerConfig.EnableWorkerWebSocket;
@@ -379,7 +449,14 @@ ZenComputeServer::Cleanup()
m_AnnounceTimer.cancel();
# if ZEN_WITH_HORDE
- // Shut down Horde provisioner first — this signals all agent threads
+ // Disconnect the provisioner state provider before destroying the
+ // provisioner so the orchestrator HTTP layer cannot call into it.
+ if (m_OrchestratorService)
+ {
+ m_OrchestratorService->SetProvisionerStateProvider(nullptr);
+ }
+
+ // Shut down Horde provisioner — this signals all agent threads
// to exit and joins them before we tear down HTTP services.
m_HordeProvisioner.reset();
# endif
@@ -482,6 +559,7 @@ ZenComputeServer::InitializeServices(const ZenComputeServerConfig& ServerConfig)
m_StatsService,
ServerConfig.DataDir / "functions",
ServerConfig.MaxConcurrentActions);
+ m_ComputeService->SetShutdownCallback([this] { RequestExit(0); });
m_FrontendService = std::make_unique<HttpFrontendService>(m_ContentRoot, m_StatsService, m_StatusService);
@@ -506,7 +584,11 @@ ZenComputeServer::InitializeServices(const ZenComputeServerConfig& ServerConfig)
OrchestratorEndpoint << '/';
}
- m_NomadProvisioner = std::make_unique<nomad::NomadProvisioner>(NomadCfg, OrchestratorEndpoint);
+ m_NomadProvisioner = std::make_unique<nomad::NomadProvisioner>(NomadCfg,
+ OrchestratorEndpoint,
+ m_OrchestratorService->GetSessionId().ToString(),
+ ServerConfig.ProvisionClean,
+ ServerConfig.ProvisionTraceHost);
}
}
# endif
@@ -537,7 +619,14 @@ ZenComputeServer::InitializeServices(const ZenComputeServerConfig& ServerConfig)
: std::filesystem::path(HordeConfig.BinariesPath);
std::filesystem::path WorkingDir = ServerConfig.DataDir / "horde";
- m_HordeProvisioner = std::make_unique<horde::HordeProvisioner>(HordeConfig, BinariesPath, WorkingDir, OrchestratorEndpoint);
+ m_HordeProvisioner = std::make_unique<horde::HordeProvisioner>(HordeConfig,
+ BinariesPath,
+ WorkingDir,
+ OrchestratorEndpoint,
+ m_OrchestratorService->GetSessionId().ToString(),
+ ServerConfig.ProvisionClean,
+ ServerConfig.ProvisionTraceHost);
+ m_OrchestratorService->SetProvisionerStateProvider(m_HordeProvisioner.get());
}
}
# endif
@@ -565,6 +654,10 @@ ZenComputeServer::GetInstanceId() const
std::string
ZenComputeServer::GetAnnounceUrl() const
{
+ if (!m_AnnounceUrl.empty())
+ {
+ return m_AnnounceUrl;
+ }
return m_Http->GetServiceUri(nullptr);
}
@@ -635,6 +728,11 @@ ZenComputeServer::BuildAnnounceBody()
<< "nomad";
}
+ if (!m_CoordinatorSession.empty())
+ {
+ AnnounceBody << "coordinator_session" << m_CoordinatorSession;
+ }
+
ResolveCloudMetadata();
if (m_CloudMetadata)
{
@@ -781,8 +879,10 @@ ZenComputeServer::ProvisionerMaintenanceTick()
# if ZEN_WITH_HORDE
if (m_HordeProvisioner)
{
- m_HordeProvisioner->SetTargetCoreCount(UINT32_MAX);
+ // Re-apply current target to spawn agent threads for any that have
+ // exited since the last tick, without overwriting a user-set target.
auto Stats = m_HordeProvisioner->GetStats();
+ m_HordeProvisioner->SetTargetCoreCount(Stats.TargetCoreCount);
ZEN_DEBUG("Horde maintenance: target={}, estimated={}, active={}",
Stats.TargetCoreCount,
Stats.EstimatedCoreCount,
diff --git a/src/zenserver/compute/computeserver.h b/src/zenserver/compute/computeserver.h
index 38f93bc36..63db7e9b3 100644
--- a/src/zenserver/compute/computeserver.h
+++ b/src/zenserver/compute/computeserver.h
@@ -49,9 +49,13 @@ struct ZenComputeServerConfig : public ZenServerConfig
std::string UpstreamNotificationEndpoint;
std::string InstanceId; // For use in notifications
std::string CoordinatorEndpoint;
+ std::string CoordinatorSession; ///< Session ID for stale-instance rejection
+ std::string AnnounceUrl; ///< Override for self-announced URL (e.g. relay-visible endpoint)
std::string IdmsEndpoint;
int32_t MaxConcurrentActions = 0; // 0 = auto (LogicalProcessorCount * 2)
bool EnableWorkerWebSocket = false; // Use WebSocket for worker↔orchestrator link
+ bool ProvisionClean = false; // Pass --clean to provisioned workers
+ std::string ProvisionTraceHost; // Pass --tracehost to provisioned workers
# if ZEN_WITH_HORDE
horde::HordeConfig HordeConfig;
@@ -84,6 +88,7 @@ private:
# if ZEN_WITH_HORDE
std::string m_HordeModeStr = "direct";
std::string m_HordeEncryptionStr = "none";
+ std::string m_HordeOidcTokenExePath;
# endif
# if ZEN_WITH_NOMAD
@@ -147,6 +152,8 @@ private:
# endif
SystemMetricsTracker m_MetricsTracker;
std::string m_CoordinatorEndpoint;
+ std::string m_CoordinatorSession;
+ std::string m_AnnounceUrl;
std::string m_InstanceId;
asio::steady_timer m_AnnounceTimer{m_IoContext};
diff --git a/src/zenserver/config/config.cpp b/src/zenserver/config/config.cpp
index daad154bc..6449159fd 100644
--- a/src/zenserver/config/config.cpp
+++ b/src/zenserver/config/config.cpp
@@ -12,6 +12,7 @@
#include <zencore/compactbinaryutil.h>
#include <zencore/compactbinaryvalidation.h>
#include <zencore/except.h>
+#include <zencore/filesystem.h>
#include <zencore/fmtutils.h>
#include <zencore/iobuffer.h>
#include <zencore/logging.h>
@@ -478,15 +479,27 @@ ZenServerCmdLineOptions::ApplyOptions(cxxopts::Options& options, ZenServerConfig
throw std::runtime_error(fmt::format("'--snapshot-dir' ('{}') must be a directory", ServerOptions.BaseSnapshotDir));
}
- ServerOptions.SystemRootDir = MakeSafeAbsolutePath(SystemRootDir);
- ServerOptions.DataDir = MakeSafeAbsolutePath(DataDir);
- ServerOptions.ContentDir = MakeSafeAbsolutePath(ContentDir);
- ServerOptions.ConfigFile = MakeSafeAbsolutePath(ConfigFile);
- ServerOptions.BaseSnapshotDir = MakeSafeAbsolutePath(BaseSnapshotDir);
+ SystemRootDir = ExpandEnvironmentVariables(SystemRootDir);
+ ServerOptions.SystemRootDir = MakeSafeAbsolutePath(SystemRootDir);
+
+ DataDir = ExpandEnvironmentVariables(DataDir);
+ ServerOptions.DataDir = MakeSafeAbsolutePath(DataDir);
+
+ ContentDir = ExpandEnvironmentVariables(ContentDir);
+ ServerOptions.ContentDir = MakeSafeAbsolutePath(ContentDir);
+
+ ConfigFile = ExpandEnvironmentVariables(ConfigFile);
+ ServerOptions.ConfigFile = MakeSafeAbsolutePath(ConfigFile);
+
+ BaseSnapshotDir = ExpandEnvironmentVariables(BaseSnapshotDir);
+ ServerOptions.BaseSnapshotDir = MakeSafeAbsolutePath(BaseSnapshotDir);
+
+ ExpandEnvironmentVariables(SecurityConfigPath);
ServerOptions.SecurityConfigPath = MakeSafeAbsolutePath(SecurityConfigPath);
if (!UnixSocketPath.empty())
{
+ UnixSocketPath = ExpandEnvironmentVariables(UnixSocketPath);
ServerOptions.HttpConfig.UnixSocketPath = MakeSafeAbsolutePath(UnixSocketPath);
}
diff --git a/src/zenserver/frontend/html/compute/compute.html b/src/zenserver/frontend/html/compute/compute.html
deleted file mode 100644
index c07bbb692..000000000
--- a/src/zenserver/frontend/html/compute/compute.html
+++ /dev/null
@@ -1,925 +0,0 @@
-<!DOCTYPE html>
-<html lang="en">
-<head>
- <meta charset="UTF-8">
- <meta name="viewport" content="width=device-width, initial-scale=1.0">
- <title>Zen Compute Dashboard</title>
- <script src="https://cdn.jsdelivr.net/npm/[email protected]/dist/chart.umd.min.js"></script>
- <link rel="stylesheet" type="text/css" href="../zen.css" />
- <script src="../util/sanitize.js"></script>
- <script src="../theme.js"></script>
- <script src="../banner.js" defer></script>
- <script src="../nav.js" defer></script>
- <style>
- .grid {
- grid-template-columns: repeat(auto-fit, minmax(280px, 1fr));
- }
-
- .chart-container {
- position: relative;
- height: 300px;
- margin-top: 20px;
- }
-
- .stats-row {
- display: flex;
- justify-content: space-between;
- margin-bottom: 12px;
- padding: 8px 0;
- border-bottom: 1px solid var(--theme_border_subtle);
- }
-
- .stats-row:last-child {
- border-bottom: none;
- margin-bottom: 0;
- }
-
- .stats-label {
- color: var(--theme_g1);
- font-size: 13px;
- }
-
- .stats-value {
- color: var(--theme_bright);
- font-weight: 600;
- font-size: 13px;
- }
-
- .rate-stats {
- display: grid;
- grid-template-columns: repeat(3, 1fr);
- gap: 16px;
- margin-top: 16px;
- }
-
- .rate-item {
- text-align: center;
- }
-
- .rate-value {
- font-size: 20px;
- font-weight: 600;
- color: var(--theme_p0);
- }
-
- .rate-label {
- font-size: 11px;
- color: var(--theme_g1);
- margin-top: 4px;
- text-transform: uppercase;
- }
-
- .worker-row {
- cursor: pointer;
- transition: background 0.15s;
- }
-
- .worker-row:hover {
- background: var(--theme_p4);
- }
-
- .worker-row.selected {
- background: var(--theme_p3);
- }
-
- .worker-detail {
- margin-top: 20px;
- border-top: 1px solid var(--theme_g2);
- padding-top: 16px;
- }
-
- .worker-detail-title {
- font-size: 15px;
- font-weight: 600;
- color: var(--theme_bright);
- margin-bottom: 12px;
- }
-
- .detail-section {
- margin-bottom: 16px;
- }
-
- .detail-section-label {
- font-size: 11px;
- font-weight: 600;
- color: var(--theme_g1);
- text-transform: uppercase;
- letter-spacing: 0.5px;
- margin-bottom: 6px;
- }
-
- .detail-table {
- width: 100%;
- border-collapse: collapse;
- font-size: 12px;
- }
-
- .detail-table td {
- padding: 4px 8px;
- color: var(--theme_g0);
- border-bottom: 1px solid var(--theme_border_subtle);
- vertical-align: top;
- }
-
- .detail-table td:first-child {
- color: var(--theme_g1);
- width: 40%;
- font-family: monospace;
- }
-
- .detail-table tr:last-child td {
- border-bottom: none;
- }
-
- .detail-mono {
- font-family: monospace;
- font-size: 11px;
- color: var(--theme_g1);
- }
-
- .detail-tag {
- display: inline-block;
- padding: 2px 8px;
- border-radius: 4px;
- background: var(--theme_border_subtle);
- color: var(--theme_g0);
- font-size: 11px;
- margin: 2px 4px 2px 0;
- }
- </style>
-</head>
-<body>
- <div class="container" style="max-width: 1400px; margin: 0 auto;">
- <zen-banner cluster-status="nominal" load="0" tagline="Node Overview" logo-src="../favicon.ico"></zen-banner>
- <zen-nav>
- <a href="/dashboard/">Home</a>
- <a href="compute.html">Node</a>
- <a href="orchestrator.html">Orchestrator</a>
- </zen-nav>
- <div class="timestamp">Last updated: <span id="last-update">Never</span></div>
-
- <div id="error-container"></div>
-
- <!-- Action Queue Stats -->
- <div class="section-title">Action Queue</div>
- <div class="grid">
- <div class="card">
- <div class="card-title">Pending Actions</div>
- <div class="metric-value" id="actions-pending">-</div>
- <div class="metric-label">Waiting to be scheduled</div>
- </div>
- <div class="card">
- <div class="card-title">Running Actions</div>
- <div class="metric-value" id="actions-running">-</div>
- <div class="metric-label">Currently executing</div>
- </div>
- <div class="card">
- <div class="card-title">Completed Actions</div>
- <div class="metric-value" id="actions-complete">-</div>
- <div class="metric-label">Results available</div>
- </div>
- </div>
-
- <!-- Action Queue Chart -->
- <div class="card" style="margin-bottom: 30px;">
- <div class="card-title">Action Queue History</div>
- <div class="chart-container">
- <canvas id="queue-chart"></canvas>
- </div>
- </div>
-
- <!-- Performance Metrics -->
- <div class="section-title">Performance Metrics</div>
- <div class="card" style="margin-bottom: 30px;">
- <div class="card-title">Completion Rate</div>
- <div class="rate-stats">
- <div class="rate-item">
- <div class="rate-value" id="rate-1">-</div>
- <div class="rate-label">1 min rate</div>
- </div>
- <div class="rate-item">
- <div class="rate-value" id="rate-5">-</div>
- <div class="rate-label">5 min rate</div>
- </div>
- <div class="rate-item">
- <div class="rate-value" id="rate-15">-</div>
- <div class="rate-label">15 min rate</div>
- </div>
- </div>
- <div style="margin-top: 20px;">
- <div class="stats-row">
- <span class="stats-label">Total Retired</span>
- <span class="stats-value" id="retired-count">-</span>
- </div>
- <div class="stats-row">
- <span class="stats-label">Mean Rate</span>
- <span class="stats-value" id="rate-mean">-</span>
- </div>
- </div>
- </div>
-
- <!-- Workers -->
- <div class="section-title">Workers</div>
- <div class="card" style="margin-bottom: 30px;">
- <div class="card-title">Worker Status</div>
- <div class="stats-row">
- <span class="stats-label">Registered Workers</span>
- <span class="stats-value" id="worker-count">-</span>
- </div>
- <div id="worker-table-container" style="margin-top: 16px; display: none;">
- <table id="worker-table">
- <thead>
- <tr>
- <th>Name</th>
- <th>Platform</th>
- <th style="text-align: right;">Cores</th>
- <th style="text-align: right;">Timeout</th>
- <th style="text-align: right;">Functions</th>
- <th>Worker ID</th>
- </tr>
- </thead>
- <tbody id="worker-table-body"></tbody>
- </table>
- <div id="worker-detail" class="worker-detail" style="display: none;"></div>
- </div>
- </div>
-
- <!-- Queues -->
- <div class="section-title">Queues</div>
- <div class="card" style="margin-bottom: 30px;">
- <div class="card-title">Queue Status</div>
- <div id="queue-list-empty" class="empty-state" style="text-align: left;">No queues.</div>
- <div id="queue-list-container" style="display: none;">
- <table id="queue-list-table">
- <thead>
- <tr>
- <th style="text-align: right; width: 60px;">ID</th>
- <th style="text-align: center; width: 80px;">Status</th>
- <th style="text-align: right;">Active</th>
- <th style="text-align: right;">Completed</th>
- <th style="text-align: right;">Failed</th>
- <th style="text-align: right;">Abandoned</th>
- <th style="text-align: right;">Cancelled</th>
- <th>Token</th>
- </tr>
- </thead>
- <tbody id="queue-list-body"></tbody>
- </table>
- </div>
- </div>
-
- <!-- Action History -->
- <div class="section-title">Recent Actions</div>
- <div class="card" style="margin-bottom: 30px;">
- <div class="card-title">Action History</div>
- <div id="action-history-empty" class="empty-state" style="text-align: left;">No actions recorded yet.</div>
- <div id="action-history-container" style="display: none;">
- <table id="action-history-table">
- <thead>
- <tr>
- <th style="text-align: right; width: 60px;">LSN</th>
- <th style="text-align: right; width: 60px;">Queue</th>
- <th style="text-align: center; width: 70px;">Status</th>
- <th>Function</th>
- <th style="text-align: right; width: 80px;">Started</th>
- <th style="text-align: right; width: 80px;">Finished</th>
- <th style="text-align: right; width: 80px;">Duration</th>
- <th>Worker ID</th>
- <th>Action ID</th>
- </tr>
- </thead>
- <tbody id="action-history-body"></tbody>
- </table>
- </div>
- </div>
-
- <!-- System Resources -->
- <div class="section-title">System Resources</div>
- <div class="grid">
- <div class="card">
- <div class="card-title">CPU Usage</div>
- <div class="metric-value" id="cpu-usage">-</div>
- <div class="metric-label">Percent</div>
- <div class="progress-bar">
- <div class="progress-fill" id="cpu-progress" style="width: 0%"></div>
- </div>
- <div style="position: relative; height: 60px; margin-top: 12px;">
- <canvas id="cpu-chart"></canvas>
- </div>
- <div style="margin-top: 12px;">
- <div class="stats-row">
- <span class="stats-label">Packages</span>
- <span class="stats-value" id="cpu-packages">-</span>
- </div>
- <div class="stats-row">
- <span class="stats-label">Physical Cores</span>
- <span class="stats-value" id="cpu-cores">-</span>
- </div>
- <div class="stats-row">
- <span class="stats-label">Logical Processors</span>
- <span class="stats-value" id="cpu-lp">-</span>
- </div>
- </div>
- </div>
- <div class="card">
- <div class="card-title">Memory</div>
- <div class="stats-row">
- <span class="stats-label">Used</span>
- <span class="stats-value" id="memory-used">-</span>
- </div>
- <div class="stats-row">
- <span class="stats-label">Total</span>
- <span class="stats-value" id="memory-total">-</span>
- </div>
- <div class="progress-bar">
- <div class="progress-fill" id="memory-progress" style="width: 0%"></div>
- </div>
- </div>
- <div class="card">
- <div class="card-title">Disk</div>
- <div class="stats-row">
- <span class="stats-label">Used</span>
- <span class="stats-value" id="disk-used">-</span>
- </div>
- <div class="stats-row">
- <span class="stats-label">Total</span>
- <span class="stats-value" id="disk-total">-</span>
- </div>
- <div class="progress-bar">
- <div class="progress-fill" id="disk-progress" style="width: 0%"></div>
- </div>
- </div>
- </div>
- </div>
-
- <script>
- // Configuration
- const BASE_URL = window.location.origin;
- const REFRESH_INTERVAL = 2000; // 2 seconds
- const MAX_HISTORY_POINTS = 60; // Show last 2 minutes
-
- // Data storage
- const history = {
- timestamps: [],
- pending: [],
- running: [],
- completed: [],
- cpu: []
- };
-
- // CPU sparkline chart
- const cpuCtx = document.getElementById('cpu-chart').getContext('2d');
- const cpuChart = new Chart(cpuCtx, {
- type: 'line',
- data: {
- labels: [],
- datasets: [{
- data: [],
- borderColor: '#58a6ff',
- backgroundColor: 'rgba(88, 166, 255, 0.15)',
- borderWidth: 1.5,
- tension: 0.4,
- fill: true,
- pointRadius: 0
- }]
- },
- options: {
- responsive: true,
- maintainAspectRatio: false,
- animation: false,
- plugins: { legend: { display: false }, tooltip: { enabled: false } },
- scales: {
- x: { display: false },
- y: { display: false, min: 0, max: 100 }
- }
- }
- });
-
- // Queue chart setup
- const ctx = document.getElementById('queue-chart').getContext('2d');
- const chart = new Chart(ctx, {
- type: 'line',
- data: {
- labels: [],
- datasets: [
- {
- label: 'Pending',
- data: [],
- borderColor: '#f0883e',
- backgroundColor: 'rgba(240, 136, 62, 0.1)',
- tension: 0.4,
- fill: true
- },
- {
- label: 'Running',
- data: [],
- borderColor: '#58a6ff',
- backgroundColor: 'rgba(88, 166, 255, 0.1)',
- tension: 0.4,
- fill: true
- },
- {
- label: 'Completed',
- data: [],
- borderColor: '#3fb950',
- backgroundColor: 'rgba(63, 185, 80, 0.1)',
- tension: 0.4,
- fill: true
- }
- ]
- },
- options: {
- responsive: true,
- maintainAspectRatio: false,
- plugins: {
- legend: {
- display: true,
- labels: {
- color: '#8b949e'
- }
- }
- },
- scales: {
- x: {
- display: false
- },
- y: {
- beginAtZero: true,
- ticks: {
- color: '#8b949e'
- },
- grid: {
- color: '#21262d'
- }
- }
- }
- }
- });
-
- // Helper functions
-
- function formatBytes(bytes) {
- if (bytes === 0) return '0 B';
- const k = 1024;
- const sizes = ['B', 'KB', 'MB', 'GB', 'TB'];
- const i = Math.floor(Math.log(bytes) / Math.log(k));
- return parseFloat((bytes / Math.pow(k, i)).toFixed(2)) + ' ' + sizes[i];
- }
-
- function formatRate(rate) {
- return rate.toFixed(2) + '/s';
- }
-
- function showError(message) {
- const container = document.getElementById('error-container');
- container.innerHTML = `<div class="error">Error: ${escapeHtml(message)}</div>`;
- }
-
- function clearError() {
- document.getElementById('error-container').innerHTML = '';
- }
-
- function updateTimestamp() {
- const now = new Date();
- document.getElementById('last-update').textContent = now.toLocaleTimeString();
- }
-
- // Fetch functions
- async function fetchJSON(endpoint) {
- const response = await fetch(`${BASE_URL}${endpoint}`, {
- headers: {
- 'Accept': 'application/json'
- }
- });
- if (!response.ok) {
- throw new Error(`HTTP ${response.status}: ${response.statusText}`);
- }
- return await response.json();
- }
-
- async function fetchHealth() {
- try {
- const response = await fetch(`${BASE_URL}/compute/ready`);
- const isHealthy = response.status === 200;
-
- const banner = document.querySelector('zen-banner');
-
- if (isHealthy) {
- banner.setAttribute('cluster-status', 'nominal');
- banner.setAttribute('load', '0');
- } else {
- banner.setAttribute('cluster-status', 'degraded');
- banner.setAttribute('load', '0');
- }
-
- return isHealthy;
- } catch (error) {
- const banner = document.querySelector('zen-banner');
- banner.setAttribute('cluster-status', 'degraded');
- banner.setAttribute('load', '0');
- throw error;
- }
- }
-
- async function fetchStats() {
- const data = await fetchJSON('/stats/compute');
-
- // Update action counts
- document.getElementById('actions-pending').textContent = data.actions_pending || 0;
- document.getElementById('actions-running').textContent = data.actions_submitted || 0;
- document.getElementById('actions-complete').textContent = data.actions_complete || 0;
-
- // Update completion rates
- if (data.actions_retired) {
- document.getElementById('rate-1').textContent = formatRate(data.actions_retired.rate_1 || 0);
- document.getElementById('rate-5').textContent = formatRate(data.actions_retired.rate_5 || 0);
- document.getElementById('rate-15').textContent = formatRate(data.actions_retired.rate_15 || 0);
- document.getElementById('retired-count').textContent = data.actions_retired.count || 0;
- document.getElementById('rate-mean').textContent = formatRate(data.actions_retired.rate_mean || 0);
- }
-
- // Update chart
- const now = new Date().toLocaleTimeString();
- history.timestamps.push(now);
- history.pending.push(data.actions_pending || 0);
- history.running.push(data.actions_submitted || 0);
- history.completed.push(data.actions_complete || 0);
-
- // Keep only last N points
- if (history.timestamps.length > MAX_HISTORY_POINTS) {
- history.timestamps.shift();
- history.pending.shift();
- history.running.shift();
- history.completed.shift();
- }
-
- chart.data.labels = history.timestamps;
- chart.data.datasets[0].data = history.pending;
- chart.data.datasets[1].data = history.running;
- chart.data.datasets[2].data = history.completed;
- chart.update('none');
- }
-
- async function fetchSysInfo() {
- const data = await fetchJSON('/compute/sysinfo');
-
- // Update CPU
- const cpuUsage = data.cpu_usage || 0;
- document.getElementById('cpu-usage').textContent = cpuUsage.toFixed(1) + '%';
- document.getElementById('cpu-progress').style.width = cpuUsage + '%';
-
- const banner = document.querySelector('zen-banner');
- banner.setAttribute('load', cpuUsage.toFixed(1));
-
- history.cpu.push(cpuUsage);
- if (history.cpu.length > MAX_HISTORY_POINTS) history.cpu.shift();
- cpuChart.data.labels = history.cpu.map(() => '');
- cpuChart.data.datasets[0].data = history.cpu;
- cpuChart.update('none');
-
- document.getElementById('cpu-packages').textContent = data.cpu_count ?? '-';
- document.getElementById('cpu-cores').textContent = data.core_count ?? '-';
- document.getElementById('cpu-lp').textContent = data.lp_count ?? '-';
-
- // Update Memory
- const memUsed = data.memory_used || 0;
- const memTotal = data.memory_total || 1;
- const memPercent = (memUsed / memTotal) * 100;
- document.getElementById('memory-used').textContent = formatBytes(memUsed);
- document.getElementById('memory-total').textContent = formatBytes(memTotal);
- document.getElementById('memory-progress').style.width = memPercent + '%';
-
- // Update Disk
- const diskUsed = data.disk_used || 0;
- const diskTotal = data.disk_total || 1;
- const diskPercent = (diskUsed / diskTotal) * 100;
- document.getElementById('disk-used').textContent = formatBytes(diskUsed);
- document.getElementById('disk-total').textContent = formatBytes(diskTotal);
- document.getElementById('disk-progress').style.width = diskPercent + '%';
- }
-
- // Persists the selected worker ID across refreshes
- let selectedWorkerId = null;
-
- function renderWorkerDetail(id, desc) {
- const panel = document.getElementById('worker-detail');
-
- if (!desc) {
- panel.style.display = 'none';
- return;
- }
-
- function field(label, value) {
- return `<tr><td>${label}</td><td>${value ?? '-'}</td></tr>`;
- }
-
- function monoField(label, value) {
- return `<tr><td>${label}</td><td class="detail-mono">${value ?? '-'}</td></tr>`;
- }
-
- // Functions
- const functions = desc.functions || [];
- const functionsHtml = functions.length === 0 ? '<span style="color:var(--theme_faint);font-size:12px;">none</span>' :
- `<table class="detail-table">${functions.map(f =>
- `<tr><td>${escapeHtml(f.name || '-')}</td><td class="detail-mono">${escapeHtml(f.version || '-')}</td></tr>`
- ).join('')}</table>`;
-
- // Executables
- const executables = desc.executables || [];
- const totalExecSize = executables.reduce((sum, e) => sum + (e.size || 0), 0);
- const execHtml = executables.length === 0 ? '<span style="color:var(--theme_faint);font-size:12px;">none</span>' :
- `<table class="detail-table">
- <tr style="font-size:11px;">
- <td style="color:var(--theme_faint);padding-bottom:4px;">Path</td>
- <td style="color:var(--theme_faint);padding-bottom:4px;">Hash</td>
- <td style="color:var(--theme_faint);padding-bottom:4px;text-align:right;">Size</td>
- </tr>
- ${executables.map(e =>
- `<tr>
- <td>${escapeHtml(e.name || '-')}</td>
- <td class="detail-mono">${escapeHtml(e.hash || '-')}</td>
- <td style="text-align:right;white-space:nowrap;">${e.size != null ? formatBytes(e.size) : '-'}</td>
- </tr>`
- ).join('')}
- <tr style="border-top:1px solid var(--theme_g2);">
- <td style="color:var(--theme_g1);padding-top:6px;">Total</td>
- <td></td>
- <td style="text-align:right;white-space:nowrap;padding-top:6px;color:var(--theme_bright);font-weight:600;">${formatBytes(totalExecSize)}</td>
- </tr>
- </table>`;
-
- // Files
- const files = desc.files || [];
- const filesHtml = files.length === 0 ? '<span style="color:var(--theme_faint);font-size:12px;">none</span>' :
- `<table class="detail-table">${files.map(f =>
- `<tr><td>${escapeHtml(f.name || f)}</td><td class="detail-mono">${escapeHtml(f.hash || '')}</td></tr>`
- ).join('')}</table>`;
-
- // Dirs
- const dirs = desc.dirs || [];
- const dirsHtml = dirs.length === 0 ? '<span style="color:var(--theme_faint);font-size:12px;">none</span>' :
- dirs.map(d => `<span class="detail-tag">${escapeHtml(d)}</span>`).join('');
-
- // Environment
- const env = desc.environment || [];
- const envHtml = env.length === 0 ? '<span style="color:var(--theme_faint);font-size:12px;">none</span>' :
- env.map(e => `<span class="detail-tag">${escapeHtml(e)}</span>`).join('');
-
- panel.innerHTML = `
- <div class="worker-detail-title">${escapeHtml(desc.name || id)}</div>
- <div class="detail-section">
- <table class="detail-table">
- ${field('Worker ID', `<span class="detail-mono">${escapeHtml(id)}</span>`)}
- ${field('Path', escapeHtml(desc.path || '-'))}
- ${field('Platform', escapeHtml(desc.host || '-'))}
- ${monoField('Build System', desc.buildsystem_version)}
- ${field('Cores', desc.cores)}
- ${field('Timeout', desc.timeout != null ? desc.timeout + 's' : null)}
- </table>
- </div>
- <div class="detail-section">
- <div class="detail-section-label">Functions</div>
- ${functionsHtml}
- </div>
- <div class="detail-section">
- <div class="detail-section-label">Executables</div>
- ${execHtml}
- </div>
- <div class="detail-section">
- <div class="detail-section-label">Files</div>
- ${filesHtml}
- </div>
- <div class="detail-section">
- <div class="detail-section-label">Directories</div>
- ${dirsHtml}
- </div>
- <div class="detail-section">
- <div class="detail-section-label">Environment</div>
- ${envHtml}
- </div>
- `;
- panel.style.display = 'block';
- }
-
- async function fetchWorkers() {
- const data = await fetchJSON('/compute/workers');
- const workerIds = data.workers || [];
-
- document.getElementById('worker-count').textContent = workerIds.length;
-
- const container = document.getElementById('worker-table-container');
- const tbody = document.getElementById('worker-table-body');
-
- if (workerIds.length === 0) {
- container.style.display = 'none';
- selectedWorkerId = null;
- return;
- }
-
- const descriptors = await Promise.all(
- workerIds.map(id => fetchJSON(`/compute/workers/${id}`).catch(() => null))
- );
-
- // Build a map for quick lookup by ID
- const descriptorMap = {};
- workerIds.forEach((id, i) => { descriptorMap[id] = descriptors[i]; });
-
- tbody.innerHTML = '';
- descriptors.forEach((desc, i) => {
- const id = workerIds[i];
- const name = desc ? (desc.name || '-') : '-';
- const host = desc ? (desc.host || '-') : '-';
- const cores = desc ? (desc.cores != null ? desc.cores : '-') : '-';
- const timeout = desc ? (desc.timeout != null ? desc.timeout + 's' : '-') : '-';
- const functions = desc ? (desc.functions ? desc.functions.length : 0) : '-';
-
- const tr = document.createElement('tr');
- tr.className = 'worker-row' + (id === selectedWorkerId ? ' selected' : '');
- tr.dataset.workerId = id;
- tr.innerHTML = `
- <td style="color: var(--theme_bright);">${escapeHtml(name)}</td>
- <td>${escapeHtml(host)}</td>
- <td style="text-align: right;">${escapeHtml(String(cores))}</td>
- <td style="text-align: right;">${escapeHtml(String(timeout))}</td>
- <td style="text-align: right;">${escapeHtml(String(functions))}</td>
- <td style="color: var(--theme_g1); font-family: monospace; font-size: 11px;">${escapeHtml(id)}</td>
- `;
- tr.addEventListener('click', () => {
- document.querySelectorAll('.worker-row').forEach(r => r.classList.remove('selected'));
- if (selectedWorkerId === id) {
- // Toggle off
- selectedWorkerId = null;
- document.getElementById('worker-detail').style.display = 'none';
- } else {
- selectedWorkerId = id;
- tr.classList.add('selected');
- renderWorkerDetail(id, descriptorMap[id]);
- }
- });
- tbody.appendChild(tr);
- });
-
- // Re-render detail if selected worker is still present
- if (selectedWorkerId && descriptorMap[selectedWorkerId]) {
- renderWorkerDetail(selectedWorkerId, descriptorMap[selectedWorkerId]);
- } else if (selectedWorkerId && !descriptorMap[selectedWorkerId]) {
- selectedWorkerId = null;
- document.getElementById('worker-detail').style.display = 'none';
- }
-
- container.style.display = 'block';
- }
-
- // Windows FILETIME: 100ns ticks since 1601-01-01. Convert to JS Date.
- const FILETIME_EPOCH_OFFSET_MS = 11644473600000n;
- function filetimeToDate(ticks) {
- if (!ticks) return null;
- const ms = BigInt(ticks) / 10000n - FILETIME_EPOCH_OFFSET_MS;
- return new Date(Number(ms));
- }
-
- function formatTime(date) {
- if (!date) return '-';
- return date.toLocaleTimeString([], { hour: '2-digit', minute: '2-digit', second: '2-digit' });
- }
-
- function formatDuration(startDate, endDate) {
- if (!startDate || !endDate) return '-';
- const ms = endDate - startDate;
- if (ms < 0) return '-';
- if (ms < 1000) return ms + ' ms';
- if (ms < 60000) return (ms / 1000).toFixed(2) + ' s';
- const m = Math.floor(ms / 60000);
- const s = ((ms % 60000) / 1000).toFixed(0).padStart(2, '0');
- return `${m}m ${s}s`;
- }
-
- async function fetchQueues() {
- const data = await fetchJSON('/compute/queues');
- const queues = data.queues || [];
-
- const empty = document.getElementById('queue-list-empty');
- const container = document.getElementById('queue-list-container');
- const tbody = document.getElementById('queue-list-body');
-
- if (queues.length === 0) {
- empty.style.display = '';
- container.style.display = 'none';
- return;
- }
-
- empty.style.display = 'none';
- tbody.innerHTML = '';
-
- for (const q of queues) {
- const id = q.queue_id ?? '-';
- const badge = q.state === 'cancelled'
- ? '<span class="status-badge failure">cancelled</span>'
- : q.state === 'draining'
- ? '<span class="status-badge" style="background:color-mix(in srgb, var(--theme_warn) 15%, transparent);color:var(--theme_warn);">draining</span>'
- : q.is_complete
- ? '<span class="status-badge success">complete</span>'
- : '<span class="status-badge" style="background:color-mix(in srgb, var(--theme_p0) 15%, transparent);color:var(--theme_p0);">active</span>';
- const token = q.queue_token
- ? `<span class="detail-mono">${escapeHtml(q.queue_token)}</span>`
- : '<span style="color:var(--theme_faint);">-</span>';
-
- const tr = document.createElement('tr');
- tr.innerHTML = `
- <td style="text-align: right; font-family: monospace; color: var(--theme_bright);">${escapeHtml(String(id))}</td>
- <td style="text-align: center;">${badge}</td>
- <td style="text-align: right;">${q.active_count ?? 0}</td>
- <td style="text-align: right; color: var(--theme_ok);">${q.completed_count ?? 0}</td>
- <td style="text-align: right; color: var(--theme_fail);">${q.failed_count ?? 0}</td>
- <td style="text-align: right; color: var(--theme_warn);">${q.abandoned_count ?? 0}</td>
- <td style="text-align: right; color: var(--theme_warn);">${q.cancelled_count ?? 0}</td>
- <td>${token}</td>
- `;
- tbody.appendChild(tr);
- }
-
- container.style.display = 'block';
- }
-
- async function fetchActionHistory() {
- const data = await fetchJSON('/compute/jobs/history?limit=50');
- const entries = data.history || [];
-
- const empty = document.getElementById('action-history-empty');
- const container = document.getElementById('action-history-container');
- const tbody = document.getElementById('action-history-body');
-
- if (entries.length === 0) {
- empty.style.display = '';
- container.style.display = 'none';
- return;
- }
-
- empty.style.display = 'none';
- tbody.innerHTML = '';
-
- // Entries arrive oldest-first; reverse to show newest at top
- for (const entry of [...entries].reverse()) {
- const lsn = entry.lsn ?? '-';
- const succeeded = entry.succeeded;
- const badge = succeeded == null
- ? '<span class="status-badge" style="background:var(--theme_border_subtle);color:var(--theme_g1);">unknown</span>'
- : succeeded
- ? '<span class="status-badge success">ok</span>'
- : '<span class="status-badge failure">failed</span>';
- const desc = entry.actionDescriptor || {};
- const fn = desc.Function || '-';
- const workerId = entry.workerId || '-';
- const actionId = entry.actionId || '-';
-
- const startDate = filetimeToDate(entry.time_Running);
- const endDate = filetimeToDate(entry.time_Completed ?? entry.time_Failed);
-
- const queueId = entry.queueId || 0;
- const queueCell = queueId
- ? `<a href="/compute/queues/${queueId}" style="color: var(--theme_ln); text-decoration: none; font-family: monospace;">${escapeHtml(String(queueId))}</a>`
- : '<span style="color: var(--theme_faint);">-</span>';
-
- const tr = document.createElement('tr');
- tr.innerHTML = `
- <td style="text-align: right; font-family: monospace; color: var(--theme_g1);">${escapeHtml(String(lsn))}</td>
- <td style="text-align: right;">${queueCell}</td>
- <td style="text-align: center;">${badge}</td>
- <td style="color: var(--theme_bright);">${escapeHtml(fn)}</td>
- <td style="text-align: right; font-size: 12px; white-space: nowrap; color: var(--theme_g1);">${formatTime(startDate)}</td>
- <td style="text-align: right; font-size: 12px; white-space: nowrap; color: var(--theme_g1);">${formatTime(endDate)}</td>
- <td style="text-align: right; font-size: 12px; white-space: nowrap;">${formatDuration(startDate, endDate)}</td>
- <td style="font-family: monospace; font-size: 11px; color: var(--theme_g1);">${escapeHtml(workerId)}</td>
- <td style="font-family: monospace; font-size: 11px; color: var(--theme_g1);">${escapeHtml(actionId)}</td>
- `;
- tbody.appendChild(tr);
- }
-
- container.style.display = 'block';
- }
-
- async function updateDashboard() {
- try {
- await Promise.all([
- fetchHealth(),
- fetchStats(),
- fetchSysInfo(),
- fetchWorkers(),
- fetchQueues(),
- fetchActionHistory()
- ]);
-
- clearError();
- updateTimestamp();
- } catch (error) {
- console.error('Error updating dashboard:', error);
- showError(error.message);
- }
- }
-
- // Start updating
- updateDashboard();
- setInterval(updateDashboard, REFRESH_INTERVAL);
- </script>
-</body>
-</html>
diff --git a/src/zenserver/frontend/html/compute/index.html b/src/zenserver/frontend/html/compute/index.html
index 9597fd7f3..aaa09aec0 100644
--- a/src/zenserver/frontend/html/compute/index.html
+++ b/src/zenserver/frontend/html/compute/index.html
@@ -1 +1 @@
-<meta http-equiv="refresh" content="0; url=compute.html" /> \ No newline at end of file
+<meta http-equiv="refresh" content="0; url=/dashboard/?page=compute" /> \ No newline at end of file
diff --git a/src/zenserver/frontend/html/compute/orchestrator.html b/src/zenserver/frontend/html/compute/orchestrator.html
deleted file mode 100644
index d1a2bb015..000000000
--- a/src/zenserver/frontend/html/compute/orchestrator.html
+++ /dev/null
@@ -1,669 +0,0 @@
-<!DOCTYPE html>
-<html lang="en">
-<head>
- <meta charset="UTF-8">
- <meta name="viewport" content="width=device-width, initial-scale=1.0">
- <link rel="stylesheet" type="text/css" href="../zen.css" />
- <script src="../util/sanitize.js"></script>
- <script src="../theme.js"></script>
- <script src="../banner.js" defer></script>
- <script src="../nav.js" defer></script>
- <title>Zen Orchestrator Dashboard</title>
- <style>
- .agent-count {
- display: flex;
- align-items: center;
- gap: 8px;
- font-size: 14px;
- padding: 8px 16px;
- border-radius: 6px;
- background: var(--theme_g3);
- border: 1px solid var(--theme_g2);
- }
-
- .agent-count .count {
- font-size: 20px;
- font-weight: 600;
- color: var(--theme_bright);
- }
- </style>
-</head>
-<body>
- <div class="container" style="max-width: 1400px; margin: 0 auto;">
- <zen-banner cluster-status="nominal" load="0" logo-src="../favicon.ico"></zen-banner>
- <zen-nav>
- <a href="/dashboard/">Home</a>
- <a href="compute.html">Node</a>
- <a href="orchestrator.html">Orchestrator</a>
- </zen-nav>
- <div class="header">
- <div>
- <div class="timestamp">Last updated: <span id="last-update">Never</span></div>
- </div>
- <div class="agent-count">
- <span>Agents:</span>
- <span class="count" id="agent-count">-</span>
- </div>
- </div>
-
- <div id="error-container"></div>
-
- <div class="card">
- <div class="card-title">Compute Agents</div>
- <div id="empty-state" class="empty-state">No agents registered.</div>
- <table id="agent-table" style="display: none;">
- <thead>
- <tr>
- <th style="width: 40px; text-align: center;">Health</th>
- <th>Hostname</th>
- <th style="text-align: right;">CPUs</th>
- <th style="text-align: right;">CPU Usage</th>
- <th style="text-align: right;">Memory</th>
- <th style="text-align: right;">Queues</th>
- <th style="text-align: right;">Pending</th>
- <th style="text-align: right;">Running</th>
- <th style="text-align: right;">Completed</th>
- <th style="text-align: right;">Traffic</th>
- <th style="text-align: right;">Last Seen</th>
- </tr>
- </thead>
- <tbody id="agent-table-body"></tbody>
- </table>
- </div>
- <div class="card" style="margin-top: 20px;">
- <div class="card-title">Connected Clients</div>
- <div id="clients-empty" class="empty-state">No clients connected.</div>
- <table id="clients-table" style="display: none;">
- <thead>
- <tr>
- <th style="width: 40px; text-align: center;">Health</th>
- <th>Client ID</th>
- <th>Hostname</th>
- <th>Address</th>
- <th style="text-align: right;">Last Seen</th>
- </tr>
- </thead>
- <tbody id="clients-table-body"></tbody>
- </table>
- </div>
- <div class="card" style="margin-top: 20px;">
- <div style="display: flex; align-items: center; gap: 12px; margin-bottom: 12px;">
- <div class="card-title" style="margin-bottom: 0;">Event History</div>
- <div class="history-tabs">
- <button class="history-tab active" data-tab="workers" onclick="switchHistoryTab('workers')">Workers</button>
- <button class="history-tab" data-tab="clients" onclick="switchHistoryTab('clients')">Clients</button>
- </div>
- </div>
- <div id="history-panel-workers">
- <div id="history-empty" class="empty-state">No provisioning events recorded.</div>
- <table id="history-table" style="display: none;">
- <thead>
- <tr>
- <th>Time</th>
- <th>Event</th>
- <th>Worker</th>
- <th>Hostname</th>
- </tr>
- </thead>
- <tbody id="history-table-body"></tbody>
- </table>
- </div>
- <div id="history-panel-clients" style="display: none;">
- <div id="client-history-empty" class="empty-state">No client events recorded.</div>
- <table id="client-history-table" style="display: none;">
- <thead>
- <tr>
- <th>Time</th>
- <th>Event</th>
- <th>Client</th>
- <th>Hostname</th>
- </tr>
- </thead>
- <tbody id="client-history-table-body"></tbody>
- </table>
- </div>
- </div>
- </div>
-
- <script>
- const BASE_URL = window.location.origin;
- const REFRESH_INTERVAL = 2000;
-
- function showError(message) {
- document.getElementById('error-container').innerHTML =
- '<div class="error">Error: ' + escapeHtml(message) + '</div>';
- }
-
- function clearError() {
- document.getElementById('error-container').innerHTML = '';
- }
-
- function formatLastSeen(dtMs) {
- if (dtMs == null) return '-';
- var seconds = Math.floor(dtMs / 1000);
- if (seconds < 60) return seconds + 's ago';
- var minutes = Math.floor(seconds / 60);
- if (minutes < 60) return minutes + 'm ' + (seconds % 60) + 's ago';
- var hours = Math.floor(minutes / 60);
- return hours + 'h ' + (minutes % 60) + 'm ago';
- }
-
- function healthClass(dtMs, reachable) {
- if (reachable === false) return 'health-red';
- if (dtMs == null) return 'health-red';
- var seconds = dtMs / 1000;
- if (seconds < 30 && reachable === true) return 'health-green';
- if (seconds < 120) return 'health-yellow';
- return 'health-red';
- }
-
- function healthTitle(dtMs, reachable) {
- var seenStr = dtMs != null ? 'Last seen ' + formatLastSeen(dtMs) : 'Never seen';
- if (reachable === true) return seenStr + ' · Reachable';
- if (reachable === false) return seenStr + ' · Unreachable';
- return seenStr + ' · Reachability unknown';
- }
-
- function formatCpuUsage(percent) {
- if (percent == null || percent === 0) return '-';
- return percent.toFixed(1) + '%';
- }
-
- function formatMemory(usedBytes, totalBytes) {
- if (!totalBytes) return '-';
- var usedGiB = usedBytes / (1024 * 1024 * 1024);
- var totalGiB = totalBytes / (1024 * 1024 * 1024);
- return usedGiB.toFixed(1) + ' / ' + totalGiB.toFixed(1) + ' GiB';
- }
-
- function formatBytes(bytes) {
- if (!bytes) return '-';
- if (bytes < 1024) return bytes + ' B';
- if (bytes < 1024 * 1024) return (bytes / 1024).toFixed(1) + ' KiB';
- if (bytes < 1024 * 1024 * 1024) return (bytes / (1024 * 1024)).toFixed(1) + ' MiB';
- if (bytes < 1024 * 1024 * 1024 * 1024) return (bytes / (1024 * 1024 * 1024)).toFixed(1) + ' GiB';
- return (bytes / (1024 * 1024 * 1024 * 1024)).toFixed(1) + ' TiB';
- }
-
- function formatTraffic(recv, sent) {
- if (!recv && !sent) return '-';
- return formatBytes(recv) + ' / ' + formatBytes(sent);
- }
-
- function parseIpFromUri(uri) {
- try {
- var url = new URL(uri);
- var host = url.hostname;
- // Strip IPv6 brackets
- if (host.startsWith('[') && host.endsWith(']')) host = host.slice(1, -1);
- // Only handle IPv4
- var parts = host.split('.');
- if (parts.length !== 4) return null;
- var octets = parts.map(Number);
- if (octets.some(function(o) { return isNaN(o) || o < 0 || o > 255; })) return null;
- return octets;
- } catch (e) {
- return null;
- }
- }
-
- function computeCidr(ips) {
- if (ips.length === 0) return null;
- if (ips.length === 1) return ips[0].join('.') + '/32';
-
- // Convert each IP to a 32-bit integer
- var ints = ips.map(function(o) {
- return ((o[0] << 24) | (o[1] << 16) | (o[2] << 8) | o[3]) >>> 0;
- });
-
- // Find common prefix length by ANDing all identical high bits
- var common = ~0 >>> 0;
- for (var i = 1; i < ints.length; i++) {
- // XOR to find differing bits, then mask away everything from the first difference down
- var diff = (ints[0] ^ ints[i]) >>> 0;
- if (diff !== 0) {
- var bit = 31 - Math.floor(Math.log2(diff));
- var mask = bit > 0 ? ((~0 << (32 - bit)) >>> 0) : 0;
- common = (common & mask) >>> 0;
- }
- }
-
- // Count leading ones in the common mask
- var prefix = 0;
- for (var b = 31; b >= 0; b--) {
- if ((common >>> b) & 1) prefix++;
- else break;
- }
-
- // Network address
- var net = (ints[0] & common) >>> 0;
- var a = (net >>> 24) & 0xff;
- var bv = (net >>> 16) & 0xff;
- var c = (net >>> 8) & 0xff;
- var d = net & 0xff;
- return a + '.' + bv + '.' + c + '.' + d + '/' + prefix;
- }
-
- function renderDashboard(data) {
- var banner = document.querySelector('zen-banner');
- if (data.hostname) {
- banner.setAttribute('tagline', 'Orchestrator \u2014 ' + data.hostname);
- }
- var workers = data.workers || [];
-
- document.getElementById('agent-count').textContent = workers.length;
-
- if (workers.length === 0) {
- banner.setAttribute('cluster-status', 'degraded');
- banner.setAttribute('load', '0');
- } else {
- banner.setAttribute('cluster-status', 'nominal');
- }
-
- var emptyState = document.getElementById('empty-state');
- var table = document.getElementById('agent-table');
- var tbody = document.getElementById('agent-table-body');
-
- if (workers.length === 0) {
- emptyState.style.display = '';
- table.style.display = 'none';
- } else {
- emptyState.style.display = 'none';
- table.style.display = '';
-
- tbody.innerHTML = '';
- var totalCpus = 0;
- var totalWeightedCpuUsage = 0;
- var totalMemUsed = 0;
- var totalMemTotal = 0;
- var totalQueues = 0;
- var totalPending = 0;
- var totalRunning = 0;
- var totalCompleted = 0;
- var totalBytesRecv = 0;
- var totalBytesSent = 0;
- var allIps = [];
- for (var i = 0; i < workers.length; i++) {
- var w = workers[i];
- var uri = w.uri || '';
- var dt = w.dt;
- var dashboardUrl = uri + '/dashboard/compute/';
-
- var id = w.id || '';
-
- var hostname = w.hostname || '';
- var cpus = w.cpus || 0;
- totalCpus += cpus;
- if (cpus > 0 && typeof w.cpu_usage === 'number') {
- totalWeightedCpuUsage += w.cpu_usage * cpus;
- }
-
- var memTotal = w.memory_total || 0;
- var memUsed = w.memory_used || 0;
- totalMemTotal += memTotal;
- totalMemUsed += memUsed;
-
- var activeQueues = w.active_queues || 0;
- totalQueues += activeQueues;
-
- var actionsPending = w.actions_pending || 0;
- var actionsRunning = w.actions_running || 0;
- var actionsCompleted = w.actions_completed || 0;
- totalPending += actionsPending;
- totalRunning += actionsRunning;
- totalCompleted += actionsCompleted;
-
- var bytesRecv = w.bytes_received || 0;
- var bytesSent = w.bytes_sent || 0;
- totalBytesRecv += bytesRecv;
- totalBytesSent += bytesSent;
-
- var ip = parseIpFromUri(uri);
- if (ip) allIps.push(ip);
-
- var reachable = w.reachable;
- var hClass = healthClass(dt, reachable);
- var hTitle = healthTitle(dt, reachable);
-
- var platform = w.platform || '';
- var badges = '';
- if (platform) {
- var platColors = { windows: '#0078d4', wine: '#722f37', linux: '#e95420', macos: '#a2aaad' };
- var platColor = platColors[platform] || '#8b949e';
- badges += ' <span style="display:inline-block;padding:1px 6px;border-radius:10px;font-size:10px;font-weight:600;color:#fff;background:' + platColor + ';vertical-align:middle;margin-left:4px;">' + escapeHtml(platform) + '</span>';
- }
- var provisioner = w.provisioner || '';
- if (provisioner) {
- var provColors = { horde: '#8957e5', nomad: '#3fb950' };
- var provColor = provColors[provisioner] || '#8b949e';
- badges += ' <span style="display:inline-block;padding:1px 6px;border-radius:10px;font-size:10px;font-weight:600;color:#fff;background:' + provColor + ';vertical-align:middle;margin-left:4px;">' + escapeHtml(provisioner) + '</span>';
- }
-
- var tr = document.createElement('tr');
- tr.title = id;
- tr.innerHTML =
- '<td style="text-align: center;"><span class="health-dot ' + hClass + '" title="' + escapeHtml(hTitle) + '"></span></td>' +
- '<td><a href="' + escapeHtml(dashboardUrl) + '" target="_blank">' + escapeHtml(hostname) + '</a>' + badges + '</td>' +
- '<td style="text-align: right;">' + (cpus > 0 ? cpus : '-') + '</td>' +
- '<td style="text-align: right;">' + formatCpuUsage(w.cpu_usage) + '</td>' +
- '<td style="text-align: right;">' + formatMemory(memUsed, memTotal) + '</td>' +
- '<td style="text-align: right;">' + (activeQueues > 0 ? activeQueues : '-') + '</td>' +
- '<td style="text-align: right;">' + actionsPending + '</td>' +
- '<td style="text-align: right;">' + actionsRunning + '</td>' +
- '<td style="text-align: right;">' + actionsCompleted + '</td>' +
- '<td style="text-align: right; font-size: 11px; color: var(--theme_g1);">' + formatTraffic(bytesRecv, bytesSent) + '</td>' +
- '<td style="text-align: right; color: var(--theme_g1);">' + formatLastSeen(dt) + '</td>';
- tbody.appendChild(tr);
- }
-
- var clusterLoad = totalCpus > 0 ? (totalWeightedCpuUsage / totalCpus) : 0;
- banner.setAttribute('load', clusterLoad.toFixed(1));
-
- // Total row
- var cidr = computeCidr(allIps);
- var totalTr = document.createElement('tr');
- totalTr.className = 'total-row';
- totalTr.innerHTML =
- '<td></td>' +
- '<td style="text-align: right; color: var(--theme_g1); text-transform: uppercase; font-size: 11px;">Total' + (cidr ? ' <span style="font-family: monospace; font-weight: normal;">' + escapeHtml(cidr) + '</span>' : '') + '</td>' +
- '<td style="text-align: right;">' + totalCpus + '</td>' +
- '<td></td>' +
- '<td style="text-align: right;">' + formatMemory(totalMemUsed, totalMemTotal) + '</td>' +
- '<td style="text-align: right;">' + totalQueues + '</td>' +
- '<td style="text-align: right;">' + totalPending + '</td>' +
- '<td style="text-align: right;">' + totalRunning + '</td>' +
- '<td style="text-align: right;">' + totalCompleted + '</td>' +
- '<td style="text-align: right; font-size: 11px;">' + formatTraffic(totalBytesRecv, totalBytesSent) + '</td>' +
- '<td></td>';
- tbody.appendChild(totalTr);
- }
-
- clearError();
- document.getElementById('last-update').textContent = new Date().toLocaleTimeString();
-
- // Render provisioning history if present in WebSocket payload
- if (data.events) {
- renderProvisioningHistory(data.events);
- }
-
- // Render connected clients if present
- if (data.clients) {
- renderClients(data.clients);
- }
-
- // Render client history if present
- if (data.client_events) {
- renderClientHistory(data.client_events);
- }
- }
-
- function eventBadge(type) {
- var colors = { joined: 'var(--theme_ok)', left: 'var(--theme_fail)', returned: 'var(--theme_warn)' };
- var labels = { joined: 'Joined', left: 'Left', returned: 'Returned' };
- var color = colors[type] || 'var(--theme_g1)';
- var label = labels[type] || type;
- return '<span style="display:inline-block;padding:2px 8px;border-radius:4px;font-size:11px;font-weight:600;color:var(--theme_g4);background:' + color + ';">' + escapeHtml(label) + '</span>';
- }
-
- function formatTimestamp(ts) {
- if (!ts) return '-';
- // CbObject DateTime serialized as ticks (100ns since 0001-01-01) or ISO string
- var date;
- if (typeof ts === 'number') {
- // .NET-style ticks: convert to Unix ms
- var unixMs = (ts - 621355968000000000) / 10000;
- date = new Date(unixMs);
- } else {
- date = new Date(ts);
- }
- if (isNaN(date.getTime())) return '-';
- return date.toLocaleTimeString();
- }
-
- var activeHistoryTab = 'workers';
-
- function switchHistoryTab(tab) {
- activeHistoryTab = tab;
- var tabs = document.querySelectorAll('.history-tab');
- for (var i = 0; i < tabs.length; i++) {
- tabs[i].classList.toggle('active', tabs[i].getAttribute('data-tab') === tab);
- }
- document.getElementById('history-panel-workers').style.display = tab === 'workers' ? '' : 'none';
- document.getElementById('history-panel-clients').style.display = tab === 'clients' ? '' : 'none';
- }
-
- function renderProvisioningHistory(events) {
- var emptyState = document.getElementById('history-empty');
- var table = document.getElementById('history-table');
- var tbody = document.getElementById('history-table-body');
-
- if (!events || events.length === 0) {
- emptyState.style.display = '';
- table.style.display = 'none';
- return;
- }
-
- emptyState.style.display = 'none';
- table.style.display = '';
- tbody.innerHTML = '';
-
- for (var i = 0; i < events.length; i++) {
- var evt = events[i];
- var tr = document.createElement('tr');
- tr.innerHTML =
- '<td style="color: var(--theme_g1);">' + formatTimestamp(evt.ts) + '</td>' +
- '<td>' + eventBadge(evt.type) + '</td>' +
- '<td>' + escapeHtml(evt.worker_id || '') + '</td>' +
- '<td>' + escapeHtml(evt.hostname || '') + '</td>';
- tbody.appendChild(tr);
- }
- }
-
- function clientHealthClass(dtMs) {
- if (dtMs == null) return 'health-red';
- var seconds = dtMs / 1000;
- if (seconds < 30) return 'health-green';
- if (seconds < 120) return 'health-yellow';
- return 'health-red';
- }
-
- function renderClients(clients) {
- var emptyState = document.getElementById('clients-empty');
- var table = document.getElementById('clients-table');
- var tbody = document.getElementById('clients-table-body');
-
- if (!clients || clients.length === 0) {
- emptyState.style.display = '';
- table.style.display = 'none';
- return;
- }
-
- emptyState.style.display = 'none';
- table.style.display = '';
- tbody.innerHTML = '';
-
- for (var i = 0; i < clients.length; i++) {
- var c = clients[i];
- var dt = c.dt;
- var hClass = clientHealthClass(dt);
- var hTitle = dt != null ? 'Last seen ' + formatLastSeen(dt) : 'Never seen';
-
- var sessionBadge = '';
- if (c.session_id) {
- sessionBadge = ' <span style="font-family:monospace;font-size:10px;color:var(--theme_faint);" title="Session ' + escapeHtml(c.session_id) + '">' + escapeHtml(c.session_id.substring(0, 8)) + '</span>';
- }
-
- var tr = document.createElement('tr');
- tr.innerHTML =
- '<td style="text-align: center;"><span class="health-dot ' + hClass + '" title="' + escapeHtml(hTitle) + '"></span></td>' +
- '<td>' + escapeHtml(c.id || '') + sessionBadge + '</td>' +
- '<td>' + escapeHtml(c.hostname || '') + '</td>' +
- '<td style="font-family: monospace; font-size: 12px; color: var(--theme_g1);">' + escapeHtml(c.address || '') + '</td>' +
- '<td style="text-align: right; color: var(--theme_g1);">' + formatLastSeen(dt) + '</td>';
- tbody.appendChild(tr);
- }
- }
-
- function clientEventBadge(type) {
- var colors = { connected: 'var(--theme_ok)', disconnected: 'var(--theme_fail)', updated: 'var(--theme_warn)' };
- var labels = { connected: 'Connected', disconnected: 'Disconnected', updated: 'Updated' };
- var color = colors[type] || 'var(--theme_g1)';
- var label = labels[type] || type;
- return '<span style="display:inline-block;padding:2px 8px;border-radius:4px;font-size:11px;font-weight:600;color:var(--theme_g4);background:' + color + ';">' + escapeHtml(label) + '</span>';
- }
-
- function renderClientHistory(events) {
- var emptyState = document.getElementById('client-history-empty');
- var table = document.getElementById('client-history-table');
- var tbody = document.getElementById('client-history-table-body');
-
- if (!events || events.length === 0) {
- emptyState.style.display = '';
- table.style.display = 'none';
- return;
- }
-
- emptyState.style.display = 'none';
- table.style.display = '';
- tbody.innerHTML = '';
-
- for (var i = 0; i < events.length; i++) {
- var evt = events[i];
- var tr = document.createElement('tr');
- tr.innerHTML =
- '<td style="color: var(--theme_g1);">' + formatTimestamp(evt.ts) + '</td>' +
- '<td>' + clientEventBadge(evt.type) + '</td>' +
- '<td>' + escapeHtml(evt.client_id || '') + '</td>' +
- '<td>' + escapeHtml(evt.hostname || '') + '</td>';
- tbody.appendChild(tr);
- }
- }
-
- // Fetch-based polling fallback
- var pollTimer = null;
-
- async function fetchProvisioningHistory() {
- try {
- var response = await fetch(BASE_URL + '/orch/history?limit=50', {
- headers: { 'Accept': 'application/json' }
- });
- if (response.ok) {
- var data = await response.json();
- renderProvisioningHistory(data.events || []);
- }
- } catch (e) {
- console.error('Error fetching provisioning history:', e);
- }
- }
-
- async function fetchClients() {
- try {
- var response = await fetch(BASE_URL + '/orch/clients', {
- headers: { 'Accept': 'application/json' }
- });
- if (response.ok) {
- var data = await response.json();
- renderClients(data.clients || []);
- }
- } catch (e) {
- console.error('Error fetching clients:', e);
- }
- }
-
- async function fetchClientHistory() {
- try {
- var response = await fetch(BASE_URL + '/orch/clients/history?limit=50', {
- headers: { 'Accept': 'application/json' }
- });
- if (response.ok) {
- var data = await response.json();
- renderClientHistory(data.client_events || []);
- }
- } catch (e) {
- console.error('Error fetching client history:', e);
- }
- }
-
- async function fetchDashboard() {
- var banner = document.querySelector('zen-banner');
- try {
- var response = await fetch(BASE_URL + '/orch/agents', {
- headers: { 'Accept': 'application/json' }
- });
-
- if (!response.ok) {
- banner.setAttribute('cluster-status', 'degraded');
- throw new Error('HTTP ' + response.status + ': ' + response.statusText);
- }
-
- renderDashboard(await response.json());
- fetchProvisioningHistory();
- fetchClients();
- fetchClientHistory();
- } catch (error) {
- console.error('Error updating dashboard:', error);
- showError(error.message);
- banner.setAttribute('cluster-status', 'offline');
- }
- }
-
- function startPolling() {
- if (pollTimer) return;
- fetchDashboard();
- pollTimer = setInterval(fetchDashboard, REFRESH_INTERVAL);
- }
-
- function stopPolling() {
- if (pollTimer) {
- clearInterval(pollTimer);
- pollTimer = null;
- }
- }
-
- // WebSocket connection with automatic reconnect and polling fallback
- var ws = null;
-
- function connectWebSocket() {
- var proto = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
- ws = new WebSocket(proto + '//' + window.location.host + '/orch/ws');
-
- ws.onopen = function() {
- stopPolling();
- clearError();
- };
-
- ws.onmessage = function(event) {
- try {
- renderDashboard(JSON.parse(event.data));
- } catch (e) {
- console.error('WebSocket message parse error:', e);
- }
- };
-
- ws.onclose = function() {
- ws = null;
- startPolling();
- setTimeout(connectWebSocket, 3000);
- };
-
- ws.onerror = function() {
- // onclose will fire after onerror
- };
- }
-
- // Fetch orchestrator hostname for the banner
- fetch(BASE_URL + '/orch/status', { headers: { 'Accept': 'application/json' } })
- .then(function(r) { return r.ok ? r.json() : null; })
- .then(function(d) {
- if (d && d.hostname) {
- document.querySelector('zen-banner').setAttribute('tagline', 'Orchestrator \u2014 ' + d.hostname);
- }
- })
- .catch(function() {});
-
- // Initial load via fetch, then try WebSocket
- fetchDashboard();
- connectWebSocket();
- </script>
-</body>
-</html>
diff --git a/src/zenserver/frontend/html/pages/orchestrator.js b/src/zenserver/frontend/html/pages/orchestrator.js
index a280fabdb..30f6a8122 100644
--- a/src/zenserver/frontend/html/pages/orchestrator.js
+++ b/src/zenserver/frontend/html/pages/orchestrator.js
@@ -14,6 +14,14 @@ export class Page extends ZenPage
{
this.set_title("orchestrator");
+ // Provisioner section (hidden until data arrives)
+ this._prov_section = this._collapsible_section("Provisioner");
+ this._prov_section._parent.inner().style.display = "none";
+ this._prov_grid = null;
+ this._prov_target_dirty = false;
+ this._prov_commit_timer = null;
+ this._prov_last_target = null;
+
// Agents section
const agents_section = this._collapsible_section("Compute Agents");
this._agents_host = agents_section;
@@ -50,11 +58,12 @@ export class Page extends ZenPage
{
try
{
- const [agents, history, clients, client_history] = await Promise.all([
+ const [agents, history, clients, client_history, prov] = await Promise.all([
new Fetcher().resource("/orch/agents").json(),
new Fetcher().resource("/orch/history").param("limit", "50").json().catch(() => null),
new Fetcher().resource("/orch/clients").json().catch(() => null),
new Fetcher().resource("/orch/clients/history").param("limit", "50").json().catch(() => null),
+ new Fetcher().resource("/orch/provisioner/status").json().catch(() => null),
]);
this._render_agents(agents);
@@ -70,6 +79,7 @@ export class Page extends ZenPage
{
this._render_client_history(client_history.client_events || []);
}
+ this._render_provisioner(prov);
}
catch (e) { /* service unavailable */ }
}
@@ -109,6 +119,7 @@ export class Page extends ZenPage
{
this._render_client_history(data.client_events);
}
+ this._render_provisioner(data.provisioner);
}
catch (e) { /* ignore parse errors */ }
};
@@ -156,7 +167,7 @@ export class Page extends ZenPage
return;
}
- let totalCpus = 0, totalWeightedCpu = 0;
+ let totalCpus = 0, activeCpus = 0, totalWeightedCpu = 0;
let totalMemUsed = 0, totalMemTotal = 0;
let totalQueues = 0, totalPending = 0, totalRunning = 0, totalCompleted = 0;
let totalRecv = 0, totalSent = 0;
@@ -173,8 +184,14 @@ export class Page extends ZenPage
const completed = w.actions_completed || 0;
const recv = w.bytes_received || 0;
const sent = w.bytes_sent || 0;
+ const provisioner = w.provisioner || "";
+ const isProvisioned = provisioner !== "";
totalCpus += cpus;
+ if (w.provisioner_status === "active")
+ {
+ activeCpus += cpus;
+ }
if (cpus > 0 && typeof cpuUsage === "number")
{
totalWeightedCpu += cpuUsage * cpus;
@@ -209,12 +226,49 @@ export class Page extends ZenPage
cell.inner().textContent = "";
cell.tag("a").text(hostname).attr("href", w.uri + "/dashboard/compute/").attr("target", "_blank");
}
+
+ // Visual treatment based on provisioner status
+ const provStatus = w.provisioner_status || "";
+ if (!isProvisioned)
+ {
+ row.inner().style.opacity = "0.45";
+ }
+ else
+ {
+ const hostCell = row.get_cell(0);
+ const el = hostCell.inner();
+ const badge = document.createElement("span");
+ const badgeBase = "display:inline-block;margin-left:6px;padding:1px 5px;border-radius:8px;" +
+ "font-size:9px;font-weight:600;color:#fff;vertical-align:middle;";
+
+ if (provStatus === "draining")
+ {
+ badge.textContent = "draining";
+ badge.style.cssText = badgeBase + "background:var(--theme_warn);";
+ row.inner().style.opacity = "0.6";
+ }
+ else if (provStatus === "active")
+ {
+ badge.textContent = provisioner;
+ badge.style.cssText = badgeBase + "background:#8957e5;";
+ }
+ else
+ {
+ badge.textContent = "deallocated";
+ badge.style.cssText = badgeBase + "background:var(--theme_fail);";
+ row.inner().style.opacity = "0.45";
+ }
+ el.appendChild(badge);
+ }
}
- // Total row
+ // Total row — show active / total in CPUs column
+ const cpuLabel = activeCpus < totalCpus
+ ? Friendly.sep(activeCpus) + " / " + Friendly.sep(totalCpus)
+ : Friendly.sep(totalCpus);
const total = this._agents_table.add_row(
"TOTAL",
- Friendly.sep(totalCpus),
+ cpuLabel,
"",
totalMemTotal > 0 ? Friendly.bytes(totalMemUsed) + " / " + Friendly.bytes(totalMemTotal) : "-",
Friendly.sep(totalQueues),
@@ -305,6 +359,154 @@ export class Page extends ZenPage
}
}
+ _render_provisioner(prov)
+ {
+ const container = this._prov_section._parent.inner();
+
+ if (!prov || !prov.name)
+ {
+ container.style.display = "none";
+ return;
+ }
+ container.style.display = "";
+
+ if (!this._prov_grid)
+ {
+ this._prov_grid = this._prov_section.tag().classify("grid").classify("stats-tiles");
+ this._prov_tiles = {};
+
+ // Target cores tile with editable input
+ const target_tile = this._prov_grid.tag().classify("card").classify("stats-tile");
+ target_tile.tag().classify("card-title").text("Target Cores");
+ const target_body = target_tile.tag().classify("tile-metrics");
+ const target_m = target_body.tag().classify("tile-metric").classify("tile-metric-hero");
+ const input = document.createElement("input");
+ input.type = "number";
+ input.min = "0";
+ input.style.cssText = "width:100px;padding:4px 8px;border:1px solid var(--theme_g2);border-radius:4px;" +
+ "background:var(--theme_g4);color:var(--theme_bright);font-size:20px;font-weight:600;text-align:right;";
+ target_m.inner().appendChild(input);
+ target_m.tag().classify("metric-label").text("target");
+ this._prov_tiles.target_input = input;
+
+ input.addEventListener("focus", () => { this._prov_target_dirty = true; });
+ input.addEventListener("input", () => {
+ this._prov_target_dirty = true;
+ if (this._prov_commit_timer)
+ {
+ clearTimeout(this._prov_commit_timer);
+ }
+ this._prov_commit_timer = setTimeout(() => this._commit_provisioner_target(), 800);
+ });
+ input.addEventListener("keydown", (e) => {
+ if (e.key === "Enter")
+ {
+ if (this._prov_commit_timer)
+ {
+ clearTimeout(this._prov_commit_timer);
+ }
+ this._commit_provisioner_target();
+ input.blur();
+ }
+ });
+ input.addEventListener("blur", () => {
+ if (this._prov_commit_timer)
+ {
+ clearTimeout(this._prov_commit_timer);
+ }
+ this._commit_provisioner_target();
+ });
+
+ // Active cores
+ const active_tile = this._prov_grid.tag().classify("card").classify("stats-tile");
+ active_tile.tag().classify("card-title").text("Active Cores");
+ const active_body = active_tile.tag().classify("tile-metrics");
+ this._prov_tiles.active = active_body;
+
+ // Estimated cores
+ const est_tile = this._prov_grid.tag().classify("card").classify("stats-tile");
+ est_tile.tag().classify("card-title").text("Estimated Cores");
+ const est_body = est_tile.tag().classify("tile-metrics");
+ this._prov_tiles.estimated = est_body;
+
+ // Agents
+ const agents_tile = this._prov_grid.tag().classify("card").classify("stats-tile");
+ agents_tile.tag().classify("card-title").text("Agents");
+ const agents_body = agents_tile.tag().classify("tile-metrics");
+ this._prov_tiles.agents = agents_body;
+
+ // Draining
+ const drain_tile = this._prov_grid.tag().classify("card").classify("stats-tile");
+ drain_tile.tag().classify("card-title").text("Draining");
+ const drain_body = drain_tile.tag().classify("tile-metrics");
+ this._prov_tiles.draining = drain_body;
+ }
+
+ // Update values
+ const input = this._prov_tiles.target_input;
+ if (!this._prov_target_dirty && document.activeElement !== input)
+ {
+ input.value = prov.target_cores;
+ }
+ this._prov_last_target = prov.target_cores;
+
+ // Re-render metric tiles (clear and recreate content)
+ for (const key of ["active", "estimated", "agents", "draining"])
+ {
+ this._prov_tiles[key].inner().innerHTML = "";
+ }
+ this._metric(this._prov_tiles.active, Friendly.sep(prov.active_cores), "cores", true);
+ this._metric(this._prov_tiles.estimated, Friendly.sep(prov.estimated_cores), "cores", true);
+ this._metric(this._prov_tiles.agents, Friendly.sep(prov.agents), "active", true);
+ this._metric(this._prov_tiles.draining, Friendly.sep(prov.agents_draining || 0), "agents", true);
+ }
+
+ async _commit_provisioner_target()
+ {
+ const input = this._prov_tiles?.target_input;
+ if (!input || this._prov_committing)
+ {
+ return;
+ }
+ const value = parseInt(input.value, 10);
+ if (isNaN(value) || value < 0)
+ {
+ return;
+ }
+ if (value === this._prov_last_target)
+ {
+ this._prov_target_dirty = false;
+ return;
+ }
+ this._prov_committing = true;
+ try
+ {
+ const resp = await fetch("/orch/provisioner/target", {
+ method: "POST",
+ headers: { "Content-Type": "application/json" },
+ body: JSON.stringify({ target_cores: value }),
+ });
+ if (resp.ok)
+ {
+ this._prov_target_dirty = false;
+ console.log("Target cores set to", value);
+ }
+ else
+ {
+ const text = await resp.text();
+ console.error("Failed to set target cores: HTTP", resp.status, text);
+ }
+ }
+ catch (e)
+ {
+ console.error("Failed to set target cores:", e);
+ }
+ finally
+ {
+ this._prov_committing = false;
+ }
+ }
+
_metric(parent, value, label, hero = false)
{
const m = parent.tag().classify("tile-metric");
diff --git a/src/zenserver/main.cpp b/src/zenserver/main.cpp
index c5f8724ca..108685eb9 100644
--- a/src/zenserver/main.cpp
+++ b/src/zenserver/main.cpp
@@ -14,7 +14,6 @@
#include <zencore/memory/memorytrace.h>
#include <zencore/memory/newdelete.h>
#include <zencore/scopeguard.h>
-#include <zencore/sentryintegration.h>
#include <zencore/session.h>
#include <zencore/string.h>
#include <zencore/thread.h>
@@ -169,7 +168,12 @@ AppMain(int argc, char* argv[])
if (IsDir(ServerOptions.DataDir))
{
ZEN_CONSOLE_INFO("Deleting files from '{}' ({})", ServerOptions.DataDir, DeleteReason);
- DeleteDirectories(ServerOptions.DataDir);
+ std::error_code Ec;
+ DeleteDirectories(ServerOptions.DataDir, Ec);
+ if (Ec)
+ {
+ ZEN_WARN("could not fully clean '{}': {} (continuing anyway)", ServerOptions.DataDir, Ec.message());
+ }
}
}