// Copyright Epic Games, Inc. All Rights Reserved. #include #include #include #include #include #include #include #include #if ZEN_PLATFORM_WINDOWS # pragma comment(lib, "crypt32.lib") # pragma comment(lib, "ws2_32.lib") #endif namespace zen { class NullCipher final : public SymmetricCipher { public: NullCipher() = default; virtual ~NullCipher() = default; virtual bool Initialize(MemoryView, MemoryView) override final { return true; } virtual CipherSettings Settings() override final { return {}; } virtual MemoryView Encrypt(MemoryView Data, MutableMemoryView) override final { return Data; } virtual MemoryView Decrypt(MemoryView Data, MutableMemoryView) override final { return Data; } }; std::unique_ptr MakeNullCipher() { return std::make_unique(); } #if ZEN_PLATFORM_WINDOWS class Aes final : public SymmetricCipher { public: Aes(const EVP_CIPHER* Cipher = EVP_aes_256_cbc()) : m_Cipher(Cipher) { ZEN_ASSERT(Cipher); m_KeySize = static_cast(EVP_CIPHER_key_length(m_Cipher)); m_InitVectorSize = static_cast(EVP_CIPHER_iv_length(m_Cipher)); m_BlockSize = static_cast(EVP_CIPHER_block_size(m_Cipher)); } virtual ~Aes() { if (m_EncryptionCtx) { EVP_CIPHER_CTX_free(m_EncryptionCtx); } if (m_DecryptionCtx) { EVP_CIPHER_CTX_free(m_DecryptionCtx); } } virtual bool Initialize(MemoryView Key, MemoryView InitVector) override final { ZEN_ASSERT(m_EncryptionCtx == nullptr && m_DecryptionCtx == nullptr); ZEN_ASSERT(Key.GetSize() == m_KeySize); ZEN_ASSERT(InitVector.GetSize() == m_InitVectorSize); m_EncryptionCtx = EVP_CIPHER_CTX_new(); m_DecryptionCtx = EVP_CIPHER_CTX_new(); if (int ErrorCode = EVP_EncryptInit_ex(m_EncryptionCtx, m_Cipher, nullptr, reinterpret_cast(Key.GetData()), reinterpret_cast(InitVector.GetData())); ErrorCode != 1) { return false; } if (int ErrorCode = EVP_DecryptInit_ex(m_DecryptionCtx, m_Cipher, nullptr, reinterpret_cast(Key.GetData()), reinterpret_cast(InitVector.GetData())); ErrorCode != 1) { return false; } return true; } virtual CipherSettings Settings() override final { return {.KeySize = m_KeySize, .InitVectorSize = m_InitVectorSize, .BlockSize = m_BlockSize}; } virtual MemoryView Encrypt(MemoryView Data, MutableMemoryView EncryptionBuffer) { ZEN_ASSERT(m_EncryptionCtx); const uint64_t InputSize = Data.GetSize(); const uint64_t NeededSize = RoundUp(InputSize, m_BlockSize); if (NeededSize > EncryptionBuffer.GetSize()) { return MemoryView(); } int TotalSize = 0; int EncryptedSize = 0; int ErrorCode = EVP_EncryptUpdate(m_EncryptionCtx, reinterpret_cast(EncryptionBuffer.GetData()), &EncryptedSize, reinterpret_cast(Data.GetData()), static_cast(Data.GetSize())); if (ErrorCode != 1) { return MemoryView(); } TotalSize = EncryptedSize; MutableMemoryView Remaining = EncryptionBuffer.RightChop(uint64_t(EncryptedSize)); ErrorCode = EVP_EncryptFinal_ex(m_EncryptionCtx, reinterpret_cast(Remaining.GetData()), &EncryptedSize); if (ErrorCode != 1) { return MemoryView(); } TotalSize += EncryptedSize; return EncryptionBuffer.Left(uint64_t(TotalSize)); } virtual MemoryView Decrypt(MemoryView Data, MutableMemoryView DecryptionBuffer) override final { ZEN_ASSERT(m_DecryptionCtx); int TotalSize = 0; int DecryptedSize = 0; int ErrorCode = EVP_DecryptUpdate(m_DecryptionCtx, reinterpret_cast(DecryptionBuffer.GetData()), &DecryptedSize, reinterpret_cast(Data.GetData()), static_cast(Data.GetSize())); if (ErrorCode != 1) { return MemoryView(); } TotalSize = DecryptedSize; MutableMemoryView Remaining = DecryptionBuffer.RightChop(uint64_t(DecryptedSize)); ErrorCode = EVP_DecryptFinal_ex(m_DecryptionCtx, reinterpret_cast(Remaining.GetData()), &DecryptedSize); TotalSize += DecryptedSize; return DecryptionBuffer.Left(uint64_t(TotalSize)); } private: const EVP_CIPHER* m_Cipher = nullptr; EVP_CIPHER_CTX* m_EncryptionCtx = nullptr; EVP_CIPHER_CTX* m_DecryptionCtx = nullptr; size_t m_BlockSize = 0; size_t m_KeySize = 0; size_t m_InitVectorSize = 0; }; std::unique_ptr MakeAesCipher() { return std::make_unique(); } #endif // ZEN_PLATFORM_WINDOWS #if ZEN_WITH_TESTS using namespace std::literals; void crypto_forcelink() { } TEST_CASE("crypto.aes") { SUBCASE("basic") { # if ZEN_PLATFORM_WINDOWS auto Cipher = std::make_unique(); std::string_view PlainText = "The quick brown fox jumps over the lazy dog"sv; std::vector Key = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}; std::vector Seed = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; std::vector EncryptionBuffer; std::vector DecryptionBuffer; bool Ok = Cipher->Initialize(MakeMemoryView(Key), MakeMemoryView(Seed)); CHECK(Ok); EncryptionBuffer.resize(PlainText.size() + Cipher->Settings().BlockSize); DecryptionBuffer.resize(PlainText.size() + Cipher->Settings().BlockSize); MemoryView EncryptedView = Cipher->Encrypt(MakeMemoryView(PlainText), MakeMutableMemoryView(EncryptionBuffer)); CHECK(EncryptedView.IsEmpty() == false); MemoryView DecryptedView = Cipher->Decrypt(EncryptedView, MakeMutableMemoryView(DecryptionBuffer)); CHECK(DecryptedView.IsEmpty() == false); std::string_view EncryptedDecryptedText = std::string_view(reinterpret_cast(DecryptedView.GetData()), DecryptedView.GetSize()); CHECK(EncryptedDecryptedText == PlainText); } # endif } #endif } // namespace zen