aboutsummaryrefslogtreecommitdiff
path: root/src/zenhorde/hordetransportaes.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/zenhorde/hordetransportaes.cpp')
-rw-r--r--src/zenhorde/hordetransportaes.cpp718
1 files changed, 419 insertions, 299 deletions
diff --git a/src/zenhorde/hordetransportaes.cpp b/src/zenhorde/hordetransportaes.cpp
index 505b6bde7..0b94a4397 100644
--- a/src/zenhorde/hordetransportaes.cpp
+++ b/src/zenhorde/hordetransportaes.cpp
@@ -5,9 +5,12 @@
#include <zencore/logging.h>
#include <zencore/trace.h>
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <asio.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
#include <algorithm>
#include <cstring>
-#include <random>
#if ZEN_PLATFORM_WINDOWS
# include <zencore/windows.h>
@@ -22,315 +25,410 @@ ZEN_THIRD_PARTY_INCLUDES_END
namespace zen::horde {
-struct AesComputeTransport::CryptoContext
-{
- uint8_t Key[KeySize] = {};
- uint8_t EncryptNonce[NonceBytes] = {};
- uint8_t DecryptNonce[NonceBytes] = {};
- bool HasErrors = false;
-
-#if !ZEN_PLATFORM_WINDOWS
- EVP_CIPHER_CTX* EncCtx = nullptr;
- EVP_CIPHER_CTX* DecCtx = nullptr;
-#endif
-
- CryptoContext(const uint8_t (&InKey)[KeySize])
- {
- memcpy(Key, InKey, KeySize);
-
- // The encrypt nonce is randomly initialized and then deterministically mutated
- // per message via UpdateNonce(). The decrypt nonce is not used — it comes from
- // the wire (each received message carries its own nonce in the header).
- std::random_device Rd;
- std::mt19937 Gen(Rd());
- std::uniform_int_distribution<int> Dist(0, 255);
- for (auto& Byte : EncryptNonce)
- {
- Byte = static_cast<uint8_t>(Dist(Gen));
- }
-
-#if !ZEN_PLATFORM_WINDOWS
- // Drain any stale OpenSSL errors
- while (ERR_get_error() != 0)
- {
- }
-
- EncCtx = EVP_CIPHER_CTX_new();
- EVP_EncryptInit_ex(EncCtx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr);
-
- DecCtx = EVP_CIPHER_CTX_new();
- EVP_DecryptInit_ex(DecCtx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr);
-#endif
- }
+namespace {
- ~CryptoContext()
- {
-#if ZEN_PLATFORM_WINDOWS
- SecureZeroMemory(Key, sizeof(Key));
- SecureZeroMemory(EncryptNonce, sizeof(EncryptNonce));
- SecureZeroMemory(DecryptNonce, sizeof(DecryptNonce));
-#else
- OPENSSL_cleanse(Key, sizeof(Key));
- OPENSSL_cleanse(EncryptNonce, sizeof(EncryptNonce));
- OPENSSL_cleanse(DecryptNonce, sizeof(DecryptNonce));
-
- if (EncCtx)
- {
- EVP_CIPHER_CTX_free(EncCtx);
- }
- if (DecCtx)
- {
- EVP_CIPHER_CTX_free(DecCtx);
- }
-#endif
- }
+ static constexpr size_t AesNonceBytes = 12;
+ static constexpr size_t AesTagBytes = 16;
- void UpdateNonce()
+ /** AES-256-GCM crypto context. Not exposed outside this translation unit. */
+ struct AesCryptoContext
{
- uint32_t* N32 = reinterpret_cast<uint32_t*>(EncryptNonce);
- N32[0]++;
- N32[1]--;
- N32[2] = N32[0] ^ N32[1];
- }
+ static constexpr size_t NonceBytes = AesNonceBytes;
+ static constexpr size_t TagBytes = AesTagBytes;
- // Returns total encrypted message size, or 0 on failure
- // Output format: [length(4B)][nonce(12B)][ciphertext][tag(16B)]
- int32_t EncryptMessage(uint8_t* Out, const void* In, int32_t InLength)
- {
- UpdateNonce();
+ uint8_t Key[KeySize] = {};
+ uint8_t EncryptNonce[NonceBytes] = {};
+ uint8_t DecryptNonce[NonceBytes] = {};
+ uint64_t DecryptCounter = 0; ///< Sequence number of the next message to be decrypted (for diagnostics)
+ bool HasErrors = false;
- // On Windows, BCrypt algorithm/key handles are created per call. This is simpler than
- // caching but has some overhead. For our use case (relatively large, infrequent messages)
- // this is acceptable.
#if ZEN_PLATFORM_WINDOWS
BCRYPT_ALG_HANDLE hAlg = nullptr;
BCRYPT_KEY_HANDLE hKey = nullptr;
+#else
+ EVP_CIPHER_CTX* EncCtx = nullptr;
+ EVP_CIPHER_CTX* DecCtx = nullptr;
+#endif
- BCryptOpenAlgorithmProvider(&hAlg, BCRYPT_AES_ALGORITHM, nullptr, 0);
- BCryptSetProperty(hAlg, BCRYPT_CHAINING_MODE, (PUCHAR)BCRYPT_CHAIN_MODE_GCM, sizeof(BCRYPT_CHAIN_MODE_GCM), 0);
- BCryptGenerateSymmetricKey(hAlg, &hKey, nullptr, 0, (PUCHAR)Key, KeySize, 0);
-
- BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO AuthInfo;
- BCRYPT_INIT_AUTH_MODE_INFO(AuthInfo);
- AuthInfo.pbNonce = EncryptNonce;
- AuthInfo.cbNonce = NonceBytes;
- uint8_t Tag[TagBytes] = {};
- AuthInfo.pbTag = Tag;
- AuthInfo.cbTag = TagBytes;
-
- ULONG CipherLen = 0;
- NTSTATUS Status =
- BCryptEncrypt(hKey, (PUCHAR)In, (ULONG)InLength, &AuthInfo, nullptr, 0, Out + 4 + NonceBytes, (ULONG)InLength, &CipherLen, 0);
-
- if (!BCRYPT_SUCCESS(Status))
+ AesCryptoContext(const uint8_t (&InKey)[KeySize])
{
- HasErrors = true;
- BCryptDestroyKey(hKey);
- BCryptCloseAlgorithmProvider(hAlg, 0);
- return 0;
- }
+ memcpy(Key, InKey, KeySize);
- // Write header: length + nonce
- memcpy(Out, &InLength, 4);
- memcpy(Out + 4, EncryptNonce, NonceBytes);
- // Write tag after ciphertext
- memcpy(Out + 4 + NonceBytes + CipherLen, Tag, TagBytes);
+ // EncryptNonce is zero-initialized (NIST SP 800-38D §8.2.1 deterministic
+ // construction): fixed_field = 0, counter starts at 0 and is incremented
+ // before each encryption by UpdateNonce(). No RNG is used here because
+ // std::random_device is not guaranteed to be a CSPRNG (historic MinGW,
+ // some WASI targets), and the deterministic construction does not need
+ // one as long as each session uses a unique key.
- BCryptDestroyKey(hKey);
- BCryptCloseAlgorithmProvider(hAlg, 0);
-
- return 4 + NonceBytes + static_cast<int32_t>(CipherLen) + TagBytes;
+#if ZEN_PLATFORM_WINDOWS
+ NTSTATUS Status = BCryptOpenAlgorithmProvider(&hAlg, BCRYPT_AES_ALGORITHM, nullptr, 0);
+ if (!BCRYPT_SUCCESS(Status))
+ {
+ ZEN_ERROR("BCryptOpenAlgorithmProvider failed: 0x{:08x}", static_cast<uint32_t>(Status));
+ hAlg = nullptr;
+ HasErrors = true;
+ return;
+ }
+
+ Status = BCryptSetProperty(hAlg, BCRYPT_CHAINING_MODE, (PUCHAR)BCRYPT_CHAIN_MODE_GCM, sizeof(BCRYPT_CHAIN_MODE_GCM), 0);
+ if (!BCRYPT_SUCCESS(Status))
+ {
+ ZEN_ERROR("BCryptSetProperty(BCRYPT_CHAIN_MODE_GCM) failed: 0x{:08x}", static_cast<uint32_t>(Status));
+ HasErrors = true;
+ return;
+ }
+
+ Status = BCryptGenerateSymmetricKey(hAlg, &hKey, nullptr, 0, (PUCHAR)Key, KeySize, 0);
+ if (!BCRYPT_SUCCESS(Status))
+ {
+ ZEN_ERROR("BCryptGenerateSymmetricKey failed: 0x{:08x}", static_cast<uint32_t>(Status));
+ hKey = nullptr;
+ HasErrors = true;
+ return;
+ }
#else
- if (EVP_EncryptInit_ex(EncCtx, nullptr, nullptr, Key, EncryptNonce) != 1)
- {
- HasErrors = true;
- return 0;
+ while (ERR_get_error() != 0)
+ {
+ }
+
+ EncCtx = EVP_CIPHER_CTX_new();
+ DecCtx = EVP_CIPHER_CTX_new();
+ if (!EncCtx || !DecCtx)
+ {
+ ZEN_ERROR("EVP_CIPHER_CTX_new failed");
+ HasErrors = true;
+ return;
+ }
+
+ if (EVP_EncryptInit_ex(EncCtx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr) != 1)
+ {
+ ZEN_ERROR("EVP_EncryptInit_ex(aes-256-gcm) failed: {}", ERR_get_error());
+ HasErrors = true;
+ return;
+ }
+
+ if (EVP_DecryptInit_ex(DecCtx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr) != 1)
+ {
+ ZEN_ERROR("EVP_DecryptInit_ex(aes-256-gcm) failed: {}", ERR_get_error());
+ HasErrors = true;
+ return;
+ }
+#endif
}
- int32_t Offset = 0;
- // Write length
- memcpy(Out + Offset, &InLength, 4);
- Offset += 4;
- // Write nonce
- memcpy(Out + Offset, EncryptNonce, NonceBytes);
- Offset += NonceBytes;
-
- // Encrypt
- int OutLen = 0;
- if (EVP_EncryptUpdate(EncCtx, Out + Offset, &OutLen, static_cast<const uint8_t*>(In), InLength) != 1)
+ ~AesCryptoContext()
{
- HasErrors = true;
- return 0;
+#if ZEN_PLATFORM_WINDOWS
+ if (hKey)
+ {
+ BCryptDestroyKey(hKey);
+ }
+ if (hAlg)
+ {
+ BCryptCloseAlgorithmProvider(hAlg, 0);
+ }
+ SecureZeroMemory(Key, sizeof(Key));
+ SecureZeroMemory(EncryptNonce, sizeof(EncryptNonce));
+ SecureZeroMemory(DecryptNonce, sizeof(DecryptNonce));
+#else
+ OPENSSL_cleanse(Key, sizeof(Key));
+ OPENSSL_cleanse(EncryptNonce, sizeof(EncryptNonce));
+ OPENSSL_cleanse(DecryptNonce, sizeof(DecryptNonce));
+
+ if (EncCtx)
+ {
+ EVP_CIPHER_CTX_free(EncCtx);
+ }
+ if (DecCtx)
+ {
+ EVP_CIPHER_CTX_free(DecCtx);
+ }
+#endif
}
- Offset += OutLen;
- // Finalize
- int FinalLen = 0;
- if (EVP_EncryptFinal_ex(EncCtx, Out + Offset, &FinalLen) != 1)
+ void UpdateNonce()
{
+ // NIST SP 800-38D §8.2.1 deterministic construction:
+ // nonce = [fixed_field (4 bytes) || invocation_counter (8 bytes, big-endian)]
+ // The low 8 bytes are a strict monotonic counter starting at zero. On 2^64
+ // exhaustion the session is torn down (HasErrors) - never wrap, since a repeated
+ // (key, nonce) pair catastrophically breaks AES-GCM confidentiality and integrity.
+ for (int i = 11; i >= 4; --i)
+ {
+ if (++EncryptNonce[i] != 0)
+ {
+ return;
+ }
+ }
HasErrors = true;
- return 0;
}
- Offset += FinalLen;
- // Get tag
- if (EVP_CIPHER_CTX_ctrl(EncCtx, EVP_CTRL_GCM_GET_TAG, TagBytes, Out + Offset) != 1)
+ int32_t EncryptMessage(uint8_t* Out, const void* In, int32_t InLength)
{
- HasErrors = true;
- return 0;
- }
- Offset += TagBytes;
-
- return Offset;
-#endif
- }
+ UpdateNonce();
+ if (HasErrors)
+ {
+ return 0;
+ }
- // Decrypt a message. Returns decrypted data length, or 0 on failure.
- // Input must be [ciphertext][tag], with nonce provided separately.
- int32_t DecryptMessage(void* Out, const uint8_t* Nonce, const uint8_t* CipherAndTag, int32_t DataLength)
- {
#if ZEN_PLATFORM_WINDOWS
- BCRYPT_ALG_HANDLE hAlg = nullptr;
- BCRYPT_KEY_HANDLE hKey = nullptr;
-
- BCryptOpenAlgorithmProvider(&hAlg, BCRYPT_AES_ALGORITHM, nullptr, 0);
- BCryptSetProperty(hAlg, BCRYPT_CHAINING_MODE, (PUCHAR)BCRYPT_CHAIN_MODE_GCM, sizeof(BCRYPT_CHAIN_MODE_GCM), 0);
- BCryptGenerateSymmetricKey(hAlg, &hKey, nullptr, 0, (PUCHAR)Key, KeySize, 0);
-
- BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO AuthInfo;
- BCRYPT_INIT_AUTH_MODE_INFO(AuthInfo);
- AuthInfo.pbNonce = const_cast<uint8_t*>(Nonce);
- AuthInfo.cbNonce = NonceBytes;
- AuthInfo.pbTag = const_cast<uint8_t*>(CipherAndTag + DataLength);
- AuthInfo.cbTag = TagBytes;
-
- ULONG PlainLen = 0;
- NTSTATUS Status = BCryptDecrypt(hKey,
- (PUCHAR)CipherAndTag,
- (ULONG)DataLength,
- &AuthInfo,
- nullptr,
- 0,
- (PUCHAR)Out,
- (ULONG)DataLength,
- &PlainLen,
- 0);
-
- BCryptDestroyKey(hKey);
- BCryptCloseAlgorithmProvider(hAlg, 0);
-
- if (!BCRYPT_SUCCESS(Status))
- {
- HasErrors = true;
- return 0;
- }
-
- return static_cast<int32_t>(PlainLen);
+ BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO AuthInfo;
+ BCRYPT_INIT_AUTH_MODE_INFO(AuthInfo);
+ AuthInfo.pbNonce = EncryptNonce;
+ AuthInfo.cbNonce = NonceBytes;
+ // Tag is output-only on encrypt; BCryptEncrypt writes TagBytes bytes into it, so skip zero-init.
+ uint8_t Tag[TagBytes];
+ AuthInfo.pbTag = Tag;
+ AuthInfo.cbTag = TagBytes;
+
+ ULONG CipherLen = 0;
+ const NTSTATUS Status = BCryptEncrypt(hKey,
+ (PUCHAR)In,
+ (ULONG)InLength,
+ &AuthInfo,
+ nullptr,
+ 0,
+ Out + 4 + NonceBytes,
+ (ULONG)InLength,
+ &CipherLen,
+ 0);
+
+ if (!BCRYPT_SUCCESS(Status))
+ {
+ ZEN_ERROR("BCryptEncrypt failed: 0x{:08x}", static_cast<uint32_t>(Status));
+ HasErrors = true;
+ return 0;
+ }
+
+ memcpy(Out, &InLength, 4);
+ memcpy(Out + 4, EncryptNonce, NonceBytes);
+ memcpy(Out + 4 + NonceBytes + CipherLen, Tag, TagBytes);
+
+ return 4 + NonceBytes + static_cast<int32_t>(CipherLen) + TagBytes;
#else
- if (EVP_DecryptInit_ex(DecCtx, nullptr, nullptr, Key, Nonce) != 1)
- {
- HasErrors = true;
- return 0;
- }
-
- int OutLen = 0;
- if (EVP_DecryptUpdate(DecCtx, static_cast<uint8_t*>(Out), &OutLen, CipherAndTag, DataLength) != 1)
- {
- HasErrors = true;
- return 0;
+ // Reset per message so any stale state from a previous encrypt (e.g. partial
+ // completion after a prior error) cannot bleed into this operation. Re-bind
+ // the cipher/key; the IV is then set via the normal init call below.
+ if (EVP_CIPHER_CTX_reset(EncCtx) != 1 || EVP_EncryptInit_ex(EncCtx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr) != 1)
+ {
+ ZEN_ERROR("EVP_CIPHER_CTX_reset/EncryptInit failed: {}", ERR_get_error());
+ HasErrors = true;
+ return 0;
+ }
+ if (EVP_EncryptInit_ex(EncCtx, nullptr, nullptr, Key, EncryptNonce) != 1)
+ {
+ ZEN_ERROR("EVP_EncryptInit_ex(key+iv) failed: {}", ERR_get_error());
+ HasErrors = true;
+ return 0;
+ }
+
+ int32_t Offset = 0;
+ memcpy(Out + Offset, &InLength, 4);
+ Offset += 4;
+ memcpy(Out + Offset, EncryptNonce, NonceBytes);
+ Offset += NonceBytes;
+
+ int OutLen = 0;
+ if (EVP_EncryptUpdate(EncCtx, Out + Offset, &OutLen, static_cast<const uint8_t*>(In), InLength) != 1)
+ {
+ ZEN_ERROR("EVP_EncryptUpdate failed: {}", ERR_get_error());
+ HasErrors = true;
+ return 0;
+ }
+ Offset += OutLen;
+
+ int FinalLen = 0;
+ if (EVP_EncryptFinal_ex(EncCtx, Out + Offset, &FinalLen) != 1)
+ {
+ ZEN_ERROR("EVP_EncryptFinal_ex failed: {}", ERR_get_error());
+ HasErrors = true;
+ return 0;
+ }
+ Offset += FinalLen;
+
+ if (EVP_CIPHER_CTX_ctrl(EncCtx, EVP_CTRL_GCM_GET_TAG, TagBytes, Out + Offset) != 1)
+ {
+ ZEN_ERROR("EVP_CTRL_GCM_GET_TAG failed: {}", ERR_get_error());
+ HasErrors = true;
+ return 0;
+ }
+ Offset += TagBytes;
+
+ return Offset;
+#endif
}
- // Set the tag for verification
- if (EVP_CIPHER_CTX_ctrl(DecCtx, EVP_CTRL_GCM_SET_TAG, TagBytes, const_cast<uint8_t*>(CipherAndTag + DataLength)) != 1)
+ int32_t DecryptMessage(void* Out, const uint8_t* Nonce, const uint8_t* CipherAndTag, int32_t DataLength)
{
- HasErrors = true;
- return 0;
+#if ZEN_PLATFORM_WINDOWS
+ BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO AuthInfo;
+ BCRYPT_INIT_AUTH_MODE_INFO(AuthInfo);
+ AuthInfo.pbNonce = const_cast<uint8_t*>(Nonce);
+ AuthInfo.cbNonce = NonceBytes;
+ AuthInfo.pbTag = const_cast<uint8_t*>(CipherAndTag + DataLength);
+ AuthInfo.cbTag = TagBytes;
+
+ ULONG PlainLen = 0;
+ const NTSTATUS Status = BCryptDecrypt(hKey,
+ (PUCHAR)CipherAndTag,
+ (ULONG)DataLength,
+ &AuthInfo,
+ nullptr,
+ 0,
+ (PUCHAR)Out,
+ (ULONG)DataLength,
+ &PlainLen,
+ 0);
+
+ if (!BCRYPT_SUCCESS(Status))
+ {
+ // STATUS_AUTH_TAG_MISMATCH (0xC000A002) indicates GCM integrity failure -
+ // either in-flight corruption or active tampering. Log distinctly from
+ // other BCryptDecrypt failures so that tamper attempts are auditable.
+ static constexpr NTSTATUS STATUS_AUTH_TAG_MISMATCH_VAL = static_cast<NTSTATUS>(0xC000A002L);
+ if (Status == STATUS_AUTH_TAG_MISMATCH_VAL)
+ {
+ ZEN_ERROR("AES-GCM tag verification failed (seq={}): possible tampering or in-flight corruption", DecryptCounter);
+ }
+ else
+ {
+ ZEN_ERROR("BCryptDecrypt failed: 0x{:08x} (seq={})", static_cast<uint32_t>(Status), DecryptCounter);
+ }
+ HasErrors = true;
+ return 0;
+ }
+
+ ++DecryptCounter;
+ return static_cast<int32_t>(PlainLen);
+#else
+ // Same rationale as EncryptMessage: reset the context and re-bind the cipher
+ // before each decrypt to avoid stale state from a previous operation.
+ if (EVP_CIPHER_CTX_reset(DecCtx) != 1 || EVP_DecryptInit_ex(DecCtx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr) != 1)
+ {
+ ZEN_ERROR("EVP_CIPHER_CTX_reset/DecryptInit failed (seq={}): {}", DecryptCounter, ERR_get_error());
+ HasErrors = true;
+ return 0;
+ }
+ if (EVP_DecryptInit_ex(DecCtx, nullptr, nullptr, Key, Nonce) != 1)
+ {
+ ZEN_ERROR("EVP_DecryptInit_ex (seq={}) failed: {}", DecryptCounter, ERR_get_error());
+ HasErrors = true;
+ return 0;
+ }
+
+ int OutLen = 0;
+ if (EVP_DecryptUpdate(DecCtx, static_cast<uint8_t*>(Out), &OutLen, CipherAndTag, DataLength) != 1)
+ {
+ ZEN_ERROR("EVP_DecryptUpdate failed (seq={}): {}", DecryptCounter, ERR_get_error());
+ HasErrors = true;
+ return 0;
+ }
+
+ if (EVP_CIPHER_CTX_ctrl(DecCtx, EVP_CTRL_GCM_SET_TAG, TagBytes, const_cast<uint8_t*>(CipherAndTag + DataLength)) != 1)
+ {
+ ZEN_ERROR("EVP_CTRL_GCM_SET_TAG failed (seq={}): {}", DecryptCounter, ERR_get_error());
+ HasErrors = true;
+ return 0;
+ }
+
+ int FinalLen = 0;
+ if (EVP_DecryptFinal_ex(DecCtx, static_cast<uint8_t*>(Out) + OutLen, &FinalLen) != 1)
+ {
+ // EVP_DecryptFinal_ex returns 0 specifically on GCM tag verification failure
+ // once the tag has been set. Log distinctly so tamper attempts are auditable.
+ ZEN_ERROR("AES-GCM tag verification failed (seq={}): possible tampering or in-flight corruption", DecryptCounter);
+ HasErrors = true;
+ return 0;
+ }
+
+ ++DecryptCounter;
+ return OutLen + FinalLen;
+#endif
}
+ };
- int FinalLen = 0;
- if (EVP_DecryptFinal_ex(DecCtx, static_cast<uint8_t*>(Out) + OutLen, &FinalLen) != 1)
- {
- HasErrors = true;
- return 0;
- }
+} // anonymous namespace
- return OutLen + FinalLen;
-#endif
- }
+struct AsyncAesComputeTransport::CryptoContext : AesCryptoContext
+{
+ using AesCryptoContext::AesCryptoContext;
};
-AesComputeTransport::AesComputeTransport(const uint8_t (&Key)[KeySize], std::unique_ptr<ComputeTransport> InnerTransport)
+// --- AsyncAesComputeTransport ---
+
+AsyncAesComputeTransport::AsyncAesComputeTransport(const uint8_t (&Key)[KeySize],
+ std::unique_ptr<AsyncComputeTransport> InnerTransport,
+ asio::io_context& IoContext)
: m_Crypto(std::make_unique<CryptoContext>(Key))
, m_Inner(std::move(InnerTransport))
+, m_IoContext(IoContext)
{
}
-AesComputeTransport::~AesComputeTransport()
+AsyncAesComputeTransport::~AsyncAesComputeTransport()
{
Close();
}
bool
-AesComputeTransport::IsValid() const
+AsyncAesComputeTransport::IsValid() const
{
return m_Inner && m_Inner->IsValid() && m_Crypto && !m_Crypto->HasErrors && !m_IsClosed;
}
-size_t
-AesComputeTransport::Send(const void* Data, size_t Size)
+void
+AsyncAesComputeTransport::AsyncWrite(const void* Data, size_t Size, AsyncIoHandler Handler)
{
- ZEN_TRACE_CPU("AesComputeTransport::Send");
-
if (!IsValid())
{
- return 0;
+ asio::post(m_IoContext, [Handler = std::move(Handler)] { Handler(asio::error::make_error_code(asio::error::not_connected), 0); });
+ return;
}
- std::lock_guard<std::mutex> Lock(m_Lock);
-
const int32_t DataLength = static_cast<int32_t>(Size);
- const size_t MessageLength = 4 + NonceBytes + Size + TagBytes;
+ const size_t MessageLength = 4 + CryptoContext::NonceBytes + Size + CryptoContext::TagBytes;
- if (m_EncryptBuffer.size() < MessageLength)
- {
- m_EncryptBuffer.resize(MessageLength);
- }
+ // Encrypt directly into the per-write buffer rather than a long-lived member. Using a
+ // member (plaintext + ciphertext share that buffer during encryption on the OpenSSL
+ // path) would leave plaintext on the heap indefinitely and would also make the
+ // transport unsafe if AsyncWrite were ever invoked concurrently. Size the shared_ptr
+ // exactly to EncryptedLen afterwards.
+ auto EncBuf = std::make_shared<std::vector<uint8_t>>(MessageLength);
- const int32_t EncryptedLen = m_Crypto->EncryptMessage(m_EncryptBuffer.data(), Data, DataLength);
+ const int32_t EncryptedLen = m_Crypto->EncryptMessage(EncBuf->data(), Data, DataLength);
if (EncryptedLen == 0)
{
- return 0;
+ asio::post(m_IoContext,
+ [Handler = std::move(Handler)] { Handler(asio::error::make_error_code(asio::error::connection_aborted), 0); });
+ return;
}
- if (!m_Inner->SendMessage(m_EncryptBuffer.data(), static_cast<size_t>(EncryptedLen)))
- {
- return 0;
- }
+ EncBuf->resize(static_cast<size_t>(EncryptedLen));
- return Size;
+ m_Inner->AsyncWrite(
+ EncBuf->data(),
+ EncBuf->size(),
+ [Handler = std::move(Handler), EncBuf, Size](const std::error_code& Ec, size_t /*BytesWritten*/) { Handler(Ec, Ec ? 0 : Size); });
}
-size_t
-AesComputeTransport::Recv(void* Data, size_t Size)
+void
+AsyncAesComputeTransport::AsyncRead(void* Data, size_t Size, AsyncIoHandler Handler)
{
if (!IsValid())
{
- return 0;
+ asio::post(m_IoContext, [Handler = std::move(Handler)] { Handler(asio::error::make_error_code(asio::error::not_connected), 0); });
+ return;
}
- // AES-GCM decrypts entire messages at once, but the caller may request fewer bytes
- // than the decrypted message contains. Excess bytes are buffered in m_RemainingData
- // and returned on subsequent Recv calls without another decryption round-trip.
- ZEN_TRACE_CPU("AesComputeTransport::Recv");
-
- std::lock_guard<std::mutex> Lock(m_Lock);
+ uint8_t* Dest = static_cast<uint8_t*>(Data);
if (!m_RemainingData.empty())
{
const size_t Available = m_RemainingData.size() - m_RemainingOffset;
const size_t ToCopy = std::min(Available, Size);
- memcpy(Data, m_RemainingData.data() + m_RemainingOffset, ToCopy);
+ memcpy(Dest, m_RemainingData.data() + m_RemainingOffset, ToCopy);
m_RemainingOffset += ToCopy;
if (m_RemainingOffset >= m_RemainingData.size())
@@ -339,82 +437,104 @@ AesComputeTransport::Recv(void* Data, size_t Size)
m_RemainingOffset = 0;
}
- return ToCopy;
- }
-
- // Receive packet header: [length(4B)][nonce(12B)]
- struct PacketHeader
- {
- int32_t DataLength = 0;
- uint8_t Nonce[NonceBytes] = {};
- } Header;
-
- if (!m_Inner->RecvMessage(&Header, sizeof(Header)))
- {
- return 0;
- }
-
- // Validate DataLength to prevent OOM from malicious/corrupt peers
- static constexpr int32_t MaxDataLength = 64 * 1024 * 1024; // 64 MiB
-
- if (Header.DataLength <= 0 || Header.DataLength > MaxDataLength)
- {
- ZEN_WARN("AES recv: invalid DataLength {} from peer", Header.DataLength);
- return 0;
- }
-
- // Receive ciphertext + tag
- const size_t MessageLength = static_cast<size_t>(Header.DataLength) + TagBytes;
-
- if (m_EncryptBuffer.size() < MessageLength)
- {
- m_EncryptBuffer.resize(MessageLength);
- }
-
- if (!m_Inner->RecvMessage(m_EncryptBuffer.data(), MessageLength))
- {
- return 0;
- }
-
- // Decrypt
- const size_t BytesToReturn = std::min(static_cast<size_t>(Header.DataLength), Size);
-
- // We need a temporary buffer for decryption if we can't decrypt directly into output
- std::vector<uint8_t> DecryptedBuf(static_cast<size_t>(Header.DataLength));
-
- const int32_t Decrypted = m_Crypto->DecryptMessage(DecryptedBuf.data(), Header.Nonce, m_EncryptBuffer.data(), Header.DataLength);
- if (Decrypted == 0)
- {
- return 0;
- }
-
- memcpy(Data, DecryptedBuf.data(), BytesToReturn);
+ if (ToCopy == Size)
+ {
+ asio::post(m_IoContext, [Handler = std::move(Handler), Size] { Handler(std::error_code{}, Size); });
+ return;
+ }
- // Store remaining data if we couldn't return everything
- if (static_cast<size_t>(Header.DataLength) > BytesToReturn)
- {
- m_RemainingOffset = 0;
- m_RemainingData.assign(DecryptedBuf.begin() + BytesToReturn, DecryptedBuf.begin() + Header.DataLength);
+ DoRecvMessage(Dest + ToCopy, Size - ToCopy, std::move(Handler));
+ return;
}
- return BytesToReturn;
+ DoRecvMessage(Dest, Size, std::move(Handler));
}
void
-AesComputeTransport::MarkComplete()
+AsyncAesComputeTransport::DoRecvMessage(uint8_t* Dest, size_t Size, AsyncIoHandler Handler)
{
- if (IsValid())
- {
- m_Inner->MarkComplete();
- }
+ static constexpr size_t HeaderSize = 4 + CryptoContext::NonceBytes;
+ auto HeaderBuf = std::make_shared<std::array<uint8_t, 4 + 12>>();
+
+ m_Inner->AsyncRead(HeaderBuf->data(),
+ HeaderSize,
+ [this, Dest, Size, Handler = std::move(Handler), HeaderBuf](const std::error_code& Ec, size_t /*Bytes*/) mutable {
+ if (Ec)
+ {
+ Handler(Ec, 0);
+ return;
+ }
+
+ int32_t DataLength = 0;
+ memcpy(&DataLength, HeaderBuf->data(), 4);
+
+ static constexpr int32_t MaxDataLength = 64 * 1024 * 1024;
+ if (DataLength <= 0 || DataLength > MaxDataLength)
+ {
+ Handler(asio::error::make_error_code(asio::error::invalid_argument), 0);
+ return;
+ }
+
+ const size_t MessageLength = static_cast<size_t>(DataLength) + CryptoContext::TagBytes;
+ if (m_DecryptBuffer.size() < MessageLength)
+ {
+ m_DecryptBuffer.resize(MessageLength);
+ }
+
+ auto NonceBuf = std::make_shared<std::array<uint8_t, CryptoContext::NonceBytes>>();
+ memcpy(NonceBuf->data(), HeaderBuf->data() + 4, CryptoContext::NonceBytes);
+
+ m_Inner->AsyncRead(
+ m_DecryptBuffer.data(),
+ MessageLength,
+ [this, Dest, Size, Handler = std::move(Handler), DataLength, NonceBuf](const std::error_code& Ec,
+ size_t /*Bytes*/) mutable {
+ if (Ec)
+ {
+ Handler(Ec, 0);
+ return;
+ }
+
+ std::vector<uint8_t> PlaintextBuf(static_cast<size_t>(DataLength));
+ const int32_t Decrypted =
+ m_Crypto->DecryptMessage(PlaintextBuf.data(), NonceBuf->data(), m_DecryptBuffer.data(), DataLength);
+ if (Decrypted == 0)
+ {
+ Handler(asio::error::make_error_code(asio::error::connection_aborted), 0);
+ return;
+ }
+
+ const size_t BytesToReturn = std::min(static_cast<size_t>(Decrypted), Size);
+ memcpy(Dest, PlaintextBuf.data(), BytesToReturn);
+
+ if (static_cast<size_t>(Decrypted) > BytesToReturn)
+ {
+ m_RemainingOffset = 0;
+ m_RemainingData.assign(PlaintextBuf.begin() + BytesToReturn, PlaintextBuf.begin() + Decrypted);
+ }
+
+ if (BytesToReturn < Size)
+ {
+ DoRecvMessage(Dest + BytesToReturn, Size - BytesToReturn, std::move(Handler));
+ }
+ else
+ {
+ Handler(std::error_code{}, Size);
+ }
+ });
+ });
}
void
-AesComputeTransport::Close()
+AsyncAesComputeTransport::Close()
{
if (!m_IsClosed)
{
- if (m_Inner && m_Inner->IsValid())
+ // Always forward Close() to the inner transport if we have one. Gating on
+ // IsValid() skipped cleanup when the inner transport was partially torn down
+ // (e.g. after a read/write error marked it non-valid but left its socket open),
+ // leaking OS handles. Close implementations are expected to be idempotent.
+ if (m_Inner)
{
m_Inner->Close();
}