// Copyright Epic Games, Inc. All Rights Reserved. #include "hordeagentmessage.h" #include ZEN_THIRD_PARTY_INCLUDES_START #include ZEN_THIRD_PARTY_INCLUDES_END #include #include namespace zen::horde { // --- AsyncAgentMessageChannel --- AsyncAgentMessageChannel::AsyncAgentMessageChannel(std::shared_ptr Socket, int ChannelId, asio::io_context& IoContext) : m_Socket(std::move(Socket)) , m_ChannelId(ChannelId) , m_IoContext(IoContext) , m_TimeoutTimer(std::make_unique(IoContext)) { } AsyncAgentMessageChannel::~AsyncAgentMessageChannel() { if (m_TimeoutTimer) { m_TimeoutTimer->cancel(); } } // --- Message building helpers --- std::vector AsyncAgentMessageChannel::BeginMessage(AgentMessageType Type, size_t ReservePayload) { std::vector Buf; Buf.reserve(MessageHeaderLength + ReservePayload); Buf.push_back(static_cast(Type)); Buf.resize(MessageHeaderLength); // 1 byte type + 4 bytes length placeholder return Buf; } void AsyncAgentMessageChannel::FinalizeAndSend(std::vector Msg) { const uint32_t PayloadSize = static_cast(Msg.size() - MessageHeaderLength); memcpy(&Msg[1], &PayloadSize, sizeof(uint32_t)); m_Socket->AsyncSendFrame(m_ChannelId, std::move(Msg)); } void AsyncAgentMessageChannel::WriteInt32(std::vector& Buf, int Value) { const uint8_t* Ptr = reinterpret_cast(&Value); Buf.insert(Buf.end(), Ptr, Ptr + sizeof(int)); } int AsyncAgentMessageChannel::ReadInt32(const uint8_t** Pos) { int Value; memcpy(&Value, *Pos, sizeof(int)); *Pos += sizeof(int); return Value; } void AsyncAgentMessageChannel::WriteFixedLengthBytes(std::vector& Buf, const uint8_t* Data, size_t Length) { Buf.insert(Buf.end(), Data, Data + Length); } const uint8_t* AsyncAgentMessageChannel::ReadFixedLengthBytes(const uint8_t** Pos, size_t Length) { const uint8_t* Data = *Pos; *Pos += Length; return Data; } size_t AsyncAgentMessageChannel::MeasureUnsignedVarInt(size_t Value) { if (Value == 0) { return 1; } return (FloorLog2_64(static_cast(Value)) / 7) + 1; } void AsyncAgentMessageChannel::WriteUnsignedVarInt(std::vector& Buf, size_t Value) { const size_t ByteCount = MeasureUnsignedVarInt(Value); const size_t StartPos = Buf.size(); Buf.resize(StartPos + ByteCount); uint8_t* Output = Buf.data() + StartPos; for (size_t i = 1; i < ByteCount; ++i) { Output[ByteCount - i] = static_cast(Value); Value >>= 8; } Output[0] = static_cast((0xFF << (9 - static_cast(ByteCount))) | static_cast(Value)); } size_t AsyncAgentMessageChannel::ReadUnsignedVarInt(const uint8_t** Pos) { const uint8_t* Data = *Pos; const uint8_t FirstByte = Data[0]; const size_t NumBytes = CountLeadingZeros(0xFF & (~static_cast(FirstByte))) + 1 - 24; size_t Value = static_cast(FirstByte & (0xFF >> NumBytes)); for (size_t i = 1; i < NumBytes; ++i) { Value <<= 8; Value |= Data[i]; } *Pos += NumBytes; return Value; } void AsyncAgentMessageChannel::WriteString(std::vector& Buf, const char* Text) { const size_t Length = strlen(Text); WriteUnsignedVarInt(Buf, Length); WriteFixedLengthBytes(Buf, reinterpret_cast(Text), Length); } void AsyncAgentMessageChannel::WriteString(std::vector& Buf, std::string_view Text) { WriteUnsignedVarInt(Buf, Text.size()); WriteFixedLengthBytes(Buf, reinterpret_cast(Text.data()), Text.size()); } std::string_view AsyncAgentMessageChannel::ReadString(const uint8_t** Pos) { const size_t Length = ReadUnsignedVarInt(Pos); const char* Start = reinterpret_cast(ReadFixedLengthBytes(Pos, Length)); return std::string_view(Start, Length); } void AsyncAgentMessageChannel::WriteOptionalString(std::vector& Buf, const char* Text) { if (!Text) { WriteUnsignedVarInt(Buf, 0); } else { const size_t Length = strlen(Text); WriteUnsignedVarInt(Buf, Length + 1); WriteFixedLengthBytes(Buf, reinterpret_cast(Text), Length); } } // --- Send methods --- void AsyncAgentMessageChannel::Close() { auto Msg = BeginMessage(AgentMessageType::None, 0); FinalizeAndSend(std::move(Msg)); } void AsyncAgentMessageChannel::Ping() { auto Msg = BeginMessage(AgentMessageType::Ping, 0); FinalizeAndSend(std::move(Msg)); } void AsyncAgentMessageChannel::Fork(int ChannelId, int BufferSize) { auto Msg = BeginMessage(AgentMessageType::Fork, sizeof(int) + sizeof(int)); WriteInt32(Msg, ChannelId); WriteInt32(Msg, BufferSize); FinalizeAndSend(std::move(Msg)); } void AsyncAgentMessageChannel::Attach() { auto Msg = BeginMessage(AgentMessageType::Attach, 0); FinalizeAndSend(std::move(Msg)); } void AsyncAgentMessageChannel::UploadFiles(const char* Path, const char* Locator) { auto Msg = BeginMessage(AgentMessageType::WriteFiles, strlen(Path) + strlen(Locator) + 20); WriteString(Msg, Path); WriteString(Msg, Locator); FinalizeAndSend(std::move(Msg)); } void AsyncAgentMessageChannel::Execute(const char* Exe, const char* const* Args, size_t NumArgs, const char* WorkingDir, const char* const* EnvVars, size_t NumEnvVars, ExecuteProcessFlags Flags) { size_t ReserveSize = 50 + strlen(Exe); for (size_t i = 0; i < NumArgs; ++i) { ReserveSize += strlen(Args[i]) + 10; } if (WorkingDir) { ReserveSize += strlen(WorkingDir) + 10; } for (size_t i = 0; i < NumEnvVars; ++i) { ReserveSize += strlen(EnvVars[i]) + 20; } auto Msg = BeginMessage(AgentMessageType::ExecuteV2, ReserveSize); WriteString(Msg, Exe); WriteUnsignedVarInt(Msg, NumArgs); for (size_t i = 0; i < NumArgs; ++i) { WriteString(Msg, Args[i]); } WriteOptionalString(Msg, WorkingDir); WriteUnsignedVarInt(Msg, NumEnvVars); for (size_t i = 0; i < NumEnvVars; ++i) { const char* Eq = strchr(EnvVars[i], '='); assert(Eq != nullptr); WriteString(Msg, std::string_view(EnvVars[i], Eq - EnvVars[i])); if (*(Eq + 1) == '\0') { WriteOptionalString(Msg, nullptr); } else { WriteOptionalString(Msg, Eq + 1); } } WriteInt32(Msg, static_cast(Flags)); FinalizeAndSend(std::move(Msg)); } void AsyncAgentMessageChannel::Blob(const uint8_t* Data, size_t Length) { static constexpr size_t MaxBlobChunkSize = 512 * 1024; for (size_t ChunkOffset = 0; ChunkOffset < Length;) { const size_t ChunkLength = std::min(Length - ChunkOffset, MaxBlobChunkSize); auto Msg = BeginMessage(AgentMessageType::ReadBlobResponse, ChunkLength + 128); WriteInt32(Msg, static_cast(ChunkOffset)); WriteInt32(Msg, static_cast(Length)); WriteFixedLengthBytes(Msg, Data + ChunkOffset, ChunkLength); FinalizeAndSend(std::move(Msg)); ChunkOffset += ChunkLength; } } // --- Async response reading --- void AsyncAgentMessageChannel::AsyncReadResponse(int32_t TimeoutMs, AsyncResponseHandler Handler) { // If frames are already queued, dispatch immediately if (!m_IncomingFrames.empty()) { std::vector Frame = std::move(m_IncomingFrames.front()); m_IncomingFrames.pop_front(); if (Frame.size() >= MessageHeaderLength) { AgentMessageType Type = static_cast(Frame[0]); const uint8_t* Data = Frame.data() + MessageHeaderLength; size_t Size = Frame.size() - MessageHeaderLength; asio::post(m_IoContext, [Handler = std::move(Handler), Type, Frame = std::move(Frame), Data, Size]() mutable { // The Frame is captured to keep Data pointer valid Handler(Type, Data, Size); }); } else { asio::post(m_IoContext, [Handler = std::move(Handler)] { Handler(AgentMessageType::None, nullptr, 0); }); } return; } if (m_Detached) { asio::post(m_IoContext, [Handler = std::move(Handler)] { Handler(AgentMessageType::None, nullptr, 0); }); return; } // No frames queued - store pending handler and arm timeout m_PendingHandler = std::move(Handler); if (TimeoutMs >= 0) { m_TimeoutTimer->expires_after(std::chrono::milliseconds(TimeoutMs)); m_TimeoutTimer->async_wait([this](const asio::error_code& Ec) { if (Ec) { return; // Cancelled - frame arrived before timeout } if (m_PendingHandler) { AsyncResponseHandler Handler = std::move(m_PendingHandler); m_PendingHandler = nullptr; Handler(AgentMessageType::None, nullptr, 0); } }); } } void AsyncAgentMessageChannel::OnFrame(std::vector Data) { if (m_PendingHandler) { // Cancel the timeout timer m_TimeoutTimer->cancel(); AsyncResponseHandler Handler = std::move(m_PendingHandler); m_PendingHandler = nullptr; if (Data.size() >= MessageHeaderLength) { AgentMessageType Type = static_cast(Data[0]); const uint8_t* Payload = Data.data() + MessageHeaderLength; size_t PayloadSize = Data.size() - MessageHeaderLength; Handler(Type, Payload, PayloadSize); } else { Handler(AgentMessageType::None, nullptr, 0); } } else { m_IncomingFrames.push_back(std::move(Data)); } } void AsyncAgentMessageChannel::OnDetach() { m_Detached = true; if (m_PendingHandler) { m_TimeoutTimer->cancel(); AsyncResponseHandler Handler = std::move(m_PendingHandler); m_PendingHandler = nullptr; Handler(AgentMessageType::None, nullptr, 0); } } // --- Response parsing helpers --- void AsyncAgentMessageChannel::ReadException(const uint8_t* Data, size_t /*Size*/, ExceptionInfo& Ex) { const uint8_t* Pos = Data; Ex.Message = ReadString(&Pos); Ex.Description = ReadString(&Pos); } int AsyncAgentMessageChannel::ReadExecuteResult(const uint8_t* Data, size_t /*Size*/) { const uint8_t* Pos = Data; return ReadInt32(&Pos); } void AsyncAgentMessageChannel::ReadBlobRequest(const uint8_t* Data, size_t /*Size*/, BlobRequest& Req) { const uint8_t* Pos = Data; Req.Locator = ReadString(&Pos); Req.Offset = ReadUnsignedVarInt(&Pos); Req.Length = ReadUnsignedVarInt(&Pos); } } // namespace zen::horde