// Copyright Epic Games, Inc. All Rights Reserved. #include "hordeagentmessage.h" #include ZEN_THIRD_PARTY_INCLUDES_START #include ZEN_THIRD_PARTY_INCLUDES_END #include #include #include #include namespace zen::horde { // --- AsyncAgentMessageChannel --- AsyncAgentMessageChannel::AsyncAgentMessageChannel(std::shared_ptr Socket, int ChannelId, asio::io_context& IoContext) : m_Socket(std::move(Socket)) , m_ChannelId(ChannelId) , m_IoContext(IoContext) , m_TimeoutTimer(std::make_unique(m_Socket->GetStrand())) { } AsyncAgentMessageChannel::~AsyncAgentMessageChannel() { if (m_TimeoutTimer) { m_TimeoutTimer->cancel(); } } // --- Message building helpers --- std::vector AsyncAgentMessageChannel::BeginMessage(AgentMessageType Type, size_t ReservePayload) { std::vector Buf; Buf.reserve(MessageHeaderLength + ReservePayload); Buf.push_back(static_cast(Type)); Buf.resize(MessageHeaderLength); // 1 byte type + 4 bytes length placeholder return Buf; } void AsyncAgentMessageChannel::FinalizeAndSend(std::vector Msg) { const uint32_t PayloadSize = static_cast(Msg.size() - MessageHeaderLength); memcpy(&Msg[1], &PayloadSize, sizeof(uint32_t)); m_Socket->AsyncSendFrame(m_ChannelId, std::move(Msg)); } void AsyncAgentMessageChannel::WriteInt32(std::vector& Buf, int Value) { const uint8_t* Ptr = reinterpret_cast(&Value); Buf.insert(Buf.end(), Ptr, Ptr + sizeof(int)); } int AsyncAgentMessageChannel::ReadInt32(ReadCursor& C) { if (!C.CheckAvailable(sizeof(int32_t))) { return 0; } int32_t Value; memcpy(&Value, C.Pos, sizeof(int32_t)); C.Pos += sizeof(int32_t); return Value; } void AsyncAgentMessageChannel::WriteFixedLengthBytes(std::vector& Buf, const uint8_t* Data, size_t Length) { Buf.insert(Buf.end(), Data, Data + Length); } const uint8_t* AsyncAgentMessageChannel::ReadFixedLengthBytes(ReadCursor& C, size_t Length) { if (!C.CheckAvailable(Length)) { return nullptr; } const uint8_t* Data = C.Pos; C.Pos += Length; return Data; } size_t AsyncAgentMessageChannel::MeasureUnsignedVarInt(size_t Value) { if (Value == 0) { return 1; } return (FloorLog2_64(static_cast(Value)) / 7) + 1; } void AsyncAgentMessageChannel::WriteUnsignedVarInt(std::vector& Buf, size_t Value) { const size_t ByteCount = MeasureUnsignedVarInt(Value); const size_t StartPos = Buf.size(); Buf.resize(StartPos + ByteCount); uint8_t* Output = Buf.data() + StartPos; for (size_t i = 1; i < ByteCount; ++i) { Output[ByteCount - i] = static_cast(Value); Value >>= 8; } Output[0] = static_cast((0xFF << (9 - static_cast(ByteCount))) | static_cast(Value)); } size_t AsyncAgentMessageChannel::ReadUnsignedVarInt(ReadCursor& C) { // Need at least the leading byte to determine the encoded length. if (!C.CheckAvailable(1)) { return 0; } const uint8_t FirstByte = C.Pos[0]; const size_t NumBytes = CountLeadingZeros(0xFF & (~static_cast(FirstByte))) + 1 - 24; // The encoded length implied by the leading 0xFF-run may be 1..9 bytes; ensure the remaining bytes are in-bounds. if (!C.CheckAvailable(NumBytes)) { return 0; } size_t Value = static_cast(FirstByte & (0xFF >> NumBytes)); for (size_t i = 1; i < NumBytes; ++i) { Value <<= 8; Value |= C.Pos[i]; } C.Pos += NumBytes; return Value; } void AsyncAgentMessageChannel::WriteString(std::vector& Buf, const char* Text) { const size_t Length = strlen(Text); WriteUnsignedVarInt(Buf, Length); WriteFixedLengthBytes(Buf, reinterpret_cast(Text), Length); } void AsyncAgentMessageChannel::WriteString(std::vector& Buf, std::string_view Text) { WriteUnsignedVarInt(Buf, Text.size()); WriteFixedLengthBytes(Buf, reinterpret_cast(Text.data()), Text.size()); } std::string_view AsyncAgentMessageChannel::ReadString(ReadCursor& C) { const size_t Length = ReadUnsignedVarInt(C); const uint8_t* Start = ReadFixedLengthBytes(C, Length); if (C.ParseError || !Start) { return {}; } return std::string_view(reinterpret_cast(Start), Length); } void AsyncAgentMessageChannel::WriteOptionalString(std::vector& Buf, const char* Text) { if (!Text) { WriteUnsignedVarInt(Buf, 0); } else { const size_t Length = strlen(Text); WriteUnsignedVarInt(Buf, Length + 1); WriteFixedLengthBytes(Buf, reinterpret_cast(Text), Length); } } // --- Send methods --- void AsyncAgentMessageChannel::Close() { auto Msg = BeginMessage(AgentMessageType::None, 0); FinalizeAndSend(std::move(Msg)); } void AsyncAgentMessageChannel::Ping() { auto Msg = BeginMessage(AgentMessageType::Ping, 0); FinalizeAndSend(std::move(Msg)); } void AsyncAgentMessageChannel::Fork(int ChannelId, int BufferSize) { auto Msg = BeginMessage(AgentMessageType::Fork, sizeof(int) + sizeof(int)); WriteInt32(Msg, ChannelId); WriteInt32(Msg, BufferSize); FinalizeAndSend(std::move(Msg)); } void AsyncAgentMessageChannel::Attach() { auto Msg = BeginMessage(AgentMessageType::Attach, 0); FinalizeAndSend(std::move(Msg)); } void AsyncAgentMessageChannel::UploadFiles(const char* Path, const char* Locator) { auto Msg = BeginMessage(AgentMessageType::WriteFiles, strlen(Path) + strlen(Locator) + 20); WriteString(Msg, Path); WriteString(Msg, Locator); FinalizeAndSend(std::move(Msg)); } void AsyncAgentMessageChannel::Execute(const char* Exe, const char* const* Args, size_t NumArgs, const char* WorkingDir, const char* const* EnvVars, size_t NumEnvVars, ExecuteProcessFlags Flags) { size_t ReserveSize = 50 + strlen(Exe); for (size_t i = 0; i < NumArgs; ++i) { ReserveSize += strlen(Args[i]) + 10; } if (WorkingDir) { ReserveSize += strlen(WorkingDir) + 10; } for (size_t i = 0; i < NumEnvVars; ++i) { ReserveSize += strlen(EnvVars[i]) + 20; } auto Msg = BeginMessage(AgentMessageType::ExecuteV2, ReserveSize); WriteString(Msg, Exe); WriteUnsignedVarInt(Msg, NumArgs); for (size_t i = 0; i < NumArgs; ++i) { WriteString(Msg, Args[i]); } WriteOptionalString(Msg, WorkingDir); WriteUnsignedVarInt(Msg, NumEnvVars); for (size_t i = 0; i < NumEnvVars; ++i) { const char* Eq = strchr(EnvVars[i], '='); if (Eq == nullptr) { // assert() would be compiled out in release and leave *(Eq+1) as UB - // refuse to build the message for a malformed KEY=VALUE string instead. throw zen::runtime_error("horde agent env var at index {} missing '=' separator", i); } WriteString(Msg, std::string_view(EnvVars[i], Eq - EnvVars[i])); if (*(Eq + 1) == '\0') { WriteOptionalString(Msg, nullptr); } else { WriteOptionalString(Msg, Eq + 1); } } WriteInt32(Msg, static_cast(Flags)); FinalizeAndSend(std::move(Msg)); } void AsyncAgentMessageChannel::Blob(const uint8_t* Data, size_t Length) { static constexpr size_t MaxBlobChunkSize = 512 * 1024; // The Horde ReadBlobResponse wire format encodes both the chunk Offset and the total // Length as int32. Lengths of 2 GiB or more would wrap to negative and confuse the // remote parser. Refuse the send rather than produce a protocol violation. if (Length > static_cast(std::numeric_limits::max())) { throw zen::runtime_error("horde ReadBlobResponse length {} exceeds int32 wire limit", Length); } for (size_t ChunkOffset = 0; ChunkOffset < Length;) { const size_t ChunkLength = std::min(Length - ChunkOffset, MaxBlobChunkSize); auto Msg = BeginMessage(AgentMessageType::ReadBlobResponse, ChunkLength + 128); WriteInt32(Msg, static_cast(ChunkOffset)); WriteInt32(Msg, static_cast(Length)); WriteFixedLengthBytes(Msg, Data + ChunkOffset, ChunkLength); FinalizeAndSend(std::move(Msg)); ChunkOffset += ChunkLength; } } // --- Async response reading --- void AsyncAgentMessageChannel::AsyncReadResponse(int32_t TimeoutMs, AsyncResponseHandler Handler) { // Serialize all access to m_IncomingFrames / m_PendingHandler / m_TimeoutTimer onto // the socket's strand; OnFrame/OnDetach also run on that strand. Without this, the // timer wait completion would run on a bare io_context thread (3 concurrent run() // loops in the provisioner) and race with OnFrame on m_PendingHandler. asio::dispatch(m_Socket->GetStrand(), [this, TimeoutMs, Handler = std::move(Handler)]() mutable { if (!m_IncomingFrames.empty()) { std::vector Frame = std::move(m_IncomingFrames.front()); m_IncomingFrames.pop_front(); if (Frame.size() >= MessageHeaderLength) { AgentMessageType Type = static_cast(Frame[0]); const uint8_t* Data = Frame.data() + MessageHeaderLength; size_t Size = Frame.size() - MessageHeaderLength; asio::post(m_IoContext, [Handler = std::move(Handler), Type, Frame = std::move(Frame), Data, Size]() mutable { // The Frame is captured to keep Data pointer valid Handler(Type, Data, Size); }); } else { asio::post(m_IoContext, [Handler = std::move(Handler)] { Handler(AgentMessageType::None, nullptr, 0); }); } return; } if (m_Detached) { asio::post(m_IoContext, [Handler = std::move(Handler)] { Handler(AgentMessageType::None, nullptr, 0); }); return; } // No frames queued - store pending handler and arm timeout m_PendingHandler = std::move(Handler); if (TimeoutMs >= 0) { m_TimeoutTimer->expires_after(std::chrono::milliseconds(TimeoutMs)); m_TimeoutTimer->async_wait(asio::bind_executor(m_Socket->GetStrand(), [this](const asio::error_code& Ec) { if (Ec) { return; // Cancelled - frame arrived before timeout } // Already on the strand: safe to mutate m_PendingHandler. if (m_PendingHandler) { AsyncResponseHandler Handler = std::move(m_PendingHandler); m_PendingHandler = nullptr; Handler(AgentMessageType::None, nullptr, 0); } })); } }); } void AsyncAgentMessageChannel::OnFrame(std::vector Data) { if (m_PendingHandler) { // Cancel the timeout timer m_TimeoutTimer->cancel(); AsyncResponseHandler Handler = std::move(m_PendingHandler); m_PendingHandler = nullptr; if (Data.size() >= MessageHeaderLength) { AgentMessageType Type = static_cast(Data[0]); const uint8_t* Payload = Data.data() + MessageHeaderLength; size_t PayloadSize = Data.size() - MessageHeaderLength; Handler(Type, Payload, PayloadSize); } else { Handler(AgentMessageType::None, nullptr, 0); } } else { m_IncomingFrames.push_back(std::move(Data)); } } void AsyncAgentMessageChannel::OnDetach() { m_Detached = true; if (m_PendingHandler) { m_TimeoutTimer->cancel(); AsyncResponseHandler Handler = std::move(m_PendingHandler); m_PendingHandler = nullptr; Handler(AgentMessageType::None, nullptr, 0); } } // --- Response parsing helpers --- bool AsyncAgentMessageChannel::ReadException(const uint8_t* Data, size_t Size, ExceptionInfo& Ex) { ReadCursor C{Data, Data + Size, false}; Ex.Message = ReadString(C); Ex.Description = ReadString(C); if (C.ParseError) { Ex = {}; return false; } return true; } bool AsyncAgentMessageChannel::ReadExecuteResult(const uint8_t* Data, size_t Size, int32_t& OutExitCode) { ReadCursor C{Data, Data + Size, false}; OutExitCode = ReadInt32(C); return !C.ParseError; } static bool IsSafeLocator(std::string_view Locator) { // Reject empty, overlong, path-separator-containing, parent-relative, absolute, or // control-character-containing locators. The locator is used as a filename component // joined with a trusted BundleDir, so the only safe characters are a restricted // filename alphabet. if (Locator.empty() || Locator.size() > 255) { return false; } if (Locator == "." || Locator == "..") { return false; } for (char Ch : Locator) { const unsigned char U = static_cast(Ch); if (U < 0x20 || U == 0x7F) { return false; // control / NUL / DEL } if (Ch == '/' || Ch == '\\' || Ch == ':') { return false; // path separators / drive letters } } // Disallow leading/trailing dot or whitespace (Windows quirks + hidden-file dodges) if (Locator.front() == '.' || Locator.front() == ' ' || Locator.back() == '.' || Locator.back() == ' ') { return false; } return true; } bool AsyncAgentMessageChannel::ReadBlobRequest(const uint8_t* Data, size_t Size, BlobRequest& Req) { ReadCursor C{Data, Data + Size, false}; Req.Locator = ReadString(C); Req.Offset = ReadUnsignedVarInt(C); Req.Length = ReadUnsignedVarInt(C); if (C.ParseError || !IsSafeLocator(Req.Locator)) { Req = {}; return false; } return true; } } // namespace zen::horde