aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorDan Engelbrecht <[email protected]>2026-03-18 22:28:14 +0100
committerGitHub Enterprise <[email protected]>2026-03-18 22:28:14 +0100
commit59bc08385515997a34fe2b4b3cbbfd03dd9a7c5b (patch)
tree0a65fca5537909f41b5f8b0d87daa7dbcd967677 /src
parentUpdate libcurl to 8.19.0 (#862) (diff)
downloadzen-59bc08385515997a34fe2b4b3cbbfd03dd9a7c5b.tar.xz
zen-59bc08385515997a34fe2b4b3cbbfd03dd9a7c5b.zip
improve auth token refresh (#863)
Authentication callbacks are not thread safe, ensured call sites does single threaded calls
Diffstat (limited to 'src')
-rw-r--r--src/zenhttp/clients/httpclientcommon.h2
-rw-r--r--src/zenhttp/clients/httpclientcpr.cpp19
-rw-r--r--src/zenhttp/clients/httpclientcpr.h14
-rw-r--r--src/zenhttp/clients/httpclientcurl.cpp6
-rw-r--r--src/zenhttp/clients/httpwsclient.cpp49
-rw-r--r--src/zenhttp/httpclient.cpp37
-rw-r--r--src/zenhttp/httpclient_test.cpp10
-rw-r--r--src/zenhttp/httpclientauth.cpp22
-rw-r--r--src/zenhttp/include/zenhttp/httpclient.h26
-rw-r--r--src/zenhttp/include/zenhttp/httpclientauth.h4
-rw-r--r--src/zenserver/storage/projectstore/httpprojectstore.cpp2
11 files changed, 126 insertions, 65 deletions
diff --git a/src/zenhttp/clients/httpclientcommon.h b/src/zenhttp/clients/httpclientcommon.h
index e95e3a253..e8d969cc8 100644
--- a/src/zenhttp/clients/httpclientcommon.h
+++ b/src/zenhttp/clients/httpclientcommon.h
@@ -70,7 +70,7 @@ protected:
const HttpClientSettings m_ConnectionSettings;
std::function<bool()> m_CheckIfAbortFunction;
- const std::optional<HttpClientAccessToken> GetAccessToken();
+ std::optional<std::string> GetAccessToken();
RwLock m_AccessTokenLock;
HttpClientAccessToken m_CachedAccessToken;
diff --git a/src/zenhttp/clients/httpclientcpr.cpp b/src/zenhttp/clients/httpclientcpr.cpp
index a52b8f74b..bd6de3ff7 100644
--- a/src/zenhttp/clients/httpclientcpr.cpp
+++ b/src/zenhttp/clients/httpclientcpr.cpp
@@ -399,13 +399,13 @@ CprHttpClient::DoWithRetry(std::string_view SessionId,
//////////////////////////////////////////////////////////////////////////
CprHttpClient::Session
-CprHttpClient::AllocSession(const std::string_view BaseUrl,
- const std::string_view ResourcePath,
- const HttpClientSettings& ConnectionSettings,
- const KeyValueMap& AdditionalHeader,
- const KeyValueMap& Parameters,
- const std::string_view SessionId,
- std::optional<HttpClientAccessToken> AccessToken)
+CprHttpClient::AllocSession(const std::string_view BaseUrl,
+ const std::string_view ResourcePath,
+ const HttpClientSettings& ConnectionSettings,
+ const KeyValueMap& AdditionalHeader,
+ const KeyValueMap& Parameters,
+ const std::string_view SessionId,
+ std::optional<std::string> AccessToken)
{
ZEN_TRACE_CPU("CprHttpClient::AllocSession");
cpr::Session* CprSession = nullptr;
@@ -494,9 +494,10 @@ CprHttpClient::AllocSession(const std::string_view BaseUrl,
{
CprSession->UpdateHeader({{"Connection", "close"}});
}
- if (AccessToken)
+
+ if (AccessToken.has_value())
{
- CprSession->UpdateHeader({{"Authorization", AccessToken->Value}});
+ CprSession->UpdateHeader({{"Authorization", AccessToken.value()}});
}
if (!Parameters->empty())
{
diff --git a/src/zenhttp/clients/httpclientcpr.h b/src/zenhttp/clients/httpclientcpr.h
index 009e6fb7a..509ca5ae2 100644
--- a/src/zenhttp/clients/httpclientcpr.h
+++ b/src/zenhttp/clients/httpclientcpr.h
@@ -149,13 +149,13 @@ private:
Session& operator=(Session&&) = delete;
};
- Session AllocSession(const std::string_view BaseUrl,
- const std::string_view Url,
- const HttpClientSettings& ConnectionSettings,
- const KeyValueMap& AdditionalHeader,
- const KeyValueMap& Parameters,
- const std::string_view SessionId,
- std::optional<HttpClientAccessToken> AccessToken);
+ Session AllocSession(const std::string_view BaseUrl,
+ const std::string_view Url,
+ const HttpClientSettings& ConnectionSettings,
+ const KeyValueMap& AdditionalHeader,
+ const KeyValueMap& Parameters,
+ const std::string_view SessionId,
+ std::optional<std::string> AccessToken);
RwLock m_SessionLock;
std::vector<cpr::Session*> m_Sessions;
diff --git a/src/zenhttp/clients/httpclientcurl.cpp b/src/zenhttp/clients/httpclientcurl.cpp
index ec9b7bac6..e76157254 100644
--- a/src/zenhttp/clients/httpclientcurl.cpp
+++ b/src/zenhttp/clients/httpclientcurl.cpp
@@ -290,7 +290,7 @@ HeaderContentType(ZenContentType ContentType)
static curl_slist*
BuildHeaderList(const HttpClient::KeyValueMap& AdditionalHeader,
std::string_view SessionId,
- const std::optional<HttpClientAccessToken>& AccessToken,
+ const std::optional<std::string>& AccessToken,
const std::vector<std::pair<std::string, std::string>>& ExtraHeaders = {})
{
curl_slist* Headers = nullptr;
@@ -309,10 +309,10 @@ BuildHeaderList(const HttpClient::KeyValueMap& AdditionalHeader,
Headers = curl_slist_append(Headers, SessionHeader.c_str());
}
- if (AccessToken)
+ if (AccessToken.has_value())
{
ExtendableStringBuilder<128> AuthHeader;
- AuthHeader << "Authorization: " << AccessToken->Value;
+ AuthHeader << "Authorization: " << AccessToken.value();
Headers = curl_slist_append(Headers, AuthHeader.c_str());
}
diff --git a/src/zenhttp/clients/httpwsclient.cpp b/src/zenhttp/clients/httpwsclient.cpp
index 770213738..4337fcb79 100644
--- a/src/zenhttp/clients/httpwsclient.cpp
+++ b/src/zenhttp/clients/httpwsclient.cpp
@@ -243,13 +243,9 @@ struct HttpWsClient::Impl
<< "Sec-WebSocket-Version: 13\r\n";
// Add Authorization header if access token provider is set
- if (m_Settings.AccessTokenProvider)
+ if (std::optional<std::string> AccessToken = GetAccessToken(); AccessToken.has_value())
{
- HttpClientAccessToken Token = (*m_Settings.AccessTokenProvider)();
- if (Token.IsValid())
- {
- Request << "Authorization: Bearer " << Token.Value << "\r\n";
- }
+ Request << "Authorization: Bearer " << AccessToken.value() << "\r\n";
}
Request << "\r\n";
@@ -557,10 +553,51 @@ struct HttpWsClient::Impl
}
}
+ std::optional<std::string> GetAccessToken()
+ {
+ if (!m_Settings.AccessTokenProvider.has_value())
+ {
+ return {};
+ }
+ {
+ RwLock::SharedLockScope _(m_AccessTokenLock);
+ if (!m_CachedAccessToken.NeedsRefresh())
+ {
+ return m_CachedAccessToken.GetValue();
+ }
+ }
+ RwLock::ExclusiveLockScope _(m_AccessTokenLock);
+ if (!m_CachedAccessToken.NeedsRefresh())
+ {
+ return m_CachedAccessToken.GetValue();
+ }
+ HttpClientAccessToken NewAccessToken = m_Settings.AccessTokenProvider.value()();
+ if (NewAccessToken.IsValid())
+ {
+ m_CachedAccessToken = NewAccessToken;
+ }
+ else
+ {
+ if (m_CachedAccessToken.HasExpired())
+ {
+ ZEN_WARN("HttpWsClient refreshed access token is not valid, clearing the cached token as it has expired");
+ m_CachedAccessToken = {};
+ }
+ else
+ {
+ ZEN_WARN("HttpWsClient refreshed access token is not valid, keeping existing token, it will expire soon");
+ }
+ }
+ return m_CachedAccessToken.GetValue();
+ }
+
IWsClientHandler& m_Handler;
HttpWsClientSettings m_Settings;
LoggerRef m_Log;
+ RwLock m_AccessTokenLock;
+ HttpClientAccessToken m_CachedAccessToken;
+
std::string m_Host;
std::string m_Port;
std::string m_Path;
diff --git a/src/zenhttp/httpclient.cpp b/src/zenhttp/httpclient.cpp
index 9f49802a0..3e81d4a8a 100644
--- a/src/zenhttp/httpclient.cpp
+++ b/src/zenhttp/httpclient.cpp
@@ -104,15 +104,10 @@ bool
HttpClientBase::Authenticate()
{
ZEN_TRACE_CPU("HttpClientBase::Authenticate");
- std::optional<HttpClientAccessToken> Token = GetAccessToken();
- if (!Token)
- {
- return false;
- }
- return Token->IsValid();
+ return GetAccessToken().has_value();
}
-const std::optional<HttpClientAccessToken>
+std::optional<std::string>
HttpClientBase::GetAccessToken()
{
ZEN_TRACE_CPU("HttpClientBase::GetAccessToken");
@@ -122,18 +117,34 @@ HttpClientBase::GetAccessToken()
}
{
RwLock::SharedLockScope _(m_AccessTokenLock);
- if (m_CachedAccessToken.IsValid())
+ if (!m_CachedAccessToken.NeedsRefresh())
{
- return m_CachedAccessToken;
+ return m_CachedAccessToken.GetValue();
}
}
RwLock::ExclusiveLockScope _(m_AccessTokenLock);
- if (m_CachedAccessToken.IsValid())
+ if (!m_CachedAccessToken.NeedsRefresh())
+ {
+ return m_CachedAccessToken.GetValue();
+ }
+ HttpClientAccessToken NewAccessToken = m_ConnectionSettings.AccessTokenProvider.value()();
+ if (NewAccessToken.IsValid())
+ {
+ m_CachedAccessToken = NewAccessToken;
+ }
+ else
{
- return m_CachedAccessToken;
+ if (m_CachedAccessToken.HasExpired())
+ {
+ ZEN_WARN("HttpClient refreshed access token is not valid, clearing the cached token as it has expired");
+ m_CachedAccessToken = {};
+ }
+ else
+ {
+ ZEN_WARN("HttpClient refreshed access token is not valid, keeping existing token, it will expire soon");
+ }
}
- m_CachedAccessToken = m_ConnectionSettings.AccessTokenProvider.value()();
- return m_CachedAccessToken;
+ return m_CachedAccessToken.GetValue();
}
//////////////////////////////////////////////////////////////////////////
diff --git a/src/zenhttp/httpclient_test.cpp b/src/zenhttp/httpclient_test.cpp
index 3ca586f87..7a657c464 100644
--- a/src/zenhttp/httpclient_test.cpp
+++ b/src/zenhttp/httpclient_test.cpp
@@ -813,10 +813,7 @@ TEST_CASE("httpclient.authentication")
{
HttpClientSettings Settings;
Settings.AccessTokenProvider = []() -> HttpClientAccessToken {
- return HttpClientAccessToken{
- .Value = "valid-token",
- .ExpireTime = HttpClientAccessToken::Clock::now() + std::chrono::hours(1),
- };
+ return HttpClientAccessToken("valid-token", HttpClientAccessToken::Clock::now() + std::chrono::hours(1));
};
HttpClient Client = Fixture.MakeClient(Settings);
CHECK(Client.Authenticate());
@@ -826,10 +823,7 @@ TEST_CASE("httpclient.authentication")
{
HttpClientSettings Settings;
Settings.AccessTokenProvider = []() -> HttpClientAccessToken {
- return HttpClientAccessToken{
- .Value = "expired-token",
- .ExpireTime = HttpClientAccessToken::Clock::now() - std::chrono::hours(1),
- };
+ return HttpClientAccessToken("expired-token", HttpClientAccessToken::Clock::now() - std::chrono::hours(1));
};
HttpClient Client = Fixture.MakeClient(Settings);
CHECK(!Client.Authenticate());
diff --git a/src/zenhttp/httpclientauth.cpp b/src/zenhttp/httpclientauth.cpp
index 02e1b57e2..6a3f18b7a 100644
--- a/src/zenhttp/httpclientauth.cpp
+++ b/src/zenhttp/httpclientauth.cpp
@@ -33,8 +33,7 @@ namespace zen { namespace httpclientauth {
std::function<HttpClientAccessToken()> CreateFromStaticToken(std::string_view Token)
{
- return CreateFromStaticToken(
- HttpClientAccessToken{.Value = fmt::format("Bearer {}"sv, Token), .ExpireTime = HttpClientAccessToken::TimePoint::max()});
+ return CreateFromStaticToken(HttpClientAccessToken(fmt::format("Bearer {}"sv, Token), HttpClientAccessToken::TimePoint::max()));
}
std::function<HttpClientAccessToken()> CreateFromOAuthClientCredentials(const OAuthClientCredentialsParams& Params)
@@ -74,7 +73,7 @@ namespace zen { namespace httpclientauth {
int64_t ExpiresInSeconds = static_cast<int64_t>(Json["expires_in"].int_value());
HttpClientAccessToken::TimePoint ExpireTime = HttpClientAccessToken::Clock::now() + seconds(ExpiresInSeconds);
- return HttpClientAccessToken{.Value = fmt::format("Bearer {}"sv, Token), .ExpireTime = ExpireTime};
+ return HttpClientAccessToken(fmt::format("Bearer {}"sv, Token), ExpireTime);
};
}
@@ -82,7 +81,7 @@ namespace zen { namespace httpclientauth {
{
return [&AuthManager = AuthManager, OpenIdProvider = std::string(OpenIdProvider)]() {
AuthMgr::OpenIdAccessToken Token = AuthManager.GetOpenIdAccessToken(OpenIdProvider);
- return HttpClientAccessToken{.Value = Token.AccessToken, .ExpireTime = Token.ExpireTime};
+ return HttpClientAccessToken(Token.AccessToken, Token.ExpireTime);
};
}
@@ -172,7 +171,7 @@ namespace zen { namespace httpclientauth {
HttpClientAccessToken::TimePoint ExpireTime = std::chrono::system_clock::from_time_t(UTCTime);
ExpireTime += std::chrono::milliseconds(Millisecond);
- return HttpClientAccessToken{.Value = fmt::format("Bearer {}"sv, Token), .ExpireTime = ExpireTime};
+ return HttpClientAccessToken(fmt::format("Bearer {}"sv, Token), ExpireTime);
}
else
{
@@ -192,16 +191,15 @@ namespace zen { namespace httpclientauth {
{
return [OidcExecutablePath = std::filesystem::path(OidcExecutablePath),
CloudHost = std::string(CloudHost),
+ Token = InitialToken,
Quiet,
- Hidden,
- InitialToken]() mutable {
- if (InitialToken.IsValid())
+ Unattended,
+ Hidden]() mutable {
+ if (!Token.NeedsRefresh())
{
- HttpClientAccessToken Result = InitialToken;
- InitialToken = {};
- return Result;
+ return std::move(Token);
}
- return GetOidcTokenFromExe(OidcExecutablePath, CloudHost, /* Unattended */ true, Quiet, Hidden);
+ return GetOidcTokenFromExe(OidcExecutablePath, CloudHost, Unattended, Quiet, Hidden);
};
}
return {};
diff --git a/src/zenhttp/include/zenhttp/httpclient.h b/src/zenhttp/include/zenhttp/httpclient.h
index e878c900f..9531b9366 100644
--- a/src/zenhttp/include/zenhttp/httpclient.h
+++ b/src/zenhttp/include/zenhttp/httpclient.h
@@ -68,14 +68,30 @@ struct HttpClientAccessToken
static constexpr int64_t ExpireMarginInSeconds = 60 * 5;
- std::string Value;
- TimePoint ExpireTime;
+ HttpClientAccessToken() {}
- bool IsValid() const
+ HttpClientAccessToken(std::string_view InValue, const TimePoint& InExpireTime) : Value(InValue), ExpireTime(InExpireTime) {}
+
+ std::optional<std::string> GetValue() const
{
- return Value.empty() == false &&
- ExpireMarginInSeconds < std::chrono::duration_cast<std::chrono::seconds>(ExpireTime - Clock::now()).count();
+ if (IsValid())
+ {
+ return Value;
+ }
+ return {};
}
+
+ bool NeedsRefresh() const
+ {
+ return Value.empty() == true || (Clock::now() + std::chrono::seconds(ExpireMarginInSeconds)) >= ExpireTime;
+ }
+
+ bool HasExpired() const { return Clock::now() >= ExpireTime; }
+ bool IsValid() const { return !Value.empty() && !HasExpired(); }
+
+private:
+ std::string Value;
+ TimePoint ExpireTime;
};
struct HttpClientSettings
diff --git a/src/zenhttp/include/zenhttp/httpclientauth.h b/src/zenhttp/include/zenhttp/httpclientauth.h
index 26f31ed2a..f1bccdca6 100644
--- a/src/zenhttp/include/zenhttp/httpclientauth.h
+++ b/src/zenhttp/include/zenhttp/httpclientauth.h
@@ -10,6 +10,10 @@ namespace zen {
class AuthMgr;
namespace httpclientauth {
+
+ // The std::function<HttpClientAccessToken()> instances returned from these functions are not guarateed to
+ // be thread safe so caller must make sure they are not called from multiple threads in parallell
+
std::function<HttpClientAccessToken()> CreateFromStaticToken(HttpClientAccessToken Token);
std::function<HttpClientAccessToken()> CreateFromStaticToken(std::string_view Token);
diff --git a/src/zenserver/storage/projectstore/httpprojectstore.cpp b/src/zenserver/storage/projectstore/httpprojectstore.cpp
index 425caee97..03b8aa382 100644
--- a/src/zenserver/storage/projectstore/httpprojectstore.cpp
+++ b/src/zenserver/storage/projectstore/httpprojectstore.cpp
@@ -546,7 +546,7 @@ namespace {
Host.empty() ? OverrideHost : Host,
/*Quiet*/ false,
/*Unattended*/ false,
- /*Hidden*/ true);
+ /*Hidden*/ false);
TokenProviderMaybe)
{
TokenProvider = TokenProviderMaybe.value();