aboutsummaryrefslogtreecommitdiff
path: root/src/zenhttp/auth/authmgr.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/zenhttp/auth/authmgr.cpp')
-rw-r--r--src/zenhttp/auth/authmgr.cpp70
1 files changed, 40 insertions, 30 deletions
diff --git a/src/zenhttp/auth/authmgr.cpp b/src/zenhttp/auth/authmgr.cpp
index bf151ce6d..8da676908 100644
--- a/src/zenhttp/auth/authmgr.cpp
+++ b/src/zenhttp/auth/authmgr.cpp
@@ -100,20 +100,32 @@ public:
virtual void AddOpenIdProvider(const AddOpenIdProviderParams& Params) final
{
- if (OpenIdProviderExist(Params.Name))
+ if (Params.Name.empty())
{
- ZEN_DEBUG("OpenID provider '{}' already exist", Params.Name);
+ ZEN_WARN("add OpenID provider FAILED, reason 'invalid name'");
return;
}
- if (Params.Name.empty())
{
- ZEN_WARN("add OpenID provider FAILED, reason 'invalid name'");
- return;
+ std::unique_lock _(m_ProviderMutex);
+ if (auto It = m_OpenIdProviders.find(std::string(Params.Name)); It != m_OpenIdProviders.end())
+ {
+ OpenIdProvider& ExistingProvider = *It->second;
+ if (ExistingProvider.ClientId == Params.ClientId && ExistingProvider.Url == Params.Url)
+ {
+ ZEN_DEBUG("OpenID provider '{}' already exists", Params.Name);
+ return;
+ }
+ else
+ {
+ m_OpenIdProviders.erase(It);
+ m_OpenIdTokens.erase(std::string(Params.Name));
+ ZEN_DEBUG("OpenID provider '{}' removed to allow add of new with same name", Params.Name);
+ }
+ }
}
- std::unique_ptr<OidcClient> Client =
- std::make_unique<OidcClient>(OidcClient::Options{.BaseUrl = Params.Url, .ClientId = Params.ClientId});
+ RefPtr<OidcClient> Client(new OidcClient(OidcClient::Options{.BaseUrl = Params.Url, .ClientId = Params.ClientId}));
if (const auto InitResult = Client->Initialize(); InitResult.Ok == false)
{
@@ -209,29 +221,43 @@ public:
}
private:
+ struct OpenIdProvider
+ {
+ std::string Name;
+ std::string Url;
+ std::string ClientId;
+ RefPtr<OidcClient> HttpClient;
+ };
+
+ struct OpenIdToken
+ {
+ std::string IdentityToken;
+ std::string RefreshToken;
+ std::string AccessToken;
+ TimePoint ExpireTime{};
+ };
+
bool OpenIdProviderExist(std::string_view ProviderName)
{
std::unique_lock _(m_ProviderMutex);
-
return m_OpenIdProviders.contains(std::string(ProviderName));
}
- OidcClient& GetOpenIdClient(std::string_view ProviderName)
+ OpenIdProvider GetOpenIdProvider(std::string_view ProviderName)
{
std::unique_lock _(m_ProviderMutex);
- return *m_OpenIdProviders[std::string(ProviderName)]->HttpClient.get();
+ return *m_OpenIdProviders[std::string(ProviderName)];
}
OidcClient::RefreshTokenResult RefreshOpenIdToken(std::string_view ProviderName, std::string_view RefreshToken)
{
- if (OpenIdProviderExist(ProviderName) == false)
+ RefPtr<OidcClient> Client = GetOpenIdProvider(ProviderName).HttpClient;
+ if (!Client)
{
return {.Reason = fmt::format("provider '{}' is missing", ProviderName)};
}
- OidcClient& Client = GetOpenIdClient(ProviderName);
-
- return Client.RefreshToken(RefreshToken);
+ return Client->RefreshToken(RefreshToken);
}
void Shutdown()
@@ -475,22 +501,6 @@ private:
}
};
- struct OpenIdProvider
- {
- std::string Name;
- std::string Url;
- std::string ClientId;
- std::unique_ptr<OidcClient> HttpClient;
- };
-
- struct OpenIdToken
- {
- std::string IdentityToken;
- std::string RefreshToken;
- std::string AccessToken;
- TimePoint ExpireTime{};
- };
-
using OpenIdProviderMap = std::unordered_map<std::string, std::unique_ptr<OpenIdProvider>>;
using OpenIdTokenMap = std::unordered_map<std::string, OpenIdToken>;