aboutsummaryrefslogtreecommitdiff
path: root/src/zenhttp/clients/httpwsclient.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/zenhttp/clients/httpwsclient.cpp')
-rw-r--r--src/zenhttp/clients/httpwsclient.cpp49
1 files changed, 43 insertions, 6 deletions
diff --git a/src/zenhttp/clients/httpwsclient.cpp b/src/zenhttp/clients/httpwsclient.cpp
index 770213738..4337fcb79 100644
--- a/src/zenhttp/clients/httpwsclient.cpp
+++ b/src/zenhttp/clients/httpwsclient.cpp
@@ -243,13 +243,9 @@ struct HttpWsClient::Impl
<< "Sec-WebSocket-Version: 13\r\n";
// Add Authorization header if access token provider is set
- if (m_Settings.AccessTokenProvider)
+ if (std::optional<std::string> AccessToken = GetAccessToken(); AccessToken.has_value())
{
- HttpClientAccessToken Token = (*m_Settings.AccessTokenProvider)();
- if (Token.IsValid())
- {
- Request << "Authorization: Bearer " << Token.Value << "\r\n";
- }
+ Request << "Authorization: Bearer " << AccessToken.value() << "\r\n";
}
Request << "\r\n";
@@ -557,10 +553,51 @@ struct HttpWsClient::Impl
}
}
+ std::optional<std::string> GetAccessToken()
+ {
+ if (!m_Settings.AccessTokenProvider.has_value())
+ {
+ return {};
+ }
+ {
+ RwLock::SharedLockScope _(m_AccessTokenLock);
+ if (!m_CachedAccessToken.NeedsRefresh())
+ {
+ return m_CachedAccessToken.GetValue();
+ }
+ }
+ RwLock::ExclusiveLockScope _(m_AccessTokenLock);
+ if (!m_CachedAccessToken.NeedsRefresh())
+ {
+ return m_CachedAccessToken.GetValue();
+ }
+ HttpClientAccessToken NewAccessToken = m_Settings.AccessTokenProvider.value()();
+ if (NewAccessToken.IsValid())
+ {
+ m_CachedAccessToken = NewAccessToken;
+ }
+ else
+ {
+ if (m_CachedAccessToken.HasExpired())
+ {
+ ZEN_WARN("HttpWsClient refreshed access token is not valid, clearing the cached token as it has expired");
+ m_CachedAccessToken = {};
+ }
+ else
+ {
+ ZEN_WARN("HttpWsClient refreshed access token is not valid, keeping existing token, it will expire soon");
+ }
+ }
+ return m_CachedAccessToken.GetValue();
+ }
+
IWsClientHandler& m_Handler;
HttpWsClientSettings m_Settings;
LoggerRef m_Log;
+ RwLock m_AccessTokenLock;
+ HttpClientAccessToken m_CachedAccessToken;
+
std::string m_Host;
std::string m_Port;
std::string m_Path;