// Copyright Epic Games, Inc. All Rights Reserved. #include #include #include #include #include #include #include #ifndef ZEN_USE_OPENSSL # if ZEN_PLATFORM_WINDOWS # define ZEN_USE_OPENSSL 0 # else # define ZEN_USE_OPENSSL 1 # endif #endif ZEN_THIRD_PARTY_INCLUDES_START #include #if ZEN_USE_OPENSSL # include # include # include #else # include # include # define NT_SUCCESS(Status) (((NTSTATUS)(Status)) >= 0) # define STATUS_UNSUCCESSFUL ((NTSTATUS)0xC0000001L) #endif ZEN_THIRD_PARTY_INCLUDES_END namespace zen { using namespace std::literals; namespace crypto { enum class TransformMode : uint32_t { Decrypt, Encrypt }; #if ZEN_USE_OPENSSL class EvpContext { public: EvpContext() : m_Ctx(EVP_CIPHER_CTX_new()) {} ~EvpContext() { EVP_CIPHER_CTX_free(m_Ctx); } operator EVP_CIPHER_CTX*() { return m_Ctx; } private: EVP_CIPHER_CTX* m_Ctx; }; MemoryView Transform(TransformMode Mode, MemoryView Key, MemoryView IV, MemoryView In, MutableMemoryView Out, std::optional& Reason) { const EVP_CIPHER* Cipher = EVP_aes_256_cbc(); ZEN_ASSERT(Cipher != nullptr); EvpContext Ctx; int Err = EVP_CipherInit_ex(Ctx, Cipher, nullptr, reinterpret_cast(Key.GetData()), reinterpret_cast(IV.GetData()), static_cast(Mode)); if (Err != 1) { Reason = fmt::format("failed to initialize cipher, error code '{}'", Err); return MemoryView(); } int EncryptedBytes = 0; int TotalEncryptedBytes = 0; Err = EVP_CipherUpdate(Ctx, reinterpret_cast(Out.GetData()), &EncryptedBytes, reinterpret_cast(In.GetData()), static_cast(In.GetSize())); if (Err != 1) { Reason = fmt::format("update crypto transform failed, error code '{}'", Err); return MemoryView(); } TotalEncryptedBytes = EncryptedBytes; MutableMemoryView Remaining = Out.RightChop(EncryptedBytes); EncryptedBytes = static_cast(Remaining.GetSize()); Err = EVP_CipherFinal(Ctx, reinterpret_cast(Remaining.GetData()), &EncryptedBytes); if (Err != 1) { Reason = fmt::format("finalize crypto transform failed, error code '{}'", Err); return MemoryView(); } TotalEncryptedBytes += EncryptedBytes; return Out.Left(TotalEncryptedBytes); } #else MemoryView Transform(TransformMode Mode, MemoryView Key, MemoryView IV, MemoryView In, MutableMemoryView Out, std::optional& Reason) { BCRYPT_ALG_HANDLE hAesAlg = NULL; NTSTATUS Status = STATUS_UNSUCCESSFUL; // Open an algorithm handle. if (!NT_SUCCESS(Status = BCryptOpenAlgorithmProvider(&hAesAlg, BCRYPT_AES_ALGORITHM, NULL, 0))) { Reason = fmt::format("Error 0x{:08x} returned by BCryptGetProperty"sv, Status); return {}; } auto _ = MakeGuard([hAesAlg] { BCryptCloseAlgorithmProvider(hAesAlg, 0); }); DWORD cbData = 0; DWORD cbBlockLen = 0; if (!NT_SUCCESS(Status = BCryptGetProperty(hAesAlg, BCRYPT_BLOCK_LENGTH, (PBYTE)&cbBlockLen, sizeof(DWORD), &cbData, 0))) { Reason = fmt::format("Error 0x{:08x} returned by BCryptGetProperty"sv, Status); return {}; } if (cbBlockLen > IV.GetSize()) { Reason = "block length is longer than the provided IV length"sv; return {}; } AesIV128Bit MutableIV = AesIV128Bit::FromMemoryView(IV); if (!NT_SUCCESS( Status = BCryptSetProperty(hAesAlg, BCRYPT_CHAINING_MODE, (PBYTE)BCRYPT_CHAIN_MODE_CBC, sizeof(BCRYPT_CHAIN_MODE_CBC), 0))) { Reason = fmt::format("Error 0x{:08x} returned by BCryptSetProperty"sv, Status); return {}; } DWORD cbKeyObject = 0; if (!NT_SUCCESS(Status = BCryptGetProperty(hAesAlg, BCRYPT_OBJECT_LENGTH, (PBYTE)&cbKeyObject, sizeof(DWORD), &cbData, 0))) { Reason = fmt::format("Error 0x{:08x} returned by BCryptGetProperty"sv, Status); return {}; } PBYTE pbKeyObject = (PBYTE)Memory::Alloc(cbKeyObject); if (NULL == pbKeyObject) { Reason = fmt::format("memory allocation failed"); return {}; } auto __ = MakeGuard([pbKeyObject] { Memory::Free(pbKeyObject); }); BCRYPT_KEY_HANDLE hKey = NULL; if (!NT_SUCCESS(Status = BCryptGenerateSymmetricKey(hAesAlg, &hKey, pbKeyObject, cbKeyObject, (PBYTE)Key.GetData(), (ULONG)Key.GetSize(), /* flags */ 0))) { Reason = fmt::format("Error 0x{:08x} returned by BCryptGenerateSymmetricKey"sv, Status); return {}; } auto ___ = MakeGuard([hKey] { BCryptDestroyKey(hKey); }); if (Mode == TransformMode::Encrypt) { DWORD CipherTextByteCount = 0; if (NT_SUCCESS(Status = BCryptEncrypt(hKey, (PUCHAR)In.GetData(), (ULONG)In.GetSize(), NULL, (PUCHAR)MutableIV.GetView().GetData(), cbBlockLen, NULL, 0, &CipherTextByteCount, BCRYPT_BLOCK_PADDING))) { if (Out.GetSize() < CipherTextByteCount) { Reason = "invalid output buffer size"; return {}; } if (NT_SUCCESS(Status = BCryptEncrypt(hKey, (PUCHAR)In.GetData(), (ULONG)In.GetSize(), NULL, (PUCHAR)MutableIV.GetView().GetData(), cbBlockLen, (PUCHAR)Out.GetData(), (ULONG)Out.GetSize(), &CipherTextByteCount, BCRYPT_BLOCK_PADDING))) { return Out.Left(CipherTextByteCount); } } Reason = fmt::format("Error 0x{:08x} returned by BCryptEncrypt", Status); return {}; } else { DWORD PlainTextByteCount = 0; // // Get the output buffer size. // if (NT_SUCCESS(Status = BCryptDecrypt(hKey, (PUCHAR)In.GetData(), (ULONG)In.GetSize(), NULL, (PUCHAR)MutableIV.GetView().GetData(), cbBlockLen, NULL, 0, &PlainTextByteCount, BCRYPT_BLOCK_PADDING))) { if (Out.GetSize() < PlainTextByteCount) { Reason = "invalid output buffer size"sv; return {}; } if (NT_SUCCESS(Status = BCryptDecrypt(hKey, (PUCHAR)In.GetData(), (ULONG)In.GetSize(), NULL, (PUCHAR)MutableIV.GetView().GetData(), cbBlockLen, (PUCHAR)Out.GetData(), (ULONG)Out.GetSize(), &PlainTextByteCount, BCRYPT_BLOCK_PADDING))) { return Out.Left(PlainTextByteCount); } } Reason = fmt::format("Error 0x{:08x} returned by BCryptDecrypt"sv, Status); return {}; } } #endif bool ValidateKeyAndIV(const AesKey256Bit& Key, const AesIV128Bit& IV, std::optional& Reason) { if (Key.IsValid() == false) { Reason = "invalid key"sv; return false; } if (IV.IsValid() == false) { Reason = "invalid initialization vector"sv; return false; } return true; } } // namespace crypto MemoryView Aes::Encrypt(const AesKey256Bit& Key, const AesIV128Bit& IV, MemoryView In, MutableMemoryView Out, std::optional& Reason) { if (crypto::ValidateKeyAndIV(Key, IV, Reason) == false) { return MemoryView(); } return crypto::Transform(crypto::TransformMode::Encrypt, Key.GetView(), IV.GetView(), In, Out, Reason); } MemoryView Aes::Decrypt(const AesKey256Bit& Key, const AesIV128Bit& IV, MemoryView In, MutableMemoryView Out, std::optional& Reason) { if (crypto::ValidateKeyAndIV(Key, IV, Reason) == false) { return MemoryView(); } return crypto::Transform(crypto::TransformMode::Decrypt, Key.GetView(), IV.GetView(), In, Out, Reason); } #if ZEN_WITH_TESTS void crypto_forcelink() { } TEST_CASE("crypto.bits") { using CryptoBits256Bit = CryptoBits<256>; CryptoBits256Bit Bits; CHECK(Bits.IsNull()); CHECK(Bits.IsValid() == false); CHECK(Bits.GetBitCount() == 256); CHECK(Bits.GetSize() == 32); Bits = CryptoBits256Bit::FromString("Addff"sv); CHECK(Bits.IsValid() == false); Bits = CryptoBits256Bit::FromString("abcdefghijklmnopqrstuvxyz0123456"sv); CHECK(Bits.IsValid()); auto SmallerBits = CryptoBits<128>::FromString("abcdefghijklmnopqrstuvxyz0123456"sv); CHECK(SmallerBits.IsValid() == false); } TEST_CASE("crypto.aes") { SUBCASE("basic") { const uint8_t InitVector[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; const AesKey256Bit Key = AesKey256Bit::FromString("abcdefghijklmnopqrstuvxyz0123456"sv); const AesIV128Bit IV = AesIV128Bit::FromMemoryView(MakeMemoryView(InitVector)); std::string_view PlainText = "The quick brown fox jumps over the lazy dog"sv; std::vector EncryptionBuffer; std::vector DecryptionBuffer; std::optional Reason; EncryptionBuffer.resize(PlainText.size() + Aes::BlockSize); DecryptionBuffer.resize(PlainText.size() + Aes::BlockSize); MemoryView EncryptedView = Aes::Encrypt(Key, IV, MakeMemoryView(PlainText), MakeMutableMemoryView(EncryptionBuffer), Reason); MemoryView DecryptedView = Aes::Decrypt(Key, IV, EncryptedView, MakeMutableMemoryView(DecryptionBuffer), Reason); std::string_view EncryptedDecryptedText = std::string_view(reinterpret_cast(DecryptedView.GetData()), DecryptedView.GetSize()); CHECK(EncryptedDecryptedText == PlainText); } } #endif } // namespace zen