aboutsummaryrefslogtreecommitdiff
path: root/src/zenhttp/clients
diff options
context:
space:
mode:
Diffstat (limited to 'src/zenhttp/clients')
-rw-r--r--src/zenhttp/clients/httpclientcommon.h2
-rw-r--r--src/zenhttp/clients/httpclientcpr.cpp19
-rw-r--r--src/zenhttp/clients/httpclientcpr.h14
-rw-r--r--src/zenhttp/clients/httpclientcurl.cpp6
-rw-r--r--src/zenhttp/clients/httpwsclient.cpp49
5 files changed, 64 insertions, 26 deletions
diff --git a/src/zenhttp/clients/httpclientcommon.h b/src/zenhttp/clients/httpclientcommon.h
index e95e3a253..e8d969cc8 100644
--- a/src/zenhttp/clients/httpclientcommon.h
+++ b/src/zenhttp/clients/httpclientcommon.h
@@ -70,7 +70,7 @@ protected:
const HttpClientSettings m_ConnectionSettings;
std::function<bool()> m_CheckIfAbortFunction;
- const std::optional<HttpClientAccessToken> GetAccessToken();
+ std::optional<std::string> GetAccessToken();
RwLock m_AccessTokenLock;
HttpClientAccessToken m_CachedAccessToken;
diff --git a/src/zenhttp/clients/httpclientcpr.cpp b/src/zenhttp/clients/httpclientcpr.cpp
index a52b8f74b..bd6de3ff7 100644
--- a/src/zenhttp/clients/httpclientcpr.cpp
+++ b/src/zenhttp/clients/httpclientcpr.cpp
@@ -399,13 +399,13 @@ CprHttpClient::DoWithRetry(std::string_view SessionId,
//////////////////////////////////////////////////////////////////////////
CprHttpClient::Session
-CprHttpClient::AllocSession(const std::string_view BaseUrl,
- const std::string_view ResourcePath,
- const HttpClientSettings& ConnectionSettings,
- const KeyValueMap& AdditionalHeader,
- const KeyValueMap& Parameters,
- const std::string_view SessionId,
- std::optional<HttpClientAccessToken> AccessToken)
+CprHttpClient::AllocSession(const std::string_view BaseUrl,
+ const std::string_view ResourcePath,
+ const HttpClientSettings& ConnectionSettings,
+ const KeyValueMap& AdditionalHeader,
+ const KeyValueMap& Parameters,
+ const std::string_view SessionId,
+ std::optional<std::string> AccessToken)
{
ZEN_TRACE_CPU("CprHttpClient::AllocSession");
cpr::Session* CprSession = nullptr;
@@ -494,9 +494,10 @@ CprHttpClient::AllocSession(const std::string_view BaseUrl,
{
CprSession->UpdateHeader({{"Connection", "close"}});
}
- if (AccessToken)
+
+ if (AccessToken.has_value())
{
- CprSession->UpdateHeader({{"Authorization", AccessToken->Value}});
+ CprSession->UpdateHeader({{"Authorization", AccessToken.value()}});
}
if (!Parameters->empty())
{
diff --git a/src/zenhttp/clients/httpclientcpr.h b/src/zenhttp/clients/httpclientcpr.h
index 009e6fb7a..509ca5ae2 100644
--- a/src/zenhttp/clients/httpclientcpr.h
+++ b/src/zenhttp/clients/httpclientcpr.h
@@ -149,13 +149,13 @@ private:
Session& operator=(Session&&) = delete;
};
- Session AllocSession(const std::string_view BaseUrl,
- const std::string_view Url,
- const HttpClientSettings& ConnectionSettings,
- const KeyValueMap& AdditionalHeader,
- const KeyValueMap& Parameters,
- const std::string_view SessionId,
- std::optional<HttpClientAccessToken> AccessToken);
+ Session AllocSession(const std::string_view BaseUrl,
+ const std::string_view Url,
+ const HttpClientSettings& ConnectionSettings,
+ const KeyValueMap& AdditionalHeader,
+ const KeyValueMap& Parameters,
+ const std::string_view SessionId,
+ std::optional<std::string> AccessToken);
RwLock m_SessionLock;
std::vector<cpr::Session*> m_Sessions;
diff --git a/src/zenhttp/clients/httpclientcurl.cpp b/src/zenhttp/clients/httpclientcurl.cpp
index ec9b7bac6..e76157254 100644
--- a/src/zenhttp/clients/httpclientcurl.cpp
+++ b/src/zenhttp/clients/httpclientcurl.cpp
@@ -290,7 +290,7 @@ HeaderContentType(ZenContentType ContentType)
static curl_slist*
BuildHeaderList(const HttpClient::KeyValueMap& AdditionalHeader,
std::string_view SessionId,
- const std::optional<HttpClientAccessToken>& AccessToken,
+ const std::optional<std::string>& AccessToken,
const std::vector<std::pair<std::string, std::string>>& ExtraHeaders = {})
{
curl_slist* Headers = nullptr;
@@ -309,10 +309,10 @@ BuildHeaderList(const HttpClient::KeyValueMap& AdditionalHeader,
Headers = curl_slist_append(Headers, SessionHeader.c_str());
}
- if (AccessToken)
+ if (AccessToken.has_value())
{
ExtendableStringBuilder<128> AuthHeader;
- AuthHeader << "Authorization: " << AccessToken->Value;
+ AuthHeader << "Authorization: " << AccessToken.value();
Headers = curl_slist_append(Headers, AuthHeader.c_str());
}
diff --git a/src/zenhttp/clients/httpwsclient.cpp b/src/zenhttp/clients/httpwsclient.cpp
index 770213738..4337fcb79 100644
--- a/src/zenhttp/clients/httpwsclient.cpp
+++ b/src/zenhttp/clients/httpwsclient.cpp
@@ -243,13 +243,9 @@ struct HttpWsClient::Impl
<< "Sec-WebSocket-Version: 13\r\n";
// Add Authorization header if access token provider is set
- if (m_Settings.AccessTokenProvider)
+ if (std::optional<std::string> AccessToken = GetAccessToken(); AccessToken.has_value())
{
- HttpClientAccessToken Token = (*m_Settings.AccessTokenProvider)();
- if (Token.IsValid())
- {
- Request << "Authorization: Bearer " << Token.Value << "\r\n";
- }
+ Request << "Authorization: Bearer " << AccessToken.value() << "\r\n";
}
Request << "\r\n";
@@ -557,10 +553,51 @@ struct HttpWsClient::Impl
}
}
+ std::optional<std::string> GetAccessToken()
+ {
+ if (!m_Settings.AccessTokenProvider.has_value())
+ {
+ return {};
+ }
+ {
+ RwLock::SharedLockScope _(m_AccessTokenLock);
+ if (!m_CachedAccessToken.NeedsRefresh())
+ {
+ return m_CachedAccessToken.GetValue();
+ }
+ }
+ RwLock::ExclusiveLockScope _(m_AccessTokenLock);
+ if (!m_CachedAccessToken.NeedsRefresh())
+ {
+ return m_CachedAccessToken.GetValue();
+ }
+ HttpClientAccessToken NewAccessToken = m_Settings.AccessTokenProvider.value()();
+ if (NewAccessToken.IsValid())
+ {
+ m_CachedAccessToken = NewAccessToken;
+ }
+ else
+ {
+ if (m_CachedAccessToken.HasExpired())
+ {
+ ZEN_WARN("HttpWsClient refreshed access token is not valid, clearing the cached token as it has expired");
+ m_CachedAccessToken = {};
+ }
+ else
+ {
+ ZEN_WARN("HttpWsClient refreshed access token is not valid, keeping existing token, it will expire soon");
+ }
+ }
+ return m_CachedAccessToken.GetValue();
+ }
+
IWsClientHandler& m_Handler;
HttpWsClientSettings m_Settings;
LoggerRef m_Log;
+ RwLock m_AccessTokenLock;
+ HttpClientAccessToken m_CachedAccessToken;
+
std::string m_Host;
std::string m_Port;
std::string m_Path;