// Copyright Epic Games, Inc. All Rights Reserved. #include "hordetransportaes.h" #include #include #include #include #include #if ZEN_PLATFORM_WINDOWS # include # include # pragma comment(lib, "bcrypt.lib") #else ZEN_THIRD_PARTY_INCLUDES_START # include # include ZEN_THIRD_PARTY_INCLUDES_END #endif namespace zen::horde { struct AesComputeTransport::CryptoContext { uint8_t Key[KeySize] = {}; uint8_t EncryptNonce[NonceBytes] = {}; uint8_t DecryptNonce[NonceBytes] = {}; bool HasErrors = false; #if !ZEN_PLATFORM_WINDOWS EVP_CIPHER_CTX* EncCtx = nullptr; EVP_CIPHER_CTX* DecCtx = nullptr; #endif CryptoContext(const uint8_t (&InKey)[KeySize]) { memcpy(Key, InKey, KeySize); // The encrypt nonce is randomly initialized and then deterministically mutated // per message via UpdateNonce(). The decrypt nonce is not used — it comes from // the wire (each received message carries its own nonce in the header). std::random_device Rd; std::mt19937 Gen(Rd()); std::uniform_int_distribution Dist(0, 255); for (auto& Byte : EncryptNonce) { Byte = static_cast(Dist(Gen)); } #if !ZEN_PLATFORM_WINDOWS // Drain any stale OpenSSL errors while (ERR_get_error() != 0) { } EncCtx = EVP_CIPHER_CTX_new(); EVP_EncryptInit_ex(EncCtx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr); DecCtx = EVP_CIPHER_CTX_new(); EVP_DecryptInit_ex(DecCtx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr); #endif } ~CryptoContext() { #if ZEN_PLATFORM_WINDOWS SecureZeroMemory(Key, sizeof(Key)); SecureZeroMemory(EncryptNonce, sizeof(EncryptNonce)); SecureZeroMemory(DecryptNonce, sizeof(DecryptNonce)); #else OPENSSL_cleanse(Key, sizeof(Key)); OPENSSL_cleanse(EncryptNonce, sizeof(EncryptNonce)); OPENSSL_cleanse(DecryptNonce, sizeof(DecryptNonce)); if (EncCtx) { EVP_CIPHER_CTX_free(EncCtx); } if (DecCtx) { EVP_CIPHER_CTX_free(DecCtx); } #endif } void UpdateNonce() { uint32_t* N32 = reinterpret_cast(EncryptNonce); N32[0]++; N32[1]--; N32[2] = N32[0] ^ N32[1]; } // Returns total encrypted message size, or 0 on failure // Output format: [length(4B)][nonce(12B)][ciphertext][tag(16B)] int32_t EncryptMessage(uint8_t* Out, const void* In, int32_t InLength) { UpdateNonce(); // On Windows, BCrypt algorithm/key handles are created per call. This is simpler than // caching but has some overhead. For our use case (relatively large, infrequent messages) // this is acceptable. #if ZEN_PLATFORM_WINDOWS BCRYPT_ALG_HANDLE hAlg = nullptr; BCRYPT_KEY_HANDLE hKey = nullptr; BCryptOpenAlgorithmProvider(&hAlg, BCRYPT_AES_ALGORITHM, nullptr, 0); BCryptSetProperty(hAlg, BCRYPT_CHAINING_MODE, (PUCHAR)BCRYPT_CHAIN_MODE_GCM, sizeof(BCRYPT_CHAIN_MODE_GCM), 0); BCryptGenerateSymmetricKey(hAlg, &hKey, nullptr, 0, (PUCHAR)Key, KeySize, 0); BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO AuthInfo; BCRYPT_INIT_AUTH_MODE_INFO(AuthInfo); AuthInfo.pbNonce = EncryptNonce; AuthInfo.cbNonce = NonceBytes; uint8_t Tag[TagBytes] = {}; AuthInfo.pbTag = Tag; AuthInfo.cbTag = TagBytes; ULONG CipherLen = 0; NTSTATUS Status = BCryptEncrypt(hKey, (PUCHAR)In, (ULONG)InLength, &AuthInfo, nullptr, 0, Out + 4 + NonceBytes, (ULONG)InLength, &CipherLen, 0); if (!BCRYPT_SUCCESS(Status)) { HasErrors = true; BCryptDestroyKey(hKey); BCryptCloseAlgorithmProvider(hAlg, 0); return 0; } // Write header: length + nonce memcpy(Out, &InLength, 4); memcpy(Out + 4, EncryptNonce, NonceBytes); // Write tag after ciphertext memcpy(Out + 4 + NonceBytes + CipherLen, Tag, TagBytes); BCryptDestroyKey(hKey); BCryptCloseAlgorithmProvider(hAlg, 0); return 4 + NonceBytes + static_cast(CipherLen) + TagBytes; #else if (EVP_EncryptInit_ex(EncCtx, nullptr, nullptr, Key, EncryptNonce) != 1) { HasErrors = true; return 0; } int32_t Offset = 0; // Write length memcpy(Out + Offset, &InLength, 4); Offset += 4; // Write nonce memcpy(Out + Offset, EncryptNonce, NonceBytes); Offset += NonceBytes; // Encrypt int OutLen = 0; if (EVP_EncryptUpdate(EncCtx, Out + Offset, &OutLen, static_cast(In), InLength) != 1) { HasErrors = true; return 0; } Offset += OutLen; // Finalize int FinalLen = 0; if (EVP_EncryptFinal_ex(EncCtx, Out + Offset, &FinalLen) != 1) { HasErrors = true; return 0; } Offset += FinalLen; // Get tag if (EVP_CIPHER_CTX_ctrl(EncCtx, EVP_CTRL_GCM_GET_TAG, TagBytes, Out + Offset) != 1) { HasErrors = true; return 0; } Offset += TagBytes; return Offset; #endif } // Decrypt a message. Returns decrypted data length, or 0 on failure. // Input must be [ciphertext][tag], with nonce provided separately. int32_t DecryptMessage(void* Out, const uint8_t* Nonce, const uint8_t* CipherAndTag, int32_t DataLength) { #if ZEN_PLATFORM_WINDOWS BCRYPT_ALG_HANDLE hAlg = nullptr; BCRYPT_KEY_HANDLE hKey = nullptr; BCryptOpenAlgorithmProvider(&hAlg, BCRYPT_AES_ALGORITHM, nullptr, 0); BCryptSetProperty(hAlg, BCRYPT_CHAINING_MODE, (PUCHAR)BCRYPT_CHAIN_MODE_GCM, sizeof(BCRYPT_CHAIN_MODE_GCM), 0); BCryptGenerateSymmetricKey(hAlg, &hKey, nullptr, 0, (PUCHAR)Key, KeySize, 0); BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO AuthInfo; BCRYPT_INIT_AUTH_MODE_INFO(AuthInfo); AuthInfo.pbNonce = const_cast(Nonce); AuthInfo.cbNonce = NonceBytes; AuthInfo.pbTag = const_cast(CipherAndTag + DataLength); AuthInfo.cbTag = TagBytes; ULONG PlainLen = 0; NTSTATUS Status = BCryptDecrypt(hKey, (PUCHAR)CipherAndTag, (ULONG)DataLength, &AuthInfo, nullptr, 0, (PUCHAR)Out, (ULONG)DataLength, &PlainLen, 0); BCryptDestroyKey(hKey); BCryptCloseAlgorithmProvider(hAlg, 0); if (!BCRYPT_SUCCESS(Status)) { HasErrors = true; return 0; } return static_cast(PlainLen); #else if (EVP_DecryptInit_ex(DecCtx, nullptr, nullptr, Key, Nonce) != 1) { HasErrors = true; return 0; } int OutLen = 0; if (EVP_DecryptUpdate(DecCtx, static_cast(Out), &OutLen, CipherAndTag, DataLength) != 1) { HasErrors = true; return 0; } // Set the tag for verification if (EVP_CIPHER_CTX_ctrl(DecCtx, EVP_CTRL_GCM_SET_TAG, TagBytes, const_cast(CipherAndTag + DataLength)) != 1) { HasErrors = true; return 0; } int FinalLen = 0; if (EVP_DecryptFinal_ex(DecCtx, static_cast(Out) + OutLen, &FinalLen) != 1) { HasErrors = true; return 0; } return OutLen + FinalLen; #endif } }; AesComputeTransport::AesComputeTransport(const uint8_t (&Key)[KeySize], std::unique_ptr InnerTransport) : m_Crypto(std::make_unique(Key)) , m_Inner(std::move(InnerTransport)) { } AesComputeTransport::~AesComputeTransport() { Close(); } bool AesComputeTransport::IsValid() const { return m_Inner && m_Inner->IsValid() && m_Crypto && !m_Crypto->HasErrors && !m_IsClosed; } size_t AesComputeTransport::Send(const void* Data, size_t Size) { ZEN_TRACE_CPU("AesComputeTransport::Send"); if (!IsValid()) { return 0; } std::lock_guard Lock(m_Lock); const int32_t DataLength = static_cast(Size); const size_t MessageLength = 4 + NonceBytes + Size + TagBytes; if (m_EncryptBuffer.size() < MessageLength) { m_EncryptBuffer.resize(MessageLength); } const int32_t EncryptedLen = m_Crypto->EncryptMessage(m_EncryptBuffer.data(), Data, DataLength); if (EncryptedLen == 0) { return 0; } if (!m_Inner->SendMessage(m_EncryptBuffer.data(), static_cast(EncryptedLen))) { return 0; } return Size; } size_t AesComputeTransport::Recv(void* Data, size_t Size) { if (!IsValid()) { return 0; } // AES-GCM decrypts entire messages at once, but the caller may request fewer bytes // than the decrypted message contains. Excess bytes are buffered in m_RemainingData // and returned on subsequent Recv calls without another decryption round-trip. ZEN_TRACE_CPU("AesComputeTransport::Recv"); std::lock_guard Lock(m_Lock); if (!m_RemainingData.empty()) { const size_t Available = m_RemainingData.size() - m_RemainingOffset; const size_t ToCopy = std::min(Available, Size); memcpy(Data, m_RemainingData.data() + m_RemainingOffset, ToCopy); m_RemainingOffset += ToCopy; if (m_RemainingOffset >= m_RemainingData.size()) { m_RemainingData.clear(); m_RemainingOffset = 0; } return ToCopy; } // Receive packet header: [length(4B)][nonce(12B)] struct PacketHeader { int32_t DataLength = 0; uint8_t Nonce[NonceBytes] = {}; } Header; if (!m_Inner->RecvMessage(&Header, sizeof(Header))) { return 0; } // Validate DataLength to prevent OOM from malicious/corrupt peers static constexpr int32_t MaxDataLength = 64 * 1024 * 1024; // 64 MiB if (Header.DataLength <= 0 || Header.DataLength > MaxDataLength) { ZEN_WARN("AES recv: invalid DataLength {} from peer", Header.DataLength); return 0; } // Receive ciphertext + tag const size_t MessageLength = static_cast(Header.DataLength) + TagBytes; if (m_EncryptBuffer.size() < MessageLength) { m_EncryptBuffer.resize(MessageLength); } if (!m_Inner->RecvMessage(m_EncryptBuffer.data(), MessageLength)) { return 0; } // Decrypt const size_t BytesToReturn = std::min(static_cast(Header.DataLength), Size); // We need a temporary buffer for decryption if we can't decrypt directly into output std::vector DecryptedBuf(static_cast(Header.DataLength)); const int32_t Decrypted = m_Crypto->DecryptMessage(DecryptedBuf.data(), Header.Nonce, m_EncryptBuffer.data(), Header.DataLength); if (Decrypted == 0) { return 0; } memcpy(Data, DecryptedBuf.data(), BytesToReturn); // Store remaining data if we couldn't return everything if (static_cast(Header.DataLength) > BytesToReturn) { m_RemainingOffset = 0; m_RemainingData.assign(DecryptedBuf.begin() + BytesToReturn, DecryptedBuf.begin() + Header.DataLength); } return BytesToReturn; } void AesComputeTransport::MarkComplete() { if (IsValid()) { m_Inner->MarkComplete(); } } void AesComputeTransport::Close() { if (!m_IsClosed) { if (m_Inner && m_Inner->IsValid()) { m_Inner->Close(); } m_IsClosed = true; } } } // namespace zen::horde