aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorDan Engelbrecht <[email protected]>2024-12-10 09:08:54 +0100
committerGitHub Enterprise <[email protected]>2024-12-10 09:08:54 +0100
commit3242a47fc2cd8d46d2d6482b7d386c26adca1ea2 (patch)
treedd232ab9430afbc5193214c27abf92ef7bbfb873 /src
parent5.5.16-pre0 (diff)
downloadzen-3242a47fc2cd8d46d2d6482b7d386c26adca1ea2.tar.xz
zen-3242a47fc2cd8d46d2d6482b7d386c26adca1ea2.zip
auth fixes (#260)
* fix so we can replace an openid provider that was read from disk file * fix OidcClient lifetime issues in authmg
Diffstat (limited to 'src')
-rw-r--r--src/zenhttp/auth/authmgr.cpp70
-rw-r--r--src/zenhttp/auth/authservice.cpp2
-rw-r--r--src/zenhttp/include/zenhttp/auth/oidc.h3
3 files changed, 43 insertions, 32 deletions
diff --git a/src/zenhttp/auth/authmgr.cpp b/src/zenhttp/auth/authmgr.cpp
index bf151ce6d..8da676908 100644
--- a/src/zenhttp/auth/authmgr.cpp
+++ b/src/zenhttp/auth/authmgr.cpp
@@ -100,20 +100,32 @@ public:
virtual void AddOpenIdProvider(const AddOpenIdProviderParams& Params) final
{
- if (OpenIdProviderExist(Params.Name))
+ if (Params.Name.empty())
{
- ZEN_DEBUG("OpenID provider '{}' already exist", Params.Name);
+ ZEN_WARN("add OpenID provider FAILED, reason 'invalid name'");
return;
}
- if (Params.Name.empty())
{
- ZEN_WARN("add OpenID provider FAILED, reason 'invalid name'");
- return;
+ std::unique_lock _(m_ProviderMutex);
+ if (auto It = m_OpenIdProviders.find(std::string(Params.Name)); It != m_OpenIdProviders.end())
+ {
+ OpenIdProvider& ExistingProvider = *It->second;
+ if (ExistingProvider.ClientId == Params.ClientId && ExistingProvider.Url == Params.Url)
+ {
+ ZEN_DEBUG("OpenID provider '{}' already exists", Params.Name);
+ return;
+ }
+ else
+ {
+ m_OpenIdProviders.erase(It);
+ m_OpenIdTokens.erase(std::string(Params.Name));
+ ZEN_DEBUG("OpenID provider '{}' removed to allow add of new with same name", Params.Name);
+ }
+ }
}
- std::unique_ptr<OidcClient> Client =
- std::make_unique<OidcClient>(OidcClient::Options{.BaseUrl = Params.Url, .ClientId = Params.ClientId});
+ RefPtr<OidcClient> Client(new OidcClient(OidcClient::Options{.BaseUrl = Params.Url, .ClientId = Params.ClientId}));
if (const auto InitResult = Client->Initialize(); InitResult.Ok == false)
{
@@ -209,29 +221,43 @@ public:
}
private:
+ struct OpenIdProvider
+ {
+ std::string Name;
+ std::string Url;
+ std::string ClientId;
+ RefPtr<OidcClient> HttpClient;
+ };
+
+ struct OpenIdToken
+ {
+ std::string IdentityToken;
+ std::string RefreshToken;
+ std::string AccessToken;
+ TimePoint ExpireTime{};
+ };
+
bool OpenIdProviderExist(std::string_view ProviderName)
{
std::unique_lock _(m_ProviderMutex);
-
return m_OpenIdProviders.contains(std::string(ProviderName));
}
- OidcClient& GetOpenIdClient(std::string_view ProviderName)
+ OpenIdProvider GetOpenIdProvider(std::string_view ProviderName)
{
std::unique_lock _(m_ProviderMutex);
- return *m_OpenIdProviders[std::string(ProviderName)]->HttpClient.get();
+ return *m_OpenIdProviders[std::string(ProviderName)];
}
OidcClient::RefreshTokenResult RefreshOpenIdToken(std::string_view ProviderName, std::string_view RefreshToken)
{
- if (OpenIdProviderExist(ProviderName) == false)
+ RefPtr<OidcClient> Client = GetOpenIdProvider(ProviderName).HttpClient;
+ if (!Client)
{
return {.Reason = fmt::format("provider '{}' is missing", ProviderName)};
}
- OidcClient& Client = GetOpenIdClient(ProviderName);
-
- return Client.RefreshToken(RefreshToken);
+ return Client->RefreshToken(RefreshToken);
}
void Shutdown()
@@ -475,22 +501,6 @@ private:
}
};
- struct OpenIdProvider
- {
- std::string Name;
- std::string Url;
- std::string ClientId;
- std::unique_ptr<OidcClient> HttpClient;
- };
-
- struct OpenIdToken
- {
- std::string IdentityToken;
- std::string RefreshToken;
- std::string AccessToken;
- TimePoint ExpireTime{};
- };
-
using OpenIdProviderMap = std::unordered_map<std::string, std::unique_ptr<OpenIdProvider>>;
using OpenIdTokenMap = std::unordered_map<std::string, OpenIdToken>;
diff --git a/src/zenhttp/auth/authservice.cpp b/src/zenhttp/auth/authservice.cpp
index 6ed587770..f89ca91da 100644
--- a/src/zenhttp/auth/authservice.cpp
+++ b/src/zenhttp/auth/authservice.cpp
@@ -56,7 +56,7 @@ HttpAuthService::HttpAuthService(AuthMgr& AuthMgr) : m_AuthMgr(AuthMgr)
if (Ok)
{
- ServerRequest.WriteResponse(Ok ? HttpResponseCode::OK : HttpResponseCode::BadRequest);
+ ServerRequest.WriteResponse(HttpResponseCode::OK);
}
else
{
diff --git a/src/zenhttp/include/zenhttp/auth/oidc.h b/src/zenhttp/include/zenhttp/auth/oidc.h
index f43ae3cd7..6f9c3198e 100644
--- a/src/zenhttp/include/zenhttp/auth/oidc.h
+++ b/src/zenhttp/include/zenhttp/auth/oidc.h
@@ -2,13 +2,14 @@
#pragma once
+#include <zenbase/refcount.h>
#include <zencore/string.h>
#include <vector>
namespace zen {
-class OidcClient
+class OidcClient : public RefCounted
{
public:
struct Options