diff options
Diffstat (limited to 'src/zenhorde/hordeagentmessage.cpp')
| -rw-r--r-- | src/zenhorde/hordeagentmessage.cpp | 502 |
1 files changed, 284 insertions, 218 deletions
diff --git a/src/zenhorde/hordeagentmessage.cpp b/src/zenhorde/hordeagentmessage.cpp index 998134a96..31498972f 100644 --- a/src/zenhorde/hordeagentmessage.cpp +++ b/src/zenhorde/hordeagentmessage.cpp @@ -4,337 +4,403 @@ #include <zencore/intmath.h> +ZEN_THIRD_PARTY_INCLUDES_START +#include <asio.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + #include <cassert> #include <cstring> namespace zen::horde { -AgentMessageChannel::AgentMessageChannel(Ref<ComputeChannel> Channel) : m_Channel(std::move(Channel)) +// --- AsyncAgentMessageChannel --- + +AsyncAgentMessageChannel::AsyncAgentMessageChannel(std::shared_ptr<AsyncComputeSocket> Socket, int ChannelId, asio::io_context& IoContext) +: m_Socket(std::move(Socket)) +, m_ChannelId(ChannelId) +, m_IoContext(IoContext) +, m_TimeoutTimer(std::make_unique<asio::steady_timer>(IoContext)) { } -AgentMessageChannel::~AgentMessageChannel() = default; - -void -AgentMessageChannel::Close() +AsyncAgentMessageChannel::~AsyncAgentMessageChannel() { - CreateMessage(AgentMessageType::None, 0); - FlushMessage(); + if (m_TimeoutTimer) + { + m_TimeoutTimer->cancel(); + } } -void -AgentMessageChannel::Ping() +// --- Message building helpers --- + +std::vector<uint8_t> +AsyncAgentMessageChannel::BeginMessage(AgentMessageType Type, size_t ReservePayload) { - CreateMessage(AgentMessageType::Ping, 0); - FlushMessage(); + std::vector<uint8_t> Buf; + Buf.reserve(MessageHeaderLength + ReservePayload); + Buf.push_back(static_cast<uint8_t>(Type)); + Buf.resize(MessageHeaderLength); // 1 byte type + 4 bytes length placeholder + return Buf; } void -AgentMessageChannel::Fork(int ChannelId, int BufferSize) +AsyncAgentMessageChannel::FinalizeAndSend(std::vector<uint8_t> Msg) { - CreateMessage(AgentMessageType::Fork, sizeof(int) + sizeof(int)); - WriteInt32(ChannelId); - WriteInt32(BufferSize); - FlushMessage(); + const uint32_t PayloadSize = static_cast<uint32_t>(Msg.size() - MessageHeaderLength); + memcpy(&Msg[1], &PayloadSize, sizeof(uint32_t)); + m_Socket->AsyncSendFrame(m_ChannelId, std::move(Msg)); } void -AgentMessageChannel::Attach() +AsyncAgentMessageChannel::WriteInt32(std::vector<uint8_t>& Buf, int Value) { - CreateMessage(AgentMessageType::Attach, 0); - FlushMessage(); + const uint8_t* Ptr = reinterpret_cast<const uint8_t*>(&Value); + Buf.insert(Buf.end(), Ptr, Ptr + sizeof(int)); } -void -AgentMessageChannel::UploadFiles(const char* Path, const char* Locator) +int +AsyncAgentMessageChannel::ReadInt32(const uint8_t** Pos) { - CreateMessage(AgentMessageType::WriteFiles, strlen(Path) + strlen(Locator) + 20); - WriteString(Path); - WriteString(Locator); - FlushMessage(); + int Value; + memcpy(&Value, *Pos, sizeof(int)); + *Pos += sizeof(int); + return Value; } void -AgentMessageChannel::Execute(const char* Exe, - const char* const* Args, - size_t NumArgs, - const char* WorkingDir, - const char* const* EnvVars, - size_t NumEnvVars, - ExecuteProcessFlags Flags) +AsyncAgentMessageChannel::WriteFixedLengthBytes(std::vector<uint8_t>& Buf, const uint8_t* Data, size_t Length) { - size_t RequiredSize = 50 + strlen(Exe); - for (size_t i = 0; i < NumArgs; ++i) - { - RequiredSize += strlen(Args[i]) + 10; - } - if (WorkingDir) - { - RequiredSize += strlen(WorkingDir) + 10; - } - for (size_t i = 0; i < NumEnvVars; ++i) - { - RequiredSize += strlen(EnvVars[i]) + 20; - } - - CreateMessage(AgentMessageType::ExecuteV2, RequiredSize); - WriteString(Exe); - - WriteUnsignedVarInt(NumArgs); - for (size_t i = 0; i < NumArgs; ++i) - { - WriteString(Args[i]); - } - - WriteOptionalString(WorkingDir); - - // ExecuteV2 protocol requires env vars as separate key/value pairs. - // Callers pass "KEY=VALUE" strings; we split on the first '=' here. - WriteUnsignedVarInt(NumEnvVars); - for (size_t i = 0; i < NumEnvVars; ++i) - { - const char* Eq = strchr(EnvVars[i], '='); - assert(Eq != nullptr); - - WriteString(std::string_view(EnvVars[i], Eq - EnvVars[i])); - if (*(Eq + 1) == '\0') - { - WriteOptionalString(nullptr); - } - else - { - WriteOptionalString(Eq + 1); - } - } + Buf.insert(Buf.end(), Data, Data + Length); +} - WriteInt32(static_cast<int>(Flags)); - FlushMessage(); +const uint8_t* +AsyncAgentMessageChannel::ReadFixedLengthBytes(const uint8_t** Pos, size_t Length) +{ + const uint8_t* Data = *Pos; + *Pos += Length; + return Data; } -void -AgentMessageChannel::Blob(const uint8_t* Data, size_t Length) +size_t +AsyncAgentMessageChannel::MeasureUnsignedVarInt(size_t Value) { - // Blob responses are chunked to fit within the compute buffer's chunk size. - // The 128-byte margin accounts for the ReadBlobResponse header (offset + total length fields). - const size_t MaxChunkSize = m_Channel->Writer.GetChunkMaxLength() - 128 - MessageHeaderLength; - for (size_t ChunkOffset = 0; ChunkOffset < Length;) + if (Value == 0) { - const size_t ChunkLength = std::min(Length - ChunkOffset, MaxChunkSize); - - CreateMessage(AgentMessageType::ReadBlobResponse, ChunkLength + 128); - WriteInt32(static_cast<int>(ChunkOffset)); - WriteInt32(static_cast<int>(Length)); - WriteFixedLengthBytes(Data + ChunkOffset, ChunkLength); - FlushMessage(); - - ChunkOffset += ChunkLength; + return 1; } + return (FloorLog2_64(static_cast<uint64_t>(Value)) / 7) + 1; } -AgentMessageType -AgentMessageChannel::ReadResponse(int32_t TimeoutMs, bool* OutTimedOut) +void +AsyncAgentMessageChannel::WriteUnsignedVarInt(std::vector<uint8_t>& Buf, size_t Value) { - // Deferred advance: the previous response's buffer is only released when the next - // ReadResponse is called. This allows callers to read response data between calls - // without copying, since the pointer comes directly from the ring buffer. - if (m_ResponseData) - { - m_Channel->Reader.AdvanceReadPosition(m_ResponseLength + MessageHeaderLength); - m_ResponseData = nullptr; - m_ResponseLength = 0; - } + const size_t ByteCount = MeasureUnsignedVarInt(Value); + const size_t StartPos = Buf.size(); + Buf.resize(StartPos + ByteCount); - const uint8_t* Header = m_Channel->Reader.WaitToRead(MessageHeaderLength, TimeoutMs, OutTimedOut); - if (!Header) + uint8_t* Output = Buf.data() + StartPos; + for (size_t i = 1; i < ByteCount; ++i) { - return AgentMessageType::None; + Output[ByteCount - i] = static_cast<uint8_t>(Value); + Value >>= 8; } + Output[0] = static_cast<uint8_t>((0xFF << (9 - static_cast<int>(ByteCount))) | static_cast<uint8_t>(Value)); +} - uint32_t Length; - memcpy(&Length, Header + 1, sizeof(uint32_t)); +size_t +AsyncAgentMessageChannel::ReadUnsignedVarInt(const uint8_t** Pos) +{ + const uint8_t* Data = *Pos; + const uint8_t FirstByte = Data[0]; + const size_t NumBytes = CountLeadingZeros(0xFF & (~static_cast<unsigned int>(FirstByte))) + 1 - 24; - Header = m_Channel->Reader.WaitToRead(MessageHeaderLength + Length, TimeoutMs, OutTimedOut); - if (!Header) + size_t Value = static_cast<size_t>(FirstByte & (0xFF >> NumBytes)); + for (size_t i = 1; i < NumBytes; ++i) { - return AgentMessageType::None; + Value <<= 8; + Value |= Data[i]; } - m_ResponseType = static_cast<AgentMessageType>(Header[0]); - m_ResponseData = Header + MessageHeaderLength; - m_ResponseLength = Length; - - return m_ResponseType; + *Pos += NumBytes; + return Value; } void -AgentMessageChannel::ReadException(ExceptionInfo& Ex) +AsyncAgentMessageChannel::WriteString(std::vector<uint8_t>& Buf, const char* Text) { - assert(m_ResponseType == AgentMessageType::Exception); - const uint8_t* Pos = m_ResponseData; - Ex.Message = ReadString(&Pos); - Ex.Description = ReadString(&Pos); + const size_t Length = strlen(Text); + WriteUnsignedVarInt(Buf, Length); + WriteFixedLengthBytes(Buf, reinterpret_cast<const uint8_t*>(Text), Length); } -int -AgentMessageChannel::ReadExecuteResult() +void +AsyncAgentMessageChannel::WriteString(std::vector<uint8_t>& Buf, std::string_view Text) { - assert(m_ResponseType == AgentMessageType::ExecuteResult); - const uint8_t* Pos = m_ResponseData; - return ReadInt32(&Pos); + WriteUnsignedVarInt(Buf, Text.size()); + WriteFixedLengthBytes(Buf, reinterpret_cast<const uint8_t*>(Text.data()), Text.size()); } -void -AgentMessageChannel::ReadBlobRequest(BlobRequest& Req) +std::string_view +AsyncAgentMessageChannel::ReadString(const uint8_t** Pos) { - assert(m_ResponseType == AgentMessageType::ReadBlob); - const uint8_t* Pos = m_ResponseData; - Req.Locator = ReadString(&Pos); - Req.Offset = ReadUnsignedVarInt(&Pos); - Req.Length = ReadUnsignedVarInt(&Pos); + const size_t Length = ReadUnsignedVarInt(Pos); + const char* Start = reinterpret_cast<const char*>(ReadFixedLengthBytes(Pos, Length)); + return std::string_view(Start, Length); } void -AgentMessageChannel::CreateMessage(AgentMessageType Type, size_t MaxLength) +AsyncAgentMessageChannel::WriteOptionalString(std::vector<uint8_t>& Buf, const char* Text) { - m_RequestData = m_Channel->Writer.WaitToWrite(MessageHeaderLength + MaxLength); - m_RequestData[0] = static_cast<uint8_t>(Type); - m_MaxRequestSize = MaxLength; - m_RequestSize = 0; + if (!Text) + { + WriteUnsignedVarInt(Buf, 0); + } + else + { + const size_t Length = strlen(Text); + WriteUnsignedVarInt(Buf, Length + 1); + WriteFixedLengthBytes(Buf, reinterpret_cast<const uint8_t*>(Text), Length); + } } +// --- Send methods --- + void -AgentMessageChannel::FlushMessage() +AsyncAgentMessageChannel::Close() { - const uint32_t Size = static_cast<uint32_t>(m_RequestSize); - memcpy(&m_RequestData[1], &Size, sizeof(uint32_t)); - m_Channel->Writer.AdvanceWritePosition(MessageHeaderLength + m_RequestSize); - m_RequestSize = 0; - m_MaxRequestSize = 0; - m_RequestData = nullptr; + auto Msg = BeginMessage(AgentMessageType::None, 0); + FinalizeAndSend(std::move(Msg)); } void -AgentMessageChannel::WriteInt32(int Value) +AsyncAgentMessageChannel::Ping() { - WriteFixedLengthBytes(reinterpret_cast<const uint8_t*>(&Value), sizeof(int)); + auto Msg = BeginMessage(AgentMessageType::Ping, 0); + FinalizeAndSend(std::move(Msg)); } -int -AgentMessageChannel::ReadInt32(const uint8_t** Pos) +void +AsyncAgentMessageChannel::Fork(int ChannelId, int BufferSize) { - int Value; - memcpy(&Value, *Pos, sizeof(int)); - *Pos += sizeof(int); - return Value; + auto Msg = BeginMessage(AgentMessageType::Fork, sizeof(int) + sizeof(int)); + WriteInt32(Msg, ChannelId); + WriteInt32(Msg, BufferSize); + FinalizeAndSend(std::move(Msg)); } void -AgentMessageChannel::WriteFixedLengthBytes(const uint8_t* Data, size_t Length) +AsyncAgentMessageChannel::Attach() { - assert(m_RequestSize + Length <= m_MaxRequestSize); - memcpy(&m_RequestData[MessageHeaderLength + m_RequestSize], Data, Length); - m_RequestSize += Length; + auto Msg = BeginMessage(AgentMessageType::Attach, 0); + FinalizeAndSend(std::move(Msg)); } -const uint8_t* -AgentMessageChannel::ReadFixedLengthBytes(const uint8_t** Pos, size_t Length) +void +AsyncAgentMessageChannel::UploadFiles(const char* Path, const char* Locator) { - const uint8_t* Data = *Pos; - *Pos += Length; - return Data; + auto Msg = BeginMessage(AgentMessageType::WriteFiles, strlen(Path) + strlen(Locator) + 20); + WriteString(Msg, Path); + WriteString(Msg, Locator); + FinalizeAndSend(std::move(Msg)); } -size_t -AgentMessageChannel::MeasureUnsignedVarInt(size_t Value) +void +AsyncAgentMessageChannel::Execute(const char* Exe, + const char* const* Args, + size_t NumArgs, + const char* WorkingDir, + const char* const* EnvVars, + size_t NumEnvVars, + ExecuteProcessFlags Flags) { - if (Value == 0) + size_t ReserveSize = 50 + strlen(Exe); + for (size_t i = 0; i < NumArgs; ++i) { - return 1; + ReserveSize += strlen(Args[i]) + 10; } - return (FloorLog2_64(static_cast<uint64_t>(Value)) / 7) + 1; + if (WorkingDir) + { + ReserveSize += strlen(WorkingDir) + 10; + } + for (size_t i = 0; i < NumEnvVars; ++i) + { + ReserveSize += strlen(EnvVars[i]) + 20; + } + + auto Msg = BeginMessage(AgentMessageType::ExecuteV2, ReserveSize); + WriteString(Msg, Exe); + + WriteUnsignedVarInt(Msg, NumArgs); + for (size_t i = 0; i < NumArgs; ++i) + { + WriteString(Msg, Args[i]); + } + + WriteOptionalString(Msg, WorkingDir); + + WriteUnsignedVarInt(Msg, NumEnvVars); + for (size_t i = 0; i < NumEnvVars; ++i) + { + const char* Eq = strchr(EnvVars[i], '='); + assert(Eq != nullptr); + + WriteString(Msg, std::string_view(EnvVars[i], Eq - EnvVars[i])); + if (*(Eq + 1) == '\0') + { + WriteOptionalString(Msg, nullptr); + } + else + { + WriteOptionalString(Msg, Eq + 1); + } + } + + WriteInt32(Msg, static_cast<int>(Flags)); + FinalizeAndSend(std::move(Msg)); } void -AgentMessageChannel::WriteUnsignedVarInt(size_t Value) +AsyncAgentMessageChannel::Blob(const uint8_t* Data, size_t Length) { - const size_t ByteCount = MeasureUnsignedVarInt(Value); - assert(m_RequestSize + ByteCount <= m_MaxRequestSize); + static constexpr size_t MaxBlobChunkSize = 512 * 1024; - uint8_t* Output = m_RequestData + MessageHeaderLength + m_RequestSize; - for (size_t i = 1; i < ByteCount; ++i) + for (size_t ChunkOffset = 0; ChunkOffset < Length;) { - Output[ByteCount - i] = static_cast<uint8_t>(Value); - Value >>= 8; - } - Output[0] = static_cast<uint8_t>((0xFF << (9 - static_cast<int>(ByteCount))) | static_cast<uint8_t>(Value)); + const size_t ChunkLength = std::min(Length - ChunkOffset, MaxBlobChunkSize); - m_RequestSize += ByteCount; + auto Msg = BeginMessage(AgentMessageType::ReadBlobResponse, ChunkLength + 128); + WriteInt32(Msg, static_cast<int>(ChunkOffset)); + WriteInt32(Msg, static_cast<int>(Length)); + WriteFixedLengthBytes(Msg, Data + ChunkOffset, ChunkLength); + FinalizeAndSend(std::move(Msg)); + + ChunkOffset += ChunkLength; + } } -size_t -AgentMessageChannel::ReadUnsignedVarInt(const uint8_t** Pos) +// --- Async response reading --- + +void +AsyncAgentMessageChannel::AsyncReadResponse(int32_t TimeoutMs, AsyncResponseHandler Handler) { - const uint8_t* Data = *Pos; - const uint8_t FirstByte = Data[0]; - const size_t NumBytes = CountLeadingZeros(0xFF & (~static_cast<unsigned int>(FirstByte))) + 1 - 24; + // If frames are already queued, dispatch immediately + if (!m_IncomingFrames.empty()) + { + std::vector<uint8_t> Frame = std::move(m_IncomingFrames.front()); + m_IncomingFrames.pop_front(); - size_t Value = static_cast<size_t>(FirstByte & (0xFF >> NumBytes)); - for (size_t i = 1; i < NumBytes; ++i) + if (Frame.size() >= MessageHeaderLength) + { + AgentMessageType Type = static_cast<AgentMessageType>(Frame[0]); + const uint8_t* Data = Frame.data() + MessageHeaderLength; + size_t Size = Frame.size() - MessageHeaderLength; + asio::post(m_IoContext, [Handler = std::move(Handler), Type, Frame = std::move(Frame), Data, Size]() mutable { + // The Frame is captured to keep Data pointer valid + Handler(Type, Data, Size); + }); + } + else + { + asio::post(m_IoContext, [Handler = std::move(Handler)] { Handler(AgentMessageType::None, nullptr, 0); }); + } + return; + } + + if (m_Detached) { - Value <<= 8; - Value |= Data[i]; + asio::post(m_IoContext, [Handler = std::move(Handler)] { Handler(AgentMessageType::None, nullptr, 0); }); + return; } - *Pos += NumBytes; - return Value; + // No frames queued — store pending handler and arm timeout + m_PendingHandler = std::move(Handler); + + if (TimeoutMs >= 0) + { + m_TimeoutTimer->expires_after(std::chrono::milliseconds(TimeoutMs)); + m_TimeoutTimer->async_wait([this](const asio::error_code& Ec) { + if (Ec) + { + return; // Cancelled — frame arrived before timeout + } + + if (m_PendingHandler) + { + AsyncResponseHandler Handler = std::move(m_PendingHandler); + m_PendingHandler = nullptr; + Handler(AgentMessageType::None, nullptr, 0); + } + }); + } } -size_t -AgentMessageChannel::MeasureString(const char* Text) const +void +AsyncAgentMessageChannel::OnFrame(std::vector<uint8_t> Data) { - const size_t Length = strlen(Text); - return MeasureUnsignedVarInt(Length) + Length; + if (m_PendingHandler) + { + // Cancel the timeout timer + m_TimeoutTimer->cancel(); + + AsyncResponseHandler Handler = std::move(m_PendingHandler); + m_PendingHandler = nullptr; + + if (Data.size() >= MessageHeaderLength) + { + AgentMessageType Type = static_cast<AgentMessageType>(Data[0]); + const uint8_t* Payload = Data.data() + MessageHeaderLength; + size_t PayloadSize = Data.size() - MessageHeaderLength; + Handler(Type, Payload, PayloadSize); + } + else + { + Handler(AgentMessageType::None, nullptr, 0); + } + } + else + { + m_IncomingFrames.push_back(std::move(Data)); + } } void -AgentMessageChannel::WriteString(const char* Text) +AsyncAgentMessageChannel::OnDetach() { - const size_t Length = strlen(Text); - WriteUnsignedVarInt(Length); - WriteFixedLengthBytes(reinterpret_cast<const uint8_t*>(Text), Length); + m_Detached = true; + + if (m_PendingHandler) + { + m_TimeoutTimer->cancel(); + AsyncResponseHandler Handler = std::move(m_PendingHandler); + m_PendingHandler = nullptr; + Handler(AgentMessageType::None, nullptr, 0); + } } +// --- Response parsing helpers --- + void -AgentMessageChannel::WriteString(std::string_view Text) +AsyncAgentMessageChannel::ReadException(const uint8_t* Data, size_t /*Size*/, ExceptionInfo& Ex) { - WriteUnsignedVarInt(Text.size()); - WriteFixedLengthBytes(reinterpret_cast<const uint8_t*>(Text.data()), Text.size()); + const uint8_t* Pos = Data; + Ex.Message = ReadString(&Pos); + Ex.Description = ReadString(&Pos); } -std::string_view -AgentMessageChannel::ReadString(const uint8_t** Pos) +int +AsyncAgentMessageChannel::ReadExecuteResult(const uint8_t* Data, size_t /*Size*/) { - const size_t Length = ReadUnsignedVarInt(Pos); - const char* Start = reinterpret_cast<const char*>(ReadFixedLengthBytes(Pos, Length)); - return std::string_view(Start, Length); + const uint8_t* Pos = Data; + return ReadInt32(&Pos); } void -AgentMessageChannel::WriteOptionalString(const char* Text) +AsyncAgentMessageChannel::ReadBlobRequest(const uint8_t* Data, size_t /*Size*/, BlobRequest& Req) { - // Optional strings use length+1 encoding: 0 means null/absent, - // N>0 means a string of length N-1 follows. This matches the UE - // FAgentMessageChannel serialization convention. - if (!Text) - { - WriteUnsignedVarInt(0); - } - else - { - const size_t Length = strlen(Text); - WriteUnsignedVarInt(Length + 1); - WriteFixedLengthBytes(reinterpret_cast<const uint8_t*>(Text), Length); - } + const uint8_t* Pos = Data; + Req.Locator = ReadString(&Pos); + Req.Offset = ReadUnsignedVarInt(&Pos); + Req.Length = ReadUnsignedVarInt(&Pos); } } // namespace zen::horde |