aboutsummaryrefslogtreecommitdiff
path: root/zenserver
diff options
context:
space:
mode:
Diffstat (limited to 'zenserver')
-rw-r--r--zenserver/auth/authmgr.cpp165
-rw-r--r--zenserver/auth/authmgr.h7
-rw-r--r--zenserver/auth/authservice.cpp6
-rw-r--r--zenserver/auth/oidc.cpp14
-rw-r--r--zenserver/auth/oidc.h2
-rw-r--r--zenserver/upstream/jupiter.cpp19
-rw-r--r--zenserver/upstream/jupiter.h2
-rw-r--r--zenserver/upstream/upstreamcache.cpp7
-rw-r--r--zenserver/zenserver.cpp5
9 files changed, 134 insertions, 93 deletions
diff --git a/zenserver/auth/authmgr.cpp b/zenserver/auth/authmgr.cpp
index 4d19316dd..9cdd5ed02 100644
--- a/zenserver/auth/authmgr.cpp
+++ b/zenserver/auth/authmgr.cpp
@@ -9,7 +9,6 @@
#include <zencore/filesystem.h>
#include <zencore/logging.h>
-#include <chrono>
#include <condition_variable>
#include <memory>
#include <shared_mutex>
@@ -24,6 +23,10 @@ using namespace std::literals;
class AuthMgrImpl final : public AuthMgr
{
+ using Clock = std::chrono::system_clock;
+ using TimePoint = Clock::time_point;
+ using Seconds = std::chrono::seconds;
+
public:
AuthMgrImpl(const AuthConfig& Config) : m_Config(Config), m_Log(logging::Get("auth")) { LoadState(); }
@@ -77,7 +80,7 @@ public:
return false;
}
- if (Params.IdentityToken.empty() || Params.RefreshToken.empty() || Params.AccessToken.empty())
+ if (Params.RefreshToken.empty())
{
ZEN_WARN("add OpenId token FAILED, reason 'Token invalid'");
return false;
@@ -94,12 +97,14 @@ public:
bool IsNew = false;
{
+ auto Token = OpenIdToken{.IdentityToken = RefreshResult.IdentityToken,
+ .RefreshToken = RefreshResult.RefreshToken,
+ .AccessToken = fmt::format("Bearer {}"sv, RefreshResult.AccessToken),
+ .ExpireTime = Clock::now() + Seconds(RefreshResult.ExpiresInSeconds)};
+
std::unique_lock _(m_TokenMutex);
- const auto InsertResult = m_OpenIdTokens.try_emplace(std::string(Params.ProviderName),
- OpenIdToken{.IdentityToken = RefreshResult.IdentityToken,
- .RefreshToken = RefreshResult.RefreshToken,
- .AccessToken = RefreshResult.AccessToken});
+ const auto InsertResult = m_OpenIdTokens.try_emplace(std::string(Params.ProviderName), std::move(Token));
IsNew = InsertResult.second;
}
@@ -124,7 +129,7 @@ public:
{
const OpenIdToken& Token = It->second;
- return {.AccessToken = fmt::format("Bearer {}", Token.AccessToken)};
+ return {.AccessToken = Token.AccessToken, .ExpireTime = Token.ExpireTime};
}
return {};
@@ -160,102 +165,118 @@ private:
void LoadState()
{
- FileContents Result = ReadFile(m_Config.RootDirectory / "authstate"sv);
-
- if (Result.ErrorCode)
+ try
{
- return;
- }
+ FileContents Result = ReadFile(m_Config.RootDirectory / "authstate"sv);
- IoBuffer Buffer = Result.Flatten();
+ if (Result.ErrorCode)
+ {
+ return;
+ }
- const CbValidateError ValidationError = ValidateCompactBinary(Buffer, CbValidateMode::All);
+ IoBuffer Buffer = Result.Flatten();
- if (ValidationError != CbValidateError::None)
- {
- ZEN_WARN("load serialized state FAILED, reason 'Invalid compact binary'");
- return;
- }
+ const CbValidateError ValidationError = ValidateCompactBinary(Buffer, CbValidateMode::All);
- if (CbObject AuthState = LoadCompactBinaryObject(Buffer))
- {
- for (CbFieldView ProviderView : AuthState["OpenIdProviders"sv])
+ if (ValidationError != CbValidateError::None)
{
- CbObjectView ProviderObj = ProviderView.AsObjectView();
-
- std::string_view ProviderName = ProviderObj["Name"].AsString();
- std::string_view Url = ProviderObj["Url"].AsString();
- std::string_view ClientId = ProviderObj["ClientId"].AsString();
-
- AddOpenIdProvider({.Name = ProviderName, .Url = Url, .ClientId = ClientId});
+ ZEN_WARN("load serialized state FAILED, reason 'Invalid compact binary'");
+ return;
}
- for (CbFieldView TokenView : AuthState["OpenIdTokens"sv])
+ if (CbObject AuthState = LoadCompactBinaryObject(Buffer))
{
- CbObjectView TokenObj = TokenView.AsObjectView();
+ for (CbFieldView ProviderView : AuthState["OpenIdProviders"sv])
+ {
+ CbObjectView ProviderObj = ProviderView.AsObjectView();
- std::string_view ProviderName = TokenObj["ProviderName"sv].AsString();
- std::string_view IdentityToken = TokenObj["IdentityToken"sv].AsString();
- std::string_view RefreshToken = TokenObj["RefreshToken"sv].AsString();
- std::string_view AccessToken = TokenObj["AccessToken"sv].AsString();
+ std::string_view ProviderName = ProviderObj["Name"].AsString();
+ std::string_view Url = ProviderObj["Url"].AsString();
+ std::string_view ClientId = ProviderObj["ClientId"].AsString();
- const bool Ok = AddOpenIdToken({.ProviderName = ProviderName,
- .IdentityToken = IdentityToken,
- .RefreshToken = RefreshToken,
- .AccessToken = AccessToken});
+ AddOpenIdProvider({.Name = ProviderName, .Url = Url, .ClientId = ClientId});
+ }
- if (!Ok)
+ for (CbFieldView TokenView : AuthState["OpenIdTokens"sv])
{
- ZEN_WARN("load serialized OpenId token for provider '{}' FAILED", ProviderName);
+ CbObjectView TokenObj = TokenView.AsObjectView();
+
+ std::string_view ProviderName = TokenObj["ProviderName"sv].AsString();
+ std::string_view RefreshToken = TokenObj["RefreshToken"sv].AsString();
+
+ const bool Ok = AddOpenIdToken({.ProviderName = ProviderName, .RefreshToken = RefreshToken});
+
+ if (!Ok)
+ {
+ ZEN_WARN("load serialized OpenId token for provider '{}' FAILED", ProviderName);
+ }
}
}
}
+ catch (std::exception& Err)
+ {
+ ZEN_ERROR("(de)serialize state FAILED, reason '{}'", Err.what());
+
+ {
+ std::unique_lock _(m_ProviderMutex);
+ m_OpenIdProviders.clear();
+ }
+
+ {
+ std::unique_lock _(m_TokenMutex);
+ m_OpenIdTokens.clear();
+ }
+ }
}
void SaveState()
{
- CbObjectWriter AuthState;
-
+ try
{
- std::unique_lock _(m_ProviderMutex);
+ CbObjectWriter AuthState;
- if (m_OpenIdProviders.size() > 0)
{
- AuthState.BeginArray("OpenIdProviders");
- for (const auto& Kv : m_OpenIdProviders)
+ std::unique_lock _(m_ProviderMutex);
+
+ if (m_OpenIdProviders.size() > 0)
{
- AuthState.BeginObject();
- AuthState << "Name"sv << Kv.second->Name;
- AuthState << "Url"sv << Kv.second->Url;
- AuthState << "ClientId"sv << Kv.second->ClientId;
- AuthState.EndObject();
+ AuthState.BeginArray("OpenIdProviders");
+ for (const auto& Kv : m_OpenIdProviders)
+ {
+ AuthState.BeginObject();
+ AuthState << "Name"sv << Kv.second->Name;
+ AuthState << "Url"sv << Kv.second->Url;
+ AuthState << "ClientId"sv << Kv.second->ClientId;
+ AuthState.EndObject();
+ }
+ AuthState.EndArray();
}
- AuthState.EndArray();
}
- }
- {
- std::unique_lock _(m_TokenMutex);
-
- AuthState.BeginArray("OpenIdTokens");
- if (m_OpenIdTokens.size() > 0)
{
- for (const auto& Kv : m_OpenIdTokens)
+ std::unique_lock _(m_TokenMutex);
+
+ AuthState.BeginArray("OpenIdTokens");
+ if (m_OpenIdTokens.size() > 0)
{
- AuthState.BeginObject();
- AuthState << "ProviderName"sv << Kv.first;
- AuthState << "IdentityToken"sv << Kv.second.IdentityToken;
- AuthState << "RefreshToken"sv << Kv.second.RefreshToken;
- AuthState << "AccessToken"sv << Kv.second.AccessToken;
- AuthState << "ExpireTime"sv << Kv.second.ExpireTime;
- AuthState.EndObject();
+ for (const auto& Kv : m_OpenIdTokens)
+ {
+ AuthState.BeginObject();
+ AuthState << "ProviderName"sv << Kv.first;
+ AuthState << "RefreshToken"sv << Kv.second.RefreshToken;
+ AuthState.EndObject();
+ }
}
+ AuthState.EndArray();
}
- AuthState.EndArray();
- }
- std::filesystem::create_directories(m_Config.RootDirectory);
- WriteFile(m_Config.RootDirectory / "authstate"sv, AuthState.Save().GetBuffer().AsIoBuffer());
+ std::filesystem::create_directories(m_Config.RootDirectory);
+ WriteFile(m_Config.RootDirectory / "authstate"sv, AuthState.Save().GetBuffer().AsIoBuffer());
+ }
+ catch (std::exception& Err)
+ {
+ ZEN_ERROR("serialize state FAILED, reason '{}'", Err.what());
+ }
}
struct OpenIdProvider
@@ -271,7 +292,7 @@ private:
std::string IdentityToken;
std::string RefreshToken;
std::string AccessToken;
- double ExpireTime{};
+ TimePoint ExpireTime{};
};
using OpenIdProviderMap = std::unordered_map<std::string, std::unique_ptr<OpenIdProvider>>;
diff --git a/zenserver/auth/authmgr.h b/zenserver/auth/authmgr.h
index 1138d9eff..59dc1725d 100644
--- a/zenserver/auth/authmgr.h
+++ b/zenserver/auth/authmgr.h
@@ -2,6 +2,7 @@
#include <zencore/string.h>
+#include <chrono>
#include <filesystem>
#include <memory>
@@ -24,17 +25,15 @@ public:
struct AddOpenIdTokenParams
{
std::string_view ProviderName;
- std::string_view IdentityToken;
std::string_view RefreshToken;
- std::string_view AccessToken;
};
-
virtual bool AddOpenIdToken(const AddOpenIdTokenParams& Params) = 0;
struct OpenIdAccessToken
{
- std::string AccessToken;
+ std::string AccessToken;
+ std::chrono::system_clock::time_point ExpireTime{};
};
virtual OpenIdAccessToken GetOpenIdAccessToken(std::string_view ProviderName) = 0;
diff --git a/zenserver/auth/authservice.cpp b/zenserver/auth/authservice.cpp
index 4e6f496a6..47a757001 100644
--- a/zenserver/auth/authservice.cpp
+++ b/zenserver/auth/authservice.cpp
@@ -45,12 +45,8 @@ HttpAuthService::HttpAuthService(AuthMgr& AuthMgr) : m_AuthMgr(AuthMgr)
const std::string IdentityToken = TokenInfo["IdentityToken"].string_value();
const std::string RefreshToken = TokenInfo["RefreshToken"].string_value();
- const std::string AccessToken = TokenInfo["AccessToken"].string_value();
- const bool Ok = m_AuthMgr.AddOpenIdToken(AuthMgr::AddOpenIdTokenParams{.ProviderName = "Okta"sv,
- .IdentityToken = IdentityToken,
- .RefreshToken = RefreshToken,
- .AccessToken = AccessToken});
+ const bool Ok = m_AuthMgr.AddOpenIdToken(AuthMgr::AddOpenIdTokenParams{.ProviderName = "Okta"sv, .RefreshToken = RefreshToken});
if (Ok)
{
diff --git a/zenserver/auth/oidc.cpp b/zenserver/auth/oidc.cpp
index 2f53f1bae..17b5bac08 100644
--- a/zenserver/auth/oidc.cpp
+++ b/zenserver/auth/oidc.cpp
@@ -115,13 +115,13 @@ OidcClient::RefreshToken(std::string_view RefreshToken)
return {.Reason = std::move(JsonError)};
}
- return {.TokenType = Json["token_type"].string_value(),
- .AccessToken = Json["access_token"].string_value(),
- .RefreshToken = Json["refresh_token"].string_value(),
- .IdentityToken = Json["id_token"].string_value(),
- .Scope = Json["scope"].string_value(),
- .ExpiresIn = Json["scope"].number_value(),
- .Ok = true};
+ return {.TokenType = Json["token_type"].string_value(),
+ .AccessToken = Json["access_token"].string_value(),
+ .RefreshToken = Json["refresh_token"].string_value(),
+ .IdentityToken = Json["id_token"].string_value(),
+ .Scope = Json["scope"].string_value(),
+ .ExpiresInSeconds = static_cast<int64_t>(Json["expires_in"].int_value()),
+ .Ok = true};
}
} // namespace zen
diff --git a/zenserver/auth/oidc.h b/zenserver/auth/oidc.h
index b08181bfd..4ed06317b 100644
--- a/zenserver/auth/oidc.h
+++ b/zenserver/auth/oidc.h
@@ -39,7 +39,7 @@ public:
std::string IdentityToken;
std::string Scope;
std::string Reason;
- double ExpiresIn{};
+ int64_t ExpiresInSeconds{};
bool Ok = false;
};
diff --git a/zenserver/upstream/jupiter.cpp b/zenserver/upstream/jupiter.cpp
index b377ac629..6fc952bab 100644
--- a/zenserver/upstream/jupiter.cpp
+++ b/zenserver/upstream/jupiter.cpp
@@ -854,6 +854,25 @@ CloudCacheTokenProvider::MakeFromOAuthClientCredentials(const OAuthClientCredent
return std::make_unique<OAuthClientCredentialsTokenProvider>(Params);
}
+class CallbackTokenProvider final : public CloudCacheTokenProvider
+{
+public:
+ CallbackTokenProvider(std::function<CloudCacheAccessToken()>&& Callback) : m_Callback(std::move(Callback)) {}
+
+ virtual ~CallbackTokenProvider() = default;
+
+ virtual CloudCacheAccessToken AcquireAccessToken() final override { return m_Callback(); }
+
+private:
+ std::function<CloudCacheAccessToken()> m_Callback;
+};
+
+std::unique_ptr<CloudCacheTokenProvider>
+CloudCacheTokenProvider::MakeFromCallback(std::function<CloudCacheAccessToken()>&& Callback)
+{
+ return std::make_unique<CallbackTokenProvider>(std::move(Callback));
+}
+
CloudCacheClient::CloudCacheClient(const CloudCacheClientOptions& Options, std::unique_ptr<CloudCacheTokenProvider> TokenProvider)
: m_Log(zen::logging::Get("jupiter"))
, m_ServiceUrl(Options.ServiceUrl)
diff --git a/zenserver/upstream/jupiter.h b/zenserver/upstream/jupiter.h
index 31224500a..1b9650bdf 100644
--- a/zenserver/upstream/jupiter.h
+++ b/zenserver/upstream/jupiter.h
@@ -157,6 +157,8 @@ public:
};
static std::unique_ptr<CloudCacheTokenProvider> MakeFromOAuthClientCredentials(const OAuthClientCredentialsParams& Params);
+
+ static std::unique_ptr<CloudCacheTokenProvider> MakeFromCallback(std::function<CloudCacheAccessToken()>&& Callback);
};
struct CloudCacheClientOptions
diff --git a/zenserver/upstream/upstreamcache.cpp b/zenserver/upstream/upstreamcache.cpp
index 232ed3031..58c025b4f 100644
--- a/zenserver/upstream/upstreamcache.cpp
+++ b/zenserver/upstream/upstreamcache.cpp
@@ -111,8 +111,6 @@ namespace detail {
return {.State = UpstreamEndpointState::kOk};
}
- const AuthMgr::OpenIdAccessToken Token = m_AuthMgr.GetOpenIdAccessToken("Okta");
-
CloudCacheSession Session(m_Client);
const CloudCacheResult Result = Session.Authenticate();
@@ -1415,11 +1413,12 @@ private:
if (Status.State == UpstreamEndpointState::kOk)
{
- ZEN_INFO("health check endpoint '{} - {}' OK", Info.Name, Info.Url);
+ ZEN_INFO("HEALTH - endpoint '{} - {}' Ok", Info.Name, Info.Url);
}
else
{
- ZEN_WARN("health check endpoint '{} - {}' FAILED, reason '{}'", Info.Name, Info.Url, Status.Reason);
+ const std::string Reason = Status.Reason.empty() ? "" : fmt::format(", reason '{}'", Status.Reason);
+ ZEN_WARN("HEALTH - endpoint '{} - {}' {} {}", Info.Name, Info.Url, ToString(Status.State), Reason);
}
}
}
diff --git a/zenserver/zenserver.cpp b/zenserver/zenserver.cpp
index 2c9610866..8d408eb90 100644
--- a/zenserver/zenserver.cpp
+++ b/zenserver/zenserver.cpp
@@ -793,6 +793,11 @@ ZenServer::InitializeStructuredCache(const ZenServerOptions& ServerOptions)
if (!Options.ServiceUrl.empty())
{
+ TokenProvider = CloudCacheTokenProvider::MakeFromCallback([this]() {
+ AuthMgr::OpenIdAccessToken Token = m_AuthMgr->GetOpenIdAccessToken("Okta"sv);
+ return CloudCacheAccessToken{.Value = Token.AccessToken, .ExpireTime = Token.ExpireTime};
+ });
+
std::unique_ptr<zen::UpstreamEndpoint> JupiterEndpoint =
zen::MakeJupiterUpstreamEndpoint(Options, std::move(TokenProvider), *m_AuthMgr);
m_UpstreamCache->RegisterEndpoint(std::move(JupiterEndpoint));