diff options
| author | Per Larsson <[email protected]> | 2022-01-28 10:16:34 +0100 |
|---|---|---|
| committer | Per Larsson <[email protected]> | 2022-01-28 10:16:34 +0100 |
| commit | b34fcd781777c522b155be69239967b2dcfd1c36 (patch) | |
| tree | 4de81ad72d94ce29857439171c76bd7bab551745 /zenserver/auth/authmgr.cpp | |
| parent | Add OpenID auth to auth mgr. (diff) | |
| download | zen-b34fcd781777c522b155be69239967b2dcfd1c36.tar.xz zen-b34fcd781777c522b155be69239967b2dcfd1c36.zip | |
Extended auth mgr to restore OpenID provider(s) and token(s).
Diffstat (limited to 'zenserver/auth/authmgr.cpp')
| -rw-r--r-- | zenserver/auth/authmgr.cpp | 195 |
1 files changed, 179 insertions, 16 deletions
diff --git a/zenserver/auth/authmgr.cpp b/zenserver/auth/authmgr.cpp index 4c97693f9..20bc7c988 100644 --- a/zenserver/auth/authmgr.cpp +++ b/zenserver/auth/authmgr.cpp @@ -1,25 +1,50 @@ // Copyright Epic Games, Inc. All Rights Reserved. #include <auth/authmgr.h> +#include <auth/oidc.h> + +#include <zencore/compactbinary.h> +#include <zencore/compactbinarybuilder.h> +#include <zencore/compactbinaryvalidation.h> +#include <zencore/filesystem.h> #include <zencore/logging.h> #include <chrono> #include <condition_variable> +#include <memory> #include <shared_mutex> #include <thread> #include <unordered_map> +#include <fmt/format.h> + namespace zen { +using namespace std::literals; + class AuthMgrImpl final : public AuthMgr { public: - AuthMgrImpl(const AuthConfig& Config) : m_Log(logging::Get("auth")) { ZEN_UNUSED(Config); } + AuthMgrImpl(const AuthConfig& Config) : m_Config(Config), m_Log(logging::Get("auth")) { LoadState(); } - virtual ~AuthMgrImpl() {} + virtual ~AuthMgrImpl() { SaveState(); } virtual void AddOpenIdProvider(const AddOpenIdProviderParams& Params) final { + if (OpenIdProviderExist(Params.Name)) + { + return; + } + + std::unique_ptr<OidcClient> Client = + std::make_unique<OidcClient>(OidcClient::Options{.BaseUrl = Params.Url, .ClientId = Params.ClientId}); + + if (const auto InitResult = Client->Initialize(); InitResult.Ok == false) + { + ZEN_WARN("query OpenID provider FAILED, reason '{}'", InitResult.Reason); + return; + } + std::string NewProviderName = std::string(Params.Name); OpenIdProvider* NewProvider = nullptr; @@ -36,9 +61,10 @@ public: NewProvider = InsertResult.first->second.get(); } - NewProvider->Name = std::string(Params.Name); - NewProvider->Url = std::string(Params.Url); - NewProvider->ClientId = std::string(Params.ClientId); + NewProvider->Name = std::string(Params.Name); + NewProvider->Url = std::string(Params.Url); + NewProvider->ClientId = std::string(Params.ClientId); + NewProvider->HttpClient = std::move(Client); ZEN_INFO("added OpenID provider '{} - {}'", Params.Name, Params.Url); } @@ -53,7 +79,15 @@ public: if (Params.IdentityToken.empty() || Params.RefreshToken.empty() || Params.AccessToken.empty()) { - ZEN_WARN("trying add invalid OpenID token"); + ZEN_WARN("add OpenId token FAILED, reason 'Token invalid'"); + return false; + } + + auto RefreshResult = RefreshOpenIdToken(Params.ProviderName, Params.RefreshToken); + + if (RefreshResult.Ok == false) + { + ZEN_WARN("refresh OpenId token FAILED, reason '{}'", RefreshResult.Reason); return false; } @@ -63,9 +97,9 @@ public: std::unique_lock _(m_TokenMutex); const auto InsertResult = m_OpenIdTokens.try_emplace(std::string(Params.ProviderName), - OpenIdToken{.IdentityToken = std::string(Params.IdentityToken), - .RefreshToken = Params.RefreshToken, - .AccessToken = Params.AccessToken}); + OpenIdToken{.IdentityToken = RefreshResult.IdentityToken, + .RefreshToken = RefreshResult.RefreshToken, + .AccessToken = RefreshResult.AccessToken}); IsNew = InsertResult.second; } @@ -83,19 +117,147 @@ public: } private: + bool OpenIdProviderExist(std::string_view ProviderName) + { + std::unique_lock _(m_ProviderMutex); + + return m_OpenIdProviders.contains(std::string(ProviderName)); + } + + OidcClient& GetOpenIdClient(std::string_view ProviderName) + { + std::unique_lock _(m_ProviderMutex); + return *m_OpenIdProviders[std::string(ProviderName)]->HttpClient.get(); + } + + OidcClient::RefreshTokenResult RefreshOpenIdToken(std::string_view ProviderName, std::string_view RefreshToken) + { + if (OpenIdProviderExist(ProviderName) == false) + { + return {.Reason = fmt::format("provider '{}' is missing", ProviderName)}; + } + + OidcClient& Client = GetOpenIdClient(ProviderName); + + return Client.RefreshToken(RefreshToken); + } + + void Shutdown() { SaveState(); } + + void LoadState() + { + FileContents Result = ReadFile(m_Config.RootDirectory / "authstate"sv); + + if (Result.ErrorCode) + { + return; + } + + IoBuffer Buffer = Result.Flatten(); + + const CbValidateError ValidationError = ValidateCompactBinary(Buffer, CbValidateMode::All); + + if (ValidationError != CbValidateError::None) + { + ZEN_WARN("load serialized state FAILED, reason 'Invalid compact binary'"); + return; + } + + if (CbObject AuthState = LoadCompactBinaryObject(Buffer)) + { + for (CbFieldView ProviderView : AuthState["OpenIdProviders"sv]) + { + CbObjectView ProviderObj = ProviderView.AsObjectView(); + + std::string_view ProviderName = ProviderObj["Name"].AsString(); + std::string_view Url = ProviderObj["Url"].AsString(); + std::string_view ClientId = ProviderObj["ClientId"].AsString(); + + AddOpenIdProvider({.Name = ProviderName, .Url = Url, .ClientId = ClientId}); + } + + for (CbFieldView TokenView : AuthState["OpenIdTokens"sv]) + { + CbObjectView TokenObj = TokenView.AsObjectView(); + + std::string_view ProviderName = TokenObj["ProviderName"sv].AsString(); + std::string_view IdentityToken = TokenObj["IdentityToken"sv].AsString(); + std::string_view RefreshToken = TokenObj["RefreshToken"sv].AsString(); + std::string_view AccessToken = TokenObj["AccessToken"sv].AsString(); + + const bool Ok = AddOpenIdToken({.ProviderName = ProviderName, + .IdentityToken = IdentityToken, + .RefreshToken = RefreshToken, + .AccessToken = AccessToken}); + + if (!Ok) + { + ZEN_WARN("load serialized OpenId token for provider '{}' FAILED", ProviderName); + } + } + } + } + + void SaveState() + { + CbObjectWriter AuthState; + + { + std::unique_lock _(m_ProviderMutex); + + if (m_OpenIdProviders.size() > 0) + { + AuthState.BeginArray("OpenIdProviders"); + for (const auto& Kv : m_OpenIdProviders) + { + AuthState.BeginObject(); + AuthState << "Name"sv << Kv.second->Name; + AuthState << "Url"sv << Kv.second->Url; + AuthState << "ClientId"sv << Kv.second->ClientId; + AuthState.EndObject(); + } + AuthState.EndArray(); + } + } + + { + std::unique_lock _(m_TokenMutex); + + AuthState.BeginArray("OpenIdTokens"); + if (m_OpenIdTokens.size() > 0) + { + for (const auto& Kv : m_OpenIdTokens) + { + AuthState.BeginObject(); + AuthState << "ProviderName"sv << Kv.first; + AuthState << "IdentityToken"sv << Kv.second.IdentityToken; + AuthState << "RefreshToken"sv << Kv.second.RefreshToken; + AuthState << "AccessToken"sv << Kv.second.AccessToken; + AuthState << "ExpireTime"sv << Kv.second.ExpireTime; + AuthState.EndObject(); + } + } + AuthState.EndArray(); + } + + std::filesystem::create_directories(m_Config.RootDirectory); + WriteFile(m_Config.RootDirectory / "authstate"sv, AuthState.Save().GetBuffer().AsIoBuffer()); + } + struct OpenIdProvider { - std::string Name; - std::string Url; - std::string ClientId; + std::string Name; + std::string Url; + std::string ClientId; + std::unique_ptr<OidcClient> HttpClient; }; struct OpenIdToken { - std::string_view IdentityToken; - std::string_view RefreshToken; - std::string_view AccessToken; - double ExpireTime{}; + std::string IdentityToken; + std::string RefreshToken; + std::string AccessToken; + double ExpireTime{}; }; using OpenIdProviderMap = std::unordered_map<std::string, std::unique_ptr<OpenIdProvider>>; @@ -103,6 +265,7 @@ private: spdlog::logger& Log() { return m_Log; } + AuthConfig m_Config; spdlog::logger& m_Log; OpenIdProviderMap m_OpenIdProviders; OpenIdTokenMap m_OpenIdTokens; |