// Copyright Epic Games, Inc. All Rights Reserved. #include "hordeagentmessage.h" #include #include #include namespace zen::horde { AgentMessageChannel::AgentMessageChannel(Ref Channel) : m_Channel(std::move(Channel)) { } AgentMessageChannel::~AgentMessageChannel() = default; void AgentMessageChannel::Close() { CreateMessage(AgentMessageType::None, 0); FlushMessage(); } void AgentMessageChannel::Ping() { CreateMessage(AgentMessageType::Ping, 0); FlushMessage(); } void AgentMessageChannel::Fork(int ChannelId, int BufferSize) { CreateMessage(AgentMessageType::Fork, sizeof(int) + sizeof(int)); WriteInt32(ChannelId); WriteInt32(BufferSize); FlushMessage(); } void AgentMessageChannel::Attach() { CreateMessage(AgentMessageType::Attach, 0); FlushMessage(); } void AgentMessageChannel::UploadFiles(const char* Path, const char* Locator) { CreateMessage(AgentMessageType::WriteFiles, strlen(Path) + strlen(Locator) + 20); WriteString(Path); WriteString(Locator); FlushMessage(); } void AgentMessageChannel::Execute(const char* Exe, const char* const* Args, size_t NumArgs, const char* WorkingDir, const char* const* EnvVars, size_t NumEnvVars, ExecuteProcessFlags Flags) { size_t RequiredSize = 50 + strlen(Exe); for (size_t i = 0; i < NumArgs; ++i) { RequiredSize += strlen(Args[i]) + 10; } if (WorkingDir) { RequiredSize += strlen(WorkingDir) + 10; } for (size_t i = 0; i < NumEnvVars; ++i) { RequiredSize += strlen(EnvVars[i]) + 20; } CreateMessage(AgentMessageType::ExecuteV2, RequiredSize); WriteString(Exe); WriteUnsignedVarInt(NumArgs); for (size_t i = 0; i < NumArgs; ++i) { WriteString(Args[i]); } WriteOptionalString(WorkingDir); // ExecuteV2 protocol requires env vars as separate key/value pairs. // Callers pass "KEY=VALUE" strings; we split on the first '=' here. WriteUnsignedVarInt(NumEnvVars); for (size_t i = 0; i < NumEnvVars; ++i) { const char* Eq = strchr(EnvVars[i], '='); assert(Eq != nullptr); WriteString(std::string_view(EnvVars[i], Eq - EnvVars[i])); if (*(Eq + 1) == '\0') { WriteOptionalString(nullptr); } else { WriteOptionalString(Eq + 1); } } WriteInt32(static_cast(Flags)); FlushMessage(); } void AgentMessageChannel::Blob(const uint8_t* Data, size_t Length) { // Blob responses are chunked to fit within the compute buffer's chunk size. // The 128-byte margin accounts for the ReadBlobResponse header (offset + total length fields). const size_t MaxChunkSize = m_Channel->Writer.GetChunkMaxLength() - 128 - MessageHeaderLength; for (size_t ChunkOffset = 0; ChunkOffset < Length;) { const size_t ChunkLength = std::min(Length - ChunkOffset, MaxChunkSize); CreateMessage(AgentMessageType::ReadBlobResponse, ChunkLength + 128); WriteInt32(static_cast(ChunkOffset)); WriteInt32(static_cast(Length)); WriteFixedLengthBytes(Data + ChunkOffset, ChunkLength); FlushMessage(); ChunkOffset += ChunkLength; } } AgentMessageType AgentMessageChannel::ReadResponse(int32_t TimeoutMs, bool* OutTimedOut) { // Deferred advance: the previous response's buffer is only released when the next // ReadResponse is called. This allows callers to read response data between calls // without copying, since the pointer comes directly from the ring buffer. if (m_ResponseData) { m_Channel->Reader.AdvanceReadPosition(m_ResponseLength + MessageHeaderLength); m_ResponseData = nullptr; m_ResponseLength = 0; } const uint8_t* Header = m_Channel->Reader.WaitToRead(MessageHeaderLength, TimeoutMs, OutTimedOut); if (!Header) { return AgentMessageType::None; } uint32_t Length; memcpy(&Length, Header + 1, sizeof(uint32_t)); Header = m_Channel->Reader.WaitToRead(MessageHeaderLength + Length, TimeoutMs, OutTimedOut); if (!Header) { return AgentMessageType::None; } m_ResponseType = static_cast(Header[0]); m_ResponseData = Header + MessageHeaderLength; m_ResponseLength = Length; return m_ResponseType; } void AgentMessageChannel::ReadException(ExceptionInfo& Ex) { assert(m_ResponseType == AgentMessageType::Exception); const uint8_t* Pos = m_ResponseData; Ex.Message = ReadString(&Pos); Ex.Description = ReadString(&Pos); } int AgentMessageChannel::ReadExecuteResult() { assert(m_ResponseType == AgentMessageType::ExecuteResult); const uint8_t* Pos = m_ResponseData; return ReadInt32(&Pos); } void AgentMessageChannel::ReadBlobRequest(BlobRequest& Req) { assert(m_ResponseType == AgentMessageType::ReadBlob); const uint8_t* Pos = m_ResponseData; Req.Locator = ReadString(&Pos); Req.Offset = ReadUnsignedVarInt(&Pos); Req.Length = ReadUnsignedVarInt(&Pos); } void AgentMessageChannel::CreateMessage(AgentMessageType Type, size_t MaxLength) { m_RequestData = m_Channel->Writer.WaitToWrite(MessageHeaderLength + MaxLength); m_RequestData[0] = static_cast(Type); m_MaxRequestSize = MaxLength; m_RequestSize = 0; } void AgentMessageChannel::FlushMessage() { const uint32_t Size = static_cast(m_RequestSize); memcpy(&m_RequestData[1], &Size, sizeof(uint32_t)); m_Channel->Writer.AdvanceWritePosition(MessageHeaderLength + m_RequestSize); m_RequestSize = 0; m_MaxRequestSize = 0; m_RequestData = nullptr; } void AgentMessageChannel::WriteInt32(int Value) { WriteFixedLengthBytes(reinterpret_cast(&Value), sizeof(int)); } int AgentMessageChannel::ReadInt32(const uint8_t** Pos) { int Value; memcpy(&Value, *Pos, sizeof(int)); *Pos += sizeof(int); return Value; } void AgentMessageChannel::WriteFixedLengthBytes(const uint8_t* Data, size_t Length) { assert(m_RequestSize + Length <= m_MaxRequestSize); memcpy(&m_RequestData[MessageHeaderLength + m_RequestSize], Data, Length); m_RequestSize += Length; } const uint8_t* AgentMessageChannel::ReadFixedLengthBytes(const uint8_t** Pos, size_t Length) { const uint8_t* Data = *Pos; *Pos += Length; return Data; } size_t AgentMessageChannel::MeasureUnsignedVarInt(size_t Value) { if (Value == 0) { return 1; } return (FloorLog2_64(static_cast(Value)) / 7) + 1; } void AgentMessageChannel::WriteUnsignedVarInt(size_t Value) { const size_t ByteCount = MeasureUnsignedVarInt(Value); assert(m_RequestSize + ByteCount <= m_MaxRequestSize); uint8_t* Output = m_RequestData + MessageHeaderLength + m_RequestSize; for (size_t i = 1; i < ByteCount; ++i) { Output[ByteCount - i] = static_cast(Value); Value >>= 8; } Output[0] = static_cast((0xFF << (9 - static_cast(ByteCount))) | static_cast(Value)); m_RequestSize += ByteCount; } size_t AgentMessageChannel::ReadUnsignedVarInt(const uint8_t** Pos) { const uint8_t* Data = *Pos; const uint8_t FirstByte = Data[0]; const size_t NumBytes = CountLeadingZeros(0xFF & (~static_cast(FirstByte))) + 1 - 24; size_t Value = static_cast(FirstByte & (0xFF >> NumBytes)); for (size_t i = 1; i < NumBytes; ++i) { Value <<= 8; Value |= Data[i]; } *Pos += NumBytes; return Value; } size_t AgentMessageChannel::MeasureString(const char* Text) const { const size_t Length = strlen(Text); return MeasureUnsignedVarInt(Length) + Length; } void AgentMessageChannel::WriteString(const char* Text) { const size_t Length = strlen(Text); WriteUnsignedVarInt(Length); WriteFixedLengthBytes(reinterpret_cast(Text), Length); } void AgentMessageChannel::WriteString(std::string_view Text) { WriteUnsignedVarInt(Text.size()); WriteFixedLengthBytes(reinterpret_cast(Text.data()), Text.size()); } std::string_view AgentMessageChannel::ReadString(const uint8_t** Pos) { const size_t Length = ReadUnsignedVarInt(Pos); const char* Start = reinterpret_cast(ReadFixedLengthBytes(Pos, Length)); return std::string_view(Start, Length); } void AgentMessageChannel::WriteOptionalString(const char* Text) { // Optional strings use length+1 encoding: 0 means null/absent, // N>0 means a string of length N-1 follows. This matches the UE // FAgentMessageChannel serialization convention. if (!Text) { WriteUnsignedVarInt(0); } else { const size_t Length = strlen(Text); WriteUnsignedVarInt(Length + 1); WriteFixedLengthBytes(reinterpret_cast(Text), Length); } } } // namespace zen::horde