aboutsummaryrefslogtreecommitdiff
path: root/src/zenhorde/hordeagent.cpp
blob: 819b2d0cb7c5daf77c63bd837b32968578e8389b (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
// Copyright Epic Games, Inc. All Rights Reserved.

#include "hordeagent.h"
#include "hordetransportaes.h"

#include <zencore/basicfile.h>
#include <zencore/fmtutils.h>
#include <zencore/logging.h>
#include <zencore/trace.h>

#include <cstring>
#include <unordered_map>

namespace zen::horde {

HordeAgent::HordeAgent(const MachineInfo& Info) : m_Log(zen::logging::Get("horde.agent")), m_MachineInfo(Info)
{
	ZEN_TRACE_CPU("HordeAgent::Connect");

	auto Transport = std::make_unique<TcpComputeTransport>(Info);
	if (!Transport->IsValid())
	{
		ZEN_WARN("failed to create TCP transport to '{}:{}'", Info.GetConnectionAddress(), Info.GetConnectionPort());
		return;
	}

	// The 64-byte nonce is always sent unencrypted as the first thing on the wire.
	// The Horde agent uses this to identify which lease this connection belongs to.
	Transport->Send(Info.Nonce, sizeof(Info.Nonce));

	std::unique_ptr<ComputeTransport> FinalTransport = std::move(Transport);
	if (Info.EncryptionMode == Encryption::AES)
	{
		FinalTransport = std::make_unique<AesComputeTransport>(Info.Key, std::move(FinalTransport));
		if (!FinalTransport->IsValid())
		{
			ZEN_WARN("failed to create AES transport");
			return;
		}
	}

	// Create multiplexed socket and channels
	m_Socket = std::make_unique<ComputeSocket>(std::move(FinalTransport));

	// Channel 0 is the agent control channel (handles Attach/Fork handshake).
	// Channel 100 is the child I/O channel (handles file upload and remote execution).
	Ref<ComputeChannel> AgentComputeChannel = m_Socket->CreateChannel(0);
	Ref<ComputeChannel> ChildComputeChannel = m_Socket->CreateChannel(100);

	if (!AgentComputeChannel || !ChildComputeChannel)
	{
		ZEN_WARN("failed to create compute channels");
		return;
	}

	m_AgentChannel = std::make_unique<AgentMessageChannel>(std::move(AgentComputeChannel));
	m_ChildChannel = std::make_unique<AgentMessageChannel>(std::move(ChildComputeChannel));

	m_IsValid = true;
}

HordeAgent::~HordeAgent()
{
	CloseConnection();
}

bool
HordeAgent::BeginCommunication()
{
	ZEN_TRACE_CPU("HordeAgent::BeginCommunication");

	if (!m_IsValid)
	{
		return false;
	}

	// Start the send/recv pump threads
	m_Socket->StartCommunication();

	// Wait for Attach on agent channel
	AgentMessageType Type = m_AgentChannel->ReadResponse(5000);
	if (Type == AgentMessageType::None)
	{
		ZEN_WARN("timed out waiting for Attach on agent channel");
		return false;
	}
	if (Type != AgentMessageType::Attach)
	{
		ZEN_WARN("expected Attach on agent channel, got 0x{:02x}", static_cast<int>(Type));
		return false;
	}

	// Fork tells the remote agent to create child channel 100 with a 4MB buffer.
	// After this, the agent will send an Attach on the child channel.
	m_AgentChannel->Fork(100, 4 * 1024 * 1024);

	// Wait for Attach on child channel
	Type = m_ChildChannel->ReadResponse(5000);
	if (Type == AgentMessageType::None)
	{
		ZEN_WARN("timed out waiting for Attach on child channel");
		return false;
	}
	if (Type != AgentMessageType::Attach)
	{
		ZEN_WARN("expected Attach on child channel, got 0x{:02x}", static_cast<int>(Type));
		return false;
	}

	return true;
}

bool
HordeAgent::UploadBinaries(const std::filesystem::path& BundleDir, const std::string& BundleLocator)
{
	ZEN_TRACE_CPU("HordeAgent::UploadBinaries");

	m_ChildChannel->UploadFiles("", BundleLocator.c_str());

	std::unordered_map<std::string, std::unique_ptr<BasicFile>> BlobFiles;

	auto FindOrOpenBlob = [&](std::string_view Locator) -> BasicFile* {
		std::string Key(Locator);

		if (auto It = BlobFiles.find(Key); It != BlobFiles.end())
		{
			return It->second.get();
		}

		const std::filesystem::path Path = BundleDir / (Key + ".blob");
		std::error_code				Ec;
		auto						File = std::make_unique<BasicFile>();
		File->Open(Path, BasicFile::Mode::kRead, Ec);

		if (Ec)
		{
			ZEN_ERROR("cannot read blob file: '{}'", Path);
			return nullptr;
		}

		BasicFile* Ptr = File.get();
		BlobFiles.emplace(std::move(Key), std::move(File));
		return Ptr;
	};

	// The upload protocol is request-driven: we send WriteFiles, then the remote agent
	// sends ReadBlob requests for each blob it needs. We respond with Blob data until
	// the agent sends WriteFilesResponse indicating the upload is complete.
	constexpr int32_t ReadResponseTimeoutMs = 1000;

	for (;;)
	{
		bool TimedOut = false;

		if (AgentMessageType Type = m_ChildChannel->ReadResponse(ReadResponseTimeoutMs, &TimedOut); Type != AgentMessageType::ReadBlob)
		{
			if (TimedOut)
			{
				continue;
			}
			// End of stream - check if it was a successful upload
			if (Type == AgentMessageType::WriteFilesResponse)
			{
				return true;
			}
			else if (Type == AgentMessageType::Exception)
			{
				ExceptionInfo Ex;
				m_ChildChannel->ReadException(Ex);
				ZEN_ERROR("upload exception: {} - {}", Ex.Message, Ex.Description);
			}
			else
			{
				ZEN_ERROR("unexpected message type 0x{:02x} during upload", static_cast<int>(Type));
			}
			return false;
		}

		BlobRequest Req;
		m_ChildChannel->ReadBlobRequest(Req);

		BasicFile* File = FindOrOpenBlob(Req.Locator);
		if (!File)
		{
			return false;
		}

		// Read from offset to end of file
		const uint64_t TotalSize = File->FileSize();
		const uint64_t Offset	 = static_cast<uint64_t>(Req.Offset);
		if (Offset >= TotalSize)
		{
			ZEN_ERROR("upload got request for data beyond end of file: offset={}, length={}, total_size={}", Offset, Req.Length, TotalSize);
			m_ChildChannel->Blob(nullptr, 0);
			continue;
		}

		const IoBuffer Data = File->ReadRange(Offset, Min(Req.Length, TotalSize - Offset));
		m_ChildChannel->Blob(static_cast<const uint8_t*>(Data.GetData()), Data.GetSize());
	}
}

void
HordeAgent::Execute(const char*		   Exe,
					const char* const* Args,
					size_t			   NumArgs,
					const char*		   WorkingDir,
					const char* const* EnvVars,
					size_t			   NumEnvVars,
					bool			   UseWine)
{
	ZEN_TRACE_CPU("HordeAgent::Execute");
	m_ChildChannel
		->Execute(Exe, Args, NumArgs, WorkingDir, EnvVars, NumEnvVars, UseWine ? ExecuteProcessFlags::UseWine : ExecuteProcessFlags::None);
}

bool
HordeAgent::Poll(bool LogOutput)
{
	constexpr int32_t ReadResponseTimeoutMs = 100;
	AgentMessageType  Type;

	while ((Type = m_ChildChannel->ReadResponse(ReadResponseTimeoutMs)) != AgentMessageType::None)
	{
		switch (Type)
		{
			case AgentMessageType::ExecuteOutput:
				{
					if (LogOutput && m_ChildChannel->GetResponseSize() > 0)
					{
						const char* ResponseData = static_cast<const char*>(m_ChildChannel->GetResponseData());
						size_t		ResponseSize = m_ChildChannel->GetResponseSize();

						// Trim trailing newlines
						while (ResponseSize > 0 && (ResponseData[ResponseSize - 1] == '\n' || ResponseData[ResponseSize - 1] == '\r'))
						{
							--ResponseSize;
						}

						if (ResponseSize > 0)
						{
							const std::string_view Output(ResponseData, ResponseSize);
							ZEN_INFO("[remote] {}", Output);
						}
					}
					break;
				}

			case AgentMessageType::ExecuteResult:
				{
					if (m_ChildChannel->GetResponseSize() == sizeof(int32_t))
					{
						int32_t ExitCode;
						memcpy(&ExitCode, m_ChildChannel->GetResponseData(), sizeof(int32_t));
						ZEN_INFO("remote process exited with code {}", ExitCode);
					}
					m_IsValid = false;
					return false;
				}

			case AgentMessageType::Exception:
				{
					ExceptionInfo Ex;
					m_ChildChannel->ReadException(Ex);
					ZEN_ERROR("exception: {} - {}", Ex.Message, Ex.Description);
					m_HasErrors = true;
					break;
				}

			default:
				break;
		}
	}

	return m_IsValid && !m_HasErrors;
}

void
HordeAgent::CloseConnection()
{
	if (m_ChildChannel)
	{
		m_ChildChannel->Close();
	}
	if (m_AgentChannel)
	{
		m_AgentChannel->Close();
	}
}

bool
HordeAgent::IsValid() const
{
	return m_IsValid && !m_HasErrors;
}

}  // namespace zen::horde