diff options
Diffstat (limited to 'src/zenhttp/clients')
| -rw-r--r-- | src/zenhttp/clients/httpclientcommon.h | 2 | ||||
| -rw-r--r-- | src/zenhttp/clients/httpclientcpr.cpp | 19 | ||||
| -rw-r--r-- | src/zenhttp/clients/httpclientcpr.h | 14 | ||||
| -rw-r--r-- | src/zenhttp/clients/httpclientcurl.cpp | 6 | ||||
| -rw-r--r-- | src/zenhttp/clients/httpwsclient.cpp | 49 |
5 files changed, 64 insertions, 26 deletions
diff --git a/src/zenhttp/clients/httpclientcommon.h b/src/zenhttp/clients/httpclientcommon.h index e95e3a253..e8d969cc8 100644 --- a/src/zenhttp/clients/httpclientcommon.h +++ b/src/zenhttp/clients/httpclientcommon.h @@ -70,7 +70,7 @@ protected: const HttpClientSettings m_ConnectionSettings; std::function<bool()> m_CheckIfAbortFunction; - const std::optional<HttpClientAccessToken> GetAccessToken(); + std::optional<std::string> GetAccessToken(); RwLock m_AccessTokenLock; HttpClientAccessToken m_CachedAccessToken; diff --git a/src/zenhttp/clients/httpclientcpr.cpp b/src/zenhttp/clients/httpclientcpr.cpp index a52b8f74b..bd6de3ff7 100644 --- a/src/zenhttp/clients/httpclientcpr.cpp +++ b/src/zenhttp/clients/httpclientcpr.cpp @@ -399,13 +399,13 @@ CprHttpClient::DoWithRetry(std::string_view SessionId, ////////////////////////////////////////////////////////////////////////// CprHttpClient::Session -CprHttpClient::AllocSession(const std::string_view BaseUrl, - const std::string_view ResourcePath, - const HttpClientSettings& ConnectionSettings, - const KeyValueMap& AdditionalHeader, - const KeyValueMap& Parameters, - const std::string_view SessionId, - std::optional<HttpClientAccessToken> AccessToken) +CprHttpClient::AllocSession(const std::string_view BaseUrl, + const std::string_view ResourcePath, + const HttpClientSettings& ConnectionSettings, + const KeyValueMap& AdditionalHeader, + const KeyValueMap& Parameters, + const std::string_view SessionId, + std::optional<std::string> AccessToken) { ZEN_TRACE_CPU("CprHttpClient::AllocSession"); cpr::Session* CprSession = nullptr; @@ -494,9 +494,10 @@ CprHttpClient::AllocSession(const std::string_view BaseUrl, { CprSession->UpdateHeader({{"Connection", "close"}}); } - if (AccessToken) + + if (AccessToken.has_value()) { - CprSession->UpdateHeader({{"Authorization", AccessToken->Value}}); + CprSession->UpdateHeader({{"Authorization", AccessToken.value()}}); } if (!Parameters->empty()) { diff --git a/src/zenhttp/clients/httpclientcpr.h b/src/zenhttp/clients/httpclientcpr.h index 009e6fb7a..509ca5ae2 100644 --- a/src/zenhttp/clients/httpclientcpr.h +++ b/src/zenhttp/clients/httpclientcpr.h @@ -149,13 +149,13 @@ private: Session& operator=(Session&&) = delete; }; - Session AllocSession(const std::string_view BaseUrl, - const std::string_view Url, - const HttpClientSettings& ConnectionSettings, - const KeyValueMap& AdditionalHeader, - const KeyValueMap& Parameters, - const std::string_view SessionId, - std::optional<HttpClientAccessToken> AccessToken); + Session AllocSession(const std::string_view BaseUrl, + const std::string_view Url, + const HttpClientSettings& ConnectionSettings, + const KeyValueMap& AdditionalHeader, + const KeyValueMap& Parameters, + const std::string_view SessionId, + std::optional<std::string> AccessToken); RwLock m_SessionLock; std::vector<cpr::Session*> m_Sessions; diff --git a/src/zenhttp/clients/httpclientcurl.cpp b/src/zenhttp/clients/httpclientcurl.cpp index ec9b7bac6..e76157254 100644 --- a/src/zenhttp/clients/httpclientcurl.cpp +++ b/src/zenhttp/clients/httpclientcurl.cpp @@ -290,7 +290,7 @@ HeaderContentType(ZenContentType ContentType) static curl_slist* BuildHeaderList(const HttpClient::KeyValueMap& AdditionalHeader, std::string_view SessionId, - const std::optional<HttpClientAccessToken>& AccessToken, + const std::optional<std::string>& AccessToken, const std::vector<std::pair<std::string, std::string>>& ExtraHeaders = {}) { curl_slist* Headers = nullptr; @@ -309,10 +309,10 @@ BuildHeaderList(const HttpClient::KeyValueMap& AdditionalHeader, Headers = curl_slist_append(Headers, SessionHeader.c_str()); } - if (AccessToken) + if (AccessToken.has_value()) { ExtendableStringBuilder<128> AuthHeader; - AuthHeader << "Authorization: " << AccessToken->Value; + AuthHeader << "Authorization: " << AccessToken.value(); Headers = curl_slist_append(Headers, AuthHeader.c_str()); } diff --git a/src/zenhttp/clients/httpwsclient.cpp b/src/zenhttp/clients/httpwsclient.cpp index 770213738..4337fcb79 100644 --- a/src/zenhttp/clients/httpwsclient.cpp +++ b/src/zenhttp/clients/httpwsclient.cpp @@ -243,13 +243,9 @@ struct HttpWsClient::Impl << "Sec-WebSocket-Version: 13\r\n"; // Add Authorization header if access token provider is set - if (m_Settings.AccessTokenProvider) + if (std::optional<std::string> AccessToken = GetAccessToken(); AccessToken.has_value()) { - HttpClientAccessToken Token = (*m_Settings.AccessTokenProvider)(); - if (Token.IsValid()) - { - Request << "Authorization: Bearer " << Token.Value << "\r\n"; - } + Request << "Authorization: Bearer " << AccessToken.value() << "\r\n"; } Request << "\r\n"; @@ -557,10 +553,51 @@ struct HttpWsClient::Impl } } + std::optional<std::string> GetAccessToken() + { + if (!m_Settings.AccessTokenProvider.has_value()) + { + return {}; + } + { + RwLock::SharedLockScope _(m_AccessTokenLock); + if (!m_CachedAccessToken.NeedsRefresh()) + { + return m_CachedAccessToken.GetValue(); + } + } + RwLock::ExclusiveLockScope _(m_AccessTokenLock); + if (!m_CachedAccessToken.NeedsRefresh()) + { + return m_CachedAccessToken.GetValue(); + } + HttpClientAccessToken NewAccessToken = m_Settings.AccessTokenProvider.value()(); + if (NewAccessToken.IsValid()) + { + m_CachedAccessToken = NewAccessToken; + } + else + { + if (m_CachedAccessToken.HasExpired()) + { + ZEN_WARN("HttpWsClient refreshed access token is not valid, clearing the cached token as it has expired"); + m_CachedAccessToken = {}; + } + else + { + ZEN_WARN("HttpWsClient refreshed access token is not valid, keeping existing token, it will expire soon"); + } + } + return m_CachedAccessToken.GetValue(); + } + IWsClientHandler& m_Handler; HttpWsClientSettings m_Settings; LoggerRef m_Log; + RwLock m_AccessTokenLock; + HttpClientAccessToken m_CachedAccessToken; + std::string m_Host; std::string m_Port; std::string m_Path; |