diff options
Diffstat (limited to 'zenserver')
| -rw-r--r-- | zenserver/auth/authmgr.cpp | 165 | ||||
| -rw-r--r-- | zenserver/auth/authmgr.h | 7 | ||||
| -rw-r--r-- | zenserver/auth/authservice.cpp | 6 | ||||
| -rw-r--r-- | zenserver/auth/oidc.cpp | 14 | ||||
| -rw-r--r-- | zenserver/auth/oidc.h | 2 | ||||
| -rw-r--r-- | zenserver/upstream/jupiter.cpp | 19 | ||||
| -rw-r--r-- | zenserver/upstream/jupiter.h | 2 | ||||
| -rw-r--r-- | zenserver/upstream/upstreamcache.cpp | 7 | ||||
| -rw-r--r-- | zenserver/zenserver.cpp | 5 |
9 files changed, 134 insertions, 93 deletions
diff --git a/zenserver/auth/authmgr.cpp b/zenserver/auth/authmgr.cpp index 4d19316dd..9cdd5ed02 100644 --- a/zenserver/auth/authmgr.cpp +++ b/zenserver/auth/authmgr.cpp @@ -9,7 +9,6 @@ #include <zencore/filesystem.h> #include <zencore/logging.h> -#include <chrono> #include <condition_variable> #include <memory> #include <shared_mutex> @@ -24,6 +23,10 @@ using namespace std::literals; class AuthMgrImpl final : public AuthMgr { + using Clock = std::chrono::system_clock; + using TimePoint = Clock::time_point; + using Seconds = std::chrono::seconds; + public: AuthMgrImpl(const AuthConfig& Config) : m_Config(Config), m_Log(logging::Get("auth")) { LoadState(); } @@ -77,7 +80,7 @@ public: return false; } - if (Params.IdentityToken.empty() || Params.RefreshToken.empty() || Params.AccessToken.empty()) + if (Params.RefreshToken.empty()) { ZEN_WARN("add OpenId token FAILED, reason 'Token invalid'"); return false; @@ -94,12 +97,14 @@ public: bool IsNew = false; { + auto Token = OpenIdToken{.IdentityToken = RefreshResult.IdentityToken, + .RefreshToken = RefreshResult.RefreshToken, + .AccessToken = fmt::format("Bearer {}"sv, RefreshResult.AccessToken), + .ExpireTime = Clock::now() + Seconds(RefreshResult.ExpiresInSeconds)}; + std::unique_lock _(m_TokenMutex); - const auto InsertResult = m_OpenIdTokens.try_emplace(std::string(Params.ProviderName), - OpenIdToken{.IdentityToken = RefreshResult.IdentityToken, - .RefreshToken = RefreshResult.RefreshToken, - .AccessToken = RefreshResult.AccessToken}); + const auto InsertResult = m_OpenIdTokens.try_emplace(std::string(Params.ProviderName), std::move(Token)); IsNew = InsertResult.second; } @@ -124,7 +129,7 @@ public: { const OpenIdToken& Token = It->second; - return {.AccessToken = fmt::format("Bearer {}", Token.AccessToken)}; + return {.AccessToken = Token.AccessToken, .ExpireTime = Token.ExpireTime}; } return {}; @@ -160,102 +165,118 @@ private: void LoadState() { - FileContents Result = ReadFile(m_Config.RootDirectory / "authstate"sv); - - if (Result.ErrorCode) + try { - return; - } + FileContents Result = ReadFile(m_Config.RootDirectory / "authstate"sv); - IoBuffer Buffer = Result.Flatten(); + if (Result.ErrorCode) + { + return; + } - const CbValidateError ValidationError = ValidateCompactBinary(Buffer, CbValidateMode::All); + IoBuffer Buffer = Result.Flatten(); - if (ValidationError != CbValidateError::None) - { - ZEN_WARN("load serialized state FAILED, reason 'Invalid compact binary'"); - return; - } + const CbValidateError ValidationError = ValidateCompactBinary(Buffer, CbValidateMode::All); - if (CbObject AuthState = LoadCompactBinaryObject(Buffer)) - { - for (CbFieldView ProviderView : AuthState["OpenIdProviders"sv]) + if (ValidationError != CbValidateError::None) { - CbObjectView ProviderObj = ProviderView.AsObjectView(); - - std::string_view ProviderName = ProviderObj["Name"].AsString(); - std::string_view Url = ProviderObj["Url"].AsString(); - std::string_view ClientId = ProviderObj["ClientId"].AsString(); - - AddOpenIdProvider({.Name = ProviderName, .Url = Url, .ClientId = ClientId}); + ZEN_WARN("load serialized state FAILED, reason 'Invalid compact binary'"); + return; } - for (CbFieldView TokenView : AuthState["OpenIdTokens"sv]) + if (CbObject AuthState = LoadCompactBinaryObject(Buffer)) { - CbObjectView TokenObj = TokenView.AsObjectView(); + for (CbFieldView ProviderView : AuthState["OpenIdProviders"sv]) + { + CbObjectView ProviderObj = ProviderView.AsObjectView(); - std::string_view ProviderName = TokenObj["ProviderName"sv].AsString(); - std::string_view IdentityToken = TokenObj["IdentityToken"sv].AsString(); - std::string_view RefreshToken = TokenObj["RefreshToken"sv].AsString(); - std::string_view AccessToken = TokenObj["AccessToken"sv].AsString(); + std::string_view ProviderName = ProviderObj["Name"].AsString(); + std::string_view Url = ProviderObj["Url"].AsString(); + std::string_view ClientId = ProviderObj["ClientId"].AsString(); - const bool Ok = AddOpenIdToken({.ProviderName = ProviderName, - .IdentityToken = IdentityToken, - .RefreshToken = RefreshToken, - .AccessToken = AccessToken}); + AddOpenIdProvider({.Name = ProviderName, .Url = Url, .ClientId = ClientId}); + } - if (!Ok) + for (CbFieldView TokenView : AuthState["OpenIdTokens"sv]) { - ZEN_WARN("load serialized OpenId token for provider '{}' FAILED", ProviderName); + CbObjectView TokenObj = TokenView.AsObjectView(); + + std::string_view ProviderName = TokenObj["ProviderName"sv].AsString(); + std::string_view RefreshToken = TokenObj["RefreshToken"sv].AsString(); + + const bool Ok = AddOpenIdToken({.ProviderName = ProviderName, .RefreshToken = RefreshToken}); + + if (!Ok) + { + ZEN_WARN("load serialized OpenId token for provider '{}' FAILED", ProviderName); + } } } } + catch (std::exception& Err) + { + ZEN_ERROR("(de)serialize state FAILED, reason '{}'", Err.what()); + + { + std::unique_lock _(m_ProviderMutex); + m_OpenIdProviders.clear(); + } + + { + std::unique_lock _(m_TokenMutex); + m_OpenIdTokens.clear(); + } + } } void SaveState() { - CbObjectWriter AuthState; - + try { - std::unique_lock _(m_ProviderMutex); + CbObjectWriter AuthState; - if (m_OpenIdProviders.size() > 0) { - AuthState.BeginArray("OpenIdProviders"); - for (const auto& Kv : m_OpenIdProviders) + std::unique_lock _(m_ProviderMutex); + + if (m_OpenIdProviders.size() > 0) { - AuthState.BeginObject(); - AuthState << "Name"sv << Kv.second->Name; - AuthState << "Url"sv << Kv.second->Url; - AuthState << "ClientId"sv << Kv.second->ClientId; - AuthState.EndObject(); + AuthState.BeginArray("OpenIdProviders"); + for (const auto& Kv : m_OpenIdProviders) + { + AuthState.BeginObject(); + AuthState << "Name"sv << Kv.second->Name; + AuthState << "Url"sv << Kv.second->Url; + AuthState << "ClientId"sv << Kv.second->ClientId; + AuthState.EndObject(); + } + AuthState.EndArray(); } - AuthState.EndArray(); } - } - { - std::unique_lock _(m_TokenMutex); - - AuthState.BeginArray("OpenIdTokens"); - if (m_OpenIdTokens.size() > 0) { - for (const auto& Kv : m_OpenIdTokens) + std::unique_lock _(m_TokenMutex); + + AuthState.BeginArray("OpenIdTokens"); + if (m_OpenIdTokens.size() > 0) { - AuthState.BeginObject(); - AuthState << "ProviderName"sv << Kv.first; - AuthState << "IdentityToken"sv << Kv.second.IdentityToken; - AuthState << "RefreshToken"sv << Kv.second.RefreshToken; - AuthState << "AccessToken"sv << Kv.second.AccessToken; - AuthState << "ExpireTime"sv << Kv.second.ExpireTime; - AuthState.EndObject(); + for (const auto& Kv : m_OpenIdTokens) + { + AuthState.BeginObject(); + AuthState << "ProviderName"sv << Kv.first; + AuthState << "RefreshToken"sv << Kv.second.RefreshToken; + AuthState.EndObject(); + } } + AuthState.EndArray(); } - AuthState.EndArray(); - } - std::filesystem::create_directories(m_Config.RootDirectory); - WriteFile(m_Config.RootDirectory / "authstate"sv, AuthState.Save().GetBuffer().AsIoBuffer()); + std::filesystem::create_directories(m_Config.RootDirectory); + WriteFile(m_Config.RootDirectory / "authstate"sv, AuthState.Save().GetBuffer().AsIoBuffer()); + } + catch (std::exception& Err) + { + ZEN_ERROR("serialize state FAILED, reason '{}'", Err.what()); + } } struct OpenIdProvider @@ -271,7 +292,7 @@ private: std::string IdentityToken; std::string RefreshToken; std::string AccessToken; - double ExpireTime{}; + TimePoint ExpireTime{}; }; using OpenIdProviderMap = std::unordered_map<std::string, std::unique_ptr<OpenIdProvider>>; diff --git a/zenserver/auth/authmgr.h b/zenserver/auth/authmgr.h index 1138d9eff..59dc1725d 100644 --- a/zenserver/auth/authmgr.h +++ b/zenserver/auth/authmgr.h @@ -2,6 +2,7 @@ #include <zencore/string.h> +#include <chrono> #include <filesystem> #include <memory> @@ -24,17 +25,15 @@ public: struct AddOpenIdTokenParams { std::string_view ProviderName; - std::string_view IdentityToken; std::string_view RefreshToken; - std::string_view AccessToken; }; - virtual bool AddOpenIdToken(const AddOpenIdTokenParams& Params) = 0; struct OpenIdAccessToken { - std::string AccessToken; + std::string AccessToken; + std::chrono::system_clock::time_point ExpireTime{}; }; virtual OpenIdAccessToken GetOpenIdAccessToken(std::string_view ProviderName) = 0; diff --git a/zenserver/auth/authservice.cpp b/zenserver/auth/authservice.cpp index 4e6f496a6..47a757001 100644 --- a/zenserver/auth/authservice.cpp +++ b/zenserver/auth/authservice.cpp @@ -45,12 +45,8 @@ HttpAuthService::HttpAuthService(AuthMgr& AuthMgr) : m_AuthMgr(AuthMgr) const std::string IdentityToken = TokenInfo["IdentityToken"].string_value(); const std::string RefreshToken = TokenInfo["RefreshToken"].string_value(); - const std::string AccessToken = TokenInfo["AccessToken"].string_value(); - const bool Ok = m_AuthMgr.AddOpenIdToken(AuthMgr::AddOpenIdTokenParams{.ProviderName = "Okta"sv, - .IdentityToken = IdentityToken, - .RefreshToken = RefreshToken, - .AccessToken = AccessToken}); + const bool Ok = m_AuthMgr.AddOpenIdToken(AuthMgr::AddOpenIdTokenParams{.ProviderName = "Okta"sv, .RefreshToken = RefreshToken}); if (Ok) { diff --git a/zenserver/auth/oidc.cpp b/zenserver/auth/oidc.cpp index 2f53f1bae..17b5bac08 100644 --- a/zenserver/auth/oidc.cpp +++ b/zenserver/auth/oidc.cpp @@ -115,13 +115,13 @@ OidcClient::RefreshToken(std::string_view RefreshToken) return {.Reason = std::move(JsonError)}; } - return {.TokenType = Json["token_type"].string_value(), - .AccessToken = Json["access_token"].string_value(), - .RefreshToken = Json["refresh_token"].string_value(), - .IdentityToken = Json["id_token"].string_value(), - .Scope = Json["scope"].string_value(), - .ExpiresIn = Json["scope"].number_value(), - .Ok = true}; + return {.TokenType = Json["token_type"].string_value(), + .AccessToken = Json["access_token"].string_value(), + .RefreshToken = Json["refresh_token"].string_value(), + .IdentityToken = Json["id_token"].string_value(), + .Scope = Json["scope"].string_value(), + .ExpiresInSeconds = static_cast<int64_t>(Json["expires_in"].int_value()), + .Ok = true}; } } // namespace zen diff --git a/zenserver/auth/oidc.h b/zenserver/auth/oidc.h index b08181bfd..4ed06317b 100644 --- a/zenserver/auth/oidc.h +++ b/zenserver/auth/oidc.h @@ -39,7 +39,7 @@ public: std::string IdentityToken; std::string Scope; std::string Reason; - double ExpiresIn{}; + int64_t ExpiresInSeconds{}; bool Ok = false; }; diff --git a/zenserver/upstream/jupiter.cpp b/zenserver/upstream/jupiter.cpp index b377ac629..6fc952bab 100644 --- a/zenserver/upstream/jupiter.cpp +++ b/zenserver/upstream/jupiter.cpp @@ -854,6 +854,25 @@ CloudCacheTokenProvider::MakeFromOAuthClientCredentials(const OAuthClientCredent return std::make_unique<OAuthClientCredentialsTokenProvider>(Params); } +class CallbackTokenProvider final : public CloudCacheTokenProvider +{ +public: + CallbackTokenProvider(std::function<CloudCacheAccessToken()>&& Callback) : m_Callback(std::move(Callback)) {} + + virtual ~CallbackTokenProvider() = default; + + virtual CloudCacheAccessToken AcquireAccessToken() final override { return m_Callback(); } + +private: + std::function<CloudCacheAccessToken()> m_Callback; +}; + +std::unique_ptr<CloudCacheTokenProvider> +CloudCacheTokenProvider::MakeFromCallback(std::function<CloudCacheAccessToken()>&& Callback) +{ + return std::make_unique<CallbackTokenProvider>(std::move(Callback)); +} + CloudCacheClient::CloudCacheClient(const CloudCacheClientOptions& Options, std::unique_ptr<CloudCacheTokenProvider> TokenProvider) : m_Log(zen::logging::Get("jupiter")) , m_ServiceUrl(Options.ServiceUrl) diff --git a/zenserver/upstream/jupiter.h b/zenserver/upstream/jupiter.h index 31224500a..1b9650bdf 100644 --- a/zenserver/upstream/jupiter.h +++ b/zenserver/upstream/jupiter.h @@ -157,6 +157,8 @@ public: }; static std::unique_ptr<CloudCacheTokenProvider> MakeFromOAuthClientCredentials(const OAuthClientCredentialsParams& Params); + + static std::unique_ptr<CloudCacheTokenProvider> MakeFromCallback(std::function<CloudCacheAccessToken()>&& Callback); }; struct CloudCacheClientOptions diff --git a/zenserver/upstream/upstreamcache.cpp b/zenserver/upstream/upstreamcache.cpp index 232ed3031..58c025b4f 100644 --- a/zenserver/upstream/upstreamcache.cpp +++ b/zenserver/upstream/upstreamcache.cpp @@ -111,8 +111,6 @@ namespace detail { return {.State = UpstreamEndpointState::kOk}; } - const AuthMgr::OpenIdAccessToken Token = m_AuthMgr.GetOpenIdAccessToken("Okta"); - CloudCacheSession Session(m_Client); const CloudCacheResult Result = Session.Authenticate(); @@ -1415,11 +1413,12 @@ private: if (Status.State == UpstreamEndpointState::kOk) { - ZEN_INFO("health check endpoint '{} - {}' OK", Info.Name, Info.Url); + ZEN_INFO("HEALTH - endpoint '{} - {}' Ok", Info.Name, Info.Url); } else { - ZEN_WARN("health check endpoint '{} - {}' FAILED, reason '{}'", Info.Name, Info.Url, Status.Reason); + const std::string Reason = Status.Reason.empty() ? "" : fmt::format(", reason '{}'", Status.Reason); + ZEN_WARN("HEALTH - endpoint '{} - {}' {} {}", Info.Name, Info.Url, ToString(Status.State), Reason); } } } diff --git a/zenserver/zenserver.cpp b/zenserver/zenserver.cpp index 2c9610866..8d408eb90 100644 --- a/zenserver/zenserver.cpp +++ b/zenserver/zenserver.cpp @@ -793,6 +793,11 @@ ZenServer::InitializeStructuredCache(const ZenServerOptions& ServerOptions) if (!Options.ServiceUrl.empty()) { + TokenProvider = CloudCacheTokenProvider::MakeFromCallback([this]() { + AuthMgr::OpenIdAccessToken Token = m_AuthMgr->GetOpenIdAccessToken("Okta"sv); + return CloudCacheAccessToken{.Value = Token.AccessToken, .ExpireTime = Token.ExpireTime}; + }); + std::unique_ptr<zen::UpstreamEndpoint> JupiterEndpoint = zen::MakeJupiterUpstreamEndpoint(Options, std::move(TokenProvider), *m_AuthMgr); m_UpstreamCache->RegisterEndpoint(std::move(JupiterEndpoint)); |