diff options
Diffstat (limited to 'zencore/crypto.cpp')
| -rw-r--r-- | zencore/crypto.cpp | 286 |
1 files changed, 136 insertions, 150 deletions
diff --git a/zencore/crypto.cpp b/zencore/crypto.cpp index 57820c278..0ad368f3f 100644 --- a/zencore/crypto.cpp +++ b/zencore/crypto.cpp @@ -4,12 +4,15 @@ #include <zencore/intmath.h> #include <zencore/testing.h> +#include <string> +#include <string_view> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <fmt/format.h> #include <openssl/conf.h> #include <openssl/err.h> #include <openssl/evp.h> - -#include <string> -#include <string_view> +ZEN_THIRD_PARTY_INCLUDES_END #if ZEN_PLATFORM_WINDOWS # pragma comment(lib, "crypt32.lib") @@ -18,212 +21,195 @@ namespace zen { -class NoOpSymmetricCipher final : public SymmetricCipher -{ -public: - NoOpSymmetricCipher() = default; - - virtual ~NoOpSymmetricCipher() = default; - - virtual bool Initialize(MemoryView, MemoryView) override final { return true; } - - virtual CipherSettings Settings() override final { return {}; } - - virtual MemoryView Encrypt(MemoryView Data, MutableMemoryView) override final { return Data; } - - virtual MemoryView Decrypt(MemoryView Data, MutableMemoryView) override final { return Data; } -}; +using namespace std::literals; -std::unique_ptr<SymmetricCipher> -SymmetricCipher::CreateNoOp() -{ - return std::make_unique<NoOpSymmetricCipher>(); -} +namespace crypto { -class Aes final : public SymmetricCipher -{ -public: - Aes(const EVP_CIPHER* Cipher = EVP_aes_256_cbc()) : m_Cipher(Cipher) + class EvpContext { - ZEN_ASSERT(Cipher); - m_KeySize = static_cast<size_t>(EVP_CIPHER_key_length(m_Cipher)); - m_InitVectorSize = static_cast<size_t>(EVP_CIPHER_iv_length(m_Cipher)); - m_BlockSize = static_cast<size_t>(EVP_CIPHER_block_size(m_Cipher)); - } + public: + EvpContext() : m_Ctx(EVP_CIPHER_CTX_new()) {} + ~EvpContext() { EVP_CIPHER_CTX_free(m_Ctx); } - virtual ~Aes() - { - if (m_EncryptionCtx) - { - EVP_CIPHER_CTX_free(m_EncryptionCtx); - } + operator EVP_CIPHER_CTX*() { return m_Ctx; } - if (m_DecryptionCtx) - { - EVP_CIPHER_CTX_free(m_DecryptionCtx); - } - } + private: + EVP_CIPHER_CTX* m_Ctx; + }; - virtual bool Initialize(MemoryView Key, MemoryView InitVector) override final + enum class TransformMode : uint32_t { - ZEN_ASSERT(m_EncryptionCtx == nullptr && m_DecryptionCtx == nullptr); - ZEN_ASSERT(Key.GetSize() == m_KeySize); - ZEN_ASSERT(InitVector.GetSize() == m_InitVectorSize); - - m_EncryptionCtx = EVP_CIPHER_CTX_new(); - m_DecryptionCtx = EVP_CIPHER_CTX_new(); - - if (int ErrorCode = EVP_EncryptInit_ex(m_EncryptionCtx, - m_Cipher, - nullptr, - reinterpret_cast<const unsigned char*>(Key.GetData()), - reinterpret_cast<const unsigned char*>(InitVector.GetData())); - ErrorCode != 1) - { - return false; - } - - if (int ErrorCode = EVP_DecryptInit_ex(m_DecryptionCtx, - m_Cipher, - nullptr, - reinterpret_cast<const unsigned char*>(Key.GetData()), - reinterpret_cast<const unsigned char*>(InitVector.GetData())); - ErrorCode != 1) - { - return false; - } - - return true; - } - - virtual CipherSettings Settings() override final + Decrypt, + Encrypt + }; + + MemoryView Transform(const EVP_CIPHER* Cipher, + TransformMode Mode, + MemoryView Key, + MemoryView IV, + MemoryView In, + MutableMemoryView Out, + std::optional<std::string>& Reason) { - return {.KeySize = m_KeySize, .InitVectorSize = m_InitVectorSize, .BlockSize = m_BlockSize}; - } + ZEN_ASSERT(Cipher != nullptr); - virtual MemoryView Encrypt(MemoryView Data, MutableMemoryView EncryptionBuffer) override - { - ZEN_ASSERT(m_EncryptionCtx); + EvpContext Ctx; - const uint64_t InputSize = Data.GetSize(); - const uint64_t NeededSize = RoundUp(InputSize, m_BlockSize); + int Err = EVP_CipherInit_ex(Ctx, + Cipher, + nullptr, + reinterpret_cast<const unsigned char*>(Key.GetData()), + reinterpret_cast<const unsigned char*>(IV.GetData()), + static_cast<int>(Mode)); - if (NeededSize > EncryptionBuffer.GetSize()) + if (Err != 1) { + if (Reason) + { + Reason = fmt::format("failed to initialize cipher, error code '{}'", Err); + } + return MemoryView(); } - int TotalSize = 0; - int EncryptedSize = 0; - int ErrorCode = EVP_EncryptUpdate(m_EncryptionCtx, - reinterpret_cast<unsigned char*>(EncryptionBuffer.GetData()), - &EncryptedSize, - reinterpret_cast<const unsigned char*>(Data.GetData()), - static_cast<int>(Data.GetSize())); + int EncryptedBytes = 0; + int TotalEncryptedBytes = 0; - if (ErrorCode != 1) + Err = EVP_CipherUpdate(Ctx, + reinterpret_cast<unsigned char*>(Out.GetData()), + &EncryptedBytes, + reinterpret_cast<const unsigned char*>(In.GetData()), + static_cast<int>(In.GetSize())); + + if (Err != 1) { + if (Reason) + { + Reason = fmt::format("update crypto transform failed, error code '{}'", Err); + } + return MemoryView(); } - TotalSize = EncryptedSize; - MutableMemoryView Remaining = EncryptionBuffer.RightChop(uint64_t(EncryptedSize)); + TotalEncryptedBytes = EncryptedBytes; + MutableMemoryView Remaining = Out.RightChop(EncryptedBytes); + + EncryptedBytes = static_cast<int>(Remaining.GetSize()); - ErrorCode = EVP_EncryptFinal_ex(m_EncryptionCtx, reinterpret_cast<unsigned char*>(Remaining.GetData()), &EncryptedSize); + Err = EVP_CipherFinal(Ctx, reinterpret_cast<unsigned char*>(Remaining.GetData()), &EncryptedBytes); - if (ErrorCode != 1) + if (Err != 1) { + if (Reason) + { + Reason = fmt::format("finalize crypto transform failed, error code '{}'", Err); + } + return MemoryView(); } - TotalSize += EncryptedSize; + TotalEncryptedBytes += EncryptedBytes; - return EncryptionBuffer.Left(uint64_t(TotalSize)); + return Out.Left(TotalEncryptedBytes); } - virtual MemoryView Decrypt(MemoryView Data, MutableMemoryView DecryptionBuffer) override final + bool ValidateKeyAndIV(const AesKey256Bit& Key, const AesIV128Bit& IV, std::optional<std::string>& Reason) { - ZEN_ASSERT(m_DecryptionCtx); - - int TotalSize = 0; - int DecryptedSize = 0; + if (Key.IsValid() == false) + { + if (Reason) + { + Reason = "Invalid key"sv; + } - int ErrorCode = EVP_DecryptUpdate(m_DecryptionCtx, - reinterpret_cast<unsigned char*>(DecryptionBuffer.GetData()), - &DecryptedSize, - reinterpret_cast<const unsigned char*>(Data.GetData()), - static_cast<int>(Data.GetSize())); + return false; + } - if (ErrorCode != 1) + if (IV.IsValid() == false) { - return MemoryView(); + if (Reason) + { + Reason = "Invalid initialization vector"sv; + } + + return false; } - TotalSize = DecryptedSize; - MutableMemoryView Remaining = DecryptionBuffer.RightChop(uint64_t(DecryptedSize)); + return true; + } - ErrorCode = EVP_DecryptFinal_ex(m_DecryptionCtx, reinterpret_cast<unsigned char*>(Remaining.GetData()), &DecryptedSize); +} // namespace crypto - if (ErrorCode != 1) - { - return MemoryView(); - } +MemoryView +Aes::Encrypt(const AesKey256Bit& Key, const AesIV128Bit& IV, MemoryView In, MutableMemoryView Out, std::optional<std::string>& Reason) +{ + if (crypto::ValidateKeyAndIV(Key, IV, Reason) == false) + { + return MemoryView(); + } - TotalSize += DecryptedSize; + return crypto::Transform(EVP_aes_256_cbc(), crypto::TransformMode::Encrypt, Key.GetView(), IV.GetView(), In, Out, Reason); +} - return DecryptionBuffer.Left(uint64_t(TotalSize)); +MemoryView +Aes::Decrypt(const AesKey256Bit& Key, const AesIV128Bit& IV, MemoryView In, MutableMemoryView Out, std::optional<std::string>& Reason) +{ + if (crypto::ValidateKeyAndIV(Key, IV, Reason) == false) + { + return MemoryView(); } -private: - const EVP_CIPHER* m_Cipher = nullptr; - EVP_CIPHER_CTX* m_EncryptionCtx = nullptr; - EVP_CIPHER_CTX* m_DecryptionCtx = nullptr; - size_t m_BlockSize = 0; - size_t m_KeySize = 0; - size_t m_InitVectorSize = 0; -}; - -std::unique_ptr<SymmetricCipher> -SymmetricCipher::CreateAes() -{ - return std::make_unique<Aes>(); + return crypto::Transform(EVP_aes_256_cbc(), crypto::TransformMode::Decrypt, Key.GetView(), IV.GetView(), In, Out, Reason); } #if ZEN_WITH_TESTS -using namespace std::literals; - void crypto_forcelink() { } +TEST_CASE("crypto.bits") +{ + using CryptoBits256Bit = CryptoBits<256>; + + CryptoBits256Bit Bits; + + CHECK(Bits.IsNull()); + CHECK(Bits.IsValid() == false); + + CHECK(Bits.GetBitCount() == 256); + CHECK(Bits.GetSize() == 32); + + Bits = CryptoBits256Bit::FromString("Addff"sv); + CHECK(Bits.IsValid() == false); + + Bits = CryptoBits256Bit::FromString("abcdefghijklmnopqrstuvxyz0123456"sv); + CHECK(Bits.IsValid()); + + auto SmallerBits = CryptoBits<128>::FromString("abcdefghijklmnopqrstuvxyz0123456"sv); + CHECK(SmallerBits.IsValid() == false); +} + TEST_CASE("crypto.aes") { SUBCASE("basic") { - auto Cipher = std::make_unique<Aes>(); + const uint8_t InitVector[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + const AesKey256Bit Key = AesKey256Bit::FromString("abcdefghijklmnopqrstuvxyz0123456"sv); + const AesIV128Bit IV = AesIV128Bit::FromMemoryView(MakeMemoryView(InitVector)); std::string_view PlainText = "The quick brown fox jumps over the lazy dog"sv; - std::vector<uint8_t> Key = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, - 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}; - std::vector<uint8_t> Seed = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; - - std::vector<uint8_t> EncryptionBuffer; - std::vector<uint8_t> DecryptionBuffer; - - bool Ok = Cipher->Initialize(MakeMemoryView(Key), MakeMemoryView(Seed)); - CHECK(Ok); - - EncryptionBuffer.resize(PlainText.size() + Cipher->Settings().BlockSize); - DecryptionBuffer.resize(PlainText.size() + Cipher->Settings().BlockSize); + std::vector<uint8_t> EncryptionBuffer; + std::vector<uint8_t> DecryptionBuffer; + std::optional<std::string> Reason; - MemoryView EncryptedView = Cipher->Encrypt(MakeMemoryView(PlainText), MakeMutableMemoryView(EncryptionBuffer)); - CHECK(EncryptedView.IsEmpty() == false); + EncryptionBuffer.resize(PlainText.size() + Aes::BlockSize); + DecryptionBuffer.resize(PlainText.size() + Aes::BlockSize); - MemoryView DecryptedView = Cipher->Decrypt(EncryptedView, MakeMutableMemoryView(DecryptionBuffer)); - CHECK(DecryptedView.IsEmpty() == false); + MemoryView EncryptedView = Aes::Encrypt(Key, IV, MakeMemoryView(PlainText), MakeMutableMemoryView(EncryptionBuffer), Reason); + MemoryView DecryptedView = Aes::Decrypt(Key, IV, EncryptedView, MakeMutableMemoryView(DecryptionBuffer), Reason); std::string_view EncryptedDecryptedText = std::string_view(reinterpret_cast<const char*>(DecryptedView.GetData()), DecryptedView.GetSize()); |