diff options
Diffstat (limited to 'src/zenhttp/clients/httpwsclient.cpp')
| -rw-r--r-- | src/zenhttp/clients/httpwsclient.cpp | 49 |
1 files changed, 43 insertions, 6 deletions
diff --git a/src/zenhttp/clients/httpwsclient.cpp b/src/zenhttp/clients/httpwsclient.cpp index 770213738..4337fcb79 100644 --- a/src/zenhttp/clients/httpwsclient.cpp +++ b/src/zenhttp/clients/httpwsclient.cpp @@ -243,13 +243,9 @@ struct HttpWsClient::Impl << "Sec-WebSocket-Version: 13\r\n"; // Add Authorization header if access token provider is set - if (m_Settings.AccessTokenProvider) + if (std::optional<std::string> AccessToken = GetAccessToken(); AccessToken.has_value()) { - HttpClientAccessToken Token = (*m_Settings.AccessTokenProvider)(); - if (Token.IsValid()) - { - Request << "Authorization: Bearer " << Token.Value << "\r\n"; - } + Request << "Authorization: Bearer " << AccessToken.value() << "\r\n"; } Request << "\r\n"; @@ -557,10 +553,51 @@ struct HttpWsClient::Impl } } + std::optional<std::string> GetAccessToken() + { + if (!m_Settings.AccessTokenProvider.has_value()) + { + return {}; + } + { + RwLock::SharedLockScope _(m_AccessTokenLock); + if (!m_CachedAccessToken.NeedsRefresh()) + { + return m_CachedAccessToken.GetValue(); + } + } + RwLock::ExclusiveLockScope _(m_AccessTokenLock); + if (!m_CachedAccessToken.NeedsRefresh()) + { + return m_CachedAccessToken.GetValue(); + } + HttpClientAccessToken NewAccessToken = m_Settings.AccessTokenProvider.value()(); + if (NewAccessToken.IsValid()) + { + m_CachedAccessToken = NewAccessToken; + } + else + { + if (m_CachedAccessToken.HasExpired()) + { + ZEN_WARN("HttpWsClient refreshed access token is not valid, clearing the cached token as it has expired"); + m_CachedAccessToken = {}; + } + else + { + ZEN_WARN("HttpWsClient refreshed access token is not valid, keeping existing token, it will expire soon"); + } + } + return m_CachedAccessToken.GetValue(); + } + IWsClientHandler& m_Handler; HttpWsClientSettings m_Settings; LoggerRef m_Log; + RwLock m_AccessTokenLock; + HttpClientAccessToken m_CachedAccessToken; + std::string m_Host; std::string m_Port; std::string m_Path; |