aboutsummaryrefslogtreecommitdiff
path: root/zenserver/auth/authmgr.cpp
diff options
context:
space:
mode:
authorStefan Boberg <[email protected]>2023-05-02 10:01:47 +0200
committerGitHub <[email protected]>2023-05-02 10:01:47 +0200
commit075d17f8ada47e990fe94606c3d21df409223465 (patch)
treee50549b766a2f3c354798a54ff73404217b4c9af /zenserver/auth/authmgr.cpp
parentfix: bundle shouldn't append content zip to zen (diff)
downloadzen-075d17f8ada47e990fe94606c3d21df409223465.tar.xz
zen-075d17f8ada47e990fe94606c3d21df409223465.zip
moved source directories into `/src` (#264)
* moved source directories into `/src` * updated bundle.lua for new `src` path * moved some docs, icon * removed old test trees
Diffstat (limited to 'zenserver/auth/authmgr.cpp')
-rw-r--r--zenserver/auth/authmgr.cpp506
1 files changed, 0 insertions, 506 deletions
diff --git a/zenserver/auth/authmgr.cpp b/zenserver/auth/authmgr.cpp
deleted file mode 100644
index 4cd6b3362..000000000
--- a/zenserver/auth/authmgr.cpp
+++ /dev/null
@@ -1,506 +0,0 @@
-// Copyright Epic Games, Inc. All Rights Reserved.
-
-#include <auth/authmgr.h>
-#include <auth/oidc.h>
-
-#include <zencore/compactbinary.h>
-#include <zencore/compactbinarybuilder.h>
-#include <zencore/compactbinaryvalidation.h>
-#include <zencore/crypto.h>
-#include <zencore/filesystem.h>
-#include <zencore/logging.h>
-
-#include <condition_variable>
-#include <memory>
-#include <shared_mutex>
-#include <thread>
-#include <unordered_map>
-
-#include <fmt/format.h>
-
-namespace zen {
-
-using namespace std::literals;
-
-namespace details {
- IoBuffer ReadEncryptedFile(std::filesystem::path Path,
- const AesKey256Bit& Key,
- const AesIV128Bit& IV,
- std::optional<std::string>& Reason)
- {
- FileContents Result = ReadFile(Path);
-
- if (Result.ErrorCode)
- {
- return IoBuffer();
- }
-
- IoBuffer EncryptedBuffer = Result.Flatten();
-
- if (EncryptedBuffer.GetSize() == 0)
- {
- return IoBuffer();
- }
-
- std::vector<uint8_t> DecryptionBuffer;
- DecryptionBuffer.resize(EncryptedBuffer.GetSize() + Aes::BlockSize);
-
- MemoryView DecryptedView = Aes::Decrypt(Key, IV, EncryptedBuffer, MakeMutableMemoryView(DecryptionBuffer), Reason);
-
- if (DecryptedView.IsEmpty())
- {
- return IoBuffer();
- }
-
- return IoBufferBuilder::MakeCloneFromMemory(DecryptedView);
- }
-
- void WriteEncryptedFile(std::filesystem::path Path,
- IoBuffer FileData,
- const AesKey256Bit& Key,
- const AesIV128Bit& IV,
- std::optional<std::string>& Reason)
- {
- if (FileData.GetSize() == 0)
- {
- return;
- }
-
- std::vector<uint8_t> EncryptionBuffer;
- EncryptionBuffer.resize(FileData.GetSize() + Aes::BlockSize);
-
- MemoryView EncryptedView = Aes::Encrypt(Key, IV, FileData, MakeMutableMemoryView(EncryptionBuffer), Reason);
-
- if (EncryptedView.IsEmpty())
- {
- return;
- }
-
- WriteFile(Path, IoBuffer(IoBuffer::Wrap, EncryptedView.GetData(), EncryptedView.GetSize()));
- }
-} // namespace details
-
-class AuthMgrImpl final : public AuthMgr
-{
- using Clock = std::chrono::system_clock;
- using TimePoint = Clock::time_point;
- using Seconds = std::chrono::seconds;
-
-public:
- AuthMgrImpl(const AuthConfig& Config) : m_Config(Config), m_Log(logging::Get("auth"))
- {
- LoadState();
-
- m_BackgroundThread.Interval = Config.UpdateInterval;
- m_BackgroundThread.Thread = std::thread(&AuthMgrImpl::BackgroundThreadEntry, this);
- }
-
- virtual ~AuthMgrImpl() { Shutdown(); }
-
- virtual void AddOpenIdProvider(const AddOpenIdProviderParams& Params) final
- {
- if (OpenIdProviderExist(Params.Name))
- {
- ZEN_DEBUG("OpenID provider '{}' already exist", Params.Name);
- return;
- }
-
- if (Params.Name.empty())
- {
- ZEN_WARN("add OpenID provider FAILED, reason 'invalid name'");
- return;
- }
-
- std::unique_ptr<OidcClient> Client =
- std::make_unique<OidcClient>(OidcClient::Options{.BaseUrl = Params.Url, .ClientId = Params.ClientId});
-
- if (const auto InitResult = Client->Initialize(); InitResult.Ok == false)
- {
- ZEN_WARN("query OpenID provider FAILED, reason '{}'", InitResult.Reason);
- return;
- }
-
- std::string NewProviderName = std::string(Params.Name);
-
- OpenIdProvider* NewProvider = nullptr;
-
- {
- std::unique_lock _(m_ProviderMutex);
-
- if (m_OpenIdProviders.contains(NewProviderName))
- {
- return;
- }
-
- auto InsertResult = m_OpenIdProviders.emplace(NewProviderName, std::make_unique<OpenIdProvider>());
- NewProvider = InsertResult.first->second.get();
- }
-
- NewProvider->Name = std::string(Params.Name);
- NewProvider->Url = std::string(Params.Url);
- NewProvider->ClientId = std::string(Params.ClientId);
- NewProvider->HttpClient = std::move(Client);
-
- ZEN_INFO("added OpenID provider '{} - {}'", Params.Name, Params.Url);
- }
-
- virtual bool AddOpenIdToken(const AddOpenIdTokenParams& Params) final
- {
- if (Params.ProviderName.empty())
- {
- ZEN_WARN("trying add OpenID token with invalid provider name");
- return false;
- }
-
- if (Params.RefreshToken.empty())
- {
- ZEN_WARN("add OpenID token FAILED, reason 'Token invalid'");
- return false;
- }
-
- auto RefreshResult = RefreshOpenIdToken(Params.ProviderName, Params.RefreshToken);
-
- if (RefreshResult.Ok == false)
- {
- ZEN_WARN("refresh OpenId token FAILED, reason '{}'", RefreshResult.Reason);
- return false;
- }
-
- bool IsNew = false;
-
- {
- auto Token = OpenIdToken{.IdentityToken = RefreshResult.IdentityToken,
- .RefreshToken = RefreshResult.RefreshToken,
- .AccessToken = fmt::format("Bearer {}"sv, RefreshResult.AccessToken),
- .ExpireTime = Clock::now() + Seconds(RefreshResult.ExpiresInSeconds)};
-
- std::unique_lock _(m_TokenMutex);
-
- const auto InsertResult = m_OpenIdTokens.insert_or_assign(std::string(Params.ProviderName), std::move(Token));
-
- IsNew = InsertResult.second;
- }
-
- if (IsNew)
- {
- ZEN_INFO("added new OpenID token for provider '{}'", Params.ProviderName);
- }
- else
- {
- ZEN_INFO("updating OpenID token for provider '{}'", Params.ProviderName);
- }
-
- return true;
- }
-
- virtual OpenIdAccessToken GetOpenIdAccessToken(std::string_view ProviderName) final
- {
- std::unique_lock _(m_TokenMutex);
-
- if (auto It = m_OpenIdTokens.find(std::string(ProviderName)); It != m_OpenIdTokens.end())
- {
- const OpenIdToken& Token = It->second;
-
- return {.AccessToken = Token.AccessToken, .ExpireTime = Token.ExpireTime};
- }
-
- return {};
- }
-
-private:
- bool OpenIdProviderExist(std::string_view ProviderName)
- {
- std::unique_lock _(m_ProviderMutex);
-
- return m_OpenIdProviders.contains(std::string(ProviderName));
- }
-
- OidcClient& GetOpenIdClient(std::string_view ProviderName)
- {
- std::unique_lock _(m_ProviderMutex);
- return *m_OpenIdProviders[std::string(ProviderName)]->HttpClient.get();
- }
-
- OidcClient::RefreshTokenResult RefreshOpenIdToken(std::string_view ProviderName, std::string_view RefreshToken)
- {
- if (OpenIdProviderExist(ProviderName) == false)
- {
- return {.Reason = fmt::format("provider '{}' is missing", ProviderName)};
- }
-
- OidcClient& Client = GetOpenIdClient(ProviderName);
-
- return Client.RefreshToken(RefreshToken);
- }
-
- void Shutdown()
- {
- BackgroundThread::Stop(m_BackgroundThread);
- SaveState();
- }
-
- void LoadState()
- {
- try
- {
- std::optional<std::string> Reason;
-
- IoBuffer Buffer =
- details::ReadEncryptedFile(m_Config.RootDirectory / "authstate"sv, m_Config.EncryptionKey, m_Config.EncryptionIV, Reason);
-
- if (!Buffer)
- {
- if (Reason)
- {
- ZEN_WARN("load auth state FAILED, reason '{}'", Reason.value());
- }
-
- return;
- }
-
- const CbValidateError ValidationError = ValidateCompactBinary(Buffer, CbValidateMode::All);
-
- if (ValidationError != CbValidateError::None)
- {
- ZEN_WARN("load serialized state FAILED, reason 'Invalid compact binary'");
- return;
- }
-
- if (CbObject AuthState = LoadCompactBinaryObject(Buffer))
- {
- for (CbFieldView ProviderView : AuthState["OpenIdProviders"sv])
- {
- CbObjectView ProviderObj = ProviderView.AsObjectView();
-
- std::string_view ProviderName = ProviderObj["Name"].AsString();
- std::string_view Url = ProviderObj["Url"].AsString();
- std::string_view ClientId = ProviderObj["ClientId"].AsString();
-
- AddOpenIdProvider({.Name = ProviderName, .Url = Url, .ClientId = ClientId});
- }
-
- for (CbFieldView TokenView : AuthState["OpenIdTokens"sv])
- {
- CbObjectView TokenObj = TokenView.AsObjectView();
-
- std::string_view ProviderName = TokenObj["ProviderName"sv].AsString();
- std::string_view RefreshToken = TokenObj["RefreshToken"sv].AsString();
-
- const bool Ok = AddOpenIdToken({.ProviderName = ProviderName, .RefreshToken = RefreshToken});
-
- if (!Ok)
- {
- ZEN_WARN("load serialized OpenId token for provider '{}' FAILED", ProviderName);
- }
- }
- }
- }
- catch (std::exception& Err)
- {
- ZEN_ERROR("(de)serialize state FAILED, reason '{}'", Err.what());
-
- {
- std::unique_lock _(m_ProviderMutex);
- m_OpenIdProviders.clear();
- }
-
- {
- std::unique_lock _(m_TokenMutex);
- m_OpenIdTokens.clear();
- }
- }
- }
-
- void SaveState()
- {
- try
- {
- CbObjectWriter AuthState;
-
- {
- std::unique_lock _(m_ProviderMutex);
-
- if (m_OpenIdProviders.size() > 0)
- {
- AuthState.BeginArray("OpenIdProviders");
- for (const auto& Kv : m_OpenIdProviders)
- {
- AuthState.BeginObject();
- AuthState << "Name"sv << Kv.second->Name;
- AuthState << "Url"sv << Kv.second->Url;
- AuthState << "ClientId"sv << Kv.second->ClientId;
- AuthState.EndObject();
- }
- AuthState.EndArray();
- }
- }
-
- {
- std::unique_lock _(m_TokenMutex);
-
- AuthState.BeginArray("OpenIdTokens");
- if (m_OpenIdTokens.size() > 0)
- {
- for (const auto& Kv : m_OpenIdTokens)
- {
- AuthState.BeginObject();
- AuthState << "ProviderName"sv << Kv.first;
- AuthState << "RefreshToken"sv << Kv.second.RefreshToken;
- AuthState.EndObject();
- }
- }
- AuthState.EndArray();
- }
-
- std::filesystem::create_directories(m_Config.RootDirectory);
-
- std::optional<std::string> Reason;
-
- details::WriteEncryptedFile(m_Config.RootDirectory / "authstate"sv,
- AuthState.Save().GetBuffer().AsIoBuffer(),
- m_Config.EncryptionKey,
- m_Config.EncryptionIV,
- Reason);
-
- if (Reason)
- {
- ZEN_WARN("save auth state FAILED, reason '{}'", Reason.value());
- }
- }
- catch (std::exception& Err)
- {
- ZEN_ERROR("serialize state FAILED, reason '{}'", Err.what());
- }
- }
-
- void BackgroundThreadEntry()
- {
- for (;;)
- {
- std::cv_status SignalStatus = BackgroundThread::WaitForSignal(m_BackgroundThread);
-
- if (m_BackgroundThread.Running.load() == false)
- {
- break;
- }
-
- if (SignalStatus != std::cv_status::timeout)
- {
- continue;
- }
-
- {
- // Refresh Open ID token(s)
-
- std::vector<OpenIdTokenMap::value_type> ExpiredTokens;
-
- {
- std::unique_lock _(m_TokenMutex);
-
- for (const auto& Kv : m_OpenIdTokens)
- {
- const Seconds ExpiresIn = std::chrono::duration_cast<Seconds>(Kv.second.ExpireTime - Clock::now());
- const bool Expired = ExpiresIn < Seconds(m_BackgroundThread.Interval * 2);
-
- if (Expired)
- {
- ExpiredTokens.push_back(Kv);
- }
- }
- }
-
- ZEN_DEBUG("refreshing '{}' OpenID token(s)", ExpiredTokens.size());
-
- for (const auto& Kv : ExpiredTokens)
- {
- OidcClient::RefreshTokenResult RefreshResult = RefreshOpenIdToken(Kv.first, Kv.second.RefreshToken);
-
- if (RefreshResult.Ok)
- {
- ZEN_DEBUG("refresh access token from provider '{}' Ok", Kv.first);
-
- auto Token = OpenIdToken{.IdentityToken = RefreshResult.IdentityToken,
- .RefreshToken = RefreshResult.RefreshToken,
- .AccessToken = fmt::format("Bearer {}"sv, RefreshResult.AccessToken),
- .ExpireTime = Clock::now() + Seconds(RefreshResult.ExpiresInSeconds)};
-
- {
- std::unique_lock _(m_TokenMutex);
- m_OpenIdTokens.insert_or_assign(Kv.first, std::move(Token));
- }
- }
- else
- {
- ZEN_WARN("refresh access token from provider '{}' FAILED, reason '{}'", Kv.first, RefreshResult.Reason);
- }
- }
- }
- }
- }
-
- struct BackgroundThread
- {
- std::chrono::seconds Interval{10};
- std::mutex Mutex;
- std::condition_variable Signal;
- std::atomic_bool Running{true};
- std::thread Thread;
-
- static void Stop(BackgroundThread& State)
- {
- if (State.Running.load())
- {
- State.Running.store(false);
- State.Signal.notify_one();
- }
-
- if (State.Thread.joinable())
- {
- State.Thread.join();
- }
- }
-
- static std::cv_status WaitForSignal(BackgroundThread& State)
- {
- std::unique_lock Lock(State.Mutex);
- return State.Signal.wait_for(Lock, State.Interval);
- }
- };
-
- struct OpenIdProvider
- {
- std::string Name;
- std::string Url;
- std::string ClientId;
- std::unique_ptr<OidcClient> HttpClient;
- };
-
- struct OpenIdToken
- {
- std::string IdentityToken;
- std::string RefreshToken;
- std::string AccessToken;
- TimePoint ExpireTime{};
- };
-
- using OpenIdProviderMap = std::unordered_map<std::string, std::unique_ptr<OpenIdProvider>>;
- using OpenIdTokenMap = std::unordered_map<std::string, OpenIdToken>;
-
- spdlog::logger& Log() { return m_Log; }
-
- AuthConfig m_Config;
- spdlog::logger& m_Log;
- BackgroundThread m_BackgroundThread;
- OpenIdProviderMap m_OpenIdProviders;
- OpenIdTokenMap m_OpenIdTokens;
- std::mutex m_ProviderMutex;
- std::shared_mutex m_TokenMutex;
-};
-
-std::unique_ptr<AuthMgr>
-AuthMgr::Create(const AuthConfig& Config)
-{
- return std::make_unique<AuthMgrImpl>(Config);
-}
-
-} // namespace zen