From 3242a47fc2cd8d46d2d6482b7d386c26adca1ea2 Mon Sep 17 00:00:00 2001 From: Dan Engelbrecht Date: Tue, 10 Dec 2024 09:08:54 +0100 Subject: auth fixes (#260) * fix so we can replace an openid provider that was read from disk file * fix OidcClient lifetime issues in authmg --- src/zenhttp/auth/authmgr.cpp | 70 +++++++++++++++++++++++++------------------- 1 file changed, 40 insertions(+), 30 deletions(-) (limited to 'src/zenhttp/auth/authmgr.cpp') diff --git a/src/zenhttp/auth/authmgr.cpp b/src/zenhttp/auth/authmgr.cpp index bf151ce6d..8da676908 100644 --- a/src/zenhttp/auth/authmgr.cpp +++ b/src/zenhttp/auth/authmgr.cpp @@ -100,20 +100,32 @@ public: virtual void AddOpenIdProvider(const AddOpenIdProviderParams& Params) final { - if (OpenIdProviderExist(Params.Name)) + if (Params.Name.empty()) { - ZEN_DEBUG("OpenID provider '{}' already exist", Params.Name); + ZEN_WARN("add OpenID provider FAILED, reason 'invalid name'"); return; } - if (Params.Name.empty()) { - ZEN_WARN("add OpenID provider FAILED, reason 'invalid name'"); - return; + std::unique_lock _(m_ProviderMutex); + if (auto It = m_OpenIdProviders.find(std::string(Params.Name)); It != m_OpenIdProviders.end()) + { + OpenIdProvider& ExistingProvider = *It->second; + if (ExistingProvider.ClientId == Params.ClientId && ExistingProvider.Url == Params.Url) + { + ZEN_DEBUG("OpenID provider '{}' already exists", Params.Name); + return; + } + else + { + m_OpenIdProviders.erase(It); + m_OpenIdTokens.erase(std::string(Params.Name)); + ZEN_DEBUG("OpenID provider '{}' removed to allow add of new with same name", Params.Name); + } + } } - std::unique_ptr Client = - std::make_unique(OidcClient::Options{.BaseUrl = Params.Url, .ClientId = Params.ClientId}); + RefPtr Client(new OidcClient(OidcClient::Options{.BaseUrl = Params.Url, .ClientId = Params.ClientId})); if (const auto InitResult = Client->Initialize(); InitResult.Ok == false) { @@ -209,29 +221,43 @@ public: } private: + struct OpenIdProvider + { + std::string Name; + std::string Url; + std::string ClientId; + RefPtr HttpClient; + }; + + struct OpenIdToken + { + std::string IdentityToken; + std::string RefreshToken; + std::string AccessToken; + TimePoint ExpireTime{}; + }; + bool OpenIdProviderExist(std::string_view ProviderName) { std::unique_lock _(m_ProviderMutex); - return m_OpenIdProviders.contains(std::string(ProviderName)); } - OidcClient& GetOpenIdClient(std::string_view ProviderName) + OpenIdProvider GetOpenIdProvider(std::string_view ProviderName) { std::unique_lock _(m_ProviderMutex); - return *m_OpenIdProviders[std::string(ProviderName)]->HttpClient.get(); + return *m_OpenIdProviders[std::string(ProviderName)]; } OidcClient::RefreshTokenResult RefreshOpenIdToken(std::string_view ProviderName, std::string_view RefreshToken) { - if (OpenIdProviderExist(ProviderName) == false) + RefPtr Client = GetOpenIdProvider(ProviderName).HttpClient; + if (!Client) { return {.Reason = fmt::format("provider '{}' is missing", ProviderName)}; } - OidcClient& Client = GetOpenIdClient(ProviderName); - - return Client.RefreshToken(RefreshToken); + return Client->RefreshToken(RefreshToken); } void Shutdown() @@ -475,22 +501,6 @@ private: } }; - struct OpenIdProvider - { - std::string Name; - std::string Url; - std::string ClientId; - std::unique_ptr HttpClient; - }; - - struct OpenIdToken - { - std::string IdentityToken; - std::string RefreshToken; - std::string AccessToken; - TimePoint ExpireTime{}; - }; - using OpenIdProviderMap = std::unordered_map>; using OpenIdTokenMap = std::unordered_map; -- cgit v1.2.3