aboutsummaryrefslogtreecommitdiff
path: root/zenserver/auth/authmgr.cpp
diff options
context:
space:
mode:
authorPer Larsson <[email protected]>2022-01-28 10:16:34 +0100
committerPer Larsson <[email protected]>2022-01-28 10:16:34 +0100
commitb34fcd781777c522b155be69239967b2dcfd1c36 (patch)
tree4de81ad72d94ce29857439171c76bd7bab551745 /zenserver/auth/authmgr.cpp
parentAdd OpenID auth to auth mgr. (diff)
downloadzen-b34fcd781777c522b155be69239967b2dcfd1c36.tar.xz
zen-b34fcd781777c522b155be69239967b2dcfd1c36.zip
Extended auth mgr to restore OpenID provider(s) and token(s).
Diffstat (limited to 'zenserver/auth/authmgr.cpp')
-rw-r--r--zenserver/auth/authmgr.cpp195
1 files changed, 179 insertions, 16 deletions
diff --git a/zenserver/auth/authmgr.cpp b/zenserver/auth/authmgr.cpp
index 4c97693f9..20bc7c988 100644
--- a/zenserver/auth/authmgr.cpp
+++ b/zenserver/auth/authmgr.cpp
@@ -1,25 +1,50 @@
// Copyright Epic Games, Inc. All Rights Reserved.
#include <auth/authmgr.h>
+#include <auth/oidc.h>
+
+#include <zencore/compactbinary.h>
+#include <zencore/compactbinarybuilder.h>
+#include <zencore/compactbinaryvalidation.h>
+#include <zencore/filesystem.h>
#include <zencore/logging.h>
#include <chrono>
#include <condition_variable>
+#include <memory>
#include <shared_mutex>
#include <thread>
#include <unordered_map>
+#include <fmt/format.h>
+
namespace zen {
+using namespace std::literals;
+
class AuthMgrImpl final : public AuthMgr
{
public:
- AuthMgrImpl(const AuthConfig& Config) : m_Log(logging::Get("auth")) { ZEN_UNUSED(Config); }
+ AuthMgrImpl(const AuthConfig& Config) : m_Config(Config), m_Log(logging::Get("auth")) { LoadState(); }
- virtual ~AuthMgrImpl() {}
+ virtual ~AuthMgrImpl() { SaveState(); }
virtual void AddOpenIdProvider(const AddOpenIdProviderParams& Params) final
{
+ if (OpenIdProviderExist(Params.Name))
+ {
+ return;
+ }
+
+ std::unique_ptr<OidcClient> Client =
+ std::make_unique<OidcClient>(OidcClient::Options{.BaseUrl = Params.Url, .ClientId = Params.ClientId});
+
+ if (const auto InitResult = Client->Initialize(); InitResult.Ok == false)
+ {
+ ZEN_WARN("query OpenID provider FAILED, reason '{}'", InitResult.Reason);
+ return;
+ }
+
std::string NewProviderName = std::string(Params.Name);
OpenIdProvider* NewProvider = nullptr;
@@ -36,9 +61,10 @@ public:
NewProvider = InsertResult.first->second.get();
}
- NewProvider->Name = std::string(Params.Name);
- NewProvider->Url = std::string(Params.Url);
- NewProvider->ClientId = std::string(Params.ClientId);
+ NewProvider->Name = std::string(Params.Name);
+ NewProvider->Url = std::string(Params.Url);
+ NewProvider->ClientId = std::string(Params.ClientId);
+ NewProvider->HttpClient = std::move(Client);
ZEN_INFO("added OpenID provider '{} - {}'", Params.Name, Params.Url);
}
@@ -53,7 +79,15 @@ public:
if (Params.IdentityToken.empty() || Params.RefreshToken.empty() || Params.AccessToken.empty())
{
- ZEN_WARN("trying add invalid OpenID token");
+ ZEN_WARN("add OpenId token FAILED, reason 'Token invalid'");
+ return false;
+ }
+
+ auto RefreshResult = RefreshOpenIdToken(Params.ProviderName, Params.RefreshToken);
+
+ if (RefreshResult.Ok == false)
+ {
+ ZEN_WARN("refresh OpenId token FAILED, reason '{}'", RefreshResult.Reason);
return false;
}
@@ -63,9 +97,9 @@ public:
std::unique_lock _(m_TokenMutex);
const auto InsertResult = m_OpenIdTokens.try_emplace(std::string(Params.ProviderName),
- OpenIdToken{.IdentityToken = std::string(Params.IdentityToken),
- .RefreshToken = Params.RefreshToken,
- .AccessToken = Params.AccessToken});
+ OpenIdToken{.IdentityToken = RefreshResult.IdentityToken,
+ .RefreshToken = RefreshResult.RefreshToken,
+ .AccessToken = RefreshResult.AccessToken});
IsNew = InsertResult.second;
}
@@ -83,19 +117,147 @@ public:
}
private:
+ bool OpenIdProviderExist(std::string_view ProviderName)
+ {
+ std::unique_lock _(m_ProviderMutex);
+
+ return m_OpenIdProviders.contains(std::string(ProviderName));
+ }
+
+ OidcClient& GetOpenIdClient(std::string_view ProviderName)
+ {
+ std::unique_lock _(m_ProviderMutex);
+ return *m_OpenIdProviders[std::string(ProviderName)]->HttpClient.get();
+ }
+
+ OidcClient::RefreshTokenResult RefreshOpenIdToken(std::string_view ProviderName, std::string_view RefreshToken)
+ {
+ if (OpenIdProviderExist(ProviderName) == false)
+ {
+ return {.Reason = fmt::format("provider '{}' is missing", ProviderName)};
+ }
+
+ OidcClient& Client = GetOpenIdClient(ProviderName);
+
+ return Client.RefreshToken(RefreshToken);
+ }
+
+ void Shutdown() { SaveState(); }
+
+ void LoadState()
+ {
+ FileContents Result = ReadFile(m_Config.RootDirectory / "authstate"sv);
+
+ if (Result.ErrorCode)
+ {
+ return;
+ }
+
+ IoBuffer Buffer = Result.Flatten();
+
+ const CbValidateError ValidationError = ValidateCompactBinary(Buffer, CbValidateMode::All);
+
+ if (ValidationError != CbValidateError::None)
+ {
+ ZEN_WARN("load serialized state FAILED, reason 'Invalid compact binary'");
+ return;
+ }
+
+ if (CbObject AuthState = LoadCompactBinaryObject(Buffer))
+ {
+ for (CbFieldView ProviderView : AuthState["OpenIdProviders"sv])
+ {
+ CbObjectView ProviderObj = ProviderView.AsObjectView();
+
+ std::string_view ProviderName = ProviderObj["Name"].AsString();
+ std::string_view Url = ProviderObj["Url"].AsString();
+ std::string_view ClientId = ProviderObj["ClientId"].AsString();
+
+ AddOpenIdProvider({.Name = ProviderName, .Url = Url, .ClientId = ClientId});
+ }
+
+ for (CbFieldView TokenView : AuthState["OpenIdTokens"sv])
+ {
+ CbObjectView TokenObj = TokenView.AsObjectView();
+
+ std::string_view ProviderName = TokenObj["ProviderName"sv].AsString();
+ std::string_view IdentityToken = TokenObj["IdentityToken"sv].AsString();
+ std::string_view RefreshToken = TokenObj["RefreshToken"sv].AsString();
+ std::string_view AccessToken = TokenObj["AccessToken"sv].AsString();
+
+ const bool Ok = AddOpenIdToken({.ProviderName = ProviderName,
+ .IdentityToken = IdentityToken,
+ .RefreshToken = RefreshToken,
+ .AccessToken = AccessToken});
+
+ if (!Ok)
+ {
+ ZEN_WARN("load serialized OpenId token for provider '{}' FAILED", ProviderName);
+ }
+ }
+ }
+ }
+
+ void SaveState()
+ {
+ CbObjectWriter AuthState;
+
+ {
+ std::unique_lock _(m_ProviderMutex);
+
+ if (m_OpenIdProviders.size() > 0)
+ {
+ AuthState.BeginArray("OpenIdProviders");
+ for (const auto& Kv : m_OpenIdProviders)
+ {
+ AuthState.BeginObject();
+ AuthState << "Name"sv << Kv.second->Name;
+ AuthState << "Url"sv << Kv.second->Url;
+ AuthState << "ClientId"sv << Kv.second->ClientId;
+ AuthState.EndObject();
+ }
+ AuthState.EndArray();
+ }
+ }
+
+ {
+ std::unique_lock _(m_TokenMutex);
+
+ AuthState.BeginArray("OpenIdTokens");
+ if (m_OpenIdTokens.size() > 0)
+ {
+ for (const auto& Kv : m_OpenIdTokens)
+ {
+ AuthState.BeginObject();
+ AuthState << "ProviderName"sv << Kv.first;
+ AuthState << "IdentityToken"sv << Kv.second.IdentityToken;
+ AuthState << "RefreshToken"sv << Kv.second.RefreshToken;
+ AuthState << "AccessToken"sv << Kv.second.AccessToken;
+ AuthState << "ExpireTime"sv << Kv.second.ExpireTime;
+ AuthState.EndObject();
+ }
+ }
+ AuthState.EndArray();
+ }
+
+ std::filesystem::create_directories(m_Config.RootDirectory);
+ WriteFile(m_Config.RootDirectory / "authstate"sv, AuthState.Save().GetBuffer().AsIoBuffer());
+ }
+
struct OpenIdProvider
{
- std::string Name;
- std::string Url;
- std::string ClientId;
+ std::string Name;
+ std::string Url;
+ std::string ClientId;
+ std::unique_ptr<OidcClient> HttpClient;
};
struct OpenIdToken
{
- std::string_view IdentityToken;
- std::string_view RefreshToken;
- std::string_view AccessToken;
- double ExpireTime{};
+ std::string IdentityToken;
+ std::string RefreshToken;
+ std::string AccessToken;
+ double ExpireTime{};
};
using OpenIdProviderMap = std::unordered_map<std::string, std::unique_ptr<OpenIdProvider>>;
@@ -103,6 +265,7 @@ private:
spdlog::logger& Log() { return m_Log; }
+ AuthConfig m_Config;
spdlog::logger& m_Log;
OpenIdProviderMap m_OpenIdProviders;
OpenIdTokenMap m_OpenIdTokens;