From 59bc08385515997a34fe2b4b3cbbfd03dd9a7c5b Mon Sep 17 00:00:00 2001 From: Dan Engelbrecht Date: Wed, 18 Mar 2026 22:28:14 +0100 Subject: improve auth token refresh (#863) Authentication callbacks are not thread safe, ensured call sites does single threaded calls --- src/zenhttp/clients/httpwsclient.cpp | 49 +++++++++++++++++++++++++++++++----- 1 file changed, 43 insertions(+), 6 deletions(-) (limited to 'src/zenhttp/clients/httpwsclient.cpp') diff --git a/src/zenhttp/clients/httpwsclient.cpp b/src/zenhttp/clients/httpwsclient.cpp index 770213738..4337fcb79 100644 --- a/src/zenhttp/clients/httpwsclient.cpp +++ b/src/zenhttp/clients/httpwsclient.cpp @@ -243,13 +243,9 @@ struct HttpWsClient::Impl << "Sec-WebSocket-Version: 13\r\n"; // Add Authorization header if access token provider is set - if (m_Settings.AccessTokenProvider) + if (std::optional AccessToken = GetAccessToken(); AccessToken.has_value()) { - HttpClientAccessToken Token = (*m_Settings.AccessTokenProvider)(); - if (Token.IsValid()) - { - Request << "Authorization: Bearer " << Token.Value << "\r\n"; - } + Request << "Authorization: Bearer " << AccessToken.value() << "\r\n"; } Request << "\r\n"; @@ -557,10 +553,51 @@ struct HttpWsClient::Impl } } + std::optional GetAccessToken() + { + if (!m_Settings.AccessTokenProvider.has_value()) + { + return {}; + } + { + RwLock::SharedLockScope _(m_AccessTokenLock); + if (!m_CachedAccessToken.NeedsRefresh()) + { + return m_CachedAccessToken.GetValue(); + } + } + RwLock::ExclusiveLockScope _(m_AccessTokenLock); + if (!m_CachedAccessToken.NeedsRefresh()) + { + return m_CachedAccessToken.GetValue(); + } + HttpClientAccessToken NewAccessToken = m_Settings.AccessTokenProvider.value()(); + if (NewAccessToken.IsValid()) + { + m_CachedAccessToken = NewAccessToken; + } + else + { + if (m_CachedAccessToken.HasExpired()) + { + ZEN_WARN("HttpWsClient refreshed access token is not valid, clearing the cached token as it has expired"); + m_CachedAccessToken = {}; + } + else + { + ZEN_WARN("HttpWsClient refreshed access token is not valid, keeping existing token, it will expire soon"); + } + } + return m_CachedAccessToken.GetValue(); + } + IWsClientHandler& m_Handler; HttpWsClientSettings m_Settings; LoggerRef m_Log; + RwLock m_AccessTokenLock; + HttpClientAccessToken m_CachedAccessToken; + std::string m_Host; std::string m_Port; std::string m_Path; -- cgit v1.2.3