diff options
| author | Dan Engelbrecht <[email protected]> | 2026-03-18 22:28:14 +0100 |
|---|---|---|
| committer | GitHub Enterprise <[email protected]> | 2026-03-18 22:28:14 +0100 |
| commit | 59bc08385515997a34fe2b4b3cbbfd03dd9a7c5b (patch) | |
| tree | 0a65fca5537909f41b5f8b0d87daa7dbcd967677 /src | |
| parent | Update libcurl to 8.19.0 (#862) (diff) | |
| download | zen-59bc08385515997a34fe2b4b3cbbfd03dd9a7c5b.tar.xz zen-59bc08385515997a34fe2b4b3cbbfd03dd9a7c5b.zip | |
improve auth token refresh (#863)
Authentication callbacks are not thread safe, ensured call sites does single threaded calls
Diffstat (limited to 'src')
| -rw-r--r-- | src/zenhttp/clients/httpclientcommon.h | 2 | ||||
| -rw-r--r-- | src/zenhttp/clients/httpclientcpr.cpp | 19 | ||||
| -rw-r--r-- | src/zenhttp/clients/httpclientcpr.h | 14 | ||||
| -rw-r--r-- | src/zenhttp/clients/httpclientcurl.cpp | 6 | ||||
| -rw-r--r-- | src/zenhttp/clients/httpwsclient.cpp | 49 | ||||
| -rw-r--r-- | src/zenhttp/httpclient.cpp | 37 | ||||
| -rw-r--r-- | src/zenhttp/httpclient_test.cpp | 10 | ||||
| -rw-r--r-- | src/zenhttp/httpclientauth.cpp | 22 | ||||
| -rw-r--r-- | src/zenhttp/include/zenhttp/httpclient.h | 26 | ||||
| -rw-r--r-- | src/zenhttp/include/zenhttp/httpclientauth.h | 4 | ||||
| -rw-r--r-- | src/zenserver/storage/projectstore/httpprojectstore.cpp | 2 |
11 files changed, 126 insertions, 65 deletions
diff --git a/src/zenhttp/clients/httpclientcommon.h b/src/zenhttp/clients/httpclientcommon.h index e95e3a253..e8d969cc8 100644 --- a/src/zenhttp/clients/httpclientcommon.h +++ b/src/zenhttp/clients/httpclientcommon.h @@ -70,7 +70,7 @@ protected: const HttpClientSettings m_ConnectionSettings; std::function<bool()> m_CheckIfAbortFunction; - const std::optional<HttpClientAccessToken> GetAccessToken(); + std::optional<std::string> GetAccessToken(); RwLock m_AccessTokenLock; HttpClientAccessToken m_CachedAccessToken; diff --git a/src/zenhttp/clients/httpclientcpr.cpp b/src/zenhttp/clients/httpclientcpr.cpp index a52b8f74b..bd6de3ff7 100644 --- a/src/zenhttp/clients/httpclientcpr.cpp +++ b/src/zenhttp/clients/httpclientcpr.cpp @@ -399,13 +399,13 @@ CprHttpClient::DoWithRetry(std::string_view SessionId, ////////////////////////////////////////////////////////////////////////// CprHttpClient::Session -CprHttpClient::AllocSession(const std::string_view BaseUrl, - const std::string_view ResourcePath, - const HttpClientSettings& ConnectionSettings, - const KeyValueMap& AdditionalHeader, - const KeyValueMap& Parameters, - const std::string_view SessionId, - std::optional<HttpClientAccessToken> AccessToken) +CprHttpClient::AllocSession(const std::string_view BaseUrl, + const std::string_view ResourcePath, + const HttpClientSettings& ConnectionSettings, + const KeyValueMap& AdditionalHeader, + const KeyValueMap& Parameters, + const std::string_view SessionId, + std::optional<std::string> AccessToken) { ZEN_TRACE_CPU("CprHttpClient::AllocSession"); cpr::Session* CprSession = nullptr; @@ -494,9 +494,10 @@ CprHttpClient::AllocSession(const std::string_view BaseUrl, { CprSession->UpdateHeader({{"Connection", "close"}}); } - if (AccessToken) + + if (AccessToken.has_value()) { - CprSession->UpdateHeader({{"Authorization", AccessToken->Value}}); + CprSession->UpdateHeader({{"Authorization", AccessToken.value()}}); } if (!Parameters->empty()) { diff --git a/src/zenhttp/clients/httpclientcpr.h b/src/zenhttp/clients/httpclientcpr.h index 009e6fb7a..509ca5ae2 100644 --- a/src/zenhttp/clients/httpclientcpr.h +++ b/src/zenhttp/clients/httpclientcpr.h @@ -149,13 +149,13 @@ private: Session& operator=(Session&&) = delete; }; - Session AllocSession(const std::string_view BaseUrl, - const std::string_view Url, - const HttpClientSettings& ConnectionSettings, - const KeyValueMap& AdditionalHeader, - const KeyValueMap& Parameters, - const std::string_view SessionId, - std::optional<HttpClientAccessToken> AccessToken); + Session AllocSession(const std::string_view BaseUrl, + const std::string_view Url, + const HttpClientSettings& ConnectionSettings, + const KeyValueMap& AdditionalHeader, + const KeyValueMap& Parameters, + const std::string_view SessionId, + std::optional<std::string> AccessToken); RwLock m_SessionLock; std::vector<cpr::Session*> m_Sessions; diff --git a/src/zenhttp/clients/httpclientcurl.cpp b/src/zenhttp/clients/httpclientcurl.cpp index ec9b7bac6..e76157254 100644 --- a/src/zenhttp/clients/httpclientcurl.cpp +++ b/src/zenhttp/clients/httpclientcurl.cpp @@ -290,7 +290,7 @@ HeaderContentType(ZenContentType ContentType) static curl_slist* BuildHeaderList(const HttpClient::KeyValueMap& AdditionalHeader, std::string_view SessionId, - const std::optional<HttpClientAccessToken>& AccessToken, + const std::optional<std::string>& AccessToken, const std::vector<std::pair<std::string, std::string>>& ExtraHeaders = {}) { curl_slist* Headers = nullptr; @@ -309,10 +309,10 @@ BuildHeaderList(const HttpClient::KeyValueMap& AdditionalHeader, Headers = curl_slist_append(Headers, SessionHeader.c_str()); } - if (AccessToken) + if (AccessToken.has_value()) { ExtendableStringBuilder<128> AuthHeader; - AuthHeader << "Authorization: " << AccessToken->Value; + AuthHeader << "Authorization: " << AccessToken.value(); Headers = curl_slist_append(Headers, AuthHeader.c_str()); } diff --git a/src/zenhttp/clients/httpwsclient.cpp b/src/zenhttp/clients/httpwsclient.cpp index 770213738..4337fcb79 100644 --- a/src/zenhttp/clients/httpwsclient.cpp +++ b/src/zenhttp/clients/httpwsclient.cpp @@ -243,13 +243,9 @@ struct HttpWsClient::Impl << "Sec-WebSocket-Version: 13\r\n"; // Add Authorization header if access token provider is set - if (m_Settings.AccessTokenProvider) + if (std::optional<std::string> AccessToken = GetAccessToken(); AccessToken.has_value()) { - HttpClientAccessToken Token = (*m_Settings.AccessTokenProvider)(); - if (Token.IsValid()) - { - Request << "Authorization: Bearer " << Token.Value << "\r\n"; - } + Request << "Authorization: Bearer " << AccessToken.value() << "\r\n"; } Request << "\r\n"; @@ -557,10 +553,51 @@ struct HttpWsClient::Impl } } + std::optional<std::string> GetAccessToken() + { + if (!m_Settings.AccessTokenProvider.has_value()) + { + return {}; + } + { + RwLock::SharedLockScope _(m_AccessTokenLock); + if (!m_CachedAccessToken.NeedsRefresh()) + { + return m_CachedAccessToken.GetValue(); + } + } + RwLock::ExclusiveLockScope _(m_AccessTokenLock); + if (!m_CachedAccessToken.NeedsRefresh()) + { + return m_CachedAccessToken.GetValue(); + } + HttpClientAccessToken NewAccessToken = m_Settings.AccessTokenProvider.value()(); + if (NewAccessToken.IsValid()) + { + m_CachedAccessToken = NewAccessToken; + } + else + { + if (m_CachedAccessToken.HasExpired()) + { + ZEN_WARN("HttpWsClient refreshed access token is not valid, clearing the cached token as it has expired"); + m_CachedAccessToken = {}; + } + else + { + ZEN_WARN("HttpWsClient refreshed access token is not valid, keeping existing token, it will expire soon"); + } + } + return m_CachedAccessToken.GetValue(); + } + IWsClientHandler& m_Handler; HttpWsClientSettings m_Settings; LoggerRef m_Log; + RwLock m_AccessTokenLock; + HttpClientAccessToken m_CachedAccessToken; + std::string m_Host; std::string m_Port; std::string m_Path; diff --git a/src/zenhttp/httpclient.cpp b/src/zenhttp/httpclient.cpp index 9f49802a0..3e81d4a8a 100644 --- a/src/zenhttp/httpclient.cpp +++ b/src/zenhttp/httpclient.cpp @@ -104,15 +104,10 @@ bool HttpClientBase::Authenticate() { ZEN_TRACE_CPU("HttpClientBase::Authenticate"); - std::optional<HttpClientAccessToken> Token = GetAccessToken(); - if (!Token) - { - return false; - } - return Token->IsValid(); + return GetAccessToken().has_value(); } -const std::optional<HttpClientAccessToken> +std::optional<std::string> HttpClientBase::GetAccessToken() { ZEN_TRACE_CPU("HttpClientBase::GetAccessToken"); @@ -122,18 +117,34 @@ HttpClientBase::GetAccessToken() } { RwLock::SharedLockScope _(m_AccessTokenLock); - if (m_CachedAccessToken.IsValid()) + if (!m_CachedAccessToken.NeedsRefresh()) { - return m_CachedAccessToken; + return m_CachedAccessToken.GetValue(); } } RwLock::ExclusiveLockScope _(m_AccessTokenLock); - if (m_CachedAccessToken.IsValid()) + if (!m_CachedAccessToken.NeedsRefresh()) + { + return m_CachedAccessToken.GetValue(); + } + HttpClientAccessToken NewAccessToken = m_ConnectionSettings.AccessTokenProvider.value()(); + if (NewAccessToken.IsValid()) + { + m_CachedAccessToken = NewAccessToken; + } + else { - return m_CachedAccessToken; + if (m_CachedAccessToken.HasExpired()) + { + ZEN_WARN("HttpClient refreshed access token is not valid, clearing the cached token as it has expired"); + m_CachedAccessToken = {}; + } + else + { + ZEN_WARN("HttpClient refreshed access token is not valid, keeping existing token, it will expire soon"); + } } - m_CachedAccessToken = m_ConnectionSettings.AccessTokenProvider.value()(); - return m_CachedAccessToken; + return m_CachedAccessToken.GetValue(); } ////////////////////////////////////////////////////////////////////////// diff --git a/src/zenhttp/httpclient_test.cpp b/src/zenhttp/httpclient_test.cpp index 3ca586f87..7a657c464 100644 --- a/src/zenhttp/httpclient_test.cpp +++ b/src/zenhttp/httpclient_test.cpp @@ -813,10 +813,7 @@ TEST_CASE("httpclient.authentication") { HttpClientSettings Settings; Settings.AccessTokenProvider = []() -> HttpClientAccessToken { - return HttpClientAccessToken{ - .Value = "valid-token", - .ExpireTime = HttpClientAccessToken::Clock::now() + std::chrono::hours(1), - }; + return HttpClientAccessToken("valid-token", HttpClientAccessToken::Clock::now() + std::chrono::hours(1)); }; HttpClient Client = Fixture.MakeClient(Settings); CHECK(Client.Authenticate()); @@ -826,10 +823,7 @@ TEST_CASE("httpclient.authentication") { HttpClientSettings Settings; Settings.AccessTokenProvider = []() -> HttpClientAccessToken { - return HttpClientAccessToken{ - .Value = "expired-token", - .ExpireTime = HttpClientAccessToken::Clock::now() - std::chrono::hours(1), - }; + return HttpClientAccessToken("expired-token", HttpClientAccessToken::Clock::now() - std::chrono::hours(1)); }; HttpClient Client = Fixture.MakeClient(Settings); CHECK(!Client.Authenticate()); diff --git a/src/zenhttp/httpclientauth.cpp b/src/zenhttp/httpclientauth.cpp index 02e1b57e2..6a3f18b7a 100644 --- a/src/zenhttp/httpclientauth.cpp +++ b/src/zenhttp/httpclientauth.cpp @@ -33,8 +33,7 @@ namespace zen { namespace httpclientauth { std::function<HttpClientAccessToken()> CreateFromStaticToken(std::string_view Token) { - return CreateFromStaticToken( - HttpClientAccessToken{.Value = fmt::format("Bearer {}"sv, Token), .ExpireTime = HttpClientAccessToken::TimePoint::max()}); + return CreateFromStaticToken(HttpClientAccessToken(fmt::format("Bearer {}"sv, Token), HttpClientAccessToken::TimePoint::max())); } std::function<HttpClientAccessToken()> CreateFromOAuthClientCredentials(const OAuthClientCredentialsParams& Params) @@ -74,7 +73,7 @@ namespace zen { namespace httpclientauth { int64_t ExpiresInSeconds = static_cast<int64_t>(Json["expires_in"].int_value()); HttpClientAccessToken::TimePoint ExpireTime = HttpClientAccessToken::Clock::now() + seconds(ExpiresInSeconds); - return HttpClientAccessToken{.Value = fmt::format("Bearer {}"sv, Token), .ExpireTime = ExpireTime}; + return HttpClientAccessToken(fmt::format("Bearer {}"sv, Token), ExpireTime); }; } @@ -82,7 +81,7 @@ namespace zen { namespace httpclientauth { { return [&AuthManager = AuthManager, OpenIdProvider = std::string(OpenIdProvider)]() { AuthMgr::OpenIdAccessToken Token = AuthManager.GetOpenIdAccessToken(OpenIdProvider); - return HttpClientAccessToken{.Value = Token.AccessToken, .ExpireTime = Token.ExpireTime}; + return HttpClientAccessToken(Token.AccessToken, Token.ExpireTime); }; } @@ -172,7 +171,7 @@ namespace zen { namespace httpclientauth { HttpClientAccessToken::TimePoint ExpireTime = std::chrono::system_clock::from_time_t(UTCTime); ExpireTime += std::chrono::milliseconds(Millisecond); - return HttpClientAccessToken{.Value = fmt::format("Bearer {}"sv, Token), .ExpireTime = ExpireTime}; + return HttpClientAccessToken(fmt::format("Bearer {}"sv, Token), ExpireTime); } else { @@ -192,16 +191,15 @@ namespace zen { namespace httpclientauth { { return [OidcExecutablePath = std::filesystem::path(OidcExecutablePath), CloudHost = std::string(CloudHost), + Token = InitialToken, Quiet, - Hidden, - InitialToken]() mutable { - if (InitialToken.IsValid()) + Unattended, + Hidden]() mutable { + if (!Token.NeedsRefresh()) { - HttpClientAccessToken Result = InitialToken; - InitialToken = {}; - return Result; + return std::move(Token); } - return GetOidcTokenFromExe(OidcExecutablePath, CloudHost, /* Unattended */ true, Quiet, Hidden); + return GetOidcTokenFromExe(OidcExecutablePath, CloudHost, Unattended, Quiet, Hidden); }; } return {}; diff --git a/src/zenhttp/include/zenhttp/httpclient.h b/src/zenhttp/include/zenhttp/httpclient.h index e878c900f..9531b9366 100644 --- a/src/zenhttp/include/zenhttp/httpclient.h +++ b/src/zenhttp/include/zenhttp/httpclient.h @@ -68,14 +68,30 @@ struct HttpClientAccessToken static constexpr int64_t ExpireMarginInSeconds = 60 * 5; - std::string Value; - TimePoint ExpireTime; + HttpClientAccessToken() {} - bool IsValid() const + HttpClientAccessToken(std::string_view InValue, const TimePoint& InExpireTime) : Value(InValue), ExpireTime(InExpireTime) {} + + std::optional<std::string> GetValue() const { - return Value.empty() == false && - ExpireMarginInSeconds < std::chrono::duration_cast<std::chrono::seconds>(ExpireTime - Clock::now()).count(); + if (IsValid()) + { + return Value; + } + return {}; } + + bool NeedsRefresh() const + { + return Value.empty() == true || (Clock::now() + std::chrono::seconds(ExpireMarginInSeconds)) >= ExpireTime; + } + + bool HasExpired() const { return Clock::now() >= ExpireTime; } + bool IsValid() const { return !Value.empty() && !HasExpired(); } + +private: + std::string Value; + TimePoint ExpireTime; }; struct HttpClientSettings diff --git a/src/zenhttp/include/zenhttp/httpclientauth.h b/src/zenhttp/include/zenhttp/httpclientauth.h index 26f31ed2a..f1bccdca6 100644 --- a/src/zenhttp/include/zenhttp/httpclientauth.h +++ b/src/zenhttp/include/zenhttp/httpclientauth.h @@ -10,6 +10,10 @@ namespace zen { class AuthMgr; namespace httpclientauth { + + // The std::function<HttpClientAccessToken()> instances returned from these functions are not guarateed to + // be thread safe so caller must make sure they are not called from multiple threads in parallell + std::function<HttpClientAccessToken()> CreateFromStaticToken(HttpClientAccessToken Token); std::function<HttpClientAccessToken()> CreateFromStaticToken(std::string_view Token); diff --git a/src/zenserver/storage/projectstore/httpprojectstore.cpp b/src/zenserver/storage/projectstore/httpprojectstore.cpp index 425caee97..03b8aa382 100644 --- a/src/zenserver/storage/projectstore/httpprojectstore.cpp +++ b/src/zenserver/storage/projectstore/httpprojectstore.cpp @@ -546,7 +546,7 @@ namespace { Host.empty() ? OverrideHost : Host, /*Quiet*/ false, /*Unattended*/ false, - /*Hidden*/ true); + /*Hidden*/ false); TokenProviderMaybe) { TokenProvider = TokenProviderMaybe.value(); |