diff options
| author | Stefan Boberg <[email protected]> | 2023-05-02 12:31:53 +0200 |
|---|---|---|
| committer | GitHub <[email protected]> | 2023-05-02 12:31:53 +0200 |
| commit | e3086573d2244def22ecbe1e6b4b3da8b47e0f14 (patch) | |
| tree | 627066debdddf7474783893f6b9b6631bb9a4833 /src/zenhttp/auth/authmgr.cpp | |
| parent | moved source directories into `/src` (#264) (diff) | |
| download | zen-e3086573d2244def22ecbe1e6b4b3da8b47e0f14.tar.xz zen-e3086573d2244def22ecbe1e6b4b3da8b47e0f14.zip | |
move auth code from zenserver into zenhttp (#265)
this code should be usable outside of zenserver, so this moves it out into zenhttp where it can be used from lower level components
Diffstat (limited to 'src/zenhttp/auth/authmgr.cpp')
| -rw-r--r-- | src/zenhttp/auth/authmgr.cpp | 506 |
1 files changed, 506 insertions, 0 deletions
diff --git a/src/zenhttp/auth/authmgr.cpp b/src/zenhttp/auth/authmgr.cpp new file mode 100644 index 000000000..d535d07a4 --- /dev/null +++ b/src/zenhttp/auth/authmgr.cpp @@ -0,0 +1,506 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "zenhttp/auth/authmgr.h" + +#include <zencore/compactbinary.h> +#include <zencore/compactbinarybuilder.h> +#include <zencore/compactbinaryvalidation.h> +#include <zencore/crypto.h> +#include <zencore/filesystem.h> +#include <zencore/logging.h> +#include <zenhttp/auth/oidc.h> + +#include <condition_variable> +#include <memory> +#include <shared_mutex> +#include <thread> +#include <unordered_map> + +#include <fmt/format.h> + +namespace zen { + +using namespace std::literals; + +namespace details { + IoBuffer ReadEncryptedFile(std::filesystem::path Path, + const AesKey256Bit& Key, + const AesIV128Bit& IV, + std::optional<std::string>& Reason) + { + FileContents Result = ReadFile(Path); + + if (Result.ErrorCode) + { + return IoBuffer(); + } + + IoBuffer EncryptedBuffer = Result.Flatten(); + + if (EncryptedBuffer.GetSize() == 0) + { + return IoBuffer(); + } + + std::vector<uint8_t> DecryptionBuffer; + DecryptionBuffer.resize(EncryptedBuffer.GetSize() + Aes::BlockSize); + + MemoryView DecryptedView = Aes::Decrypt(Key, IV, EncryptedBuffer, MakeMutableMemoryView(DecryptionBuffer), Reason); + + if (DecryptedView.IsEmpty()) + { + return IoBuffer(); + } + + return IoBufferBuilder::MakeCloneFromMemory(DecryptedView); + } + + void WriteEncryptedFile(std::filesystem::path Path, + IoBuffer FileData, + const AesKey256Bit& Key, + const AesIV128Bit& IV, + std::optional<std::string>& Reason) + { + if (FileData.GetSize() == 0) + { + return; + } + + std::vector<uint8_t> EncryptionBuffer; + EncryptionBuffer.resize(FileData.GetSize() + Aes::BlockSize); + + MemoryView EncryptedView = Aes::Encrypt(Key, IV, FileData, MakeMutableMemoryView(EncryptionBuffer), Reason); + + if (EncryptedView.IsEmpty()) + { + return; + } + + WriteFile(Path, IoBuffer(IoBuffer::Wrap, EncryptedView.GetData(), EncryptedView.GetSize())); + } +} // namespace details + +class AuthMgrImpl final : public AuthMgr +{ + using Clock = std::chrono::system_clock; + using TimePoint = Clock::time_point; + using Seconds = std::chrono::seconds; + +public: + AuthMgrImpl(const AuthConfig& Config) : m_Config(Config), m_Log(logging::Get("auth")) + { + LoadState(); + + m_BackgroundThread.Interval = Config.UpdateInterval; + m_BackgroundThread.Thread = std::thread(&AuthMgrImpl::BackgroundThreadEntry, this); + } + + virtual ~AuthMgrImpl() { Shutdown(); } + + virtual void AddOpenIdProvider(const AddOpenIdProviderParams& Params) final + { + if (OpenIdProviderExist(Params.Name)) + { + ZEN_DEBUG("OpenID provider '{}' already exist", Params.Name); + return; + } + + if (Params.Name.empty()) + { + ZEN_WARN("add OpenID provider FAILED, reason 'invalid name'"); + return; + } + + std::unique_ptr<OidcClient> Client = + std::make_unique<OidcClient>(OidcClient::Options{.BaseUrl = Params.Url, .ClientId = Params.ClientId}); + + if (const auto InitResult = Client->Initialize(); InitResult.Ok == false) + { + ZEN_WARN("query OpenID provider FAILED, reason '{}'", InitResult.Reason); + return; + } + + std::string NewProviderName = std::string(Params.Name); + + OpenIdProvider* NewProvider = nullptr; + + { + std::unique_lock _(m_ProviderMutex); + + if (m_OpenIdProviders.contains(NewProviderName)) + { + return; + } + + auto InsertResult = m_OpenIdProviders.emplace(NewProviderName, std::make_unique<OpenIdProvider>()); + NewProvider = InsertResult.first->second.get(); + } + + NewProvider->Name = std::string(Params.Name); + NewProvider->Url = std::string(Params.Url); + NewProvider->ClientId = std::string(Params.ClientId); + NewProvider->HttpClient = std::move(Client); + + ZEN_INFO("added OpenID provider '{} - {}'", Params.Name, Params.Url); + } + + virtual bool AddOpenIdToken(const AddOpenIdTokenParams& Params) final + { + if (Params.ProviderName.empty()) + { + ZEN_WARN("trying add OpenID token with invalid provider name"); + return false; + } + + if (Params.RefreshToken.empty()) + { + ZEN_WARN("add OpenID token FAILED, reason 'Token invalid'"); + return false; + } + + auto RefreshResult = RefreshOpenIdToken(Params.ProviderName, Params.RefreshToken); + + if (RefreshResult.Ok == false) + { + ZEN_WARN("refresh OpenId token FAILED, reason '{}'", RefreshResult.Reason); + return false; + } + + bool IsNew = false; + + { + auto Token = OpenIdToken{.IdentityToken = RefreshResult.IdentityToken, + .RefreshToken = RefreshResult.RefreshToken, + .AccessToken = fmt::format("Bearer {}"sv, RefreshResult.AccessToken), + .ExpireTime = Clock::now() + Seconds(RefreshResult.ExpiresInSeconds)}; + + std::unique_lock _(m_TokenMutex); + + const auto InsertResult = m_OpenIdTokens.insert_or_assign(std::string(Params.ProviderName), std::move(Token)); + + IsNew = InsertResult.second; + } + + if (IsNew) + { + ZEN_INFO("added new OpenID token for provider '{}'", Params.ProviderName); + } + else + { + ZEN_INFO("updating OpenID token for provider '{}'", Params.ProviderName); + } + + return true; + } + + virtual OpenIdAccessToken GetOpenIdAccessToken(std::string_view ProviderName) final + { + std::unique_lock _(m_TokenMutex); + + if (auto It = m_OpenIdTokens.find(std::string(ProviderName)); It != m_OpenIdTokens.end()) + { + const OpenIdToken& Token = It->second; + + return {.AccessToken = Token.AccessToken, .ExpireTime = Token.ExpireTime}; + } + + return {}; + } + +private: + bool OpenIdProviderExist(std::string_view ProviderName) + { + std::unique_lock _(m_ProviderMutex); + + return m_OpenIdProviders.contains(std::string(ProviderName)); + } + + OidcClient& GetOpenIdClient(std::string_view ProviderName) + { + std::unique_lock _(m_ProviderMutex); + return *m_OpenIdProviders[std::string(ProviderName)]->HttpClient.get(); + } + + OidcClient::RefreshTokenResult RefreshOpenIdToken(std::string_view ProviderName, std::string_view RefreshToken) + { + if (OpenIdProviderExist(ProviderName) == false) + { + return {.Reason = fmt::format("provider '{}' is missing", ProviderName)}; + } + + OidcClient& Client = GetOpenIdClient(ProviderName); + + return Client.RefreshToken(RefreshToken); + } + + void Shutdown() + { + BackgroundThread::Stop(m_BackgroundThread); + SaveState(); + } + + void LoadState() + { + try + { + std::optional<std::string> Reason; + + IoBuffer Buffer = + details::ReadEncryptedFile(m_Config.RootDirectory / "authstate"sv, m_Config.EncryptionKey, m_Config.EncryptionIV, Reason); + + if (!Buffer) + { + if (Reason) + { + ZEN_WARN("load auth state FAILED, reason '{}'", Reason.value()); + } + + return; + } + + const CbValidateError ValidationError = ValidateCompactBinary(Buffer, CbValidateMode::All); + + if (ValidationError != CbValidateError::None) + { + ZEN_WARN("load serialized state FAILED, reason 'Invalid compact binary'"); + return; + } + + if (CbObject AuthState = LoadCompactBinaryObject(Buffer)) + { + for (CbFieldView ProviderView : AuthState["OpenIdProviders"sv]) + { + CbObjectView ProviderObj = ProviderView.AsObjectView(); + + std::string_view ProviderName = ProviderObj["Name"].AsString(); + std::string_view Url = ProviderObj["Url"].AsString(); + std::string_view ClientId = ProviderObj["ClientId"].AsString(); + + AddOpenIdProvider({.Name = ProviderName, .Url = Url, .ClientId = ClientId}); + } + + for (CbFieldView TokenView : AuthState["OpenIdTokens"sv]) + { + CbObjectView TokenObj = TokenView.AsObjectView(); + + std::string_view ProviderName = TokenObj["ProviderName"sv].AsString(); + std::string_view RefreshToken = TokenObj["RefreshToken"sv].AsString(); + + const bool Ok = AddOpenIdToken({.ProviderName = ProviderName, .RefreshToken = RefreshToken}); + + if (!Ok) + { + ZEN_WARN("load serialized OpenId token for provider '{}' FAILED", ProviderName); + } + } + } + } + catch (std::exception& Err) + { + ZEN_ERROR("(de)serialize state FAILED, reason '{}'", Err.what()); + + { + std::unique_lock _(m_ProviderMutex); + m_OpenIdProviders.clear(); + } + + { + std::unique_lock _(m_TokenMutex); + m_OpenIdTokens.clear(); + } + } + } + + void SaveState() + { + try + { + CbObjectWriter AuthState; + + { + std::unique_lock _(m_ProviderMutex); + + if (m_OpenIdProviders.size() > 0) + { + AuthState.BeginArray("OpenIdProviders"); + for (const auto& Kv : m_OpenIdProviders) + { + AuthState.BeginObject(); + AuthState << "Name"sv << Kv.second->Name; + AuthState << "Url"sv << Kv.second->Url; + AuthState << "ClientId"sv << Kv.second->ClientId; + AuthState.EndObject(); + } + AuthState.EndArray(); + } + } + + { + std::unique_lock _(m_TokenMutex); + + AuthState.BeginArray("OpenIdTokens"); + if (m_OpenIdTokens.size() > 0) + { + for (const auto& Kv : m_OpenIdTokens) + { + AuthState.BeginObject(); + AuthState << "ProviderName"sv << Kv.first; + AuthState << "RefreshToken"sv << Kv.second.RefreshToken; + AuthState.EndObject(); + } + } + AuthState.EndArray(); + } + + std::filesystem::create_directories(m_Config.RootDirectory); + + std::optional<std::string> Reason; + + details::WriteEncryptedFile(m_Config.RootDirectory / "authstate"sv, + AuthState.Save().GetBuffer().AsIoBuffer(), + m_Config.EncryptionKey, + m_Config.EncryptionIV, + Reason); + + if (Reason) + { + ZEN_WARN("save auth state FAILED, reason '{}'", Reason.value()); + } + } + catch (std::exception& Err) + { + ZEN_ERROR("serialize state FAILED, reason '{}'", Err.what()); + } + } + + void BackgroundThreadEntry() + { + for (;;) + { + std::cv_status SignalStatus = BackgroundThread::WaitForSignal(m_BackgroundThread); + + if (m_BackgroundThread.Running.load() == false) + { + break; + } + + if (SignalStatus != std::cv_status::timeout) + { + continue; + } + + { + // Refresh Open ID token(s) + + std::vector<OpenIdTokenMap::value_type> ExpiredTokens; + + { + std::unique_lock _(m_TokenMutex); + + for (const auto& Kv : m_OpenIdTokens) + { + const Seconds ExpiresIn = std::chrono::duration_cast<Seconds>(Kv.second.ExpireTime - Clock::now()); + const bool Expired = ExpiresIn < Seconds(m_BackgroundThread.Interval * 2); + + if (Expired) + { + ExpiredTokens.push_back(Kv); + } + } + } + + ZEN_DEBUG("refreshing '{}' OpenID token(s)", ExpiredTokens.size()); + + for (const auto& Kv : ExpiredTokens) + { + OidcClient::RefreshTokenResult RefreshResult = RefreshOpenIdToken(Kv.first, Kv.second.RefreshToken); + + if (RefreshResult.Ok) + { + ZEN_DEBUG("refresh access token from provider '{}' Ok", Kv.first); + + auto Token = OpenIdToken{.IdentityToken = RefreshResult.IdentityToken, + .RefreshToken = RefreshResult.RefreshToken, + .AccessToken = fmt::format("Bearer {}"sv, RefreshResult.AccessToken), + .ExpireTime = Clock::now() + Seconds(RefreshResult.ExpiresInSeconds)}; + + { + std::unique_lock _(m_TokenMutex); + m_OpenIdTokens.insert_or_assign(Kv.first, std::move(Token)); + } + } + else + { + ZEN_WARN("refresh access token from provider '{}' FAILED, reason '{}'", Kv.first, RefreshResult.Reason); + } + } + } + } + } + + struct BackgroundThread + { + std::chrono::seconds Interval{10}; + std::mutex Mutex; + std::condition_variable Signal; + std::atomic_bool Running{true}; + std::thread Thread; + + static void Stop(BackgroundThread& State) + { + if (State.Running.load()) + { + State.Running.store(false); + State.Signal.notify_one(); + } + + if (State.Thread.joinable()) + { + State.Thread.join(); + } + } + + static std::cv_status WaitForSignal(BackgroundThread& State) + { + std::unique_lock Lock(State.Mutex); + return State.Signal.wait_for(Lock, State.Interval); + } + }; + + struct OpenIdProvider + { + std::string Name; + std::string Url; + std::string ClientId; + std::unique_ptr<OidcClient> HttpClient; + }; + + struct OpenIdToken + { + std::string IdentityToken; + std::string RefreshToken; + std::string AccessToken; + TimePoint ExpireTime{}; + }; + + using OpenIdProviderMap = std::unordered_map<std::string, std::unique_ptr<OpenIdProvider>>; + using OpenIdTokenMap = std::unordered_map<std::string, OpenIdToken>; + + spdlog::logger& Log() { return m_Log; } + + AuthConfig m_Config; + spdlog::logger& m_Log; + BackgroundThread m_BackgroundThread; + OpenIdProviderMap m_OpenIdProviders; + OpenIdTokenMap m_OpenIdTokens; + std::mutex m_ProviderMutex; + std::shared_mutex m_TokenMutex; +}; + +std::unique_ptr<AuthMgr> +AuthMgr::Create(const AuthConfig& Config) +{ + return std::make_unique<AuthMgrImpl>(Config); +} + +} // namespace zen |