diff options
Diffstat (limited to 'src/zenhorde')
23 files changed, 2492 insertions, 2127 deletions
diff --git a/src/zenhorde/README.md b/src/zenhorde/README.md new file mode 100644 index 000000000..13beaa968 --- /dev/null +++ b/src/zenhorde/README.md @@ -0,0 +1,17 @@ +# Horde Compute integration + +Zen compute can use Horde to provision runner nodes. + +## Launch a coordinator instance + +Coordinator instances provision compute resources (runners) from a compute provider such as Horde, and surface an interface which allows zenserver instances to discover endpoints which they can submit actions to. + +```bash +zenserver compute --horde-enabled --horde-server=https://horde.dev.net:13340/ --horde-max-cores=512 --horde-zen-service-port=25000 --http=asio +``` + +## Use a coordinator + +```bash +zen exec beacon --path=e:\lyra-recording --orch=http://localhost:8558 +``` diff --git a/src/zenhorde/hordeagent.cpp b/src/zenhorde/hordeagent.cpp index 819b2d0cb..029b98e55 100644 --- a/src/zenhorde/hordeagent.cpp +++ b/src/zenhorde/hordeagent.cpp @@ -8,290 +8,479 @@ #include <zencore/logging.h> #include <zencore/trace.h> +ZEN_THIRD_PARTY_INCLUDES_START +#include <asio.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + #include <cstring> -#include <unordered_map> namespace zen::horde { -HordeAgent::HordeAgent(const MachineInfo& Info) : m_Log(zen::logging::Get("horde.agent")), m_MachineInfo(Info) -{ - ZEN_TRACE_CPU("HordeAgent::Connect"); +// --- AsyncHordeAgent --- - auto Transport = std::make_unique<TcpComputeTransport>(Info); - if (!Transport->IsValid()) +static const char* +GetStateName(AsyncHordeAgent::State S) +{ + switch (S) { - ZEN_WARN("failed to create TCP transport to '{}:{}'", Info.GetConnectionAddress(), Info.GetConnectionPort()); - return; + case AsyncHordeAgent::State::Idle: + return "idle"; + case AsyncHordeAgent::State::Connecting: + return "connect"; + case AsyncHordeAgent::State::WaitAgentAttach: + return "agent-attach"; + case AsyncHordeAgent::State::SentFork: + return "fork"; + case AsyncHordeAgent::State::WaitChildAttach: + return "child-attach"; + case AsyncHordeAgent::State::Uploading: + return "upload"; + case AsyncHordeAgent::State::Executing: + return "execute"; + case AsyncHordeAgent::State::Polling: + return "poll"; + case AsyncHordeAgent::State::Done: + return "done"; + default: + return "unknown"; } +} - // The 64-byte nonce is always sent unencrypted as the first thing on the wire. - // The Horde agent uses this to identify which lease this connection belongs to. - Transport->Send(Info.Nonce, sizeof(Info.Nonce)); +AsyncHordeAgent::AsyncHordeAgent(asio::io_context& IoContext) : m_IoContext(IoContext), m_Log(zen::logging::Get("horde.agent.async")) +{ +} - std::unique_ptr<ComputeTransport> FinalTransport = std::move(Transport); - if (Info.EncryptionMode == Encryption::AES) +AsyncHordeAgent::~AsyncHordeAgent() +{ + Cancel(); +} + +void +AsyncHordeAgent::Start(AsyncAgentConfig Config, AsyncAgentCompletionHandler OnDone) +{ + m_Config = std::move(Config); + m_OnDone = std::move(OnDone); + m_State = State::Connecting; + DoConnect(); +} + +void +AsyncHordeAgent::Cancel() +{ + m_Cancelled = true; + if (m_Socket) { - FinalTransport = std::make_unique<AesComputeTransport>(Info.Key, std::move(FinalTransport)); - if (!FinalTransport->IsValid()) - { - ZEN_WARN("failed to create AES transport"); - return; - } + m_Socket->Close(); + } + else if (m_TcpTransport) + { + // Cancelled before handshake completed - tear down the pending TCP connect. + m_TcpTransport->Close(); } +} - // Create multiplexed socket and channels - m_Socket = std::make_unique<ComputeSocket>(std::move(FinalTransport)); +void +AsyncHordeAgent::DoConnect() +{ + ZEN_TRACE_CPU("AsyncHordeAgent::DoConnect"); - // Channel 0 is the agent control channel (handles Attach/Fork handshake). - // Channel 100 is the child I/O channel (handles file upload and remote execution). - Ref<ComputeChannel> AgentComputeChannel = m_Socket->CreateChannel(0); - Ref<ComputeChannel> ChildComputeChannel = m_Socket->CreateChannel(100); + m_TcpTransport = std::make_unique<AsyncTcpComputeTransport>(m_IoContext); + + auto Self = shared_from_this(); + m_TcpTransport->AsyncConnect(m_Config.Machine, [this, Self](const std::error_code& Ec) { OnConnected(Ec); }); +} - if (!AgentComputeChannel || !ChildComputeChannel) +void +AsyncHordeAgent::OnConnected(const std::error_code& Ec) +{ + if (Ec || m_Cancelled) { - ZEN_WARN("failed to create compute channels"); + if (Ec) + { + ZEN_WARN("connect failed: {}", Ec.message()); + } + Finish(false); return; } - m_AgentChannel = std::make_unique<AgentMessageChannel>(std::move(AgentComputeChannel)); - m_ChildChannel = std::make_unique<AgentMessageChannel>(std::move(ChildComputeChannel)); + // Optionally wrap with AES encryption + std::unique_ptr<AsyncComputeTransport> FinalTransport = std::move(m_TcpTransport); + if (m_Config.Machine.EncryptionMode == Encryption::AES) + { + FinalTransport = std::make_unique<AsyncAesComputeTransport>(m_Config.Machine.Key, std::move(FinalTransport), m_IoContext); + } + + // Create the multiplexed socket and register channels. Ownership of the transport + // moves into the socket here - no need to retain a separate m_Transport field. + m_Socket = std::make_shared<AsyncComputeSocket>(std::move(FinalTransport), m_IoContext); + + m_AgentChannel = std::make_unique<AsyncAgentMessageChannel>(m_Socket, 0, m_IoContext); + m_ChildChannel = std::make_unique<AsyncAgentMessageChannel>(m_Socket, 100, m_IoContext); + + m_Socket->RegisterChannel( + 0, + [this](std::vector<uint8_t> Data) { m_AgentChannel->OnFrame(std::move(Data)); }, + [this]() { m_AgentChannel->OnDetach(); }); - m_IsValid = true; + m_Socket->RegisterChannel( + 100, + [this](std::vector<uint8_t> Data) { m_ChildChannel->OnFrame(std::move(Data)); }, + [this]() { m_ChildChannel->OnDetach(); }); + + m_Socket->StartRecvPump(); + + m_State = State::WaitAgentAttach; + DoWaitAgentAttach(); } -HordeAgent::~HordeAgent() +void +AsyncHordeAgent::DoWaitAgentAttach() { - CloseConnection(); + auto Self = shared_from_this(); + m_AgentChannel->AsyncReadResponse(5000, [this, Self](AgentMessageType Type, const uint8_t* Data, size_t Size) { + OnAgentResponse(Type, Data, Size); + }); } -bool -HordeAgent::BeginCommunication() +void +AsyncHordeAgent::OnAgentResponse(AgentMessageType Type, const uint8_t* /*Data*/, size_t /*Size*/) { - ZEN_TRACE_CPU("HordeAgent::BeginCommunication"); - - if (!m_IsValid) + if (m_Cancelled) { - return false; + Finish(false); + return; } - // Start the send/recv pump threads - m_Socket->StartCommunication(); - - // Wait for Attach on agent channel - AgentMessageType Type = m_AgentChannel->ReadResponse(5000); if (Type == AgentMessageType::None) { ZEN_WARN("timed out waiting for Attach on agent channel"); - return false; + Finish(false); + return; } + if (Type != AgentMessageType::Attach) { ZEN_WARN("expected Attach on agent channel, got 0x{:02x}", static_cast<int>(Type)); - return false; + Finish(false); + return; } - // Fork tells the remote agent to create child channel 100 with a 4MB buffer. - // After this, the agent will send an Attach on the child channel. + m_State = State::SentFork; + DoSendFork(); +} + +void +AsyncHordeAgent::DoSendFork() +{ m_AgentChannel->Fork(100, 4 * 1024 * 1024); - // Wait for Attach on child channel - Type = m_ChildChannel->ReadResponse(5000); + m_State = State::WaitChildAttach; + DoWaitChildAttach(); +} + +void +AsyncHordeAgent::DoWaitChildAttach() +{ + auto Self = shared_from_this(); + m_ChildChannel->AsyncReadResponse(5000, [this, Self](AgentMessageType Type, const uint8_t* Data, size_t Size) { + OnChildAttachResponse(Type, Data, Size); + }); +} + +void +AsyncHordeAgent::OnChildAttachResponse(AgentMessageType Type, const uint8_t* /*Data*/, size_t /*Size*/) +{ + if (m_Cancelled) + { + Finish(false); + return; + } + if (Type == AgentMessageType::None) { ZEN_WARN("timed out waiting for Attach on child channel"); - return false; + Finish(false); + return; } + if (Type != AgentMessageType::Attach) { ZEN_WARN("expected Attach on child channel, got 0x{:02x}", static_cast<int>(Type)); - return false; + Finish(false); + return; } - return true; + m_State = State::Uploading; + m_CurrentBundleIndex = 0; + DoUploadNext(); } -bool -HordeAgent::UploadBinaries(const std::filesystem::path& BundleDir, const std::string& BundleLocator) +void +AsyncHordeAgent::DoUploadNext() { - ZEN_TRACE_CPU("HordeAgent::UploadBinaries"); + if (m_Cancelled) + { + Finish(false); + return; + } + + if (m_CurrentBundleIndex >= m_Config.Bundles.size()) + { + // All bundles uploaded - proceed to execute + m_State = State::Executing; + DoExecute(); + return; + } - m_ChildChannel->UploadFiles("", BundleLocator.c_str()); + const auto& [Locator, BundleDir] = m_Config.Bundles[m_CurrentBundleIndex]; + m_ChildChannel->UploadFiles("", Locator.c_str()); - std::unordered_map<std::string, std::unique_ptr<BasicFile>> BlobFiles; + // Enter the ReadBlob/Blob upload loop + auto Self = shared_from_this(); + m_ChildChannel->AsyncReadResponse(1000, [this, Self](AgentMessageType Type, const uint8_t* Data, size_t Size) { + OnUploadResponse(Type, Data, Size); + }); +} - auto FindOrOpenBlob = [&](std::string_view Locator) -> BasicFile* { - std::string Key(Locator); +void +AsyncHordeAgent::OnUploadResponse(AgentMessageType Type, const uint8_t* Data, size_t Size) +{ + if (m_Cancelled) + { + Finish(false); + return; + } - if (auto It = BlobFiles.find(Key); It != BlobFiles.end()) + if (Type == AgentMessageType::None) + { + if (m_ChildChannel->IsDetached()) { - return It->second.get(); + ZEN_WARN("connection lost during upload"); + Finish(false); + return; } + // Timeout - retry read + auto Self = shared_from_this(); + m_ChildChannel->AsyncReadResponse(1000, [this, Self](AgentMessageType Type, const uint8_t* Data, size_t Size) { + OnUploadResponse(Type, Data, Size); + }); + return; + } - const std::filesystem::path Path = BundleDir / (Key + ".blob"); - std::error_code Ec; - auto File = std::make_unique<BasicFile>(); - File->Open(Path, BasicFile::Mode::kRead, Ec); + if (Type == AgentMessageType::WriteFilesResponse) + { + // This bundle upload is done - move to next + ++m_CurrentBundleIndex; + DoUploadNext(); + return; + } - if (Ec) + if (Type == AgentMessageType::Exception) + { + ExceptionInfo Ex; + if (!AsyncAgentMessageChannel::ReadException(Data, Size, Ex)) { - ZEN_ERROR("cannot read blob file: '{}'", Path); - return nullptr; + ZEN_ERROR("malformed Exception message during upload (size={})", Size); + Finish(false); + return; } + ZEN_ERROR("upload exception: {} - {}", Ex.Message, Ex.Description); + Finish(false); + return; + } - BasicFile* Ptr = File.get(); - BlobFiles.emplace(std::move(Key), std::move(File)); - return Ptr; - }; - - // The upload protocol is request-driven: we send WriteFiles, then the remote agent - // sends ReadBlob requests for each blob it needs. We respond with Blob data until - // the agent sends WriteFilesResponse indicating the upload is complete. - constexpr int32_t ReadResponseTimeoutMs = 1000; + if (Type != AgentMessageType::ReadBlob) + { + ZEN_ERROR("unexpected message type 0x{:02x} during upload", static_cast<int>(Type)); + Finish(false); + return; + } - for (;;) + // Handle ReadBlob request + BlobRequest Req; + if (!AsyncAgentMessageChannel::ReadBlobRequest(Data, Size, Req)) { - bool TimedOut = false; + ZEN_ERROR("malformed ReadBlob message during upload (size={})", Size); + Finish(false); + return; + } - if (AgentMessageType Type = m_ChildChannel->ReadResponse(ReadResponseTimeoutMs, &TimedOut); Type != AgentMessageType::ReadBlob) - { - if (TimedOut) - { - continue; - } - // End of stream - check if it was a successful upload - if (Type == AgentMessageType::WriteFilesResponse) - { - return true; - } - else if (Type == AgentMessageType::Exception) - { - ExceptionInfo Ex; - m_ChildChannel->ReadException(Ex); - ZEN_ERROR("upload exception: {} - {}", Ex.Message, Ex.Description); - } - else - { - ZEN_ERROR("unexpected message type 0x{:02x} during upload", static_cast<int>(Type)); - } - return false; - } + const auto& [Locator, BundleDir] = m_Config.Bundles[m_CurrentBundleIndex]; + const std::filesystem::path BlobPath = BundleDir / (std::string(Req.Locator) + ".blob"); - BlobRequest Req; - m_ChildChannel->ReadBlobRequest(Req); + std::error_code FsEc; + BasicFile File; + File.Open(BlobPath, BasicFile::Mode::kRead, FsEc); - BasicFile* File = FindOrOpenBlob(Req.Locator); - if (!File) - { - return false; - } + if (FsEc) + { + ZEN_ERROR("cannot read blob file: '{}'", BlobPath); + Finish(false); + return; + } - // Read from offset to end of file - const uint64_t TotalSize = File->FileSize(); - const uint64_t Offset = static_cast<uint64_t>(Req.Offset); - if (Offset >= TotalSize) - { - ZEN_ERROR("upload got request for data beyond end of file: offset={}, length={}, total_size={}", Offset, Req.Length, TotalSize); - m_ChildChannel->Blob(nullptr, 0); - continue; - } + const uint64_t TotalSize = File.FileSize(); + const uint64_t Offset = static_cast<uint64_t>(Req.Offset); + if (Offset >= TotalSize) + { + ZEN_ERROR("blob request beyond end of file: offset={}, length={}, total_size={}", Offset, Req.Length, TotalSize); + m_ChildChannel->Blob(nullptr, 0); + } + else + { + const IoBuffer FileData = File.ReadRange(Offset, Min(Req.Length, TotalSize - Offset)); + m_ChildChannel->Blob(static_cast<const uint8_t*>(FileData.GetData()), FileData.GetSize()); + } + + // Continue the upload loop + auto Self = shared_from_this(); + m_ChildChannel->AsyncReadResponse(1000, [this, Self](AgentMessageType Type, const uint8_t* Data, size_t Size) { + OnUploadResponse(Type, Data, Size); + }); +} - const IoBuffer Data = File->ReadRange(Offset, Min(Req.Length, TotalSize - Offset)); - m_ChildChannel->Blob(static_cast<const uint8_t*>(Data.GetData()), Data.GetSize()); +void +AsyncHordeAgent::DoExecute() +{ + ZEN_TRACE_CPU("AsyncHordeAgent::DoExecute"); + + std::vector<const char*> ArgPtrs; + ArgPtrs.reserve(m_Config.Args.size()); + for (const std::string& Arg : m_Config.Args) + { + ArgPtrs.push_back(Arg.c_str()); } + + m_ChildChannel->Execute(m_Config.Executable.c_str(), + ArgPtrs.data(), + ArgPtrs.size(), + nullptr, + nullptr, + 0, + m_Config.UseWine ? ExecuteProcessFlags::UseWine : ExecuteProcessFlags::None); + + ZEN_INFO("remote execution started on [{}:{}] lease={}", + m_Config.Machine.GetConnectionAddress(), + m_Config.Machine.GetConnectionPort(), + m_Config.Machine.LeaseId); + + m_State = State::Polling; + DoPoll(); } void -HordeAgent::Execute(const char* Exe, - const char* const* Args, - size_t NumArgs, - const char* WorkingDir, - const char* const* EnvVars, - size_t NumEnvVars, - bool UseWine) +AsyncHordeAgent::DoPoll() { - ZEN_TRACE_CPU("HordeAgent::Execute"); - m_ChildChannel - ->Execute(Exe, Args, NumArgs, WorkingDir, EnvVars, NumEnvVars, UseWine ? ExecuteProcessFlags::UseWine : ExecuteProcessFlags::None); + if (m_Cancelled) + { + Finish(false); + return; + } + + auto Self = shared_from_this(); + m_ChildChannel->AsyncReadResponse(100, [this, Self](AgentMessageType Type, const uint8_t* Data, size_t Size) { + OnPollResponse(Type, Data, Size); + }); } -bool -HordeAgent::Poll(bool LogOutput) +void +AsyncHordeAgent::OnPollResponse(AgentMessageType Type, const uint8_t* Data, size_t Size) { - constexpr int32_t ReadResponseTimeoutMs = 100; - AgentMessageType Type; + if (m_Cancelled) + { + Finish(false); + return; + } - while ((Type = m_ChildChannel->ReadResponse(ReadResponseTimeoutMs)) != AgentMessageType::None) + switch (Type) { - switch (Type) - { - case AgentMessageType::ExecuteOutput: + case AgentMessageType::None: + if (m_ChildChannel->IsDetached()) + { + ZEN_WARN("connection lost during execution"); + Finish(false); + } + else + { + // Timeout - poll again + DoPoll(); + } + break; + + case AgentMessageType::ExecuteOutput: + // Silently consume remote stdout (matching LogOutput=false in provisioner) + DoPoll(); + break; + + case AgentMessageType::ExecuteResult: + { + int32_t ExitCode = -1; + if (!AsyncAgentMessageChannel::ReadExecuteResult(Data, Size, ExitCode)) { - if (LogOutput && m_ChildChannel->GetResponseSize() > 0) - { - const char* ResponseData = static_cast<const char*>(m_ChildChannel->GetResponseData()); - size_t ResponseSize = m_ChildChannel->GetResponseSize(); - - // Trim trailing newlines - while (ResponseSize > 0 && (ResponseData[ResponseSize - 1] == '\n' || ResponseData[ResponseSize - 1] == '\r')) - { - --ResponseSize; - } - - if (ResponseSize > 0) - { - const std::string_view Output(ResponseData, ResponseSize); - ZEN_INFO("[remote] {}", Output); - } - } + // A remote with a malformed ExecuteResult cannot be trusted to report + // process outcome - treat as a protocol error and tear down rather than + // silently recording \"exited with -1\". + ZEN_ERROR("malformed ExecuteResult (size={}, lease={}) - disconnecting", Size, m_Config.Machine.LeaseId); + Finish(false); break; } + ZEN_INFO("remote process exited with code {} (lease={})", ExitCode, m_Config.Machine.LeaseId); + Finish(ExitCode == 0, ExitCode); + } + break; - case AgentMessageType::ExecuteResult: + case AgentMessageType::Exception: + { + ExceptionInfo Ex; + if (AsyncAgentMessageChannel::ReadException(Data, Size, Ex)) { - if (m_ChildChannel->GetResponseSize() == sizeof(int32_t)) - { - int32_t ExitCode; - memcpy(&ExitCode, m_ChildChannel->GetResponseData(), sizeof(int32_t)); - ZEN_INFO("remote process exited with code {}", ExitCode); - } - m_IsValid = false; - return false; + ZEN_ERROR("exception: {} - {}", Ex.Message, Ex.Description); } - - case AgentMessageType::Exception: + else { - ExceptionInfo Ex; - m_ChildChannel->ReadException(Ex); - ZEN_ERROR("exception: {} - {}", Ex.Message, Ex.Description); - m_HasErrors = true; - break; + ZEN_ERROR("malformed Exception message (size={})", Size); } + Finish(false); + } + break; - default: - break; - } + default: + DoPoll(); + break; } - - return m_IsValid && !m_HasErrors; } void -HordeAgent::CloseConnection() +AsyncHordeAgent::Finish(bool Success, int32_t ExitCode) { - if (m_ChildChannel) + if (m_State == State::Done) { - m_ChildChannel->Close(); + return; // Already finished } - if (m_AgentChannel) + + if (!Success) { - m_AgentChannel->Close(); + ZEN_WARN("agent failed during {} (lease={})", GetStateName(m_State), m_Config.Machine.LeaseId); } -} -bool -HordeAgent::IsValid() const -{ - return m_IsValid && !m_HasErrors; + m_State = State::Done; + + if (m_Socket) + { + m_Socket->Close(); + } + + if (m_OnDone) + { + AsyncAgentResult Result; + Result.Success = Success; + Result.ExitCode = ExitCode; + Result.CoreCount = m_Config.Machine.LogicalCores; + + auto Handler = std::move(m_OnDone); + m_OnDone = nullptr; + Handler(Result); + } } } // namespace zen::horde diff --git a/src/zenhorde/hordeagent.h b/src/zenhorde/hordeagent.h index e0ae89ead..a5b3248ab 100644 --- a/src/zenhorde/hordeagent.h +++ b/src/zenhorde/hordeagent.h @@ -10,68 +10,107 @@ #include <zencore/logbase.h> #include <filesystem> +#include <functional> #include <memory> #include <string> +#include <vector> + +namespace asio { +class io_context; +} namespace zen::horde { -/** Manages the lifecycle of a single Horde compute agent. +class AsyncComputeTransport; + +/** Result passed to the completion handler when an async agent finishes. */ +struct AsyncAgentResult +{ + bool Success = false; + int32_t ExitCode = -1; + uint16_t CoreCount = 0; ///< Logical cores on the provisioned machine +}; + +/** Completion handler for async agent lifecycle. */ +using AsyncAgentCompletionHandler = std::function<void(const AsyncAgentResult&)>; + +/** Configuration for launching a remote zenserver instance via an async agent. */ +struct AsyncAgentConfig +{ + MachineInfo Machine; + std::vector<std::pair<std::string, std::filesystem::path>> Bundles; ///< (locator, bundleDir) pairs + std::string Executable; + std::vector<std::string> Args; + bool UseWine = false; +}; + +/** Async agent that manages the full lifecycle of a single Horde compute connection. * - * Handles the full connection sequence for one provisioned machine: - * 1. Connect via TCP transport (with optional AES encryption wrapping) - * 2. Create a multiplexed ComputeSocket with agent (channel 0) and child (channel 100) - * 3. Perform the Attach/Fork handshake to establish the child channel - * 4. Upload zenserver binary via the WriteFiles/ReadBlob protocol - * 5. Execute zenserver remotely via ExecuteV2 - * 6. Poll for ExecuteOutput (stdout) and ExecuteResult (exit code) + * Driven by a state machine using callbacks on a shared io_context - no dedicated + * threads. Call Start() to begin the connection/handshake/upload/execute/poll + * sequence. The completion handler is invoked when the remote process exits or + * an error occurs. */ -class HordeAgent +class AsyncHordeAgent : public std::enable_shared_from_this<AsyncHordeAgent> { public: - explicit HordeAgent(const MachineInfo& Info); - ~HordeAgent(); + AsyncHordeAgent(asio::io_context& IoContext); + ~AsyncHordeAgent(); - HordeAgent(const HordeAgent&) = delete; - HordeAgent& operator=(const HordeAgent&) = delete; + AsyncHordeAgent(const AsyncHordeAgent&) = delete; + AsyncHordeAgent& operator=(const AsyncHordeAgent&) = delete; - /** Perform the channel setup handshake (Attach on agent channel, Fork, Attach on child channel). - * Returns false if the handshake times out or receives an unexpected message. */ - bool BeginCommunication(); + /** Start the full agent lifecycle. The completion handler is called exactly once. */ + void Start(AsyncAgentConfig Config, AsyncAgentCompletionHandler OnDone); - /** Upload binary files to the remote agent. - * @param BundleDir Directory containing .blob files. - * @param BundleLocator Locator string identifying the bundle (from CreateBundle). */ - bool UploadBinaries(const std::filesystem::path& BundleDir, const std::string& BundleLocator); + /** Cancel in-flight operations. The completion handler is still called (with Success=false). */ + void Cancel(); - /** Execute a command on the remote machine. */ - void Execute(const char* Exe, - const char* const* Args, - size_t NumArgs, - const char* WorkingDir = nullptr, - const char* const* EnvVars = nullptr, - size_t NumEnvVars = 0, - bool UseWine = false); + const MachineInfo& GetMachineInfo() const { return m_Config.Machine; } - /** Poll for output and results. Returns true if the agent is still running. - * When LogOutput is true, remote stdout is logged via ZEN_INFO. */ - bool Poll(bool LogOutput = true); - - void CloseConnection(); - bool IsValid() const; - - const MachineInfo& GetMachineInfo() const { return m_MachineInfo; } + enum class State + { + Idle, + Connecting, + WaitAgentAttach, + SentFork, + WaitChildAttach, + Uploading, + Executing, + Polling, + Done + }; private: LoggerRef Log() { return m_Log; } - std::unique_ptr<ComputeSocket> m_Socket; - std::unique_ptr<AgentMessageChannel> m_AgentChannel; ///< Channel 0: agent control - std::unique_ptr<AgentMessageChannel> m_ChildChannel; ///< Channel 100: child I/O - - LoggerRef m_Log; - bool m_IsValid = false; - bool m_HasErrors = false; - MachineInfo m_MachineInfo; + void DoConnect(); + void OnConnected(const std::error_code& Ec); + void DoWaitAgentAttach(); + void OnAgentResponse(AgentMessageType Type, const uint8_t* Data, size_t Size); + void DoSendFork(); + void DoWaitChildAttach(); + void OnChildAttachResponse(AgentMessageType Type, const uint8_t* Data, size_t Size); + void DoUploadNext(); + void OnUploadResponse(AgentMessageType Type, const uint8_t* Data, size_t Size); + void DoExecute(); + void DoPoll(); + void OnPollResponse(AgentMessageType Type, const uint8_t* Data, size_t Size); + void Finish(bool Success, int32_t ExitCode = -1); + + asio::io_context& m_IoContext; + LoggerRef m_Log; + State m_State = State::Idle; + bool m_Cancelled = false; + + AsyncAgentConfig m_Config; + AsyncAgentCompletionHandler m_OnDone; + size_t m_CurrentBundleIndex = 0; + + std::unique_ptr<AsyncTcpComputeTransport> m_TcpTransport; + std::shared_ptr<AsyncComputeSocket> m_Socket; + std::unique_ptr<AsyncAgentMessageChannel> m_AgentChannel; + std::unique_ptr<AsyncAgentMessageChannel> m_ChildChannel; }; } // namespace zen::horde diff --git a/src/zenhorde/hordeagentmessage.cpp b/src/zenhorde/hordeagentmessage.cpp index 998134a96..bef1bdda8 100644 --- a/src/zenhorde/hordeagentmessage.cpp +++ b/src/zenhorde/hordeagentmessage.cpp @@ -4,337 +4,496 @@ #include <zencore/intmath.h> -#include <cassert> +ZEN_THIRD_PARTY_INCLUDES_START +#include <asio.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +#include <zencore/except_fmt.h> +#include <zencore/logging.h> + #include <cstring> +#include <limits> namespace zen::horde { -AgentMessageChannel::AgentMessageChannel(Ref<ComputeChannel> Channel) : m_Channel(std::move(Channel)) -{ -} - -AgentMessageChannel::~AgentMessageChannel() = default; +// --- AsyncAgentMessageChannel --- -void -AgentMessageChannel::Close() +AsyncAgentMessageChannel::AsyncAgentMessageChannel(std::shared_ptr<AsyncComputeSocket> Socket, int ChannelId, asio::io_context& IoContext) +: m_Socket(std::move(Socket)) +, m_ChannelId(ChannelId) +, m_IoContext(IoContext) +, m_TimeoutTimer(std::make_unique<asio::steady_timer>(m_Socket->GetStrand())) { - CreateMessage(AgentMessageType::None, 0); - FlushMessage(); } -void -AgentMessageChannel::Ping() +AsyncAgentMessageChannel::~AsyncAgentMessageChannel() { - CreateMessage(AgentMessageType::Ping, 0); - FlushMessage(); + if (m_TimeoutTimer) + { + m_TimeoutTimer->cancel(); + } } -void -AgentMessageChannel::Fork(int ChannelId, int BufferSize) +// --- Message building helpers --- + +std::vector<uint8_t> +AsyncAgentMessageChannel::BeginMessage(AgentMessageType Type, size_t ReservePayload) { - CreateMessage(AgentMessageType::Fork, sizeof(int) + sizeof(int)); - WriteInt32(ChannelId); - WriteInt32(BufferSize); - FlushMessage(); + std::vector<uint8_t> Buf; + Buf.reserve(MessageHeaderLength + ReservePayload); + Buf.push_back(static_cast<uint8_t>(Type)); + Buf.resize(MessageHeaderLength); // 1 byte type + 4 bytes length placeholder + return Buf; } void -AgentMessageChannel::Attach() +AsyncAgentMessageChannel::FinalizeAndSend(std::vector<uint8_t> Msg) { - CreateMessage(AgentMessageType::Attach, 0); - FlushMessage(); + const uint32_t PayloadSize = static_cast<uint32_t>(Msg.size() - MessageHeaderLength); + memcpy(&Msg[1], &PayloadSize, sizeof(uint32_t)); + m_Socket->AsyncSendFrame(m_ChannelId, std::move(Msg)); } void -AgentMessageChannel::UploadFiles(const char* Path, const char* Locator) +AsyncAgentMessageChannel::WriteInt32(std::vector<uint8_t>& Buf, int Value) { - CreateMessage(AgentMessageType::WriteFiles, strlen(Path) + strlen(Locator) + 20); - WriteString(Path); - WriteString(Locator); - FlushMessage(); + const uint8_t* Ptr = reinterpret_cast<const uint8_t*>(&Value); + Buf.insert(Buf.end(), Ptr, Ptr + sizeof(int)); } -void -AgentMessageChannel::Execute(const char* Exe, - const char* const* Args, - size_t NumArgs, - const char* WorkingDir, - const char* const* EnvVars, - size_t NumEnvVars, - ExecuteProcessFlags Flags) +int +AsyncAgentMessageChannel::ReadInt32(ReadCursor& C) { - size_t RequiredSize = 50 + strlen(Exe); - for (size_t i = 0; i < NumArgs; ++i) - { - RequiredSize += strlen(Args[i]) + 10; - } - if (WorkingDir) - { - RequiredSize += strlen(WorkingDir) + 10; - } - for (size_t i = 0; i < NumEnvVars; ++i) + if (!C.CheckAvailable(sizeof(int32_t))) { - RequiredSize += strlen(EnvVars[i]) + 20; + return 0; } + int32_t Value; + memcpy(&Value, C.Pos, sizeof(int32_t)); + C.Pos += sizeof(int32_t); + return Value; +} - CreateMessage(AgentMessageType::ExecuteV2, RequiredSize); - WriteString(Exe); +void +AsyncAgentMessageChannel::WriteFixedLengthBytes(std::vector<uint8_t>& Buf, const uint8_t* Data, size_t Length) +{ + Buf.insert(Buf.end(), Data, Data + Length); +} - WriteUnsignedVarInt(NumArgs); - for (size_t i = 0; i < NumArgs; ++i) +const uint8_t* +AsyncAgentMessageChannel::ReadFixedLengthBytes(ReadCursor& C, size_t Length) +{ + if (!C.CheckAvailable(Length)) { - WriteString(Args[i]); + return nullptr; } + const uint8_t* Data = C.Pos; + C.Pos += Length; + return Data; +} - WriteOptionalString(WorkingDir); - - // ExecuteV2 protocol requires env vars as separate key/value pairs. - // Callers pass "KEY=VALUE" strings; we split on the first '=' here. - WriteUnsignedVarInt(NumEnvVars); - for (size_t i = 0; i < NumEnvVars; ++i) +size_t +AsyncAgentMessageChannel::MeasureUnsignedVarInt(size_t Value) +{ + if (Value == 0) { - const char* Eq = strchr(EnvVars[i], '='); - assert(Eq != nullptr); - - WriteString(std::string_view(EnvVars[i], Eq - EnvVars[i])); - if (*(Eq + 1) == '\0') - { - WriteOptionalString(nullptr); - } - else - { - WriteOptionalString(Eq + 1); - } + return 1; } - - WriteInt32(static_cast<int>(Flags)); - FlushMessage(); + return (FloorLog2_64(static_cast<uint64_t>(Value)) / 7) + 1; } void -AgentMessageChannel::Blob(const uint8_t* Data, size_t Length) +AsyncAgentMessageChannel::WriteUnsignedVarInt(std::vector<uint8_t>& Buf, size_t Value) { - // Blob responses are chunked to fit within the compute buffer's chunk size. - // The 128-byte margin accounts for the ReadBlobResponse header (offset + total length fields). - const size_t MaxChunkSize = m_Channel->Writer.GetChunkMaxLength() - 128 - MessageHeaderLength; - for (size_t ChunkOffset = 0; ChunkOffset < Length;) - { - const size_t ChunkLength = std::min(Length - ChunkOffset, MaxChunkSize); - - CreateMessage(AgentMessageType::ReadBlobResponse, ChunkLength + 128); - WriteInt32(static_cast<int>(ChunkOffset)); - WriteInt32(static_cast<int>(Length)); - WriteFixedLengthBytes(Data + ChunkOffset, ChunkLength); - FlushMessage(); + const size_t ByteCount = MeasureUnsignedVarInt(Value); + const size_t StartPos = Buf.size(); + Buf.resize(StartPos + ByteCount); - ChunkOffset += ChunkLength; + uint8_t* Output = Buf.data() + StartPos; + for (size_t i = 1; i < ByteCount; ++i) + { + Output[ByteCount - i] = static_cast<uint8_t>(Value); + Value >>= 8; } + Output[0] = static_cast<uint8_t>((0xFF << (9 - static_cast<int>(ByteCount))) | static_cast<uint8_t>(Value)); } -AgentMessageType -AgentMessageChannel::ReadResponse(int32_t TimeoutMs, bool* OutTimedOut) +size_t +AsyncAgentMessageChannel::ReadUnsignedVarInt(ReadCursor& C) { - // Deferred advance: the previous response's buffer is only released when the next - // ReadResponse is called. This allows callers to read response data between calls - // without copying, since the pointer comes directly from the ring buffer. - if (m_ResponseData) + // Need at least the leading byte to determine the encoded length. + if (!C.CheckAvailable(1)) { - m_Channel->Reader.AdvanceReadPosition(m_ResponseLength + MessageHeaderLength); - m_ResponseData = nullptr; - m_ResponseLength = 0; + return 0; } - const uint8_t* Header = m_Channel->Reader.WaitToRead(MessageHeaderLength, TimeoutMs, OutTimedOut); - if (!Header) + const uint8_t FirstByte = C.Pos[0]; + const size_t NumBytes = CountLeadingZeros(0xFF & (~static_cast<unsigned int>(FirstByte))) + 1 - 24; + + // The encoded length implied by the leading 0xFF-run may be 1..9 bytes; ensure the remaining bytes are in-bounds. + if (!C.CheckAvailable(NumBytes)) { - return AgentMessageType::None; + return 0; } - uint32_t Length; - memcpy(&Length, Header + 1, sizeof(uint32_t)); - - Header = m_Channel->Reader.WaitToRead(MessageHeaderLength + Length, TimeoutMs, OutTimedOut); - if (!Header) + size_t Value = static_cast<size_t>(FirstByte & (0xFF >> NumBytes)); + for (size_t i = 1; i < NumBytes; ++i) { - return AgentMessageType::None; + Value <<= 8; + Value |= C.Pos[i]; } - m_ResponseType = static_cast<AgentMessageType>(Header[0]); - m_ResponseData = Header + MessageHeaderLength; - m_ResponseLength = Length; - - return m_ResponseType; + C.Pos += NumBytes; + return Value; } void -AgentMessageChannel::ReadException(ExceptionInfo& Ex) +AsyncAgentMessageChannel::WriteString(std::vector<uint8_t>& Buf, const char* Text) { - assert(m_ResponseType == AgentMessageType::Exception); - const uint8_t* Pos = m_ResponseData; - Ex.Message = ReadString(&Pos); - Ex.Description = ReadString(&Pos); + const size_t Length = strlen(Text); + WriteUnsignedVarInt(Buf, Length); + WriteFixedLengthBytes(Buf, reinterpret_cast<const uint8_t*>(Text), Length); } -int -AgentMessageChannel::ReadExecuteResult() +void +AsyncAgentMessageChannel::WriteString(std::vector<uint8_t>& Buf, std::string_view Text) { - assert(m_ResponseType == AgentMessageType::ExecuteResult); - const uint8_t* Pos = m_ResponseData; - return ReadInt32(&Pos); + WriteUnsignedVarInt(Buf, Text.size()); + WriteFixedLengthBytes(Buf, reinterpret_cast<const uint8_t*>(Text.data()), Text.size()); } -void -AgentMessageChannel::ReadBlobRequest(BlobRequest& Req) +std::string_view +AsyncAgentMessageChannel::ReadString(ReadCursor& C) { - assert(m_ResponseType == AgentMessageType::ReadBlob); - const uint8_t* Pos = m_ResponseData; - Req.Locator = ReadString(&Pos); - Req.Offset = ReadUnsignedVarInt(&Pos); - Req.Length = ReadUnsignedVarInt(&Pos); + const size_t Length = ReadUnsignedVarInt(C); + const uint8_t* Start = ReadFixedLengthBytes(C, Length); + if (C.ParseError || !Start) + { + return {}; + } + return std::string_view(reinterpret_cast<const char*>(Start), Length); } void -AgentMessageChannel::CreateMessage(AgentMessageType Type, size_t MaxLength) +AsyncAgentMessageChannel::WriteOptionalString(std::vector<uint8_t>& Buf, const char* Text) { - m_RequestData = m_Channel->Writer.WaitToWrite(MessageHeaderLength + MaxLength); - m_RequestData[0] = static_cast<uint8_t>(Type); - m_MaxRequestSize = MaxLength; - m_RequestSize = 0; + if (!Text) + { + WriteUnsignedVarInt(Buf, 0); + } + else + { + const size_t Length = strlen(Text); + WriteUnsignedVarInt(Buf, Length + 1); + WriteFixedLengthBytes(Buf, reinterpret_cast<const uint8_t*>(Text), Length); + } } +// --- Send methods --- + void -AgentMessageChannel::FlushMessage() +AsyncAgentMessageChannel::Close() { - const uint32_t Size = static_cast<uint32_t>(m_RequestSize); - memcpy(&m_RequestData[1], &Size, sizeof(uint32_t)); - m_Channel->Writer.AdvanceWritePosition(MessageHeaderLength + m_RequestSize); - m_RequestSize = 0; - m_MaxRequestSize = 0; - m_RequestData = nullptr; + auto Msg = BeginMessage(AgentMessageType::None, 0); + FinalizeAndSend(std::move(Msg)); } void -AgentMessageChannel::WriteInt32(int Value) +AsyncAgentMessageChannel::Ping() { - WriteFixedLengthBytes(reinterpret_cast<const uint8_t*>(&Value), sizeof(int)); + auto Msg = BeginMessage(AgentMessageType::Ping, 0); + FinalizeAndSend(std::move(Msg)); } -int -AgentMessageChannel::ReadInt32(const uint8_t** Pos) +void +AsyncAgentMessageChannel::Fork(int ChannelId, int BufferSize) { - int Value; - memcpy(&Value, *Pos, sizeof(int)); - *Pos += sizeof(int); - return Value; + auto Msg = BeginMessage(AgentMessageType::Fork, sizeof(int) + sizeof(int)); + WriteInt32(Msg, ChannelId); + WriteInt32(Msg, BufferSize); + FinalizeAndSend(std::move(Msg)); } void -AgentMessageChannel::WriteFixedLengthBytes(const uint8_t* Data, size_t Length) +AsyncAgentMessageChannel::Attach() { - assert(m_RequestSize + Length <= m_MaxRequestSize); - memcpy(&m_RequestData[MessageHeaderLength + m_RequestSize], Data, Length); - m_RequestSize += Length; + auto Msg = BeginMessage(AgentMessageType::Attach, 0); + FinalizeAndSend(std::move(Msg)); } -const uint8_t* -AgentMessageChannel::ReadFixedLengthBytes(const uint8_t** Pos, size_t Length) +void +AsyncAgentMessageChannel::UploadFiles(const char* Path, const char* Locator) { - const uint8_t* Data = *Pos; - *Pos += Length; - return Data; + auto Msg = BeginMessage(AgentMessageType::WriteFiles, strlen(Path) + strlen(Locator) + 20); + WriteString(Msg, Path); + WriteString(Msg, Locator); + FinalizeAndSend(std::move(Msg)); } -size_t -AgentMessageChannel::MeasureUnsignedVarInt(size_t Value) +void +AsyncAgentMessageChannel::Execute(const char* Exe, + const char* const* Args, + size_t NumArgs, + const char* WorkingDir, + const char* const* EnvVars, + size_t NumEnvVars, + ExecuteProcessFlags Flags) { - if (Value == 0) + size_t ReserveSize = 50 + strlen(Exe); + for (size_t i = 0; i < NumArgs; ++i) { - return 1; + ReserveSize += strlen(Args[i]) + 10; + } + if (WorkingDir) + { + ReserveSize += strlen(WorkingDir) + 10; + } + for (size_t i = 0; i < NumEnvVars; ++i) + { + ReserveSize += strlen(EnvVars[i]) + 20; } - return (FloorLog2_64(static_cast<uint64_t>(Value)) / 7) + 1; -} -void -AgentMessageChannel::WriteUnsignedVarInt(size_t Value) -{ - const size_t ByteCount = MeasureUnsignedVarInt(Value); - assert(m_RequestSize + ByteCount <= m_MaxRequestSize); + auto Msg = BeginMessage(AgentMessageType::ExecuteV2, ReserveSize); + WriteString(Msg, Exe); - uint8_t* Output = m_RequestData + MessageHeaderLength + m_RequestSize; - for (size_t i = 1; i < ByteCount; ++i) + WriteUnsignedVarInt(Msg, NumArgs); + for (size_t i = 0; i < NumArgs; ++i) { - Output[ByteCount - i] = static_cast<uint8_t>(Value); - Value >>= 8; + WriteString(Msg, Args[i]); } - Output[0] = static_cast<uint8_t>((0xFF << (9 - static_cast<int>(ByteCount))) | static_cast<uint8_t>(Value)); - m_RequestSize += ByteCount; + WriteOptionalString(Msg, WorkingDir); + + WriteUnsignedVarInt(Msg, NumEnvVars); + for (size_t i = 0; i < NumEnvVars; ++i) + { + const char* Eq = strchr(EnvVars[i], '='); + if (Eq == nullptr) + { + // assert() would be compiled out in release and leave *(Eq+1) as UB - + // refuse to build the message for a malformed KEY=VALUE string instead. + throw zen::runtime_error("horde agent env var at index {} missing '=' separator", i); + } + + WriteString(Msg, std::string_view(EnvVars[i], Eq - EnvVars[i])); + if (*(Eq + 1) == '\0') + { + WriteOptionalString(Msg, nullptr); + } + else + { + WriteOptionalString(Msg, Eq + 1); + } + } + + WriteInt32(Msg, static_cast<int>(Flags)); + FinalizeAndSend(std::move(Msg)); } -size_t -AgentMessageChannel::ReadUnsignedVarInt(const uint8_t** Pos) +void +AsyncAgentMessageChannel::Blob(const uint8_t* Data, size_t Length) { - const uint8_t* Data = *Pos; - const uint8_t FirstByte = Data[0]; - const size_t NumBytes = CountLeadingZeros(0xFF & (~static_cast<unsigned int>(FirstByte))) + 1 - 24; + static constexpr size_t MaxBlobChunkSize = 512 * 1024; - size_t Value = static_cast<size_t>(FirstByte & (0xFF >> NumBytes)); - for (size_t i = 1; i < NumBytes; ++i) + // The Horde ReadBlobResponse wire format encodes both the chunk Offset and the total + // Length as int32. Lengths of 2 GiB or more would wrap to negative and confuse the + // remote parser. Refuse the send rather than produce a protocol violation. + if (Length > static_cast<size_t>(std::numeric_limits<int32_t>::max())) { - Value <<= 8; - Value |= Data[i]; + throw zen::runtime_error("horde ReadBlobResponse length {} exceeds int32 wire limit", Length); } - *Pos += NumBytes; - return Value; + for (size_t ChunkOffset = 0; ChunkOffset < Length;) + { + const size_t ChunkLength = std::min(Length - ChunkOffset, MaxBlobChunkSize); + + auto Msg = BeginMessage(AgentMessageType::ReadBlobResponse, ChunkLength + 128); + WriteInt32(Msg, static_cast<int32_t>(ChunkOffset)); + WriteInt32(Msg, static_cast<int32_t>(Length)); + WriteFixedLengthBytes(Msg, Data + ChunkOffset, ChunkLength); + FinalizeAndSend(std::move(Msg)); + + ChunkOffset += ChunkLength; + } } -size_t -AgentMessageChannel::MeasureString(const char* Text) const +// --- Async response reading --- + +void +AsyncAgentMessageChannel::AsyncReadResponse(int32_t TimeoutMs, AsyncResponseHandler Handler) { - const size_t Length = strlen(Text); - return MeasureUnsignedVarInt(Length) + Length; + // Serialize all access to m_IncomingFrames / m_PendingHandler / m_TimeoutTimer onto + // the socket's strand; OnFrame/OnDetach also run on that strand. Without this, the + // timer wait completion would run on a bare io_context thread (3 concurrent run() + // loops in the provisioner) and race with OnFrame on m_PendingHandler. + asio::dispatch(m_Socket->GetStrand(), [this, TimeoutMs, Handler = std::move(Handler)]() mutable { + if (!m_IncomingFrames.empty()) + { + std::vector<uint8_t> Frame = std::move(m_IncomingFrames.front()); + m_IncomingFrames.pop_front(); + + if (Frame.size() >= MessageHeaderLength) + { + AgentMessageType Type = static_cast<AgentMessageType>(Frame[0]); + const uint8_t* Data = Frame.data() + MessageHeaderLength; + size_t Size = Frame.size() - MessageHeaderLength; + asio::post(m_IoContext, [Handler = std::move(Handler), Type, Frame = std::move(Frame), Data, Size]() mutable { + // The Frame is captured to keep Data pointer valid + Handler(Type, Data, Size); + }); + } + else + { + asio::post(m_IoContext, [Handler = std::move(Handler)] { Handler(AgentMessageType::None, nullptr, 0); }); + } + return; + } + + if (m_Detached) + { + asio::post(m_IoContext, [Handler = std::move(Handler)] { Handler(AgentMessageType::None, nullptr, 0); }); + return; + } + + // No frames queued - store pending handler and arm timeout + m_PendingHandler = std::move(Handler); + + if (TimeoutMs >= 0) + { + m_TimeoutTimer->expires_after(std::chrono::milliseconds(TimeoutMs)); + m_TimeoutTimer->async_wait(asio::bind_executor(m_Socket->GetStrand(), [this](const asio::error_code& Ec) { + if (Ec) + { + return; // Cancelled - frame arrived before timeout + } + + // Already on the strand: safe to mutate m_PendingHandler. + if (m_PendingHandler) + { + AsyncResponseHandler Handler = std::move(m_PendingHandler); + m_PendingHandler = nullptr; + Handler(AgentMessageType::None, nullptr, 0); + } + })); + } + }); } void -AgentMessageChannel::WriteString(const char* Text) +AsyncAgentMessageChannel::OnFrame(std::vector<uint8_t> Data) { - const size_t Length = strlen(Text); - WriteUnsignedVarInt(Length); - WriteFixedLengthBytes(reinterpret_cast<const uint8_t*>(Text), Length); + if (m_PendingHandler) + { + // Cancel the timeout timer + m_TimeoutTimer->cancel(); + + AsyncResponseHandler Handler = std::move(m_PendingHandler); + m_PendingHandler = nullptr; + + if (Data.size() >= MessageHeaderLength) + { + AgentMessageType Type = static_cast<AgentMessageType>(Data[0]); + const uint8_t* Payload = Data.data() + MessageHeaderLength; + size_t PayloadSize = Data.size() - MessageHeaderLength; + Handler(Type, Payload, PayloadSize); + } + else + { + Handler(AgentMessageType::None, nullptr, 0); + } + } + else + { + m_IncomingFrames.push_back(std::move(Data)); + } } void -AgentMessageChannel::WriteString(std::string_view Text) +AsyncAgentMessageChannel::OnDetach() { - WriteUnsignedVarInt(Text.size()); - WriteFixedLengthBytes(reinterpret_cast<const uint8_t*>(Text.data()), Text.size()); + m_Detached = true; + + if (m_PendingHandler) + { + m_TimeoutTimer->cancel(); + AsyncResponseHandler Handler = std::move(m_PendingHandler); + m_PendingHandler = nullptr; + Handler(AgentMessageType::None, nullptr, 0); + } } -std::string_view -AgentMessageChannel::ReadString(const uint8_t** Pos) +// --- Response parsing helpers --- + +bool +AsyncAgentMessageChannel::ReadException(const uint8_t* Data, size_t Size, ExceptionInfo& Ex) { - const size_t Length = ReadUnsignedVarInt(Pos); - const char* Start = reinterpret_cast<const char*>(ReadFixedLengthBytes(Pos, Length)); - return std::string_view(Start, Length); + ReadCursor C{Data, Data + Size, false}; + Ex.Message = ReadString(C); + Ex.Description = ReadString(C); + if (C.ParseError) + { + Ex = {}; + return false; + } + return true; } -void -AgentMessageChannel::WriteOptionalString(const char* Text) +bool +AsyncAgentMessageChannel::ReadExecuteResult(const uint8_t* Data, size_t Size, int32_t& OutExitCode) { - // Optional strings use length+1 encoding: 0 means null/absent, - // N>0 means a string of length N-1 follows. This matches the UE - // FAgentMessageChannel serialization convention. - if (!Text) + ReadCursor C{Data, Data + Size, false}; + OutExitCode = ReadInt32(C); + return !C.ParseError; +} + +static bool +IsSafeLocator(std::string_view Locator) +{ + // Reject empty, overlong, path-separator-containing, parent-relative, absolute, or + // control-character-containing locators. The locator is used as a filename component + // joined with a trusted BundleDir, so the only safe characters are a restricted + // filename alphabet. + if (Locator.empty() || Locator.size() > 255) { - WriteUnsignedVarInt(0); + return false; } - else + if (Locator == "." || Locator == "..") { - const size_t Length = strlen(Text); - WriteUnsignedVarInt(Length + 1); - WriteFixedLengthBytes(reinterpret_cast<const uint8_t*>(Text), Length); + return false; + } + for (char Ch : Locator) + { + const unsigned char U = static_cast<unsigned char>(Ch); + if (U < 0x20 || U == 0x7F) + { + return false; // control / NUL / DEL + } + if (Ch == '/' || Ch == '\\' || Ch == ':') + { + return false; // path separators / drive letters + } + } + // Disallow leading/trailing dot or whitespace (Windows quirks + hidden-file dodges) + if (Locator.front() == '.' || Locator.front() == ' ' || Locator.back() == '.' || Locator.back() == ' ') + { + return false; + } + return true; +} + +bool +AsyncAgentMessageChannel::ReadBlobRequest(const uint8_t* Data, size_t Size, BlobRequest& Req) +{ + ReadCursor C{Data, Data + Size, false}; + Req.Locator = ReadString(C); + Req.Offset = ReadUnsignedVarInt(C); + Req.Length = ReadUnsignedVarInt(C); + if (C.ParseError || !IsSafeLocator(Req.Locator)) + { + Req = {}; + return false; } + return true; } } // namespace zen::horde diff --git a/src/zenhorde/hordeagentmessage.h b/src/zenhorde/hordeagentmessage.h index 38c4375fd..fb7c5ed29 100644 --- a/src/zenhorde/hordeagentmessage.h +++ b/src/zenhorde/hordeagentmessage.h @@ -4,14 +4,22 @@ #include <zenbase/zenbase.h> -#include "hordecomputechannel.h" +#include "hordecomputesocket.h" #include <cstddef> #include <cstdint> +#include <deque> +#include <functional> +#include <memory> #include <string> #include <string_view> +#include <system_error> #include <vector> +namespace asio { +class io_context; +} // namespace asio + namespace zen::horde { /** Agent message types matching the UE EAgentMessageType byte values. @@ -55,45 +63,34 @@ struct BlobRequest size_t Length = 0; }; -/** Channel for sending and receiving agent messages over a ComputeChannel. +/** Handler for async response reads. Receives the message type and a view of the payload data. + * The payload vector is valid until the next AsyncReadResponse call. */ +using AsyncResponseHandler = std::function<void(AgentMessageType Type, const uint8_t* Data, size_t Size)>; + +/** Async channel for sending and receiving agent messages over an AsyncComputeSocket. * - * Implements the Horde agent message protocol, matching the UE - * FAgentMessageChannel serialization format exactly. Messages are framed as - * [type (1B)][payload length (4B)][payload]. Strings use length-prefixed UTF-8; - * integers use variable-length encoding. + * Send methods build messages into vectors and submit them via AsyncComputeSocket. + * Receives are delivered via the socket's FrameHandler callback and queued internally. + * AsyncReadResponse checks the queue and invokes the handler, with optional timeout. * - * The protocol has two directions: - * - Requests (initiator -> remote): Close, Ping, Fork, Attach, UploadFiles, Execute, Blob - * - Responses (remote -> initiator): ReadResponse returns the type, then call the - * appropriate Read* method to parse the payload. + * All operations must be externally serialized (e.g. via the socket's strand). */ -class AgentMessageChannel +class AsyncAgentMessageChannel { public: - explicit AgentMessageChannel(Ref<ComputeChannel> Channel); - ~AgentMessageChannel(); + AsyncAgentMessageChannel(std::shared_ptr<AsyncComputeSocket> Socket, int ChannelId, asio::io_context& IoContext); + ~AsyncAgentMessageChannel(); - AgentMessageChannel(const AgentMessageChannel&) = delete; - AgentMessageChannel& operator=(const AgentMessageChannel&) = delete; + AsyncAgentMessageChannel(const AsyncAgentMessageChannel&) = delete; + AsyncAgentMessageChannel& operator=(const AsyncAgentMessageChannel&) = delete; - // --- Requests (Initiator -> Remote) --- + // --- Requests (fire-and-forget sends) --- - /** Close the channel. */ void Close(); - - /** Send a keepalive ping. */ void Ping(); - - /** Fork communication to a new channel with the given ID and buffer size. */ void Fork(int ChannelId, int BufferSize); - - /** Send an attach request (used during channel setup handshake). */ void Attach(); - - /** Request the remote agent to write files from the given bundle locator. */ void UploadFiles(const char* Path, const char* Locator); - - /** Execute a process on the remote machine. */ void Execute(const char* Exe, const char* const* Args, size_t NumArgs, @@ -101,61 +98,85 @@ public: const char* const* EnvVars, size_t NumEnvVars, ExecuteProcessFlags Flags = ExecuteProcessFlags::None); - - /** Send blob data in response to a ReadBlob request. */ void Blob(const uint8_t* Data, size_t Length); - // --- Responses (Remote -> Initiator) --- - - /** Read the next response message. Returns the message type, or None on timeout. - * After this returns, use GetResponseData()/GetResponseSize() or the typed - * Read* methods to access the payload. */ - AgentMessageType ReadResponse(int32_t TimeoutMs = -1, bool* OutTimedOut = nullptr); + // --- Async response reading --- - const void* GetResponseData() const { return m_ResponseData; } - size_t GetResponseSize() const { return m_ResponseLength; } + /** Read the next response. If a frame is already queued, the handler is posted immediately. + * Otherwise waits up to TimeoutMs for a frame to arrive. On timeout, invokes the handler + * with AgentMessageType::None. */ + void AsyncReadResponse(int32_t TimeoutMs, AsyncResponseHandler Handler); - /** Parse an Exception response payload. */ - void ReadException(ExceptionInfo& Ex); + /** Called by the socket's FrameHandler when a frame arrives for this channel. */ + void OnFrame(std::vector<uint8_t> Data); - /** Parse an ExecuteResult response payload. Returns the exit code. */ - int ReadExecuteResult(); + /** Called by the socket's DetachHandler. */ + void OnDetach(); - /** Parse a ReadBlob response payload into a BlobRequest. */ - void ReadBlobRequest(BlobRequest& Req); + /** Returns true if the channel has been detached (connection lost). */ + bool IsDetached() const { return m_Detached; } -private: - static constexpr size_t MessageHeaderLength = 5; ///< [type(1B)][length(4B)] + // --- Response parsing helpers --- - Ref<ComputeChannel> m_Channel; + /** Parse an Exception message payload. Returns false on malformed/truncated input. */ + [[nodiscard]] static bool ReadException(const uint8_t* Data, size_t Size, ExceptionInfo& Ex); - uint8_t* m_RequestData = nullptr; - size_t m_RequestSize = 0; - size_t m_MaxRequestSize = 0; + /** Parse an ExecuteResult message payload. Returns false on malformed/truncated input. */ + [[nodiscard]] static bool ReadExecuteResult(const uint8_t* Data, size_t Size, int32_t& OutExitCode); - AgentMessageType m_ResponseType = AgentMessageType::None; - const uint8_t* m_ResponseData = nullptr; - size_t m_ResponseLength = 0; + /** Parse a ReadBlob message payload. Returns false on malformed/truncated input or + * if the Locator contains characters that would not be safe to use as a path component. */ + [[nodiscard]] static bool ReadBlobRequest(const uint8_t* Data, size_t Size, BlobRequest& Req); - void CreateMessage(AgentMessageType Type, size_t MaxLength); - void FlushMessage(); +private: + static constexpr size_t MessageHeaderLength = 5; + + // Message building helpers + std::vector<uint8_t> BeginMessage(AgentMessageType Type, size_t ReservePayload); + void FinalizeAndSend(std::vector<uint8_t> Msg); + + /** Bounds-checked reader cursor. All Read* helpers set ParseError instead of reading past End. */ + struct ReadCursor + { + const uint8_t* Pos = nullptr; + const uint8_t* End = nullptr; + bool ParseError = false; + + [[nodiscard]] bool CheckAvailable(size_t N) + { + if (ParseError || static_cast<size_t>(End - Pos) < N) + { + ParseError = true; + return false; + } + return true; + } + }; + + static void WriteInt32(std::vector<uint8_t>& Buf, int Value); + static int ReadInt32(ReadCursor& C); + + static void WriteFixedLengthBytes(std::vector<uint8_t>& Buf, const uint8_t* Data, size_t Length); + static const uint8_t* ReadFixedLengthBytes(ReadCursor& C, size_t Length); - void WriteInt32(int Value); - static int ReadInt32(const uint8_t** Pos); + static size_t MeasureUnsignedVarInt(size_t Value); + static void WriteUnsignedVarInt(std::vector<uint8_t>& Buf, size_t Value); + static size_t ReadUnsignedVarInt(ReadCursor& C); - void WriteFixedLengthBytes(const uint8_t* Data, size_t Length); - static const uint8_t* ReadFixedLengthBytes(const uint8_t** Pos, size_t Length); + static void WriteString(std::vector<uint8_t>& Buf, const char* Text); + static void WriteString(std::vector<uint8_t>& Buf, std::string_view Text); + static std::string_view ReadString(ReadCursor& C); - static size_t MeasureUnsignedVarInt(size_t Value); - void WriteUnsignedVarInt(size_t Value); - static size_t ReadUnsignedVarInt(const uint8_t** Pos); + static void WriteOptionalString(std::vector<uint8_t>& Buf, const char* Text); - size_t MeasureString(const char* Text) const; - void WriteString(const char* Text); - void WriteString(std::string_view Text); - static std::string_view ReadString(const uint8_t** Pos); + std::shared_ptr<AsyncComputeSocket> m_Socket; + int m_ChannelId; + asio::io_context& m_IoContext; - void WriteOptionalString(const char* Text); + std::deque<std::vector<uint8_t>> m_IncomingFrames; + AsyncResponseHandler m_PendingHandler; + std::unique_ptr<asio::steady_timer> m_TimeoutTimer; + bool m_Detached = false; }; } // namespace zen::horde diff --git a/src/zenhorde/hordebundle.cpp b/src/zenhorde/hordebundle.cpp index d3974bc28..8493a9456 100644 --- a/src/zenhorde/hordebundle.cpp +++ b/src/zenhorde/hordebundle.cpp @@ -10,6 +10,7 @@ #include <zencore/logging.h> #include <zencore/process.h> #include <zencore/trace.h> +#include <zencore/uid.h> #include <algorithm> #include <chrono> @@ -48,7 +49,7 @@ static constexpr uint8_t BlobType_DirectoryV1[20] = {0x11, 0xEC, 0x14, 0x07, 0x1 static constexpr size_t BlobTypeSize = 20; -// ─── VarInt helpers (UE format) ───────────────────────────────────────────── +// --- VarInt helpers (UE format) --------------------------------------------- static size_t MeasureVarInt(size_t Value) @@ -57,7 +58,7 @@ MeasureVarInt(size_t Value) { return 1; } - return (FloorLog2(static_cast<unsigned int>(Value)) / 7) + 1; + return (FloorLog2_64(static_cast<uint64_t>(Value)) / 7) + 1; } static void @@ -76,7 +77,7 @@ WriteVarInt(std::vector<uint8_t>& Buffer, size_t Value) Output[0] = static_cast<uint8_t>((0xFF << (9 - static_cast<int>(ByteCount))) | static_cast<uint8_t>(Value)); } -// ─── Binary helpers ───────────────────────────────────────────────────────── +// --- Binary helpers --------------------------------------------------------- static void WriteLE32(std::vector<uint8_t>& Buffer, int32_t Value) @@ -121,7 +122,7 @@ PatchLE32(std::vector<uint8_t>& Buffer, size_t Offset, int32_t Value) memcpy(Buffer.data() + Offset, &Value, 4); } -// ─── Packet builder ───────────────────────────────────────────────────────── +// --- Packet builder --------------------------------------------------------- // Builds a single uncompressed Horde V2 packet. Layout: // [Signature(3) + Version(1) + PacketLength(4)] 8 bytes (header) @@ -229,7 +230,7 @@ struct PacketBuilder { AlignTo4(Data); - // ── Type table: count(int32) + count * BlobTypeSize bytes ── + // -- Type table: count(int32) + count * BlobTypeSize bytes -- const int32_t TypeTableOffset = static_cast<int32_t>(Data.size()); WriteLE32(Data, static_cast<int32_t>(Types.size())); for (const uint8_t* TypeEntry : Types) @@ -237,12 +238,12 @@ struct PacketBuilder WriteBytes(Data, TypeEntry, BlobTypeSize); } - // ── Import table: count(int32) + (count+1) offsets(int32 each) + import data ── + // -- Import table: count(int32) + (count+1) offsets(int32 each) + import data -- const int32_t ImportTableOffset = static_cast<int32_t>(Data.size()); const int32_t ImportCount = static_cast<int32_t>(Imports.size()); WriteLE32(Data, ImportCount); - // Reserve space for (count+1) offset entries — will be patched below + // Reserve space for (count+1) offset entries - will be patched below const size_t ImportOffsetsStart = Data.size(); for (int32_t i = 0; i <= ImportCount; ++i) { @@ -266,7 +267,7 @@ struct PacketBuilder // Sentinel offset (points past the last import's data) PatchLE32(Data, ImportOffsetsStart + static_cast<size_t>(ImportCount) * 4, static_cast<int32_t>(Data.size())); - // ── Export table: count(int32) + (count+1) offsets(int32 each) ── + // -- Export table: count(int32) + (count+1) offsets(int32 each) -- const int32_t ExportTableOffset = static_cast<int32_t>(Data.size()); const int32_t ExportCount = static_cast<int32_t>(ExportOffsets.size()); WriteLE32(Data, ExportCount); @@ -278,7 +279,7 @@ struct PacketBuilder // Sentinel: points to the start of the type table (end of export data region) WriteLE32(Data, TypeTableOffset); - // ── Patch header ── + // -- Patch header -- // PacketLength = total packet size including the 8-byte header const int32_t PacketLength = static_cast<int32_t>(Data.size()); PatchLE32(Data, 4, PacketLength); @@ -290,7 +291,7 @@ struct PacketBuilder } }; -// ─── Encoded packet wrapper ───────────────────────────────────────────────── +// --- Encoded packet wrapper ------------------------------------------------- // Wraps an uncompressed packet with the encoded header: // [Signature(3) + Version(1) + HeaderLength(4)] 8 bytes @@ -327,24 +328,22 @@ EncodePacket(std::vector<uint8_t> UncompressedPacket) return Encoded; } -// ─── Bundle blob name generation ──────────────────────────────────────────── +// --- Bundle blob name generation -------------------------------------------- static std::string GenerateBlobName() { - static std::atomic<uint32_t> s_Counter{0}; - - const int Pid = GetCurrentProcessId(); - - auto Now = std::chrono::steady_clock::now().time_since_epoch(); - auto Ms = std::chrono::duration_cast<std::chrono::milliseconds>(Now).count(); - - ExtendableStringBuilder<64> Name; - Name << Pid << "_" << Ms << "_" << s_Counter.fetch_add(1); - return std::string(Name.ToView()); + // Oid is a 12-byte identifier built from a timestamp, a monotonic serial number + // initialised from std::random_device, and a per-process run id also drawn from + // std::random_device. The 24-hex-char rendering gives ~80 bits of effective + // name-prediction entropy, so a local attacker cannot race-create the blob + // path before we open it. Previously the name was pid+ms+counter, which two + // zenserver processes with the same PID could collide on and which was + // entirely predictable. + return zen::Oid::NewOid().ToString(); } -// ─── File info for bundling ───────────────────────────────────────────────── +// --- File info for bundling ------------------------------------------------- struct FileInfo { @@ -357,7 +356,7 @@ struct FileInfo IoHash RootExportHash; // IoHash of the root export for this file }; -// ─── CreateBundle implementation ──────────────────────────────────────────── +// --- CreateBundle implementation -------------------------------------------- bool BundleCreator::CreateBundle(const std::vector<BundleFile>& Files, const std::filesystem::path& OutputDir, BundleResult& OutResult) @@ -534,7 +533,7 @@ BundleCreator::CreateBundle(const std::vector<BundleFile>& Files, const std::fil FileInfo& Info = ValidFiles[i]; DirImports.push_back(Info.DirectoryExportImportIndex); - // IoHash of target (20 bytes) — import is consumed sequentially from the + // IoHash of target (20 bytes) - import is consumed sequentially from the // export's import list by ReadBlobRef, not encoded in the payload WriteBytes(DirPayload, Info.RootExportHash.Hash, sizeof(IoHash)); // name (string) @@ -557,8 +556,16 @@ BundleCreator::CreateBundle(const std::vector<BundleFile>& Files, const std::fil std::vector<uint8_t> UncompressedPacket = Packet.Finish(); std::vector<uint8_t> EncodedPacket = EncodePacket(std::move(UncompressedPacket)); - // Write .blob file + // Write .blob file. Refuse to proceed if a file with this name already exists - + // the Oid-based BlobName should make collisions astronomically unlikely, so an + // existing file implies either an extraordinary collision or an attacker having + // pre-seeded the path; either way, we do not want to overwrite it. const std::filesystem::path BlobFilePath = OutputDir / (BlobName + ".blob"); + if (std::filesystem::exists(BlobFilePath, Ec)) + { + ZEN_ERROR("blob file already exists at {} - refusing to overwrite", BlobFilePath.string()); + return false; + } { BasicFile BlobFile(BlobFilePath, BasicFile::Mode::kTruncate, Ec); if (Ec) @@ -574,8 +581,10 @@ BundleCreator::CreateBundle(const std::vector<BundleFile>& Files, const std::fil Locator << BlobName << "#pkt=0," << uint64_t(EncodedPacket.size()) << "&exp=" << DirExportIndex; const std::string LocatorStr(Locator.ToView()); - // Write .ref file (use first file's name as the ref base) - const std::filesystem::path RefFilePath = OutputDir / (ValidFiles[0].Name + ".Bundle.ref"); + // Write .ref file. Include the Oid-based BlobName so that two concurrent + // CreateBundle() calls into the same OutputDir that happen to share the first + // filename don't clobber each other's ref file. + const std::filesystem::path RefFilePath = OutputDir / (ValidFiles[0].Name + "." + BlobName + ".Bundle.ref"); { BasicFile RefFile(RefFilePath, BasicFile::Mode::kTruncate, Ec); if (Ec) diff --git a/src/zenhorde/hordeclient.cpp b/src/zenhorde/hordeclient.cpp index fb981f0ba..762edce06 100644 --- a/src/zenhorde/hordeclient.cpp +++ b/src/zenhorde/hordeclient.cpp @@ -4,6 +4,7 @@ #include <zencore/iobuffer.h> #include <zencore/logging.h> #include <zencore/memoryview.h> +#include <zencore/string.h> #include <zencore/trace.h> #include <zenhorde/hordeclient.h> #include <zenhttp/httpclient.h> @@ -14,7 +15,7 @@ ZEN_THIRD_PARTY_INCLUDES_END namespace zen::horde { -HordeClient::HordeClient(const HordeConfig& Config) : m_Config(Config), m_Log(zen::logging::Get("horde.client")) +HordeClient::HordeClient(HordeConfig Config) : m_Config(std::move(Config)), m_Log("horde.client") { } @@ -32,19 +33,24 @@ HordeClient::Initialize() Settings.RetryCount = 1; Settings.ExpectedErrorCodes = {HttpResponseCode::ServiceUnavailable, HttpResponseCode::TooManyRequests}; - if (!m_Config.AuthToken.empty()) + if (m_Config.AccessTokenProvider) { + Settings.AccessTokenProvider = m_Config.AccessTokenProvider; + } + else if (!m_Config.AuthToken.empty()) + { + // Static tokens have no wire-provided expiry. Synthesising \"now + 24h\" is wrong + // in both directions: if the real token expires before 24h we keep sending it after + // it dies; if it's long-lived we force unnecessary re-auth churn every day. Use the + // never-expires sentinel, matching zenhttp's CreateFromStaticToken. Settings.AccessTokenProvider = [token = m_Config.AuthToken]() -> HttpClientAccessToken { - HttpClientAccessToken Token; - Token.Value = token; - Token.ExpireTime = HttpClientAccessToken::Clock::now() + std::chrono::hours{24}; - return Token; + return HttpClientAccessToken(token, HttpClientAccessToken::TimePoint::max()); }; } m_Http = std::make_unique<zen::HttpClient>(m_Config.ServerUrl, Settings); - if (!m_Config.AuthToken.empty()) + if (Settings.AccessTokenProvider) { if (!m_Http->Authenticate()) { @@ -66,24 +72,21 @@ HordeClient::BuildRequestBody() const Requirements["pool"] = m_Config.Pool; } - std::string Condition; -#if ZEN_PLATFORM_WINDOWS ExtendableStringBuilder<256> CondBuf; +#if ZEN_PLATFORM_WINDOWS CondBuf << "(OSFamily == 'Windows' || WineEnabled == '" << (m_Config.AllowWine ? "true" : "false") << "')"; - Condition = std::string(CondBuf); #elif ZEN_PLATFORM_MAC - Condition = "OSFamily == 'MacOS'"; + CondBuf << "OSFamily == 'MacOS'"; #else - Condition = "OSFamily == 'Linux'"; + CondBuf << "OSFamily == 'Linux'"; #endif if (!m_Config.Condition.empty()) { - Condition += " "; - Condition += m_Config.Condition; + CondBuf << " " << m_Config.Condition; } - Requirements["condition"] = Condition; + Requirements["condition"] = std::string(CondBuf); Requirements["exclusive"] = true; json11::Json::object Connection; @@ -159,39 +162,27 @@ HordeClient::ResolveCluster(const std::string& RequestBody, ClusterInfo& OutClus return false; } - OutCluster.ClusterId = ClusterIdVal.string_value(); - return true; -} - -bool -HordeClient::ParseHexBytes(std::string_view Hex, uint8_t* Out, size_t OutSize) -{ - if (Hex.size() != OutSize * 2) + // A server-returned ClusterId is interpolated directly into the request URL below + // (api/v2/compute/<ClusterId>), so a compromised or MITM'd Horde server could + // otherwise inject additional path segments or query strings. Constrain to a + // conservative identifier alphabet. + const std::string& ClusterIdStr = ClusterIdVal.string_value(); + if (ClusterIdStr.size() > 64) { + ZEN_WARN("rejecting overlong clusterId ({} bytes) in cluster resolution response", ClusterIdStr.size()); return false; } - - for (size_t i = 0; i < OutSize; ++i) + static constexpr AsciiSet ValidClusterIdCharactersSet{"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789._-"}; + if (!AsciiSet::HasOnly(ClusterIdStr, ValidClusterIdCharactersSet)) { - auto HexToByte = [](char c) -> int { - if (c >= '0' && c <= '9') - return c - '0'; - if (c >= 'a' && c <= 'f') - return c - 'a' + 10; - if (c >= 'A' && c <= 'F') - return c - 'A' + 10; - return -1; - }; - - const int Hi = HexToByte(Hex[i * 2]); - const int Lo = HexToByte(Hex[i * 2 + 1]); - if (Hi < 0 || Lo < 0) - { - return false; - } - Out[i] = static_cast<uint8_t>((Hi << 4) | Lo); + ZEN_WARN("rejecting clusterId with unsafe character in cluster resolution response"); + return false; } + OutCluster.ClusterId = ClusterIdStr; + + ZEN_DEBUG("cluster resolution succeeded: clusterId='{}'", OutCluster.ClusterId); + return true; } @@ -200,8 +191,6 @@ HordeClient::RequestMachine(const std::string& RequestBody, const std::string& C { ZEN_TRACE_CPU("HordeClient::RequestMachine"); - ZEN_INFO("requesting machine from Horde with cluster '{}'", ClusterId.empty() ? "default" : ClusterId.c_str()); - ExtendableStringBuilder<128> ResourcePath; ResourcePath << "api/v2/compute/" << (ClusterId.empty() ? "default" : ClusterId.c_str()); @@ -321,11 +310,15 @@ HordeClient::RequestMachine(const std::string& RequestBody, const std::string& C } else if (Prop.starts_with("LogicalCores=")) { - LogicalCores = static_cast<uint16_t>(std::atoi(Prop.c_str() + 13)); + LogicalCores = ParseInt<uint16_t>(std::string_view(Prop).substr(13)).value_or(0); } else if (Prop.starts_with("PhysicalCores=")) { - PhysicalCores = static_cast<uint16_t>(std::atoi(Prop.c_str() + 14)); + PhysicalCores = ParseInt<uint16_t>(std::string_view(Prop).substr(14)).value_or(0); + } + else if (Prop.starts_with("Pool=")) + { + OutMachine.Pool = Prop.substr(5); } } } @@ -370,10 +363,12 @@ HordeClient::RequestMachine(const std::string& RequestBody, const std::string& C OutMachine.LeaseId = LeaseIdVal.string_value(); } - ZEN_INFO("Horde machine assigned [{}:{}] cores={} lease={}", + ZEN_INFO("Horde machine assigned [{}:{}] mode={} cores={} pool={} lease={}", OutMachine.GetConnectionAddress(), OutMachine.GetConnectionPort(), + ToString(OutMachine.Mode), OutMachine.LogicalCores, + OutMachine.Pool, OutMachine.LeaseId); return true; diff --git a/src/zenhorde/hordecomputebuffer.cpp b/src/zenhorde/hordecomputebuffer.cpp deleted file mode 100644 index 0d032b5d5..000000000 --- a/src/zenhorde/hordecomputebuffer.cpp +++ /dev/null @@ -1,454 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#include "hordecomputebuffer.h" - -#include <algorithm> -#include <cassert> -#include <chrono> -#include <condition_variable> -#include <cstring> - -namespace zen::horde { - -// Simplified ring buffer implementation for in-process use only. -// Uses a single contiguous buffer with write/read cursors and -// mutex+condvar for synchronization. This is simpler than the UE version -// which uses lock-free atomics and shared memory, but sufficient for our -// use case where we're the initiator side of the compute protocol. - -struct ComputeBuffer::Detail : TRefCounted<Detail> -{ - std::vector<uint8_t> Data; - size_t NumChunks = 0; - size_t ChunkLength = 0; - - // Current write state - size_t WriteChunkIdx = 0; - size_t WriteOffset = 0; - bool WriteComplete = false; - - // Current read state - size_t ReadChunkIdx = 0; - size_t ReadOffset = 0; - bool Detached = false; - - // Per-chunk written length - std::vector<size_t> ChunkWrittenLength; - std::vector<bool> ChunkFinished; // Writer moved to next chunk - - std::mutex Mutex; - std::condition_variable ReadCV; ///< Signaled when new data is written or stream completes - std::condition_variable WriteCV; ///< Signaled when reader advances past a chunk, freeing space - - bool HasWriter = false; - bool HasReader = false; - - uint8_t* ChunkPtr(size_t ChunkIdx) { return Data.data() + ChunkIdx * ChunkLength; } - const uint8_t* ChunkPtr(size_t ChunkIdx) const { return Data.data() + ChunkIdx * ChunkLength; } -}; - -// ComputeBuffer - -ComputeBuffer::ComputeBuffer() -{ -} -ComputeBuffer::~ComputeBuffer() -{ -} - -bool -ComputeBuffer::CreateNew(const Params& InParams) -{ - auto* NewDetail = new Detail(); - NewDetail->NumChunks = InParams.NumChunks; - NewDetail->ChunkLength = InParams.ChunkLength; - NewDetail->Data.resize(InParams.NumChunks * InParams.ChunkLength, 0); - NewDetail->ChunkWrittenLength.resize(InParams.NumChunks, 0); - NewDetail->ChunkFinished.resize(InParams.NumChunks, false); - - m_Detail = NewDetail; - return true; -} - -void -ComputeBuffer::Close() -{ - m_Detail = nullptr; -} - -bool -ComputeBuffer::IsValid() const -{ - return static_cast<bool>(m_Detail); -} - -ComputeBufferReader -ComputeBuffer::CreateReader() -{ - assert(m_Detail); - m_Detail->HasReader = true; - return ComputeBufferReader(m_Detail); -} - -ComputeBufferWriter -ComputeBuffer::CreateWriter() -{ - assert(m_Detail); - m_Detail->HasWriter = true; - return ComputeBufferWriter(m_Detail); -} - -// ComputeBufferReader - -ComputeBufferReader::ComputeBufferReader() -{ -} -ComputeBufferReader::~ComputeBufferReader() -{ -} - -ComputeBufferReader::ComputeBufferReader(const ComputeBufferReader& Other) = default; -ComputeBufferReader::ComputeBufferReader(ComputeBufferReader&& Other) noexcept = default; -ComputeBufferReader& ComputeBufferReader::operator=(const ComputeBufferReader& Other) = default; -ComputeBufferReader& ComputeBufferReader::operator=(ComputeBufferReader&& Other) noexcept = default; - -ComputeBufferReader::ComputeBufferReader(Ref<ComputeBuffer::Detail> InDetail) : m_Detail(std::move(InDetail)) -{ -} - -void -ComputeBufferReader::Close() -{ - m_Detail = nullptr; -} - -void -ComputeBufferReader::Detach() -{ - if (m_Detail) - { - std::lock_guard<std::mutex> Lock(m_Detail->Mutex); - m_Detail->Detached = true; - m_Detail->ReadCV.notify_all(); - } -} - -bool -ComputeBufferReader::IsValid() const -{ - return static_cast<bool>(m_Detail); -} - -bool -ComputeBufferReader::IsComplete() const -{ - if (!m_Detail) - { - return true; - } - std::lock_guard<std::mutex> Lock(m_Detail->Mutex); - if (m_Detail->Detached) - { - return true; - } - return m_Detail->WriteComplete && m_Detail->ReadChunkIdx == m_Detail->WriteChunkIdx && - m_Detail->ReadOffset >= m_Detail->ChunkWrittenLength[m_Detail->ReadChunkIdx]; -} - -void -ComputeBufferReader::AdvanceReadPosition(size_t Size) -{ - if (!m_Detail) - { - return; - } - - std::lock_guard<std::mutex> Lock(m_Detail->Mutex); - - m_Detail->ReadOffset += Size; - - // Check if we need to move to next chunk - const size_t ReadChunk = m_Detail->ReadChunkIdx; - if (m_Detail->ChunkFinished[ReadChunk] && m_Detail->ReadOffset >= m_Detail->ChunkWrittenLength[ReadChunk]) - { - const size_t NextChunk = (ReadChunk + 1) % m_Detail->NumChunks; - m_Detail->ReadChunkIdx = NextChunk; - m_Detail->ReadOffset = 0; - m_Detail->WriteCV.notify_all(); - } - - m_Detail->ReadCV.notify_all(); -} - -size_t -ComputeBufferReader::GetMaxReadSize() const -{ - if (!m_Detail) - { - return 0; - } - std::lock_guard<std::mutex> Lock(m_Detail->Mutex); - const size_t ReadChunk = m_Detail->ReadChunkIdx; - return m_Detail->ChunkWrittenLength[ReadChunk] - m_Detail->ReadOffset; -} - -const uint8_t* -ComputeBufferReader::WaitToRead(size_t MinSize, int TimeoutMs, bool* OutTimedOut) -{ - if (!m_Detail) - { - return nullptr; - } - - std::unique_lock<std::mutex> Lock(m_Detail->Mutex); - - auto Predicate = [&]() -> bool { - if (m_Detail->Detached) - { - return true; - } - - const size_t ReadChunk = m_Detail->ReadChunkIdx; - const size_t Available = m_Detail->ChunkWrittenLength[ReadChunk] - m_Detail->ReadOffset; - - if (Available >= MinSize) - { - return true; - } - - // If chunk is finished and we've read everything, try to move to next - if (m_Detail->ChunkFinished[ReadChunk] && m_Detail->ReadOffset >= m_Detail->ChunkWrittenLength[ReadChunk]) - { - if (m_Detail->WriteComplete) - { - return true; // End of stream - } - // Move to next chunk - const size_t NextChunk = (ReadChunk + 1) % m_Detail->NumChunks; - m_Detail->ReadChunkIdx = NextChunk; - m_Detail->ReadOffset = 0; - m_Detail->WriteCV.notify_all(); - return false; // Re-check with new chunk - } - - if (m_Detail->WriteComplete) - { - return true; // End of stream - } - - return false; - }; - - if (TimeoutMs < 0) - { - m_Detail->ReadCV.wait(Lock, Predicate); - } - else - { - if (!m_Detail->ReadCV.wait_for(Lock, std::chrono::milliseconds(TimeoutMs), Predicate)) - { - if (OutTimedOut) - { - *OutTimedOut = true; - } - return nullptr; - } - } - - if (m_Detail->Detached) - { - return nullptr; - } - - const size_t ReadChunk = m_Detail->ReadChunkIdx; - const size_t Available = m_Detail->ChunkWrittenLength[ReadChunk] - m_Detail->ReadOffset; - - if (Available < MinSize) - { - return nullptr; // End of stream - } - - return m_Detail->ChunkPtr(ReadChunk) + m_Detail->ReadOffset; -} - -size_t -ComputeBufferReader::Read(void* Buffer, size_t MaxSize, int TimeoutMs, bool* OutTimedOut) -{ - const uint8_t* Data = WaitToRead(1, TimeoutMs, OutTimedOut); - if (!Data) - { - return 0; - } - - const size_t Available = GetMaxReadSize(); - const size_t ToCopy = std::min(Available, MaxSize); - memcpy(Buffer, Data, ToCopy); - AdvanceReadPosition(ToCopy); - return ToCopy; -} - -// ComputeBufferWriter - -ComputeBufferWriter::ComputeBufferWriter() = default; -ComputeBufferWriter::ComputeBufferWriter(const ComputeBufferWriter& Other) = default; -ComputeBufferWriter::ComputeBufferWriter(ComputeBufferWriter&& Other) noexcept = default; -ComputeBufferWriter::~ComputeBufferWriter() = default; -ComputeBufferWriter& ComputeBufferWriter::operator=(const ComputeBufferWriter& Other) = default; -ComputeBufferWriter& ComputeBufferWriter::operator=(ComputeBufferWriter&& Other) noexcept = default; - -ComputeBufferWriter::ComputeBufferWriter(Ref<ComputeBuffer::Detail> InDetail) : m_Detail(std::move(InDetail)) -{ -} - -void -ComputeBufferWriter::Close() -{ - if (m_Detail) - { - { - std::lock_guard<std::mutex> Lock(m_Detail->Mutex); - if (!m_Detail->WriteComplete) - { - m_Detail->WriteComplete = true; - m_Detail->ReadCV.notify_all(); - } - } - m_Detail = nullptr; - } -} - -bool -ComputeBufferWriter::IsValid() const -{ - return static_cast<bool>(m_Detail); -} - -void -ComputeBufferWriter::MarkComplete() -{ - if (m_Detail) - { - std::lock_guard<std::mutex> Lock(m_Detail->Mutex); - m_Detail->WriteComplete = true; - m_Detail->ReadCV.notify_all(); - } -} - -void -ComputeBufferWriter::AdvanceWritePosition(size_t Size) -{ - if (!m_Detail || Size == 0) - { - return; - } - - std::lock_guard<std::mutex> Lock(m_Detail->Mutex); - const size_t WriteChunk = m_Detail->WriteChunkIdx; - m_Detail->ChunkWrittenLength[WriteChunk] += Size; - m_Detail->WriteOffset += Size; - m_Detail->ReadCV.notify_all(); -} - -size_t -ComputeBufferWriter::GetMaxWriteSize() const -{ - if (!m_Detail) - { - return 0; - } - std::lock_guard<std::mutex> Lock(m_Detail->Mutex); - const size_t WriteChunk = m_Detail->WriteChunkIdx; - return m_Detail->ChunkLength - m_Detail->ChunkWrittenLength[WriteChunk]; -} - -size_t -ComputeBufferWriter::GetChunkMaxLength() const -{ - if (!m_Detail) - { - return 0; - } - return m_Detail->ChunkLength; -} - -size_t -ComputeBufferWriter::Write(const void* Buffer, size_t MaxSize, int TimeoutMs) -{ - uint8_t* Dest = WaitToWrite(1, TimeoutMs); - if (!Dest) - { - return 0; - } - - const size_t Available = GetMaxWriteSize(); - const size_t ToCopy = std::min(Available, MaxSize); - memcpy(Dest, Buffer, ToCopy); - AdvanceWritePosition(ToCopy); - return ToCopy; -} - -uint8_t* -ComputeBufferWriter::WaitToWrite(size_t MinSize, int TimeoutMs) -{ - if (!m_Detail) - { - return nullptr; - } - - std::unique_lock<std::mutex> Lock(m_Detail->Mutex); - - if (m_Detail->WriteComplete) - { - return nullptr; - } - - const size_t WriteChunk = m_Detail->WriteChunkIdx; - const size_t Available = m_Detail->ChunkLength - m_Detail->ChunkWrittenLength[WriteChunk]; - - // If current chunk has enough space, return pointer - if (Available >= MinSize) - { - return m_Detail->ChunkPtr(WriteChunk) + m_Detail->ChunkWrittenLength[WriteChunk]; - } - - // Current chunk is full - mark it as finished and move to next. - // The writer cannot advance until the reader has fully consumed the next chunk, - // preventing the writer from overwriting data the reader hasn't processed yet. - m_Detail->ChunkFinished[WriteChunk] = true; - m_Detail->ReadCV.notify_all(); - - const size_t NextChunk = (WriteChunk + 1) % m_Detail->NumChunks; - - // Wait until reader has consumed the next chunk - auto Predicate = [&]() -> bool { - // Check if read has moved past this chunk - return m_Detail->ReadChunkIdx != NextChunk || m_Detail->Detached; - }; - - if (TimeoutMs < 0) - { - m_Detail->WriteCV.wait(Lock, Predicate); - } - else - { - if (!m_Detail->WriteCV.wait_for(Lock, std::chrono::milliseconds(TimeoutMs), Predicate)) - { - return nullptr; - } - } - - if (m_Detail->Detached) - { - return nullptr; - } - - // Reset next chunk - m_Detail->ChunkWrittenLength[NextChunk] = 0; - m_Detail->ChunkFinished[NextChunk] = false; - m_Detail->WriteChunkIdx = NextChunk; - m_Detail->WriteOffset = 0; - - return m_Detail->ChunkPtr(NextChunk); -} - -} // namespace zen::horde diff --git a/src/zenhorde/hordecomputebuffer.h b/src/zenhorde/hordecomputebuffer.h deleted file mode 100644 index 64ef91b7a..000000000 --- a/src/zenhorde/hordecomputebuffer.h +++ /dev/null @@ -1,136 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#pragma once - -#include <zenbase/refcount.h> - -#include <cstddef> -#include <cstdint> -#include <mutex> -#include <vector> - -namespace zen::horde { - -class ComputeBufferReader; -class ComputeBufferWriter; - -/** Simplified in-process ring buffer for the Horde compute protocol. - * - * Unlike the UE FComputeBuffer which supports shared-memory and memory-mapped files, - * this implementation uses plain heap-allocated memory since we only need in-process - * communication between channel and transport threads. The buffer is divided into - * fixed-size chunks; readers and writers block when no space is available. - */ -class ComputeBuffer -{ -public: - struct Params - { - size_t NumChunks = 2; - size_t ChunkLength = 512 * 1024; - }; - - ComputeBuffer(); - ~ComputeBuffer(); - - ComputeBuffer(const ComputeBuffer&) = delete; - ComputeBuffer& operator=(const ComputeBuffer&) = delete; - - bool CreateNew(const Params& InParams); - void Close(); - - bool IsValid() const; - - ComputeBufferReader CreateReader(); - ComputeBufferWriter CreateWriter(); - -private: - struct Detail; - Ref<Detail> m_Detail; - - friend class ComputeBufferReader; - friend class ComputeBufferWriter; -}; - -/** Read endpoint for a ComputeBuffer. - * - * Provides blocking reads from the ring buffer. WaitToRead() returns a pointer - * directly into the buffer memory (zero-copy); the caller must call - * AdvanceReadPosition() after consuming the data. - */ -class ComputeBufferReader -{ -public: - ComputeBufferReader(); - ComputeBufferReader(const ComputeBufferReader&); - ComputeBufferReader(ComputeBufferReader&&) noexcept; - ~ComputeBufferReader(); - - ComputeBufferReader& operator=(const ComputeBufferReader&); - ComputeBufferReader& operator=(ComputeBufferReader&&) noexcept; - - void Close(); - void Detach(); - bool IsValid() const; - bool IsComplete() const; - - void AdvanceReadPosition(size_t Size); - size_t GetMaxReadSize() const; - - /** Copy up to MaxSize bytes from the buffer into Buffer. Blocks until data is available. */ - size_t Read(void* Buffer, size_t MaxSize, int TimeoutMs = -1, bool* OutTimedOut = nullptr); - - /** Wait until at least MinSize bytes are available and return a direct pointer. - * Returns nullptr on timeout or if the writer has completed. */ - const uint8_t* WaitToRead(size_t MinSize, int TimeoutMs = -1, bool* OutTimedOut = nullptr); - -private: - friend class ComputeBuffer; - explicit ComputeBufferReader(Ref<ComputeBuffer::Detail> InDetail); - - Ref<ComputeBuffer::Detail> m_Detail; -}; - -/** Write endpoint for a ComputeBuffer. - * - * Provides blocking writes into the ring buffer. WaitToWrite() returns a pointer - * directly into the buffer memory (zero-copy); the caller must call - * AdvanceWritePosition() after filling the data. Call MarkComplete() to signal - * that no more data will be written. - */ -class ComputeBufferWriter -{ -public: - ComputeBufferWriter(); - ComputeBufferWriter(const ComputeBufferWriter&); - ComputeBufferWriter(ComputeBufferWriter&&) noexcept; - ~ComputeBufferWriter(); - - ComputeBufferWriter& operator=(const ComputeBufferWriter&); - ComputeBufferWriter& operator=(ComputeBufferWriter&&) noexcept; - - void Close(); - bool IsValid() const; - - /** Signal that no more data will be written. Unblocks any waiting readers. */ - void MarkComplete(); - - void AdvanceWritePosition(size_t Size); - size_t GetMaxWriteSize() const; - size_t GetChunkMaxLength() const; - - /** Copy up to MaxSize bytes from Buffer into the ring buffer. Blocks until space is available. */ - size_t Write(const void* Buffer, size_t MaxSize, int TimeoutMs = -1); - - /** Wait until at least MinSize bytes of write space are available and return a direct pointer. - * Returns nullptr on timeout. */ - uint8_t* WaitToWrite(size_t MinSize, int TimeoutMs = -1); - -private: - friend class ComputeBuffer; - explicit ComputeBufferWriter(Ref<ComputeBuffer::Detail> InDetail); - - Ref<ComputeBuffer::Detail> m_Detail; -}; - -} // namespace zen::horde diff --git a/src/zenhorde/hordecomputechannel.cpp b/src/zenhorde/hordecomputechannel.cpp deleted file mode 100644 index ee2a6f327..000000000 --- a/src/zenhorde/hordecomputechannel.cpp +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#include "hordecomputechannel.h" - -namespace zen::horde { - -ComputeChannel::ComputeChannel(ComputeBufferReader InReader, ComputeBufferWriter InWriter) -: Reader(std::move(InReader)) -, Writer(std::move(InWriter)) -{ -} - -bool -ComputeChannel::IsValid() const -{ - return Reader.IsValid() && Writer.IsValid(); -} - -size_t -ComputeChannel::Send(const void* Data, size_t Size, int TimeoutMs) -{ - return Writer.Write(Data, Size, TimeoutMs); -} - -size_t -ComputeChannel::Recv(void* Data, size_t Size, int TimeoutMs) -{ - return Reader.Read(Data, Size, TimeoutMs); -} - -void -ComputeChannel::MarkComplete() -{ - Writer.MarkComplete(); -} - -} // namespace zen::horde diff --git a/src/zenhorde/hordecomputechannel.h b/src/zenhorde/hordecomputechannel.h deleted file mode 100644 index c1dff20e4..000000000 --- a/src/zenhorde/hordecomputechannel.h +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#pragma once - -#include "hordecomputebuffer.h" - -namespace zen::horde { - -/** Bidirectional communication channel using a pair of compute buffers. - * - * Pairs a ComputeBufferReader (for receiving data) with a ComputeBufferWriter - * (for sending data). Used by ComputeSocket to represent one logical channel - * within a multiplexed connection. - */ -class ComputeChannel : public TRefCounted<ComputeChannel> -{ -public: - ComputeBufferReader Reader; - ComputeBufferWriter Writer; - - ComputeChannel(ComputeBufferReader InReader, ComputeBufferWriter InWriter); - - bool IsValid() const; - - size_t Send(const void* Data, size_t Size, int TimeoutMs = -1); - size_t Recv(void* Data, size_t Size, int TimeoutMs = -1); - - /** Signal that no more data will be sent on this channel. */ - void MarkComplete(); -}; - -} // namespace zen::horde diff --git a/src/zenhorde/hordecomputesocket.cpp b/src/zenhorde/hordecomputesocket.cpp index 6ef67760c..92a56c077 100644 --- a/src/zenhorde/hordecomputesocket.cpp +++ b/src/zenhorde/hordecomputesocket.cpp @@ -6,198 +6,326 @@ namespace zen::horde { -ComputeSocket::ComputeSocket(std::unique_ptr<ComputeTransport> Transport) -: m_Log(zen::logging::Get("horde.socket")) +AsyncComputeSocket::AsyncComputeSocket(std::unique_ptr<AsyncComputeTransport> Transport, asio::io_context& IoContext) +: m_Log(zen::logging::Get("horde.socket.async")) , m_Transport(std::move(Transport)) +, m_Strand(asio::make_strand(IoContext)) +, m_PingTimer(m_Strand) { } -ComputeSocket::~ComputeSocket() +AsyncComputeSocket::~AsyncComputeSocket() { - // Shutdown order matters: first stop the ping thread, then unblock send threads - // by detaching readers, then join send threads, and finally close the transport - // to unblock the recv thread (which is blocked on RecvMessage). - { - std::lock_guard<std::mutex> Lock(m_PingMutex); - m_PingShouldStop = true; - m_PingCV.notify_all(); - } - - for (auto& Reader : m_Readers) - { - Reader.Detach(); - } - - for (auto& [Id, Thread] : m_SendThreads) - { - if (Thread.joinable()) - { - Thread.join(); - } - } + Close(); +} - m_Transport->Close(); +void +AsyncComputeSocket::RegisterChannel(int ChannelId, FrameHandler OnFrame, DetachHandler OnDetach) +{ + m_FrameHandlers[ChannelId] = std::move(OnFrame); + m_DetachHandlers[ChannelId] = std::move(OnDetach); +} - if (m_RecvThread.joinable()) - { - m_RecvThread.join(); - } - if (m_PingThread.joinable()) - { - m_PingThread.join(); - } +void +AsyncComputeSocket::StartRecvPump() +{ + StartPingTimer(); + DoRecvHeader(); } -Ref<ComputeChannel> -ComputeSocket::CreateChannel(int ChannelId) +void +AsyncComputeSocket::DoRecvHeader() { - ComputeBuffer::Params Params; + auto Self = shared_from_this(); + m_Transport->AsyncRead(&m_RecvHeader, + sizeof(FrameHeader), + asio::bind_executor(m_Strand, [this, Self](const std::error_code& Ec, size_t /*Bytes*/) { + if (Ec) + { + if (Ec != asio::error::operation_aborted && !m_Closed) + { + ZEN_WARN("recv header error: {}", Ec.message()); + HandleError(); + } + return; + } - ComputeBuffer RecvBuffer; - if (!RecvBuffer.CreateNew(Params)) - { - return {}; - } + if (m_Closed) + { + return; + } - ComputeBuffer SendBuffer; - if (!SendBuffer.CreateNew(Params)) - { - return {}; - } + if (m_RecvHeader.Size >= 0) + { + DoRecvPayload(m_RecvHeader); + } + else if (m_RecvHeader.Size == ControlDetach) + { + if (auto It = m_DetachHandlers.find(m_RecvHeader.Channel); It != m_DetachHandlers.end() && It->second) + { + It->second(); + } + DoRecvHeader(); + } + else if (m_RecvHeader.Size == ControlPing) + { + DoRecvHeader(); + } + else + { + ZEN_WARN("invalid frame header size: {}", m_RecvHeader.Size); + } + })); +} - Ref<ComputeChannel> Channel(new ComputeChannel(RecvBuffer.CreateReader(), SendBuffer.CreateWriter())); +void +AsyncComputeSocket::DoRecvPayload(FrameHeader Header) +{ + auto PayloadBuf = std::make_shared<std::vector<uint8_t>>(static_cast<size_t>(Header.Size)); + auto Self = shared_from_this(); - // Attach recv buffer writer (transport recv thread writes into this) - { - std::lock_guard<std::mutex> Lock(m_WritersMutex); - m_Writers.emplace(ChannelId, RecvBuffer.CreateWriter()); - } + m_Transport->AsyncRead(PayloadBuf->data(), + PayloadBuf->size(), + asio::bind_executor(m_Strand, [this, Self, Header, PayloadBuf](const std::error_code& Ec, size_t /*Bytes*/) { + if (Ec) + { + if (Ec != asio::error::operation_aborted && !m_Closed) + { + ZEN_WARN("recv payload error (channel={}, size={}): {}", Header.Channel, Header.Size, Ec.message()); + HandleError(); + } + return; + } - // Attach send buffer reader (send thread reads from this) - { - ComputeBufferReader Reader = SendBuffer.CreateReader(); - m_Readers.push_back(Reader); - m_SendThreads.emplace(ChannelId, std::thread(&ComputeSocket::SendThreadProc, this, ChannelId, std::move(Reader))); - } + if (m_Closed) + { + return; + } + + if (auto It = m_FrameHandlers.find(Header.Channel); It != m_FrameHandlers.end() && It->second) + { + It->second(std::move(*PayloadBuf)); + } + else + { + ZEN_WARN("recv frame for unknown channel {}", Header.Channel); + } - return Channel; + DoRecvHeader(); + })); } void -ComputeSocket::StartCommunication() +AsyncComputeSocket::AsyncSendFrame(int ChannelId, std::vector<uint8_t> Data, SendHandler Handler) { - m_RecvThread = std::thread(&ComputeSocket::RecvThreadProc, this); - m_PingThread = std::thread(&ComputeSocket::PingThreadProc, this); + auto Self = shared_from_this(); + asio::dispatch(m_Strand, [this, Self, ChannelId, Data = std::move(Data), Handler = std::move(Handler)]() mutable { + if (m_Closed) + { + if (Handler) + { + Handler(asio::error::make_error_code(asio::error::operation_aborted)); + } + return; + } + + PendingWrite Write; + Write.Header.Channel = ChannelId; + Write.Header.Size = static_cast<int32_t>(Data.size()); + Write.Data = std::move(Data); + Write.Handler = std::move(Handler); + + m_SendQueue.push_back(std::move(Write)); + if (m_SendQueue.size() == 1) + { + FlushNextSend(); + } + }); } void -ComputeSocket::PingThreadProc() +AsyncComputeSocket::AsyncSendDetach(int ChannelId, SendHandler Handler) { - while (true) - { + auto Self = shared_from_this(); + asio::dispatch(m_Strand, [this, Self, ChannelId, Handler = std::move(Handler)]() mutable { + if (m_Closed) { - std::unique_lock<std::mutex> Lock(m_PingMutex); - if (m_PingCV.wait_for(Lock, std::chrono::milliseconds(2000), [this] { return m_PingShouldStop; })) + if (Handler) { - break; + Handler(asio::error::make_error_code(asio::error::operation_aborted)); } + return; } - std::lock_guard<std::mutex> Lock(m_SendMutex); - FrameHeader Header; - Header.Channel = 0; - Header.Size = ControlPing; - m_Transport->SendMessage(&Header, sizeof(Header)); - } + PendingWrite Write; + Write.Header.Channel = ChannelId; + Write.Header.Size = ControlDetach; + Write.Handler = std::move(Handler); + + m_SendQueue.push_back(std::move(Write)); + if (m_SendQueue.size() == 1) + { + FlushNextSend(); + } + }); } void -ComputeSocket::RecvThreadProc() +AsyncComputeSocket::FlushNextSend() { - // Writers are cached locally to avoid taking m_WritersMutex on every frame. - // The shared m_Writers map is only accessed when a channel is seen for the first time. - std::unordered_map<int, ComputeBufferWriter> CachedWriters; + if (m_SendQueue.empty() || m_Closed) + { + return; + } - FrameHeader Header; - while (m_Transport->RecvMessage(&Header, sizeof(Header))) + PendingWrite& Front = m_SendQueue.front(); + + if (Front.Data.empty()) { - if (Header.Size >= 0) - { - // Data frame - auto It = CachedWriters.find(Header.Channel); - if (It == CachedWriters.end()) - { - std::lock_guard<std::mutex> Lock(m_WritersMutex); - auto WIt = m_Writers.find(Header.Channel); - if (WIt == m_Writers.end()) - { - ZEN_WARN("recv frame for unknown channel {}", Header.Channel); - // Skip the data - std::vector<uint8_t> Discard(Header.Size); - m_Transport->RecvMessage(Discard.data(), Header.Size); - continue; - } - It = CachedWriters.emplace(Header.Channel, WIt->second).first; - } + // Control frame - header only + auto Self = shared_from_this(); + m_Transport->AsyncWrite(&Front.Header, + sizeof(FrameHeader), + asio::bind_executor(m_Strand, [this, Self](const std::error_code& Ec, size_t /*Bytes*/) { + SendHandler Handler = std::move(m_SendQueue.front().Handler); + m_SendQueue.pop_front(); - ComputeBufferWriter& Writer = It->second; - uint8_t* Dest = Writer.WaitToWrite(Header.Size); - if (!Dest || !m_Transport->RecvMessage(Dest, Header.Size)) - { - ZEN_WARN("failed to read frame data (channel={}, size={})", Header.Channel, Header.Size); - return; - } - Writer.AdvanceWritePosition(Header.Size); - } - else if (Header.Size == ControlDetach) - { - // Detach the recv buffer for this channel - CachedWriters.erase(Header.Channel); + if (Handler) + { + Handler(Ec); + } - std::lock_guard<std::mutex> Lock(m_WritersMutex); - auto It = m_Writers.find(Header.Channel); - if (It != m_Writers.end()) - { - It->second.MarkComplete(); - m_Writers.erase(It); - } - } - else if (Header.Size == ControlPing) + if (Ec) + { + if (Ec != asio::error::operation_aborted) + { + ZEN_WARN("send error: {}", Ec.message()); + HandleError(); + } + return; + } + + FlushNextSend(); + })); + } + else + { + // Data frame - write header first, then payload + auto Self = shared_from_this(); + m_Transport->AsyncWrite(&Front.Header, + sizeof(FrameHeader), + asio::bind_executor(m_Strand, [this, Self](const std::error_code& Ec, size_t /*Bytes*/) { + if (Ec) + { + SendHandler Handler = std::move(m_SendQueue.front().Handler); + m_SendQueue.pop_front(); + if (Handler) + { + Handler(Ec); + } + if (Ec != asio::error::operation_aborted) + { + ZEN_WARN("send header error: {}", Ec.message()); + HandleError(); + } + return; + } + + PendingWrite& Payload = m_SendQueue.front(); + m_Transport->AsyncWrite( + Payload.Data.data(), + Payload.Data.size(), + asio::bind_executor(m_Strand, [this, Self](const std::error_code& Ec, size_t /*Bytes*/) { + SendHandler Handler = std::move(m_SendQueue.front().Handler); + m_SendQueue.pop_front(); + + if (Handler) + { + Handler(Ec); + } + + if (Ec) + { + if (Ec != asio::error::operation_aborted) + { + ZEN_WARN("send payload error: {}", Ec.message()); + HandleError(); + } + return; + } + + FlushNextSend(); + })); + })); + } +} + +void +AsyncComputeSocket::StartPingTimer() +{ + if (m_Closed) + { + return; + } + + m_PingTimer.expires_after(std::chrono::seconds(2)); + + auto Self = shared_from_this(); + m_PingTimer.async_wait(asio::bind_executor(m_Strand, [this, Self](const asio::error_code& Ec) { + if (Ec || m_Closed) { - // Ping response - ignore + return; } - else + + // Enqueue a ping control frame + PendingWrite Write; + Write.Header.Channel = 0; + Write.Header.Size = ControlPing; + + m_SendQueue.push_back(std::move(Write)); + if (m_SendQueue.size() == 1) { - ZEN_WARN("invalid frame header size: {}", Header.Size); - return; + FlushNextSend(); } - } + + StartPingTimer(); + })); } void -ComputeSocket::SendThreadProc(int Channel, ComputeBufferReader Reader) +AsyncComputeSocket::HandleError() { - // Each channel has its own send thread. All send threads share m_SendMutex - // to serialize writes to the transport, since TCP requires atomic frame writes. - FrameHeader Header; - Header.Channel = Channel; + if (m_Closed) + { + return; + } + + Close(); - const uint8_t* Data; - while ((Data = Reader.WaitToRead(1)) != nullptr) + // Notify all channels that the connection is gone so agents can clean up + for (auto& [ChannelId, Handler] : m_DetachHandlers) { - std::lock_guard<std::mutex> Lock(m_SendMutex); + if (Handler) + { + Handler(); + } + } +} - Header.Size = static_cast<int32_t>(Reader.GetMaxReadSize()); - m_Transport->SendMessage(&Header, sizeof(Header)); - m_Transport->SendMessage(Data, Header.Size); - Reader.AdvanceReadPosition(Header.Size); +void +AsyncComputeSocket::Close() +{ + if (m_Closed) + { + return; } - if (Reader.IsComplete()) + m_Closed = true; + m_PingTimer.cancel(); + + if (m_Transport) { - std::lock_guard<std::mutex> Lock(m_SendMutex); - Header.Size = ControlDetach; - m_Transport->SendMessage(&Header, sizeof(Header)); + m_Transport->Close(); } } diff --git a/src/zenhorde/hordecomputesocket.h b/src/zenhorde/hordecomputesocket.h index 0c3cb4195..6c494603a 100644 --- a/src/zenhorde/hordecomputesocket.h +++ b/src/zenhorde/hordecomputesocket.h @@ -2,45 +2,74 @@ #pragma once -#include "hordecomputebuffer.h" -#include "hordecomputechannel.h" #include "hordetransport.h" #include <zencore/logbase.h> -#include <condition_variable> +ZEN_THIRD_PARTY_INCLUDES_START +#include <asio.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +#if ZEN_PLATFORM_WINDOWS +# undef SendMessage +#endif + +#include <deque> +#include <functional> #include <memory> -#include <mutex> -#include <thread> +#include <system_error> #include <unordered_map> #include <vector> namespace zen::horde { -/** Multiplexed socket that routes data between multiple channels over a single transport. +class AsyncComputeTransport; + +/** Handler called when a data frame arrives for a channel. */ +using FrameHandler = std::function<void(std::vector<uint8_t> Data)>; + +/** Handler called when a channel is detached by the remote peer. */ +using DetachHandler = std::function<void()>; + +/** Handler for async send completion. */ +using SendHandler = std::function<void(const std::error_code&)>; + +/** Async multiplexed socket that routes data between channels over a single transport. * - * Each channel is identified by an integer ID and backed by a pair of ComputeBuffers. - * A recv thread demultiplexes incoming frames to channel-specific buffers, while - * per-channel send threads multiplex outgoing data onto the shared transport. + * Uses an async recv pump, a serialized send queue, and a periodic ping timer - + * all running on a shared io_context. * - * Wire format per frame: [channelId (4B)][size (4B)][data] - * Control messages use negative sizes: -2 = detach (channel closed), -3 = ping. + * Wire format per frame: [channelId(4B)][size(4B)][data]. + * Control messages use negative sizes: -2 = detach, -3 = ping. */ -class ComputeSocket +class AsyncComputeSocket : public std::enable_shared_from_this<AsyncComputeSocket> { public: - explicit ComputeSocket(std::unique_ptr<ComputeTransport> Transport); - ~ComputeSocket(); + AsyncComputeSocket(std::unique_ptr<AsyncComputeTransport> Transport, asio::io_context& IoContext); + ~AsyncComputeSocket(); + + AsyncComputeSocket(const AsyncComputeSocket&) = delete; + AsyncComputeSocket& operator=(const AsyncComputeSocket&) = delete; + + /** Register callbacks for a channel. Must be called before StartRecvPump(). */ + void RegisterChannel(int ChannelId, FrameHandler OnFrame, DetachHandler OnDetach); - ComputeSocket(const ComputeSocket&) = delete; - ComputeSocket& operator=(const ComputeSocket&) = delete; + /** Begin the async recv pump and ping timer. */ + void StartRecvPump(); - /** Create a channel with the given ID. - * Allocates anonymous in-process buffers and spawns a send thread for the channel. */ - Ref<ComputeChannel> CreateChannel(int ChannelId); + /** Enqueue a data frame for async transmission. */ + void AsyncSendFrame(int ChannelId, std::vector<uint8_t> Data, SendHandler Handler = {}); - /** Start the recv pump and ping threads. Must be called after all channels are created. */ - void StartCommunication(); + /** Send a control frame (detach) for a channel. */ + void AsyncSendDetach(int ChannelId, SendHandler Handler = {}); + + /** Close the transport and cancel all pending operations. */ + void Close(); + + /** The strand on which all socket I/O callbacks run. Channels that need to serialize + * their own state with OnFrame/OnDetach (which are invoked from this strand) should + * bind their timers and async operations to it as well. */ + asio::strand<asio::any_io_executor>& GetStrand() { return m_Strand; } private: struct FrameHeader @@ -49,31 +78,35 @@ private: int32_t Size = 0; }; + struct PendingWrite + { + FrameHeader Header; + std::vector<uint8_t> Data; + SendHandler Handler; + }; + static constexpr int32_t ControlDetach = -2; static constexpr int32_t ControlPing = -3; LoggerRef Log() { return m_Log; } - void RecvThreadProc(); - void SendThreadProc(int Channel, ComputeBufferReader Reader); - void PingThreadProc(); - - LoggerRef m_Log; - std::unique_ptr<ComputeTransport> m_Transport; - std::mutex m_SendMutex; ///< Serializes writes to the transport - - std::mutex m_WritersMutex; - std::unordered_map<int, ComputeBufferWriter> m_Writers; ///< Recv-side: writers keyed by channel ID + void DoRecvHeader(); + void DoRecvPayload(FrameHeader Header); + void FlushNextSend(); + void StartPingTimer(); + void HandleError(); - std::vector<ComputeBufferReader> m_Readers; ///< Send-side: readers for join on destruction - std::unordered_map<int, std::thread> m_SendThreads; ///< One send thread per channel + LoggerRef m_Log; + std::unique_ptr<AsyncComputeTransport> m_Transport; + asio::strand<asio::any_io_executor> m_Strand; + asio::steady_timer m_PingTimer; - std::thread m_RecvThread; - std::thread m_PingThread; + std::unordered_map<int, FrameHandler> m_FrameHandlers; + std::unordered_map<int, DetachHandler> m_DetachHandlers; - bool m_PingShouldStop = false; - std::mutex m_PingMutex; - std::condition_variable m_PingCV; + FrameHeader m_RecvHeader; + std::deque<PendingWrite> m_SendQueue; + bool m_Closed = false; }; } // namespace zen::horde diff --git a/src/zenhorde/hordeconfig.cpp b/src/zenhorde/hordeconfig.cpp index 2dca228d9..9f6125c64 100644 --- a/src/zenhorde/hordeconfig.cpp +++ b/src/zenhorde/hordeconfig.cpp @@ -1,5 +1,7 @@ // Copyright Epic Games, Inc. All Rights Reserved. +#include <zencore/logging.h> +#include <zencore/string.h> #include <zenhorde/hordeconfig.h> namespace zen::horde { @@ -9,12 +11,14 @@ HordeConfig::Validate() const { if (ServerUrl.empty()) { + ZEN_WARN("Horde server URL is not configured"); return false; } // Relay mode implies AES encryption if (Mode == ConnectionMode::Relay && EncryptionMode != Encryption::AES) { + ZEN_WARN("Horde relay mode requires AES encryption, but encryption is set to '{}'", ToString(EncryptionMode)); return false; } @@ -52,37 +56,39 @@ ToString(Encryption Enc) bool FromString(ConnectionMode& OutMode, std::string_view Str) { - if (Str == "direct") + if (StrCaseCompare(Str, "direct") == 0) { OutMode = ConnectionMode::Direct; return true; } - if (Str == "tunnel") + if (StrCaseCompare(Str, "tunnel") == 0) { OutMode = ConnectionMode::Tunnel; return true; } - if (Str == "relay") + if (StrCaseCompare(Str, "relay") == 0) { OutMode = ConnectionMode::Relay; return true; } + ZEN_WARN("unrecognized Horde connection mode: '{}'", Str); return false; } bool FromString(Encryption& OutEnc, std::string_view Str) { - if (Str == "none") + if (StrCaseCompare(Str, "none") == 0) { OutEnc = Encryption::None; return true; } - if (Str == "aes") + if (StrCaseCompare(Str, "aes") == 0) { OutEnc = Encryption::AES; return true; } + ZEN_WARN("unrecognized Horde encryption mode: '{}'", Str); return false; } diff --git a/src/zenhorde/hordeprovisioner.cpp b/src/zenhorde/hordeprovisioner.cpp index f88c95da2..ea0ea1e83 100644 --- a/src/zenhorde/hordeprovisioner.cpp +++ b/src/zenhorde/hordeprovisioner.cpp @@ -6,49 +6,83 @@ #include "hordeagent.h" #include "hordebundle.h" +#include <zencore/compactbinary.h> #include <zencore/fmtutils.h> #include <zencore/logging.h> #include <zencore/scopeguard.h> #include <zencore/thread.h> #include <zencore/trace.h> +#include <zenhttp/httpclient.h> +#include <zenutil/workerpools.h> +ZEN_THIRD_PARTY_INCLUDES_START +#include <asio.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +#include <algorithm> #include <chrono> #include <thread> namespace zen::horde { -struct HordeProvisioner::AgentWrapper -{ - std::thread Thread; - std::atomic<bool> ShouldExit{false}; -}; - HordeProvisioner::HordeProvisioner(const HordeConfig& Config, const std::filesystem::path& BinariesPath, const std::filesystem::path& WorkingDir, - std::string_view OrchestratorEndpoint) + std::string_view OrchestratorEndpoint, + std::string_view CoordinatorSession, + bool CleanStart, + std::string_view TraceHost) : m_Config(Config) , m_BinariesPath(BinariesPath) , m_WorkingDir(WorkingDir) , m_OrchestratorEndpoint(OrchestratorEndpoint) +, m_CoordinatorSession(CoordinatorSession) +, m_CleanStart(CleanStart) +, m_TraceHost(TraceHost) , m_Log(zen::logging::Get("horde.provisioner")) { + m_IoContext = std::make_unique<asio::io_context>(); + + auto Work = asio::make_work_guard(*m_IoContext); + for (int i = 0; i < IoThreadCount; ++i) + { + m_IoThreads.emplace_back([this, i, Work] { + zen::SetCurrentThreadName(fmt::format("horde_io_{}", i)); + m_IoContext->run(); + }); + } } HordeProvisioner::~HordeProvisioner() { - std::lock_guard<std::mutex> Lock(m_AgentsLock); - for (auto& Agent : m_Agents) + m_AskForAgents.store(false); + m_ShutdownEvent.Set(); + + // Shut down async agents and io_context { - Agent->ShouldExit.store(true); + std::lock_guard<std::mutex> Lock(m_AsyncAgentsLock); + for (auto& Entry : m_AsyncAgents) + { + Entry.Agent->Cancel(); + } + m_AsyncAgents.clear(); } - for (auto& Agent : m_Agents) + + m_IoContext->stop(); + + for (auto& Thread : m_IoThreads) { - if (Agent->Thread.joinable()) + if (Thread.joinable()) { - Agent->Thread.join(); + Thread.join(); } } + + // Wait for all pool work items to finish before destroying members they reference + if (m_PendingWorkItems.load() > 0) + { + m_AllWorkDone.Wait(); + } } void @@ -56,9 +90,23 @@ HordeProvisioner::SetTargetCoreCount(uint32_t Count) { ZEN_TRACE_CPU("HordeProvisioner::SetTargetCoreCount"); - m_TargetCoreCount.store(std::min(Count, static_cast<uint32_t>(m_Config.MaxCores))); + const uint32_t ClampedCount = std::min(Count, static_cast<uint32_t>(m_Config.MaxCores)); + const uint32_t PreviousTarget = m_TargetCoreCount.exchange(ClampedCount); + + if (ClampedCount != PreviousTarget) + { + ZEN_INFO("target core count changed: {} -> {} (active={}, estimated={})", + PreviousTarget, + ClampedCount, + m_ActiveCoreCount.load(), + m_EstimatedCoreCount.load()); + } - while (m_EstimatedCoreCount.load() < m_TargetCoreCount.load()) + // Only provision if the gap is at least one agent-sized chunk. Without + // this, draining a 32-core agent to cover a 28-core excess would leave a + // 4-core gap that triggers a 32-core provision, which triggers another + // drain, ad infinitum. + while (m_EstimatedCoreCount.load() + EstimatedCoresPerAgent <= m_TargetCoreCount.load()) { if (!m_AskForAgents.load()) { @@ -67,21 +115,108 @@ HordeProvisioner::SetTargetCoreCount(uint32_t Count) RequestAgent(); } - // Clean up finished agent threads - std::lock_guard<std::mutex> Lock(m_AgentsLock); - for (auto It = m_Agents.begin(); It != m_Agents.end();) + // Scale down async agents { - if ((*It)->ShouldExit.load()) + std::lock_guard<std::mutex> AsyncLock(m_AsyncAgentsLock); + + uint32_t AsyncActive = m_ActiveCoreCount.load(); + uint32_t AsyncTarget = m_TargetCoreCount.load(); + + uint32_t AlreadyDrainingCores = 0; + for (const auto& Entry : m_AsyncAgents) { - if ((*It)->Thread.joinable()) + if (Entry.Draining) { - (*It)->Thread.join(); + AlreadyDrainingCores += Entry.CoreCount; } - It = m_Agents.erase(It); } - else + + uint32_t EffectiveAsync = (AsyncActive > AlreadyDrainingCores) ? AsyncActive - AlreadyDrainingCores : 0; + + if (EffectiveAsync > AsyncTarget) { - ++It; + struct Candidate + { + AsyncAgentEntry* Entry; + int Workload; + }; + std::vector<Candidate> Candidates; + + for (auto& Entry : m_AsyncAgents) + { + if (Entry.Draining || Entry.RemoteEndpoint.empty()) + { + continue; + } + + int Workload = 0; + bool Reachable = false; + HttpClientSettings Settings; + Settings.LogCategory = "horde.drain"; + Settings.ConnectTimeout = std::chrono::milliseconds{2000}; + Settings.Timeout = std::chrono::milliseconds{3000}; + try + { + HttpClient Client(Entry.RemoteEndpoint, Settings); + HttpClient::Response Resp = Client.Get("/compute/session/status"); + if (Resp.IsSuccess()) + { + CbObject Status = Resp.AsObject(); + Workload = Status["actions_pending"].AsInt32(0) + Status["actions_running"].AsInt32(0); + Reachable = true; + } + } + catch (const std::exception& Ex) + { + ZEN_DEBUG("agent lease={} not yet reachable for drain: {}", Entry.LeaseId, Ex.what()); + } + + if (Reachable) + { + Candidates.push_back({&Entry, Workload}); + } + } + + const uint32_t ExcessCores = EffectiveAsync - AsyncTarget; + uint32_t CoresDrained = 0; + + while (CoresDrained < ExcessCores && !Candidates.empty()) + { + const uint32_t Remaining = ExcessCores - CoresDrained; + + Candidates.erase(std::remove_if(Candidates.begin(), + Candidates.end(), + [Remaining](const Candidate& C) { return C.Entry->CoreCount > Remaining; }), + Candidates.end()); + + if (Candidates.empty()) + { + break; + } + + Candidate* Best = &Candidates[0]; + for (auto& C : Candidates) + { + if (C.Entry->CoreCount > Best->Entry->CoreCount || + (C.Entry->CoreCount == Best->Entry->CoreCount && C.Workload < Best->Workload)) + { + Best = &C; + } + } + + ZEN_INFO("draining async agent lease={} ({} cores, workload={})", + Best->Entry->LeaseId, + Best->Entry->CoreCount, + Best->Workload); + + DrainAsyncAgent(*Best->Entry); + CoresDrained += Best->Entry->CoreCount; + + AsyncAgentEntry* Drained = Best->Entry; + Candidates.erase( + std::remove_if(Candidates.begin(), Candidates.end(), [Drained](const Candidate& C) { return C.Entry == Drained; }), + Candidates.end()); + } } } } @@ -101,266 +236,395 @@ HordeProvisioner::GetStats() const uint32_t HordeProvisioner::GetAgentCount() const { - std::lock_guard<std::mutex> Lock(m_AgentsLock); - return static_cast<uint32_t>(m_Agents.size()); + std::lock_guard<std::mutex> Lock(m_AsyncAgentsLock); + return static_cast<uint32_t>(m_AsyncAgents.size()); } -void -HordeProvisioner::RequestAgent() +compute::AgentProvisioningStatus +HordeProvisioner::GetAgentStatus(std::string_view WorkerId) const { - m_EstimatedCoreCount.fetch_add(EstimatedCoresPerAgent); + // Worker IDs are "horde-{LeaseId}" - strip the prefix to match lease ID + constexpr std::string_view Prefix = "horde-"; + if (!WorkerId.starts_with(Prefix)) + { + return compute::AgentProvisioningStatus::Unknown; + } + std::string_view LeaseId = WorkerId.substr(Prefix.size()); - std::lock_guard<std::mutex> Lock(m_AgentsLock); + std::lock_guard<std::mutex> AsyncLock(m_AsyncAgentsLock); + for (const auto& Entry : m_AsyncAgents) + { + if (Entry.LeaseId == LeaseId) + { + if (Entry.Draining) + { + return compute::AgentProvisioningStatus::Draining; + } + return compute::AgentProvisioningStatus::Active; + } + } - auto Wrapper = std::make_unique<AgentWrapper>(); - AgentWrapper& Ref = *Wrapper; - Wrapper->Thread = std::thread([this, &Ref] { ThreadAgent(Ref); }); + // Check recently-drained agents that have already been cleaned up + std::string WorkerIdStr(WorkerId); + if (m_RecentlyDrainedWorkerIds.erase(WorkerIdStr) > 0) + { + // Also remove from the ordering queue so size accounting stays consistent. + auto It = std::find(m_RecentlyDrainedOrder.begin(), m_RecentlyDrainedOrder.end(), WorkerIdStr); + if (It != m_RecentlyDrainedOrder.end()) + { + m_RecentlyDrainedOrder.erase(It); + } + return compute::AgentProvisioningStatus::Draining; + } - m_Agents.push_back(std::move(Wrapper)); + return compute::AgentProvisioningStatus::Unknown; } -void -HordeProvisioner::ThreadAgent(AgentWrapper& Wrapper) +std::vector<std::string> +HordeProvisioner::BuildAgentArgs(const MachineInfo& Machine) const { - ZEN_TRACE_CPU("HordeProvisioner::ThreadAgent"); + std::vector<std::string> Args; + Args.emplace_back("compute"); + Args.emplace_back("--http=asio"); + Args.push_back(fmt::format("--port={}", m_Config.ZenServicePort)); + Args.emplace_back("--data-dir=%UE_HORDE_SHARED_DIR%\\zen"); - static std::atomic<uint32_t> ThreadIndex{0}; - const uint32_t CurrentIndex = ThreadIndex.fetch_add(1); + if (m_CleanStart) + { + Args.emplace_back("--clean"); + } - zen::SetCurrentThreadName(fmt::format("horde_agent_{}", CurrentIndex)); + if (!m_OrchestratorEndpoint.empty()) + { + ExtendableStringBuilder<256> CoordArg; + CoordArg << "--coordinator-endpoint=" << m_OrchestratorEndpoint; + Args.emplace_back(CoordArg.ToView()); + } - std::unique_ptr<HordeAgent> Agent; - uint32_t MachineCoreCount = 0; + { + ExtendableStringBuilder<128> IdArg; + IdArg << "--instance-id=horde-" << Machine.LeaseId; + Args.emplace_back(IdArg.ToView()); + } - auto _ = MakeGuard([&] { - if (Agent) - { - Agent->CloseConnection(); - } - Wrapper.ShouldExit.store(true); - }); + if (!m_CoordinatorSession.empty()) + { + ExtendableStringBuilder<128> SessionArg; + SessionArg << "--coordinator-session=" << m_CoordinatorSession; + Args.emplace_back(SessionArg.ToView()); + } + if (!m_TraceHost.empty()) { - // EstimatedCoreCount is incremented speculatively when the agent is requested - // (in RequestAgent) so that SetTargetCoreCount doesn't over-provision. - auto $ = MakeGuard([&] { m_EstimatedCoreCount.fetch_sub(EstimatedCoresPerAgent); }); + ExtendableStringBuilder<128> TraceArg; + TraceArg << "--tracehost=" << m_TraceHost; + Args.emplace_back(TraceArg.ToView()); + } + // In relay mode, the remote zenserver's local address is not reachable from the + // orchestrator. Pass the relay-visible endpoint so it announces the correct URL. + if (Machine.Mode == ConnectionMode::Relay) + { + const auto [Addr, Port] = Machine.GetZenServiceEndpoint(m_Config.ZenServicePort); + if (Addr.find(':') != std::string::npos) + { + Args.push_back(fmt::format("--announce-url=http://[{}]:{}", Addr, Port)); + } + else { - ZEN_TRACE_CPU("HordeProvisioner::CreateBundles"); + Args.push_back(fmt::format("--announce-url=http://{}:{}", Addr, Port)); + } + } - std::lock_guard<std::mutex> BundleLock(m_BundleLock); + return Args; +} - if (!m_BundlesCreated) - { - const std::filesystem::path OutputDir = m_WorkingDir / "horde_bundles"; +bool +HordeProvisioner::InitializeHordeClient() +{ + ZEN_TRACE_CPU("HordeProvisioner::InitializeHordeClient"); + + std::lock_guard<std::mutex> BundleLock(m_BundleLock); - std::vector<BundleFile> Files; + if (!m_BundlesCreated) + { + const std::filesystem::path OutputDir = m_WorkingDir / "horde_bundles"; + + std::vector<BundleFile> Files; #if ZEN_PLATFORM_WINDOWS - Files.emplace_back(m_BinariesPath / "zenserver.exe", false); + Files.emplace_back(m_BinariesPath / "zenserver.exe", false); + Files.emplace_back(m_BinariesPath / "zenserver.pdb", true); #elif ZEN_PLATFORM_LINUX - Files.emplace_back(m_BinariesPath / "zenserver", false); - Files.emplace_back(m_BinariesPath / "zenserver.debug", true); + Files.emplace_back(m_BinariesPath / "zenserver", false); + Files.emplace_back(m_BinariesPath / "zenserver.debug", true); #elif ZEN_PLATFORM_MAC - Files.emplace_back(m_BinariesPath / "zenserver", false); + Files.emplace_back(m_BinariesPath / "zenserver", false); #endif - BundleResult Result; - if (!BundleCreator::CreateBundle(Files, OutputDir, Result)) - { - ZEN_WARN("failed to create bundle, cannot provision any agents!"); - m_AskForAgents.store(false); - return; - } - - m_Bundles.emplace_back(Result.Locator, Result.BundleDir); - m_BundlesCreated = true; - } - - if (!m_HordeClient) - { - m_HordeClient = std::make_unique<HordeClient>(m_Config); - if (!m_HordeClient->Initialize()) - { - ZEN_WARN("failed to initialize Horde HTTP client, cannot provision any agents!"); - m_AskForAgents.store(false); - return; - } - } + BundleResult Result; + if (!BundleCreator::CreateBundle(Files, OutputDir, Result)) + { + ZEN_WARN("failed to create bundle, cannot provision any agents!"); + m_AskForAgents.store(false); + m_ShutdownEvent.Set(); + return false; } - if (!m_AskForAgents.load()) + m_Bundles.emplace_back(Result.Locator, Result.BundleDir); + m_BundlesCreated = true; + } + + if (!m_HordeClient) + { + m_HordeClient = std::make_unique<HordeClient>(m_Config); + if (!m_HordeClient->Initialize()) { - return; + ZEN_WARN("failed to initialize Horde HTTP client, cannot provision any agents!"); + m_AskForAgents.store(false); + m_ShutdownEvent.Set(); + return false; } + } - m_AgentsRequesting.fetch_add(1); - auto ReqGuard = MakeGuard([this] { m_AgentsRequesting.fetch_sub(1); }); + return true; +} - // Simple backoff: if the last machine request failed, wait up to 5 seconds - // before trying again. - // - // Note however that it's possible that multiple threads enter this code at - // the same time if multiple agents are requested at once, and they will all - // see the same last failure time and back off accordingly. We might want to - // use a semaphore or similar to limit the number of concurrent requests. +void +HordeProvisioner::RequestAgent() +{ + m_EstimatedCoreCount.fetch_add(EstimatedCoresPerAgent); - if (const uint64_t LastFail = m_LastRequestFailTime.load(); LastFail != 0) - { - auto Now = static_cast<uint64_t>(std::chrono::steady_clock::now().time_since_epoch().count()); - const uint64_t ElapsedNs = Now - LastFail; - const uint64_t ElapsedMs = ElapsedNs / 1'000'000; - if (ElapsedMs < 5000) - { - const uint64_t WaitMs = 5000 - ElapsedMs; - for (uint64_t Waited = 0; Waited < WaitMs && !Wrapper.ShouldExit.load(); Waited += 100) - { - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - } + if (m_PendingWorkItems.fetch_add(1) == 0) + { + m_AllWorkDone.Reset(); + } - if (Wrapper.ShouldExit.load()) + GetSmallWorkerPool(EWorkloadType::Background) + .ScheduleWork( + [this] { + ProvisionAgent(); + if (m_PendingWorkItems.fetch_sub(1) == 1) { - return; + m_AllWorkDone.Set(); } - } - } + }, + WorkerThreadPool::EMode::EnableBacklog); +} - if (m_ActiveCoreCount.load() >= m_TargetCoreCount.load()) - { - return; - } +void +HordeProvisioner::ProvisionAgent() +{ + ZEN_TRACE_CPU("HordeProvisioner::ProvisionAgent"); + + // EstimatedCoreCount is incremented speculatively when the agent is requested + // (in RequestAgent) so that SetTargetCoreCount doesn't over-provision. + auto $ = MakeGuard([&] { m_EstimatedCoreCount.fetch_sub(EstimatedCoresPerAgent); }); - std::string RequestBody = m_HordeClient->BuildRequestBody(); + if (!InitializeHordeClient()) + { + return; + } - // Resolve cluster if needed - std::string ClusterId = m_Config.Cluster; - if (ClusterId == HordeConfig::ClusterAuto) + if (!m_AskForAgents.load()) + { + return; + } + + m_AgentsRequesting.fetch_add(1); + auto ReqGuard = MakeGuard([this] { m_AgentsRequesting.fetch_sub(1); }); + + // Simple backoff: if the last machine request failed, wait up to 5 seconds + // before trying again. + // + // Note however that it's possible that multiple threads enter this code at + // the same time if multiple agents are requested at once, and they will all + // see the same last failure time and back off accordingly. We might want to + // use a semaphore or similar to limit the number of concurrent requests. + + if (const uint64_t LastFail = m_LastRequestFailTime.load(); LastFail != 0) + { + auto Now = static_cast<uint64_t>(std::chrono::steady_clock::now().time_since_epoch().count()); + const uint64_t ElapsedNs = Now - LastFail; + const uint64_t ElapsedMs = ElapsedNs / 1'000'000; + if (ElapsedMs < 5000) { - ClusterInfo Cluster; - if (!m_HordeClient->ResolveCluster(RequestBody, Cluster)) + // Wait on m_ShutdownEvent so shutdown wakes this pool thread immediately instead + // of stalling for up to 5s in 100ms sleep chunks. Wait() returns true iff the + // event was signaled (shutdown); false means the backoff elapsed normally. + const uint64_t WaitMs = 5000 - ElapsedMs; + if (m_ShutdownEvent.Wait(static_cast<int>(WaitMs))) { - ZEN_WARN("failed to resolve cluster"); - m_LastRequestFailTime.store(static_cast<uint64_t>(std::chrono::steady_clock::now().time_since_epoch().count())); return; } - ClusterId = Cluster.ClusterId; } + } - MachineInfo Machine; - if (!m_HordeClient->RequestMachine(RequestBody, ClusterId, /* out */ Machine) || !Machine.IsValid()) + if (m_ActiveCoreCount.load() >= m_TargetCoreCount.load()) + { + return; + } + + std::string RequestBody = m_HordeClient->BuildRequestBody(); + + // Resolve cluster if needed + std::string ClusterId = m_Config.Cluster; + if (ClusterId == HordeConfig::ClusterAuto) + { + ClusterInfo Cluster; + if (!m_HordeClient->ResolveCluster(RequestBody, Cluster)) { + ZEN_WARN("failed to resolve cluster"); m_LastRequestFailTime.store(static_cast<uint64_t>(std::chrono::steady_clock::now().time_since_epoch().count())); return; } + ClusterId = Cluster.ClusterId; + } - m_LastRequestFailTime.store(0); + ZEN_INFO("requesting machine from Horde (cluster='{}', cores={}/{})", + ClusterId.empty() ? "default" : ClusterId.c_str(), + m_ActiveCoreCount.load(), + m_TargetCoreCount.load()); - if (Wrapper.ShouldExit.load()) - { - return; - } + MachineInfo Machine; + if (!m_HordeClient->RequestMachine(RequestBody, ClusterId, /* out */ Machine) || !Machine.IsValid()) + { + m_LastRequestFailTime.store(static_cast<uint64_t>(std::chrono::steady_clock::now().time_since_epoch().count())); + return; + } - // Connect to agent and perform handshake - Agent = std::make_unique<HordeAgent>(Machine); - if (!Agent->IsValid()) - { - ZEN_WARN("agent creation failed for {}:{}", Machine.GetConnectionAddress(), Machine.GetConnectionPort()); - return; - } + m_LastRequestFailTime.store(0); - if (!Agent->BeginCommunication()) - { - ZEN_WARN("BeginCommunication failed"); - return; - } + if (!m_AskForAgents.load()) + { + return; + } - for (auto& [Locator, BundleDir] : m_Bundles) + AsyncAgentConfig AgentConfig; + AgentConfig.Machine = Machine; + AgentConfig.Bundles = m_Bundles; + AgentConfig.Args = BuildAgentArgs(Machine); + +#if ZEN_PLATFORM_WINDOWS + AgentConfig.UseWine = !Machine.IsWindows; + AgentConfig.Executable = "zenserver.exe"; +#else + AgentConfig.UseWine = false; + AgentConfig.Executable = "zenserver"; +#endif + + auto AsyncAgent = std::make_shared<AsyncHordeAgent>(*m_IoContext); + + AsyncAgentEntry Entry; + Entry.Agent = AsyncAgent; + Entry.LeaseId = Machine.LeaseId; + Entry.CoreCount = Machine.LogicalCores; + + const auto [EndpointAddr, EndpointPort] = Machine.GetZenServiceEndpoint(m_Config.ZenServicePort); + if (EndpointAddr.find(':') != std::string::npos) + { + Entry.RemoteEndpoint = fmt::format("http://[{}]:{}", EndpointAddr, EndpointPort); + } + else + { + Entry.RemoteEndpoint = fmt::format("http://{}:{}", EndpointAddr, EndpointPort); + } + + { + std::lock_guard<std::mutex> Lock(m_AsyncAgentsLock); + m_AsyncAgents.push_back(std::move(Entry)); + } + + AsyncAgent->Start(std::move(AgentConfig), [this, AsyncAgent](const AsyncAgentResult& Result) { + if (Result.CoreCount > 0) { - if (Wrapper.ShouldExit.load()) + // Only subtract estimated cores if not already subtracted by DrainAsyncAgent + bool WasDraining = false; { - return; + std::lock_guard<std::mutex> Lock(m_AsyncAgentsLock); + for (const auto& Entry : m_AsyncAgents) + { + if (Entry.Agent == AsyncAgent) + { + WasDraining = Entry.Draining; + break; + } + } } - if (!Agent->UploadBinaries(BundleDir, Locator)) + if (!WasDraining) { - ZEN_WARN("UploadBinaries failed"); - return; + m_EstimatedCoreCount.fetch_sub(Result.CoreCount); } + m_ActiveCoreCount.fetch_sub(Result.CoreCount); + m_AgentsActive.fetch_sub(1); } + OnAsyncAgentDone(AsyncAgent); + }); - if (Wrapper.ShouldExit.load()) - { - return; - } - - // Build command line for remote zenserver - std::vector<std::string> ArgStrings; - ArgStrings.push_back("compute"); - ArgStrings.push_back("--http=asio"); + // Track active cores (estimated was already added by RequestAgent) + m_EstimatedCoreCount.fetch_add(Machine.LogicalCores); + m_ActiveCoreCount.fetch_add(Machine.LogicalCores); + m_AgentsActive.fetch_add(1); +} - // TEMP HACK - these should be made fully dynamic - // these are currently here to allow spawning the compute agent locally - // for debugging purposes (i.e with a local Horde Server+Agent setup) - ArgStrings.push_back(fmt::format("--port={}", m_Config.ZenServicePort)); - ArgStrings.push_back("--data-dir=c:\\temp\\123"); +void +HordeProvisioner::DrainAsyncAgent(AsyncAgentEntry& Entry) +{ + Entry.Draining = true; + m_EstimatedCoreCount.fetch_sub(Entry.CoreCount); + m_AgentsDraining.fetch_add(1); - if (!m_OrchestratorEndpoint.empty()) - { - ExtendableStringBuilder<256> CoordArg; - CoordArg << "--coordinator-endpoint=" << m_OrchestratorEndpoint; - ArgStrings.emplace_back(CoordArg.ToView()); - } + HttpClientSettings Settings; + Settings.LogCategory = "horde.drain"; + Settings.ConnectTimeout = std::chrono::milliseconds{5000}; + Settings.Timeout = std::chrono::milliseconds{10000}; - { - ExtendableStringBuilder<128> IdArg; - IdArg << "--instance-id=horde-" << Machine.LeaseId; - ArgStrings.emplace_back(IdArg.ToView()); - } + try + { + HttpClient Client(Entry.RemoteEndpoint, Settings); - std::vector<const char*> Args; - Args.reserve(ArgStrings.size()); - for (const std::string& Arg : ArgStrings) + HttpClient::Response Response = Client.Post("/compute/session/drain"); + if (!Response.IsSuccess()) { - Args.push_back(Arg.c_str()); + ZEN_WARN("drain[{}]: POST session/drain failed: HTTP {}", Entry.LeaseId, static_cast<int>(Response.StatusCode)); + return; } -#if ZEN_PLATFORM_WINDOWS - const bool UseWine = !Machine.IsWindows; - const char* AppName = "zenserver.exe"; -#else - const bool UseWine = false; - const char* AppName = "zenserver"; -#endif - - Agent->Execute(AppName, Args.data(), Args.size(), nullptr, nullptr, 0, UseWine); - - ZEN_INFO("remote execution started on [{}:{}] lease={}", - Machine.GetConnectionAddress(), - Machine.GetConnectionPort(), - Machine.LeaseId); - - MachineCoreCount = Machine.LogicalCores; - m_EstimatedCoreCount.fetch_add(MachineCoreCount); - m_ActiveCoreCount.fetch_add(MachineCoreCount); - m_AgentsActive.fetch_add(1); + ZEN_INFO("drain[{}]: session/drain accepted, sending sunset", Entry.LeaseId); + (void)Client.Post("/compute/session/sunset"); } + catch (const std::exception& Ex) + { + ZEN_WARN("drain[{}]: exception: {}", Entry.LeaseId, Ex.what()); + } +} - // Agent poll loop - - auto ActiveGuard = MakeGuard([&]() { - m_EstimatedCoreCount.fetch_sub(MachineCoreCount); - m_ActiveCoreCount.fetch_sub(MachineCoreCount); - m_AgentsActive.fetch_sub(1); - }); - - while (Agent->IsValid() && !Wrapper.ShouldExit.load()) +void +HordeProvisioner::OnAsyncAgentDone(std::shared_ptr<AsyncHordeAgent> Agent) +{ + std::lock_guard<std::mutex> Lock(m_AsyncAgentsLock); + for (auto It = m_AsyncAgents.begin(); It != m_AsyncAgents.end(); ++It) { - const bool LogOutput = false; - if (!Agent->Poll(LogOutput)) + if (It->Agent == Agent) { + if (It->Draining) + { + m_AgentsDraining.fetch_sub(1); + std::string WorkerId = "horde-" + It->LeaseId; + if (m_RecentlyDrainedWorkerIds.insert(WorkerId).second) + { + m_RecentlyDrainedOrder.push_back(WorkerId); + while (m_RecentlyDrainedOrder.size() > RecentlyDrainedCapacity) + { + m_RecentlyDrainedWorkerIds.erase(m_RecentlyDrainedOrder.front()); + m_RecentlyDrainedOrder.pop_front(); + } + } + } + m_AsyncAgents.erase(It); break; } - std::this_thread::sleep_for(std::chrono::milliseconds(100)); } } diff --git a/src/zenhorde/hordetransport.cpp b/src/zenhorde/hordetransport.cpp index 69766e73e..65eaea477 100644 --- a/src/zenhorde/hordetransport.cpp +++ b/src/zenhorde/hordetransport.cpp @@ -9,71 +9,33 @@ ZEN_THIRD_PARTY_INCLUDES_START #include <asio.hpp> ZEN_THIRD_PARTY_INCLUDES_END -#if ZEN_PLATFORM_WINDOWS -# undef SendMessage -#endif - namespace zen::horde { -// ComputeTransport base +// --- AsyncTcpComputeTransport --- -bool -ComputeTransport::SendMessage(const void* Data, size_t Size) +struct AsyncTcpComputeTransport::Impl { - const uint8_t* Ptr = static_cast<const uint8_t*>(Data); - size_t Remaining = Size; - - while (Remaining > 0) - { - const size_t Sent = Send(Ptr, Remaining); - if (Sent == 0) - { - return false; - } - Ptr += Sent; - Remaining -= Sent; - } + asio::io_context& IoContext; + asio::ip::tcp::socket Socket; - return true; -} + explicit Impl(asio::io_context& Ctx) : IoContext(Ctx), Socket(Ctx) {} +}; -bool -ComputeTransport::RecvMessage(void* Data, size_t Size) +AsyncTcpComputeTransport::AsyncTcpComputeTransport(asio::io_context& IoContext) +: m_Impl(std::make_unique<Impl>(IoContext)) +, m_Log(zen::logging::Get("horde.transport.async")) { - uint8_t* Ptr = static_cast<uint8_t*>(Data); - size_t Remaining = Size; - - while (Remaining > 0) - { - const size_t Received = Recv(Ptr, Remaining); - if (Received == 0) - { - return false; - } - Ptr += Received; - Remaining -= Received; - } - - return true; } -// TcpComputeTransport - ASIO pimpl - -struct TcpComputeTransport::Impl +AsyncTcpComputeTransport::~AsyncTcpComputeTransport() { - asio::io_context IoContext; - asio::ip::tcp::socket Socket; - - Impl() : Socket(IoContext) {} -}; + Close(); +} -// Uses ASIO in synchronous mode only — no async operations or io_context::run(). -// The io_context is only needed because ASIO sockets require one to be constructed. -TcpComputeTransport::TcpComputeTransport(const MachineInfo& Info) -: m_Impl(std::make_unique<Impl>()) -, m_Log(zen::logging::Get("horde.transport")) +void +AsyncTcpComputeTransport::AsyncConnect(const MachineInfo& Info, AsyncConnectHandler Handler) { - ZEN_TRACE_CPU("TcpComputeTransport::Connect"); + ZEN_TRACE_CPU("AsyncTcpComputeTransport::AsyncConnect"); asio::error_code Ec; @@ -82,80 +44,75 @@ TcpComputeTransport::TcpComputeTransport(const MachineInfo& Info) { ZEN_WARN("invalid address '{}': {}", Info.GetConnectionAddress(), Ec.message()); m_HasErrors = true; + asio::post(m_Impl->IoContext, [Handler = std::move(Handler), Ec] { Handler(Ec); }); return; } const asio::ip::tcp::endpoint Endpoint(Address, Info.GetConnectionPort()); - m_Impl->Socket.connect(Endpoint, Ec); - if (Ec) - { - ZEN_WARN("failed to connect to Horde compute [{}:{}]: {}", Info.GetConnectionAddress(), Info.GetConnectionPort(), Ec.message()); - m_HasErrors = true; - return; - } + // Copy the nonce so it survives past this scope into the async callback + auto NonceBuf = std::make_shared<std::vector<uint8_t>>(Info.Nonce, Info.Nonce + NonceSize); - // Disable Nagle's algorithm for lower latency - m_Impl->Socket.set_option(asio::ip::tcp::no_delay(true), Ec); -} + m_Impl->Socket.async_connect(Endpoint, [this, Handler = std::move(Handler), NonceBuf](const asio::error_code& Ec) mutable { + if (Ec) + { + ZEN_WARN("async connect failed: {}", Ec.message()); + m_HasErrors = true; + Handler(Ec); + return; + } -TcpComputeTransport::~TcpComputeTransport() -{ - Close(); + asio::error_code SetOptEc; + m_Impl->Socket.set_option(asio::ip::tcp::no_delay(true), SetOptEc); + + // Send the 64-byte nonce as the first thing on the wire + asio::async_write(m_Impl->Socket, + asio::buffer(*NonceBuf), + [this, Handler = std::move(Handler), NonceBuf](const asio::error_code& Ec, size_t /*BytesWritten*/) { + if (Ec) + { + ZEN_WARN("nonce write failed: {}", Ec.message()); + m_HasErrors = true; + } + Handler(Ec); + }); + }); } bool -TcpComputeTransport::IsValid() const +AsyncTcpComputeTransport::IsValid() const { return m_Impl && m_Impl->Socket.is_open() && !m_HasErrors && !m_IsClosed; } -size_t -TcpComputeTransport::Send(const void* Data, size_t Size) +void +AsyncTcpComputeTransport::AsyncWrite(const void* Data, size_t Size, AsyncIoHandler Handler) { if (!IsValid()) { - return 0; - } - - asio::error_code Ec; - const size_t Sent = m_Impl->Socket.send(asio::buffer(Data, Size), 0, Ec); - - if (Ec) - { - m_HasErrors = true; - return 0; + asio::post(m_Impl->IoContext, + [Handler = std::move(Handler)] { Handler(asio::error::make_error_code(asio::error::not_connected), 0); }); + return; } - return Sent; + asio::async_write(m_Impl->Socket, asio::buffer(Data, Size), std::move(Handler)); } -size_t -TcpComputeTransport::Recv(void* Data, size_t Size) +void +AsyncTcpComputeTransport::AsyncRead(void* Data, size_t Size, AsyncIoHandler Handler) { if (!IsValid()) { - return 0; - } - - asio::error_code Ec; - const size_t Received = m_Impl->Socket.receive(asio::buffer(Data, Size), 0, Ec); - - if (Ec) - { - return 0; + asio::post(m_Impl->IoContext, + [Handler = std::move(Handler)] { Handler(asio::error::make_error_code(asio::error::not_connected), 0); }); + return; } - return Received; -} - -void -TcpComputeTransport::MarkComplete() -{ + asio::async_read(m_Impl->Socket, asio::buffer(Data, Size), std::move(Handler)); } void -TcpComputeTransport::Close() +AsyncTcpComputeTransport::Close() { if (!m_IsClosed && m_Impl && m_Impl->Socket.is_open()) { diff --git a/src/zenhorde/hordetransport.h b/src/zenhorde/hordetransport.h index 1b178dc0f..b5e841d7a 100644 --- a/src/zenhorde/hordetransport.h +++ b/src/zenhorde/hordetransport.h @@ -8,55 +8,60 @@ #include <cstddef> #include <cstdint> +#include <functional> #include <memory> +#include <system_error> -#if ZEN_PLATFORM_WINDOWS -# undef SendMessage -#endif +namespace asio { +class io_context; +} namespace zen::horde { -/** Abstract base interface for compute transports. +/** Handler types for async transport operations. */ +using AsyncConnectHandler = std::function<void(const std::error_code&)>; +using AsyncIoHandler = std::function<void(const std::error_code&, size_t)>; + +/** Abstract base for asynchronous compute transports. * - * Matches the UE FComputeTransport pattern. Concrete implementations handle - * the underlying I/O (TCP, AES-wrapped, etc.) while this interface provides - * blocking message helpers on top. + * All callbacks are invoked on the io_context that was provided at construction. + * Callers are responsible for strand serialization if needed. */ -class ComputeTransport +class AsyncComputeTransport { public: - virtual ~ComputeTransport() = default; + virtual ~AsyncComputeTransport() = default; + + virtual bool IsValid() const = 0; - virtual bool IsValid() const = 0; - virtual size_t Send(const void* Data, size_t Size) = 0; - virtual size_t Recv(void* Data, size_t Size) = 0; - virtual void MarkComplete() = 0; - virtual void Close() = 0; + /** Asynchronous write of exactly Size bytes. Handler called on completion or error. */ + virtual void AsyncWrite(const void* Data, size_t Size, AsyncIoHandler Handler) = 0; - /** Blocking send that loops until all bytes are transferred. Returns false on error. */ - bool SendMessage(const void* Data, size_t Size); + /** Asynchronous read of exactly Size bytes into Data. Handler called on completion or error. */ + virtual void AsyncRead(void* Data, size_t Size, AsyncIoHandler Handler) = 0; - /** Blocking receive that loops until all bytes are transferred. Returns false on error. */ - bool RecvMessage(void* Data, size_t Size); + virtual void Close() = 0; }; -/** TCP socket transport using ASIO. +/** Async TCP transport using ASIO. * - * Connects to the Horde compute endpoint specified by MachineInfo and provides - * raw TCP send/receive. ASIO internals are hidden behind a pimpl to keep the - * header clean. + * Connects to the Horde compute endpoint and provides async send/receive. + * The socket is created on a caller-provided io_context (shared across agents). */ -class TcpComputeTransport final : public ComputeTransport +class AsyncTcpComputeTransport final : public AsyncComputeTransport { public: - explicit TcpComputeTransport(const MachineInfo& Info); - ~TcpComputeTransport() override; - - bool IsValid() const override; - size_t Send(const void* Data, size_t Size) override; - size_t Recv(void* Data, size_t Size) override; - void MarkComplete() override; - void Close() override; + /** Construct a transport on the given io_context. Does not connect yet. */ + explicit AsyncTcpComputeTransport(asio::io_context& IoContext); + ~AsyncTcpComputeTransport() override; + + /** Asynchronously connect to the endpoint and send the nonce. */ + void AsyncConnect(const MachineInfo& Info, AsyncConnectHandler Handler); + + bool IsValid() const override; + void AsyncWrite(const void* Data, size_t Size, AsyncIoHandler Handler) override; + void AsyncRead(void* Data, size_t Size, AsyncIoHandler Handler) override; + void Close() override; private: LoggerRef Log() { return m_Log; } diff --git a/src/zenhorde/hordetransportaes.cpp b/src/zenhorde/hordetransportaes.cpp index 505b6bde7..0b94a4397 100644 --- a/src/zenhorde/hordetransportaes.cpp +++ b/src/zenhorde/hordetransportaes.cpp @@ -5,9 +5,12 @@ #include <zencore/logging.h> #include <zencore/trace.h> +ZEN_THIRD_PARTY_INCLUDES_START +#include <asio.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + #include <algorithm> #include <cstring> -#include <random> #if ZEN_PLATFORM_WINDOWS # include <zencore/windows.h> @@ -22,315 +25,410 @@ ZEN_THIRD_PARTY_INCLUDES_END namespace zen::horde { -struct AesComputeTransport::CryptoContext -{ - uint8_t Key[KeySize] = {}; - uint8_t EncryptNonce[NonceBytes] = {}; - uint8_t DecryptNonce[NonceBytes] = {}; - bool HasErrors = false; - -#if !ZEN_PLATFORM_WINDOWS - EVP_CIPHER_CTX* EncCtx = nullptr; - EVP_CIPHER_CTX* DecCtx = nullptr; -#endif - - CryptoContext(const uint8_t (&InKey)[KeySize]) - { - memcpy(Key, InKey, KeySize); - - // The encrypt nonce is randomly initialized and then deterministically mutated - // per message via UpdateNonce(). The decrypt nonce is not used — it comes from - // the wire (each received message carries its own nonce in the header). - std::random_device Rd; - std::mt19937 Gen(Rd()); - std::uniform_int_distribution<int> Dist(0, 255); - for (auto& Byte : EncryptNonce) - { - Byte = static_cast<uint8_t>(Dist(Gen)); - } - -#if !ZEN_PLATFORM_WINDOWS - // Drain any stale OpenSSL errors - while (ERR_get_error() != 0) - { - } - - EncCtx = EVP_CIPHER_CTX_new(); - EVP_EncryptInit_ex(EncCtx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr); - - DecCtx = EVP_CIPHER_CTX_new(); - EVP_DecryptInit_ex(DecCtx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr); -#endif - } +namespace { - ~CryptoContext() - { -#if ZEN_PLATFORM_WINDOWS - SecureZeroMemory(Key, sizeof(Key)); - SecureZeroMemory(EncryptNonce, sizeof(EncryptNonce)); - SecureZeroMemory(DecryptNonce, sizeof(DecryptNonce)); -#else - OPENSSL_cleanse(Key, sizeof(Key)); - OPENSSL_cleanse(EncryptNonce, sizeof(EncryptNonce)); - OPENSSL_cleanse(DecryptNonce, sizeof(DecryptNonce)); - - if (EncCtx) - { - EVP_CIPHER_CTX_free(EncCtx); - } - if (DecCtx) - { - EVP_CIPHER_CTX_free(DecCtx); - } -#endif - } + static constexpr size_t AesNonceBytes = 12; + static constexpr size_t AesTagBytes = 16; - void UpdateNonce() + /** AES-256-GCM crypto context. Not exposed outside this translation unit. */ + struct AesCryptoContext { - uint32_t* N32 = reinterpret_cast<uint32_t*>(EncryptNonce); - N32[0]++; - N32[1]--; - N32[2] = N32[0] ^ N32[1]; - } + static constexpr size_t NonceBytes = AesNonceBytes; + static constexpr size_t TagBytes = AesTagBytes; - // Returns total encrypted message size, or 0 on failure - // Output format: [length(4B)][nonce(12B)][ciphertext][tag(16B)] - int32_t EncryptMessage(uint8_t* Out, const void* In, int32_t InLength) - { - UpdateNonce(); + uint8_t Key[KeySize] = {}; + uint8_t EncryptNonce[NonceBytes] = {}; + uint8_t DecryptNonce[NonceBytes] = {}; + uint64_t DecryptCounter = 0; ///< Sequence number of the next message to be decrypted (for diagnostics) + bool HasErrors = false; - // On Windows, BCrypt algorithm/key handles are created per call. This is simpler than - // caching but has some overhead. For our use case (relatively large, infrequent messages) - // this is acceptable. #if ZEN_PLATFORM_WINDOWS BCRYPT_ALG_HANDLE hAlg = nullptr; BCRYPT_KEY_HANDLE hKey = nullptr; +#else + EVP_CIPHER_CTX* EncCtx = nullptr; + EVP_CIPHER_CTX* DecCtx = nullptr; +#endif - BCryptOpenAlgorithmProvider(&hAlg, BCRYPT_AES_ALGORITHM, nullptr, 0); - BCryptSetProperty(hAlg, BCRYPT_CHAINING_MODE, (PUCHAR)BCRYPT_CHAIN_MODE_GCM, sizeof(BCRYPT_CHAIN_MODE_GCM), 0); - BCryptGenerateSymmetricKey(hAlg, &hKey, nullptr, 0, (PUCHAR)Key, KeySize, 0); - - BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO AuthInfo; - BCRYPT_INIT_AUTH_MODE_INFO(AuthInfo); - AuthInfo.pbNonce = EncryptNonce; - AuthInfo.cbNonce = NonceBytes; - uint8_t Tag[TagBytes] = {}; - AuthInfo.pbTag = Tag; - AuthInfo.cbTag = TagBytes; - - ULONG CipherLen = 0; - NTSTATUS Status = - BCryptEncrypt(hKey, (PUCHAR)In, (ULONG)InLength, &AuthInfo, nullptr, 0, Out + 4 + NonceBytes, (ULONG)InLength, &CipherLen, 0); - - if (!BCRYPT_SUCCESS(Status)) + AesCryptoContext(const uint8_t (&InKey)[KeySize]) { - HasErrors = true; - BCryptDestroyKey(hKey); - BCryptCloseAlgorithmProvider(hAlg, 0); - return 0; - } + memcpy(Key, InKey, KeySize); - // Write header: length + nonce - memcpy(Out, &InLength, 4); - memcpy(Out + 4, EncryptNonce, NonceBytes); - // Write tag after ciphertext - memcpy(Out + 4 + NonceBytes + CipherLen, Tag, TagBytes); + // EncryptNonce is zero-initialized (NIST SP 800-38D §8.2.1 deterministic + // construction): fixed_field = 0, counter starts at 0 and is incremented + // before each encryption by UpdateNonce(). No RNG is used here because + // std::random_device is not guaranteed to be a CSPRNG (historic MinGW, + // some WASI targets), and the deterministic construction does not need + // one as long as each session uses a unique key. - BCryptDestroyKey(hKey); - BCryptCloseAlgorithmProvider(hAlg, 0); - - return 4 + NonceBytes + static_cast<int32_t>(CipherLen) + TagBytes; +#if ZEN_PLATFORM_WINDOWS + NTSTATUS Status = BCryptOpenAlgorithmProvider(&hAlg, BCRYPT_AES_ALGORITHM, nullptr, 0); + if (!BCRYPT_SUCCESS(Status)) + { + ZEN_ERROR("BCryptOpenAlgorithmProvider failed: 0x{:08x}", static_cast<uint32_t>(Status)); + hAlg = nullptr; + HasErrors = true; + return; + } + + Status = BCryptSetProperty(hAlg, BCRYPT_CHAINING_MODE, (PUCHAR)BCRYPT_CHAIN_MODE_GCM, sizeof(BCRYPT_CHAIN_MODE_GCM), 0); + if (!BCRYPT_SUCCESS(Status)) + { + ZEN_ERROR("BCryptSetProperty(BCRYPT_CHAIN_MODE_GCM) failed: 0x{:08x}", static_cast<uint32_t>(Status)); + HasErrors = true; + return; + } + + Status = BCryptGenerateSymmetricKey(hAlg, &hKey, nullptr, 0, (PUCHAR)Key, KeySize, 0); + if (!BCRYPT_SUCCESS(Status)) + { + ZEN_ERROR("BCryptGenerateSymmetricKey failed: 0x{:08x}", static_cast<uint32_t>(Status)); + hKey = nullptr; + HasErrors = true; + return; + } #else - if (EVP_EncryptInit_ex(EncCtx, nullptr, nullptr, Key, EncryptNonce) != 1) - { - HasErrors = true; - return 0; + while (ERR_get_error() != 0) + { + } + + EncCtx = EVP_CIPHER_CTX_new(); + DecCtx = EVP_CIPHER_CTX_new(); + if (!EncCtx || !DecCtx) + { + ZEN_ERROR("EVP_CIPHER_CTX_new failed"); + HasErrors = true; + return; + } + + if (EVP_EncryptInit_ex(EncCtx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr) != 1) + { + ZEN_ERROR("EVP_EncryptInit_ex(aes-256-gcm) failed: {}", ERR_get_error()); + HasErrors = true; + return; + } + + if (EVP_DecryptInit_ex(DecCtx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr) != 1) + { + ZEN_ERROR("EVP_DecryptInit_ex(aes-256-gcm) failed: {}", ERR_get_error()); + HasErrors = true; + return; + } +#endif } - int32_t Offset = 0; - // Write length - memcpy(Out + Offset, &InLength, 4); - Offset += 4; - // Write nonce - memcpy(Out + Offset, EncryptNonce, NonceBytes); - Offset += NonceBytes; - - // Encrypt - int OutLen = 0; - if (EVP_EncryptUpdate(EncCtx, Out + Offset, &OutLen, static_cast<const uint8_t*>(In), InLength) != 1) + ~AesCryptoContext() { - HasErrors = true; - return 0; +#if ZEN_PLATFORM_WINDOWS + if (hKey) + { + BCryptDestroyKey(hKey); + } + if (hAlg) + { + BCryptCloseAlgorithmProvider(hAlg, 0); + } + SecureZeroMemory(Key, sizeof(Key)); + SecureZeroMemory(EncryptNonce, sizeof(EncryptNonce)); + SecureZeroMemory(DecryptNonce, sizeof(DecryptNonce)); +#else + OPENSSL_cleanse(Key, sizeof(Key)); + OPENSSL_cleanse(EncryptNonce, sizeof(EncryptNonce)); + OPENSSL_cleanse(DecryptNonce, sizeof(DecryptNonce)); + + if (EncCtx) + { + EVP_CIPHER_CTX_free(EncCtx); + } + if (DecCtx) + { + EVP_CIPHER_CTX_free(DecCtx); + } +#endif } - Offset += OutLen; - // Finalize - int FinalLen = 0; - if (EVP_EncryptFinal_ex(EncCtx, Out + Offset, &FinalLen) != 1) + void UpdateNonce() { + // NIST SP 800-38D §8.2.1 deterministic construction: + // nonce = [fixed_field (4 bytes) || invocation_counter (8 bytes, big-endian)] + // The low 8 bytes are a strict monotonic counter starting at zero. On 2^64 + // exhaustion the session is torn down (HasErrors) - never wrap, since a repeated + // (key, nonce) pair catastrophically breaks AES-GCM confidentiality and integrity. + for (int i = 11; i >= 4; --i) + { + if (++EncryptNonce[i] != 0) + { + return; + } + } HasErrors = true; - return 0; } - Offset += FinalLen; - // Get tag - if (EVP_CIPHER_CTX_ctrl(EncCtx, EVP_CTRL_GCM_GET_TAG, TagBytes, Out + Offset) != 1) + int32_t EncryptMessage(uint8_t* Out, const void* In, int32_t InLength) { - HasErrors = true; - return 0; - } - Offset += TagBytes; - - return Offset; -#endif - } + UpdateNonce(); + if (HasErrors) + { + return 0; + } - // Decrypt a message. Returns decrypted data length, or 0 on failure. - // Input must be [ciphertext][tag], with nonce provided separately. - int32_t DecryptMessage(void* Out, const uint8_t* Nonce, const uint8_t* CipherAndTag, int32_t DataLength) - { #if ZEN_PLATFORM_WINDOWS - BCRYPT_ALG_HANDLE hAlg = nullptr; - BCRYPT_KEY_HANDLE hKey = nullptr; - - BCryptOpenAlgorithmProvider(&hAlg, BCRYPT_AES_ALGORITHM, nullptr, 0); - BCryptSetProperty(hAlg, BCRYPT_CHAINING_MODE, (PUCHAR)BCRYPT_CHAIN_MODE_GCM, sizeof(BCRYPT_CHAIN_MODE_GCM), 0); - BCryptGenerateSymmetricKey(hAlg, &hKey, nullptr, 0, (PUCHAR)Key, KeySize, 0); - - BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO AuthInfo; - BCRYPT_INIT_AUTH_MODE_INFO(AuthInfo); - AuthInfo.pbNonce = const_cast<uint8_t*>(Nonce); - AuthInfo.cbNonce = NonceBytes; - AuthInfo.pbTag = const_cast<uint8_t*>(CipherAndTag + DataLength); - AuthInfo.cbTag = TagBytes; - - ULONG PlainLen = 0; - NTSTATUS Status = BCryptDecrypt(hKey, - (PUCHAR)CipherAndTag, - (ULONG)DataLength, - &AuthInfo, - nullptr, - 0, - (PUCHAR)Out, - (ULONG)DataLength, - &PlainLen, - 0); - - BCryptDestroyKey(hKey); - BCryptCloseAlgorithmProvider(hAlg, 0); - - if (!BCRYPT_SUCCESS(Status)) - { - HasErrors = true; - return 0; - } - - return static_cast<int32_t>(PlainLen); + BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO AuthInfo; + BCRYPT_INIT_AUTH_MODE_INFO(AuthInfo); + AuthInfo.pbNonce = EncryptNonce; + AuthInfo.cbNonce = NonceBytes; + // Tag is output-only on encrypt; BCryptEncrypt writes TagBytes bytes into it, so skip zero-init. + uint8_t Tag[TagBytes]; + AuthInfo.pbTag = Tag; + AuthInfo.cbTag = TagBytes; + + ULONG CipherLen = 0; + const NTSTATUS Status = BCryptEncrypt(hKey, + (PUCHAR)In, + (ULONG)InLength, + &AuthInfo, + nullptr, + 0, + Out + 4 + NonceBytes, + (ULONG)InLength, + &CipherLen, + 0); + + if (!BCRYPT_SUCCESS(Status)) + { + ZEN_ERROR("BCryptEncrypt failed: 0x{:08x}", static_cast<uint32_t>(Status)); + HasErrors = true; + return 0; + } + + memcpy(Out, &InLength, 4); + memcpy(Out + 4, EncryptNonce, NonceBytes); + memcpy(Out + 4 + NonceBytes + CipherLen, Tag, TagBytes); + + return 4 + NonceBytes + static_cast<int32_t>(CipherLen) + TagBytes; #else - if (EVP_DecryptInit_ex(DecCtx, nullptr, nullptr, Key, Nonce) != 1) - { - HasErrors = true; - return 0; - } - - int OutLen = 0; - if (EVP_DecryptUpdate(DecCtx, static_cast<uint8_t*>(Out), &OutLen, CipherAndTag, DataLength) != 1) - { - HasErrors = true; - return 0; + // Reset per message so any stale state from a previous encrypt (e.g. partial + // completion after a prior error) cannot bleed into this operation. Re-bind + // the cipher/key; the IV is then set via the normal init call below. + if (EVP_CIPHER_CTX_reset(EncCtx) != 1 || EVP_EncryptInit_ex(EncCtx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr) != 1) + { + ZEN_ERROR("EVP_CIPHER_CTX_reset/EncryptInit failed: {}", ERR_get_error()); + HasErrors = true; + return 0; + } + if (EVP_EncryptInit_ex(EncCtx, nullptr, nullptr, Key, EncryptNonce) != 1) + { + ZEN_ERROR("EVP_EncryptInit_ex(key+iv) failed: {}", ERR_get_error()); + HasErrors = true; + return 0; + } + + int32_t Offset = 0; + memcpy(Out + Offset, &InLength, 4); + Offset += 4; + memcpy(Out + Offset, EncryptNonce, NonceBytes); + Offset += NonceBytes; + + int OutLen = 0; + if (EVP_EncryptUpdate(EncCtx, Out + Offset, &OutLen, static_cast<const uint8_t*>(In), InLength) != 1) + { + ZEN_ERROR("EVP_EncryptUpdate failed: {}", ERR_get_error()); + HasErrors = true; + return 0; + } + Offset += OutLen; + + int FinalLen = 0; + if (EVP_EncryptFinal_ex(EncCtx, Out + Offset, &FinalLen) != 1) + { + ZEN_ERROR("EVP_EncryptFinal_ex failed: {}", ERR_get_error()); + HasErrors = true; + return 0; + } + Offset += FinalLen; + + if (EVP_CIPHER_CTX_ctrl(EncCtx, EVP_CTRL_GCM_GET_TAG, TagBytes, Out + Offset) != 1) + { + ZEN_ERROR("EVP_CTRL_GCM_GET_TAG failed: {}", ERR_get_error()); + HasErrors = true; + return 0; + } + Offset += TagBytes; + + return Offset; +#endif } - // Set the tag for verification - if (EVP_CIPHER_CTX_ctrl(DecCtx, EVP_CTRL_GCM_SET_TAG, TagBytes, const_cast<uint8_t*>(CipherAndTag + DataLength)) != 1) + int32_t DecryptMessage(void* Out, const uint8_t* Nonce, const uint8_t* CipherAndTag, int32_t DataLength) { - HasErrors = true; - return 0; +#if ZEN_PLATFORM_WINDOWS + BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO AuthInfo; + BCRYPT_INIT_AUTH_MODE_INFO(AuthInfo); + AuthInfo.pbNonce = const_cast<uint8_t*>(Nonce); + AuthInfo.cbNonce = NonceBytes; + AuthInfo.pbTag = const_cast<uint8_t*>(CipherAndTag + DataLength); + AuthInfo.cbTag = TagBytes; + + ULONG PlainLen = 0; + const NTSTATUS Status = BCryptDecrypt(hKey, + (PUCHAR)CipherAndTag, + (ULONG)DataLength, + &AuthInfo, + nullptr, + 0, + (PUCHAR)Out, + (ULONG)DataLength, + &PlainLen, + 0); + + if (!BCRYPT_SUCCESS(Status)) + { + // STATUS_AUTH_TAG_MISMATCH (0xC000A002) indicates GCM integrity failure - + // either in-flight corruption or active tampering. Log distinctly from + // other BCryptDecrypt failures so that tamper attempts are auditable. + static constexpr NTSTATUS STATUS_AUTH_TAG_MISMATCH_VAL = static_cast<NTSTATUS>(0xC000A002L); + if (Status == STATUS_AUTH_TAG_MISMATCH_VAL) + { + ZEN_ERROR("AES-GCM tag verification failed (seq={}): possible tampering or in-flight corruption", DecryptCounter); + } + else + { + ZEN_ERROR("BCryptDecrypt failed: 0x{:08x} (seq={})", static_cast<uint32_t>(Status), DecryptCounter); + } + HasErrors = true; + return 0; + } + + ++DecryptCounter; + return static_cast<int32_t>(PlainLen); +#else + // Same rationale as EncryptMessage: reset the context and re-bind the cipher + // before each decrypt to avoid stale state from a previous operation. + if (EVP_CIPHER_CTX_reset(DecCtx) != 1 || EVP_DecryptInit_ex(DecCtx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr) != 1) + { + ZEN_ERROR("EVP_CIPHER_CTX_reset/DecryptInit failed (seq={}): {}", DecryptCounter, ERR_get_error()); + HasErrors = true; + return 0; + } + if (EVP_DecryptInit_ex(DecCtx, nullptr, nullptr, Key, Nonce) != 1) + { + ZEN_ERROR("EVP_DecryptInit_ex (seq={}) failed: {}", DecryptCounter, ERR_get_error()); + HasErrors = true; + return 0; + } + + int OutLen = 0; + if (EVP_DecryptUpdate(DecCtx, static_cast<uint8_t*>(Out), &OutLen, CipherAndTag, DataLength) != 1) + { + ZEN_ERROR("EVP_DecryptUpdate failed (seq={}): {}", DecryptCounter, ERR_get_error()); + HasErrors = true; + return 0; + } + + if (EVP_CIPHER_CTX_ctrl(DecCtx, EVP_CTRL_GCM_SET_TAG, TagBytes, const_cast<uint8_t*>(CipherAndTag + DataLength)) != 1) + { + ZEN_ERROR("EVP_CTRL_GCM_SET_TAG failed (seq={}): {}", DecryptCounter, ERR_get_error()); + HasErrors = true; + return 0; + } + + int FinalLen = 0; + if (EVP_DecryptFinal_ex(DecCtx, static_cast<uint8_t*>(Out) + OutLen, &FinalLen) != 1) + { + // EVP_DecryptFinal_ex returns 0 specifically on GCM tag verification failure + // once the tag has been set. Log distinctly so tamper attempts are auditable. + ZEN_ERROR("AES-GCM tag verification failed (seq={}): possible tampering or in-flight corruption", DecryptCounter); + HasErrors = true; + return 0; + } + + ++DecryptCounter; + return OutLen + FinalLen; +#endif } + }; - int FinalLen = 0; - if (EVP_DecryptFinal_ex(DecCtx, static_cast<uint8_t*>(Out) + OutLen, &FinalLen) != 1) - { - HasErrors = true; - return 0; - } +} // anonymous namespace - return OutLen + FinalLen; -#endif - } +struct AsyncAesComputeTransport::CryptoContext : AesCryptoContext +{ + using AesCryptoContext::AesCryptoContext; }; -AesComputeTransport::AesComputeTransport(const uint8_t (&Key)[KeySize], std::unique_ptr<ComputeTransport> InnerTransport) +// --- AsyncAesComputeTransport --- + +AsyncAesComputeTransport::AsyncAesComputeTransport(const uint8_t (&Key)[KeySize], + std::unique_ptr<AsyncComputeTransport> InnerTransport, + asio::io_context& IoContext) : m_Crypto(std::make_unique<CryptoContext>(Key)) , m_Inner(std::move(InnerTransport)) +, m_IoContext(IoContext) { } -AesComputeTransport::~AesComputeTransport() +AsyncAesComputeTransport::~AsyncAesComputeTransport() { Close(); } bool -AesComputeTransport::IsValid() const +AsyncAesComputeTransport::IsValid() const { return m_Inner && m_Inner->IsValid() && m_Crypto && !m_Crypto->HasErrors && !m_IsClosed; } -size_t -AesComputeTransport::Send(const void* Data, size_t Size) +void +AsyncAesComputeTransport::AsyncWrite(const void* Data, size_t Size, AsyncIoHandler Handler) { - ZEN_TRACE_CPU("AesComputeTransport::Send"); - if (!IsValid()) { - return 0; + asio::post(m_IoContext, [Handler = std::move(Handler)] { Handler(asio::error::make_error_code(asio::error::not_connected), 0); }); + return; } - std::lock_guard<std::mutex> Lock(m_Lock); - const int32_t DataLength = static_cast<int32_t>(Size); - const size_t MessageLength = 4 + NonceBytes + Size + TagBytes; + const size_t MessageLength = 4 + CryptoContext::NonceBytes + Size + CryptoContext::TagBytes; - if (m_EncryptBuffer.size() < MessageLength) - { - m_EncryptBuffer.resize(MessageLength); - } + // Encrypt directly into the per-write buffer rather than a long-lived member. Using a + // member (plaintext + ciphertext share that buffer during encryption on the OpenSSL + // path) would leave plaintext on the heap indefinitely and would also make the + // transport unsafe if AsyncWrite were ever invoked concurrently. Size the shared_ptr + // exactly to EncryptedLen afterwards. + auto EncBuf = std::make_shared<std::vector<uint8_t>>(MessageLength); - const int32_t EncryptedLen = m_Crypto->EncryptMessage(m_EncryptBuffer.data(), Data, DataLength); + const int32_t EncryptedLen = m_Crypto->EncryptMessage(EncBuf->data(), Data, DataLength); if (EncryptedLen == 0) { - return 0; + asio::post(m_IoContext, + [Handler = std::move(Handler)] { Handler(asio::error::make_error_code(asio::error::connection_aborted), 0); }); + return; } - if (!m_Inner->SendMessage(m_EncryptBuffer.data(), static_cast<size_t>(EncryptedLen))) - { - return 0; - } + EncBuf->resize(static_cast<size_t>(EncryptedLen)); - return Size; + m_Inner->AsyncWrite( + EncBuf->data(), + EncBuf->size(), + [Handler = std::move(Handler), EncBuf, Size](const std::error_code& Ec, size_t /*BytesWritten*/) { Handler(Ec, Ec ? 0 : Size); }); } -size_t -AesComputeTransport::Recv(void* Data, size_t Size) +void +AsyncAesComputeTransport::AsyncRead(void* Data, size_t Size, AsyncIoHandler Handler) { if (!IsValid()) { - return 0; + asio::post(m_IoContext, [Handler = std::move(Handler)] { Handler(asio::error::make_error_code(asio::error::not_connected), 0); }); + return; } - // AES-GCM decrypts entire messages at once, but the caller may request fewer bytes - // than the decrypted message contains. Excess bytes are buffered in m_RemainingData - // and returned on subsequent Recv calls without another decryption round-trip. - ZEN_TRACE_CPU("AesComputeTransport::Recv"); - - std::lock_guard<std::mutex> Lock(m_Lock); + uint8_t* Dest = static_cast<uint8_t*>(Data); if (!m_RemainingData.empty()) { const size_t Available = m_RemainingData.size() - m_RemainingOffset; const size_t ToCopy = std::min(Available, Size); - memcpy(Data, m_RemainingData.data() + m_RemainingOffset, ToCopy); + memcpy(Dest, m_RemainingData.data() + m_RemainingOffset, ToCopy); m_RemainingOffset += ToCopy; if (m_RemainingOffset >= m_RemainingData.size()) @@ -339,82 +437,104 @@ AesComputeTransport::Recv(void* Data, size_t Size) m_RemainingOffset = 0; } - return ToCopy; - } - - // Receive packet header: [length(4B)][nonce(12B)] - struct PacketHeader - { - int32_t DataLength = 0; - uint8_t Nonce[NonceBytes] = {}; - } Header; - - if (!m_Inner->RecvMessage(&Header, sizeof(Header))) - { - return 0; - } - - // Validate DataLength to prevent OOM from malicious/corrupt peers - static constexpr int32_t MaxDataLength = 64 * 1024 * 1024; // 64 MiB - - if (Header.DataLength <= 0 || Header.DataLength > MaxDataLength) - { - ZEN_WARN("AES recv: invalid DataLength {} from peer", Header.DataLength); - return 0; - } - - // Receive ciphertext + tag - const size_t MessageLength = static_cast<size_t>(Header.DataLength) + TagBytes; - - if (m_EncryptBuffer.size() < MessageLength) - { - m_EncryptBuffer.resize(MessageLength); - } - - if (!m_Inner->RecvMessage(m_EncryptBuffer.data(), MessageLength)) - { - return 0; - } - - // Decrypt - const size_t BytesToReturn = std::min(static_cast<size_t>(Header.DataLength), Size); - - // We need a temporary buffer for decryption if we can't decrypt directly into output - std::vector<uint8_t> DecryptedBuf(static_cast<size_t>(Header.DataLength)); - - const int32_t Decrypted = m_Crypto->DecryptMessage(DecryptedBuf.data(), Header.Nonce, m_EncryptBuffer.data(), Header.DataLength); - if (Decrypted == 0) - { - return 0; - } - - memcpy(Data, DecryptedBuf.data(), BytesToReturn); + if (ToCopy == Size) + { + asio::post(m_IoContext, [Handler = std::move(Handler), Size] { Handler(std::error_code{}, Size); }); + return; + } - // Store remaining data if we couldn't return everything - if (static_cast<size_t>(Header.DataLength) > BytesToReturn) - { - m_RemainingOffset = 0; - m_RemainingData.assign(DecryptedBuf.begin() + BytesToReturn, DecryptedBuf.begin() + Header.DataLength); + DoRecvMessage(Dest + ToCopy, Size - ToCopy, std::move(Handler)); + return; } - return BytesToReturn; + DoRecvMessage(Dest, Size, std::move(Handler)); } void -AesComputeTransport::MarkComplete() +AsyncAesComputeTransport::DoRecvMessage(uint8_t* Dest, size_t Size, AsyncIoHandler Handler) { - if (IsValid()) - { - m_Inner->MarkComplete(); - } + static constexpr size_t HeaderSize = 4 + CryptoContext::NonceBytes; + auto HeaderBuf = std::make_shared<std::array<uint8_t, 4 + 12>>(); + + m_Inner->AsyncRead(HeaderBuf->data(), + HeaderSize, + [this, Dest, Size, Handler = std::move(Handler), HeaderBuf](const std::error_code& Ec, size_t /*Bytes*/) mutable { + if (Ec) + { + Handler(Ec, 0); + return; + } + + int32_t DataLength = 0; + memcpy(&DataLength, HeaderBuf->data(), 4); + + static constexpr int32_t MaxDataLength = 64 * 1024 * 1024; + if (DataLength <= 0 || DataLength > MaxDataLength) + { + Handler(asio::error::make_error_code(asio::error::invalid_argument), 0); + return; + } + + const size_t MessageLength = static_cast<size_t>(DataLength) + CryptoContext::TagBytes; + if (m_DecryptBuffer.size() < MessageLength) + { + m_DecryptBuffer.resize(MessageLength); + } + + auto NonceBuf = std::make_shared<std::array<uint8_t, CryptoContext::NonceBytes>>(); + memcpy(NonceBuf->data(), HeaderBuf->data() + 4, CryptoContext::NonceBytes); + + m_Inner->AsyncRead( + m_DecryptBuffer.data(), + MessageLength, + [this, Dest, Size, Handler = std::move(Handler), DataLength, NonceBuf](const std::error_code& Ec, + size_t /*Bytes*/) mutable { + if (Ec) + { + Handler(Ec, 0); + return; + } + + std::vector<uint8_t> PlaintextBuf(static_cast<size_t>(DataLength)); + const int32_t Decrypted = + m_Crypto->DecryptMessage(PlaintextBuf.data(), NonceBuf->data(), m_DecryptBuffer.data(), DataLength); + if (Decrypted == 0) + { + Handler(asio::error::make_error_code(asio::error::connection_aborted), 0); + return; + } + + const size_t BytesToReturn = std::min(static_cast<size_t>(Decrypted), Size); + memcpy(Dest, PlaintextBuf.data(), BytesToReturn); + + if (static_cast<size_t>(Decrypted) > BytesToReturn) + { + m_RemainingOffset = 0; + m_RemainingData.assign(PlaintextBuf.begin() + BytesToReturn, PlaintextBuf.begin() + Decrypted); + } + + if (BytesToReturn < Size) + { + DoRecvMessage(Dest + BytesToReturn, Size - BytesToReturn, std::move(Handler)); + } + else + { + Handler(std::error_code{}, Size); + } + }); + }); } void -AesComputeTransport::Close() +AsyncAesComputeTransport::Close() { if (!m_IsClosed) { - if (m_Inner && m_Inner->IsValid()) + // Always forward Close() to the inner transport if we have one. Gating on + // IsValid() skipped cleanup when the inner transport was partially torn down + // (e.g. after a read/write error marked it non-valid but left its socket open), + // leaking OS handles. Close implementations are expected to be idempotent. + if (m_Inner) { m_Inner->Close(); } diff --git a/src/zenhorde/hordetransportaes.h b/src/zenhorde/hordetransportaes.h index efcad9835..7846073dc 100644 --- a/src/zenhorde/hordetransportaes.h +++ b/src/zenhorde/hordetransportaes.h @@ -6,47 +6,54 @@ #include <cstdint> #include <memory> -#include <mutex> #include <vector> +namespace asio { +class io_context; +} + namespace zen::horde { -/** AES-256-GCM encrypted transport wrapper. +/** Async AES-256-GCM encrypted transport wrapper. * - * Wraps an inner ComputeTransport, encrypting all outgoing data and decrypting - * all incoming data using AES-256-GCM. The nonce is mutated per message using - * the Horde nonce mangling scheme: n32[0]++; n32[1]--; n32[2] = n32[0] ^ n32[1]. + * Wraps an AsyncComputeTransport, encrypting outgoing and decrypting incoming + * data using AES-256-GCM. Outgoing nonces follow the NIST SP 800-38D §8.2.1 + * deterministic construction: a 4-byte fixed field followed by an 8-byte + * big-endian monotonic counter. The session is torn down if the counter + * would wrap. * * Wire format per encrypted message: * [plaintext length (4B little-endian)][nonce (12B)][ciphertext][GCM tag (16B)] * * Uses BCrypt on Windows and OpenSSL EVP on Linux/macOS (selected at compile time). + * + * Thread safety: all operations must be serialized by the caller (e.g. via a strand). */ -class AesComputeTransport final : public ComputeTransport +class AsyncAesComputeTransport final : public AsyncComputeTransport { public: - AesComputeTransport(const uint8_t (&Key)[KeySize], std::unique_ptr<ComputeTransport> InnerTransport); - ~AesComputeTransport() override; + AsyncAesComputeTransport(const uint8_t (&Key)[KeySize], + std::unique_ptr<AsyncComputeTransport> InnerTransport, + asio::io_context& IoContext); + ~AsyncAesComputeTransport() override; - bool IsValid() const override; - size_t Send(const void* Data, size_t Size) override; - size_t Recv(void* Data, size_t Size) override; - void MarkComplete() override; - void Close() override; + bool IsValid() const override; + void AsyncWrite(const void* Data, size_t Size, AsyncIoHandler Handler) override; + void AsyncRead(void* Data, size_t Size, AsyncIoHandler Handler) override; + void Close() override; private: - static constexpr size_t NonceBytes = 12; ///< AES-GCM nonce size - static constexpr size_t TagBytes = 16; ///< AES-GCM authentication tag size + void DoRecvMessage(uint8_t* Dest, size_t Size, AsyncIoHandler Handler); struct CryptoContext; - std::unique_ptr<CryptoContext> m_Crypto; - std::unique_ptr<ComputeTransport> m_Inner; - std::vector<uint8_t> m_EncryptBuffer; - std::vector<uint8_t> m_RemainingData; ///< Buffered decrypted data from a partially consumed Recv - size_t m_RemainingOffset = 0; - std::mutex m_Lock; - bool m_IsClosed = false; + std::unique_ptr<CryptoContext> m_Crypto; + std::unique_ptr<AsyncComputeTransport> m_Inner; + asio::io_context& m_IoContext; + std::vector<uint8_t> m_DecryptBuffer; + std::vector<uint8_t> m_RemainingData; + size_t m_RemainingOffset = 0; + bool m_IsClosed = false; }; } // namespace zen::horde diff --git a/src/zenhorde/include/zenhorde/hordeclient.h b/src/zenhorde/include/zenhorde/hordeclient.h index 201d68b83..87caec019 100644 --- a/src/zenhorde/include/zenhorde/hordeclient.h +++ b/src/zenhorde/include/zenhorde/hordeclient.h @@ -45,14 +45,15 @@ struct MachineInfo uint8_t Key[KeySize] = {}; ///< 32-byte AES key (when EncryptionMode == AES) bool IsWindows = false; std::string LeaseId; + std::string Pool; std::map<std::string, PortInfo> Ports; /** Return the address to connect to, accounting for connection mode. */ - const std::string& GetConnectionAddress() const { return Mode == ConnectionMode::Relay ? ConnectionAddress : Ip; } + [[nodiscard]] const std::string& GetConnectionAddress() const { return Mode == ConnectionMode::Relay ? ConnectionAddress : Ip; } /** Return the port to connect to, accounting for connection mode and port mapping. */ - uint16_t GetConnectionPort() const + [[nodiscard]] uint16_t GetConnectionPort() const { if (Mode == ConnectionMode::Relay) { @@ -65,7 +66,20 @@ struct MachineInfo return Port; } - bool IsValid() const { return !Ip.empty() && Port != 0xFFFF; } + /** Return the address and port for the Zen service endpoint, accounting for relay port mapping. */ + [[nodiscard]] std::pair<const std::string&, uint16_t> GetZenServiceEndpoint(uint16_t DefaultPort) const + { + if (Mode == ConnectionMode::Relay) + { + if (auto It = Ports.find("ZenPort"); It != Ports.end()) + { + return {ConnectionAddress, It->second.Port}; + } + } + return {Ip, DefaultPort}; + } + + [[nodiscard]] bool IsValid() const { return !Ip.empty() && Port != 0xFFFF; } }; /** Result of cluster auto-resolution via the Horde API. */ @@ -83,31 +97,29 @@ struct ClusterInfo class HordeClient { public: - explicit HordeClient(const HordeConfig& Config); + explicit HordeClient(HordeConfig Config); ~HordeClient(); HordeClient(const HordeClient&) = delete; HordeClient& operator=(const HordeClient&) = delete; /** Initialize the underlying HTTP client. Must be called before other methods. */ - bool Initialize(); + [[nodiscard]] bool Initialize(); /** Build the JSON request body for cluster resolution and machine requests. * Encodes pool, condition, connection mode, encryption, and port requirements. */ - std::string BuildRequestBody() const; + [[nodiscard]] std::string BuildRequestBody() const; /** Resolve the best cluster for the given request via POST /api/v2/compute/_cluster. */ - bool ResolveCluster(const std::string& RequestBody, ClusterInfo& OutCluster); + [[nodiscard]] bool ResolveCluster(const std::string& RequestBody, ClusterInfo& OutCluster); /** Request a compute machine from the given cluster via POST /api/v2/compute/{clusterId}. * On success, populates OutMachine with connection details and credentials. */ - bool RequestMachine(const std::string& RequestBody, const std::string& ClusterId, MachineInfo& OutMachine); + [[nodiscard]] bool RequestMachine(const std::string& RequestBody, const std::string& ClusterId, MachineInfo& OutMachine); LoggerRef Log() { return m_Log; } private: - bool ParseHexBytes(std::string_view Hex, uint8_t* Out, size_t OutSize); - HordeConfig m_Config; std::unique_ptr<zen::HttpClient> m_Http; LoggerRef m_Log; diff --git a/src/zenhorde/include/zenhorde/hordeconfig.h b/src/zenhorde/include/zenhorde/hordeconfig.h index dd70f9832..3a4dfb386 100644 --- a/src/zenhorde/include/zenhorde/hordeconfig.h +++ b/src/zenhorde/include/zenhorde/hordeconfig.h @@ -4,6 +4,10 @@ #include <zenhorde/zenhorde.h> +#include <zenhttp/httpclient.h> + +#include <functional> +#include <optional> #include <string> namespace zen::horde { @@ -33,20 +37,25 @@ struct HordeConfig static constexpr const char* ClusterDefault = "default"; static constexpr const char* ClusterAuto = "_auto"; - bool Enabled = false; ///< Whether Horde provisioning is active - std::string ServerUrl; ///< Horde server base URL (e.g. "https://horde.example.com") - std::string AuthToken; ///< Authentication token for the Horde API - std::string Pool; ///< Pool name to request machines from - std::string Cluster = ClusterDefault; ///< Cluster ID, or "_auto" to auto-resolve - std::string Condition; ///< Agent filter expression for machine selection - std::string HostAddress; ///< Address that provisioned agents use to connect back to us - std::string BinariesPath; ///< Path to directory containing zenserver binary for remote upload - uint16_t ZenServicePort = 8558; ///< Port number that provisioned agents should forward to us for Zen service communication - - int MaxCores = 2048; - bool AllowWine = true; ///< Allow running Windows binaries under Wine on Linux agents - ConnectionMode Mode = ConnectionMode::Direct; - Encryption EncryptionMode = Encryption::None; + bool Enabled = false; ///< Whether Horde provisioning is active + std::string ServerUrl; ///< Horde server base URL (e.g. "https://horde.example.com") + std::string AuthToken; ///< Authentication token for the Horde API (static fallback) + + /// Optional token provider with automatic refresh (e.g. from OidcToken executable). + /// When set, takes priority over the static AuthToken string. + std::optional<std::function<HttpClientAccessToken()>> AccessTokenProvider; + std::string Pool; ///< Pool name to request machines from + std::string Cluster = ClusterDefault; ///< Cluster ID, or "_auto" to auto-resolve + std::string Condition; ///< Agent filter expression for machine selection + std::string HostAddress; ///< Address that provisioned agents use to connect back to us + std::string BinariesPath; ///< Path to directory containing zenserver binary for remote upload + uint16_t ZenServicePort = 8558; ///< Port number that provisioned agents should forward to us for Zen service communication + + int MaxCores = 2048; + int DrainGracePeriodSeconds = 300; ///< Grace period for draining agents before force-kill (default 5 min) + bool AllowWine = true; ///< Allow running Windows binaries under Wine on Linux agents + ConnectionMode Mode = ConnectionMode::Direct; + Encryption EncryptionMode = Encryption::None; /** Validate the configuration. Returns false if the configuration is invalid * (e.g. Relay mode without AES encryption). */ diff --git a/src/zenhorde/include/zenhorde/hordeprovisioner.h b/src/zenhorde/include/zenhorde/hordeprovisioner.h index 4e2e63bbd..ea2fd7783 100644 --- a/src/zenhorde/include/zenhorde/hordeprovisioner.h +++ b/src/zenhorde/include/zenhorde/hordeprovisioner.h @@ -2,21 +2,32 @@ #pragma once +#include <zenhorde/hordeclient.h> #include <zenhorde/hordeconfig.h> +#include <zencompute/provisionerstate.h> #include <zencore/logbase.h> +#include <zencore/thread.h> #include <atomic> #include <cstdint> +#include <deque> #include <filesystem> #include <memory> #include <mutex> #include <string> +#include <thread> +#include <unordered_set> #include <vector> +namespace asio { +class io_context; +} + namespace zen::horde { class HordeClient; +class AsyncHordeAgent; /** Snapshot of the current provisioning state, returned by HordeProvisioner::GetStats(). */ struct ProvisioningStats @@ -35,13 +46,12 @@ struct ProvisioningStats * binary, and executing it remotely. Each provisioned machine runs zenserver * in compute mode, which announces itself back to the orchestrator. * - * Spawns one thread per agent. Each thread handles the full lifecycle: - * HTTP request -> TCP connect -> nonce handshake -> optional AES encryption -> - * channel setup -> binary upload -> remote execution -> poll until exit. + * Agent work (HTTP request, connect, upload, poll) is dispatched to a thread + * pool rather than spawning a dedicated thread per agent. * * Thread safety: SetTargetCoreCount and GetStats may be called from any thread. */ -class HordeProvisioner +class HordeProvisioner : public compute::IProvisionerStateProvider { public: /** Construct a provisioner. @@ -52,38 +62,48 @@ public: HordeProvisioner(const HordeConfig& Config, const std::filesystem::path& BinariesPath, const std::filesystem::path& WorkingDir, - std::string_view OrchestratorEndpoint); + std::string_view OrchestratorEndpoint, + std::string_view CoordinatorSession = {}, + bool CleanStart = false, + std::string_view TraceHost = {}); - /** Signals all agent threads to exit and joins them. */ - ~HordeProvisioner(); + /** Signals all agents to exit and waits for completion. */ + ~HordeProvisioner() override; HordeProvisioner(const HordeProvisioner&) = delete; HordeProvisioner& operator=(const HordeProvisioner&) = delete; /** Set the target number of cores to provision. - * Clamped to HordeConfig::MaxCores. Spawns new agent threads if the - * estimated core count is below the target. Also joins any finished - * agent threads. */ - void SetTargetCoreCount(uint32_t Count); + * Clamped to HordeConfig::MaxCores. Dispatches new agent work if the + * estimated core count is below the target. Also removes finished agents. */ + void SetTargetCoreCount(uint32_t Count) override; /** Return a snapshot of the current provisioning counters. */ ProvisioningStats GetStats() const; - uint32_t GetActiveCoreCount() const { return m_ActiveCoreCount.load(); } - uint32_t GetAgentCount() const; + // IProvisionerStateProvider + std::string_view GetName() const override { return "horde"; } + uint32_t GetTargetCoreCount() const override { return m_TargetCoreCount.load(); } + uint32_t GetEstimatedCoreCount() const override { return m_EstimatedCoreCount.load(); } + uint32_t GetActiveCoreCount() const override { return m_ActiveCoreCount.load(); } + uint32_t GetAgentCount() const override; + uint32_t GetDrainingAgentCount() const override { return m_AgentsDraining.load(); } + compute::AgentProvisioningStatus GetAgentStatus(std::string_view WorkerId) const override; private: LoggerRef Log() { return m_Log; } - struct AgentWrapper; - void RequestAgent(); - void ThreadAgent(AgentWrapper& Wrapper); + void ProvisionAgent(); + bool InitializeHordeClient(); HordeConfig m_Config; std::filesystem::path m_BinariesPath; std::filesystem::path m_WorkingDir; std::string m_OrchestratorEndpoint; + std::string m_CoordinatorSession; + bool m_CleanStart = false; + std::string m_TraceHost; std::unique_ptr<HordeClient> m_HordeClient; @@ -91,20 +111,54 @@ private: std::vector<std::pair<std::string, std::filesystem::path>> m_Bundles; ///< (locator, bundleDir) pairs bool m_BundlesCreated = false; - mutable std::mutex m_AgentsLock; - std::vector<std::unique_ptr<AgentWrapper>> m_Agents; - std::atomic<uint64_t> m_LastRequestFailTime{0}; std::atomic<uint32_t> m_TargetCoreCount{0}; std::atomic<uint32_t> m_EstimatedCoreCount{0}; std::atomic<uint32_t> m_ActiveCoreCount{0}; std::atomic<uint32_t> m_AgentsActive{0}; + std::atomic<uint32_t> m_AgentsDraining{0}; std::atomic<uint32_t> m_AgentsRequesting{0}; std::atomic<bool> m_AskForAgents{true}; + std::atomic<uint32_t> m_PendingWorkItems{0}; + Event m_AllWorkDone; + /** Manual-reset event set alongside m_AskForAgents=false so pool-thread backoff waits + * wake immediately on shutdown instead of polling a 100ms sleep. */ + Event m_ShutdownEvent; LoggerRef m_Log; + // Async I/O + std::unique_ptr<asio::io_context> m_IoContext; + std::vector<std::thread> m_IoThreads; + + struct AsyncAgentEntry + { + std::shared_ptr<AsyncHordeAgent> Agent; + std::string RemoteEndpoint; + std::string LeaseId; + uint16_t CoreCount = 0; + bool Draining = false; + }; + + mutable std::mutex m_AsyncAgentsLock; + std::vector<AsyncAgentEntry> m_AsyncAgents; + + /** Worker IDs of agents that completed after draining. + * GetAgentStatus() consumes entries when queried, but if no one queries, entries would + * otherwise accumulate unbounded across the lifetime of the provisioner. Cap the set + * at RecentlyDrainedCapacity by evicting the oldest entry (tracked in an insertion-order + * queue) whenever we insert past the limit. */ + mutable std::unordered_set<std::string> m_RecentlyDrainedWorkerIds; + mutable std::deque<std::string> m_RecentlyDrainedOrder; + static constexpr size_t RecentlyDrainedCapacity = 256; + + void OnAsyncAgentDone(std::shared_ptr<AsyncHordeAgent> Agent); + void DrainAsyncAgent(AsyncAgentEntry& Entry); + + std::vector<std::string> BuildAgentArgs(const MachineInfo& Machine) const; + static constexpr uint32_t EstimatedCoresPerAgent = 32; + static constexpr int IoThreadCount = 3; }; } // namespace zen::horde diff --git a/src/zenhorde/xmake.lua b/src/zenhorde/xmake.lua index 48d028e86..0e69e9c5f 100644 --- a/src/zenhorde/xmake.lua +++ b/src/zenhorde/xmake.lua @@ -14,7 +14,7 @@ target('zenhorde') end if is_plat("linux") or is_plat("macosx") then - add_packages("openssl") + add_packages("openssl3") end if is_os("macosx") then |