aboutsummaryrefslogtreecommitdiff
path: root/src/zenhorde/hordeagent.cpp
blob: 480b4b985d95f6f585320cd635b7f07186fa0d85 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
// Copyright Epic Games, Inc. All Rights Reserved.

#include "hordeagent.h"
#include "hordetransportaes.h"

#include <zencore/basicfile.h>
#include <zencore/fmtutils.h>
#include <zencore/logging.h>
#include <zencore/trace.h>

#include <cstring>
#include <unordered_map>

namespace zen::horde {

HordeAgent::HordeAgent(const MachineInfo& Info) : m_Log(zen::logging::Get("horde.agent")), m_MachineInfo(Info)
{
	ZEN_TRACE_CPU("HordeAgent::Connect");

	auto Transport = std::make_unique<TcpComputeTransport>(Info);
	if (!Transport->IsValid())
	{
		ZEN_WARN("failed to create TCP transport to '{}:{}'", Info.GetConnectionAddress(), Info.GetConnectionPort());
		return;
	}

	// The 64-byte nonce is always sent unencrypted as the first thing on the wire.
	// The Horde agent uses this to identify which lease this connection belongs to.
	Transport->Send(Info.Nonce, sizeof(Info.Nonce));

	std::unique_ptr<ComputeTransport> FinalTransport = std::move(Transport);
	if (Info.EncryptionMode == Encryption::AES)
	{
		FinalTransport = std::make_unique<AesComputeTransport>(Info.Key, std::move(FinalTransport));
		if (!FinalTransport->IsValid())
		{
			ZEN_WARN("failed to create AES transport");
			return;
		}
	}

	// Create multiplexed socket and channels
	m_Socket = std::make_unique<ComputeSocket>(std::move(FinalTransport));

	// Channel 0 is the agent control channel (handles Attach/Fork handshake).
	// Channel 100 is the child I/O channel (handles file upload and remote execution).
	Ref<ComputeChannel> AgentComputeChannel = m_Socket->CreateChannel(0);
	Ref<ComputeChannel> ChildComputeChannel = m_Socket->CreateChannel(100);

	if (!AgentComputeChannel || !ChildComputeChannel)
	{
		ZEN_WARN("failed to create compute channels");
		return;
	}

	m_AgentChannel = std::make_unique<AgentMessageChannel>(std::move(AgentComputeChannel));
	m_ChildChannel = std::make_unique<AgentMessageChannel>(std::move(ChildComputeChannel));

	m_IsValid = true;
}

HordeAgent::~HordeAgent()
{
	CloseConnection();
}

bool
HordeAgent::BeginCommunication()
{
	ZEN_TRACE_CPU("HordeAgent::BeginCommunication");

	if (!m_IsValid)
	{
		return false;
	}

	// Start the send/recv pump threads
	m_Socket->StartCommunication();

	// Wait for Attach on agent channel
	AgentMessageType Type = m_AgentChannel->ReadResponse(5000);
	if (Type == AgentMessageType::None)
	{
		ZEN_WARN("timed out waiting for Attach on agent channel");
		return false;
	}
	if (Type != AgentMessageType::Attach)
	{
		ZEN_WARN("expected Attach on agent channel, got 0x{:02x}", static_cast<int>(Type));
		return false;
	}

	// Fork tells the remote agent to create child channel 100 with a 4MB buffer.
	// After this, the agent will send an Attach on the child channel.
	m_AgentChannel->Fork(100, 4 * 1024 * 1024);

	// Wait for Attach on child channel
	Type = m_ChildChannel->ReadResponse(5000);
	if (Type == AgentMessageType::None)
	{
		ZEN_WARN("timed out waiting for Attach on child channel");
		return false;
	}
	if (Type != AgentMessageType::Attach)
	{
		ZEN_WARN("expected Attach on child channel, got 0x{:02x}", static_cast<int>(Type));
		return false;
	}

	return true;
}

bool
HordeAgent::UploadBinaries(const std::filesystem::path& BundleDir, const std::string& BundleLocator)
{
	ZEN_TRACE_CPU("HordeAgent::UploadBinaries");

	m_ChildChannel->UploadFiles("", BundleLocator.c_str());

	std::unordered_map<std::string, std::unique_ptr<BasicFile>> BlobFiles;

	auto FindOrOpenBlob = [&](std::string_view Locator) -> BasicFile* {
		std::string Key(Locator);

		if (auto It = BlobFiles.find(Key); It != BlobFiles.end())
		{
			return It->second.get();
		}

		const std::filesystem::path Path = BundleDir / (Key + ".blob");
		std::error_code				Ec;
		auto						File = std::make_unique<BasicFile>();
		File->Open(Path, BasicFile::Mode::kRead, Ec);

		if (Ec)
		{
			ZEN_ERROR("cannot read blob file: '{}'", Path);
			return nullptr;
		}

		BasicFile* Ptr = File.get();
		BlobFiles.emplace(std::move(Key), std::move(File));
		return Ptr;
	};

	// The upload protocol is request-driven: we send WriteFiles, then the remote agent
	// sends ReadBlob requests for each blob it needs. We respond with Blob data until
	// the agent sends WriteFilesResponse indicating the upload is complete.
	constexpr int32_t ReadResponseTimeoutMs = 1000;

	for (;;)
	{
		bool TimedOut = false;

		if (AgentMessageType Type = m_ChildChannel->ReadResponse(ReadResponseTimeoutMs, &TimedOut); Type != AgentMessageType::ReadBlob)
		{
			if (TimedOut)
			{
				continue;
			}
			// End of stream - check if it was a successful upload
			if (Type == AgentMessageType::WriteFilesResponse)
			{
				return true;
			}
			else if (Type == AgentMessageType::Exception)
			{
				ExceptionInfo Ex;
				m_ChildChannel->ReadException(Ex);
				ZEN_ERROR("upload exception: {} - {}", Ex.Message, Ex.Description);
			}
			else
			{
				ZEN_ERROR("unexpected message type 0x{:02x} during upload", static_cast<int>(Type));
			}
			return false;
		}

		BlobRequest Req;
		m_ChildChannel->ReadBlobRequest(Req);

		BasicFile* File = FindOrOpenBlob(Req.Locator);
		if (!File)
		{
			return false;
		}

		// Read from offset to end of file
		const uint64_t TotalSize = File->FileSize();
		const uint64_t Offset	 = static_cast<uint64_t>(Req.Offset);
		if (Offset >= TotalSize)
		{
			ZEN_ERROR("upload got request for data beyond end of file: offset={}, length={}, total_size={}", Offset, Req.Length, TotalSize);
			m_ChildChannel->Blob(nullptr, 0);
			continue;
		}

		const IoBuffer Data = File->ReadRange(Offset, Min(Req.Length, TotalSize - Offset));
		m_ChildChannel->Blob(static_cast<const uint8_t*>(Data.GetData()), Data.GetSize());
	}
}

bool
HordeAgent::UploadCompressedFiles(const std::vector<std::filesystem::path>& FilePaths)
{
	ZEN_TRACE_CPU("HordeAgent::UploadCompressedFiles");

	for (const std::filesystem::path& FilePath : FilePaths)
	{
		std::error_code Ec;
		BasicFile		File;
		File.Open(FilePath, BasicFile::Mode::kRead, Ec);
		if (Ec)
		{
			if (!std::filesystem::exists(FilePath))
			{
				ZEN_DEBUG("skipping missing optional file: '{}'", FilePath);
				continue;
			}
			ZEN_ERROR("failed to open file for compressed upload: '{}'", FilePath);
			return false;
		}

		IoBuffer RawData = File.ReadAll();
		if (RawData.GetSize() == 0)
		{
			ZEN_WARN("empty file, skipping: '{}'", FilePath);
			continue;
		}

		const IoHash		   Hash = IoHash::HashBuffer(RawData.GetData(), RawData.GetSize());
		const SharedBuffer	   Shared(std::move(RawData));
		const CompressedBuffer Compressed = CompressedBuffer::Compress(Shared);

		if (!Compressed)
		{
			ZEN_ERROR("failed to compress file: '{}'", FilePath);
			return false;
		}

		const std::string SandboxPath = FilePath.filename().string();
		if (!UploadCompressedFile(SandboxPath, Compressed, Hash))
		{
			ZEN_ERROR("failed to upload compressed file: '{}'", FilePath);
			return false;
		}

		ZEN_INFO("uploaded '{}' ({} bytes compressed)", SandboxPath, Compressed.GetCompressedSize());
	}

	return true;
}

bool
HordeAgent::UploadCompressedFile(std::string_view SandboxRelativePath, const CompressedBuffer& Compressed, const IoHash& UncompressedHash)
{
	ZEN_TRACE_CPU("HordeAgent::UploadCompressedFile");

	const int64_t CompressedSize = static_cast<int64_t>(Compressed.GetCompressedSize());

	m_ChildChannel->SendWriteCompressedFile(SandboxRelativePath, CompressedSize, UncompressedHash);

	// Wait for initial response - the agent tells us whether it needs the data
	constexpr int32_t ResponseTimeoutMs = 5000;
	AgentMessageType  Type				= m_ChildChannel->ReadResponse(ResponseTimeoutMs);

	if (Type == AgentMessageType::Exception)
	{
		ExceptionInfo Ex;
		m_ChildChannel->ReadException(Ex);
		ZEN_ERROR("WriteCompressedFile exception: {} - {}", Ex.Message, Ex.Description);
		return false;
	}
	if (Type != AgentMessageType::WriteCompressedFileResponse)
	{
		ZEN_ERROR("expected WriteCompressedFileResponse, got 0x{:02x}", static_cast<int>(Type));
		return false;
	}

	bool NeedData = m_ChildChannel->ReadWriteCompressedFileResponse();
	if (!NeedData)
	{
		// Cache hit - agent already has this content
		return true;
	}

	// Send the compressed data in chunks
	static constexpr size_t MaxChunkSize = 524288;	// 512 KB

	const CompositeBuffer& CompressedData = Compressed.GetCompressed();
	const uint64_t		   TotalSize	  = CompressedData.GetSize();

	for (uint64_t Offset = 0; Offset < TotalSize;)
	{
		const size_t ChunkSize = static_cast<size_t>(std::min(static_cast<uint64_t>(MaxChunkSize), TotalSize - Offset));

		// Copy the chunk from the composite buffer into a contiguous staging buffer
		UniqueBuffer ChunkBuffer = UniqueBuffer::Alloc(ChunkSize);
		CompressedData.CopyTo(MutableMemoryView(ChunkBuffer.GetData(), ChunkSize), Offset);

		m_ChildChannel->SendWriteCompressedFileData(static_cast<int32_t>(Offset),
													static_cast<const uint8_t*>(ChunkBuffer.GetData()),
													ChunkSize);
		Offset += ChunkSize;
	}

	// Wait for final confirmation
	Type = m_ChildChannel->ReadResponse(ResponseTimeoutMs);
	if (Type == AgentMessageType::Exception)
	{
		ExceptionInfo Ex;
		m_ChildChannel->ReadException(Ex);
		ZEN_ERROR("WriteCompressedFile data exception: {} - {}", Ex.Message, Ex.Description);
		return false;
	}
	if (Type != AgentMessageType::WriteCompressedFileResponse)
	{
		ZEN_ERROR("expected final WriteCompressedFileResponse, got 0x{:02x}", static_cast<int>(Type));
		return false;
	}

	return true;
}

void
HordeAgent::Execute(const char*		   Exe,
					const char* const* Args,
					size_t			   NumArgs,
					const char*		   WorkingDir,
					const char* const* EnvVars,
					size_t			   NumEnvVars,
					bool			   UseWine)
{
	ZEN_TRACE_CPU("HordeAgent::Execute");
	m_ChildChannel
		->Execute(Exe, Args, NumArgs, WorkingDir, EnvVars, NumEnvVars, UseWine ? ExecuteProcessFlags::UseWine : ExecuteProcessFlags::None);
}

bool
HordeAgent::Poll(bool LogOutput)
{
	constexpr int32_t ReadResponseTimeoutMs = 100;
	AgentMessageType  Type;

	while ((Type = m_ChildChannel->ReadResponse(ReadResponseTimeoutMs)) != AgentMessageType::None)
	{
		switch (Type)
		{
			case AgentMessageType::ExecuteOutput:
				{
					if (LogOutput && m_ChildChannel->GetResponseSize() > 0)
					{
						const char* ResponseData = static_cast<const char*>(m_ChildChannel->GetResponseData());
						size_t		ResponseSize = m_ChildChannel->GetResponseSize();

						// Trim trailing newlines
						while (ResponseSize > 0 && (ResponseData[ResponseSize - 1] == '\n' || ResponseData[ResponseSize - 1] == '\r'))
						{
							--ResponseSize;
						}

						if (ResponseSize > 0)
						{
							const std::string_view Output(ResponseData, ResponseSize);
							ZEN_INFO("[remote] {}", Output);
						}
					}
					break;
				}

			case AgentMessageType::ExecuteResult:
				{
					if (m_ChildChannel->GetResponseSize() == sizeof(int32_t))
					{
						int32_t ExitCode;
						memcpy(&ExitCode, m_ChildChannel->GetResponseData(), sizeof(int32_t));
						ZEN_INFO("remote process exited with code {}", ExitCode);
					}
					m_IsValid = false;
					return false;
				}

			case AgentMessageType::Exception:
				{
					ExceptionInfo Ex;
					m_ChildChannel->ReadException(Ex);
					ZEN_ERROR("exception: {} - {}", Ex.Message, Ex.Description);
					m_HasErrors = true;
					break;
				}

			default:
				break;
		}
	}

	return m_IsValid && !m_HasErrors;
}

void
HordeAgent::CloseConnection()
{
	if (m_ChildChannel)
	{
		m_ChildChannel->Close();
	}
	if (m_AgentChannel)
	{
		m_AgentChannel->Close();
	}
}

bool
HordeAgent::IsValid() const
{
	return m_IsValid && !m_HasErrors;
}

}  // namespace zen::horde