aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPer Larsson <[email protected]>2022-01-28 10:16:34 +0100
committerPer Larsson <[email protected]>2022-01-28 10:16:34 +0100
commitb34fcd781777c522b155be69239967b2dcfd1c36 (patch)
tree4de81ad72d94ce29857439171c76bd7bab551745
parentAdd OpenID auth to auth mgr. (diff)
downloadzen-b34fcd781777c522b155be69239967b2dcfd1c36.tar.xz
zen-b34fcd781777c522b155be69239967b2dcfd1c36.zip
Extended auth mgr to restore OpenID provider(s) and token(s).
-rw-r--r--zenserver/auth/authmgr.cpp195
-rw-r--r--zenserver/auth/authservice.cpp2
-rw-r--r--zenserver/auth/oidc.cpp127
-rw-r--r--zenserver/auth/oidc.h74
4 files changed, 382 insertions, 16 deletions
diff --git a/zenserver/auth/authmgr.cpp b/zenserver/auth/authmgr.cpp
index 4c97693f9..20bc7c988 100644
--- a/zenserver/auth/authmgr.cpp
+++ b/zenserver/auth/authmgr.cpp
@@ -1,25 +1,50 @@
// Copyright Epic Games, Inc. All Rights Reserved.
#include <auth/authmgr.h>
+#include <auth/oidc.h>
+
+#include <zencore/compactbinary.h>
+#include <zencore/compactbinarybuilder.h>
+#include <zencore/compactbinaryvalidation.h>
+#include <zencore/filesystem.h>
#include <zencore/logging.h>
#include <chrono>
#include <condition_variable>
+#include <memory>
#include <shared_mutex>
#include <thread>
#include <unordered_map>
+#include <fmt/format.h>
+
namespace zen {
+using namespace std::literals;
+
class AuthMgrImpl final : public AuthMgr
{
public:
- AuthMgrImpl(const AuthConfig& Config) : m_Log(logging::Get("auth")) { ZEN_UNUSED(Config); }
+ AuthMgrImpl(const AuthConfig& Config) : m_Config(Config), m_Log(logging::Get("auth")) { LoadState(); }
- virtual ~AuthMgrImpl() {}
+ virtual ~AuthMgrImpl() { SaveState(); }
virtual void AddOpenIdProvider(const AddOpenIdProviderParams& Params) final
{
+ if (OpenIdProviderExist(Params.Name))
+ {
+ return;
+ }
+
+ std::unique_ptr<OidcClient> Client =
+ std::make_unique<OidcClient>(OidcClient::Options{.BaseUrl = Params.Url, .ClientId = Params.ClientId});
+
+ if (const auto InitResult = Client->Initialize(); InitResult.Ok == false)
+ {
+ ZEN_WARN("query OpenID provider FAILED, reason '{}'", InitResult.Reason);
+ return;
+ }
+
std::string NewProviderName = std::string(Params.Name);
OpenIdProvider* NewProvider = nullptr;
@@ -36,9 +61,10 @@ public:
NewProvider = InsertResult.first->second.get();
}
- NewProvider->Name = std::string(Params.Name);
- NewProvider->Url = std::string(Params.Url);
- NewProvider->ClientId = std::string(Params.ClientId);
+ NewProvider->Name = std::string(Params.Name);
+ NewProvider->Url = std::string(Params.Url);
+ NewProvider->ClientId = std::string(Params.ClientId);
+ NewProvider->HttpClient = std::move(Client);
ZEN_INFO("added OpenID provider '{} - {}'", Params.Name, Params.Url);
}
@@ -53,7 +79,15 @@ public:
if (Params.IdentityToken.empty() || Params.RefreshToken.empty() || Params.AccessToken.empty())
{
- ZEN_WARN("trying add invalid OpenID token");
+ ZEN_WARN("add OpenId token FAILED, reason 'Token invalid'");
+ return false;
+ }
+
+ auto RefreshResult = RefreshOpenIdToken(Params.ProviderName, Params.RefreshToken);
+
+ if (RefreshResult.Ok == false)
+ {
+ ZEN_WARN("refresh OpenId token FAILED, reason '{}'", RefreshResult.Reason);
return false;
}
@@ -63,9 +97,9 @@ public:
std::unique_lock _(m_TokenMutex);
const auto InsertResult = m_OpenIdTokens.try_emplace(std::string(Params.ProviderName),
- OpenIdToken{.IdentityToken = std::string(Params.IdentityToken),
- .RefreshToken = Params.RefreshToken,
- .AccessToken = Params.AccessToken});
+ OpenIdToken{.IdentityToken = RefreshResult.IdentityToken,
+ .RefreshToken = RefreshResult.RefreshToken,
+ .AccessToken = RefreshResult.AccessToken});
IsNew = InsertResult.second;
}
@@ -83,19 +117,147 @@ public:
}
private:
+ bool OpenIdProviderExist(std::string_view ProviderName)
+ {
+ std::unique_lock _(m_ProviderMutex);
+
+ return m_OpenIdProviders.contains(std::string(ProviderName));
+ }
+
+ OidcClient& GetOpenIdClient(std::string_view ProviderName)
+ {
+ std::unique_lock _(m_ProviderMutex);
+ return *m_OpenIdProviders[std::string(ProviderName)]->HttpClient.get();
+ }
+
+ OidcClient::RefreshTokenResult RefreshOpenIdToken(std::string_view ProviderName, std::string_view RefreshToken)
+ {
+ if (OpenIdProviderExist(ProviderName) == false)
+ {
+ return {.Reason = fmt::format("provider '{}' is missing", ProviderName)};
+ }
+
+ OidcClient& Client = GetOpenIdClient(ProviderName);
+
+ return Client.RefreshToken(RefreshToken);
+ }
+
+ void Shutdown() { SaveState(); }
+
+ void LoadState()
+ {
+ FileContents Result = ReadFile(m_Config.RootDirectory / "authstate"sv);
+
+ if (Result.ErrorCode)
+ {
+ return;
+ }
+
+ IoBuffer Buffer = Result.Flatten();
+
+ const CbValidateError ValidationError = ValidateCompactBinary(Buffer, CbValidateMode::All);
+
+ if (ValidationError != CbValidateError::None)
+ {
+ ZEN_WARN("load serialized state FAILED, reason 'Invalid compact binary'");
+ return;
+ }
+
+ if (CbObject AuthState = LoadCompactBinaryObject(Buffer))
+ {
+ for (CbFieldView ProviderView : AuthState["OpenIdProviders"sv])
+ {
+ CbObjectView ProviderObj = ProviderView.AsObjectView();
+
+ std::string_view ProviderName = ProviderObj["Name"].AsString();
+ std::string_view Url = ProviderObj["Url"].AsString();
+ std::string_view ClientId = ProviderObj["ClientId"].AsString();
+
+ AddOpenIdProvider({.Name = ProviderName, .Url = Url, .ClientId = ClientId});
+ }
+
+ for (CbFieldView TokenView : AuthState["OpenIdTokens"sv])
+ {
+ CbObjectView TokenObj = TokenView.AsObjectView();
+
+ std::string_view ProviderName = TokenObj["ProviderName"sv].AsString();
+ std::string_view IdentityToken = TokenObj["IdentityToken"sv].AsString();
+ std::string_view RefreshToken = TokenObj["RefreshToken"sv].AsString();
+ std::string_view AccessToken = TokenObj["AccessToken"sv].AsString();
+
+ const bool Ok = AddOpenIdToken({.ProviderName = ProviderName,
+ .IdentityToken = IdentityToken,
+ .RefreshToken = RefreshToken,
+ .AccessToken = AccessToken});
+
+ if (!Ok)
+ {
+ ZEN_WARN("load serialized OpenId token for provider '{}' FAILED", ProviderName);
+ }
+ }
+ }
+ }
+
+ void SaveState()
+ {
+ CbObjectWriter AuthState;
+
+ {
+ std::unique_lock _(m_ProviderMutex);
+
+ if (m_OpenIdProviders.size() > 0)
+ {
+ AuthState.BeginArray("OpenIdProviders");
+ for (const auto& Kv : m_OpenIdProviders)
+ {
+ AuthState.BeginObject();
+ AuthState << "Name"sv << Kv.second->Name;
+ AuthState << "Url"sv << Kv.second->Url;
+ AuthState << "ClientId"sv << Kv.second->ClientId;
+ AuthState.EndObject();
+ }
+ AuthState.EndArray();
+ }
+ }
+
+ {
+ std::unique_lock _(m_TokenMutex);
+
+ AuthState.BeginArray("OpenIdTokens");
+ if (m_OpenIdTokens.size() > 0)
+ {
+ for (const auto& Kv : m_OpenIdTokens)
+ {
+ AuthState.BeginObject();
+ AuthState << "ProviderName"sv << Kv.first;
+ AuthState << "IdentityToken"sv << Kv.second.IdentityToken;
+ AuthState << "RefreshToken"sv << Kv.second.RefreshToken;
+ AuthState << "AccessToken"sv << Kv.second.AccessToken;
+ AuthState << "ExpireTime"sv << Kv.second.ExpireTime;
+ AuthState.EndObject();
+ }
+ }
+ AuthState.EndArray();
+ }
+
+ std::filesystem::create_directories(m_Config.RootDirectory);
+ WriteFile(m_Config.RootDirectory / "authstate"sv, AuthState.Save().GetBuffer().AsIoBuffer());
+ }
+
struct OpenIdProvider
{
- std::string Name;
- std::string Url;
- std::string ClientId;
+ std::string Name;
+ std::string Url;
+ std::string ClientId;
+ std::unique_ptr<OidcClient> HttpClient;
};
struct OpenIdToken
{
- std::string_view IdentityToken;
- std::string_view RefreshToken;
- std::string_view AccessToken;
- double ExpireTime{};
+ std::string IdentityToken;
+ std::string RefreshToken;
+ std::string AccessToken;
+ double ExpireTime{};
};
using OpenIdProviderMap = std::unordered_map<std::string, std::unique_ptr<OpenIdProvider>>;
@@ -103,6 +265,7 @@ private:
spdlog::logger& Log() { return m_Log; }
+ AuthConfig m_Config;
spdlog::logger& m_Log;
OpenIdProviderMap m_OpenIdProviders;
OpenIdTokenMap m_OpenIdTokens;
diff --git a/zenserver/auth/authservice.cpp b/zenserver/auth/authservice.cpp
index 20ea252fa..4e6f496a6 100644
--- a/zenserver/auth/authservice.cpp
+++ b/zenserver/auth/authservice.cpp
@@ -6,7 +6,9 @@
#include <zencore/compactbinarybuilder.h>
#include <zencore/string.h>
+ZEN_THIRD_PARTY_INCLUDES_START
#include <json11.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
namespace zen {
diff --git a/zenserver/auth/oidc.cpp b/zenserver/auth/oidc.cpp
new file mode 100644
index 000000000..2f53f1bae
--- /dev/null
+++ b/zenserver/auth/oidc.cpp
@@ -0,0 +1,127 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <auth/oidc.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <cpr/cpr.h>
+#include <fmt/format.h>
+#include <json11.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+namespace zen {
+
+namespace details {
+
+ using StringArray = std::vector<std::string>;
+
+ StringArray ToStringArray(const json11::Json JsonArray)
+ {
+ StringArray Result;
+
+ const auto& Items = JsonArray.array_items();
+
+ for (const auto& Item : Items)
+ {
+ Result.push_back(Item.string_value());
+ }
+
+ return Result;
+ }
+
+} // namespace details
+
+using namespace std::literals;
+
+OidcClient::OidcClient(const OidcClient::Options& Options)
+{
+ m_BaseUrl = std::string(Options.BaseUrl);
+ m_ClientId = std::string(Options.ClientId);
+}
+
+OidcClient::InitResult
+OidcClient::Initialize()
+{
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_BaseUrl << "/.well-known/openid-configuration"sv;
+
+ cpr::Session Session;
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+
+ cpr::Response Response = Session.Get();
+
+ if (Response.error)
+ {
+ return {.Reason = std::move(Response.error.message)};
+ }
+
+ if (Response.status_code != 200)
+ {
+ return {.Reason = std::move(Response.reason)};
+ }
+
+ std::string JsonError;
+ json11::Json Json = json11::Json::parse(Response.text, JsonError);
+
+ if (JsonError.empty() == false)
+ {
+ return {.Reason = std::move(JsonError)};
+ }
+
+ m_Config = {.Issuer = Json["issuer"].string_value(),
+ .AuthorizationEndpoint = Json["authorization_endpoint"].string_value(),
+ .TokenEndpoint = Json["token_endpoint"].string_value(),
+ .UserInfoEndpoint = Json["userinfo_endpoint"].string_value(),
+ .RegistrationEndpoint = Json["registration_endpoint"].string_value(),
+ .JwksUri = Json["jwks_uri"].string_value(),
+ .SupportedResponseTypes = details::ToStringArray(Json["response_types_supported"]),
+ .SupportedResponseModes = details::ToStringArray(Json["response_modes_supported"]),
+ .SupportedGrantTypes = details::ToStringArray(Json["grant_types_supported"]),
+ .SupportedScopes = details::ToStringArray(Json["scopes_supported"]),
+ .SupportedTokenEndpointAuthMethods = details::ToStringArray(Json["token_endpoint_auth_methods_supported"]),
+ .SupportedClaims = details::ToStringArray(Json["claims_supported"])};
+
+ return {.Ok = true};
+}
+
+OidcClient::RefreshTokenResult
+OidcClient::RefreshToken(std::string_view RefreshToken)
+{
+ const std::string Body = fmt::format("grant_type=refresh_token&refresh_token={}&client_id={}", RefreshToken, m_ClientId);
+
+ cpr::Session Session;
+
+ Session.SetOption(cpr::Url{m_Config.TokenEndpoint.c_str()});
+ Session.SetOption(cpr::Header{{"Content-Type", "application/x-www-form-urlencoded"}});
+ Session.SetBody(cpr::Body{Body.data(), Body.size()});
+
+ cpr::Response Response = Session.Post();
+
+ if (Response.error)
+ {
+ return {.Reason = std::move(Response.error.message)};
+ }
+
+ if (Response.status_code != 200)
+ {
+ return {.Reason = std::move(Response.reason)};
+ }
+
+ std::string JsonError;
+ json11::Json Json = json11::Json::parse(Response.text, JsonError);
+
+ if (JsonError.empty() == false)
+ {
+ return {.Reason = std::move(JsonError)};
+ }
+
+ return {.TokenType = Json["token_type"].string_value(),
+ .AccessToken = Json["access_token"].string_value(),
+ .RefreshToken = Json["refresh_token"].string_value(),
+ .IdentityToken = Json["id_token"].string_value(),
+ .Scope = Json["scope"].string_value(),
+ .ExpiresIn = Json["scope"].number_value(),
+ .Ok = true};
+}
+
+} // namespace zen
diff --git a/zenserver/auth/oidc.h b/zenserver/auth/oidc.h
new file mode 100644
index 000000000..b08181bfd
--- /dev/null
+++ b/zenserver/auth/oidc.h
@@ -0,0 +1,74 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zencore/string.h>
+
+#include <vector>
+
+namespace zen {
+
+class OidcClient
+{
+public:
+ struct Options
+ {
+ std::string_view BaseUrl;
+ std::string_view ClientId;
+ };
+
+ OidcClient(const Options& Options);
+ ~OidcClient() = default;
+
+ OidcClient(const OidcClient&) = delete;
+ OidcClient& operator=(const OidcClient&) = delete;
+
+ struct Result
+ {
+ std::string Reason;
+ bool Ok = false;
+ };
+
+ using InitResult = Result;
+
+ InitResult Initialize();
+
+ struct RefreshTokenResult
+ {
+ std::string TokenType;
+ std::string AccessToken;
+ std::string RefreshToken;
+ std::string IdentityToken;
+ std::string Scope;
+ std::string Reason;
+ double ExpiresIn{};
+ bool Ok = false;
+ };
+
+ RefreshTokenResult RefreshToken(std::string_view RefreshToken);
+
+private:
+ using StringArray = std::vector<std::string>;
+
+ struct OpenIdConfiguration
+ {
+ std::string Issuer;
+ std::string AuthorizationEndpoint;
+ std::string TokenEndpoint;
+ std::string UserInfoEndpoint;
+ std::string RegistrationEndpoint;
+ std::string EndSessionEndpoint;
+ std::string DeviceAuthorizationEndpoint;
+ std::string JwksUri;
+ StringArray SupportedResponseTypes;
+ StringArray SupportedResponseModes;
+ StringArray SupportedGrantTypes;
+ StringArray SupportedScopes;
+ StringArray SupportedTokenEndpointAuthMethods;
+ StringArray SupportedClaims;
+ };
+
+ std::string m_BaseUrl;
+ std::string m_ClientId;
+ OpenIdConfiguration m_Config;
+};
+
+} // namespace zen