diff options
| -rw-r--r-- | zenserver/auth/authmgr.cpp | 57 | ||||
| -rw-r--r-- | zenserver/auth/authmgr.h | 6 | ||||
| -rw-r--r-- | zenserver/auth/authservice.cpp | 26 |
3 files changed, 80 insertions, 9 deletions
diff --git a/zenserver/auth/authmgr.cpp b/zenserver/auth/authmgr.cpp index af579d950..4c97693f9 100644 --- a/zenserver/auth/authmgr.cpp +++ b/zenserver/auth/authmgr.cpp @@ -14,10 +14,7 @@ namespace zen { class AuthMgrImpl final : public AuthMgr { public: - AuthMgrImpl(const AuthConfig& Config) : m_Log(logging::Get("auth")) - { - ZEN_UNUSED(Config); - } + AuthMgrImpl(const AuthConfig& Config) : m_Log(logging::Get("auth")) { ZEN_UNUSED(Config); } virtual ~AuthMgrImpl() {} @@ -46,6 +43,45 @@ public: ZEN_INFO("added OpenID provider '{} - {}'", Params.Name, Params.Url); } + virtual bool AddOpenIdToken(const AddOpenIdTokenParams& Params) final + { + if (Params.ProviderName.empty()) + { + ZEN_WARN("trying add OpenID token with invalid provider name"); + return false; + } + + if (Params.IdentityToken.empty() || Params.RefreshToken.empty() || Params.AccessToken.empty()) + { + ZEN_WARN("trying add invalid OpenID token"); + return false; + } + + bool IsNew = false; + + { + std::unique_lock _(m_TokenMutex); + + const auto InsertResult = m_OpenIdTokens.try_emplace(std::string(Params.ProviderName), + OpenIdToken{.IdentityToken = std::string(Params.IdentityToken), + .RefreshToken = Params.RefreshToken, + .AccessToken = Params.AccessToken}); + + IsNew = InsertResult.second; + } + + if (IsNew) + { + ZEN_INFO("added new OpenID token for provider '{}'", Params.ProviderName); + } + else + { + ZEN_INFO("updating OpenID token for provider '{}'", Params.ProviderName); + } + + return true; + } + private: struct OpenIdProvider { @@ -54,13 +90,24 @@ private: std::string ClientId; }; + struct OpenIdToken + { + std::string_view IdentityToken; + std::string_view RefreshToken; + std::string_view AccessToken; + double ExpireTime{}; + }; + using OpenIdProviderMap = std::unordered_map<std::string, std::unique_ptr<OpenIdProvider>>; + using OpenIdTokenMap = std::unordered_map<std::string, OpenIdToken>; spdlog::logger& Log() { return m_Log; } spdlog::logger& m_Log; - std::mutex m_ProviderMutex; OpenIdProviderMap m_OpenIdProviders; + OpenIdTokenMap m_OpenIdTokens; + std::mutex m_ProviderMutex; + std::shared_mutex m_TokenMutex; }; std::unique_ptr<AuthMgr> diff --git a/zenserver/auth/authmgr.h b/zenserver/auth/authmgr.h index 33bd15ee9..16d4071bf 100644 --- a/zenserver/auth/authmgr.h +++ b/zenserver/auth/authmgr.h @@ -23,7 +23,13 @@ public: struct AddOpenIdTokenParams { + std::string_view ProviderName; + std::string_view IdentityToken; + std::string_view RefreshToken; + std::string_view AccessToken; }; + + virtual bool AddOpenIdToken(const AddOpenIdTokenParams& Params) = 0; }; struct AuthConfig diff --git a/zenserver/auth/authservice.cpp b/zenserver/auth/authservice.cpp index 8200b9c9b..20ea252fa 100644 --- a/zenserver/auth/authservice.cpp +++ b/zenserver/auth/authservice.cpp @@ -1,7 +1,7 @@ // Copyright Epic Games, Inc. All Rights Reserved. -#include <auth/authservice.h> #include <auth/authmgr.h> +#include <auth/authservice.h> #include <zencore/compactbinarybuilder.h> #include <zencore/string.h> @@ -41,10 +41,28 @@ HttpAuthService::HttpAuthService(AuthMgr& AuthMgr) : m_AuthMgr(AuthMgr) return ServerRequest.WriteResponse(HttpResponseCode::BadRequest, Response.Save()); } - const std::string RefreshToken = TokenInfo["RefreshToken"].string_value(); - const std::string AccessToken = TokenInfo["AccessToken"].string_value(); + const std::string IdentityToken = TokenInfo["IdentityToken"].string_value(); + const std::string RefreshToken = TokenInfo["RefreshToken"].string_value(); + const std::string AccessToken = TokenInfo["AccessToken"].string_value(); + + const bool Ok = m_AuthMgr.AddOpenIdToken(AuthMgr::AddOpenIdTokenParams{.ProviderName = "Okta"sv, + .IdentityToken = IdentityToken, + .RefreshToken = RefreshToken, + .AccessToken = AccessToken}); + + if (Ok) + { + ServerRequest.WriteResponse(Ok ? HttpResponseCode::OK : HttpResponseCode::BadRequest); + } + else + { + CbObjectWriter Response; + Response << "Result"sv << false; + Response << "Error"sv + << "Invalid token"sv; - ServerRequest.WriteResponse(HttpResponseCode::OK); + ServerRequest.WriteResponse(HttpResponseCode::BadRequest, Response.Save()); + } }, HttpVerb::kPost); } |