aboutsummaryrefslogtreecommitdiff
path: root/src/zenhorde/hordetransportaes.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/zenhorde/hordetransportaes.cpp')
-rw-r--r--src/zenhorde/hordetransportaes.cpp609
1 files changed, 318 insertions, 291 deletions
diff --git a/src/zenhorde/hordetransportaes.cpp b/src/zenhorde/hordetransportaes.cpp
index 505b6bde7..c71866e8c 100644
--- a/src/zenhorde/hordetransportaes.cpp
+++ b/src/zenhorde/hordetransportaes.cpp
@@ -5,6 +5,10 @@
#include <zencore/logging.h>
#include <zencore/trace.h>
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <asio.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
#include <algorithm>
#include <cstring>
#include <random>
@@ -22,274 +26,281 @@ ZEN_THIRD_PARTY_INCLUDES_END
namespace zen::horde {
-struct AesComputeTransport::CryptoContext
-{
- uint8_t Key[KeySize] = {};
- uint8_t EncryptNonce[NonceBytes] = {};
- uint8_t DecryptNonce[NonceBytes] = {};
- bool HasErrors = false;
+namespace {
+
+ static constexpr size_t AesNonceBytes = 12;
+ static constexpr size_t AesTagBytes = 16;
+
+ /** AES-256-GCM crypto context. Not exposed outside this translation unit. */
+ struct AesCryptoContext
+ {
+ static constexpr size_t NonceBytes = AesNonceBytes;
+ static constexpr size_t TagBytes = AesTagBytes;
+
+ uint8_t Key[KeySize] = {};
+ uint8_t EncryptNonce[NonceBytes] = {};
+ uint8_t DecryptNonce[NonceBytes] = {};
+ bool HasErrors = false;
#if !ZEN_PLATFORM_WINDOWS
- EVP_CIPHER_CTX* EncCtx = nullptr;
- EVP_CIPHER_CTX* DecCtx = nullptr;
+ EVP_CIPHER_CTX* EncCtx = nullptr;
+ EVP_CIPHER_CTX* DecCtx = nullptr;
#endif
- CryptoContext(const uint8_t (&InKey)[KeySize])
- {
- memcpy(Key, InKey, KeySize);
-
- // The encrypt nonce is randomly initialized and then deterministically mutated
- // per message via UpdateNonce(). The decrypt nonce is not used — it comes from
- // the wire (each received message carries its own nonce in the header).
- std::random_device Rd;
- std::mt19937 Gen(Rd());
- std::uniform_int_distribution<int> Dist(0, 255);
- for (auto& Byte : EncryptNonce)
+ AesCryptoContext(const uint8_t (&InKey)[KeySize])
{
- Byte = static_cast<uint8_t>(Dist(Gen));
- }
+ memcpy(Key, InKey, KeySize);
+
+ std::random_device Rd;
+ std::mt19937 Gen(Rd());
+ std::uniform_int_distribution<int> Dist(0, 255);
+ for (auto& Byte : EncryptNonce)
+ {
+ Byte = static_cast<uint8_t>(Dist(Gen));
+ }
#if !ZEN_PLATFORM_WINDOWS
- // Drain any stale OpenSSL errors
- while (ERR_get_error() != 0)
- {
- }
+ while (ERR_get_error() != 0)
+ {
+ }
- EncCtx = EVP_CIPHER_CTX_new();
- EVP_EncryptInit_ex(EncCtx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr);
+ EncCtx = EVP_CIPHER_CTX_new();
+ EVP_EncryptInit_ex(EncCtx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr);
- DecCtx = EVP_CIPHER_CTX_new();
- EVP_DecryptInit_ex(DecCtx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr);
+ DecCtx = EVP_CIPHER_CTX_new();
+ EVP_DecryptInit_ex(DecCtx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr);
#endif
- }
+ }
- ~CryptoContext()
- {
+ ~AesCryptoContext()
+ {
#if ZEN_PLATFORM_WINDOWS
- SecureZeroMemory(Key, sizeof(Key));
- SecureZeroMemory(EncryptNonce, sizeof(EncryptNonce));
- SecureZeroMemory(DecryptNonce, sizeof(DecryptNonce));
+ SecureZeroMemory(Key, sizeof(Key));
+ SecureZeroMemory(EncryptNonce, sizeof(EncryptNonce));
+ SecureZeroMemory(DecryptNonce, sizeof(DecryptNonce));
#else
- OPENSSL_cleanse(Key, sizeof(Key));
- OPENSSL_cleanse(EncryptNonce, sizeof(EncryptNonce));
- OPENSSL_cleanse(DecryptNonce, sizeof(DecryptNonce));
-
- if (EncCtx)
- {
- EVP_CIPHER_CTX_free(EncCtx);
+ OPENSSL_cleanse(Key, sizeof(Key));
+ OPENSSL_cleanse(EncryptNonce, sizeof(EncryptNonce));
+ OPENSSL_cleanse(DecryptNonce, sizeof(DecryptNonce));
+
+ if (EncCtx)
+ {
+ EVP_CIPHER_CTX_free(EncCtx);
+ }
+ if (DecCtx)
+ {
+ EVP_CIPHER_CTX_free(DecCtx);
+ }
+#endif
}
- if (DecCtx)
+
+ void UpdateNonce()
{
- EVP_CIPHER_CTX_free(DecCtx);
+ uint32_t* N32 = reinterpret_cast<uint32_t*>(EncryptNonce);
+ N32[0]++;
+ N32[1]--;
+ N32[2] = N32[0] ^ N32[1];
}
-#endif
- }
-
- void UpdateNonce()
- {
- uint32_t* N32 = reinterpret_cast<uint32_t*>(EncryptNonce);
- N32[0]++;
- N32[1]--;
- N32[2] = N32[0] ^ N32[1];
- }
- // Returns total encrypted message size, or 0 on failure
- // Output format: [length(4B)][nonce(12B)][ciphertext][tag(16B)]
- int32_t EncryptMessage(uint8_t* Out, const void* In, int32_t InLength)
- {
- UpdateNonce();
+ int32_t EncryptMessage(uint8_t* Out, const void* In, int32_t InLength)
+ {
+ UpdateNonce();
- // On Windows, BCrypt algorithm/key handles are created per call. This is simpler than
- // caching but has some overhead. For our use case (relatively large, infrequent messages)
- // this is acceptable.
#if ZEN_PLATFORM_WINDOWS
- BCRYPT_ALG_HANDLE hAlg = nullptr;
- BCRYPT_KEY_HANDLE hKey = nullptr;
-
- BCryptOpenAlgorithmProvider(&hAlg, BCRYPT_AES_ALGORITHM, nullptr, 0);
- BCryptSetProperty(hAlg, BCRYPT_CHAINING_MODE, (PUCHAR)BCRYPT_CHAIN_MODE_GCM, sizeof(BCRYPT_CHAIN_MODE_GCM), 0);
- BCryptGenerateSymmetricKey(hAlg, &hKey, nullptr, 0, (PUCHAR)Key, KeySize, 0);
-
- BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO AuthInfo;
- BCRYPT_INIT_AUTH_MODE_INFO(AuthInfo);
- AuthInfo.pbNonce = EncryptNonce;
- AuthInfo.cbNonce = NonceBytes;
- uint8_t Tag[TagBytes] = {};
- AuthInfo.pbTag = Tag;
- AuthInfo.cbTag = TagBytes;
-
- ULONG CipherLen = 0;
- NTSTATUS Status =
- BCryptEncrypt(hKey, (PUCHAR)In, (ULONG)InLength, &AuthInfo, nullptr, 0, Out + 4 + NonceBytes, (ULONG)InLength, &CipherLen, 0);
-
- if (!BCRYPT_SUCCESS(Status))
- {
- HasErrors = true;
+ BCRYPT_ALG_HANDLE hAlg = nullptr;
+ BCRYPT_KEY_HANDLE hKey = nullptr;
+
+ BCryptOpenAlgorithmProvider(&hAlg, BCRYPT_AES_ALGORITHM, nullptr, 0);
+ BCryptSetProperty(hAlg, BCRYPT_CHAINING_MODE, (PUCHAR)BCRYPT_CHAIN_MODE_GCM, sizeof(BCRYPT_CHAIN_MODE_GCM), 0);
+ BCryptGenerateSymmetricKey(hAlg, &hKey, nullptr, 0, (PUCHAR)Key, KeySize, 0);
+
+ BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO AuthInfo;
+ BCRYPT_INIT_AUTH_MODE_INFO(AuthInfo);
+ AuthInfo.pbNonce = EncryptNonce;
+ AuthInfo.cbNonce = NonceBytes;
+ uint8_t Tag[TagBytes] = {};
+ AuthInfo.pbTag = Tag;
+ AuthInfo.cbTag = TagBytes;
+
+ ULONG CipherLen = 0;
+ NTSTATUS Status = BCryptEncrypt(hKey,
+ (PUCHAR)In,
+ (ULONG)InLength,
+ &AuthInfo,
+ nullptr,
+ 0,
+ Out + 4 + NonceBytes,
+ (ULONG)InLength,
+ &CipherLen,
+ 0);
+
+ if (!BCRYPT_SUCCESS(Status))
+ {
+ HasErrors = true;
+ BCryptDestroyKey(hKey);
+ BCryptCloseAlgorithmProvider(hAlg, 0);
+ return 0;
+ }
+
+ memcpy(Out, &InLength, 4);
+ memcpy(Out + 4, EncryptNonce, NonceBytes);
+ memcpy(Out + 4 + NonceBytes + CipherLen, Tag, TagBytes);
+
BCryptDestroyKey(hKey);
BCryptCloseAlgorithmProvider(hAlg, 0);
- return 0;
- }
-
- // Write header: length + nonce
- memcpy(Out, &InLength, 4);
- memcpy(Out + 4, EncryptNonce, NonceBytes);
- // Write tag after ciphertext
- memcpy(Out + 4 + NonceBytes + CipherLen, Tag, TagBytes);
-
- BCryptDestroyKey(hKey);
- BCryptCloseAlgorithmProvider(hAlg, 0);
- return 4 + NonceBytes + static_cast<int32_t>(CipherLen) + TagBytes;
+ return 4 + NonceBytes + static_cast<int32_t>(CipherLen) + TagBytes;
#else
- if (EVP_EncryptInit_ex(EncCtx, nullptr, nullptr, Key, EncryptNonce) != 1)
- {
- HasErrors = true;
- return 0;
- }
-
- int32_t Offset = 0;
- // Write length
- memcpy(Out + Offset, &InLength, 4);
- Offset += 4;
- // Write nonce
- memcpy(Out + Offset, EncryptNonce, NonceBytes);
- Offset += NonceBytes;
-
- // Encrypt
- int OutLen = 0;
- if (EVP_EncryptUpdate(EncCtx, Out + Offset, &OutLen, static_cast<const uint8_t*>(In), InLength) != 1)
- {
- HasErrors = true;
- return 0;
- }
- Offset += OutLen;
-
- // Finalize
- int FinalLen = 0;
- if (EVP_EncryptFinal_ex(EncCtx, Out + Offset, &FinalLen) != 1)
- {
- HasErrors = true;
- return 0;
+ if (EVP_EncryptInit_ex(EncCtx, nullptr, nullptr, Key, EncryptNonce) != 1)
+ {
+ HasErrors = true;
+ return 0;
+ }
+
+ int32_t Offset = 0;
+ memcpy(Out + Offset, &InLength, 4);
+ Offset += 4;
+ memcpy(Out + Offset, EncryptNonce, NonceBytes);
+ Offset += NonceBytes;
+
+ int OutLen = 0;
+ if (EVP_EncryptUpdate(EncCtx, Out + Offset, &OutLen, static_cast<const uint8_t*>(In), InLength) != 1)
+ {
+ HasErrors = true;
+ return 0;
+ }
+ Offset += OutLen;
+
+ int FinalLen = 0;
+ if (EVP_EncryptFinal_ex(EncCtx, Out + Offset, &FinalLen) != 1)
+ {
+ HasErrors = true;
+ return 0;
+ }
+ Offset += FinalLen;
+
+ if (EVP_CIPHER_CTX_ctrl(EncCtx, EVP_CTRL_GCM_GET_TAG, TagBytes, Out + Offset) != 1)
+ {
+ HasErrors = true;
+ return 0;
+ }
+ Offset += TagBytes;
+
+ return Offset;
+#endif
}
- Offset += FinalLen;
- // Get tag
- if (EVP_CIPHER_CTX_ctrl(EncCtx, EVP_CTRL_GCM_GET_TAG, TagBytes, Out + Offset) != 1)
+ int32_t DecryptMessage(void* Out, const uint8_t* Nonce, const uint8_t* CipherAndTag, int32_t DataLength)
{
- HasErrors = true;
- return 0;
- }
- Offset += TagBytes;
-
- return Offset;
-#endif
- }
-
- // Decrypt a message. Returns decrypted data length, or 0 on failure.
- // Input must be [ciphertext][tag], with nonce provided separately.
- int32_t DecryptMessage(void* Out, const uint8_t* Nonce, const uint8_t* CipherAndTag, int32_t DataLength)
- {
#if ZEN_PLATFORM_WINDOWS
- BCRYPT_ALG_HANDLE hAlg = nullptr;
- BCRYPT_KEY_HANDLE hKey = nullptr;
-
- BCryptOpenAlgorithmProvider(&hAlg, BCRYPT_AES_ALGORITHM, nullptr, 0);
- BCryptSetProperty(hAlg, BCRYPT_CHAINING_MODE, (PUCHAR)BCRYPT_CHAIN_MODE_GCM, sizeof(BCRYPT_CHAIN_MODE_GCM), 0);
- BCryptGenerateSymmetricKey(hAlg, &hKey, nullptr, 0, (PUCHAR)Key, KeySize, 0);
-
- BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO AuthInfo;
- BCRYPT_INIT_AUTH_MODE_INFO(AuthInfo);
- AuthInfo.pbNonce = const_cast<uint8_t*>(Nonce);
- AuthInfo.cbNonce = NonceBytes;
- AuthInfo.pbTag = const_cast<uint8_t*>(CipherAndTag + DataLength);
- AuthInfo.cbTag = TagBytes;
-
- ULONG PlainLen = 0;
- NTSTATUS Status = BCryptDecrypt(hKey,
- (PUCHAR)CipherAndTag,
- (ULONG)DataLength,
- &AuthInfo,
- nullptr,
- 0,
- (PUCHAR)Out,
- (ULONG)DataLength,
- &PlainLen,
- 0);
-
- BCryptDestroyKey(hKey);
- BCryptCloseAlgorithmProvider(hAlg, 0);
-
- if (!BCRYPT_SUCCESS(Status))
- {
- HasErrors = true;
- return 0;
- }
+ BCRYPT_ALG_HANDLE hAlg = nullptr;
+ BCRYPT_KEY_HANDLE hKey = nullptr;
+
+ BCryptOpenAlgorithmProvider(&hAlg, BCRYPT_AES_ALGORITHM, nullptr, 0);
+ BCryptSetProperty(hAlg, BCRYPT_CHAINING_MODE, (PUCHAR)BCRYPT_CHAIN_MODE_GCM, sizeof(BCRYPT_CHAIN_MODE_GCM), 0);
+ BCryptGenerateSymmetricKey(hAlg, &hKey, nullptr, 0, (PUCHAR)Key, KeySize, 0);
+
+ BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO AuthInfo;
+ BCRYPT_INIT_AUTH_MODE_INFO(AuthInfo);
+ AuthInfo.pbNonce = const_cast<uint8_t*>(Nonce);
+ AuthInfo.cbNonce = NonceBytes;
+ AuthInfo.pbTag = const_cast<uint8_t*>(CipherAndTag + DataLength);
+ AuthInfo.cbTag = TagBytes;
+
+ ULONG PlainLen = 0;
+ NTSTATUS Status = BCryptDecrypt(hKey,
+ (PUCHAR)CipherAndTag,
+ (ULONG)DataLength,
+ &AuthInfo,
+ nullptr,
+ 0,
+ (PUCHAR)Out,
+ (ULONG)DataLength,
+ &PlainLen,
+ 0);
- return static_cast<int32_t>(PlainLen);
-#else
- if (EVP_DecryptInit_ex(DecCtx, nullptr, nullptr, Key, Nonce) != 1)
- {
- HasErrors = true;
- return 0;
- }
+ BCryptDestroyKey(hKey);
+ BCryptCloseAlgorithmProvider(hAlg, 0);
- int OutLen = 0;
- if (EVP_DecryptUpdate(DecCtx, static_cast<uint8_t*>(Out), &OutLen, CipherAndTag, DataLength) != 1)
- {
- HasErrors = true;
- return 0;
- }
+ if (!BCRYPT_SUCCESS(Status))
+ {
+ HasErrors = true;
+ return 0;
+ }
- // Set the tag for verification
- if (EVP_CIPHER_CTX_ctrl(DecCtx, EVP_CTRL_GCM_SET_TAG, TagBytes, const_cast<uint8_t*>(CipherAndTag + DataLength)) != 1)
- {
- HasErrors = true;
- return 0;
+ return static_cast<int32_t>(PlainLen);
+#else
+ if (EVP_DecryptInit_ex(DecCtx, nullptr, nullptr, Key, Nonce) != 1)
+ {
+ HasErrors = true;
+ return 0;
+ }
+
+ int OutLen = 0;
+ if (EVP_DecryptUpdate(DecCtx, static_cast<uint8_t*>(Out), &OutLen, CipherAndTag, DataLength) != 1)
+ {
+ HasErrors = true;
+ return 0;
+ }
+
+ if (EVP_CIPHER_CTX_ctrl(DecCtx, EVP_CTRL_GCM_SET_TAG, TagBytes, const_cast<uint8_t*>(CipherAndTag + DataLength)) != 1)
+ {
+ HasErrors = true;
+ return 0;
+ }
+
+ int FinalLen = 0;
+ if (EVP_DecryptFinal_ex(DecCtx, static_cast<uint8_t*>(Out) + OutLen, &FinalLen) != 1)
+ {
+ HasErrors = true;
+ return 0;
+ }
+
+ return OutLen + FinalLen;
+#endif
}
+ };
- int FinalLen = 0;
- if (EVP_DecryptFinal_ex(DecCtx, static_cast<uint8_t*>(Out) + OutLen, &FinalLen) != 1)
- {
- HasErrors = true;
- return 0;
- }
+} // anonymous namespace
- return OutLen + FinalLen;
-#endif
- }
+struct AsyncAesComputeTransport::CryptoContext : AesCryptoContext
+{
+ using AesCryptoContext::AesCryptoContext;
};
-AesComputeTransport::AesComputeTransport(const uint8_t (&Key)[KeySize], std::unique_ptr<ComputeTransport> InnerTransport)
+// --- AsyncAesComputeTransport ---
+
+AsyncAesComputeTransport::AsyncAesComputeTransport(const uint8_t (&Key)[KeySize],
+ std::unique_ptr<AsyncComputeTransport> InnerTransport,
+ asio::io_context& IoContext)
: m_Crypto(std::make_unique<CryptoContext>(Key))
, m_Inner(std::move(InnerTransport))
+, m_IoContext(IoContext)
{
}
-AesComputeTransport::~AesComputeTransport()
+AsyncAesComputeTransport::~AsyncAesComputeTransport()
{
Close();
}
bool
-AesComputeTransport::IsValid() const
+AsyncAesComputeTransport::IsValid() const
{
return m_Inner && m_Inner->IsValid() && m_Crypto && !m_Crypto->HasErrors && !m_IsClosed;
}
-size_t
-AesComputeTransport::Send(const void* Data, size_t Size)
+void
+AsyncAesComputeTransport::AsyncWrite(const void* Data, size_t Size, AsyncIoHandler Handler)
{
- ZEN_TRACE_CPU("AesComputeTransport::Send");
-
if (!IsValid())
{
- return 0;
+ asio::post(m_IoContext, [Handler = std::move(Handler)] { Handler(asio::error::make_error_code(asio::error::not_connected), 0); });
+ return;
}
- std::lock_guard<std::mutex> Lock(m_Lock);
-
const int32_t DataLength = static_cast<int32_t>(Size);
- const size_t MessageLength = 4 + NonceBytes + Size + TagBytes;
+ const size_t MessageLength = 4 + CryptoContext::NonceBytes + Size + CryptoContext::TagBytes;
if (m_EncryptBuffer.size() < MessageLength)
{
@@ -299,38 +310,36 @@ AesComputeTransport::Send(const void* Data, size_t Size)
const int32_t EncryptedLen = m_Crypto->EncryptMessage(m_EncryptBuffer.data(), Data, DataLength);
if (EncryptedLen == 0)
{
- return 0;
+ asio::post(m_IoContext,
+ [Handler = std::move(Handler)] { Handler(asio::error::make_error_code(asio::error::connection_aborted), 0); });
+ return;
}
- if (!m_Inner->SendMessage(m_EncryptBuffer.data(), static_cast<size_t>(EncryptedLen)))
- {
- return 0;
- }
+ auto EncBuf = std::make_shared<std::vector<uint8_t>>(m_EncryptBuffer.begin(), m_EncryptBuffer.begin() + EncryptedLen);
- return Size;
+ m_Inner->AsyncWrite(
+ EncBuf->data(),
+ EncBuf->size(),
+ [Handler = std::move(Handler), EncBuf, Size](const std::error_code& Ec, size_t /*BytesWritten*/) { Handler(Ec, Ec ? 0 : Size); });
}
-size_t
-AesComputeTransport::Recv(void* Data, size_t Size)
+void
+AsyncAesComputeTransport::AsyncRead(void* Data, size_t Size, AsyncIoHandler Handler)
{
if (!IsValid())
{
- return 0;
+ asio::post(m_IoContext, [Handler = std::move(Handler)] { Handler(asio::error::make_error_code(asio::error::not_connected), 0); });
+ return;
}
- // AES-GCM decrypts entire messages at once, but the caller may request fewer bytes
- // than the decrypted message contains. Excess bytes are buffered in m_RemainingData
- // and returned on subsequent Recv calls without another decryption round-trip.
- ZEN_TRACE_CPU("AesComputeTransport::Recv");
-
- std::lock_guard<std::mutex> Lock(m_Lock);
+ uint8_t* Dest = static_cast<uint8_t*>(Data);
if (!m_RemainingData.empty())
{
const size_t Available = m_RemainingData.size() - m_RemainingOffset;
const size_t ToCopy = std::min(Available, Size);
- memcpy(Data, m_RemainingData.data() + m_RemainingOffset, ToCopy);
+ memcpy(Dest, m_RemainingData.data() + m_RemainingOffset, ToCopy);
m_RemainingOffset += ToCopy;
if (m_RemainingOffset >= m_RemainingData.size())
@@ -339,78 +348,96 @@ AesComputeTransport::Recv(void* Data, size_t Size)
m_RemainingOffset = 0;
}
- return ToCopy;
- }
-
- // Receive packet header: [length(4B)][nonce(12B)]
- struct PacketHeader
- {
- int32_t DataLength = 0;
- uint8_t Nonce[NonceBytes] = {};
- } Header;
-
- if (!m_Inner->RecvMessage(&Header, sizeof(Header)))
- {
- return 0;
- }
-
- // Validate DataLength to prevent OOM from malicious/corrupt peers
- static constexpr int32_t MaxDataLength = 64 * 1024 * 1024; // 64 MiB
-
- if (Header.DataLength <= 0 || Header.DataLength > MaxDataLength)
- {
- ZEN_WARN("AES recv: invalid DataLength {} from peer", Header.DataLength);
- return 0;
- }
-
- // Receive ciphertext + tag
- const size_t MessageLength = static_cast<size_t>(Header.DataLength) + TagBytes;
-
- if (m_EncryptBuffer.size() < MessageLength)
- {
- m_EncryptBuffer.resize(MessageLength);
- }
-
- if (!m_Inner->RecvMessage(m_EncryptBuffer.data(), MessageLength))
- {
- return 0;
- }
-
- // Decrypt
- const size_t BytesToReturn = std::min(static_cast<size_t>(Header.DataLength), Size);
-
- // We need a temporary buffer for decryption if we can't decrypt directly into output
- std::vector<uint8_t> DecryptedBuf(static_cast<size_t>(Header.DataLength));
-
- const int32_t Decrypted = m_Crypto->DecryptMessage(DecryptedBuf.data(), Header.Nonce, m_EncryptBuffer.data(), Header.DataLength);
- if (Decrypted == 0)
- {
- return 0;
- }
-
- memcpy(Data, DecryptedBuf.data(), BytesToReturn);
+ if (ToCopy == Size)
+ {
+ asio::post(m_IoContext, [Handler = std::move(Handler), Size] { Handler(std::error_code{}, Size); });
+ return;
+ }
- // Store remaining data if we couldn't return everything
- if (static_cast<size_t>(Header.DataLength) > BytesToReturn)
- {
- m_RemainingOffset = 0;
- m_RemainingData.assign(DecryptedBuf.begin() + BytesToReturn, DecryptedBuf.begin() + Header.DataLength);
+ DoRecvMessage(Dest + ToCopy, Size - ToCopy, std::move(Handler));
+ return;
}
- return BytesToReturn;
+ DoRecvMessage(Dest, Size, std::move(Handler));
}
void
-AesComputeTransport::MarkComplete()
+AsyncAesComputeTransport::DoRecvMessage(uint8_t* Dest, size_t Size, AsyncIoHandler Handler)
{
- if (IsValid())
- {
- m_Inner->MarkComplete();
- }
+ static constexpr size_t HeaderSize = 4 + CryptoContext::NonceBytes;
+ auto HeaderBuf = std::make_shared<std::array<uint8_t, 4 + 12>>();
+
+ m_Inner->AsyncRead(HeaderBuf->data(),
+ HeaderSize,
+ [this, Dest, Size, Handler = std::move(Handler), HeaderBuf](const std::error_code& Ec, size_t /*Bytes*/) mutable {
+ if (Ec)
+ {
+ Handler(Ec, 0);
+ return;
+ }
+
+ int32_t DataLength = 0;
+ memcpy(&DataLength, HeaderBuf->data(), 4);
+
+ static constexpr int32_t MaxDataLength = 64 * 1024 * 1024;
+ if (DataLength <= 0 || DataLength > MaxDataLength)
+ {
+ Handler(asio::error::make_error_code(asio::error::invalid_argument), 0);
+ return;
+ }
+
+ const size_t MessageLength = static_cast<size_t>(DataLength) + CryptoContext::TagBytes;
+ if (m_DecryptBuffer.size() < MessageLength)
+ {
+ m_DecryptBuffer.resize(MessageLength);
+ }
+
+ auto NonceBuf = std::make_shared<std::array<uint8_t, CryptoContext::NonceBytes>>();
+ memcpy(NonceBuf->data(), HeaderBuf->data() + 4, CryptoContext::NonceBytes);
+
+ m_Inner->AsyncRead(
+ m_DecryptBuffer.data(),
+ MessageLength,
+ [this, Dest, Size, Handler = std::move(Handler), DataLength, NonceBuf](const std::error_code& Ec,
+ size_t /*Bytes*/) mutable {
+ if (Ec)
+ {
+ Handler(Ec, 0);
+ return;
+ }
+
+ std::vector<uint8_t> PlaintextBuf(static_cast<size_t>(DataLength));
+ const int32_t Decrypted =
+ m_Crypto->DecryptMessage(PlaintextBuf.data(), NonceBuf->data(), m_DecryptBuffer.data(), DataLength);
+ if (Decrypted == 0)
+ {
+ Handler(asio::error::make_error_code(asio::error::connection_aborted), 0);
+ return;
+ }
+
+ const size_t BytesToReturn = std::min(static_cast<size_t>(Decrypted), Size);
+ memcpy(Dest, PlaintextBuf.data(), BytesToReturn);
+
+ if (static_cast<size_t>(Decrypted) > BytesToReturn)
+ {
+ m_RemainingOffset = 0;
+ m_RemainingData.assign(PlaintextBuf.begin() + BytesToReturn, PlaintextBuf.begin() + Decrypted);
+ }
+
+ if (BytesToReturn < Size)
+ {
+ DoRecvMessage(Dest + BytesToReturn, Size - BytesToReturn, std::move(Handler));
+ }
+ else
+ {
+ Handler(std::error_code{}, Size);
+ }
+ });
+ });
}
void
-AesComputeTransport::Close()
+AsyncAesComputeTransport::Close()
{
if (!m_IsClosed)
{