diff options
| author | Per Larsson <[email protected]> | 2022-02-07 13:00:29 +0100 |
|---|---|---|
| committer | Per Larsson <[email protected]> | 2022-02-07 13:00:29 +0100 |
| commit | aba862f45cbde3debe99e31bc8ec3e338c5bbf4a (patch) | |
| tree | d47580f8027d90e53a084bbb19192e141f497637 | |
| parent | Missing override suffix compile fix (diff) | |
| download | zen-aba862f45cbde3debe99e31bc8ec3e338c5bbf4a.tar.xz zen-aba862f45cbde3debe99e31bc8ec3e338c5bbf4a.zip | |
Replaced crypto transform abstraction with a concrete API.
| -rw-r--r-- | zencore/crypto.cpp | 286 | ||||
| -rw-r--r-- | zencore/include/zencore/crypto.h | 67 |
2 files changed, 186 insertions, 167 deletions
diff --git a/zencore/crypto.cpp b/zencore/crypto.cpp index 57820c278..0ad368f3f 100644 --- a/zencore/crypto.cpp +++ b/zencore/crypto.cpp @@ -4,12 +4,15 @@ #include <zencore/intmath.h> #include <zencore/testing.h> +#include <string> +#include <string_view> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <fmt/format.h> #include <openssl/conf.h> #include <openssl/err.h> #include <openssl/evp.h> - -#include <string> -#include <string_view> +ZEN_THIRD_PARTY_INCLUDES_END #if ZEN_PLATFORM_WINDOWS # pragma comment(lib, "crypt32.lib") @@ -18,212 +21,195 @@ namespace zen { -class NoOpSymmetricCipher final : public SymmetricCipher -{ -public: - NoOpSymmetricCipher() = default; - - virtual ~NoOpSymmetricCipher() = default; - - virtual bool Initialize(MemoryView, MemoryView) override final { return true; } - - virtual CipherSettings Settings() override final { return {}; } - - virtual MemoryView Encrypt(MemoryView Data, MutableMemoryView) override final { return Data; } - - virtual MemoryView Decrypt(MemoryView Data, MutableMemoryView) override final { return Data; } -}; +using namespace std::literals; -std::unique_ptr<SymmetricCipher> -SymmetricCipher::CreateNoOp() -{ - return std::make_unique<NoOpSymmetricCipher>(); -} +namespace crypto { -class Aes final : public SymmetricCipher -{ -public: - Aes(const EVP_CIPHER* Cipher = EVP_aes_256_cbc()) : m_Cipher(Cipher) + class EvpContext { - ZEN_ASSERT(Cipher); - m_KeySize = static_cast<size_t>(EVP_CIPHER_key_length(m_Cipher)); - m_InitVectorSize = static_cast<size_t>(EVP_CIPHER_iv_length(m_Cipher)); - m_BlockSize = static_cast<size_t>(EVP_CIPHER_block_size(m_Cipher)); - } + public: + EvpContext() : m_Ctx(EVP_CIPHER_CTX_new()) {} + ~EvpContext() { EVP_CIPHER_CTX_free(m_Ctx); } - virtual ~Aes() - { - if (m_EncryptionCtx) - { - EVP_CIPHER_CTX_free(m_EncryptionCtx); - } + operator EVP_CIPHER_CTX*() { return m_Ctx; } - if (m_DecryptionCtx) - { - EVP_CIPHER_CTX_free(m_DecryptionCtx); - } - } + private: + EVP_CIPHER_CTX* m_Ctx; + }; - virtual bool Initialize(MemoryView Key, MemoryView InitVector) override final + enum class TransformMode : uint32_t { - ZEN_ASSERT(m_EncryptionCtx == nullptr && m_DecryptionCtx == nullptr); - ZEN_ASSERT(Key.GetSize() == m_KeySize); - ZEN_ASSERT(InitVector.GetSize() == m_InitVectorSize); - - m_EncryptionCtx = EVP_CIPHER_CTX_new(); - m_DecryptionCtx = EVP_CIPHER_CTX_new(); - - if (int ErrorCode = EVP_EncryptInit_ex(m_EncryptionCtx, - m_Cipher, - nullptr, - reinterpret_cast<const unsigned char*>(Key.GetData()), - reinterpret_cast<const unsigned char*>(InitVector.GetData())); - ErrorCode != 1) - { - return false; - } - - if (int ErrorCode = EVP_DecryptInit_ex(m_DecryptionCtx, - m_Cipher, - nullptr, - reinterpret_cast<const unsigned char*>(Key.GetData()), - reinterpret_cast<const unsigned char*>(InitVector.GetData())); - ErrorCode != 1) - { - return false; - } - - return true; - } - - virtual CipherSettings Settings() override final + Decrypt, + Encrypt + }; + + MemoryView Transform(const EVP_CIPHER* Cipher, + TransformMode Mode, + MemoryView Key, + MemoryView IV, + MemoryView In, + MutableMemoryView Out, + std::optional<std::string>& Reason) { - return {.KeySize = m_KeySize, .InitVectorSize = m_InitVectorSize, .BlockSize = m_BlockSize}; - } + ZEN_ASSERT(Cipher != nullptr); - virtual MemoryView Encrypt(MemoryView Data, MutableMemoryView EncryptionBuffer) override - { - ZEN_ASSERT(m_EncryptionCtx); + EvpContext Ctx; - const uint64_t InputSize = Data.GetSize(); - const uint64_t NeededSize = RoundUp(InputSize, m_BlockSize); + int Err = EVP_CipherInit_ex(Ctx, + Cipher, + nullptr, + reinterpret_cast<const unsigned char*>(Key.GetData()), + reinterpret_cast<const unsigned char*>(IV.GetData()), + static_cast<int>(Mode)); - if (NeededSize > EncryptionBuffer.GetSize()) + if (Err != 1) { + if (Reason) + { + Reason = fmt::format("failed to initialize cipher, error code '{}'", Err); + } + return MemoryView(); } - int TotalSize = 0; - int EncryptedSize = 0; - int ErrorCode = EVP_EncryptUpdate(m_EncryptionCtx, - reinterpret_cast<unsigned char*>(EncryptionBuffer.GetData()), - &EncryptedSize, - reinterpret_cast<const unsigned char*>(Data.GetData()), - static_cast<int>(Data.GetSize())); + int EncryptedBytes = 0; + int TotalEncryptedBytes = 0; - if (ErrorCode != 1) + Err = EVP_CipherUpdate(Ctx, + reinterpret_cast<unsigned char*>(Out.GetData()), + &EncryptedBytes, + reinterpret_cast<const unsigned char*>(In.GetData()), + static_cast<int>(In.GetSize())); + + if (Err != 1) { + if (Reason) + { + Reason = fmt::format("update crypto transform failed, error code '{}'", Err); + } + return MemoryView(); } - TotalSize = EncryptedSize; - MutableMemoryView Remaining = EncryptionBuffer.RightChop(uint64_t(EncryptedSize)); + TotalEncryptedBytes = EncryptedBytes; + MutableMemoryView Remaining = Out.RightChop(EncryptedBytes); + + EncryptedBytes = static_cast<int>(Remaining.GetSize()); - ErrorCode = EVP_EncryptFinal_ex(m_EncryptionCtx, reinterpret_cast<unsigned char*>(Remaining.GetData()), &EncryptedSize); + Err = EVP_CipherFinal(Ctx, reinterpret_cast<unsigned char*>(Remaining.GetData()), &EncryptedBytes); - if (ErrorCode != 1) + if (Err != 1) { + if (Reason) + { + Reason = fmt::format("finalize crypto transform failed, error code '{}'", Err); + } + return MemoryView(); } - TotalSize += EncryptedSize; + TotalEncryptedBytes += EncryptedBytes; - return EncryptionBuffer.Left(uint64_t(TotalSize)); + return Out.Left(TotalEncryptedBytes); } - virtual MemoryView Decrypt(MemoryView Data, MutableMemoryView DecryptionBuffer) override final + bool ValidateKeyAndIV(const AesKey256Bit& Key, const AesIV128Bit& IV, std::optional<std::string>& Reason) { - ZEN_ASSERT(m_DecryptionCtx); - - int TotalSize = 0; - int DecryptedSize = 0; + if (Key.IsValid() == false) + { + if (Reason) + { + Reason = "Invalid key"sv; + } - int ErrorCode = EVP_DecryptUpdate(m_DecryptionCtx, - reinterpret_cast<unsigned char*>(DecryptionBuffer.GetData()), - &DecryptedSize, - reinterpret_cast<const unsigned char*>(Data.GetData()), - static_cast<int>(Data.GetSize())); + return false; + } - if (ErrorCode != 1) + if (IV.IsValid() == false) { - return MemoryView(); + if (Reason) + { + Reason = "Invalid initialization vector"sv; + } + + return false; } - TotalSize = DecryptedSize; - MutableMemoryView Remaining = DecryptionBuffer.RightChop(uint64_t(DecryptedSize)); + return true; + } - ErrorCode = EVP_DecryptFinal_ex(m_DecryptionCtx, reinterpret_cast<unsigned char*>(Remaining.GetData()), &DecryptedSize); +} // namespace crypto - if (ErrorCode != 1) - { - return MemoryView(); - } +MemoryView +Aes::Encrypt(const AesKey256Bit& Key, const AesIV128Bit& IV, MemoryView In, MutableMemoryView Out, std::optional<std::string>& Reason) +{ + if (crypto::ValidateKeyAndIV(Key, IV, Reason) == false) + { + return MemoryView(); + } - TotalSize += DecryptedSize; + return crypto::Transform(EVP_aes_256_cbc(), crypto::TransformMode::Encrypt, Key.GetView(), IV.GetView(), In, Out, Reason); +} - return DecryptionBuffer.Left(uint64_t(TotalSize)); +MemoryView +Aes::Decrypt(const AesKey256Bit& Key, const AesIV128Bit& IV, MemoryView In, MutableMemoryView Out, std::optional<std::string>& Reason) +{ + if (crypto::ValidateKeyAndIV(Key, IV, Reason) == false) + { + return MemoryView(); } -private: - const EVP_CIPHER* m_Cipher = nullptr; - EVP_CIPHER_CTX* m_EncryptionCtx = nullptr; - EVP_CIPHER_CTX* m_DecryptionCtx = nullptr; - size_t m_BlockSize = 0; - size_t m_KeySize = 0; - size_t m_InitVectorSize = 0; -}; - -std::unique_ptr<SymmetricCipher> -SymmetricCipher::CreateAes() -{ - return std::make_unique<Aes>(); + return crypto::Transform(EVP_aes_256_cbc(), crypto::TransformMode::Decrypt, Key.GetView(), IV.GetView(), In, Out, Reason); } #if ZEN_WITH_TESTS -using namespace std::literals; - void crypto_forcelink() { } +TEST_CASE("crypto.bits") +{ + using CryptoBits256Bit = CryptoBits<256>; + + CryptoBits256Bit Bits; + + CHECK(Bits.IsNull()); + CHECK(Bits.IsValid() == false); + + CHECK(Bits.GetBitCount() == 256); + CHECK(Bits.GetSize() == 32); + + Bits = CryptoBits256Bit::FromString("Addff"sv); + CHECK(Bits.IsValid() == false); + + Bits = CryptoBits256Bit::FromString("abcdefghijklmnopqrstuvxyz0123456"sv); + CHECK(Bits.IsValid()); + + auto SmallerBits = CryptoBits<128>::FromString("abcdefghijklmnopqrstuvxyz0123456"sv); + CHECK(SmallerBits.IsValid() == false); +} + TEST_CASE("crypto.aes") { SUBCASE("basic") { - auto Cipher = std::make_unique<Aes>(); + const uint8_t InitVector[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + const AesKey256Bit Key = AesKey256Bit::FromString("abcdefghijklmnopqrstuvxyz0123456"sv); + const AesIV128Bit IV = AesIV128Bit::FromMemoryView(MakeMemoryView(InitVector)); std::string_view PlainText = "The quick brown fox jumps over the lazy dog"sv; - std::vector<uint8_t> Key = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, - 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}; - std::vector<uint8_t> Seed = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; - - std::vector<uint8_t> EncryptionBuffer; - std::vector<uint8_t> DecryptionBuffer; - - bool Ok = Cipher->Initialize(MakeMemoryView(Key), MakeMemoryView(Seed)); - CHECK(Ok); - - EncryptionBuffer.resize(PlainText.size() + Cipher->Settings().BlockSize); - DecryptionBuffer.resize(PlainText.size() + Cipher->Settings().BlockSize); + std::vector<uint8_t> EncryptionBuffer; + std::vector<uint8_t> DecryptionBuffer; + std::optional<std::string> Reason; - MemoryView EncryptedView = Cipher->Encrypt(MakeMemoryView(PlainText), MakeMutableMemoryView(EncryptionBuffer)); - CHECK(EncryptedView.IsEmpty() == false); + EncryptionBuffer.resize(PlainText.size() + Aes::BlockSize); + DecryptionBuffer.resize(PlainText.size() + Aes::BlockSize); - MemoryView DecryptedView = Cipher->Decrypt(EncryptedView, MakeMutableMemoryView(DecryptionBuffer)); - CHECK(DecryptedView.IsEmpty() == false); + MemoryView EncryptedView = Aes::Encrypt(Key, IV, MakeMemoryView(PlainText), MakeMutableMemoryView(EncryptionBuffer), Reason); + MemoryView DecryptedView = Aes::Decrypt(Key, IV, EncryptedView, MakeMutableMemoryView(DecryptionBuffer), Reason); std::string_view EncryptedDecryptedText = std::string_view(reinterpret_cast<const char*>(DecryptedView.GetData()), DecryptedView.GetSize()); diff --git a/zencore/include/zencore/crypto.h b/zencore/include/zencore/crypto.h index 44783cdeb..83d416b0f 100644 --- a/zencore/include/zencore/crypto.h +++ b/zencore/include/zencore/crypto.h @@ -7,36 +7,69 @@ #include <zencore/zencore.h> #include <memory> +#include <optional> namespace zen { -/** - * Experimental interface for a symmetric encryption/decryption algorithm. - * Currenlty only AES 256 bit CBC is supported using OpenSSL. - */ -class SymmetricCipher +template<size_t BitCount> +struct CryptoBits { public: - virtual ~SymmetricCipher() = default; + static constexpr size_t ByteCount = BitCount / 8; - virtual bool Initialize(MemoryView Key, MemoryView InitVector) = 0; + CryptoBits() = default; - struct CipherSettings + bool IsNull() const { return memcmp(&m_Bits, &Zero, ByteCount) == 0; } + bool IsValid() const { return IsNull() == false; } + + size_t GetSize() const { return ByteCount; } + size_t GetBitCount() const { return BitCount; } + + MemoryView GetView() const { return MemoryView(m_Bits, ByteCount); } + + static CryptoBits FromMemoryView(MemoryView Bits) { - size_t KeySize = 0; - size_t InitVectorSize = 0; - size_t BlockSize = 0; - }; + if (Bits.GetSize() != ByteCount) + { + return CryptoBits(); + } - virtual CipherSettings Settings() = 0; + return CryptoBits(Bits); + } + + static CryptoBits FromString(std::string_view Str) { return FromMemoryView(MakeMemoryView(Str)); } + +private: + CryptoBits(MemoryView Bits) + { + ZEN_ASSERT(Bits.GetSize() == GetSize()); + memcpy(&m_Bits, Bits.GetData(), GetSize()); + } - virtual MemoryView Encrypt(MemoryView Data, MutableMemoryView EncryptionBuffer) = 0; + static constexpr uint8_t Zero[ByteCount] = {0}; - virtual MemoryView Decrypt(MemoryView Data, MutableMemoryView DecryptionBuffer) = 0; + uint8_t m_Bits[ByteCount] = {0}; +}; + +using AesKey256Bit = CryptoBits<256>; +using AesIV128Bit = CryptoBits<128>; + +class Aes +{ +public: + static constexpr size_t BlockSize = 16; - static std::unique_ptr<SymmetricCipher> CreateNoOp(); + static MemoryView Encrypt(const AesKey256Bit& Key, + const AesIV128Bit& IV, + MemoryView In, + MutableMemoryView Out, + std::optional<std::string>& Reason); - static std::unique_ptr<SymmetricCipher> CreateAes(); + static MemoryView Decrypt(const AesKey256Bit& Key, + const AesIV128Bit& IV, + MemoryView In, + MutableMemoryView Out, + std::optional<std::string>& Reason); }; void crypto_forcelink(); |