aboutsummaryrefslogtreecommitdiff
path: root/src/zenhorde/hordeagentmessage.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/zenhorde/hordeagentmessage.cpp')
-rw-r--r--src/zenhorde/hordeagentmessage.cpp502
1 files changed, 284 insertions, 218 deletions
diff --git a/src/zenhorde/hordeagentmessage.cpp b/src/zenhorde/hordeagentmessage.cpp
index 998134a96..31498972f 100644
--- a/src/zenhorde/hordeagentmessage.cpp
+++ b/src/zenhorde/hordeagentmessage.cpp
@@ -4,337 +4,403 @@
#include <zencore/intmath.h>
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <asio.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
#include <cassert>
#include <cstring>
namespace zen::horde {
-AgentMessageChannel::AgentMessageChannel(Ref<ComputeChannel> Channel) : m_Channel(std::move(Channel))
+// --- AsyncAgentMessageChannel ---
+
+AsyncAgentMessageChannel::AsyncAgentMessageChannel(std::shared_ptr<AsyncComputeSocket> Socket, int ChannelId, asio::io_context& IoContext)
+: m_Socket(std::move(Socket))
+, m_ChannelId(ChannelId)
+, m_IoContext(IoContext)
+, m_TimeoutTimer(std::make_unique<asio::steady_timer>(IoContext))
{
}
-AgentMessageChannel::~AgentMessageChannel() = default;
-
-void
-AgentMessageChannel::Close()
+AsyncAgentMessageChannel::~AsyncAgentMessageChannel()
{
- CreateMessage(AgentMessageType::None, 0);
- FlushMessage();
+ if (m_TimeoutTimer)
+ {
+ m_TimeoutTimer->cancel();
+ }
}
-void
-AgentMessageChannel::Ping()
+// --- Message building helpers ---
+
+std::vector<uint8_t>
+AsyncAgentMessageChannel::BeginMessage(AgentMessageType Type, size_t ReservePayload)
{
- CreateMessage(AgentMessageType::Ping, 0);
- FlushMessage();
+ std::vector<uint8_t> Buf;
+ Buf.reserve(MessageHeaderLength + ReservePayload);
+ Buf.push_back(static_cast<uint8_t>(Type));
+ Buf.resize(MessageHeaderLength); // 1 byte type + 4 bytes length placeholder
+ return Buf;
}
void
-AgentMessageChannel::Fork(int ChannelId, int BufferSize)
+AsyncAgentMessageChannel::FinalizeAndSend(std::vector<uint8_t> Msg)
{
- CreateMessage(AgentMessageType::Fork, sizeof(int) + sizeof(int));
- WriteInt32(ChannelId);
- WriteInt32(BufferSize);
- FlushMessage();
+ const uint32_t PayloadSize = static_cast<uint32_t>(Msg.size() - MessageHeaderLength);
+ memcpy(&Msg[1], &PayloadSize, sizeof(uint32_t));
+ m_Socket->AsyncSendFrame(m_ChannelId, std::move(Msg));
}
void
-AgentMessageChannel::Attach()
+AsyncAgentMessageChannel::WriteInt32(std::vector<uint8_t>& Buf, int Value)
{
- CreateMessage(AgentMessageType::Attach, 0);
- FlushMessage();
+ const uint8_t* Ptr = reinterpret_cast<const uint8_t*>(&Value);
+ Buf.insert(Buf.end(), Ptr, Ptr + sizeof(int));
}
-void
-AgentMessageChannel::UploadFiles(const char* Path, const char* Locator)
+int
+AsyncAgentMessageChannel::ReadInt32(const uint8_t** Pos)
{
- CreateMessage(AgentMessageType::WriteFiles, strlen(Path) + strlen(Locator) + 20);
- WriteString(Path);
- WriteString(Locator);
- FlushMessage();
+ int Value;
+ memcpy(&Value, *Pos, sizeof(int));
+ *Pos += sizeof(int);
+ return Value;
}
void
-AgentMessageChannel::Execute(const char* Exe,
- const char* const* Args,
- size_t NumArgs,
- const char* WorkingDir,
- const char* const* EnvVars,
- size_t NumEnvVars,
- ExecuteProcessFlags Flags)
+AsyncAgentMessageChannel::WriteFixedLengthBytes(std::vector<uint8_t>& Buf, const uint8_t* Data, size_t Length)
{
- size_t RequiredSize = 50 + strlen(Exe);
- for (size_t i = 0; i < NumArgs; ++i)
- {
- RequiredSize += strlen(Args[i]) + 10;
- }
- if (WorkingDir)
- {
- RequiredSize += strlen(WorkingDir) + 10;
- }
- for (size_t i = 0; i < NumEnvVars; ++i)
- {
- RequiredSize += strlen(EnvVars[i]) + 20;
- }
-
- CreateMessage(AgentMessageType::ExecuteV2, RequiredSize);
- WriteString(Exe);
-
- WriteUnsignedVarInt(NumArgs);
- for (size_t i = 0; i < NumArgs; ++i)
- {
- WriteString(Args[i]);
- }
-
- WriteOptionalString(WorkingDir);
-
- // ExecuteV2 protocol requires env vars as separate key/value pairs.
- // Callers pass "KEY=VALUE" strings; we split on the first '=' here.
- WriteUnsignedVarInt(NumEnvVars);
- for (size_t i = 0; i < NumEnvVars; ++i)
- {
- const char* Eq = strchr(EnvVars[i], '=');
- assert(Eq != nullptr);
-
- WriteString(std::string_view(EnvVars[i], Eq - EnvVars[i]));
- if (*(Eq + 1) == '\0')
- {
- WriteOptionalString(nullptr);
- }
- else
- {
- WriteOptionalString(Eq + 1);
- }
- }
+ Buf.insert(Buf.end(), Data, Data + Length);
+}
- WriteInt32(static_cast<int>(Flags));
- FlushMessage();
+const uint8_t*
+AsyncAgentMessageChannel::ReadFixedLengthBytes(const uint8_t** Pos, size_t Length)
+{
+ const uint8_t* Data = *Pos;
+ *Pos += Length;
+ return Data;
}
-void
-AgentMessageChannel::Blob(const uint8_t* Data, size_t Length)
+size_t
+AsyncAgentMessageChannel::MeasureUnsignedVarInt(size_t Value)
{
- // Blob responses are chunked to fit within the compute buffer's chunk size.
- // The 128-byte margin accounts for the ReadBlobResponse header (offset + total length fields).
- const size_t MaxChunkSize = m_Channel->Writer.GetChunkMaxLength() - 128 - MessageHeaderLength;
- for (size_t ChunkOffset = 0; ChunkOffset < Length;)
+ if (Value == 0)
{
- const size_t ChunkLength = std::min(Length - ChunkOffset, MaxChunkSize);
-
- CreateMessage(AgentMessageType::ReadBlobResponse, ChunkLength + 128);
- WriteInt32(static_cast<int>(ChunkOffset));
- WriteInt32(static_cast<int>(Length));
- WriteFixedLengthBytes(Data + ChunkOffset, ChunkLength);
- FlushMessage();
-
- ChunkOffset += ChunkLength;
+ return 1;
}
+ return (FloorLog2_64(static_cast<uint64_t>(Value)) / 7) + 1;
}
-AgentMessageType
-AgentMessageChannel::ReadResponse(int32_t TimeoutMs, bool* OutTimedOut)
+void
+AsyncAgentMessageChannel::WriteUnsignedVarInt(std::vector<uint8_t>& Buf, size_t Value)
{
- // Deferred advance: the previous response's buffer is only released when the next
- // ReadResponse is called. This allows callers to read response data between calls
- // without copying, since the pointer comes directly from the ring buffer.
- if (m_ResponseData)
- {
- m_Channel->Reader.AdvanceReadPosition(m_ResponseLength + MessageHeaderLength);
- m_ResponseData = nullptr;
- m_ResponseLength = 0;
- }
+ const size_t ByteCount = MeasureUnsignedVarInt(Value);
+ const size_t StartPos = Buf.size();
+ Buf.resize(StartPos + ByteCount);
- const uint8_t* Header = m_Channel->Reader.WaitToRead(MessageHeaderLength, TimeoutMs, OutTimedOut);
- if (!Header)
+ uint8_t* Output = Buf.data() + StartPos;
+ for (size_t i = 1; i < ByteCount; ++i)
{
- return AgentMessageType::None;
+ Output[ByteCount - i] = static_cast<uint8_t>(Value);
+ Value >>= 8;
}
+ Output[0] = static_cast<uint8_t>((0xFF << (9 - static_cast<int>(ByteCount))) | static_cast<uint8_t>(Value));
+}
- uint32_t Length;
- memcpy(&Length, Header + 1, sizeof(uint32_t));
+size_t
+AsyncAgentMessageChannel::ReadUnsignedVarInt(const uint8_t** Pos)
+{
+ const uint8_t* Data = *Pos;
+ const uint8_t FirstByte = Data[0];
+ const size_t NumBytes = CountLeadingZeros(0xFF & (~static_cast<unsigned int>(FirstByte))) + 1 - 24;
- Header = m_Channel->Reader.WaitToRead(MessageHeaderLength + Length, TimeoutMs, OutTimedOut);
- if (!Header)
+ size_t Value = static_cast<size_t>(FirstByte & (0xFF >> NumBytes));
+ for (size_t i = 1; i < NumBytes; ++i)
{
- return AgentMessageType::None;
+ Value <<= 8;
+ Value |= Data[i];
}
- m_ResponseType = static_cast<AgentMessageType>(Header[0]);
- m_ResponseData = Header + MessageHeaderLength;
- m_ResponseLength = Length;
-
- return m_ResponseType;
+ *Pos += NumBytes;
+ return Value;
}
void
-AgentMessageChannel::ReadException(ExceptionInfo& Ex)
+AsyncAgentMessageChannel::WriteString(std::vector<uint8_t>& Buf, const char* Text)
{
- assert(m_ResponseType == AgentMessageType::Exception);
- const uint8_t* Pos = m_ResponseData;
- Ex.Message = ReadString(&Pos);
- Ex.Description = ReadString(&Pos);
+ const size_t Length = strlen(Text);
+ WriteUnsignedVarInt(Buf, Length);
+ WriteFixedLengthBytes(Buf, reinterpret_cast<const uint8_t*>(Text), Length);
}
-int
-AgentMessageChannel::ReadExecuteResult()
+void
+AsyncAgentMessageChannel::WriteString(std::vector<uint8_t>& Buf, std::string_view Text)
{
- assert(m_ResponseType == AgentMessageType::ExecuteResult);
- const uint8_t* Pos = m_ResponseData;
- return ReadInt32(&Pos);
+ WriteUnsignedVarInt(Buf, Text.size());
+ WriteFixedLengthBytes(Buf, reinterpret_cast<const uint8_t*>(Text.data()), Text.size());
}
-void
-AgentMessageChannel::ReadBlobRequest(BlobRequest& Req)
+std::string_view
+AsyncAgentMessageChannel::ReadString(const uint8_t** Pos)
{
- assert(m_ResponseType == AgentMessageType::ReadBlob);
- const uint8_t* Pos = m_ResponseData;
- Req.Locator = ReadString(&Pos);
- Req.Offset = ReadUnsignedVarInt(&Pos);
- Req.Length = ReadUnsignedVarInt(&Pos);
+ const size_t Length = ReadUnsignedVarInt(Pos);
+ const char* Start = reinterpret_cast<const char*>(ReadFixedLengthBytes(Pos, Length));
+ return std::string_view(Start, Length);
}
void
-AgentMessageChannel::CreateMessage(AgentMessageType Type, size_t MaxLength)
+AsyncAgentMessageChannel::WriteOptionalString(std::vector<uint8_t>& Buf, const char* Text)
{
- m_RequestData = m_Channel->Writer.WaitToWrite(MessageHeaderLength + MaxLength);
- m_RequestData[0] = static_cast<uint8_t>(Type);
- m_MaxRequestSize = MaxLength;
- m_RequestSize = 0;
+ if (!Text)
+ {
+ WriteUnsignedVarInt(Buf, 0);
+ }
+ else
+ {
+ const size_t Length = strlen(Text);
+ WriteUnsignedVarInt(Buf, Length + 1);
+ WriteFixedLengthBytes(Buf, reinterpret_cast<const uint8_t*>(Text), Length);
+ }
}
+// --- Send methods ---
+
void
-AgentMessageChannel::FlushMessage()
+AsyncAgentMessageChannel::Close()
{
- const uint32_t Size = static_cast<uint32_t>(m_RequestSize);
- memcpy(&m_RequestData[1], &Size, sizeof(uint32_t));
- m_Channel->Writer.AdvanceWritePosition(MessageHeaderLength + m_RequestSize);
- m_RequestSize = 0;
- m_MaxRequestSize = 0;
- m_RequestData = nullptr;
+ auto Msg = BeginMessage(AgentMessageType::None, 0);
+ FinalizeAndSend(std::move(Msg));
}
void
-AgentMessageChannel::WriteInt32(int Value)
+AsyncAgentMessageChannel::Ping()
{
- WriteFixedLengthBytes(reinterpret_cast<const uint8_t*>(&Value), sizeof(int));
+ auto Msg = BeginMessage(AgentMessageType::Ping, 0);
+ FinalizeAndSend(std::move(Msg));
}
-int
-AgentMessageChannel::ReadInt32(const uint8_t** Pos)
+void
+AsyncAgentMessageChannel::Fork(int ChannelId, int BufferSize)
{
- int Value;
- memcpy(&Value, *Pos, sizeof(int));
- *Pos += sizeof(int);
- return Value;
+ auto Msg = BeginMessage(AgentMessageType::Fork, sizeof(int) + sizeof(int));
+ WriteInt32(Msg, ChannelId);
+ WriteInt32(Msg, BufferSize);
+ FinalizeAndSend(std::move(Msg));
}
void
-AgentMessageChannel::WriteFixedLengthBytes(const uint8_t* Data, size_t Length)
+AsyncAgentMessageChannel::Attach()
{
- assert(m_RequestSize + Length <= m_MaxRequestSize);
- memcpy(&m_RequestData[MessageHeaderLength + m_RequestSize], Data, Length);
- m_RequestSize += Length;
+ auto Msg = BeginMessage(AgentMessageType::Attach, 0);
+ FinalizeAndSend(std::move(Msg));
}
-const uint8_t*
-AgentMessageChannel::ReadFixedLengthBytes(const uint8_t** Pos, size_t Length)
+void
+AsyncAgentMessageChannel::UploadFiles(const char* Path, const char* Locator)
{
- const uint8_t* Data = *Pos;
- *Pos += Length;
- return Data;
+ auto Msg = BeginMessage(AgentMessageType::WriteFiles, strlen(Path) + strlen(Locator) + 20);
+ WriteString(Msg, Path);
+ WriteString(Msg, Locator);
+ FinalizeAndSend(std::move(Msg));
}
-size_t
-AgentMessageChannel::MeasureUnsignedVarInt(size_t Value)
+void
+AsyncAgentMessageChannel::Execute(const char* Exe,
+ const char* const* Args,
+ size_t NumArgs,
+ const char* WorkingDir,
+ const char* const* EnvVars,
+ size_t NumEnvVars,
+ ExecuteProcessFlags Flags)
{
- if (Value == 0)
+ size_t ReserveSize = 50 + strlen(Exe);
+ for (size_t i = 0; i < NumArgs; ++i)
{
- return 1;
+ ReserveSize += strlen(Args[i]) + 10;
}
- return (FloorLog2_64(static_cast<uint64_t>(Value)) / 7) + 1;
+ if (WorkingDir)
+ {
+ ReserveSize += strlen(WorkingDir) + 10;
+ }
+ for (size_t i = 0; i < NumEnvVars; ++i)
+ {
+ ReserveSize += strlen(EnvVars[i]) + 20;
+ }
+
+ auto Msg = BeginMessage(AgentMessageType::ExecuteV2, ReserveSize);
+ WriteString(Msg, Exe);
+
+ WriteUnsignedVarInt(Msg, NumArgs);
+ for (size_t i = 0; i < NumArgs; ++i)
+ {
+ WriteString(Msg, Args[i]);
+ }
+
+ WriteOptionalString(Msg, WorkingDir);
+
+ WriteUnsignedVarInt(Msg, NumEnvVars);
+ for (size_t i = 0; i < NumEnvVars; ++i)
+ {
+ const char* Eq = strchr(EnvVars[i], '=');
+ assert(Eq != nullptr);
+
+ WriteString(Msg, std::string_view(EnvVars[i], Eq - EnvVars[i]));
+ if (*(Eq + 1) == '\0')
+ {
+ WriteOptionalString(Msg, nullptr);
+ }
+ else
+ {
+ WriteOptionalString(Msg, Eq + 1);
+ }
+ }
+
+ WriteInt32(Msg, static_cast<int>(Flags));
+ FinalizeAndSend(std::move(Msg));
}
void
-AgentMessageChannel::WriteUnsignedVarInt(size_t Value)
+AsyncAgentMessageChannel::Blob(const uint8_t* Data, size_t Length)
{
- const size_t ByteCount = MeasureUnsignedVarInt(Value);
- assert(m_RequestSize + ByteCount <= m_MaxRequestSize);
+ static constexpr size_t MaxBlobChunkSize = 512 * 1024;
- uint8_t* Output = m_RequestData + MessageHeaderLength + m_RequestSize;
- for (size_t i = 1; i < ByteCount; ++i)
+ for (size_t ChunkOffset = 0; ChunkOffset < Length;)
{
- Output[ByteCount - i] = static_cast<uint8_t>(Value);
- Value >>= 8;
- }
- Output[0] = static_cast<uint8_t>((0xFF << (9 - static_cast<int>(ByteCount))) | static_cast<uint8_t>(Value));
+ const size_t ChunkLength = std::min(Length - ChunkOffset, MaxBlobChunkSize);
- m_RequestSize += ByteCount;
+ auto Msg = BeginMessage(AgentMessageType::ReadBlobResponse, ChunkLength + 128);
+ WriteInt32(Msg, static_cast<int>(ChunkOffset));
+ WriteInt32(Msg, static_cast<int>(Length));
+ WriteFixedLengthBytes(Msg, Data + ChunkOffset, ChunkLength);
+ FinalizeAndSend(std::move(Msg));
+
+ ChunkOffset += ChunkLength;
+ }
}
-size_t
-AgentMessageChannel::ReadUnsignedVarInt(const uint8_t** Pos)
+// --- Async response reading ---
+
+void
+AsyncAgentMessageChannel::AsyncReadResponse(int32_t TimeoutMs, AsyncResponseHandler Handler)
{
- const uint8_t* Data = *Pos;
- const uint8_t FirstByte = Data[0];
- const size_t NumBytes = CountLeadingZeros(0xFF & (~static_cast<unsigned int>(FirstByte))) + 1 - 24;
+ // If frames are already queued, dispatch immediately
+ if (!m_IncomingFrames.empty())
+ {
+ std::vector<uint8_t> Frame = std::move(m_IncomingFrames.front());
+ m_IncomingFrames.pop_front();
- size_t Value = static_cast<size_t>(FirstByte & (0xFF >> NumBytes));
- for (size_t i = 1; i < NumBytes; ++i)
+ if (Frame.size() >= MessageHeaderLength)
+ {
+ AgentMessageType Type = static_cast<AgentMessageType>(Frame[0]);
+ const uint8_t* Data = Frame.data() + MessageHeaderLength;
+ size_t Size = Frame.size() - MessageHeaderLength;
+ asio::post(m_IoContext, [Handler = std::move(Handler), Type, Frame = std::move(Frame), Data, Size]() mutable {
+ // The Frame is captured to keep Data pointer valid
+ Handler(Type, Data, Size);
+ });
+ }
+ else
+ {
+ asio::post(m_IoContext, [Handler = std::move(Handler)] { Handler(AgentMessageType::None, nullptr, 0); });
+ }
+ return;
+ }
+
+ if (m_Detached)
{
- Value <<= 8;
- Value |= Data[i];
+ asio::post(m_IoContext, [Handler = std::move(Handler)] { Handler(AgentMessageType::None, nullptr, 0); });
+ return;
}
- *Pos += NumBytes;
- return Value;
+ // No frames queued — store pending handler and arm timeout
+ m_PendingHandler = std::move(Handler);
+
+ if (TimeoutMs >= 0)
+ {
+ m_TimeoutTimer->expires_after(std::chrono::milliseconds(TimeoutMs));
+ m_TimeoutTimer->async_wait([this](const asio::error_code& Ec) {
+ if (Ec)
+ {
+ return; // Cancelled — frame arrived before timeout
+ }
+
+ if (m_PendingHandler)
+ {
+ AsyncResponseHandler Handler = std::move(m_PendingHandler);
+ m_PendingHandler = nullptr;
+ Handler(AgentMessageType::None, nullptr, 0);
+ }
+ });
+ }
}
-size_t
-AgentMessageChannel::MeasureString(const char* Text) const
+void
+AsyncAgentMessageChannel::OnFrame(std::vector<uint8_t> Data)
{
- const size_t Length = strlen(Text);
- return MeasureUnsignedVarInt(Length) + Length;
+ if (m_PendingHandler)
+ {
+ // Cancel the timeout timer
+ m_TimeoutTimer->cancel();
+
+ AsyncResponseHandler Handler = std::move(m_PendingHandler);
+ m_PendingHandler = nullptr;
+
+ if (Data.size() >= MessageHeaderLength)
+ {
+ AgentMessageType Type = static_cast<AgentMessageType>(Data[0]);
+ const uint8_t* Payload = Data.data() + MessageHeaderLength;
+ size_t PayloadSize = Data.size() - MessageHeaderLength;
+ Handler(Type, Payload, PayloadSize);
+ }
+ else
+ {
+ Handler(AgentMessageType::None, nullptr, 0);
+ }
+ }
+ else
+ {
+ m_IncomingFrames.push_back(std::move(Data));
+ }
}
void
-AgentMessageChannel::WriteString(const char* Text)
+AsyncAgentMessageChannel::OnDetach()
{
- const size_t Length = strlen(Text);
- WriteUnsignedVarInt(Length);
- WriteFixedLengthBytes(reinterpret_cast<const uint8_t*>(Text), Length);
+ m_Detached = true;
+
+ if (m_PendingHandler)
+ {
+ m_TimeoutTimer->cancel();
+ AsyncResponseHandler Handler = std::move(m_PendingHandler);
+ m_PendingHandler = nullptr;
+ Handler(AgentMessageType::None, nullptr, 0);
+ }
}
+// --- Response parsing helpers ---
+
void
-AgentMessageChannel::WriteString(std::string_view Text)
+AsyncAgentMessageChannel::ReadException(const uint8_t* Data, size_t /*Size*/, ExceptionInfo& Ex)
{
- WriteUnsignedVarInt(Text.size());
- WriteFixedLengthBytes(reinterpret_cast<const uint8_t*>(Text.data()), Text.size());
+ const uint8_t* Pos = Data;
+ Ex.Message = ReadString(&Pos);
+ Ex.Description = ReadString(&Pos);
}
-std::string_view
-AgentMessageChannel::ReadString(const uint8_t** Pos)
+int
+AsyncAgentMessageChannel::ReadExecuteResult(const uint8_t* Data, size_t /*Size*/)
{
- const size_t Length = ReadUnsignedVarInt(Pos);
- const char* Start = reinterpret_cast<const char*>(ReadFixedLengthBytes(Pos, Length));
- return std::string_view(Start, Length);
+ const uint8_t* Pos = Data;
+ return ReadInt32(&Pos);
}
void
-AgentMessageChannel::WriteOptionalString(const char* Text)
+AsyncAgentMessageChannel::ReadBlobRequest(const uint8_t* Data, size_t /*Size*/, BlobRequest& Req)
{
- // Optional strings use length+1 encoding: 0 means null/absent,
- // N>0 means a string of length N-1 follows. This matches the UE
- // FAgentMessageChannel serialization convention.
- if (!Text)
- {
- WriteUnsignedVarInt(0);
- }
- else
- {
- const size_t Length = strlen(Text);
- WriteUnsignedVarInt(Length + 1);
- WriteFixedLengthBytes(reinterpret_cast<const uint8_t*>(Text), Length);
- }
+ const uint8_t* Pos = Data;
+ Req.Locator = ReadString(&Pos);
+ Req.Offset = ReadUnsignedVarInt(&Pos);
+ Req.Length = ReadUnsignedVarInt(&Pos);
}
} // namespace zen::horde