// Copyright Epic Games, Inc. All Rights Reserved. #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace zen { using namespace std::literals; class AuthMgrImpl final : public AuthMgr { public: AuthMgrImpl(const AuthConfig& Config) : m_Config(Config), m_Log(logging::Get("auth")) { LoadState(); } virtual ~AuthMgrImpl() { SaveState(); } virtual void AddOpenIdProvider(const AddOpenIdProviderParams& Params) final { if (OpenIdProviderExist(Params.Name)) { return; } std::unique_ptr Client = std::make_unique(OidcClient::Options{.BaseUrl = Params.Url, .ClientId = Params.ClientId}); if (const auto InitResult = Client->Initialize(); InitResult.Ok == false) { ZEN_WARN("query OpenID provider FAILED, reason '{}'", InitResult.Reason); return; } std::string NewProviderName = std::string(Params.Name); OpenIdProvider* NewProvider = nullptr; { std::unique_lock _(m_ProviderMutex); if (m_OpenIdProviders.contains(NewProviderName)) { return; } auto InsertResult = m_OpenIdProviders.emplace(NewProviderName, std::make_unique()); NewProvider = InsertResult.first->second.get(); } NewProvider->Name = std::string(Params.Name); NewProvider->Url = std::string(Params.Url); NewProvider->ClientId = std::string(Params.ClientId); NewProvider->HttpClient = std::move(Client); ZEN_INFO("added OpenID provider '{} - {}'", Params.Name, Params.Url); } virtual bool AddOpenIdToken(const AddOpenIdTokenParams& Params) final { if (Params.ProviderName.empty()) { ZEN_WARN("trying add OpenID token with invalid provider name"); return false; } if (Params.IdentityToken.empty() || Params.RefreshToken.empty() || Params.AccessToken.empty()) { ZEN_WARN("add OpenId token FAILED, reason 'Token invalid'"); return false; } auto RefreshResult = RefreshOpenIdToken(Params.ProviderName, Params.RefreshToken); if (RefreshResult.Ok == false) { ZEN_WARN("refresh OpenId token FAILED, reason '{}'", RefreshResult.Reason); return false; } bool IsNew = false; { std::unique_lock _(m_TokenMutex); const auto InsertResult = m_OpenIdTokens.try_emplace(std::string(Params.ProviderName), OpenIdToken{.IdentityToken = RefreshResult.IdentityToken, .RefreshToken = RefreshResult.RefreshToken, .AccessToken = RefreshResult.AccessToken}); IsNew = InsertResult.second; } if (IsNew) { ZEN_INFO("added new OpenID token for provider '{}'", Params.ProviderName); } else { ZEN_INFO("updating OpenID token for provider '{}'", Params.ProviderName); } return true; } virtual OpenIdAccessToken GetOpenIdAccessToken(std::string_view ProviderName) final { std::unique_lock _(m_TokenMutex); if (auto It = m_OpenIdTokens.find(std::string(ProviderName)); It != m_OpenIdTokens.end()) { const OpenIdToken& Token = It->second; return {.AccessToken = Token.AccessToken}; } return {}; } private: bool OpenIdProviderExist(std::string_view ProviderName) { std::unique_lock _(m_ProviderMutex); return m_OpenIdProviders.contains(std::string(ProviderName)); } OidcClient& GetOpenIdClient(std::string_view ProviderName) { std::unique_lock _(m_ProviderMutex); return *m_OpenIdProviders[std::string(ProviderName)]->HttpClient.get(); } OidcClient::RefreshTokenResult RefreshOpenIdToken(std::string_view ProviderName, std::string_view RefreshToken) { if (OpenIdProviderExist(ProviderName) == false) { return {.Reason = fmt::format("provider '{}' is missing", ProviderName)}; } OidcClient& Client = GetOpenIdClient(ProviderName); return Client.RefreshToken(RefreshToken); } void Shutdown() { SaveState(); } void LoadState() { FileContents Result = ReadFile(m_Config.RootDirectory / "authstate"sv); if (Result.ErrorCode) { return; } IoBuffer Buffer = Result.Flatten(); const CbValidateError ValidationError = ValidateCompactBinary(Buffer, CbValidateMode::All); if (ValidationError != CbValidateError::None) { ZEN_WARN("load serialized state FAILED, reason 'Invalid compact binary'"); return; } if (CbObject AuthState = LoadCompactBinaryObject(Buffer)) { for (CbFieldView ProviderView : AuthState["OpenIdProviders"sv]) { CbObjectView ProviderObj = ProviderView.AsObjectView(); std::string_view ProviderName = ProviderObj["Name"].AsString(); std::string_view Url = ProviderObj["Url"].AsString(); std::string_view ClientId = ProviderObj["ClientId"].AsString(); AddOpenIdProvider({.Name = ProviderName, .Url = Url, .ClientId = ClientId}); } for (CbFieldView TokenView : AuthState["OpenIdTokens"sv]) { CbObjectView TokenObj = TokenView.AsObjectView(); std::string_view ProviderName = TokenObj["ProviderName"sv].AsString(); std::string_view IdentityToken = TokenObj["IdentityToken"sv].AsString(); std::string_view RefreshToken = TokenObj["RefreshToken"sv].AsString(); std::string_view AccessToken = TokenObj["AccessToken"sv].AsString(); const bool Ok = AddOpenIdToken({.ProviderName = ProviderName, .IdentityToken = IdentityToken, .RefreshToken = RefreshToken, .AccessToken = AccessToken}); if (!Ok) { ZEN_WARN("load serialized OpenId token for provider '{}' FAILED", ProviderName); } } } } void SaveState() { CbObjectWriter AuthState; { std::unique_lock _(m_ProviderMutex); if (m_OpenIdProviders.size() > 0) { AuthState.BeginArray("OpenIdProviders"); for (const auto& Kv : m_OpenIdProviders) { AuthState.BeginObject(); AuthState << "Name"sv << Kv.second->Name; AuthState << "Url"sv << Kv.second->Url; AuthState << "ClientId"sv << Kv.second->ClientId; AuthState.EndObject(); } AuthState.EndArray(); } } { std::unique_lock _(m_TokenMutex); AuthState.BeginArray("OpenIdTokens"); if (m_OpenIdTokens.size() > 0) { for (const auto& Kv : m_OpenIdTokens) { AuthState.BeginObject(); AuthState << "ProviderName"sv << Kv.first; AuthState << "IdentityToken"sv << Kv.second.IdentityToken; AuthState << "RefreshToken"sv << Kv.second.RefreshToken; AuthState << "AccessToken"sv << Kv.second.AccessToken; AuthState << "ExpireTime"sv << Kv.second.ExpireTime; AuthState.EndObject(); } } AuthState.EndArray(); } std::filesystem::create_directories(m_Config.RootDirectory); WriteFile(m_Config.RootDirectory / "authstate"sv, AuthState.Save().GetBuffer().AsIoBuffer()); } struct OpenIdProvider { std::string Name; std::string Url; std::string ClientId; std::unique_ptr HttpClient; }; struct OpenIdToken { std::string IdentityToken; std::string RefreshToken; std::string AccessToken; double ExpireTime{}; }; using OpenIdProviderMap = std::unordered_map>; using OpenIdTokenMap = std::unordered_map; spdlog::logger& Log() { return m_Log; } AuthConfig m_Config; spdlog::logger& m_Log; OpenIdProviderMap m_OpenIdProviders; OpenIdTokenMap m_OpenIdTokens; std::mutex m_ProviderMutex; std::shared_mutex m_TokenMutex; }; std::unique_ptr MakeAuthMgr(const AuthConfig& Config) { return std::make_unique(Config); } } // namespace zen