diff options
Diffstat (limited to 'src/zenhorde/hordeagent.cpp')
| -rw-r--r-- | src/zenhorde/hordeagent.cpp | 551 |
1 files changed, 359 insertions, 192 deletions
diff --git a/src/zenhorde/hordeagent.cpp b/src/zenhorde/hordeagent.cpp index 819b2d0cb..275f5bd4c 100644 --- a/src/zenhorde/hordeagent.cpp +++ b/src/zenhorde/hordeagent.cpp @@ -8,290 +8,457 @@ #include <zencore/logging.h> #include <zencore/trace.h> +ZEN_THIRD_PARTY_INCLUDES_START +#include <asio.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + #include <cstring> -#include <unordered_map> namespace zen::horde { -HordeAgent::HordeAgent(const MachineInfo& Info) : m_Log(zen::logging::Get("horde.agent")), m_MachineInfo(Info) -{ - ZEN_TRACE_CPU("HordeAgent::Connect"); +// --- AsyncHordeAgent --- - auto Transport = std::make_unique<TcpComputeTransport>(Info); - if (!Transport->IsValid()) +static const char* +GetStateName(AsyncHordeAgent::State S) +{ + switch (S) { - ZEN_WARN("failed to create TCP transport to '{}:{}'", Info.GetConnectionAddress(), Info.GetConnectionPort()); - return; + case AsyncHordeAgent::State::Idle: + return "idle"; + case AsyncHordeAgent::State::Connecting: + return "connect"; + case AsyncHordeAgent::State::WaitAgentAttach: + return "agent-attach"; + case AsyncHordeAgent::State::SentFork: + return "fork"; + case AsyncHordeAgent::State::WaitChildAttach: + return "child-attach"; + case AsyncHordeAgent::State::Uploading: + return "upload"; + case AsyncHordeAgent::State::Executing: + return "execute"; + case AsyncHordeAgent::State::Polling: + return "poll"; + case AsyncHordeAgent::State::Done: + return "done"; + default: + return "unknown"; } +} - // The 64-byte nonce is always sent unencrypted as the first thing on the wire. - // The Horde agent uses this to identify which lease this connection belongs to. - Transport->Send(Info.Nonce, sizeof(Info.Nonce)); +AsyncHordeAgent::AsyncHordeAgent(asio::io_context& IoContext) : m_IoContext(IoContext), m_Log(zen::logging::Get("horde.agent.async")) +{ +} - std::unique_ptr<ComputeTransport> FinalTransport = std::move(Transport); - if (Info.EncryptionMode == Encryption::AES) +AsyncHordeAgent::~AsyncHordeAgent() +{ + Cancel(); +} + +void +AsyncHordeAgent::Start(AsyncAgentConfig Config, AsyncAgentCompletionHandler OnDone) +{ + m_Config = std::move(Config); + m_OnDone = std::move(OnDone); + m_State = State::Connecting; + DoConnect(); +} + +void +AsyncHordeAgent::Cancel() +{ + m_Cancelled = true; + if (m_Socket) { - FinalTransport = std::make_unique<AesComputeTransport>(Info.Key, std::move(FinalTransport)); - if (!FinalTransport->IsValid()) - { - ZEN_WARN("failed to create AES transport"); - return; - } + m_Socket->Close(); + } + else if (m_Transport) + { + m_Transport->Close(); } +} - // Create multiplexed socket and channels - m_Socket = std::make_unique<ComputeSocket>(std::move(FinalTransport)); +void +AsyncHordeAgent::DoConnect() +{ + ZEN_TRACE_CPU("AsyncHordeAgent::DoConnect"); - // Channel 0 is the agent control channel (handles Attach/Fork handshake). - // Channel 100 is the child I/O channel (handles file upload and remote execution). - Ref<ComputeChannel> AgentComputeChannel = m_Socket->CreateChannel(0); - Ref<ComputeChannel> ChildComputeChannel = m_Socket->CreateChannel(100); + m_TcpTransport = std::make_unique<AsyncTcpComputeTransport>(m_IoContext); + + auto Self = shared_from_this(); + m_TcpTransport->AsyncConnect(m_Config.Machine, [this, Self](const std::error_code& Ec) { OnConnected(Ec); }); +} - if (!AgentComputeChannel || !ChildComputeChannel) +void +AsyncHordeAgent::OnConnected(const std::error_code& Ec) +{ + if (Ec || m_Cancelled) { - ZEN_WARN("failed to create compute channels"); + if (Ec) + { + ZEN_WARN("connect failed: {}", Ec.message()); + } + Finish(false); return; } - m_AgentChannel = std::make_unique<AgentMessageChannel>(std::move(AgentComputeChannel)); - m_ChildChannel = std::make_unique<AgentMessageChannel>(std::move(ChildComputeChannel)); + // Optionally wrap with AES encryption + std::unique_ptr<AsyncComputeTransport> FinalTransport = std::move(m_TcpTransport); + if (m_Config.Machine.EncryptionMode == Encryption::AES) + { + FinalTransport = std::make_unique<AsyncAesComputeTransport>(m_Config.Machine.Key, std::move(FinalTransport), m_IoContext); + } + m_Transport = std::move(FinalTransport); + + // Create the multiplexed socket and register channels + m_Socket = std::make_shared<AsyncComputeSocket>(std::move(m_Transport), m_IoContext); + + m_AgentChannel = std::make_unique<AsyncAgentMessageChannel>(m_Socket, 0, m_IoContext); + m_ChildChannel = std::make_unique<AsyncAgentMessageChannel>(m_Socket, 100, m_IoContext); + + m_Socket->RegisterChannel( + 0, + [this](std::vector<uint8_t> Data) { m_AgentChannel->OnFrame(std::move(Data)); }, + [this]() { m_AgentChannel->OnDetach(); }); - m_IsValid = true; + m_Socket->RegisterChannel( + 100, + [this](std::vector<uint8_t> Data) { m_ChildChannel->OnFrame(std::move(Data)); }, + [this]() { m_ChildChannel->OnDetach(); }); + + m_Socket->StartRecvPump(); + + m_State = State::WaitAgentAttach; + DoWaitAgentAttach(); } -HordeAgent::~HordeAgent() +void +AsyncHordeAgent::DoWaitAgentAttach() { - CloseConnection(); + auto Self = shared_from_this(); + m_AgentChannel->AsyncReadResponse(5000, [this, Self](AgentMessageType Type, const uint8_t* Data, size_t Size) { + OnAgentResponse(Type, Data, Size); + }); } -bool -HordeAgent::BeginCommunication() +void +AsyncHordeAgent::OnAgentResponse(AgentMessageType Type, const uint8_t* /*Data*/, size_t /*Size*/) { - ZEN_TRACE_CPU("HordeAgent::BeginCommunication"); - - if (!m_IsValid) + if (m_Cancelled) { - return false; + Finish(false); + return; } - // Start the send/recv pump threads - m_Socket->StartCommunication(); - - // Wait for Attach on agent channel - AgentMessageType Type = m_AgentChannel->ReadResponse(5000); if (Type == AgentMessageType::None) { ZEN_WARN("timed out waiting for Attach on agent channel"); - return false; + Finish(false); + return; } + if (Type != AgentMessageType::Attach) { ZEN_WARN("expected Attach on agent channel, got 0x{:02x}", static_cast<int>(Type)); - return false; + Finish(false); + return; } - // Fork tells the remote agent to create child channel 100 with a 4MB buffer. - // After this, the agent will send an Attach on the child channel. + m_State = State::SentFork; + DoSendFork(); +} + +void +AsyncHordeAgent::DoSendFork() +{ m_AgentChannel->Fork(100, 4 * 1024 * 1024); - // Wait for Attach on child channel - Type = m_ChildChannel->ReadResponse(5000); + m_State = State::WaitChildAttach; + DoWaitChildAttach(); +} + +void +AsyncHordeAgent::DoWaitChildAttach() +{ + auto Self = shared_from_this(); + m_ChildChannel->AsyncReadResponse(5000, [this, Self](AgentMessageType Type, const uint8_t* Data, size_t Size) { + OnChildAttachResponse(Type, Data, Size); + }); +} + +void +AsyncHordeAgent::OnChildAttachResponse(AgentMessageType Type, const uint8_t* /*Data*/, size_t /*Size*/) +{ + if (m_Cancelled) + { + Finish(false); + return; + } + if (Type == AgentMessageType::None) { ZEN_WARN("timed out waiting for Attach on child channel"); - return false; + Finish(false); + return; } + if (Type != AgentMessageType::Attach) { ZEN_WARN("expected Attach on child channel, got 0x{:02x}", static_cast<int>(Type)); - return false; + Finish(false); + return; } - return true; + m_State = State::Uploading; + m_CurrentBundleIndex = 0; + DoUploadNext(); } -bool -HordeAgent::UploadBinaries(const std::filesystem::path& BundleDir, const std::string& BundleLocator) +void +AsyncHordeAgent::DoUploadNext() { - ZEN_TRACE_CPU("HordeAgent::UploadBinaries"); - - m_ChildChannel->UploadFiles("", BundleLocator.c_str()); + if (m_Cancelled) + { + Finish(false); + return; + } - std::unordered_map<std::string, std::unique_ptr<BasicFile>> BlobFiles; + if (m_CurrentBundleIndex >= m_Config.Bundles.size()) + { + // All bundles uploaded — proceed to execute + m_State = State::Executing; + DoExecute(); + return; + } - auto FindOrOpenBlob = [&](std::string_view Locator) -> BasicFile* { - std::string Key(Locator); + const auto& [Locator, BundleDir] = m_Config.Bundles[m_CurrentBundleIndex]; + m_ChildChannel->UploadFiles("", Locator.c_str()); - if (auto It = BlobFiles.find(Key); It != BlobFiles.end()) - { - return It->second.get(); - } + // Enter the ReadBlob/Blob upload loop + auto Self = shared_from_this(); + m_ChildChannel->AsyncReadResponse(1000, [this, Self](AgentMessageType Type, const uint8_t* Data, size_t Size) { + OnUploadResponse(Type, Data, Size); + }); +} - const std::filesystem::path Path = BundleDir / (Key + ".blob"); - std::error_code Ec; - auto File = std::make_unique<BasicFile>(); - File->Open(Path, BasicFile::Mode::kRead, Ec); +void +AsyncHordeAgent::OnUploadResponse(AgentMessageType Type, const uint8_t* Data, size_t Size) +{ + if (m_Cancelled) + { + Finish(false); + return; + } - if (Ec) + if (Type == AgentMessageType::None) + { + if (m_ChildChannel->IsDetached()) { - ZEN_ERROR("cannot read blob file: '{}'", Path); - return nullptr; + ZEN_WARN("connection lost during upload"); + Finish(false); + return; } + // Timeout — retry read + auto Self = shared_from_this(); + m_ChildChannel->AsyncReadResponse(1000, [this, Self](AgentMessageType Type, const uint8_t* Data, size_t Size) { + OnUploadResponse(Type, Data, Size); + }); + return; + } - BasicFile* Ptr = File.get(); - BlobFiles.emplace(std::move(Key), std::move(File)); - return Ptr; - }; + if (Type == AgentMessageType::WriteFilesResponse) + { + // This bundle upload is done — move to next + ++m_CurrentBundleIndex; + DoUploadNext(); + return; + } - // The upload protocol is request-driven: we send WriteFiles, then the remote agent - // sends ReadBlob requests for each blob it needs. We respond with Blob data until - // the agent sends WriteFilesResponse indicating the upload is complete. - constexpr int32_t ReadResponseTimeoutMs = 1000; + if (Type == AgentMessageType::Exception) + { + ExceptionInfo Ex; + AsyncAgentMessageChannel::ReadException(Data, Size, Ex); + ZEN_ERROR("upload exception: {} - {}", Ex.Message, Ex.Description); + Finish(false); + return; + } - for (;;) + if (Type != AgentMessageType::ReadBlob) { - bool TimedOut = false; + ZEN_ERROR("unexpected message type 0x{:02x} during upload", static_cast<int>(Type)); + Finish(false); + return; + } - if (AgentMessageType Type = m_ChildChannel->ReadResponse(ReadResponseTimeoutMs, &TimedOut); Type != AgentMessageType::ReadBlob) - { - if (TimedOut) - { - continue; - } - // End of stream - check if it was a successful upload - if (Type == AgentMessageType::WriteFilesResponse) - { - return true; - } - else if (Type == AgentMessageType::Exception) - { - ExceptionInfo Ex; - m_ChildChannel->ReadException(Ex); - ZEN_ERROR("upload exception: {} - {}", Ex.Message, Ex.Description); - } - else - { - ZEN_ERROR("unexpected message type 0x{:02x} during upload", static_cast<int>(Type)); - } - return false; - } + // Handle ReadBlob request + BlobRequest Req; + AsyncAgentMessageChannel::ReadBlobRequest(Data, Size, Req); - BlobRequest Req; - m_ChildChannel->ReadBlobRequest(Req); + const auto& [Locator, BundleDir] = m_Config.Bundles[m_CurrentBundleIndex]; + const std::filesystem::path BlobPath = BundleDir / (std::string(Req.Locator) + ".blob"); - BasicFile* File = FindOrOpenBlob(Req.Locator); - if (!File) - { - return false; - } + std::error_code FsEc; + BasicFile File; + File.Open(BlobPath, BasicFile::Mode::kRead, FsEc); - // Read from offset to end of file - const uint64_t TotalSize = File->FileSize(); - const uint64_t Offset = static_cast<uint64_t>(Req.Offset); - if (Offset >= TotalSize) - { - ZEN_ERROR("upload got request for data beyond end of file: offset={}, length={}, total_size={}", Offset, Req.Length, TotalSize); - m_ChildChannel->Blob(nullptr, 0); - continue; - } + if (FsEc) + { + ZEN_ERROR("cannot read blob file: '{}'", BlobPath); + Finish(false); + return; + } + + const uint64_t TotalSize = File.FileSize(); + const uint64_t Offset = static_cast<uint64_t>(Req.Offset); + if (Offset >= TotalSize) + { + ZEN_ERROR("blob request beyond end of file: offset={}, length={}, total_size={}", Offset, Req.Length, TotalSize); + m_ChildChannel->Blob(nullptr, 0); + } + else + { + const IoBuffer FileData = File.ReadRange(Offset, Min(Req.Length, TotalSize - Offset)); + m_ChildChannel->Blob(static_cast<const uint8_t*>(FileData.GetData()), FileData.GetSize()); + } + + // Continue the upload loop + auto Self = shared_from_this(); + m_ChildChannel->AsyncReadResponse(1000, [this, Self](AgentMessageType Type, const uint8_t* Data, size_t Size) { + OnUploadResponse(Type, Data, Size); + }); +} - const IoBuffer Data = File->ReadRange(Offset, Min(Req.Length, TotalSize - Offset)); - m_ChildChannel->Blob(static_cast<const uint8_t*>(Data.GetData()), Data.GetSize()); +void +AsyncHordeAgent::DoExecute() +{ + ZEN_TRACE_CPU("AsyncHordeAgent::DoExecute"); + + std::vector<const char*> ArgPtrs; + ArgPtrs.reserve(m_Config.Args.size()); + for (const std::string& Arg : m_Config.Args) + { + ArgPtrs.push_back(Arg.c_str()); } + + m_ChildChannel->Execute(m_Config.Executable.c_str(), + ArgPtrs.data(), + ArgPtrs.size(), + nullptr, + nullptr, + 0, + m_Config.UseWine ? ExecuteProcessFlags::UseWine : ExecuteProcessFlags::None); + + ZEN_INFO("remote execution started on [{}:{}] lease={}", + m_Config.Machine.GetConnectionAddress(), + m_Config.Machine.GetConnectionPort(), + m_Config.Machine.LeaseId); + + m_State = State::Polling; + DoPoll(); } void -HordeAgent::Execute(const char* Exe, - const char* const* Args, - size_t NumArgs, - const char* WorkingDir, - const char* const* EnvVars, - size_t NumEnvVars, - bool UseWine) +AsyncHordeAgent::DoPoll() { - ZEN_TRACE_CPU("HordeAgent::Execute"); - m_ChildChannel - ->Execute(Exe, Args, NumArgs, WorkingDir, EnvVars, NumEnvVars, UseWine ? ExecuteProcessFlags::UseWine : ExecuteProcessFlags::None); + if (m_Cancelled) + { + Finish(false); + return; + } + + auto Self = shared_from_this(); + m_ChildChannel->AsyncReadResponse(100, [this, Self](AgentMessageType Type, const uint8_t* Data, size_t Size) { + OnPollResponse(Type, Data, Size); + }); } -bool -HordeAgent::Poll(bool LogOutput) +void +AsyncHordeAgent::OnPollResponse(AgentMessageType Type, const uint8_t* Data, size_t Size) { - constexpr int32_t ReadResponseTimeoutMs = 100; - AgentMessageType Type; + if (m_Cancelled) + { + Finish(false); + return; + } - while ((Type = m_ChildChannel->ReadResponse(ReadResponseTimeoutMs)) != AgentMessageType::None) + switch (Type) { - switch (Type) - { - case AgentMessageType::ExecuteOutput: - { - if (LogOutput && m_ChildChannel->GetResponseSize() > 0) - { - const char* ResponseData = static_cast<const char*>(m_ChildChannel->GetResponseData()); - size_t ResponseSize = m_ChildChannel->GetResponseSize(); - - // Trim trailing newlines - while (ResponseSize > 0 && (ResponseData[ResponseSize - 1] == '\n' || ResponseData[ResponseSize - 1] == '\r')) - { - --ResponseSize; - } - - if (ResponseSize > 0) - { - const std::string_view Output(ResponseData, ResponseSize); - ZEN_INFO("[remote] {}", Output); - } - } - break; - } + case AgentMessageType::None: + if (m_ChildChannel->IsDetached()) + { + ZEN_WARN("connection lost during execution"); + Finish(false); + } + else + { + // Timeout — poll again + DoPoll(); + } + break; - case AgentMessageType::ExecuteResult: - { - if (m_ChildChannel->GetResponseSize() == sizeof(int32_t)) - { - int32_t ExitCode; - memcpy(&ExitCode, m_ChildChannel->GetResponseData(), sizeof(int32_t)); - ZEN_INFO("remote process exited with code {}", ExitCode); - } - m_IsValid = false; - return false; - } + case AgentMessageType::ExecuteOutput: + // Silently consume remote stdout (matching LogOutput=false in provisioner) + DoPoll(); + break; - case AgentMessageType::Exception: + case AgentMessageType::ExecuteResult: + { + int32_t ExitCode = -1; + if (Size == sizeof(int32_t)) { - ExceptionInfo Ex; - m_ChildChannel->ReadException(Ex); - ZEN_ERROR("exception: {} - {}", Ex.Message, Ex.Description); - m_HasErrors = true; - break; + memcpy(&ExitCode, Data, sizeof(int32_t)); } + ZEN_INFO("remote process exited with code {} (lease={})", ExitCode, m_Config.Machine.LeaseId); + Finish(ExitCode == 0, ExitCode); + } + break; - default: - break; - } - } + case AgentMessageType::Exception: + { + ExceptionInfo Ex; + AsyncAgentMessageChannel::ReadException(Data, Size, Ex); + ZEN_ERROR("exception: {} - {}", Ex.Message, Ex.Description); + Finish(false); + } + break; - return m_IsValid && !m_HasErrors; + default: + DoPoll(); + break; + } } void -HordeAgent::CloseConnection() +AsyncHordeAgent::Finish(bool Success, int32_t ExitCode) { - if (m_ChildChannel) + if (m_State == State::Done) { - m_ChildChannel->Close(); + return; // Already finished } - if (m_AgentChannel) + + if (!Success) { - m_AgentChannel->Close(); + ZEN_WARN("agent failed during {} (lease={})", GetStateName(m_State), m_Config.Machine.LeaseId); } -} -bool -HordeAgent::IsValid() const -{ - return m_IsValid && !m_HasErrors; + m_State = State::Done; + + if (m_Socket) + { + m_Socket->Close(); + } + + if (m_OnDone) + { + AsyncAgentResult Result; + Result.Success = Success; + Result.ExitCode = ExitCode; + Result.CoreCount = m_Config.Machine.LogicalCores; + + auto Handler = std::move(m_OnDone); + m_OnDone = nullptr; + Handler(Result); + } } } // namespace zen::horde |