diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/zenhorde/hordeclient.cpp | 3 | ||||
| -rw-r--r-- | src/zenhorde/hordeconfig.cpp | 13 | ||||
| -rw-r--r-- | src/zenhorde/hordeprovisioner.cpp | 24 | ||||
| -rw-r--r-- | src/zenhorde/hordetransportaes.cpp | 20 | ||||
| -rw-r--r-- | src/zenhorde/hordetransportaes.h | 6 | ||||
| -rw-r--r-- | src/zenhorde/include/zenhorde/hordeclient.h | 14 | ||||
| -rw-r--r-- | src/zenserver/compute/computeserver.cpp | 12 | ||||
| -rw-r--r-- | src/zenserver/compute/computeserver.h | 2 |
8 files changed, 72 insertions, 22 deletions
diff --git a/src/zenhorde/hordeclient.cpp b/src/zenhorde/hordeclient.cpp index c57332082..00edc8f1c 100644 --- a/src/zenhorde/hordeclient.cpp +++ b/src/zenhorde/hordeclient.cpp @@ -348,9 +348,10 @@ HordeClient::RequestMachine(const std::string& RequestBody, const std::string& C OutMachine.LeaseId = LeaseIdVal.string_value(); } - ZEN_INFO("Horde machine assigned [{}:{}] cores={} pool={} lease={}", + ZEN_INFO("Horde machine assigned [{}:{}] mode={} cores={} pool={} lease={}", OutMachine.GetConnectionAddress(), OutMachine.GetConnectionPort(), + ToString(OutMachine.Mode), OutMachine.LogicalCores, OutMachine.Pool, OutMachine.LeaseId); diff --git a/src/zenhorde/hordeconfig.cpp b/src/zenhorde/hordeconfig.cpp index d14d44c01..9f6125c64 100644 --- a/src/zenhorde/hordeconfig.cpp +++ b/src/zenhorde/hordeconfig.cpp @@ -1,6 +1,7 @@ // Copyright Epic Games, Inc. All Rights Reserved. #include <zencore/logging.h> +#include <zencore/string.h> #include <zenhorde/hordeconfig.h> namespace zen::horde { @@ -55,37 +56,39 @@ ToString(Encryption Enc) bool FromString(ConnectionMode& OutMode, std::string_view Str) { - if (Str == "direct") + if (StrCaseCompare(Str, "direct") == 0) { OutMode = ConnectionMode::Direct; return true; } - if (Str == "tunnel") + if (StrCaseCompare(Str, "tunnel") == 0) { OutMode = ConnectionMode::Tunnel; return true; } - if (Str == "relay") + if (StrCaseCompare(Str, "relay") == 0) { OutMode = ConnectionMode::Relay; return true; } + ZEN_WARN("unrecognized Horde connection mode: '{}'", Str); return false; } bool FromString(Encryption& OutEnc, std::string_view Str) { - if (Str == "none") + if (StrCaseCompare(Str, "none") == 0) { OutEnc = Encryption::None; return true; } - if (Str == "aes") + if (StrCaseCompare(Str, "aes") == 0) { OutEnc = Encryption::AES; return true; } + ZEN_WARN("unrecognized Horde encryption mode: '{}'", Str); return false; } diff --git a/src/zenhorde/hordeprovisioner.cpp b/src/zenhorde/hordeprovisioner.cpp index 55f5ed583..5ad9d46d0 100644 --- a/src/zenhorde/hordeprovisioner.cpp +++ b/src/zenhorde/hordeprovisioner.cpp @@ -624,6 +624,21 @@ HordeProvisioner::ThreadAgent(AgentWrapper& Wrapper) ArgStrings.emplace_back(IdArg.ToView()); } + // In relay mode, the remote zenserver's local address is not reachable from the + // orchestrator. Pass the relay-visible endpoint so it announces the correct URL. + if (Machine.Mode == ConnectionMode::Relay) + { + const auto [Addr, Port] = Machine.GetZenServiceEndpoint(m_Config.ZenServicePort); + if (Addr.find(':') != std::string::npos) + { + ArgStrings.push_back(fmt::format("--announce-url=http://[{}]:{}", Addr, Port)); + } + else + { + ArgStrings.push_back(fmt::format("--announce-url=http://{}:{}", Addr, Port)); + } + } + std::vector<const char*> Args; Args.reserve(ArgStrings.size()); for (const std::string& Arg : ArgStrings) @@ -646,15 +661,16 @@ HordeProvisioner::ThreadAgent(AgentWrapper& Wrapper) Machine.GetConnectionPort(), Machine.LeaseId); - MachineCoreCount = Machine.LogicalCores; + MachineCoreCount = Machine.LogicalCores; + const auto [EndpointAddr, EndpointPort] = Machine.GetZenServiceEndpoint(m_Config.ZenServicePort); // Bracket IPv6 addresses so the URL is valid (e.g. http://[::1]:8558) - if (Machine.Ip.find(':') != std::string::npos) + if (EndpointAddr.find(':') != std::string::npos) { - Wrapper.RemoteEndpoint = fmt::format("http://[{}]:{}", Machine.Ip, m_Config.ZenServicePort); + Wrapper.RemoteEndpoint = fmt::format("http://[{}]:{}", EndpointAddr, EndpointPort); } else { - Wrapper.RemoteEndpoint = fmt::format("http://{}:{}", Machine.Ip, m_Config.ZenServicePort); + Wrapper.RemoteEndpoint = fmt::format("http://{}:{}", EndpointAddr, EndpointPort); } Wrapper.CoreCount = MachineCoreCount; Wrapper.LeaseId = Machine.LeaseId; diff --git a/src/zenhorde/hordetransportaes.cpp b/src/zenhorde/hordetransportaes.cpp index 505b6bde7..dd71651af 100644 --- a/src/zenhorde/hordetransportaes.cpp +++ b/src/zenhorde/hordetransportaes.cpp @@ -286,23 +286,23 @@ AesComputeTransport::Send(const void* Data, size_t Size) return 0; } - std::lock_guard<std::mutex> Lock(m_Lock); + std::lock_guard<std::mutex> Lock(m_SendMutex); const int32_t DataLength = static_cast<int32_t>(Size); const size_t MessageLength = 4 + NonceBytes + Size + TagBytes; - if (m_EncryptBuffer.size() < MessageLength) + if (m_SendBuffer.size() < MessageLength) { - m_EncryptBuffer.resize(MessageLength); + m_SendBuffer.resize(MessageLength); } - const int32_t EncryptedLen = m_Crypto->EncryptMessage(m_EncryptBuffer.data(), Data, DataLength); + const int32_t EncryptedLen = m_Crypto->EncryptMessage(m_SendBuffer.data(), Data, DataLength); if (EncryptedLen == 0) { return 0; } - if (!m_Inner->SendMessage(m_EncryptBuffer.data(), static_cast<size_t>(EncryptedLen))) + if (!m_Inner->SendMessage(m_SendBuffer.data(), static_cast<size_t>(EncryptedLen))) { return 0; } @@ -323,7 +323,7 @@ AesComputeTransport::Recv(void* Data, size_t Size) // and returned on subsequent Recv calls without another decryption round-trip. ZEN_TRACE_CPU("AesComputeTransport::Recv"); - std::lock_guard<std::mutex> Lock(m_Lock); + std::lock_guard<std::mutex> Lock(m_RecvMutex); if (!m_RemainingData.empty()) { @@ -366,12 +366,12 @@ AesComputeTransport::Recv(void* Data, size_t Size) // Receive ciphertext + tag const size_t MessageLength = static_cast<size_t>(Header.DataLength) + TagBytes; - if (m_EncryptBuffer.size() < MessageLength) + if (m_RecvBuffer.size() < MessageLength) { - m_EncryptBuffer.resize(MessageLength); + m_RecvBuffer.resize(MessageLength); } - if (!m_Inner->RecvMessage(m_EncryptBuffer.data(), MessageLength)) + if (!m_Inner->RecvMessage(m_RecvBuffer.data(), MessageLength)) { return 0; } @@ -382,7 +382,7 @@ AesComputeTransport::Recv(void* Data, size_t Size) // We need a temporary buffer for decryption if we can't decrypt directly into output std::vector<uint8_t> DecryptedBuf(static_cast<size_t>(Header.DataLength)); - const int32_t Decrypted = m_Crypto->DecryptMessage(DecryptedBuf.data(), Header.Nonce, m_EncryptBuffer.data(), Header.DataLength); + const int32_t Decrypted = m_Crypto->DecryptMessage(DecryptedBuf.data(), Header.Nonce, m_RecvBuffer.data(), Header.DataLength); if (Decrypted == 0) { return 0; diff --git a/src/zenhorde/hordetransportaes.h b/src/zenhorde/hordetransportaes.h index efcad9835..7d0b3d27c 100644 --- a/src/zenhorde/hordetransportaes.h +++ b/src/zenhorde/hordetransportaes.h @@ -42,10 +42,12 @@ private: std::unique_ptr<CryptoContext> m_Crypto; std::unique_ptr<ComputeTransport> m_Inner; - std::vector<uint8_t> m_EncryptBuffer; + std::vector<uint8_t> m_SendBuffer; + std::vector<uint8_t> m_RecvBuffer; std::vector<uint8_t> m_RemainingData; ///< Buffered decrypted data from a partially consumed Recv size_t m_RemainingOffset = 0; - std::mutex m_Lock; + std::mutex m_SendMutex; + std::mutex m_RecvMutex; bool m_IsClosed = false; }; diff --git a/src/zenhorde/include/zenhorde/hordeclient.h b/src/zenhorde/include/zenhorde/hordeclient.h index 15b2ba9af..5c119dd19 100644 --- a/src/zenhorde/include/zenhorde/hordeclient.h +++ b/src/zenhorde/include/zenhorde/hordeclient.h @@ -66,6 +66,20 @@ struct MachineInfo return Port; } + /** Return the address and port for the Zen service endpoint, accounting for relay port mapping. */ + std::pair<const std::string&, uint16_t> GetZenServiceEndpoint(uint16_t DefaultPort) const + { + if (Mode == ConnectionMode::Relay) + { + auto It = Ports.find("ZenPort"); + if (It != Ports.end()) + { + return {ConnectionAddress, It->second.Port}; + } + } + return {Ip, DefaultPort}; + } + bool IsValid() const { return !Ip.empty() && Port != 0xFFFF; } }; diff --git a/src/zenserver/compute/computeserver.cpp b/src/zenserver/compute/computeserver.cpp index 8f6913fce..20cc4cd65 100644 --- a/src/zenserver/compute/computeserver.cpp +++ b/src/zenserver/compute/computeserver.cpp @@ -69,6 +69,13 @@ ZenComputeServerConfigurator::AddCliOptions(cxxopts::Options& Options) Options.add_option("compute", "", + "announce-url", + "Override URL announced to the coordinator (e.g. relay-visible endpoint)", + cxxopts::value<std::string>(m_ServerOptions.AnnounceUrl)->default_value(""), + ""); + + Options.add_option("compute", + "", "idms", "Enable IDMS cloud detection; optionally specify a custom probe endpoint", cxxopts::value<std::string>(m_ServerOptions.IdmsEndpoint)->default_value("")->implicit_value("auto"), @@ -387,6 +394,7 @@ ZenComputeServer::Initialize(const ZenComputeServerConfig& ServerConfig, ZenServ } m_CoordinatorEndpoint = ServerConfig.CoordinatorEndpoint; + m_AnnounceUrl = ServerConfig.AnnounceUrl; m_InstanceId = ServerConfig.InstanceId; m_EnableWorkerWebSocket = ServerConfig.EnableWorkerWebSocket; @@ -612,6 +620,10 @@ ZenComputeServer::GetInstanceId() const std::string ZenComputeServer::GetAnnounceUrl() const { + if (!m_AnnounceUrl.empty()) + { + return m_AnnounceUrl; + } return m_Http->GetServiceUri(nullptr); } diff --git a/src/zenserver/compute/computeserver.h b/src/zenserver/compute/computeserver.h index 83205a2e4..9dc83edb9 100644 --- a/src/zenserver/compute/computeserver.h +++ b/src/zenserver/compute/computeserver.h @@ -49,6 +49,7 @@ struct ZenComputeServerConfig : public ZenServerConfig std::string UpstreamNotificationEndpoint; std::string InstanceId; // For use in notifications std::string CoordinatorEndpoint; + std::string AnnounceUrl; ///< Override for self-announced URL (e.g. relay-visible endpoint) std::string IdmsEndpoint; int32_t MaxConcurrentActions = 0; // 0 = auto (LogicalProcessorCount * 2) bool EnableWorkerWebSocket = false; // Use WebSocket for worker↔orchestrator link @@ -147,6 +148,7 @@ private: # endif SystemMetricsTracker m_MetricsTracker; std::string m_CoordinatorEndpoint; + std::string m_AnnounceUrl; std::string m_InstanceId; asio::steady_timer m_AnnounceTimer{m_IoContext}; |