aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorStefan Boberg <[email protected]>2026-04-02 11:13:25 +0200
committerStefan Boberg <[email protected]>2026-04-02 11:13:25 +0200
commit04640f9a2ddf11b98c7f2bd3413e7e1fb055feda (patch)
tree439c99d87f50fc4971de2bf5cdf7bbb376ed345a /src
parentImprove OidcToken auth diagnostics and use --HordeUrl for Horde servers (diff)
downloadzen-sb/compute-oidc-auth.tar.xz
zen-sb/compute-oidc-auth.zip
Fix relay mode: case-insensitive mode parsing, AES send/recv deadlock, and endpoint routingsb/compute-oidc-auth
- Make FromString for ConnectionMode and Encryption case-insensitive so PascalCase values from the Horde API (e.g. "Relay") are recognized. - Split AesComputeTransport's single mutex into separate send/recv mutexes to prevent deadlock where the recv thread blocks on TCP while holding the lock, starving the send thread from sending Fork. - Add MachineInfo::GetZenServiceEndpoint() to resolve the relay-mapped address and port for the Zen service, used by the provisioner for both its own health-check endpoint and the remote zenserver's announce URL. - Add --announce-url CLI option so the provisioner can tell the remote zenserver which externally-visible URL to announce to the orchestrator (instead of its unreachable private IP in relay mode). - Log connection mode in machine-assigned message for diagnostics.
Diffstat (limited to 'src')
-rw-r--r--src/zenhorde/hordeclient.cpp3
-rw-r--r--src/zenhorde/hordeconfig.cpp13
-rw-r--r--src/zenhorde/hordeprovisioner.cpp24
-rw-r--r--src/zenhorde/hordetransportaes.cpp20
-rw-r--r--src/zenhorde/hordetransportaes.h6
-rw-r--r--src/zenhorde/include/zenhorde/hordeclient.h14
-rw-r--r--src/zenserver/compute/computeserver.cpp12
-rw-r--r--src/zenserver/compute/computeserver.h2
8 files changed, 72 insertions, 22 deletions
diff --git a/src/zenhorde/hordeclient.cpp b/src/zenhorde/hordeclient.cpp
index c57332082..00edc8f1c 100644
--- a/src/zenhorde/hordeclient.cpp
+++ b/src/zenhorde/hordeclient.cpp
@@ -348,9 +348,10 @@ HordeClient::RequestMachine(const std::string& RequestBody, const std::string& C
OutMachine.LeaseId = LeaseIdVal.string_value();
}
- ZEN_INFO("Horde machine assigned [{}:{}] cores={} pool={} lease={}",
+ ZEN_INFO("Horde machine assigned [{}:{}] mode={} cores={} pool={} lease={}",
OutMachine.GetConnectionAddress(),
OutMachine.GetConnectionPort(),
+ ToString(OutMachine.Mode),
OutMachine.LogicalCores,
OutMachine.Pool,
OutMachine.LeaseId);
diff --git a/src/zenhorde/hordeconfig.cpp b/src/zenhorde/hordeconfig.cpp
index d14d44c01..9f6125c64 100644
--- a/src/zenhorde/hordeconfig.cpp
+++ b/src/zenhorde/hordeconfig.cpp
@@ -1,6 +1,7 @@
// Copyright Epic Games, Inc. All Rights Reserved.
#include <zencore/logging.h>
+#include <zencore/string.h>
#include <zenhorde/hordeconfig.h>
namespace zen::horde {
@@ -55,37 +56,39 @@ ToString(Encryption Enc)
bool
FromString(ConnectionMode& OutMode, std::string_view Str)
{
- if (Str == "direct")
+ if (StrCaseCompare(Str, "direct") == 0)
{
OutMode = ConnectionMode::Direct;
return true;
}
- if (Str == "tunnel")
+ if (StrCaseCompare(Str, "tunnel") == 0)
{
OutMode = ConnectionMode::Tunnel;
return true;
}
- if (Str == "relay")
+ if (StrCaseCompare(Str, "relay") == 0)
{
OutMode = ConnectionMode::Relay;
return true;
}
+ ZEN_WARN("unrecognized Horde connection mode: '{}'", Str);
return false;
}
bool
FromString(Encryption& OutEnc, std::string_view Str)
{
- if (Str == "none")
+ if (StrCaseCompare(Str, "none") == 0)
{
OutEnc = Encryption::None;
return true;
}
- if (Str == "aes")
+ if (StrCaseCompare(Str, "aes") == 0)
{
OutEnc = Encryption::AES;
return true;
}
+ ZEN_WARN("unrecognized Horde encryption mode: '{}'", Str);
return false;
}
diff --git a/src/zenhorde/hordeprovisioner.cpp b/src/zenhorde/hordeprovisioner.cpp
index 55f5ed583..5ad9d46d0 100644
--- a/src/zenhorde/hordeprovisioner.cpp
+++ b/src/zenhorde/hordeprovisioner.cpp
@@ -624,6 +624,21 @@ HordeProvisioner::ThreadAgent(AgentWrapper& Wrapper)
ArgStrings.emplace_back(IdArg.ToView());
}
+ // In relay mode, the remote zenserver's local address is not reachable from the
+ // orchestrator. Pass the relay-visible endpoint so it announces the correct URL.
+ if (Machine.Mode == ConnectionMode::Relay)
+ {
+ const auto [Addr, Port] = Machine.GetZenServiceEndpoint(m_Config.ZenServicePort);
+ if (Addr.find(':') != std::string::npos)
+ {
+ ArgStrings.push_back(fmt::format("--announce-url=http://[{}]:{}", Addr, Port));
+ }
+ else
+ {
+ ArgStrings.push_back(fmt::format("--announce-url=http://{}:{}", Addr, Port));
+ }
+ }
+
std::vector<const char*> Args;
Args.reserve(ArgStrings.size());
for (const std::string& Arg : ArgStrings)
@@ -646,15 +661,16 @@ HordeProvisioner::ThreadAgent(AgentWrapper& Wrapper)
Machine.GetConnectionPort(),
Machine.LeaseId);
- MachineCoreCount = Machine.LogicalCores;
+ MachineCoreCount = Machine.LogicalCores;
+ const auto [EndpointAddr, EndpointPort] = Machine.GetZenServiceEndpoint(m_Config.ZenServicePort);
// Bracket IPv6 addresses so the URL is valid (e.g. http://[::1]:8558)
- if (Machine.Ip.find(':') != std::string::npos)
+ if (EndpointAddr.find(':') != std::string::npos)
{
- Wrapper.RemoteEndpoint = fmt::format("http://[{}]:{}", Machine.Ip, m_Config.ZenServicePort);
+ Wrapper.RemoteEndpoint = fmt::format("http://[{}]:{}", EndpointAddr, EndpointPort);
}
else
{
- Wrapper.RemoteEndpoint = fmt::format("http://{}:{}", Machine.Ip, m_Config.ZenServicePort);
+ Wrapper.RemoteEndpoint = fmt::format("http://{}:{}", EndpointAddr, EndpointPort);
}
Wrapper.CoreCount = MachineCoreCount;
Wrapper.LeaseId = Machine.LeaseId;
diff --git a/src/zenhorde/hordetransportaes.cpp b/src/zenhorde/hordetransportaes.cpp
index 505b6bde7..dd71651af 100644
--- a/src/zenhorde/hordetransportaes.cpp
+++ b/src/zenhorde/hordetransportaes.cpp
@@ -286,23 +286,23 @@ AesComputeTransport::Send(const void* Data, size_t Size)
return 0;
}
- std::lock_guard<std::mutex> Lock(m_Lock);
+ std::lock_guard<std::mutex> Lock(m_SendMutex);
const int32_t DataLength = static_cast<int32_t>(Size);
const size_t MessageLength = 4 + NonceBytes + Size + TagBytes;
- if (m_EncryptBuffer.size() < MessageLength)
+ if (m_SendBuffer.size() < MessageLength)
{
- m_EncryptBuffer.resize(MessageLength);
+ m_SendBuffer.resize(MessageLength);
}
- const int32_t EncryptedLen = m_Crypto->EncryptMessage(m_EncryptBuffer.data(), Data, DataLength);
+ const int32_t EncryptedLen = m_Crypto->EncryptMessage(m_SendBuffer.data(), Data, DataLength);
if (EncryptedLen == 0)
{
return 0;
}
- if (!m_Inner->SendMessage(m_EncryptBuffer.data(), static_cast<size_t>(EncryptedLen)))
+ if (!m_Inner->SendMessage(m_SendBuffer.data(), static_cast<size_t>(EncryptedLen)))
{
return 0;
}
@@ -323,7 +323,7 @@ AesComputeTransport::Recv(void* Data, size_t Size)
// and returned on subsequent Recv calls without another decryption round-trip.
ZEN_TRACE_CPU("AesComputeTransport::Recv");
- std::lock_guard<std::mutex> Lock(m_Lock);
+ std::lock_guard<std::mutex> Lock(m_RecvMutex);
if (!m_RemainingData.empty())
{
@@ -366,12 +366,12 @@ AesComputeTransport::Recv(void* Data, size_t Size)
// Receive ciphertext + tag
const size_t MessageLength = static_cast<size_t>(Header.DataLength) + TagBytes;
- if (m_EncryptBuffer.size() < MessageLength)
+ if (m_RecvBuffer.size() < MessageLength)
{
- m_EncryptBuffer.resize(MessageLength);
+ m_RecvBuffer.resize(MessageLength);
}
- if (!m_Inner->RecvMessage(m_EncryptBuffer.data(), MessageLength))
+ if (!m_Inner->RecvMessage(m_RecvBuffer.data(), MessageLength))
{
return 0;
}
@@ -382,7 +382,7 @@ AesComputeTransport::Recv(void* Data, size_t Size)
// We need a temporary buffer for decryption if we can't decrypt directly into output
std::vector<uint8_t> DecryptedBuf(static_cast<size_t>(Header.DataLength));
- const int32_t Decrypted = m_Crypto->DecryptMessage(DecryptedBuf.data(), Header.Nonce, m_EncryptBuffer.data(), Header.DataLength);
+ const int32_t Decrypted = m_Crypto->DecryptMessage(DecryptedBuf.data(), Header.Nonce, m_RecvBuffer.data(), Header.DataLength);
if (Decrypted == 0)
{
return 0;
diff --git a/src/zenhorde/hordetransportaes.h b/src/zenhorde/hordetransportaes.h
index efcad9835..7d0b3d27c 100644
--- a/src/zenhorde/hordetransportaes.h
+++ b/src/zenhorde/hordetransportaes.h
@@ -42,10 +42,12 @@ private:
std::unique_ptr<CryptoContext> m_Crypto;
std::unique_ptr<ComputeTransport> m_Inner;
- std::vector<uint8_t> m_EncryptBuffer;
+ std::vector<uint8_t> m_SendBuffer;
+ std::vector<uint8_t> m_RecvBuffer;
std::vector<uint8_t> m_RemainingData; ///< Buffered decrypted data from a partially consumed Recv
size_t m_RemainingOffset = 0;
- std::mutex m_Lock;
+ std::mutex m_SendMutex;
+ std::mutex m_RecvMutex;
bool m_IsClosed = false;
};
diff --git a/src/zenhorde/include/zenhorde/hordeclient.h b/src/zenhorde/include/zenhorde/hordeclient.h
index 15b2ba9af..5c119dd19 100644
--- a/src/zenhorde/include/zenhorde/hordeclient.h
+++ b/src/zenhorde/include/zenhorde/hordeclient.h
@@ -66,6 +66,20 @@ struct MachineInfo
return Port;
}
+ /** Return the address and port for the Zen service endpoint, accounting for relay port mapping. */
+ std::pair<const std::string&, uint16_t> GetZenServiceEndpoint(uint16_t DefaultPort) const
+ {
+ if (Mode == ConnectionMode::Relay)
+ {
+ auto It = Ports.find("ZenPort");
+ if (It != Ports.end())
+ {
+ return {ConnectionAddress, It->second.Port};
+ }
+ }
+ return {Ip, DefaultPort};
+ }
+
bool IsValid() const { return !Ip.empty() && Port != 0xFFFF; }
};
diff --git a/src/zenserver/compute/computeserver.cpp b/src/zenserver/compute/computeserver.cpp
index 8f6913fce..20cc4cd65 100644
--- a/src/zenserver/compute/computeserver.cpp
+++ b/src/zenserver/compute/computeserver.cpp
@@ -69,6 +69,13 @@ ZenComputeServerConfigurator::AddCliOptions(cxxopts::Options& Options)
Options.add_option("compute",
"",
+ "announce-url",
+ "Override URL announced to the coordinator (e.g. relay-visible endpoint)",
+ cxxopts::value<std::string>(m_ServerOptions.AnnounceUrl)->default_value(""),
+ "");
+
+ Options.add_option("compute",
+ "",
"idms",
"Enable IDMS cloud detection; optionally specify a custom probe endpoint",
cxxopts::value<std::string>(m_ServerOptions.IdmsEndpoint)->default_value("")->implicit_value("auto"),
@@ -387,6 +394,7 @@ ZenComputeServer::Initialize(const ZenComputeServerConfig& ServerConfig, ZenServ
}
m_CoordinatorEndpoint = ServerConfig.CoordinatorEndpoint;
+ m_AnnounceUrl = ServerConfig.AnnounceUrl;
m_InstanceId = ServerConfig.InstanceId;
m_EnableWorkerWebSocket = ServerConfig.EnableWorkerWebSocket;
@@ -612,6 +620,10 @@ ZenComputeServer::GetInstanceId() const
std::string
ZenComputeServer::GetAnnounceUrl() const
{
+ if (!m_AnnounceUrl.empty())
+ {
+ return m_AnnounceUrl;
+ }
return m_Http->GetServiceUri(nullptr);
}
diff --git a/src/zenserver/compute/computeserver.h b/src/zenserver/compute/computeserver.h
index 83205a2e4..9dc83edb9 100644
--- a/src/zenserver/compute/computeserver.h
+++ b/src/zenserver/compute/computeserver.h
@@ -49,6 +49,7 @@ struct ZenComputeServerConfig : public ZenServerConfig
std::string UpstreamNotificationEndpoint;
std::string InstanceId; // For use in notifications
std::string CoordinatorEndpoint;
+ std::string AnnounceUrl; ///< Override for self-announced URL (e.g. relay-visible endpoint)
std::string IdmsEndpoint;
int32_t MaxConcurrentActions = 0; // 0 = auto (LogicalProcessorCount * 2)
bool EnableWorkerWebSocket = false; // Use WebSocket for worker↔orchestrator link
@@ -147,6 +148,7 @@ private:
# endif
SystemMetricsTracker m_MetricsTracker;
std::string m_CoordinatorEndpoint;
+ std::string m_AnnounceUrl;
std::string m_InstanceId;
asio::steady_timer m_AnnounceTimer{m_IoContext};