diff options
Diffstat (limited to 'src/zenhorde/hordetransportaes.cpp')
| -rw-r--r-- | src/zenhorde/hordetransportaes.cpp | 718 |
1 files changed, 419 insertions, 299 deletions
diff --git a/src/zenhorde/hordetransportaes.cpp b/src/zenhorde/hordetransportaes.cpp index 505b6bde7..0b94a4397 100644 --- a/src/zenhorde/hordetransportaes.cpp +++ b/src/zenhorde/hordetransportaes.cpp @@ -5,9 +5,12 @@ #include <zencore/logging.h> #include <zencore/trace.h> +ZEN_THIRD_PARTY_INCLUDES_START +#include <asio.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + #include <algorithm> #include <cstring> -#include <random> #if ZEN_PLATFORM_WINDOWS # include <zencore/windows.h> @@ -22,315 +25,410 @@ ZEN_THIRD_PARTY_INCLUDES_END namespace zen::horde { -struct AesComputeTransport::CryptoContext -{ - uint8_t Key[KeySize] = {}; - uint8_t EncryptNonce[NonceBytes] = {}; - uint8_t DecryptNonce[NonceBytes] = {}; - bool HasErrors = false; - -#if !ZEN_PLATFORM_WINDOWS - EVP_CIPHER_CTX* EncCtx = nullptr; - EVP_CIPHER_CTX* DecCtx = nullptr; -#endif - - CryptoContext(const uint8_t (&InKey)[KeySize]) - { - memcpy(Key, InKey, KeySize); - - // The encrypt nonce is randomly initialized and then deterministically mutated - // per message via UpdateNonce(). The decrypt nonce is not used — it comes from - // the wire (each received message carries its own nonce in the header). - std::random_device Rd; - std::mt19937 Gen(Rd()); - std::uniform_int_distribution<int> Dist(0, 255); - for (auto& Byte : EncryptNonce) - { - Byte = static_cast<uint8_t>(Dist(Gen)); - } - -#if !ZEN_PLATFORM_WINDOWS - // Drain any stale OpenSSL errors - while (ERR_get_error() != 0) - { - } - - EncCtx = EVP_CIPHER_CTX_new(); - EVP_EncryptInit_ex(EncCtx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr); - - DecCtx = EVP_CIPHER_CTX_new(); - EVP_DecryptInit_ex(DecCtx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr); -#endif - } +namespace { - ~CryptoContext() - { -#if ZEN_PLATFORM_WINDOWS - SecureZeroMemory(Key, sizeof(Key)); - SecureZeroMemory(EncryptNonce, sizeof(EncryptNonce)); - SecureZeroMemory(DecryptNonce, sizeof(DecryptNonce)); -#else - OPENSSL_cleanse(Key, sizeof(Key)); - OPENSSL_cleanse(EncryptNonce, sizeof(EncryptNonce)); - OPENSSL_cleanse(DecryptNonce, sizeof(DecryptNonce)); - - if (EncCtx) - { - EVP_CIPHER_CTX_free(EncCtx); - } - if (DecCtx) - { - EVP_CIPHER_CTX_free(DecCtx); - } -#endif - } + static constexpr size_t AesNonceBytes = 12; + static constexpr size_t AesTagBytes = 16; - void UpdateNonce() + /** AES-256-GCM crypto context. Not exposed outside this translation unit. */ + struct AesCryptoContext { - uint32_t* N32 = reinterpret_cast<uint32_t*>(EncryptNonce); - N32[0]++; - N32[1]--; - N32[2] = N32[0] ^ N32[1]; - } + static constexpr size_t NonceBytes = AesNonceBytes; + static constexpr size_t TagBytes = AesTagBytes; - // Returns total encrypted message size, or 0 on failure - // Output format: [length(4B)][nonce(12B)][ciphertext][tag(16B)] - int32_t EncryptMessage(uint8_t* Out, const void* In, int32_t InLength) - { - UpdateNonce(); + uint8_t Key[KeySize] = {}; + uint8_t EncryptNonce[NonceBytes] = {}; + uint8_t DecryptNonce[NonceBytes] = {}; + uint64_t DecryptCounter = 0; ///< Sequence number of the next message to be decrypted (for diagnostics) + bool HasErrors = false; - // On Windows, BCrypt algorithm/key handles are created per call. This is simpler than - // caching but has some overhead. For our use case (relatively large, infrequent messages) - // this is acceptable. #if ZEN_PLATFORM_WINDOWS BCRYPT_ALG_HANDLE hAlg = nullptr; BCRYPT_KEY_HANDLE hKey = nullptr; +#else + EVP_CIPHER_CTX* EncCtx = nullptr; + EVP_CIPHER_CTX* DecCtx = nullptr; +#endif - BCryptOpenAlgorithmProvider(&hAlg, BCRYPT_AES_ALGORITHM, nullptr, 0); - BCryptSetProperty(hAlg, BCRYPT_CHAINING_MODE, (PUCHAR)BCRYPT_CHAIN_MODE_GCM, sizeof(BCRYPT_CHAIN_MODE_GCM), 0); - BCryptGenerateSymmetricKey(hAlg, &hKey, nullptr, 0, (PUCHAR)Key, KeySize, 0); - - BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO AuthInfo; - BCRYPT_INIT_AUTH_MODE_INFO(AuthInfo); - AuthInfo.pbNonce = EncryptNonce; - AuthInfo.cbNonce = NonceBytes; - uint8_t Tag[TagBytes] = {}; - AuthInfo.pbTag = Tag; - AuthInfo.cbTag = TagBytes; - - ULONG CipherLen = 0; - NTSTATUS Status = - BCryptEncrypt(hKey, (PUCHAR)In, (ULONG)InLength, &AuthInfo, nullptr, 0, Out + 4 + NonceBytes, (ULONG)InLength, &CipherLen, 0); - - if (!BCRYPT_SUCCESS(Status)) + AesCryptoContext(const uint8_t (&InKey)[KeySize]) { - HasErrors = true; - BCryptDestroyKey(hKey); - BCryptCloseAlgorithmProvider(hAlg, 0); - return 0; - } + memcpy(Key, InKey, KeySize); - // Write header: length + nonce - memcpy(Out, &InLength, 4); - memcpy(Out + 4, EncryptNonce, NonceBytes); - // Write tag after ciphertext - memcpy(Out + 4 + NonceBytes + CipherLen, Tag, TagBytes); + // EncryptNonce is zero-initialized (NIST SP 800-38D §8.2.1 deterministic + // construction): fixed_field = 0, counter starts at 0 and is incremented + // before each encryption by UpdateNonce(). No RNG is used here because + // std::random_device is not guaranteed to be a CSPRNG (historic MinGW, + // some WASI targets), and the deterministic construction does not need + // one as long as each session uses a unique key. - BCryptDestroyKey(hKey); - BCryptCloseAlgorithmProvider(hAlg, 0); - - return 4 + NonceBytes + static_cast<int32_t>(CipherLen) + TagBytes; +#if ZEN_PLATFORM_WINDOWS + NTSTATUS Status = BCryptOpenAlgorithmProvider(&hAlg, BCRYPT_AES_ALGORITHM, nullptr, 0); + if (!BCRYPT_SUCCESS(Status)) + { + ZEN_ERROR("BCryptOpenAlgorithmProvider failed: 0x{:08x}", static_cast<uint32_t>(Status)); + hAlg = nullptr; + HasErrors = true; + return; + } + + Status = BCryptSetProperty(hAlg, BCRYPT_CHAINING_MODE, (PUCHAR)BCRYPT_CHAIN_MODE_GCM, sizeof(BCRYPT_CHAIN_MODE_GCM), 0); + if (!BCRYPT_SUCCESS(Status)) + { + ZEN_ERROR("BCryptSetProperty(BCRYPT_CHAIN_MODE_GCM) failed: 0x{:08x}", static_cast<uint32_t>(Status)); + HasErrors = true; + return; + } + + Status = BCryptGenerateSymmetricKey(hAlg, &hKey, nullptr, 0, (PUCHAR)Key, KeySize, 0); + if (!BCRYPT_SUCCESS(Status)) + { + ZEN_ERROR("BCryptGenerateSymmetricKey failed: 0x{:08x}", static_cast<uint32_t>(Status)); + hKey = nullptr; + HasErrors = true; + return; + } #else - if (EVP_EncryptInit_ex(EncCtx, nullptr, nullptr, Key, EncryptNonce) != 1) - { - HasErrors = true; - return 0; + while (ERR_get_error() != 0) + { + } + + EncCtx = EVP_CIPHER_CTX_new(); + DecCtx = EVP_CIPHER_CTX_new(); + if (!EncCtx || !DecCtx) + { + ZEN_ERROR("EVP_CIPHER_CTX_new failed"); + HasErrors = true; + return; + } + + if (EVP_EncryptInit_ex(EncCtx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr) != 1) + { + ZEN_ERROR("EVP_EncryptInit_ex(aes-256-gcm) failed: {}", ERR_get_error()); + HasErrors = true; + return; + } + + if (EVP_DecryptInit_ex(DecCtx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr) != 1) + { + ZEN_ERROR("EVP_DecryptInit_ex(aes-256-gcm) failed: {}", ERR_get_error()); + HasErrors = true; + return; + } +#endif } - int32_t Offset = 0; - // Write length - memcpy(Out + Offset, &InLength, 4); - Offset += 4; - // Write nonce - memcpy(Out + Offset, EncryptNonce, NonceBytes); - Offset += NonceBytes; - - // Encrypt - int OutLen = 0; - if (EVP_EncryptUpdate(EncCtx, Out + Offset, &OutLen, static_cast<const uint8_t*>(In), InLength) != 1) + ~AesCryptoContext() { - HasErrors = true; - return 0; +#if ZEN_PLATFORM_WINDOWS + if (hKey) + { + BCryptDestroyKey(hKey); + } + if (hAlg) + { + BCryptCloseAlgorithmProvider(hAlg, 0); + } + SecureZeroMemory(Key, sizeof(Key)); + SecureZeroMemory(EncryptNonce, sizeof(EncryptNonce)); + SecureZeroMemory(DecryptNonce, sizeof(DecryptNonce)); +#else + OPENSSL_cleanse(Key, sizeof(Key)); + OPENSSL_cleanse(EncryptNonce, sizeof(EncryptNonce)); + OPENSSL_cleanse(DecryptNonce, sizeof(DecryptNonce)); + + if (EncCtx) + { + EVP_CIPHER_CTX_free(EncCtx); + } + if (DecCtx) + { + EVP_CIPHER_CTX_free(DecCtx); + } +#endif } - Offset += OutLen; - // Finalize - int FinalLen = 0; - if (EVP_EncryptFinal_ex(EncCtx, Out + Offset, &FinalLen) != 1) + void UpdateNonce() { + // NIST SP 800-38D §8.2.1 deterministic construction: + // nonce = [fixed_field (4 bytes) || invocation_counter (8 bytes, big-endian)] + // The low 8 bytes are a strict monotonic counter starting at zero. On 2^64 + // exhaustion the session is torn down (HasErrors) - never wrap, since a repeated + // (key, nonce) pair catastrophically breaks AES-GCM confidentiality and integrity. + for (int i = 11; i >= 4; --i) + { + if (++EncryptNonce[i] != 0) + { + return; + } + } HasErrors = true; - return 0; } - Offset += FinalLen; - // Get tag - if (EVP_CIPHER_CTX_ctrl(EncCtx, EVP_CTRL_GCM_GET_TAG, TagBytes, Out + Offset) != 1) + int32_t EncryptMessage(uint8_t* Out, const void* In, int32_t InLength) { - HasErrors = true; - return 0; - } - Offset += TagBytes; - - return Offset; -#endif - } + UpdateNonce(); + if (HasErrors) + { + return 0; + } - // Decrypt a message. Returns decrypted data length, or 0 on failure. - // Input must be [ciphertext][tag], with nonce provided separately. - int32_t DecryptMessage(void* Out, const uint8_t* Nonce, const uint8_t* CipherAndTag, int32_t DataLength) - { #if ZEN_PLATFORM_WINDOWS - BCRYPT_ALG_HANDLE hAlg = nullptr; - BCRYPT_KEY_HANDLE hKey = nullptr; - - BCryptOpenAlgorithmProvider(&hAlg, BCRYPT_AES_ALGORITHM, nullptr, 0); - BCryptSetProperty(hAlg, BCRYPT_CHAINING_MODE, (PUCHAR)BCRYPT_CHAIN_MODE_GCM, sizeof(BCRYPT_CHAIN_MODE_GCM), 0); - BCryptGenerateSymmetricKey(hAlg, &hKey, nullptr, 0, (PUCHAR)Key, KeySize, 0); - - BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO AuthInfo; - BCRYPT_INIT_AUTH_MODE_INFO(AuthInfo); - AuthInfo.pbNonce = const_cast<uint8_t*>(Nonce); - AuthInfo.cbNonce = NonceBytes; - AuthInfo.pbTag = const_cast<uint8_t*>(CipherAndTag + DataLength); - AuthInfo.cbTag = TagBytes; - - ULONG PlainLen = 0; - NTSTATUS Status = BCryptDecrypt(hKey, - (PUCHAR)CipherAndTag, - (ULONG)DataLength, - &AuthInfo, - nullptr, - 0, - (PUCHAR)Out, - (ULONG)DataLength, - &PlainLen, - 0); - - BCryptDestroyKey(hKey); - BCryptCloseAlgorithmProvider(hAlg, 0); - - if (!BCRYPT_SUCCESS(Status)) - { - HasErrors = true; - return 0; - } - - return static_cast<int32_t>(PlainLen); + BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO AuthInfo; + BCRYPT_INIT_AUTH_MODE_INFO(AuthInfo); + AuthInfo.pbNonce = EncryptNonce; + AuthInfo.cbNonce = NonceBytes; + // Tag is output-only on encrypt; BCryptEncrypt writes TagBytes bytes into it, so skip zero-init. + uint8_t Tag[TagBytes]; + AuthInfo.pbTag = Tag; + AuthInfo.cbTag = TagBytes; + + ULONG CipherLen = 0; + const NTSTATUS Status = BCryptEncrypt(hKey, + (PUCHAR)In, + (ULONG)InLength, + &AuthInfo, + nullptr, + 0, + Out + 4 + NonceBytes, + (ULONG)InLength, + &CipherLen, + 0); + + if (!BCRYPT_SUCCESS(Status)) + { + ZEN_ERROR("BCryptEncrypt failed: 0x{:08x}", static_cast<uint32_t>(Status)); + HasErrors = true; + return 0; + } + + memcpy(Out, &InLength, 4); + memcpy(Out + 4, EncryptNonce, NonceBytes); + memcpy(Out + 4 + NonceBytes + CipherLen, Tag, TagBytes); + + return 4 + NonceBytes + static_cast<int32_t>(CipherLen) + TagBytes; #else - if (EVP_DecryptInit_ex(DecCtx, nullptr, nullptr, Key, Nonce) != 1) - { - HasErrors = true; - return 0; - } - - int OutLen = 0; - if (EVP_DecryptUpdate(DecCtx, static_cast<uint8_t*>(Out), &OutLen, CipherAndTag, DataLength) != 1) - { - HasErrors = true; - return 0; + // Reset per message so any stale state from a previous encrypt (e.g. partial + // completion after a prior error) cannot bleed into this operation. Re-bind + // the cipher/key; the IV is then set via the normal init call below. + if (EVP_CIPHER_CTX_reset(EncCtx) != 1 || EVP_EncryptInit_ex(EncCtx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr) != 1) + { + ZEN_ERROR("EVP_CIPHER_CTX_reset/EncryptInit failed: {}", ERR_get_error()); + HasErrors = true; + return 0; + } + if (EVP_EncryptInit_ex(EncCtx, nullptr, nullptr, Key, EncryptNonce) != 1) + { + ZEN_ERROR("EVP_EncryptInit_ex(key+iv) failed: {}", ERR_get_error()); + HasErrors = true; + return 0; + } + + int32_t Offset = 0; + memcpy(Out + Offset, &InLength, 4); + Offset += 4; + memcpy(Out + Offset, EncryptNonce, NonceBytes); + Offset += NonceBytes; + + int OutLen = 0; + if (EVP_EncryptUpdate(EncCtx, Out + Offset, &OutLen, static_cast<const uint8_t*>(In), InLength) != 1) + { + ZEN_ERROR("EVP_EncryptUpdate failed: {}", ERR_get_error()); + HasErrors = true; + return 0; + } + Offset += OutLen; + + int FinalLen = 0; + if (EVP_EncryptFinal_ex(EncCtx, Out + Offset, &FinalLen) != 1) + { + ZEN_ERROR("EVP_EncryptFinal_ex failed: {}", ERR_get_error()); + HasErrors = true; + return 0; + } + Offset += FinalLen; + + if (EVP_CIPHER_CTX_ctrl(EncCtx, EVP_CTRL_GCM_GET_TAG, TagBytes, Out + Offset) != 1) + { + ZEN_ERROR("EVP_CTRL_GCM_GET_TAG failed: {}", ERR_get_error()); + HasErrors = true; + return 0; + } + Offset += TagBytes; + + return Offset; +#endif } - // Set the tag for verification - if (EVP_CIPHER_CTX_ctrl(DecCtx, EVP_CTRL_GCM_SET_TAG, TagBytes, const_cast<uint8_t*>(CipherAndTag + DataLength)) != 1) + int32_t DecryptMessage(void* Out, const uint8_t* Nonce, const uint8_t* CipherAndTag, int32_t DataLength) { - HasErrors = true; - return 0; +#if ZEN_PLATFORM_WINDOWS + BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO AuthInfo; + BCRYPT_INIT_AUTH_MODE_INFO(AuthInfo); + AuthInfo.pbNonce = const_cast<uint8_t*>(Nonce); + AuthInfo.cbNonce = NonceBytes; + AuthInfo.pbTag = const_cast<uint8_t*>(CipherAndTag + DataLength); + AuthInfo.cbTag = TagBytes; + + ULONG PlainLen = 0; + const NTSTATUS Status = BCryptDecrypt(hKey, + (PUCHAR)CipherAndTag, + (ULONG)DataLength, + &AuthInfo, + nullptr, + 0, + (PUCHAR)Out, + (ULONG)DataLength, + &PlainLen, + 0); + + if (!BCRYPT_SUCCESS(Status)) + { + // STATUS_AUTH_TAG_MISMATCH (0xC000A002) indicates GCM integrity failure - + // either in-flight corruption or active tampering. Log distinctly from + // other BCryptDecrypt failures so that tamper attempts are auditable. + static constexpr NTSTATUS STATUS_AUTH_TAG_MISMATCH_VAL = static_cast<NTSTATUS>(0xC000A002L); + if (Status == STATUS_AUTH_TAG_MISMATCH_VAL) + { + ZEN_ERROR("AES-GCM tag verification failed (seq={}): possible tampering or in-flight corruption", DecryptCounter); + } + else + { + ZEN_ERROR("BCryptDecrypt failed: 0x{:08x} (seq={})", static_cast<uint32_t>(Status), DecryptCounter); + } + HasErrors = true; + return 0; + } + + ++DecryptCounter; + return static_cast<int32_t>(PlainLen); +#else + // Same rationale as EncryptMessage: reset the context and re-bind the cipher + // before each decrypt to avoid stale state from a previous operation. + if (EVP_CIPHER_CTX_reset(DecCtx) != 1 || EVP_DecryptInit_ex(DecCtx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr) != 1) + { + ZEN_ERROR("EVP_CIPHER_CTX_reset/DecryptInit failed (seq={}): {}", DecryptCounter, ERR_get_error()); + HasErrors = true; + return 0; + } + if (EVP_DecryptInit_ex(DecCtx, nullptr, nullptr, Key, Nonce) != 1) + { + ZEN_ERROR("EVP_DecryptInit_ex (seq={}) failed: {}", DecryptCounter, ERR_get_error()); + HasErrors = true; + return 0; + } + + int OutLen = 0; + if (EVP_DecryptUpdate(DecCtx, static_cast<uint8_t*>(Out), &OutLen, CipherAndTag, DataLength) != 1) + { + ZEN_ERROR("EVP_DecryptUpdate failed (seq={}): {}", DecryptCounter, ERR_get_error()); + HasErrors = true; + return 0; + } + + if (EVP_CIPHER_CTX_ctrl(DecCtx, EVP_CTRL_GCM_SET_TAG, TagBytes, const_cast<uint8_t*>(CipherAndTag + DataLength)) != 1) + { + ZEN_ERROR("EVP_CTRL_GCM_SET_TAG failed (seq={}): {}", DecryptCounter, ERR_get_error()); + HasErrors = true; + return 0; + } + + int FinalLen = 0; + if (EVP_DecryptFinal_ex(DecCtx, static_cast<uint8_t*>(Out) + OutLen, &FinalLen) != 1) + { + // EVP_DecryptFinal_ex returns 0 specifically on GCM tag verification failure + // once the tag has been set. Log distinctly so tamper attempts are auditable. + ZEN_ERROR("AES-GCM tag verification failed (seq={}): possible tampering or in-flight corruption", DecryptCounter); + HasErrors = true; + return 0; + } + + ++DecryptCounter; + return OutLen + FinalLen; +#endif } + }; - int FinalLen = 0; - if (EVP_DecryptFinal_ex(DecCtx, static_cast<uint8_t*>(Out) + OutLen, &FinalLen) != 1) - { - HasErrors = true; - return 0; - } +} // anonymous namespace - return OutLen + FinalLen; -#endif - } +struct AsyncAesComputeTransport::CryptoContext : AesCryptoContext +{ + using AesCryptoContext::AesCryptoContext; }; -AesComputeTransport::AesComputeTransport(const uint8_t (&Key)[KeySize], std::unique_ptr<ComputeTransport> InnerTransport) +// --- AsyncAesComputeTransport --- + +AsyncAesComputeTransport::AsyncAesComputeTransport(const uint8_t (&Key)[KeySize], + std::unique_ptr<AsyncComputeTransport> InnerTransport, + asio::io_context& IoContext) : m_Crypto(std::make_unique<CryptoContext>(Key)) , m_Inner(std::move(InnerTransport)) +, m_IoContext(IoContext) { } -AesComputeTransport::~AesComputeTransport() +AsyncAesComputeTransport::~AsyncAesComputeTransport() { Close(); } bool -AesComputeTransport::IsValid() const +AsyncAesComputeTransport::IsValid() const { return m_Inner && m_Inner->IsValid() && m_Crypto && !m_Crypto->HasErrors && !m_IsClosed; } -size_t -AesComputeTransport::Send(const void* Data, size_t Size) +void +AsyncAesComputeTransport::AsyncWrite(const void* Data, size_t Size, AsyncIoHandler Handler) { - ZEN_TRACE_CPU("AesComputeTransport::Send"); - if (!IsValid()) { - return 0; + asio::post(m_IoContext, [Handler = std::move(Handler)] { Handler(asio::error::make_error_code(asio::error::not_connected), 0); }); + return; } - std::lock_guard<std::mutex> Lock(m_Lock); - const int32_t DataLength = static_cast<int32_t>(Size); - const size_t MessageLength = 4 + NonceBytes + Size + TagBytes; + const size_t MessageLength = 4 + CryptoContext::NonceBytes + Size + CryptoContext::TagBytes; - if (m_EncryptBuffer.size() < MessageLength) - { - m_EncryptBuffer.resize(MessageLength); - } + // Encrypt directly into the per-write buffer rather than a long-lived member. Using a + // member (plaintext + ciphertext share that buffer during encryption on the OpenSSL + // path) would leave plaintext on the heap indefinitely and would also make the + // transport unsafe if AsyncWrite were ever invoked concurrently. Size the shared_ptr + // exactly to EncryptedLen afterwards. + auto EncBuf = std::make_shared<std::vector<uint8_t>>(MessageLength); - const int32_t EncryptedLen = m_Crypto->EncryptMessage(m_EncryptBuffer.data(), Data, DataLength); + const int32_t EncryptedLen = m_Crypto->EncryptMessage(EncBuf->data(), Data, DataLength); if (EncryptedLen == 0) { - return 0; + asio::post(m_IoContext, + [Handler = std::move(Handler)] { Handler(asio::error::make_error_code(asio::error::connection_aborted), 0); }); + return; } - if (!m_Inner->SendMessage(m_EncryptBuffer.data(), static_cast<size_t>(EncryptedLen))) - { - return 0; - } + EncBuf->resize(static_cast<size_t>(EncryptedLen)); - return Size; + m_Inner->AsyncWrite( + EncBuf->data(), + EncBuf->size(), + [Handler = std::move(Handler), EncBuf, Size](const std::error_code& Ec, size_t /*BytesWritten*/) { Handler(Ec, Ec ? 0 : Size); }); } -size_t -AesComputeTransport::Recv(void* Data, size_t Size) +void +AsyncAesComputeTransport::AsyncRead(void* Data, size_t Size, AsyncIoHandler Handler) { if (!IsValid()) { - return 0; + asio::post(m_IoContext, [Handler = std::move(Handler)] { Handler(asio::error::make_error_code(asio::error::not_connected), 0); }); + return; } - // AES-GCM decrypts entire messages at once, but the caller may request fewer bytes - // than the decrypted message contains. Excess bytes are buffered in m_RemainingData - // and returned on subsequent Recv calls without another decryption round-trip. - ZEN_TRACE_CPU("AesComputeTransport::Recv"); - - std::lock_guard<std::mutex> Lock(m_Lock); + uint8_t* Dest = static_cast<uint8_t*>(Data); if (!m_RemainingData.empty()) { const size_t Available = m_RemainingData.size() - m_RemainingOffset; const size_t ToCopy = std::min(Available, Size); - memcpy(Data, m_RemainingData.data() + m_RemainingOffset, ToCopy); + memcpy(Dest, m_RemainingData.data() + m_RemainingOffset, ToCopy); m_RemainingOffset += ToCopy; if (m_RemainingOffset >= m_RemainingData.size()) @@ -339,82 +437,104 @@ AesComputeTransport::Recv(void* Data, size_t Size) m_RemainingOffset = 0; } - return ToCopy; - } - - // Receive packet header: [length(4B)][nonce(12B)] - struct PacketHeader - { - int32_t DataLength = 0; - uint8_t Nonce[NonceBytes] = {}; - } Header; - - if (!m_Inner->RecvMessage(&Header, sizeof(Header))) - { - return 0; - } - - // Validate DataLength to prevent OOM from malicious/corrupt peers - static constexpr int32_t MaxDataLength = 64 * 1024 * 1024; // 64 MiB - - if (Header.DataLength <= 0 || Header.DataLength > MaxDataLength) - { - ZEN_WARN("AES recv: invalid DataLength {} from peer", Header.DataLength); - return 0; - } - - // Receive ciphertext + tag - const size_t MessageLength = static_cast<size_t>(Header.DataLength) + TagBytes; - - if (m_EncryptBuffer.size() < MessageLength) - { - m_EncryptBuffer.resize(MessageLength); - } - - if (!m_Inner->RecvMessage(m_EncryptBuffer.data(), MessageLength)) - { - return 0; - } - - // Decrypt - const size_t BytesToReturn = std::min(static_cast<size_t>(Header.DataLength), Size); - - // We need a temporary buffer for decryption if we can't decrypt directly into output - std::vector<uint8_t> DecryptedBuf(static_cast<size_t>(Header.DataLength)); - - const int32_t Decrypted = m_Crypto->DecryptMessage(DecryptedBuf.data(), Header.Nonce, m_EncryptBuffer.data(), Header.DataLength); - if (Decrypted == 0) - { - return 0; - } - - memcpy(Data, DecryptedBuf.data(), BytesToReturn); + if (ToCopy == Size) + { + asio::post(m_IoContext, [Handler = std::move(Handler), Size] { Handler(std::error_code{}, Size); }); + return; + } - // Store remaining data if we couldn't return everything - if (static_cast<size_t>(Header.DataLength) > BytesToReturn) - { - m_RemainingOffset = 0; - m_RemainingData.assign(DecryptedBuf.begin() + BytesToReturn, DecryptedBuf.begin() + Header.DataLength); + DoRecvMessage(Dest + ToCopy, Size - ToCopy, std::move(Handler)); + return; } - return BytesToReturn; + DoRecvMessage(Dest, Size, std::move(Handler)); } void -AesComputeTransport::MarkComplete() +AsyncAesComputeTransport::DoRecvMessage(uint8_t* Dest, size_t Size, AsyncIoHandler Handler) { - if (IsValid()) - { - m_Inner->MarkComplete(); - } + static constexpr size_t HeaderSize = 4 + CryptoContext::NonceBytes; + auto HeaderBuf = std::make_shared<std::array<uint8_t, 4 + 12>>(); + + m_Inner->AsyncRead(HeaderBuf->data(), + HeaderSize, + [this, Dest, Size, Handler = std::move(Handler), HeaderBuf](const std::error_code& Ec, size_t /*Bytes*/) mutable { + if (Ec) + { + Handler(Ec, 0); + return; + } + + int32_t DataLength = 0; + memcpy(&DataLength, HeaderBuf->data(), 4); + + static constexpr int32_t MaxDataLength = 64 * 1024 * 1024; + if (DataLength <= 0 || DataLength > MaxDataLength) + { + Handler(asio::error::make_error_code(asio::error::invalid_argument), 0); + return; + } + + const size_t MessageLength = static_cast<size_t>(DataLength) + CryptoContext::TagBytes; + if (m_DecryptBuffer.size() < MessageLength) + { + m_DecryptBuffer.resize(MessageLength); + } + + auto NonceBuf = std::make_shared<std::array<uint8_t, CryptoContext::NonceBytes>>(); + memcpy(NonceBuf->data(), HeaderBuf->data() + 4, CryptoContext::NonceBytes); + + m_Inner->AsyncRead( + m_DecryptBuffer.data(), + MessageLength, + [this, Dest, Size, Handler = std::move(Handler), DataLength, NonceBuf](const std::error_code& Ec, + size_t /*Bytes*/) mutable { + if (Ec) + { + Handler(Ec, 0); + return; + } + + std::vector<uint8_t> PlaintextBuf(static_cast<size_t>(DataLength)); + const int32_t Decrypted = + m_Crypto->DecryptMessage(PlaintextBuf.data(), NonceBuf->data(), m_DecryptBuffer.data(), DataLength); + if (Decrypted == 0) + { + Handler(asio::error::make_error_code(asio::error::connection_aborted), 0); + return; + } + + const size_t BytesToReturn = std::min(static_cast<size_t>(Decrypted), Size); + memcpy(Dest, PlaintextBuf.data(), BytesToReturn); + + if (static_cast<size_t>(Decrypted) > BytesToReturn) + { + m_RemainingOffset = 0; + m_RemainingData.assign(PlaintextBuf.begin() + BytesToReturn, PlaintextBuf.begin() + Decrypted); + } + + if (BytesToReturn < Size) + { + DoRecvMessage(Dest + BytesToReturn, Size - BytesToReturn, std::move(Handler)); + } + else + { + Handler(std::error_code{}, Size); + } + }); + }); } void -AesComputeTransport::Close() +AsyncAesComputeTransport::Close() { if (!m_IsClosed) { - if (m_Inner && m_Inner->IsValid()) + // Always forward Close() to the inner transport if we have one. Gating on + // IsValid() skipped cleanup when the inner transport was partially torn down + // (e.g. after a read/write error marked it non-valid but left its socket open), + // leaking OS handles. Close implementations are expected to be idempotent. + if (m_Inner) { m_Inner->Close(); } |