aboutsummaryrefslogtreecommitdiff
path: root/src/zenhorde/hordeagentmessage.h
blob: fb7c5ed29543cdb873f6041054061db234eeee3f (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
// Copyright Epic Games, Inc. All Rights Reserved.

#pragma once

#include <zenbase/zenbase.h>

#include "hordecomputesocket.h"

#include <cstddef>
#include <cstdint>
#include <deque>
#include <functional>
#include <memory>
#include <string>
#include <string_view>
#include <system_error>
#include <vector>

namespace asio {
class io_context;
}  // namespace asio

namespace zen::horde {

/** Agent message types matching the UE EAgentMessageType byte values.
 *  These are the message opcodes exchanged over the agent/child channels. */
enum class AgentMessageType : uint8_t
{
	None			   = 0x00,
	Ping			   = 0x01,
	Exception		   = 0x02,
	Fork			   = 0x03,
	Attach			   = 0x04,
	WriteFiles		   = 0x10,
	WriteFilesResponse = 0x11,
	DeleteFiles		   = 0x12,
	ExecuteV2		   = 0x22,
	ExecuteOutput	   = 0x17,
	ExecuteResult	   = 0x18,
	ReadBlob		   = 0x20,
	ReadBlobResponse   = 0x21,
};

/** Flags for the ExecuteV2 message. */
enum class ExecuteProcessFlags : uint8_t
{
	None	= 0,
	UseWine = 1,  ///< Run the executable under Wine on Linux agents
};

/** Parsed exception information from an Exception message. */
struct ExceptionInfo
{
	std::string_view Message;
	std::string_view Description;
};

/** Parsed blob read request from a ReadBlob message. */
struct BlobRequest
{
	std::string_view Locator;
	size_t			 Offset = 0;
	size_t			 Length = 0;
};

/** Handler for async response reads. Receives the message type and a view of the payload data.
 *  The payload vector is valid until the next AsyncReadResponse call. */
using AsyncResponseHandler = std::function<void(AgentMessageType Type, const uint8_t* Data, size_t Size)>;

/** Async channel for sending and receiving agent messages over an AsyncComputeSocket.
 *
 *  Send methods build messages into vectors and submit them via AsyncComputeSocket.
 *  Receives are delivered via the socket's FrameHandler callback and queued internally.
 *  AsyncReadResponse checks the queue and invokes the handler, with optional timeout.
 *
 *  All operations must be externally serialized (e.g. via the socket's strand).
 */
class AsyncAgentMessageChannel
{
public:
	AsyncAgentMessageChannel(std::shared_ptr<AsyncComputeSocket> Socket, int ChannelId, asio::io_context& IoContext);
	~AsyncAgentMessageChannel();

	AsyncAgentMessageChannel(const AsyncAgentMessageChannel&) = delete;
	AsyncAgentMessageChannel& operator=(const AsyncAgentMessageChannel&) = delete;

	// --- Requests (fire-and-forget sends) ---

	void Close();
	void Ping();
	void Fork(int ChannelId, int BufferSize);
	void Attach();
	void UploadFiles(const char* Path, const char* Locator);
	void Execute(const char*		 Exe,
				 const char* const*	 Args,
				 size_t				 NumArgs,
				 const char*		 WorkingDir,
				 const char* const*	 EnvVars,
				 size_t				 NumEnvVars,
				 ExecuteProcessFlags Flags = ExecuteProcessFlags::None);
	void Blob(const uint8_t* Data, size_t Length);

	// --- Async response reading ---

	/** Read the next response. If a frame is already queued, the handler is posted immediately.
	 *  Otherwise waits up to TimeoutMs for a frame to arrive. On timeout, invokes the handler
	 *  with AgentMessageType::None. */
	void AsyncReadResponse(int32_t TimeoutMs, AsyncResponseHandler Handler);

	/** Called by the socket's FrameHandler when a frame arrives for this channel. */
	void OnFrame(std::vector<uint8_t> Data);

	/** Called by the socket's DetachHandler. */
	void OnDetach();

	/** Returns true if the channel has been detached (connection lost). */
	bool IsDetached() const { return m_Detached; }

	// --- Response parsing helpers ---

	/** Parse an Exception message payload. Returns false on malformed/truncated input. */
	[[nodiscard]] static bool ReadException(const uint8_t* Data, size_t Size, ExceptionInfo& Ex);

	/** Parse an ExecuteResult message payload. Returns false on malformed/truncated input. */
	[[nodiscard]] static bool ReadExecuteResult(const uint8_t* Data, size_t Size, int32_t& OutExitCode);

	/** Parse a ReadBlob message payload. Returns false on malformed/truncated input or
	 *  if the Locator contains characters that would not be safe to use as a path component. */
	[[nodiscard]] static bool ReadBlobRequest(const uint8_t* Data, size_t Size, BlobRequest& Req);

private:
	static constexpr size_t MessageHeaderLength = 5;

	// Message building helpers
	std::vector<uint8_t> BeginMessage(AgentMessageType Type, size_t ReservePayload);
	void				 FinalizeAndSend(std::vector<uint8_t> Msg);

	/** Bounds-checked reader cursor. All Read* helpers set ParseError instead of reading past End. */
	struct ReadCursor
	{
		const uint8_t* Pos		  = nullptr;
		const uint8_t* End		  = nullptr;
		bool		   ParseError = false;

		[[nodiscard]] bool CheckAvailable(size_t N)
		{
			if (ParseError || static_cast<size_t>(End - Pos) < N)
			{
				ParseError = true;
				return false;
			}
			return true;
		}
	};

	static void WriteInt32(std::vector<uint8_t>& Buf, int Value);
	static int	ReadInt32(ReadCursor& C);

	static void			  WriteFixedLengthBytes(std::vector<uint8_t>& Buf, const uint8_t* Data, size_t Length);
	static const uint8_t* ReadFixedLengthBytes(ReadCursor& C, size_t Length);

	static size_t MeasureUnsignedVarInt(size_t Value);
	static void	  WriteUnsignedVarInt(std::vector<uint8_t>& Buf, size_t Value);
	static size_t ReadUnsignedVarInt(ReadCursor& C);

	static void				WriteString(std::vector<uint8_t>& Buf, const char* Text);
	static void				WriteString(std::vector<uint8_t>& Buf, std::string_view Text);
	static std::string_view ReadString(ReadCursor& C);

	static void WriteOptionalString(std::vector<uint8_t>& Buf, const char* Text);

	std::shared_ptr<AsyncComputeSocket> m_Socket;
	int									m_ChannelId;
	asio::io_context&					m_IoContext;

	std::deque<std::vector<uint8_t>>	m_IncomingFrames;
	AsyncResponseHandler				m_PendingHandler;
	std::unique_ptr<asio::steady_timer> m_TimeoutTimer;
	bool								m_Detached = false;
};

}  // namespace zen::horde