diff options
| author | Per Larsson <[email protected]> | 2022-02-03 10:10:32 +0100 |
|---|---|---|
| committer | Per Larsson <[email protected]> | 2022-02-03 10:10:32 +0100 |
| commit | 02046f9bb44b16b7ce457f5b564cbf9d9cee4a91 (patch) | |
| tree | 3b3a443d4a04308944f648406a87511fab8b5f8a /zenserver/auth/authmgr.cpp | |
| parent | Overwrite existing refresh token. (diff) | |
| download | zen-02046f9bb44b16b7ce457f5b564cbf9d9cee4a91.tar.xz zen-02046f9bb44b16b7ce457f5b564cbf9d9cee4a91.zip | |
Encrypt serialized auth state.
Diffstat (limited to 'zenserver/auth/authmgr.cpp')
| -rw-r--r-- | zenserver/auth/authmgr.cpp | 84 |
1 files changed, 77 insertions, 7 deletions
diff --git a/zenserver/auth/authmgr.cpp b/zenserver/auth/authmgr.cpp index 08cab975d..f61e4acd7 100644 --- a/zenserver/auth/authmgr.cpp +++ b/zenserver/auth/authmgr.cpp @@ -6,6 +6,7 @@ #include <zencore/compactbinary.h> #include <zencore/compactbinarybuilder.h> #include <zencore/compactbinaryvalidation.h> +#include <zencore/crypto.h> #include <zencore/filesystem.h> #include <zencore/logging.h> @@ -21,6 +22,63 @@ namespace zen { using namespace std::literals; +namespace details { + const std::string_view DefaultPrivateKey = "HeyThisIsNotAGoodPrivateKeyToUse"sv; + const std::string_view DefaultIV = "DefaultInitVecto"sv; + + IoBuffer ReadEncryptedFile(std::filesystem::path Path, MemoryView EncryptionKey, MemoryView IV) + { + FileContents Result = ReadFile(Path); + + if (Result.ErrorCode) + { + return IoBuffer(); + } + + IoBuffer EncryptedBuffer = Result.Flatten(); + + if (EncryptedBuffer.GetSize() == 0) + { + return IoBuffer(); + } + + std::unique_ptr<SymmetricCipher> Cipher = MakeAesCipher(); + + if (Cipher->Initialize(EncryptionKey, IV) == false) + { + return IoBuffer(); + } + + IoBuffer DecryptionBuffer(EncryptedBuffer.GetSize() + Cipher->Settings().BlockSize); + MemoryView DecryptedView = Cipher->Decrypt(EncryptedBuffer, DecryptionBuffer.GetMutableView()); + + return IoBufferBuilder::MakeCloneFromMemory(DecryptedView); + } + + uint64_t WriteEncryptedFile(std::filesystem::path Path, IoBuffer FileData, MemoryView EncryptionKey, MemoryView IV) + { + if (FileData.GetSize() == 0) + { + return 0; + } + + std::unique_ptr<SymmetricCipher> Cipher = MakeAesCipher(); + + if (Cipher->Initialize(EncryptionKey, IV) == false) + { + return 0; + } + + IoBuffer EncryptionBuffer(FileData.GetSize() + Cipher->Settings().BlockSize); + + MemoryView EncryptedView = Cipher->Encrypt(FileData, EncryptionBuffer.GetMutableView()); + + WriteFile(Path, IoBuffer(IoBuffer::Wrap, EncryptedView.GetData(), EncryptedView.GetSize())); + + return EncryptedView.GetSize(); + } +} // namespace details + class AuthMgrImpl final : public AuthMgr { using Clock = std::chrono::system_clock; @@ -36,7 +94,7 @@ public: m_BackgroundThread.Thread = std::thread(&AuthMgrImpl::BackgroundThreadEntry, this); } - virtual ~AuthMgrImpl() { SaveState(); } + virtual ~AuthMgrImpl() { Shutdown(); } virtual void AddOpenIdProvider(const AddOpenIdProviderParams& Params) final { @@ -177,15 +235,15 @@ private: { try { - FileContents Result = ReadFile(m_Config.RootDirectory / "authstate"sv); + IoBuffer Buffer = details::ReadEncryptedFile(m_Config.RootDirectory / "authstate"sv, + MakeMemoryView(details::DefaultPrivateKey), + MakeMemoryView(details::DefaultIV)); - if (Result.ErrorCode) + if (Buffer.GetSize() == 0) { return; } - IoBuffer Buffer = Result.Flatten(); - const CbValidateError ValidationError = ValidateCompactBinary(Buffer, CbValidateMode::All); if (ValidationError != CbValidateError::None) @@ -281,7 +339,19 @@ private: } std::filesystem::create_directories(m_Config.RootDirectory); - WriteFile(m_Config.RootDirectory / "authstate"sv, AuthState.Save().GetBuffer().AsIoBuffer()); + + MemoryView EncryptionKey = MakeMemoryView(details::DefaultPrivateKey); + MemoryView IV = MakeMemoryView(details::DefaultIV); + + const uint64_t ByteCount = details::WriteEncryptedFile(m_Config.RootDirectory / "authstate"sv, + AuthState.Save().GetBuffer().AsIoBuffer(), + EncryptionKey, + IV); + + if (ByteCount == 0) + { + ZEN_WARN("save auth state FAILED"); + } } catch (std::exception& Err) { @@ -366,7 +436,7 @@ private: { if (State.Running.load()) { - State.Running.store(true); + State.Running.store(false); State.Signal.notify_one(); } |