// Copyright Epic Games, Inc. All Rights Reserved. #include "zenhttp/auth/authmgr.h" #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace zen { using namespace std::literals; namespace details { IoBuffer ReadEncryptedFile(std::filesystem::path Path, const AesKey256Bit& Key, const AesIV128Bit& IV, std::optional& Reason) { ZEN_TRACE_CPU("AuthMgr::ReadEncryptedFile"); FileContents Result = ReadFile(Path); if (Result.ErrorCode) { return IoBuffer(); } IoBuffer EncryptedBuffer = Result.Flatten(); if (EncryptedBuffer.GetSize() == 0) { return IoBuffer(); } std::vector DecryptionBuffer; DecryptionBuffer.resize(EncryptedBuffer.GetSize() + Aes::BlockSize); MemoryView DecryptedView = Aes::Decrypt(Key, IV, EncryptedBuffer, MakeMutableMemoryView(DecryptionBuffer), Reason); if (DecryptedView.IsEmpty()) { return IoBuffer(); } return IoBufferBuilder::MakeCloneFromMemory(DecryptedView); } void WriteEncryptedFile(std::filesystem::path Path, IoBuffer FileData, const AesKey256Bit& Key, const AesIV128Bit& IV, std::optional& Reason) { ZEN_TRACE_CPU("AuthMgr::WriteEncryptedFile"); if (FileData.GetSize() == 0) { return; } std::vector EncryptionBuffer; EncryptionBuffer.resize(FileData.GetSize() + Aes::BlockSize); MemoryView EncryptedView = Aes::Encrypt(Key, IV, FileData, MakeMutableMemoryView(EncryptionBuffer), Reason); if (EncryptedView.IsEmpty()) { return; } TemporaryFile::SafeWriteFile(Path, EncryptedView); } } // namespace details class AuthMgrImpl final : public AuthMgr { using Clock = std::chrono::system_clock; using TimePoint = Clock::time_point; using Seconds = std::chrono::seconds; public: AuthMgrImpl(const AuthConfig& Config) : m_Config(Config), m_Log(logging::Get("auth")) { LoadState(); m_BackgroundThread.Interval = Config.UpdateInterval; m_BackgroundThread.Thread = std::thread(&AuthMgrImpl::BackgroundThreadEntry, this); } virtual ~AuthMgrImpl() { Shutdown(); } virtual void AddOpenIdProvider(const AddOpenIdProviderParams& Params) final { ZEN_TRACE_CPU("AuthMgr::AddOpenIdProvider"); if (Params.Name.empty()) { ZEN_WARN("add OpenID provider FAILED, reason 'invalid name'"); return; } { std::unique_lock _(m_ProviderMutex); if (auto It = m_OpenIdProviders.find(std::string(Params.Name)); It != m_OpenIdProviders.end()) { OpenIdProvider& ExistingProvider = *It->second; if (ExistingProvider.ClientId == Params.ClientId && ExistingProvider.Url == Params.Url) { ZEN_DEBUG("OpenID provider '{}' already exists", Params.Name); return; } else { m_OpenIdProviders.erase(It); m_OpenIdTokens.erase(std::string(Params.Name)); ZEN_DEBUG("OpenID provider '{}' removed to allow add of new with same name", Params.Name); } } } RefPtr Client(new OidcClient(OidcClient::Options{.BaseUrl = Params.Url, .ClientId = Params.ClientId})); if (const auto InitResult = Client->Initialize(); InitResult.Ok == false) { ZEN_WARN("query OpenID provider FAILED, reason '{}'", InitResult.Reason); return; } std::string NewProviderName = std::string(Params.Name); OpenIdProvider* NewProvider = nullptr; { std::unique_lock _(m_ProviderMutex); if (m_OpenIdProviders.contains(NewProviderName)) { return; } auto InsertResult = m_OpenIdProviders.emplace(NewProviderName, std::make_unique()); NewProvider = InsertResult.first->second.get(); } NewProvider->Name = std::string(Params.Name); NewProvider->Url = std::string(Params.Url); NewProvider->ClientId = std::string(Params.ClientId); NewProvider->HttpClient = std::move(Client); ZEN_INFO("added OpenID provider '{} - {}'", Params.Name, Params.Url); } virtual bool AddOpenIdToken(const AddOpenIdTokenParams& Params) final { ZEN_TRACE_CPU("AuthMgr::AddOpenIdToken"); if (Params.ProviderName.empty()) { ZEN_WARN("trying add OpenID token with invalid provider name"); return false; } if (Params.RefreshToken.empty()) { ZEN_WARN("add OpenID token FAILED, reason 'Token invalid'"); return false; } auto RefreshResult = RefreshOpenIdToken(Params.ProviderName, Params.RefreshToken); if (RefreshResult.Ok == false) { ZEN_WARN("refresh OpenId token FAILED, reason '{}'", RefreshResult.Reason); return false; } bool IsNew = false; { auto Token = OpenIdToken{.IdentityToken = RefreshResult.IdentityToken, .RefreshToken = RefreshResult.RefreshToken, .AccessToken = fmt::format("Bearer {}"sv, RefreshResult.AccessToken), .ExpireTime = Clock::now() + Seconds(RefreshResult.ExpiresInSeconds)}; std::unique_lock _(m_TokenMutex); const auto InsertResult = m_OpenIdTokens.insert_or_assign(std::string(Params.ProviderName), std::move(Token)); IsNew = InsertResult.second; } if (IsNew) { ZEN_INFO("added new OpenID token for provider '{}'", Params.ProviderName); } else { ZEN_INFO("updating OpenID token for provider '{}'", Params.ProviderName); } return true; } virtual OpenIdAccessToken GetOpenIdAccessToken(std::string_view ProviderName) final { std::unique_lock _(m_TokenMutex); if (auto It = m_OpenIdTokens.find(std::string(ProviderName)); It != m_OpenIdTokens.end()) { const OpenIdToken& Token = It->second; return {.AccessToken = Token.AccessToken, .ExpireTime = Token.ExpireTime}; } return {}; } private: struct OpenIdProvider { std::string Name; std::string Url; std::string ClientId; RefPtr HttpClient; }; struct OpenIdToken { std::string IdentityToken; std::string RefreshToken; std::string AccessToken; TimePoint ExpireTime{}; }; bool OpenIdProviderExist(std::string_view ProviderName) { std::unique_lock _(m_ProviderMutex); return m_OpenIdProviders.contains(std::string(ProviderName)); } OpenIdProvider GetOpenIdProvider(std::string_view ProviderName) { std::unique_lock _(m_ProviderMutex); return *m_OpenIdProviders[std::string(ProviderName)]; } OidcClient::RefreshTokenResult RefreshOpenIdToken(std::string_view ProviderName, std::string_view RefreshToken) { ZEN_TRACE_CPU("AuthMgr::RefreshOpenIdToken"); RefPtr Client = GetOpenIdProvider(ProviderName).HttpClient; if (!Client) { return {.Reason = fmt::format("provider '{}' is missing", ProviderName)}; } return Client->RefreshToken(RefreshToken); } void Shutdown() { BackgroundThread::Stop(m_BackgroundThread); SaveState(); } void LoadState() { ZEN_TRACE_CPU("AuthMgrImpl::LoadState"); try { std::optional Reason; IoBuffer Buffer = details::ReadEncryptedFile(m_Config.RootDirectory / "authstate"sv, m_Config.EncryptionKey, m_Config.EncryptionIV, Reason); if (!Buffer) { if (Reason) { ZEN_WARN("load auth state FAILED, reason '{}'", Reason.value()); } return; } const CbValidateError ValidationError = ValidateCompactBinary(Buffer, CbValidateMode::All); if (ValidationError != CbValidateError::None) { ZEN_WARN("load serialized state FAILED, reason 'Invalid compact binary'"); return; } if (CbObject AuthState = LoadCompactBinaryObject(Buffer)) { for (CbFieldView ProviderView : AuthState["OpenIdProviders"sv]) { CbObjectView ProviderObj = ProviderView.AsObjectView(); std::string_view ProviderName = ProviderObj["Name"].AsString(); std::string_view Url = ProviderObj["Url"].AsString(); std::string_view ClientId = ProviderObj["ClientId"].AsString(); AddOpenIdProvider({.Name = ProviderName, .Url = Url, .ClientId = ClientId}); } for (CbFieldView TokenView : AuthState["OpenIdTokens"sv]) { CbObjectView TokenObj = TokenView.AsObjectView(); std::string_view ProviderName = TokenObj["ProviderName"sv].AsString(); std::string_view RefreshToken = TokenObj["RefreshToken"sv].AsString(); const bool Ok = AddOpenIdToken({.ProviderName = ProviderName, .RefreshToken = RefreshToken}); if (!Ok) { ZEN_WARN("load serialized OpenId token for provider '{}' FAILED", ProviderName); } } } } catch (const std::exception& Err) { ZEN_ERROR("(de)serialize state FAILED, reason '{}'", Err.what()); { std::unique_lock _(m_ProviderMutex); m_OpenIdProviders.clear(); } { std::unique_lock _(m_TokenMutex); m_OpenIdTokens.clear(); } } } void SaveState() { ZEN_TRACE_CPU("AuthMgr::SaveState"); try { CbObjectWriter AuthState; { std::unique_lock _(m_ProviderMutex); if (m_OpenIdProviders.size() > 0) { AuthState.BeginArray("OpenIdProviders"); for (const auto& Kv : m_OpenIdProviders) { AuthState.BeginObject(); AuthState << "Name"sv << Kv.second->Name; AuthState << "Url"sv << Kv.second->Url; AuthState << "ClientId"sv << Kv.second->ClientId; AuthState.EndObject(); } AuthState.EndArray(); } } { std::unique_lock _(m_TokenMutex); AuthState.BeginArray("OpenIdTokens"); if (m_OpenIdTokens.size() > 0) { for (const auto& Kv : m_OpenIdTokens) { AuthState.BeginObject(); AuthState << "ProviderName"sv << Kv.first; AuthState << "RefreshToken"sv << Kv.second.RefreshToken; AuthState.EndObject(); } } AuthState.EndArray(); } CreateDirectories(m_Config.RootDirectory); std::optional Reason; details::WriteEncryptedFile(m_Config.RootDirectory / "authstate"sv, AuthState.Save().GetBuffer().AsIoBuffer(), m_Config.EncryptionKey, m_Config.EncryptionIV, Reason); if (Reason) { ZEN_WARN("save auth state FAILED, reason '{}'", Reason.value()); } } catch (const std::exception& Err) { ZEN_WARN("serialize state FAILED, reason '{}'", Err.what()); } } void BackgroundThreadEntry() { SetCurrentThreadName("auth"); for (;;) { std::cv_status SignalStatus = BackgroundThread::WaitForSignal(m_BackgroundThread); if (m_BackgroundThread.Running.load() == false) { break; } if (SignalStatus != std::cv_status::timeout) { continue; } { // Refresh Open ID token(s) std::vector ExpiredTokens; { std::unique_lock _(m_TokenMutex); for (const auto& Kv : m_OpenIdTokens) { const Seconds ExpiresIn = std::chrono::duration_cast(Kv.second.ExpireTime - Clock::now()); const bool Expired = ExpiresIn < Seconds(m_BackgroundThread.Interval * 2); if (Expired) { ExpiredTokens.push_back(Kv); } } } if (ExpiredTokens.empty()) { continue; } ZEN_DEBUG("refreshing {} OpenID token(s)", ExpiredTokens.size()); for (const auto& Kv : ExpiredTokens) { OidcClient::RefreshTokenResult RefreshResult = RefreshOpenIdToken(Kv.first, Kv.second.RefreshToken); if (RefreshResult.Ok) { ZEN_DEBUG("refresh access token from provider '{}' Ok", Kv.first); auto Token = OpenIdToken{.IdentityToken = RefreshResult.IdentityToken, .RefreshToken = RefreshResult.RefreshToken, .AccessToken = fmt::format("Bearer {}"sv, RefreshResult.AccessToken), .ExpireTime = Clock::now() + Seconds(RefreshResult.ExpiresInSeconds)}; { std::unique_lock _(m_TokenMutex); m_OpenIdTokens.insert_or_assign(Kv.first, std::move(Token)); } } else { ZEN_WARN("refresh access token from provider '{}' FAILED, reason '{}'", Kv.first, RefreshResult.Reason); } } } } } struct BackgroundThread { std::chrono::seconds Interval{10}; std::mutex Mutex; std::condition_variable Signal; std::atomic_bool Running{true}; std::thread Thread; static void Stop(BackgroundThread& State) { if (State.Running.load()) { State.Running.store(false); State.Signal.notify_one(); } if (State.Thread.joinable()) { State.Thread.join(); } } static std::cv_status WaitForSignal(BackgroundThread& State) { std::unique_lock Lock(State.Mutex); return State.Signal.wait_for(Lock, State.Interval); } }; using OpenIdProviderMap = std::unordered_map>; using OpenIdTokenMap = std::unordered_map; LoggerRef Log() { return m_Log; } AuthConfig m_Config; LoggerRef m_Log; BackgroundThread m_BackgroundThread; OpenIdProviderMap m_OpenIdProviders; OpenIdTokenMap m_OpenIdTokens; std::mutex m_ProviderMutex; std::shared_mutex m_TokenMutex; }; std::unique_ptr AuthMgr::Create(const AuthConfig& Config) { return std::make_unique(Config); } } // namespace zen