aboutsummaryrefslogtreecommitdiff
path: root/src/zenhorde/hordetransportaes.cpp
diff options
context:
space:
mode:
authorStefan Boberg <[email protected]>2026-03-04 14:13:46 +0100
committerGitHub Enterprise <[email protected]>2026-03-04 14:13:46 +0100
commit0763d09a81e5a1d3df11763a7ec75e7860c9510a (patch)
tree074575ba6ea259044a179eab0bb396d37268fb09 /src/zenhorde/hordetransportaes.cpp
parentnative xmake toolchain definition for UE-clang (#805) (diff)
downloadzen-0763d09a81e5a1d3df11763a7ec75e7860c9510a.tar.xz
zen-0763d09a81e5a1d3df11763a7ec75e7860c9510a.zip
compute orchestration (#763)
- Added local process runners for Linux/Wine, Mac with some sandboxing support - Horde & Nomad provisioning for development and testing - Client session queues with lifecycle management (active/draining/cancelled), automatic retry with configurable limits, and manual reschedule API - Improved web UI for orchestrator, compute, and hub dashboards with WebSocket push updates - Some security hardening - Improved scalability and `zen exec` command Still experimental - compute support is disabled by default
Diffstat (limited to 'src/zenhorde/hordetransportaes.cpp')
-rw-r--r--src/zenhorde/hordetransportaes.cpp425
1 files changed, 425 insertions, 0 deletions
diff --git a/src/zenhorde/hordetransportaes.cpp b/src/zenhorde/hordetransportaes.cpp
new file mode 100644
index 000000000..986dd3705
--- /dev/null
+++ b/src/zenhorde/hordetransportaes.cpp
@@ -0,0 +1,425 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "hordetransportaes.h"
+
+#include <zencore/logging.h>
+#include <zencore/trace.h>
+
+#include <algorithm>
+#include <cstring>
+#include <random>
+
+#if ZEN_PLATFORM_WINDOWS
+# include <zencore/windows.h>
+# include <bcrypt.h>
+# pragma comment(lib, "Bcrypt.lib")
+#else
+ZEN_THIRD_PARTY_INCLUDES_START
+# include <openssl/evp.h>
+# include <openssl/err.h>
+ZEN_THIRD_PARTY_INCLUDES_END
+#endif
+
+namespace zen::horde {
+
+struct AesComputeTransport::CryptoContext
+{
+ uint8_t Key[KeySize] = {};
+ uint8_t EncryptNonce[NonceBytes] = {};
+ uint8_t DecryptNonce[NonceBytes] = {};
+ bool HasErrors = false;
+
+#if !ZEN_PLATFORM_WINDOWS
+ EVP_CIPHER_CTX* EncCtx = nullptr;
+ EVP_CIPHER_CTX* DecCtx = nullptr;
+#endif
+
+ CryptoContext(const uint8_t (&InKey)[KeySize])
+ {
+ memcpy(Key, InKey, KeySize);
+
+ // The encrypt nonce is randomly initialized and then deterministically mutated
+ // per message via UpdateNonce(). The decrypt nonce is not used — it comes from
+ // the wire (each received message carries its own nonce in the header).
+ std::random_device Rd;
+ std::mt19937 Gen(Rd());
+ std::uniform_int_distribution<int> Dist(0, 255);
+ for (auto& Byte : EncryptNonce)
+ {
+ Byte = static_cast<uint8_t>(Dist(Gen));
+ }
+
+#if !ZEN_PLATFORM_WINDOWS
+ // Drain any stale OpenSSL errors
+ while (ERR_get_error() != 0)
+ {
+ }
+
+ EncCtx = EVP_CIPHER_CTX_new();
+ EVP_EncryptInit_ex(EncCtx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr);
+
+ DecCtx = EVP_CIPHER_CTX_new();
+ EVP_DecryptInit_ex(DecCtx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr);
+#endif
+ }
+
+ ~CryptoContext()
+ {
+#if ZEN_PLATFORM_WINDOWS
+ SecureZeroMemory(Key, sizeof(Key));
+ SecureZeroMemory(EncryptNonce, sizeof(EncryptNonce));
+ SecureZeroMemory(DecryptNonce, sizeof(DecryptNonce));
+#else
+ OPENSSL_cleanse(Key, sizeof(Key));
+ OPENSSL_cleanse(EncryptNonce, sizeof(EncryptNonce));
+ OPENSSL_cleanse(DecryptNonce, sizeof(DecryptNonce));
+
+ if (EncCtx)
+ {
+ EVP_CIPHER_CTX_free(EncCtx);
+ }
+ if (DecCtx)
+ {
+ EVP_CIPHER_CTX_free(DecCtx);
+ }
+#endif
+ }
+
+ void UpdateNonce()
+ {
+ uint32_t* N32 = reinterpret_cast<uint32_t*>(EncryptNonce);
+ N32[0]++;
+ N32[1]--;
+ N32[2] = N32[0] ^ N32[1];
+ }
+
+ // Returns total encrypted message size, or 0 on failure
+ // Output format: [length(4B)][nonce(12B)][ciphertext][tag(16B)]
+ int32_t EncryptMessage(uint8_t* Out, const void* In, int32_t InLength)
+ {
+ UpdateNonce();
+
+ // On Windows, BCrypt algorithm/key handles are created per call. This is simpler than
+ // caching but has some overhead. For our use case (relatively large, infrequent messages)
+ // this is acceptable.
+#if ZEN_PLATFORM_WINDOWS
+ BCRYPT_ALG_HANDLE hAlg = nullptr;
+ BCRYPT_KEY_HANDLE hKey = nullptr;
+
+ BCryptOpenAlgorithmProvider(&hAlg, BCRYPT_AES_ALGORITHM, nullptr, 0);
+ BCryptSetProperty(hAlg, BCRYPT_CHAINING_MODE, (PUCHAR)BCRYPT_CHAIN_MODE_GCM, sizeof(BCRYPT_CHAIN_MODE_GCM), 0);
+ BCryptGenerateSymmetricKey(hAlg, &hKey, nullptr, 0, (PUCHAR)Key, KeySize, 0);
+
+ BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO AuthInfo;
+ BCRYPT_INIT_AUTH_MODE_INFO(AuthInfo);
+ AuthInfo.pbNonce = EncryptNonce;
+ AuthInfo.cbNonce = NonceBytes;
+ uint8_t Tag[TagBytes] = {};
+ AuthInfo.pbTag = Tag;
+ AuthInfo.cbTag = TagBytes;
+
+ ULONG CipherLen = 0;
+ NTSTATUS Status =
+ BCryptEncrypt(hKey, (PUCHAR)In, (ULONG)InLength, &AuthInfo, nullptr, 0, Out + 4 + NonceBytes, (ULONG)InLength, &CipherLen, 0);
+
+ if (!BCRYPT_SUCCESS(Status))
+ {
+ HasErrors = true;
+ BCryptDestroyKey(hKey);
+ BCryptCloseAlgorithmProvider(hAlg, 0);
+ return 0;
+ }
+
+ // Write header: length + nonce
+ memcpy(Out, &InLength, 4);
+ memcpy(Out + 4, EncryptNonce, NonceBytes);
+ // Write tag after ciphertext
+ memcpy(Out + 4 + NonceBytes + CipherLen, Tag, TagBytes);
+
+ BCryptDestroyKey(hKey);
+ BCryptCloseAlgorithmProvider(hAlg, 0);
+
+ return 4 + NonceBytes + static_cast<int32_t>(CipherLen) + TagBytes;
+#else
+ if (EVP_EncryptInit_ex(EncCtx, nullptr, nullptr, Key, EncryptNonce) != 1)
+ {
+ HasErrors = true;
+ return 0;
+ }
+
+ int32_t Offset = 0;
+ // Write length
+ memcpy(Out + Offset, &InLength, 4);
+ Offset += 4;
+ // Write nonce
+ memcpy(Out + Offset, EncryptNonce, NonceBytes);
+ Offset += NonceBytes;
+
+ // Encrypt
+ int OutLen = 0;
+ if (EVP_EncryptUpdate(EncCtx, Out + Offset, &OutLen, static_cast<const uint8_t*>(In), InLength) != 1)
+ {
+ HasErrors = true;
+ return 0;
+ }
+ Offset += OutLen;
+
+ // Finalize
+ int FinalLen = 0;
+ if (EVP_EncryptFinal_ex(EncCtx, Out + Offset, &FinalLen) != 1)
+ {
+ HasErrors = true;
+ return 0;
+ }
+ Offset += FinalLen;
+
+ // Get tag
+ if (EVP_CIPHER_CTX_ctrl(EncCtx, EVP_CTRL_GCM_GET_TAG, TagBytes, Out + Offset) != 1)
+ {
+ HasErrors = true;
+ return 0;
+ }
+ Offset += TagBytes;
+
+ return Offset;
+#endif
+ }
+
+ // Decrypt a message. Returns decrypted data length, or 0 on failure.
+ // Input must be [ciphertext][tag], with nonce provided separately.
+ int32_t DecryptMessage(void* Out, const uint8_t* Nonce, const uint8_t* CipherAndTag, int32_t DataLength)
+ {
+#if ZEN_PLATFORM_WINDOWS
+ BCRYPT_ALG_HANDLE hAlg = nullptr;
+ BCRYPT_KEY_HANDLE hKey = nullptr;
+
+ BCryptOpenAlgorithmProvider(&hAlg, BCRYPT_AES_ALGORITHM, nullptr, 0);
+ BCryptSetProperty(hAlg, BCRYPT_CHAINING_MODE, (PUCHAR)BCRYPT_CHAIN_MODE_GCM, sizeof(BCRYPT_CHAIN_MODE_GCM), 0);
+ BCryptGenerateSymmetricKey(hAlg, &hKey, nullptr, 0, (PUCHAR)Key, KeySize, 0);
+
+ BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO AuthInfo;
+ BCRYPT_INIT_AUTH_MODE_INFO(AuthInfo);
+ AuthInfo.pbNonce = const_cast<uint8_t*>(Nonce);
+ AuthInfo.cbNonce = NonceBytes;
+ AuthInfo.pbTag = const_cast<uint8_t*>(CipherAndTag + DataLength);
+ AuthInfo.cbTag = TagBytes;
+
+ ULONG PlainLen = 0;
+ NTSTATUS Status = BCryptDecrypt(hKey,
+ (PUCHAR)CipherAndTag,
+ (ULONG)DataLength,
+ &AuthInfo,
+ nullptr,
+ 0,
+ (PUCHAR)Out,
+ (ULONG)DataLength,
+ &PlainLen,
+ 0);
+
+ BCryptDestroyKey(hKey);
+ BCryptCloseAlgorithmProvider(hAlg, 0);
+
+ if (!BCRYPT_SUCCESS(Status))
+ {
+ HasErrors = true;
+ return 0;
+ }
+
+ return static_cast<int32_t>(PlainLen);
+#else
+ if (EVP_DecryptInit_ex(DecCtx, nullptr, nullptr, Key, Nonce) != 1)
+ {
+ HasErrors = true;
+ return 0;
+ }
+
+ int OutLen = 0;
+ if (EVP_DecryptUpdate(DecCtx, static_cast<uint8_t*>(Out), &OutLen, CipherAndTag, DataLength) != 1)
+ {
+ HasErrors = true;
+ return 0;
+ }
+
+ // Set the tag for verification
+ if (EVP_CIPHER_CTX_ctrl(DecCtx, EVP_CTRL_GCM_SET_TAG, TagBytes, const_cast<uint8_t*>(CipherAndTag + DataLength)) != 1)
+ {
+ HasErrors = true;
+ return 0;
+ }
+
+ int FinalLen = 0;
+ if (EVP_DecryptFinal_ex(DecCtx, static_cast<uint8_t*>(Out) + OutLen, &FinalLen) != 1)
+ {
+ HasErrors = true;
+ return 0;
+ }
+
+ return OutLen + FinalLen;
+#endif
+ }
+};
+
+AesComputeTransport::AesComputeTransport(const uint8_t (&Key)[KeySize], std::unique_ptr<ComputeTransport> InnerTransport)
+: m_Crypto(std::make_unique<CryptoContext>(Key))
+, m_Inner(std::move(InnerTransport))
+{
+}
+
+AesComputeTransport::~AesComputeTransport()
+{
+ Close();
+}
+
+bool
+AesComputeTransport::IsValid() const
+{
+ return m_Inner && m_Inner->IsValid() && m_Crypto && !m_Crypto->HasErrors && !m_IsClosed;
+}
+
+size_t
+AesComputeTransport::Send(const void* Data, size_t Size)
+{
+ ZEN_TRACE_CPU("AesComputeTransport::Send");
+
+ if (!IsValid())
+ {
+ return 0;
+ }
+
+ std::lock_guard<std::mutex> Lock(m_Lock);
+
+ const int32_t DataLength = static_cast<int32_t>(Size);
+ const size_t MessageLength = 4 + NonceBytes + Size + TagBytes;
+
+ if (m_EncryptBuffer.size() < MessageLength)
+ {
+ m_EncryptBuffer.resize(MessageLength);
+ }
+
+ const int32_t EncryptedLen = m_Crypto->EncryptMessage(m_EncryptBuffer.data(), Data, DataLength);
+ if (EncryptedLen == 0)
+ {
+ return 0;
+ }
+
+ if (!m_Inner->SendMessage(m_EncryptBuffer.data(), static_cast<size_t>(EncryptedLen)))
+ {
+ return 0;
+ }
+
+ return Size;
+}
+
+size_t
+AesComputeTransport::Recv(void* Data, size_t Size)
+{
+ if (!IsValid())
+ {
+ return 0;
+ }
+
+ // AES-GCM decrypts entire messages at once, but the caller may request fewer bytes
+ // than the decrypted message contains. Excess bytes are buffered in m_RemainingData
+ // and returned on subsequent Recv calls without another decryption round-trip.
+ ZEN_TRACE_CPU("AesComputeTransport::Recv");
+
+ std::lock_guard<std::mutex> Lock(m_Lock);
+
+ if (!m_RemainingData.empty())
+ {
+ const size_t Available = m_RemainingData.size() - m_RemainingOffset;
+ const size_t ToCopy = std::min(Available, Size);
+
+ memcpy(Data, m_RemainingData.data() + m_RemainingOffset, ToCopy);
+ m_RemainingOffset += ToCopy;
+
+ if (m_RemainingOffset >= m_RemainingData.size())
+ {
+ m_RemainingData.clear();
+ m_RemainingOffset = 0;
+ }
+
+ return ToCopy;
+ }
+
+ // Receive packet header: [length(4B)][nonce(12B)]
+ struct PacketHeader
+ {
+ int32_t DataLength = 0;
+ uint8_t Nonce[NonceBytes] = {};
+ } Header;
+
+ if (!m_Inner->RecvMessage(&Header, sizeof(Header)))
+ {
+ return 0;
+ }
+
+ // Validate DataLength to prevent OOM from malicious/corrupt peers
+ static constexpr int32_t MaxDataLength = 64 * 1024 * 1024; // 64 MiB
+
+ if (Header.DataLength <= 0 || Header.DataLength > MaxDataLength)
+ {
+ ZEN_WARN("AES recv: invalid DataLength {} from peer", Header.DataLength);
+ return 0;
+ }
+
+ // Receive ciphertext + tag
+ const size_t MessageLength = static_cast<size_t>(Header.DataLength) + TagBytes;
+
+ if (m_EncryptBuffer.size() < MessageLength)
+ {
+ m_EncryptBuffer.resize(MessageLength);
+ }
+
+ if (!m_Inner->RecvMessage(m_EncryptBuffer.data(), MessageLength))
+ {
+ return 0;
+ }
+
+ // Decrypt
+ const size_t BytesToReturn = std::min(static_cast<size_t>(Header.DataLength), Size);
+
+ // We need a temporary buffer for decryption if we can't decrypt directly into output
+ std::vector<uint8_t> DecryptedBuf(static_cast<size_t>(Header.DataLength));
+
+ const int32_t Decrypted = m_Crypto->DecryptMessage(DecryptedBuf.data(), Header.Nonce, m_EncryptBuffer.data(), Header.DataLength);
+ if (Decrypted == 0)
+ {
+ return 0;
+ }
+
+ memcpy(Data, DecryptedBuf.data(), BytesToReturn);
+
+ // Store remaining data if we couldn't return everything
+ if (static_cast<size_t>(Header.DataLength) > BytesToReturn)
+ {
+ m_RemainingOffset = 0;
+ m_RemainingData.assign(DecryptedBuf.begin() + BytesToReturn, DecryptedBuf.begin() + Header.DataLength);
+ }
+
+ return BytesToReturn;
+}
+
+void
+AesComputeTransport::MarkComplete()
+{
+ if (IsValid())
+ {
+ m_Inner->MarkComplete();
+ }
+}
+
+void
+AesComputeTransport::Close()
+{
+ if (!m_IsClosed)
+ {
+ if (m_Inner && m_Inner->IsValid())
+ {
+ m_Inner->Close();
+ }
+ m_IsClosed = true;
+ }
+}
+
+} // namespace zen::horde