diff options
| -rw-r--r-- | zenserver/auth/authmgr.cpp | 296 | ||||
| -rw-r--r-- | zenserver/auth/authmgr.h | 50 | ||||
| -rw-r--r-- | zenserver/auth/authservice.cpp | 57 | ||||
| -rw-r--r-- | zenserver/auth/authservice.h | 5 | ||||
| -rw-r--r-- | zenserver/auth/oidc.cpp | 127 | ||||
| -rw-r--r-- | zenserver/auth/oidc.h | 74 | ||||
| -rw-r--r-- | zenserver/config.cpp | 1 | ||||
| -rw-r--r-- | zenserver/upstream/upstreamcache.cpp | 11 | ||||
| -rw-r--r-- | zenserver/upstream/upstreamcache.h | 3 | ||||
| -rw-r--r-- | zenserver/upstream/upstreamservice.cpp | 6 | ||||
| -rw-r--r-- | zenserver/upstream/upstreamservice.h | 6 | ||||
| -rw-r--r-- | zenserver/zenserver.cpp | 12 |
12 files changed, 632 insertions, 16 deletions
diff --git a/zenserver/auth/authmgr.cpp b/zenserver/auth/authmgr.cpp new file mode 100644 index 000000000..28e128fc0 --- /dev/null +++ b/zenserver/auth/authmgr.cpp @@ -0,0 +1,296 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <auth/authmgr.h> +#include <auth/oidc.h> + +#include <zencore/compactbinary.h> +#include <zencore/compactbinarybuilder.h> +#include <zencore/compactbinaryvalidation.h> +#include <zencore/filesystem.h> +#include <zencore/logging.h> + +#include <chrono> +#include <condition_variable> +#include <memory> +#include <shared_mutex> +#include <thread> +#include <unordered_map> + +#include <fmt/format.h> + +namespace zen { + +using namespace std::literals; + +class AuthMgrImpl final : public AuthMgr +{ +public: + AuthMgrImpl(const AuthConfig& Config) : m_Config(Config), m_Log(logging::Get("auth")) { LoadState(); } + + virtual ~AuthMgrImpl() { SaveState(); } + + virtual void AddOpenIdProvider(const AddOpenIdProviderParams& Params) final + { + if (OpenIdProviderExist(Params.Name)) + { + return; + } + + std::unique_ptr<OidcClient> Client = + std::make_unique<OidcClient>(OidcClient::Options{.BaseUrl = Params.Url, .ClientId = Params.ClientId}); + + if (const auto InitResult = Client->Initialize(); InitResult.Ok == false) + { + ZEN_WARN("query OpenID provider FAILED, reason '{}'", InitResult.Reason); + return; + } + + std::string NewProviderName = std::string(Params.Name); + + OpenIdProvider* NewProvider = nullptr; + + { + std::unique_lock _(m_ProviderMutex); + + if (m_OpenIdProviders.contains(NewProviderName)) + { + return; + } + + auto InsertResult = m_OpenIdProviders.emplace(NewProviderName, std::make_unique<OpenIdProvider>()); + NewProvider = InsertResult.first->second.get(); + } + + NewProvider->Name = std::string(Params.Name); + NewProvider->Url = std::string(Params.Url); + NewProvider->ClientId = std::string(Params.ClientId); + NewProvider->HttpClient = std::move(Client); + + ZEN_INFO("added OpenID provider '{} - {}'", Params.Name, Params.Url); + } + + virtual bool AddOpenIdToken(const AddOpenIdTokenParams& Params) final + { + if (Params.ProviderName.empty()) + { + ZEN_WARN("trying add OpenID token with invalid provider name"); + return false; + } + + if (Params.IdentityToken.empty() || Params.RefreshToken.empty() || Params.AccessToken.empty()) + { + ZEN_WARN("add OpenId token FAILED, reason 'Token invalid'"); + return false; + } + + auto RefreshResult = RefreshOpenIdToken(Params.ProviderName, Params.RefreshToken); + + if (RefreshResult.Ok == false) + { + ZEN_WARN("refresh OpenId token FAILED, reason '{}'", RefreshResult.Reason); + return false; + } + + bool IsNew = false; + + { + std::unique_lock _(m_TokenMutex); + + const auto InsertResult = m_OpenIdTokens.try_emplace(std::string(Params.ProviderName), + OpenIdToken{.IdentityToken = RefreshResult.IdentityToken, + .RefreshToken = RefreshResult.RefreshToken, + .AccessToken = RefreshResult.AccessToken}); + + IsNew = InsertResult.second; + } + + if (IsNew) + { + ZEN_INFO("added new OpenID token for provider '{}'", Params.ProviderName); + } + else + { + ZEN_INFO("updating OpenID token for provider '{}'", Params.ProviderName); + } + + return true; + } + + virtual OpenIdAccessToken GetOpenIdAccessToken(std::string_view ProviderName) final + { + std::unique_lock _(m_TokenMutex); + + if (auto It = m_OpenIdTokens.find(std::string(ProviderName)); It != m_OpenIdTokens.end()) + { + const OpenIdToken& Token = It->second; + + return {.AccessToken = Token.AccessToken}; + } + + return {}; + } + +private: + bool OpenIdProviderExist(std::string_view ProviderName) + { + std::unique_lock _(m_ProviderMutex); + + return m_OpenIdProviders.contains(std::string(ProviderName)); + } + + OidcClient& GetOpenIdClient(std::string_view ProviderName) + { + std::unique_lock _(m_ProviderMutex); + return *m_OpenIdProviders[std::string(ProviderName)]->HttpClient.get(); + } + + OidcClient::RefreshTokenResult RefreshOpenIdToken(std::string_view ProviderName, std::string_view RefreshToken) + { + if (OpenIdProviderExist(ProviderName) == false) + { + return {.Reason = fmt::format("provider '{}' is missing", ProviderName)}; + } + + OidcClient& Client = GetOpenIdClient(ProviderName); + + return Client.RefreshToken(RefreshToken); + } + + void Shutdown() { SaveState(); } + + void LoadState() + { + FileContents Result = ReadFile(m_Config.RootDirectory / "authstate"sv); + + if (Result.ErrorCode) + { + return; + } + + IoBuffer Buffer = Result.Flatten(); + + const CbValidateError ValidationError = ValidateCompactBinary(Buffer, CbValidateMode::All); + + if (ValidationError != CbValidateError::None) + { + ZEN_WARN("load serialized state FAILED, reason 'Invalid compact binary'"); + return; + } + + if (CbObject AuthState = LoadCompactBinaryObject(Buffer)) + { + for (CbFieldView ProviderView : AuthState["OpenIdProviders"sv]) + { + CbObjectView ProviderObj = ProviderView.AsObjectView(); + + std::string_view ProviderName = ProviderObj["Name"].AsString(); + std::string_view Url = ProviderObj["Url"].AsString(); + std::string_view ClientId = ProviderObj["ClientId"].AsString(); + + AddOpenIdProvider({.Name = ProviderName, .Url = Url, .ClientId = ClientId}); + } + + for (CbFieldView TokenView : AuthState["OpenIdTokens"sv]) + { + CbObjectView TokenObj = TokenView.AsObjectView(); + + std::string_view ProviderName = TokenObj["ProviderName"sv].AsString(); + std::string_view IdentityToken = TokenObj["IdentityToken"sv].AsString(); + std::string_view RefreshToken = TokenObj["RefreshToken"sv].AsString(); + std::string_view AccessToken = TokenObj["AccessToken"sv].AsString(); + + const bool Ok = AddOpenIdToken({.ProviderName = ProviderName, + .IdentityToken = IdentityToken, + .RefreshToken = RefreshToken, + .AccessToken = AccessToken}); + + if (!Ok) + { + ZEN_WARN("load serialized OpenId token for provider '{}' FAILED", ProviderName); + } + } + } + } + + void SaveState() + { + CbObjectWriter AuthState; + + { + std::unique_lock _(m_ProviderMutex); + + if (m_OpenIdProviders.size() > 0) + { + AuthState.BeginArray("OpenIdProviders"); + for (const auto& Kv : m_OpenIdProviders) + { + AuthState.BeginObject(); + AuthState << "Name"sv << Kv.second->Name; + AuthState << "Url"sv << Kv.second->Url; + AuthState << "ClientId"sv << Kv.second->ClientId; + AuthState.EndObject(); + } + AuthState.EndArray(); + } + } + + { + std::unique_lock _(m_TokenMutex); + + AuthState.BeginArray("OpenIdTokens"); + if (m_OpenIdTokens.size() > 0) + { + for (const auto& Kv : m_OpenIdTokens) + { + AuthState.BeginObject(); + AuthState << "ProviderName"sv << Kv.first; + AuthState << "IdentityToken"sv << Kv.second.IdentityToken; + AuthState << "RefreshToken"sv << Kv.second.RefreshToken; + AuthState << "AccessToken"sv << Kv.second.AccessToken; + AuthState << "ExpireTime"sv << Kv.second.ExpireTime; + AuthState.EndObject(); + } + } + AuthState.EndArray(); + } + + std::filesystem::create_directories(m_Config.RootDirectory); + WriteFile(m_Config.RootDirectory / "authstate"sv, AuthState.Save().GetBuffer().AsIoBuffer()); + } + + struct OpenIdProvider + { + std::string Name; + std::string Url; + std::string ClientId; + std::unique_ptr<OidcClient> HttpClient; + }; + + struct OpenIdToken + { + std::string IdentityToken; + std::string RefreshToken; + std::string AccessToken; + double ExpireTime{}; + }; + + using OpenIdProviderMap = std::unordered_map<std::string, std::unique_ptr<OpenIdProvider>>; + using OpenIdTokenMap = std::unordered_map<std::string, OpenIdToken>; + + spdlog::logger& Log() { return m_Log; } + + AuthConfig m_Config; + spdlog::logger& m_Log; + OpenIdProviderMap m_OpenIdProviders; + OpenIdTokenMap m_OpenIdTokens; + std::mutex m_ProviderMutex; + std::shared_mutex m_TokenMutex; +}; + +std::unique_ptr<AuthMgr> +MakeAuthMgr(const AuthConfig& Config) +{ + return std::make_unique<AuthMgrImpl>(Config); +} + +} // namespace zen diff --git a/zenserver/auth/authmgr.h b/zenserver/auth/authmgr.h new file mode 100644 index 000000000..1138d9eff --- /dev/null +++ b/zenserver/auth/authmgr.h @@ -0,0 +1,50 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencore/string.h> + +#include <filesystem> +#include <memory> + +namespace zen { + +class AuthMgr +{ +public: + virtual ~AuthMgr() = default; + + struct AddOpenIdProviderParams + { + std::string_view Name; + std::string_view Url; + std::string_view ClientId; + }; + + virtual void AddOpenIdProvider(const AddOpenIdProviderParams& Params) = 0; + + struct AddOpenIdTokenParams + { + std::string_view ProviderName; + std::string_view IdentityToken; + std::string_view RefreshToken; + std::string_view AccessToken; + }; + + + virtual bool AddOpenIdToken(const AddOpenIdTokenParams& Params) = 0; + + struct OpenIdAccessToken + { + std::string AccessToken; + }; + + virtual OpenIdAccessToken GetOpenIdAccessToken(std::string_view ProviderName) = 0; +}; + +struct AuthConfig +{ + std::filesystem::path RootDirectory; +}; + +std::unique_ptr<AuthMgr> MakeAuthMgr(const AuthConfig& Config); + +} // namespace zen diff --git a/zenserver/auth/authservice.cpp b/zenserver/auth/authservice.cpp index eecad45bf..4e6f496a6 100644 --- a/zenserver/auth/authservice.cpp +++ b/zenserver/auth/authservice.cpp @@ -1,19 +1,70 @@ // Copyright Epic Games, Inc. All Rights Reserved. +#include <auth/authmgr.h> #include <auth/authservice.h> + +#include <zencore/compactbinarybuilder.h> #include <zencore/string.h> +ZEN_THIRD_PARTY_INCLUDES_START +#include <json11.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + namespace zen { using namespace std::literals; -HttpAuthService::HttpAuthService() +HttpAuthService::HttpAuthService(AuthMgr& AuthMgr) : m_AuthMgr(AuthMgr) { m_Router.RegisterRoute( "token", - [](HttpRouterRequest& RouterRequest) { + [this](HttpRouterRequest& RouterRequest) { HttpServerRequest& ServerRequest = RouterRequest.ServerRequest(); - ServerRequest.WriteResponse(HttpResponseCode::OK); + + const HttpContentType ContentType = ServerRequest.RequestContentType(); + + if ((ContentType == HttpContentType::kUnknownContentType || ContentType == HttpContentType::kJSON) == false) + { + return ServerRequest.WriteResponse(HttpResponseCode::BadRequest); + } + + const IoBuffer Body = ServerRequest.ReadPayload(); + + std::string JsonText(reinterpret_cast<const char*>(Body.GetData()), Body.GetSize()); + std::string JsonError; + json11::Json TokenInfo = json11::Json::parse(JsonText, JsonError); + + if (!JsonError.empty()) + { + CbObjectWriter Response; + Response << "Result"sv << false; + Response << "Error"sv << JsonError; + + return ServerRequest.WriteResponse(HttpResponseCode::BadRequest, Response.Save()); + } + + const std::string IdentityToken = TokenInfo["IdentityToken"].string_value(); + const std::string RefreshToken = TokenInfo["RefreshToken"].string_value(); + const std::string AccessToken = TokenInfo["AccessToken"].string_value(); + + const bool Ok = m_AuthMgr.AddOpenIdToken(AuthMgr::AddOpenIdTokenParams{.ProviderName = "Okta"sv, + .IdentityToken = IdentityToken, + .RefreshToken = RefreshToken, + .AccessToken = AccessToken}); + + if (Ok) + { + ServerRequest.WriteResponse(Ok ? HttpResponseCode::OK : HttpResponseCode::BadRequest); + } + else + { + CbObjectWriter Response; + Response << "Result"sv << false; + Response << "Error"sv + << "Invalid token"sv; + + ServerRequest.WriteResponse(HttpResponseCode::BadRequest, Response.Save()); + } }, HttpVerb::kPost); } diff --git a/zenserver/auth/authservice.h b/zenserver/auth/authservice.h index 30b2b5864..64b86e21f 100644 --- a/zenserver/auth/authservice.h +++ b/zenserver/auth/authservice.h @@ -6,16 +6,19 @@ namespace zen { +class AuthMgr; + class HttpAuthService final : public zen::HttpService { public: - HttpAuthService(); + HttpAuthService(AuthMgr& AuthMgr); virtual ~HttpAuthService(); virtual const char* BaseUri() const override; virtual void HandleRequest(zen::HttpServerRequest& Request) override; private: + AuthMgr& m_AuthMgr; HttpRequestRouter m_Router; }; diff --git a/zenserver/auth/oidc.cpp b/zenserver/auth/oidc.cpp new file mode 100644 index 000000000..2f53f1bae --- /dev/null +++ b/zenserver/auth/oidc.cpp @@ -0,0 +1,127 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <auth/oidc.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <cpr/cpr.h> +#include <fmt/format.h> +#include <json11.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen { + +namespace details { + + using StringArray = std::vector<std::string>; + + StringArray ToStringArray(const json11::Json JsonArray) + { + StringArray Result; + + const auto& Items = JsonArray.array_items(); + + for (const auto& Item : Items) + { + Result.push_back(Item.string_value()); + } + + return Result; + } + +} // namespace details + +using namespace std::literals; + +OidcClient::OidcClient(const OidcClient::Options& Options) +{ + m_BaseUrl = std::string(Options.BaseUrl); + m_ClientId = std::string(Options.ClientId); +} + +OidcClient::InitResult +OidcClient::Initialize() +{ + ExtendableStringBuilder<256> Uri; + Uri << m_BaseUrl << "/.well-known/openid-configuration"sv; + + cpr::Session Session; + + Session.SetOption(cpr::Url{Uri.c_str()}); + + cpr::Response Response = Session.Get(); + + if (Response.error) + { + return {.Reason = std::move(Response.error.message)}; + } + + if (Response.status_code != 200) + { + return {.Reason = std::move(Response.reason)}; + } + + std::string JsonError; + json11::Json Json = json11::Json::parse(Response.text, JsonError); + + if (JsonError.empty() == false) + { + return {.Reason = std::move(JsonError)}; + } + + m_Config = {.Issuer = Json["issuer"].string_value(), + .AuthorizationEndpoint = Json["authorization_endpoint"].string_value(), + .TokenEndpoint = Json["token_endpoint"].string_value(), + .UserInfoEndpoint = Json["userinfo_endpoint"].string_value(), + .RegistrationEndpoint = Json["registration_endpoint"].string_value(), + .JwksUri = Json["jwks_uri"].string_value(), + .SupportedResponseTypes = details::ToStringArray(Json["response_types_supported"]), + .SupportedResponseModes = details::ToStringArray(Json["response_modes_supported"]), + .SupportedGrantTypes = details::ToStringArray(Json["grant_types_supported"]), + .SupportedScopes = details::ToStringArray(Json["scopes_supported"]), + .SupportedTokenEndpointAuthMethods = details::ToStringArray(Json["token_endpoint_auth_methods_supported"]), + .SupportedClaims = details::ToStringArray(Json["claims_supported"])}; + + return {.Ok = true}; +} + +OidcClient::RefreshTokenResult +OidcClient::RefreshToken(std::string_view RefreshToken) +{ + const std::string Body = fmt::format("grant_type=refresh_token&refresh_token={}&client_id={}", RefreshToken, m_ClientId); + + cpr::Session Session; + + Session.SetOption(cpr::Url{m_Config.TokenEndpoint.c_str()}); + Session.SetOption(cpr::Header{{"Content-Type", "application/x-www-form-urlencoded"}}); + Session.SetBody(cpr::Body{Body.data(), Body.size()}); + + cpr::Response Response = Session.Post(); + + if (Response.error) + { + return {.Reason = std::move(Response.error.message)}; + } + + if (Response.status_code != 200) + { + return {.Reason = std::move(Response.reason)}; + } + + std::string JsonError; + json11::Json Json = json11::Json::parse(Response.text, JsonError); + + if (JsonError.empty() == false) + { + return {.Reason = std::move(JsonError)}; + } + + return {.TokenType = Json["token_type"].string_value(), + .AccessToken = Json["access_token"].string_value(), + .RefreshToken = Json["refresh_token"].string_value(), + .IdentityToken = Json["id_token"].string_value(), + .Scope = Json["scope"].string_value(), + .ExpiresIn = Json["scope"].number_value(), + .Ok = true}; +} + +} // namespace zen diff --git a/zenserver/auth/oidc.h b/zenserver/auth/oidc.h new file mode 100644 index 000000000..b08181bfd --- /dev/null +++ b/zenserver/auth/oidc.h @@ -0,0 +1,74 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencore/string.h> + +#include <vector> + +namespace zen { + +class OidcClient +{ +public: + struct Options + { + std::string_view BaseUrl; + std::string_view ClientId; + }; + + OidcClient(const Options& Options); + ~OidcClient() = default; + + OidcClient(const OidcClient&) = delete; + OidcClient& operator=(const OidcClient&) = delete; + + struct Result + { + std::string Reason; + bool Ok = false; + }; + + using InitResult = Result; + + InitResult Initialize(); + + struct RefreshTokenResult + { + std::string TokenType; + std::string AccessToken; + std::string RefreshToken; + std::string IdentityToken; + std::string Scope; + std::string Reason; + double ExpiresIn{}; + bool Ok = false; + }; + + RefreshTokenResult RefreshToken(std::string_view RefreshToken); + +private: + using StringArray = std::vector<std::string>; + + struct OpenIdConfiguration + { + std::string Issuer; + std::string AuthorizationEndpoint; + std::string TokenEndpoint; + std::string UserInfoEndpoint; + std::string RegistrationEndpoint; + std::string EndSessionEndpoint; + std::string DeviceAuthorizationEndpoint; + std::string JwksUri; + StringArray SupportedResponseTypes; + StringArray SupportedResponseModes; + StringArray SupportedGrantTypes; + StringArray SupportedScopes; + StringArray SupportedTokenEndpointAuthMethods; + StringArray SupportedClaims; + }; + + std::string m_BaseUrl; + std::string m_ClientId; + OpenIdConfiguration m_Config; +}; + +} // namespace zen diff --git a/zenserver/config.cpp b/zenserver/config.cpp index a36ce5f33..6fd1c3bea 100644 --- a/zenserver/config.cpp +++ b/zenserver/config.cpp @@ -323,6 +323,7 @@ ParseCliOptions(int argc, char* argv[], ZenServerOptions& ServerOptions) try { auto result = options.parse(argc, argv); + ServerOptions.DataDir = DataDir; if (result.count("help")) { diff --git a/zenserver/upstream/upstreamcache.cpp b/zenserver/upstream/upstreamcache.cpp index 657cfb729..206787bf7 100644 --- a/zenserver/upstream/upstreamcache.cpp +++ b/zenserver/upstream/upstreamcache.cpp @@ -18,6 +18,7 @@ #include <zenstore/cas.h> #include <zenstore/cidstore.h> +#include <auth/authmgr.h> #include "cache/structuredcachestore.h" #include "diag/logging.h" @@ -84,8 +85,9 @@ namespace detail { class JupiterUpstreamEndpoint final : public UpstreamEndpoint { public: - JupiterUpstreamEndpoint(const CloudCacheClientOptions& Options) - : m_Log(zen::logging::Get("upstream")) + JupiterUpstreamEndpoint(const CloudCacheClientOptions& Options, AuthMgr& Mgr) + : m_AuthMgr(Mgr) + , m_Log(zen::logging::Get("upstream")) , m_UseLegacyDdc(Options.UseLegacyDdc) { ZEN_ASSERT(!Options.Name.empty()); @@ -513,6 +515,7 @@ namespace detail { spdlog::logger& Log() { return m_Log; } + AuthMgr& m_AuthMgr; spdlog::logger& m_Log; UpstreamEndpointInfo m_Info; UpstreamStatus m_Status; @@ -1485,9 +1488,9 @@ MakeUpstreamCache(const UpstreamCacheOptions& Options, ZenCacheStore& CacheStore } std::unique_ptr<UpstreamEndpoint> -MakeJupiterUpstreamEndpoint(const CloudCacheClientOptions& Options) +MakeJupiterUpstreamEndpoint(const CloudCacheClientOptions& Options, AuthMgr& Mgr) { - return std::make_unique<detail::JupiterUpstreamEndpoint>(Options); + return std::make_unique<detail::JupiterUpstreamEndpoint>(Options, Mgr); } std::unique_ptr<UpstreamEndpoint> diff --git a/zenserver/upstream/upstreamcache.h b/zenserver/upstream/upstreamcache.h index 2087b1fba..5bc9f58d7 100644 --- a/zenserver/upstream/upstreamcache.h +++ b/zenserver/upstream/upstreamcache.h @@ -15,6 +15,7 @@ namespace zen { +class AuthMgr; class CbObjectView; class CbPackage; class CbObjectWriter; @@ -203,7 +204,7 @@ public: std::unique_ptr<UpstreamCache> MakeUpstreamCache(const UpstreamCacheOptions& Options, ZenCacheStore& CacheStore, CidStore& CidStore); -std::unique_ptr<UpstreamEndpoint> MakeJupiterUpstreamEndpoint(const CloudCacheClientOptions& Options); +std::unique_ptr<UpstreamEndpoint> MakeJupiterUpstreamEndpoint(const CloudCacheClientOptions& Options, AuthMgr& Mgr); std::unique_ptr<UpstreamEndpoint> MakeZenUpstreamEndpoint(const ZenStructuredCacheClientOptions& Options); diff --git a/zenserver/upstream/upstreamservice.cpp b/zenserver/upstream/upstreamservice.cpp index 1cfd1df85..c8176779e 100644 --- a/zenserver/upstream/upstreamservice.cpp +++ b/zenserver/upstream/upstreamservice.cpp @@ -1,9 +1,11 @@ // Copyright Epic Games, Inc. All Rights Reserved. +#include <auth/authmgr.h> #include <upstream/jupiter.h> #include <upstream/upstreamcache.h> #include <upstream/upstreamservice.h> #include <upstream/zen.h> + #include <zencore/compactbinarybuilder.h> #include <zencore/string.h> @@ -64,7 +66,7 @@ namespace { } } // namespace -HttpUpstreamService::HttpUpstreamService(UpstreamCache& Upstream) : m_Upstream(Upstream) +HttpUpstreamService::HttpUpstreamService(UpstreamCache& Upstream, AuthMgr& Mgr) : m_Upstream(Upstream), m_AuthMgr(Mgr) { m_Router.RegisterRoute( "endpoints", @@ -179,7 +181,7 @@ HttpUpstreamService::HttpUpstreamService(UpstreamCache& Upstream) : m_Upstream(U .OAuthSecret = OAuthSecret, .AccessToken = OAuthToken}; - Endpoint = zen::MakeJupiterUpstreamEndpoint(Options); + Endpoint = zen::MakeJupiterUpstreamEndpoint(Options, m_AuthMgr); } m_Upstream.RegisterEndpoint(std::move(Endpoint)); diff --git a/zenserver/upstream/upstreamservice.h b/zenserver/upstream/upstreamservice.h index 0a42198c2..f1da03c8c 100644 --- a/zenserver/upstream/upstreamservice.h +++ b/zenserver/upstream/upstreamservice.h @@ -6,20 +6,22 @@ namespace zen { +class AuthMgr; class UpstreamCache; class HttpUpstreamService final : public zen::HttpService { public: - HttpUpstreamService(UpstreamCache& Upstream); + HttpUpstreamService(UpstreamCache& Upstream, AuthMgr& Mgr); virtual ~HttpUpstreamService(); virtual const char* BaseUri() const override; virtual void HandleRequest(zen::HttpServerRequest& Request) override; private: - HttpRequestRouter m_Router; UpstreamCache& m_Upstream; + AuthMgr& m_AuthMgr; + HttpRequestRouter m_Router; }; } // namespace zen diff --git a/zenserver/zenserver.cpp b/zenserver/zenserver.cpp index c6a27ec44..bd4a3c1cf 100644 --- a/zenserver/zenserver.cpp +++ b/zenserver/zenserver.cpp @@ -99,6 +99,7 @@ ZEN_THIRD_PARTY_INCLUDES_END // #include "admin/admin.h" +#include "auth/authmgr.h" #include "auth/authservice.h" #include "cache/structuredcache.h" #include "cache/structuredcachestore.h" @@ -203,9 +204,13 @@ public: m_Http = zen::CreateHttpServer(ServerOptions.HttpServerClass); int EffectiveBasePort = m_Http->Initialize(ServerOptions.BasePort); - m_AuthService = std::make_unique<zen::HttpAuthService>(); + m_AuthMgr = MakeAuthMgr({.RootDirectory = m_DataRoot / "auth"}); + m_AuthService = std::make_unique<zen::HttpAuthService>(*m_AuthMgr); m_Http->RegisterService(*m_AuthService); + m_AuthMgr->AddOpenIdProvider( + {.Name = "Okta"sv, .Url = "https://epicgames.okta.com/oauth2/auso645ojjWVdRI3d0x7"sv, .ClientId = "0oapq1knoglGFqQvr0x7"sv}); + m_Http->RegisterService(m_HealthService); m_Http->RegisterService(m_StatsService); m_Http->RegisterService(m_StatusService); @@ -529,6 +534,7 @@ private: } zen::Ref<zen::HttpServer> m_Http; + std::unique_ptr<zen::AuthMgr> m_AuthMgr; std::unique_ptr<zen::HttpAuthService> m_AuthService; zen::HttpStatusService m_StatusService; zen::HttpStatsService m_StatsService; @@ -696,7 +702,7 @@ ZenServer::InitializeStructuredCache(const ZenServerOptions& ServerOptions) } m_UpstreamCache = zen::MakeUpstreamCache(UpstreamOptions, *m_CacheStore, *m_CidStore); - m_UpstreamService = std::make_unique<HttpUpstreamService>(*m_UpstreamCache); + m_UpstreamService = std::make_unique<HttpUpstreamService>(*m_UpstreamCache, *m_AuthMgr); m_UpstreamCache->Initialize(); if (ServerOptions.UpstreamCacheConfig.CachePolicy != UpstreamCachePolicy::Disabled) @@ -785,7 +791,7 @@ ZenServer::InitializeStructuredCache(const ZenServerOptions& ServerOptions) if (!Options.ServiceUrl.empty()) { - std::unique_ptr<zen::UpstreamEndpoint> JupiterEndpoint = zen::MakeJupiterUpstreamEndpoint(Options); + std::unique_ptr<zen::UpstreamEndpoint> JupiterEndpoint = zen::MakeJupiterUpstreamEndpoint(Options, *m_AuthMgr); m_UpstreamCache->RegisterEndpoint(std::move(JupiterEndpoint)); } } |