diff options
Diffstat (limited to 'src/zenhorde/hordeagentmessage.h')
| -rw-r--r-- | src/zenhorde/hordeagentmessage.h | 153 |
1 files changed, 87 insertions, 66 deletions
diff --git a/src/zenhorde/hordeagentmessage.h b/src/zenhorde/hordeagentmessage.h index 38c4375fd..fb7c5ed29 100644 --- a/src/zenhorde/hordeagentmessage.h +++ b/src/zenhorde/hordeagentmessage.h @@ -4,14 +4,22 @@ #include <zenbase/zenbase.h> -#include "hordecomputechannel.h" +#include "hordecomputesocket.h" #include <cstddef> #include <cstdint> +#include <deque> +#include <functional> +#include <memory> #include <string> #include <string_view> +#include <system_error> #include <vector> +namespace asio { +class io_context; +} // namespace asio + namespace zen::horde { /** Agent message types matching the UE EAgentMessageType byte values. @@ -55,45 +63,34 @@ struct BlobRequest size_t Length = 0; }; -/** Channel for sending and receiving agent messages over a ComputeChannel. +/** Handler for async response reads. Receives the message type and a view of the payload data. + * The payload vector is valid until the next AsyncReadResponse call. */ +using AsyncResponseHandler = std::function<void(AgentMessageType Type, const uint8_t* Data, size_t Size)>; + +/** Async channel for sending and receiving agent messages over an AsyncComputeSocket. * - * Implements the Horde agent message protocol, matching the UE - * FAgentMessageChannel serialization format exactly. Messages are framed as - * [type (1B)][payload length (4B)][payload]. Strings use length-prefixed UTF-8; - * integers use variable-length encoding. + * Send methods build messages into vectors and submit them via AsyncComputeSocket. + * Receives are delivered via the socket's FrameHandler callback and queued internally. + * AsyncReadResponse checks the queue and invokes the handler, with optional timeout. * - * The protocol has two directions: - * - Requests (initiator -> remote): Close, Ping, Fork, Attach, UploadFiles, Execute, Blob - * - Responses (remote -> initiator): ReadResponse returns the type, then call the - * appropriate Read* method to parse the payload. + * All operations must be externally serialized (e.g. via the socket's strand). */ -class AgentMessageChannel +class AsyncAgentMessageChannel { public: - explicit AgentMessageChannel(Ref<ComputeChannel> Channel); - ~AgentMessageChannel(); + AsyncAgentMessageChannel(std::shared_ptr<AsyncComputeSocket> Socket, int ChannelId, asio::io_context& IoContext); + ~AsyncAgentMessageChannel(); - AgentMessageChannel(const AgentMessageChannel&) = delete; - AgentMessageChannel& operator=(const AgentMessageChannel&) = delete; + AsyncAgentMessageChannel(const AsyncAgentMessageChannel&) = delete; + AsyncAgentMessageChannel& operator=(const AsyncAgentMessageChannel&) = delete; - // --- Requests (Initiator -> Remote) --- + // --- Requests (fire-and-forget sends) --- - /** Close the channel. */ void Close(); - - /** Send a keepalive ping. */ void Ping(); - - /** Fork communication to a new channel with the given ID and buffer size. */ void Fork(int ChannelId, int BufferSize); - - /** Send an attach request (used during channel setup handshake). */ void Attach(); - - /** Request the remote agent to write files from the given bundle locator. */ void UploadFiles(const char* Path, const char* Locator); - - /** Execute a process on the remote machine. */ void Execute(const char* Exe, const char* const* Args, size_t NumArgs, @@ -101,61 +98,85 @@ public: const char* const* EnvVars, size_t NumEnvVars, ExecuteProcessFlags Flags = ExecuteProcessFlags::None); - - /** Send blob data in response to a ReadBlob request. */ void Blob(const uint8_t* Data, size_t Length); - // --- Responses (Remote -> Initiator) --- - - /** Read the next response message. Returns the message type, or None on timeout. - * After this returns, use GetResponseData()/GetResponseSize() or the typed - * Read* methods to access the payload. */ - AgentMessageType ReadResponse(int32_t TimeoutMs = -1, bool* OutTimedOut = nullptr); + // --- Async response reading --- - const void* GetResponseData() const { return m_ResponseData; } - size_t GetResponseSize() const { return m_ResponseLength; } + /** Read the next response. If a frame is already queued, the handler is posted immediately. + * Otherwise waits up to TimeoutMs for a frame to arrive. On timeout, invokes the handler + * with AgentMessageType::None. */ + void AsyncReadResponse(int32_t TimeoutMs, AsyncResponseHandler Handler); - /** Parse an Exception response payload. */ - void ReadException(ExceptionInfo& Ex); + /** Called by the socket's FrameHandler when a frame arrives for this channel. */ + void OnFrame(std::vector<uint8_t> Data); - /** Parse an ExecuteResult response payload. Returns the exit code. */ - int ReadExecuteResult(); + /** Called by the socket's DetachHandler. */ + void OnDetach(); - /** Parse a ReadBlob response payload into a BlobRequest. */ - void ReadBlobRequest(BlobRequest& Req); + /** Returns true if the channel has been detached (connection lost). */ + bool IsDetached() const { return m_Detached; } -private: - static constexpr size_t MessageHeaderLength = 5; ///< [type(1B)][length(4B)] + // --- Response parsing helpers --- - Ref<ComputeChannel> m_Channel; + /** Parse an Exception message payload. Returns false on malformed/truncated input. */ + [[nodiscard]] static bool ReadException(const uint8_t* Data, size_t Size, ExceptionInfo& Ex); - uint8_t* m_RequestData = nullptr; - size_t m_RequestSize = 0; - size_t m_MaxRequestSize = 0; + /** Parse an ExecuteResult message payload. Returns false on malformed/truncated input. */ + [[nodiscard]] static bool ReadExecuteResult(const uint8_t* Data, size_t Size, int32_t& OutExitCode); - AgentMessageType m_ResponseType = AgentMessageType::None; - const uint8_t* m_ResponseData = nullptr; - size_t m_ResponseLength = 0; + /** Parse a ReadBlob message payload. Returns false on malformed/truncated input or + * if the Locator contains characters that would not be safe to use as a path component. */ + [[nodiscard]] static bool ReadBlobRequest(const uint8_t* Data, size_t Size, BlobRequest& Req); - void CreateMessage(AgentMessageType Type, size_t MaxLength); - void FlushMessage(); +private: + static constexpr size_t MessageHeaderLength = 5; + + // Message building helpers + std::vector<uint8_t> BeginMessage(AgentMessageType Type, size_t ReservePayload); + void FinalizeAndSend(std::vector<uint8_t> Msg); + + /** Bounds-checked reader cursor. All Read* helpers set ParseError instead of reading past End. */ + struct ReadCursor + { + const uint8_t* Pos = nullptr; + const uint8_t* End = nullptr; + bool ParseError = false; + + [[nodiscard]] bool CheckAvailable(size_t N) + { + if (ParseError || static_cast<size_t>(End - Pos) < N) + { + ParseError = true; + return false; + } + return true; + } + }; + + static void WriteInt32(std::vector<uint8_t>& Buf, int Value); + static int ReadInt32(ReadCursor& C); + + static void WriteFixedLengthBytes(std::vector<uint8_t>& Buf, const uint8_t* Data, size_t Length); + static const uint8_t* ReadFixedLengthBytes(ReadCursor& C, size_t Length); - void WriteInt32(int Value); - static int ReadInt32(const uint8_t** Pos); + static size_t MeasureUnsignedVarInt(size_t Value); + static void WriteUnsignedVarInt(std::vector<uint8_t>& Buf, size_t Value); + static size_t ReadUnsignedVarInt(ReadCursor& C); - void WriteFixedLengthBytes(const uint8_t* Data, size_t Length); - static const uint8_t* ReadFixedLengthBytes(const uint8_t** Pos, size_t Length); + static void WriteString(std::vector<uint8_t>& Buf, const char* Text); + static void WriteString(std::vector<uint8_t>& Buf, std::string_view Text); + static std::string_view ReadString(ReadCursor& C); - static size_t MeasureUnsignedVarInt(size_t Value); - void WriteUnsignedVarInt(size_t Value); - static size_t ReadUnsignedVarInt(const uint8_t** Pos); + static void WriteOptionalString(std::vector<uint8_t>& Buf, const char* Text); - size_t MeasureString(const char* Text) const; - void WriteString(const char* Text); - void WriteString(std::string_view Text); - static std::string_view ReadString(const uint8_t** Pos); + std::shared_ptr<AsyncComputeSocket> m_Socket; + int m_ChannelId; + asio::io_context& m_IoContext; - void WriteOptionalString(const char* Text); + std::deque<std::vector<uint8_t>> m_IncomingFrames; + AsyncResponseHandler m_PendingHandler; + std::unique_ptr<asio::steady_timer> m_TimeoutTimer; + bool m_Detached = false; }; } // namespace zen::horde |