diff options
| author | Dan Engelbrecht <[email protected]> | 2024-12-10 09:08:54 +0100 |
|---|---|---|
| committer | GitHub Enterprise <[email protected]> | 2024-12-10 09:08:54 +0100 |
| commit | 3242a47fc2cd8d46d2d6482b7d386c26adca1ea2 (patch) | |
| tree | dd232ab9430afbc5193214c27abf92ef7bbfb873 /src | |
| parent | 5.5.16-pre0 (diff) | |
| download | zen-3242a47fc2cd8d46d2d6482b7d386c26adca1ea2.tar.xz zen-3242a47fc2cd8d46d2d6482b7d386c26adca1ea2.zip | |
auth fixes (#260)
* fix so we can replace an openid provider that was read from disk file
* fix OidcClient lifetime issues in authmg
Diffstat (limited to 'src')
| -rw-r--r-- | src/zenhttp/auth/authmgr.cpp | 70 | ||||
| -rw-r--r-- | src/zenhttp/auth/authservice.cpp | 2 | ||||
| -rw-r--r-- | src/zenhttp/include/zenhttp/auth/oidc.h | 3 |
3 files changed, 43 insertions, 32 deletions
diff --git a/src/zenhttp/auth/authmgr.cpp b/src/zenhttp/auth/authmgr.cpp index bf151ce6d..8da676908 100644 --- a/src/zenhttp/auth/authmgr.cpp +++ b/src/zenhttp/auth/authmgr.cpp @@ -100,20 +100,32 @@ public: virtual void AddOpenIdProvider(const AddOpenIdProviderParams& Params) final { - if (OpenIdProviderExist(Params.Name)) + if (Params.Name.empty()) { - ZEN_DEBUG("OpenID provider '{}' already exist", Params.Name); + ZEN_WARN("add OpenID provider FAILED, reason 'invalid name'"); return; } - if (Params.Name.empty()) { - ZEN_WARN("add OpenID provider FAILED, reason 'invalid name'"); - return; + std::unique_lock _(m_ProviderMutex); + if (auto It = m_OpenIdProviders.find(std::string(Params.Name)); It != m_OpenIdProviders.end()) + { + OpenIdProvider& ExistingProvider = *It->second; + if (ExistingProvider.ClientId == Params.ClientId && ExistingProvider.Url == Params.Url) + { + ZEN_DEBUG("OpenID provider '{}' already exists", Params.Name); + return; + } + else + { + m_OpenIdProviders.erase(It); + m_OpenIdTokens.erase(std::string(Params.Name)); + ZEN_DEBUG("OpenID provider '{}' removed to allow add of new with same name", Params.Name); + } + } } - std::unique_ptr<OidcClient> Client = - std::make_unique<OidcClient>(OidcClient::Options{.BaseUrl = Params.Url, .ClientId = Params.ClientId}); + RefPtr<OidcClient> Client(new OidcClient(OidcClient::Options{.BaseUrl = Params.Url, .ClientId = Params.ClientId})); if (const auto InitResult = Client->Initialize(); InitResult.Ok == false) { @@ -209,29 +221,43 @@ public: } private: + struct OpenIdProvider + { + std::string Name; + std::string Url; + std::string ClientId; + RefPtr<OidcClient> HttpClient; + }; + + struct OpenIdToken + { + std::string IdentityToken; + std::string RefreshToken; + std::string AccessToken; + TimePoint ExpireTime{}; + }; + bool OpenIdProviderExist(std::string_view ProviderName) { std::unique_lock _(m_ProviderMutex); - return m_OpenIdProviders.contains(std::string(ProviderName)); } - OidcClient& GetOpenIdClient(std::string_view ProviderName) + OpenIdProvider GetOpenIdProvider(std::string_view ProviderName) { std::unique_lock _(m_ProviderMutex); - return *m_OpenIdProviders[std::string(ProviderName)]->HttpClient.get(); + return *m_OpenIdProviders[std::string(ProviderName)]; } OidcClient::RefreshTokenResult RefreshOpenIdToken(std::string_view ProviderName, std::string_view RefreshToken) { - if (OpenIdProviderExist(ProviderName) == false) + RefPtr<OidcClient> Client = GetOpenIdProvider(ProviderName).HttpClient; + if (!Client) { return {.Reason = fmt::format("provider '{}' is missing", ProviderName)}; } - OidcClient& Client = GetOpenIdClient(ProviderName); - - return Client.RefreshToken(RefreshToken); + return Client->RefreshToken(RefreshToken); } void Shutdown() @@ -475,22 +501,6 @@ private: } }; - struct OpenIdProvider - { - std::string Name; - std::string Url; - std::string ClientId; - std::unique_ptr<OidcClient> HttpClient; - }; - - struct OpenIdToken - { - std::string IdentityToken; - std::string RefreshToken; - std::string AccessToken; - TimePoint ExpireTime{}; - }; - using OpenIdProviderMap = std::unordered_map<std::string, std::unique_ptr<OpenIdProvider>>; using OpenIdTokenMap = std::unordered_map<std::string, OpenIdToken>; diff --git a/src/zenhttp/auth/authservice.cpp b/src/zenhttp/auth/authservice.cpp index 6ed587770..f89ca91da 100644 --- a/src/zenhttp/auth/authservice.cpp +++ b/src/zenhttp/auth/authservice.cpp @@ -56,7 +56,7 @@ HttpAuthService::HttpAuthService(AuthMgr& AuthMgr) : m_AuthMgr(AuthMgr) if (Ok) { - ServerRequest.WriteResponse(Ok ? HttpResponseCode::OK : HttpResponseCode::BadRequest); + ServerRequest.WriteResponse(HttpResponseCode::OK); } else { diff --git a/src/zenhttp/include/zenhttp/auth/oidc.h b/src/zenhttp/include/zenhttp/auth/oidc.h index f43ae3cd7..6f9c3198e 100644 --- a/src/zenhttp/include/zenhttp/auth/oidc.h +++ b/src/zenhttp/include/zenhttp/auth/oidc.h @@ -2,13 +2,14 @@ #pragma once +#include <zenbase/refcount.h> #include <zencore/string.h> #include <vector> namespace zen { -class OidcClient +class OidcClient : public RefCounted { public: struct Options |