diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/zenhttp/clients/httpclientcommon.h | 2 | ||||
| -rw-r--r-- | src/zenhttp/clients/httpclientcpr.cpp | 19 | ||||
| -rw-r--r-- | src/zenhttp/clients/httpclientcpr.h | 14 | ||||
| -rw-r--r-- | src/zenhttp/clients/httpclientcurl.cpp | 9 | ||||
| -rw-r--r-- | src/zenhttp/clients/httpwsclient.cpp | 49 | ||||
| -rw-r--r-- | src/zenhttp/httpclient.cpp | 42 | ||||
| -rw-r--r-- | src/zenhttp/httpclient_test.cpp | 10 | ||||
| -rw-r--r-- | src/zenhttp/httpclientauth.cpp | 22 | ||||
| -rw-r--r-- | src/zenhttp/include/zenhttp/httpclient.h | 26 | ||||
| -rw-r--r-- | src/zenhttp/include/zenhttp/httpclientauth.h | 4 | ||||
| -rw-r--r-- | src/zenremotestore/jupiter/jupitersession.cpp | 3 | ||||
| -rw-r--r-- | src/zenserver/storage/projectstore/httpprojectstore.cpp | 2 |
12 files changed, 135 insertions, 67 deletions
diff --git a/src/zenhttp/clients/httpclientcommon.h b/src/zenhttp/clients/httpclientcommon.h index e95e3a253..e8d969cc8 100644 --- a/src/zenhttp/clients/httpclientcommon.h +++ b/src/zenhttp/clients/httpclientcommon.h @@ -70,7 +70,7 @@ protected: const HttpClientSettings m_ConnectionSettings; std::function<bool()> m_CheckIfAbortFunction; - const std::optional<HttpClientAccessToken> GetAccessToken(); + std::optional<std::string> GetAccessToken(); RwLock m_AccessTokenLock; HttpClientAccessToken m_CachedAccessToken; diff --git a/src/zenhttp/clients/httpclientcpr.cpp b/src/zenhttp/clients/httpclientcpr.cpp index a0f5cc38f..98a3ce612 100644 --- a/src/zenhttp/clients/httpclientcpr.cpp +++ b/src/zenhttp/clients/httpclientcpr.cpp @@ -398,13 +398,13 @@ CprHttpClient::DoWithRetry(std::string_view SessionId, ////////////////////////////////////////////////////////////////////////// CprHttpClient::Session -CprHttpClient::AllocSession(const std::string_view BaseUrl, - const std::string_view ResourcePath, - const HttpClientSettings& ConnectionSettings, - const KeyValueMap& AdditionalHeader, - const KeyValueMap& Parameters, - const std::string_view SessionId, - std::optional<HttpClientAccessToken> AccessToken) +CprHttpClient::AllocSession(const std::string_view BaseUrl, + const std::string_view ResourcePath, + const HttpClientSettings& ConnectionSettings, + const KeyValueMap& AdditionalHeader, + const KeyValueMap& Parameters, + const std::string_view SessionId, + std::optional<std::string> AccessToken) { ZEN_TRACE_CPU("CprHttpClient::AllocSession"); cpr::Session* CprSession = nullptr; @@ -493,9 +493,10 @@ CprHttpClient::AllocSession(const std::string_view BaseUrl, { CprSession->UpdateHeader({{"Connection", "close"}}); } - if (AccessToken) + + if (AccessToken.has_value()) { - CprSession->UpdateHeader({{"Authorization", AccessToken->Value}}); + CprSession->UpdateHeader({{"Authorization", AccessToken.value()}}); } if (!Parameters->empty()) { diff --git a/src/zenhttp/clients/httpclientcpr.h b/src/zenhttp/clients/httpclientcpr.h index 009e6fb7a..509ca5ae2 100644 --- a/src/zenhttp/clients/httpclientcpr.h +++ b/src/zenhttp/clients/httpclientcpr.h @@ -149,13 +149,13 @@ private: Session& operator=(Session&&) = delete; }; - Session AllocSession(const std::string_view BaseUrl, - const std::string_view Url, - const HttpClientSettings& ConnectionSettings, - const KeyValueMap& AdditionalHeader, - const KeyValueMap& Parameters, - const std::string_view SessionId, - std::optional<HttpClientAccessToken> AccessToken); + Session AllocSession(const std::string_view BaseUrl, + const std::string_view Url, + const HttpClientSettings& ConnectionSettings, + const KeyValueMap& AdditionalHeader, + const KeyValueMap& Parameters, + const std::string_view SessionId, + std::optional<std::string> AccessToken); RwLock m_SessionLock; std::vector<cpr::Session*> m_Sessions; diff --git a/src/zenhttp/clients/httpclientcurl.cpp b/src/zenhttp/clients/httpclientcurl.cpp index 341adc5f7..31665781d 100644 --- a/src/zenhttp/clients/httpclientcurl.cpp +++ b/src/zenhttp/clients/httpclientcurl.cpp @@ -278,7 +278,7 @@ HeaderContentType(ZenContentType ContentType) static curl_slist* BuildHeaderList(const HttpClient::KeyValueMap& AdditionalHeader, std::string_view SessionId, - const std::optional<HttpClientAccessToken>& AccessToken, + const std::optional<std::string>& AccessToken, const std::vector<std::pair<std::string, std::string>>& ExtraHeaders = {}) { curl_slist* Headers = nullptr; @@ -295,10 +295,11 @@ BuildHeaderList(const HttpClient::KeyValueMap& AdditionalHeader, Headers = curl_slist_append(Headers, SessionHeader.c_str()); } - if (AccessToken) + if (AccessToken.has_value()) { - std::string AuthHeader = fmt::format("Authorization: {}", AccessToken->Value); - Headers = curl_slist_append(Headers, AuthHeader.c_str()); + ExtendableStringBuilder<128> AuthHeader; + AuthHeader << "Authorization: " << AccessToken.value(); + Headers = curl_slist_append(Headers, AuthHeader.c_str()); } for (const auto& [Key, Value] : ExtraHeaders) diff --git a/src/zenhttp/clients/httpwsclient.cpp b/src/zenhttp/clients/httpwsclient.cpp index 792848a6b..1360ebccb 100644 --- a/src/zenhttp/clients/httpwsclient.cpp +++ b/src/zenhttp/clients/httpwsclient.cpp @@ -241,13 +241,9 @@ struct HttpWsClient::Impl << "Sec-WebSocket-Version: 13\r\n"; // Add Authorization header if access token provider is set - if (m_Settings.AccessTokenProvider) + if (std::optional<std::string> AccessToken = GetAccessToken(); AccessToken.has_value()) { - HttpClientAccessToken Token = (*m_Settings.AccessTokenProvider)(); - if (Token.IsValid()) - { - Request << "Authorization: Bearer " << Token.Value << "\r\n"; - } + Request << "Authorization: Bearer " << AccessToken.value() << "\r\n"; } Request << "\r\n"; @@ -555,10 +551,51 @@ struct HttpWsClient::Impl } } + std::optional<std::string> GetAccessToken() + { + if (!m_Settings.AccessTokenProvider.has_value()) + { + return {}; + } + { + RwLock::SharedLockScope _(m_AccessTokenLock); + if (!m_CachedAccessToken.NeedsRefresh()) + { + return m_CachedAccessToken.GetValue(); + } + } + RwLock::ExclusiveLockScope _(m_AccessTokenLock); + if (!m_CachedAccessToken.NeedsRefresh()) + { + return m_CachedAccessToken.GetValue(); + } + HttpClientAccessToken NewAccessToken = m_Settings.AccessTokenProvider.value()(); + if (NewAccessToken.IsValid()) + { + m_CachedAccessToken = NewAccessToken; + } + else + { + if (m_CachedAccessToken.HasExpired()) + { + ZEN_WARN("HttpWsClient refreshed access token is not valid, clearing the cached token as it has expired"); + m_CachedAccessToken = {}; + } + else + { + ZEN_WARN("HttpWsClient refreshed access token is not valid, keeping existing token, it will expire soon"); + } + } + return m_CachedAccessToken.GetValue(); + } + IWsClientHandler& m_Handler; HttpWsClientSettings m_Settings; LoggerRef m_Log; + RwLock m_AccessTokenLock; + HttpClientAccessToken m_CachedAccessToken; + std::string m_Host; std::string m_Port; std::string m_Path; diff --git a/src/zenhttp/httpclient.cpp b/src/zenhttp/httpclient.cpp index deeeb6c85..4465e0c73 100644 --- a/src/zenhttp/httpclient.cpp +++ b/src/zenhttp/httpclient.cpp @@ -99,15 +99,10 @@ bool HttpClientBase::Authenticate() { ZEN_TRACE_CPU("HttpClientBase::Authenticate"); - std::optional<HttpClientAccessToken> Token = GetAccessToken(); - if (!Token) - { - return false; - } - return Token->IsValid(); + return GetAccessToken().has_value(); } -const std::optional<HttpClientAccessToken> +std::optional<std::string> HttpClientBase::GetAccessToken() { ZEN_TRACE_CPU("HttpClientBase::GetAccessToken"); @@ -117,18 +112,39 @@ HttpClientBase::GetAccessToken() } { RwLock::SharedLockScope _(m_AccessTokenLock); - if (m_CachedAccessToken.IsValid()) + if (!m_CachedAccessToken.NeedsRefresh()) { - return m_CachedAccessToken; + return m_CachedAccessToken.GetValue(); } } RwLock::ExclusiveLockScope _(m_AccessTokenLock); - if (m_CachedAccessToken.IsValid()) + if (!m_CachedAccessToken.NeedsRefresh()) + { + return m_CachedAccessToken.GetValue(); + } + HttpClientAccessToken NewAccessToken = m_ConnectionSettings.AccessTokenProvider.value()(); + if (!NewAccessToken.IsValid()) + { + ZEN_WARN("HttpClient failed to refresh access token, retrying once"); + NewAccessToken = m_ConnectionSettings.AccessTokenProvider.value()(); + } + if (NewAccessToken.IsValid()) + { + m_CachedAccessToken = NewAccessToken; + } + else { - return m_CachedAccessToken; + if (m_CachedAccessToken.HasExpired()) + { + ZEN_WARN("HttpClient refreshed access token is not valid, clearing the cached token as it has expired"); + m_CachedAccessToken = {}; + } + else + { + ZEN_WARN("HttpClient refreshed access token is not valid, keeping existing token, it will expire soon"); + } } - m_CachedAccessToken = m_ConnectionSettings.AccessTokenProvider.value()(); - return m_CachedAccessToken; + return m_CachedAccessToken.GetValue(); } ////////////////////////////////////////////////////////////////////////// diff --git a/src/zenhttp/httpclient_test.cpp b/src/zenhttp/httpclient_test.cpp index 5f3ad2455..8858a6176 100644 --- a/src/zenhttp/httpclient_test.cpp +++ b/src/zenhttp/httpclient_test.cpp @@ -777,10 +777,7 @@ TEST_CASE("httpclient.authentication") { HttpClientSettings Settings; Settings.AccessTokenProvider = []() -> HttpClientAccessToken { - return HttpClientAccessToken{ - .Value = "valid-token", - .ExpireTime = HttpClientAccessToken::Clock::now() + std::chrono::hours(1), - }; + return HttpClientAccessToken("valid-token", HttpClientAccessToken::Clock::now() + std::chrono::hours(1)); }; HttpClient Client = Fixture.MakeClient(Settings); CHECK(Client.Authenticate()); @@ -790,10 +787,7 @@ TEST_CASE("httpclient.authentication") { HttpClientSettings Settings; Settings.AccessTokenProvider = []() -> HttpClientAccessToken { - return HttpClientAccessToken{ - .Value = "expired-token", - .ExpireTime = HttpClientAccessToken::Clock::now() - std::chrono::hours(1), - }; + return HttpClientAccessToken("expired-token", HttpClientAccessToken::Clock::now() - std::chrono::hours(1)); }; HttpClient Client = Fixture.MakeClient(Settings); CHECK(!Client.Authenticate()); diff --git a/src/zenhttp/httpclientauth.cpp b/src/zenhttp/httpclientauth.cpp index 02e1b57e2..6a3f18b7a 100644 --- a/src/zenhttp/httpclientauth.cpp +++ b/src/zenhttp/httpclientauth.cpp @@ -33,8 +33,7 @@ namespace zen { namespace httpclientauth { std::function<HttpClientAccessToken()> CreateFromStaticToken(std::string_view Token) { - return CreateFromStaticToken( - HttpClientAccessToken{.Value = fmt::format("Bearer {}"sv, Token), .ExpireTime = HttpClientAccessToken::TimePoint::max()}); + return CreateFromStaticToken(HttpClientAccessToken(fmt::format("Bearer {}"sv, Token), HttpClientAccessToken::TimePoint::max())); } std::function<HttpClientAccessToken()> CreateFromOAuthClientCredentials(const OAuthClientCredentialsParams& Params) @@ -74,7 +73,7 @@ namespace zen { namespace httpclientauth { int64_t ExpiresInSeconds = static_cast<int64_t>(Json["expires_in"].int_value()); HttpClientAccessToken::TimePoint ExpireTime = HttpClientAccessToken::Clock::now() + seconds(ExpiresInSeconds); - return HttpClientAccessToken{.Value = fmt::format("Bearer {}"sv, Token), .ExpireTime = ExpireTime}; + return HttpClientAccessToken(fmt::format("Bearer {}"sv, Token), ExpireTime); }; } @@ -82,7 +81,7 @@ namespace zen { namespace httpclientauth { { return [&AuthManager = AuthManager, OpenIdProvider = std::string(OpenIdProvider)]() { AuthMgr::OpenIdAccessToken Token = AuthManager.GetOpenIdAccessToken(OpenIdProvider); - return HttpClientAccessToken{.Value = Token.AccessToken, .ExpireTime = Token.ExpireTime}; + return HttpClientAccessToken(Token.AccessToken, Token.ExpireTime); }; } @@ -172,7 +171,7 @@ namespace zen { namespace httpclientauth { HttpClientAccessToken::TimePoint ExpireTime = std::chrono::system_clock::from_time_t(UTCTime); ExpireTime += std::chrono::milliseconds(Millisecond); - return HttpClientAccessToken{.Value = fmt::format("Bearer {}"sv, Token), .ExpireTime = ExpireTime}; + return HttpClientAccessToken(fmt::format("Bearer {}"sv, Token), ExpireTime); } else { @@ -192,16 +191,15 @@ namespace zen { namespace httpclientauth { { return [OidcExecutablePath = std::filesystem::path(OidcExecutablePath), CloudHost = std::string(CloudHost), + Token = InitialToken, Quiet, - Hidden, - InitialToken]() mutable { - if (InitialToken.IsValid()) + Unattended, + Hidden]() mutable { + if (!Token.NeedsRefresh()) { - HttpClientAccessToken Result = InitialToken; - InitialToken = {}; - return Result; + return std::move(Token); } - return GetOidcTokenFromExe(OidcExecutablePath, CloudHost, /* Unattended */ true, Quiet, Hidden); + return GetOidcTokenFromExe(OidcExecutablePath, CloudHost, Unattended, Quiet, Hidden); }; } return {}; diff --git a/src/zenhttp/include/zenhttp/httpclient.h b/src/zenhttp/include/zenhttp/httpclient.h index 03c98af7e..df862b10f 100644 --- a/src/zenhttp/include/zenhttp/httpclient.h +++ b/src/zenhttp/include/zenhttp/httpclient.h @@ -65,14 +65,30 @@ struct HttpClientAccessToken static constexpr int64_t ExpireMarginInSeconds = 60 * 5; - std::string Value; - TimePoint ExpireTime; + HttpClientAccessToken() {} - bool IsValid() const + HttpClientAccessToken(std::string_view InValue, const TimePoint& InExpireTime) : Value(InValue), ExpireTime(InExpireTime) {} + + std::optional<std::string> GetValue() const { - return Value.empty() == false && - ExpireMarginInSeconds < std::chrono::duration_cast<std::chrono::seconds>(ExpireTime - Clock::now()).count(); + if (IsValid()) + { + return Value; + } + return {}; } + + bool NeedsRefresh() const + { + return Value.empty() == true || (Clock::now() + std::chrono::seconds(ExpireMarginInSeconds)) >= ExpireTime; + } + + bool HasExpired() const { return Clock::now() >= ExpireTime; } + bool IsValid() const { return !Value.empty() && !HasExpired(); } + +private: + std::string Value; + TimePoint ExpireTime; }; struct HttpClientSettings diff --git a/src/zenhttp/include/zenhttp/httpclientauth.h b/src/zenhttp/include/zenhttp/httpclientauth.h index 26f31ed2a..f1bccdca6 100644 --- a/src/zenhttp/include/zenhttp/httpclientauth.h +++ b/src/zenhttp/include/zenhttp/httpclientauth.h @@ -10,6 +10,10 @@ namespace zen { class AuthMgr; namespace httpclientauth { + + // The std::function<HttpClientAccessToken()> instances returned from these functions are not guarateed to + // be thread safe so caller must make sure they are not called from multiple threads in parallell + std::function<HttpClientAccessToken()> CreateFromStaticToken(HttpClientAccessToken Token); std::function<HttpClientAccessToken()> CreateFromStaticToken(std::string_view Token); diff --git a/src/zenremotestore/jupiter/jupitersession.cpp b/src/zenremotestore/jupiter/jupitersession.cpp index b5531fa60..a9788cb4e 100644 --- a/src/zenremotestore/jupiter/jupitersession.cpp +++ b/src/zenremotestore/jupiter/jupitersession.cpp @@ -882,7 +882,8 @@ JupiterSession::GetBuildBlob(std::string_view Namespace, m_AllowRedirect ? "true"sv : "false"sv); HttpClient::Response Response = m_HttpClient.Download(Url, TempFolderPath, Headers); - if (Response.StatusCode == HttpResponseCode::RangeNotSatisfiable && Ranges.size() > 1) + if ((Response.StatusCode == HttpResponseCode::RangeNotSatisfiable || Response.StatusCode == HttpResponseCode::NotImplemented) && + Ranges.size() > 1) { // Requests to Jupiter that is not served via nginx (content not stored locally in the file system) can not serve multi-range // requests (asp.net limitation) This rejection is not implemented as of 2026-03-02, it is in the backlog (@joakim.lindqvist) diff --git a/src/zenserver/storage/projectstore/httpprojectstore.cpp b/src/zenserver/storage/projectstore/httpprojectstore.cpp index 2fa10a292..767ebde0e 100644 --- a/src/zenserver/storage/projectstore/httpprojectstore.cpp +++ b/src/zenserver/storage/projectstore/httpprojectstore.cpp @@ -546,7 +546,7 @@ namespace { Host.empty() ? OverrideHost : Host, /*Quiet*/ false, /*Unattended*/ false, - /*Hidden*/ true); + /*Hidden*/ false); TokenProviderMaybe) { TokenProvider = TokenProviderMaybe.value(); |