diff options
| author | Liam Mitchell <[email protected]> | 2026-03-09 19:06:36 -0700 |
|---|---|---|
| committer | Liam Mitchell <[email protected]> | 2026-03-09 19:06:36 -0700 |
| commit | d1abc50ee9d4fb72efc646e17decafea741caa34 (patch) | |
| tree | e4288e00f2f7ca0391b83d986efcb69d3ba66a83 /src/zenhorde | |
| parent | Allow requests with invalid content-types unless specified in command line or... (diff) | |
| parent | updated chunk–block analyser (#818) (diff) | |
| download | zen-d1abc50ee9d4fb72efc646e17decafea741caa34.tar.xz zen-d1abc50ee9d4fb72efc646e17decafea741caa34.zip | |
Merge branch 'main' into lm/restrict-content-type
Diffstat (limited to 'src/zenhorde')
24 files changed, 4359 insertions, 0 deletions
diff --git a/src/zenhorde/hordeagent.cpp b/src/zenhorde/hordeagent.cpp new file mode 100644 index 000000000..819b2d0cb --- /dev/null +++ b/src/zenhorde/hordeagent.cpp @@ -0,0 +1,297 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "hordeagent.h" +#include "hordetransportaes.h" + +#include <zencore/basicfile.h> +#include <zencore/fmtutils.h> +#include <zencore/logging.h> +#include <zencore/trace.h> + +#include <cstring> +#include <unordered_map> + +namespace zen::horde { + +HordeAgent::HordeAgent(const MachineInfo& Info) : m_Log(zen::logging::Get("horde.agent")), m_MachineInfo(Info) +{ + ZEN_TRACE_CPU("HordeAgent::Connect"); + + auto Transport = std::make_unique<TcpComputeTransport>(Info); + if (!Transport->IsValid()) + { + ZEN_WARN("failed to create TCP transport to '{}:{}'", Info.GetConnectionAddress(), Info.GetConnectionPort()); + return; + } + + // The 64-byte nonce is always sent unencrypted as the first thing on the wire. + // The Horde agent uses this to identify which lease this connection belongs to. + Transport->Send(Info.Nonce, sizeof(Info.Nonce)); + + std::unique_ptr<ComputeTransport> FinalTransport = std::move(Transport); + if (Info.EncryptionMode == Encryption::AES) + { + FinalTransport = std::make_unique<AesComputeTransport>(Info.Key, std::move(FinalTransport)); + if (!FinalTransport->IsValid()) + { + ZEN_WARN("failed to create AES transport"); + return; + } + } + + // Create multiplexed socket and channels + m_Socket = std::make_unique<ComputeSocket>(std::move(FinalTransport)); + + // Channel 0 is the agent control channel (handles Attach/Fork handshake). + // Channel 100 is the child I/O channel (handles file upload and remote execution). + Ref<ComputeChannel> AgentComputeChannel = m_Socket->CreateChannel(0); + Ref<ComputeChannel> ChildComputeChannel = m_Socket->CreateChannel(100); + + if (!AgentComputeChannel || !ChildComputeChannel) + { + ZEN_WARN("failed to create compute channels"); + return; + } + + m_AgentChannel = std::make_unique<AgentMessageChannel>(std::move(AgentComputeChannel)); + m_ChildChannel = std::make_unique<AgentMessageChannel>(std::move(ChildComputeChannel)); + + m_IsValid = true; +} + +HordeAgent::~HordeAgent() +{ + CloseConnection(); +} + +bool +HordeAgent::BeginCommunication() +{ + ZEN_TRACE_CPU("HordeAgent::BeginCommunication"); + + if (!m_IsValid) + { + return false; + } + + // Start the send/recv pump threads + m_Socket->StartCommunication(); + + // Wait for Attach on agent channel + AgentMessageType Type = m_AgentChannel->ReadResponse(5000); + if (Type == AgentMessageType::None) + { + ZEN_WARN("timed out waiting for Attach on agent channel"); + return false; + } + if (Type != AgentMessageType::Attach) + { + ZEN_WARN("expected Attach on agent channel, got 0x{:02x}", static_cast<int>(Type)); + return false; + } + + // Fork tells the remote agent to create child channel 100 with a 4MB buffer. + // After this, the agent will send an Attach on the child channel. + m_AgentChannel->Fork(100, 4 * 1024 * 1024); + + // Wait for Attach on child channel + Type = m_ChildChannel->ReadResponse(5000); + if (Type == AgentMessageType::None) + { + ZEN_WARN("timed out waiting for Attach on child channel"); + return false; + } + if (Type != AgentMessageType::Attach) + { + ZEN_WARN("expected Attach on child channel, got 0x{:02x}", static_cast<int>(Type)); + return false; + } + + return true; +} + +bool +HordeAgent::UploadBinaries(const std::filesystem::path& BundleDir, const std::string& BundleLocator) +{ + ZEN_TRACE_CPU("HordeAgent::UploadBinaries"); + + m_ChildChannel->UploadFiles("", BundleLocator.c_str()); + + std::unordered_map<std::string, std::unique_ptr<BasicFile>> BlobFiles; + + auto FindOrOpenBlob = [&](std::string_view Locator) -> BasicFile* { + std::string Key(Locator); + + if (auto It = BlobFiles.find(Key); It != BlobFiles.end()) + { + return It->second.get(); + } + + const std::filesystem::path Path = BundleDir / (Key + ".blob"); + std::error_code Ec; + auto File = std::make_unique<BasicFile>(); + File->Open(Path, BasicFile::Mode::kRead, Ec); + + if (Ec) + { + ZEN_ERROR("cannot read blob file: '{}'", Path); + return nullptr; + } + + BasicFile* Ptr = File.get(); + BlobFiles.emplace(std::move(Key), std::move(File)); + return Ptr; + }; + + // The upload protocol is request-driven: we send WriteFiles, then the remote agent + // sends ReadBlob requests for each blob it needs. We respond with Blob data until + // the agent sends WriteFilesResponse indicating the upload is complete. + constexpr int32_t ReadResponseTimeoutMs = 1000; + + for (;;) + { + bool TimedOut = false; + + if (AgentMessageType Type = m_ChildChannel->ReadResponse(ReadResponseTimeoutMs, &TimedOut); Type != AgentMessageType::ReadBlob) + { + if (TimedOut) + { + continue; + } + // End of stream - check if it was a successful upload + if (Type == AgentMessageType::WriteFilesResponse) + { + return true; + } + else if (Type == AgentMessageType::Exception) + { + ExceptionInfo Ex; + m_ChildChannel->ReadException(Ex); + ZEN_ERROR("upload exception: {} - {}", Ex.Message, Ex.Description); + } + else + { + ZEN_ERROR("unexpected message type 0x{:02x} during upload", static_cast<int>(Type)); + } + return false; + } + + BlobRequest Req; + m_ChildChannel->ReadBlobRequest(Req); + + BasicFile* File = FindOrOpenBlob(Req.Locator); + if (!File) + { + return false; + } + + // Read from offset to end of file + const uint64_t TotalSize = File->FileSize(); + const uint64_t Offset = static_cast<uint64_t>(Req.Offset); + if (Offset >= TotalSize) + { + ZEN_ERROR("upload got request for data beyond end of file: offset={}, length={}, total_size={}", Offset, Req.Length, TotalSize); + m_ChildChannel->Blob(nullptr, 0); + continue; + } + + const IoBuffer Data = File->ReadRange(Offset, Min(Req.Length, TotalSize - Offset)); + m_ChildChannel->Blob(static_cast<const uint8_t*>(Data.GetData()), Data.GetSize()); + } +} + +void +HordeAgent::Execute(const char* Exe, + const char* const* Args, + size_t NumArgs, + const char* WorkingDir, + const char* const* EnvVars, + size_t NumEnvVars, + bool UseWine) +{ + ZEN_TRACE_CPU("HordeAgent::Execute"); + m_ChildChannel + ->Execute(Exe, Args, NumArgs, WorkingDir, EnvVars, NumEnvVars, UseWine ? ExecuteProcessFlags::UseWine : ExecuteProcessFlags::None); +} + +bool +HordeAgent::Poll(bool LogOutput) +{ + constexpr int32_t ReadResponseTimeoutMs = 100; + AgentMessageType Type; + + while ((Type = m_ChildChannel->ReadResponse(ReadResponseTimeoutMs)) != AgentMessageType::None) + { + switch (Type) + { + case AgentMessageType::ExecuteOutput: + { + if (LogOutput && m_ChildChannel->GetResponseSize() > 0) + { + const char* ResponseData = static_cast<const char*>(m_ChildChannel->GetResponseData()); + size_t ResponseSize = m_ChildChannel->GetResponseSize(); + + // Trim trailing newlines + while (ResponseSize > 0 && (ResponseData[ResponseSize - 1] == '\n' || ResponseData[ResponseSize - 1] == '\r')) + { + --ResponseSize; + } + + if (ResponseSize > 0) + { + const std::string_view Output(ResponseData, ResponseSize); + ZEN_INFO("[remote] {}", Output); + } + } + break; + } + + case AgentMessageType::ExecuteResult: + { + if (m_ChildChannel->GetResponseSize() == sizeof(int32_t)) + { + int32_t ExitCode; + memcpy(&ExitCode, m_ChildChannel->GetResponseData(), sizeof(int32_t)); + ZEN_INFO("remote process exited with code {}", ExitCode); + } + m_IsValid = false; + return false; + } + + case AgentMessageType::Exception: + { + ExceptionInfo Ex; + m_ChildChannel->ReadException(Ex); + ZEN_ERROR("exception: {} - {}", Ex.Message, Ex.Description); + m_HasErrors = true; + break; + } + + default: + break; + } + } + + return m_IsValid && !m_HasErrors; +} + +void +HordeAgent::CloseConnection() +{ + if (m_ChildChannel) + { + m_ChildChannel->Close(); + } + if (m_AgentChannel) + { + m_AgentChannel->Close(); + } +} + +bool +HordeAgent::IsValid() const +{ + return m_IsValid && !m_HasErrors; +} + +} // namespace zen::horde diff --git a/src/zenhorde/hordeagent.h b/src/zenhorde/hordeagent.h new file mode 100644 index 000000000..e0ae89ead --- /dev/null +++ b/src/zenhorde/hordeagent.h @@ -0,0 +1,77 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "hordeagentmessage.h" +#include "hordecomputesocket.h" + +#include <zenhorde/hordeclient.h> + +#include <zencore/logbase.h> + +#include <filesystem> +#include <memory> +#include <string> + +namespace zen::horde { + +/** Manages the lifecycle of a single Horde compute agent. + * + * Handles the full connection sequence for one provisioned machine: + * 1. Connect via TCP transport (with optional AES encryption wrapping) + * 2. Create a multiplexed ComputeSocket with agent (channel 0) and child (channel 100) + * 3. Perform the Attach/Fork handshake to establish the child channel + * 4. Upload zenserver binary via the WriteFiles/ReadBlob protocol + * 5. Execute zenserver remotely via ExecuteV2 + * 6. Poll for ExecuteOutput (stdout) and ExecuteResult (exit code) + */ +class HordeAgent +{ +public: + explicit HordeAgent(const MachineInfo& Info); + ~HordeAgent(); + + HordeAgent(const HordeAgent&) = delete; + HordeAgent& operator=(const HordeAgent&) = delete; + + /** Perform the channel setup handshake (Attach on agent channel, Fork, Attach on child channel). + * Returns false if the handshake times out or receives an unexpected message. */ + bool BeginCommunication(); + + /** Upload binary files to the remote agent. + * @param BundleDir Directory containing .blob files. + * @param BundleLocator Locator string identifying the bundle (from CreateBundle). */ + bool UploadBinaries(const std::filesystem::path& BundleDir, const std::string& BundleLocator); + + /** Execute a command on the remote machine. */ + void Execute(const char* Exe, + const char* const* Args, + size_t NumArgs, + const char* WorkingDir = nullptr, + const char* const* EnvVars = nullptr, + size_t NumEnvVars = 0, + bool UseWine = false); + + /** Poll for output and results. Returns true if the agent is still running. + * When LogOutput is true, remote stdout is logged via ZEN_INFO. */ + bool Poll(bool LogOutput = true); + + void CloseConnection(); + bool IsValid() const; + + const MachineInfo& GetMachineInfo() const { return m_MachineInfo; } + +private: + LoggerRef Log() { return m_Log; } + + std::unique_ptr<ComputeSocket> m_Socket; + std::unique_ptr<AgentMessageChannel> m_AgentChannel; ///< Channel 0: agent control + std::unique_ptr<AgentMessageChannel> m_ChildChannel; ///< Channel 100: child I/O + + LoggerRef m_Log; + bool m_IsValid = false; + bool m_HasErrors = false; + MachineInfo m_MachineInfo; +}; + +} // namespace zen::horde diff --git a/src/zenhorde/hordeagentmessage.cpp b/src/zenhorde/hordeagentmessage.cpp new file mode 100644 index 000000000..998134a96 --- /dev/null +++ b/src/zenhorde/hordeagentmessage.cpp @@ -0,0 +1,340 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "hordeagentmessage.h" + +#include <zencore/intmath.h> + +#include <cassert> +#include <cstring> + +namespace zen::horde { + +AgentMessageChannel::AgentMessageChannel(Ref<ComputeChannel> Channel) : m_Channel(std::move(Channel)) +{ +} + +AgentMessageChannel::~AgentMessageChannel() = default; + +void +AgentMessageChannel::Close() +{ + CreateMessage(AgentMessageType::None, 0); + FlushMessage(); +} + +void +AgentMessageChannel::Ping() +{ + CreateMessage(AgentMessageType::Ping, 0); + FlushMessage(); +} + +void +AgentMessageChannel::Fork(int ChannelId, int BufferSize) +{ + CreateMessage(AgentMessageType::Fork, sizeof(int) + sizeof(int)); + WriteInt32(ChannelId); + WriteInt32(BufferSize); + FlushMessage(); +} + +void +AgentMessageChannel::Attach() +{ + CreateMessage(AgentMessageType::Attach, 0); + FlushMessage(); +} + +void +AgentMessageChannel::UploadFiles(const char* Path, const char* Locator) +{ + CreateMessage(AgentMessageType::WriteFiles, strlen(Path) + strlen(Locator) + 20); + WriteString(Path); + WriteString(Locator); + FlushMessage(); +} + +void +AgentMessageChannel::Execute(const char* Exe, + const char* const* Args, + size_t NumArgs, + const char* WorkingDir, + const char* const* EnvVars, + size_t NumEnvVars, + ExecuteProcessFlags Flags) +{ + size_t RequiredSize = 50 + strlen(Exe); + for (size_t i = 0; i < NumArgs; ++i) + { + RequiredSize += strlen(Args[i]) + 10; + } + if (WorkingDir) + { + RequiredSize += strlen(WorkingDir) + 10; + } + for (size_t i = 0; i < NumEnvVars; ++i) + { + RequiredSize += strlen(EnvVars[i]) + 20; + } + + CreateMessage(AgentMessageType::ExecuteV2, RequiredSize); + WriteString(Exe); + + WriteUnsignedVarInt(NumArgs); + for (size_t i = 0; i < NumArgs; ++i) + { + WriteString(Args[i]); + } + + WriteOptionalString(WorkingDir); + + // ExecuteV2 protocol requires env vars as separate key/value pairs. + // Callers pass "KEY=VALUE" strings; we split on the first '=' here. + WriteUnsignedVarInt(NumEnvVars); + for (size_t i = 0; i < NumEnvVars; ++i) + { + const char* Eq = strchr(EnvVars[i], '='); + assert(Eq != nullptr); + + WriteString(std::string_view(EnvVars[i], Eq - EnvVars[i])); + if (*(Eq + 1) == '\0') + { + WriteOptionalString(nullptr); + } + else + { + WriteOptionalString(Eq + 1); + } + } + + WriteInt32(static_cast<int>(Flags)); + FlushMessage(); +} + +void +AgentMessageChannel::Blob(const uint8_t* Data, size_t Length) +{ + // Blob responses are chunked to fit within the compute buffer's chunk size. + // The 128-byte margin accounts for the ReadBlobResponse header (offset + total length fields). + const size_t MaxChunkSize = m_Channel->Writer.GetChunkMaxLength() - 128 - MessageHeaderLength; + for (size_t ChunkOffset = 0; ChunkOffset < Length;) + { + const size_t ChunkLength = std::min(Length - ChunkOffset, MaxChunkSize); + + CreateMessage(AgentMessageType::ReadBlobResponse, ChunkLength + 128); + WriteInt32(static_cast<int>(ChunkOffset)); + WriteInt32(static_cast<int>(Length)); + WriteFixedLengthBytes(Data + ChunkOffset, ChunkLength); + FlushMessage(); + + ChunkOffset += ChunkLength; + } +} + +AgentMessageType +AgentMessageChannel::ReadResponse(int32_t TimeoutMs, bool* OutTimedOut) +{ + // Deferred advance: the previous response's buffer is only released when the next + // ReadResponse is called. This allows callers to read response data between calls + // without copying, since the pointer comes directly from the ring buffer. + if (m_ResponseData) + { + m_Channel->Reader.AdvanceReadPosition(m_ResponseLength + MessageHeaderLength); + m_ResponseData = nullptr; + m_ResponseLength = 0; + } + + const uint8_t* Header = m_Channel->Reader.WaitToRead(MessageHeaderLength, TimeoutMs, OutTimedOut); + if (!Header) + { + return AgentMessageType::None; + } + + uint32_t Length; + memcpy(&Length, Header + 1, sizeof(uint32_t)); + + Header = m_Channel->Reader.WaitToRead(MessageHeaderLength + Length, TimeoutMs, OutTimedOut); + if (!Header) + { + return AgentMessageType::None; + } + + m_ResponseType = static_cast<AgentMessageType>(Header[0]); + m_ResponseData = Header + MessageHeaderLength; + m_ResponseLength = Length; + + return m_ResponseType; +} + +void +AgentMessageChannel::ReadException(ExceptionInfo& Ex) +{ + assert(m_ResponseType == AgentMessageType::Exception); + const uint8_t* Pos = m_ResponseData; + Ex.Message = ReadString(&Pos); + Ex.Description = ReadString(&Pos); +} + +int +AgentMessageChannel::ReadExecuteResult() +{ + assert(m_ResponseType == AgentMessageType::ExecuteResult); + const uint8_t* Pos = m_ResponseData; + return ReadInt32(&Pos); +} + +void +AgentMessageChannel::ReadBlobRequest(BlobRequest& Req) +{ + assert(m_ResponseType == AgentMessageType::ReadBlob); + const uint8_t* Pos = m_ResponseData; + Req.Locator = ReadString(&Pos); + Req.Offset = ReadUnsignedVarInt(&Pos); + Req.Length = ReadUnsignedVarInt(&Pos); +} + +void +AgentMessageChannel::CreateMessage(AgentMessageType Type, size_t MaxLength) +{ + m_RequestData = m_Channel->Writer.WaitToWrite(MessageHeaderLength + MaxLength); + m_RequestData[0] = static_cast<uint8_t>(Type); + m_MaxRequestSize = MaxLength; + m_RequestSize = 0; +} + +void +AgentMessageChannel::FlushMessage() +{ + const uint32_t Size = static_cast<uint32_t>(m_RequestSize); + memcpy(&m_RequestData[1], &Size, sizeof(uint32_t)); + m_Channel->Writer.AdvanceWritePosition(MessageHeaderLength + m_RequestSize); + m_RequestSize = 0; + m_MaxRequestSize = 0; + m_RequestData = nullptr; +} + +void +AgentMessageChannel::WriteInt32(int Value) +{ + WriteFixedLengthBytes(reinterpret_cast<const uint8_t*>(&Value), sizeof(int)); +} + +int +AgentMessageChannel::ReadInt32(const uint8_t** Pos) +{ + int Value; + memcpy(&Value, *Pos, sizeof(int)); + *Pos += sizeof(int); + return Value; +} + +void +AgentMessageChannel::WriteFixedLengthBytes(const uint8_t* Data, size_t Length) +{ + assert(m_RequestSize + Length <= m_MaxRequestSize); + memcpy(&m_RequestData[MessageHeaderLength + m_RequestSize], Data, Length); + m_RequestSize += Length; +} + +const uint8_t* +AgentMessageChannel::ReadFixedLengthBytes(const uint8_t** Pos, size_t Length) +{ + const uint8_t* Data = *Pos; + *Pos += Length; + return Data; +} + +size_t +AgentMessageChannel::MeasureUnsignedVarInt(size_t Value) +{ + if (Value == 0) + { + return 1; + } + return (FloorLog2_64(static_cast<uint64_t>(Value)) / 7) + 1; +} + +void +AgentMessageChannel::WriteUnsignedVarInt(size_t Value) +{ + const size_t ByteCount = MeasureUnsignedVarInt(Value); + assert(m_RequestSize + ByteCount <= m_MaxRequestSize); + + uint8_t* Output = m_RequestData + MessageHeaderLength + m_RequestSize; + for (size_t i = 1; i < ByteCount; ++i) + { + Output[ByteCount - i] = static_cast<uint8_t>(Value); + Value >>= 8; + } + Output[0] = static_cast<uint8_t>((0xFF << (9 - static_cast<int>(ByteCount))) | static_cast<uint8_t>(Value)); + + m_RequestSize += ByteCount; +} + +size_t +AgentMessageChannel::ReadUnsignedVarInt(const uint8_t** Pos) +{ + const uint8_t* Data = *Pos; + const uint8_t FirstByte = Data[0]; + const size_t NumBytes = CountLeadingZeros(0xFF & (~static_cast<unsigned int>(FirstByte))) + 1 - 24; + + size_t Value = static_cast<size_t>(FirstByte & (0xFF >> NumBytes)); + for (size_t i = 1; i < NumBytes; ++i) + { + Value <<= 8; + Value |= Data[i]; + } + + *Pos += NumBytes; + return Value; +} + +size_t +AgentMessageChannel::MeasureString(const char* Text) const +{ + const size_t Length = strlen(Text); + return MeasureUnsignedVarInt(Length) + Length; +} + +void +AgentMessageChannel::WriteString(const char* Text) +{ + const size_t Length = strlen(Text); + WriteUnsignedVarInt(Length); + WriteFixedLengthBytes(reinterpret_cast<const uint8_t*>(Text), Length); +} + +void +AgentMessageChannel::WriteString(std::string_view Text) +{ + WriteUnsignedVarInt(Text.size()); + WriteFixedLengthBytes(reinterpret_cast<const uint8_t*>(Text.data()), Text.size()); +} + +std::string_view +AgentMessageChannel::ReadString(const uint8_t** Pos) +{ + const size_t Length = ReadUnsignedVarInt(Pos); + const char* Start = reinterpret_cast<const char*>(ReadFixedLengthBytes(Pos, Length)); + return std::string_view(Start, Length); +} + +void +AgentMessageChannel::WriteOptionalString(const char* Text) +{ + // Optional strings use length+1 encoding: 0 means null/absent, + // N>0 means a string of length N-1 follows. This matches the UE + // FAgentMessageChannel serialization convention. + if (!Text) + { + WriteUnsignedVarInt(0); + } + else + { + const size_t Length = strlen(Text); + WriteUnsignedVarInt(Length + 1); + WriteFixedLengthBytes(reinterpret_cast<const uint8_t*>(Text), Length); + } +} + +} // namespace zen::horde diff --git a/src/zenhorde/hordeagentmessage.h b/src/zenhorde/hordeagentmessage.h new file mode 100644 index 000000000..38c4375fd --- /dev/null +++ b/src/zenhorde/hordeagentmessage.h @@ -0,0 +1,161 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenbase/zenbase.h> + +#include "hordecomputechannel.h" + +#include <cstddef> +#include <cstdint> +#include <string> +#include <string_view> +#include <vector> + +namespace zen::horde { + +/** Agent message types matching the UE EAgentMessageType byte values. + * These are the message opcodes exchanged over the agent/child channels. */ +enum class AgentMessageType : uint8_t +{ + None = 0x00, + Ping = 0x01, + Exception = 0x02, + Fork = 0x03, + Attach = 0x04, + WriteFiles = 0x10, + WriteFilesResponse = 0x11, + DeleteFiles = 0x12, + ExecuteV2 = 0x22, + ExecuteOutput = 0x17, + ExecuteResult = 0x18, + ReadBlob = 0x20, + ReadBlobResponse = 0x21, +}; + +/** Flags for the ExecuteV2 message. */ +enum class ExecuteProcessFlags : uint8_t +{ + None = 0, + UseWine = 1, ///< Run the executable under Wine on Linux agents +}; + +/** Parsed exception information from an Exception message. */ +struct ExceptionInfo +{ + std::string_view Message; + std::string_view Description; +}; + +/** Parsed blob read request from a ReadBlob message. */ +struct BlobRequest +{ + std::string_view Locator; + size_t Offset = 0; + size_t Length = 0; +}; + +/** Channel for sending and receiving agent messages over a ComputeChannel. + * + * Implements the Horde agent message protocol, matching the UE + * FAgentMessageChannel serialization format exactly. Messages are framed as + * [type (1B)][payload length (4B)][payload]. Strings use length-prefixed UTF-8; + * integers use variable-length encoding. + * + * The protocol has two directions: + * - Requests (initiator -> remote): Close, Ping, Fork, Attach, UploadFiles, Execute, Blob + * - Responses (remote -> initiator): ReadResponse returns the type, then call the + * appropriate Read* method to parse the payload. + */ +class AgentMessageChannel +{ +public: + explicit AgentMessageChannel(Ref<ComputeChannel> Channel); + ~AgentMessageChannel(); + + AgentMessageChannel(const AgentMessageChannel&) = delete; + AgentMessageChannel& operator=(const AgentMessageChannel&) = delete; + + // --- Requests (Initiator -> Remote) --- + + /** Close the channel. */ + void Close(); + + /** Send a keepalive ping. */ + void Ping(); + + /** Fork communication to a new channel with the given ID and buffer size. */ + void Fork(int ChannelId, int BufferSize); + + /** Send an attach request (used during channel setup handshake). */ + void Attach(); + + /** Request the remote agent to write files from the given bundle locator. */ + void UploadFiles(const char* Path, const char* Locator); + + /** Execute a process on the remote machine. */ + void Execute(const char* Exe, + const char* const* Args, + size_t NumArgs, + const char* WorkingDir, + const char* const* EnvVars, + size_t NumEnvVars, + ExecuteProcessFlags Flags = ExecuteProcessFlags::None); + + /** Send blob data in response to a ReadBlob request. */ + void Blob(const uint8_t* Data, size_t Length); + + // --- Responses (Remote -> Initiator) --- + + /** Read the next response message. Returns the message type, or None on timeout. + * After this returns, use GetResponseData()/GetResponseSize() or the typed + * Read* methods to access the payload. */ + AgentMessageType ReadResponse(int32_t TimeoutMs = -1, bool* OutTimedOut = nullptr); + + const void* GetResponseData() const { return m_ResponseData; } + size_t GetResponseSize() const { return m_ResponseLength; } + + /** Parse an Exception response payload. */ + void ReadException(ExceptionInfo& Ex); + + /** Parse an ExecuteResult response payload. Returns the exit code. */ + int ReadExecuteResult(); + + /** Parse a ReadBlob response payload into a BlobRequest. */ + void ReadBlobRequest(BlobRequest& Req); + +private: + static constexpr size_t MessageHeaderLength = 5; ///< [type(1B)][length(4B)] + + Ref<ComputeChannel> m_Channel; + + uint8_t* m_RequestData = nullptr; + size_t m_RequestSize = 0; + size_t m_MaxRequestSize = 0; + + AgentMessageType m_ResponseType = AgentMessageType::None; + const uint8_t* m_ResponseData = nullptr; + size_t m_ResponseLength = 0; + + void CreateMessage(AgentMessageType Type, size_t MaxLength); + void FlushMessage(); + + void WriteInt32(int Value); + static int ReadInt32(const uint8_t** Pos); + + void WriteFixedLengthBytes(const uint8_t* Data, size_t Length); + static const uint8_t* ReadFixedLengthBytes(const uint8_t** Pos, size_t Length); + + static size_t MeasureUnsignedVarInt(size_t Value); + void WriteUnsignedVarInt(size_t Value); + static size_t ReadUnsignedVarInt(const uint8_t** Pos); + + size_t MeasureString(const char* Text) const; + void WriteString(const char* Text); + void WriteString(std::string_view Text); + static std::string_view ReadString(const uint8_t** Pos); + + void WriteOptionalString(const char* Text); +}; + +} // namespace zen::horde diff --git a/src/zenhorde/hordebundle.cpp b/src/zenhorde/hordebundle.cpp new file mode 100644 index 000000000..d3974bc28 --- /dev/null +++ b/src/zenhorde/hordebundle.cpp @@ -0,0 +1,619 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "hordebundle.h" + +#include <zencore/basicfile.h> +#include <zencore/filesystem.h> +#include <zencore/fmtutils.h> +#include <zencore/intmath.h> +#include <zencore/iohash.h> +#include <zencore/logging.h> +#include <zencore/process.h> +#include <zencore/trace.h> + +#include <algorithm> +#include <chrono> +#include <cstring> + +namespace zen::horde { + +static LoggerRef +Log() +{ + static auto s_Logger = zen::logging::Get("horde.bundle"); + return s_Logger; +} + +static constexpr uint8_t PacketSignature[3] = {'U', 'B', 'N'}; +static constexpr uint8_t PacketVersion = 5; +static constexpr int32_t CurrentPacketBaseIdx = -2; +static constexpr int ImportBias = 3; +static constexpr uint32_t ChunkSize = 64 * 1024; // 64KB fixed chunks +static constexpr uint32_t LargeFileThreshold = 128 * 1024; // 128KB + +// BlobType: 20 bytes each = FGuid (16 bytes, 4x uint32 LE) + Version (int32 LE) +// Values from UE SDK: GUIDs stored as 4 uint32 LE values. + +// ChunkLeaf v1: {0xB27AFB68, 0x4A4B9E20, 0x8A78D8A4, 0x39D49840} +static constexpr uint8_t BlobType_ChunkLeafV1[20] = {0x68, 0xFB, 0x7A, 0xB2, 0x20, 0x9E, 0x4B, 0x4A, 0xA4, 0xD8, + 0x78, 0x8A, 0x40, 0x98, 0xD4, 0x39, 0x01, 0x00, 0x00, 0x00}; // version 1 + +// ChunkInterior v2: {0xF4DEDDBC, 0x4C7A70CB, 0x11F04783, 0xB9CDCCAF} +static constexpr uint8_t BlobType_ChunkInteriorV2[20] = {0xBC, 0xDD, 0xDE, 0xF4, 0xCB, 0x70, 0x7A, 0x4C, 0x83, 0x47, + 0xF0, 0x11, 0xAF, 0xCC, 0xCD, 0xB9, 0x02, 0x00, 0x00, 0x00}; // version 2 + +// Directory v1: {0x0714EC11, 0x4D07291A, 0x8AE77F86, 0x799980D6} +static constexpr uint8_t BlobType_DirectoryV1[20] = {0x11, 0xEC, 0x14, 0x07, 0x1A, 0x29, 0x07, 0x4D, 0x86, 0x7F, + 0xE7, 0x8A, 0xD6, 0x80, 0x99, 0x79, 0x01, 0x00, 0x00, 0x00}; // version 1 + +static constexpr size_t BlobTypeSize = 20; + +// ─── VarInt helpers (UE format) ───────────────────────────────────────────── + +static size_t +MeasureVarInt(size_t Value) +{ + if (Value == 0) + { + return 1; + } + return (FloorLog2(static_cast<unsigned int>(Value)) / 7) + 1; +} + +static void +WriteVarInt(std::vector<uint8_t>& Buffer, size_t Value) +{ + const size_t ByteCount = MeasureVarInt(Value); + const size_t Offset = Buffer.size(); + Buffer.resize(Offset + ByteCount); + + uint8_t* Output = Buffer.data() + Offset; + for (size_t i = 1; i < ByteCount; ++i) + { + Output[ByteCount - i] = static_cast<uint8_t>(Value); + Value >>= 8; + } + Output[0] = static_cast<uint8_t>((0xFF << (9 - static_cast<int>(ByteCount))) | static_cast<uint8_t>(Value)); +} + +// ─── Binary helpers ───────────────────────────────────────────────────────── + +static void +WriteLE32(std::vector<uint8_t>& Buffer, int32_t Value) +{ + uint8_t Bytes[4]; + memcpy(Bytes, &Value, 4); + Buffer.insert(Buffer.end(), Bytes, Bytes + 4); +} + +static void +WriteByte(std::vector<uint8_t>& Buffer, uint8_t Value) +{ + Buffer.push_back(Value); +} + +static void +WriteBytes(std::vector<uint8_t>& Buffer, const void* Data, size_t Size) +{ + auto* Ptr = static_cast<const uint8_t*>(Data); + Buffer.insert(Buffer.end(), Ptr, Ptr + Size); +} + +static void +WriteString(std::vector<uint8_t>& Buffer, std::string_view Str) +{ + WriteVarInt(Buffer, Str.size()); + WriteBytes(Buffer, Str.data(), Str.size()); +} + +static void +AlignTo4(std::vector<uint8_t>& Buffer) +{ + while (Buffer.size() % 4 != 0) + { + Buffer.push_back(0); + } +} + +static void +PatchLE32(std::vector<uint8_t>& Buffer, size_t Offset, int32_t Value) +{ + memcpy(Buffer.data() + Offset, &Value, 4); +} + +// ─── Packet builder ───────────────────────────────────────────────────────── + +// Builds a single uncompressed Horde V2 packet. Layout: +// [Signature(3) + Version(1) + PacketLength(4)] 8 bytes (header) +// [TypeTableOffset(4) + ImportTableOffset(4) + ExportTableOffset(4)] 12 bytes +// [Export data...] +// [Type table: count(4) + count * 20 bytes] +// [Import table: count(4) + (count+1) offset entries(4 each) + import data] +// [Export table: count(4) + (count+1) offset entries(4 each)] +// +// ALL offsets are absolute from byte 0 of the full packet (including the 8-byte header). +// PacketLength in the header = total packet size including the 8-byte header. + +struct PacketBuilder +{ + std::vector<uint8_t> Data; + std::vector<int32_t> ExportOffsets; // Absolute byte offset of each export from byte 0 + + // Type table: unique 20-byte BlobType entries + std::vector<const uint8_t*> Types; + + // Import table entries: (baseIdx, fragment) + struct ImportEntry + { + int32_t BaseIdx; + std::string Fragment; + }; + std::vector<ImportEntry> Imports; + + // Current export's start offset (absolute from byte 0) + size_t CurrentExportStart = 0; + + PacketBuilder() + { + // Reserve packet header (8 bytes) + table offsets (12 bytes) = 20 bytes + Data.resize(20, 0); + + // Write signature + Data[0] = PacketSignature[0]; + Data[1] = PacketSignature[1]; + Data[2] = PacketSignature[2]; + Data[3] = PacketVersion; + // PacketLength, TypeTableOffset, ImportTableOffset, ExportTableOffset + // will be patched in Finish() + } + + int AddType(const uint8_t* BlobType) + { + for (size_t i = 0; i < Types.size(); ++i) + { + if (memcmp(Types[i], BlobType, BlobTypeSize) == 0) + { + return static_cast<int>(i); + } + } + Types.push_back(BlobType); + return static_cast<int>(Types.size() - 1); + } + + int AddImport(int32_t BaseIdx, std::string Fragment) + { + Imports.push_back({BaseIdx, std::move(Fragment)}); + return static_cast<int>(Imports.size() - 1); + } + + void BeginExport() + { + AlignTo4(Data); + CurrentExportStart = Data.size(); + // Reserve space for payload length + WriteLE32(Data, 0); + } + + // Write raw payload data into the current export + void WritePayload(const void* Payload, size_t Size) { WriteBytes(Data, Payload, Size); } + + // Complete the current export: patches payload length, writes type+imports metadata + int CompleteExport(const uint8_t* BlobType, const std::vector<int>& ImportIndices) + { + const int ExportIndex = static_cast<int>(ExportOffsets.size()); + + // Patch payload length (does not include the 4-byte length field itself) + const size_t PayloadStart = CurrentExportStart + 4; + const int32_t PayloadLen = static_cast<int32_t>(Data.size() - PayloadStart); + PatchLE32(Data, CurrentExportStart, PayloadLen); + + // Write type index (varint) + const int TypeIdx = AddType(BlobType); + WriteVarInt(Data, static_cast<size_t>(TypeIdx)); + + // Write import count + indices + WriteVarInt(Data, ImportIndices.size()); + for (int Idx : ImportIndices) + { + WriteVarInt(Data, static_cast<size_t>(Idx)); + } + + // Record export offset (absolute from byte 0) + ExportOffsets.push_back(static_cast<int32_t>(CurrentExportStart)); + + return ExportIndex; + } + + // Finalize the packet: write type/import/export tables, patch header. + std::vector<uint8_t> Finish() + { + AlignTo4(Data); + + // ── Type table: count(int32) + count * BlobTypeSize bytes ── + const int32_t TypeTableOffset = static_cast<int32_t>(Data.size()); + WriteLE32(Data, static_cast<int32_t>(Types.size())); + for (const uint8_t* TypeEntry : Types) + { + WriteBytes(Data, TypeEntry, BlobTypeSize); + } + + // ── Import table: count(int32) + (count+1) offsets(int32 each) + import data ── + const int32_t ImportTableOffset = static_cast<int32_t>(Data.size()); + const int32_t ImportCount = static_cast<int32_t>(Imports.size()); + WriteLE32(Data, ImportCount); + + // Reserve space for (count+1) offset entries — will be patched below + const size_t ImportOffsetsStart = Data.size(); + for (int32_t i = 0; i <= ImportCount; ++i) + { + WriteLE32(Data, 0); // placeholder + } + + // Write import data and record offsets + for (int32_t i = 0; i < ImportCount; ++i) + { + // Record absolute offset of this import's data + PatchLE32(Data, ImportOffsetsStart + static_cast<size_t>(i) * 4, static_cast<int32_t>(Data.size())); + + ImportEntry& Imp = Imports[static_cast<size_t>(i)]; + // BaseIdx encoded as unsigned VarInt with bias: VarInt(BaseIdx + ImportBias) + const size_t EncodedBaseIdx = static_cast<size_t>(static_cast<int64_t>(Imp.BaseIdx) + ImportBias); + WriteVarInt(Data, EncodedBaseIdx); + // Fragment: raw UTF-8 bytes, NO length prefix (length determined by offset table) + WriteBytes(Data, Imp.Fragment.data(), Imp.Fragment.size()); + } + + // Sentinel offset (points past the last import's data) + PatchLE32(Data, ImportOffsetsStart + static_cast<size_t>(ImportCount) * 4, static_cast<int32_t>(Data.size())); + + // ── Export table: count(int32) + (count+1) offsets(int32 each) ── + const int32_t ExportTableOffset = static_cast<int32_t>(Data.size()); + const int32_t ExportCount = static_cast<int32_t>(ExportOffsets.size()); + WriteLE32(Data, ExportCount); + + for (int32_t Off : ExportOffsets) + { + WriteLE32(Data, Off); + } + // Sentinel: points to the start of the type table (end of export data region) + WriteLE32(Data, TypeTableOffset); + + // ── Patch header ── + // PacketLength = total packet size including the 8-byte header + const int32_t PacketLength = static_cast<int32_t>(Data.size()); + PatchLE32(Data, 4, PacketLength); + PatchLE32(Data, 8, TypeTableOffset); + PatchLE32(Data, 12, ImportTableOffset); + PatchLE32(Data, 16, ExportTableOffset); + + return std::move(Data); + } +}; + +// ─── Encoded packet wrapper ───────────────────────────────────────────────── + +// Wraps an uncompressed packet with the encoded header: +// [Signature(3) + Version(1) + HeaderLength(4)] 8 bytes +// [DecompressedLength(4)] 4 bytes +// [CompressionFormat(1): 0=None] 1 byte +// [PacketData...] +// +// HeaderLength = total encoded packet size INCLUDING the 8-byte outer header. + +static std::vector<uint8_t> +EncodePacket(std::vector<uint8_t> UncompressedPacket) +{ + const int32_t DecompressedLen = static_cast<int32_t>(UncompressedPacket.size()); + // HeaderLength includes the 8-byte outer signature header itself + const int32_t HeaderLength = 8 + 4 + 1 + DecompressedLen; + + std::vector<uint8_t> Encoded; + Encoded.reserve(static_cast<size_t>(HeaderLength)); + + // Outer signature: 'U','B','N', version=5, HeaderLength (LE int32) + WriteByte(Encoded, PacketSignature[0]); // 'U' + WriteByte(Encoded, PacketSignature[1]); // 'B' + WriteByte(Encoded, PacketSignature[2]); // 'N' + WriteByte(Encoded, PacketVersion); // 5 + WriteLE32(Encoded, HeaderLength); + + // Decompressed length + compression format + WriteLE32(Encoded, DecompressedLen); + WriteByte(Encoded, 0); // CompressionFormat::None + + // Packet data + WriteBytes(Encoded, UncompressedPacket.data(), UncompressedPacket.size()); + + return Encoded; +} + +// ─── Bundle blob name generation ──────────────────────────────────────────── + +static std::string +GenerateBlobName() +{ + static std::atomic<uint32_t> s_Counter{0}; + + const int Pid = GetCurrentProcessId(); + + auto Now = std::chrono::steady_clock::now().time_since_epoch(); + auto Ms = std::chrono::duration_cast<std::chrono::milliseconds>(Now).count(); + + ExtendableStringBuilder<64> Name; + Name << Pid << "_" << Ms << "_" << s_Counter.fetch_add(1); + return std::string(Name.ToView()); +} + +// ─── File info for bundling ───────────────────────────────────────────────── + +struct FileInfo +{ + std::filesystem::path Path; + std::string Name; // Filename only (for directory entry) + uint64_t FileSize; + IoHash ContentHash; // IoHash of file content + BLAKE3 StreamHash; // Full BLAKE3 for stream hash + int DirectoryExportImportIndex; // Import index referencing this file's root export + IoHash RootExportHash; // IoHash of the root export for this file +}; + +// ─── CreateBundle implementation ──────────────────────────────────────────── + +bool +BundleCreator::CreateBundle(const std::vector<BundleFile>& Files, const std::filesystem::path& OutputDir, BundleResult& OutResult) +{ + ZEN_TRACE_CPU("BundleCreator::CreateBundle"); + + std::error_code Ec; + + // Collect files that exist + std::vector<FileInfo> ValidFiles; + for (const BundleFile& F : Files) + { + if (!std::filesystem::exists(F.Path, Ec)) + { + if (F.Optional) + { + continue; + } + ZEN_ERROR("required bundle file does not exist: {}", F.Path.string()); + return false; + } + FileInfo Info; + Info.Path = F.Path; + Info.Name = F.Path.filename().string(); + Info.FileSize = std::filesystem::file_size(F.Path, Ec); + if (Ec) + { + ZEN_ERROR("failed to get file size: {}", F.Path.string()); + return false; + } + ValidFiles.push_back(std::move(Info)); + } + + if (ValidFiles.empty()) + { + ZEN_ERROR("no valid files to bundle"); + return false; + } + + std::filesystem::create_directories(OutputDir, Ec); + if (Ec) + { + ZEN_ERROR("failed to create output directory: {}", OutputDir.string()); + return false; + } + + const std::string BlobName = GenerateBlobName(); + PacketBuilder Packet; + + // Process each file: create chunk exports + for (FileInfo& Info : ValidFiles) + { + BasicFile File; + File.Open(Info.Path, BasicFile::Mode::kRead, Ec); + if (Ec) + { + ZEN_ERROR("failed to open file: {}", Info.Path.string()); + return false; + } + + // Compute stream hash (full BLAKE3) and content hash (IoHash) while reading + BLAKE3Stream StreamHasher; + IoHashStream ContentHasher; + + if (Info.FileSize <= LargeFileThreshold) + { + // Small file: single chunk leaf export + IoBuffer Content = File.ReadAll(); + const auto* Data = static_cast<const uint8_t*>(Content.GetData()); + const size_t Size = Content.GetSize(); + + StreamHasher.Append(Data, Size); + ContentHasher.Append(Data, Size); + + Packet.BeginExport(); + Packet.WritePayload(Data, Size); + + const IoHash ChunkHash = IoHash::HashBuffer(Data, Size); + const int ExportIndex = Packet.CompleteExport(BlobType_ChunkLeafV1, {}); + Info.RootExportHash = ChunkHash; + Info.ContentHash = ContentHasher.GetHash(); + Info.StreamHash = StreamHasher.GetHash(); + + // Add import for this file's root export (references export within same packet) + ExtendableStringBuilder<32> Fragment; + Fragment << "exp=" << ExportIndex; + Info.DirectoryExportImportIndex = Packet.AddImport(CurrentPacketBaseIdx, std::string(Fragment.ToView())); + } + else + { + // Large file: split into fixed 64KB chunks, then create interior node + std::vector<int> ChunkExportIndices; + std::vector<IoHash> ChunkHashes; + + uint64_t Remaining = Info.FileSize; + uint64_t Offset = 0; + + while (Remaining > 0) + { + const uint64_t ReadSize = std::min(static_cast<uint64_t>(ChunkSize), Remaining); + IoBuffer Chunk = File.ReadRange(Offset, ReadSize); + const auto* Data = static_cast<const uint8_t*>(Chunk.GetData()); + const size_t Size = Chunk.GetSize(); + + StreamHasher.Append(Data, Size); + ContentHasher.Append(Data, Size); + + Packet.BeginExport(); + Packet.WritePayload(Data, Size); + + const IoHash ChunkHash = IoHash::HashBuffer(Data, Size); + const int ExpIdx = Packet.CompleteExport(BlobType_ChunkLeafV1, {}); + + ChunkExportIndices.push_back(ExpIdx); + ChunkHashes.push_back(ChunkHash); + + Offset += ReadSize; + Remaining -= ReadSize; + } + + Info.ContentHash = ContentHasher.GetHash(); + Info.StreamHash = StreamHasher.GetHash(); + + // Create interior node referencing all chunk leaves + // Interior payload: for each child: [IoHash(20)][node_type=1(1)] + imports + std::vector<int> InteriorImports; + for (size_t i = 0; i < ChunkExportIndices.size(); ++i) + { + ExtendableStringBuilder<32> Fragment; + Fragment << "exp=" << ChunkExportIndices[i]; + const int ImportIdx = Packet.AddImport(CurrentPacketBaseIdx, std::string(Fragment.ToView())); + InteriorImports.push_back(ImportIdx); + } + + Packet.BeginExport(); + + // Write interior payload: [hash(20)][type(1)] per child + for (size_t i = 0; i < ChunkHashes.size(); ++i) + { + Packet.WritePayload(ChunkHashes[i].Hash, sizeof(IoHash)); + const uint8_t NodeType = 1; // ChunkNode type + Packet.WritePayload(&NodeType, 1); + } + + // Hash the interior payload to get the interior node hash + const IoHash InteriorHash = IoHash::HashBuffer(Packet.Data.data() + (Packet.CurrentExportStart + 4), + Packet.Data.size() - (Packet.CurrentExportStart + 4)); + + const int InteriorExportIndex = Packet.CompleteExport(BlobType_ChunkInteriorV2, InteriorImports); + + Info.RootExportHash = InteriorHash; + + // Add import for directory to reference this interior node + ExtendableStringBuilder<32> Fragment; + Fragment << "exp=" << InteriorExportIndex; + Info.DirectoryExportImportIndex = Packet.AddImport(CurrentPacketBaseIdx, std::string(Fragment.ToView())); + } + } + + // Create directory node export + // Payload: [flags(varint=0)] [file_count(varint)] [file_entries...] [dir_count(varint=0)] + // FileEntry: [import(varint)] [IoHash(20)] [name(string)] [flags(varint)] [length(varint)] [IoHash_stream(20)] + + Packet.BeginExport(); + + // Build directory payload into a temporary buffer, then write it + std::vector<uint8_t> DirPayload; + WriteVarInt(DirPayload, 0); // flags + WriteVarInt(DirPayload, ValidFiles.size()); // file_count + + std::vector<int> DirImports; + for (size_t i = 0; i < ValidFiles.size(); ++i) + { + FileInfo& Info = ValidFiles[i]; + DirImports.push_back(Info.DirectoryExportImportIndex); + + // IoHash of target (20 bytes) — import is consumed sequentially from the + // export's import list by ReadBlobRef, not encoded in the payload + WriteBytes(DirPayload, Info.RootExportHash.Hash, sizeof(IoHash)); + // name (string) + WriteString(DirPayload, Info.Name); + // flags (varint): 1 = Executable + WriteVarInt(DirPayload, 1); + // length (varint) + WriteVarInt(DirPayload, static_cast<size_t>(Info.FileSize)); + // stream hash: IoHash from full BLAKE3, truncated to 20 bytes + const IoHash StreamIoHash = IoHash::FromBLAKE3(Info.StreamHash); + WriteBytes(DirPayload, StreamIoHash.Hash, sizeof(IoHash)); + } + + WriteVarInt(DirPayload, 0); // dir_count + + Packet.WritePayload(DirPayload.data(), DirPayload.size()); + const int DirExportIndex = Packet.CompleteExport(BlobType_DirectoryV1, DirImports); + + // Finalize packet and encode + std::vector<uint8_t> UncompressedPacket = Packet.Finish(); + std::vector<uint8_t> EncodedPacket = EncodePacket(std::move(UncompressedPacket)); + + // Write .blob file + const std::filesystem::path BlobFilePath = OutputDir / (BlobName + ".blob"); + { + BasicFile BlobFile(BlobFilePath, BasicFile::Mode::kTruncate, Ec); + if (Ec) + { + ZEN_ERROR("failed to create blob file: {}", BlobFilePath.string()); + return false; + } + BlobFile.Write(EncodedPacket.data(), EncodedPacket.size(), 0); + } + + // Build locator: <blob_name>#pkt=0,<encoded_len>&exp=<dir_export_index> + ExtendableStringBuilder<256> Locator; + Locator << BlobName << "#pkt=0," << uint64_t(EncodedPacket.size()) << "&exp=" << DirExportIndex; + const std::string LocatorStr(Locator.ToView()); + + // Write .ref file (use first file's name as the ref base) + const std::filesystem::path RefFilePath = OutputDir / (ValidFiles[0].Name + ".Bundle.ref"); + { + BasicFile RefFile(RefFilePath, BasicFile::Mode::kTruncate, Ec); + if (Ec) + { + ZEN_ERROR("failed to create ref file: {}", RefFilePath.string()); + return false; + } + RefFile.Write(LocatorStr.data(), LocatorStr.size(), 0); + } + + OutResult.Locator = LocatorStr; + OutResult.BundleDir = OutputDir; + + ZEN_INFO("created V2 bundle: blob={}.blob locator={} files={}", BlobName, LocatorStr, ValidFiles.size()); + return true; +} + +bool +BundleCreator::ReadLocator(const std::filesystem::path& RefFile, std::string& OutLocator) +{ + BasicFile File; + std::error_code Ec; + File.Open(RefFile, BasicFile::Mode::kRead, Ec); + if (Ec) + { + return false; + } + + IoBuffer Content = File.ReadAll(); + OutLocator.assign(static_cast<const char*>(Content.GetData()), Content.GetSize()); + + // Strip trailing whitespace/newlines + while (!OutLocator.empty() && (OutLocator.back() == '\n' || OutLocator.back() == '\r' || OutLocator.back() == '\0')) + { + OutLocator.pop_back(); + } + + return !OutLocator.empty(); +} + +} // namespace zen::horde diff --git a/src/zenhorde/hordebundle.h b/src/zenhorde/hordebundle.h new file mode 100644 index 000000000..052f60435 --- /dev/null +++ b/src/zenhorde/hordebundle.h @@ -0,0 +1,49 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <filesystem> +#include <string> +#include <vector> + +namespace zen::horde { + +/** Describes a file to include in a Horde bundle. */ +struct BundleFile +{ + std::filesystem::path Path; ///< Local file path + bool Optional; ///< If true, skip without error if missing +}; + +/** Result of a successful bundle creation. */ +struct BundleResult +{ + std::string Locator; ///< Root directory locator for WriteFiles + std::filesystem::path BundleDir; ///< Directory containing .blob files +}; + +/** Creates Horde V2 bundles from local files for upload to remote agents. + * + * Produces a proper Horde storage V2 bundle containing: + * - Chunk leaf exports for file data (split into 64KB chunks for large files) + * - Optional interior chunk nodes referencing leaf chunks + * - A directory node listing all bundled files with metadata + * + * The bundle is written as a single .blob file with a corresponding .ref file + * containing the locator string. The locator format is: + * <blob_name>#pkt=0,<encoded_len>&exp=<directory_export_index> + */ +struct BundleCreator +{ + /** Create a V2 bundle from one or more input files. + * @param Files Files to include in the bundle. + * @param OutputDir Directory where .blob and .ref files will be written. + * @param OutResult Receives the locator and output directory on success. + * @return True on success. */ + static bool CreateBundle(const std::vector<BundleFile>& Files, const std::filesystem::path& OutputDir, BundleResult& OutResult); + + /** Read a locator string from a .ref file. Strips trailing whitespace/newlines. */ + static bool ReadLocator(const std::filesystem::path& RefFile, std::string& OutLocator); +}; + +} // namespace zen::horde diff --git a/src/zenhorde/hordeclient.cpp b/src/zenhorde/hordeclient.cpp new file mode 100644 index 000000000..fb981f0ba --- /dev/null +++ b/src/zenhorde/hordeclient.cpp @@ -0,0 +1,382 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencore/fmtutils.h> +#include <zencore/iobuffer.h> +#include <zencore/logging.h> +#include <zencore/memoryview.h> +#include <zencore/trace.h> +#include <zenhorde/hordeclient.h> +#include <zenhttp/httpclient.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <json11.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen::horde { + +HordeClient::HordeClient(const HordeConfig& Config) : m_Config(Config), m_Log(zen::logging::Get("horde.client")) +{ +} + +HordeClient::~HordeClient() = default; + +bool +HordeClient::Initialize() +{ + ZEN_TRACE_CPU("HordeClient::Initialize"); + + HttpClientSettings Settings; + Settings.LogCategory = "horde.http"; + Settings.ConnectTimeout = std::chrono::milliseconds{10000}; + Settings.Timeout = std::chrono::milliseconds{60000}; + Settings.RetryCount = 1; + Settings.ExpectedErrorCodes = {HttpResponseCode::ServiceUnavailable, HttpResponseCode::TooManyRequests}; + + if (!m_Config.AuthToken.empty()) + { + Settings.AccessTokenProvider = [token = m_Config.AuthToken]() -> HttpClientAccessToken { + HttpClientAccessToken Token; + Token.Value = token; + Token.ExpireTime = HttpClientAccessToken::Clock::now() + std::chrono::hours{24}; + return Token; + }; + } + + m_Http = std::make_unique<zen::HttpClient>(m_Config.ServerUrl, Settings); + + if (!m_Config.AuthToken.empty()) + { + if (!m_Http->Authenticate()) + { + ZEN_WARN("failed to authenticate with Horde server"); + return false; + } + } + + return true; +} + +std::string +HordeClient::BuildRequestBody() const +{ + json11::Json::object Requirements; + + if (m_Config.Mode == ConnectionMode::Direct && !m_Config.Pool.empty()) + { + Requirements["pool"] = m_Config.Pool; + } + + std::string Condition; +#if ZEN_PLATFORM_WINDOWS + ExtendableStringBuilder<256> CondBuf; + CondBuf << "(OSFamily == 'Windows' || WineEnabled == '" << (m_Config.AllowWine ? "true" : "false") << "')"; + Condition = std::string(CondBuf); +#elif ZEN_PLATFORM_MAC + Condition = "OSFamily == 'MacOS'"; +#else + Condition = "OSFamily == 'Linux'"; +#endif + + if (!m_Config.Condition.empty()) + { + Condition += " "; + Condition += m_Config.Condition; + } + + Requirements["condition"] = Condition; + Requirements["exclusive"] = true; + + json11::Json::object Connection; + Connection["modePreference"] = ToString(m_Config.Mode); + + if (m_Config.EncryptionMode != Encryption::None) + { + Connection["encryption"] = ToString(m_Config.EncryptionMode); + } + + // Request configured zen service port to be forwarded. The Horde agent will map this + // to a local port on the provisioned machine and report it back in the response. + json11::Json::object PortsObj; + PortsObj["ZenPort"] = json11::Json(m_Config.ZenServicePort); + Connection["ports"] = PortsObj; + + json11::Json::object Root; + Root["requirements"] = Requirements; + Root["connection"] = Connection; + + return json11::Json(Root).dump(); +} + +bool +HordeClient::ResolveCluster(const std::string& RequestBody, ClusterInfo& OutCluster) +{ + ZEN_TRACE_CPU("HordeClient::ResolveCluster"); + + const IoBuffer Payload = IoBufferBuilder::MakeFromMemory(MemoryView{RequestBody.data(), RequestBody.size()}, ZenContentType::kJSON); + + const HttpClient::Response Response = m_Http->Post("api/v2/compute/_cluster", Payload); + + if (Response.Error) + { + ZEN_WARN("cluster resolution failed: {}", Response.Error->ErrorMessage); + return false; + } + + const int StatusCode = static_cast<int>(Response.StatusCode); + + if (StatusCode == 503 || StatusCode == 429) + { + ZEN_DEBUG("cluster resolution returned HTTP/{}: no resources", StatusCode); + return false; + } + + if (StatusCode == 401) + { + ZEN_WARN("cluster resolution returned HTTP/401: token expired"); + return false; + } + + if (!Response.IsSuccess()) + { + ZEN_WARN("cluster resolution failed with HTTP/{}", StatusCode); + return false; + } + + const std::string Body(Response.AsText()); + std::string Err; + const json11::Json Json = json11::Json::parse(Body, Err); + + if (!Err.empty()) + { + ZEN_WARN("invalid JSON response for cluster resolution: {}", Err); + return false; + } + + const json11::Json ClusterIdVal = Json["clusterId"]; + if (!ClusterIdVal.is_string() || ClusterIdVal.string_value().empty()) + { + ZEN_WARN("missing 'clusterId' in cluster resolution response"); + return false; + } + + OutCluster.ClusterId = ClusterIdVal.string_value(); + return true; +} + +bool +HordeClient::ParseHexBytes(std::string_view Hex, uint8_t* Out, size_t OutSize) +{ + if (Hex.size() != OutSize * 2) + { + return false; + } + + for (size_t i = 0; i < OutSize; ++i) + { + auto HexToByte = [](char c) -> int { + if (c >= '0' && c <= '9') + return c - '0'; + if (c >= 'a' && c <= 'f') + return c - 'a' + 10; + if (c >= 'A' && c <= 'F') + return c - 'A' + 10; + return -1; + }; + + const int Hi = HexToByte(Hex[i * 2]); + const int Lo = HexToByte(Hex[i * 2 + 1]); + if (Hi < 0 || Lo < 0) + { + return false; + } + Out[i] = static_cast<uint8_t>((Hi << 4) | Lo); + } + + return true; +} + +bool +HordeClient::RequestMachine(const std::string& RequestBody, const std::string& ClusterId, MachineInfo& OutMachine) +{ + ZEN_TRACE_CPU("HordeClient::RequestMachine"); + + ZEN_INFO("requesting machine from Horde with cluster '{}'", ClusterId.empty() ? "default" : ClusterId.c_str()); + + ExtendableStringBuilder<128> ResourcePath; + ResourcePath << "api/v2/compute/" << (ClusterId.empty() ? "default" : ClusterId.c_str()); + + const IoBuffer Payload = IoBufferBuilder::MakeFromMemory(MemoryView{RequestBody.data(), RequestBody.size()}, ZenContentType::kJSON); + const HttpClient::Response Response = m_Http->Post(ResourcePath.ToView(), Payload); + + // Reset output to invalid state + OutMachine = {}; + OutMachine.Port = 0xFFFF; + + if (Response.Error) + { + ZEN_WARN("machine request failed: {}", Response.Error->ErrorMessage); + return false; + } + + const int StatusCode = static_cast<int>(Response.StatusCode); + + if (StatusCode == 404 || StatusCode == 503 || StatusCode == 429) + { + ZEN_DEBUG("machine request returned HTTP/{}: no resources", StatusCode); + return false; + } + + if (StatusCode == 401) + { + ZEN_WARN("machine request returned HTTP/401: token expired"); + return false; + } + + if (!Response.IsSuccess()) + { + ZEN_WARN("machine request failed with HTTP/{}", StatusCode); + return false; + } + + const std::string Body(Response.AsText()); + std::string Err; + const json11::Json Json = json11::Json::parse(Body, Err); + + if (!Err.empty()) + { + ZEN_WARN("invalid JSON response for machine request: {}", Err); + return false; + } + + // Required fields + const json11::Json NonceVal = Json["nonce"]; + const json11::Json IpVal = Json["ip"]; + const json11::Json PortVal = Json["port"]; + + if (!NonceVal.is_string() || !IpVal.is_string() || !PortVal.is_number()) + { + ZEN_WARN("missing 'nonce', 'ip', or 'port' in machine response"); + return false; + } + + OutMachine.Ip = IpVal.string_value(); + OutMachine.Port = static_cast<uint16_t>(PortVal.int_value()); + + if (!ParseHexBytes(NonceVal.string_value(), OutMachine.Nonce, NonceSize)) + { + ZEN_WARN("invalid nonce hex string in machine response"); + return false; + } + + if (const json11::Json PortsVal = Json["ports"]; PortsVal.is_object()) + { + for (const auto& [Key, Val] : PortsVal.object_items()) + { + PortInfo Info; + if (Val["port"].is_number()) + { + Info.Port = static_cast<uint16_t>(Val["port"].int_value()); + } + if (Val["agentPort"].is_number()) + { + Info.AgentPort = static_cast<uint16_t>(Val["agentPort"].int_value()); + } + OutMachine.Ports[Key] = Info; + } + } + + if (const json11::Json ConnectionModeVal = Json["connectionMode"]; ConnectionModeVal.is_string()) + { + if (FromString(OutMachine.Mode, ConnectionModeVal.string_value())) + { + if (const json11::Json ConnectionAddressVal = Json["connectionAddress"]; ConnectionAddressVal.is_string()) + { + OutMachine.ConnectionAddress = ConnectionAddressVal.string_value(); + } + } + } + + // Properties are a flat string array of "Key=Value" pairs describing the machine. + // We extract OS family and core counts for sizing decisions. If neither core count + // is available, we fall back to 16 as a conservative default. + uint16_t LogicalCores = 0; + uint16_t PhysicalCores = 0; + + if (const json11::Json PropertiesVal = Json["properties"]; PropertiesVal.is_array()) + { + for (const json11::Json& PropVal : PropertiesVal.array_items()) + { + if (!PropVal.is_string()) + { + continue; + } + + const std::string Prop = PropVal.string_value(); + if (Prop.starts_with("OSFamily=")) + { + if (Prop.substr(9) == "Windows") + { + OutMachine.IsWindows = true; + } + } + else if (Prop.starts_with("LogicalCores=")) + { + LogicalCores = static_cast<uint16_t>(std::atoi(Prop.c_str() + 13)); + } + else if (Prop.starts_with("PhysicalCores=")) + { + PhysicalCores = static_cast<uint16_t>(std::atoi(Prop.c_str() + 14)); + } + } + } + + if (LogicalCores > 0) + { + OutMachine.LogicalCores = LogicalCores; + } + else if (PhysicalCores > 0) + { + OutMachine.LogicalCores = PhysicalCores * 2; + } + else + { + OutMachine.LogicalCores = 16; + } + + if (const json11::Json EncryptionVal = Json["encryption"]; EncryptionVal.is_string()) + { + if (FromString(OutMachine.EncryptionMode, EncryptionVal.string_value())) + { + if (OutMachine.EncryptionMode == Encryption::AES) + { + const json11::Json KeyVal = Json["key"]; + if (KeyVal.is_string() && !KeyVal.string_value().empty()) + { + if (!ParseHexBytes(KeyVal.string_value(), OutMachine.Key, KeySize)) + { + ZEN_WARN("invalid AES key in machine response"); + } + } + else + { + ZEN_WARN("AES encryption requested but no key provided"); + } + } + } + } + + if (const json11::Json LeaseIdVal = Json["leaseId"]; LeaseIdVal.is_string()) + { + OutMachine.LeaseId = LeaseIdVal.string_value(); + } + + ZEN_INFO("Horde machine assigned [{}:{}] cores={} lease={}", + OutMachine.GetConnectionAddress(), + OutMachine.GetConnectionPort(), + OutMachine.LogicalCores, + OutMachine.LeaseId); + + return true; +} + +} // namespace zen::horde diff --git a/src/zenhorde/hordecomputebuffer.cpp b/src/zenhorde/hordecomputebuffer.cpp new file mode 100644 index 000000000..0d032b5d5 --- /dev/null +++ b/src/zenhorde/hordecomputebuffer.cpp @@ -0,0 +1,454 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "hordecomputebuffer.h" + +#include <algorithm> +#include <cassert> +#include <chrono> +#include <condition_variable> +#include <cstring> + +namespace zen::horde { + +// Simplified ring buffer implementation for in-process use only. +// Uses a single contiguous buffer with write/read cursors and +// mutex+condvar for synchronization. This is simpler than the UE version +// which uses lock-free atomics and shared memory, but sufficient for our +// use case where we're the initiator side of the compute protocol. + +struct ComputeBuffer::Detail : TRefCounted<Detail> +{ + std::vector<uint8_t> Data; + size_t NumChunks = 0; + size_t ChunkLength = 0; + + // Current write state + size_t WriteChunkIdx = 0; + size_t WriteOffset = 0; + bool WriteComplete = false; + + // Current read state + size_t ReadChunkIdx = 0; + size_t ReadOffset = 0; + bool Detached = false; + + // Per-chunk written length + std::vector<size_t> ChunkWrittenLength; + std::vector<bool> ChunkFinished; // Writer moved to next chunk + + std::mutex Mutex; + std::condition_variable ReadCV; ///< Signaled when new data is written or stream completes + std::condition_variable WriteCV; ///< Signaled when reader advances past a chunk, freeing space + + bool HasWriter = false; + bool HasReader = false; + + uint8_t* ChunkPtr(size_t ChunkIdx) { return Data.data() + ChunkIdx * ChunkLength; } + const uint8_t* ChunkPtr(size_t ChunkIdx) const { return Data.data() + ChunkIdx * ChunkLength; } +}; + +// ComputeBuffer + +ComputeBuffer::ComputeBuffer() +{ +} +ComputeBuffer::~ComputeBuffer() +{ +} + +bool +ComputeBuffer::CreateNew(const Params& InParams) +{ + auto* NewDetail = new Detail(); + NewDetail->NumChunks = InParams.NumChunks; + NewDetail->ChunkLength = InParams.ChunkLength; + NewDetail->Data.resize(InParams.NumChunks * InParams.ChunkLength, 0); + NewDetail->ChunkWrittenLength.resize(InParams.NumChunks, 0); + NewDetail->ChunkFinished.resize(InParams.NumChunks, false); + + m_Detail = NewDetail; + return true; +} + +void +ComputeBuffer::Close() +{ + m_Detail = nullptr; +} + +bool +ComputeBuffer::IsValid() const +{ + return static_cast<bool>(m_Detail); +} + +ComputeBufferReader +ComputeBuffer::CreateReader() +{ + assert(m_Detail); + m_Detail->HasReader = true; + return ComputeBufferReader(m_Detail); +} + +ComputeBufferWriter +ComputeBuffer::CreateWriter() +{ + assert(m_Detail); + m_Detail->HasWriter = true; + return ComputeBufferWriter(m_Detail); +} + +// ComputeBufferReader + +ComputeBufferReader::ComputeBufferReader() +{ +} +ComputeBufferReader::~ComputeBufferReader() +{ +} + +ComputeBufferReader::ComputeBufferReader(const ComputeBufferReader& Other) = default; +ComputeBufferReader::ComputeBufferReader(ComputeBufferReader&& Other) noexcept = default; +ComputeBufferReader& ComputeBufferReader::operator=(const ComputeBufferReader& Other) = default; +ComputeBufferReader& ComputeBufferReader::operator=(ComputeBufferReader&& Other) noexcept = default; + +ComputeBufferReader::ComputeBufferReader(Ref<ComputeBuffer::Detail> InDetail) : m_Detail(std::move(InDetail)) +{ +} + +void +ComputeBufferReader::Close() +{ + m_Detail = nullptr; +} + +void +ComputeBufferReader::Detach() +{ + if (m_Detail) + { + std::lock_guard<std::mutex> Lock(m_Detail->Mutex); + m_Detail->Detached = true; + m_Detail->ReadCV.notify_all(); + } +} + +bool +ComputeBufferReader::IsValid() const +{ + return static_cast<bool>(m_Detail); +} + +bool +ComputeBufferReader::IsComplete() const +{ + if (!m_Detail) + { + return true; + } + std::lock_guard<std::mutex> Lock(m_Detail->Mutex); + if (m_Detail->Detached) + { + return true; + } + return m_Detail->WriteComplete && m_Detail->ReadChunkIdx == m_Detail->WriteChunkIdx && + m_Detail->ReadOffset >= m_Detail->ChunkWrittenLength[m_Detail->ReadChunkIdx]; +} + +void +ComputeBufferReader::AdvanceReadPosition(size_t Size) +{ + if (!m_Detail) + { + return; + } + + std::lock_guard<std::mutex> Lock(m_Detail->Mutex); + + m_Detail->ReadOffset += Size; + + // Check if we need to move to next chunk + const size_t ReadChunk = m_Detail->ReadChunkIdx; + if (m_Detail->ChunkFinished[ReadChunk] && m_Detail->ReadOffset >= m_Detail->ChunkWrittenLength[ReadChunk]) + { + const size_t NextChunk = (ReadChunk + 1) % m_Detail->NumChunks; + m_Detail->ReadChunkIdx = NextChunk; + m_Detail->ReadOffset = 0; + m_Detail->WriteCV.notify_all(); + } + + m_Detail->ReadCV.notify_all(); +} + +size_t +ComputeBufferReader::GetMaxReadSize() const +{ + if (!m_Detail) + { + return 0; + } + std::lock_guard<std::mutex> Lock(m_Detail->Mutex); + const size_t ReadChunk = m_Detail->ReadChunkIdx; + return m_Detail->ChunkWrittenLength[ReadChunk] - m_Detail->ReadOffset; +} + +const uint8_t* +ComputeBufferReader::WaitToRead(size_t MinSize, int TimeoutMs, bool* OutTimedOut) +{ + if (!m_Detail) + { + return nullptr; + } + + std::unique_lock<std::mutex> Lock(m_Detail->Mutex); + + auto Predicate = [&]() -> bool { + if (m_Detail->Detached) + { + return true; + } + + const size_t ReadChunk = m_Detail->ReadChunkIdx; + const size_t Available = m_Detail->ChunkWrittenLength[ReadChunk] - m_Detail->ReadOffset; + + if (Available >= MinSize) + { + return true; + } + + // If chunk is finished and we've read everything, try to move to next + if (m_Detail->ChunkFinished[ReadChunk] && m_Detail->ReadOffset >= m_Detail->ChunkWrittenLength[ReadChunk]) + { + if (m_Detail->WriteComplete) + { + return true; // End of stream + } + // Move to next chunk + const size_t NextChunk = (ReadChunk + 1) % m_Detail->NumChunks; + m_Detail->ReadChunkIdx = NextChunk; + m_Detail->ReadOffset = 0; + m_Detail->WriteCV.notify_all(); + return false; // Re-check with new chunk + } + + if (m_Detail->WriteComplete) + { + return true; // End of stream + } + + return false; + }; + + if (TimeoutMs < 0) + { + m_Detail->ReadCV.wait(Lock, Predicate); + } + else + { + if (!m_Detail->ReadCV.wait_for(Lock, std::chrono::milliseconds(TimeoutMs), Predicate)) + { + if (OutTimedOut) + { + *OutTimedOut = true; + } + return nullptr; + } + } + + if (m_Detail->Detached) + { + return nullptr; + } + + const size_t ReadChunk = m_Detail->ReadChunkIdx; + const size_t Available = m_Detail->ChunkWrittenLength[ReadChunk] - m_Detail->ReadOffset; + + if (Available < MinSize) + { + return nullptr; // End of stream + } + + return m_Detail->ChunkPtr(ReadChunk) + m_Detail->ReadOffset; +} + +size_t +ComputeBufferReader::Read(void* Buffer, size_t MaxSize, int TimeoutMs, bool* OutTimedOut) +{ + const uint8_t* Data = WaitToRead(1, TimeoutMs, OutTimedOut); + if (!Data) + { + return 0; + } + + const size_t Available = GetMaxReadSize(); + const size_t ToCopy = std::min(Available, MaxSize); + memcpy(Buffer, Data, ToCopy); + AdvanceReadPosition(ToCopy); + return ToCopy; +} + +// ComputeBufferWriter + +ComputeBufferWriter::ComputeBufferWriter() = default; +ComputeBufferWriter::ComputeBufferWriter(const ComputeBufferWriter& Other) = default; +ComputeBufferWriter::ComputeBufferWriter(ComputeBufferWriter&& Other) noexcept = default; +ComputeBufferWriter::~ComputeBufferWriter() = default; +ComputeBufferWriter& ComputeBufferWriter::operator=(const ComputeBufferWriter& Other) = default; +ComputeBufferWriter& ComputeBufferWriter::operator=(ComputeBufferWriter&& Other) noexcept = default; + +ComputeBufferWriter::ComputeBufferWriter(Ref<ComputeBuffer::Detail> InDetail) : m_Detail(std::move(InDetail)) +{ +} + +void +ComputeBufferWriter::Close() +{ + if (m_Detail) + { + { + std::lock_guard<std::mutex> Lock(m_Detail->Mutex); + if (!m_Detail->WriteComplete) + { + m_Detail->WriteComplete = true; + m_Detail->ReadCV.notify_all(); + } + } + m_Detail = nullptr; + } +} + +bool +ComputeBufferWriter::IsValid() const +{ + return static_cast<bool>(m_Detail); +} + +void +ComputeBufferWriter::MarkComplete() +{ + if (m_Detail) + { + std::lock_guard<std::mutex> Lock(m_Detail->Mutex); + m_Detail->WriteComplete = true; + m_Detail->ReadCV.notify_all(); + } +} + +void +ComputeBufferWriter::AdvanceWritePosition(size_t Size) +{ + if (!m_Detail || Size == 0) + { + return; + } + + std::lock_guard<std::mutex> Lock(m_Detail->Mutex); + const size_t WriteChunk = m_Detail->WriteChunkIdx; + m_Detail->ChunkWrittenLength[WriteChunk] += Size; + m_Detail->WriteOffset += Size; + m_Detail->ReadCV.notify_all(); +} + +size_t +ComputeBufferWriter::GetMaxWriteSize() const +{ + if (!m_Detail) + { + return 0; + } + std::lock_guard<std::mutex> Lock(m_Detail->Mutex); + const size_t WriteChunk = m_Detail->WriteChunkIdx; + return m_Detail->ChunkLength - m_Detail->ChunkWrittenLength[WriteChunk]; +} + +size_t +ComputeBufferWriter::GetChunkMaxLength() const +{ + if (!m_Detail) + { + return 0; + } + return m_Detail->ChunkLength; +} + +size_t +ComputeBufferWriter::Write(const void* Buffer, size_t MaxSize, int TimeoutMs) +{ + uint8_t* Dest = WaitToWrite(1, TimeoutMs); + if (!Dest) + { + return 0; + } + + const size_t Available = GetMaxWriteSize(); + const size_t ToCopy = std::min(Available, MaxSize); + memcpy(Dest, Buffer, ToCopy); + AdvanceWritePosition(ToCopy); + return ToCopy; +} + +uint8_t* +ComputeBufferWriter::WaitToWrite(size_t MinSize, int TimeoutMs) +{ + if (!m_Detail) + { + return nullptr; + } + + std::unique_lock<std::mutex> Lock(m_Detail->Mutex); + + if (m_Detail->WriteComplete) + { + return nullptr; + } + + const size_t WriteChunk = m_Detail->WriteChunkIdx; + const size_t Available = m_Detail->ChunkLength - m_Detail->ChunkWrittenLength[WriteChunk]; + + // If current chunk has enough space, return pointer + if (Available >= MinSize) + { + return m_Detail->ChunkPtr(WriteChunk) + m_Detail->ChunkWrittenLength[WriteChunk]; + } + + // Current chunk is full - mark it as finished and move to next. + // The writer cannot advance until the reader has fully consumed the next chunk, + // preventing the writer from overwriting data the reader hasn't processed yet. + m_Detail->ChunkFinished[WriteChunk] = true; + m_Detail->ReadCV.notify_all(); + + const size_t NextChunk = (WriteChunk + 1) % m_Detail->NumChunks; + + // Wait until reader has consumed the next chunk + auto Predicate = [&]() -> bool { + // Check if read has moved past this chunk + return m_Detail->ReadChunkIdx != NextChunk || m_Detail->Detached; + }; + + if (TimeoutMs < 0) + { + m_Detail->WriteCV.wait(Lock, Predicate); + } + else + { + if (!m_Detail->WriteCV.wait_for(Lock, std::chrono::milliseconds(TimeoutMs), Predicate)) + { + return nullptr; + } + } + + if (m_Detail->Detached) + { + return nullptr; + } + + // Reset next chunk + m_Detail->ChunkWrittenLength[NextChunk] = 0; + m_Detail->ChunkFinished[NextChunk] = false; + m_Detail->WriteChunkIdx = NextChunk; + m_Detail->WriteOffset = 0; + + return m_Detail->ChunkPtr(NextChunk); +} + +} // namespace zen::horde diff --git a/src/zenhorde/hordecomputebuffer.h b/src/zenhorde/hordecomputebuffer.h new file mode 100644 index 000000000..64ef91b7a --- /dev/null +++ b/src/zenhorde/hordecomputebuffer.h @@ -0,0 +1,136 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenbase/refcount.h> + +#include <cstddef> +#include <cstdint> +#include <mutex> +#include <vector> + +namespace zen::horde { + +class ComputeBufferReader; +class ComputeBufferWriter; + +/** Simplified in-process ring buffer for the Horde compute protocol. + * + * Unlike the UE FComputeBuffer which supports shared-memory and memory-mapped files, + * this implementation uses plain heap-allocated memory since we only need in-process + * communication between channel and transport threads. The buffer is divided into + * fixed-size chunks; readers and writers block when no space is available. + */ +class ComputeBuffer +{ +public: + struct Params + { + size_t NumChunks = 2; + size_t ChunkLength = 512 * 1024; + }; + + ComputeBuffer(); + ~ComputeBuffer(); + + ComputeBuffer(const ComputeBuffer&) = delete; + ComputeBuffer& operator=(const ComputeBuffer&) = delete; + + bool CreateNew(const Params& InParams); + void Close(); + + bool IsValid() const; + + ComputeBufferReader CreateReader(); + ComputeBufferWriter CreateWriter(); + +private: + struct Detail; + Ref<Detail> m_Detail; + + friend class ComputeBufferReader; + friend class ComputeBufferWriter; +}; + +/** Read endpoint for a ComputeBuffer. + * + * Provides blocking reads from the ring buffer. WaitToRead() returns a pointer + * directly into the buffer memory (zero-copy); the caller must call + * AdvanceReadPosition() after consuming the data. + */ +class ComputeBufferReader +{ +public: + ComputeBufferReader(); + ComputeBufferReader(const ComputeBufferReader&); + ComputeBufferReader(ComputeBufferReader&&) noexcept; + ~ComputeBufferReader(); + + ComputeBufferReader& operator=(const ComputeBufferReader&); + ComputeBufferReader& operator=(ComputeBufferReader&&) noexcept; + + void Close(); + void Detach(); + bool IsValid() const; + bool IsComplete() const; + + void AdvanceReadPosition(size_t Size); + size_t GetMaxReadSize() const; + + /** Copy up to MaxSize bytes from the buffer into Buffer. Blocks until data is available. */ + size_t Read(void* Buffer, size_t MaxSize, int TimeoutMs = -1, bool* OutTimedOut = nullptr); + + /** Wait until at least MinSize bytes are available and return a direct pointer. + * Returns nullptr on timeout or if the writer has completed. */ + const uint8_t* WaitToRead(size_t MinSize, int TimeoutMs = -1, bool* OutTimedOut = nullptr); + +private: + friend class ComputeBuffer; + explicit ComputeBufferReader(Ref<ComputeBuffer::Detail> InDetail); + + Ref<ComputeBuffer::Detail> m_Detail; +}; + +/** Write endpoint for a ComputeBuffer. + * + * Provides blocking writes into the ring buffer. WaitToWrite() returns a pointer + * directly into the buffer memory (zero-copy); the caller must call + * AdvanceWritePosition() after filling the data. Call MarkComplete() to signal + * that no more data will be written. + */ +class ComputeBufferWriter +{ +public: + ComputeBufferWriter(); + ComputeBufferWriter(const ComputeBufferWriter&); + ComputeBufferWriter(ComputeBufferWriter&&) noexcept; + ~ComputeBufferWriter(); + + ComputeBufferWriter& operator=(const ComputeBufferWriter&); + ComputeBufferWriter& operator=(ComputeBufferWriter&&) noexcept; + + void Close(); + bool IsValid() const; + + /** Signal that no more data will be written. Unblocks any waiting readers. */ + void MarkComplete(); + + void AdvanceWritePosition(size_t Size); + size_t GetMaxWriteSize() const; + size_t GetChunkMaxLength() const; + + /** Copy up to MaxSize bytes from Buffer into the ring buffer. Blocks until space is available. */ + size_t Write(const void* Buffer, size_t MaxSize, int TimeoutMs = -1); + + /** Wait until at least MinSize bytes of write space are available and return a direct pointer. + * Returns nullptr on timeout. */ + uint8_t* WaitToWrite(size_t MinSize, int TimeoutMs = -1); + +private: + friend class ComputeBuffer; + explicit ComputeBufferWriter(Ref<ComputeBuffer::Detail> InDetail); + + Ref<ComputeBuffer::Detail> m_Detail; +}; + +} // namespace zen::horde diff --git a/src/zenhorde/hordecomputechannel.cpp b/src/zenhorde/hordecomputechannel.cpp new file mode 100644 index 000000000..ee2a6f327 --- /dev/null +++ b/src/zenhorde/hordecomputechannel.cpp @@ -0,0 +1,37 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "hordecomputechannel.h" + +namespace zen::horde { + +ComputeChannel::ComputeChannel(ComputeBufferReader InReader, ComputeBufferWriter InWriter) +: Reader(std::move(InReader)) +, Writer(std::move(InWriter)) +{ +} + +bool +ComputeChannel::IsValid() const +{ + return Reader.IsValid() && Writer.IsValid(); +} + +size_t +ComputeChannel::Send(const void* Data, size_t Size, int TimeoutMs) +{ + return Writer.Write(Data, Size, TimeoutMs); +} + +size_t +ComputeChannel::Recv(void* Data, size_t Size, int TimeoutMs) +{ + return Reader.Read(Data, Size, TimeoutMs); +} + +void +ComputeChannel::MarkComplete() +{ + Writer.MarkComplete(); +} + +} // namespace zen::horde diff --git a/src/zenhorde/hordecomputechannel.h b/src/zenhorde/hordecomputechannel.h new file mode 100644 index 000000000..c1dff20e4 --- /dev/null +++ b/src/zenhorde/hordecomputechannel.h @@ -0,0 +1,32 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "hordecomputebuffer.h" + +namespace zen::horde { + +/** Bidirectional communication channel using a pair of compute buffers. + * + * Pairs a ComputeBufferReader (for receiving data) with a ComputeBufferWriter + * (for sending data). Used by ComputeSocket to represent one logical channel + * within a multiplexed connection. + */ +class ComputeChannel : public TRefCounted<ComputeChannel> +{ +public: + ComputeBufferReader Reader; + ComputeBufferWriter Writer; + + ComputeChannel(ComputeBufferReader InReader, ComputeBufferWriter InWriter); + + bool IsValid() const; + + size_t Send(const void* Data, size_t Size, int TimeoutMs = -1); + size_t Recv(void* Data, size_t Size, int TimeoutMs = -1); + + /** Signal that no more data will be sent on this channel. */ + void MarkComplete(); +}; + +} // namespace zen::horde diff --git a/src/zenhorde/hordecomputesocket.cpp b/src/zenhorde/hordecomputesocket.cpp new file mode 100644 index 000000000..6ef67760c --- /dev/null +++ b/src/zenhorde/hordecomputesocket.cpp @@ -0,0 +1,204 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "hordecomputesocket.h" + +#include <zencore/logging.h> + +namespace zen::horde { + +ComputeSocket::ComputeSocket(std::unique_ptr<ComputeTransport> Transport) +: m_Log(zen::logging::Get("horde.socket")) +, m_Transport(std::move(Transport)) +{ +} + +ComputeSocket::~ComputeSocket() +{ + // Shutdown order matters: first stop the ping thread, then unblock send threads + // by detaching readers, then join send threads, and finally close the transport + // to unblock the recv thread (which is blocked on RecvMessage). + { + std::lock_guard<std::mutex> Lock(m_PingMutex); + m_PingShouldStop = true; + m_PingCV.notify_all(); + } + + for (auto& Reader : m_Readers) + { + Reader.Detach(); + } + + for (auto& [Id, Thread] : m_SendThreads) + { + if (Thread.joinable()) + { + Thread.join(); + } + } + + m_Transport->Close(); + + if (m_RecvThread.joinable()) + { + m_RecvThread.join(); + } + if (m_PingThread.joinable()) + { + m_PingThread.join(); + } +} + +Ref<ComputeChannel> +ComputeSocket::CreateChannel(int ChannelId) +{ + ComputeBuffer::Params Params; + + ComputeBuffer RecvBuffer; + if (!RecvBuffer.CreateNew(Params)) + { + return {}; + } + + ComputeBuffer SendBuffer; + if (!SendBuffer.CreateNew(Params)) + { + return {}; + } + + Ref<ComputeChannel> Channel(new ComputeChannel(RecvBuffer.CreateReader(), SendBuffer.CreateWriter())); + + // Attach recv buffer writer (transport recv thread writes into this) + { + std::lock_guard<std::mutex> Lock(m_WritersMutex); + m_Writers.emplace(ChannelId, RecvBuffer.CreateWriter()); + } + + // Attach send buffer reader (send thread reads from this) + { + ComputeBufferReader Reader = SendBuffer.CreateReader(); + m_Readers.push_back(Reader); + m_SendThreads.emplace(ChannelId, std::thread(&ComputeSocket::SendThreadProc, this, ChannelId, std::move(Reader))); + } + + return Channel; +} + +void +ComputeSocket::StartCommunication() +{ + m_RecvThread = std::thread(&ComputeSocket::RecvThreadProc, this); + m_PingThread = std::thread(&ComputeSocket::PingThreadProc, this); +} + +void +ComputeSocket::PingThreadProc() +{ + while (true) + { + { + std::unique_lock<std::mutex> Lock(m_PingMutex); + if (m_PingCV.wait_for(Lock, std::chrono::milliseconds(2000), [this] { return m_PingShouldStop; })) + { + break; + } + } + + std::lock_guard<std::mutex> Lock(m_SendMutex); + FrameHeader Header; + Header.Channel = 0; + Header.Size = ControlPing; + m_Transport->SendMessage(&Header, sizeof(Header)); + } +} + +void +ComputeSocket::RecvThreadProc() +{ + // Writers are cached locally to avoid taking m_WritersMutex on every frame. + // The shared m_Writers map is only accessed when a channel is seen for the first time. + std::unordered_map<int, ComputeBufferWriter> CachedWriters; + + FrameHeader Header; + while (m_Transport->RecvMessage(&Header, sizeof(Header))) + { + if (Header.Size >= 0) + { + // Data frame + auto It = CachedWriters.find(Header.Channel); + if (It == CachedWriters.end()) + { + std::lock_guard<std::mutex> Lock(m_WritersMutex); + auto WIt = m_Writers.find(Header.Channel); + if (WIt == m_Writers.end()) + { + ZEN_WARN("recv frame for unknown channel {}", Header.Channel); + // Skip the data + std::vector<uint8_t> Discard(Header.Size); + m_Transport->RecvMessage(Discard.data(), Header.Size); + continue; + } + It = CachedWriters.emplace(Header.Channel, WIt->second).first; + } + + ComputeBufferWriter& Writer = It->second; + uint8_t* Dest = Writer.WaitToWrite(Header.Size); + if (!Dest || !m_Transport->RecvMessage(Dest, Header.Size)) + { + ZEN_WARN("failed to read frame data (channel={}, size={})", Header.Channel, Header.Size); + return; + } + Writer.AdvanceWritePosition(Header.Size); + } + else if (Header.Size == ControlDetach) + { + // Detach the recv buffer for this channel + CachedWriters.erase(Header.Channel); + + std::lock_guard<std::mutex> Lock(m_WritersMutex); + auto It = m_Writers.find(Header.Channel); + if (It != m_Writers.end()) + { + It->second.MarkComplete(); + m_Writers.erase(It); + } + } + else if (Header.Size == ControlPing) + { + // Ping response - ignore + } + else + { + ZEN_WARN("invalid frame header size: {}", Header.Size); + return; + } + } +} + +void +ComputeSocket::SendThreadProc(int Channel, ComputeBufferReader Reader) +{ + // Each channel has its own send thread. All send threads share m_SendMutex + // to serialize writes to the transport, since TCP requires atomic frame writes. + FrameHeader Header; + Header.Channel = Channel; + + const uint8_t* Data; + while ((Data = Reader.WaitToRead(1)) != nullptr) + { + std::lock_guard<std::mutex> Lock(m_SendMutex); + + Header.Size = static_cast<int32_t>(Reader.GetMaxReadSize()); + m_Transport->SendMessage(&Header, sizeof(Header)); + m_Transport->SendMessage(Data, Header.Size); + Reader.AdvanceReadPosition(Header.Size); + } + + if (Reader.IsComplete()) + { + std::lock_guard<std::mutex> Lock(m_SendMutex); + Header.Size = ControlDetach; + m_Transport->SendMessage(&Header, sizeof(Header)); + } +} + +} // namespace zen::horde diff --git a/src/zenhorde/hordecomputesocket.h b/src/zenhorde/hordecomputesocket.h new file mode 100644 index 000000000..0c3cb4195 --- /dev/null +++ b/src/zenhorde/hordecomputesocket.h @@ -0,0 +1,79 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "hordecomputebuffer.h" +#include "hordecomputechannel.h" +#include "hordetransport.h" + +#include <zencore/logbase.h> + +#include <condition_variable> +#include <memory> +#include <mutex> +#include <thread> +#include <unordered_map> +#include <vector> + +namespace zen::horde { + +/** Multiplexed socket that routes data between multiple channels over a single transport. + * + * Each channel is identified by an integer ID and backed by a pair of ComputeBuffers. + * A recv thread demultiplexes incoming frames to channel-specific buffers, while + * per-channel send threads multiplex outgoing data onto the shared transport. + * + * Wire format per frame: [channelId (4B)][size (4B)][data] + * Control messages use negative sizes: -2 = detach (channel closed), -3 = ping. + */ +class ComputeSocket +{ +public: + explicit ComputeSocket(std::unique_ptr<ComputeTransport> Transport); + ~ComputeSocket(); + + ComputeSocket(const ComputeSocket&) = delete; + ComputeSocket& operator=(const ComputeSocket&) = delete; + + /** Create a channel with the given ID. + * Allocates anonymous in-process buffers and spawns a send thread for the channel. */ + Ref<ComputeChannel> CreateChannel(int ChannelId); + + /** Start the recv pump and ping threads. Must be called after all channels are created. */ + void StartCommunication(); + +private: + struct FrameHeader + { + int32_t Channel = 0; + int32_t Size = 0; + }; + + static constexpr int32_t ControlDetach = -2; + static constexpr int32_t ControlPing = -3; + + LoggerRef Log() { return m_Log; } + + void RecvThreadProc(); + void SendThreadProc(int Channel, ComputeBufferReader Reader); + void PingThreadProc(); + + LoggerRef m_Log; + std::unique_ptr<ComputeTransport> m_Transport; + std::mutex m_SendMutex; ///< Serializes writes to the transport + + std::mutex m_WritersMutex; + std::unordered_map<int, ComputeBufferWriter> m_Writers; ///< Recv-side: writers keyed by channel ID + + std::vector<ComputeBufferReader> m_Readers; ///< Send-side: readers for join on destruction + std::unordered_map<int, std::thread> m_SendThreads; ///< One send thread per channel + + std::thread m_RecvThread; + std::thread m_PingThread; + + bool m_PingShouldStop = false; + std::mutex m_PingMutex; + std::condition_variable m_PingCV; +}; + +} // namespace zen::horde diff --git a/src/zenhorde/hordeconfig.cpp b/src/zenhorde/hordeconfig.cpp new file mode 100644 index 000000000..2dca228d9 --- /dev/null +++ b/src/zenhorde/hordeconfig.cpp @@ -0,0 +1,89 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenhorde/hordeconfig.h> + +namespace zen::horde { + +bool +HordeConfig::Validate() const +{ + if (ServerUrl.empty()) + { + return false; + } + + // Relay mode implies AES encryption + if (Mode == ConnectionMode::Relay && EncryptionMode != Encryption::AES) + { + return false; + } + + return true; +} + +const char* +ToString(ConnectionMode Mode) +{ + switch (Mode) + { + case ConnectionMode::Direct: + return "direct"; + case ConnectionMode::Tunnel: + return "tunnel"; + case ConnectionMode::Relay: + return "relay"; + } + return "direct"; +} + +const char* +ToString(Encryption Enc) +{ + switch (Enc) + { + case Encryption::None: + return "none"; + case Encryption::AES: + return "aes"; + } + return "none"; +} + +bool +FromString(ConnectionMode& OutMode, std::string_view Str) +{ + if (Str == "direct") + { + OutMode = ConnectionMode::Direct; + return true; + } + if (Str == "tunnel") + { + OutMode = ConnectionMode::Tunnel; + return true; + } + if (Str == "relay") + { + OutMode = ConnectionMode::Relay; + return true; + } + return false; +} + +bool +FromString(Encryption& OutEnc, std::string_view Str) +{ + if (Str == "none") + { + OutEnc = Encryption::None; + return true; + } + if (Str == "aes") + { + OutEnc = Encryption::AES; + return true; + } + return false; +} + +} // namespace zen::horde diff --git a/src/zenhorde/hordeprovisioner.cpp b/src/zenhorde/hordeprovisioner.cpp new file mode 100644 index 000000000..f88c95da2 --- /dev/null +++ b/src/zenhorde/hordeprovisioner.cpp @@ -0,0 +1,367 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenhorde/hordeclient.h> +#include <zenhorde/hordeprovisioner.h> + +#include "hordeagent.h" +#include "hordebundle.h" + +#include <zencore/fmtutils.h> +#include <zencore/logging.h> +#include <zencore/scopeguard.h> +#include <zencore/thread.h> +#include <zencore/trace.h> + +#include <chrono> +#include <thread> + +namespace zen::horde { + +struct HordeProvisioner::AgentWrapper +{ + std::thread Thread; + std::atomic<bool> ShouldExit{false}; +}; + +HordeProvisioner::HordeProvisioner(const HordeConfig& Config, + const std::filesystem::path& BinariesPath, + const std::filesystem::path& WorkingDir, + std::string_view OrchestratorEndpoint) +: m_Config(Config) +, m_BinariesPath(BinariesPath) +, m_WorkingDir(WorkingDir) +, m_OrchestratorEndpoint(OrchestratorEndpoint) +, m_Log(zen::logging::Get("horde.provisioner")) +{ +} + +HordeProvisioner::~HordeProvisioner() +{ + std::lock_guard<std::mutex> Lock(m_AgentsLock); + for (auto& Agent : m_Agents) + { + Agent->ShouldExit.store(true); + } + for (auto& Agent : m_Agents) + { + if (Agent->Thread.joinable()) + { + Agent->Thread.join(); + } + } +} + +void +HordeProvisioner::SetTargetCoreCount(uint32_t Count) +{ + ZEN_TRACE_CPU("HordeProvisioner::SetTargetCoreCount"); + + m_TargetCoreCount.store(std::min(Count, static_cast<uint32_t>(m_Config.MaxCores))); + + while (m_EstimatedCoreCount.load() < m_TargetCoreCount.load()) + { + if (!m_AskForAgents.load()) + { + return; + } + RequestAgent(); + } + + // Clean up finished agent threads + std::lock_guard<std::mutex> Lock(m_AgentsLock); + for (auto It = m_Agents.begin(); It != m_Agents.end();) + { + if ((*It)->ShouldExit.load()) + { + if ((*It)->Thread.joinable()) + { + (*It)->Thread.join(); + } + It = m_Agents.erase(It); + } + else + { + ++It; + } + } +} + +ProvisioningStats +HordeProvisioner::GetStats() const +{ + ProvisioningStats Stats; + Stats.TargetCoreCount = m_TargetCoreCount.load(); + Stats.EstimatedCoreCount = m_EstimatedCoreCount.load(); + Stats.ActiveCoreCount = m_ActiveCoreCount.load(); + Stats.AgentsActive = m_AgentsActive.load(); + Stats.AgentsRequesting = m_AgentsRequesting.load(); + return Stats; +} + +uint32_t +HordeProvisioner::GetAgentCount() const +{ + std::lock_guard<std::mutex> Lock(m_AgentsLock); + return static_cast<uint32_t>(m_Agents.size()); +} + +void +HordeProvisioner::RequestAgent() +{ + m_EstimatedCoreCount.fetch_add(EstimatedCoresPerAgent); + + std::lock_guard<std::mutex> Lock(m_AgentsLock); + + auto Wrapper = std::make_unique<AgentWrapper>(); + AgentWrapper& Ref = *Wrapper; + Wrapper->Thread = std::thread([this, &Ref] { ThreadAgent(Ref); }); + + m_Agents.push_back(std::move(Wrapper)); +} + +void +HordeProvisioner::ThreadAgent(AgentWrapper& Wrapper) +{ + ZEN_TRACE_CPU("HordeProvisioner::ThreadAgent"); + + static std::atomic<uint32_t> ThreadIndex{0}; + const uint32_t CurrentIndex = ThreadIndex.fetch_add(1); + + zen::SetCurrentThreadName(fmt::format("horde_agent_{}", CurrentIndex)); + + std::unique_ptr<HordeAgent> Agent; + uint32_t MachineCoreCount = 0; + + auto _ = MakeGuard([&] { + if (Agent) + { + Agent->CloseConnection(); + } + Wrapper.ShouldExit.store(true); + }); + + { + // EstimatedCoreCount is incremented speculatively when the agent is requested + // (in RequestAgent) so that SetTargetCoreCount doesn't over-provision. + auto $ = MakeGuard([&] { m_EstimatedCoreCount.fetch_sub(EstimatedCoresPerAgent); }); + + { + ZEN_TRACE_CPU("HordeProvisioner::CreateBundles"); + + std::lock_guard<std::mutex> BundleLock(m_BundleLock); + + if (!m_BundlesCreated) + { + const std::filesystem::path OutputDir = m_WorkingDir / "horde_bundles"; + + std::vector<BundleFile> Files; + +#if ZEN_PLATFORM_WINDOWS + Files.emplace_back(m_BinariesPath / "zenserver.exe", false); +#elif ZEN_PLATFORM_LINUX + Files.emplace_back(m_BinariesPath / "zenserver", false); + Files.emplace_back(m_BinariesPath / "zenserver.debug", true); +#elif ZEN_PLATFORM_MAC + Files.emplace_back(m_BinariesPath / "zenserver", false); +#endif + + BundleResult Result; + if (!BundleCreator::CreateBundle(Files, OutputDir, Result)) + { + ZEN_WARN("failed to create bundle, cannot provision any agents!"); + m_AskForAgents.store(false); + return; + } + + m_Bundles.emplace_back(Result.Locator, Result.BundleDir); + m_BundlesCreated = true; + } + + if (!m_HordeClient) + { + m_HordeClient = std::make_unique<HordeClient>(m_Config); + if (!m_HordeClient->Initialize()) + { + ZEN_WARN("failed to initialize Horde HTTP client, cannot provision any agents!"); + m_AskForAgents.store(false); + return; + } + } + } + + if (!m_AskForAgents.load()) + { + return; + } + + m_AgentsRequesting.fetch_add(1); + auto ReqGuard = MakeGuard([this] { m_AgentsRequesting.fetch_sub(1); }); + + // Simple backoff: if the last machine request failed, wait up to 5 seconds + // before trying again. + // + // Note however that it's possible that multiple threads enter this code at + // the same time if multiple agents are requested at once, and they will all + // see the same last failure time and back off accordingly. We might want to + // use a semaphore or similar to limit the number of concurrent requests. + + if (const uint64_t LastFail = m_LastRequestFailTime.load(); LastFail != 0) + { + auto Now = static_cast<uint64_t>(std::chrono::steady_clock::now().time_since_epoch().count()); + const uint64_t ElapsedNs = Now - LastFail; + const uint64_t ElapsedMs = ElapsedNs / 1'000'000; + if (ElapsedMs < 5000) + { + const uint64_t WaitMs = 5000 - ElapsedMs; + for (uint64_t Waited = 0; Waited < WaitMs && !Wrapper.ShouldExit.load(); Waited += 100) + { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + + if (Wrapper.ShouldExit.load()) + { + return; + } + } + } + + if (m_ActiveCoreCount.load() >= m_TargetCoreCount.load()) + { + return; + } + + std::string RequestBody = m_HordeClient->BuildRequestBody(); + + // Resolve cluster if needed + std::string ClusterId = m_Config.Cluster; + if (ClusterId == HordeConfig::ClusterAuto) + { + ClusterInfo Cluster; + if (!m_HordeClient->ResolveCluster(RequestBody, Cluster)) + { + ZEN_WARN("failed to resolve cluster"); + m_LastRequestFailTime.store(static_cast<uint64_t>(std::chrono::steady_clock::now().time_since_epoch().count())); + return; + } + ClusterId = Cluster.ClusterId; + } + + MachineInfo Machine; + if (!m_HordeClient->RequestMachine(RequestBody, ClusterId, /* out */ Machine) || !Machine.IsValid()) + { + m_LastRequestFailTime.store(static_cast<uint64_t>(std::chrono::steady_clock::now().time_since_epoch().count())); + return; + } + + m_LastRequestFailTime.store(0); + + if (Wrapper.ShouldExit.load()) + { + return; + } + + // Connect to agent and perform handshake + Agent = std::make_unique<HordeAgent>(Machine); + if (!Agent->IsValid()) + { + ZEN_WARN("agent creation failed for {}:{}", Machine.GetConnectionAddress(), Machine.GetConnectionPort()); + return; + } + + if (!Agent->BeginCommunication()) + { + ZEN_WARN("BeginCommunication failed"); + return; + } + + for (auto& [Locator, BundleDir] : m_Bundles) + { + if (Wrapper.ShouldExit.load()) + { + return; + } + + if (!Agent->UploadBinaries(BundleDir, Locator)) + { + ZEN_WARN("UploadBinaries failed"); + return; + } + } + + if (Wrapper.ShouldExit.load()) + { + return; + } + + // Build command line for remote zenserver + std::vector<std::string> ArgStrings; + ArgStrings.push_back("compute"); + ArgStrings.push_back("--http=asio"); + + // TEMP HACK - these should be made fully dynamic + // these are currently here to allow spawning the compute agent locally + // for debugging purposes (i.e with a local Horde Server+Agent setup) + ArgStrings.push_back(fmt::format("--port={}", m_Config.ZenServicePort)); + ArgStrings.push_back("--data-dir=c:\\temp\\123"); + + if (!m_OrchestratorEndpoint.empty()) + { + ExtendableStringBuilder<256> CoordArg; + CoordArg << "--coordinator-endpoint=" << m_OrchestratorEndpoint; + ArgStrings.emplace_back(CoordArg.ToView()); + } + + { + ExtendableStringBuilder<128> IdArg; + IdArg << "--instance-id=horde-" << Machine.LeaseId; + ArgStrings.emplace_back(IdArg.ToView()); + } + + std::vector<const char*> Args; + Args.reserve(ArgStrings.size()); + for (const std::string& Arg : ArgStrings) + { + Args.push_back(Arg.c_str()); + } + +#if ZEN_PLATFORM_WINDOWS + const bool UseWine = !Machine.IsWindows; + const char* AppName = "zenserver.exe"; +#else + const bool UseWine = false; + const char* AppName = "zenserver"; +#endif + + Agent->Execute(AppName, Args.data(), Args.size(), nullptr, nullptr, 0, UseWine); + + ZEN_INFO("remote execution started on [{}:{}] lease={}", + Machine.GetConnectionAddress(), + Machine.GetConnectionPort(), + Machine.LeaseId); + + MachineCoreCount = Machine.LogicalCores; + m_EstimatedCoreCount.fetch_add(MachineCoreCount); + m_ActiveCoreCount.fetch_add(MachineCoreCount); + m_AgentsActive.fetch_add(1); + } + + // Agent poll loop + + auto ActiveGuard = MakeGuard([&]() { + m_EstimatedCoreCount.fetch_sub(MachineCoreCount); + m_ActiveCoreCount.fetch_sub(MachineCoreCount); + m_AgentsActive.fetch_sub(1); + }); + + while (Agent->IsValid() && !Wrapper.ShouldExit.load()) + { + const bool LogOutput = false; + if (!Agent->Poll(LogOutput)) + { + break; + } + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } +} + +} // namespace zen::horde diff --git a/src/zenhorde/hordetransport.cpp b/src/zenhorde/hordetransport.cpp new file mode 100644 index 000000000..69766e73e --- /dev/null +++ b/src/zenhorde/hordetransport.cpp @@ -0,0 +1,169 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "hordetransport.h" + +#include <zencore/logging.h> +#include <zencore/trace.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <asio.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +#if ZEN_PLATFORM_WINDOWS +# undef SendMessage +#endif + +namespace zen::horde { + +// ComputeTransport base + +bool +ComputeTransport::SendMessage(const void* Data, size_t Size) +{ + const uint8_t* Ptr = static_cast<const uint8_t*>(Data); + size_t Remaining = Size; + + while (Remaining > 0) + { + const size_t Sent = Send(Ptr, Remaining); + if (Sent == 0) + { + return false; + } + Ptr += Sent; + Remaining -= Sent; + } + + return true; +} + +bool +ComputeTransport::RecvMessage(void* Data, size_t Size) +{ + uint8_t* Ptr = static_cast<uint8_t*>(Data); + size_t Remaining = Size; + + while (Remaining > 0) + { + const size_t Received = Recv(Ptr, Remaining); + if (Received == 0) + { + return false; + } + Ptr += Received; + Remaining -= Received; + } + + return true; +} + +// TcpComputeTransport - ASIO pimpl + +struct TcpComputeTransport::Impl +{ + asio::io_context IoContext; + asio::ip::tcp::socket Socket; + + Impl() : Socket(IoContext) {} +}; + +// Uses ASIO in synchronous mode only — no async operations or io_context::run(). +// The io_context is only needed because ASIO sockets require one to be constructed. +TcpComputeTransport::TcpComputeTransport(const MachineInfo& Info) +: m_Impl(std::make_unique<Impl>()) +, m_Log(zen::logging::Get("horde.transport")) +{ + ZEN_TRACE_CPU("TcpComputeTransport::Connect"); + + asio::error_code Ec; + + const asio::ip::address Address = asio::ip::make_address(Info.GetConnectionAddress(), Ec); + if (Ec) + { + ZEN_WARN("invalid address '{}': {}", Info.GetConnectionAddress(), Ec.message()); + m_HasErrors = true; + return; + } + + const asio::ip::tcp::endpoint Endpoint(Address, Info.GetConnectionPort()); + + m_Impl->Socket.connect(Endpoint, Ec); + if (Ec) + { + ZEN_WARN("failed to connect to Horde compute [{}:{}]: {}", Info.GetConnectionAddress(), Info.GetConnectionPort(), Ec.message()); + m_HasErrors = true; + return; + } + + // Disable Nagle's algorithm for lower latency + m_Impl->Socket.set_option(asio::ip::tcp::no_delay(true), Ec); +} + +TcpComputeTransport::~TcpComputeTransport() +{ + Close(); +} + +bool +TcpComputeTransport::IsValid() const +{ + return m_Impl && m_Impl->Socket.is_open() && !m_HasErrors && !m_IsClosed; +} + +size_t +TcpComputeTransport::Send(const void* Data, size_t Size) +{ + if (!IsValid()) + { + return 0; + } + + asio::error_code Ec; + const size_t Sent = m_Impl->Socket.send(asio::buffer(Data, Size), 0, Ec); + + if (Ec) + { + m_HasErrors = true; + return 0; + } + + return Sent; +} + +size_t +TcpComputeTransport::Recv(void* Data, size_t Size) +{ + if (!IsValid()) + { + return 0; + } + + asio::error_code Ec; + const size_t Received = m_Impl->Socket.receive(asio::buffer(Data, Size), 0, Ec); + + if (Ec) + { + return 0; + } + + return Received; +} + +void +TcpComputeTransport::MarkComplete() +{ +} + +void +TcpComputeTransport::Close() +{ + if (!m_IsClosed && m_Impl && m_Impl->Socket.is_open()) + { + asio::error_code Ec; + m_Impl->Socket.shutdown(asio::ip::tcp::socket::shutdown_both, Ec); + m_Impl->Socket.close(Ec); + } + m_IsClosed = true; +} + +} // namespace zen::horde diff --git a/src/zenhorde/hordetransport.h b/src/zenhorde/hordetransport.h new file mode 100644 index 000000000..1b178dc0f --- /dev/null +++ b/src/zenhorde/hordetransport.h @@ -0,0 +1,71 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenhorde/hordeclient.h> + +#include <zencore/logbase.h> + +#include <cstddef> +#include <cstdint> +#include <memory> + +#if ZEN_PLATFORM_WINDOWS +# undef SendMessage +#endif + +namespace zen::horde { + +/** Abstract base interface for compute transports. + * + * Matches the UE FComputeTransport pattern. Concrete implementations handle + * the underlying I/O (TCP, AES-wrapped, etc.) while this interface provides + * blocking message helpers on top. + */ +class ComputeTransport +{ +public: + virtual ~ComputeTransport() = default; + + virtual bool IsValid() const = 0; + virtual size_t Send(const void* Data, size_t Size) = 0; + virtual size_t Recv(void* Data, size_t Size) = 0; + virtual void MarkComplete() = 0; + virtual void Close() = 0; + + /** Blocking send that loops until all bytes are transferred. Returns false on error. */ + bool SendMessage(const void* Data, size_t Size); + + /** Blocking receive that loops until all bytes are transferred. Returns false on error. */ + bool RecvMessage(void* Data, size_t Size); +}; + +/** TCP socket transport using ASIO. + * + * Connects to the Horde compute endpoint specified by MachineInfo and provides + * raw TCP send/receive. ASIO internals are hidden behind a pimpl to keep the + * header clean. + */ +class TcpComputeTransport final : public ComputeTransport +{ +public: + explicit TcpComputeTransport(const MachineInfo& Info); + ~TcpComputeTransport() override; + + bool IsValid() const override; + size_t Send(const void* Data, size_t Size) override; + size_t Recv(void* Data, size_t Size) override; + void MarkComplete() override; + void Close() override; + +private: + LoggerRef Log() { return m_Log; } + + struct Impl; + std::unique_ptr<Impl> m_Impl; + LoggerRef m_Log; + bool m_IsClosed = false; + bool m_HasErrors = false; +}; + +} // namespace zen::horde diff --git a/src/zenhorde/hordetransportaes.cpp b/src/zenhorde/hordetransportaes.cpp new file mode 100644 index 000000000..986dd3705 --- /dev/null +++ b/src/zenhorde/hordetransportaes.cpp @@ -0,0 +1,425 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "hordetransportaes.h" + +#include <zencore/logging.h> +#include <zencore/trace.h> + +#include <algorithm> +#include <cstring> +#include <random> + +#if ZEN_PLATFORM_WINDOWS +# include <zencore/windows.h> +# include <bcrypt.h> +# pragma comment(lib, "Bcrypt.lib") +#else +ZEN_THIRD_PARTY_INCLUDES_START +# include <openssl/evp.h> +# include <openssl/err.h> +ZEN_THIRD_PARTY_INCLUDES_END +#endif + +namespace zen::horde { + +struct AesComputeTransport::CryptoContext +{ + uint8_t Key[KeySize] = {}; + uint8_t EncryptNonce[NonceBytes] = {}; + uint8_t DecryptNonce[NonceBytes] = {}; + bool HasErrors = false; + +#if !ZEN_PLATFORM_WINDOWS + EVP_CIPHER_CTX* EncCtx = nullptr; + EVP_CIPHER_CTX* DecCtx = nullptr; +#endif + + CryptoContext(const uint8_t (&InKey)[KeySize]) + { + memcpy(Key, InKey, KeySize); + + // The encrypt nonce is randomly initialized and then deterministically mutated + // per message via UpdateNonce(). The decrypt nonce is not used — it comes from + // the wire (each received message carries its own nonce in the header). + std::random_device Rd; + std::mt19937 Gen(Rd()); + std::uniform_int_distribution<int> Dist(0, 255); + for (auto& Byte : EncryptNonce) + { + Byte = static_cast<uint8_t>(Dist(Gen)); + } + +#if !ZEN_PLATFORM_WINDOWS + // Drain any stale OpenSSL errors + while (ERR_get_error() != 0) + { + } + + EncCtx = EVP_CIPHER_CTX_new(); + EVP_EncryptInit_ex(EncCtx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr); + + DecCtx = EVP_CIPHER_CTX_new(); + EVP_DecryptInit_ex(DecCtx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr); +#endif + } + + ~CryptoContext() + { +#if ZEN_PLATFORM_WINDOWS + SecureZeroMemory(Key, sizeof(Key)); + SecureZeroMemory(EncryptNonce, sizeof(EncryptNonce)); + SecureZeroMemory(DecryptNonce, sizeof(DecryptNonce)); +#else + OPENSSL_cleanse(Key, sizeof(Key)); + OPENSSL_cleanse(EncryptNonce, sizeof(EncryptNonce)); + OPENSSL_cleanse(DecryptNonce, sizeof(DecryptNonce)); + + if (EncCtx) + { + EVP_CIPHER_CTX_free(EncCtx); + } + if (DecCtx) + { + EVP_CIPHER_CTX_free(DecCtx); + } +#endif + } + + void UpdateNonce() + { + uint32_t* N32 = reinterpret_cast<uint32_t*>(EncryptNonce); + N32[0]++; + N32[1]--; + N32[2] = N32[0] ^ N32[1]; + } + + // Returns total encrypted message size, or 0 on failure + // Output format: [length(4B)][nonce(12B)][ciphertext][tag(16B)] + int32_t EncryptMessage(uint8_t* Out, const void* In, int32_t InLength) + { + UpdateNonce(); + + // On Windows, BCrypt algorithm/key handles are created per call. This is simpler than + // caching but has some overhead. For our use case (relatively large, infrequent messages) + // this is acceptable. +#if ZEN_PLATFORM_WINDOWS + BCRYPT_ALG_HANDLE hAlg = nullptr; + BCRYPT_KEY_HANDLE hKey = nullptr; + + BCryptOpenAlgorithmProvider(&hAlg, BCRYPT_AES_ALGORITHM, nullptr, 0); + BCryptSetProperty(hAlg, BCRYPT_CHAINING_MODE, (PUCHAR)BCRYPT_CHAIN_MODE_GCM, sizeof(BCRYPT_CHAIN_MODE_GCM), 0); + BCryptGenerateSymmetricKey(hAlg, &hKey, nullptr, 0, (PUCHAR)Key, KeySize, 0); + + BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO AuthInfo; + BCRYPT_INIT_AUTH_MODE_INFO(AuthInfo); + AuthInfo.pbNonce = EncryptNonce; + AuthInfo.cbNonce = NonceBytes; + uint8_t Tag[TagBytes] = {}; + AuthInfo.pbTag = Tag; + AuthInfo.cbTag = TagBytes; + + ULONG CipherLen = 0; + NTSTATUS Status = + BCryptEncrypt(hKey, (PUCHAR)In, (ULONG)InLength, &AuthInfo, nullptr, 0, Out + 4 + NonceBytes, (ULONG)InLength, &CipherLen, 0); + + if (!BCRYPT_SUCCESS(Status)) + { + HasErrors = true; + BCryptDestroyKey(hKey); + BCryptCloseAlgorithmProvider(hAlg, 0); + return 0; + } + + // Write header: length + nonce + memcpy(Out, &InLength, 4); + memcpy(Out + 4, EncryptNonce, NonceBytes); + // Write tag after ciphertext + memcpy(Out + 4 + NonceBytes + CipherLen, Tag, TagBytes); + + BCryptDestroyKey(hKey); + BCryptCloseAlgorithmProvider(hAlg, 0); + + return 4 + NonceBytes + static_cast<int32_t>(CipherLen) + TagBytes; +#else + if (EVP_EncryptInit_ex(EncCtx, nullptr, nullptr, Key, EncryptNonce) != 1) + { + HasErrors = true; + return 0; + } + + int32_t Offset = 0; + // Write length + memcpy(Out + Offset, &InLength, 4); + Offset += 4; + // Write nonce + memcpy(Out + Offset, EncryptNonce, NonceBytes); + Offset += NonceBytes; + + // Encrypt + int OutLen = 0; + if (EVP_EncryptUpdate(EncCtx, Out + Offset, &OutLen, static_cast<const uint8_t*>(In), InLength) != 1) + { + HasErrors = true; + return 0; + } + Offset += OutLen; + + // Finalize + int FinalLen = 0; + if (EVP_EncryptFinal_ex(EncCtx, Out + Offset, &FinalLen) != 1) + { + HasErrors = true; + return 0; + } + Offset += FinalLen; + + // Get tag + if (EVP_CIPHER_CTX_ctrl(EncCtx, EVP_CTRL_GCM_GET_TAG, TagBytes, Out + Offset) != 1) + { + HasErrors = true; + return 0; + } + Offset += TagBytes; + + return Offset; +#endif + } + + // Decrypt a message. Returns decrypted data length, or 0 on failure. + // Input must be [ciphertext][tag], with nonce provided separately. + int32_t DecryptMessage(void* Out, const uint8_t* Nonce, const uint8_t* CipherAndTag, int32_t DataLength) + { +#if ZEN_PLATFORM_WINDOWS + BCRYPT_ALG_HANDLE hAlg = nullptr; + BCRYPT_KEY_HANDLE hKey = nullptr; + + BCryptOpenAlgorithmProvider(&hAlg, BCRYPT_AES_ALGORITHM, nullptr, 0); + BCryptSetProperty(hAlg, BCRYPT_CHAINING_MODE, (PUCHAR)BCRYPT_CHAIN_MODE_GCM, sizeof(BCRYPT_CHAIN_MODE_GCM), 0); + BCryptGenerateSymmetricKey(hAlg, &hKey, nullptr, 0, (PUCHAR)Key, KeySize, 0); + + BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO AuthInfo; + BCRYPT_INIT_AUTH_MODE_INFO(AuthInfo); + AuthInfo.pbNonce = const_cast<uint8_t*>(Nonce); + AuthInfo.cbNonce = NonceBytes; + AuthInfo.pbTag = const_cast<uint8_t*>(CipherAndTag + DataLength); + AuthInfo.cbTag = TagBytes; + + ULONG PlainLen = 0; + NTSTATUS Status = BCryptDecrypt(hKey, + (PUCHAR)CipherAndTag, + (ULONG)DataLength, + &AuthInfo, + nullptr, + 0, + (PUCHAR)Out, + (ULONG)DataLength, + &PlainLen, + 0); + + BCryptDestroyKey(hKey); + BCryptCloseAlgorithmProvider(hAlg, 0); + + if (!BCRYPT_SUCCESS(Status)) + { + HasErrors = true; + return 0; + } + + return static_cast<int32_t>(PlainLen); +#else + if (EVP_DecryptInit_ex(DecCtx, nullptr, nullptr, Key, Nonce) != 1) + { + HasErrors = true; + return 0; + } + + int OutLen = 0; + if (EVP_DecryptUpdate(DecCtx, static_cast<uint8_t*>(Out), &OutLen, CipherAndTag, DataLength) != 1) + { + HasErrors = true; + return 0; + } + + // Set the tag for verification + if (EVP_CIPHER_CTX_ctrl(DecCtx, EVP_CTRL_GCM_SET_TAG, TagBytes, const_cast<uint8_t*>(CipherAndTag + DataLength)) != 1) + { + HasErrors = true; + return 0; + } + + int FinalLen = 0; + if (EVP_DecryptFinal_ex(DecCtx, static_cast<uint8_t*>(Out) + OutLen, &FinalLen) != 1) + { + HasErrors = true; + return 0; + } + + return OutLen + FinalLen; +#endif + } +}; + +AesComputeTransport::AesComputeTransport(const uint8_t (&Key)[KeySize], std::unique_ptr<ComputeTransport> InnerTransport) +: m_Crypto(std::make_unique<CryptoContext>(Key)) +, m_Inner(std::move(InnerTransport)) +{ +} + +AesComputeTransport::~AesComputeTransport() +{ + Close(); +} + +bool +AesComputeTransport::IsValid() const +{ + return m_Inner && m_Inner->IsValid() && m_Crypto && !m_Crypto->HasErrors && !m_IsClosed; +} + +size_t +AesComputeTransport::Send(const void* Data, size_t Size) +{ + ZEN_TRACE_CPU("AesComputeTransport::Send"); + + if (!IsValid()) + { + return 0; + } + + std::lock_guard<std::mutex> Lock(m_Lock); + + const int32_t DataLength = static_cast<int32_t>(Size); + const size_t MessageLength = 4 + NonceBytes + Size + TagBytes; + + if (m_EncryptBuffer.size() < MessageLength) + { + m_EncryptBuffer.resize(MessageLength); + } + + const int32_t EncryptedLen = m_Crypto->EncryptMessage(m_EncryptBuffer.data(), Data, DataLength); + if (EncryptedLen == 0) + { + return 0; + } + + if (!m_Inner->SendMessage(m_EncryptBuffer.data(), static_cast<size_t>(EncryptedLen))) + { + return 0; + } + + return Size; +} + +size_t +AesComputeTransport::Recv(void* Data, size_t Size) +{ + if (!IsValid()) + { + return 0; + } + + // AES-GCM decrypts entire messages at once, but the caller may request fewer bytes + // than the decrypted message contains. Excess bytes are buffered in m_RemainingData + // and returned on subsequent Recv calls without another decryption round-trip. + ZEN_TRACE_CPU("AesComputeTransport::Recv"); + + std::lock_guard<std::mutex> Lock(m_Lock); + + if (!m_RemainingData.empty()) + { + const size_t Available = m_RemainingData.size() - m_RemainingOffset; + const size_t ToCopy = std::min(Available, Size); + + memcpy(Data, m_RemainingData.data() + m_RemainingOffset, ToCopy); + m_RemainingOffset += ToCopy; + + if (m_RemainingOffset >= m_RemainingData.size()) + { + m_RemainingData.clear(); + m_RemainingOffset = 0; + } + + return ToCopy; + } + + // Receive packet header: [length(4B)][nonce(12B)] + struct PacketHeader + { + int32_t DataLength = 0; + uint8_t Nonce[NonceBytes] = {}; + } Header; + + if (!m_Inner->RecvMessage(&Header, sizeof(Header))) + { + return 0; + } + + // Validate DataLength to prevent OOM from malicious/corrupt peers + static constexpr int32_t MaxDataLength = 64 * 1024 * 1024; // 64 MiB + + if (Header.DataLength <= 0 || Header.DataLength > MaxDataLength) + { + ZEN_WARN("AES recv: invalid DataLength {} from peer", Header.DataLength); + return 0; + } + + // Receive ciphertext + tag + const size_t MessageLength = static_cast<size_t>(Header.DataLength) + TagBytes; + + if (m_EncryptBuffer.size() < MessageLength) + { + m_EncryptBuffer.resize(MessageLength); + } + + if (!m_Inner->RecvMessage(m_EncryptBuffer.data(), MessageLength)) + { + return 0; + } + + // Decrypt + const size_t BytesToReturn = std::min(static_cast<size_t>(Header.DataLength), Size); + + // We need a temporary buffer for decryption if we can't decrypt directly into output + std::vector<uint8_t> DecryptedBuf(static_cast<size_t>(Header.DataLength)); + + const int32_t Decrypted = m_Crypto->DecryptMessage(DecryptedBuf.data(), Header.Nonce, m_EncryptBuffer.data(), Header.DataLength); + if (Decrypted == 0) + { + return 0; + } + + memcpy(Data, DecryptedBuf.data(), BytesToReturn); + + // Store remaining data if we couldn't return everything + if (static_cast<size_t>(Header.DataLength) > BytesToReturn) + { + m_RemainingOffset = 0; + m_RemainingData.assign(DecryptedBuf.begin() + BytesToReturn, DecryptedBuf.begin() + Header.DataLength); + } + + return BytesToReturn; +} + +void +AesComputeTransport::MarkComplete() +{ + if (IsValid()) + { + m_Inner->MarkComplete(); + } +} + +void +AesComputeTransport::Close() +{ + if (!m_IsClosed) + { + if (m_Inner && m_Inner->IsValid()) + { + m_Inner->Close(); + } + m_IsClosed = true; + } +} + +} // namespace zen::horde diff --git a/src/zenhorde/hordetransportaes.h b/src/zenhorde/hordetransportaes.h new file mode 100644 index 000000000..efcad9835 --- /dev/null +++ b/src/zenhorde/hordetransportaes.h @@ -0,0 +1,52 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "hordetransport.h" + +#include <cstdint> +#include <memory> +#include <mutex> +#include <vector> + +namespace zen::horde { + +/** AES-256-GCM encrypted transport wrapper. + * + * Wraps an inner ComputeTransport, encrypting all outgoing data and decrypting + * all incoming data using AES-256-GCM. The nonce is mutated per message using + * the Horde nonce mangling scheme: n32[0]++; n32[1]--; n32[2] = n32[0] ^ n32[1]. + * + * Wire format per encrypted message: + * [plaintext length (4B little-endian)][nonce (12B)][ciphertext][GCM tag (16B)] + * + * Uses BCrypt on Windows and OpenSSL EVP on Linux/macOS (selected at compile time). + */ +class AesComputeTransport final : public ComputeTransport +{ +public: + AesComputeTransport(const uint8_t (&Key)[KeySize], std::unique_ptr<ComputeTransport> InnerTransport); + ~AesComputeTransport() override; + + bool IsValid() const override; + size_t Send(const void* Data, size_t Size) override; + size_t Recv(void* Data, size_t Size) override; + void MarkComplete() override; + void Close() override; + +private: + static constexpr size_t NonceBytes = 12; ///< AES-GCM nonce size + static constexpr size_t TagBytes = 16; ///< AES-GCM authentication tag size + + struct CryptoContext; + + std::unique_ptr<CryptoContext> m_Crypto; + std::unique_ptr<ComputeTransport> m_Inner; + std::vector<uint8_t> m_EncryptBuffer; + std::vector<uint8_t> m_RemainingData; ///< Buffered decrypted data from a partially consumed Recv + size_t m_RemainingOffset = 0; + std::mutex m_Lock; + bool m_IsClosed = false; +}; + +} // namespace zen::horde diff --git a/src/zenhorde/include/zenhorde/hordeclient.h b/src/zenhorde/include/zenhorde/hordeclient.h new file mode 100644 index 000000000..201d68b83 --- /dev/null +++ b/src/zenhorde/include/zenhorde/hordeclient.h @@ -0,0 +1,116 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenhorde/hordeconfig.h> + +#include <zencore/logbase.h> + +#include <cstdint> +#include <map> +#include <memory> +#include <string> +#include <vector> + +namespace zen { +class HttpClient; +} + +namespace zen::horde { + +static constexpr size_t NonceSize = 64; +static constexpr size_t KeySize = 32; + +/** Port mapping information returned by Horde for a provisioned machine. */ +struct PortInfo +{ + uint16_t Port = 0; + uint16_t AgentPort = 0; +}; + +/** Describes a provisioned compute machine returned by the Horde API. + * + * Contains the network address, encryption credentials, and capabilities + * needed to establish a compute transport connection to the machine. + */ +struct MachineInfo +{ + std::string Ip; + ConnectionMode Mode = ConnectionMode::Direct; + std::string ConnectionAddress; ///< Relay/tunnel address (used when Mode != Direct) + uint16_t Port = 0; + uint16_t LogicalCores = 0; + Encryption EncryptionMode = Encryption::None; + uint8_t Nonce[NonceSize] = {}; ///< 64-byte nonce sent during TCP handshake + uint8_t Key[KeySize] = {}; ///< 32-byte AES key (when EncryptionMode == AES) + bool IsWindows = false; + std::string LeaseId; + + std::map<std::string, PortInfo> Ports; + + /** Return the address to connect to, accounting for connection mode. */ + const std::string& GetConnectionAddress() const { return Mode == ConnectionMode::Relay ? ConnectionAddress : Ip; } + + /** Return the port to connect to, accounting for connection mode and port mapping. */ + uint16_t GetConnectionPort() const + { + if (Mode == ConnectionMode::Relay) + { + auto It = Ports.find("_horde_compute"); + if (It != Ports.end()) + { + return It->second.Port; + } + } + return Port; + } + + bool IsValid() const { return !Ip.empty() && Port != 0xFFFF; } +}; + +/** Result of cluster auto-resolution via the Horde API. */ +struct ClusterInfo +{ + std::string ClusterId = "default"; +}; + +/** HTTP client for the Horde compute REST API. + * + * Handles cluster resolution and machine provisioning requests. Each call + * is synchronous and returns success/failure. Thread safety: individual + * methods are not thread-safe; callers must synchronize access. + */ +class HordeClient +{ +public: + explicit HordeClient(const HordeConfig& Config); + ~HordeClient(); + + HordeClient(const HordeClient&) = delete; + HordeClient& operator=(const HordeClient&) = delete; + + /** Initialize the underlying HTTP client. Must be called before other methods. */ + bool Initialize(); + + /** Build the JSON request body for cluster resolution and machine requests. + * Encodes pool, condition, connection mode, encryption, and port requirements. */ + std::string BuildRequestBody() const; + + /** Resolve the best cluster for the given request via POST /api/v2/compute/_cluster. */ + bool ResolveCluster(const std::string& RequestBody, ClusterInfo& OutCluster); + + /** Request a compute machine from the given cluster via POST /api/v2/compute/{clusterId}. + * On success, populates OutMachine with connection details and credentials. */ + bool RequestMachine(const std::string& RequestBody, const std::string& ClusterId, MachineInfo& OutMachine); + + LoggerRef Log() { return m_Log; } + +private: + bool ParseHexBytes(std::string_view Hex, uint8_t* Out, size_t OutSize); + + HordeConfig m_Config; + std::unique_ptr<zen::HttpClient> m_Http; + LoggerRef m_Log; +}; + +} // namespace zen::horde diff --git a/src/zenhorde/include/zenhorde/hordeconfig.h b/src/zenhorde/include/zenhorde/hordeconfig.h new file mode 100644 index 000000000..dd70f9832 --- /dev/null +++ b/src/zenhorde/include/zenhorde/hordeconfig.h @@ -0,0 +1,62 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenhorde/zenhorde.h> + +#include <string> + +namespace zen::horde { + +/** Transport connection mode for Horde compute agents. */ +enum class ConnectionMode +{ + Direct, ///< Connect directly to the agent IP + Tunnel, ///< Connect through a Horde tunnel relay + Relay, ///< Connect through a Horde relay with port mapping +}; + +/** Transport encryption mode for Horde compute channels. */ +enum class Encryption +{ + None, ///< No encryption + AES, ///< AES-256-GCM encryption (required for Relay mode) +}; + +/** Configuration for connecting to an Epic Horde compute cluster. + * + * Specifies the Horde server URL, authentication token, pool selection, + * connection mode, and resource limits. Used by HordeClient and HordeProvisioner. + */ +struct HordeConfig +{ + static constexpr const char* ClusterDefault = "default"; + static constexpr const char* ClusterAuto = "_auto"; + + bool Enabled = false; ///< Whether Horde provisioning is active + std::string ServerUrl; ///< Horde server base URL (e.g. "https://horde.example.com") + std::string AuthToken; ///< Authentication token for the Horde API + std::string Pool; ///< Pool name to request machines from + std::string Cluster = ClusterDefault; ///< Cluster ID, or "_auto" to auto-resolve + std::string Condition; ///< Agent filter expression for machine selection + std::string HostAddress; ///< Address that provisioned agents use to connect back to us + std::string BinariesPath; ///< Path to directory containing zenserver binary for remote upload + uint16_t ZenServicePort = 8558; ///< Port number that provisioned agents should forward to us for Zen service communication + + int MaxCores = 2048; + bool AllowWine = true; ///< Allow running Windows binaries under Wine on Linux agents + ConnectionMode Mode = ConnectionMode::Direct; + Encryption EncryptionMode = Encryption::None; + + /** Validate the configuration. Returns false if the configuration is invalid + * (e.g. Relay mode without AES encryption). */ + bool Validate() const; +}; + +const char* ToString(ConnectionMode Mode); +const char* ToString(Encryption Enc); + +bool FromString(ConnectionMode& OutMode, std::string_view Str); +bool FromString(Encryption& OutEnc, std::string_view Str); + +} // namespace zen::horde diff --git a/src/zenhorde/include/zenhorde/hordeprovisioner.h b/src/zenhorde/include/zenhorde/hordeprovisioner.h new file mode 100644 index 000000000..4e2e63bbd --- /dev/null +++ b/src/zenhorde/include/zenhorde/hordeprovisioner.h @@ -0,0 +1,110 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenhorde/hordeconfig.h> + +#include <zencore/logbase.h> + +#include <atomic> +#include <cstdint> +#include <filesystem> +#include <memory> +#include <mutex> +#include <string> +#include <vector> + +namespace zen::horde { + +class HordeClient; + +/** Snapshot of the current provisioning state, returned by HordeProvisioner::GetStats(). */ +struct ProvisioningStats +{ + uint32_t TargetCoreCount = 0; ///< Requested number of cores (clamped to MaxCores) + uint32_t EstimatedCoreCount = 0; ///< Cores expected once pending requests complete + uint32_t ActiveCoreCount = 0; ///< Cores on machines that are currently running zenserver + uint32_t AgentsActive = 0; ///< Number of agents with a running remote process + uint32_t AgentsRequesting = 0; ///< Number of agents currently requesting a machine from Horde +}; + +/** Multi-agent lifecycle manager for Horde worker provisioning. + * + * Provisions remote compute workers by requesting machines from the Horde API, + * connecting via the Horde compute transport protocol, uploading the zenserver + * binary, and executing it remotely. Each provisioned machine runs zenserver + * in compute mode, which announces itself back to the orchestrator. + * + * Spawns one thread per agent. Each thread handles the full lifecycle: + * HTTP request -> TCP connect -> nonce handshake -> optional AES encryption -> + * channel setup -> binary upload -> remote execution -> poll until exit. + * + * Thread safety: SetTargetCoreCount and GetStats may be called from any thread. + */ +class HordeProvisioner +{ +public: + /** Construct a provisioner. + * @param Config Horde connection and pool configuration. + * @param BinariesPath Directory containing the zenserver binary to upload. + * @param WorkingDir Local directory for bundle staging and working files. + * @param OrchestratorEndpoint URL of the orchestrator that remote workers announce to. */ + HordeProvisioner(const HordeConfig& Config, + const std::filesystem::path& BinariesPath, + const std::filesystem::path& WorkingDir, + std::string_view OrchestratorEndpoint); + + /** Signals all agent threads to exit and joins them. */ + ~HordeProvisioner(); + + HordeProvisioner(const HordeProvisioner&) = delete; + HordeProvisioner& operator=(const HordeProvisioner&) = delete; + + /** Set the target number of cores to provision. + * Clamped to HordeConfig::MaxCores. Spawns new agent threads if the + * estimated core count is below the target. Also joins any finished + * agent threads. */ + void SetTargetCoreCount(uint32_t Count); + + /** Return a snapshot of the current provisioning counters. */ + ProvisioningStats GetStats() const; + + uint32_t GetActiveCoreCount() const { return m_ActiveCoreCount.load(); } + uint32_t GetAgentCount() const; + +private: + LoggerRef Log() { return m_Log; } + + struct AgentWrapper; + + void RequestAgent(); + void ThreadAgent(AgentWrapper& Wrapper); + + HordeConfig m_Config; + std::filesystem::path m_BinariesPath; + std::filesystem::path m_WorkingDir; + std::string m_OrchestratorEndpoint; + + std::unique_ptr<HordeClient> m_HordeClient; + + std::mutex m_BundleLock; + std::vector<std::pair<std::string, std::filesystem::path>> m_Bundles; ///< (locator, bundleDir) pairs + bool m_BundlesCreated = false; + + mutable std::mutex m_AgentsLock; + std::vector<std::unique_ptr<AgentWrapper>> m_Agents; + + std::atomic<uint64_t> m_LastRequestFailTime{0}; + std::atomic<uint32_t> m_TargetCoreCount{0}; + std::atomic<uint32_t> m_EstimatedCoreCount{0}; + std::atomic<uint32_t> m_ActiveCoreCount{0}; + std::atomic<uint32_t> m_AgentsActive{0}; + std::atomic<uint32_t> m_AgentsRequesting{0}; + std::atomic<bool> m_AskForAgents{true}; + + LoggerRef m_Log; + + static constexpr uint32_t EstimatedCoresPerAgent = 32; +}; + +} // namespace zen::horde diff --git a/src/zenhorde/include/zenhorde/zenhorde.h b/src/zenhorde/include/zenhorde/zenhorde.h new file mode 100644 index 000000000..35147ff75 --- /dev/null +++ b/src/zenhorde/include/zenhorde/zenhorde.h @@ -0,0 +1,9 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/zencore.h> + +#if !defined(ZEN_WITH_HORDE) +# define ZEN_WITH_HORDE 1 +#endif diff --git a/src/zenhorde/xmake.lua b/src/zenhorde/xmake.lua new file mode 100644 index 000000000..48d028e86 --- /dev/null +++ b/src/zenhorde/xmake.lua @@ -0,0 +1,22 @@ +-- Copyright Epic Games, Inc. All Rights Reserved. + +target('zenhorde') + set_kind("static") + set_group("libs") + add_headerfiles("**.h") + add_files("**.cpp") + add_includedirs("include", {public=true}) + add_deps("zencore", "zenhttp", "zencompute", "zenutil") + add_packages("asio", "json11") + + if is_plat("windows") then + add_syslinks("Ws2_32", "Bcrypt") + end + + if is_plat("linux") or is_plat("macosx") then + add_packages("openssl") + end + + if is_os("macosx") then + add_cxxflags("-Wno-deprecated-declarations") + end |