From 04640f9a2ddf11b98c7f2bd3413e7e1fb055feda Mon Sep 17 00:00:00 2001 From: Stefan Boberg Date: Thu, 2 Apr 2026 11:13:25 +0200 Subject: Fix relay mode: case-insensitive mode parsing, AES send/recv deadlock, and endpoint routing - Make FromString for ConnectionMode and Encryption case-insensitive so PascalCase values from the Horde API (e.g. "Relay") are recognized. - Split AesComputeTransport's single mutex into separate send/recv mutexes to prevent deadlock where the recv thread blocks on TCP while holding the lock, starving the send thread from sending Fork. - Add MachineInfo::GetZenServiceEndpoint() to resolve the relay-mapped address and port for the Zen service, used by the provisioner for both its own health-check endpoint and the remote zenserver's announce URL. - Add --announce-url CLI option so the provisioner can tell the remote zenserver which externally-visible URL to announce to the orchestrator (instead of its unreachable private IP in relay mode). - Log connection mode in machine-assigned message for diagnostics. --- src/zenhorde/hordeclient.cpp | 3 ++- src/zenhorde/hordeconfig.cpp | 13 ++++++++----- src/zenhorde/hordeprovisioner.cpp | 24 ++++++++++++++++++++---- src/zenhorde/hordetransportaes.cpp | 20 ++++++++++---------- src/zenhorde/hordetransportaes.h | 6 ++++-- src/zenhorde/include/zenhorde/hordeclient.h | 14 ++++++++++++++ src/zenserver/compute/computeserver.cpp | 12 ++++++++++++ src/zenserver/compute/computeserver.h | 2 ++ 8 files changed, 72 insertions(+), 22 deletions(-) (limited to 'src') diff --git a/src/zenhorde/hordeclient.cpp b/src/zenhorde/hordeclient.cpp index c57332082..00edc8f1c 100644 --- a/src/zenhorde/hordeclient.cpp +++ b/src/zenhorde/hordeclient.cpp @@ -348,9 +348,10 @@ HordeClient::RequestMachine(const std::string& RequestBody, const std::string& C OutMachine.LeaseId = LeaseIdVal.string_value(); } - ZEN_INFO("Horde machine assigned [{}:{}] cores={} pool={} lease={}", + ZEN_INFO("Horde machine assigned [{}:{}] mode={} cores={} pool={} lease={}", OutMachine.GetConnectionAddress(), OutMachine.GetConnectionPort(), + ToString(OutMachine.Mode), OutMachine.LogicalCores, OutMachine.Pool, OutMachine.LeaseId); diff --git a/src/zenhorde/hordeconfig.cpp b/src/zenhorde/hordeconfig.cpp index d14d44c01..9f6125c64 100644 --- a/src/zenhorde/hordeconfig.cpp +++ b/src/zenhorde/hordeconfig.cpp @@ -1,6 +1,7 @@ // Copyright Epic Games, Inc. All Rights Reserved. #include +#include #include namespace zen::horde { @@ -55,37 +56,39 @@ ToString(Encryption Enc) bool FromString(ConnectionMode& OutMode, std::string_view Str) { - if (Str == "direct") + if (StrCaseCompare(Str, "direct") == 0) { OutMode = ConnectionMode::Direct; return true; } - if (Str == "tunnel") + if (StrCaseCompare(Str, "tunnel") == 0) { OutMode = ConnectionMode::Tunnel; return true; } - if (Str == "relay") + if (StrCaseCompare(Str, "relay") == 0) { OutMode = ConnectionMode::Relay; return true; } + ZEN_WARN("unrecognized Horde connection mode: '{}'", Str); return false; } bool FromString(Encryption& OutEnc, std::string_view Str) { - if (Str == "none") + if (StrCaseCompare(Str, "none") == 0) { OutEnc = Encryption::None; return true; } - if (Str == "aes") + if (StrCaseCompare(Str, "aes") == 0) { OutEnc = Encryption::AES; return true; } + ZEN_WARN("unrecognized Horde encryption mode: '{}'", Str); return false; } diff --git a/src/zenhorde/hordeprovisioner.cpp b/src/zenhorde/hordeprovisioner.cpp index 55f5ed583..5ad9d46d0 100644 --- a/src/zenhorde/hordeprovisioner.cpp +++ b/src/zenhorde/hordeprovisioner.cpp @@ -624,6 +624,21 @@ HordeProvisioner::ThreadAgent(AgentWrapper& Wrapper) ArgStrings.emplace_back(IdArg.ToView()); } + // In relay mode, the remote zenserver's local address is not reachable from the + // orchestrator. Pass the relay-visible endpoint so it announces the correct URL. + if (Machine.Mode == ConnectionMode::Relay) + { + const auto [Addr, Port] = Machine.GetZenServiceEndpoint(m_Config.ZenServicePort); + if (Addr.find(':') != std::string::npos) + { + ArgStrings.push_back(fmt::format("--announce-url=http://[{}]:{}", Addr, Port)); + } + else + { + ArgStrings.push_back(fmt::format("--announce-url=http://{}:{}", Addr, Port)); + } + } + std::vector Args; Args.reserve(ArgStrings.size()); for (const std::string& Arg : ArgStrings) @@ -646,15 +661,16 @@ HordeProvisioner::ThreadAgent(AgentWrapper& Wrapper) Machine.GetConnectionPort(), Machine.LeaseId); - MachineCoreCount = Machine.LogicalCores; + MachineCoreCount = Machine.LogicalCores; + const auto [EndpointAddr, EndpointPort] = Machine.GetZenServiceEndpoint(m_Config.ZenServicePort); // Bracket IPv6 addresses so the URL is valid (e.g. http://[::1]:8558) - if (Machine.Ip.find(':') != std::string::npos) + if (EndpointAddr.find(':') != std::string::npos) { - Wrapper.RemoteEndpoint = fmt::format("http://[{}]:{}", Machine.Ip, m_Config.ZenServicePort); + Wrapper.RemoteEndpoint = fmt::format("http://[{}]:{}", EndpointAddr, EndpointPort); } else { - Wrapper.RemoteEndpoint = fmt::format("http://{}:{}", Machine.Ip, m_Config.ZenServicePort); + Wrapper.RemoteEndpoint = fmt::format("http://{}:{}", EndpointAddr, EndpointPort); } Wrapper.CoreCount = MachineCoreCount; Wrapper.LeaseId = Machine.LeaseId; diff --git a/src/zenhorde/hordetransportaes.cpp b/src/zenhorde/hordetransportaes.cpp index 505b6bde7..dd71651af 100644 --- a/src/zenhorde/hordetransportaes.cpp +++ b/src/zenhorde/hordetransportaes.cpp @@ -286,23 +286,23 @@ AesComputeTransport::Send(const void* Data, size_t Size) return 0; } - std::lock_guard Lock(m_Lock); + std::lock_guard Lock(m_SendMutex); const int32_t DataLength = static_cast(Size); const size_t MessageLength = 4 + NonceBytes + Size + TagBytes; - if (m_EncryptBuffer.size() < MessageLength) + if (m_SendBuffer.size() < MessageLength) { - m_EncryptBuffer.resize(MessageLength); + m_SendBuffer.resize(MessageLength); } - const int32_t EncryptedLen = m_Crypto->EncryptMessage(m_EncryptBuffer.data(), Data, DataLength); + const int32_t EncryptedLen = m_Crypto->EncryptMessage(m_SendBuffer.data(), Data, DataLength); if (EncryptedLen == 0) { return 0; } - if (!m_Inner->SendMessage(m_EncryptBuffer.data(), static_cast(EncryptedLen))) + if (!m_Inner->SendMessage(m_SendBuffer.data(), static_cast(EncryptedLen))) { return 0; } @@ -323,7 +323,7 @@ AesComputeTransport::Recv(void* Data, size_t Size) // and returned on subsequent Recv calls without another decryption round-trip. ZEN_TRACE_CPU("AesComputeTransport::Recv"); - std::lock_guard Lock(m_Lock); + std::lock_guard Lock(m_RecvMutex); if (!m_RemainingData.empty()) { @@ -366,12 +366,12 @@ AesComputeTransport::Recv(void* Data, size_t Size) // Receive ciphertext + tag const size_t MessageLength = static_cast(Header.DataLength) + TagBytes; - if (m_EncryptBuffer.size() < MessageLength) + if (m_RecvBuffer.size() < MessageLength) { - m_EncryptBuffer.resize(MessageLength); + m_RecvBuffer.resize(MessageLength); } - if (!m_Inner->RecvMessage(m_EncryptBuffer.data(), MessageLength)) + if (!m_Inner->RecvMessage(m_RecvBuffer.data(), MessageLength)) { return 0; } @@ -382,7 +382,7 @@ AesComputeTransport::Recv(void* Data, size_t Size) // We need a temporary buffer for decryption if we can't decrypt directly into output std::vector DecryptedBuf(static_cast(Header.DataLength)); - const int32_t Decrypted = m_Crypto->DecryptMessage(DecryptedBuf.data(), Header.Nonce, m_EncryptBuffer.data(), Header.DataLength); + const int32_t Decrypted = m_Crypto->DecryptMessage(DecryptedBuf.data(), Header.Nonce, m_RecvBuffer.data(), Header.DataLength); if (Decrypted == 0) { return 0; diff --git a/src/zenhorde/hordetransportaes.h b/src/zenhorde/hordetransportaes.h index efcad9835..7d0b3d27c 100644 --- a/src/zenhorde/hordetransportaes.h +++ b/src/zenhorde/hordetransportaes.h @@ -42,10 +42,12 @@ private: std::unique_ptr m_Crypto; std::unique_ptr m_Inner; - std::vector m_EncryptBuffer; + std::vector m_SendBuffer; + std::vector m_RecvBuffer; std::vector m_RemainingData; ///< Buffered decrypted data from a partially consumed Recv size_t m_RemainingOffset = 0; - std::mutex m_Lock; + std::mutex m_SendMutex; + std::mutex m_RecvMutex; bool m_IsClosed = false; }; diff --git a/src/zenhorde/include/zenhorde/hordeclient.h b/src/zenhorde/include/zenhorde/hordeclient.h index 15b2ba9af..5c119dd19 100644 --- a/src/zenhorde/include/zenhorde/hordeclient.h +++ b/src/zenhorde/include/zenhorde/hordeclient.h @@ -66,6 +66,20 @@ struct MachineInfo return Port; } + /** Return the address and port for the Zen service endpoint, accounting for relay port mapping. */ + std::pair GetZenServiceEndpoint(uint16_t DefaultPort) const + { + if (Mode == ConnectionMode::Relay) + { + auto It = Ports.find("ZenPort"); + if (It != Ports.end()) + { + return {ConnectionAddress, It->second.Port}; + } + } + return {Ip, DefaultPort}; + } + bool IsValid() const { return !Ip.empty() && Port != 0xFFFF; } }; diff --git a/src/zenserver/compute/computeserver.cpp b/src/zenserver/compute/computeserver.cpp index 8f6913fce..20cc4cd65 100644 --- a/src/zenserver/compute/computeserver.cpp +++ b/src/zenserver/compute/computeserver.cpp @@ -67,6 +67,13 @@ ZenComputeServerConfigurator::AddCliOptions(cxxopts::Options& Options) cxxopts::value(m_ServerOptions.CoordinatorEndpoint)->default_value(""), ""); + Options.add_option("compute", + "", + "announce-url", + "Override URL announced to the coordinator (e.g. relay-visible endpoint)", + cxxopts::value(m_ServerOptions.AnnounceUrl)->default_value(""), + ""); + Options.add_option("compute", "", "idms", @@ -387,6 +394,7 @@ ZenComputeServer::Initialize(const ZenComputeServerConfig& ServerConfig, ZenServ } m_CoordinatorEndpoint = ServerConfig.CoordinatorEndpoint; + m_AnnounceUrl = ServerConfig.AnnounceUrl; m_InstanceId = ServerConfig.InstanceId; m_EnableWorkerWebSocket = ServerConfig.EnableWorkerWebSocket; @@ -612,6 +620,10 @@ ZenComputeServer::GetInstanceId() const std::string ZenComputeServer::GetAnnounceUrl() const { + if (!m_AnnounceUrl.empty()) + { + return m_AnnounceUrl; + } return m_Http->GetServiceUri(nullptr); } diff --git a/src/zenserver/compute/computeserver.h b/src/zenserver/compute/computeserver.h index 83205a2e4..9dc83edb9 100644 --- a/src/zenserver/compute/computeserver.h +++ b/src/zenserver/compute/computeserver.h @@ -49,6 +49,7 @@ struct ZenComputeServerConfig : public ZenServerConfig std::string UpstreamNotificationEndpoint; std::string InstanceId; // For use in notifications std::string CoordinatorEndpoint; + std::string AnnounceUrl; ///< Override for self-announced URL (e.g. relay-visible endpoint) std::string IdmsEndpoint; int32_t MaxConcurrentActions = 0; // 0 = auto (LogicalProcessorCount * 2) bool EnableWorkerWebSocket = false; // Use WebSocket for worker↔orchestrator link @@ -147,6 +148,7 @@ private: # endif SystemMetricsTracker m_MetricsTracker; std::string m_CoordinatorEndpoint; + std::string m_AnnounceUrl; std::string m_InstanceId; asio::steady_timer m_AnnounceTimer{m_IoContext}; -- cgit v1.2.3