aboutsummaryrefslogtreecommitdiff
path: root/src/zenhorde/hordeagentmessage.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/zenhorde/hordeagentmessage.cpp')
-rw-r--r--src/zenhorde/hordeagentmessage.cpp581
1 files changed, 370 insertions, 211 deletions
diff --git a/src/zenhorde/hordeagentmessage.cpp b/src/zenhorde/hordeagentmessage.cpp
index 998134a96..bef1bdda8 100644
--- a/src/zenhorde/hordeagentmessage.cpp
+++ b/src/zenhorde/hordeagentmessage.cpp
@@ -4,337 +4,496 @@
#include <zencore/intmath.h>
-#include <cassert>
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <asio.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+#include <zencore/except_fmt.h>
+#include <zencore/logging.h>
+
#include <cstring>
+#include <limits>
namespace zen::horde {
-AgentMessageChannel::AgentMessageChannel(Ref<ComputeChannel> Channel) : m_Channel(std::move(Channel))
-{
-}
-
-AgentMessageChannel::~AgentMessageChannel() = default;
+// --- AsyncAgentMessageChannel ---
-void
-AgentMessageChannel::Close()
+AsyncAgentMessageChannel::AsyncAgentMessageChannel(std::shared_ptr<AsyncComputeSocket> Socket, int ChannelId, asio::io_context& IoContext)
+: m_Socket(std::move(Socket))
+, m_ChannelId(ChannelId)
+, m_IoContext(IoContext)
+, m_TimeoutTimer(std::make_unique<asio::steady_timer>(m_Socket->GetStrand()))
{
- CreateMessage(AgentMessageType::None, 0);
- FlushMessage();
}
-void
-AgentMessageChannel::Ping()
+AsyncAgentMessageChannel::~AsyncAgentMessageChannel()
{
- CreateMessage(AgentMessageType::Ping, 0);
- FlushMessage();
+ if (m_TimeoutTimer)
+ {
+ m_TimeoutTimer->cancel();
+ }
}
-void
-AgentMessageChannel::Fork(int ChannelId, int BufferSize)
+// --- Message building helpers ---
+
+std::vector<uint8_t>
+AsyncAgentMessageChannel::BeginMessage(AgentMessageType Type, size_t ReservePayload)
{
- CreateMessage(AgentMessageType::Fork, sizeof(int) + sizeof(int));
- WriteInt32(ChannelId);
- WriteInt32(BufferSize);
- FlushMessage();
+ std::vector<uint8_t> Buf;
+ Buf.reserve(MessageHeaderLength + ReservePayload);
+ Buf.push_back(static_cast<uint8_t>(Type));
+ Buf.resize(MessageHeaderLength); // 1 byte type + 4 bytes length placeholder
+ return Buf;
}
void
-AgentMessageChannel::Attach()
+AsyncAgentMessageChannel::FinalizeAndSend(std::vector<uint8_t> Msg)
{
- CreateMessage(AgentMessageType::Attach, 0);
- FlushMessage();
+ const uint32_t PayloadSize = static_cast<uint32_t>(Msg.size() - MessageHeaderLength);
+ memcpy(&Msg[1], &PayloadSize, sizeof(uint32_t));
+ m_Socket->AsyncSendFrame(m_ChannelId, std::move(Msg));
}
void
-AgentMessageChannel::UploadFiles(const char* Path, const char* Locator)
+AsyncAgentMessageChannel::WriteInt32(std::vector<uint8_t>& Buf, int Value)
{
- CreateMessage(AgentMessageType::WriteFiles, strlen(Path) + strlen(Locator) + 20);
- WriteString(Path);
- WriteString(Locator);
- FlushMessage();
+ const uint8_t* Ptr = reinterpret_cast<const uint8_t*>(&Value);
+ Buf.insert(Buf.end(), Ptr, Ptr + sizeof(int));
}
-void
-AgentMessageChannel::Execute(const char* Exe,
- const char* const* Args,
- size_t NumArgs,
- const char* WorkingDir,
- const char* const* EnvVars,
- size_t NumEnvVars,
- ExecuteProcessFlags Flags)
+int
+AsyncAgentMessageChannel::ReadInt32(ReadCursor& C)
{
- size_t RequiredSize = 50 + strlen(Exe);
- for (size_t i = 0; i < NumArgs; ++i)
- {
- RequiredSize += strlen(Args[i]) + 10;
- }
- if (WorkingDir)
- {
- RequiredSize += strlen(WorkingDir) + 10;
- }
- for (size_t i = 0; i < NumEnvVars; ++i)
+ if (!C.CheckAvailable(sizeof(int32_t)))
{
- RequiredSize += strlen(EnvVars[i]) + 20;
+ return 0;
}
+ int32_t Value;
+ memcpy(&Value, C.Pos, sizeof(int32_t));
+ C.Pos += sizeof(int32_t);
+ return Value;
+}
- CreateMessage(AgentMessageType::ExecuteV2, RequiredSize);
- WriteString(Exe);
+void
+AsyncAgentMessageChannel::WriteFixedLengthBytes(std::vector<uint8_t>& Buf, const uint8_t* Data, size_t Length)
+{
+ Buf.insert(Buf.end(), Data, Data + Length);
+}
- WriteUnsignedVarInt(NumArgs);
- for (size_t i = 0; i < NumArgs; ++i)
+const uint8_t*
+AsyncAgentMessageChannel::ReadFixedLengthBytes(ReadCursor& C, size_t Length)
+{
+ if (!C.CheckAvailable(Length))
{
- WriteString(Args[i]);
+ return nullptr;
}
+ const uint8_t* Data = C.Pos;
+ C.Pos += Length;
+ return Data;
+}
- WriteOptionalString(WorkingDir);
-
- // ExecuteV2 protocol requires env vars as separate key/value pairs.
- // Callers pass "KEY=VALUE" strings; we split on the first '=' here.
- WriteUnsignedVarInt(NumEnvVars);
- for (size_t i = 0; i < NumEnvVars; ++i)
+size_t
+AsyncAgentMessageChannel::MeasureUnsignedVarInt(size_t Value)
+{
+ if (Value == 0)
{
- const char* Eq = strchr(EnvVars[i], '=');
- assert(Eq != nullptr);
-
- WriteString(std::string_view(EnvVars[i], Eq - EnvVars[i]));
- if (*(Eq + 1) == '\0')
- {
- WriteOptionalString(nullptr);
- }
- else
- {
- WriteOptionalString(Eq + 1);
- }
+ return 1;
}
-
- WriteInt32(static_cast<int>(Flags));
- FlushMessage();
+ return (FloorLog2_64(static_cast<uint64_t>(Value)) / 7) + 1;
}
void
-AgentMessageChannel::Blob(const uint8_t* Data, size_t Length)
+AsyncAgentMessageChannel::WriteUnsignedVarInt(std::vector<uint8_t>& Buf, size_t Value)
{
- // Blob responses are chunked to fit within the compute buffer's chunk size.
- // The 128-byte margin accounts for the ReadBlobResponse header (offset + total length fields).
- const size_t MaxChunkSize = m_Channel->Writer.GetChunkMaxLength() - 128 - MessageHeaderLength;
- for (size_t ChunkOffset = 0; ChunkOffset < Length;)
- {
- const size_t ChunkLength = std::min(Length - ChunkOffset, MaxChunkSize);
-
- CreateMessage(AgentMessageType::ReadBlobResponse, ChunkLength + 128);
- WriteInt32(static_cast<int>(ChunkOffset));
- WriteInt32(static_cast<int>(Length));
- WriteFixedLengthBytes(Data + ChunkOffset, ChunkLength);
- FlushMessage();
+ const size_t ByteCount = MeasureUnsignedVarInt(Value);
+ const size_t StartPos = Buf.size();
+ Buf.resize(StartPos + ByteCount);
- ChunkOffset += ChunkLength;
+ uint8_t* Output = Buf.data() + StartPos;
+ for (size_t i = 1; i < ByteCount; ++i)
+ {
+ Output[ByteCount - i] = static_cast<uint8_t>(Value);
+ Value >>= 8;
}
+ Output[0] = static_cast<uint8_t>((0xFF << (9 - static_cast<int>(ByteCount))) | static_cast<uint8_t>(Value));
}
-AgentMessageType
-AgentMessageChannel::ReadResponse(int32_t TimeoutMs, bool* OutTimedOut)
+size_t
+AsyncAgentMessageChannel::ReadUnsignedVarInt(ReadCursor& C)
{
- // Deferred advance: the previous response's buffer is only released when the next
- // ReadResponse is called. This allows callers to read response data between calls
- // without copying, since the pointer comes directly from the ring buffer.
- if (m_ResponseData)
+ // Need at least the leading byte to determine the encoded length.
+ if (!C.CheckAvailable(1))
{
- m_Channel->Reader.AdvanceReadPosition(m_ResponseLength + MessageHeaderLength);
- m_ResponseData = nullptr;
- m_ResponseLength = 0;
+ return 0;
}
- const uint8_t* Header = m_Channel->Reader.WaitToRead(MessageHeaderLength, TimeoutMs, OutTimedOut);
- if (!Header)
+ const uint8_t FirstByte = C.Pos[0];
+ const size_t NumBytes = CountLeadingZeros(0xFF & (~static_cast<unsigned int>(FirstByte))) + 1 - 24;
+
+ // The encoded length implied by the leading 0xFF-run may be 1..9 bytes; ensure the remaining bytes are in-bounds.
+ if (!C.CheckAvailable(NumBytes))
{
- return AgentMessageType::None;
+ return 0;
}
- uint32_t Length;
- memcpy(&Length, Header + 1, sizeof(uint32_t));
-
- Header = m_Channel->Reader.WaitToRead(MessageHeaderLength + Length, TimeoutMs, OutTimedOut);
- if (!Header)
+ size_t Value = static_cast<size_t>(FirstByte & (0xFF >> NumBytes));
+ for (size_t i = 1; i < NumBytes; ++i)
{
- return AgentMessageType::None;
+ Value <<= 8;
+ Value |= C.Pos[i];
}
- m_ResponseType = static_cast<AgentMessageType>(Header[0]);
- m_ResponseData = Header + MessageHeaderLength;
- m_ResponseLength = Length;
-
- return m_ResponseType;
+ C.Pos += NumBytes;
+ return Value;
}
void
-AgentMessageChannel::ReadException(ExceptionInfo& Ex)
+AsyncAgentMessageChannel::WriteString(std::vector<uint8_t>& Buf, const char* Text)
{
- assert(m_ResponseType == AgentMessageType::Exception);
- const uint8_t* Pos = m_ResponseData;
- Ex.Message = ReadString(&Pos);
- Ex.Description = ReadString(&Pos);
+ const size_t Length = strlen(Text);
+ WriteUnsignedVarInt(Buf, Length);
+ WriteFixedLengthBytes(Buf, reinterpret_cast<const uint8_t*>(Text), Length);
}
-int
-AgentMessageChannel::ReadExecuteResult()
+void
+AsyncAgentMessageChannel::WriteString(std::vector<uint8_t>& Buf, std::string_view Text)
{
- assert(m_ResponseType == AgentMessageType::ExecuteResult);
- const uint8_t* Pos = m_ResponseData;
- return ReadInt32(&Pos);
+ WriteUnsignedVarInt(Buf, Text.size());
+ WriteFixedLengthBytes(Buf, reinterpret_cast<const uint8_t*>(Text.data()), Text.size());
}
-void
-AgentMessageChannel::ReadBlobRequest(BlobRequest& Req)
+std::string_view
+AsyncAgentMessageChannel::ReadString(ReadCursor& C)
{
- assert(m_ResponseType == AgentMessageType::ReadBlob);
- const uint8_t* Pos = m_ResponseData;
- Req.Locator = ReadString(&Pos);
- Req.Offset = ReadUnsignedVarInt(&Pos);
- Req.Length = ReadUnsignedVarInt(&Pos);
+ const size_t Length = ReadUnsignedVarInt(C);
+ const uint8_t* Start = ReadFixedLengthBytes(C, Length);
+ if (C.ParseError || !Start)
+ {
+ return {};
+ }
+ return std::string_view(reinterpret_cast<const char*>(Start), Length);
}
void
-AgentMessageChannel::CreateMessage(AgentMessageType Type, size_t MaxLength)
+AsyncAgentMessageChannel::WriteOptionalString(std::vector<uint8_t>& Buf, const char* Text)
{
- m_RequestData = m_Channel->Writer.WaitToWrite(MessageHeaderLength + MaxLength);
- m_RequestData[0] = static_cast<uint8_t>(Type);
- m_MaxRequestSize = MaxLength;
- m_RequestSize = 0;
+ if (!Text)
+ {
+ WriteUnsignedVarInt(Buf, 0);
+ }
+ else
+ {
+ const size_t Length = strlen(Text);
+ WriteUnsignedVarInt(Buf, Length + 1);
+ WriteFixedLengthBytes(Buf, reinterpret_cast<const uint8_t*>(Text), Length);
+ }
}
+// --- Send methods ---
+
void
-AgentMessageChannel::FlushMessage()
+AsyncAgentMessageChannel::Close()
{
- const uint32_t Size = static_cast<uint32_t>(m_RequestSize);
- memcpy(&m_RequestData[1], &Size, sizeof(uint32_t));
- m_Channel->Writer.AdvanceWritePosition(MessageHeaderLength + m_RequestSize);
- m_RequestSize = 0;
- m_MaxRequestSize = 0;
- m_RequestData = nullptr;
+ auto Msg = BeginMessage(AgentMessageType::None, 0);
+ FinalizeAndSend(std::move(Msg));
}
void
-AgentMessageChannel::WriteInt32(int Value)
+AsyncAgentMessageChannel::Ping()
{
- WriteFixedLengthBytes(reinterpret_cast<const uint8_t*>(&Value), sizeof(int));
+ auto Msg = BeginMessage(AgentMessageType::Ping, 0);
+ FinalizeAndSend(std::move(Msg));
}
-int
-AgentMessageChannel::ReadInt32(const uint8_t** Pos)
+void
+AsyncAgentMessageChannel::Fork(int ChannelId, int BufferSize)
{
- int Value;
- memcpy(&Value, *Pos, sizeof(int));
- *Pos += sizeof(int);
- return Value;
+ auto Msg = BeginMessage(AgentMessageType::Fork, sizeof(int) + sizeof(int));
+ WriteInt32(Msg, ChannelId);
+ WriteInt32(Msg, BufferSize);
+ FinalizeAndSend(std::move(Msg));
}
void
-AgentMessageChannel::WriteFixedLengthBytes(const uint8_t* Data, size_t Length)
+AsyncAgentMessageChannel::Attach()
{
- assert(m_RequestSize + Length <= m_MaxRequestSize);
- memcpy(&m_RequestData[MessageHeaderLength + m_RequestSize], Data, Length);
- m_RequestSize += Length;
+ auto Msg = BeginMessage(AgentMessageType::Attach, 0);
+ FinalizeAndSend(std::move(Msg));
}
-const uint8_t*
-AgentMessageChannel::ReadFixedLengthBytes(const uint8_t** Pos, size_t Length)
+void
+AsyncAgentMessageChannel::UploadFiles(const char* Path, const char* Locator)
{
- const uint8_t* Data = *Pos;
- *Pos += Length;
- return Data;
+ auto Msg = BeginMessage(AgentMessageType::WriteFiles, strlen(Path) + strlen(Locator) + 20);
+ WriteString(Msg, Path);
+ WriteString(Msg, Locator);
+ FinalizeAndSend(std::move(Msg));
}
-size_t
-AgentMessageChannel::MeasureUnsignedVarInt(size_t Value)
+void
+AsyncAgentMessageChannel::Execute(const char* Exe,
+ const char* const* Args,
+ size_t NumArgs,
+ const char* WorkingDir,
+ const char* const* EnvVars,
+ size_t NumEnvVars,
+ ExecuteProcessFlags Flags)
{
- if (Value == 0)
+ size_t ReserveSize = 50 + strlen(Exe);
+ for (size_t i = 0; i < NumArgs; ++i)
{
- return 1;
+ ReserveSize += strlen(Args[i]) + 10;
+ }
+ if (WorkingDir)
+ {
+ ReserveSize += strlen(WorkingDir) + 10;
+ }
+ for (size_t i = 0; i < NumEnvVars; ++i)
+ {
+ ReserveSize += strlen(EnvVars[i]) + 20;
}
- return (FloorLog2_64(static_cast<uint64_t>(Value)) / 7) + 1;
-}
-void
-AgentMessageChannel::WriteUnsignedVarInt(size_t Value)
-{
- const size_t ByteCount = MeasureUnsignedVarInt(Value);
- assert(m_RequestSize + ByteCount <= m_MaxRequestSize);
+ auto Msg = BeginMessage(AgentMessageType::ExecuteV2, ReserveSize);
+ WriteString(Msg, Exe);
- uint8_t* Output = m_RequestData + MessageHeaderLength + m_RequestSize;
- for (size_t i = 1; i < ByteCount; ++i)
+ WriteUnsignedVarInt(Msg, NumArgs);
+ for (size_t i = 0; i < NumArgs; ++i)
{
- Output[ByteCount - i] = static_cast<uint8_t>(Value);
- Value >>= 8;
+ WriteString(Msg, Args[i]);
}
- Output[0] = static_cast<uint8_t>((0xFF << (9 - static_cast<int>(ByteCount))) | static_cast<uint8_t>(Value));
- m_RequestSize += ByteCount;
+ WriteOptionalString(Msg, WorkingDir);
+
+ WriteUnsignedVarInt(Msg, NumEnvVars);
+ for (size_t i = 0; i < NumEnvVars; ++i)
+ {
+ const char* Eq = strchr(EnvVars[i], '=');
+ if (Eq == nullptr)
+ {
+ // assert() would be compiled out in release and leave *(Eq+1) as UB -
+ // refuse to build the message for a malformed KEY=VALUE string instead.
+ throw zen::runtime_error("horde agent env var at index {} missing '=' separator", i);
+ }
+
+ WriteString(Msg, std::string_view(EnvVars[i], Eq - EnvVars[i]));
+ if (*(Eq + 1) == '\0')
+ {
+ WriteOptionalString(Msg, nullptr);
+ }
+ else
+ {
+ WriteOptionalString(Msg, Eq + 1);
+ }
+ }
+
+ WriteInt32(Msg, static_cast<int>(Flags));
+ FinalizeAndSend(std::move(Msg));
}
-size_t
-AgentMessageChannel::ReadUnsignedVarInt(const uint8_t** Pos)
+void
+AsyncAgentMessageChannel::Blob(const uint8_t* Data, size_t Length)
{
- const uint8_t* Data = *Pos;
- const uint8_t FirstByte = Data[0];
- const size_t NumBytes = CountLeadingZeros(0xFF & (~static_cast<unsigned int>(FirstByte))) + 1 - 24;
+ static constexpr size_t MaxBlobChunkSize = 512 * 1024;
- size_t Value = static_cast<size_t>(FirstByte & (0xFF >> NumBytes));
- for (size_t i = 1; i < NumBytes; ++i)
+ // The Horde ReadBlobResponse wire format encodes both the chunk Offset and the total
+ // Length as int32. Lengths of 2 GiB or more would wrap to negative and confuse the
+ // remote parser. Refuse the send rather than produce a protocol violation.
+ if (Length > static_cast<size_t>(std::numeric_limits<int32_t>::max()))
{
- Value <<= 8;
- Value |= Data[i];
+ throw zen::runtime_error("horde ReadBlobResponse length {} exceeds int32 wire limit", Length);
}
- *Pos += NumBytes;
- return Value;
+ for (size_t ChunkOffset = 0; ChunkOffset < Length;)
+ {
+ const size_t ChunkLength = std::min(Length - ChunkOffset, MaxBlobChunkSize);
+
+ auto Msg = BeginMessage(AgentMessageType::ReadBlobResponse, ChunkLength + 128);
+ WriteInt32(Msg, static_cast<int32_t>(ChunkOffset));
+ WriteInt32(Msg, static_cast<int32_t>(Length));
+ WriteFixedLengthBytes(Msg, Data + ChunkOffset, ChunkLength);
+ FinalizeAndSend(std::move(Msg));
+
+ ChunkOffset += ChunkLength;
+ }
}
-size_t
-AgentMessageChannel::MeasureString(const char* Text) const
+// --- Async response reading ---
+
+void
+AsyncAgentMessageChannel::AsyncReadResponse(int32_t TimeoutMs, AsyncResponseHandler Handler)
{
- const size_t Length = strlen(Text);
- return MeasureUnsignedVarInt(Length) + Length;
+ // Serialize all access to m_IncomingFrames / m_PendingHandler / m_TimeoutTimer onto
+ // the socket's strand; OnFrame/OnDetach also run on that strand. Without this, the
+ // timer wait completion would run on a bare io_context thread (3 concurrent run()
+ // loops in the provisioner) and race with OnFrame on m_PendingHandler.
+ asio::dispatch(m_Socket->GetStrand(), [this, TimeoutMs, Handler = std::move(Handler)]() mutable {
+ if (!m_IncomingFrames.empty())
+ {
+ std::vector<uint8_t> Frame = std::move(m_IncomingFrames.front());
+ m_IncomingFrames.pop_front();
+
+ if (Frame.size() >= MessageHeaderLength)
+ {
+ AgentMessageType Type = static_cast<AgentMessageType>(Frame[0]);
+ const uint8_t* Data = Frame.data() + MessageHeaderLength;
+ size_t Size = Frame.size() - MessageHeaderLength;
+ asio::post(m_IoContext, [Handler = std::move(Handler), Type, Frame = std::move(Frame), Data, Size]() mutable {
+ // The Frame is captured to keep Data pointer valid
+ Handler(Type, Data, Size);
+ });
+ }
+ else
+ {
+ asio::post(m_IoContext, [Handler = std::move(Handler)] { Handler(AgentMessageType::None, nullptr, 0); });
+ }
+ return;
+ }
+
+ if (m_Detached)
+ {
+ asio::post(m_IoContext, [Handler = std::move(Handler)] { Handler(AgentMessageType::None, nullptr, 0); });
+ return;
+ }
+
+ // No frames queued - store pending handler and arm timeout
+ m_PendingHandler = std::move(Handler);
+
+ if (TimeoutMs >= 0)
+ {
+ m_TimeoutTimer->expires_after(std::chrono::milliseconds(TimeoutMs));
+ m_TimeoutTimer->async_wait(asio::bind_executor(m_Socket->GetStrand(), [this](const asio::error_code& Ec) {
+ if (Ec)
+ {
+ return; // Cancelled - frame arrived before timeout
+ }
+
+ // Already on the strand: safe to mutate m_PendingHandler.
+ if (m_PendingHandler)
+ {
+ AsyncResponseHandler Handler = std::move(m_PendingHandler);
+ m_PendingHandler = nullptr;
+ Handler(AgentMessageType::None, nullptr, 0);
+ }
+ }));
+ }
+ });
}
void
-AgentMessageChannel::WriteString(const char* Text)
+AsyncAgentMessageChannel::OnFrame(std::vector<uint8_t> Data)
{
- const size_t Length = strlen(Text);
- WriteUnsignedVarInt(Length);
- WriteFixedLengthBytes(reinterpret_cast<const uint8_t*>(Text), Length);
+ if (m_PendingHandler)
+ {
+ // Cancel the timeout timer
+ m_TimeoutTimer->cancel();
+
+ AsyncResponseHandler Handler = std::move(m_PendingHandler);
+ m_PendingHandler = nullptr;
+
+ if (Data.size() >= MessageHeaderLength)
+ {
+ AgentMessageType Type = static_cast<AgentMessageType>(Data[0]);
+ const uint8_t* Payload = Data.data() + MessageHeaderLength;
+ size_t PayloadSize = Data.size() - MessageHeaderLength;
+ Handler(Type, Payload, PayloadSize);
+ }
+ else
+ {
+ Handler(AgentMessageType::None, nullptr, 0);
+ }
+ }
+ else
+ {
+ m_IncomingFrames.push_back(std::move(Data));
+ }
}
void
-AgentMessageChannel::WriteString(std::string_view Text)
+AsyncAgentMessageChannel::OnDetach()
{
- WriteUnsignedVarInt(Text.size());
- WriteFixedLengthBytes(reinterpret_cast<const uint8_t*>(Text.data()), Text.size());
+ m_Detached = true;
+
+ if (m_PendingHandler)
+ {
+ m_TimeoutTimer->cancel();
+ AsyncResponseHandler Handler = std::move(m_PendingHandler);
+ m_PendingHandler = nullptr;
+ Handler(AgentMessageType::None, nullptr, 0);
+ }
}
-std::string_view
-AgentMessageChannel::ReadString(const uint8_t** Pos)
+// --- Response parsing helpers ---
+
+bool
+AsyncAgentMessageChannel::ReadException(const uint8_t* Data, size_t Size, ExceptionInfo& Ex)
{
- const size_t Length = ReadUnsignedVarInt(Pos);
- const char* Start = reinterpret_cast<const char*>(ReadFixedLengthBytes(Pos, Length));
- return std::string_view(Start, Length);
+ ReadCursor C{Data, Data + Size, false};
+ Ex.Message = ReadString(C);
+ Ex.Description = ReadString(C);
+ if (C.ParseError)
+ {
+ Ex = {};
+ return false;
+ }
+ return true;
}
-void
-AgentMessageChannel::WriteOptionalString(const char* Text)
+bool
+AsyncAgentMessageChannel::ReadExecuteResult(const uint8_t* Data, size_t Size, int32_t& OutExitCode)
{
- // Optional strings use length+1 encoding: 0 means null/absent,
- // N>0 means a string of length N-1 follows. This matches the UE
- // FAgentMessageChannel serialization convention.
- if (!Text)
+ ReadCursor C{Data, Data + Size, false};
+ OutExitCode = ReadInt32(C);
+ return !C.ParseError;
+}
+
+static bool
+IsSafeLocator(std::string_view Locator)
+{
+ // Reject empty, overlong, path-separator-containing, parent-relative, absolute, or
+ // control-character-containing locators. The locator is used as a filename component
+ // joined with a trusted BundleDir, so the only safe characters are a restricted
+ // filename alphabet.
+ if (Locator.empty() || Locator.size() > 255)
{
- WriteUnsignedVarInt(0);
+ return false;
}
- else
+ if (Locator == "." || Locator == "..")
{
- const size_t Length = strlen(Text);
- WriteUnsignedVarInt(Length + 1);
- WriteFixedLengthBytes(reinterpret_cast<const uint8_t*>(Text), Length);
+ return false;
+ }
+ for (char Ch : Locator)
+ {
+ const unsigned char U = static_cast<unsigned char>(Ch);
+ if (U < 0x20 || U == 0x7F)
+ {
+ return false; // control / NUL / DEL
+ }
+ if (Ch == '/' || Ch == '\\' || Ch == ':')
+ {
+ return false; // path separators / drive letters
+ }
+ }
+ // Disallow leading/trailing dot or whitespace (Windows quirks + hidden-file dodges)
+ if (Locator.front() == '.' || Locator.front() == ' ' || Locator.back() == '.' || Locator.back() == ' ')
+ {
+ return false;
+ }
+ return true;
+}
+
+bool
+AsyncAgentMessageChannel::ReadBlobRequest(const uint8_t* Data, size_t Size, BlobRequest& Req)
+{
+ ReadCursor C{Data, Data + Size, false};
+ Req.Locator = ReadString(C);
+ Req.Offset = ReadUnsignedVarInt(C);
+ Req.Length = ReadUnsignedVarInt(C);
+ if (C.ParseError || !IsSafeLocator(Req.Locator))
+ {
+ Req = {};
+ return false;
}
+ return true;
}
} // namespace zen::horde