diff options
Diffstat (limited to 'src/zenhttp/auth/authmgr.cpp')
| -rw-r--r-- | src/zenhttp/auth/authmgr.cpp | 70 |
1 files changed, 40 insertions, 30 deletions
diff --git a/src/zenhttp/auth/authmgr.cpp b/src/zenhttp/auth/authmgr.cpp index bf151ce6d..8da676908 100644 --- a/src/zenhttp/auth/authmgr.cpp +++ b/src/zenhttp/auth/authmgr.cpp @@ -100,20 +100,32 @@ public: virtual void AddOpenIdProvider(const AddOpenIdProviderParams& Params) final { - if (OpenIdProviderExist(Params.Name)) + if (Params.Name.empty()) { - ZEN_DEBUG("OpenID provider '{}' already exist", Params.Name); + ZEN_WARN("add OpenID provider FAILED, reason 'invalid name'"); return; } - if (Params.Name.empty()) { - ZEN_WARN("add OpenID provider FAILED, reason 'invalid name'"); - return; + std::unique_lock _(m_ProviderMutex); + if (auto It = m_OpenIdProviders.find(std::string(Params.Name)); It != m_OpenIdProviders.end()) + { + OpenIdProvider& ExistingProvider = *It->second; + if (ExistingProvider.ClientId == Params.ClientId && ExistingProvider.Url == Params.Url) + { + ZEN_DEBUG("OpenID provider '{}' already exists", Params.Name); + return; + } + else + { + m_OpenIdProviders.erase(It); + m_OpenIdTokens.erase(std::string(Params.Name)); + ZEN_DEBUG("OpenID provider '{}' removed to allow add of new with same name", Params.Name); + } + } } - std::unique_ptr<OidcClient> Client = - std::make_unique<OidcClient>(OidcClient::Options{.BaseUrl = Params.Url, .ClientId = Params.ClientId}); + RefPtr<OidcClient> Client(new OidcClient(OidcClient::Options{.BaseUrl = Params.Url, .ClientId = Params.ClientId})); if (const auto InitResult = Client->Initialize(); InitResult.Ok == false) { @@ -209,29 +221,43 @@ public: } private: + struct OpenIdProvider + { + std::string Name; + std::string Url; + std::string ClientId; + RefPtr<OidcClient> HttpClient; + }; + + struct OpenIdToken + { + std::string IdentityToken; + std::string RefreshToken; + std::string AccessToken; + TimePoint ExpireTime{}; + }; + bool OpenIdProviderExist(std::string_view ProviderName) { std::unique_lock _(m_ProviderMutex); - return m_OpenIdProviders.contains(std::string(ProviderName)); } - OidcClient& GetOpenIdClient(std::string_view ProviderName) + OpenIdProvider GetOpenIdProvider(std::string_view ProviderName) { std::unique_lock _(m_ProviderMutex); - return *m_OpenIdProviders[std::string(ProviderName)]->HttpClient.get(); + return *m_OpenIdProviders[std::string(ProviderName)]; } OidcClient::RefreshTokenResult RefreshOpenIdToken(std::string_view ProviderName, std::string_view RefreshToken) { - if (OpenIdProviderExist(ProviderName) == false) + RefPtr<OidcClient> Client = GetOpenIdProvider(ProviderName).HttpClient; + if (!Client) { return {.Reason = fmt::format("provider '{}' is missing", ProviderName)}; } - OidcClient& Client = GetOpenIdClient(ProviderName); - - return Client.RefreshToken(RefreshToken); + return Client->RefreshToken(RefreshToken); } void Shutdown() @@ -475,22 +501,6 @@ private: } }; - struct OpenIdProvider - { - std::string Name; - std::string Url; - std::string ClientId; - std::unique_ptr<OidcClient> HttpClient; - }; - - struct OpenIdToken - { - std::string IdentityToken; - std::string RefreshToken; - std::string AccessToken; - TimePoint ExpireTime{}; - }; - using OpenIdProviderMap = std::unordered_map<std::string, std::unique_ptr<OpenIdProvider>>; using OpenIdTokenMap = std::unordered_map<std::string, OpenIdToken>; |