aboutsummaryrefslogtreecommitdiff
path: root/src/zenhorde/hordeagentmessage.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/zenhorde/hordeagentmessage.h')
-rw-r--r--src/zenhorde/hordeagentmessage.h153
1 files changed, 87 insertions, 66 deletions
diff --git a/src/zenhorde/hordeagentmessage.h b/src/zenhorde/hordeagentmessage.h
index 38c4375fd..fb7c5ed29 100644
--- a/src/zenhorde/hordeagentmessage.h
+++ b/src/zenhorde/hordeagentmessage.h
@@ -4,14 +4,22 @@
#include <zenbase/zenbase.h>
-#include "hordecomputechannel.h"
+#include "hordecomputesocket.h"
#include <cstddef>
#include <cstdint>
+#include <deque>
+#include <functional>
+#include <memory>
#include <string>
#include <string_view>
+#include <system_error>
#include <vector>
+namespace asio {
+class io_context;
+} // namespace asio
+
namespace zen::horde {
/** Agent message types matching the UE EAgentMessageType byte values.
@@ -55,45 +63,34 @@ struct BlobRequest
size_t Length = 0;
};
-/** Channel for sending and receiving agent messages over a ComputeChannel.
+/** Handler for async response reads. Receives the message type and a view of the payload data.
+ * The payload vector is valid until the next AsyncReadResponse call. */
+using AsyncResponseHandler = std::function<void(AgentMessageType Type, const uint8_t* Data, size_t Size)>;
+
+/** Async channel for sending and receiving agent messages over an AsyncComputeSocket.
*
- * Implements the Horde agent message protocol, matching the UE
- * FAgentMessageChannel serialization format exactly. Messages are framed as
- * [type (1B)][payload length (4B)][payload]. Strings use length-prefixed UTF-8;
- * integers use variable-length encoding.
+ * Send methods build messages into vectors and submit them via AsyncComputeSocket.
+ * Receives are delivered via the socket's FrameHandler callback and queued internally.
+ * AsyncReadResponse checks the queue and invokes the handler, with optional timeout.
*
- * The protocol has two directions:
- * - Requests (initiator -> remote): Close, Ping, Fork, Attach, UploadFiles, Execute, Blob
- * - Responses (remote -> initiator): ReadResponse returns the type, then call the
- * appropriate Read* method to parse the payload.
+ * All operations must be externally serialized (e.g. via the socket's strand).
*/
-class AgentMessageChannel
+class AsyncAgentMessageChannel
{
public:
- explicit AgentMessageChannel(Ref<ComputeChannel> Channel);
- ~AgentMessageChannel();
+ AsyncAgentMessageChannel(std::shared_ptr<AsyncComputeSocket> Socket, int ChannelId, asio::io_context& IoContext);
+ ~AsyncAgentMessageChannel();
- AgentMessageChannel(const AgentMessageChannel&) = delete;
- AgentMessageChannel& operator=(const AgentMessageChannel&) = delete;
+ AsyncAgentMessageChannel(const AsyncAgentMessageChannel&) = delete;
+ AsyncAgentMessageChannel& operator=(const AsyncAgentMessageChannel&) = delete;
- // --- Requests (Initiator -> Remote) ---
+ // --- Requests (fire-and-forget sends) ---
- /** Close the channel. */
void Close();
-
- /** Send a keepalive ping. */
void Ping();
-
- /** Fork communication to a new channel with the given ID and buffer size. */
void Fork(int ChannelId, int BufferSize);
-
- /** Send an attach request (used during channel setup handshake). */
void Attach();
-
- /** Request the remote agent to write files from the given bundle locator. */
void UploadFiles(const char* Path, const char* Locator);
-
- /** Execute a process on the remote machine. */
void Execute(const char* Exe,
const char* const* Args,
size_t NumArgs,
@@ -101,61 +98,85 @@ public:
const char* const* EnvVars,
size_t NumEnvVars,
ExecuteProcessFlags Flags = ExecuteProcessFlags::None);
-
- /** Send blob data in response to a ReadBlob request. */
void Blob(const uint8_t* Data, size_t Length);
- // --- Responses (Remote -> Initiator) ---
-
- /** Read the next response message. Returns the message type, or None on timeout.
- * After this returns, use GetResponseData()/GetResponseSize() or the typed
- * Read* methods to access the payload. */
- AgentMessageType ReadResponse(int32_t TimeoutMs = -1, bool* OutTimedOut = nullptr);
+ // --- Async response reading ---
- const void* GetResponseData() const { return m_ResponseData; }
- size_t GetResponseSize() const { return m_ResponseLength; }
+ /** Read the next response. If a frame is already queued, the handler is posted immediately.
+ * Otherwise waits up to TimeoutMs for a frame to arrive. On timeout, invokes the handler
+ * with AgentMessageType::None. */
+ void AsyncReadResponse(int32_t TimeoutMs, AsyncResponseHandler Handler);
- /** Parse an Exception response payload. */
- void ReadException(ExceptionInfo& Ex);
+ /** Called by the socket's FrameHandler when a frame arrives for this channel. */
+ void OnFrame(std::vector<uint8_t> Data);
- /** Parse an ExecuteResult response payload. Returns the exit code. */
- int ReadExecuteResult();
+ /** Called by the socket's DetachHandler. */
+ void OnDetach();
- /** Parse a ReadBlob response payload into a BlobRequest. */
- void ReadBlobRequest(BlobRequest& Req);
+ /** Returns true if the channel has been detached (connection lost). */
+ bool IsDetached() const { return m_Detached; }
-private:
- static constexpr size_t MessageHeaderLength = 5; ///< [type(1B)][length(4B)]
+ // --- Response parsing helpers ---
- Ref<ComputeChannel> m_Channel;
+ /** Parse an Exception message payload. Returns false on malformed/truncated input. */
+ [[nodiscard]] static bool ReadException(const uint8_t* Data, size_t Size, ExceptionInfo& Ex);
- uint8_t* m_RequestData = nullptr;
- size_t m_RequestSize = 0;
- size_t m_MaxRequestSize = 0;
+ /** Parse an ExecuteResult message payload. Returns false on malformed/truncated input. */
+ [[nodiscard]] static bool ReadExecuteResult(const uint8_t* Data, size_t Size, int32_t& OutExitCode);
- AgentMessageType m_ResponseType = AgentMessageType::None;
- const uint8_t* m_ResponseData = nullptr;
- size_t m_ResponseLength = 0;
+ /** Parse a ReadBlob message payload. Returns false on malformed/truncated input or
+ * if the Locator contains characters that would not be safe to use as a path component. */
+ [[nodiscard]] static bool ReadBlobRequest(const uint8_t* Data, size_t Size, BlobRequest& Req);
- void CreateMessage(AgentMessageType Type, size_t MaxLength);
- void FlushMessage();
+private:
+ static constexpr size_t MessageHeaderLength = 5;
+
+ // Message building helpers
+ std::vector<uint8_t> BeginMessage(AgentMessageType Type, size_t ReservePayload);
+ void FinalizeAndSend(std::vector<uint8_t> Msg);
+
+ /** Bounds-checked reader cursor. All Read* helpers set ParseError instead of reading past End. */
+ struct ReadCursor
+ {
+ const uint8_t* Pos = nullptr;
+ const uint8_t* End = nullptr;
+ bool ParseError = false;
+
+ [[nodiscard]] bool CheckAvailable(size_t N)
+ {
+ if (ParseError || static_cast<size_t>(End - Pos) < N)
+ {
+ ParseError = true;
+ return false;
+ }
+ return true;
+ }
+ };
+
+ static void WriteInt32(std::vector<uint8_t>& Buf, int Value);
+ static int ReadInt32(ReadCursor& C);
+
+ static void WriteFixedLengthBytes(std::vector<uint8_t>& Buf, const uint8_t* Data, size_t Length);
+ static const uint8_t* ReadFixedLengthBytes(ReadCursor& C, size_t Length);
- void WriteInt32(int Value);
- static int ReadInt32(const uint8_t** Pos);
+ static size_t MeasureUnsignedVarInt(size_t Value);
+ static void WriteUnsignedVarInt(std::vector<uint8_t>& Buf, size_t Value);
+ static size_t ReadUnsignedVarInt(ReadCursor& C);
- void WriteFixedLengthBytes(const uint8_t* Data, size_t Length);
- static const uint8_t* ReadFixedLengthBytes(const uint8_t** Pos, size_t Length);
+ static void WriteString(std::vector<uint8_t>& Buf, const char* Text);
+ static void WriteString(std::vector<uint8_t>& Buf, std::string_view Text);
+ static std::string_view ReadString(ReadCursor& C);
- static size_t MeasureUnsignedVarInt(size_t Value);
- void WriteUnsignedVarInt(size_t Value);
- static size_t ReadUnsignedVarInt(const uint8_t** Pos);
+ static void WriteOptionalString(std::vector<uint8_t>& Buf, const char* Text);
- size_t MeasureString(const char* Text) const;
- void WriteString(const char* Text);
- void WriteString(std::string_view Text);
- static std::string_view ReadString(const uint8_t** Pos);
+ std::shared_ptr<AsyncComputeSocket> m_Socket;
+ int m_ChannelId;
+ asio::io_context& m_IoContext;
- void WriteOptionalString(const char* Text);
+ std::deque<std::vector<uint8_t>> m_IncomingFrames;
+ AsyncResponseHandler m_PendingHandler;
+ std::unique_ptr<asio::steady_timer> m_TimeoutTimer;
+ bool m_Detached = false;
};
} // namespace zen::horde