diff options
| author | Stefan Boberg <[email protected]> | 2026-03-04 14:13:46 +0100 |
|---|---|---|
| committer | GitHub Enterprise <[email protected]> | 2026-03-04 14:13:46 +0100 |
| commit | 0763d09a81e5a1d3df11763a7ec75e7860c9510a (patch) | |
| tree | 074575ba6ea259044a179eab0bb396d37268fb09 /src/zenhorde/hordetransportaes.cpp | |
| parent | native xmake toolchain definition for UE-clang (#805) (diff) | |
| download | zen-0763d09a81e5a1d3df11763a7ec75e7860c9510a.tar.xz zen-0763d09a81e5a1d3df11763a7ec75e7860c9510a.zip | |
compute orchestration (#763)
- Added local process runners for Linux/Wine, Mac with some sandboxing support
- Horde & Nomad provisioning for development and testing
- Client session queues with lifecycle management (active/draining/cancelled), automatic retry with configurable limits, and manual reschedule API
- Improved web UI for orchestrator, compute, and hub dashboards with WebSocket push updates
- Some security hardening
- Improved scalability and `zen exec` command
Still experimental - compute support is disabled by default
Diffstat (limited to 'src/zenhorde/hordetransportaes.cpp')
| -rw-r--r-- | src/zenhorde/hordetransportaes.cpp | 425 |
1 files changed, 425 insertions, 0 deletions
diff --git a/src/zenhorde/hordetransportaes.cpp b/src/zenhorde/hordetransportaes.cpp new file mode 100644 index 000000000..986dd3705 --- /dev/null +++ b/src/zenhorde/hordetransportaes.cpp @@ -0,0 +1,425 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "hordetransportaes.h" + +#include <zencore/logging.h> +#include <zencore/trace.h> + +#include <algorithm> +#include <cstring> +#include <random> + +#if ZEN_PLATFORM_WINDOWS +# include <zencore/windows.h> +# include <bcrypt.h> +# pragma comment(lib, "Bcrypt.lib") +#else +ZEN_THIRD_PARTY_INCLUDES_START +# include <openssl/evp.h> +# include <openssl/err.h> +ZEN_THIRD_PARTY_INCLUDES_END +#endif + +namespace zen::horde { + +struct AesComputeTransport::CryptoContext +{ + uint8_t Key[KeySize] = {}; + uint8_t EncryptNonce[NonceBytes] = {}; + uint8_t DecryptNonce[NonceBytes] = {}; + bool HasErrors = false; + +#if !ZEN_PLATFORM_WINDOWS + EVP_CIPHER_CTX* EncCtx = nullptr; + EVP_CIPHER_CTX* DecCtx = nullptr; +#endif + + CryptoContext(const uint8_t (&InKey)[KeySize]) + { + memcpy(Key, InKey, KeySize); + + // The encrypt nonce is randomly initialized and then deterministically mutated + // per message via UpdateNonce(). The decrypt nonce is not used — it comes from + // the wire (each received message carries its own nonce in the header). + std::random_device Rd; + std::mt19937 Gen(Rd()); + std::uniform_int_distribution<int> Dist(0, 255); + for (auto& Byte : EncryptNonce) + { + Byte = static_cast<uint8_t>(Dist(Gen)); + } + +#if !ZEN_PLATFORM_WINDOWS + // Drain any stale OpenSSL errors + while (ERR_get_error() != 0) + { + } + + EncCtx = EVP_CIPHER_CTX_new(); + EVP_EncryptInit_ex(EncCtx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr); + + DecCtx = EVP_CIPHER_CTX_new(); + EVP_DecryptInit_ex(DecCtx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr); +#endif + } + + ~CryptoContext() + { +#if ZEN_PLATFORM_WINDOWS + SecureZeroMemory(Key, sizeof(Key)); + SecureZeroMemory(EncryptNonce, sizeof(EncryptNonce)); + SecureZeroMemory(DecryptNonce, sizeof(DecryptNonce)); +#else + OPENSSL_cleanse(Key, sizeof(Key)); + OPENSSL_cleanse(EncryptNonce, sizeof(EncryptNonce)); + OPENSSL_cleanse(DecryptNonce, sizeof(DecryptNonce)); + + if (EncCtx) + { + EVP_CIPHER_CTX_free(EncCtx); + } + if (DecCtx) + { + EVP_CIPHER_CTX_free(DecCtx); + } +#endif + } + + void UpdateNonce() + { + uint32_t* N32 = reinterpret_cast<uint32_t*>(EncryptNonce); + N32[0]++; + N32[1]--; + N32[2] = N32[0] ^ N32[1]; + } + + // Returns total encrypted message size, or 0 on failure + // Output format: [length(4B)][nonce(12B)][ciphertext][tag(16B)] + int32_t EncryptMessage(uint8_t* Out, const void* In, int32_t InLength) + { + UpdateNonce(); + + // On Windows, BCrypt algorithm/key handles are created per call. This is simpler than + // caching but has some overhead. For our use case (relatively large, infrequent messages) + // this is acceptable. +#if ZEN_PLATFORM_WINDOWS + BCRYPT_ALG_HANDLE hAlg = nullptr; + BCRYPT_KEY_HANDLE hKey = nullptr; + + BCryptOpenAlgorithmProvider(&hAlg, BCRYPT_AES_ALGORITHM, nullptr, 0); + BCryptSetProperty(hAlg, BCRYPT_CHAINING_MODE, (PUCHAR)BCRYPT_CHAIN_MODE_GCM, sizeof(BCRYPT_CHAIN_MODE_GCM), 0); + BCryptGenerateSymmetricKey(hAlg, &hKey, nullptr, 0, (PUCHAR)Key, KeySize, 0); + + BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO AuthInfo; + BCRYPT_INIT_AUTH_MODE_INFO(AuthInfo); + AuthInfo.pbNonce = EncryptNonce; + AuthInfo.cbNonce = NonceBytes; + uint8_t Tag[TagBytes] = {}; + AuthInfo.pbTag = Tag; + AuthInfo.cbTag = TagBytes; + + ULONG CipherLen = 0; + NTSTATUS Status = + BCryptEncrypt(hKey, (PUCHAR)In, (ULONG)InLength, &AuthInfo, nullptr, 0, Out + 4 + NonceBytes, (ULONG)InLength, &CipherLen, 0); + + if (!BCRYPT_SUCCESS(Status)) + { + HasErrors = true; + BCryptDestroyKey(hKey); + BCryptCloseAlgorithmProvider(hAlg, 0); + return 0; + } + + // Write header: length + nonce + memcpy(Out, &InLength, 4); + memcpy(Out + 4, EncryptNonce, NonceBytes); + // Write tag after ciphertext + memcpy(Out + 4 + NonceBytes + CipherLen, Tag, TagBytes); + + BCryptDestroyKey(hKey); + BCryptCloseAlgorithmProvider(hAlg, 0); + + return 4 + NonceBytes + static_cast<int32_t>(CipherLen) + TagBytes; +#else + if (EVP_EncryptInit_ex(EncCtx, nullptr, nullptr, Key, EncryptNonce) != 1) + { + HasErrors = true; + return 0; + } + + int32_t Offset = 0; + // Write length + memcpy(Out + Offset, &InLength, 4); + Offset += 4; + // Write nonce + memcpy(Out + Offset, EncryptNonce, NonceBytes); + Offset += NonceBytes; + + // Encrypt + int OutLen = 0; + if (EVP_EncryptUpdate(EncCtx, Out + Offset, &OutLen, static_cast<const uint8_t*>(In), InLength) != 1) + { + HasErrors = true; + return 0; + } + Offset += OutLen; + + // Finalize + int FinalLen = 0; + if (EVP_EncryptFinal_ex(EncCtx, Out + Offset, &FinalLen) != 1) + { + HasErrors = true; + return 0; + } + Offset += FinalLen; + + // Get tag + if (EVP_CIPHER_CTX_ctrl(EncCtx, EVP_CTRL_GCM_GET_TAG, TagBytes, Out + Offset) != 1) + { + HasErrors = true; + return 0; + } + Offset += TagBytes; + + return Offset; +#endif + } + + // Decrypt a message. Returns decrypted data length, or 0 on failure. + // Input must be [ciphertext][tag], with nonce provided separately. + int32_t DecryptMessage(void* Out, const uint8_t* Nonce, const uint8_t* CipherAndTag, int32_t DataLength) + { +#if ZEN_PLATFORM_WINDOWS + BCRYPT_ALG_HANDLE hAlg = nullptr; + BCRYPT_KEY_HANDLE hKey = nullptr; + + BCryptOpenAlgorithmProvider(&hAlg, BCRYPT_AES_ALGORITHM, nullptr, 0); + BCryptSetProperty(hAlg, BCRYPT_CHAINING_MODE, (PUCHAR)BCRYPT_CHAIN_MODE_GCM, sizeof(BCRYPT_CHAIN_MODE_GCM), 0); + BCryptGenerateSymmetricKey(hAlg, &hKey, nullptr, 0, (PUCHAR)Key, KeySize, 0); + + BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO AuthInfo; + BCRYPT_INIT_AUTH_MODE_INFO(AuthInfo); + AuthInfo.pbNonce = const_cast<uint8_t*>(Nonce); + AuthInfo.cbNonce = NonceBytes; + AuthInfo.pbTag = const_cast<uint8_t*>(CipherAndTag + DataLength); + AuthInfo.cbTag = TagBytes; + + ULONG PlainLen = 0; + NTSTATUS Status = BCryptDecrypt(hKey, + (PUCHAR)CipherAndTag, + (ULONG)DataLength, + &AuthInfo, + nullptr, + 0, + (PUCHAR)Out, + (ULONG)DataLength, + &PlainLen, + 0); + + BCryptDestroyKey(hKey); + BCryptCloseAlgorithmProvider(hAlg, 0); + + if (!BCRYPT_SUCCESS(Status)) + { + HasErrors = true; + return 0; + } + + return static_cast<int32_t>(PlainLen); +#else + if (EVP_DecryptInit_ex(DecCtx, nullptr, nullptr, Key, Nonce) != 1) + { + HasErrors = true; + return 0; + } + + int OutLen = 0; + if (EVP_DecryptUpdate(DecCtx, static_cast<uint8_t*>(Out), &OutLen, CipherAndTag, DataLength) != 1) + { + HasErrors = true; + return 0; + } + + // Set the tag for verification + if (EVP_CIPHER_CTX_ctrl(DecCtx, EVP_CTRL_GCM_SET_TAG, TagBytes, const_cast<uint8_t*>(CipherAndTag + DataLength)) != 1) + { + HasErrors = true; + return 0; + } + + int FinalLen = 0; + if (EVP_DecryptFinal_ex(DecCtx, static_cast<uint8_t*>(Out) + OutLen, &FinalLen) != 1) + { + HasErrors = true; + return 0; + } + + return OutLen + FinalLen; +#endif + } +}; + +AesComputeTransport::AesComputeTransport(const uint8_t (&Key)[KeySize], std::unique_ptr<ComputeTransport> InnerTransport) +: m_Crypto(std::make_unique<CryptoContext>(Key)) +, m_Inner(std::move(InnerTransport)) +{ +} + +AesComputeTransport::~AesComputeTransport() +{ + Close(); +} + +bool +AesComputeTransport::IsValid() const +{ + return m_Inner && m_Inner->IsValid() && m_Crypto && !m_Crypto->HasErrors && !m_IsClosed; +} + +size_t +AesComputeTransport::Send(const void* Data, size_t Size) +{ + ZEN_TRACE_CPU("AesComputeTransport::Send"); + + if (!IsValid()) + { + return 0; + } + + std::lock_guard<std::mutex> Lock(m_Lock); + + const int32_t DataLength = static_cast<int32_t>(Size); + const size_t MessageLength = 4 + NonceBytes + Size + TagBytes; + + if (m_EncryptBuffer.size() < MessageLength) + { + m_EncryptBuffer.resize(MessageLength); + } + + const int32_t EncryptedLen = m_Crypto->EncryptMessage(m_EncryptBuffer.data(), Data, DataLength); + if (EncryptedLen == 0) + { + return 0; + } + + if (!m_Inner->SendMessage(m_EncryptBuffer.data(), static_cast<size_t>(EncryptedLen))) + { + return 0; + } + + return Size; +} + +size_t +AesComputeTransport::Recv(void* Data, size_t Size) +{ + if (!IsValid()) + { + return 0; + } + + // AES-GCM decrypts entire messages at once, but the caller may request fewer bytes + // than the decrypted message contains. Excess bytes are buffered in m_RemainingData + // and returned on subsequent Recv calls without another decryption round-trip. + ZEN_TRACE_CPU("AesComputeTransport::Recv"); + + std::lock_guard<std::mutex> Lock(m_Lock); + + if (!m_RemainingData.empty()) + { + const size_t Available = m_RemainingData.size() - m_RemainingOffset; + const size_t ToCopy = std::min(Available, Size); + + memcpy(Data, m_RemainingData.data() + m_RemainingOffset, ToCopy); + m_RemainingOffset += ToCopy; + + if (m_RemainingOffset >= m_RemainingData.size()) + { + m_RemainingData.clear(); + m_RemainingOffset = 0; + } + + return ToCopy; + } + + // Receive packet header: [length(4B)][nonce(12B)] + struct PacketHeader + { + int32_t DataLength = 0; + uint8_t Nonce[NonceBytes] = {}; + } Header; + + if (!m_Inner->RecvMessage(&Header, sizeof(Header))) + { + return 0; + } + + // Validate DataLength to prevent OOM from malicious/corrupt peers + static constexpr int32_t MaxDataLength = 64 * 1024 * 1024; // 64 MiB + + if (Header.DataLength <= 0 || Header.DataLength > MaxDataLength) + { + ZEN_WARN("AES recv: invalid DataLength {} from peer", Header.DataLength); + return 0; + } + + // Receive ciphertext + tag + const size_t MessageLength = static_cast<size_t>(Header.DataLength) + TagBytes; + + if (m_EncryptBuffer.size() < MessageLength) + { + m_EncryptBuffer.resize(MessageLength); + } + + if (!m_Inner->RecvMessage(m_EncryptBuffer.data(), MessageLength)) + { + return 0; + } + + // Decrypt + const size_t BytesToReturn = std::min(static_cast<size_t>(Header.DataLength), Size); + + // We need a temporary buffer for decryption if we can't decrypt directly into output + std::vector<uint8_t> DecryptedBuf(static_cast<size_t>(Header.DataLength)); + + const int32_t Decrypted = m_Crypto->DecryptMessage(DecryptedBuf.data(), Header.Nonce, m_EncryptBuffer.data(), Header.DataLength); + if (Decrypted == 0) + { + return 0; + } + + memcpy(Data, DecryptedBuf.data(), BytesToReturn); + + // Store remaining data if we couldn't return everything + if (static_cast<size_t>(Header.DataLength) > BytesToReturn) + { + m_RemainingOffset = 0; + m_RemainingData.assign(DecryptedBuf.begin() + BytesToReturn, DecryptedBuf.begin() + Header.DataLength); + } + + return BytesToReturn; +} + +void +AesComputeTransport::MarkComplete() +{ + if (IsValid()) + { + m_Inner->MarkComplete(); + } +} + +void +AesComputeTransport::Close() +{ + if (!m_IsClosed) + { + if (m_Inner && m_Inner->IsValid()) + { + m_Inner->Close(); + } + m_IsClosed = true; + } +} + +} // namespace zen::horde |