aboutsummaryrefslogtreecommitdiff
path: root/src/zenhorde/hordeagent.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/zenhorde/hordeagent.cpp')
-rw-r--r--src/zenhorde/hordeagent.cpp551
1 files changed, 359 insertions, 192 deletions
diff --git a/src/zenhorde/hordeagent.cpp b/src/zenhorde/hordeagent.cpp
index 819b2d0cb..275f5bd4c 100644
--- a/src/zenhorde/hordeagent.cpp
+++ b/src/zenhorde/hordeagent.cpp
@@ -8,290 +8,457 @@
#include <zencore/logging.h>
#include <zencore/trace.h>
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <asio.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
#include <cstring>
-#include <unordered_map>
namespace zen::horde {
-HordeAgent::HordeAgent(const MachineInfo& Info) : m_Log(zen::logging::Get("horde.agent")), m_MachineInfo(Info)
-{
- ZEN_TRACE_CPU("HordeAgent::Connect");
+// --- AsyncHordeAgent ---
- auto Transport = std::make_unique<TcpComputeTransport>(Info);
- if (!Transport->IsValid())
+static const char*
+GetStateName(AsyncHordeAgent::State S)
+{
+ switch (S)
{
- ZEN_WARN("failed to create TCP transport to '{}:{}'", Info.GetConnectionAddress(), Info.GetConnectionPort());
- return;
+ case AsyncHordeAgent::State::Idle:
+ return "idle";
+ case AsyncHordeAgent::State::Connecting:
+ return "connect";
+ case AsyncHordeAgent::State::WaitAgentAttach:
+ return "agent-attach";
+ case AsyncHordeAgent::State::SentFork:
+ return "fork";
+ case AsyncHordeAgent::State::WaitChildAttach:
+ return "child-attach";
+ case AsyncHordeAgent::State::Uploading:
+ return "upload";
+ case AsyncHordeAgent::State::Executing:
+ return "execute";
+ case AsyncHordeAgent::State::Polling:
+ return "poll";
+ case AsyncHordeAgent::State::Done:
+ return "done";
+ default:
+ return "unknown";
}
+}
- // The 64-byte nonce is always sent unencrypted as the first thing on the wire.
- // The Horde agent uses this to identify which lease this connection belongs to.
- Transport->Send(Info.Nonce, sizeof(Info.Nonce));
+AsyncHordeAgent::AsyncHordeAgent(asio::io_context& IoContext) : m_IoContext(IoContext), m_Log(zen::logging::Get("horde.agent.async"))
+{
+}
- std::unique_ptr<ComputeTransport> FinalTransport = std::move(Transport);
- if (Info.EncryptionMode == Encryption::AES)
+AsyncHordeAgent::~AsyncHordeAgent()
+{
+ Cancel();
+}
+
+void
+AsyncHordeAgent::Start(AsyncAgentConfig Config, AsyncAgentCompletionHandler OnDone)
+{
+ m_Config = std::move(Config);
+ m_OnDone = std::move(OnDone);
+ m_State = State::Connecting;
+ DoConnect();
+}
+
+void
+AsyncHordeAgent::Cancel()
+{
+ m_Cancelled = true;
+ if (m_Socket)
{
- FinalTransport = std::make_unique<AesComputeTransport>(Info.Key, std::move(FinalTransport));
- if (!FinalTransport->IsValid())
- {
- ZEN_WARN("failed to create AES transport");
- return;
- }
+ m_Socket->Close();
+ }
+ else if (m_Transport)
+ {
+ m_Transport->Close();
}
+}
- // Create multiplexed socket and channels
- m_Socket = std::make_unique<ComputeSocket>(std::move(FinalTransport));
+void
+AsyncHordeAgent::DoConnect()
+{
+ ZEN_TRACE_CPU("AsyncHordeAgent::DoConnect");
- // Channel 0 is the agent control channel (handles Attach/Fork handshake).
- // Channel 100 is the child I/O channel (handles file upload and remote execution).
- Ref<ComputeChannel> AgentComputeChannel = m_Socket->CreateChannel(0);
- Ref<ComputeChannel> ChildComputeChannel = m_Socket->CreateChannel(100);
+ m_TcpTransport = std::make_unique<AsyncTcpComputeTransport>(m_IoContext);
+
+ auto Self = shared_from_this();
+ m_TcpTransport->AsyncConnect(m_Config.Machine, [this, Self](const std::error_code& Ec) { OnConnected(Ec); });
+}
- if (!AgentComputeChannel || !ChildComputeChannel)
+void
+AsyncHordeAgent::OnConnected(const std::error_code& Ec)
+{
+ if (Ec || m_Cancelled)
{
- ZEN_WARN("failed to create compute channels");
+ if (Ec)
+ {
+ ZEN_WARN("connect failed: {}", Ec.message());
+ }
+ Finish(false);
return;
}
- m_AgentChannel = std::make_unique<AgentMessageChannel>(std::move(AgentComputeChannel));
- m_ChildChannel = std::make_unique<AgentMessageChannel>(std::move(ChildComputeChannel));
+ // Optionally wrap with AES encryption
+ std::unique_ptr<AsyncComputeTransport> FinalTransport = std::move(m_TcpTransport);
+ if (m_Config.Machine.EncryptionMode == Encryption::AES)
+ {
+ FinalTransport = std::make_unique<AsyncAesComputeTransport>(m_Config.Machine.Key, std::move(FinalTransport), m_IoContext);
+ }
+ m_Transport = std::move(FinalTransport);
+
+ // Create the multiplexed socket and register channels
+ m_Socket = std::make_shared<AsyncComputeSocket>(std::move(m_Transport), m_IoContext);
+
+ m_AgentChannel = std::make_unique<AsyncAgentMessageChannel>(m_Socket, 0, m_IoContext);
+ m_ChildChannel = std::make_unique<AsyncAgentMessageChannel>(m_Socket, 100, m_IoContext);
+
+ m_Socket->RegisterChannel(
+ 0,
+ [this](std::vector<uint8_t> Data) { m_AgentChannel->OnFrame(std::move(Data)); },
+ [this]() { m_AgentChannel->OnDetach(); });
- m_IsValid = true;
+ m_Socket->RegisterChannel(
+ 100,
+ [this](std::vector<uint8_t> Data) { m_ChildChannel->OnFrame(std::move(Data)); },
+ [this]() { m_ChildChannel->OnDetach(); });
+
+ m_Socket->StartRecvPump();
+
+ m_State = State::WaitAgentAttach;
+ DoWaitAgentAttach();
}
-HordeAgent::~HordeAgent()
+void
+AsyncHordeAgent::DoWaitAgentAttach()
{
- CloseConnection();
+ auto Self = shared_from_this();
+ m_AgentChannel->AsyncReadResponse(5000, [this, Self](AgentMessageType Type, const uint8_t* Data, size_t Size) {
+ OnAgentResponse(Type, Data, Size);
+ });
}
-bool
-HordeAgent::BeginCommunication()
+void
+AsyncHordeAgent::OnAgentResponse(AgentMessageType Type, const uint8_t* /*Data*/, size_t /*Size*/)
{
- ZEN_TRACE_CPU("HordeAgent::BeginCommunication");
-
- if (!m_IsValid)
+ if (m_Cancelled)
{
- return false;
+ Finish(false);
+ return;
}
- // Start the send/recv pump threads
- m_Socket->StartCommunication();
-
- // Wait for Attach on agent channel
- AgentMessageType Type = m_AgentChannel->ReadResponse(5000);
if (Type == AgentMessageType::None)
{
ZEN_WARN("timed out waiting for Attach on agent channel");
- return false;
+ Finish(false);
+ return;
}
+
if (Type != AgentMessageType::Attach)
{
ZEN_WARN("expected Attach on agent channel, got 0x{:02x}", static_cast<int>(Type));
- return false;
+ Finish(false);
+ return;
}
- // Fork tells the remote agent to create child channel 100 with a 4MB buffer.
- // After this, the agent will send an Attach on the child channel.
+ m_State = State::SentFork;
+ DoSendFork();
+}
+
+void
+AsyncHordeAgent::DoSendFork()
+{
m_AgentChannel->Fork(100, 4 * 1024 * 1024);
- // Wait for Attach on child channel
- Type = m_ChildChannel->ReadResponse(5000);
+ m_State = State::WaitChildAttach;
+ DoWaitChildAttach();
+}
+
+void
+AsyncHordeAgent::DoWaitChildAttach()
+{
+ auto Self = shared_from_this();
+ m_ChildChannel->AsyncReadResponse(5000, [this, Self](AgentMessageType Type, const uint8_t* Data, size_t Size) {
+ OnChildAttachResponse(Type, Data, Size);
+ });
+}
+
+void
+AsyncHordeAgent::OnChildAttachResponse(AgentMessageType Type, const uint8_t* /*Data*/, size_t /*Size*/)
+{
+ if (m_Cancelled)
+ {
+ Finish(false);
+ return;
+ }
+
if (Type == AgentMessageType::None)
{
ZEN_WARN("timed out waiting for Attach on child channel");
- return false;
+ Finish(false);
+ return;
}
+
if (Type != AgentMessageType::Attach)
{
ZEN_WARN("expected Attach on child channel, got 0x{:02x}", static_cast<int>(Type));
- return false;
+ Finish(false);
+ return;
}
- return true;
+ m_State = State::Uploading;
+ m_CurrentBundleIndex = 0;
+ DoUploadNext();
}
-bool
-HordeAgent::UploadBinaries(const std::filesystem::path& BundleDir, const std::string& BundleLocator)
+void
+AsyncHordeAgent::DoUploadNext()
{
- ZEN_TRACE_CPU("HordeAgent::UploadBinaries");
-
- m_ChildChannel->UploadFiles("", BundleLocator.c_str());
+ if (m_Cancelled)
+ {
+ Finish(false);
+ return;
+ }
- std::unordered_map<std::string, std::unique_ptr<BasicFile>> BlobFiles;
+ if (m_CurrentBundleIndex >= m_Config.Bundles.size())
+ {
+ // All bundles uploaded — proceed to execute
+ m_State = State::Executing;
+ DoExecute();
+ return;
+ }
- auto FindOrOpenBlob = [&](std::string_view Locator) -> BasicFile* {
- std::string Key(Locator);
+ const auto& [Locator, BundleDir] = m_Config.Bundles[m_CurrentBundleIndex];
+ m_ChildChannel->UploadFiles("", Locator.c_str());
- if (auto It = BlobFiles.find(Key); It != BlobFiles.end())
- {
- return It->second.get();
- }
+ // Enter the ReadBlob/Blob upload loop
+ auto Self = shared_from_this();
+ m_ChildChannel->AsyncReadResponse(1000, [this, Self](AgentMessageType Type, const uint8_t* Data, size_t Size) {
+ OnUploadResponse(Type, Data, Size);
+ });
+}
- const std::filesystem::path Path = BundleDir / (Key + ".blob");
- std::error_code Ec;
- auto File = std::make_unique<BasicFile>();
- File->Open(Path, BasicFile::Mode::kRead, Ec);
+void
+AsyncHordeAgent::OnUploadResponse(AgentMessageType Type, const uint8_t* Data, size_t Size)
+{
+ if (m_Cancelled)
+ {
+ Finish(false);
+ return;
+ }
- if (Ec)
+ if (Type == AgentMessageType::None)
+ {
+ if (m_ChildChannel->IsDetached())
{
- ZEN_ERROR("cannot read blob file: '{}'", Path);
- return nullptr;
+ ZEN_WARN("connection lost during upload");
+ Finish(false);
+ return;
}
+ // Timeout — retry read
+ auto Self = shared_from_this();
+ m_ChildChannel->AsyncReadResponse(1000, [this, Self](AgentMessageType Type, const uint8_t* Data, size_t Size) {
+ OnUploadResponse(Type, Data, Size);
+ });
+ return;
+ }
- BasicFile* Ptr = File.get();
- BlobFiles.emplace(std::move(Key), std::move(File));
- return Ptr;
- };
+ if (Type == AgentMessageType::WriteFilesResponse)
+ {
+ // This bundle upload is done — move to next
+ ++m_CurrentBundleIndex;
+ DoUploadNext();
+ return;
+ }
- // The upload protocol is request-driven: we send WriteFiles, then the remote agent
- // sends ReadBlob requests for each blob it needs. We respond with Blob data until
- // the agent sends WriteFilesResponse indicating the upload is complete.
- constexpr int32_t ReadResponseTimeoutMs = 1000;
+ if (Type == AgentMessageType::Exception)
+ {
+ ExceptionInfo Ex;
+ AsyncAgentMessageChannel::ReadException(Data, Size, Ex);
+ ZEN_ERROR("upload exception: {} - {}", Ex.Message, Ex.Description);
+ Finish(false);
+ return;
+ }
- for (;;)
+ if (Type != AgentMessageType::ReadBlob)
{
- bool TimedOut = false;
+ ZEN_ERROR("unexpected message type 0x{:02x} during upload", static_cast<int>(Type));
+ Finish(false);
+ return;
+ }
- if (AgentMessageType Type = m_ChildChannel->ReadResponse(ReadResponseTimeoutMs, &TimedOut); Type != AgentMessageType::ReadBlob)
- {
- if (TimedOut)
- {
- continue;
- }
- // End of stream - check if it was a successful upload
- if (Type == AgentMessageType::WriteFilesResponse)
- {
- return true;
- }
- else if (Type == AgentMessageType::Exception)
- {
- ExceptionInfo Ex;
- m_ChildChannel->ReadException(Ex);
- ZEN_ERROR("upload exception: {} - {}", Ex.Message, Ex.Description);
- }
- else
- {
- ZEN_ERROR("unexpected message type 0x{:02x} during upload", static_cast<int>(Type));
- }
- return false;
- }
+ // Handle ReadBlob request
+ BlobRequest Req;
+ AsyncAgentMessageChannel::ReadBlobRequest(Data, Size, Req);
- BlobRequest Req;
- m_ChildChannel->ReadBlobRequest(Req);
+ const auto& [Locator, BundleDir] = m_Config.Bundles[m_CurrentBundleIndex];
+ const std::filesystem::path BlobPath = BundleDir / (std::string(Req.Locator) + ".blob");
- BasicFile* File = FindOrOpenBlob(Req.Locator);
- if (!File)
- {
- return false;
- }
+ std::error_code FsEc;
+ BasicFile File;
+ File.Open(BlobPath, BasicFile::Mode::kRead, FsEc);
- // Read from offset to end of file
- const uint64_t TotalSize = File->FileSize();
- const uint64_t Offset = static_cast<uint64_t>(Req.Offset);
- if (Offset >= TotalSize)
- {
- ZEN_ERROR("upload got request for data beyond end of file: offset={}, length={}, total_size={}", Offset, Req.Length, TotalSize);
- m_ChildChannel->Blob(nullptr, 0);
- continue;
- }
+ if (FsEc)
+ {
+ ZEN_ERROR("cannot read blob file: '{}'", BlobPath);
+ Finish(false);
+ return;
+ }
+
+ const uint64_t TotalSize = File.FileSize();
+ const uint64_t Offset = static_cast<uint64_t>(Req.Offset);
+ if (Offset >= TotalSize)
+ {
+ ZEN_ERROR("blob request beyond end of file: offset={}, length={}, total_size={}", Offset, Req.Length, TotalSize);
+ m_ChildChannel->Blob(nullptr, 0);
+ }
+ else
+ {
+ const IoBuffer FileData = File.ReadRange(Offset, Min(Req.Length, TotalSize - Offset));
+ m_ChildChannel->Blob(static_cast<const uint8_t*>(FileData.GetData()), FileData.GetSize());
+ }
+
+ // Continue the upload loop
+ auto Self = shared_from_this();
+ m_ChildChannel->AsyncReadResponse(1000, [this, Self](AgentMessageType Type, const uint8_t* Data, size_t Size) {
+ OnUploadResponse(Type, Data, Size);
+ });
+}
- const IoBuffer Data = File->ReadRange(Offset, Min(Req.Length, TotalSize - Offset));
- m_ChildChannel->Blob(static_cast<const uint8_t*>(Data.GetData()), Data.GetSize());
+void
+AsyncHordeAgent::DoExecute()
+{
+ ZEN_TRACE_CPU("AsyncHordeAgent::DoExecute");
+
+ std::vector<const char*> ArgPtrs;
+ ArgPtrs.reserve(m_Config.Args.size());
+ for (const std::string& Arg : m_Config.Args)
+ {
+ ArgPtrs.push_back(Arg.c_str());
}
+
+ m_ChildChannel->Execute(m_Config.Executable.c_str(),
+ ArgPtrs.data(),
+ ArgPtrs.size(),
+ nullptr,
+ nullptr,
+ 0,
+ m_Config.UseWine ? ExecuteProcessFlags::UseWine : ExecuteProcessFlags::None);
+
+ ZEN_INFO("remote execution started on [{}:{}] lease={}",
+ m_Config.Machine.GetConnectionAddress(),
+ m_Config.Machine.GetConnectionPort(),
+ m_Config.Machine.LeaseId);
+
+ m_State = State::Polling;
+ DoPoll();
}
void
-HordeAgent::Execute(const char* Exe,
- const char* const* Args,
- size_t NumArgs,
- const char* WorkingDir,
- const char* const* EnvVars,
- size_t NumEnvVars,
- bool UseWine)
+AsyncHordeAgent::DoPoll()
{
- ZEN_TRACE_CPU("HordeAgent::Execute");
- m_ChildChannel
- ->Execute(Exe, Args, NumArgs, WorkingDir, EnvVars, NumEnvVars, UseWine ? ExecuteProcessFlags::UseWine : ExecuteProcessFlags::None);
+ if (m_Cancelled)
+ {
+ Finish(false);
+ return;
+ }
+
+ auto Self = shared_from_this();
+ m_ChildChannel->AsyncReadResponse(100, [this, Self](AgentMessageType Type, const uint8_t* Data, size_t Size) {
+ OnPollResponse(Type, Data, Size);
+ });
}
-bool
-HordeAgent::Poll(bool LogOutput)
+void
+AsyncHordeAgent::OnPollResponse(AgentMessageType Type, const uint8_t* Data, size_t Size)
{
- constexpr int32_t ReadResponseTimeoutMs = 100;
- AgentMessageType Type;
+ if (m_Cancelled)
+ {
+ Finish(false);
+ return;
+ }
- while ((Type = m_ChildChannel->ReadResponse(ReadResponseTimeoutMs)) != AgentMessageType::None)
+ switch (Type)
{
- switch (Type)
- {
- case AgentMessageType::ExecuteOutput:
- {
- if (LogOutput && m_ChildChannel->GetResponseSize() > 0)
- {
- const char* ResponseData = static_cast<const char*>(m_ChildChannel->GetResponseData());
- size_t ResponseSize = m_ChildChannel->GetResponseSize();
-
- // Trim trailing newlines
- while (ResponseSize > 0 && (ResponseData[ResponseSize - 1] == '\n' || ResponseData[ResponseSize - 1] == '\r'))
- {
- --ResponseSize;
- }
-
- if (ResponseSize > 0)
- {
- const std::string_view Output(ResponseData, ResponseSize);
- ZEN_INFO("[remote] {}", Output);
- }
- }
- break;
- }
+ case AgentMessageType::None:
+ if (m_ChildChannel->IsDetached())
+ {
+ ZEN_WARN("connection lost during execution");
+ Finish(false);
+ }
+ else
+ {
+ // Timeout — poll again
+ DoPoll();
+ }
+ break;
- case AgentMessageType::ExecuteResult:
- {
- if (m_ChildChannel->GetResponseSize() == sizeof(int32_t))
- {
- int32_t ExitCode;
- memcpy(&ExitCode, m_ChildChannel->GetResponseData(), sizeof(int32_t));
- ZEN_INFO("remote process exited with code {}", ExitCode);
- }
- m_IsValid = false;
- return false;
- }
+ case AgentMessageType::ExecuteOutput:
+ // Silently consume remote stdout (matching LogOutput=false in provisioner)
+ DoPoll();
+ break;
- case AgentMessageType::Exception:
+ case AgentMessageType::ExecuteResult:
+ {
+ int32_t ExitCode = -1;
+ if (Size == sizeof(int32_t))
{
- ExceptionInfo Ex;
- m_ChildChannel->ReadException(Ex);
- ZEN_ERROR("exception: {} - {}", Ex.Message, Ex.Description);
- m_HasErrors = true;
- break;
+ memcpy(&ExitCode, Data, sizeof(int32_t));
}
+ ZEN_INFO("remote process exited with code {} (lease={})", ExitCode, m_Config.Machine.LeaseId);
+ Finish(ExitCode == 0, ExitCode);
+ }
+ break;
- default:
- break;
- }
- }
+ case AgentMessageType::Exception:
+ {
+ ExceptionInfo Ex;
+ AsyncAgentMessageChannel::ReadException(Data, Size, Ex);
+ ZEN_ERROR("exception: {} - {}", Ex.Message, Ex.Description);
+ Finish(false);
+ }
+ break;
- return m_IsValid && !m_HasErrors;
+ default:
+ DoPoll();
+ break;
+ }
}
void
-HordeAgent::CloseConnection()
+AsyncHordeAgent::Finish(bool Success, int32_t ExitCode)
{
- if (m_ChildChannel)
+ if (m_State == State::Done)
{
- m_ChildChannel->Close();
+ return; // Already finished
}
- if (m_AgentChannel)
+
+ if (!Success)
{
- m_AgentChannel->Close();
+ ZEN_WARN("agent failed during {} (lease={})", GetStateName(m_State), m_Config.Machine.LeaseId);
}
-}
-bool
-HordeAgent::IsValid() const
-{
- return m_IsValid && !m_HasErrors;
+ m_State = State::Done;
+
+ if (m_Socket)
+ {
+ m_Socket->Close();
+ }
+
+ if (m_OnDone)
+ {
+ AsyncAgentResult Result;
+ Result.Success = Success;
+ Result.ExitCode = ExitCode;
+ Result.CoreCount = m_Config.Machine.LogicalCores;
+
+ auto Handler = std::move(m_OnDone);
+ m_OnDone = nullptr;
+ Handler(Result);
+ }
}
} // namespace zen::horde