diff options
Diffstat (limited to 'src/zenhorde/hordetransportaes.cpp')
| -rw-r--r-- | src/zenhorde/hordetransportaes.cpp | 609 |
1 files changed, 318 insertions, 291 deletions
diff --git a/src/zenhorde/hordetransportaes.cpp b/src/zenhorde/hordetransportaes.cpp index 505b6bde7..c71866e8c 100644 --- a/src/zenhorde/hordetransportaes.cpp +++ b/src/zenhorde/hordetransportaes.cpp @@ -5,6 +5,10 @@ #include <zencore/logging.h> #include <zencore/trace.h> +ZEN_THIRD_PARTY_INCLUDES_START +#include <asio.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + #include <algorithm> #include <cstring> #include <random> @@ -22,274 +26,281 @@ ZEN_THIRD_PARTY_INCLUDES_END namespace zen::horde { -struct AesComputeTransport::CryptoContext -{ - uint8_t Key[KeySize] = {}; - uint8_t EncryptNonce[NonceBytes] = {}; - uint8_t DecryptNonce[NonceBytes] = {}; - bool HasErrors = false; +namespace { + + static constexpr size_t AesNonceBytes = 12; + static constexpr size_t AesTagBytes = 16; + + /** AES-256-GCM crypto context. Not exposed outside this translation unit. */ + struct AesCryptoContext + { + static constexpr size_t NonceBytes = AesNonceBytes; + static constexpr size_t TagBytes = AesTagBytes; + + uint8_t Key[KeySize] = {}; + uint8_t EncryptNonce[NonceBytes] = {}; + uint8_t DecryptNonce[NonceBytes] = {}; + bool HasErrors = false; #if !ZEN_PLATFORM_WINDOWS - EVP_CIPHER_CTX* EncCtx = nullptr; - EVP_CIPHER_CTX* DecCtx = nullptr; + EVP_CIPHER_CTX* EncCtx = nullptr; + EVP_CIPHER_CTX* DecCtx = nullptr; #endif - CryptoContext(const uint8_t (&InKey)[KeySize]) - { - memcpy(Key, InKey, KeySize); - - // The encrypt nonce is randomly initialized and then deterministically mutated - // per message via UpdateNonce(). The decrypt nonce is not used — it comes from - // the wire (each received message carries its own nonce in the header). - std::random_device Rd; - std::mt19937 Gen(Rd()); - std::uniform_int_distribution<int> Dist(0, 255); - for (auto& Byte : EncryptNonce) + AesCryptoContext(const uint8_t (&InKey)[KeySize]) { - Byte = static_cast<uint8_t>(Dist(Gen)); - } + memcpy(Key, InKey, KeySize); + + std::random_device Rd; + std::mt19937 Gen(Rd()); + std::uniform_int_distribution<int> Dist(0, 255); + for (auto& Byte : EncryptNonce) + { + Byte = static_cast<uint8_t>(Dist(Gen)); + } #if !ZEN_PLATFORM_WINDOWS - // Drain any stale OpenSSL errors - while (ERR_get_error() != 0) - { - } + while (ERR_get_error() != 0) + { + } - EncCtx = EVP_CIPHER_CTX_new(); - EVP_EncryptInit_ex(EncCtx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr); + EncCtx = EVP_CIPHER_CTX_new(); + EVP_EncryptInit_ex(EncCtx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr); - DecCtx = EVP_CIPHER_CTX_new(); - EVP_DecryptInit_ex(DecCtx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr); + DecCtx = EVP_CIPHER_CTX_new(); + EVP_DecryptInit_ex(DecCtx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr); #endif - } + } - ~CryptoContext() - { + ~AesCryptoContext() + { #if ZEN_PLATFORM_WINDOWS - SecureZeroMemory(Key, sizeof(Key)); - SecureZeroMemory(EncryptNonce, sizeof(EncryptNonce)); - SecureZeroMemory(DecryptNonce, sizeof(DecryptNonce)); + SecureZeroMemory(Key, sizeof(Key)); + SecureZeroMemory(EncryptNonce, sizeof(EncryptNonce)); + SecureZeroMemory(DecryptNonce, sizeof(DecryptNonce)); #else - OPENSSL_cleanse(Key, sizeof(Key)); - OPENSSL_cleanse(EncryptNonce, sizeof(EncryptNonce)); - OPENSSL_cleanse(DecryptNonce, sizeof(DecryptNonce)); - - if (EncCtx) - { - EVP_CIPHER_CTX_free(EncCtx); + OPENSSL_cleanse(Key, sizeof(Key)); + OPENSSL_cleanse(EncryptNonce, sizeof(EncryptNonce)); + OPENSSL_cleanse(DecryptNonce, sizeof(DecryptNonce)); + + if (EncCtx) + { + EVP_CIPHER_CTX_free(EncCtx); + } + if (DecCtx) + { + EVP_CIPHER_CTX_free(DecCtx); + } +#endif } - if (DecCtx) + + void UpdateNonce() { - EVP_CIPHER_CTX_free(DecCtx); + uint32_t* N32 = reinterpret_cast<uint32_t*>(EncryptNonce); + N32[0]++; + N32[1]--; + N32[2] = N32[0] ^ N32[1]; } -#endif - } - - void UpdateNonce() - { - uint32_t* N32 = reinterpret_cast<uint32_t*>(EncryptNonce); - N32[0]++; - N32[1]--; - N32[2] = N32[0] ^ N32[1]; - } - // Returns total encrypted message size, or 0 on failure - // Output format: [length(4B)][nonce(12B)][ciphertext][tag(16B)] - int32_t EncryptMessage(uint8_t* Out, const void* In, int32_t InLength) - { - UpdateNonce(); + int32_t EncryptMessage(uint8_t* Out, const void* In, int32_t InLength) + { + UpdateNonce(); - // On Windows, BCrypt algorithm/key handles are created per call. This is simpler than - // caching but has some overhead. For our use case (relatively large, infrequent messages) - // this is acceptable. #if ZEN_PLATFORM_WINDOWS - BCRYPT_ALG_HANDLE hAlg = nullptr; - BCRYPT_KEY_HANDLE hKey = nullptr; - - BCryptOpenAlgorithmProvider(&hAlg, BCRYPT_AES_ALGORITHM, nullptr, 0); - BCryptSetProperty(hAlg, BCRYPT_CHAINING_MODE, (PUCHAR)BCRYPT_CHAIN_MODE_GCM, sizeof(BCRYPT_CHAIN_MODE_GCM), 0); - BCryptGenerateSymmetricKey(hAlg, &hKey, nullptr, 0, (PUCHAR)Key, KeySize, 0); - - BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO AuthInfo; - BCRYPT_INIT_AUTH_MODE_INFO(AuthInfo); - AuthInfo.pbNonce = EncryptNonce; - AuthInfo.cbNonce = NonceBytes; - uint8_t Tag[TagBytes] = {}; - AuthInfo.pbTag = Tag; - AuthInfo.cbTag = TagBytes; - - ULONG CipherLen = 0; - NTSTATUS Status = - BCryptEncrypt(hKey, (PUCHAR)In, (ULONG)InLength, &AuthInfo, nullptr, 0, Out + 4 + NonceBytes, (ULONG)InLength, &CipherLen, 0); - - if (!BCRYPT_SUCCESS(Status)) - { - HasErrors = true; + BCRYPT_ALG_HANDLE hAlg = nullptr; + BCRYPT_KEY_HANDLE hKey = nullptr; + + BCryptOpenAlgorithmProvider(&hAlg, BCRYPT_AES_ALGORITHM, nullptr, 0); + BCryptSetProperty(hAlg, BCRYPT_CHAINING_MODE, (PUCHAR)BCRYPT_CHAIN_MODE_GCM, sizeof(BCRYPT_CHAIN_MODE_GCM), 0); + BCryptGenerateSymmetricKey(hAlg, &hKey, nullptr, 0, (PUCHAR)Key, KeySize, 0); + + BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO AuthInfo; + BCRYPT_INIT_AUTH_MODE_INFO(AuthInfo); + AuthInfo.pbNonce = EncryptNonce; + AuthInfo.cbNonce = NonceBytes; + uint8_t Tag[TagBytes] = {}; + AuthInfo.pbTag = Tag; + AuthInfo.cbTag = TagBytes; + + ULONG CipherLen = 0; + NTSTATUS Status = BCryptEncrypt(hKey, + (PUCHAR)In, + (ULONG)InLength, + &AuthInfo, + nullptr, + 0, + Out + 4 + NonceBytes, + (ULONG)InLength, + &CipherLen, + 0); + + if (!BCRYPT_SUCCESS(Status)) + { + HasErrors = true; + BCryptDestroyKey(hKey); + BCryptCloseAlgorithmProvider(hAlg, 0); + return 0; + } + + memcpy(Out, &InLength, 4); + memcpy(Out + 4, EncryptNonce, NonceBytes); + memcpy(Out + 4 + NonceBytes + CipherLen, Tag, TagBytes); + BCryptDestroyKey(hKey); BCryptCloseAlgorithmProvider(hAlg, 0); - return 0; - } - - // Write header: length + nonce - memcpy(Out, &InLength, 4); - memcpy(Out + 4, EncryptNonce, NonceBytes); - // Write tag after ciphertext - memcpy(Out + 4 + NonceBytes + CipherLen, Tag, TagBytes); - - BCryptDestroyKey(hKey); - BCryptCloseAlgorithmProvider(hAlg, 0); - return 4 + NonceBytes + static_cast<int32_t>(CipherLen) + TagBytes; + return 4 + NonceBytes + static_cast<int32_t>(CipherLen) + TagBytes; #else - if (EVP_EncryptInit_ex(EncCtx, nullptr, nullptr, Key, EncryptNonce) != 1) - { - HasErrors = true; - return 0; - } - - int32_t Offset = 0; - // Write length - memcpy(Out + Offset, &InLength, 4); - Offset += 4; - // Write nonce - memcpy(Out + Offset, EncryptNonce, NonceBytes); - Offset += NonceBytes; - - // Encrypt - int OutLen = 0; - if (EVP_EncryptUpdate(EncCtx, Out + Offset, &OutLen, static_cast<const uint8_t*>(In), InLength) != 1) - { - HasErrors = true; - return 0; - } - Offset += OutLen; - - // Finalize - int FinalLen = 0; - if (EVP_EncryptFinal_ex(EncCtx, Out + Offset, &FinalLen) != 1) - { - HasErrors = true; - return 0; + if (EVP_EncryptInit_ex(EncCtx, nullptr, nullptr, Key, EncryptNonce) != 1) + { + HasErrors = true; + return 0; + } + + int32_t Offset = 0; + memcpy(Out + Offset, &InLength, 4); + Offset += 4; + memcpy(Out + Offset, EncryptNonce, NonceBytes); + Offset += NonceBytes; + + int OutLen = 0; + if (EVP_EncryptUpdate(EncCtx, Out + Offset, &OutLen, static_cast<const uint8_t*>(In), InLength) != 1) + { + HasErrors = true; + return 0; + } + Offset += OutLen; + + int FinalLen = 0; + if (EVP_EncryptFinal_ex(EncCtx, Out + Offset, &FinalLen) != 1) + { + HasErrors = true; + return 0; + } + Offset += FinalLen; + + if (EVP_CIPHER_CTX_ctrl(EncCtx, EVP_CTRL_GCM_GET_TAG, TagBytes, Out + Offset) != 1) + { + HasErrors = true; + return 0; + } + Offset += TagBytes; + + return Offset; +#endif } - Offset += FinalLen; - // Get tag - if (EVP_CIPHER_CTX_ctrl(EncCtx, EVP_CTRL_GCM_GET_TAG, TagBytes, Out + Offset) != 1) + int32_t DecryptMessage(void* Out, const uint8_t* Nonce, const uint8_t* CipherAndTag, int32_t DataLength) { - HasErrors = true; - return 0; - } - Offset += TagBytes; - - return Offset; -#endif - } - - // Decrypt a message. Returns decrypted data length, or 0 on failure. - // Input must be [ciphertext][tag], with nonce provided separately. - int32_t DecryptMessage(void* Out, const uint8_t* Nonce, const uint8_t* CipherAndTag, int32_t DataLength) - { #if ZEN_PLATFORM_WINDOWS - BCRYPT_ALG_HANDLE hAlg = nullptr; - BCRYPT_KEY_HANDLE hKey = nullptr; - - BCryptOpenAlgorithmProvider(&hAlg, BCRYPT_AES_ALGORITHM, nullptr, 0); - BCryptSetProperty(hAlg, BCRYPT_CHAINING_MODE, (PUCHAR)BCRYPT_CHAIN_MODE_GCM, sizeof(BCRYPT_CHAIN_MODE_GCM), 0); - BCryptGenerateSymmetricKey(hAlg, &hKey, nullptr, 0, (PUCHAR)Key, KeySize, 0); - - BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO AuthInfo; - BCRYPT_INIT_AUTH_MODE_INFO(AuthInfo); - AuthInfo.pbNonce = const_cast<uint8_t*>(Nonce); - AuthInfo.cbNonce = NonceBytes; - AuthInfo.pbTag = const_cast<uint8_t*>(CipherAndTag + DataLength); - AuthInfo.cbTag = TagBytes; - - ULONG PlainLen = 0; - NTSTATUS Status = BCryptDecrypt(hKey, - (PUCHAR)CipherAndTag, - (ULONG)DataLength, - &AuthInfo, - nullptr, - 0, - (PUCHAR)Out, - (ULONG)DataLength, - &PlainLen, - 0); - - BCryptDestroyKey(hKey); - BCryptCloseAlgorithmProvider(hAlg, 0); - - if (!BCRYPT_SUCCESS(Status)) - { - HasErrors = true; - return 0; - } + BCRYPT_ALG_HANDLE hAlg = nullptr; + BCRYPT_KEY_HANDLE hKey = nullptr; + + BCryptOpenAlgorithmProvider(&hAlg, BCRYPT_AES_ALGORITHM, nullptr, 0); + BCryptSetProperty(hAlg, BCRYPT_CHAINING_MODE, (PUCHAR)BCRYPT_CHAIN_MODE_GCM, sizeof(BCRYPT_CHAIN_MODE_GCM), 0); + BCryptGenerateSymmetricKey(hAlg, &hKey, nullptr, 0, (PUCHAR)Key, KeySize, 0); + + BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO AuthInfo; + BCRYPT_INIT_AUTH_MODE_INFO(AuthInfo); + AuthInfo.pbNonce = const_cast<uint8_t*>(Nonce); + AuthInfo.cbNonce = NonceBytes; + AuthInfo.pbTag = const_cast<uint8_t*>(CipherAndTag + DataLength); + AuthInfo.cbTag = TagBytes; + + ULONG PlainLen = 0; + NTSTATUS Status = BCryptDecrypt(hKey, + (PUCHAR)CipherAndTag, + (ULONG)DataLength, + &AuthInfo, + nullptr, + 0, + (PUCHAR)Out, + (ULONG)DataLength, + &PlainLen, + 0); - return static_cast<int32_t>(PlainLen); -#else - if (EVP_DecryptInit_ex(DecCtx, nullptr, nullptr, Key, Nonce) != 1) - { - HasErrors = true; - return 0; - } + BCryptDestroyKey(hKey); + BCryptCloseAlgorithmProvider(hAlg, 0); - int OutLen = 0; - if (EVP_DecryptUpdate(DecCtx, static_cast<uint8_t*>(Out), &OutLen, CipherAndTag, DataLength) != 1) - { - HasErrors = true; - return 0; - } + if (!BCRYPT_SUCCESS(Status)) + { + HasErrors = true; + return 0; + } - // Set the tag for verification - if (EVP_CIPHER_CTX_ctrl(DecCtx, EVP_CTRL_GCM_SET_TAG, TagBytes, const_cast<uint8_t*>(CipherAndTag + DataLength)) != 1) - { - HasErrors = true; - return 0; + return static_cast<int32_t>(PlainLen); +#else + if (EVP_DecryptInit_ex(DecCtx, nullptr, nullptr, Key, Nonce) != 1) + { + HasErrors = true; + return 0; + } + + int OutLen = 0; + if (EVP_DecryptUpdate(DecCtx, static_cast<uint8_t*>(Out), &OutLen, CipherAndTag, DataLength) != 1) + { + HasErrors = true; + return 0; + } + + if (EVP_CIPHER_CTX_ctrl(DecCtx, EVP_CTRL_GCM_SET_TAG, TagBytes, const_cast<uint8_t*>(CipherAndTag + DataLength)) != 1) + { + HasErrors = true; + return 0; + } + + int FinalLen = 0; + if (EVP_DecryptFinal_ex(DecCtx, static_cast<uint8_t*>(Out) + OutLen, &FinalLen) != 1) + { + HasErrors = true; + return 0; + } + + return OutLen + FinalLen; +#endif } + }; - int FinalLen = 0; - if (EVP_DecryptFinal_ex(DecCtx, static_cast<uint8_t*>(Out) + OutLen, &FinalLen) != 1) - { - HasErrors = true; - return 0; - } +} // anonymous namespace - return OutLen + FinalLen; -#endif - } +struct AsyncAesComputeTransport::CryptoContext : AesCryptoContext +{ + using AesCryptoContext::AesCryptoContext; }; -AesComputeTransport::AesComputeTransport(const uint8_t (&Key)[KeySize], std::unique_ptr<ComputeTransport> InnerTransport) +// --- AsyncAesComputeTransport --- + +AsyncAesComputeTransport::AsyncAesComputeTransport(const uint8_t (&Key)[KeySize], + std::unique_ptr<AsyncComputeTransport> InnerTransport, + asio::io_context& IoContext) : m_Crypto(std::make_unique<CryptoContext>(Key)) , m_Inner(std::move(InnerTransport)) +, m_IoContext(IoContext) { } -AesComputeTransport::~AesComputeTransport() +AsyncAesComputeTransport::~AsyncAesComputeTransport() { Close(); } bool -AesComputeTransport::IsValid() const +AsyncAesComputeTransport::IsValid() const { return m_Inner && m_Inner->IsValid() && m_Crypto && !m_Crypto->HasErrors && !m_IsClosed; } -size_t -AesComputeTransport::Send(const void* Data, size_t Size) +void +AsyncAesComputeTransport::AsyncWrite(const void* Data, size_t Size, AsyncIoHandler Handler) { - ZEN_TRACE_CPU("AesComputeTransport::Send"); - if (!IsValid()) { - return 0; + asio::post(m_IoContext, [Handler = std::move(Handler)] { Handler(asio::error::make_error_code(asio::error::not_connected), 0); }); + return; } - std::lock_guard<std::mutex> Lock(m_Lock); - const int32_t DataLength = static_cast<int32_t>(Size); - const size_t MessageLength = 4 + NonceBytes + Size + TagBytes; + const size_t MessageLength = 4 + CryptoContext::NonceBytes + Size + CryptoContext::TagBytes; if (m_EncryptBuffer.size() < MessageLength) { @@ -299,38 +310,36 @@ AesComputeTransport::Send(const void* Data, size_t Size) const int32_t EncryptedLen = m_Crypto->EncryptMessage(m_EncryptBuffer.data(), Data, DataLength); if (EncryptedLen == 0) { - return 0; + asio::post(m_IoContext, + [Handler = std::move(Handler)] { Handler(asio::error::make_error_code(asio::error::connection_aborted), 0); }); + return; } - if (!m_Inner->SendMessage(m_EncryptBuffer.data(), static_cast<size_t>(EncryptedLen))) - { - return 0; - } + auto EncBuf = std::make_shared<std::vector<uint8_t>>(m_EncryptBuffer.begin(), m_EncryptBuffer.begin() + EncryptedLen); - return Size; + m_Inner->AsyncWrite( + EncBuf->data(), + EncBuf->size(), + [Handler = std::move(Handler), EncBuf, Size](const std::error_code& Ec, size_t /*BytesWritten*/) { Handler(Ec, Ec ? 0 : Size); }); } -size_t -AesComputeTransport::Recv(void* Data, size_t Size) +void +AsyncAesComputeTransport::AsyncRead(void* Data, size_t Size, AsyncIoHandler Handler) { if (!IsValid()) { - return 0; + asio::post(m_IoContext, [Handler = std::move(Handler)] { Handler(asio::error::make_error_code(asio::error::not_connected), 0); }); + return; } - // AES-GCM decrypts entire messages at once, but the caller may request fewer bytes - // than the decrypted message contains. Excess bytes are buffered in m_RemainingData - // and returned on subsequent Recv calls without another decryption round-trip. - ZEN_TRACE_CPU("AesComputeTransport::Recv"); - - std::lock_guard<std::mutex> Lock(m_Lock); + uint8_t* Dest = static_cast<uint8_t*>(Data); if (!m_RemainingData.empty()) { const size_t Available = m_RemainingData.size() - m_RemainingOffset; const size_t ToCopy = std::min(Available, Size); - memcpy(Data, m_RemainingData.data() + m_RemainingOffset, ToCopy); + memcpy(Dest, m_RemainingData.data() + m_RemainingOffset, ToCopy); m_RemainingOffset += ToCopy; if (m_RemainingOffset >= m_RemainingData.size()) @@ -339,78 +348,96 @@ AesComputeTransport::Recv(void* Data, size_t Size) m_RemainingOffset = 0; } - return ToCopy; - } - - // Receive packet header: [length(4B)][nonce(12B)] - struct PacketHeader - { - int32_t DataLength = 0; - uint8_t Nonce[NonceBytes] = {}; - } Header; - - if (!m_Inner->RecvMessage(&Header, sizeof(Header))) - { - return 0; - } - - // Validate DataLength to prevent OOM from malicious/corrupt peers - static constexpr int32_t MaxDataLength = 64 * 1024 * 1024; // 64 MiB - - if (Header.DataLength <= 0 || Header.DataLength > MaxDataLength) - { - ZEN_WARN("AES recv: invalid DataLength {} from peer", Header.DataLength); - return 0; - } - - // Receive ciphertext + tag - const size_t MessageLength = static_cast<size_t>(Header.DataLength) + TagBytes; - - if (m_EncryptBuffer.size() < MessageLength) - { - m_EncryptBuffer.resize(MessageLength); - } - - if (!m_Inner->RecvMessage(m_EncryptBuffer.data(), MessageLength)) - { - return 0; - } - - // Decrypt - const size_t BytesToReturn = std::min(static_cast<size_t>(Header.DataLength), Size); - - // We need a temporary buffer for decryption if we can't decrypt directly into output - std::vector<uint8_t> DecryptedBuf(static_cast<size_t>(Header.DataLength)); - - const int32_t Decrypted = m_Crypto->DecryptMessage(DecryptedBuf.data(), Header.Nonce, m_EncryptBuffer.data(), Header.DataLength); - if (Decrypted == 0) - { - return 0; - } - - memcpy(Data, DecryptedBuf.data(), BytesToReturn); + if (ToCopy == Size) + { + asio::post(m_IoContext, [Handler = std::move(Handler), Size] { Handler(std::error_code{}, Size); }); + return; + } - // Store remaining data if we couldn't return everything - if (static_cast<size_t>(Header.DataLength) > BytesToReturn) - { - m_RemainingOffset = 0; - m_RemainingData.assign(DecryptedBuf.begin() + BytesToReturn, DecryptedBuf.begin() + Header.DataLength); + DoRecvMessage(Dest + ToCopy, Size - ToCopy, std::move(Handler)); + return; } - return BytesToReturn; + DoRecvMessage(Dest, Size, std::move(Handler)); } void -AesComputeTransport::MarkComplete() +AsyncAesComputeTransport::DoRecvMessage(uint8_t* Dest, size_t Size, AsyncIoHandler Handler) { - if (IsValid()) - { - m_Inner->MarkComplete(); - } + static constexpr size_t HeaderSize = 4 + CryptoContext::NonceBytes; + auto HeaderBuf = std::make_shared<std::array<uint8_t, 4 + 12>>(); + + m_Inner->AsyncRead(HeaderBuf->data(), + HeaderSize, + [this, Dest, Size, Handler = std::move(Handler), HeaderBuf](const std::error_code& Ec, size_t /*Bytes*/) mutable { + if (Ec) + { + Handler(Ec, 0); + return; + } + + int32_t DataLength = 0; + memcpy(&DataLength, HeaderBuf->data(), 4); + + static constexpr int32_t MaxDataLength = 64 * 1024 * 1024; + if (DataLength <= 0 || DataLength > MaxDataLength) + { + Handler(asio::error::make_error_code(asio::error::invalid_argument), 0); + return; + } + + const size_t MessageLength = static_cast<size_t>(DataLength) + CryptoContext::TagBytes; + if (m_DecryptBuffer.size() < MessageLength) + { + m_DecryptBuffer.resize(MessageLength); + } + + auto NonceBuf = std::make_shared<std::array<uint8_t, CryptoContext::NonceBytes>>(); + memcpy(NonceBuf->data(), HeaderBuf->data() + 4, CryptoContext::NonceBytes); + + m_Inner->AsyncRead( + m_DecryptBuffer.data(), + MessageLength, + [this, Dest, Size, Handler = std::move(Handler), DataLength, NonceBuf](const std::error_code& Ec, + size_t /*Bytes*/) mutable { + if (Ec) + { + Handler(Ec, 0); + return; + } + + std::vector<uint8_t> PlaintextBuf(static_cast<size_t>(DataLength)); + const int32_t Decrypted = + m_Crypto->DecryptMessage(PlaintextBuf.data(), NonceBuf->data(), m_DecryptBuffer.data(), DataLength); + if (Decrypted == 0) + { + Handler(asio::error::make_error_code(asio::error::connection_aborted), 0); + return; + } + + const size_t BytesToReturn = std::min(static_cast<size_t>(Decrypted), Size); + memcpy(Dest, PlaintextBuf.data(), BytesToReturn); + + if (static_cast<size_t>(Decrypted) > BytesToReturn) + { + m_RemainingOffset = 0; + m_RemainingData.assign(PlaintextBuf.begin() + BytesToReturn, PlaintextBuf.begin() + Decrypted); + } + + if (BytesToReturn < Size) + { + DoRecvMessage(Dest + BytesToReturn, Size - BytesToReturn, std::move(Handler)); + } + else + { + Handler(std::error_code{}, Size); + } + }); + }); } void -AesComputeTransport::Close() +AsyncAesComputeTransport::Close() { if (!m_IsClosed) { |