aboutsummaryrefslogtreecommitdiff
path: root/zenserver/auth/authmgr.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'zenserver/auth/authmgr.cpp')
-rw-r--r--zenserver/auth/authmgr.cpp296
1 files changed, 296 insertions, 0 deletions
diff --git a/zenserver/auth/authmgr.cpp b/zenserver/auth/authmgr.cpp
new file mode 100644
index 000000000..28e128fc0
--- /dev/null
+++ b/zenserver/auth/authmgr.cpp
@@ -0,0 +1,296 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <auth/authmgr.h>
+#include <auth/oidc.h>
+
+#include <zencore/compactbinary.h>
+#include <zencore/compactbinarybuilder.h>
+#include <zencore/compactbinaryvalidation.h>
+#include <zencore/filesystem.h>
+#include <zencore/logging.h>
+
+#include <chrono>
+#include <condition_variable>
+#include <memory>
+#include <shared_mutex>
+#include <thread>
+#include <unordered_map>
+
+#include <fmt/format.h>
+
+namespace zen {
+
+using namespace std::literals;
+
+class AuthMgrImpl final : public AuthMgr
+{
+public:
+ AuthMgrImpl(const AuthConfig& Config) : m_Config(Config), m_Log(logging::Get("auth")) { LoadState(); }
+
+ virtual ~AuthMgrImpl() { SaveState(); }
+
+ virtual void AddOpenIdProvider(const AddOpenIdProviderParams& Params) final
+ {
+ if (OpenIdProviderExist(Params.Name))
+ {
+ return;
+ }
+
+ std::unique_ptr<OidcClient> Client =
+ std::make_unique<OidcClient>(OidcClient::Options{.BaseUrl = Params.Url, .ClientId = Params.ClientId});
+
+ if (const auto InitResult = Client->Initialize(); InitResult.Ok == false)
+ {
+ ZEN_WARN("query OpenID provider FAILED, reason '{}'", InitResult.Reason);
+ return;
+ }
+
+ std::string NewProviderName = std::string(Params.Name);
+
+ OpenIdProvider* NewProvider = nullptr;
+
+ {
+ std::unique_lock _(m_ProviderMutex);
+
+ if (m_OpenIdProviders.contains(NewProviderName))
+ {
+ return;
+ }
+
+ auto InsertResult = m_OpenIdProviders.emplace(NewProviderName, std::make_unique<OpenIdProvider>());
+ NewProvider = InsertResult.first->second.get();
+ }
+
+ NewProvider->Name = std::string(Params.Name);
+ NewProvider->Url = std::string(Params.Url);
+ NewProvider->ClientId = std::string(Params.ClientId);
+ NewProvider->HttpClient = std::move(Client);
+
+ ZEN_INFO("added OpenID provider '{} - {}'", Params.Name, Params.Url);
+ }
+
+ virtual bool AddOpenIdToken(const AddOpenIdTokenParams& Params) final
+ {
+ if (Params.ProviderName.empty())
+ {
+ ZEN_WARN("trying add OpenID token with invalid provider name");
+ return false;
+ }
+
+ if (Params.IdentityToken.empty() || Params.RefreshToken.empty() || Params.AccessToken.empty())
+ {
+ ZEN_WARN("add OpenId token FAILED, reason 'Token invalid'");
+ return false;
+ }
+
+ auto RefreshResult = RefreshOpenIdToken(Params.ProviderName, Params.RefreshToken);
+
+ if (RefreshResult.Ok == false)
+ {
+ ZEN_WARN("refresh OpenId token FAILED, reason '{}'", RefreshResult.Reason);
+ return false;
+ }
+
+ bool IsNew = false;
+
+ {
+ std::unique_lock _(m_TokenMutex);
+
+ const auto InsertResult = m_OpenIdTokens.try_emplace(std::string(Params.ProviderName),
+ OpenIdToken{.IdentityToken = RefreshResult.IdentityToken,
+ .RefreshToken = RefreshResult.RefreshToken,
+ .AccessToken = RefreshResult.AccessToken});
+
+ IsNew = InsertResult.second;
+ }
+
+ if (IsNew)
+ {
+ ZEN_INFO("added new OpenID token for provider '{}'", Params.ProviderName);
+ }
+ else
+ {
+ ZEN_INFO("updating OpenID token for provider '{}'", Params.ProviderName);
+ }
+
+ return true;
+ }
+
+ virtual OpenIdAccessToken GetOpenIdAccessToken(std::string_view ProviderName) final
+ {
+ std::unique_lock _(m_TokenMutex);
+
+ if (auto It = m_OpenIdTokens.find(std::string(ProviderName)); It != m_OpenIdTokens.end())
+ {
+ const OpenIdToken& Token = It->second;
+
+ return {.AccessToken = Token.AccessToken};
+ }
+
+ return {};
+ }
+
+private:
+ bool OpenIdProviderExist(std::string_view ProviderName)
+ {
+ std::unique_lock _(m_ProviderMutex);
+
+ return m_OpenIdProviders.contains(std::string(ProviderName));
+ }
+
+ OidcClient& GetOpenIdClient(std::string_view ProviderName)
+ {
+ std::unique_lock _(m_ProviderMutex);
+ return *m_OpenIdProviders[std::string(ProviderName)]->HttpClient.get();
+ }
+
+ OidcClient::RefreshTokenResult RefreshOpenIdToken(std::string_view ProviderName, std::string_view RefreshToken)
+ {
+ if (OpenIdProviderExist(ProviderName) == false)
+ {
+ return {.Reason = fmt::format("provider '{}' is missing", ProviderName)};
+ }
+
+ OidcClient& Client = GetOpenIdClient(ProviderName);
+
+ return Client.RefreshToken(RefreshToken);
+ }
+
+ void Shutdown() { SaveState(); }
+
+ void LoadState()
+ {
+ FileContents Result = ReadFile(m_Config.RootDirectory / "authstate"sv);
+
+ if (Result.ErrorCode)
+ {
+ return;
+ }
+
+ IoBuffer Buffer = Result.Flatten();
+
+ const CbValidateError ValidationError = ValidateCompactBinary(Buffer, CbValidateMode::All);
+
+ if (ValidationError != CbValidateError::None)
+ {
+ ZEN_WARN("load serialized state FAILED, reason 'Invalid compact binary'");
+ return;
+ }
+
+ if (CbObject AuthState = LoadCompactBinaryObject(Buffer))
+ {
+ for (CbFieldView ProviderView : AuthState["OpenIdProviders"sv])
+ {
+ CbObjectView ProviderObj = ProviderView.AsObjectView();
+
+ std::string_view ProviderName = ProviderObj["Name"].AsString();
+ std::string_view Url = ProviderObj["Url"].AsString();
+ std::string_view ClientId = ProviderObj["ClientId"].AsString();
+
+ AddOpenIdProvider({.Name = ProviderName, .Url = Url, .ClientId = ClientId});
+ }
+
+ for (CbFieldView TokenView : AuthState["OpenIdTokens"sv])
+ {
+ CbObjectView TokenObj = TokenView.AsObjectView();
+
+ std::string_view ProviderName = TokenObj["ProviderName"sv].AsString();
+ std::string_view IdentityToken = TokenObj["IdentityToken"sv].AsString();
+ std::string_view RefreshToken = TokenObj["RefreshToken"sv].AsString();
+ std::string_view AccessToken = TokenObj["AccessToken"sv].AsString();
+
+ const bool Ok = AddOpenIdToken({.ProviderName = ProviderName,
+ .IdentityToken = IdentityToken,
+ .RefreshToken = RefreshToken,
+ .AccessToken = AccessToken});
+
+ if (!Ok)
+ {
+ ZEN_WARN("load serialized OpenId token for provider '{}' FAILED", ProviderName);
+ }
+ }
+ }
+ }
+
+ void SaveState()
+ {
+ CbObjectWriter AuthState;
+
+ {
+ std::unique_lock _(m_ProviderMutex);
+
+ if (m_OpenIdProviders.size() > 0)
+ {
+ AuthState.BeginArray("OpenIdProviders");
+ for (const auto& Kv : m_OpenIdProviders)
+ {
+ AuthState.BeginObject();
+ AuthState << "Name"sv << Kv.second->Name;
+ AuthState << "Url"sv << Kv.second->Url;
+ AuthState << "ClientId"sv << Kv.second->ClientId;
+ AuthState.EndObject();
+ }
+ AuthState.EndArray();
+ }
+ }
+
+ {
+ std::unique_lock _(m_TokenMutex);
+
+ AuthState.BeginArray("OpenIdTokens");
+ if (m_OpenIdTokens.size() > 0)
+ {
+ for (const auto& Kv : m_OpenIdTokens)
+ {
+ AuthState.BeginObject();
+ AuthState << "ProviderName"sv << Kv.first;
+ AuthState << "IdentityToken"sv << Kv.second.IdentityToken;
+ AuthState << "RefreshToken"sv << Kv.second.RefreshToken;
+ AuthState << "AccessToken"sv << Kv.second.AccessToken;
+ AuthState << "ExpireTime"sv << Kv.second.ExpireTime;
+ AuthState.EndObject();
+ }
+ }
+ AuthState.EndArray();
+ }
+
+ std::filesystem::create_directories(m_Config.RootDirectory);
+ WriteFile(m_Config.RootDirectory / "authstate"sv, AuthState.Save().GetBuffer().AsIoBuffer());
+ }
+
+ struct OpenIdProvider
+ {
+ std::string Name;
+ std::string Url;
+ std::string ClientId;
+ std::unique_ptr<OidcClient> HttpClient;
+ };
+
+ struct OpenIdToken
+ {
+ std::string IdentityToken;
+ std::string RefreshToken;
+ std::string AccessToken;
+ double ExpireTime{};
+ };
+
+ using OpenIdProviderMap = std::unordered_map<std::string, std::unique_ptr<OpenIdProvider>>;
+ using OpenIdTokenMap = std::unordered_map<std::string, OpenIdToken>;
+
+ spdlog::logger& Log() { return m_Log; }
+
+ AuthConfig m_Config;
+ spdlog::logger& m_Log;
+ OpenIdProviderMap m_OpenIdProviders;
+ OpenIdTokenMap m_OpenIdTokens;
+ std::mutex m_ProviderMutex;
+ std::shared_mutex m_TokenMutex;
+};
+
+std::unique_ptr<AuthMgr>
+MakeAuthMgr(const AuthConfig& Config)
+{
+ return std::make_unique<AuthMgrImpl>(Config);
+}
+
+} // namespace zen