// Copyright Epic Games, Inc. All Rights Reserved. #include #include #include #include #include ZEN_THIRD_PARTY_INCLUDES_START #include #include #include #include ZEN_THIRD_PARTY_INCLUDES_END #if ZEN_PLATFORM_WINDOWS # pragma comment(lib, "crypt32.lib") # pragma comment(lib, "ws2_32.lib") #endif namespace zen { using namespace std::literals; namespace crypto { class EvpContext { public: EvpContext() : m_Ctx(EVP_CIPHER_CTX_new()) {} ~EvpContext() { EVP_CIPHER_CTX_free(m_Ctx); } operator EVP_CIPHER_CTX*() { return m_Ctx; } private: EVP_CIPHER_CTX* m_Ctx; }; enum class TransformMode : uint32_t { Decrypt, Encrypt }; MemoryView Transform(const EVP_CIPHER* Cipher, TransformMode Mode, MemoryView Key, MemoryView IV, MemoryView In, MutableMemoryView Out, std::optional& Reason) { ZEN_ASSERT(Cipher != nullptr); EvpContext Ctx; int Err = EVP_CipherInit_ex(Ctx, Cipher, nullptr, reinterpret_cast(Key.GetData()), reinterpret_cast(IV.GetData()), static_cast(Mode)); if (Err != 1) { Reason = fmt::format("failed to initialize cipher, error code '{}'", Err); return MemoryView(); } int EncryptedBytes = 0; int TotalEncryptedBytes = 0; Err = EVP_CipherUpdate(Ctx, reinterpret_cast(Out.GetData()), &EncryptedBytes, reinterpret_cast(In.GetData()), static_cast(In.GetSize())); if (Err != 1) { Reason = fmt::format("update crypto transform failed, error code '{}'", Err); return MemoryView(); } TotalEncryptedBytes = EncryptedBytes; MutableMemoryView Remaining = Out.RightChop(EncryptedBytes); EncryptedBytes = static_cast(Remaining.GetSize()); Err = EVP_CipherFinal(Ctx, reinterpret_cast(Remaining.GetData()), &EncryptedBytes); if (Err != 1) { Reason = fmt::format("finalize crypto transform failed, error code '{}'", Err); return MemoryView(); } TotalEncryptedBytes += EncryptedBytes; return Out.Left(TotalEncryptedBytes); } bool ValidateKeyAndIV(const AesKey256Bit& Key, const AesIV128Bit& IV, std::optional& Reason) { if (Key.IsValid() == false) { Reason = "invalid key"sv; return false; } if (IV.IsValid() == false) { Reason = "invalid initialization vector"sv; return false; } return true; } } // namespace crypto MemoryView Aes::Encrypt(const AesKey256Bit& Key, const AesIV128Bit& IV, MemoryView In, MutableMemoryView Out, std::optional& Reason) { if (crypto::ValidateKeyAndIV(Key, IV, Reason) == false) { return MemoryView(); } return crypto::Transform(EVP_aes_256_cbc(), crypto::TransformMode::Encrypt, Key.GetView(), IV.GetView(), In, Out, Reason); } MemoryView Aes::Decrypt(const AesKey256Bit& Key, const AesIV128Bit& IV, MemoryView In, MutableMemoryView Out, std::optional& Reason) { if (crypto::ValidateKeyAndIV(Key, IV, Reason) == false) { return MemoryView(); } return crypto::Transform(EVP_aes_256_cbc(), crypto::TransformMode::Decrypt, Key.GetView(), IV.GetView(), In, Out, Reason); } #if ZEN_WITH_TESTS void crypto_forcelink() { } TEST_CASE("crypto.bits") { using CryptoBits256Bit = CryptoBits<256>; CryptoBits256Bit Bits; CHECK(Bits.IsNull()); CHECK(Bits.IsValid() == false); CHECK(Bits.GetBitCount() == 256); CHECK(Bits.GetSize() == 32); Bits = CryptoBits256Bit::FromString("Addff"sv); CHECK(Bits.IsValid() == false); Bits = CryptoBits256Bit::FromString("abcdefghijklmnopqrstuvxyz0123456"sv); CHECK(Bits.IsValid()); auto SmallerBits = CryptoBits<128>::FromString("abcdefghijklmnopqrstuvxyz0123456"sv); CHECK(SmallerBits.IsValid() == false); } TEST_CASE("crypto.aes") { SUBCASE("basic") { const uint8_t InitVector[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; const AesKey256Bit Key = AesKey256Bit::FromString("abcdefghijklmnopqrstuvxyz0123456"sv); const AesIV128Bit IV = AesIV128Bit::FromMemoryView(MakeMemoryView(InitVector)); std::string_view PlainText = "The quick brown fox jumps over the lazy dog"sv; std::vector EncryptionBuffer; std::vector DecryptionBuffer; std::optional Reason; EncryptionBuffer.resize(PlainText.size() + Aes::BlockSize); DecryptionBuffer.resize(PlainText.size() + Aes::BlockSize); MemoryView EncryptedView = Aes::Encrypt(Key, IV, MakeMemoryView(PlainText), MakeMutableMemoryView(EncryptionBuffer), Reason); MemoryView DecryptedView = Aes::Decrypt(Key, IV, EncryptedView, MakeMutableMemoryView(DecryptionBuffer), Reason); std::string_view EncryptedDecryptedText = std::string_view(reinterpret_cast(DecryptedView.GetData()), DecryptedView.GetSize()); CHECK(EncryptedDecryptedText == PlainText); } } #endif } // namespace zen