aboutsummaryrefslogtreecommitdiff
path: root/zenserver/auth
diff options
context:
space:
mode:
Diffstat (limited to 'zenserver/auth')
-rw-r--r--zenserver/auth/authmgr.cpp296
-rw-r--r--zenserver/auth/authmgr.h50
-rw-r--r--zenserver/auth/authservice.cpp57
-rw-r--r--zenserver/auth/authservice.h5
-rw-r--r--zenserver/auth/oidc.cpp127
-rw-r--r--zenserver/auth/oidc.h74
6 files changed, 605 insertions, 4 deletions
diff --git a/zenserver/auth/authmgr.cpp b/zenserver/auth/authmgr.cpp
new file mode 100644
index 000000000..28e128fc0
--- /dev/null
+++ b/zenserver/auth/authmgr.cpp
@@ -0,0 +1,296 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <auth/authmgr.h>
+#include <auth/oidc.h>
+
+#include <zencore/compactbinary.h>
+#include <zencore/compactbinarybuilder.h>
+#include <zencore/compactbinaryvalidation.h>
+#include <zencore/filesystem.h>
+#include <zencore/logging.h>
+
+#include <chrono>
+#include <condition_variable>
+#include <memory>
+#include <shared_mutex>
+#include <thread>
+#include <unordered_map>
+
+#include <fmt/format.h>
+
+namespace zen {
+
+using namespace std::literals;
+
+class AuthMgrImpl final : public AuthMgr
+{
+public:
+ AuthMgrImpl(const AuthConfig& Config) : m_Config(Config), m_Log(logging::Get("auth")) { LoadState(); }
+
+ virtual ~AuthMgrImpl() { SaveState(); }
+
+ virtual void AddOpenIdProvider(const AddOpenIdProviderParams& Params) final
+ {
+ if (OpenIdProviderExist(Params.Name))
+ {
+ return;
+ }
+
+ std::unique_ptr<OidcClient> Client =
+ std::make_unique<OidcClient>(OidcClient::Options{.BaseUrl = Params.Url, .ClientId = Params.ClientId});
+
+ if (const auto InitResult = Client->Initialize(); InitResult.Ok == false)
+ {
+ ZEN_WARN("query OpenID provider FAILED, reason '{}'", InitResult.Reason);
+ return;
+ }
+
+ std::string NewProviderName = std::string(Params.Name);
+
+ OpenIdProvider* NewProvider = nullptr;
+
+ {
+ std::unique_lock _(m_ProviderMutex);
+
+ if (m_OpenIdProviders.contains(NewProviderName))
+ {
+ return;
+ }
+
+ auto InsertResult = m_OpenIdProviders.emplace(NewProviderName, std::make_unique<OpenIdProvider>());
+ NewProvider = InsertResult.first->second.get();
+ }
+
+ NewProvider->Name = std::string(Params.Name);
+ NewProvider->Url = std::string(Params.Url);
+ NewProvider->ClientId = std::string(Params.ClientId);
+ NewProvider->HttpClient = std::move(Client);
+
+ ZEN_INFO("added OpenID provider '{} - {}'", Params.Name, Params.Url);
+ }
+
+ virtual bool AddOpenIdToken(const AddOpenIdTokenParams& Params) final
+ {
+ if (Params.ProviderName.empty())
+ {
+ ZEN_WARN("trying add OpenID token with invalid provider name");
+ return false;
+ }
+
+ if (Params.IdentityToken.empty() || Params.RefreshToken.empty() || Params.AccessToken.empty())
+ {
+ ZEN_WARN("add OpenId token FAILED, reason 'Token invalid'");
+ return false;
+ }
+
+ auto RefreshResult = RefreshOpenIdToken(Params.ProviderName, Params.RefreshToken);
+
+ if (RefreshResult.Ok == false)
+ {
+ ZEN_WARN("refresh OpenId token FAILED, reason '{}'", RefreshResult.Reason);
+ return false;
+ }
+
+ bool IsNew = false;
+
+ {
+ std::unique_lock _(m_TokenMutex);
+
+ const auto InsertResult = m_OpenIdTokens.try_emplace(std::string(Params.ProviderName),
+ OpenIdToken{.IdentityToken = RefreshResult.IdentityToken,
+ .RefreshToken = RefreshResult.RefreshToken,
+ .AccessToken = RefreshResult.AccessToken});
+
+ IsNew = InsertResult.second;
+ }
+
+ if (IsNew)
+ {
+ ZEN_INFO("added new OpenID token for provider '{}'", Params.ProviderName);
+ }
+ else
+ {
+ ZEN_INFO("updating OpenID token for provider '{}'", Params.ProviderName);
+ }
+
+ return true;
+ }
+
+ virtual OpenIdAccessToken GetOpenIdAccessToken(std::string_view ProviderName) final
+ {
+ std::unique_lock _(m_TokenMutex);
+
+ if (auto It = m_OpenIdTokens.find(std::string(ProviderName)); It != m_OpenIdTokens.end())
+ {
+ const OpenIdToken& Token = It->second;
+
+ return {.AccessToken = Token.AccessToken};
+ }
+
+ return {};
+ }
+
+private:
+ bool OpenIdProviderExist(std::string_view ProviderName)
+ {
+ std::unique_lock _(m_ProviderMutex);
+
+ return m_OpenIdProviders.contains(std::string(ProviderName));
+ }
+
+ OidcClient& GetOpenIdClient(std::string_view ProviderName)
+ {
+ std::unique_lock _(m_ProviderMutex);
+ return *m_OpenIdProviders[std::string(ProviderName)]->HttpClient.get();
+ }
+
+ OidcClient::RefreshTokenResult RefreshOpenIdToken(std::string_view ProviderName, std::string_view RefreshToken)
+ {
+ if (OpenIdProviderExist(ProviderName) == false)
+ {
+ return {.Reason = fmt::format("provider '{}' is missing", ProviderName)};
+ }
+
+ OidcClient& Client = GetOpenIdClient(ProviderName);
+
+ return Client.RefreshToken(RefreshToken);
+ }
+
+ void Shutdown() { SaveState(); }
+
+ void LoadState()
+ {
+ FileContents Result = ReadFile(m_Config.RootDirectory / "authstate"sv);
+
+ if (Result.ErrorCode)
+ {
+ return;
+ }
+
+ IoBuffer Buffer = Result.Flatten();
+
+ const CbValidateError ValidationError = ValidateCompactBinary(Buffer, CbValidateMode::All);
+
+ if (ValidationError != CbValidateError::None)
+ {
+ ZEN_WARN("load serialized state FAILED, reason 'Invalid compact binary'");
+ return;
+ }
+
+ if (CbObject AuthState = LoadCompactBinaryObject(Buffer))
+ {
+ for (CbFieldView ProviderView : AuthState["OpenIdProviders"sv])
+ {
+ CbObjectView ProviderObj = ProviderView.AsObjectView();
+
+ std::string_view ProviderName = ProviderObj["Name"].AsString();
+ std::string_view Url = ProviderObj["Url"].AsString();
+ std::string_view ClientId = ProviderObj["ClientId"].AsString();
+
+ AddOpenIdProvider({.Name = ProviderName, .Url = Url, .ClientId = ClientId});
+ }
+
+ for (CbFieldView TokenView : AuthState["OpenIdTokens"sv])
+ {
+ CbObjectView TokenObj = TokenView.AsObjectView();
+
+ std::string_view ProviderName = TokenObj["ProviderName"sv].AsString();
+ std::string_view IdentityToken = TokenObj["IdentityToken"sv].AsString();
+ std::string_view RefreshToken = TokenObj["RefreshToken"sv].AsString();
+ std::string_view AccessToken = TokenObj["AccessToken"sv].AsString();
+
+ const bool Ok = AddOpenIdToken({.ProviderName = ProviderName,
+ .IdentityToken = IdentityToken,
+ .RefreshToken = RefreshToken,
+ .AccessToken = AccessToken});
+
+ if (!Ok)
+ {
+ ZEN_WARN("load serialized OpenId token for provider '{}' FAILED", ProviderName);
+ }
+ }
+ }
+ }
+
+ void SaveState()
+ {
+ CbObjectWriter AuthState;
+
+ {
+ std::unique_lock _(m_ProviderMutex);
+
+ if (m_OpenIdProviders.size() > 0)
+ {
+ AuthState.BeginArray("OpenIdProviders");
+ for (const auto& Kv : m_OpenIdProviders)
+ {
+ AuthState.BeginObject();
+ AuthState << "Name"sv << Kv.second->Name;
+ AuthState << "Url"sv << Kv.second->Url;
+ AuthState << "ClientId"sv << Kv.second->ClientId;
+ AuthState.EndObject();
+ }
+ AuthState.EndArray();
+ }
+ }
+
+ {
+ std::unique_lock _(m_TokenMutex);
+
+ AuthState.BeginArray("OpenIdTokens");
+ if (m_OpenIdTokens.size() > 0)
+ {
+ for (const auto& Kv : m_OpenIdTokens)
+ {
+ AuthState.BeginObject();
+ AuthState << "ProviderName"sv << Kv.first;
+ AuthState << "IdentityToken"sv << Kv.second.IdentityToken;
+ AuthState << "RefreshToken"sv << Kv.second.RefreshToken;
+ AuthState << "AccessToken"sv << Kv.second.AccessToken;
+ AuthState << "ExpireTime"sv << Kv.second.ExpireTime;
+ AuthState.EndObject();
+ }
+ }
+ AuthState.EndArray();
+ }
+
+ std::filesystem::create_directories(m_Config.RootDirectory);
+ WriteFile(m_Config.RootDirectory / "authstate"sv, AuthState.Save().GetBuffer().AsIoBuffer());
+ }
+
+ struct OpenIdProvider
+ {
+ std::string Name;
+ std::string Url;
+ std::string ClientId;
+ std::unique_ptr<OidcClient> HttpClient;
+ };
+
+ struct OpenIdToken
+ {
+ std::string IdentityToken;
+ std::string RefreshToken;
+ std::string AccessToken;
+ double ExpireTime{};
+ };
+
+ using OpenIdProviderMap = std::unordered_map<std::string, std::unique_ptr<OpenIdProvider>>;
+ using OpenIdTokenMap = std::unordered_map<std::string, OpenIdToken>;
+
+ spdlog::logger& Log() { return m_Log; }
+
+ AuthConfig m_Config;
+ spdlog::logger& m_Log;
+ OpenIdProviderMap m_OpenIdProviders;
+ OpenIdTokenMap m_OpenIdTokens;
+ std::mutex m_ProviderMutex;
+ std::shared_mutex m_TokenMutex;
+};
+
+std::unique_ptr<AuthMgr>
+MakeAuthMgr(const AuthConfig& Config)
+{
+ return std::make_unique<AuthMgrImpl>(Config);
+}
+
+} // namespace zen
diff --git a/zenserver/auth/authmgr.h b/zenserver/auth/authmgr.h
new file mode 100644
index 000000000..1138d9eff
--- /dev/null
+++ b/zenserver/auth/authmgr.h
@@ -0,0 +1,50 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zencore/string.h>
+
+#include <filesystem>
+#include <memory>
+
+namespace zen {
+
+class AuthMgr
+{
+public:
+ virtual ~AuthMgr() = default;
+
+ struct AddOpenIdProviderParams
+ {
+ std::string_view Name;
+ std::string_view Url;
+ std::string_view ClientId;
+ };
+
+ virtual void AddOpenIdProvider(const AddOpenIdProviderParams& Params) = 0;
+
+ struct AddOpenIdTokenParams
+ {
+ std::string_view ProviderName;
+ std::string_view IdentityToken;
+ std::string_view RefreshToken;
+ std::string_view AccessToken;
+ };
+
+
+ virtual bool AddOpenIdToken(const AddOpenIdTokenParams& Params) = 0;
+
+ struct OpenIdAccessToken
+ {
+ std::string AccessToken;
+ };
+
+ virtual OpenIdAccessToken GetOpenIdAccessToken(std::string_view ProviderName) = 0;
+};
+
+struct AuthConfig
+{
+ std::filesystem::path RootDirectory;
+};
+
+std::unique_ptr<AuthMgr> MakeAuthMgr(const AuthConfig& Config);
+
+} // namespace zen
diff --git a/zenserver/auth/authservice.cpp b/zenserver/auth/authservice.cpp
index eecad45bf..4e6f496a6 100644
--- a/zenserver/auth/authservice.cpp
+++ b/zenserver/auth/authservice.cpp
@@ -1,19 +1,70 @@
// Copyright Epic Games, Inc. All Rights Reserved.
+#include <auth/authmgr.h>
#include <auth/authservice.h>
+
+#include <zencore/compactbinarybuilder.h>
#include <zencore/string.h>
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <json11.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
namespace zen {
using namespace std::literals;
-HttpAuthService::HttpAuthService()
+HttpAuthService::HttpAuthService(AuthMgr& AuthMgr) : m_AuthMgr(AuthMgr)
{
m_Router.RegisterRoute(
"token",
- [](HttpRouterRequest& RouterRequest) {
+ [this](HttpRouterRequest& RouterRequest) {
HttpServerRequest& ServerRequest = RouterRequest.ServerRequest();
- ServerRequest.WriteResponse(HttpResponseCode::OK);
+
+ const HttpContentType ContentType = ServerRequest.RequestContentType();
+
+ if ((ContentType == HttpContentType::kUnknownContentType || ContentType == HttpContentType::kJSON) == false)
+ {
+ return ServerRequest.WriteResponse(HttpResponseCode::BadRequest);
+ }
+
+ const IoBuffer Body = ServerRequest.ReadPayload();
+
+ std::string JsonText(reinterpret_cast<const char*>(Body.GetData()), Body.GetSize());
+ std::string JsonError;
+ json11::Json TokenInfo = json11::Json::parse(JsonText, JsonError);
+
+ if (!JsonError.empty())
+ {
+ CbObjectWriter Response;
+ Response << "Result"sv << false;
+ Response << "Error"sv << JsonError;
+
+ return ServerRequest.WriteResponse(HttpResponseCode::BadRequest, Response.Save());
+ }
+
+ const std::string IdentityToken = TokenInfo["IdentityToken"].string_value();
+ const std::string RefreshToken = TokenInfo["RefreshToken"].string_value();
+ const std::string AccessToken = TokenInfo["AccessToken"].string_value();
+
+ const bool Ok = m_AuthMgr.AddOpenIdToken(AuthMgr::AddOpenIdTokenParams{.ProviderName = "Okta"sv,
+ .IdentityToken = IdentityToken,
+ .RefreshToken = RefreshToken,
+ .AccessToken = AccessToken});
+
+ if (Ok)
+ {
+ ServerRequest.WriteResponse(Ok ? HttpResponseCode::OK : HttpResponseCode::BadRequest);
+ }
+ else
+ {
+ CbObjectWriter Response;
+ Response << "Result"sv << false;
+ Response << "Error"sv
+ << "Invalid token"sv;
+
+ ServerRequest.WriteResponse(HttpResponseCode::BadRequest, Response.Save());
+ }
},
HttpVerb::kPost);
}
diff --git a/zenserver/auth/authservice.h b/zenserver/auth/authservice.h
index 30b2b5864..64b86e21f 100644
--- a/zenserver/auth/authservice.h
+++ b/zenserver/auth/authservice.h
@@ -6,16 +6,19 @@
namespace zen {
+class AuthMgr;
+
class HttpAuthService final : public zen::HttpService
{
public:
- HttpAuthService();
+ HttpAuthService(AuthMgr& AuthMgr);
virtual ~HttpAuthService();
virtual const char* BaseUri() const override;
virtual void HandleRequest(zen::HttpServerRequest& Request) override;
private:
+ AuthMgr& m_AuthMgr;
HttpRequestRouter m_Router;
};
diff --git a/zenserver/auth/oidc.cpp b/zenserver/auth/oidc.cpp
new file mode 100644
index 000000000..2f53f1bae
--- /dev/null
+++ b/zenserver/auth/oidc.cpp
@@ -0,0 +1,127 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <auth/oidc.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <cpr/cpr.h>
+#include <fmt/format.h>
+#include <json11.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+namespace zen {
+
+namespace details {
+
+ using StringArray = std::vector<std::string>;
+
+ StringArray ToStringArray(const json11::Json JsonArray)
+ {
+ StringArray Result;
+
+ const auto& Items = JsonArray.array_items();
+
+ for (const auto& Item : Items)
+ {
+ Result.push_back(Item.string_value());
+ }
+
+ return Result;
+ }
+
+} // namespace details
+
+using namespace std::literals;
+
+OidcClient::OidcClient(const OidcClient::Options& Options)
+{
+ m_BaseUrl = std::string(Options.BaseUrl);
+ m_ClientId = std::string(Options.ClientId);
+}
+
+OidcClient::InitResult
+OidcClient::Initialize()
+{
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_BaseUrl << "/.well-known/openid-configuration"sv;
+
+ cpr::Session Session;
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+
+ cpr::Response Response = Session.Get();
+
+ if (Response.error)
+ {
+ return {.Reason = std::move(Response.error.message)};
+ }
+
+ if (Response.status_code != 200)
+ {
+ return {.Reason = std::move(Response.reason)};
+ }
+
+ std::string JsonError;
+ json11::Json Json = json11::Json::parse(Response.text, JsonError);
+
+ if (JsonError.empty() == false)
+ {
+ return {.Reason = std::move(JsonError)};
+ }
+
+ m_Config = {.Issuer = Json["issuer"].string_value(),
+ .AuthorizationEndpoint = Json["authorization_endpoint"].string_value(),
+ .TokenEndpoint = Json["token_endpoint"].string_value(),
+ .UserInfoEndpoint = Json["userinfo_endpoint"].string_value(),
+ .RegistrationEndpoint = Json["registration_endpoint"].string_value(),
+ .JwksUri = Json["jwks_uri"].string_value(),
+ .SupportedResponseTypes = details::ToStringArray(Json["response_types_supported"]),
+ .SupportedResponseModes = details::ToStringArray(Json["response_modes_supported"]),
+ .SupportedGrantTypes = details::ToStringArray(Json["grant_types_supported"]),
+ .SupportedScopes = details::ToStringArray(Json["scopes_supported"]),
+ .SupportedTokenEndpointAuthMethods = details::ToStringArray(Json["token_endpoint_auth_methods_supported"]),
+ .SupportedClaims = details::ToStringArray(Json["claims_supported"])};
+
+ return {.Ok = true};
+}
+
+OidcClient::RefreshTokenResult
+OidcClient::RefreshToken(std::string_view RefreshToken)
+{
+ const std::string Body = fmt::format("grant_type=refresh_token&refresh_token={}&client_id={}", RefreshToken, m_ClientId);
+
+ cpr::Session Session;
+
+ Session.SetOption(cpr::Url{m_Config.TokenEndpoint.c_str()});
+ Session.SetOption(cpr::Header{{"Content-Type", "application/x-www-form-urlencoded"}});
+ Session.SetBody(cpr::Body{Body.data(), Body.size()});
+
+ cpr::Response Response = Session.Post();
+
+ if (Response.error)
+ {
+ return {.Reason = std::move(Response.error.message)};
+ }
+
+ if (Response.status_code != 200)
+ {
+ return {.Reason = std::move(Response.reason)};
+ }
+
+ std::string JsonError;
+ json11::Json Json = json11::Json::parse(Response.text, JsonError);
+
+ if (JsonError.empty() == false)
+ {
+ return {.Reason = std::move(JsonError)};
+ }
+
+ return {.TokenType = Json["token_type"].string_value(),
+ .AccessToken = Json["access_token"].string_value(),
+ .RefreshToken = Json["refresh_token"].string_value(),
+ .IdentityToken = Json["id_token"].string_value(),
+ .Scope = Json["scope"].string_value(),
+ .ExpiresIn = Json["scope"].number_value(),
+ .Ok = true};
+}
+
+} // namespace zen
diff --git a/zenserver/auth/oidc.h b/zenserver/auth/oidc.h
new file mode 100644
index 000000000..b08181bfd
--- /dev/null
+++ b/zenserver/auth/oidc.h
@@ -0,0 +1,74 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zencore/string.h>
+
+#include <vector>
+
+namespace zen {
+
+class OidcClient
+{
+public:
+ struct Options
+ {
+ std::string_view BaseUrl;
+ std::string_view ClientId;
+ };
+
+ OidcClient(const Options& Options);
+ ~OidcClient() = default;
+
+ OidcClient(const OidcClient&) = delete;
+ OidcClient& operator=(const OidcClient&) = delete;
+
+ struct Result
+ {
+ std::string Reason;
+ bool Ok = false;
+ };
+
+ using InitResult = Result;
+
+ InitResult Initialize();
+
+ struct RefreshTokenResult
+ {
+ std::string TokenType;
+ std::string AccessToken;
+ std::string RefreshToken;
+ std::string IdentityToken;
+ std::string Scope;
+ std::string Reason;
+ double ExpiresIn{};
+ bool Ok = false;
+ };
+
+ RefreshTokenResult RefreshToken(std::string_view RefreshToken);
+
+private:
+ using StringArray = std::vector<std::string>;
+
+ struct OpenIdConfiguration
+ {
+ std::string Issuer;
+ std::string AuthorizationEndpoint;
+ std::string TokenEndpoint;
+ std::string UserInfoEndpoint;
+ std::string RegistrationEndpoint;
+ std::string EndSessionEndpoint;
+ std::string DeviceAuthorizationEndpoint;
+ std::string JwksUri;
+ StringArray SupportedResponseTypes;
+ StringArray SupportedResponseModes;
+ StringArray SupportedGrantTypes;
+ StringArray SupportedScopes;
+ StringArray SupportedTokenEndpointAuthMethods;
+ StringArray SupportedClaims;
+ };
+
+ std::string m_BaseUrl;
+ std::string m_ClientId;
+ OpenIdConfiguration m_Config;
+};
+
+} // namespace zen