aboutsummaryrefslogtreecommitdiff
path: root/src/zenhorde
diff options
context:
space:
mode:
Diffstat (limited to 'src/zenhorde')
-rw-r--r--src/zenhorde/README.md17
-rw-r--r--src/zenhorde/hordeagent.cpp561
-rw-r--r--src/zenhorde/hordeagent.h127
-rw-r--r--src/zenhorde/hordeagentmessage.cpp581
-rw-r--r--src/zenhorde/hordeagentmessage.h153
-rw-r--r--src/zenhorde/hordebundle.cpp63
-rw-r--r--src/zenhorde/hordeclient.cpp89
-rw-r--r--src/zenhorde/hordecomputebuffer.cpp454
-rw-r--r--src/zenhorde/hordecomputebuffer.h136
-rw-r--r--src/zenhorde/hordecomputechannel.cpp37
-rw-r--r--src/zenhorde/hordecomputechannel.h32
-rw-r--r--src/zenhorde/hordecomputesocket.cpp410
-rw-r--r--src/zenhorde/hordecomputesocket.h109
-rw-r--r--src/zenhorde/hordeconfig.cpp16
-rw-r--r--src/zenhorde/hordeprovisioner.cpp682
-rw-r--r--src/zenhorde/hordetransport.cpp153
-rw-r--r--src/zenhorde/hordetransport.h67
-rw-r--r--src/zenhorde/hordetransportaes.cpp718
-rw-r--r--src/zenhorde/hordetransportaes.h51
-rw-r--r--src/zenhorde/include/zenhorde/hordeclient.h32
-rw-r--r--src/zenhorde/include/zenhorde/hordeconfig.h37
-rw-r--r--src/zenhorde/include/zenhorde/hordeprovisioner.h92
-rw-r--r--src/zenhorde/xmake.lua2
23 files changed, 2492 insertions, 2127 deletions
diff --git a/src/zenhorde/README.md b/src/zenhorde/README.md
new file mode 100644
index 000000000..13beaa968
--- /dev/null
+++ b/src/zenhorde/README.md
@@ -0,0 +1,17 @@
+# Horde Compute integration
+
+Zen compute can use Horde to provision runner nodes.
+
+## Launch a coordinator instance
+
+Coordinator instances provision compute resources (runners) from a compute provider such as Horde, and surface an interface which allows zenserver instances to discover endpoints which they can submit actions to.
+
+```bash
+zenserver compute --horde-enabled --horde-server=https://horde.dev.net:13340/ --horde-max-cores=512 --horde-zen-service-port=25000 --http=asio
+```
+
+## Use a coordinator
+
+```bash
+zen exec beacon --path=e:\lyra-recording --orch=http://localhost:8558
+```
diff --git a/src/zenhorde/hordeagent.cpp b/src/zenhorde/hordeagent.cpp
index 819b2d0cb..029b98e55 100644
--- a/src/zenhorde/hordeagent.cpp
+++ b/src/zenhorde/hordeagent.cpp
@@ -8,290 +8,479 @@
#include <zencore/logging.h>
#include <zencore/trace.h>
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <asio.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
#include <cstring>
-#include <unordered_map>
namespace zen::horde {
-HordeAgent::HordeAgent(const MachineInfo& Info) : m_Log(zen::logging::Get("horde.agent")), m_MachineInfo(Info)
-{
- ZEN_TRACE_CPU("HordeAgent::Connect");
+// --- AsyncHordeAgent ---
- auto Transport = std::make_unique<TcpComputeTransport>(Info);
- if (!Transport->IsValid())
+static const char*
+GetStateName(AsyncHordeAgent::State S)
+{
+ switch (S)
{
- ZEN_WARN("failed to create TCP transport to '{}:{}'", Info.GetConnectionAddress(), Info.GetConnectionPort());
- return;
+ case AsyncHordeAgent::State::Idle:
+ return "idle";
+ case AsyncHordeAgent::State::Connecting:
+ return "connect";
+ case AsyncHordeAgent::State::WaitAgentAttach:
+ return "agent-attach";
+ case AsyncHordeAgent::State::SentFork:
+ return "fork";
+ case AsyncHordeAgent::State::WaitChildAttach:
+ return "child-attach";
+ case AsyncHordeAgent::State::Uploading:
+ return "upload";
+ case AsyncHordeAgent::State::Executing:
+ return "execute";
+ case AsyncHordeAgent::State::Polling:
+ return "poll";
+ case AsyncHordeAgent::State::Done:
+ return "done";
+ default:
+ return "unknown";
}
+}
- // The 64-byte nonce is always sent unencrypted as the first thing on the wire.
- // The Horde agent uses this to identify which lease this connection belongs to.
- Transport->Send(Info.Nonce, sizeof(Info.Nonce));
+AsyncHordeAgent::AsyncHordeAgent(asio::io_context& IoContext) : m_IoContext(IoContext), m_Log(zen::logging::Get("horde.agent.async"))
+{
+}
- std::unique_ptr<ComputeTransport> FinalTransport = std::move(Transport);
- if (Info.EncryptionMode == Encryption::AES)
+AsyncHordeAgent::~AsyncHordeAgent()
+{
+ Cancel();
+}
+
+void
+AsyncHordeAgent::Start(AsyncAgentConfig Config, AsyncAgentCompletionHandler OnDone)
+{
+ m_Config = std::move(Config);
+ m_OnDone = std::move(OnDone);
+ m_State = State::Connecting;
+ DoConnect();
+}
+
+void
+AsyncHordeAgent::Cancel()
+{
+ m_Cancelled = true;
+ if (m_Socket)
{
- FinalTransport = std::make_unique<AesComputeTransport>(Info.Key, std::move(FinalTransport));
- if (!FinalTransport->IsValid())
- {
- ZEN_WARN("failed to create AES transport");
- return;
- }
+ m_Socket->Close();
+ }
+ else if (m_TcpTransport)
+ {
+ // Cancelled before handshake completed - tear down the pending TCP connect.
+ m_TcpTransport->Close();
}
+}
- // Create multiplexed socket and channels
- m_Socket = std::make_unique<ComputeSocket>(std::move(FinalTransport));
+void
+AsyncHordeAgent::DoConnect()
+{
+ ZEN_TRACE_CPU("AsyncHordeAgent::DoConnect");
- // Channel 0 is the agent control channel (handles Attach/Fork handshake).
- // Channel 100 is the child I/O channel (handles file upload and remote execution).
- Ref<ComputeChannel> AgentComputeChannel = m_Socket->CreateChannel(0);
- Ref<ComputeChannel> ChildComputeChannel = m_Socket->CreateChannel(100);
+ m_TcpTransport = std::make_unique<AsyncTcpComputeTransport>(m_IoContext);
+
+ auto Self = shared_from_this();
+ m_TcpTransport->AsyncConnect(m_Config.Machine, [this, Self](const std::error_code& Ec) { OnConnected(Ec); });
+}
- if (!AgentComputeChannel || !ChildComputeChannel)
+void
+AsyncHordeAgent::OnConnected(const std::error_code& Ec)
+{
+ if (Ec || m_Cancelled)
{
- ZEN_WARN("failed to create compute channels");
+ if (Ec)
+ {
+ ZEN_WARN("connect failed: {}", Ec.message());
+ }
+ Finish(false);
return;
}
- m_AgentChannel = std::make_unique<AgentMessageChannel>(std::move(AgentComputeChannel));
- m_ChildChannel = std::make_unique<AgentMessageChannel>(std::move(ChildComputeChannel));
+ // Optionally wrap with AES encryption
+ std::unique_ptr<AsyncComputeTransport> FinalTransport = std::move(m_TcpTransport);
+ if (m_Config.Machine.EncryptionMode == Encryption::AES)
+ {
+ FinalTransport = std::make_unique<AsyncAesComputeTransport>(m_Config.Machine.Key, std::move(FinalTransport), m_IoContext);
+ }
+
+ // Create the multiplexed socket and register channels. Ownership of the transport
+ // moves into the socket here - no need to retain a separate m_Transport field.
+ m_Socket = std::make_shared<AsyncComputeSocket>(std::move(FinalTransport), m_IoContext);
+
+ m_AgentChannel = std::make_unique<AsyncAgentMessageChannel>(m_Socket, 0, m_IoContext);
+ m_ChildChannel = std::make_unique<AsyncAgentMessageChannel>(m_Socket, 100, m_IoContext);
+
+ m_Socket->RegisterChannel(
+ 0,
+ [this](std::vector<uint8_t> Data) { m_AgentChannel->OnFrame(std::move(Data)); },
+ [this]() { m_AgentChannel->OnDetach(); });
- m_IsValid = true;
+ m_Socket->RegisterChannel(
+ 100,
+ [this](std::vector<uint8_t> Data) { m_ChildChannel->OnFrame(std::move(Data)); },
+ [this]() { m_ChildChannel->OnDetach(); });
+
+ m_Socket->StartRecvPump();
+
+ m_State = State::WaitAgentAttach;
+ DoWaitAgentAttach();
}
-HordeAgent::~HordeAgent()
+void
+AsyncHordeAgent::DoWaitAgentAttach()
{
- CloseConnection();
+ auto Self = shared_from_this();
+ m_AgentChannel->AsyncReadResponse(5000, [this, Self](AgentMessageType Type, const uint8_t* Data, size_t Size) {
+ OnAgentResponse(Type, Data, Size);
+ });
}
-bool
-HordeAgent::BeginCommunication()
+void
+AsyncHordeAgent::OnAgentResponse(AgentMessageType Type, const uint8_t* /*Data*/, size_t /*Size*/)
{
- ZEN_TRACE_CPU("HordeAgent::BeginCommunication");
-
- if (!m_IsValid)
+ if (m_Cancelled)
{
- return false;
+ Finish(false);
+ return;
}
- // Start the send/recv pump threads
- m_Socket->StartCommunication();
-
- // Wait for Attach on agent channel
- AgentMessageType Type = m_AgentChannel->ReadResponse(5000);
if (Type == AgentMessageType::None)
{
ZEN_WARN("timed out waiting for Attach on agent channel");
- return false;
+ Finish(false);
+ return;
}
+
if (Type != AgentMessageType::Attach)
{
ZEN_WARN("expected Attach on agent channel, got 0x{:02x}", static_cast<int>(Type));
- return false;
+ Finish(false);
+ return;
}
- // Fork tells the remote agent to create child channel 100 with a 4MB buffer.
- // After this, the agent will send an Attach on the child channel.
+ m_State = State::SentFork;
+ DoSendFork();
+}
+
+void
+AsyncHordeAgent::DoSendFork()
+{
m_AgentChannel->Fork(100, 4 * 1024 * 1024);
- // Wait for Attach on child channel
- Type = m_ChildChannel->ReadResponse(5000);
+ m_State = State::WaitChildAttach;
+ DoWaitChildAttach();
+}
+
+void
+AsyncHordeAgent::DoWaitChildAttach()
+{
+ auto Self = shared_from_this();
+ m_ChildChannel->AsyncReadResponse(5000, [this, Self](AgentMessageType Type, const uint8_t* Data, size_t Size) {
+ OnChildAttachResponse(Type, Data, Size);
+ });
+}
+
+void
+AsyncHordeAgent::OnChildAttachResponse(AgentMessageType Type, const uint8_t* /*Data*/, size_t /*Size*/)
+{
+ if (m_Cancelled)
+ {
+ Finish(false);
+ return;
+ }
+
if (Type == AgentMessageType::None)
{
ZEN_WARN("timed out waiting for Attach on child channel");
- return false;
+ Finish(false);
+ return;
}
+
if (Type != AgentMessageType::Attach)
{
ZEN_WARN("expected Attach on child channel, got 0x{:02x}", static_cast<int>(Type));
- return false;
+ Finish(false);
+ return;
}
- return true;
+ m_State = State::Uploading;
+ m_CurrentBundleIndex = 0;
+ DoUploadNext();
}
-bool
-HordeAgent::UploadBinaries(const std::filesystem::path& BundleDir, const std::string& BundleLocator)
+void
+AsyncHordeAgent::DoUploadNext()
{
- ZEN_TRACE_CPU("HordeAgent::UploadBinaries");
+ if (m_Cancelled)
+ {
+ Finish(false);
+ return;
+ }
+
+ if (m_CurrentBundleIndex >= m_Config.Bundles.size())
+ {
+ // All bundles uploaded - proceed to execute
+ m_State = State::Executing;
+ DoExecute();
+ return;
+ }
- m_ChildChannel->UploadFiles("", BundleLocator.c_str());
+ const auto& [Locator, BundleDir] = m_Config.Bundles[m_CurrentBundleIndex];
+ m_ChildChannel->UploadFiles("", Locator.c_str());
- std::unordered_map<std::string, std::unique_ptr<BasicFile>> BlobFiles;
+ // Enter the ReadBlob/Blob upload loop
+ auto Self = shared_from_this();
+ m_ChildChannel->AsyncReadResponse(1000, [this, Self](AgentMessageType Type, const uint8_t* Data, size_t Size) {
+ OnUploadResponse(Type, Data, Size);
+ });
+}
- auto FindOrOpenBlob = [&](std::string_view Locator) -> BasicFile* {
- std::string Key(Locator);
+void
+AsyncHordeAgent::OnUploadResponse(AgentMessageType Type, const uint8_t* Data, size_t Size)
+{
+ if (m_Cancelled)
+ {
+ Finish(false);
+ return;
+ }
- if (auto It = BlobFiles.find(Key); It != BlobFiles.end())
+ if (Type == AgentMessageType::None)
+ {
+ if (m_ChildChannel->IsDetached())
{
- return It->second.get();
+ ZEN_WARN("connection lost during upload");
+ Finish(false);
+ return;
}
+ // Timeout - retry read
+ auto Self = shared_from_this();
+ m_ChildChannel->AsyncReadResponse(1000, [this, Self](AgentMessageType Type, const uint8_t* Data, size_t Size) {
+ OnUploadResponse(Type, Data, Size);
+ });
+ return;
+ }
- const std::filesystem::path Path = BundleDir / (Key + ".blob");
- std::error_code Ec;
- auto File = std::make_unique<BasicFile>();
- File->Open(Path, BasicFile::Mode::kRead, Ec);
+ if (Type == AgentMessageType::WriteFilesResponse)
+ {
+ // This bundle upload is done - move to next
+ ++m_CurrentBundleIndex;
+ DoUploadNext();
+ return;
+ }
- if (Ec)
+ if (Type == AgentMessageType::Exception)
+ {
+ ExceptionInfo Ex;
+ if (!AsyncAgentMessageChannel::ReadException(Data, Size, Ex))
{
- ZEN_ERROR("cannot read blob file: '{}'", Path);
- return nullptr;
+ ZEN_ERROR("malformed Exception message during upload (size={})", Size);
+ Finish(false);
+ return;
}
+ ZEN_ERROR("upload exception: {} - {}", Ex.Message, Ex.Description);
+ Finish(false);
+ return;
+ }
- BasicFile* Ptr = File.get();
- BlobFiles.emplace(std::move(Key), std::move(File));
- return Ptr;
- };
-
- // The upload protocol is request-driven: we send WriteFiles, then the remote agent
- // sends ReadBlob requests for each blob it needs. We respond with Blob data until
- // the agent sends WriteFilesResponse indicating the upload is complete.
- constexpr int32_t ReadResponseTimeoutMs = 1000;
+ if (Type != AgentMessageType::ReadBlob)
+ {
+ ZEN_ERROR("unexpected message type 0x{:02x} during upload", static_cast<int>(Type));
+ Finish(false);
+ return;
+ }
- for (;;)
+ // Handle ReadBlob request
+ BlobRequest Req;
+ if (!AsyncAgentMessageChannel::ReadBlobRequest(Data, Size, Req))
{
- bool TimedOut = false;
+ ZEN_ERROR("malformed ReadBlob message during upload (size={})", Size);
+ Finish(false);
+ return;
+ }
- if (AgentMessageType Type = m_ChildChannel->ReadResponse(ReadResponseTimeoutMs, &TimedOut); Type != AgentMessageType::ReadBlob)
- {
- if (TimedOut)
- {
- continue;
- }
- // End of stream - check if it was a successful upload
- if (Type == AgentMessageType::WriteFilesResponse)
- {
- return true;
- }
- else if (Type == AgentMessageType::Exception)
- {
- ExceptionInfo Ex;
- m_ChildChannel->ReadException(Ex);
- ZEN_ERROR("upload exception: {} - {}", Ex.Message, Ex.Description);
- }
- else
- {
- ZEN_ERROR("unexpected message type 0x{:02x} during upload", static_cast<int>(Type));
- }
- return false;
- }
+ const auto& [Locator, BundleDir] = m_Config.Bundles[m_CurrentBundleIndex];
+ const std::filesystem::path BlobPath = BundleDir / (std::string(Req.Locator) + ".blob");
- BlobRequest Req;
- m_ChildChannel->ReadBlobRequest(Req);
+ std::error_code FsEc;
+ BasicFile File;
+ File.Open(BlobPath, BasicFile::Mode::kRead, FsEc);
- BasicFile* File = FindOrOpenBlob(Req.Locator);
- if (!File)
- {
- return false;
- }
+ if (FsEc)
+ {
+ ZEN_ERROR("cannot read blob file: '{}'", BlobPath);
+ Finish(false);
+ return;
+ }
- // Read from offset to end of file
- const uint64_t TotalSize = File->FileSize();
- const uint64_t Offset = static_cast<uint64_t>(Req.Offset);
- if (Offset >= TotalSize)
- {
- ZEN_ERROR("upload got request for data beyond end of file: offset={}, length={}, total_size={}", Offset, Req.Length, TotalSize);
- m_ChildChannel->Blob(nullptr, 0);
- continue;
- }
+ const uint64_t TotalSize = File.FileSize();
+ const uint64_t Offset = static_cast<uint64_t>(Req.Offset);
+ if (Offset >= TotalSize)
+ {
+ ZEN_ERROR("blob request beyond end of file: offset={}, length={}, total_size={}", Offset, Req.Length, TotalSize);
+ m_ChildChannel->Blob(nullptr, 0);
+ }
+ else
+ {
+ const IoBuffer FileData = File.ReadRange(Offset, Min(Req.Length, TotalSize - Offset));
+ m_ChildChannel->Blob(static_cast<const uint8_t*>(FileData.GetData()), FileData.GetSize());
+ }
+
+ // Continue the upload loop
+ auto Self = shared_from_this();
+ m_ChildChannel->AsyncReadResponse(1000, [this, Self](AgentMessageType Type, const uint8_t* Data, size_t Size) {
+ OnUploadResponse(Type, Data, Size);
+ });
+}
- const IoBuffer Data = File->ReadRange(Offset, Min(Req.Length, TotalSize - Offset));
- m_ChildChannel->Blob(static_cast<const uint8_t*>(Data.GetData()), Data.GetSize());
+void
+AsyncHordeAgent::DoExecute()
+{
+ ZEN_TRACE_CPU("AsyncHordeAgent::DoExecute");
+
+ std::vector<const char*> ArgPtrs;
+ ArgPtrs.reserve(m_Config.Args.size());
+ for (const std::string& Arg : m_Config.Args)
+ {
+ ArgPtrs.push_back(Arg.c_str());
}
+
+ m_ChildChannel->Execute(m_Config.Executable.c_str(),
+ ArgPtrs.data(),
+ ArgPtrs.size(),
+ nullptr,
+ nullptr,
+ 0,
+ m_Config.UseWine ? ExecuteProcessFlags::UseWine : ExecuteProcessFlags::None);
+
+ ZEN_INFO("remote execution started on [{}:{}] lease={}",
+ m_Config.Machine.GetConnectionAddress(),
+ m_Config.Machine.GetConnectionPort(),
+ m_Config.Machine.LeaseId);
+
+ m_State = State::Polling;
+ DoPoll();
}
void
-HordeAgent::Execute(const char* Exe,
- const char* const* Args,
- size_t NumArgs,
- const char* WorkingDir,
- const char* const* EnvVars,
- size_t NumEnvVars,
- bool UseWine)
+AsyncHordeAgent::DoPoll()
{
- ZEN_TRACE_CPU("HordeAgent::Execute");
- m_ChildChannel
- ->Execute(Exe, Args, NumArgs, WorkingDir, EnvVars, NumEnvVars, UseWine ? ExecuteProcessFlags::UseWine : ExecuteProcessFlags::None);
+ if (m_Cancelled)
+ {
+ Finish(false);
+ return;
+ }
+
+ auto Self = shared_from_this();
+ m_ChildChannel->AsyncReadResponse(100, [this, Self](AgentMessageType Type, const uint8_t* Data, size_t Size) {
+ OnPollResponse(Type, Data, Size);
+ });
}
-bool
-HordeAgent::Poll(bool LogOutput)
+void
+AsyncHordeAgent::OnPollResponse(AgentMessageType Type, const uint8_t* Data, size_t Size)
{
- constexpr int32_t ReadResponseTimeoutMs = 100;
- AgentMessageType Type;
+ if (m_Cancelled)
+ {
+ Finish(false);
+ return;
+ }
- while ((Type = m_ChildChannel->ReadResponse(ReadResponseTimeoutMs)) != AgentMessageType::None)
+ switch (Type)
{
- switch (Type)
- {
- case AgentMessageType::ExecuteOutput:
+ case AgentMessageType::None:
+ if (m_ChildChannel->IsDetached())
+ {
+ ZEN_WARN("connection lost during execution");
+ Finish(false);
+ }
+ else
+ {
+ // Timeout - poll again
+ DoPoll();
+ }
+ break;
+
+ case AgentMessageType::ExecuteOutput:
+ // Silently consume remote stdout (matching LogOutput=false in provisioner)
+ DoPoll();
+ break;
+
+ case AgentMessageType::ExecuteResult:
+ {
+ int32_t ExitCode = -1;
+ if (!AsyncAgentMessageChannel::ReadExecuteResult(Data, Size, ExitCode))
{
- if (LogOutput && m_ChildChannel->GetResponseSize() > 0)
- {
- const char* ResponseData = static_cast<const char*>(m_ChildChannel->GetResponseData());
- size_t ResponseSize = m_ChildChannel->GetResponseSize();
-
- // Trim trailing newlines
- while (ResponseSize > 0 && (ResponseData[ResponseSize - 1] == '\n' || ResponseData[ResponseSize - 1] == '\r'))
- {
- --ResponseSize;
- }
-
- if (ResponseSize > 0)
- {
- const std::string_view Output(ResponseData, ResponseSize);
- ZEN_INFO("[remote] {}", Output);
- }
- }
+ // A remote with a malformed ExecuteResult cannot be trusted to report
+ // process outcome - treat as a protocol error and tear down rather than
+ // silently recording \"exited with -1\".
+ ZEN_ERROR("malformed ExecuteResult (size={}, lease={}) - disconnecting", Size, m_Config.Machine.LeaseId);
+ Finish(false);
break;
}
+ ZEN_INFO("remote process exited with code {} (lease={})", ExitCode, m_Config.Machine.LeaseId);
+ Finish(ExitCode == 0, ExitCode);
+ }
+ break;
- case AgentMessageType::ExecuteResult:
+ case AgentMessageType::Exception:
+ {
+ ExceptionInfo Ex;
+ if (AsyncAgentMessageChannel::ReadException(Data, Size, Ex))
{
- if (m_ChildChannel->GetResponseSize() == sizeof(int32_t))
- {
- int32_t ExitCode;
- memcpy(&ExitCode, m_ChildChannel->GetResponseData(), sizeof(int32_t));
- ZEN_INFO("remote process exited with code {}", ExitCode);
- }
- m_IsValid = false;
- return false;
+ ZEN_ERROR("exception: {} - {}", Ex.Message, Ex.Description);
}
-
- case AgentMessageType::Exception:
+ else
{
- ExceptionInfo Ex;
- m_ChildChannel->ReadException(Ex);
- ZEN_ERROR("exception: {} - {}", Ex.Message, Ex.Description);
- m_HasErrors = true;
- break;
+ ZEN_ERROR("malformed Exception message (size={})", Size);
}
+ Finish(false);
+ }
+ break;
- default:
- break;
- }
+ default:
+ DoPoll();
+ break;
}
-
- return m_IsValid && !m_HasErrors;
}
void
-HordeAgent::CloseConnection()
+AsyncHordeAgent::Finish(bool Success, int32_t ExitCode)
{
- if (m_ChildChannel)
+ if (m_State == State::Done)
{
- m_ChildChannel->Close();
+ return; // Already finished
}
- if (m_AgentChannel)
+
+ if (!Success)
{
- m_AgentChannel->Close();
+ ZEN_WARN("agent failed during {} (lease={})", GetStateName(m_State), m_Config.Machine.LeaseId);
}
-}
-bool
-HordeAgent::IsValid() const
-{
- return m_IsValid && !m_HasErrors;
+ m_State = State::Done;
+
+ if (m_Socket)
+ {
+ m_Socket->Close();
+ }
+
+ if (m_OnDone)
+ {
+ AsyncAgentResult Result;
+ Result.Success = Success;
+ Result.ExitCode = ExitCode;
+ Result.CoreCount = m_Config.Machine.LogicalCores;
+
+ auto Handler = std::move(m_OnDone);
+ m_OnDone = nullptr;
+ Handler(Result);
+ }
}
} // namespace zen::horde
diff --git a/src/zenhorde/hordeagent.h b/src/zenhorde/hordeagent.h
index e0ae89ead..a5b3248ab 100644
--- a/src/zenhorde/hordeagent.h
+++ b/src/zenhorde/hordeagent.h
@@ -10,68 +10,107 @@
#include <zencore/logbase.h>
#include <filesystem>
+#include <functional>
#include <memory>
#include <string>
+#include <vector>
+
+namespace asio {
+class io_context;
+}
namespace zen::horde {
-/** Manages the lifecycle of a single Horde compute agent.
+class AsyncComputeTransport;
+
+/** Result passed to the completion handler when an async agent finishes. */
+struct AsyncAgentResult
+{
+ bool Success = false;
+ int32_t ExitCode = -1;
+ uint16_t CoreCount = 0; ///< Logical cores on the provisioned machine
+};
+
+/** Completion handler for async agent lifecycle. */
+using AsyncAgentCompletionHandler = std::function<void(const AsyncAgentResult&)>;
+
+/** Configuration for launching a remote zenserver instance via an async agent. */
+struct AsyncAgentConfig
+{
+ MachineInfo Machine;
+ std::vector<std::pair<std::string, std::filesystem::path>> Bundles; ///< (locator, bundleDir) pairs
+ std::string Executable;
+ std::vector<std::string> Args;
+ bool UseWine = false;
+};
+
+/** Async agent that manages the full lifecycle of a single Horde compute connection.
*
- * Handles the full connection sequence for one provisioned machine:
- * 1. Connect via TCP transport (with optional AES encryption wrapping)
- * 2. Create a multiplexed ComputeSocket with agent (channel 0) and child (channel 100)
- * 3. Perform the Attach/Fork handshake to establish the child channel
- * 4. Upload zenserver binary via the WriteFiles/ReadBlob protocol
- * 5. Execute zenserver remotely via ExecuteV2
- * 6. Poll for ExecuteOutput (stdout) and ExecuteResult (exit code)
+ * Driven by a state machine using callbacks on a shared io_context - no dedicated
+ * threads. Call Start() to begin the connection/handshake/upload/execute/poll
+ * sequence. The completion handler is invoked when the remote process exits or
+ * an error occurs.
*/
-class HordeAgent
+class AsyncHordeAgent : public std::enable_shared_from_this<AsyncHordeAgent>
{
public:
- explicit HordeAgent(const MachineInfo& Info);
- ~HordeAgent();
+ AsyncHordeAgent(asio::io_context& IoContext);
+ ~AsyncHordeAgent();
- HordeAgent(const HordeAgent&) = delete;
- HordeAgent& operator=(const HordeAgent&) = delete;
+ AsyncHordeAgent(const AsyncHordeAgent&) = delete;
+ AsyncHordeAgent& operator=(const AsyncHordeAgent&) = delete;
- /** Perform the channel setup handshake (Attach on agent channel, Fork, Attach on child channel).
- * Returns false if the handshake times out or receives an unexpected message. */
- bool BeginCommunication();
+ /** Start the full agent lifecycle. The completion handler is called exactly once. */
+ void Start(AsyncAgentConfig Config, AsyncAgentCompletionHandler OnDone);
- /** Upload binary files to the remote agent.
- * @param BundleDir Directory containing .blob files.
- * @param BundleLocator Locator string identifying the bundle (from CreateBundle). */
- bool UploadBinaries(const std::filesystem::path& BundleDir, const std::string& BundleLocator);
+ /** Cancel in-flight operations. The completion handler is still called (with Success=false). */
+ void Cancel();
- /** Execute a command on the remote machine. */
- void Execute(const char* Exe,
- const char* const* Args,
- size_t NumArgs,
- const char* WorkingDir = nullptr,
- const char* const* EnvVars = nullptr,
- size_t NumEnvVars = 0,
- bool UseWine = false);
+ const MachineInfo& GetMachineInfo() const { return m_Config.Machine; }
- /** Poll for output and results. Returns true if the agent is still running.
- * When LogOutput is true, remote stdout is logged via ZEN_INFO. */
- bool Poll(bool LogOutput = true);
-
- void CloseConnection();
- bool IsValid() const;
-
- const MachineInfo& GetMachineInfo() const { return m_MachineInfo; }
+ enum class State
+ {
+ Idle,
+ Connecting,
+ WaitAgentAttach,
+ SentFork,
+ WaitChildAttach,
+ Uploading,
+ Executing,
+ Polling,
+ Done
+ };
private:
LoggerRef Log() { return m_Log; }
- std::unique_ptr<ComputeSocket> m_Socket;
- std::unique_ptr<AgentMessageChannel> m_AgentChannel; ///< Channel 0: agent control
- std::unique_ptr<AgentMessageChannel> m_ChildChannel; ///< Channel 100: child I/O
-
- LoggerRef m_Log;
- bool m_IsValid = false;
- bool m_HasErrors = false;
- MachineInfo m_MachineInfo;
+ void DoConnect();
+ void OnConnected(const std::error_code& Ec);
+ void DoWaitAgentAttach();
+ void OnAgentResponse(AgentMessageType Type, const uint8_t* Data, size_t Size);
+ void DoSendFork();
+ void DoWaitChildAttach();
+ void OnChildAttachResponse(AgentMessageType Type, const uint8_t* Data, size_t Size);
+ void DoUploadNext();
+ void OnUploadResponse(AgentMessageType Type, const uint8_t* Data, size_t Size);
+ void DoExecute();
+ void DoPoll();
+ void OnPollResponse(AgentMessageType Type, const uint8_t* Data, size_t Size);
+ void Finish(bool Success, int32_t ExitCode = -1);
+
+ asio::io_context& m_IoContext;
+ LoggerRef m_Log;
+ State m_State = State::Idle;
+ bool m_Cancelled = false;
+
+ AsyncAgentConfig m_Config;
+ AsyncAgentCompletionHandler m_OnDone;
+ size_t m_CurrentBundleIndex = 0;
+
+ std::unique_ptr<AsyncTcpComputeTransport> m_TcpTransport;
+ std::shared_ptr<AsyncComputeSocket> m_Socket;
+ std::unique_ptr<AsyncAgentMessageChannel> m_AgentChannel;
+ std::unique_ptr<AsyncAgentMessageChannel> m_ChildChannel;
};
} // namespace zen::horde
diff --git a/src/zenhorde/hordeagentmessage.cpp b/src/zenhorde/hordeagentmessage.cpp
index 998134a96..bef1bdda8 100644
--- a/src/zenhorde/hordeagentmessage.cpp
+++ b/src/zenhorde/hordeagentmessage.cpp
@@ -4,337 +4,496 @@
#include <zencore/intmath.h>
-#include <cassert>
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <asio.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+#include <zencore/except_fmt.h>
+#include <zencore/logging.h>
+
#include <cstring>
+#include <limits>
namespace zen::horde {
-AgentMessageChannel::AgentMessageChannel(Ref<ComputeChannel> Channel) : m_Channel(std::move(Channel))
-{
-}
-
-AgentMessageChannel::~AgentMessageChannel() = default;
+// --- AsyncAgentMessageChannel ---
-void
-AgentMessageChannel::Close()
+AsyncAgentMessageChannel::AsyncAgentMessageChannel(std::shared_ptr<AsyncComputeSocket> Socket, int ChannelId, asio::io_context& IoContext)
+: m_Socket(std::move(Socket))
+, m_ChannelId(ChannelId)
+, m_IoContext(IoContext)
+, m_TimeoutTimer(std::make_unique<asio::steady_timer>(m_Socket->GetStrand()))
{
- CreateMessage(AgentMessageType::None, 0);
- FlushMessage();
}
-void
-AgentMessageChannel::Ping()
+AsyncAgentMessageChannel::~AsyncAgentMessageChannel()
{
- CreateMessage(AgentMessageType::Ping, 0);
- FlushMessage();
+ if (m_TimeoutTimer)
+ {
+ m_TimeoutTimer->cancel();
+ }
}
-void
-AgentMessageChannel::Fork(int ChannelId, int BufferSize)
+// --- Message building helpers ---
+
+std::vector<uint8_t>
+AsyncAgentMessageChannel::BeginMessage(AgentMessageType Type, size_t ReservePayload)
{
- CreateMessage(AgentMessageType::Fork, sizeof(int) + sizeof(int));
- WriteInt32(ChannelId);
- WriteInt32(BufferSize);
- FlushMessage();
+ std::vector<uint8_t> Buf;
+ Buf.reserve(MessageHeaderLength + ReservePayload);
+ Buf.push_back(static_cast<uint8_t>(Type));
+ Buf.resize(MessageHeaderLength); // 1 byte type + 4 bytes length placeholder
+ return Buf;
}
void
-AgentMessageChannel::Attach()
+AsyncAgentMessageChannel::FinalizeAndSend(std::vector<uint8_t> Msg)
{
- CreateMessage(AgentMessageType::Attach, 0);
- FlushMessage();
+ const uint32_t PayloadSize = static_cast<uint32_t>(Msg.size() - MessageHeaderLength);
+ memcpy(&Msg[1], &PayloadSize, sizeof(uint32_t));
+ m_Socket->AsyncSendFrame(m_ChannelId, std::move(Msg));
}
void
-AgentMessageChannel::UploadFiles(const char* Path, const char* Locator)
+AsyncAgentMessageChannel::WriteInt32(std::vector<uint8_t>& Buf, int Value)
{
- CreateMessage(AgentMessageType::WriteFiles, strlen(Path) + strlen(Locator) + 20);
- WriteString(Path);
- WriteString(Locator);
- FlushMessage();
+ const uint8_t* Ptr = reinterpret_cast<const uint8_t*>(&Value);
+ Buf.insert(Buf.end(), Ptr, Ptr + sizeof(int));
}
-void
-AgentMessageChannel::Execute(const char* Exe,
- const char* const* Args,
- size_t NumArgs,
- const char* WorkingDir,
- const char* const* EnvVars,
- size_t NumEnvVars,
- ExecuteProcessFlags Flags)
+int
+AsyncAgentMessageChannel::ReadInt32(ReadCursor& C)
{
- size_t RequiredSize = 50 + strlen(Exe);
- for (size_t i = 0; i < NumArgs; ++i)
- {
- RequiredSize += strlen(Args[i]) + 10;
- }
- if (WorkingDir)
- {
- RequiredSize += strlen(WorkingDir) + 10;
- }
- for (size_t i = 0; i < NumEnvVars; ++i)
+ if (!C.CheckAvailable(sizeof(int32_t)))
{
- RequiredSize += strlen(EnvVars[i]) + 20;
+ return 0;
}
+ int32_t Value;
+ memcpy(&Value, C.Pos, sizeof(int32_t));
+ C.Pos += sizeof(int32_t);
+ return Value;
+}
- CreateMessage(AgentMessageType::ExecuteV2, RequiredSize);
- WriteString(Exe);
+void
+AsyncAgentMessageChannel::WriteFixedLengthBytes(std::vector<uint8_t>& Buf, const uint8_t* Data, size_t Length)
+{
+ Buf.insert(Buf.end(), Data, Data + Length);
+}
- WriteUnsignedVarInt(NumArgs);
- for (size_t i = 0; i < NumArgs; ++i)
+const uint8_t*
+AsyncAgentMessageChannel::ReadFixedLengthBytes(ReadCursor& C, size_t Length)
+{
+ if (!C.CheckAvailable(Length))
{
- WriteString(Args[i]);
+ return nullptr;
}
+ const uint8_t* Data = C.Pos;
+ C.Pos += Length;
+ return Data;
+}
- WriteOptionalString(WorkingDir);
-
- // ExecuteV2 protocol requires env vars as separate key/value pairs.
- // Callers pass "KEY=VALUE" strings; we split on the first '=' here.
- WriteUnsignedVarInt(NumEnvVars);
- for (size_t i = 0; i < NumEnvVars; ++i)
+size_t
+AsyncAgentMessageChannel::MeasureUnsignedVarInt(size_t Value)
+{
+ if (Value == 0)
{
- const char* Eq = strchr(EnvVars[i], '=');
- assert(Eq != nullptr);
-
- WriteString(std::string_view(EnvVars[i], Eq - EnvVars[i]));
- if (*(Eq + 1) == '\0')
- {
- WriteOptionalString(nullptr);
- }
- else
- {
- WriteOptionalString(Eq + 1);
- }
+ return 1;
}
-
- WriteInt32(static_cast<int>(Flags));
- FlushMessage();
+ return (FloorLog2_64(static_cast<uint64_t>(Value)) / 7) + 1;
}
void
-AgentMessageChannel::Blob(const uint8_t* Data, size_t Length)
+AsyncAgentMessageChannel::WriteUnsignedVarInt(std::vector<uint8_t>& Buf, size_t Value)
{
- // Blob responses are chunked to fit within the compute buffer's chunk size.
- // The 128-byte margin accounts for the ReadBlobResponse header (offset + total length fields).
- const size_t MaxChunkSize = m_Channel->Writer.GetChunkMaxLength() - 128 - MessageHeaderLength;
- for (size_t ChunkOffset = 0; ChunkOffset < Length;)
- {
- const size_t ChunkLength = std::min(Length - ChunkOffset, MaxChunkSize);
-
- CreateMessage(AgentMessageType::ReadBlobResponse, ChunkLength + 128);
- WriteInt32(static_cast<int>(ChunkOffset));
- WriteInt32(static_cast<int>(Length));
- WriteFixedLengthBytes(Data + ChunkOffset, ChunkLength);
- FlushMessage();
+ const size_t ByteCount = MeasureUnsignedVarInt(Value);
+ const size_t StartPos = Buf.size();
+ Buf.resize(StartPos + ByteCount);
- ChunkOffset += ChunkLength;
+ uint8_t* Output = Buf.data() + StartPos;
+ for (size_t i = 1; i < ByteCount; ++i)
+ {
+ Output[ByteCount - i] = static_cast<uint8_t>(Value);
+ Value >>= 8;
}
+ Output[0] = static_cast<uint8_t>((0xFF << (9 - static_cast<int>(ByteCount))) | static_cast<uint8_t>(Value));
}
-AgentMessageType
-AgentMessageChannel::ReadResponse(int32_t TimeoutMs, bool* OutTimedOut)
+size_t
+AsyncAgentMessageChannel::ReadUnsignedVarInt(ReadCursor& C)
{
- // Deferred advance: the previous response's buffer is only released when the next
- // ReadResponse is called. This allows callers to read response data between calls
- // without copying, since the pointer comes directly from the ring buffer.
- if (m_ResponseData)
+ // Need at least the leading byte to determine the encoded length.
+ if (!C.CheckAvailable(1))
{
- m_Channel->Reader.AdvanceReadPosition(m_ResponseLength + MessageHeaderLength);
- m_ResponseData = nullptr;
- m_ResponseLength = 0;
+ return 0;
}
- const uint8_t* Header = m_Channel->Reader.WaitToRead(MessageHeaderLength, TimeoutMs, OutTimedOut);
- if (!Header)
+ const uint8_t FirstByte = C.Pos[0];
+ const size_t NumBytes = CountLeadingZeros(0xFF & (~static_cast<unsigned int>(FirstByte))) + 1 - 24;
+
+ // The encoded length implied by the leading 0xFF-run may be 1..9 bytes; ensure the remaining bytes are in-bounds.
+ if (!C.CheckAvailable(NumBytes))
{
- return AgentMessageType::None;
+ return 0;
}
- uint32_t Length;
- memcpy(&Length, Header + 1, sizeof(uint32_t));
-
- Header = m_Channel->Reader.WaitToRead(MessageHeaderLength + Length, TimeoutMs, OutTimedOut);
- if (!Header)
+ size_t Value = static_cast<size_t>(FirstByte & (0xFF >> NumBytes));
+ for (size_t i = 1; i < NumBytes; ++i)
{
- return AgentMessageType::None;
+ Value <<= 8;
+ Value |= C.Pos[i];
}
- m_ResponseType = static_cast<AgentMessageType>(Header[0]);
- m_ResponseData = Header + MessageHeaderLength;
- m_ResponseLength = Length;
-
- return m_ResponseType;
+ C.Pos += NumBytes;
+ return Value;
}
void
-AgentMessageChannel::ReadException(ExceptionInfo& Ex)
+AsyncAgentMessageChannel::WriteString(std::vector<uint8_t>& Buf, const char* Text)
{
- assert(m_ResponseType == AgentMessageType::Exception);
- const uint8_t* Pos = m_ResponseData;
- Ex.Message = ReadString(&Pos);
- Ex.Description = ReadString(&Pos);
+ const size_t Length = strlen(Text);
+ WriteUnsignedVarInt(Buf, Length);
+ WriteFixedLengthBytes(Buf, reinterpret_cast<const uint8_t*>(Text), Length);
}
-int
-AgentMessageChannel::ReadExecuteResult()
+void
+AsyncAgentMessageChannel::WriteString(std::vector<uint8_t>& Buf, std::string_view Text)
{
- assert(m_ResponseType == AgentMessageType::ExecuteResult);
- const uint8_t* Pos = m_ResponseData;
- return ReadInt32(&Pos);
+ WriteUnsignedVarInt(Buf, Text.size());
+ WriteFixedLengthBytes(Buf, reinterpret_cast<const uint8_t*>(Text.data()), Text.size());
}
-void
-AgentMessageChannel::ReadBlobRequest(BlobRequest& Req)
+std::string_view
+AsyncAgentMessageChannel::ReadString(ReadCursor& C)
{
- assert(m_ResponseType == AgentMessageType::ReadBlob);
- const uint8_t* Pos = m_ResponseData;
- Req.Locator = ReadString(&Pos);
- Req.Offset = ReadUnsignedVarInt(&Pos);
- Req.Length = ReadUnsignedVarInt(&Pos);
+ const size_t Length = ReadUnsignedVarInt(C);
+ const uint8_t* Start = ReadFixedLengthBytes(C, Length);
+ if (C.ParseError || !Start)
+ {
+ return {};
+ }
+ return std::string_view(reinterpret_cast<const char*>(Start), Length);
}
void
-AgentMessageChannel::CreateMessage(AgentMessageType Type, size_t MaxLength)
+AsyncAgentMessageChannel::WriteOptionalString(std::vector<uint8_t>& Buf, const char* Text)
{
- m_RequestData = m_Channel->Writer.WaitToWrite(MessageHeaderLength + MaxLength);
- m_RequestData[0] = static_cast<uint8_t>(Type);
- m_MaxRequestSize = MaxLength;
- m_RequestSize = 0;
+ if (!Text)
+ {
+ WriteUnsignedVarInt(Buf, 0);
+ }
+ else
+ {
+ const size_t Length = strlen(Text);
+ WriteUnsignedVarInt(Buf, Length + 1);
+ WriteFixedLengthBytes(Buf, reinterpret_cast<const uint8_t*>(Text), Length);
+ }
}
+// --- Send methods ---
+
void
-AgentMessageChannel::FlushMessage()
+AsyncAgentMessageChannel::Close()
{
- const uint32_t Size = static_cast<uint32_t>(m_RequestSize);
- memcpy(&m_RequestData[1], &Size, sizeof(uint32_t));
- m_Channel->Writer.AdvanceWritePosition(MessageHeaderLength + m_RequestSize);
- m_RequestSize = 0;
- m_MaxRequestSize = 0;
- m_RequestData = nullptr;
+ auto Msg = BeginMessage(AgentMessageType::None, 0);
+ FinalizeAndSend(std::move(Msg));
}
void
-AgentMessageChannel::WriteInt32(int Value)
+AsyncAgentMessageChannel::Ping()
{
- WriteFixedLengthBytes(reinterpret_cast<const uint8_t*>(&Value), sizeof(int));
+ auto Msg = BeginMessage(AgentMessageType::Ping, 0);
+ FinalizeAndSend(std::move(Msg));
}
-int
-AgentMessageChannel::ReadInt32(const uint8_t** Pos)
+void
+AsyncAgentMessageChannel::Fork(int ChannelId, int BufferSize)
{
- int Value;
- memcpy(&Value, *Pos, sizeof(int));
- *Pos += sizeof(int);
- return Value;
+ auto Msg = BeginMessage(AgentMessageType::Fork, sizeof(int) + sizeof(int));
+ WriteInt32(Msg, ChannelId);
+ WriteInt32(Msg, BufferSize);
+ FinalizeAndSend(std::move(Msg));
}
void
-AgentMessageChannel::WriteFixedLengthBytes(const uint8_t* Data, size_t Length)
+AsyncAgentMessageChannel::Attach()
{
- assert(m_RequestSize + Length <= m_MaxRequestSize);
- memcpy(&m_RequestData[MessageHeaderLength + m_RequestSize], Data, Length);
- m_RequestSize += Length;
+ auto Msg = BeginMessage(AgentMessageType::Attach, 0);
+ FinalizeAndSend(std::move(Msg));
}
-const uint8_t*
-AgentMessageChannel::ReadFixedLengthBytes(const uint8_t** Pos, size_t Length)
+void
+AsyncAgentMessageChannel::UploadFiles(const char* Path, const char* Locator)
{
- const uint8_t* Data = *Pos;
- *Pos += Length;
- return Data;
+ auto Msg = BeginMessage(AgentMessageType::WriteFiles, strlen(Path) + strlen(Locator) + 20);
+ WriteString(Msg, Path);
+ WriteString(Msg, Locator);
+ FinalizeAndSend(std::move(Msg));
}
-size_t
-AgentMessageChannel::MeasureUnsignedVarInt(size_t Value)
+void
+AsyncAgentMessageChannel::Execute(const char* Exe,
+ const char* const* Args,
+ size_t NumArgs,
+ const char* WorkingDir,
+ const char* const* EnvVars,
+ size_t NumEnvVars,
+ ExecuteProcessFlags Flags)
{
- if (Value == 0)
+ size_t ReserveSize = 50 + strlen(Exe);
+ for (size_t i = 0; i < NumArgs; ++i)
{
- return 1;
+ ReserveSize += strlen(Args[i]) + 10;
+ }
+ if (WorkingDir)
+ {
+ ReserveSize += strlen(WorkingDir) + 10;
+ }
+ for (size_t i = 0; i < NumEnvVars; ++i)
+ {
+ ReserveSize += strlen(EnvVars[i]) + 20;
}
- return (FloorLog2_64(static_cast<uint64_t>(Value)) / 7) + 1;
-}
-void
-AgentMessageChannel::WriteUnsignedVarInt(size_t Value)
-{
- const size_t ByteCount = MeasureUnsignedVarInt(Value);
- assert(m_RequestSize + ByteCount <= m_MaxRequestSize);
+ auto Msg = BeginMessage(AgentMessageType::ExecuteV2, ReserveSize);
+ WriteString(Msg, Exe);
- uint8_t* Output = m_RequestData + MessageHeaderLength + m_RequestSize;
- for (size_t i = 1; i < ByteCount; ++i)
+ WriteUnsignedVarInt(Msg, NumArgs);
+ for (size_t i = 0; i < NumArgs; ++i)
{
- Output[ByteCount - i] = static_cast<uint8_t>(Value);
- Value >>= 8;
+ WriteString(Msg, Args[i]);
}
- Output[0] = static_cast<uint8_t>((0xFF << (9 - static_cast<int>(ByteCount))) | static_cast<uint8_t>(Value));
- m_RequestSize += ByteCount;
+ WriteOptionalString(Msg, WorkingDir);
+
+ WriteUnsignedVarInt(Msg, NumEnvVars);
+ for (size_t i = 0; i < NumEnvVars; ++i)
+ {
+ const char* Eq = strchr(EnvVars[i], '=');
+ if (Eq == nullptr)
+ {
+ // assert() would be compiled out in release and leave *(Eq+1) as UB -
+ // refuse to build the message for a malformed KEY=VALUE string instead.
+ throw zen::runtime_error("horde agent env var at index {} missing '=' separator", i);
+ }
+
+ WriteString(Msg, std::string_view(EnvVars[i], Eq - EnvVars[i]));
+ if (*(Eq + 1) == '\0')
+ {
+ WriteOptionalString(Msg, nullptr);
+ }
+ else
+ {
+ WriteOptionalString(Msg, Eq + 1);
+ }
+ }
+
+ WriteInt32(Msg, static_cast<int>(Flags));
+ FinalizeAndSend(std::move(Msg));
}
-size_t
-AgentMessageChannel::ReadUnsignedVarInt(const uint8_t** Pos)
+void
+AsyncAgentMessageChannel::Blob(const uint8_t* Data, size_t Length)
{
- const uint8_t* Data = *Pos;
- const uint8_t FirstByte = Data[0];
- const size_t NumBytes = CountLeadingZeros(0xFF & (~static_cast<unsigned int>(FirstByte))) + 1 - 24;
+ static constexpr size_t MaxBlobChunkSize = 512 * 1024;
- size_t Value = static_cast<size_t>(FirstByte & (0xFF >> NumBytes));
- for (size_t i = 1; i < NumBytes; ++i)
+ // The Horde ReadBlobResponse wire format encodes both the chunk Offset and the total
+ // Length as int32. Lengths of 2 GiB or more would wrap to negative and confuse the
+ // remote parser. Refuse the send rather than produce a protocol violation.
+ if (Length > static_cast<size_t>(std::numeric_limits<int32_t>::max()))
{
- Value <<= 8;
- Value |= Data[i];
+ throw zen::runtime_error("horde ReadBlobResponse length {} exceeds int32 wire limit", Length);
}
- *Pos += NumBytes;
- return Value;
+ for (size_t ChunkOffset = 0; ChunkOffset < Length;)
+ {
+ const size_t ChunkLength = std::min(Length - ChunkOffset, MaxBlobChunkSize);
+
+ auto Msg = BeginMessage(AgentMessageType::ReadBlobResponse, ChunkLength + 128);
+ WriteInt32(Msg, static_cast<int32_t>(ChunkOffset));
+ WriteInt32(Msg, static_cast<int32_t>(Length));
+ WriteFixedLengthBytes(Msg, Data + ChunkOffset, ChunkLength);
+ FinalizeAndSend(std::move(Msg));
+
+ ChunkOffset += ChunkLength;
+ }
}
-size_t
-AgentMessageChannel::MeasureString(const char* Text) const
+// --- Async response reading ---
+
+void
+AsyncAgentMessageChannel::AsyncReadResponse(int32_t TimeoutMs, AsyncResponseHandler Handler)
{
- const size_t Length = strlen(Text);
- return MeasureUnsignedVarInt(Length) + Length;
+ // Serialize all access to m_IncomingFrames / m_PendingHandler / m_TimeoutTimer onto
+ // the socket's strand; OnFrame/OnDetach also run on that strand. Without this, the
+ // timer wait completion would run on a bare io_context thread (3 concurrent run()
+ // loops in the provisioner) and race with OnFrame on m_PendingHandler.
+ asio::dispatch(m_Socket->GetStrand(), [this, TimeoutMs, Handler = std::move(Handler)]() mutable {
+ if (!m_IncomingFrames.empty())
+ {
+ std::vector<uint8_t> Frame = std::move(m_IncomingFrames.front());
+ m_IncomingFrames.pop_front();
+
+ if (Frame.size() >= MessageHeaderLength)
+ {
+ AgentMessageType Type = static_cast<AgentMessageType>(Frame[0]);
+ const uint8_t* Data = Frame.data() + MessageHeaderLength;
+ size_t Size = Frame.size() - MessageHeaderLength;
+ asio::post(m_IoContext, [Handler = std::move(Handler), Type, Frame = std::move(Frame), Data, Size]() mutable {
+ // The Frame is captured to keep Data pointer valid
+ Handler(Type, Data, Size);
+ });
+ }
+ else
+ {
+ asio::post(m_IoContext, [Handler = std::move(Handler)] { Handler(AgentMessageType::None, nullptr, 0); });
+ }
+ return;
+ }
+
+ if (m_Detached)
+ {
+ asio::post(m_IoContext, [Handler = std::move(Handler)] { Handler(AgentMessageType::None, nullptr, 0); });
+ return;
+ }
+
+ // No frames queued - store pending handler and arm timeout
+ m_PendingHandler = std::move(Handler);
+
+ if (TimeoutMs >= 0)
+ {
+ m_TimeoutTimer->expires_after(std::chrono::milliseconds(TimeoutMs));
+ m_TimeoutTimer->async_wait(asio::bind_executor(m_Socket->GetStrand(), [this](const asio::error_code& Ec) {
+ if (Ec)
+ {
+ return; // Cancelled - frame arrived before timeout
+ }
+
+ // Already on the strand: safe to mutate m_PendingHandler.
+ if (m_PendingHandler)
+ {
+ AsyncResponseHandler Handler = std::move(m_PendingHandler);
+ m_PendingHandler = nullptr;
+ Handler(AgentMessageType::None, nullptr, 0);
+ }
+ }));
+ }
+ });
}
void
-AgentMessageChannel::WriteString(const char* Text)
+AsyncAgentMessageChannel::OnFrame(std::vector<uint8_t> Data)
{
- const size_t Length = strlen(Text);
- WriteUnsignedVarInt(Length);
- WriteFixedLengthBytes(reinterpret_cast<const uint8_t*>(Text), Length);
+ if (m_PendingHandler)
+ {
+ // Cancel the timeout timer
+ m_TimeoutTimer->cancel();
+
+ AsyncResponseHandler Handler = std::move(m_PendingHandler);
+ m_PendingHandler = nullptr;
+
+ if (Data.size() >= MessageHeaderLength)
+ {
+ AgentMessageType Type = static_cast<AgentMessageType>(Data[0]);
+ const uint8_t* Payload = Data.data() + MessageHeaderLength;
+ size_t PayloadSize = Data.size() - MessageHeaderLength;
+ Handler(Type, Payload, PayloadSize);
+ }
+ else
+ {
+ Handler(AgentMessageType::None, nullptr, 0);
+ }
+ }
+ else
+ {
+ m_IncomingFrames.push_back(std::move(Data));
+ }
}
void
-AgentMessageChannel::WriteString(std::string_view Text)
+AsyncAgentMessageChannel::OnDetach()
{
- WriteUnsignedVarInt(Text.size());
- WriteFixedLengthBytes(reinterpret_cast<const uint8_t*>(Text.data()), Text.size());
+ m_Detached = true;
+
+ if (m_PendingHandler)
+ {
+ m_TimeoutTimer->cancel();
+ AsyncResponseHandler Handler = std::move(m_PendingHandler);
+ m_PendingHandler = nullptr;
+ Handler(AgentMessageType::None, nullptr, 0);
+ }
}
-std::string_view
-AgentMessageChannel::ReadString(const uint8_t** Pos)
+// --- Response parsing helpers ---
+
+bool
+AsyncAgentMessageChannel::ReadException(const uint8_t* Data, size_t Size, ExceptionInfo& Ex)
{
- const size_t Length = ReadUnsignedVarInt(Pos);
- const char* Start = reinterpret_cast<const char*>(ReadFixedLengthBytes(Pos, Length));
- return std::string_view(Start, Length);
+ ReadCursor C{Data, Data + Size, false};
+ Ex.Message = ReadString(C);
+ Ex.Description = ReadString(C);
+ if (C.ParseError)
+ {
+ Ex = {};
+ return false;
+ }
+ return true;
}
-void
-AgentMessageChannel::WriteOptionalString(const char* Text)
+bool
+AsyncAgentMessageChannel::ReadExecuteResult(const uint8_t* Data, size_t Size, int32_t& OutExitCode)
{
- // Optional strings use length+1 encoding: 0 means null/absent,
- // N>0 means a string of length N-1 follows. This matches the UE
- // FAgentMessageChannel serialization convention.
- if (!Text)
+ ReadCursor C{Data, Data + Size, false};
+ OutExitCode = ReadInt32(C);
+ return !C.ParseError;
+}
+
+static bool
+IsSafeLocator(std::string_view Locator)
+{
+ // Reject empty, overlong, path-separator-containing, parent-relative, absolute, or
+ // control-character-containing locators. The locator is used as a filename component
+ // joined with a trusted BundleDir, so the only safe characters are a restricted
+ // filename alphabet.
+ if (Locator.empty() || Locator.size() > 255)
{
- WriteUnsignedVarInt(0);
+ return false;
}
- else
+ if (Locator == "." || Locator == "..")
{
- const size_t Length = strlen(Text);
- WriteUnsignedVarInt(Length + 1);
- WriteFixedLengthBytes(reinterpret_cast<const uint8_t*>(Text), Length);
+ return false;
+ }
+ for (char Ch : Locator)
+ {
+ const unsigned char U = static_cast<unsigned char>(Ch);
+ if (U < 0x20 || U == 0x7F)
+ {
+ return false; // control / NUL / DEL
+ }
+ if (Ch == '/' || Ch == '\\' || Ch == ':')
+ {
+ return false; // path separators / drive letters
+ }
+ }
+ // Disallow leading/trailing dot or whitespace (Windows quirks + hidden-file dodges)
+ if (Locator.front() == '.' || Locator.front() == ' ' || Locator.back() == '.' || Locator.back() == ' ')
+ {
+ return false;
+ }
+ return true;
+}
+
+bool
+AsyncAgentMessageChannel::ReadBlobRequest(const uint8_t* Data, size_t Size, BlobRequest& Req)
+{
+ ReadCursor C{Data, Data + Size, false};
+ Req.Locator = ReadString(C);
+ Req.Offset = ReadUnsignedVarInt(C);
+ Req.Length = ReadUnsignedVarInt(C);
+ if (C.ParseError || !IsSafeLocator(Req.Locator))
+ {
+ Req = {};
+ return false;
}
+ return true;
}
} // namespace zen::horde
diff --git a/src/zenhorde/hordeagentmessage.h b/src/zenhorde/hordeagentmessage.h
index 38c4375fd..fb7c5ed29 100644
--- a/src/zenhorde/hordeagentmessage.h
+++ b/src/zenhorde/hordeagentmessage.h
@@ -4,14 +4,22 @@
#include <zenbase/zenbase.h>
-#include "hordecomputechannel.h"
+#include "hordecomputesocket.h"
#include <cstddef>
#include <cstdint>
+#include <deque>
+#include <functional>
+#include <memory>
#include <string>
#include <string_view>
+#include <system_error>
#include <vector>
+namespace asio {
+class io_context;
+} // namespace asio
+
namespace zen::horde {
/** Agent message types matching the UE EAgentMessageType byte values.
@@ -55,45 +63,34 @@ struct BlobRequest
size_t Length = 0;
};
-/** Channel for sending and receiving agent messages over a ComputeChannel.
+/** Handler for async response reads. Receives the message type and a view of the payload data.
+ * The payload vector is valid until the next AsyncReadResponse call. */
+using AsyncResponseHandler = std::function<void(AgentMessageType Type, const uint8_t* Data, size_t Size)>;
+
+/** Async channel for sending and receiving agent messages over an AsyncComputeSocket.
*
- * Implements the Horde agent message protocol, matching the UE
- * FAgentMessageChannel serialization format exactly. Messages are framed as
- * [type (1B)][payload length (4B)][payload]. Strings use length-prefixed UTF-8;
- * integers use variable-length encoding.
+ * Send methods build messages into vectors and submit them via AsyncComputeSocket.
+ * Receives are delivered via the socket's FrameHandler callback and queued internally.
+ * AsyncReadResponse checks the queue and invokes the handler, with optional timeout.
*
- * The protocol has two directions:
- * - Requests (initiator -> remote): Close, Ping, Fork, Attach, UploadFiles, Execute, Blob
- * - Responses (remote -> initiator): ReadResponse returns the type, then call the
- * appropriate Read* method to parse the payload.
+ * All operations must be externally serialized (e.g. via the socket's strand).
*/
-class AgentMessageChannel
+class AsyncAgentMessageChannel
{
public:
- explicit AgentMessageChannel(Ref<ComputeChannel> Channel);
- ~AgentMessageChannel();
+ AsyncAgentMessageChannel(std::shared_ptr<AsyncComputeSocket> Socket, int ChannelId, asio::io_context& IoContext);
+ ~AsyncAgentMessageChannel();
- AgentMessageChannel(const AgentMessageChannel&) = delete;
- AgentMessageChannel& operator=(const AgentMessageChannel&) = delete;
+ AsyncAgentMessageChannel(const AsyncAgentMessageChannel&) = delete;
+ AsyncAgentMessageChannel& operator=(const AsyncAgentMessageChannel&) = delete;
- // --- Requests (Initiator -> Remote) ---
+ // --- Requests (fire-and-forget sends) ---
- /** Close the channel. */
void Close();
-
- /** Send a keepalive ping. */
void Ping();
-
- /** Fork communication to a new channel with the given ID and buffer size. */
void Fork(int ChannelId, int BufferSize);
-
- /** Send an attach request (used during channel setup handshake). */
void Attach();
-
- /** Request the remote agent to write files from the given bundle locator. */
void UploadFiles(const char* Path, const char* Locator);
-
- /** Execute a process on the remote machine. */
void Execute(const char* Exe,
const char* const* Args,
size_t NumArgs,
@@ -101,61 +98,85 @@ public:
const char* const* EnvVars,
size_t NumEnvVars,
ExecuteProcessFlags Flags = ExecuteProcessFlags::None);
-
- /** Send blob data in response to a ReadBlob request. */
void Blob(const uint8_t* Data, size_t Length);
- // --- Responses (Remote -> Initiator) ---
-
- /** Read the next response message. Returns the message type, or None on timeout.
- * After this returns, use GetResponseData()/GetResponseSize() or the typed
- * Read* methods to access the payload. */
- AgentMessageType ReadResponse(int32_t TimeoutMs = -1, bool* OutTimedOut = nullptr);
+ // --- Async response reading ---
- const void* GetResponseData() const { return m_ResponseData; }
- size_t GetResponseSize() const { return m_ResponseLength; }
+ /** Read the next response. If a frame is already queued, the handler is posted immediately.
+ * Otherwise waits up to TimeoutMs for a frame to arrive. On timeout, invokes the handler
+ * with AgentMessageType::None. */
+ void AsyncReadResponse(int32_t TimeoutMs, AsyncResponseHandler Handler);
- /** Parse an Exception response payload. */
- void ReadException(ExceptionInfo& Ex);
+ /** Called by the socket's FrameHandler when a frame arrives for this channel. */
+ void OnFrame(std::vector<uint8_t> Data);
- /** Parse an ExecuteResult response payload. Returns the exit code. */
- int ReadExecuteResult();
+ /** Called by the socket's DetachHandler. */
+ void OnDetach();
- /** Parse a ReadBlob response payload into a BlobRequest. */
- void ReadBlobRequest(BlobRequest& Req);
+ /** Returns true if the channel has been detached (connection lost). */
+ bool IsDetached() const { return m_Detached; }
-private:
- static constexpr size_t MessageHeaderLength = 5; ///< [type(1B)][length(4B)]
+ // --- Response parsing helpers ---
- Ref<ComputeChannel> m_Channel;
+ /** Parse an Exception message payload. Returns false on malformed/truncated input. */
+ [[nodiscard]] static bool ReadException(const uint8_t* Data, size_t Size, ExceptionInfo& Ex);
- uint8_t* m_RequestData = nullptr;
- size_t m_RequestSize = 0;
- size_t m_MaxRequestSize = 0;
+ /** Parse an ExecuteResult message payload. Returns false on malformed/truncated input. */
+ [[nodiscard]] static bool ReadExecuteResult(const uint8_t* Data, size_t Size, int32_t& OutExitCode);
- AgentMessageType m_ResponseType = AgentMessageType::None;
- const uint8_t* m_ResponseData = nullptr;
- size_t m_ResponseLength = 0;
+ /** Parse a ReadBlob message payload. Returns false on malformed/truncated input or
+ * if the Locator contains characters that would not be safe to use as a path component. */
+ [[nodiscard]] static bool ReadBlobRequest(const uint8_t* Data, size_t Size, BlobRequest& Req);
- void CreateMessage(AgentMessageType Type, size_t MaxLength);
- void FlushMessage();
+private:
+ static constexpr size_t MessageHeaderLength = 5;
+
+ // Message building helpers
+ std::vector<uint8_t> BeginMessage(AgentMessageType Type, size_t ReservePayload);
+ void FinalizeAndSend(std::vector<uint8_t> Msg);
+
+ /** Bounds-checked reader cursor. All Read* helpers set ParseError instead of reading past End. */
+ struct ReadCursor
+ {
+ const uint8_t* Pos = nullptr;
+ const uint8_t* End = nullptr;
+ bool ParseError = false;
+
+ [[nodiscard]] bool CheckAvailable(size_t N)
+ {
+ if (ParseError || static_cast<size_t>(End - Pos) < N)
+ {
+ ParseError = true;
+ return false;
+ }
+ return true;
+ }
+ };
+
+ static void WriteInt32(std::vector<uint8_t>& Buf, int Value);
+ static int ReadInt32(ReadCursor& C);
+
+ static void WriteFixedLengthBytes(std::vector<uint8_t>& Buf, const uint8_t* Data, size_t Length);
+ static const uint8_t* ReadFixedLengthBytes(ReadCursor& C, size_t Length);
- void WriteInt32(int Value);
- static int ReadInt32(const uint8_t** Pos);
+ static size_t MeasureUnsignedVarInt(size_t Value);
+ static void WriteUnsignedVarInt(std::vector<uint8_t>& Buf, size_t Value);
+ static size_t ReadUnsignedVarInt(ReadCursor& C);
- void WriteFixedLengthBytes(const uint8_t* Data, size_t Length);
- static const uint8_t* ReadFixedLengthBytes(const uint8_t** Pos, size_t Length);
+ static void WriteString(std::vector<uint8_t>& Buf, const char* Text);
+ static void WriteString(std::vector<uint8_t>& Buf, std::string_view Text);
+ static std::string_view ReadString(ReadCursor& C);
- static size_t MeasureUnsignedVarInt(size_t Value);
- void WriteUnsignedVarInt(size_t Value);
- static size_t ReadUnsignedVarInt(const uint8_t** Pos);
+ static void WriteOptionalString(std::vector<uint8_t>& Buf, const char* Text);
- size_t MeasureString(const char* Text) const;
- void WriteString(const char* Text);
- void WriteString(std::string_view Text);
- static std::string_view ReadString(const uint8_t** Pos);
+ std::shared_ptr<AsyncComputeSocket> m_Socket;
+ int m_ChannelId;
+ asio::io_context& m_IoContext;
- void WriteOptionalString(const char* Text);
+ std::deque<std::vector<uint8_t>> m_IncomingFrames;
+ AsyncResponseHandler m_PendingHandler;
+ std::unique_ptr<asio::steady_timer> m_TimeoutTimer;
+ bool m_Detached = false;
};
} // namespace zen::horde
diff --git a/src/zenhorde/hordebundle.cpp b/src/zenhorde/hordebundle.cpp
index d3974bc28..8493a9456 100644
--- a/src/zenhorde/hordebundle.cpp
+++ b/src/zenhorde/hordebundle.cpp
@@ -10,6 +10,7 @@
#include <zencore/logging.h>
#include <zencore/process.h>
#include <zencore/trace.h>
+#include <zencore/uid.h>
#include <algorithm>
#include <chrono>
@@ -48,7 +49,7 @@ static constexpr uint8_t BlobType_DirectoryV1[20] = {0x11, 0xEC, 0x14, 0x07, 0x1
static constexpr size_t BlobTypeSize = 20;
-// ─── VarInt helpers (UE format) ─────────────────────────────────────────────
+// --- VarInt helpers (UE format) ---------------------------------------------
static size_t
MeasureVarInt(size_t Value)
@@ -57,7 +58,7 @@ MeasureVarInt(size_t Value)
{
return 1;
}
- return (FloorLog2(static_cast<unsigned int>(Value)) / 7) + 1;
+ return (FloorLog2_64(static_cast<uint64_t>(Value)) / 7) + 1;
}
static void
@@ -76,7 +77,7 @@ WriteVarInt(std::vector<uint8_t>& Buffer, size_t Value)
Output[0] = static_cast<uint8_t>((0xFF << (9 - static_cast<int>(ByteCount))) | static_cast<uint8_t>(Value));
}
-// ─── Binary helpers ─────────────────────────────────────────────────────────
+// --- Binary helpers ---------------------------------------------------------
static void
WriteLE32(std::vector<uint8_t>& Buffer, int32_t Value)
@@ -121,7 +122,7 @@ PatchLE32(std::vector<uint8_t>& Buffer, size_t Offset, int32_t Value)
memcpy(Buffer.data() + Offset, &Value, 4);
}
-// ─── Packet builder ─────────────────────────────────────────────────────────
+// --- Packet builder ---------------------------------------------------------
// Builds a single uncompressed Horde V2 packet. Layout:
// [Signature(3) + Version(1) + PacketLength(4)] 8 bytes (header)
@@ -229,7 +230,7 @@ struct PacketBuilder
{
AlignTo4(Data);
- // ── Type table: count(int32) + count * BlobTypeSize bytes ──
+ // -- Type table: count(int32) + count * BlobTypeSize bytes --
const int32_t TypeTableOffset = static_cast<int32_t>(Data.size());
WriteLE32(Data, static_cast<int32_t>(Types.size()));
for (const uint8_t* TypeEntry : Types)
@@ -237,12 +238,12 @@ struct PacketBuilder
WriteBytes(Data, TypeEntry, BlobTypeSize);
}
- // ── Import table: count(int32) + (count+1) offsets(int32 each) + import data ──
+ // -- Import table: count(int32) + (count+1) offsets(int32 each) + import data --
const int32_t ImportTableOffset = static_cast<int32_t>(Data.size());
const int32_t ImportCount = static_cast<int32_t>(Imports.size());
WriteLE32(Data, ImportCount);
- // Reserve space for (count+1) offset entries — will be patched below
+ // Reserve space for (count+1) offset entries - will be patched below
const size_t ImportOffsetsStart = Data.size();
for (int32_t i = 0; i <= ImportCount; ++i)
{
@@ -266,7 +267,7 @@ struct PacketBuilder
// Sentinel offset (points past the last import's data)
PatchLE32(Data, ImportOffsetsStart + static_cast<size_t>(ImportCount) * 4, static_cast<int32_t>(Data.size()));
- // ── Export table: count(int32) + (count+1) offsets(int32 each) ──
+ // -- Export table: count(int32) + (count+1) offsets(int32 each) --
const int32_t ExportTableOffset = static_cast<int32_t>(Data.size());
const int32_t ExportCount = static_cast<int32_t>(ExportOffsets.size());
WriteLE32(Data, ExportCount);
@@ -278,7 +279,7 @@ struct PacketBuilder
// Sentinel: points to the start of the type table (end of export data region)
WriteLE32(Data, TypeTableOffset);
- // ── Patch header ──
+ // -- Patch header --
// PacketLength = total packet size including the 8-byte header
const int32_t PacketLength = static_cast<int32_t>(Data.size());
PatchLE32(Data, 4, PacketLength);
@@ -290,7 +291,7 @@ struct PacketBuilder
}
};
-// ─── Encoded packet wrapper ─────────────────────────────────────────────────
+// --- Encoded packet wrapper -------------------------------------------------
// Wraps an uncompressed packet with the encoded header:
// [Signature(3) + Version(1) + HeaderLength(4)] 8 bytes
@@ -327,24 +328,22 @@ EncodePacket(std::vector<uint8_t> UncompressedPacket)
return Encoded;
}
-// ─── Bundle blob name generation ────────────────────────────────────────────
+// --- Bundle blob name generation --------------------------------------------
static std::string
GenerateBlobName()
{
- static std::atomic<uint32_t> s_Counter{0};
-
- const int Pid = GetCurrentProcessId();
-
- auto Now = std::chrono::steady_clock::now().time_since_epoch();
- auto Ms = std::chrono::duration_cast<std::chrono::milliseconds>(Now).count();
-
- ExtendableStringBuilder<64> Name;
- Name << Pid << "_" << Ms << "_" << s_Counter.fetch_add(1);
- return std::string(Name.ToView());
+ // Oid is a 12-byte identifier built from a timestamp, a monotonic serial number
+ // initialised from std::random_device, and a per-process run id also drawn from
+ // std::random_device. The 24-hex-char rendering gives ~80 bits of effective
+ // name-prediction entropy, so a local attacker cannot race-create the blob
+ // path before we open it. Previously the name was pid+ms+counter, which two
+ // zenserver processes with the same PID could collide on and which was
+ // entirely predictable.
+ return zen::Oid::NewOid().ToString();
}
-// ─── File info for bundling ─────────────────────────────────────────────────
+// --- File info for bundling -------------------------------------------------
struct FileInfo
{
@@ -357,7 +356,7 @@ struct FileInfo
IoHash RootExportHash; // IoHash of the root export for this file
};
-// ─── CreateBundle implementation ────────────────────────────────────────────
+// --- CreateBundle implementation --------------------------------------------
bool
BundleCreator::CreateBundle(const std::vector<BundleFile>& Files, const std::filesystem::path& OutputDir, BundleResult& OutResult)
@@ -534,7 +533,7 @@ BundleCreator::CreateBundle(const std::vector<BundleFile>& Files, const std::fil
FileInfo& Info = ValidFiles[i];
DirImports.push_back(Info.DirectoryExportImportIndex);
- // IoHash of target (20 bytes) — import is consumed sequentially from the
+ // IoHash of target (20 bytes) - import is consumed sequentially from the
// export's import list by ReadBlobRef, not encoded in the payload
WriteBytes(DirPayload, Info.RootExportHash.Hash, sizeof(IoHash));
// name (string)
@@ -557,8 +556,16 @@ BundleCreator::CreateBundle(const std::vector<BundleFile>& Files, const std::fil
std::vector<uint8_t> UncompressedPacket = Packet.Finish();
std::vector<uint8_t> EncodedPacket = EncodePacket(std::move(UncompressedPacket));
- // Write .blob file
+ // Write .blob file. Refuse to proceed if a file with this name already exists -
+ // the Oid-based BlobName should make collisions astronomically unlikely, so an
+ // existing file implies either an extraordinary collision or an attacker having
+ // pre-seeded the path; either way, we do not want to overwrite it.
const std::filesystem::path BlobFilePath = OutputDir / (BlobName + ".blob");
+ if (std::filesystem::exists(BlobFilePath, Ec))
+ {
+ ZEN_ERROR("blob file already exists at {} - refusing to overwrite", BlobFilePath.string());
+ return false;
+ }
{
BasicFile BlobFile(BlobFilePath, BasicFile::Mode::kTruncate, Ec);
if (Ec)
@@ -574,8 +581,10 @@ BundleCreator::CreateBundle(const std::vector<BundleFile>& Files, const std::fil
Locator << BlobName << "#pkt=0," << uint64_t(EncodedPacket.size()) << "&exp=" << DirExportIndex;
const std::string LocatorStr(Locator.ToView());
- // Write .ref file (use first file's name as the ref base)
- const std::filesystem::path RefFilePath = OutputDir / (ValidFiles[0].Name + ".Bundle.ref");
+ // Write .ref file. Include the Oid-based BlobName so that two concurrent
+ // CreateBundle() calls into the same OutputDir that happen to share the first
+ // filename don't clobber each other's ref file.
+ const std::filesystem::path RefFilePath = OutputDir / (ValidFiles[0].Name + "." + BlobName + ".Bundle.ref");
{
BasicFile RefFile(RefFilePath, BasicFile::Mode::kTruncate, Ec);
if (Ec)
diff --git a/src/zenhorde/hordeclient.cpp b/src/zenhorde/hordeclient.cpp
index fb981f0ba..762edce06 100644
--- a/src/zenhorde/hordeclient.cpp
+++ b/src/zenhorde/hordeclient.cpp
@@ -4,6 +4,7 @@
#include <zencore/iobuffer.h>
#include <zencore/logging.h>
#include <zencore/memoryview.h>
+#include <zencore/string.h>
#include <zencore/trace.h>
#include <zenhorde/hordeclient.h>
#include <zenhttp/httpclient.h>
@@ -14,7 +15,7 @@ ZEN_THIRD_PARTY_INCLUDES_END
namespace zen::horde {
-HordeClient::HordeClient(const HordeConfig& Config) : m_Config(Config), m_Log(zen::logging::Get("horde.client"))
+HordeClient::HordeClient(HordeConfig Config) : m_Config(std::move(Config)), m_Log("horde.client")
{
}
@@ -32,19 +33,24 @@ HordeClient::Initialize()
Settings.RetryCount = 1;
Settings.ExpectedErrorCodes = {HttpResponseCode::ServiceUnavailable, HttpResponseCode::TooManyRequests};
- if (!m_Config.AuthToken.empty())
+ if (m_Config.AccessTokenProvider)
{
+ Settings.AccessTokenProvider = m_Config.AccessTokenProvider;
+ }
+ else if (!m_Config.AuthToken.empty())
+ {
+ // Static tokens have no wire-provided expiry. Synthesising \"now + 24h\" is wrong
+ // in both directions: if the real token expires before 24h we keep sending it after
+ // it dies; if it's long-lived we force unnecessary re-auth churn every day. Use the
+ // never-expires sentinel, matching zenhttp's CreateFromStaticToken.
Settings.AccessTokenProvider = [token = m_Config.AuthToken]() -> HttpClientAccessToken {
- HttpClientAccessToken Token;
- Token.Value = token;
- Token.ExpireTime = HttpClientAccessToken::Clock::now() + std::chrono::hours{24};
- return Token;
+ return HttpClientAccessToken(token, HttpClientAccessToken::TimePoint::max());
};
}
m_Http = std::make_unique<zen::HttpClient>(m_Config.ServerUrl, Settings);
- if (!m_Config.AuthToken.empty())
+ if (Settings.AccessTokenProvider)
{
if (!m_Http->Authenticate())
{
@@ -66,24 +72,21 @@ HordeClient::BuildRequestBody() const
Requirements["pool"] = m_Config.Pool;
}
- std::string Condition;
-#if ZEN_PLATFORM_WINDOWS
ExtendableStringBuilder<256> CondBuf;
+#if ZEN_PLATFORM_WINDOWS
CondBuf << "(OSFamily == 'Windows' || WineEnabled == '" << (m_Config.AllowWine ? "true" : "false") << "')";
- Condition = std::string(CondBuf);
#elif ZEN_PLATFORM_MAC
- Condition = "OSFamily == 'MacOS'";
+ CondBuf << "OSFamily == 'MacOS'";
#else
- Condition = "OSFamily == 'Linux'";
+ CondBuf << "OSFamily == 'Linux'";
#endif
if (!m_Config.Condition.empty())
{
- Condition += " ";
- Condition += m_Config.Condition;
+ CondBuf << " " << m_Config.Condition;
}
- Requirements["condition"] = Condition;
+ Requirements["condition"] = std::string(CondBuf);
Requirements["exclusive"] = true;
json11::Json::object Connection;
@@ -159,39 +162,27 @@ HordeClient::ResolveCluster(const std::string& RequestBody, ClusterInfo& OutClus
return false;
}
- OutCluster.ClusterId = ClusterIdVal.string_value();
- return true;
-}
-
-bool
-HordeClient::ParseHexBytes(std::string_view Hex, uint8_t* Out, size_t OutSize)
-{
- if (Hex.size() != OutSize * 2)
+ // A server-returned ClusterId is interpolated directly into the request URL below
+ // (api/v2/compute/<ClusterId>), so a compromised or MITM'd Horde server could
+ // otherwise inject additional path segments or query strings. Constrain to a
+ // conservative identifier alphabet.
+ const std::string& ClusterIdStr = ClusterIdVal.string_value();
+ if (ClusterIdStr.size() > 64)
{
+ ZEN_WARN("rejecting overlong clusterId ({} bytes) in cluster resolution response", ClusterIdStr.size());
return false;
}
-
- for (size_t i = 0; i < OutSize; ++i)
+ static constexpr AsciiSet ValidClusterIdCharactersSet{"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789._-"};
+ if (!AsciiSet::HasOnly(ClusterIdStr, ValidClusterIdCharactersSet))
{
- auto HexToByte = [](char c) -> int {
- if (c >= '0' && c <= '9')
- return c - '0';
- if (c >= 'a' && c <= 'f')
- return c - 'a' + 10;
- if (c >= 'A' && c <= 'F')
- return c - 'A' + 10;
- return -1;
- };
-
- const int Hi = HexToByte(Hex[i * 2]);
- const int Lo = HexToByte(Hex[i * 2 + 1]);
- if (Hi < 0 || Lo < 0)
- {
- return false;
- }
- Out[i] = static_cast<uint8_t>((Hi << 4) | Lo);
+ ZEN_WARN("rejecting clusterId with unsafe character in cluster resolution response");
+ return false;
}
+ OutCluster.ClusterId = ClusterIdStr;
+
+ ZEN_DEBUG("cluster resolution succeeded: clusterId='{}'", OutCluster.ClusterId);
+
return true;
}
@@ -200,8 +191,6 @@ HordeClient::RequestMachine(const std::string& RequestBody, const std::string& C
{
ZEN_TRACE_CPU("HordeClient::RequestMachine");
- ZEN_INFO("requesting machine from Horde with cluster '{}'", ClusterId.empty() ? "default" : ClusterId.c_str());
-
ExtendableStringBuilder<128> ResourcePath;
ResourcePath << "api/v2/compute/" << (ClusterId.empty() ? "default" : ClusterId.c_str());
@@ -321,11 +310,15 @@ HordeClient::RequestMachine(const std::string& RequestBody, const std::string& C
}
else if (Prop.starts_with("LogicalCores="))
{
- LogicalCores = static_cast<uint16_t>(std::atoi(Prop.c_str() + 13));
+ LogicalCores = ParseInt<uint16_t>(std::string_view(Prop).substr(13)).value_or(0);
}
else if (Prop.starts_with("PhysicalCores="))
{
- PhysicalCores = static_cast<uint16_t>(std::atoi(Prop.c_str() + 14));
+ PhysicalCores = ParseInt<uint16_t>(std::string_view(Prop).substr(14)).value_or(0);
+ }
+ else if (Prop.starts_with("Pool="))
+ {
+ OutMachine.Pool = Prop.substr(5);
}
}
}
@@ -370,10 +363,12 @@ HordeClient::RequestMachine(const std::string& RequestBody, const std::string& C
OutMachine.LeaseId = LeaseIdVal.string_value();
}
- ZEN_INFO("Horde machine assigned [{}:{}] cores={} lease={}",
+ ZEN_INFO("Horde machine assigned [{}:{}] mode={} cores={} pool={} lease={}",
OutMachine.GetConnectionAddress(),
OutMachine.GetConnectionPort(),
+ ToString(OutMachine.Mode),
OutMachine.LogicalCores,
+ OutMachine.Pool,
OutMachine.LeaseId);
return true;
diff --git a/src/zenhorde/hordecomputebuffer.cpp b/src/zenhorde/hordecomputebuffer.cpp
deleted file mode 100644
index 0d032b5d5..000000000
--- a/src/zenhorde/hordecomputebuffer.cpp
+++ /dev/null
@@ -1,454 +0,0 @@
-// Copyright Epic Games, Inc. All Rights Reserved.
-
-#include "hordecomputebuffer.h"
-
-#include <algorithm>
-#include <cassert>
-#include <chrono>
-#include <condition_variable>
-#include <cstring>
-
-namespace zen::horde {
-
-// Simplified ring buffer implementation for in-process use only.
-// Uses a single contiguous buffer with write/read cursors and
-// mutex+condvar for synchronization. This is simpler than the UE version
-// which uses lock-free atomics and shared memory, but sufficient for our
-// use case where we're the initiator side of the compute protocol.
-
-struct ComputeBuffer::Detail : TRefCounted<Detail>
-{
- std::vector<uint8_t> Data;
- size_t NumChunks = 0;
- size_t ChunkLength = 0;
-
- // Current write state
- size_t WriteChunkIdx = 0;
- size_t WriteOffset = 0;
- bool WriteComplete = false;
-
- // Current read state
- size_t ReadChunkIdx = 0;
- size_t ReadOffset = 0;
- bool Detached = false;
-
- // Per-chunk written length
- std::vector<size_t> ChunkWrittenLength;
- std::vector<bool> ChunkFinished; // Writer moved to next chunk
-
- std::mutex Mutex;
- std::condition_variable ReadCV; ///< Signaled when new data is written or stream completes
- std::condition_variable WriteCV; ///< Signaled when reader advances past a chunk, freeing space
-
- bool HasWriter = false;
- bool HasReader = false;
-
- uint8_t* ChunkPtr(size_t ChunkIdx) { return Data.data() + ChunkIdx * ChunkLength; }
- const uint8_t* ChunkPtr(size_t ChunkIdx) const { return Data.data() + ChunkIdx * ChunkLength; }
-};
-
-// ComputeBuffer
-
-ComputeBuffer::ComputeBuffer()
-{
-}
-ComputeBuffer::~ComputeBuffer()
-{
-}
-
-bool
-ComputeBuffer::CreateNew(const Params& InParams)
-{
- auto* NewDetail = new Detail();
- NewDetail->NumChunks = InParams.NumChunks;
- NewDetail->ChunkLength = InParams.ChunkLength;
- NewDetail->Data.resize(InParams.NumChunks * InParams.ChunkLength, 0);
- NewDetail->ChunkWrittenLength.resize(InParams.NumChunks, 0);
- NewDetail->ChunkFinished.resize(InParams.NumChunks, false);
-
- m_Detail = NewDetail;
- return true;
-}
-
-void
-ComputeBuffer::Close()
-{
- m_Detail = nullptr;
-}
-
-bool
-ComputeBuffer::IsValid() const
-{
- return static_cast<bool>(m_Detail);
-}
-
-ComputeBufferReader
-ComputeBuffer::CreateReader()
-{
- assert(m_Detail);
- m_Detail->HasReader = true;
- return ComputeBufferReader(m_Detail);
-}
-
-ComputeBufferWriter
-ComputeBuffer::CreateWriter()
-{
- assert(m_Detail);
- m_Detail->HasWriter = true;
- return ComputeBufferWriter(m_Detail);
-}
-
-// ComputeBufferReader
-
-ComputeBufferReader::ComputeBufferReader()
-{
-}
-ComputeBufferReader::~ComputeBufferReader()
-{
-}
-
-ComputeBufferReader::ComputeBufferReader(const ComputeBufferReader& Other) = default;
-ComputeBufferReader::ComputeBufferReader(ComputeBufferReader&& Other) noexcept = default;
-ComputeBufferReader& ComputeBufferReader::operator=(const ComputeBufferReader& Other) = default;
-ComputeBufferReader& ComputeBufferReader::operator=(ComputeBufferReader&& Other) noexcept = default;
-
-ComputeBufferReader::ComputeBufferReader(Ref<ComputeBuffer::Detail> InDetail) : m_Detail(std::move(InDetail))
-{
-}
-
-void
-ComputeBufferReader::Close()
-{
- m_Detail = nullptr;
-}
-
-void
-ComputeBufferReader::Detach()
-{
- if (m_Detail)
- {
- std::lock_guard<std::mutex> Lock(m_Detail->Mutex);
- m_Detail->Detached = true;
- m_Detail->ReadCV.notify_all();
- }
-}
-
-bool
-ComputeBufferReader::IsValid() const
-{
- return static_cast<bool>(m_Detail);
-}
-
-bool
-ComputeBufferReader::IsComplete() const
-{
- if (!m_Detail)
- {
- return true;
- }
- std::lock_guard<std::mutex> Lock(m_Detail->Mutex);
- if (m_Detail->Detached)
- {
- return true;
- }
- return m_Detail->WriteComplete && m_Detail->ReadChunkIdx == m_Detail->WriteChunkIdx &&
- m_Detail->ReadOffset >= m_Detail->ChunkWrittenLength[m_Detail->ReadChunkIdx];
-}
-
-void
-ComputeBufferReader::AdvanceReadPosition(size_t Size)
-{
- if (!m_Detail)
- {
- return;
- }
-
- std::lock_guard<std::mutex> Lock(m_Detail->Mutex);
-
- m_Detail->ReadOffset += Size;
-
- // Check if we need to move to next chunk
- const size_t ReadChunk = m_Detail->ReadChunkIdx;
- if (m_Detail->ChunkFinished[ReadChunk] && m_Detail->ReadOffset >= m_Detail->ChunkWrittenLength[ReadChunk])
- {
- const size_t NextChunk = (ReadChunk + 1) % m_Detail->NumChunks;
- m_Detail->ReadChunkIdx = NextChunk;
- m_Detail->ReadOffset = 0;
- m_Detail->WriteCV.notify_all();
- }
-
- m_Detail->ReadCV.notify_all();
-}
-
-size_t
-ComputeBufferReader::GetMaxReadSize() const
-{
- if (!m_Detail)
- {
- return 0;
- }
- std::lock_guard<std::mutex> Lock(m_Detail->Mutex);
- const size_t ReadChunk = m_Detail->ReadChunkIdx;
- return m_Detail->ChunkWrittenLength[ReadChunk] - m_Detail->ReadOffset;
-}
-
-const uint8_t*
-ComputeBufferReader::WaitToRead(size_t MinSize, int TimeoutMs, bool* OutTimedOut)
-{
- if (!m_Detail)
- {
- return nullptr;
- }
-
- std::unique_lock<std::mutex> Lock(m_Detail->Mutex);
-
- auto Predicate = [&]() -> bool {
- if (m_Detail->Detached)
- {
- return true;
- }
-
- const size_t ReadChunk = m_Detail->ReadChunkIdx;
- const size_t Available = m_Detail->ChunkWrittenLength[ReadChunk] - m_Detail->ReadOffset;
-
- if (Available >= MinSize)
- {
- return true;
- }
-
- // If chunk is finished and we've read everything, try to move to next
- if (m_Detail->ChunkFinished[ReadChunk] && m_Detail->ReadOffset >= m_Detail->ChunkWrittenLength[ReadChunk])
- {
- if (m_Detail->WriteComplete)
- {
- return true; // End of stream
- }
- // Move to next chunk
- const size_t NextChunk = (ReadChunk + 1) % m_Detail->NumChunks;
- m_Detail->ReadChunkIdx = NextChunk;
- m_Detail->ReadOffset = 0;
- m_Detail->WriteCV.notify_all();
- return false; // Re-check with new chunk
- }
-
- if (m_Detail->WriteComplete)
- {
- return true; // End of stream
- }
-
- return false;
- };
-
- if (TimeoutMs < 0)
- {
- m_Detail->ReadCV.wait(Lock, Predicate);
- }
- else
- {
- if (!m_Detail->ReadCV.wait_for(Lock, std::chrono::milliseconds(TimeoutMs), Predicate))
- {
- if (OutTimedOut)
- {
- *OutTimedOut = true;
- }
- return nullptr;
- }
- }
-
- if (m_Detail->Detached)
- {
- return nullptr;
- }
-
- const size_t ReadChunk = m_Detail->ReadChunkIdx;
- const size_t Available = m_Detail->ChunkWrittenLength[ReadChunk] - m_Detail->ReadOffset;
-
- if (Available < MinSize)
- {
- return nullptr; // End of stream
- }
-
- return m_Detail->ChunkPtr(ReadChunk) + m_Detail->ReadOffset;
-}
-
-size_t
-ComputeBufferReader::Read(void* Buffer, size_t MaxSize, int TimeoutMs, bool* OutTimedOut)
-{
- const uint8_t* Data = WaitToRead(1, TimeoutMs, OutTimedOut);
- if (!Data)
- {
- return 0;
- }
-
- const size_t Available = GetMaxReadSize();
- const size_t ToCopy = std::min(Available, MaxSize);
- memcpy(Buffer, Data, ToCopy);
- AdvanceReadPosition(ToCopy);
- return ToCopy;
-}
-
-// ComputeBufferWriter
-
-ComputeBufferWriter::ComputeBufferWriter() = default;
-ComputeBufferWriter::ComputeBufferWriter(const ComputeBufferWriter& Other) = default;
-ComputeBufferWriter::ComputeBufferWriter(ComputeBufferWriter&& Other) noexcept = default;
-ComputeBufferWriter::~ComputeBufferWriter() = default;
-ComputeBufferWriter& ComputeBufferWriter::operator=(const ComputeBufferWriter& Other) = default;
-ComputeBufferWriter& ComputeBufferWriter::operator=(ComputeBufferWriter&& Other) noexcept = default;
-
-ComputeBufferWriter::ComputeBufferWriter(Ref<ComputeBuffer::Detail> InDetail) : m_Detail(std::move(InDetail))
-{
-}
-
-void
-ComputeBufferWriter::Close()
-{
- if (m_Detail)
- {
- {
- std::lock_guard<std::mutex> Lock(m_Detail->Mutex);
- if (!m_Detail->WriteComplete)
- {
- m_Detail->WriteComplete = true;
- m_Detail->ReadCV.notify_all();
- }
- }
- m_Detail = nullptr;
- }
-}
-
-bool
-ComputeBufferWriter::IsValid() const
-{
- return static_cast<bool>(m_Detail);
-}
-
-void
-ComputeBufferWriter::MarkComplete()
-{
- if (m_Detail)
- {
- std::lock_guard<std::mutex> Lock(m_Detail->Mutex);
- m_Detail->WriteComplete = true;
- m_Detail->ReadCV.notify_all();
- }
-}
-
-void
-ComputeBufferWriter::AdvanceWritePosition(size_t Size)
-{
- if (!m_Detail || Size == 0)
- {
- return;
- }
-
- std::lock_guard<std::mutex> Lock(m_Detail->Mutex);
- const size_t WriteChunk = m_Detail->WriteChunkIdx;
- m_Detail->ChunkWrittenLength[WriteChunk] += Size;
- m_Detail->WriteOffset += Size;
- m_Detail->ReadCV.notify_all();
-}
-
-size_t
-ComputeBufferWriter::GetMaxWriteSize() const
-{
- if (!m_Detail)
- {
- return 0;
- }
- std::lock_guard<std::mutex> Lock(m_Detail->Mutex);
- const size_t WriteChunk = m_Detail->WriteChunkIdx;
- return m_Detail->ChunkLength - m_Detail->ChunkWrittenLength[WriteChunk];
-}
-
-size_t
-ComputeBufferWriter::GetChunkMaxLength() const
-{
- if (!m_Detail)
- {
- return 0;
- }
- return m_Detail->ChunkLength;
-}
-
-size_t
-ComputeBufferWriter::Write(const void* Buffer, size_t MaxSize, int TimeoutMs)
-{
- uint8_t* Dest = WaitToWrite(1, TimeoutMs);
- if (!Dest)
- {
- return 0;
- }
-
- const size_t Available = GetMaxWriteSize();
- const size_t ToCopy = std::min(Available, MaxSize);
- memcpy(Dest, Buffer, ToCopy);
- AdvanceWritePosition(ToCopy);
- return ToCopy;
-}
-
-uint8_t*
-ComputeBufferWriter::WaitToWrite(size_t MinSize, int TimeoutMs)
-{
- if (!m_Detail)
- {
- return nullptr;
- }
-
- std::unique_lock<std::mutex> Lock(m_Detail->Mutex);
-
- if (m_Detail->WriteComplete)
- {
- return nullptr;
- }
-
- const size_t WriteChunk = m_Detail->WriteChunkIdx;
- const size_t Available = m_Detail->ChunkLength - m_Detail->ChunkWrittenLength[WriteChunk];
-
- // If current chunk has enough space, return pointer
- if (Available >= MinSize)
- {
- return m_Detail->ChunkPtr(WriteChunk) + m_Detail->ChunkWrittenLength[WriteChunk];
- }
-
- // Current chunk is full - mark it as finished and move to next.
- // The writer cannot advance until the reader has fully consumed the next chunk,
- // preventing the writer from overwriting data the reader hasn't processed yet.
- m_Detail->ChunkFinished[WriteChunk] = true;
- m_Detail->ReadCV.notify_all();
-
- const size_t NextChunk = (WriteChunk + 1) % m_Detail->NumChunks;
-
- // Wait until reader has consumed the next chunk
- auto Predicate = [&]() -> bool {
- // Check if read has moved past this chunk
- return m_Detail->ReadChunkIdx != NextChunk || m_Detail->Detached;
- };
-
- if (TimeoutMs < 0)
- {
- m_Detail->WriteCV.wait(Lock, Predicate);
- }
- else
- {
- if (!m_Detail->WriteCV.wait_for(Lock, std::chrono::milliseconds(TimeoutMs), Predicate))
- {
- return nullptr;
- }
- }
-
- if (m_Detail->Detached)
- {
- return nullptr;
- }
-
- // Reset next chunk
- m_Detail->ChunkWrittenLength[NextChunk] = 0;
- m_Detail->ChunkFinished[NextChunk] = false;
- m_Detail->WriteChunkIdx = NextChunk;
- m_Detail->WriteOffset = 0;
-
- return m_Detail->ChunkPtr(NextChunk);
-}
-
-} // namespace zen::horde
diff --git a/src/zenhorde/hordecomputebuffer.h b/src/zenhorde/hordecomputebuffer.h
deleted file mode 100644
index 64ef91b7a..000000000
--- a/src/zenhorde/hordecomputebuffer.h
+++ /dev/null
@@ -1,136 +0,0 @@
-// Copyright Epic Games, Inc. All Rights Reserved.
-
-#pragma once
-
-#include <zenbase/refcount.h>
-
-#include <cstddef>
-#include <cstdint>
-#include <mutex>
-#include <vector>
-
-namespace zen::horde {
-
-class ComputeBufferReader;
-class ComputeBufferWriter;
-
-/** Simplified in-process ring buffer for the Horde compute protocol.
- *
- * Unlike the UE FComputeBuffer which supports shared-memory and memory-mapped files,
- * this implementation uses plain heap-allocated memory since we only need in-process
- * communication between channel and transport threads. The buffer is divided into
- * fixed-size chunks; readers and writers block when no space is available.
- */
-class ComputeBuffer
-{
-public:
- struct Params
- {
- size_t NumChunks = 2;
- size_t ChunkLength = 512 * 1024;
- };
-
- ComputeBuffer();
- ~ComputeBuffer();
-
- ComputeBuffer(const ComputeBuffer&) = delete;
- ComputeBuffer& operator=(const ComputeBuffer&) = delete;
-
- bool CreateNew(const Params& InParams);
- void Close();
-
- bool IsValid() const;
-
- ComputeBufferReader CreateReader();
- ComputeBufferWriter CreateWriter();
-
-private:
- struct Detail;
- Ref<Detail> m_Detail;
-
- friend class ComputeBufferReader;
- friend class ComputeBufferWriter;
-};
-
-/** Read endpoint for a ComputeBuffer.
- *
- * Provides blocking reads from the ring buffer. WaitToRead() returns a pointer
- * directly into the buffer memory (zero-copy); the caller must call
- * AdvanceReadPosition() after consuming the data.
- */
-class ComputeBufferReader
-{
-public:
- ComputeBufferReader();
- ComputeBufferReader(const ComputeBufferReader&);
- ComputeBufferReader(ComputeBufferReader&&) noexcept;
- ~ComputeBufferReader();
-
- ComputeBufferReader& operator=(const ComputeBufferReader&);
- ComputeBufferReader& operator=(ComputeBufferReader&&) noexcept;
-
- void Close();
- void Detach();
- bool IsValid() const;
- bool IsComplete() const;
-
- void AdvanceReadPosition(size_t Size);
- size_t GetMaxReadSize() const;
-
- /** Copy up to MaxSize bytes from the buffer into Buffer. Blocks until data is available. */
- size_t Read(void* Buffer, size_t MaxSize, int TimeoutMs = -1, bool* OutTimedOut = nullptr);
-
- /** Wait until at least MinSize bytes are available and return a direct pointer.
- * Returns nullptr on timeout or if the writer has completed. */
- const uint8_t* WaitToRead(size_t MinSize, int TimeoutMs = -1, bool* OutTimedOut = nullptr);
-
-private:
- friend class ComputeBuffer;
- explicit ComputeBufferReader(Ref<ComputeBuffer::Detail> InDetail);
-
- Ref<ComputeBuffer::Detail> m_Detail;
-};
-
-/** Write endpoint for a ComputeBuffer.
- *
- * Provides blocking writes into the ring buffer. WaitToWrite() returns a pointer
- * directly into the buffer memory (zero-copy); the caller must call
- * AdvanceWritePosition() after filling the data. Call MarkComplete() to signal
- * that no more data will be written.
- */
-class ComputeBufferWriter
-{
-public:
- ComputeBufferWriter();
- ComputeBufferWriter(const ComputeBufferWriter&);
- ComputeBufferWriter(ComputeBufferWriter&&) noexcept;
- ~ComputeBufferWriter();
-
- ComputeBufferWriter& operator=(const ComputeBufferWriter&);
- ComputeBufferWriter& operator=(ComputeBufferWriter&&) noexcept;
-
- void Close();
- bool IsValid() const;
-
- /** Signal that no more data will be written. Unblocks any waiting readers. */
- void MarkComplete();
-
- void AdvanceWritePosition(size_t Size);
- size_t GetMaxWriteSize() const;
- size_t GetChunkMaxLength() const;
-
- /** Copy up to MaxSize bytes from Buffer into the ring buffer. Blocks until space is available. */
- size_t Write(const void* Buffer, size_t MaxSize, int TimeoutMs = -1);
-
- /** Wait until at least MinSize bytes of write space are available and return a direct pointer.
- * Returns nullptr on timeout. */
- uint8_t* WaitToWrite(size_t MinSize, int TimeoutMs = -1);
-
-private:
- friend class ComputeBuffer;
- explicit ComputeBufferWriter(Ref<ComputeBuffer::Detail> InDetail);
-
- Ref<ComputeBuffer::Detail> m_Detail;
-};
-
-} // namespace zen::horde
diff --git a/src/zenhorde/hordecomputechannel.cpp b/src/zenhorde/hordecomputechannel.cpp
deleted file mode 100644
index ee2a6f327..000000000
--- a/src/zenhorde/hordecomputechannel.cpp
+++ /dev/null
@@ -1,37 +0,0 @@
-// Copyright Epic Games, Inc. All Rights Reserved.
-
-#include "hordecomputechannel.h"
-
-namespace zen::horde {
-
-ComputeChannel::ComputeChannel(ComputeBufferReader InReader, ComputeBufferWriter InWriter)
-: Reader(std::move(InReader))
-, Writer(std::move(InWriter))
-{
-}
-
-bool
-ComputeChannel::IsValid() const
-{
- return Reader.IsValid() && Writer.IsValid();
-}
-
-size_t
-ComputeChannel::Send(const void* Data, size_t Size, int TimeoutMs)
-{
- return Writer.Write(Data, Size, TimeoutMs);
-}
-
-size_t
-ComputeChannel::Recv(void* Data, size_t Size, int TimeoutMs)
-{
- return Reader.Read(Data, Size, TimeoutMs);
-}
-
-void
-ComputeChannel::MarkComplete()
-{
- Writer.MarkComplete();
-}
-
-} // namespace zen::horde
diff --git a/src/zenhorde/hordecomputechannel.h b/src/zenhorde/hordecomputechannel.h
deleted file mode 100644
index c1dff20e4..000000000
--- a/src/zenhorde/hordecomputechannel.h
+++ /dev/null
@@ -1,32 +0,0 @@
-// Copyright Epic Games, Inc. All Rights Reserved.
-
-#pragma once
-
-#include "hordecomputebuffer.h"
-
-namespace zen::horde {
-
-/** Bidirectional communication channel using a pair of compute buffers.
- *
- * Pairs a ComputeBufferReader (for receiving data) with a ComputeBufferWriter
- * (for sending data). Used by ComputeSocket to represent one logical channel
- * within a multiplexed connection.
- */
-class ComputeChannel : public TRefCounted<ComputeChannel>
-{
-public:
- ComputeBufferReader Reader;
- ComputeBufferWriter Writer;
-
- ComputeChannel(ComputeBufferReader InReader, ComputeBufferWriter InWriter);
-
- bool IsValid() const;
-
- size_t Send(const void* Data, size_t Size, int TimeoutMs = -1);
- size_t Recv(void* Data, size_t Size, int TimeoutMs = -1);
-
- /** Signal that no more data will be sent on this channel. */
- void MarkComplete();
-};
-
-} // namespace zen::horde
diff --git a/src/zenhorde/hordecomputesocket.cpp b/src/zenhorde/hordecomputesocket.cpp
index 6ef67760c..92a56c077 100644
--- a/src/zenhorde/hordecomputesocket.cpp
+++ b/src/zenhorde/hordecomputesocket.cpp
@@ -6,198 +6,326 @@
namespace zen::horde {
-ComputeSocket::ComputeSocket(std::unique_ptr<ComputeTransport> Transport)
-: m_Log(zen::logging::Get("horde.socket"))
+AsyncComputeSocket::AsyncComputeSocket(std::unique_ptr<AsyncComputeTransport> Transport, asio::io_context& IoContext)
+: m_Log(zen::logging::Get("horde.socket.async"))
, m_Transport(std::move(Transport))
+, m_Strand(asio::make_strand(IoContext))
+, m_PingTimer(m_Strand)
{
}
-ComputeSocket::~ComputeSocket()
+AsyncComputeSocket::~AsyncComputeSocket()
{
- // Shutdown order matters: first stop the ping thread, then unblock send threads
- // by detaching readers, then join send threads, and finally close the transport
- // to unblock the recv thread (which is blocked on RecvMessage).
- {
- std::lock_guard<std::mutex> Lock(m_PingMutex);
- m_PingShouldStop = true;
- m_PingCV.notify_all();
- }
-
- for (auto& Reader : m_Readers)
- {
- Reader.Detach();
- }
-
- for (auto& [Id, Thread] : m_SendThreads)
- {
- if (Thread.joinable())
- {
- Thread.join();
- }
- }
+ Close();
+}
- m_Transport->Close();
+void
+AsyncComputeSocket::RegisterChannel(int ChannelId, FrameHandler OnFrame, DetachHandler OnDetach)
+{
+ m_FrameHandlers[ChannelId] = std::move(OnFrame);
+ m_DetachHandlers[ChannelId] = std::move(OnDetach);
+}
- if (m_RecvThread.joinable())
- {
- m_RecvThread.join();
- }
- if (m_PingThread.joinable())
- {
- m_PingThread.join();
- }
+void
+AsyncComputeSocket::StartRecvPump()
+{
+ StartPingTimer();
+ DoRecvHeader();
}
-Ref<ComputeChannel>
-ComputeSocket::CreateChannel(int ChannelId)
+void
+AsyncComputeSocket::DoRecvHeader()
{
- ComputeBuffer::Params Params;
+ auto Self = shared_from_this();
+ m_Transport->AsyncRead(&m_RecvHeader,
+ sizeof(FrameHeader),
+ asio::bind_executor(m_Strand, [this, Self](const std::error_code& Ec, size_t /*Bytes*/) {
+ if (Ec)
+ {
+ if (Ec != asio::error::operation_aborted && !m_Closed)
+ {
+ ZEN_WARN("recv header error: {}", Ec.message());
+ HandleError();
+ }
+ return;
+ }
- ComputeBuffer RecvBuffer;
- if (!RecvBuffer.CreateNew(Params))
- {
- return {};
- }
+ if (m_Closed)
+ {
+ return;
+ }
- ComputeBuffer SendBuffer;
- if (!SendBuffer.CreateNew(Params))
- {
- return {};
- }
+ if (m_RecvHeader.Size >= 0)
+ {
+ DoRecvPayload(m_RecvHeader);
+ }
+ else if (m_RecvHeader.Size == ControlDetach)
+ {
+ if (auto It = m_DetachHandlers.find(m_RecvHeader.Channel); It != m_DetachHandlers.end() && It->second)
+ {
+ It->second();
+ }
+ DoRecvHeader();
+ }
+ else if (m_RecvHeader.Size == ControlPing)
+ {
+ DoRecvHeader();
+ }
+ else
+ {
+ ZEN_WARN("invalid frame header size: {}", m_RecvHeader.Size);
+ }
+ }));
+}
- Ref<ComputeChannel> Channel(new ComputeChannel(RecvBuffer.CreateReader(), SendBuffer.CreateWriter()));
+void
+AsyncComputeSocket::DoRecvPayload(FrameHeader Header)
+{
+ auto PayloadBuf = std::make_shared<std::vector<uint8_t>>(static_cast<size_t>(Header.Size));
+ auto Self = shared_from_this();
- // Attach recv buffer writer (transport recv thread writes into this)
- {
- std::lock_guard<std::mutex> Lock(m_WritersMutex);
- m_Writers.emplace(ChannelId, RecvBuffer.CreateWriter());
- }
+ m_Transport->AsyncRead(PayloadBuf->data(),
+ PayloadBuf->size(),
+ asio::bind_executor(m_Strand, [this, Self, Header, PayloadBuf](const std::error_code& Ec, size_t /*Bytes*/) {
+ if (Ec)
+ {
+ if (Ec != asio::error::operation_aborted && !m_Closed)
+ {
+ ZEN_WARN("recv payload error (channel={}, size={}): {}", Header.Channel, Header.Size, Ec.message());
+ HandleError();
+ }
+ return;
+ }
- // Attach send buffer reader (send thread reads from this)
- {
- ComputeBufferReader Reader = SendBuffer.CreateReader();
- m_Readers.push_back(Reader);
- m_SendThreads.emplace(ChannelId, std::thread(&ComputeSocket::SendThreadProc, this, ChannelId, std::move(Reader)));
- }
+ if (m_Closed)
+ {
+ return;
+ }
+
+ if (auto It = m_FrameHandlers.find(Header.Channel); It != m_FrameHandlers.end() && It->second)
+ {
+ It->second(std::move(*PayloadBuf));
+ }
+ else
+ {
+ ZEN_WARN("recv frame for unknown channel {}", Header.Channel);
+ }
- return Channel;
+ DoRecvHeader();
+ }));
}
void
-ComputeSocket::StartCommunication()
+AsyncComputeSocket::AsyncSendFrame(int ChannelId, std::vector<uint8_t> Data, SendHandler Handler)
{
- m_RecvThread = std::thread(&ComputeSocket::RecvThreadProc, this);
- m_PingThread = std::thread(&ComputeSocket::PingThreadProc, this);
+ auto Self = shared_from_this();
+ asio::dispatch(m_Strand, [this, Self, ChannelId, Data = std::move(Data), Handler = std::move(Handler)]() mutable {
+ if (m_Closed)
+ {
+ if (Handler)
+ {
+ Handler(asio::error::make_error_code(asio::error::operation_aborted));
+ }
+ return;
+ }
+
+ PendingWrite Write;
+ Write.Header.Channel = ChannelId;
+ Write.Header.Size = static_cast<int32_t>(Data.size());
+ Write.Data = std::move(Data);
+ Write.Handler = std::move(Handler);
+
+ m_SendQueue.push_back(std::move(Write));
+ if (m_SendQueue.size() == 1)
+ {
+ FlushNextSend();
+ }
+ });
}
void
-ComputeSocket::PingThreadProc()
+AsyncComputeSocket::AsyncSendDetach(int ChannelId, SendHandler Handler)
{
- while (true)
- {
+ auto Self = shared_from_this();
+ asio::dispatch(m_Strand, [this, Self, ChannelId, Handler = std::move(Handler)]() mutable {
+ if (m_Closed)
{
- std::unique_lock<std::mutex> Lock(m_PingMutex);
- if (m_PingCV.wait_for(Lock, std::chrono::milliseconds(2000), [this] { return m_PingShouldStop; }))
+ if (Handler)
{
- break;
+ Handler(asio::error::make_error_code(asio::error::operation_aborted));
}
+ return;
}
- std::lock_guard<std::mutex> Lock(m_SendMutex);
- FrameHeader Header;
- Header.Channel = 0;
- Header.Size = ControlPing;
- m_Transport->SendMessage(&Header, sizeof(Header));
- }
+ PendingWrite Write;
+ Write.Header.Channel = ChannelId;
+ Write.Header.Size = ControlDetach;
+ Write.Handler = std::move(Handler);
+
+ m_SendQueue.push_back(std::move(Write));
+ if (m_SendQueue.size() == 1)
+ {
+ FlushNextSend();
+ }
+ });
}
void
-ComputeSocket::RecvThreadProc()
+AsyncComputeSocket::FlushNextSend()
{
- // Writers are cached locally to avoid taking m_WritersMutex on every frame.
- // The shared m_Writers map is only accessed when a channel is seen for the first time.
- std::unordered_map<int, ComputeBufferWriter> CachedWriters;
+ if (m_SendQueue.empty() || m_Closed)
+ {
+ return;
+ }
- FrameHeader Header;
- while (m_Transport->RecvMessage(&Header, sizeof(Header)))
+ PendingWrite& Front = m_SendQueue.front();
+
+ if (Front.Data.empty())
{
- if (Header.Size >= 0)
- {
- // Data frame
- auto It = CachedWriters.find(Header.Channel);
- if (It == CachedWriters.end())
- {
- std::lock_guard<std::mutex> Lock(m_WritersMutex);
- auto WIt = m_Writers.find(Header.Channel);
- if (WIt == m_Writers.end())
- {
- ZEN_WARN("recv frame for unknown channel {}", Header.Channel);
- // Skip the data
- std::vector<uint8_t> Discard(Header.Size);
- m_Transport->RecvMessage(Discard.data(), Header.Size);
- continue;
- }
- It = CachedWriters.emplace(Header.Channel, WIt->second).first;
- }
+ // Control frame - header only
+ auto Self = shared_from_this();
+ m_Transport->AsyncWrite(&Front.Header,
+ sizeof(FrameHeader),
+ asio::bind_executor(m_Strand, [this, Self](const std::error_code& Ec, size_t /*Bytes*/) {
+ SendHandler Handler = std::move(m_SendQueue.front().Handler);
+ m_SendQueue.pop_front();
- ComputeBufferWriter& Writer = It->second;
- uint8_t* Dest = Writer.WaitToWrite(Header.Size);
- if (!Dest || !m_Transport->RecvMessage(Dest, Header.Size))
- {
- ZEN_WARN("failed to read frame data (channel={}, size={})", Header.Channel, Header.Size);
- return;
- }
- Writer.AdvanceWritePosition(Header.Size);
- }
- else if (Header.Size == ControlDetach)
- {
- // Detach the recv buffer for this channel
- CachedWriters.erase(Header.Channel);
+ if (Handler)
+ {
+ Handler(Ec);
+ }
- std::lock_guard<std::mutex> Lock(m_WritersMutex);
- auto It = m_Writers.find(Header.Channel);
- if (It != m_Writers.end())
- {
- It->second.MarkComplete();
- m_Writers.erase(It);
- }
- }
- else if (Header.Size == ControlPing)
+ if (Ec)
+ {
+ if (Ec != asio::error::operation_aborted)
+ {
+ ZEN_WARN("send error: {}", Ec.message());
+ HandleError();
+ }
+ return;
+ }
+
+ FlushNextSend();
+ }));
+ }
+ else
+ {
+ // Data frame - write header first, then payload
+ auto Self = shared_from_this();
+ m_Transport->AsyncWrite(&Front.Header,
+ sizeof(FrameHeader),
+ asio::bind_executor(m_Strand, [this, Self](const std::error_code& Ec, size_t /*Bytes*/) {
+ if (Ec)
+ {
+ SendHandler Handler = std::move(m_SendQueue.front().Handler);
+ m_SendQueue.pop_front();
+ if (Handler)
+ {
+ Handler(Ec);
+ }
+ if (Ec != asio::error::operation_aborted)
+ {
+ ZEN_WARN("send header error: {}", Ec.message());
+ HandleError();
+ }
+ return;
+ }
+
+ PendingWrite& Payload = m_SendQueue.front();
+ m_Transport->AsyncWrite(
+ Payload.Data.data(),
+ Payload.Data.size(),
+ asio::bind_executor(m_Strand, [this, Self](const std::error_code& Ec, size_t /*Bytes*/) {
+ SendHandler Handler = std::move(m_SendQueue.front().Handler);
+ m_SendQueue.pop_front();
+
+ if (Handler)
+ {
+ Handler(Ec);
+ }
+
+ if (Ec)
+ {
+ if (Ec != asio::error::operation_aborted)
+ {
+ ZEN_WARN("send payload error: {}", Ec.message());
+ HandleError();
+ }
+ return;
+ }
+
+ FlushNextSend();
+ }));
+ }));
+ }
+}
+
+void
+AsyncComputeSocket::StartPingTimer()
+{
+ if (m_Closed)
+ {
+ return;
+ }
+
+ m_PingTimer.expires_after(std::chrono::seconds(2));
+
+ auto Self = shared_from_this();
+ m_PingTimer.async_wait(asio::bind_executor(m_Strand, [this, Self](const asio::error_code& Ec) {
+ if (Ec || m_Closed)
{
- // Ping response - ignore
+ return;
}
- else
+
+ // Enqueue a ping control frame
+ PendingWrite Write;
+ Write.Header.Channel = 0;
+ Write.Header.Size = ControlPing;
+
+ m_SendQueue.push_back(std::move(Write));
+ if (m_SendQueue.size() == 1)
{
- ZEN_WARN("invalid frame header size: {}", Header.Size);
- return;
+ FlushNextSend();
}
- }
+
+ StartPingTimer();
+ }));
}
void
-ComputeSocket::SendThreadProc(int Channel, ComputeBufferReader Reader)
+AsyncComputeSocket::HandleError()
{
- // Each channel has its own send thread. All send threads share m_SendMutex
- // to serialize writes to the transport, since TCP requires atomic frame writes.
- FrameHeader Header;
- Header.Channel = Channel;
+ if (m_Closed)
+ {
+ return;
+ }
+
+ Close();
- const uint8_t* Data;
- while ((Data = Reader.WaitToRead(1)) != nullptr)
+ // Notify all channels that the connection is gone so agents can clean up
+ for (auto& [ChannelId, Handler] : m_DetachHandlers)
{
- std::lock_guard<std::mutex> Lock(m_SendMutex);
+ if (Handler)
+ {
+ Handler();
+ }
+ }
+}
- Header.Size = static_cast<int32_t>(Reader.GetMaxReadSize());
- m_Transport->SendMessage(&Header, sizeof(Header));
- m_Transport->SendMessage(Data, Header.Size);
- Reader.AdvanceReadPosition(Header.Size);
+void
+AsyncComputeSocket::Close()
+{
+ if (m_Closed)
+ {
+ return;
}
- if (Reader.IsComplete())
+ m_Closed = true;
+ m_PingTimer.cancel();
+
+ if (m_Transport)
{
- std::lock_guard<std::mutex> Lock(m_SendMutex);
- Header.Size = ControlDetach;
- m_Transport->SendMessage(&Header, sizeof(Header));
+ m_Transport->Close();
}
}
diff --git a/src/zenhorde/hordecomputesocket.h b/src/zenhorde/hordecomputesocket.h
index 0c3cb4195..6c494603a 100644
--- a/src/zenhorde/hordecomputesocket.h
+++ b/src/zenhorde/hordecomputesocket.h
@@ -2,45 +2,74 @@
#pragma once
-#include "hordecomputebuffer.h"
-#include "hordecomputechannel.h"
#include "hordetransport.h"
#include <zencore/logbase.h>
-#include <condition_variable>
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <asio.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+#if ZEN_PLATFORM_WINDOWS
+# undef SendMessage
+#endif
+
+#include <deque>
+#include <functional>
#include <memory>
-#include <mutex>
-#include <thread>
+#include <system_error>
#include <unordered_map>
#include <vector>
namespace zen::horde {
-/** Multiplexed socket that routes data between multiple channels over a single transport.
+class AsyncComputeTransport;
+
+/** Handler called when a data frame arrives for a channel. */
+using FrameHandler = std::function<void(std::vector<uint8_t> Data)>;
+
+/** Handler called when a channel is detached by the remote peer. */
+using DetachHandler = std::function<void()>;
+
+/** Handler for async send completion. */
+using SendHandler = std::function<void(const std::error_code&)>;
+
+/** Async multiplexed socket that routes data between channels over a single transport.
*
- * Each channel is identified by an integer ID and backed by a pair of ComputeBuffers.
- * A recv thread demultiplexes incoming frames to channel-specific buffers, while
- * per-channel send threads multiplex outgoing data onto the shared transport.
+ * Uses an async recv pump, a serialized send queue, and a periodic ping timer -
+ * all running on a shared io_context.
*
- * Wire format per frame: [channelId (4B)][size (4B)][data]
- * Control messages use negative sizes: -2 = detach (channel closed), -3 = ping.
+ * Wire format per frame: [channelId(4B)][size(4B)][data].
+ * Control messages use negative sizes: -2 = detach, -3 = ping.
*/
-class ComputeSocket
+class AsyncComputeSocket : public std::enable_shared_from_this<AsyncComputeSocket>
{
public:
- explicit ComputeSocket(std::unique_ptr<ComputeTransport> Transport);
- ~ComputeSocket();
+ AsyncComputeSocket(std::unique_ptr<AsyncComputeTransport> Transport, asio::io_context& IoContext);
+ ~AsyncComputeSocket();
+
+ AsyncComputeSocket(const AsyncComputeSocket&) = delete;
+ AsyncComputeSocket& operator=(const AsyncComputeSocket&) = delete;
+
+ /** Register callbacks for a channel. Must be called before StartRecvPump(). */
+ void RegisterChannel(int ChannelId, FrameHandler OnFrame, DetachHandler OnDetach);
- ComputeSocket(const ComputeSocket&) = delete;
- ComputeSocket& operator=(const ComputeSocket&) = delete;
+ /** Begin the async recv pump and ping timer. */
+ void StartRecvPump();
- /** Create a channel with the given ID.
- * Allocates anonymous in-process buffers and spawns a send thread for the channel. */
- Ref<ComputeChannel> CreateChannel(int ChannelId);
+ /** Enqueue a data frame for async transmission. */
+ void AsyncSendFrame(int ChannelId, std::vector<uint8_t> Data, SendHandler Handler = {});
- /** Start the recv pump and ping threads. Must be called after all channels are created. */
- void StartCommunication();
+ /** Send a control frame (detach) for a channel. */
+ void AsyncSendDetach(int ChannelId, SendHandler Handler = {});
+
+ /** Close the transport and cancel all pending operations. */
+ void Close();
+
+ /** The strand on which all socket I/O callbacks run. Channels that need to serialize
+ * their own state with OnFrame/OnDetach (which are invoked from this strand) should
+ * bind their timers and async operations to it as well. */
+ asio::strand<asio::any_io_executor>& GetStrand() { return m_Strand; }
private:
struct FrameHeader
@@ -49,31 +78,35 @@ private:
int32_t Size = 0;
};
+ struct PendingWrite
+ {
+ FrameHeader Header;
+ std::vector<uint8_t> Data;
+ SendHandler Handler;
+ };
+
static constexpr int32_t ControlDetach = -2;
static constexpr int32_t ControlPing = -3;
LoggerRef Log() { return m_Log; }
- void RecvThreadProc();
- void SendThreadProc(int Channel, ComputeBufferReader Reader);
- void PingThreadProc();
-
- LoggerRef m_Log;
- std::unique_ptr<ComputeTransport> m_Transport;
- std::mutex m_SendMutex; ///< Serializes writes to the transport
-
- std::mutex m_WritersMutex;
- std::unordered_map<int, ComputeBufferWriter> m_Writers; ///< Recv-side: writers keyed by channel ID
+ void DoRecvHeader();
+ void DoRecvPayload(FrameHeader Header);
+ void FlushNextSend();
+ void StartPingTimer();
+ void HandleError();
- std::vector<ComputeBufferReader> m_Readers; ///< Send-side: readers for join on destruction
- std::unordered_map<int, std::thread> m_SendThreads; ///< One send thread per channel
+ LoggerRef m_Log;
+ std::unique_ptr<AsyncComputeTransport> m_Transport;
+ asio::strand<asio::any_io_executor> m_Strand;
+ asio::steady_timer m_PingTimer;
- std::thread m_RecvThread;
- std::thread m_PingThread;
+ std::unordered_map<int, FrameHandler> m_FrameHandlers;
+ std::unordered_map<int, DetachHandler> m_DetachHandlers;
- bool m_PingShouldStop = false;
- std::mutex m_PingMutex;
- std::condition_variable m_PingCV;
+ FrameHeader m_RecvHeader;
+ std::deque<PendingWrite> m_SendQueue;
+ bool m_Closed = false;
};
} // namespace zen::horde
diff --git a/src/zenhorde/hordeconfig.cpp b/src/zenhorde/hordeconfig.cpp
index 2dca228d9..9f6125c64 100644
--- a/src/zenhorde/hordeconfig.cpp
+++ b/src/zenhorde/hordeconfig.cpp
@@ -1,5 +1,7 @@
// Copyright Epic Games, Inc. All Rights Reserved.
+#include <zencore/logging.h>
+#include <zencore/string.h>
#include <zenhorde/hordeconfig.h>
namespace zen::horde {
@@ -9,12 +11,14 @@ HordeConfig::Validate() const
{
if (ServerUrl.empty())
{
+ ZEN_WARN("Horde server URL is not configured");
return false;
}
// Relay mode implies AES encryption
if (Mode == ConnectionMode::Relay && EncryptionMode != Encryption::AES)
{
+ ZEN_WARN("Horde relay mode requires AES encryption, but encryption is set to '{}'", ToString(EncryptionMode));
return false;
}
@@ -52,37 +56,39 @@ ToString(Encryption Enc)
bool
FromString(ConnectionMode& OutMode, std::string_view Str)
{
- if (Str == "direct")
+ if (StrCaseCompare(Str, "direct") == 0)
{
OutMode = ConnectionMode::Direct;
return true;
}
- if (Str == "tunnel")
+ if (StrCaseCompare(Str, "tunnel") == 0)
{
OutMode = ConnectionMode::Tunnel;
return true;
}
- if (Str == "relay")
+ if (StrCaseCompare(Str, "relay") == 0)
{
OutMode = ConnectionMode::Relay;
return true;
}
+ ZEN_WARN("unrecognized Horde connection mode: '{}'", Str);
return false;
}
bool
FromString(Encryption& OutEnc, std::string_view Str)
{
- if (Str == "none")
+ if (StrCaseCompare(Str, "none") == 0)
{
OutEnc = Encryption::None;
return true;
}
- if (Str == "aes")
+ if (StrCaseCompare(Str, "aes") == 0)
{
OutEnc = Encryption::AES;
return true;
}
+ ZEN_WARN("unrecognized Horde encryption mode: '{}'", Str);
return false;
}
diff --git a/src/zenhorde/hordeprovisioner.cpp b/src/zenhorde/hordeprovisioner.cpp
index f88c95da2..ea0ea1e83 100644
--- a/src/zenhorde/hordeprovisioner.cpp
+++ b/src/zenhorde/hordeprovisioner.cpp
@@ -6,49 +6,83 @@
#include "hordeagent.h"
#include "hordebundle.h"
+#include <zencore/compactbinary.h>
#include <zencore/fmtutils.h>
#include <zencore/logging.h>
#include <zencore/scopeguard.h>
#include <zencore/thread.h>
#include <zencore/trace.h>
+#include <zenhttp/httpclient.h>
+#include <zenutil/workerpools.h>
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <asio.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+#include <algorithm>
#include <chrono>
#include <thread>
namespace zen::horde {
-struct HordeProvisioner::AgentWrapper
-{
- std::thread Thread;
- std::atomic<bool> ShouldExit{false};
-};
-
HordeProvisioner::HordeProvisioner(const HordeConfig& Config,
const std::filesystem::path& BinariesPath,
const std::filesystem::path& WorkingDir,
- std::string_view OrchestratorEndpoint)
+ std::string_view OrchestratorEndpoint,
+ std::string_view CoordinatorSession,
+ bool CleanStart,
+ std::string_view TraceHost)
: m_Config(Config)
, m_BinariesPath(BinariesPath)
, m_WorkingDir(WorkingDir)
, m_OrchestratorEndpoint(OrchestratorEndpoint)
+, m_CoordinatorSession(CoordinatorSession)
+, m_CleanStart(CleanStart)
+, m_TraceHost(TraceHost)
, m_Log(zen::logging::Get("horde.provisioner"))
{
+ m_IoContext = std::make_unique<asio::io_context>();
+
+ auto Work = asio::make_work_guard(*m_IoContext);
+ for (int i = 0; i < IoThreadCount; ++i)
+ {
+ m_IoThreads.emplace_back([this, i, Work] {
+ zen::SetCurrentThreadName(fmt::format("horde_io_{}", i));
+ m_IoContext->run();
+ });
+ }
}
HordeProvisioner::~HordeProvisioner()
{
- std::lock_guard<std::mutex> Lock(m_AgentsLock);
- for (auto& Agent : m_Agents)
+ m_AskForAgents.store(false);
+ m_ShutdownEvent.Set();
+
+ // Shut down async agents and io_context
{
- Agent->ShouldExit.store(true);
+ std::lock_guard<std::mutex> Lock(m_AsyncAgentsLock);
+ for (auto& Entry : m_AsyncAgents)
+ {
+ Entry.Agent->Cancel();
+ }
+ m_AsyncAgents.clear();
}
- for (auto& Agent : m_Agents)
+
+ m_IoContext->stop();
+
+ for (auto& Thread : m_IoThreads)
{
- if (Agent->Thread.joinable())
+ if (Thread.joinable())
{
- Agent->Thread.join();
+ Thread.join();
}
}
+
+ // Wait for all pool work items to finish before destroying members they reference
+ if (m_PendingWorkItems.load() > 0)
+ {
+ m_AllWorkDone.Wait();
+ }
}
void
@@ -56,9 +90,23 @@ HordeProvisioner::SetTargetCoreCount(uint32_t Count)
{
ZEN_TRACE_CPU("HordeProvisioner::SetTargetCoreCount");
- m_TargetCoreCount.store(std::min(Count, static_cast<uint32_t>(m_Config.MaxCores)));
+ const uint32_t ClampedCount = std::min(Count, static_cast<uint32_t>(m_Config.MaxCores));
+ const uint32_t PreviousTarget = m_TargetCoreCount.exchange(ClampedCount);
+
+ if (ClampedCount != PreviousTarget)
+ {
+ ZEN_INFO("target core count changed: {} -> {} (active={}, estimated={})",
+ PreviousTarget,
+ ClampedCount,
+ m_ActiveCoreCount.load(),
+ m_EstimatedCoreCount.load());
+ }
- while (m_EstimatedCoreCount.load() < m_TargetCoreCount.load())
+ // Only provision if the gap is at least one agent-sized chunk. Without
+ // this, draining a 32-core agent to cover a 28-core excess would leave a
+ // 4-core gap that triggers a 32-core provision, which triggers another
+ // drain, ad infinitum.
+ while (m_EstimatedCoreCount.load() + EstimatedCoresPerAgent <= m_TargetCoreCount.load())
{
if (!m_AskForAgents.load())
{
@@ -67,21 +115,108 @@ HordeProvisioner::SetTargetCoreCount(uint32_t Count)
RequestAgent();
}
- // Clean up finished agent threads
- std::lock_guard<std::mutex> Lock(m_AgentsLock);
- for (auto It = m_Agents.begin(); It != m_Agents.end();)
+ // Scale down async agents
{
- if ((*It)->ShouldExit.load())
+ std::lock_guard<std::mutex> AsyncLock(m_AsyncAgentsLock);
+
+ uint32_t AsyncActive = m_ActiveCoreCount.load();
+ uint32_t AsyncTarget = m_TargetCoreCount.load();
+
+ uint32_t AlreadyDrainingCores = 0;
+ for (const auto& Entry : m_AsyncAgents)
{
- if ((*It)->Thread.joinable())
+ if (Entry.Draining)
{
- (*It)->Thread.join();
+ AlreadyDrainingCores += Entry.CoreCount;
}
- It = m_Agents.erase(It);
}
- else
+
+ uint32_t EffectiveAsync = (AsyncActive > AlreadyDrainingCores) ? AsyncActive - AlreadyDrainingCores : 0;
+
+ if (EffectiveAsync > AsyncTarget)
{
- ++It;
+ struct Candidate
+ {
+ AsyncAgentEntry* Entry;
+ int Workload;
+ };
+ std::vector<Candidate> Candidates;
+
+ for (auto& Entry : m_AsyncAgents)
+ {
+ if (Entry.Draining || Entry.RemoteEndpoint.empty())
+ {
+ continue;
+ }
+
+ int Workload = 0;
+ bool Reachable = false;
+ HttpClientSettings Settings;
+ Settings.LogCategory = "horde.drain";
+ Settings.ConnectTimeout = std::chrono::milliseconds{2000};
+ Settings.Timeout = std::chrono::milliseconds{3000};
+ try
+ {
+ HttpClient Client(Entry.RemoteEndpoint, Settings);
+ HttpClient::Response Resp = Client.Get("/compute/session/status");
+ if (Resp.IsSuccess())
+ {
+ CbObject Status = Resp.AsObject();
+ Workload = Status["actions_pending"].AsInt32(0) + Status["actions_running"].AsInt32(0);
+ Reachable = true;
+ }
+ }
+ catch (const std::exception& Ex)
+ {
+ ZEN_DEBUG("agent lease={} not yet reachable for drain: {}", Entry.LeaseId, Ex.what());
+ }
+
+ if (Reachable)
+ {
+ Candidates.push_back({&Entry, Workload});
+ }
+ }
+
+ const uint32_t ExcessCores = EffectiveAsync - AsyncTarget;
+ uint32_t CoresDrained = 0;
+
+ while (CoresDrained < ExcessCores && !Candidates.empty())
+ {
+ const uint32_t Remaining = ExcessCores - CoresDrained;
+
+ Candidates.erase(std::remove_if(Candidates.begin(),
+ Candidates.end(),
+ [Remaining](const Candidate& C) { return C.Entry->CoreCount > Remaining; }),
+ Candidates.end());
+
+ if (Candidates.empty())
+ {
+ break;
+ }
+
+ Candidate* Best = &Candidates[0];
+ for (auto& C : Candidates)
+ {
+ if (C.Entry->CoreCount > Best->Entry->CoreCount ||
+ (C.Entry->CoreCount == Best->Entry->CoreCount && C.Workload < Best->Workload))
+ {
+ Best = &C;
+ }
+ }
+
+ ZEN_INFO("draining async agent lease={} ({} cores, workload={})",
+ Best->Entry->LeaseId,
+ Best->Entry->CoreCount,
+ Best->Workload);
+
+ DrainAsyncAgent(*Best->Entry);
+ CoresDrained += Best->Entry->CoreCount;
+
+ AsyncAgentEntry* Drained = Best->Entry;
+ Candidates.erase(
+ std::remove_if(Candidates.begin(), Candidates.end(), [Drained](const Candidate& C) { return C.Entry == Drained; }),
+ Candidates.end());
+ }
}
}
}
@@ -101,266 +236,395 @@ HordeProvisioner::GetStats() const
uint32_t
HordeProvisioner::GetAgentCount() const
{
- std::lock_guard<std::mutex> Lock(m_AgentsLock);
- return static_cast<uint32_t>(m_Agents.size());
+ std::lock_guard<std::mutex> Lock(m_AsyncAgentsLock);
+ return static_cast<uint32_t>(m_AsyncAgents.size());
}
-void
-HordeProvisioner::RequestAgent()
+compute::AgentProvisioningStatus
+HordeProvisioner::GetAgentStatus(std::string_view WorkerId) const
{
- m_EstimatedCoreCount.fetch_add(EstimatedCoresPerAgent);
+ // Worker IDs are "horde-{LeaseId}" - strip the prefix to match lease ID
+ constexpr std::string_view Prefix = "horde-";
+ if (!WorkerId.starts_with(Prefix))
+ {
+ return compute::AgentProvisioningStatus::Unknown;
+ }
+ std::string_view LeaseId = WorkerId.substr(Prefix.size());
- std::lock_guard<std::mutex> Lock(m_AgentsLock);
+ std::lock_guard<std::mutex> AsyncLock(m_AsyncAgentsLock);
+ for (const auto& Entry : m_AsyncAgents)
+ {
+ if (Entry.LeaseId == LeaseId)
+ {
+ if (Entry.Draining)
+ {
+ return compute::AgentProvisioningStatus::Draining;
+ }
+ return compute::AgentProvisioningStatus::Active;
+ }
+ }
- auto Wrapper = std::make_unique<AgentWrapper>();
- AgentWrapper& Ref = *Wrapper;
- Wrapper->Thread = std::thread([this, &Ref] { ThreadAgent(Ref); });
+ // Check recently-drained agents that have already been cleaned up
+ std::string WorkerIdStr(WorkerId);
+ if (m_RecentlyDrainedWorkerIds.erase(WorkerIdStr) > 0)
+ {
+ // Also remove from the ordering queue so size accounting stays consistent.
+ auto It = std::find(m_RecentlyDrainedOrder.begin(), m_RecentlyDrainedOrder.end(), WorkerIdStr);
+ if (It != m_RecentlyDrainedOrder.end())
+ {
+ m_RecentlyDrainedOrder.erase(It);
+ }
+ return compute::AgentProvisioningStatus::Draining;
+ }
- m_Agents.push_back(std::move(Wrapper));
+ return compute::AgentProvisioningStatus::Unknown;
}
-void
-HordeProvisioner::ThreadAgent(AgentWrapper& Wrapper)
+std::vector<std::string>
+HordeProvisioner::BuildAgentArgs(const MachineInfo& Machine) const
{
- ZEN_TRACE_CPU("HordeProvisioner::ThreadAgent");
+ std::vector<std::string> Args;
+ Args.emplace_back("compute");
+ Args.emplace_back("--http=asio");
+ Args.push_back(fmt::format("--port={}", m_Config.ZenServicePort));
+ Args.emplace_back("--data-dir=%UE_HORDE_SHARED_DIR%\\zen");
- static std::atomic<uint32_t> ThreadIndex{0};
- const uint32_t CurrentIndex = ThreadIndex.fetch_add(1);
+ if (m_CleanStart)
+ {
+ Args.emplace_back("--clean");
+ }
- zen::SetCurrentThreadName(fmt::format("horde_agent_{}", CurrentIndex));
+ if (!m_OrchestratorEndpoint.empty())
+ {
+ ExtendableStringBuilder<256> CoordArg;
+ CoordArg << "--coordinator-endpoint=" << m_OrchestratorEndpoint;
+ Args.emplace_back(CoordArg.ToView());
+ }
- std::unique_ptr<HordeAgent> Agent;
- uint32_t MachineCoreCount = 0;
+ {
+ ExtendableStringBuilder<128> IdArg;
+ IdArg << "--instance-id=horde-" << Machine.LeaseId;
+ Args.emplace_back(IdArg.ToView());
+ }
- auto _ = MakeGuard([&] {
- if (Agent)
- {
- Agent->CloseConnection();
- }
- Wrapper.ShouldExit.store(true);
- });
+ if (!m_CoordinatorSession.empty())
+ {
+ ExtendableStringBuilder<128> SessionArg;
+ SessionArg << "--coordinator-session=" << m_CoordinatorSession;
+ Args.emplace_back(SessionArg.ToView());
+ }
+ if (!m_TraceHost.empty())
{
- // EstimatedCoreCount is incremented speculatively when the agent is requested
- // (in RequestAgent) so that SetTargetCoreCount doesn't over-provision.
- auto $ = MakeGuard([&] { m_EstimatedCoreCount.fetch_sub(EstimatedCoresPerAgent); });
+ ExtendableStringBuilder<128> TraceArg;
+ TraceArg << "--tracehost=" << m_TraceHost;
+ Args.emplace_back(TraceArg.ToView());
+ }
+ // In relay mode, the remote zenserver's local address is not reachable from the
+ // orchestrator. Pass the relay-visible endpoint so it announces the correct URL.
+ if (Machine.Mode == ConnectionMode::Relay)
+ {
+ const auto [Addr, Port] = Machine.GetZenServiceEndpoint(m_Config.ZenServicePort);
+ if (Addr.find(':') != std::string::npos)
+ {
+ Args.push_back(fmt::format("--announce-url=http://[{}]:{}", Addr, Port));
+ }
+ else
{
- ZEN_TRACE_CPU("HordeProvisioner::CreateBundles");
+ Args.push_back(fmt::format("--announce-url=http://{}:{}", Addr, Port));
+ }
+ }
- std::lock_guard<std::mutex> BundleLock(m_BundleLock);
+ return Args;
+}
- if (!m_BundlesCreated)
- {
- const std::filesystem::path OutputDir = m_WorkingDir / "horde_bundles";
+bool
+HordeProvisioner::InitializeHordeClient()
+{
+ ZEN_TRACE_CPU("HordeProvisioner::InitializeHordeClient");
+
+ std::lock_guard<std::mutex> BundleLock(m_BundleLock);
- std::vector<BundleFile> Files;
+ if (!m_BundlesCreated)
+ {
+ const std::filesystem::path OutputDir = m_WorkingDir / "horde_bundles";
+
+ std::vector<BundleFile> Files;
#if ZEN_PLATFORM_WINDOWS
- Files.emplace_back(m_BinariesPath / "zenserver.exe", false);
+ Files.emplace_back(m_BinariesPath / "zenserver.exe", false);
+ Files.emplace_back(m_BinariesPath / "zenserver.pdb", true);
#elif ZEN_PLATFORM_LINUX
- Files.emplace_back(m_BinariesPath / "zenserver", false);
- Files.emplace_back(m_BinariesPath / "zenserver.debug", true);
+ Files.emplace_back(m_BinariesPath / "zenserver", false);
+ Files.emplace_back(m_BinariesPath / "zenserver.debug", true);
#elif ZEN_PLATFORM_MAC
- Files.emplace_back(m_BinariesPath / "zenserver", false);
+ Files.emplace_back(m_BinariesPath / "zenserver", false);
#endif
- BundleResult Result;
- if (!BundleCreator::CreateBundle(Files, OutputDir, Result))
- {
- ZEN_WARN("failed to create bundle, cannot provision any agents!");
- m_AskForAgents.store(false);
- return;
- }
-
- m_Bundles.emplace_back(Result.Locator, Result.BundleDir);
- m_BundlesCreated = true;
- }
-
- if (!m_HordeClient)
- {
- m_HordeClient = std::make_unique<HordeClient>(m_Config);
- if (!m_HordeClient->Initialize())
- {
- ZEN_WARN("failed to initialize Horde HTTP client, cannot provision any agents!");
- m_AskForAgents.store(false);
- return;
- }
- }
+ BundleResult Result;
+ if (!BundleCreator::CreateBundle(Files, OutputDir, Result))
+ {
+ ZEN_WARN("failed to create bundle, cannot provision any agents!");
+ m_AskForAgents.store(false);
+ m_ShutdownEvent.Set();
+ return false;
}
- if (!m_AskForAgents.load())
+ m_Bundles.emplace_back(Result.Locator, Result.BundleDir);
+ m_BundlesCreated = true;
+ }
+
+ if (!m_HordeClient)
+ {
+ m_HordeClient = std::make_unique<HordeClient>(m_Config);
+ if (!m_HordeClient->Initialize())
{
- return;
+ ZEN_WARN("failed to initialize Horde HTTP client, cannot provision any agents!");
+ m_AskForAgents.store(false);
+ m_ShutdownEvent.Set();
+ return false;
}
+ }
- m_AgentsRequesting.fetch_add(1);
- auto ReqGuard = MakeGuard([this] { m_AgentsRequesting.fetch_sub(1); });
+ return true;
+}
- // Simple backoff: if the last machine request failed, wait up to 5 seconds
- // before trying again.
- //
- // Note however that it's possible that multiple threads enter this code at
- // the same time if multiple agents are requested at once, and they will all
- // see the same last failure time and back off accordingly. We might want to
- // use a semaphore or similar to limit the number of concurrent requests.
+void
+HordeProvisioner::RequestAgent()
+{
+ m_EstimatedCoreCount.fetch_add(EstimatedCoresPerAgent);
- if (const uint64_t LastFail = m_LastRequestFailTime.load(); LastFail != 0)
- {
- auto Now = static_cast<uint64_t>(std::chrono::steady_clock::now().time_since_epoch().count());
- const uint64_t ElapsedNs = Now - LastFail;
- const uint64_t ElapsedMs = ElapsedNs / 1'000'000;
- if (ElapsedMs < 5000)
- {
- const uint64_t WaitMs = 5000 - ElapsedMs;
- for (uint64_t Waited = 0; Waited < WaitMs && !Wrapper.ShouldExit.load(); Waited += 100)
- {
- std::this_thread::sleep_for(std::chrono::milliseconds(100));
- }
+ if (m_PendingWorkItems.fetch_add(1) == 0)
+ {
+ m_AllWorkDone.Reset();
+ }
- if (Wrapper.ShouldExit.load())
+ GetSmallWorkerPool(EWorkloadType::Background)
+ .ScheduleWork(
+ [this] {
+ ProvisionAgent();
+ if (m_PendingWorkItems.fetch_sub(1) == 1)
{
- return;
+ m_AllWorkDone.Set();
}
- }
- }
+ },
+ WorkerThreadPool::EMode::EnableBacklog);
+}
- if (m_ActiveCoreCount.load() >= m_TargetCoreCount.load())
- {
- return;
- }
+void
+HordeProvisioner::ProvisionAgent()
+{
+ ZEN_TRACE_CPU("HordeProvisioner::ProvisionAgent");
+
+ // EstimatedCoreCount is incremented speculatively when the agent is requested
+ // (in RequestAgent) so that SetTargetCoreCount doesn't over-provision.
+ auto $ = MakeGuard([&] { m_EstimatedCoreCount.fetch_sub(EstimatedCoresPerAgent); });
- std::string RequestBody = m_HordeClient->BuildRequestBody();
+ if (!InitializeHordeClient())
+ {
+ return;
+ }
- // Resolve cluster if needed
- std::string ClusterId = m_Config.Cluster;
- if (ClusterId == HordeConfig::ClusterAuto)
+ if (!m_AskForAgents.load())
+ {
+ return;
+ }
+
+ m_AgentsRequesting.fetch_add(1);
+ auto ReqGuard = MakeGuard([this] { m_AgentsRequesting.fetch_sub(1); });
+
+ // Simple backoff: if the last machine request failed, wait up to 5 seconds
+ // before trying again.
+ //
+ // Note however that it's possible that multiple threads enter this code at
+ // the same time if multiple agents are requested at once, and they will all
+ // see the same last failure time and back off accordingly. We might want to
+ // use a semaphore or similar to limit the number of concurrent requests.
+
+ if (const uint64_t LastFail = m_LastRequestFailTime.load(); LastFail != 0)
+ {
+ auto Now = static_cast<uint64_t>(std::chrono::steady_clock::now().time_since_epoch().count());
+ const uint64_t ElapsedNs = Now - LastFail;
+ const uint64_t ElapsedMs = ElapsedNs / 1'000'000;
+ if (ElapsedMs < 5000)
{
- ClusterInfo Cluster;
- if (!m_HordeClient->ResolveCluster(RequestBody, Cluster))
+ // Wait on m_ShutdownEvent so shutdown wakes this pool thread immediately instead
+ // of stalling for up to 5s in 100ms sleep chunks. Wait() returns true iff the
+ // event was signaled (shutdown); false means the backoff elapsed normally.
+ const uint64_t WaitMs = 5000 - ElapsedMs;
+ if (m_ShutdownEvent.Wait(static_cast<int>(WaitMs)))
{
- ZEN_WARN("failed to resolve cluster");
- m_LastRequestFailTime.store(static_cast<uint64_t>(std::chrono::steady_clock::now().time_since_epoch().count()));
return;
}
- ClusterId = Cluster.ClusterId;
}
+ }
- MachineInfo Machine;
- if (!m_HordeClient->RequestMachine(RequestBody, ClusterId, /* out */ Machine) || !Machine.IsValid())
+ if (m_ActiveCoreCount.load() >= m_TargetCoreCount.load())
+ {
+ return;
+ }
+
+ std::string RequestBody = m_HordeClient->BuildRequestBody();
+
+ // Resolve cluster if needed
+ std::string ClusterId = m_Config.Cluster;
+ if (ClusterId == HordeConfig::ClusterAuto)
+ {
+ ClusterInfo Cluster;
+ if (!m_HordeClient->ResolveCluster(RequestBody, Cluster))
{
+ ZEN_WARN("failed to resolve cluster");
m_LastRequestFailTime.store(static_cast<uint64_t>(std::chrono::steady_clock::now().time_since_epoch().count()));
return;
}
+ ClusterId = Cluster.ClusterId;
+ }
- m_LastRequestFailTime.store(0);
+ ZEN_INFO("requesting machine from Horde (cluster='{}', cores={}/{})",
+ ClusterId.empty() ? "default" : ClusterId.c_str(),
+ m_ActiveCoreCount.load(),
+ m_TargetCoreCount.load());
- if (Wrapper.ShouldExit.load())
- {
- return;
- }
+ MachineInfo Machine;
+ if (!m_HordeClient->RequestMachine(RequestBody, ClusterId, /* out */ Machine) || !Machine.IsValid())
+ {
+ m_LastRequestFailTime.store(static_cast<uint64_t>(std::chrono::steady_clock::now().time_since_epoch().count()));
+ return;
+ }
- // Connect to agent and perform handshake
- Agent = std::make_unique<HordeAgent>(Machine);
- if (!Agent->IsValid())
- {
- ZEN_WARN("agent creation failed for {}:{}", Machine.GetConnectionAddress(), Machine.GetConnectionPort());
- return;
- }
+ m_LastRequestFailTime.store(0);
- if (!Agent->BeginCommunication())
- {
- ZEN_WARN("BeginCommunication failed");
- return;
- }
+ if (!m_AskForAgents.load())
+ {
+ return;
+ }
- for (auto& [Locator, BundleDir] : m_Bundles)
+ AsyncAgentConfig AgentConfig;
+ AgentConfig.Machine = Machine;
+ AgentConfig.Bundles = m_Bundles;
+ AgentConfig.Args = BuildAgentArgs(Machine);
+
+#if ZEN_PLATFORM_WINDOWS
+ AgentConfig.UseWine = !Machine.IsWindows;
+ AgentConfig.Executable = "zenserver.exe";
+#else
+ AgentConfig.UseWine = false;
+ AgentConfig.Executable = "zenserver";
+#endif
+
+ auto AsyncAgent = std::make_shared<AsyncHordeAgent>(*m_IoContext);
+
+ AsyncAgentEntry Entry;
+ Entry.Agent = AsyncAgent;
+ Entry.LeaseId = Machine.LeaseId;
+ Entry.CoreCount = Machine.LogicalCores;
+
+ const auto [EndpointAddr, EndpointPort] = Machine.GetZenServiceEndpoint(m_Config.ZenServicePort);
+ if (EndpointAddr.find(':') != std::string::npos)
+ {
+ Entry.RemoteEndpoint = fmt::format("http://[{}]:{}", EndpointAddr, EndpointPort);
+ }
+ else
+ {
+ Entry.RemoteEndpoint = fmt::format("http://{}:{}", EndpointAddr, EndpointPort);
+ }
+
+ {
+ std::lock_guard<std::mutex> Lock(m_AsyncAgentsLock);
+ m_AsyncAgents.push_back(std::move(Entry));
+ }
+
+ AsyncAgent->Start(std::move(AgentConfig), [this, AsyncAgent](const AsyncAgentResult& Result) {
+ if (Result.CoreCount > 0)
{
- if (Wrapper.ShouldExit.load())
+ // Only subtract estimated cores if not already subtracted by DrainAsyncAgent
+ bool WasDraining = false;
{
- return;
+ std::lock_guard<std::mutex> Lock(m_AsyncAgentsLock);
+ for (const auto& Entry : m_AsyncAgents)
+ {
+ if (Entry.Agent == AsyncAgent)
+ {
+ WasDraining = Entry.Draining;
+ break;
+ }
+ }
}
- if (!Agent->UploadBinaries(BundleDir, Locator))
+ if (!WasDraining)
{
- ZEN_WARN("UploadBinaries failed");
- return;
+ m_EstimatedCoreCount.fetch_sub(Result.CoreCount);
}
+ m_ActiveCoreCount.fetch_sub(Result.CoreCount);
+ m_AgentsActive.fetch_sub(1);
}
+ OnAsyncAgentDone(AsyncAgent);
+ });
- if (Wrapper.ShouldExit.load())
- {
- return;
- }
-
- // Build command line for remote zenserver
- std::vector<std::string> ArgStrings;
- ArgStrings.push_back("compute");
- ArgStrings.push_back("--http=asio");
+ // Track active cores (estimated was already added by RequestAgent)
+ m_EstimatedCoreCount.fetch_add(Machine.LogicalCores);
+ m_ActiveCoreCount.fetch_add(Machine.LogicalCores);
+ m_AgentsActive.fetch_add(1);
+}
- // TEMP HACK - these should be made fully dynamic
- // these are currently here to allow spawning the compute agent locally
- // for debugging purposes (i.e with a local Horde Server+Agent setup)
- ArgStrings.push_back(fmt::format("--port={}", m_Config.ZenServicePort));
- ArgStrings.push_back("--data-dir=c:\\temp\\123");
+void
+HordeProvisioner::DrainAsyncAgent(AsyncAgentEntry& Entry)
+{
+ Entry.Draining = true;
+ m_EstimatedCoreCount.fetch_sub(Entry.CoreCount);
+ m_AgentsDraining.fetch_add(1);
- if (!m_OrchestratorEndpoint.empty())
- {
- ExtendableStringBuilder<256> CoordArg;
- CoordArg << "--coordinator-endpoint=" << m_OrchestratorEndpoint;
- ArgStrings.emplace_back(CoordArg.ToView());
- }
+ HttpClientSettings Settings;
+ Settings.LogCategory = "horde.drain";
+ Settings.ConnectTimeout = std::chrono::milliseconds{5000};
+ Settings.Timeout = std::chrono::milliseconds{10000};
- {
- ExtendableStringBuilder<128> IdArg;
- IdArg << "--instance-id=horde-" << Machine.LeaseId;
- ArgStrings.emplace_back(IdArg.ToView());
- }
+ try
+ {
+ HttpClient Client(Entry.RemoteEndpoint, Settings);
- std::vector<const char*> Args;
- Args.reserve(ArgStrings.size());
- for (const std::string& Arg : ArgStrings)
+ HttpClient::Response Response = Client.Post("/compute/session/drain");
+ if (!Response.IsSuccess())
{
- Args.push_back(Arg.c_str());
+ ZEN_WARN("drain[{}]: POST session/drain failed: HTTP {}", Entry.LeaseId, static_cast<int>(Response.StatusCode));
+ return;
}
-#if ZEN_PLATFORM_WINDOWS
- const bool UseWine = !Machine.IsWindows;
- const char* AppName = "zenserver.exe";
-#else
- const bool UseWine = false;
- const char* AppName = "zenserver";
-#endif
-
- Agent->Execute(AppName, Args.data(), Args.size(), nullptr, nullptr, 0, UseWine);
-
- ZEN_INFO("remote execution started on [{}:{}] lease={}",
- Machine.GetConnectionAddress(),
- Machine.GetConnectionPort(),
- Machine.LeaseId);
-
- MachineCoreCount = Machine.LogicalCores;
- m_EstimatedCoreCount.fetch_add(MachineCoreCount);
- m_ActiveCoreCount.fetch_add(MachineCoreCount);
- m_AgentsActive.fetch_add(1);
+ ZEN_INFO("drain[{}]: session/drain accepted, sending sunset", Entry.LeaseId);
+ (void)Client.Post("/compute/session/sunset");
}
+ catch (const std::exception& Ex)
+ {
+ ZEN_WARN("drain[{}]: exception: {}", Entry.LeaseId, Ex.what());
+ }
+}
- // Agent poll loop
-
- auto ActiveGuard = MakeGuard([&]() {
- m_EstimatedCoreCount.fetch_sub(MachineCoreCount);
- m_ActiveCoreCount.fetch_sub(MachineCoreCount);
- m_AgentsActive.fetch_sub(1);
- });
-
- while (Agent->IsValid() && !Wrapper.ShouldExit.load())
+void
+HordeProvisioner::OnAsyncAgentDone(std::shared_ptr<AsyncHordeAgent> Agent)
+{
+ std::lock_guard<std::mutex> Lock(m_AsyncAgentsLock);
+ for (auto It = m_AsyncAgents.begin(); It != m_AsyncAgents.end(); ++It)
{
- const bool LogOutput = false;
- if (!Agent->Poll(LogOutput))
+ if (It->Agent == Agent)
{
+ if (It->Draining)
+ {
+ m_AgentsDraining.fetch_sub(1);
+ std::string WorkerId = "horde-" + It->LeaseId;
+ if (m_RecentlyDrainedWorkerIds.insert(WorkerId).second)
+ {
+ m_RecentlyDrainedOrder.push_back(WorkerId);
+ while (m_RecentlyDrainedOrder.size() > RecentlyDrainedCapacity)
+ {
+ m_RecentlyDrainedWorkerIds.erase(m_RecentlyDrainedOrder.front());
+ m_RecentlyDrainedOrder.pop_front();
+ }
+ }
+ }
+ m_AsyncAgents.erase(It);
break;
}
- std::this_thread::sleep_for(std::chrono::milliseconds(100));
}
}
diff --git a/src/zenhorde/hordetransport.cpp b/src/zenhorde/hordetransport.cpp
index 69766e73e..65eaea477 100644
--- a/src/zenhorde/hordetransport.cpp
+++ b/src/zenhorde/hordetransport.cpp
@@ -9,71 +9,33 @@ ZEN_THIRD_PARTY_INCLUDES_START
#include <asio.hpp>
ZEN_THIRD_PARTY_INCLUDES_END
-#if ZEN_PLATFORM_WINDOWS
-# undef SendMessage
-#endif
-
namespace zen::horde {
-// ComputeTransport base
+// --- AsyncTcpComputeTransport ---
-bool
-ComputeTransport::SendMessage(const void* Data, size_t Size)
+struct AsyncTcpComputeTransport::Impl
{
- const uint8_t* Ptr = static_cast<const uint8_t*>(Data);
- size_t Remaining = Size;
-
- while (Remaining > 0)
- {
- const size_t Sent = Send(Ptr, Remaining);
- if (Sent == 0)
- {
- return false;
- }
- Ptr += Sent;
- Remaining -= Sent;
- }
+ asio::io_context& IoContext;
+ asio::ip::tcp::socket Socket;
- return true;
-}
+ explicit Impl(asio::io_context& Ctx) : IoContext(Ctx), Socket(Ctx) {}
+};
-bool
-ComputeTransport::RecvMessage(void* Data, size_t Size)
+AsyncTcpComputeTransport::AsyncTcpComputeTransport(asio::io_context& IoContext)
+: m_Impl(std::make_unique<Impl>(IoContext))
+, m_Log(zen::logging::Get("horde.transport.async"))
{
- uint8_t* Ptr = static_cast<uint8_t*>(Data);
- size_t Remaining = Size;
-
- while (Remaining > 0)
- {
- const size_t Received = Recv(Ptr, Remaining);
- if (Received == 0)
- {
- return false;
- }
- Ptr += Received;
- Remaining -= Received;
- }
-
- return true;
}
-// TcpComputeTransport - ASIO pimpl
-
-struct TcpComputeTransport::Impl
+AsyncTcpComputeTransport::~AsyncTcpComputeTransport()
{
- asio::io_context IoContext;
- asio::ip::tcp::socket Socket;
-
- Impl() : Socket(IoContext) {}
-};
+ Close();
+}
-// Uses ASIO in synchronous mode only — no async operations or io_context::run().
-// The io_context is only needed because ASIO sockets require one to be constructed.
-TcpComputeTransport::TcpComputeTransport(const MachineInfo& Info)
-: m_Impl(std::make_unique<Impl>())
-, m_Log(zen::logging::Get("horde.transport"))
+void
+AsyncTcpComputeTransport::AsyncConnect(const MachineInfo& Info, AsyncConnectHandler Handler)
{
- ZEN_TRACE_CPU("TcpComputeTransport::Connect");
+ ZEN_TRACE_CPU("AsyncTcpComputeTransport::AsyncConnect");
asio::error_code Ec;
@@ -82,80 +44,75 @@ TcpComputeTransport::TcpComputeTransport(const MachineInfo& Info)
{
ZEN_WARN("invalid address '{}': {}", Info.GetConnectionAddress(), Ec.message());
m_HasErrors = true;
+ asio::post(m_Impl->IoContext, [Handler = std::move(Handler), Ec] { Handler(Ec); });
return;
}
const asio::ip::tcp::endpoint Endpoint(Address, Info.GetConnectionPort());
- m_Impl->Socket.connect(Endpoint, Ec);
- if (Ec)
- {
- ZEN_WARN("failed to connect to Horde compute [{}:{}]: {}", Info.GetConnectionAddress(), Info.GetConnectionPort(), Ec.message());
- m_HasErrors = true;
- return;
- }
+ // Copy the nonce so it survives past this scope into the async callback
+ auto NonceBuf = std::make_shared<std::vector<uint8_t>>(Info.Nonce, Info.Nonce + NonceSize);
- // Disable Nagle's algorithm for lower latency
- m_Impl->Socket.set_option(asio::ip::tcp::no_delay(true), Ec);
-}
+ m_Impl->Socket.async_connect(Endpoint, [this, Handler = std::move(Handler), NonceBuf](const asio::error_code& Ec) mutable {
+ if (Ec)
+ {
+ ZEN_WARN("async connect failed: {}", Ec.message());
+ m_HasErrors = true;
+ Handler(Ec);
+ return;
+ }
-TcpComputeTransport::~TcpComputeTransport()
-{
- Close();
+ asio::error_code SetOptEc;
+ m_Impl->Socket.set_option(asio::ip::tcp::no_delay(true), SetOptEc);
+
+ // Send the 64-byte nonce as the first thing on the wire
+ asio::async_write(m_Impl->Socket,
+ asio::buffer(*NonceBuf),
+ [this, Handler = std::move(Handler), NonceBuf](const asio::error_code& Ec, size_t /*BytesWritten*/) {
+ if (Ec)
+ {
+ ZEN_WARN("nonce write failed: {}", Ec.message());
+ m_HasErrors = true;
+ }
+ Handler(Ec);
+ });
+ });
}
bool
-TcpComputeTransport::IsValid() const
+AsyncTcpComputeTransport::IsValid() const
{
return m_Impl && m_Impl->Socket.is_open() && !m_HasErrors && !m_IsClosed;
}
-size_t
-TcpComputeTransport::Send(const void* Data, size_t Size)
+void
+AsyncTcpComputeTransport::AsyncWrite(const void* Data, size_t Size, AsyncIoHandler Handler)
{
if (!IsValid())
{
- return 0;
- }
-
- asio::error_code Ec;
- const size_t Sent = m_Impl->Socket.send(asio::buffer(Data, Size), 0, Ec);
-
- if (Ec)
- {
- m_HasErrors = true;
- return 0;
+ asio::post(m_Impl->IoContext,
+ [Handler = std::move(Handler)] { Handler(asio::error::make_error_code(asio::error::not_connected), 0); });
+ return;
}
- return Sent;
+ asio::async_write(m_Impl->Socket, asio::buffer(Data, Size), std::move(Handler));
}
-size_t
-TcpComputeTransport::Recv(void* Data, size_t Size)
+void
+AsyncTcpComputeTransport::AsyncRead(void* Data, size_t Size, AsyncIoHandler Handler)
{
if (!IsValid())
{
- return 0;
- }
-
- asio::error_code Ec;
- const size_t Received = m_Impl->Socket.receive(asio::buffer(Data, Size), 0, Ec);
-
- if (Ec)
- {
- return 0;
+ asio::post(m_Impl->IoContext,
+ [Handler = std::move(Handler)] { Handler(asio::error::make_error_code(asio::error::not_connected), 0); });
+ return;
}
- return Received;
-}
-
-void
-TcpComputeTransport::MarkComplete()
-{
+ asio::async_read(m_Impl->Socket, asio::buffer(Data, Size), std::move(Handler));
}
void
-TcpComputeTransport::Close()
+AsyncTcpComputeTransport::Close()
{
if (!m_IsClosed && m_Impl && m_Impl->Socket.is_open())
{
diff --git a/src/zenhorde/hordetransport.h b/src/zenhorde/hordetransport.h
index 1b178dc0f..b5e841d7a 100644
--- a/src/zenhorde/hordetransport.h
+++ b/src/zenhorde/hordetransport.h
@@ -8,55 +8,60 @@
#include <cstddef>
#include <cstdint>
+#include <functional>
#include <memory>
+#include <system_error>
-#if ZEN_PLATFORM_WINDOWS
-# undef SendMessage
-#endif
+namespace asio {
+class io_context;
+}
namespace zen::horde {
-/** Abstract base interface for compute transports.
+/** Handler types for async transport operations. */
+using AsyncConnectHandler = std::function<void(const std::error_code&)>;
+using AsyncIoHandler = std::function<void(const std::error_code&, size_t)>;
+
+/** Abstract base for asynchronous compute transports.
*
- * Matches the UE FComputeTransport pattern. Concrete implementations handle
- * the underlying I/O (TCP, AES-wrapped, etc.) while this interface provides
- * blocking message helpers on top.
+ * All callbacks are invoked on the io_context that was provided at construction.
+ * Callers are responsible for strand serialization if needed.
*/
-class ComputeTransport
+class AsyncComputeTransport
{
public:
- virtual ~ComputeTransport() = default;
+ virtual ~AsyncComputeTransport() = default;
+
+ virtual bool IsValid() const = 0;
- virtual bool IsValid() const = 0;
- virtual size_t Send(const void* Data, size_t Size) = 0;
- virtual size_t Recv(void* Data, size_t Size) = 0;
- virtual void MarkComplete() = 0;
- virtual void Close() = 0;
+ /** Asynchronous write of exactly Size bytes. Handler called on completion or error. */
+ virtual void AsyncWrite(const void* Data, size_t Size, AsyncIoHandler Handler) = 0;
- /** Blocking send that loops until all bytes are transferred. Returns false on error. */
- bool SendMessage(const void* Data, size_t Size);
+ /** Asynchronous read of exactly Size bytes into Data. Handler called on completion or error. */
+ virtual void AsyncRead(void* Data, size_t Size, AsyncIoHandler Handler) = 0;
- /** Blocking receive that loops until all bytes are transferred. Returns false on error. */
- bool RecvMessage(void* Data, size_t Size);
+ virtual void Close() = 0;
};
-/** TCP socket transport using ASIO.
+/** Async TCP transport using ASIO.
*
- * Connects to the Horde compute endpoint specified by MachineInfo and provides
- * raw TCP send/receive. ASIO internals are hidden behind a pimpl to keep the
- * header clean.
+ * Connects to the Horde compute endpoint and provides async send/receive.
+ * The socket is created on a caller-provided io_context (shared across agents).
*/
-class TcpComputeTransport final : public ComputeTransport
+class AsyncTcpComputeTransport final : public AsyncComputeTransport
{
public:
- explicit TcpComputeTransport(const MachineInfo& Info);
- ~TcpComputeTransport() override;
-
- bool IsValid() const override;
- size_t Send(const void* Data, size_t Size) override;
- size_t Recv(void* Data, size_t Size) override;
- void MarkComplete() override;
- void Close() override;
+ /** Construct a transport on the given io_context. Does not connect yet. */
+ explicit AsyncTcpComputeTransport(asio::io_context& IoContext);
+ ~AsyncTcpComputeTransport() override;
+
+ /** Asynchronously connect to the endpoint and send the nonce. */
+ void AsyncConnect(const MachineInfo& Info, AsyncConnectHandler Handler);
+
+ bool IsValid() const override;
+ void AsyncWrite(const void* Data, size_t Size, AsyncIoHandler Handler) override;
+ void AsyncRead(void* Data, size_t Size, AsyncIoHandler Handler) override;
+ void Close() override;
private:
LoggerRef Log() { return m_Log; }
diff --git a/src/zenhorde/hordetransportaes.cpp b/src/zenhorde/hordetransportaes.cpp
index 505b6bde7..0b94a4397 100644
--- a/src/zenhorde/hordetransportaes.cpp
+++ b/src/zenhorde/hordetransportaes.cpp
@@ -5,9 +5,12 @@
#include <zencore/logging.h>
#include <zencore/trace.h>
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <asio.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
#include <algorithm>
#include <cstring>
-#include <random>
#if ZEN_PLATFORM_WINDOWS
# include <zencore/windows.h>
@@ -22,315 +25,410 @@ ZEN_THIRD_PARTY_INCLUDES_END
namespace zen::horde {
-struct AesComputeTransport::CryptoContext
-{
- uint8_t Key[KeySize] = {};
- uint8_t EncryptNonce[NonceBytes] = {};
- uint8_t DecryptNonce[NonceBytes] = {};
- bool HasErrors = false;
-
-#if !ZEN_PLATFORM_WINDOWS
- EVP_CIPHER_CTX* EncCtx = nullptr;
- EVP_CIPHER_CTX* DecCtx = nullptr;
-#endif
-
- CryptoContext(const uint8_t (&InKey)[KeySize])
- {
- memcpy(Key, InKey, KeySize);
-
- // The encrypt nonce is randomly initialized and then deterministically mutated
- // per message via UpdateNonce(). The decrypt nonce is not used — it comes from
- // the wire (each received message carries its own nonce in the header).
- std::random_device Rd;
- std::mt19937 Gen(Rd());
- std::uniform_int_distribution<int> Dist(0, 255);
- for (auto& Byte : EncryptNonce)
- {
- Byte = static_cast<uint8_t>(Dist(Gen));
- }
-
-#if !ZEN_PLATFORM_WINDOWS
- // Drain any stale OpenSSL errors
- while (ERR_get_error() != 0)
- {
- }
-
- EncCtx = EVP_CIPHER_CTX_new();
- EVP_EncryptInit_ex(EncCtx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr);
-
- DecCtx = EVP_CIPHER_CTX_new();
- EVP_DecryptInit_ex(DecCtx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr);
-#endif
- }
+namespace {
- ~CryptoContext()
- {
-#if ZEN_PLATFORM_WINDOWS
- SecureZeroMemory(Key, sizeof(Key));
- SecureZeroMemory(EncryptNonce, sizeof(EncryptNonce));
- SecureZeroMemory(DecryptNonce, sizeof(DecryptNonce));
-#else
- OPENSSL_cleanse(Key, sizeof(Key));
- OPENSSL_cleanse(EncryptNonce, sizeof(EncryptNonce));
- OPENSSL_cleanse(DecryptNonce, sizeof(DecryptNonce));
-
- if (EncCtx)
- {
- EVP_CIPHER_CTX_free(EncCtx);
- }
- if (DecCtx)
- {
- EVP_CIPHER_CTX_free(DecCtx);
- }
-#endif
- }
+ static constexpr size_t AesNonceBytes = 12;
+ static constexpr size_t AesTagBytes = 16;
- void UpdateNonce()
+ /** AES-256-GCM crypto context. Not exposed outside this translation unit. */
+ struct AesCryptoContext
{
- uint32_t* N32 = reinterpret_cast<uint32_t*>(EncryptNonce);
- N32[0]++;
- N32[1]--;
- N32[2] = N32[0] ^ N32[1];
- }
+ static constexpr size_t NonceBytes = AesNonceBytes;
+ static constexpr size_t TagBytes = AesTagBytes;
- // Returns total encrypted message size, or 0 on failure
- // Output format: [length(4B)][nonce(12B)][ciphertext][tag(16B)]
- int32_t EncryptMessage(uint8_t* Out, const void* In, int32_t InLength)
- {
- UpdateNonce();
+ uint8_t Key[KeySize] = {};
+ uint8_t EncryptNonce[NonceBytes] = {};
+ uint8_t DecryptNonce[NonceBytes] = {};
+ uint64_t DecryptCounter = 0; ///< Sequence number of the next message to be decrypted (for diagnostics)
+ bool HasErrors = false;
- // On Windows, BCrypt algorithm/key handles are created per call. This is simpler than
- // caching but has some overhead. For our use case (relatively large, infrequent messages)
- // this is acceptable.
#if ZEN_PLATFORM_WINDOWS
BCRYPT_ALG_HANDLE hAlg = nullptr;
BCRYPT_KEY_HANDLE hKey = nullptr;
+#else
+ EVP_CIPHER_CTX* EncCtx = nullptr;
+ EVP_CIPHER_CTX* DecCtx = nullptr;
+#endif
- BCryptOpenAlgorithmProvider(&hAlg, BCRYPT_AES_ALGORITHM, nullptr, 0);
- BCryptSetProperty(hAlg, BCRYPT_CHAINING_MODE, (PUCHAR)BCRYPT_CHAIN_MODE_GCM, sizeof(BCRYPT_CHAIN_MODE_GCM), 0);
- BCryptGenerateSymmetricKey(hAlg, &hKey, nullptr, 0, (PUCHAR)Key, KeySize, 0);
-
- BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO AuthInfo;
- BCRYPT_INIT_AUTH_MODE_INFO(AuthInfo);
- AuthInfo.pbNonce = EncryptNonce;
- AuthInfo.cbNonce = NonceBytes;
- uint8_t Tag[TagBytes] = {};
- AuthInfo.pbTag = Tag;
- AuthInfo.cbTag = TagBytes;
-
- ULONG CipherLen = 0;
- NTSTATUS Status =
- BCryptEncrypt(hKey, (PUCHAR)In, (ULONG)InLength, &AuthInfo, nullptr, 0, Out + 4 + NonceBytes, (ULONG)InLength, &CipherLen, 0);
-
- if (!BCRYPT_SUCCESS(Status))
+ AesCryptoContext(const uint8_t (&InKey)[KeySize])
{
- HasErrors = true;
- BCryptDestroyKey(hKey);
- BCryptCloseAlgorithmProvider(hAlg, 0);
- return 0;
- }
+ memcpy(Key, InKey, KeySize);
- // Write header: length + nonce
- memcpy(Out, &InLength, 4);
- memcpy(Out + 4, EncryptNonce, NonceBytes);
- // Write tag after ciphertext
- memcpy(Out + 4 + NonceBytes + CipherLen, Tag, TagBytes);
+ // EncryptNonce is zero-initialized (NIST SP 800-38D §8.2.1 deterministic
+ // construction): fixed_field = 0, counter starts at 0 and is incremented
+ // before each encryption by UpdateNonce(). No RNG is used here because
+ // std::random_device is not guaranteed to be a CSPRNG (historic MinGW,
+ // some WASI targets), and the deterministic construction does not need
+ // one as long as each session uses a unique key.
- BCryptDestroyKey(hKey);
- BCryptCloseAlgorithmProvider(hAlg, 0);
-
- return 4 + NonceBytes + static_cast<int32_t>(CipherLen) + TagBytes;
+#if ZEN_PLATFORM_WINDOWS
+ NTSTATUS Status = BCryptOpenAlgorithmProvider(&hAlg, BCRYPT_AES_ALGORITHM, nullptr, 0);
+ if (!BCRYPT_SUCCESS(Status))
+ {
+ ZEN_ERROR("BCryptOpenAlgorithmProvider failed: 0x{:08x}", static_cast<uint32_t>(Status));
+ hAlg = nullptr;
+ HasErrors = true;
+ return;
+ }
+
+ Status = BCryptSetProperty(hAlg, BCRYPT_CHAINING_MODE, (PUCHAR)BCRYPT_CHAIN_MODE_GCM, sizeof(BCRYPT_CHAIN_MODE_GCM), 0);
+ if (!BCRYPT_SUCCESS(Status))
+ {
+ ZEN_ERROR("BCryptSetProperty(BCRYPT_CHAIN_MODE_GCM) failed: 0x{:08x}", static_cast<uint32_t>(Status));
+ HasErrors = true;
+ return;
+ }
+
+ Status = BCryptGenerateSymmetricKey(hAlg, &hKey, nullptr, 0, (PUCHAR)Key, KeySize, 0);
+ if (!BCRYPT_SUCCESS(Status))
+ {
+ ZEN_ERROR("BCryptGenerateSymmetricKey failed: 0x{:08x}", static_cast<uint32_t>(Status));
+ hKey = nullptr;
+ HasErrors = true;
+ return;
+ }
#else
- if (EVP_EncryptInit_ex(EncCtx, nullptr, nullptr, Key, EncryptNonce) != 1)
- {
- HasErrors = true;
- return 0;
+ while (ERR_get_error() != 0)
+ {
+ }
+
+ EncCtx = EVP_CIPHER_CTX_new();
+ DecCtx = EVP_CIPHER_CTX_new();
+ if (!EncCtx || !DecCtx)
+ {
+ ZEN_ERROR("EVP_CIPHER_CTX_new failed");
+ HasErrors = true;
+ return;
+ }
+
+ if (EVP_EncryptInit_ex(EncCtx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr) != 1)
+ {
+ ZEN_ERROR("EVP_EncryptInit_ex(aes-256-gcm) failed: {}", ERR_get_error());
+ HasErrors = true;
+ return;
+ }
+
+ if (EVP_DecryptInit_ex(DecCtx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr) != 1)
+ {
+ ZEN_ERROR("EVP_DecryptInit_ex(aes-256-gcm) failed: {}", ERR_get_error());
+ HasErrors = true;
+ return;
+ }
+#endif
}
- int32_t Offset = 0;
- // Write length
- memcpy(Out + Offset, &InLength, 4);
- Offset += 4;
- // Write nonce
- memcpy(Out + Offset, EncryptNonce, NonceBytes);
- Offset += NonceBytes;
-
- // Encrypt
- int OutLen = 0;
- if (EVP_EncryptUpdate(EncCtx, Out + Offset, &OutLen, static_cast<const uint8_t*>(In), InLength) != 1)
+ ~AesCryptoContext()
{
- HasErrors = true;
- return 0;
+#if ZEN_PLATFORM_WINDOWS
+ if (hKey)
+ {
+ BCryptDestroyKey(hKey);
+ }
+ if (hAlg)
+ {
+ BCryptCloseAlgorithmProvider(hAlg, 0);
+ }
+ SecureZeroMemory(Key, sizeof(Key));
+ SecureZeroMemory(EncryptNonce, sizeof(EncryptNonce));
+ SecureZeroMemory(DecryptNonce, sizeof(DecryptNonce));
+#else
+ OPENSSL_cleanse(Key, sizeof(Key));
+ OPENSSL_cleanse(EncryptNonce, sizeof(EncryptNonce));
+ OPENSSL_cleanse(DecryptNonce, sizeof(DecryptNonce));
+
+ if (EncCtx)
+ {
+ EVP_CIPHER_CTX_free(EncCtx);
+ }
+ if (DecCtx)
+ {
+ EVP_CIPHER_CTX_free(DecCtx);
+ }
+#endif
}
- Offset += OutLen;
- // Finalize
- int FinalLen = 0;
- if (EVP_EncryptFinal_ex(EncCtx, Out + Offset, &FinalLen) != 1)
+ void UpdateNonce()
{
+ // NIST SP 800-38D §8.2.1 deterministic construction:
+ // nonce = [fixed_field (4 bytes) || invocation_counter (8 bytes, big-endian)]
+ // The low 8 bytes are a strict monotonic counter starting at zero. On 2^64
+ // exhaustion the session is torn down (HasErrors) - never wrap, since a repeated
+ // (key, nonce) pair catastrophically breaks AES-GCM confidentiality and integrity.
+ for (int i = 11; i >= 4; --i)
+ {
+ if (++EncryptNonce[i] != 0)
+ {
+ return;
+ }
+ }
HasErrors = true;
- return 0;
}
- Offset += FinalLen;
- // Get tag
- if (EVP_CIPHER_CTX_ctrl(EncCtx, EVP_CTRL_GCM_GET_TAG, TagBytes, Out + Offset) != 1)
+ int32_t EncryptMessage(uint8_t* Out, const void* In, int32_t InLength)
{
- HasErrors = true;
- return 0;
- }
- Offset += TagBytes;
-
- return Offset;
-#endif
- }
+ UpdateNonce();
+ if (HasErrors)
+ {
+ return 0;
+ }
- // Decrypt a message. Returns decrypted data length, or 0 on failure.
- // Input must be [ciphertext][tag], with nonce provided separately.
- int32_t DecryptMessage(void* Out, const uint8_t* Nonce, const uint8_t* CipherAndTag, int32_t DataLength)
- {
#if ZEN_PLATFORM_WINDOWS
- BCRYPT_ALG_HANDLE hAlg = nullptr;
- BCRYPT_KEY_HANDLE hKey = nullptr;
-
- BCryptOpenAlgorithmProvider(&hAlg, BCRYPT_AES_ALGORITHM, nullptr, 0);
- BCryptSetProperty(hAlg, BCRYPT_CHAINING_MODE, (PUCHAR)BCRYPT_CHAIN_MODE_GCM, sizeof(BCRYPT_CHAIN_MODE_GCM), 0);
- BCryptGenerateSymmetricKey(hAlg, &hKey, nullptr, 0, (PUCHAR)Key, KeySize, 0);
-
- BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO AuthInfo;
- BCRYPT_INIT_AUTH_MODE_INFO(AuthInfo);
- AuthInfo.pbNonce = const_cast<uint8_t*>(Nonce);
- AuthInfo.cbNonce = NonceBytes;
- AuthInfo.pbTag = const_cast<uint8_t*>(CipherAndTag + DataLength);
- AuthInfo.cbTag = TagBytes;
-
- ULONG PlainLen = 0;
- NTSTATUS Status = BCryptDecrypt(hKey,
- (PUCHAR)CipherAndTag,
- (ULONG)DataLength,
- &AuthInfo,
- nullptr,
- 0,
- (PUCHAR)Out,
- (ULONG)DataLength,
- &PlainLen,
- 0);
-
- BCryptDestroyKey(hKey);
- BCryptCloseAlgorithmProvider(hAlg, 0);
-
- if (!BCRYPT_SUCCESS(Status))
- {
- HasErrors = true;
- return 0;
- }
-
- return static_cast<int32_t>(PlainLen);
+ BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO AuthInfo;
+ BCRYPT_INIT_AUTH_MODE_INFO(AuthInfo);
+ AuthInfo.pbNonce = EncryptNonce;
+ AuthInfo.cbNonce = NonceBytes;
+ // Tag is output-only on encrypt; BCryptEncrypt writes TagBytes bytes into it, so skip zero-init.
+ uint8_t Tag[TagBytes];
+ AuthInfo.pbTag = Tag;
+ AuthInfo.cbTag = TagBytes;
+
+ ULONG CipherLen = 0;
+ const NTSTATUS Status = BCryptEncrypt(hKey,
+ (PUCHAR)In,
+ (ULONG)InLength,
+ &AuthInfo,
+ nullptr,
+ 0,
+ Out + 4 + NonceBytes,
+ (ULONG)InLength,
+ &CipherLen,
+ 0);
+
+ if (!BCRYPT_SUCCESS(Status))
+ {
+ ZEN_ERROR("BCryptEncrypt failed: 0x{:08x}", static_cast<uint32_t>(Status));
+ HasErrors = true;
+ return 0;
+ }
+
+ memcpy(Out, &InLength, 4);
+ memcpy(Out + 4, EncryptNonce, NonceBytes);
+ memcpy(Out + 4 + NonceBytes + CipherLen, Tag, TagBytes);
+
+ return 4 + NonceBytes + static_cast<int32_t>(CipherLen) + TagBytes;
#else
- if (EVP_DecryptInit_ex(DecCtx, nullptr, nullptr, Key, Nonce) != 1)
- {
- HasErrors = true;
- return 0;
- }
-
- int OutLen = 0;
- if (EVP_DecryptUpdate(DecCtx, static_cast<uint8_t*>(Out), &OutLen, CipherAndTag, DataLength) != 1)
- {
- HasErrors = true;
- return 0;
+ // Reset per message so any stale state from a previous encrypt (e.g. partial
+ // completion after a prior error) cannot bleed into this operation. Re-bind
+ // the cipher/key; the IV is then set via the normal init call below.
+ if (EVP_CIPHER_CTX_reset(EncCtx) != 1 || EVP_EncryptInit_ex(EncCtx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr) != 1)
+ {
+ ZEN_ERROR("EVP_CIPHER_CTX_reset/EncryptInit failed: {}", ERR_get_error());
+ HasErrors = true;
+ return 0;
+ }
+ if (EVP_EncryptInit_ex(EncCtx, nullptr, nullptr, Key, EncryptNonce) != 1)
+ {
+ ZEN_ERROR("EVP_EncryptInit_ex(key+iv) failed: {}", ERR_get_error());
+ HasErrors = true;
+ return 0;
+ }
+
+ int32_t Offset = 0;
+ memcpy(Out + Offset, &InLength, 4);
+ Offset += 4;
+ memcpy(Out + Offset, EncryptNonce, NonceBytes);
+ Offset += NonceBytes;
+
+ int OutLen = 0;
+ if (EVP_EncryptUpdate(EncCtx, Out + Offset, &OutLen, static_cast<const uint8_t*>(In), InLength) != 1)
+ {
+ ZEN_ERROR("EVP_EncryptUpdate failed: {}", ERR_get_error());
+ HasErrors = true;
+ return 0;
+ }
+ Offset += OutLen;
+
+ int FinalLen = 0;
+ if (EVP_EncryptFinal_ex(EncCtx, Out + Offset, &FinalLen) != 1)
+ {
+ ZEN_ERROR("EVP_EncryptFinal_ex failed: {}", ERR_get_error());
+ HasErrors = true;
+ return 0;
+ }
+ Offset += FinalLen;
+
+ if (EVP_CIPHER_CTX_ctrl(EncCtx, EVP_CTRL_GCM_GET_TAG, TagBytes, Out + Offset) != 1)
+ {
+ ZEN_ERROR("EVP_CTRL_GCM_GET_TAG failed: {}", ERR_get_error());
+ HasErrors = true;
+ return 0;
+ }
+ Offset += TagBytes;
+
+ return Offset;
+#endif
}
- // Set the tag for verification
- if (EVP_CIPHER_CTX_ctrl(DecCtx, EVP_CTRL_GCM_SET_TAG, TagBytes, const_cast<uint8_t*>(CipherAndTag + DataLength)) != 1)
+ int32_t DecryptMessage(void* Out, const uint8_t* Nonce, const uint8_t* CipherAndTag, int32_t DataLength)
{
- HasErrors = true;
- return 0;
+#if ZEN_PLATFORM_WINDOWS
+ BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO AuthInfo;
+ BCRYPT_INIT_AUTH_MODE_INFO(AuthInfo);
+ AuthInfo.pbNonce = const_cast<uint8_t*>(Nonce);
+ AuthInfo.cbNonce = NonceBytes;
+ AuthInfo.pbTag = const_cast<uint8_t*>(CipherAndTag + DataLength);
+ AuthInfo.cbTag = TagBytes;
+
+ ULONG PlainLen = 0;
+ const NTSTATUS Status = BCryptDecrypt(hKey,
+ (PUCHAR)CipherAndTag,
+ (ULONG)DataLength,
+ &AuthInfo,
+ nullptr,
+ 0,
+ (PUCHAR)Out,
+ (ULONG)DataLength,
+ &PlainLen,
+ 0);
+
+ if (!BCRYPT_SUCCESS(Status))
+ {
+ // STATUS_AUTH_TAG_MISMATCH (0xC000A002) indicates GCM integrity failure -
+ // either in-flight corruption or active tampering. Log distinctly from
+ // other BCryptDecrypt failures so that tamper attempts are auditable.
+ static constexpr NTSTATUS STATUS_AUTH_TAG_MISMATCH_VAL = static_cast<NTSTATUS>(0xC000A002L);
+ if (Status == STATUS_AUTH_TAG_MISMATCH_VAL)
+ {
+ ZEN_ERROR("AES-GCM tag verification failed (seq={}): possible tampering or in-flight corruption", DecryptCounter);
+ }
+ else
+ {
+ ZEN_ERROR("BCryptDecrypt failed: 0x{:08x} (seq={})", static_cast<uint32_t>(Status), DecryptCounter);
+ }
+ HasErrors = true;
+ return 0;
+ }
+
+ ++DecryptCounter;
+ return static_cast<int32_t>(PlainLen);
+#else
+ // Same rationale as EncryptMessage: reset the context and re-bind the cipher
+ // before each decrypt to avoid stale state from a previous operation.
+ if (EVP_CIPHER_CTX_reset(DecCtx) != 1 || EVP_DecryptInit_ex(DecCtx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr) != 1)
+ {
+ ZEN_ERROR("EVP_CIPHER_CTX_reset/DecryptInit failed (seq={}): {}", DecryptCounter, ERR_get_error());
+ HasErrors = true;
+ return 0;
+ }
+ if (EVP_DecryptInit_ex(DecCtx, nullptr, nullptr, Key, Nonce) != 1)
+ {
+ ZEN_ERROR("EVP_DecryptInit_ex (seq={}) failed: {}", DecryptCounter, ERR_get_error());
+ HasErrors = true;
+ return 0;
+ }
+
+ int OutLen = 0;
+ if (EVP_DecryptUpdate(DecCtx, static_cast<uint8_t*>(Out), &OutLen, CipherAndTag, DataLength) != 1)
+ {
+ ZEN_ERROR("EVP_DecryptUpdate failed (seq={}): {}", DecryptCounter, ERR_get_error());
+ HasErrors = true;
+ return 0;
+ }
+
+ if (EVP_CIPHER_CTX_ctrl(DecCtx, EVP_CTRL_GCM_SET_TAG, TagBytes, const_cast<uint8_t*>(CipherAndTag + DataLength)) != 1)
+ {
+ ZEN_ERROR("EVP_CTRL_GCM_SET_TAG failed (seq={}): {}", DecryptCounter, ERR_get_error());
+ HasErrors = true;
+ return 0;
+ }
+
+ int FinalLen = 0;
+ if (EVP_DecryptFinal_ex(DecCtx, static_cast<uint8_t*>(Out) + OutLen, &FinalLen) != 1)
+ {
+ // EVP_DecryptFinal_ex returns 0 specifically on GCM tag verification failure
+ // once the tag has been set. Log distinctly so tamper attempts are auditable.
+ ZEN_ERROR("AES-GCM tag verification failed (seq={}): possible tampering or in-flight corruption", DecryptCounter);
+ HasErrors = true;
+ return 0;
+ }
+
+ ++DecryptCounter;
+ return OutLen + FinalLen;
+#endif
}
+ };
- int FinalLen = 0;
- if (EVP_DecryptFinal_ex(DecCtx, static_cast<uint8_t*>(Out) + OutLen, &FinalLen) != 1)
- {
- HasErrors = true;
- return 0;
- }
+} // anonymous namespace
- return OutLen + FinalLen;
-#endif
- }
+struct AsyncAesComputeTransport::CryptoContext : AesCryptoContext
+{
+ using AesCryptoContext::AesCryptoContext;
};
-AesComputeTransport::AesComputeTransport(const uint8_t (&Key)[KeySize], std::unique_ptr<ComputeTransport> InnerTransport)
+// --- AsyncAesComputeTransport ---
+
+AsyncAesComputeTransport::AsyncAesComputeTransport(const uint8_t (&Key)[KeySize],
+ std::unique_ptr<AsyncComputeTransport> InnerTransport,
+ asio::io_context& IoContext)
: m_Crypto(std::make_unique<CryptoContext>(Key))
, m_Inner(std::move(InnerTransport))
+, m_IoContext(IoContext)
{
}
-AesComputeTransport::~AesComputeTransport()
+AsyncAesComputeTransport::~AsyncAesComputeTransport()
{
Close();
}
bool
-AesComputeTransport::IsValid() const
+AsyncAesComputeTransport::IsValid() const
{
return m_Inner && m_Inner->IsValid() && m_Crypto && !m_Crypto->HasErrors && !m_IsClosed;
}
-size_t
-AesComputeTransport::Send(const void* Data, size_t Size)
+void
+AsyncAesComputeTransport::AsyncWrite(const void* Data, size_t Size, AsyncIoHandler Handler)
{
- ZEN_TRACE_CPU("AesComputeTransport::Send");
-
if (!IsValid())
{
- return 0;
+ asio::post(m_IoContext, [Handler = std::move(Handler)] { Handler(asio::error::make_error_code(asio::error::not_connected), 0); });
+ return;
}
- std::lock_guard<std::mutex> Lock(m_Lock);
-
const int32_t DataLength = static_cast<int32_t>(Size);
- const size_t MessageLength = 4 + NonceBytes + Size + TagBytes;
+ const size_t MessageLength = 4 + CryptoContext::NonceBytes + Size + CryptoContext::TagBytes;
- if (m_EncryptBuffer.size() < MessageLength)
- {
- m_EncryptBuffer.resize(MessageLength);
- }
+ // Encrypt directly into the per-write buffer rather than a long-lived member. Using a
+ // member (plaintext + ciphertext share that buffer during encryption on the OpenSSL
+ // path) would leave plaintext on the heap indefinitely and would also make the
+ // transport unsafe if AsyncWrite were ever invoked concurrently. Size the shared_ptr
+ // exactly to EncryptedLen afterwards.
+ auto EncBuf = std::make_shared<std::vector<uint8_t>>(MessageLength);
- const int32_t EncryptedLen = m_Crypto->EncryptMessage(m_EncryptBuffer.data(), Data, DataLength);
+ const int32_t EncryptedLen = m_Crypto->EncryptMessage(EncBuf->data(), Data, DataLength);
if (EncryptedLen == 0)
{
- return 0;
+ asio::post(m_IoContext,
+ [Handler = std::move(Handler)] { Handler(asio::error::make_error_code(asio::error::connection_aborted), 0); });
+ return;
}
- if (!m_Inner->SendMessage(m_EncryptBuffer.data(), static_cast<size_t>(EncryptedLen)))
- {
- return 0;
- }
+ EncBuf->resize(static_cast<size_t>(EncryptedLen));
- return Size;
+ m_Inner->AsyncWrite(
+ EncBuf->data(),
+ EncBuf->size(),
+ [Handler = std::move(Handler), EncBuf, Size](const std::error_code& Ec, size_t /*BytesWritten*/) { Handler(Ec, Ec ? 0 : Size); });
}
-size_t
-AesComputeTransport::Recv(void* Data, size_t Size)
+void
+AsyncAesComputeTransport::AsyncRead(void* Data, size_t Size, AsyncIoHandler Handler)
{
if (!IsValid())
{
- return 0;
+ asio::post(m_IoContext, [Handler = std::move(Handler)] { Handler(asio::error::make_error_code(asio::error::not_connected), 0); });
+ return;
}
- // AES-GCM decrypts entire messages at once, but the caller may request fewer bytes
- // than the decrypted message contains. Excess bytes are buffered in m_RemainingData
- // and returned on subsequent Recv calls without another decryption round-trip.
- ZEN_TRACE_CPU("AesComputeTransport::Recv");
-
- std::lock_guard<std::mutex> Lock(m_Lock);
+ uint8_t* Dest = static_cast<uint8_t*>(Data);
if (!m_RemainingData.empty())
{
const size_t Available = m_RemainingData.size() - m_RemainingOffset;
const size_t ToCopy = std::min(Available, Size);
- memcpy(Data, m_RemainingData.data() + m_RemainingOffset, ToCopy);
+ memcpy(Dest, m_RemainingData.data() + m_RemainingOffset, ToCopy);
m_RemainingOffset += ToCopy;
if (m_RemainingOffset >= m_RemainingData.size())
@@ -339,82 +437,104 @@ AesComputeTransport::Recv(void* Data, size_t Size)
m_RemainingOffset = 0;
}
- return ToCopy;
- }
-
- // Receive packet header: [length(4B)][nonce(12B)]
- struct PacketHeader
- {
- int32_t DataLength = 0;
- uint8_t Nonce[NonceBytes] = {};
- } Header;
-
- if (!m_Inner->RecvMessage(&Header, sizeof(Header)))
- {
- return 0;
- }
-
- // Validate DataLength to prevent OOM from malicious/corrupt peers
- static constexpr int32_t MaxDataLength = 64 * 1024 * 1024; // 64 MiB
-
- if (Header.DataLength <= 0 || Header.DataLength > MaxDataLength)
- {
- ZEN_WARN("AES recv: invalid DataLength {} from peer", Header.DataLength);
- return 0;
- }
-
- // Receive ciphertext + tag
- const size_t MessageLength = static_cast<size_t>(Header.DataLength) + TagBytes;
-
- if (m_EncryptBuffer.size() < MessageLength)
- {
- m_EncryptBuffer.resize(MessageLength);
- }
-
- if (!m_Inner->RecvMessage(m_EncryptBuffer.data(), MessageLength))
- {
- return 0;
- }
-
- // Decrypt
- const size_t BytesToReturn = std::min(static_cast<size_t>(Header.DataLength), Size);
-
- // We need a temporary buffer for decryption if we can't decrypt directly into output
- std::vector<uint8_t> DecryptedBuf(static_cast<size_t>(Header.DataLength));
-
- const int32_t Decrypted = m_Crypto->DecryptMessage(DecryptedBuf.data(), Header.Nonce, m_EncryptBuffer.data(), Header.DataLength);
- if (Decrypted == 0)
- {
- return 0;
- }
-
- memcpy(Data, DecryptedBuf.data(), BytesToReturn);
+ if (ToCopy == Size)
+ {
+ asio::post(m_IoContext, [Handler = std::move(Handler), Size] { Handler(std::error_code{}, Size); });
+ return;
+ }
- // Store remaining data if we couldn't return everything
- if (static_cast<size_t>(Header.DataLength) > BytesToReturn)
- {
- m_RemainingOffset = 0;
- m_RemainingData.assign(DecryptedBuf.begin() + BytesToReturn, DecryptedBuf.begin() + Header.DataLength);
+ DoRecvMessage(Dest + ToCopy, Size - ToCopy, std::move(Handler));
+ return;
}
- return BytesToReturn;
+ DoRecvMessage(Dest, Size, std::move(Handler));
}
void
-AesComputeTransport::MarkComplete()
+AsyncAesComputeTransport::DoRecvMessage(uint8_t* Dest, size_t Size, AsyncIoHandler Handler)
{
- if (IsValid())
- {
- m_Inner->MarkComplete();
- }
+ static constexpr size_t HeaderSize = 4 + CryptoContext::NonceBytes;
+ auto HeaderBuf = std::make_shared<std::array<uint8_t, 4 + 12>>();
+
+ m_Inner->AsyncRead(HeaderBuf->data(),
+ HeaderSize,
+ [this, Dest, Size, Handler = std::move(Handler), HeaderBuf](const std::error_code& Ec, size_t /*Bytes*/) mutable {
+ if (Ec)
+ {
+ Handler(Ec, 0);
+ return;
+ }
+
+ int32_t DataLength = 0;
+ memcpy(&DataLength, HeaderBuf->data(), 4);
+
+ static constexpr int32_t MaxDataLength = 64 * 1024 * 1024;
+ if (DataLength <= 0 || DataLength > MaxDataLength)
+ {
+ Handler(asio::error::make_error_code(asio::error::invalid_argument), 0);
+ return;
+ }
+
+ const size_t MessageLength = static_cast<size_t>(DataLength) + CryptoContext::TagBytes;
+ if (m_DecryptBuffer.size() < MessageLength)
+ {
+ m_DecryptBuffer.resize(MessageLength);
+ }
+
+ auto NonceBuf = std::make_shared<std::array<uint8_t, CryptoContext::NonceBytes>>();
+ memcpy(NonceBuf->data(), HeaderBuf->data() + 4, CryptoContext::NonceBytes);
+
+ m_Inner->AsyncRead(
+ m_DecryptBuffer.data(),
+ MessageLength,
+ [this, Dest, Size, Handler = std::move(Handler), DataLength, NonceBuf](const std::error_code& Ec,
+ size_t /*Bytes*/) mutable {
+ if (Ec)
+ {
+ Handler(Ec, 0);
+ return;
+ }
+
+ std::vector<uint8_t> PlaintextBuf(static_cast<size_t>(DataLength));
+ const int32_t Decrypted =
+ m_Crypto->DecryptMessage(PlaintextBuf.data(), NonceBuf->data(), m_DecryptBuffer.data(), DataLength);
+ if (Decrypted == 0)
+ {
+ Handler(asio::error::make_error_code(asio::error::connection_aborted), 0);
+ return;
+ }
+
+ const size_t BytesToReturn = std::min(static_cast<size_t>(Decrypted), Size);
+ memcpy(Dest, PlaintextBuf.data(), BytesToReturn);
+
+ if (static_cast<size_t>(Decrypted) > BytesToReturn)
+ {
+ m_RemainingOffset = 0;
+ m_RemainingData.assign(PlaintextBuf.begin() + BytesToReturn, PlaintextBuf.begin() + Decrypted);
+ }
+
+ if (BytesToReturn < Size)
+ {
+ DoRecvMessage(Dest + BytesToReturn, Size - BytesToReturn, std::move(Handler));
+ }
+ else
+ {
+ Handler(std::error_code{}, Size);
+ }
+ });
+ });
}
void
-AesComputeTransport::Close()
+AsyncAesComputeTransport::Close()
{
if (!m_IsClosed)
{
- if (m_Inner && m_Inner->IsValid())
+ // Always forward Close() to the inner transport if we have one. Gating on
+ // IsValid() skipped cleanup when the inner transport was partially torn down
+ // (e.g. after a read/write error marked it non-valid but left its socket open),
+ // leaking OS handles. Close implementations are expected to be idempotent.
+ if (m_Inner)
{
m_Inner->Close();
}
diff --git a/src/zenhorde/hordetransportaes.h b/src/zenhorde/hordetransportaes.h
index efcad9835..7846073dc 100644
--- a/src/zenhorde/hordetransportaes.h
+++ b/src/zenhorde/hordetransportaes.h
@@ -6,47 +6,54 @@
#include <cstdint>
#include <memory>
-#include <mutex>
#include <vector>
+namespace asio {
+class io_context;
+}
+
namespace zen::horde {
-/** AES-256-GCM encrypted transport wrapper.
+/** Async AES-256-GCM encrypted transport wrapper.
*
- * Wraps an inner ComputeTransport, encrypting all outgoing data and decrypting
- * all incoming data using AES-256-GCM. The nonce is mutated per message using
- * the Horde nonce mangling scheme: n32[0]++; n32[1]--; n32[2] = n32[0] ^ n32[1].
+ * Wraps an AsyncComputeTransport, encrypting outgoing and decrypting incoming
+ * data using AES-256-GCM. Outgoing nonces follow the NIST SP 800-38D §8.2.1
+ * deterministic construction: a 4-byte fixed field followed by an 8-byte
+ * big-endian monotonic counter. The session is torn down if the counter
+ * would wrap.
*
* Wire format per encrypted message:
* [plaintext length (4B little-endian)][nonce (12B)][ciphertext][GCM tag (16B)]
*
* Uses BCrypt on Windows and OpenSSL EVP on Linux/macOS (selected at compile time).
+ *
+ * Thread safety: all operations must be serialized by the caller (e.g. via a strand).
*/
-class AesComputeTransport final : public ComputeTransport
+class AsyncAesComputeTransport final : public AsyncComputeTransport
{
public:
- AesComputeTransport(const uint8_t (&Key)[KeySize], std::unique_ptr<ComputeTransport> InnerTransport);
- ~AesComputeTransport() override;
+ AsyncAesComputeTransport(const uint8_t (&Key)[KeySize],
+ std::unique_ptr<AsyncComputeTransport> InnerTransport,
+ asio::io_context& IoContext);
+ ~AsyncAesComputeTransport() override;
- bool IsValid() const override;
- size_t Send(const void* Data, size_t Size) override;
- size_t Recv(void* Data, size_t Size) override;
- void MarkComplete() override;
- void Close() override;
+ bool IsValid() const override;
+ void AsyncWrite(const void* Data, size_t Size, AsyncIoHandler Handler) override;
+ void AsyncRead(void* Data, size_t Size, AsyncIoHandler Handler) override;
+ void Close() override;
private:
- static constexpr size_t NonceBytes = 12; ///< AES-GCM nonce size
- static constexpr size_t TagBytes = 16; ///< AES-GCM authentication tag size
+ void DoRecvMessage(uint8_t* Dest, size_t Size, AsyncIoHandler Handler);
struct CryptoContext;
- std::unique_ptr<CryptoContext> m_Crypto;
- std::unique_ptr<ComputeTransport> m_Inner;
- std::vector<uint8_t> m_EncryptBuffer;
- std::vector<uint8_t> m_RemainingData; ///< Buffered decrypted data from a partially consumed Recv
- size_t m_RemainingOffset = 0;
- std::mutex m_Lock;
- bool m_IsClosed = false;
+ std::unique_ptr<CryptoContext> m_Crypto;
+ std::unique_ptr<AsyncComputeTransport> m_Inner;
+ asio::io_context& m_IoContext;
+ std::vector<uint8_t> m_DecryptBuffer;
+ std::vector<uint8_t> m_RemainingData;
+ size_t m_RemainingOffset = 0;
+ bool m_IsClosed = false;
};
} // namespace zen::horde
diff --git a/src/zenhorde/include/zenhorde/hordeclient.h b/src/zenhorde/include/zenhorde/hordeclient.h
index 201d68b83..87caec019 100644
--- a/src/zenhorde/include/zenhorde/hordeclient.h
+++ b/src/zenhorde/include/zenhorde/hordeclient.h
@@ -45,14 +45,15 @@ struct MachineInfo
uint8_t Key[KeySize] = {}; ///< 32-byte AES key (when EncryptionMode == AES)
bool IsWindows = false;
std::string LeaseId;
+ std::string Pool;
std::map<std::string, PortInfo> Ports;
/** Return the address to connect to, accounting for connection mode. */
- const std::string& GetConnectionAddress() const { return Mode == ConnectionMode::Relay ? ConnectionAddress : Ip; }
+ [[nodiscard]] const std::string& GetConnectionAddress() const { return Mode == ConnectionMode::Relay ? ConnectionAddress : Ip; }
/** Return the port to connect to, accounting for connection mode and port mapping. */
- uint16_t GetConnectionPort() const
+ [[nodiscard]] uint16_t GetConnectionPort() const
{
if (Mode == ConnectionMode::Relay)
{
@@ -65,7 +66,20 @@ struct MachineInfo
return Port;
}
- bool IsValid() const { return !Ip.empty() && Port != 0xFFFF; }
+ /** Return the address and port for the Zen service endpoint, accounting for relay port mapping. */
+ [[nodiscard]] std::pair<const std::string&, uint16_t> GetZenServiceEndpoint(uint16_t DefaultPort) const
+ {
+ if (Mode == ConnectionMode::Relay)
+ {
+ if (auto It = Ports.find("ZenPort"); It != Ports.end())
+ {
+ return {ConnectionAddress, It->second.Port};
+ }
+ }
+ return {Ip, DefaultPort};
+ }
+
+ [[nodiscard]] bool IsValid() const { return !Ip.empty() && Port != 0xFFFF; }
};
/** Result of cluster auto-resolution via the Horde API. */
@@ -83,31 +97,29 @@ struct ClusterInfo
class HordeClient
{
public:
- explicit HordeClient(const HordeConfig& Config);
+ explicit HordeClient(HordeConfig Config);
~HordeClient();
HordeClient(const HordeClient&) = delete;
HordeClient& operator=(const HordeClient&) = delete;
/** Initialize the underlying HTTP client. Must be called before other methods. */
- bool Initialize();
+ [[nodiscard]] bool Initialize();
/** Build the JSON request body for cluster resolution and machine requests.
* Encodes pool, condition, connection mode, encryption, and port requirements. */
- std::string BuildRequestBody() const;
+ [[nodiscard]] std::string BuildRequestBody() const;
/** Resolve the best cluster for the given request via POST /api/v2/compute/_cluster. */
- bool ResolveCluster(const std::string& RequestBody, ClusterInfo& OutCluster);
+ [[nodiscard]] bool ResolveCluster(const std::string& RequestBody, ClusterInfo& OutCluster);
/** Request a compute machine from the given cluster via POST /api/v2/compute/{clusterId}.
* On success, populates OutMachine with connection details and credentials. */
- bool RequestMachine(const std::string& RequestBody, const std::string& ClusterId, MachineInfo& OutMachine);
+ [[nodiscard]] bool RequestMachine(const std::string& RequestBody, const std::string& ClusterId, MachineInfo& OutMachine);
LoggerRef Log() { return m_Log; }
private:
- bool ParseHexBytes(std::string_view Hex, uint8_t* Out, size_t OutSize);
-
HordeConfig m_Config;
std::unique_ptr<zen::HttpClient> m_Http;
LoggerRef m_Log;
diff --git a/src/zenhorde/include/zenhorde/hordeconfig.h b/src/zenhorde/include/zenhorde/hordeconfig.h
index dd70f9832..3a4dfb386 100644
--- a/src/zenhorde/include/zenhorde/hordeconfig.h
+++ b/src/zenhorde/include/zenhorde/hordeconfig.h
@@ -4,6 +4,10 @@
#include <zenhorde/zenhorde.h>
+#include <zenhttp/httpclient.h>
+
+#include <functional>
+#include <optional>
#include <string>
namespace zen::horde {
@@ -33,20 +37,25 @@ struct HordeConfig
static constexpr const char* ClusterDefault = "default";
static constexpr const char* ClusterAuto = "_auto";
- bool Enabled = false; ///< Whether Horde provisioning is active
- std::string ServerUrl; ///< Horde server base URL (e.g. "https://horde.example.com")
- std::string AuthToken; ///< Authentication token for the Horde API
- std::string Pool; ///< Pool name to request machines from
- std::string Cluster = ClusterDefault; ///< Cluster ID, or "_auto" to auto-resolve
- std::string Condition; ///< Agent filter expression for machine selection
- std::string HostAddress; ///< Address that provisioned agents use to connect back to us
- std::string BinariesPath; ///< Path to directory containing zenserver binary for remote upload
- uint16_t ZenServicePort = 8558; ///< Port number that provisioned agents should forward to us for Zen service communication
-
- int MaxCores = 2048;
- bool AllowWine = true; ///< Allow running Windows binaries under Wine on Linux agents
- ConnectionMode Mode = ConnectionMode::Direct;
- Encryption EncryptionMode = Encryption::None;
+ bool Enabled = false; ///< Whether Horde provisioning is active
+ std::string ServerUrl; ///< Horde server base URL (e.g. "https://horde.example.com")
+ std::string AuthToken; ///< Authentication token for the Horde API (static fallback)
+
+ /// Optional token provider with automatic refresh (e.g. from OidcToken executable).
+ /// When set, takes priority over the static AuthToken string.
+ std::optional<std::function<HttpClientAccessToken()>> AccessTokenProvider;
+ std::string Pool; ///< Pool name to request machines from
+ std::string Cluster = ClusterDefault; ///< Cluster ID, or "_auto" to auto-resolve
+ std::string Condition; ///< Agent filter expression for machine selection
+ std::string HostAddress; ///< Address that provisioned agents use to connect back to us
+ std::string BinariesPath; ///< Path to directory containing zenserver binary for remote upload
+ uint16_t ZenServicePort = 8558; ///< Port number that provisioned agents should forward to us for Zen service communication
+
+ int MaxCores = 2048;
+ int DrainGracePeriodSeconds = 300; ///< Grace period for draining agents before force-kill (default 5 min)
+ bool AllowWine = true; ///< Allow running Windows binaries under Wine on Linux agents
+ ConnectionMode Mode = ConnectionMode::Direct;
+ Encryption EncryptionMode = Encryption::None;
/** Validate the configuration. Returns false if the configuration is invalid
* (e.g. Relay mode without AES encryption). */
diff --git a/src/zenhorde/include/zenhorde/hordeprovisioner.h b/src/zenhorde/include/zenhorde/hordeprovisioner.h
index 4e2e63bbd..ea2fd7783 100644
--- a/src/zenhorde/include/zenhorde/hordeprovisioner.h
+++ b/src/zenhorde/include/zenhorde/hordeprovisioner.h
@@ -2,21 +2,32 @@
#pragma once
+#include <zenhorde/hordeclient.h>
#include <zenhorde/hordeconfig.h>
+#include <zencompute/provisionerstate.h>
#include <zencore/logbase.h>
+#include <zencore/thread.h>
#include <atomic>
#include <cstdint>
+#include <deque>
#include <filesystem>
#include <memory>
#include <mutex>
#include <string>
+#include <thread>
+#include <unordered_set>
#include <vector>
+namespace asio {
+class io_context;
+}
+
namespace zen::horde {
class HordeClient;
+class AsyncHordeAgent;
/** Snapshot of the current provisioning state, returned by HordeProvisioner::GetStats(). */
struct ProvisioningStats
@@ -35,13 +46,12 @@ struct ProvisioningStats
* binary, and executing it remotely. Each provisioned machine runs zenserver
* in compute mode, which announces itself back to the orchestrator.
*
- * Spawns one thread per agent. Each thread handles the full lifecycle:
- * HTTP request -> TCP connect -> nonce handshake -> optional AES encryption ->
- * channel setup -> binary upload -> remote execution -> poll until exit.
+ * Agent work (HTTP request, connect, upload, poll) is dispatched to a thread
+ * pool rather than spawning a dedicated thread per agent.
*
* Thread safety: SetTargetCoreCount and GetStats may be called from any thread.
*/
-class HordeProvisioner
+class HordeProvisioner : public compute::IProvisionerStateProvider
{
public:
/** Construct a provisioner.
@@ -52,38 +62,48 @@ public:
HordeProvisioner(const HordeConfig& Config,
const std::filesystem::path& BinariesPath,
const std::filesystem::path& WorkingDir,
- std::string_view OrchestratorEndpoint);
+ std::string_view OrchestratorEndpoint,
+ std::string_view CoordinatorSession = {},
+ bool CleanStart = false,
+ std::string_view TraceHost = {});
- /** Signals all agent threads to exit and joins them. */
- ~HordeProvisioner();
+ /** Signals all agents to exit and waits for completion. */
+ ~HordeProvisioner() override;
HordeProvisioner(const HordeProvisioner&) = delete;
HordeProvisioner& operator=(const HordeProvisioner&) = delete;
/** Set the target number of cores to provision.
- * Clamped to HordeConfig::MaxCores. Spawns new agent threads if the
- * estimated core count is below the target. Also joins any finished
- * agent threads. */
- void SetTargetCoreCount(uint32_t Count);
+ * Clamped to HordeConfig::MaxCores. Dispatches new agent work if the
+ * estimated core count is below the target. Also removes finished agents. */
+ void SetTargetCoreCount(uint32_t Count) override;
/** Return a snapshot of the current provisioning counters. */
ProvisioningStats GetStats() const;
- uint32_t GetActiveCoreCount() const { return m_ActiveCoreCount.load(); }
- uint32_t GetAgentCount() const;
+ // IProvisionerStateProvider
+ std::string_view GetName() const override { return "horde"; }
+ uint32_t GetTargetCoreCount() const override { return m_TargetCoreCount.load(); }
+ uint32_t GetEstimatedCoreCount() const override { return m_EstimatedCoreCount.load(); }
+ uint32_t GetActiveCoreCount() const override { return m_ActiveCoreCount.load(); }
+ uint32_t GetAgentCount() const override;
+ uint32_t GetDrainingAgentCount() const override { return m_AgentsDraining.load(); }
+ compute::AgentProvisioningStatus GetAgentStatus(std::string_view WorkerId) const override;
private:
LoggerRef Log() { return m_Log; }
- struct AgentWrapper;
-
void RequestAgent();
- void ThreadAgent(AgentWrapper& Wrapper);
+ void ProvisionAgent();
+ bool InitializeHordeClient();
HordeConfig m_Config;
std::filesystem::path m_BinariesPath;
std::filesystem::path m_WorkingDir;
std::string m_OrchestratorEndpoint;
+ std::string m_CoordinatorSession;
+ bool m_CleanStart = false;
+ std::string m_TraceHost;
std::unique_ptr<HordeClient> m_HordeClient;
@@ -91,20 +111,54 @@ private:
std::vector<std::pair<std::string, std::filesystem::path>> m_Bundles; ///< (locator, bundleDir) pairs
bool m_BundlesCreated = false;
- mutable std::mutex m_AgentsLock;
- std::vector<std::unique_ptr<AgentWrapper>> m_Agents;
-
std::atomic<uint64_t> m_LastRequestFailTime{0};
std::atomic<uint32_t> m_TargetCoreCount{0};
std::atomic<uint32_t> m_EstimatedCoreCount{0};
std::atomic<uint32_t> m_ActiveCoreCount{0};
std::atomic<uint32_t> m_AgentsActive{0};
+ std::atomic<uint32_t> m_AgentsDraining{0};
std::atomic<uint32_t> m_AgentsRequesting{0};
std::atomic<bool> m_AskForAgents{true};
+ std::atomic<uint32_t> m_PendingWorkItems{0};
+ Event m_AllWorkDone;
+ /** Manual-reset event set alongside m_AskForAgents=false so pool-thread backoff waits
+ * wake immediately on shutdown instead of polling a 100ms sleep. */
+ Event m_ShutdownEvent;
LoggerRef m_Log;
+ // Async I/O
+ std::unique_ptr<asio::io_context> m_IoContext;
+ std::vector<std::thread> m_IoThreads;
+
+ struct AsyncAgentEntry
+ {
+ std::shared_ptr<AsyncHordeAgent> Agent;
+ std::string RemoteEndpoint;
+ std::string LeaseId;
+ uint16_t CoreCount = 0;
+ bool Draining = false;
+ };
+
+ mutable std::mutex m_AsyncAgentsLock;
+ std::vector<AsyncAgentEntry> m_AsyncAgents;
+
+ /** Worker IDs of agents that completed after draining.
+ * GetAgentStatus() consumes entries when queried, but if no one queries, entries would
+ * otherwise accumulate unbounded across the lifetime of the provisioner. Cap the set
+ * at RecentlyDrainedCapacity by evicting the oldest entry (tracked in an insertion-order
+ * queue) whenever we insert past the limit. */
+ mutable std::unordered_set<std::string> m_RecentlyDrainedWorkerIds;
+ mutable std::deque<std::string> m_RecentlyDrainedOrder;
+ static constexpr size_t RecentlyDrainedCapacity = 256;
+
+ void OnAsyncAgentDone(std::shared_ptr<AsyncHordeAgent> Agent);
+ void DrainAsyncAgent(AsyncAgentEntry& Entry);
+
+ std::vector<std::string> BuildAgentArgs(const MachineInfo& Machine) const;
+
static constexpr uint32_t EstimatedCoresPerAgent = 32;
+ static constexpr int IoThreadCount = 3;
};
} // namespace zen::horde
diff --git a/src/zenhorde/xmake.lua b/src/zenhorde/xmake.lua
index 48d028e86..0e69e9c5f 100644
--- a/src/zenhorde/xmake.lua
+++ b/src/zenhorde/xmake.lua
@@ -14,7 +14,7 @@ target('zenhorde')
end
if is_plat("linux") or is_plat("macosx") then
- add_packages("openssl")
+ add_packages("openssl3")
end
if is_os("macosx") then