aboutsummaryrefslogtreecommitdiff
path: root/src/zenhorde/hordeagent.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/zenhorde/hordeagent.cpp')
-rw-r--r--src/zenhorde/hordeagent.cpp297
1 files changed, 297 insertions, 0 deletions
diff --git a/src/zenhorde/hordeagent.cpp b/src/zenhorde/hordeagent.cpp
new file mode 100644
index 000000000..819b2d0cb
--- /dev/null
+++ b/src/zenhorde/hordeagent.cpp
@@ -0,0 +1,297 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "hordeagent.h"
+#include "hordetransportaes.h"
+
+#include <zencore/basicfile.h>
+#include <zencore/fmtutils.h>
+#include <zencore/logging.h>
+#include <zencore/trace.h>
+
+#include <cstring>
+#include <unordered_map>
+
+namespace zen::horde {
+
+HordeAgent::HordeAgent(const MachineInfo& Info) : m_Log(zen::logging::Get("horde.agent")), m_MachineInfo(Info)
+{
+ ZEN_TRACE_CPU("HordeAgent::Connect");
+
+ auto Transport = std::make_unique<TcpComputeTransport>(Info);
+ if (!Transport->IsValid())
+ {
+ ZEN_WARN("failed to create TCP transport to '{}:{}'", Info.GetConnectionAddress(), Info.GetConnectionPort());
+ return;
+ }
+
+ // The 64-byte nonce is always sent unencrypted as the first thing on the wire.
+ // The Horde agent uses this to identify which lease this connection belongs to.
+ Transport->Send(Info.Nonce, sizeof(Info.Nonce));
+
+ std::unique_ptr<ComputeTransport> FinalTransport = std::move(Transport);
+ if (Info.EncryptionMode == Encryption::AES)
+ {
+ FinalTransport = std::make_unique<AesComputeTransport>(Info.Key, std::move(FinalTransport));
+ if (!FinalTransport->IsValid())
+ {
+ ZEN_WARN("failed to create AES transport");
+ return;
+ }
+ }
+
+ // Create multiplexed socket and channels
+ m_Socket = std::make_unique<ComputeSocket>(std::move(FinalTransport));
+
+ // Channel 0 is the agent control channel (handles Attach/Fork handshake).
+ // Channel 100 is the child I/O channel (handles file upload and remote execution).
+ Ref<ComputeChannel> AgentComputeChannel = m_Socket->CreateChannel(0);
+ Ref<ComputeChannel> ChildComputeChannel = m_Socket->CreateChannel(100);
+
+ if (!AgentComputeChannel || !ChildComputeChannel)
+ {
+ ZEN_WARN("failed to create compute channels");
+ return;
+ }
+
+ m_AgentChannel = std::make_unique<AgentMessageChannel>(std::move(AgentComputeChannel));
+ m_ChildChannel = std::make_unique<AgentMessageChannel>(std::move(ChildComputeChannel));
+
+ m_IsValid = true;
+}
+
+HordeAgent::~HordeAgent()
+{
+ CloseConnection();
+}
+
+bool
+HordeAgent::BeginCommunication()
+{
+ ZEN_TRACE_CPU("HordeAgent::BeginCommunication");
+
+ if (!m_IsValid)
+ {
+ return false;
+ }
+
+ // Start the send/recv pump threads
+ m_Socket->StartCommunication();
+
+ // Wait for Attach on agent channel
+ AgentMessageType Type = m_AgentChannel->ReadResponse(5000);
+ if (Type == AgentMessageType::None)
+ {
+ ZEN_WARN("timed out waiting for Attach on agent channel");
+ return false;
+ }
+ if (Type != AgentMessageType::Attach)
+ {
+ ZEN_WARN("expected Attach on agent channel, got 0x{:02x}", static_cast<int>(Type));
+ return false;
+ }
+
+ // Fork tells the remote agent to create child channel 100 with a 4MB buffer.
+ // After this, the agent will send an Attach on the child channel.
+ m_AgentChannel->Fork(100, 4 * 1024 * 1024);
+
+ // Wait for Attach on child channel
+ Type = m_ChildChannel->ReadResponse(5000);
+ if (Type == AgentMessageType::None)
+ {
+ ZEN_WARN("timed out waiting for Attach on child channel");
+ return false;
+ }
+ if (Type != AgentMessageType::Attach)
+ {
+ ZEN_WARN("expected Attach on child channel, got 0x{:02x}", static_cast<int>(Type));
+ return false;
+ }
+
+ return true;
+}
+
+bool
+HordeAgent::UploadBinaries(const std::filesystem::path& BundleDir, const std::string& BundleLocator)
+{
+ ZEN_TRACE_CPU("HordeAgent::UploadBinaries");
+
+ m_ChildChannel->UploadFiles("", BundleLocator.c_str());
+
+ std::unordered_map<std::string, std::unique_ptr<BasicFile>> BlobFiles;
+
+ auto FindOrOpenBlob = [&](std::string_view Locator) -> BasicFile* {
+ std::string Key(Locator);
+
+ if (auto It = BlobFiles.find(Key); It != BlobFiles.end())
+ {
+ return It->second.get();
+ }
+
+ const std::filesystem::path Path = BundleDir / (Key + ".blob");
+ std::error_code Ec;
+ auto File = std::make_unique<BasicFile>();
+ File->Open(Path, BasicFile::Mode::kRead, Ec);
+
+ if (Ec)
+ {
+ ZEN_ERROR("cannot read blob file: '{}'", Path);
+ return nullptr;
+ }
+
+ BasicFile* Ptr = File.get();
+ BlobFiles.emplace(std::move(Key), std::move(File));
+ return Ptr;
+ };
+
+ // The upload protocol is request-driven: we send WriteFiles, then the remote agent
+ // sends ReadBlob requests for each blob it needs. We respond with Blob data until
+ // the agent sends WriteFilesResponse indicating the upload is complete.
+ constexpr int32_t ReadResponseTimeoutMs = 1000;
+
+ for (;;)
+ {
+ bool TimedOut = false;
+
+ if (AgentMessageType Type = m_ChildChannel->ReadResponse(ReadResponseTimeoutMs, &TimedOut); Type != AgentMessageType::ReadBlob)
+ {
+ if (TimedOut)
+ {
+ continue;
+ }
+ // End of stream - check if it was a successful upload
+ if (Type == AgentMessageType::WriteFilesResponse)
+ {
+ return true;
+ }
+ else if (Type == AgentMessageType::Exception)
+ {
+ ExceptionInfo Ex;
+ m_ChildChannel->ReadException(Ex);
+ ZEN_ERROR("upload exception: {} - {}", Ex.Message, Ex.Description);
+ }
+ else
+ {
+ ZEN_ERROR("unexpected message type 0x{:02x} during upload", static_cast<int>(Type));
+ }
+ return false;
+ }
+
+ BlobRequest Req;
+ m_ChildChannel->ReadBlobRequest(Req);
+
+ BasicFile* File = FindOrOpenBlob(Req.Locator);
+ if (!File)
+ {
+ return false;
+ }
+
+ // Read from offset to end of file
+ const uint64_t TotalSize = File->FileSize();
+ const uint64_t Offset = static_cast<uint64_t>(Req.Offset);
+ if (Offset >= TotalSize)
+ {
+ ZEN_ERROR("upload got request for data beyond end of file: offset={}, length={}, total_size={}", Offset, Req.Length, TotalSize);
+ m_ChildChannel->Blob(nullptr, 0);
+ continue;
+ }
+
+ const IoBuffer Data = File->ReadRange(Offset, Min(Req.Length, TotalSize - Offset));
+ m_ChildChannel->Blob(static_cast<const uint8_t*>(Data.GetData()), Data.GetSize());
+ }
+}
+
+void
+HordeAgent::Execute(const char* Exe,
+ const char* const* Args,
+ size_t NumArgs,
+ const char* WorkingDir,
+ const char* const* EnvVars,
+ size_t NumEnvVars,
+ bool UseWine)
+{
+ ZEN_TRACE_CPU("HordeAgent::Execute");
+ m_ChildChannel
+ ->Execute(Exe, Args, NumArgs, WorkingDir, EnvVars, NumEnvVars, UseWine ? ExecuteProcessFlags::UseWine : ExecuteProcessFlags::None);
+}
+
+bool
+HordeAgent::Poll(bool LogOutput)
+{
+ constexpr int32_t ReadResponseTimeoutMs = 100;
+ AgentMessageType Type;
+
+ while ((Type = m_ChildChannel->ReadResponse(ReadResponseTimeoutMs)) != AgentMessageType::None)
+ {
+ switch (Type)
+ {
+ case AgentMessageType::ExecuteOutput:
+ {
+ if (LogOutput && m_ChildChannel->GetResponseSize() > 0)
+ {
+ const char* ResponseData = static_cast<const char*>(m_ChildChannel->GetResponseData());
+ size_t ResponseSize = m_ChildChannel->GetResponseSize();
+
+ // Trim trailing newlines
+ while (ResponseSize > 0 && (ResponseData[ResponseSize - 1] == '\n' || ResponseData[ResponseSize - 1] == '\r'))
+ {
+ --ResponseSize;
+ }
+
+ if (ResponseSize > 0)
+ {
+ const std::string_view Output(ResponseData, ResponseSize);
+ ZEN_INFO("[remote] {}", Output);
+ }
+ }
+ break;
+ }
+
+ case AgentMessageType::ExecuteResult:
+ {
+ if (m_ChildChannel->GetResponseSize() == sizeof(int32_t))
+ {
+ int32_t ExitCode;
+ memcpy(&ExitCode, m_ChildChannel->GetResponseData(), sizeof(int32_t));
+ ZEN_INFO("remote process exited with code {}", ExitCode);
+ }
+ m_IsValid = false;
+ return false;
+ }
+
+ case AgentMessageType::Exception:
+ {
+ ExceptionInfo Ex;
+ m_ChildChannel->ReadException(Ex);
+ ZEN_ERROR("exception: {} - {}", Ex.Message, Ex.Description);
+ m_HasErrors = true;
+ break;
+ }
+
+ default:
+ break;
+ }
+ }
+
+ return m_IsValid && !m_HasErrors;
+}
+
+void
+HordeAgent::CloseConnection()
+{
+ if (m_ChildChannel)
+ {
+ m_ChildChannel->Close();
+ }
+ if (m_AgentChannel)
+ {
+ m_AgentChannel->Close();
+ }
+}
+
+bool
+HordeAgent::IsValid() const
+{
+ return m_IsValid && !m_HasErrors;
+}
+
+} // namespace zen::horde