diff options
| author | Stefan Boberg <[email protected]> | 2023-05-02 10:01:47 +0200 |
|---|---|---|
| committer | GitHub <[email protected]> | 2023-05-02 10:01:47 +0200 |
| commit | 075d17f8ada47e990fe94606c3d21df409223465 (patch) | |
| tree | e50549b766a2f3c354798a54ff73404217b4c9af /src/zenserver | |
| parent | fix: bundle shouldn't append content zip to zen (diff) | |
| download | zen-075d17f8ada47e990fe94606c3d21df409223465.tar.xz zen-075d17f8ada47e990fe94606c3d21df409223465.zip | |
moved source directories into `/src` (#264)
* moved source directories into `/src`
* updated bundle.lua for new `src` path
* moved some docs, icon
* removed old test trees
Diffstat (limited to 'src/zenserver')
67 files changed, 27610 insertions, 0 deletions
diff --git a/src/zenserver/admin/admin.cpp b/src/zenserver/admin/admin.cpp new file mode 100644 index 000000000..7aa1b48d1 --- /dev/null +++ b/src/zenserver/admin/admin.cpp @@ -0,0 +1,101 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "admin.h" + +#include <zencore/compactbinarybuilder.h> +#include <zencore/string.h> +#include <zenstore/gc.h> + +#include <chrono> + +namespace zen { + +HttpAdminService::HttpAdminService(GcScheduler& Scheduler) : m_GcScheduler(Scheduler) +{ + using namespace std::literals; + + m_Router.RegisterRoute( + "health", + [](HttpRouterRequest& Req) { + CbObjectWriter Obj; + Obj.AddBool("ok", true); + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Obj.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "gc", + [this](HttpRouterRequest& Req) { + const GcSchedulerStatus Status = m_GcScheduler.Status(); + + CbObjectWriter Response; + Response << "Status"sv << (GcSchedulerStatus::kIdle == Status ? "Idle"sv : "Running"sv); + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Response.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "gc", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const HttpServerRequest::QueryParams Params = HttpReq.GetQueryParams(); + GcScheduler::TriggerParams GcParams; + + if (auto Param = Params.GetValue("smallobjects"); Param.empty() == false) + { + GcParams.CollectSmallObjects = Param == "true"sv; + } + + if (auto Param = Params.GetValue("maxcacheduration"); Param.empty() == false) + { + if (auto Value = ParseInt<uint64_t>(Param)) + { + GcParams.MaxCacheDuration = std::chrono::seconds(Value.value()); + } + } + + if (auto Param = Params.GetValue("disksizesoftlimit"); Param.empty() == false) + { + if (auto Value = ParseInt<uint64_t>(Param)) + { + GcParams.DiskSizeSoftLimit = Value.value(); + } + } + + const bool Started = m_GcScheduler.Trigger(GcParams); + + CbObjectWriter Response; + Response << "Status"sv << (Started ? "Started"sv : "Running"sv); + HttpReq.WriteResponse(HttpResponseCode::OK, Response.Save()); + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "", + [](HttpRouterRequest& Req) { + CbObject Payload = Req.ServerRequest().ReadPayloadObject(); + + CbObjectWriter Obj; + Obj.AddBool("ok", true); + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Obj.Save()); + }, + HttpVerb::kPost); +} + +HttpAdminService::~HttpAdminService() +{ +} + +const char* +HttpAdminService::BaseUri() const +{ + return "/admin/"; +} + +void +HttpAdminService::HandleRequest(zen::HttpServerRequest& Request) +{ + m_Router.HandleRequest(Request); +} + +} // namespace zen diff --git a/src/zenserver/admin/admin.h b/src/zenserver/admin/admin.h new file mode 100644 index 000000000..9463ffbb3 --- /dev/null +++ b/src/zenserver/admin/admin.h @@ -0,0 +1,26 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/compactbinary.h> +#include <zenhttp/httpserver.h> + +namespace zen { + +class GcScheduler; + +class HttpAdminService : public zen::HttpService +{ +public: + HttpAdminService(GcScheduler& Scheduler); + ~HttpAdminService(); + + virtual const char* BaseUri() const override; + virtual void HandleRequest(zen::HttpServerRequest& Request) override; + +private: + HttpRequestRouter m_Router; + GcScheduler& m_GcScheduler; +}; + +} // namespace zen diff --git a/src/zenserver/auth/authmgr.cpp b/src/zenserver/auth/authmgr.cpp new file mode 100644 index 000000000..4cd6b3362 --- /dev/null +++ b/src/zenserver/auth/authmgr.cpp @@ -0,0 +1,506 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <auth/authmgr.h> +#include <auth/oidc.h> + +#include <zencore/compactbinary.h> +#include <zencore/compactbinarybuilder.h> +#include <zencore/compactbinaryvalidation.h> +#include <zencore/crypto.h> +#include <zencore/filesystem.h> +#include <zencore/logging.h> + +#include <condition_variable> +#include <memory> +#include <shared_mutex> +#include <thread> +#include <unordered_map> + +#include <fmt/format.h> + +namespace zen { + +using namespace std::literals; + +namespace details { + IoBuffer ReadEncryptedFile(std::filesystem::path Path, + const AesKey256Bit& Key, + const AesIV128Bit& IV, + std::optional<std::string>& Reason) + { + FileContents Result = ReadFile(Path); + + if (Result.ErrorCode) + { + return IoBuffer(); + } + + IoBuffer EncryptedBuffer = Result.Flatten(); + + if (EncryptedBuffer.GetSize() == 0) + { + return IoBuffer(); + } + + std::vector<uint8_t> DecryptionBuffer; + DecryptionBuffer.resize(EncryptedBuffer.GetSize() + Aes::BlockSize); + + MemoryView DecryptedView = Aes::Decrypt(Key, IV, EncryptedBuffer, MakeMutableMemoryView(DecryptionBuffer), Reason); + + if (DecryptedView.IsEmpty()) + { + return IoBuffer(); + } + + return IoBufferBuilder::MakeCloneFromMemory(DecryptedView); + } + + void WriteEncryptedFile(std::filesystem::path Path, + IoBuffer FileData, + const AesKey256Bit& Key, + const AesIV128Bit& IV, + std::optional<std::string>& Reason) + { + if (FileData.GetSize() == 0) + { + return; + } + + std::vector<uint8_t> EncryptionBuffer; + EncryptionBuffer.resize(FileData.GetSize() + Aes::BlockSize); + + MemoryView EncryptedView = Aes::Encrypt(Key, IV, FileData, MakeMutableMemoryView(EncryptionBuffer), Reason); + + if (EncryptedView.IsEmpty()) + { + return; + } + + WriteFile(Path, IoBuffer(IoBuffer::Wrap, EncryptedView.GetData(), EncryptedView.GetSize())); + } +} // namespace details + +class AuthMgrImpl final : public AuthMgr +{ + using Clock = std::chrono::system_clock; + using TimePoint = Clock::time_point; + using Seconds = std::chrono::seconds; + +public: + AuthMgrImpl(const AuthConfig& Config) : m_Config(Config), m_Log(logging::Get("auth")) + { + LoadState(); + + m_BackgroundThread.Interval = Config.UpdateInterval; + m_BackgroundThread.Thread = std::thread(&AuthMgrImpl::BackgroundThreadEntry, this); + } + + virtual ~AuthMgrImpl() { Shutdown(); } + + virtual void AddOpenIdProvider(const AddOpenIdProviderParams& Params) final + { + if (OpenIdProviderExist(Params.Name)) + { + ZEN_DEBUG("OpenID provider '{}' already exist", Params.Name); + return; + } + + if (Params.Name.empty()) + { + ZEN_WARN("add OpenID provider FAILED, reason 'invalid name'"); + return; + } + + std::unique_ptr<OidcClient> Client = + std::make_unique<OidcClient>(OidcClient::Options{.BaseUrl = Params.Url, .ClientId = Params.ClientId}); + + if (const auto InitResult = Client->Initialize(); InitResult.Ok == false) + { + ZEN_WARN("query OpenID provider FAILED, reason '{}'", InitResult.Reason); + return; + } + + std::string NewProviderName = std::string(Params.Name); + + OpenIdProvider* NewProvider = nullptr; + + { + std::unique_lock _(m_ProviderMutex); + + if (m_OpenIdProviders.contains(NewProviderName)) + { + return; + } + + auto InsertResult = m_OpenIdProviders.emplace(NewProviderName, std::make_unique<OpenIdProvider>()); + NewProvider = InsertResult.first->second.get(); + } + + NewProvider->Name = std::string(Params.Name); + NewProvider->Url = std::string(Params.Url); + NewProvider->ClientId = std::string(Params.ClientId); + NewProvider->HttpClient = std::move(Client); + + ZEN_INFO("added OpenID provider '{} - {}'", Params.Name, Params.Url); + } + + virtual bool AddOpenIdToken(const AddOpenIdTokenParams& Params) final + { + if (Params.ProviderName.empty()) + { + ZEN_WARN("trying add OpenID token with invalid provider name"); + return false; + } + + if (Params.RefreshToken.empty()) + { + ZEN_WARN("add OpenID token FAILED, reason 'Token invalid'"); + return false; + } + + auto RefreshResult = RefreshOpenIdToken(Params.ProviderName, Params.RefreshToken); + + if (RefreshResult.Ok == false) + { + ZEN_WARN("refresh OpenId token FAILED, reason '{}'", RefreshResult.Reason); + return false; + } + + bool IsNew = false; + + { + auto Token = OpenIdToken{.IdentityToken = RefreshResult.IdentityToken, + .RefreshToken = RefreshResult.RefreshToken, + .AccessToken = fmt::format("Bearer {}"sv, RefreshResult.AccessToken), + .ExpireTime = Clock::now() + Seconds(RefreshResult.ExpiresInSeconds)}; + + std::unique_lock _(m_TokenMutex); + + const auto InsertResult = m_OpenIdTokens.insert_or_assign(std::string(Params.ProviderName), std::move(Token)); + + IsNew = InsertResult.second; + } + + if (IsNew) + { + ZEN_INFO("added new OpenID token for provider '{}'", Params.ProviderName); + } + else + { + ZEN_INFO("updating OpenID token for provider '{}'", Params.ProviderName); + } + + return true; + } + + virtual OpenIdAccessToken GetOpenIdAccessToken(std::string_view ProviderName) final + { + std::unique_lock _(m_TokenMutex); + + if (auto It = m_OpenIdTokens.find(std::string(ProviderName)); It != m_OpenIdTokens.end()) + { + const OpenIdToken& Token = It->second; + + return {.AccessToken = Token.AccessToken, .ExpireTime = Token.ExpireTime}; + } + + return {}; + } + +private: + bool OpenIdProviderExist(std::string_view ProviderName) + { + std::unique_lock _(m_ProviderMutex); + + return m_OpenIdProviders.contains(std::string(ProviderName)); + } + + OidcClient& GetOpenIdClient(std::string_view ProviderName) + { + std::unique_lock _(m_ProviderMutex); + return *m_OpenIdProviders[std::string(ProviderName)]->HttpClient.get(); + } + + OidcClient::RefreshTokenResult RefreshOpenIdToken(std::string_view ProviderName, std::string_view RefreshToken) + { + if (OpenIdProviderExist(ProviderName) == false) + { + return {.Reason = fmt::format("provider '{}' is missing", ProviderName)}; + } + + OidcClient& Client = GetOpenIdClient(ProviderName); + + return Client.RefreshToken(RefreshToken); + } + + void Shutdown() + { + BackgroundThread::Stop(m_BackgroundThread); + SaveState(); + } + + void LoadState() + { + try + { + std::optional<std::string> Reason; + + IoBuffer Buffer = + details::ReadEncryptedFile(m_Config.RootDirectory / "authstate"sv, m_Config.EncryptionKey, m_Config.EncryptionIV, Reason); + + if (!Buffer) + { + if (Reason) + { + ZEN_WARN("load auth state FAILED, reason '{}'", Reason.value()); + } + + return; + } + + const CbValidateError ValidationError = ValidateCompactBinary(Buffer, CbValidateMode::All); + + if (ValidationError != CbValidateError::None) + { + ZEN_WARN("load serialized state FAILED, reason 'Invalid compact binary'"); + return; + } + + if (CbObject AuthState = LoadCompactBinaryObject(Buffer)) + { + for (CbFieldView ProviderView : AuthState["OpenIdProviders"sv]) + { + CbObjectView ProviderObj = ProviderView.AsObjectView(); + + std::string_view ProviderName = ProviderObj["Name"].AsString(); + std::string_view Url = ProviderObj["Url"].AsString(); + std::string_view ClientId = ProviderObj["ClientId"].AsString(); + + AddOpenIdProvider({.Name = ProviderName, .Url = Url, .ClientId = ClientId}); + } + + for (CbFieldView TokenView : AuthState["OpenIdTokens"sv]) + { + CbObjectView TokenObj = TokenView.AsObjectView(); + + std::string_view ProviderName = TokenObj["ProviderName"sv].AsString(); + std::string_view RefreshToken = TokenObj["RefreshToken"sv].AsString(); + + const bool Ok = AddOpenIdToken({.ProviderName = ProviderName, .RefreshToken = RefreshToken}); + + if (!Ok) + { + ZEN_WARN("load serialized OpenId token for provider '{}' FAILED", ProviderName); + } + } + } + } + catch (std::exception& Err) + { + ZEN_ERROR("(de)serialize state FAILED, reason '{}'", Err.what()); + + { + std::unique_lock _(m_ProviderMutex); + m_OpenIdProviders.clear(); + } + + { + std::unique_lock _(m_TokenMutex); + m_OpenIdTokens.clear(); + } + } + } + + void SaveState() + { + try + { + CbObjectWriter AuthState; + + { + std::unique_lock _(m_ProviderMutex); + + if (m_OpenIdProviders.size() > 0) + { + AuthState.BeginArray("OpenIdProviders"); + for (const auto& Kv : m_OpenIdProviders) + { + AuthState.BeginObject(); + AuthState << "Name"sv << Kv.second->Name; + AuthState << "Url"sv << Kv.second->Url; + AuthState << "ClientId"sv << Kv.second->ClientId; + AuthState.EndObject(); + } + AuthState.EndArray(); + } + } + + { + std::unique_lock _(m_TokenMutex); + + AuthState.BeginArray("OpenIdTokens"); + if (m_OpenIdTokens.size() > 0) + { + for (const auto& Kv : m_OpenIdTokens) + { + AuthState.BeginObject(); + AuthState << "ProviderName"sv << Kv.first; + AuthState << "RefreshToken"sv << Kv.second.RefreshToken; + AuthState.EndObject(); + } + } + AuthState.EndArray(); + } + + std::filesystem::create_directories(m_Config.RootDirectory); + + std::optional<std::string> Reason; + + details::WriteEncryptedFile(m_Config.RootDirectory / "authstate"sv, + AuthState.Save().GetBuffer().AsIoBuffer(), + m_Config.EncryptionKey, + m_Config.EncryptionIV, + Reason); + + if (Reason) + { + ZEN_WARN("save auth state FAILED, reason '{}'", Reason.value()); + } + } + catch (std::exception& Err) + { + ZEN_ERROR("serialize state FAILED, reason '{}'", Err.what()); + } + } + + void BackgroundThreadEntry() + { + for (;;) + { + std::cv_status SignalStatus = BackgroundThread::WaitForSignal(m_BackgroundThread); + + if (m_BackgroundThread.Running.load() == false) + { + break; + } + + if (SignalStatus != std::cv_status::timeout) + { + continue; + } + + { + // Refresh Open ID token(s) + + std::vector<OpenIdTokenMap::value_type> ExpiredTokens; + + { + std::unique_lock _(m_TokenMutex); + + for (const auto& Kv : m_OpenIdTokens) + { + const Seconds ExpiresIn = std::chrono::duration_cast<Seconds>(Kv.second.ExpireTime - Clock::now()); + const bool Expired = ExpiresIn < Seconds(m_BackgroundThread.Interval * 2); + + if (Expired) + { + ExpiredTokens.push_back(Kv); + } + } + } + + ZEN_DEBUG("refreshing '{}' OpenID token(s)", ExpiredTokens.size()); + + for (const auto& Kv : ExpiredTokens) + { + OidcClient::RefreshTokenResult RefreshResult = RefreshOpenIdToken(Kv.first, Kv.second.RefreshToken); + + if (RefreshResult.Ok) + { + ZEN_DEBUG("refresh access token from provider '{}' Ok", Kv.first); + + auto Token = OpenIdToken{.IdentityToken = RefreshResult.IdentityToken, + .RefreshToken = RefreshResult.RefreshToken, + .AccessToken = fmt::format("Bearer {}"sv, RefreshResult.AccessToken), + .ExpireTime = Clock::now() + Seconds(RefreshResult.ExpiresInSeconds)}; + + { + std::unique_lock _(m_TokenMutex); + m_OpenIdTokens.insert_or_assign(Kv.first, std::move(Token)); + } + } + else + { + ZEN_WARN("refresh access token from provider '{}' FAILED, reason '{}'", Kv.first, RefreshResult.Reason); + } + } + } + } + } + + struct BackgroundThread + { + std::chrono::seconds Interval{10}; + std::mutex Mutex; + std::condition_variable Signal; + std::atomic_bool Running{true}; + std::thread Thread; + + static void Stop(BackgroundThread& State) + { + if (State.Running.load()) + { + State.Running.store(false); + State.Signal.notify_one(); + } + + if (State.Thread.joinable()) + { + State.Thread.join(); + } + } + + static std::cv_status WaitForSignal(BackgroundThread& State) + { + std::unique_lock Lock(State.Mutex); + return State.Signal.wait_for(Lock, State.Interval); + } + }; + + struct OpenIdProvider + { + std::string Name; + std::string Url; + std::string ClientId; + std::unique_ptr<OidcClient> HttpClient; + }; + + struct OpenIdToken + { + std::string IdentityToken; + std::string RefreshToken; + std::string AccessToken; + TimePoint ExpireTime{}; + }; + + using OpenIdProviderMap = std::unordered_map<std::string, std::unique_ptr<OpenIdProvider>>; + using OpenIdTokenMap = std::unordered_map<std::string, OpenIdToken>; + + spdlog::logger& Log() { return m_Log; } + + AuthConfig m_Config; + spdlog::logger& m_Log; + BackgroundThread m_BackgroundThread; + OpenIdProviderMap m_OpenIdProviders; + OpenIdTokenMap m_OpenIdTokens; + std::mutex m_ProviderMutex; + std::shared_mutex m_TokenMutex; +}; + +std::unique_ptr<AuthMgr> +AuthMgr::Create(const AuthConfig& Config) +{ + return std::make_unique<AuthMgrImpl>(Config); +} + +} // namespace zen diff --git a/src/zenserver/auth/authmgr.h b/src/zenserver/auth/authmgr.h new file mode 100644 index 000000000..054588ab9 --- /dev/null +++ b/src/zenserver/auth/authmgr.h @@ -0,0 +1,56 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/crypto.h> +#include <zencore/iobuffer.h> +#include <zencore/string.h> + +#include <chrono> +#include <filesystem> +#include <memory> + +namespace zen { + +struct AuthConfig +{ + std::filesystem::path RootDirectory; + std::chrono::seconds UpdateInterval{30}; + AesKey256Bit EncryptionKey; + AesIV128Bit EncryptionIV; +}; + +class AuthMgr +{ +public: + virtual ~AuthMgr() = default; + + struct AddOpenIdProviderParams + { + std::string_view Name; + std::string_view Url; + std::string_view ClientId; + }; + + virtual void AddOpenIdProvider(const AddOpenIdProviderParams& Params) = 0; + + struct AddOpenIdTokenParams + { + std::string_view ProviderName; + std::string_view RefreshToken; + }; + + virtual bool AddOpenIdToken(const AddOpenIdTokenParams& Params) = 0; + + struct OpenIdAccessToken + { + std::string AccessToken; + std::chrono::system_clock::time_point ExpireTime{}; + }; + + virtual OpenIdAccessToken GetOpenIdAccessToken(std::string_view ProviderName) = 0; + + static std::unique_ptr<AuthMgr> Create(const AuthConfig& Config); +}; + +} // namespace zen diff --git a/src/zenserver/auth/authservice.cpp b/src/zenserver/auth/authservice.cpp new file mode 100644 index 000000000..1cc679540 --- /dev/null +++ b/src/zenserver/auth/authservice.cpp @@ -0,0 +1,91 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <auth/authservice.h> + +#include <auth/authmgr.h> + +#include <zencore/compactbinarybuilder.h> +#include <zencore/string.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <json11.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen { + +using namespace std::literals; + +HttpAuthService::HttpAuthService(AuthMgr& AuthMgr) : m_AuthMgr(AuthMgr) +{ + m_Router.RegisterRoute( + "oidc/refreshtoken", + [this](HttpRouterRequest& RouterRequest) { + HttpServerRequest& ServerRequest = RouterRequest.ServerRequest(); + + const HttpContentType ContentType = ServerRequest.RequestContentType(); + + if ((ContentType == HttpContentType::kUnknownContentType || ContentType == HttpContentType::kJSON) == false) + { + return ServerRequest.WriteResponse(HttpResponseCode::BadRequest); + } + + const IoBuffer Body = ServerRequest.ReadPayload(); + + std::string JsonText(reinterpret_cast<const char*>(Body.GetData()), Body.GetSize()); + std::string JsonError; + json11::Json TokenInfo = json11::Json::parse(JsonText, JsonError); + + if (!JsonError.empty()) + { + CbObjectWriter Response; + Response << "Result"sv << false; + Response << "Error"sv << JsonError; + + return ServerRequest.WriteResponse(HttpResponseCode::BadRequest, Response.Save()); + } + + const std::string RefreshToken = TokenInfo["RefreshToken"].string_value(); + std::string ProviderName = TokenInfo["ProviderName"].string_value(); + + if (ProviderName.empty()) + { + ProviderName = "Default"sv; + } + + const bool Ok = + m_AuthMgr.AddOpenIdToken(AuthMgr::AddOpenIdTokenParams{.ProviderName = ProviderName, .RefreshToken = RefreshToken}); + + if (Ok) + { + ServerRequest.WriteResponse(Ok ? HttpResponseCode::OK : HttpResponseCode::BadRequest); + } + else + { + CbObjectWriter Response; + Response << "Result"sv << false; + Response << "Error"sv + << "Invalid token"sv; + + ServerRequest.WriteResponse(HttpResponseCode::BadRequest, Response.Save()); + } + }, + HttpVerb::kPost); +} + +HttpAuthService::~HttpAuthService() +{ +} + +const char* +HttpAuthService::BaseUri() const +{ + return "/auth/"; +} + +void +HttpAuthService::HandleRequest(zen::HttpServerRequest& Request) +{ + m_Router.HandleRequest(Request); +} + +} // namespace zen diff --git a/src/zenserver/auth/authservice.h b/src/zenserver/auth/authservice.h new file mode 100644 index 000000000..64b86e21f --- /dev/null +++ b/src/zenserver/auth/authservice.h @@ -0,0 +1,25 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenhttp/httpserver.h> + +namespace zen { + +class AuthMgr; + +class HttpAuthService final : public zen::HttpService +{ +public: + HttpAuthService(AuthMgr& AuthMgr); + virtual ~HttpAuthService(); + + virtual const char* BaseUri() const override; + virtual void HandleRequest(zen::HttpServerRequest& Request) override; + +private: + AuthMgr& m_AuthMgr; + HttpRequestRouter m_Router; +}; + +} // namespace zen diff --git a/src/zenserver/auth/oidc.cpp b/src/zenserver/auth/oidc.cpp new file mode 100644 index 000000000..d2265c22f --- /dev/null +++ b/src/zenserver/auth/oidc.cpp @@ -0,0 +1,127 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <auth/oidc.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <cpr/cpr.h> +#include <fmt/format.h> +#include <json11.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen { + +namespace details { + + using StringArray = std::vector<std::string>; + + StringArray ToStringArray(const json11::Json JsonArray) + { + StringArray Result; + + const auto& Items = JsonArray.array_items(); + + for (const auto& Item : Items) + { + Result.push_back(Item.string_value()); + } + + return Result; + } + +} // namespace details + +using namespace std::literals; + +OidcClient::OidcClient(const OidcClient::Options& Options) +{ + m_BaseUrl = std::string(Options.BaseUrl); + m_ClientId = std::string(Options.ClientId); +} + +OidcClient::InitResult +OidcClient::Initialize() +{ + ExtendableStringBuilder<256> Uri; + Uri << m_BaseUrl << "/.well-known/openid-configuration"sv; + + cpr::Session Session; + + Session.SetOption(cpr::Url{Uri.c_str()}); + + cpr::Response Response = Session.Get(); + + if (Response.error) + { + return {.Reason = std::move(Response.error.message)}; + } + + if (Response.status_code != 200) + { + return {.Reason = std::move(Response.reason)}; + } + + std::string JsonError; + json11::Json Json = json11::Json::parse(Response.text, JsonError); + + if (JsonError.empty() == false) + { + return {.Reason = std::move(JsonError)}; + } + + m_Config = {.Issuer = Json["issuer"].string_value(), + .AuthorizationEndpoint = Json["authorization_endpoint"].string_value(), + .TokenEndpoint = Json["token_endpoint"].string_value(), + .UserInfoEndpoint = Json["userinfo_endpoint"].string_value(), + .RegistrationEndpoint = Json["registration_endpoint"].string_value(), + .JwksUri = Json["jwks_uri"].string_value(), + .SupportedResponseTypes = details::ToStringArray(Json["response_types_supported"]), + .SupportedResponseModes = details::ToStringArray(Json["response_modes_supported"]), + .SupportedGrantTypes = details::ToStringArray(Json["grant_types_supported"]), + .SupportedScopes = details::ToStringArray(Json["scopes_supported"]), + .SupportedTokenEndpointAuthMethods = details::ToStringArray(Json["token_endpoint_auth_methods_supported"]), + .SupportedClaims = details::ToStringArray(Json["claims_supported"])}; + + return {.Ok = true}; +} + +OidcClient::RefreshTokenResult +OidcClient::RefreshToken(std::string_view RefreshToken) +{ + const std::string Body = fmt::format("grant_type=refresh_token&refresh_token={}&client_id={}", RefreshToken, m_ClientId); + + cpr::Session Session; + + Session.SetOption(cpr::Url{m_Config.TokenEndpoint.c_str()}); + Session.SetOption(cpr::Header{{"Content-Type", "application/x-www-form-urlencoded"}}); + Session.SetBody(cpr::Body{Body.data(), Body.size()}); + + cpr::Response Response = Session.Post(); + + if (Response.error) + { + return {.Reason = std::move(Response.error.message)}; + } + + if (Response.status_code != 200) + { + return {.Reason = fmt::format("{} ({})", Response.reason, Response.text)}; + } + + std::string JsonError; + json11::Json Json = json11::Json::parse(Response.text, JsonError); + + if (JsonError.empty() == false) + { + return {.Reason = std::move(JsonError)}; + } + + return {.TokenType = Json["token_type"].string_value(), + .AccessToken = Json["access_token"].string_value(), + .RefreshToken = Json["refresh_token"].string_value(), + .IdentityToken = Json["id_token"].string_value(), + .Scope = Json["scope"].string_value(), + .ExpiresInSeconds = static_cast<int64_t>(Json["expires_in"].int_value()), + .Ok = true}; +} + +} // namespace zen diff --git a/src/zenserver/auth/oidc.h b/src/zenserver/auth/oidc.h new file mode 100644 index 000000000..f43ae3cd7 --- /dev/null +++ b/src/zenserver/auth/oidc.h @@ -0,0 +1,76 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/string.h> + +#include <vector> + +namespace zen { + +class OidcClient +{ +public: + struct Options + { + std::string_view BaseUrl; + std::string_view ClientId; + }; + + OidcClient(const Options& Options); + ~OidcClient() = default; + + OidcClient(const OidcClient&) = delete; + OidcClient& operator=(const OidcClient&) = delete; + + struct Result + { + std::string Reason; + bool Ok = false; + }; + + using InitResult = Result; + + InitResult Initialize(); + + struct RefreshTokenResult + { + std::string TokenType; + std::string AccessToken; + std::string RefreshToken; + std::string IdentityToken; + std::string Scope; + std::string Reason; + int64_t ExpiresInSeconds{}; + bool Ok = false; + }; + + RefreshTokenResult RefreshToken(std::string_view RefreshToken); + +private: + using StringArray = std::vector<std::string>; + + struct OpenIdConfiguration + { + std::string Issuer; + std::string AuthorizationEndpoint; + std::string TokenEndpoint; + std::string UserInfoEndpoint; + std::string RegistrationEndpoint; + std::string EndSessionEndpoint; + std::string DeviceAuthorizationEndpoint; + std::string JwksUri; + StringArray SupportedResponseTypes; + StringArray SupportedResponseModes; + StringArray SupportedGrantTypes; + StringArray SupportedScopes; + StringArray SupportedTokenEndpointAuthMethods; + StringArray SupportedClaims; + }; + + std::string m_BaseUrl; + std::string m_ClientId; + OpenIdConfiguration m_Config; +}; + +} // namespace zen diff --git a/src/zenserver/cache/cachetracking.cpp b/src/zenserver/cache/cachetracking.cpp new file mode 100644 index 000000000..9119e3122 --- /dev/null +++ b/src/zenserver/cache/cachetracking.cpp @@ -0,0 +1,376 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "cachetracking.h" + +#if ZEN_USE_CACHE_TRACKER + +# include <zencore/compactbinarybuilder.h> +# include <zencore/compactbinaryvalue.h> +# include <zencore/endian.h> +# include <zencore/filesystem.h> +# include <zencore/logging.h> +# include <zencore/scopeguard.h> +# include <zencore/string.h> + +# include <zencore/testing.h> +# include <zencore/testutils.h> + +ZEN_THIRD_PARTY_INCLUDES_START +# pragma comment(lib, "Rpcrt4.lib") // RocksDB made me do this +# include <fmt/format.h> +# include <rocksdb/db.h> +# include <tsl/robin_map.h> +# include <tsl/robin_set.h> +# include <gsl/gsl-lite.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen { + +namespace rocksdb = ROCKSDB_NAMESPACE; + +static constinit auto Epoch = std::chrono::time_point<std::chrono::system_clock>{}; + +static uint64_t +GetCurrentCacheTimeStamp() +{ + auto Duration = std::chrono::system_clock::now() - Epoch; + uint64_t Millis = std::chrono::duration_cast<std::chrono::milliseconds>(Duration).count(); + + return Millis; +} + +struct CacheAccessSnapshot +{ +public: + void TrackAccess(std::string_view BucketSegment, const IoHash& HashKey) + { + BucketTracker* Tracker = GetBucket(std::string(BucketSegment)); + + Tracker->Track(HashKey); + } + + bool SerializeSnapshot(CbObjectWriter& Cbo) + { + bool Serialized = false; + RwLock::ExclusiveLockScope _(m_Lock); + + for (const auto& Kv : m_BucketMapping) + { + if (m_Buckets[Kv.second]->Size()) + { + Cbo.BeginArray(Kv.first); + m_Buckets[Kv.second]->SerializeSnapshotAndClear(Cbo); + Cbo.EndArray(); + Serialized = true; + } + } + + return Serialized; + } + +private: + struct BucketTracker + { + mutable RwLock Lock; + tsl::robin_set<IoHash> AccessedKeys; + + void Track(const IoHash& HashKey) + { + if (RwLock::SharedLockScope _(Lock); AccessedKeys.contains(HashKey)) + { + return; + } + + RwLock::ExclusiveLockScope _(Lock); + + AccessedKeys.insert(HashKey); + } + + void SerializeSnapshotAndClear(CbObjectWriter& Cbo) + { + RwLock::ExclusiveLockScope _(Lock); + + for (const IoHash& Hash : AccessedKeys) + { + Cbo.AddHash(Hash); + } + + AccessedKeys.clear(); + } + + size_t Size() const + { + RwLock::SharedLockScope _(Lock); + return AccessedKeys.size(); + } + }; + + BucketTracker* GetBucket(const std::string& BucketName) + { + RwLock::SharedLockScope _(m_Lock); + + if (auto It = m_BucketMapping.find(BucketName); It == m_BucketMapping.end()) + { + _.ReleaseNow(); + + return AddNewBucket(BucketName); + } + else + { + return m_Buckets[It->second].get(); + } + } + + BucketTracker* AddNewBucket(const std::string& BucketName) + { + RwLock::ExclusiveLockScope _(m_Lock); + + if (auto It = m_BucketMapping.find(BucketName); It == m_BucketMapping.end()) + { + const uint32_t BucketIndex = gsl::narrow<uint32_t>(m_Buckets.size()); + m_Buckets.emplace_back(std::make_unique<BucketTracker>()); + m_BucketMapping[BucketName] = BucketIndex; + + return m_Buckets[BucketIndex].get(); + } + else + { + return m_Buckets[It->second].get(); + } + } + + RwLock m_Lock; + std::vector<std::unique_ptr<BucketTracker>> m_Buckets; + tsl::robin_map<std::string, uint32_t> m_BucketMapping; +}; + +struct ZenCacheTracker::Impl +{ + Impl(std::filesystem::path StateDirectory) + { + std::filesystem::path StatsDbPath{StateDirectory / ".zdb"}; + + std::string RocksdbPath = StatsDbPath.string(); + + ZEN_DEBUG("opening tracker db at '{}'", RocksdbPath); + + rocksdb::DB* Db = nullptr; + rocksdb::DBOptions Options; + Options.create_if_missing = true; + + std::vector<std::string> ExistingColumnFamilies; + rocksdb::Status Status = rocksdb::DB::ListColumnFamilies(Options, RocksdbPath, &ExistingColumnFamilies); + + std::vector<rocksdb::ColumnFamilyDescriptor> ColumnDescriptors; + + if (Status.IsPathNotFound()) + { + ColumnDescriptors.emplace_back(rocksdb::ColumnFamilyDescriptor{rocksdb::kDefaultColumnFamilyName, {}}); + } + else if (Status.ok()) + { + for (const std::string& Column : ExistingColumnFamilies) + { + rocksdb::ColumnFamilyDescriptor ColumnFamily; + ColumnFamily.name = Column; + ColumnDescriptors.push_back(ColumnFamily); + } + } + else + { + throw std::runtime_error(fmt::format("column family iteration failed for '{}': '{}'", RocksdbPath, Status.getState()).c_str()); + } + + Status = rocksdb::DB::Open(Options, RocksdbPath, ColumnDescriptors, &m_RocksDbColumnHandles, &Db); + + if (!Status.ok()) + { + throw std::runtime_error(fmt::format("database open failed for '{}': '{}'", RocksdbPath, Status.getState()).c_str()); + } + + m_RocksDb.reset(Db); + } + + ~Impl() + { + for (auto* Column : m_RocksDbColumnHandles) + { + delete Column; + } + + m_RocksDbColumnHandles.clear(); + } + + struct KeyStruct + { + uint64_t TimestampLittleEndian; + }; + + void TrackAccess(std::string_view BucketSegment, const IoHash& HashKey) { m_CurrentSnapshot.TrackAccess(BucketSegment, HashKey); } + + void SaveSnapshot() + { + CbObjectWriter Cbo; + + if (m_CurrentSnapshot.SerializeSnapshot(Cbo)) + { + IoBuffer SnapshotBuffer = Cbo.Save().GetBuffer().AsIoBuffer(); + + const KeyStruct Key{.TimestampLittleEndian = ToNetworkOrder(GetCurrentCacheTimeStamp())}; + rocksdb::Slice KeySlice{(const char*)&Key, sizeof Key}; + rocksdb::Slice ValueSlice{(char*)SnapshotBuffer.Data(), SnapshotBuffer.Size()}; + + rocksdb::WriteOptions Wo; + m_RocksDb->Put(Wo, KeySlice, ValueSlice); + } + } + + void IterateSnapshots(std::function<void(uint64_t TimeStamp, CbObject Snapshot)>&& Callback) + { + rocksdb::ManagedSnapshot Snap(m_RocksDb.get()); + + rocksdb::ReadOptions Ro; + Ro.snapshot = Snap.snapshot(); + + std::unique_ptr<rocksdb::Iterator> It{m_RocksDb->NewIterator(Ro)}; + + const KeyStruct ZeroKey{.TimestampLittleEndian = 0}; + rocksdb::Slice ZeroKeySlice{(const char*)&ZeroKey, sizeof ZeroKey}; + + It->Seek(ZeroKeySlice); + + while (It->Valid()) + { + rocksdb::Slice KeySlice = It->key(); + rocksdb::Slice ValueSlice = It->value(); + + if (KeySlice.size() == sizeof(KeyStruct)) + { + IoBuffer ValueBuffer(IoBuffer::Wrap, ValueSlice.data(), ValueSlice.size()); + + CbObject Value = LoadCompactBinaryObject(ValueBuffer); + + uint64_t Key = FromNetworkOrder(*reinterpret_cast<const uint64_t*>(KeySlice.data())); + + Callback(Key, Value); + } + + It->Next(); + } + } + + std::unique_ptr<rocksdb::DB> m_RocksDb; + std::vector<rocksdb::ColumnFamilyHandle*> m_RocksDbColumnHandles; + CacheAccessSnapshot m_CurrentSnapshot; +}; + +ZenCacheTracker::ZenCacheTracker(std::filesystem::path StateDirectory) : m_Impl(new Impl(StateDirectory)) +{ +} + +ZenCacheTracker::~ZenCacheTracker() +{ + delete m_Impl; +} + +void +ZenCacheTracker::TrackAccess(std::string_view BucketSegment, const IoHash& HashKey) +{ + m_Impl->TrackAccess(BucketSegment, HashKey); +} + +void +ZenCacheTracker::SaveSnapshot() +{ + m_Impl->SaveSnapshot(); +} + +void +ZenCacheTracker::IterateSnapshots(std::function<void(uint64_t TimeStamp, CbObject Snapshot)>&& Callback) +{ + m_Impl->IterateSnapshots(std::move(Callback)); +} + +# if ZEN_WITH_TESTS + +TEST_CASE("z$.tracker") +{ + using namespace std::literals; + + const uint64_t t0 = GetCurrentCacheTimeStamp(); + + ScopedTemporaryDirectory TempDir; + + ZenCacheTracker Zcs(TempDir.Path()); + + tsl::robin_set<IoHash> KeyHashes; + + for (int i = 0; i < 10000; ++i) + { + IoHash KeyHash = IoHash::HashBuffer(&i, sizeof i); + + KeyHashes.insert(KeyHash); + + Zcs.TrackAccess("foo"sv, KeyHash); + } + + for (int i = 0; i < 10000; ++i) + { + IoHash KeyHash = IoHash::HashBuffer(&i, sizeof i); + + Zcs.TrackAccess("foo"sv, KeyHash); + } + + Zcs.SaveSnapshot(); + + for (int n = 0; n < 10; ++n) + { + for (int i = 0; i < 1000; ++i) + { + const int Index = i + n * 1000; + IoHash KeyHash = IoHash::HashBuffer(&Index, sizeof Index); + + Zcs.TrackAccess("foo"sv, KeyHash); + } + + Zcs.SaveSnapshot(); + } + + Zcs.SaveSnapshot(); + + const uint64_t t1 = GetCurrentCacheTimeStamp(); + + int SnapshotCount = 0; + + Zcs.IterateSnapshots([&](uint64_t TimeStamp, CbObject Snapshot) { + CHECK(TimeStamp >= t0); + CHECK(TimeStamp <= t1); + + for (auto& Field : Snapshot) + { + CHECK_EQ(Field.GetName(), "foo"sv); + + const CbArray& Array = Field.AsArray(); + + for (const auto& Element : Array) + { + CHECK(KeyHashes.contains(Element.GetValue().AsHash())); + } + } + + ++SnapshotCount; + }); + + CHECK_EQ(SnapshotCount, 11); +} + +# endif + +void +cachetracker_forcelink() +{ +} + +} // namespace zen + +#endif // ZEN_USE_CACHE_TRACKER diff --git a/src/zenserver/cache/cachetracking.h b/src/zenserver/cache/cachetracking.h new file mode 100644 index 000000000..fdfe1a4c7 --- /dev/null +++ b/src/zenserver/cache/cachetracking.h @@ -0,0 +1,41 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/iohash.h> + +#include <stdint.h> +#include <filesystem> +#include <functional> + +namespace zen { + +#define ZEN_USE_CACHE_TRACKER 0 +#if ZEN_USE_CACHE_TRACKER + +class CbObject; + +/** + */ + +class ZenCacheTracker +{ +public: + ZenCacheTracker(std::filesystem::path StateDirectory); + ~ZenCacheTracker(); + + void TrackAccess(std::string_view BucketSegment, const IoHash& HashKey); + void SaveSnapshot(); + void IterateSnapshots(std::function<void(uint64_t TimeStamp, CbObject Snapshot)>&& Callback); + +private: + struct Impl; + + Impl* m_Impl = nullptr; +}; + +void cachetracker_forcelink(); + +#endif // ZEN_USE_CACHE_TRACKER + +} // namespace zen diff --git a/src/zenserver/cache/structuredcache.cpp b/src/zenserver/cache/structuredcache.cpp new file mode 100644 index 000000000..90e905bf6 --- /dev/null +++ b/src/zenserver/cache/structuredcache.cpp @@ -0,0 +1,3159 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "structuredcache.h" + +#include <zencore/compactbinary.h> +#include <zencore/compactbinarybuilder.h> +#include <zencore/compactbinarypackage.h> +#include <zencore/compactbinaryvalidation.h> +#include <zencore/compress.h> +#include <zencore/enumflags.h> +#include <zencore/fmtutils.h> +#include <zencore/logging.h> +#include <zencore/scopeguard.h> +#include <zencore/stream.h> +#include <zencore/timer.h> +#include <zencore/trace.h> +#include <zencore/workthreadpool.h> +#include <zenhttp/httpserver.h> +#include <zenhttp/httpshared.h> +#include <zenutil/cache/cache.h> +#include <zenutil/cache/rpcrecording.h> + +#include "monitoring/httpstats.h" +#include "structuredcachestore.h" +#include "upstream/jupiter.h" +#include "upstream/upstreamcache.h" +#include "upstream/zen.h" +#include "zenstore/cidstore.h" +#include "zenstore/scrubcontext.h" + +#include <algorithm> +#include <atomic> +#include <filesystem> +#include <queue> +#include <thread> + +#include <cpr/cpr.h> +#include <gsl/gsl-lite.hpp> + +#if ZEN_WITH_TESTS +# include <zencore/testing.h> +# include <zencore/testutils.h> +#endif + +namespace zen { + +using namespace std::literals; + +////////////////////////////////////////////////////////////////////////// + +CachePolicy +ParseCachePolicy(const HttpServerRequest::QueryParams& QueryParams) +{ + std::string_view PolicyText = QueryParams.GetValue("Policy"sv); + return !PolicyText.empty() ? zen::ParseCachePolicy(PolicyText) : CachePolicy::Default; +} + +CacheRecordPolicy +LoadCacheRecordPolicy(CbObjectView Object, CachePolicy DefaultPolicy = CachePolicy::Default) +{ + OptionalCacheRecordPolicy Policy = CacheRecordPolicy::Load(Object); + return Policy ? std::move(Policy).Get() : CacheRecordPolicy(DefaultPolicy); +} + +struct AttachmentCount +{ + uint32_t New = 0; + uint32_t Valid = 0; + uint32_t Invalid = 0; + uint32_t Total = 0; +}; + +struct PutRequestData +{ + std::string Namespace; + CacheKey Key; + CbObjectView RecordObject; + CacheRecordPolicy Policy; +}; + +namespace { + static constinit std::string_view HttpZCacheRPCPrefix = "$rpc"sv; + static constinit std::string_view HttpZCacheUtilStartRecording = "exec$/start-recording"sv; + static constinit std::string_view HttpZCacheUtilStopRecording = "exec$/stop-recording"sv; + static constinit std::string_view HttpZCacheUtilReplayRecording = "exec$/replay-recording"sv; + static constinit std::string_view HttpZCacheDetailsPrefix = "details$"sv; + + struct HttpRequestData + { + std::optional<std::string> Namespace; + std::optional<std::string> Bucket; + std::optional<IoHash> HashKey; + std::optional<IoHash> ValueContentId; + }; + + constinit AsciiSet ValidNamespaceNameCharactersSet{"abcdefghijklmnopqrstuvwxyz0123456789-_.ABCDEFGHIJKLMNOPQRSTUVWXYZ"}; + constinit AsciiSet ValidBucketNameCharactersSet{"abcdefghijklmnopqrstuvwxyz0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"}; + + std::optional<std::string> GetValidNamespaceName(std::string_view Name) + { + if (Name.empty()) + { + ZEN_WARN("Namespace is invalid, empty namespace is not allowed"); + return {}; + } + + if (Name.length() > 64) + { + ZEN_WARN("Namespace '{}' is invalid, length exceeds 64 characters", Name); + return {}; + } + + if (!AsciiSet::HasOnly(Name, ValidNamespaceNameCharactersSet)) + { + ZEN_WARN("Namespace '{}' is invalid, invalid characters detected", Name); + return {}; + } + + return ToLower(Name); + } + + std::optional<std::string> GetValidBucketName(std::string_view Name) + { + if (Name.empty()) + { + ZEN_WARN("Bucket name is invalid, empty bucket name is not allowed"); + return {}; + } + + if (!AsciiSet::HasOnly(Name, ValidBucketNameCharactersSet)) + { + ZEN_WARN("Bucket name '{}' is invalid, invalid characters detected", Name); + return {}; + } + + return ToLower(Name); + } + + std::optional<IoHash> GetValidIoHash(std::string_view Hash) + { + if (Hash.length() != IoHash::StringLength) + { + return {}; + } + + IoHash KeyHash; + if (!ParseHexBytes(Hash.data(), Hash.size(), KeyHash.Hash)) + { + return {}; + } + return KeyHash; + } + + bool HttpRequestParseRelativeUri(std::string_view Key, HttpRequestData& Data) + { + std::vector<std::string_view> Tokens; + uint32_t TokenCount = ForEachStrTok(Key, '/', [&](const std::string_view& Token) { + Tokens.push_back(Token); + return true; + }); + + switch (TokenCount) + { + case 0: + return true; + case 1: + Data.Namespace = GetValidNamespaceName(Tokens[0]); + return Data.Namespace.has_value(); + case 2: + { + std::optional<IoHash> PossibleHashKey = GetValidIoHash(Tokens[1]); + if (PossibleHashKey.has_value()) + { + // Legacy bucket/key request + Data.Bucket = GetValidBucketName(Tokens[0]); + if (!Data.Bucket.has_value()) + { + return false; + } + Data.HashKey = PossibleHashKey; + Data.Namespace = ZenCacheStore::DefaultNamespace; + return true; + } + Data.Namespace = GetValidNamespaceName(Tokens[0]); + if (!Data.Namespace.has_value()) + { + return false; + } + Data.Bucket = GetValidBucketName(Tokens[1]); + if (!Data.Bucket.has_value()) + { + return false; + } + return true; + } + case 3: + { + std::optional<IoHash> PossibleHashKey = GetValidIoHash(Tokens[1]); + if (PossibleHashKey.has_value()) + { + // Legacy bucket/key/valueid request + Data.Bucket = GetValidBucketName(Tokens[0]); + if (!Data.Bucket.has_value()) + { + return false; + } + Data.HashKey = PossibleHashKey; + Data.ValueContentId = GetValidIoHash(Tokens[2]); + if (!Data.ValueContentId.has_value()) + { + return false; + } + Data.Namespace = ZenCacheStore::DefaultNamespace; + return true; + } + Data.Namespace = GetValidNamespaceName(Tokens[0]); + if (!Data.Namespace.has_value()) + { + return false; + } + Data.Bucket = GetValidBucketName(Tokens[1]); + if (!Data.Bucket.has_value()) + { + return false; + } + Data.HashKey = GetValidIoHash(Tokens[2]); + if (!Data.HashKey) + { + return false; + } + return true; + } + case 4: + { + Data.Namespace = GetValidNamespaceName(Tokens[0]); + if (!Data.Namespace.has_value()) + { + return false; + } + + Data.Bucket = GetValidBucketName(Tokens[1]); + if (!Data.Bucket.has_value()) + { + return false; + } + + Data.HashKey = GetValidIoHash(Tokens[2]); + if (!Data.HashKey.has_value()) + { + return false; + } + + Data.ValueContentId = GetValidIoHash(Tokens[3]); + if (!Data.ValueContentId.has_value()) + { + return false; + } + return true; + } + default: + return false; + } + } + + std::optional<std::string> GetRpcRequestNamespace(const CbObjectView Params) + { + CbFieldView NamespaceField = Params["Namespace"sv]; + if (!NamespaceField) + { + return std::string(ZenCacheStore::DefaultNamespace); + } + + if (NamespaceField.HasError()) + { + return {}; + } + if (!NamespaceField.IsString()) + { + return {}; + } + return GetValidNamespaceName(NamespaceField.AsString()); + } + + bool GetRpcRequestCacheKey(const CbObjectView& KeyView, CacheKey& Key) + { + CbFieldView BucketField = KeyView["Bucket"sv]; + if (BucketField.HasError()) + { + return false; + } + if (!BucketField.IsString()) + { + return false; + } + std::optional<std::string> Bucket = GetValidBucketName(BucketField.AsString()); + if (!Bucket.has_value()) + { + return false; + } + CbFieldView HashField = KeyView["Hash"sv]; + if (HashField.HasError()) + { + return false; + } + if (!HashField.IsHash()) + { + return false; + } + IoHash Hash = HashField.AsHash(); + Key = CacheKey::Create(*Bucket, Hash); + return true; + } + +} // namespace + +////////////////////////////////////////////////////////////////////////// + +HttpStructuredCacheService::HttpStructuredCacheService(ZenCacheStore& InCacheStore, + CidStore& InCidStore, + HttpStatsService& StatsService, + HttpStatusService& StatusService, + UpstreamCache& UpstreamCache) +: m_Log(logging::Get("cache")) +, m_CacheStore(InCacheStore) +, m_StatsService(StatsService) +, m_StatusService(StatusService) +, m_CidStore(InCidStore) +, m_UpstreamCache(UpstreamCache) +{ + m_StatsService.RegisterHandler("z$", *this); + m_StatusService.RegisterHandler("z$", *this); +} + +HttpStructuredCacheService::~HttpStructuredCacheService() +{ + ZEN_INFO("closing structured cache"); + m_RequestRecorder.reset(); + + m_StatsService.UnregisterHandler("z$", *this); + m_StatusService.UnregisterHandler("z$", *this); +} + +const char* +HttpStructuredCacheService::BaseUri() const +{ + return "/z$/"; +} + +void +HttpStructuredCacheService::Flush() +{ + m_CacheStore.Flush(); +} + +void +HttpStructuredCacheService::Scrub(ScrubContext& Ctx) +{ + if (m_LastScrubTime == Ctx.ScrubTimestamp()) + { + return; + } + + m_LastScrubTime = Ctx.ScrubTimestamp(); + + m_CidStore.Scrub(Ctx); + m_CacheStore.Scrub(Ctx); +} + +void +HttpStructuredCacheService::HandleDetailsRequest(HttpServerRequest& Request) +{ + std::string_view Key = Request.RelativeUri(); + std::vector<std::string> Tokens; + uint32_t TokenCount = ForEachStrTok(Key, '/', [&Tokens](std::string_view Token) { + Tokens.push_back(std::string(Token)); + return true; + }); + std::string FilterNamespace; + std::string FilterBucket; + std::string FilterValue; + switch (TokenCount) + { + case 1: + break; + case 2: + { + FilterNamespace = Tokens[1]; + if (FilterNamespace.empty()) + { + return Request.WriteResponse(HttpResponseCode::BadRequest); // invalid URL + } + } + break; + case 3: + { + FilterNamespace = Tokens[1]; + if (FilterNamespace.empty()) + { + return Request.WriteResponse(HttpResponseCode::BadRequest); // invalid URL + } + FilterBucket = Tokens[2]; + if (FilterBucket.empty()) + { + return Request.WriteResponse(HttpResponseCode::BadRequest); // invalid URL + } + } + break; + case 4: + { + FilterNamespace = Tokens[1]; + if (FilterNamespace.empty()) + { + return Request.WriteResponse(HttpResponseCode::BadRequest); // invalid URL + } + FilterBucket = Tokens[2]; + if (FilterBucket.empty()) + { + return Request.WriteResponse(HttpResponseCode::BadRequest); // invalid URL + } + FilterValue = Tokens[3]; + if (FilterValue.empty()) + { + return Request.WriteResponse(HttpResponseCode::BadRequest); // invalid URL + } + } + break; + default: + return Request.WriteResponse(HttpResponseCode::BadRequest); // invalid URL + } + + HttpServerRequest::QueryParams Params = Request.GetQueryParams(); + bool CSV = Params.GetValue("csv") == "true"; + bool Details = Params.GetValue("details") == "true"; + bool AttachmentDetails = Params.GetValue("attachmentdetails") == "true"; + + std::chrono::seconds NowSeconds = std::chrono::duration_cast<std::chrono::seconds>(GcClock::Now().time_since_epoch()); + CacheValueDetails ValueDetails = m_CacheStore.GetValueDetails(FilterNamespace, FilterBucket, FilterValue); + + if (CSV) + { + ExtendableStringBuilder<4096> CSVWriter; + if (AttachmentDetails) + { + CSVWriter << "Namespace, Bucket, Key, Cid, Size"; + } + else if (Details) + { + CSVWriter << "Namespace, Bucket, Key, Size, RawSize, RawHash, ContentType, Age, AttachmentsCount, AttachmentsSize"; + } + else + { + CSVWriter << "Namespace, Bucket, Key"; + } + for (const auto& NamespaceIt : ValueDetails.Namespaces) + { + const std::string& Namespace = NamespaceIt.first; + for (const auto& BucketIt : NamespaceIt.second.Buckets) + { + const std::string& Bucket = BucketIt.first; + for (const auto& ValueIt : BucketIt.second.Values) + { + if (AttachmentDetails) + { + for (const IoHash& Hash : ValueIt.second.Attachments) + { + IoBuffer Payload = m_CidStore.FindChunkByCid(Hash); + CSVWriter << "\r\n" + << Namespace << "," << Bucket << "," << ValueIt.first.ToHexString() << ", " << Hash.ToHexString() + << ", " << gsl::narrow<uint64_t>(Payload.GetSize()); + } + } + else if (Details) + { + std::chrono::seconds LastAccessedSeconds = std::chrono::duration_cast<std::chrono::seconds>( + GcClock::TimePointFromTick(ValueIt.second.LastAccess).time_since_epoch()); + CSVWriter << "\r\n" + << Namespace << "," << Bucket << "," << ValueIt.first.ToHexString() << ", " << ValueIt.second.Size << "," + << ValueIt.second.RawSize << "," << ValueIt.second.RawHash.ToHexString() << ", " + << ToString(ValueIt.second.ContentType) << ", " << (NowSeconds.count() - LastAccessedSeconds.count()) + << ", " << gsl::narrow<uint64_t>(ValueIt.second.Attachments.size()); + size_t AttachmentsSize = 0; + for (const IoHash& Hash : ValueIt.second.Attachments) + { + IoBuffer Payload = m_CidStore.FindChunkByCid(Hash); + AttachmentsSize += Payload.GetSize(); + } + CSVWriter << ", " << gsl::narrow<uint64_t>(AttachmentsSize); + } + else + { + CSVWriter << "\r\n" << Namespace << "," << Bucket << "," << ValueIt.first.ToHexString(); + } + } + } + } + return Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, CSVWriter.ToView()); + } + else + { + CbObjectWriter Cbo; + Cbo.BeginArray("namespaces"); + { + for (const auto& NamespaceIt : ValueDetails.Namespaces) + { + const std::string& Namespace = NamespaceIt.first; + Cbo.BeginObject(); + { + Cbo.AddString("name", Namespace); + Cbo.BeginArray("buckets"); + { + for (const auto& BucketIt : NamespaceIt.second.Buckets) + { + const std::string& Bucket = BucketIt.first; + Cbo.BeginObject(); + { + Cbo.AddString("name", Bucket); + Cbo.BeginArray("values"); + { + for (const auto& ValueIt : BucketIt.second.Values) + { + std::chrono::seconds LastAccessedSeconds = std::chrono::duration_cast<std::chrono::seconds>( + GcClock::TimePointFromTick(ValueIt.second.LastAccess).time_since_epoch()); + Cbo.BeginObject(); + { + Cbo.AddHash("key", ValueIt.first); + if (Details) + { + Cbo.AddInteger("size", ValueIt.second.Size); + if (ValueIt.second.Size > 0 && ValueIt.second.RawSize != 0 && + ValueIt.second.RawSize != ValueIt.second.Size) + { + Cbo.AddInteger("rawsize", ValueIt.second.RawSize); + Cbo.AddHash("rawhash", ValueIt.second.RawHash); + } + Cbo.AddString("contenttype", ToString(ValueIt.second.ContentType)); + Cbo.AddInteger("age", NowSeconds.count() - LastAccessedSeconds.count()); + if (ValueIt.second.Attachments.size() > 0) + { + if (AttachmentDetails) + { + Cbo.BeginArray("attachments"); + { + for (const IoHash& Hash : ValueIt.second.Attachments) + { + Cbo.BeginObject(); + Cbo.AddHash("cid", Hash); + IoBuffer Payload = m_CidStore.FindChunkByCid(Hash); + Cbo.AddInteger("size", gsl::narrow<uint64_t>(Payload.GetSize())); + Cbo.EndObject(); + } + } + Cbo.EndArray(); + } + else + { + Cbo.AddInteger("attachmentcount", + gsl::narrow<uint64_t>(ValueIt.second.Attachments.size())); + size_t AttachmentsSize = 0; + for (const IoHash& Hash : ValueIt.second.Attachments) + { + IoBuffer Payload = m_CidStore.FindChunkByCid(Hash); + AttachmentsSize += Payload.GetSize(); + } + Cbo.AddInteger("attachmentssize", gsl::narrow<uint64_t>(AttachmentsSize)); + } + } + } + } + Cbo.EndObject(); + } + } + Cbo.EndArray(); + } + Cbo.EndObject(); + } + } + Cbo.EndArray(); + } + Cbo.EndObject(); + } + } + Cbo.EndArray(); + Request.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + } +} + +void +HttpStructuredCacheService::HandleRequest(HttpServerRequest& Request) +{ + metrics::OperationTiming::Scope $(m_HttpRequests); + + std::string_view Key = Request.RelativeUri(); + if (Key == HttpZCacheRPCPrefix) + { + return HandleRpcRequest(Request); + } + + if (Key == HttpZCacheUtilStartRecording) + { + m_RequestRecorder.reset(); + HttpServerRequest::QueryParams Params = Request.GetQueryParams(); + std::string RecordPath = cpr::util::urlDecode(std::string(Params.GetValue("path"))); + m_RequestRecorder = cache::MakeDiskRequestRecorder(RecordPath); + Request.WriteResponse(HttpResponseCode::OK); + return; + } + if (Key == HttpZCacheUtilStopRecording) + { + m_RequestRecorder.reset(); + Request.WriteResponse(HttpResponseCode::OK); + return; + } + if (Key == HttpZCacheUtilReplayRecording) + { + m_RequestRecorder.reset(); + HttpServerRequest::QueryParams Params = Request.GetQueryParams(); + std::string RecordPath = cpr::util::urlDecode(std::string(Params.GetValue("path"))); + uint32_t ThreadCount = std::thread::hardware_concurrency(); + if (auto Param = Params.GetValue("thread_count"); Param.empty() == false) + { + if (auto Value = ParseInt<uint64_t>(Param)) + { + ThreadCount = gsl::narrow<uint32_t>(Value.value()); + } + } + std::unique_ptr<cache::IRpcRequestReplayer> Replayer(cache::MakeDiskRequestReplayer(RecordPath, false)); + ReplayRequestRecorder(*Replayer, ThreadCount < 1 ? 1 : ThreadCount); + Request.WriteResponse(HttpResponseCode::OK); + return; + } + if (Key.starts_with(HttpZCacheDetailsPrefix)) + { + HandleDetailsRequest(Request); + return; + } + + HttpRequestData RequestData; + if (!HttpRequestParseRelativeUri(Key, RequestData)) + { + return Request.WriteResponse(HttpResponseCode::BadRequest); // invalid URL + } + + if (RequestData.ValueContentId.has_value()) + { + ZEN_ASSERT(RequestData.Namespace.has_value()); + ZEN_ASSERT(RequestData.Bucket.has_value()); + ZEN_ASSERT(RequestData.HashKey.has_value()); + CacheRef Ref = {.Namespace = RequestData.Namespace.value(), + .BucketSegment = RequestData.Bucket.value(), + .HashKey = RequestData.HashKey.value(), + .ValueContentId = RequestData.ValueContentId.value()}; + return HandleCacheChunkRequest(Request, Ref, ParseCachePolicy(Request.GetQueryParams())); + } + + if (RequestData.HashKey.has_value()) + { + ZEN_ASSERT(RequestData.Namespace.has_value()); + ZEN_ASSERT(RequestData.Bucket.has_value()); + CacheRef Ref = {.Namespace = RequestData.Namespace.value(), + .BucketSegment = RequestData.Bucket.value(), + .HashKey = RequestData.HashKey.value(), + .ValueContentId = IoHash::Zero}; + return HandleCacheRecordRequest(Request, Ref, ParseCachePolicy(Request.GetQueryParams())); + } + + if (RequestData.Bucket.has_value()) + { + ZEN_ASSERT(RequestData.Namespace.has_value()); + return HandleCacheBucketRequest(Request, RequestData.Namespace.value(), RequestData.Bucket.value()); + } + + if (RequestData.Namespace.has_value()) + { + return HandleCacheNamespaceRequest(Request, RequestData.Namespace.value()); + } + return HandleCacheRequest(Request); +} + +void +HttpStructuredCacheService::HandleCacheRequest(HttpServerRequest& Request) +{ + switch (Request.RequestVerb()) + { + case HttpVerb::kHead: + case HttpVerb::kGet: + { + ZenCacheStore::Info Info = m_CacheStore.GetInfo(); + + CbObjectWriter ResponseWriter; + + ResponseWriter.BeginObject("Configuration"); + { + ExtendableStringBuilder<128> BasePathString; + BasePathString << Info.Config.BasePath.u8string(); + ResponseWriter.AddString("BasePath"sv, BasePathString.ToView()); + ResponseWriter.AddBool("AllowAutomaticCreationOfNamespaces", Info.Config.AllowAutomaticCreationOfNamespaces); + } + ResponseWriter.EndObject(); + + std::sort(begin(Info.NamespaceNames), end(Info.NamespaceNames), [](std::string_view L, std::string_view R) { + return L.compare(R) < 0; + }); + ResponseWriter.BeginArray("Namespaces"); + for (const std::string& NamespaceName : Info.NamespaceNames) + { + ResponseWriter.AddString(NamespaceName); + } + ResponseWriter.EndArray(); + ResponseWriter.BeginObject("StorageSize"); + { + ResponseWriter.AddInteger("DiskSize", Info.StorageSize.DiskSize); + ResponseWriter.AddInteger("MemorySize", Info.StorageSize.MemorySize); + } + + ResponseWriter.EndObject(); + + ResponseWriter.AddInteger("DiskEntryCount", Info.DiskEntryCount); + ResponseWriter.AddInteger("MemoryEntryCount", Info.MemoryEntryCount); + + return Request.WriteResponse(HttpResponseCode::OK, ResponseWriter.Save()); + } + break; + } +} + +void +HttpStructuredCacheService::HandleCacheNamespaceRequest(HttpServerRequest& Request, std::string_view NamespaceName) +{ + switch (Request.RequestVerb()) + { + case HttpVerb::kHead: + case HttpVerb::kGet: + { + std::optional<ZenCacheNamespace::Info> Info = m_CacheStore.GetNamespaceInfo(NamespaceName); + if (!Info.has_value()) + { + return Request.WriteResponse(HttpResponseCode::NotFound); + } + + CbObjectWriter ResponseWriter; + + ResponseWriter.BeginObject("Configuration"); + { + ExtendableStringBuilder<128> BasePathString; + BasePathString << Info->Config.RootDir.u8string(); + ResponseWriter.AddString("RootDir"sv, BasePathString.ToView()); + ResponseWriter.AddInteger("DiskLayerThreshold"sv, Info->Config.DiskLayerThreshold); + } + ResponseWriter.EndObject(); + + std::sort(begin(Info->BucketNames), end(Info->BucketNames), [](std::string_view L, std::string_view R) { + return L.compare(R) < 0; + }); + + ResponseWriter.BeginArray("Buckets"sv); + for (const std::string& BucketName : Info->BucketNames) + { + ResponseWriter.AddString(BucketName); + } + ResponseWriter.EndArray(); + + ResponseWriter.BeginObject("StorageSize"sv); + { + ResponseWriter.AddInteger("DiskSize"sv, Info->DiskLayerInfo.TotalSize); + ResponseWriter.AddInteger("MemorySize"sv, Info->MemoryLayerInfo.TotalSize); + } + ResponseWriter.EndObject(); + + ResponseWriter.AddInteger("DiskEntryCount", Info->DiskLayerInfo.EntryCount); + ResponseWriter.AddInteger("MemoryEntryCount", Info->MemoryLayerInfo.EntryCount); + + return Request.WriteResponse(HttpResponseCode::OK, ResponseWriter.Save()); + } + break; + + case HttpVerb::kDelete: + // Drop namespace + { + if (m_CacheStore.DropNamespace(NamespaceName)) + { + return Request.WriteResponse(HttpResponseCode::OK); + } + else + { + return Request.WriteResponse(HttpResponseCode::NotFound); + } + } + break; + + default: + break; + } +} + +void +HttpStructuredCacheService::HandleCacheBucketRequest(HttpServerRequest& Request, + std::string_view NamespaceName, + std::string_view BucketName) +{ + switch (Request.RequestVerb()) + { + case HttpVerb::kHead: + case HttpVerb::kGet: + { + std::optional<ZenCacheNamespace::BucketInfo> Info = m_CacheStore.GetBucketInfo(NamespaceName, BucketName); + if (!Info.has_value()) + { + return Request.WriteResponse(HttpResponseCode::NotFound); + } + + CbObjectWriter ResponseWriter; + + ResponseWriter.BeginObject("StorageSize"); + { + ResponseWriter.AddInteger("DiskSize", Info->DiskLayerInfo.TotalSize); + ResponseWriter.AddInteger("MemorySize", Info->MemoryLayerInfo.TotalSize); + } + ResponseWriter.EndObject(); + + ResponseWriter.AddInteger("DiskEntryCount", Info->DiskLayerInfo.EntryCount); + ResponseWriter.AddInteger("MemoryEntryCount", Info->MemoryLayerInfo.EntryCount); + + return Request.WriteResponse(HttpResponseCode::OK, ResponseWriter.Save()); + } + break; + + case HttpVerb::kDelete: + // Drop bucket + { + if (m_CacheStore.DropBucket(NamespaceName, BucketName)) + { + return Request.WriteResponse(HttpResponseCode::OK); + } + else + { + return Request.WriteResponse(HttpResponseCode::NotFound); + } + } + break; + + default: + break; + } +} + +void +HttpStructuredCacheService::HandleCacheRecordRequest(HttpServerRequest& Request, const CacheRef& Ref, CachePolicy PolicyFromUrl) +{ + switch (Request.RequestVerb()) + { + case HttpVerb::kHead: + case HttpVerb::kGet: + { + HandleGetCacheRecord(Request, Ref, PolicyFromUrl); + } + break; + + case HttpVerb::kPut: + HandlePutCacheRecord(Request, Ref, PolicyFromUrl); + break; + default: + break; + } +} + +void +HttpStructuredCacheService::HandleGetCacheRecord(HttpServerRequest& Request, const CacheRef& Ref, CachePolicy PolicyFromUrl) +{ + const ZenContentType AcceptType = Request.AcceptContentType(); + const bool SkipData = EnumHasAllFlags(PolicyFromUrl, CachePolicy::SkipData); + const bool PartialRecord = EnumHasAllFlags(PolicyFromUrl, CachePolicy::PartialRecord); + + bool Success = false; + ZenCacheValue ClientResultValue; + if (!EnumHasAnyFlags(PolicyFromUrl, CachePolicy::Query)) + { + return Request.WriteResponse(HttpResponseCode::OK); + } + + Stopwatch Timer; + + if (EnumHasAllFlags(PolicyFromUrl, CachePolicy::QueryLocal) && + m_CacheStore.Get(Ref.Namespace, Ref.BucketSegment, Ref.HashKey, ClientResultValue)) + { + Success = true; + ZenContentType ContentType = ClientResultValue.Value.GetContentType(); + + if (AcceptType == ZenContentType::kCbPackage) + { + if (ContentType == ZenContentType::kCbObject) + { + CbPackage Package; + uint32_t MissingCount = 0; + + CbObjectView CacheRecord(ClientResultValue.Value.Data()); + CacheRecord.IterateAttachments([this, &MissingCount, &Package, SkipData](CbFieldView AttachmentHash) { + if (SkipData) + { + if (!m_CidStore.ContainsChunk(AttachmentHash.AsHash())) + { + MissingCount++; + } + } + else + { + if (IoBuffer Chunk = m_CidStore.FindChunkByCid(AttachmentHash.AsHash())) + { + CompressedBuffer Compressed = CompressedBuffer::FromCompressedNoValidate(std::move(Chunk)); + Package.AddAttachment(CbAttachment(Compressed, AttachmentHash.AsHash())); + } + else + { + MissingCount++; + } + } + }); + + Success = MissingCount == 0 || PartialRecord; + + if (Success) + { + Package.SetObject(LoadCompactBinaryObject(ClientResultValue.Value)); + + BinaryWriter MemStream; + Package.Save(MemStream); + + ClientResultValue.Value = IoBuffer(IoBuffer::Clone, MemStream.Data(), MemStream.Size()); + ClientResultValue.Value.SetContentType(HttpContentType::kCbPackage); + } + } + else + { + Success = false; + } + } + else if (AcceptType != ClientResultValue.Value.GetContentType() && AcceptType != ZenContentType::kUnknownContentType && + AcceptType != ZenContentType::kBinary) + { + Success = false; + } + } + + if (Success) + { + ZEN_DEBUG("GETCACHERECORD HIT - '{}/{}/{}' {} '{}' (LOCAL) in {}", + Ref.Namespace, + Ref.BucketSegment, + Ref.HashKey, + NiceBytes(ClientResultValue.Value.Size()), + ToString(ClientResultValue.Value.GetContentType()), + NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000)); + + m_CacheStats.HitCount++; + if (SkipData && AcceptType != ZenContentType::kCbPackage && AcceptType != ZenContentType::kCbObject) + { + return Request.WriteResponse(HttpResponseCode::OK); + } + else + { + // kCbPackage handled SkipData when constructing the ClientResultValue, kcbObject ignores SkipData + return Request.WriteResponse(HttpResponseCode::OK, ClientResultValue.Value.GetContentType(), ClientResultValue.Value); + } + } + else if (!EnumHasAllFlags(PolicyFromUrl, CachePolicy::QueryRemote)) + { + ZEN_DEBUG("GETCACHERECORD MISS - '{}/{}/{}' '{}' in {}", + Ref.Namespace, + Ref.BucketSegment, + Ref.HashKey, + ToString(AcceptType), + NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000)); + m_CacheStats.MissCount++; + return Request.WriteResponse(HttpResponseCode::NotFound); + } + + // Issue upstream query asynchronously in order to keep requests flowing without + // hogging I/O servicing threads with blocking work + + uint64_t LocalElapsedTimeUs = Timer.GetElapsedTimeUs(); + + Request.WriteResponseAsync([this, AcceptType, PolicyFromUrl, Ref, LocalElapsedTimeUs](HttpServerRequest& AsyncRequest) { + Stopwatch Timer; + bool Success = false; + const bool PartialRecord = EnumHasAllFlags(PolicyFromUrl, CachePolicy::PartialRecord); + const bool QueryLocal = EnumHasAllFlags(PolicyFromUrl, CachePolicy::QueryLocal); + const bool StoreLocal = EnumHasAllFlags(PolicyFromUrl, CachePolicy::StoreLocal); + const bool SkipData = EnumHasAllFlags(PolicyFromUrl, CachePolicy::SkipData); + ZenCacheValue ClientResultValue; + + metrics::OperationTiming::Scope $(m_UpstreamGetRequestTiming); + + if (GetUpstreamCacheSingleResult UpstreamResult = + m_UpstreamCache.GetCacheRecord(Ref.Namespace, {Ref.BucketSegment, Ref.HashKey}, AcceptType); + UpstreamResult.Status.Success) + { + Success = true; + + ClientResultValue.Value = UpstreamResult.Value; + ClientResultValue.Value.SetContentType(AcceptType); + + if (AcceptType == ZenContentType::kBinary || AcceptType == ZenContentType::kCbObject) + { + if (AcceptType == ZenContentType::kCbObject) + { + const CbValidateError ValidationResult = ValidateCompactBinary(UpstreamResult.Value, CbValidateMode::All); + if (ValidationResult != CbValidateError::None) + { + Success = false; + ZEN_WARN("Get - '{}/{}/{}' '{}' FAILED, invalid compact binary object from upstream", + Ref.Namespace, + Ref.BucketSegment, + Ref.HashKey, + ToString(AcceptType)); + } + + // We do not do anything to the returned object for SkipData, only package attachments are cut when skipping data + } + + if (Success && StoreLocal) + { + m_CacheStore.Put(Ref.Namespace, Ref.BucketSegment, Ref.HashKey, ClientResultValue); + } + } + else if (AcceptType == ZenContentType::kCbPackage) + { + CbPackage Package; + if (Package.TryLoad(ClientResultValue.Value)) + { + CbObject CacheRecord = Package.GetObject(); + AttachmentCount Count; + size_t NumAttachments = Package.GetAttachments().size(); + std::vector<const CbAttachment*> AttachmentsToStoreLocally; + AttachmentsToStoreLocally.reserve(NumAttachments); + + CacheRecord.IterateAttachments( + [this, &Package, &Ref, &AttachmentsToStoreLocally, &Count, QueryLocal, StoreLocal, SkipData](CbFieldView HashView) { + IoHash Hash = HashView.AsHash(); + if (const CbAttachment* Attachment = Package.FindAttachment(Hash)) + { + if (Attachment->IsCompressedBinary()) + { + if (StoreLocal) + { + AttachmentsToStoreLocally.emplace_back(Attachment); + } + Count.Valid++; + } + else + { + ZEN_WARN("Uncompressed value '{}' from upstream cache record '{}/{}'", + Hash, + Ref.BucketSegment, + Ref.HashKey); + Count.Invalid++; + } + } + else if (QueryLocal) + { + if (SkipData) + { + if (m_CidStore.ContainsChunk(Hash)) + { + Count.Valid++; + } + } + else if (IoBuffer Chunk = m_CidStore.FindChunkByCid(Hash)) + { + CompressedBuffer Compressed = CompressedBuffer::FromCompressedNoValidate(std::move(Chunk)); + if (Compressed) + { + Package.AddAttachment(CbAttachment(Compressed, Hash)); + Count.Valid++; + } + else + { + ZEN_WARN("Uncompressed value '{}' stored in local cache '{}/{}'", + Hash, + Ref.BucketSegment, + Ref.HashKey); + Count.Invalid++; + } + } + } + Count.Total++; + }); + + if ((Count.Valid == Count.Total) || PartialRecord) + { + ZenCacheValue CacheValue; + CacheValue.Value = CacheRecord.GetBuffer().AsIoBuffer(); + CacheValue.Value.SetContentType(ZenContentType::kCbObject); + + if (StoreLocal) + { + m_CacheStore.Put(Ref.Namespace, Ref.BucketSegment, Ref.HashKey, CacheValue); + } + + for (const CbAttachment* Attachment : AttachmentsToStoreLocally) + { + CompressedBuffer Chunk = Attachment->AsCompressedBinary(); + CidStore::InsertResult InsertResult = + m_CidStore.AddChunk(Chunk.GetCompressed().Flatten().AsIoBuffer(), Attachment->GetHash()); + if (InsertResult.New) + { + Count.New++; + } + } + + BinaryWriter MemStream; + if (SkipData) + { + // Save a package containing only the object. + CbPackage(Package.GetObject()).Save(MemStream); + } + else + { + Package.Save(MemStream); + } + + ClientResultValue.Value = IoBuffer(IoBuffer::Clone, MemStream.Data(), MemStream.Size()); + ClientResultValue.Value.SetContentType(ZenContentType::kCbPackage); + } + else + { + Success = false; + ZEN_WARN("Get - '{}/{}' '{}' FAILED, attachments missing in upstream package", + Ref.BucketSegment, + Ref.HashKey, + ToString(AcceptType)); + } + } + else + { + Success = false; + ZEN_WARN("Get - '{}/{}/{}' '{}' FAILED, invalid upstream package", + Ref.Namespace, + Ref.BucketSegment, + Ref.HashKey, + ToString(AcceptType)); + } + } + } + + if (Success) + { + ZEN_DEBUG("GETCACHERECORD HIT - '{}/{}/{}' {} '{}' (UPSTREAM) in {}", + Ref.Namespace, + Ref.BucketSegment, + Ref.HashKey, + NiceBytes(ClientResultValue.Value.Size()), + ToString(ClientResultValue.Value.GetContentType()), + NiceLatencyNs((LocalElapsedTimeUs + Timer.GetElapsedTimeUs()) * 1000)); + + m_CacheStats.HitCount++; + m_CacheStats.UpstreamHitCount++; + + if (SkipData && AcceptType == ZenContentType::kBinary) + { + AsyncRequest.WriteResponse(HttpResponseCode::OK); + } + else + { + // Other methods modify ClientResultValue to a version that has skipped the data but keeps the Object and optionally + // metadata. + AsyncRequest.WriteResponse(HttpResponseCode::OK, ClientResultValue.Value.GetContentType(), ClientResultValue.Value); + } + } + else + { + ZEN_DEBUG("GETCACHERECORD MISS - '{}/{}/{}' '{}' in {}", + Ref.Namespace, + Ref.BucketSegment, + Ref.HashKey, + ToString(AcceptType), + NiceLatencyNs((LocalElapsedTimeUs + Timer.GetElapsedTimeUs()) * 1000)); + m_CacheStats.MissCount++; + AsyncRequest.WriteResponse(HttpResponseCode::NotFound); + } + }); +} + +void +HttpStructuredCacheService::HandlePutCacheRecord(HttpServerRequest& Request, const CacheRef& Ref, CachePolicy PolicyFromUrl) +{ + IoBuffer Body = Request.ReadPayload(); + + if (!Body || Body.Size() == 0) + { + return Request.WriteResponse(HttpResponseCode::BadRequest); + } + + const HttpContentType ContentType = Request.RequestContentType(); + + Body.SetContentType(ContentType); + + Stopwatch Timer; + + if (ContentType == HttpContentType::kBinary || ContentType == HttpContentType::kCompressedBinary) + { + IoHash RawHash = IoHash::Zero; + uint64_t RawSize = Body.GetSize(); + if (ContentType == HttpContentType::kCompressedBinary) + { + if (!CompressedBuffer::ValidateCompressedHeader(Body, RawHash, RawSize)) + { + return Request.WriteResponse(HttpResponseCode::BadRequest, + HttpContentType::kText, + "Payload is not a valid compressed binary"sv); + } + } + else + { + RawHash = IoHash::HashBuffer(SharedBuffer(Body)); + } + m_CacheStore.Put(Ref.Namespace, Ref.BucketSegment, Ref.HashKey, {.Value = Body, .RawSize = RawSize, .RawHash = RawHash}); + + if (EnumHasAllFlags(PolicyFromUrl, CachePolicy::StoreRemote)) + { + m_UpstreamCache.EnqueueUpstream({.Type = ContentType, .Namespace = Ref.Namespace, .Key = {Ref.BucketSegment, Ref.HashKey}}); + } + + ZEN_DEBUG("PUTCACHERECORD - '{}/{}/{}' {} '{}' in {}", + Ref.Namespace, + Ref.BucketSegment, + Ref.HashKey, + NiceBytes(Body.Size()), + ToString(ContentType), + NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000)); + Request.WriteResponse(HttpResponseCode::Created); + } + else if (ContentType == HttpContentType::kCbObject) + { + const CbValidateError ValidationResult = ValidateCompactBinary(MemoryView(Body.GetData(), Body.GetSize()), CbValidateMode::All); + + if (ValidationResult != CbValidateError::None) + { + ZEN_WARN("PUTCACHERECORD - '{}/{}/{}' '{}' FAILED, invalid compact binary", + Ref.Namespace, + Ref.BucketSegment, + Ref.HashKey, + ToString(ContentType)); + return Request.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "Compact binary validation failed"sv); + } + + Body.SetContentType(ZenContentType::kCbObject); + m_CacheStore.Put(Ref.Namespace, Ref.BucketSegment, Ref.HashKey, {.Value = Body}); + + CbObjectView CacheRecord(Body.Data()); + std::vector<IoHash> ValidAttachments; + int32_t TotalCount = 0; + + CacheRecord.IterateAttachments([this, &TotalCount, &ValidAttachments](CbFieldView AttachmentHash) { + const IoHash Hash = AttachmentHash.AsHash(); + if (m_CidStore.ContainsChunk(Hash)) + { + ValidAttachments.emplace_back(Hash); + } + TotalCount++; + }); + + ZEN_DEBUG("PUTCACHERECORD - '{}/{}/{}' {} '{}' attachments '{}/{}' (valid/total) in {}", + Ref.Namespace, + Ref.BucketSegment, + Ref.HashKey, + NiceBytes(Body.Size()), + ToString(ContentType), + TotalCount, + ValidAttachments.size(), + NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000)); + + const bool IsPartialRecord = TotalCount != static_cast<int32_t>(ValidAttachments.size()); + + CachePolicy Policy = PolicyFromUrl; + if (EnumHasAllFlags(Policy, CachePolicy::StoreRemote) && !IsPartialRecord) + { + m_UpstreamCache.EnqueueUpstream({.Type = ZenContentType::kCbObject, + .Namespace = Ref.Namespace, + .Key = {Ref.BucketSegment, Ref.HashKey}, + .ValueContentIds = std::move(ValidAttachments)}); + } + + Request.WriteResponse(HttpResponseCode::Created); + } + else if (ContentType == HttpContentType::kCbPackage) + { + CbPackage Package; + + if (!Package.TryLoad(Body)) + { + ZEN_WARN("PUTCACHERECORD - '{}/{}/{}' '{}' FAILED, invalid package", + Ref.Namespace, + Ref.BucketSegment, + Ref.HashKey, + ToString(ContentType)); + return Request.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "Invalid package"sv); + } + CachePolicy Policy = PolicyFromUrl; + + CbObject CacheRecord = Package.GetObject(); + + AttachmentCount Count; + size_t NumAttachments = Package.GetAttachments().size(); + std::vector<IoHash> ValidAttachments; + std::vector<const CbAttachment*> AttachmentsToStoreLocally; + ValidAttachments.reserve(NumAttachments); + AttachmentsToStoreLocally.reserve(NumAttachments); + + CacheRecord.IterateAttachments([this, &Ref, &Package, &AttachmentsToStoreLocally, &ValidAttachments, &Count](CbFieldView HashView) { + const IoHash Hash = HashView.AsHash(); + if (const CbAttachment* Attachment = Package.FindAttachment(Hash)) + { + if (Attachment->IsCompressedBinary()) + { + AttachmentsToStoreLocally.emplace_back(Attachment); + ValidAttachments.emplace_back(Hash); + Count.Valid++; + } + else + { + ZEN_WARN("PUTCACHERECORD - '{}/{}/{}' '{}' FAILED, attachment '{}' is not compressed", + Ref.Namespace, + Ref.BucketSegment, + Ref.HashKey, + ToString(HttpContentType::kCbPackage), + Hash); + Count.Invalid++; + } + } + else if (m_CidStore.ContainsChunk(Hash)) + { + ValidAttachments.emplace_back(Hash); + Count.Valid++; + } + Count.Total++; + }); + + if (Count.Invalid > 0) + { + return Request.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "Invalid attachment(s)"sv); + } + + ZenCacheValue CacheValue; + CacheValue.Value = CacheRecord.GetBuffer().AsIoBuffer(); + CacheValue.Value.SetContentType(ZenContentType::kCbObject); + m_CacheStore.Put(Ref.Namespace, Ref.BucketSegment, Ref.HashKey, CacheValue); + + for (const CbAttachment* Attachment : AttachmentsToStoreLocally) + { + CompressedBuffer Chunk = Attachment->AsCompressedBinary(); + CidStore::InsertResult InsertResult = m_CidStore.AddChunk(Chunk.GetCompressed().Flatten().AsIoBuffer(), Attachment->GetHash()); + if (InsertResult.New) + { + Count.New++; + } + } + + ZEN_DEBUG("PUTCACHERECORD - '{}/{}/{}' {} '{}', attachments '{}/{}/{}' (new/valid/total) in {}", + Ref.Namespace, + Ref.BucketSegment, + Ref.HashKey, + NiceBytes(Body.GetSize()), + ToString(ContentType), + Count.New, + Count.Valid, + Count.Total, + NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000)); + + const bool IsPartialRecord = Count.Valid != Count.Total; + + if (EnumHasAllFlags(Policy, CachePolicy::StoreRemote) && !IsPartialRecord) + { + m_UpstreamCache.EnqueueUpstream({.Type = ZenContentType::kCbPackage, + .Namespace = Ref.Namespace, + .Key = {Ref.BucketSegment, Ref.HashKey}, + .ValueContentIds = std::move(ValidAttachments)}); + } + + Request.WriteResponse(HttpResponseCode::Created); + } + else + { + return Request.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "Content-Type invalid"sv); + } +} + +void +HttpStructuredCacheService::HandleCacheChunkRequest(HttpServerRequest& Request, const CacheRef& Ref, CachePolicy PolicyFromUrl) +{ + switch (Request.RequestVerb()) + { + case HttpVerb::kHead: + case HttpVerb::kGet: + HandleGetCacheChunk(Request, Ref, PolicyFromUrl); + break; + case HttpVerb::kPut: + HandlePutCacheChunk(Request, Ref, PolicyFromUrl); + break; + default: + break; + } +} + +void +HttpStructuredCacheService::HandleGetCacheChunk(HttpServerRequest& Request, const CacheRef& Ref, CachePolicy PolicyFromUrl) +{ + Stopwatch Timer; + + IoBuffer Value = m_CidStore.FindChunkByCid(Ref.ValueContentId); + const UpstreamEndpointInfo* Source = nullptr; + CachePolicy Policy = PolicyFromUrl; + { + const bool QueryUpstream = !Value && EnumHasAllFlags(Policy, CachePolicy::QueryRemote); + + if (QueryUpstream) + { + if (GetUpstreamCacheSingleResult UpstreamResult = + m_UpstreamCache.GetCacheChunk(Ref.Namespace, {Ref.BucketSegment, Ref.HashKey}, Ref.ValueContentId); + UpstreamResult.Status.Success) + { + IoHash RawHash; + uint64_t RawSize; + if (CompressedBuffer::ValidateCompressedHeader(UpstreamResult.Value, RawHash, RawSize)) + { + if (RawHash == Ref.ValueContentId) + { + m_CidStore.AddChunk(UpstreamResult.Value, RawHash); + Source = UpstreamResult.Source; + } + else + { + ZEN_WARN("got missmatching upstream cache value"); + } + } + else + { + ZEN_WARN("got uncompressed upstream cache value"); + } + } + } + } + + if (!Value) + { + ZEN_DEBUG("GETCACHECHUNK MISS - '{}/{}/{}/{}' '{}' in {}", + Ref.Namespace, + Ref.BucketSegment, + Ref.HashKey, + Ref.ValueContentId, + ToString(Request.AcceptContentType()), + NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000)); + m_CacheStats.MissCount++; + return Request.WriteResponse(HttpResponseCode::NotFound); + } + + ZEN_DEBUG("GETCACHECHUNK HIT - '{}/{}/{}/{}' {} '{}' ({}) in {}", + Ref.Namespace, + Ref.BucketSegment, + Ref.HashKey, + Ref.ValueContentId, + NiceBytes(Value.Size()), + ToString(Value.GetContentType()), + Source ? Source->Url : "LOCAL"sv, + NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000)); + + m_CacheStats.HitCount++; + if (Source) + { + m_CacheStats.UpstreamHitCount++; + } + + if (EnumHasAllFlags(Policy, CachePolicy::SkipData)) + { + Request.WriteResponse(HttpResponseCode::OK); + } + else + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, Value); + } +} + +void +HttpStructuredCacheService::HandlePutCacheChunk(HttpServerRequest& Request, const CacheRef& Ref, CachePolicy PolicyFromUrl) +{ + // Note: Individual cacherecord values are not propagated upstream until a valid cache record has been stored + ZEN_UNUSED(PolicyFromUrl); + + Stopwatch Timer; + + IoBuffer Body = Request.ReadPayload(); + + if (!Body || Body.Size() == 0) + { + return Request.WriteResponse(HttpResponseCode::BadRequest); + } + + Body.SetContentType(Request.RequestContentType()); + + IoHash RawHash; + uint64_t RawSize; + if (!CompressedBuffer::ValidateCompressedHeader(Body, RawHash, RawSize)) + { + return Request.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "Attachments must be compressed"sv); + } + + if (RawHash != Ref.ValueContentId) + { + return Request.WriteResponse(HttpResponseCode::BadRequest, + HttpContentType::kText, + "ValueContentId does not match attachment hash"sv); + } + + CidStore::InsertResult Result = m_CidStore.AddChunk(Body, RawHash); + + ZEN_DEBUG("PUTCACHECHUNK - '{}/{}/{}/{}' {} '{}' ({}) in {}", + Ref.Namespace, + Ref.BucketSegment, + Ref.HashKey, + Ref.ValueContentId, + NiceBytes(Body.Size()), + ToString(Body.GetContentType()), + Result.New ? "NEW" : "OLD", + NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000)); + + const HttpResponseCode ResponseCode = Result.New ? HttpResponseCode::Created : HttpResponseCode::OK; + + Request.WriteResponse(ResponseCode); +} + +CbPackage +HttpStructuredCacheService::HandleRpcRequest(const ZenContentType ContentType, + IoBuffer&& Body, + uint32_t& OutAcceptMagic, + RpcAcceptOptions& OutAcceptFlags, + int& OutTargetProcessId) +{ + CbPackage Package; + CbObjectView Object; + CbObject ObjectBuffer; + if (ContentType == ZenContentType::kCbObject) + { + ObjectBuffer = LoadCompactBinaryObject(std::move(Body)); + Object = ObjectBuffer; + } + else + { + Package = ParsePackageMessage(Body); + Object = Package.GetObject(); + } + OutAcceptMagic = Object["Accept"sv].AsUInt32(); + OutAcceptFlags = static_cast<RpcAcceptOptions>(Object["AcceptFlags"sv].AsUInt16(0u)); + OutTargetProcessId = Object["Pid"sv].AsInt32(0); + + const std::string_view Method = Object["Method"sv].AsString(); + + if (Method == "PutCacheRecords"sv) + { + return HandleRpcPutCacheRecords(Package); + } + else if (Method == "GetCacheRecords"sv) + { + return HandleRpcGetCacheRecords(Object); + } + else if (Method == "PutCacheValues"sv) + { + return HandleRpcPutCacheValues(Package); + } + else if (Method == "GetCacheValues"sv) + { + return HandleRpcGetCacheValues(Object); + } + else if (Method == "GetCacheChunks"sv) + { + return HandleRpcGetCacheChunks(Object); + } + return CbPackage{}; +} + +void +HttpStructuredCacheService::ReplayRequestRecorder(cache::IRpcRequestReplayer& Replayer, uint32_t ThreadCount) +{ + WorkerThreadPool WorkerPool(ThreadCount); + uint64_t RequestCount = Replayer.GetRequestCount(); + Stopwatch Timer; + auto _ = MakeGuard([&]() { ZEN_INFO("Replayed {} requests in {}", RequestCount, NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000)); }); + Latch JobLatch(RequestCount); + ZEN_INFO("Replaying {} requests", RequestCount); + for (uint64_t RequestIndex = 0; RequestIndex < RequestCount; ++RequestIndex) + { + WorkerPool.ScheduleWork([this, &JobLatch, &Replayer, RequestIndex]() { + IoBuffer Body; + std::pair<ZenContentType, ZenContentType> ContentType = Replayer.GetRequest(RequestIndex, Body); + if (Body) + { + uint32_t AcceptMagic = 0; + RpcAcceptOptions AcceptFlags = RpcAcceptOptions::kNone; + int TargetPid = 0; + CbPackage RpcResult = HandleRpcRequest(ContentType.first, std::move(Body), AcceptMagic, AcceptFlags, TargetPid); + if (AcceptMagic == kCbPkgMagic) + { + FormatFlags Flags = FormatFlags::kDefault; + if (EnumHasAllFlags(AcceptFlags, RpcAcceptOptions::kAllowLocalReferences)) + { + Flags |= FormatFlags::kAllowLocalReferences; + if (!EnumHasAnyFlags(AcceptFlags, RpcAcceptOptions::kAllowPartialLocalReferences)) + { + Flags |= FormatFlags::kDenyPartialLocalReferences; + } + } + CompositeBuffer RpcResponseBuffer = FormatPackageMessageBuffer(RpcResult, Flags, TargetPid); + ZEN_ASSERT(RpcResponseBuffer.GetSize() > 0); + } + else + { + BinaryWriter MemStream; + RpcResult.Save(MemStream); + IoBuffer RpcResponseBuffer(IoBuffer::Wrap, MemStream.GetData(), MemStream.GetSize()); + ZEN_ASSERT(RpcResponseBuffer.Size() > 0); + } + } + JobLatch.CountDown(); + }); + } + while (!JobLatch.Wait(10000)) + { + ZEN_INFO("Replayed {} of {} requests, elapsed {}", + RequestCount - JobLatch.Remaining(), + RequestCount, + NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000)); + } +} + +void +HttpStructuredCacheService::HandleRpcRequest(HttpServerRequest& Request) +{ + switch (Request.RequestVerb()) + { + case HttpVerb::kPost: + { + const HttpContentType ContentType = Request.RequestContentType(); + const HttpContentType AcceptType = Request.AcceptContentType(); + + if ((ContentType != HttpContentType::kCbObject && ContentType != HttpContentType::kCbPackage) || + AcceptType != HttpContentType::kCbPackage) + { + return Request.WriteResponse(HttpResponseCode::BadRequest); + } + + Request.WriteResponseAsync( + [this, Body = Request.ReadPayload(), ContentType, AcceptType](HttpServerRequest& AsyncRequest) mutable { + std::uint64_t RequestIndex = + m_RequestRecorder ? m_RequestRecorder->RecordRequest(ContentType, AcceptType, Body) : ~0ull; + uint32_t AcceptMagic = 0; + RpcAcceptOptions AcceptFlags = RpcAcceptOptions::kNone; + int TargetProcessId = 0; + CbPackage RpcResult = HandleRpcRequest(ContentType, std::move(Body), AcceptMagic, AcceptFlags, TargetProcessId); + if (RpcResult.IsNull()) + { + AsyncRequest.WriteResponse(HttpResponseCode::BadRequest); + return; + } + if (AcceptMagic == kCbPkgMagic) + { + FormatFlags Flags = FormatFlags::kDefault; + if (EnumHasAllFlags(AcceptFlags, RpcAcceptOptions::kAllowLocalReferences)) + { + Flags |= FormatFlags::kAllowLocalReferences; + if (!EnumHasAnyFlags(AcceptFlags, RpcAcceptOptions::kAllowPartialLocalReferences)) + { + Flags |= FormatFlags::kDenyPartialLocalReferences; + } + } + CompositeBuffer RpcResponseBuffer = FormatPackageMessageBuffer(RpcResult, Flags, TargetProcessId); + if (RequestIndex != ~0ull) + { + ZEN_ASSERT(m_RequestRecorder); + m_RequestRecorder->RecordResponse(RequestIndex, HttpContentType::kCbPackage, RpcResponseBuffer); + } + AsyncRequest.WriteResponse(HttpResponseCode::OK, HttpContentType::kCbPackage, RpcResponseBuffer); + } + else + { + BinaryWriter MemStream; + RpcResult.Save(MemStream); + + if (RequestIndex != ~0ull) + { + ZEN_ASSERT(m_RequestRecorder); + m_RequestRecorder->RecordResponse(RequestIndex, + HttpContentType::kCbPackage, + IoBuffer(IoBuffer::Wrap, MemStream.GetData(), MemStream.GetSize())); + } + AsyncRequest.WriteResponse(HttpResponseCode::OK, + HttpContentType::kCbPackage, + IoBuffer(IoBuffer::Wrap, MemStream.GetData(), MemStream.GetSize())); + } + }); + } + break; + default: + Request.WriteResponse(HttpResponseCode::BadRequest); + break; + } +} + +CbPackage +HttpStructuredCacheService::HandleRpcPutCacheRecords(const CbPackage& BatchRequest) +{ + ZEN_TRACE_CPU("Z$::RpcPutCacheRecords"); + CbObjectView BatchObject = BatchRequest.GetObject(); + ZEN_ASSERT(BatchObject["Method"sv].AsString() == "PutCacheRecords"sv); + + CbObjectView Params = BatchObject["Params"sv].AsObjectView(); + CachePolicy DefaultPolicy; + + std::string_view PolicyText = Params["DefaultPolicy"].AsString(); + std::optional<std::string> Namespace = GetRpcRequestNamespace(Params); + if (!Namespace) + { + return CbPackage{}; + } + DefaultPolicy = !PolicyText.empty() ? ParseCachePolicy(PolicyText) : CachePolicy::Default; + std::vector<bool> Results; + for (CbFieldView RequestField : Params["Requests"sv]) + { + CbObjectView RequestObject = RequestField.AsObjectView(); + CbObjectView RecordObject = RequestObject["Record"sv].AsObjectView(); + CbObjectView KeyView = RecordObject["Key"sv].AsObjectView(); + + CacheKey Key; + if (!GetRpcRequestCacheKey(KeyView, Key)) + { + return CbPackage{}; + } + CacheRecordPolicy Policy = LoadCacheRecordPolicy(RequestObject["Policy"sv].AsObjectView(), DefaultPolicy); + PutRequestData PutRequest{*Namespace, std::move(Key), RecordObject, std::move(Policy)}; + + PutResult Result = PutCacheRecord(PutRequest, &BatchRequest); + + if (Result == PutResult::Invalid) + { + return CbPackage{}; + } + Results.push_back(Result == PutResult::Success); + } + if (Results.empty()) + { + return CbPackage{}; + } + + CbObjectWriter ResponseObject; + ResponseObject.BeginArray("Result"sv); + for (bool Value : Results) + { + ResponseObject.AddBool(Value); + } + ResponseObject.EndArray(); + + CbPackage RpcResponse; + RpcResponse.SetObject(ResponseObject.Save()); + return RpcResponse; +} + +HttpStructuredCacheService::PutResult +HttpStructuredCacheService::PutCacheRecord(PutRequestData& Request, const CbPackage* Package) +{ + CbObjectView Record = Request.RecordObject; + uint64_t RecordObjectSize = Record.GetSize(); + uint64_t TransferredSize = RecordObjectSize; + + AttachmentCount Count; + size_t NumAttachments = Package->GetAttachments().size(); + std::vector<IoHash> ValidAttachments; + std::vector<const CbAttachment*> AttachmentsToStoreLocally; + ValidAttachments.reserve(NumAttachments); + AttachmentsToStoreLocally.reserve(NumAttachments); + + Stopwatch Timer; + + Request.RecordObject.IterateAttachments( + [this, &Request, Package, &AttachmentsToStoreLocally, &ValidAttachments, &Count, &TransferredSize](CbFieldView HashView) { + const IoHash ValueHash = HashView.AsHash(); + if (const CbAttachment* Attachment = Package ? Package->FindAttachment(ValueHash) : nullptr) + { + if (Attachment->IsCompressedBinary()) + { + AttachmentsToStoreLocally.emplace_back(Attachment); + ValidAttachments.emplace_back(ValueHash); + Count.Valid++; + } + else + { + ZEN_WARN("PUTCACEHRECORD - '{}/{}/{}' '{}' FAILED, attachment '{}' is not compressed", + Request.Namespace, + Request.Key.Bucket, + Request.Key.Hash, + ToString(HttpContentType::kCbPackage), + ValueHash); + Count.Invalid++; + } + } + else if (m_CidStore.ContainsChunk(ValueHash)) + { + ValidAttachments.emplace_back(ValueHash); + Count.Valid++; + } + Count.Total++; + }); + + if (Count.Invalid > 0) + { + return PutResult::Invalid; + } + + ZenCacheValue CacheValue; + CacheValue.Value = IoBuffer(Record.GetSize()); + Record.CopyTo(MutableMemoryView(CacheValue.Value.MutableData(), CacheValue.Value.GetSize())); + CacheValue.Value.SetContentType(ZenContentType::kCbObject); + m_CacheStore.Put(Request.Namespace, Request.Key.Bucket, Request.Key.Hash, CacheValue); + + for (const CbAttachment* Attachment : AttachmentsToStoreLocally) + { + CompressedBuffer Chunk = Attachment->AsCompressedBinary(); + CidStore::InsertResult InsertResult = m_CidStore.AddChunk(Chunk.GetCompressed().Flatten().AsIoBuffer(), Attachment->GetHash()); + if (InsertResult.New) + { + Count.New++; + } + TransferredSize += Chunk.GetCompressedSize(); + } + + ZEN_DEBUG("PUTCACEHRECORD - '{}/{}/{}' {}, attachments '{}/{}/{}' (new/valid/total) in {}", + Request.Namespace, + Request.Key.Bucket, + Request.Key.Hash, + NiceBytes(TransferredSize), + Count.New, + Count.Valid, + Count.Total, + NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000)); + + const bool IsPartialRecord = Count.Valid != Count.Total; + + if (EnumHasAllFlags(Request.Policy.GetRecordPolicy(), CachePolicy::StoreRemote) && !IsPartialRecord) + { + m_UpstreamCache.EnqueueUpstream({.Type = ZenContentType::kCbPackage, + .Namespace = Request.Namespace, + .Key = Request.Key, + .ValueContentIds = std::move(ValidAttachments)}); + } + return PutResult::Success; +} + +CbPackage +HttpStructuredCacheService::HandleRpcGetCacheRecords(CbObjectView RpcRequest) +{ + ZEN_TRACE_CPU("Z$::RpcGetCacheRecords"); + + ZEN_ASSERT(RpcRequest["Method"sv].AsString() == "GetCacheRecords"sv); + + CbObjectView Params = RpcRequest["Params"sv].AsObjectView(); + + struct ValueRequestData + { + Oid ValueId; + IoHash ContentId; + CompressedBuffer Payload; + CachePolicy DownstreamPolicy; + bool Exists = false; + bool ReadFromUpstream = false; + }; + struct RecordRequestData + { + CacheKeyRequest Upstream; + CbObjectView RecordObject; + IoBuffer RecordCacheValue; + CacheRecordPolicy DownstreamPolicy; + std::vector<ValueRequestData> Values; + bool Complete = false; + const UpstreamEndpointInfo* Source = nullptr; + uint64_t ElapsedTimeUs; + }; + + std::string_view PolicyText = Params["DefaultPolicy"sv].AsString(); + CachePolicy DefaultPolicy = !PolicyText.empty() ? ParseCachePolicy(PolicyText) : CachePolicy::Default; + std::optional<std::string> Namespace = GetRpcRequestNamespace(Params); + if (!Namespace) + { + return CbPackage{}; + } + std::vector<RecordRequestData> Requests; + std::vector<size_t> UpstreamIndexes; + CbArrayView RequestsArray = Params["Requests"sv].AsArrayView(); + Requests.reserve(RequestsArray.Num()); + + auto ParseValues = [](RecordRequestData& Request) { + CbArrayView ValuesArray = Request.RecordObject["Values"sv].AsArrayView(); + Request.Values.reserve(ValuesArray.Num()); + for (CbFieldView ValueField : ValuesArray) + { + CbObjectView ValueObject = ValueField.AsObjectView(); + Oid ValueId = ValueObject["Id"sv].AsObjectId(); + CbFieldView RawHashField = ValueObject["RawHash"sv]; + IoHash RawHash = RawHashField.AsBinaryAttachment(); + if (ValueId && !RawHashField.HasError()) + { + Request.Values.push_back({ValueId, RawHash}); + Request.Values.back().DownstreamPolicy = Request.DownstreamPolicy.GetValuePolicy(ValueId); + } + } + }; + + for (CbFieldView RequestField : RequestsArray) + { + Stopwatch Timer; + RecordRequestData& Request = Requests.emplace_back(); + CbObjectView RequestObject = RequestField.AsObjectView(); + CbObjectView KeyObject = RequestObject["Key"sv].AsObjectView(); + + CacheKey& Key = Request.Upstream.Key; + if (!GetRpcRequestCacheKey(KeyObject, Key)) + { + return CbPackage{}; + } + + Request.DownstreamPolicy = LoadCacheRecordPolicy(RequestObject["Policy"sv].AsObjectView(), DefaultPolicy); + const CacheRecordPolicy& Policy = Request.DownstreamPolicy; + + ZenCacheValue CacheValue; + bool NeedUpstreamAttachment = false; + bool FoundLocalInvalid = false; + ZenCacheValue RecordCacheValue; + + if (EnumHasAllFlags(Policy.GetRecordPolicy(), CachePolicy::QueryLocal) && + m_CacheStore.Get(*Namespace, Key.Bucket, Key.Hash, RecordCacheValue)) + { + Request.RecordCacheValue = std::move(RecordCacheValue.Value); + if (Request.RecordCacheValue.GetContentType() != ZenContentType::kCbObject) + { + FoundLocalInvalid = true; + } + else + { + Request.RecordObject = CbObjectView(Request.RecordCacheValue.GetData()); + ParseValues(Request); + + Request.Complete = true; + for (ValueRequestData& Value : Request.Values) + { + CachePolicy ValuePolicy = Value.DownstreamPolicy; + if (!EnumHasAllFlags(ValuePolicy, CachePolicy::QueryLocal)) + { + // A value that is requested without the Query flag (such as None/Disable) counts as existing, because we + // didn't ask for it and thus the record is complete in its absence. + if (!EnumHasAllFlags(ValuePolicy, CachePolicy::QueryRemote)) + { + Value.Exists = true; + } + else + { + NeedUpstreamAttachment = true; + Value.ReadFromUpstream = true; + Request.Complete = false; + } + } + else if (EnumHasAllFlags(ValuePolicy, CachePolicy::SkipData)) + { + if (m_CidStore.ContainsChunk(Value.ContentId)) + { + Value.Exists = true; + } + else + { + if (EnumHasAllFlags(ValuePolicy, CachePolicy::QueryRemote)) + { + NeedUpstreamAttachment = true; + Value.ReadFromUpstream = true; + } + Request.Complete = false; + } + } + else + { + if (IoBuffer Chunk = m_CidStore.FindChunkByCid(Value.ContentId)) + { + ZEN_ASSERT(Chunk.GetSize() > 0); + Value.Payload = CompressedBuffer::FromCompressedNoValidate(std::move(Chunk)); + Value.Exists = true; + } + else + { + if (EnumHasAllFlags(ValuePolicy, CachePolicy::QueryRemote)) + { + NeedUpstreamAttachment = true; + Value.ReadFromUpstream = true; + } + Request.Complete = false; + } + } + } + } + } + if (!Request.Complete) + { + bool NeedUpstreamRecord = + !Request.RecordObject && !FoundLocalInvalid && EnumHasAllFlags(Policy.GetRecordPolicy(), CachePolicy::QueryRemote); + if (NeedUpstreamRecord || NeedUpstreamAttachment) + { + UpstreamIndexes.push_back(Requests.size() - 1); + } + } + Request.ElapsedTimeUs = Timer.GetElapsedTimeUs(); + } + if (Requests.empty()) + { + return CbPackage{}; + } + + if (!UpstreamIndexes.empty()) + { + std::vector<CacheKeyRequest*> UpstreamRequests; + UpstreamRequests.reserve(UpstreamIndexes.size()); + for (size_t Index : UpstreamIndexes) + { + RecordRequestData& Request = Requests[Index]; + UpstreamRequests.push_back(&Request.Upstream); + + if (Request.Values.size()) + { + // We will be returning the local object and know all the value Ids that exist in it + // Convert all their Downstream Values to upstream values, and add SkipData to any ones that we already have. + CachePolicy UpstreamBasePolicy = ConvertToUpstream(Request.DownstreamPolicy.GetBasePolicy()) | CachePolicy::SkipMeta; + CacheRecordPolicyBuilder Builder(UpstreamBasePolicy); + for (ValueRequestData& Value : Request.Values) + { + CachePolicy UpstreamPolicy = ConvertToUpstream(Value.DownstreamPolicy); + UpstreamPolicy |= !Value.ReadFromUpstream ? CachePolicy::SkipData : CachePolicy::None; + Builder.AddValuePolicy(Value.ValueId, UpstreamPolicy); + } + Request.Upstream.Policy = Builder.Build(); + } + else + { + // We don't know which Values exist in the Record; ask the upstrem for all values that the client wants, + // and convert the CacheRecordPolicy to an upstream policy + Request.Upstream.Policy = Request.DownstreamPolicy.ConvertToUpstream(); + } + } + + const auto OnCacheRecordGetComplete = [this, Namespace, &ParseValues](CacheRecordGetCompleteParams&& Params) { + if (!Params.Record) + { + return; + } + + RecordRequestData& Request = + *reinterpret_cast<RecordRequestData*>(reinterpret_cast<char*>(&Params.Request) - offsetof(RecordRequestData, Upstream)); + Request.ElapsedTimeUs += static_cast<uint64_t>(Params.ElapsedSeconds * 1000000.0); + const CacheKey& Key = Request.Upstream.Key; + Stopwatch Timer; + auto TimeGuard = MakeGuard([&Timer, &Request]() { Request.ElapsedTimeUs += Timer.GetElapsedTimeUs(); }); + if (!Request.RecordObject) + { + CbObject ObjectBuffer = CbObject::Clone(Params.Record); + Request.RecordCacheValue = ObjectBuffer.GetBuffer().AsIoBuffer(); + Request.RecordCacheValue.SetContentType(ZenContentType::kCbObject); + Request.RecordObject = ObjectBuffer; + if (EnumHasAllFlags(Request.DownstreamPolicy.GetRecordPolicy(), CachePolicy::StoreLocal)) + { + m_CacheStore.Put(*Namespace, Key.Bucket, Key.Hash, {.Value = {Request.RecordCacheValue}}); + } + ParseValues(Request); + Request.Source = Params.Source; + } + + Request.Complete = true; + for (ValueRequestData& Value : Request.Values) + { + if (Value.Exists) + { + continue; + } + CachePolicy ValuePolicy = Value.DownstreamPolicy; + if (!EnumHasAllFlags(ValuePolicy, CachePolicy::QueryRemote)) + { + Request.Complete = false; + continue; + } + if (!EnumHasAllFlags(ValuePolicy, CachePolicy::SkipData) || EnumHasAllFlags(ValuePolicy, CachePolicy::StoreLocal)) + { + if (const CbAttachment* Attachment = Params.Package.FindAttachment(Value.ContentId)) + { + if (CompressedBuffer Compressed = Attachment->AsCompressedBinary()) + { + Request.Source = Params.Source; + Value.Exists = true; + if (EnumHasAllFlags(ValuePolicy, CachePolicy::StoreLocal)) + { + m_CidStore.AddChunk(Compressed.GetCompressed().Flatten().AsIoBuffer(), Attachment->GetHash()); + } + if (!EnumHasAllFlags(ValuePolicy, CachePolicy::SkipData)) + { + Value.Payload = Compressed; + } + } + else + { + ZEN_DEBUG("Uncompressed value '{}' from upstream cache record '{}/{}/{}'", + Value.ContentId, + *Namespace, + Key.Bucket, + Key.Hash); + } + } + if (!Value.Exists && !EnumHasAllFlags(ValuePolicy, CachePolicy::SkipData)) + { + Request.Complete = false; + } + // Request.Complete does not need to be set to false for upstream SkipData attachments. + // In the PartialRecord==false case, the upstream will have failed the entire record if any SkipData attachment + // didn't exist and we will not get here. In the PartialRecord==true case, we do not need to inform the client of + // any missing SkipData attachments. + } + Request.ElapsedTimeUs += Timer.GetElapsedTimeUs(); + } + }; + + m_UpstreamCache.GetCacheRecords(*Namespace, UpstreamRequests, std::move(OnCacheRecordGetComplete)); + } + + CbPackage ResponsePackage; + CbObjectWriter ResponseObject; + + ResponseObject.BeginArray("Result"sv); + for (RecordRequestData& Request : Requests) + { + const CacheKey& Key = Request.Upstream.Key; + if (Request.Complete || + (Request.RecordObject && EnumHasAllFlags(Request.DownstreamPolicy.GetRecordPolicy(), CachePolicy::PartialRecord))) + { + ResponseObject << Request.RecordObject; + for (ValueRequestData& Value : Request.Values) + { + if (!EnumHasAllFlags(Value.DownstreamPolicy, CachePolicy::SkipData) && Value.Payload) + { + ResponsePackage.AddAttachment(CbAttachment(Value.Payload, Value.ContentId)); + } + } + + ZEN_DEBUG("GETCACHERECORD HIT - '{}/{}/{}' {}{} ({}) in {}", + *Namespace, + Key.Bucket, + Key.Hash, + NiceBytes(Request.RecordCacheValue.Size()), + Request.Complete ? ""sv : " (PARTIAL)"sv, + Request.Source ? Request.Source->Url : "LOCAL"sv, + NiceLatencyNs(Request.ElapsedTimeUs * 1000)); + m_CacheStats.HitCount++; + m_CacheStats.UpstreamHitCount += Request.Source ? 1 : 0; + } + else + { + ResponseObject.AddNull(); + + if (!EnumHasAnyFlags(Request.DownstreamPolicy.GetRecordPolicy(), CachePolicy::Query)) + { + // If they requested no query, do not record this as a miss + ZEN_DEBUG("GETCACHERECORD DISABLEDQUERY - '{}/{}/{}' in {}", + *Namespace, + Key.Bucket, + Key.Hash, + NiceLatencyNs(Request.ElapsedTimeUs * 1000)); + } + else + { + ZEN_DEBUG("GETCACHERECORD MISS - '{}/{}/{}'{} ({}) in {}", + *Namespace, + Key.Bucket, + Key.Hash, + Request.RecordObject ? ""sv : " (PARTIAL)"sv, + Request.Source ? Request.Source->Url : "LOCAL"sv, + NiceLatencyNs(Request.ElapsedTimeUs * 1000)); + m_CacheStats.MissCount++; + } + } + } + ResponseObject.EndArray(); + ResponsePackage.SetObject(ResponseObject.Save()); + return ResponsePackage; +} + +CbPackage +HttpStructuredCacheService::HandleRpcPutCacheValues(const CbPackage& BatchRequest) +{ + CbObjectView BatchObject = BatchRequest.GetObject(); + CbObjectView Params = BatchObject["Params"sv].AsObjectView(); + + std::string_view PolicyText = Params["DefaultPolicy"].AsString(); + CachePolicy DefaultPolicy = !PolicyText.empty() ? ParseCachePolicy(PolicyText) : CachePolicy::Default; + std::optional<std::string> Namespace = GetRpcRequestNamespace(Params); + if (!Namespace) + { + return CbPackage{}; + } + std::vector<bool> Results; + for (CbFieldView RequestField : Params["Requests"sv]) + { + Stopwatch Timer; + + CbObjectView RequestObject = RequestField.AsObjectView(); + CbObjectView KeyView = RequestObject["Key"sv].AsObjectView(); + + CacheKey Key; + if (!GetRpcRequestCacheKey(KeyView, Key)) + { + return CbPackage{}; + } + + PolicyText = RequestObject["Policy"sv].AsString(); + CachePolicy Policy = !PolicyText.empty() ? ParseCachePolicy(PolicyText) : DefaultPolicy; + IoHash RawHash = RequestObject["RawHash"sv].AsBinaryAttachment(); + uint64_t RawSize = RequestObject["RawSize"sv].AsUInt64(); + bool Succeeded = false; + uint64_t TransferredSize = 0; + + if (const CbAttachment* Attachment = BatchRequest.FindAttachment(RawHash)) + { + if (Attachment->IsCompressedBinary()) + { + CompressedBuffer Chunk = Attachment->AsCompressedBinary(); + if (EnumHasAllFlags(Policy, CachePolicy::StoreRemote)) + { + // TODO: Implement upstream puts of CacheValues with StoreLocal == false. + // Currently ProcessCacheRecord requires that the value exist in the local cache to put it upstream. + Policy |= CachePolicy::StoreLocal; + } + + if (EnumHasAllFlags(Policy, CachePolicy::StoreLocal)) + { + IoBuffer Value = Chunk.GetCompressed().Flatten().AsIoBuffer(); + Value.SetContentType(ZenContentType::kCompressedBinary); + if (RawSize == 0) + { + RawSize = Chunk.DecodeRawSize(); + } + m_CacheStore.Put(*Namespace, Key.Bucket, Key.Hash, {.Value = Value, .RawSize = RawSize, .RawHash = RawHash}); + TransferredSize = Chunk.GetCompressedSize(); + } + Succeeded = true; + } + else + { + ZEN_WARN("PUTCACHEVALUES - '{}/{}/{}/{}' FAILED, value is not compressed", *Namespace, Key.Bucket, Key.Hash, RawHash); + return CbPackage{}; + } + } + else if (EnumHasAllFlags(Policy, CachePolicy::QueryLocal)) + { + ZenCacheValue ExistingValue; + if (m_CacheStore.Get(*Namespace, Key.Bucket, Key.Hash, ExistingValue) && + IsCompressedBinary(ExistingValue.Value.GetContentType())) + { + Succeeded = true; + } + } + // We do not search the Upstream. No data in a put means the caller is probing for whether they need to do a heavy put. + // If it doesn't exist locally they should do the heavy put rather than having us fetch it from upstream. + + if (Succeeded && EnumHasAllFlags(Policy, CachePolicy::StoreRemote)) + { + m_UpstreamCache.EnqueueUpstream({.Type = ZenContentType::kCompressedBinary, .Namespace = *Namespace, .Key = Key}); + } + Results.push_back(Succeeded); + ZEN_DEBUG("PUTCACHEVALUES - '{}/{}/{}' {}, '{}' in {}", + *Namespace, + Key.Bucket, + Key.Hash, + NiceBytes(TransferredSize), + Succeeded ? "Added"sv : "Invalid", + NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000)); + } + if (Results.empty()) + { + return CbPackage{}; + } + + CbObjectWriter ResponseObject; + ResponseObject.BeginArray("Result"sv); + for (bool Value : Results) + { + ResponseObject.AddBool(Value); + } + ResponseObject.EndArray(); + + CbPackage RpcResponse; + RpcResponse.SetObject(ResponseObject.Save()); + + return RpcResponse; +} + +CbPackage +HttpStructuredCacheService::HandleRpcGetCacheValues(CbObjectView RpcRequest) +{ + ZEN_ASSERT(RpcRequest["Method"sv].AsString() == "GetCacheValues"sv); + + CbObjectView Params = RpcRequest["Params"sv].AsObjectView(); + std::string_view PolicyText = Params["DefaultPolicy"sv].AsString(); + CachePolicy DefaultPolicy = !PolicyText.empty() ? ParseCachePolicy(PolicyText) : CachePolicy::Default; + std::optional<std::string> Namespace = GetRpcRequestNamespace(Params); + if (!Namespace) + { + return CbPackage{}; + } + + struct RequestData + { + CacheKey Key; + CachePolicy Policy; + IoHash RawHash = IoHash::Zero; + uint64_t RawSize = 0; + CompressedBuffer Result; + }; + std::vector<RequestData> Requests; + + std::vector<size_t> RemoteRequestIndexes; + + for (CbFieldView RequestField : Params["Requests"sv]) + { + Stopwatch Timer; + + RequestData& Request = Requests.emplace_back(); + CbObjectView RequestObject = RequestField.AsObjectView(); + CbObjectView KeyObject = RequestObject["Key"sv].AsObjectView(); + + if (!GetRpcRequestCacheKey(KeyObject, Request.Key)) + { + return CbPackage{}; + } + + PolicyText = RequestObject["Policy"sv].AsString(); + Request.Policy = !PolicyText.empty() ? ParseCachePolicy(PolicyText) : DefaultPolicy; + + CacheKey& Key = Request.Key; + CachePolicy Policy = Request.Policy; + + ZenCacheValue CacheValue; + if (EnumHasAllFlags(Policy, CachePolicy::QueryLocal)) + { + if (m_CacheStore.Get(*Namespace, Key.Bucket, Key.Hash, CacheValue) && IsCompressedBinary(CacheValue.Value.GetContentType())) + { + Request.RawHash = CacheValue.RawHash; + Request.RawSize = CacheValue.RawSize; + Request.Result = CompressedBuffer::FromCompressedNoValidate(std::move(CacheValue.Value)); + } + } + if (Request.Result) + { + ZEN_DEBUG("GETCACHEVALUES HIT - '{}/{}/{}' {} ({}) in {}", + *Namespace, + Key.Bucket, + Key.Hash, + NiceBytes(Request.Result.GetCompressed().GetSize()), + "LOCAL"sv, + NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000)); + m_CacheStats.HitCount++; + } + else if (EnumHasAllFlags(Policy, CachePolicy::QueryRemote)) + { + RemoteRequestIndexes.push_back(Requests.size() - 1); + } + else if (!EnumHasAnyFlags(Policy, CachePolicy::Query)) + { + // If they requested no query, do not record this as a miss + ZEN_DEBUG("GETCACHEVALUES DISABLEDQUERY - '{}/{}/{}'", *Namespace, Key.Bucket, Key.Hash); + } + else + { + ZEN_DEBUG("GETCACHEVALUES MISS - '{}/{}/{}' ({}) in {}", + *Namespace, + Key.Bucket, + Key.Hash, + "LOCAL"sv, + NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000)); + m_CacheStats.MissCount++; + } + } + + if (!RemoteRequestIndexes.empty()) + { + std::vector<CacheValueRequest> RequestedRecordsData; + std::vector<CacheValueRequest*> CacheValueRequests; + RequestedRecordsData.reserve(RemoteRequestIndexes.size()); + CacheValueRequests.reserve(RemoteRequestIndexes.size()); + for (size_t Index : RemoteRequestIndexes) + { + RequestData& Request = Requests[Index]; + RequestedRecordsData.push_back({.Key = {Request.Key.Bucket, Request.Key.Hash}, .Policy = ConvertToUpstream(Request.Policy)}); + CacheValueRequests.push_back(&RequestedRecordsData.back()); + } + Stopwatch Timer; + m_UpstreamCache.GetCacheValues( + *Namespace, + CacheValueRequests, + [this, Namespace, &RequestedRecordsData, &Requests, &RemoteRequestIndexes, &Timer](CacheValueGetCompleteParams&& Params) { + CacheValueRequest& ChunkRequest = Params.Request; + if (Params.RawHash != IoHash::Zero) + { + size_t RequestOffset = std::distance(RequestedRecordsData.data(), &ChunkRequest); + size_t RequestIndex = RemoteRequestIndexes[RequestOffset]; + RequestData& Request = Requests[RequestIndex]; + Request.RawHash = Params.RawHash; + Request.RawSize = Params.RawSize; + const bool HasData = IsCompressedBinary(Params.Value.GetContentType()); + const bool SkipData = EnumHasAllFlags(Request.Policy, CachePolicy::SkipData); + const bool StoreData = EnumHasAllFlags(Request.Policy, CachePolicy::StoreLocal); + const bool IsHit = SkipData || HasData; + if (IsHit) + { + if (HasData && !SkipData) + { + Request.Result = CompressedBuffer::FromCompressedNoValidate(IoBuffer(Params.Value)); + } + + if (HasData && StoreData) + { + m_CacheStore.Put(*Namespace, + Request.Key.Bucket, + Request.Key.Hash, + ZenCacheValue{.Value = Params.Value, .RawSize = Request.RawSize, .RawHash = Request.RawHash}); + } + + ZEN_DEBUG("GETCACHEVALUES HIT - '{}/{}/{}' {} ({}) in {}", + *Namespace, + ChunkRequest.Key.Bucket, + ChunkRequest.Key.Hash, + NiceBytes(Request.Result.GetCompressed().GetSize()), + Params.Source ? Params.Source->Url : "UPSTREAM", + NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000)); + m_CacheStats.HitCount++; + m_CacheStats.UpstreamHitCount++; + return; + } + } + ZEN_DEBUG("GETCACHEVALUES MISS - '{}/{}/{}' ({}) in {}", + *Namespace, + ChunkRequest.Key.Bucket, + ChunkRequest.Key.Hash, + Params.Source ? Params.Source->Url : "UPSTREAM", + NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000)); + m_CacheStats.MissCount++; + }); + } + + if (Requests.empty()) + { + return CbPackage{}; + } + + CbPackage RpcResponse; + CbObjectWriter ResponseObject; + ResponseObject.BeginArray("Result"sv); + for (const RequestData& Request : Requests) + { + ResponseObject.BeginObject(); + { + const CompressedBuffer& Result = Request.Result; + if (Result) + { + ResponseObject.AddHash("RawHash"sv, Request.RawHash); + if (!EnumHasAllFlags(Request.Policy, CachePolicy::SkipData)) + { + RpcResponse.AddAttachment(CbAttachment(Result, Request.RawHash)); + } + else + { + ResponseObject.AddInteger("RawSize"sv, Request.RawSize); + } + } + else if (Request.RawHash != IoHash::Zero) + { + ResponseObject.AddHash("RawHash"sv, Request.RawHash); + ResponseObject.AddInteger("RawSize"sv, Request.RawSize); + } + } + ResponseObject.EndObject(); + } + ResponseObject.EndArray(); + + RpcResponse.SetObject(ResponseObject.Save()); + return RpcResponse; +} + +namespace cache::detail { + + struct RecordValue + { + Oid ValueId; + IoHash ContentId; + uint64_t RawSize; + }; + struct RecordBody + { + IoBuffer CacheValue; + std::vector<RecordValue> Values; + const UpstreamEndpointInfo* Source = nullptr; + CachePolicy DownstreamPolicy; + bool Exists = false; + bool HasRequest = false; + bool ValuesRead = false; + }; + struct ChunkRequest + { + CacheChunkRequest* Key = nullptr; + RecordBody* Record = nullptr; + CompressedBuffer Value; + const UpstreamEndpointInfo* Source = nullptr; + uint64_t RawSize = 0; + uint64_t RequestedSize = 0; + uint64_t RequestedOffset = 0; + CachePolicy DownstreamPolicy; + bool Exists = false; + bool RawSizeKnown = false; + bool IsRecordRequest = false; + uint64_t ElapsedTimeUs = 0; + }; + +} // namespace cache::detail + +CbPackage +HttpStructuredCacheService::HandleRpcGetCacheChunks(CbObjectView RpcRequest) +{ + using namespace cache::detail; + + std::string Namespace; + std::vector<CacheKeyRequest> RecordKeys; // Data about a Record necessary to identify it to the upstream + std::vector<RecordBody> Records; // Scratch-space data about a Record when fulfilling RecordRequests + std::vector<CacheChunkRequest> RequestKeys; // Data about a ChunkRequest necessary to identify it to the upstream + std::vector<ChunkRequest> Requests; // Intermediate and result data about a ChunkRequest + std::vector<ChunkRequest*> RecordRequests; // The ChunkRequests that are requesting a subvalue from a Record Key + std::vector<ChunkRequest*> ValueRequests; // The ChunkRequests that are requesting a Value Key + std::vector<CacheChunkRequest*> UpstreamChunks; // ChunkRequests that we need to send to the upstream + + // Parse requests from the CompactBinary body of the RpcRequest and divide it into RecordRequests and ValueRequests + if (!ParseGetCacheChunksRequest(Namespace, RecordKeys, Records, RequestKeys, Requests, RecordRequests, ValueRequests, RpcRequest)) + { + return CbPackage{}; + } + + // For each Record request, load the Record if necessary to find the Chunk's ContentId, load its Payloads if we + // have it locally, and otherwise append a request for the payload to UpstreamChunks + GetLocalCacheRecords(Namespace, RecordKeys, Records, RecordRequests, UpstreamChunks); + + // For each Value request, load the Value if we have it locally and otherwise append a request for the payload to UpstreamChunks + GetLocalCacheValues(Namespace, ValueRequests, UpstreamChunks); + + // Call GetCacheChunks on the upstream for any payloads we do not have locally + GetUpstreamCacheChunks(Namespace, UpstreamChunks, RequestKeys, Requests); + + // Send the payload and descriptive data about each chunk to the client + return WriteGetCacheChunksResponse(Namespace, Requests); +} + +bool +HttpStructuredCacheService::ParseGetCacheChunksRequest(std::string& Namespace, + std::vector<CacheKeyRequest>& RecordKeys, + std::vector<cache::detail::RecordBody>& Records, + std::vector<CacheChunkRequest>& RequestKeys, + std::vector<cache::detail::ChunkRequest>& Requests, + std::vector<cache::detail::ChunkRequest*>& RecordRequests, + std::vector<cache::detail::ChunkRequest*>& ValueRequests, + CbObjectView RpcRequest) +{ + using namespace cache::detail; + + ZEN_ASSERT(RpcRequest["Method"sv].AsString() == "GetCacheChunks"sv); + + CbObjectView Params = RpcRequest["Params"sv].AsObjectView(); + std::string_view DefaultPolicyText = Params["DefaultPolicy"sv].AsString(); + CachePolicy DefaultPolicy = !DefaultPolicyText.empty() ? ParseCachePolicy(DefaultPolicyText) : CachePolicy::Default; + + std::optional<std::string> NamespaceText = GetRpcRequestNamespace(Params); + if (!NamespaceText) + { + ZEN_WARN("GetCacheChunks: Invalid namespace in ChunkRequest."); + return false; + } + Namespace = *NamespaceText; + + CbArrayView ChunkRequestsArray = Params["ChunkRequests"sv].AsArrayView(); + size_t NumRequests = static_cast<size_t>(ChunkRequestsArray.Num()); + + // Note that these reservations allow us to take pointers to the elements while populating them. If the reservation is removed, + // we will need to change the pointers to indexes to handle reallocations. + RecordKeys.reserve(NumRequests); + Records.reserve(NumRequests); + RequestKeys.reserve(NumRequests); + Requests.reserve(NumRequests); + RecordRequests.reserve(NumRequests); + ValueRequests.reserve(NumRequests); + + CacheKeyRequest* PreviousRecordKey = nullptr; + RecordBody* PreviousRecord = nullptr; + + for (CbFieldView RequestView : ChunkRequestsArray) + { + CbObjectView RequestObject = RequestView.AsObjectView(); + CacheChunkRequest& RequestKey = RequestKeys.emplace_back(); + ChunkRequest& Request = Requests.emplace_back(); + CbObjectView KeyObject = RequestObject["Key"sv].AsObjectView(); + + Request.Key = &RequestKey; + if (!GetRpcRequestCacheKey(KeyObject, Request.Key->Key)) + { + ZEN_WARN("GetCacheChunks: Invalid key in ChunkRequest."); + return false; + } + + RequestKey.ChunkId = RequestObject["ChunkId"sv].AsHash(); + RequestKey.ValueId = RequestObject["ValueId"sv].AsObjectId(); + RequestKey.RawOffset = RequestObject["RawOffset"sv].AsUInt64(); + RequestKey.RawSize = RequestObject["RawSize"sv].AsUInt64(UINT64_MAX); + Request.RequestedSize = RequestKey.RawSize; + Request.RequestedOffset = RequestKey.RawOffset; + std::string_view PolicyText = RequestObject["Policy"sv].AsString(); + Request.DownstreamPolicy = !PolicyText.empty() ? ParseCachePolicy(PolicyText) : DefaultPolicy; + Request.IsRecordRequest = (bool)RequestKey.ValueId; + + if (!Request.IsRecordRequest) + { + ValueRequests.push_back(&Request); + } + else + { + RecordRequests.push_back(&Request); + CacheKeyRequest* RecordKey = nullptr; + RecordBody* Record = nullptr; + + if (!PreviousRecordKey || PreviousRecordKey->Key < RequestKey.Key) + { + RecordKey = &RecordKeys.emplace_back(); + PreviousRecordKey = RecordKey; + Record = &Records.emplace_back(); + PreviousRecord = Record; + RecordKey->Key = RequestKey.Key; + } + else if (RequestKey.Key == PreviousRecordKey->Key) + { + RecordKey = PreviousRecordKey; + Record = PreviousRecord; + } + else + { + ZEN_WARN("GetCacheChunks: Keys in ChunkRequest are not sorted: {}/{} came after {}/{}.", + RequestKey.Key.Bucket, + RequestKey.Key.Hash, + PreviousRecordKey->Key.Bucket, + PreviousRecordKey->Key.Hash); + return false; + } + Request.Record = Record; + if (RequestKey.ChunkId == RequestKey.ChunkId.Zero) + { + Record->DownstreamPolicy = + Record->HasRequest ? Union(Record->DownstreamPolicy, Request.DownstreamPolicy) : Request.DownstreamPolicy; + Record->HasRequest = true; + } + } + } + if (Requests.empty()) + { + return false; + } + return true; +} + +void +HttpStructuredCacheService::GetLocalCacheRecords(std::string_view Namespace, + std::vector<CacheKeyRequest>& RecordKeys, + std::vector<cache::detail::RecordBody>& Records, + std::vector<cache::detail::ChunkRequest*>& RecordRequests, + std::vector<CacheChunkRequest*>& OutUpstreamChunks) +{ + using namespace cache::detail; + + std::vector<CacheKeyRequest*> UpstreamRecordRequests; + for (size_t RecordIndex = 0; RecordIndex < Records.size(); ++RecordIndex) + { + Stopwatch Timer; + CacheKeyRequest& RecordKey = RecordKeys[RecordIndex]; + RecordBody& Record = Records[RecordIndex]; + if (Record.HasRequest) + { + Record.DownstreamPolicy |= CachePolicy::SkipData | CachePolicy::SkipMeta; + + if (!Record.Exists && EnumHasAllFlags(Record.DownstreamPolicy, CachePolicy::QueryLocal)) + { + ZenCacheValue CacheValue; + if (m_CacheStore.Get(Namespace, RecordKey.Key.Bucket, RecordKey.Key.Hash, CacheValue)) + { + Record.Exists = true; + Record.CacheValue = std::move(CacheValue.Value); + } + } + if (!Record.Exists && EnumHasAllFlags(Record.DownstreamPolicy, CachePolicy::QueryRemote)) + { + RecordKey.Policy = CacheRecordPolicy(ConvertToUpstream(Record.DownstreamPolicy)); + UpstreamRecordRequests.push_back(&RecordKey); + } + RecordRequests[RecordIndex]->ElapsedTimeUs += Timer.GetElapsedTimeUs(); + } + } + + if (!UpstreamRecordRequests.empty()) + { + const auto OnCacheRecordGetComplete = + [this, Namespace, &RecordKeys, &Records, &RecordRequests](CacheRecordGetCompleteParams&& Params) { + if (!Params.Record) + { + return; + } + CacheKeyRequest& RecordKey = Params.Request; + size_t RecordIndex = std::distance(RecordKeys.data(), &RecordKey); + RecordRequests[RecordIndex]->ElapsedTimeUs += static_cast<uint64_t>(Params.ElapsedSeconds * 1000000.0); + RecordBody& Record = Records[RecordIndex]; + + const CacheKey& Key = RecordKey.Key; + Record.Exists = true; + CbObject ObjectBuffer = CbObject::Clone(Params.Record); + Record.CacheValue = ObjectBuffer.GetBuffer().AsIoBuffer(); + Record.CacheValue.SetContentType(ZenContentType::kCbObject); + Record.Source = Params.Source; + + if (EnumHasAllFlags(Record.DownstreamPolicy, CachePolicy::StoreLocal)) + { + m_CacheStore.Put(Namespace, Key.Bucket, Key.Hash, {.Value = Record.CacheValue}); + } + }; + m_UpstreamCache.GetCacheRecords(Namespace, UpstreamRecordRequests, std::move(OnCacheRecordGetComplete)); + } + + std::vector<CacheChunkRequest*> UpstreamPayloadRequests; + for (ChunkRequest* Request : RecordRequests) + { + Stopwatch Timer; + if (Request->Key->ChunkId == IoHash::Zero) + { + // Unreal uses a 12 byte ID to address cache record values. When the uncompressed hash (ChunkId) + // is missing, parse the cache record and try to find the raw hash from the ValueId. + RecordBody& Record = *Request->Record; + if (!Record.ValuesRead) + { + Record.ValuesRead = true; + if (Record.CacheValue && Record.CacheValue.GetContentType() == ZenContentType::kCbObject) + { + CbObjectView RecordObject = CbObjectView(Record.CacheValue.GetData()); + CbArrayView ValuesArray = RecordObject["Values"sv].AsArrayView(); + Record.Values.reserve(ValuesArray.Num()); + for (CbFieldView ValueField : ValuesArray) + { + CbObjectView ValueObject = ValueField.AsObjectView(); + Oid ValueId = ValueObject["Id"sv].AsObjectId(); + CbFieldView RawHashField = ValueObject["RawHash"sv]; + IoHash RawHash = RawHashField.AsBinaryAttachment(); + if (ValueId && !RawHashField.HasError()) + { + Record.Values.push_back({ValueId, RawHash, ValueObject["RawSize"sv].AsUInt64()}); + } + } + } + } + + for (const RecordValue& Value : Record.Values) + { + if (Value.ValueId == Request->Key->ValueId) + { + Request->Key->ChunkId = Value.ContentId; + Request->RawSize = Value.RawSize; + Request->RawSizeKnown = true; + break; + } + } + } + + // Now load the ContentId from the local ContentIdStore or from the upstream + if (Request->Key->ChunkId != IoHash::Zero) + { + if (EnumHasAllFlags(Request->DownstreamPolicy, CachePolicy::QueryLocal)) + { + if (EnumHasAllFlags(Request->DownstreamPolicy, CachePolicy::SkipData) && Request->RawSizeKnown) + { + if (m_CidStore.ContainsChunk(Request->Key->ChunkId)) + { + Request->Exists = true; + } + } + else if (IoBuffer Payload = m_CidStore.FindChunkByCid(Request->Key->ChunkId)) + { + if (!EnumHasAllFlags(Request->DownstreamPolicy, CachePolicy::SkipData)) + { + Request->Value = CompressedBuffer::FromCompressedNoValidate(std::move(Payload)); + if (Request->Value) + { + Request->Exists = true; + Request->RawSizeKnown = false; + } + } + else + { + IoHash _; + if (CompressedBuffer::ValidateCompressedHeader(Payload, _, Request->RawSize)) + { + Request->Exists = true; + Request->RawSizeKnown = true; + } + } + } + } + if (!Request->Exists && EnumHasAllFlags(Request->DownstreamPolicy, CachePolicy::QueryRemote)) + { + Request->Key->Policy = ConvertToUpstream(Request->DownstreamPolicy); + OutUpstreamChunks.push_back(Request->Key); + } + } + Request->ElapsedTimeUs += Timer.GetElapsedTimeUs(); + } +} + +void +HttpStructuredCacheService::GetLocalCacheValues(std::string_view Namespace, + std::vector<cache::detail::ChunkRequest*>& ValueRequests, + std::vector<CacheChunkRequest*>& OutUpstreamChunks) +{ + using namespace cache::detail; + + for (ChunkRequest* Request : ValueRequests) + { + Stopwatch Timer; + if (!Request->Exists && EnumHasAllFlags(Request->DownstreamPolicy, CachePolicy::QueryLocal)) + { + ZenCacheValue CacheValue; + if (m_CacheStore.Get(Namespace, Request->Key->Key.Bucket, Request->Key->Key.Hash, CacheValue)) + { + if (IsCompressedBinary(CacheValue.Value.GetContentType())) + { + Request->Key->ChunkId = CacheValue.RawHash; + Request->Exists = true; + Request->RawSize = CacheValue.RawSize; + Request->RawSizeKnown = true; + if (!EnumHasAllFlags(Request->DownstreamPolicy, CachePolicy::SkipData)) + { + Request->Value = CompressedBuffer::FromCompressedNoValidate(std::move(CacheValue.Value)); + } + } + } + } + if (!Request->Exists && EnumHasAllFlags(Request->DownstreamPolicy, CachePolicy::QueryRemote)) + { + if (EnumHasAllFlags(Request->DownstreamPolicy, CachePolicy::StoreLocal)) + { + // Convert the Offset,Size request into a request for the entire value; we will need it all to be able to store it locally + Request->Key->RawOffset = 0; + Request->Key->RawSize = UINT64_MAX; + } + OutUpstreamChunks.push_back(Request->Key); + } + Request->ElapsedTimeUs += Timer.GetElapsedTimeUs(); + } +} + +void +HttpStructuredCacheService::GetUpstreamCacheChunks(std::string_view Namespace, + std::vector<CacheChunkRequest*>& UpstreamChunks, + std::vector<CacheChunkRequest>& RequestKeys, + std::vector<cache::detail::ChunkRequest>& Requests) +{ + using namespace cache::detail; + + if (!UpstreamChunks.empty()) + { + const auto OnCacheChunksGetComplete = [this, Namespace, &RequestKeys, &Requests](CacheChunkGetCompleteParams&& Params) { + if (Params.RawHash == Params.RawHash.Zero) + { + return; + } + + CacheChunkRequest& Key = Params.Request; + size_t RequestIndex = std::distance(RequestKeys.data(), &Key); + ChunkRequest& Request = Requests[RequestIndex]; + Request.ElapsedTimeUs += static_cast<uint64_t>(Params.ElapsedSeconds * 1000000.0); + if (EnumHasAllFlags(Request.DownstreamPolicy, CachePolicy::StoreLocal) || + !EnumHasAllFlags(Request.DownstreamPolicy, CachePolicy::SkipData)) + { + CompressedBuffer Compressed = CompressedBuffer::FromCompressedNoValidate(IoBuffer(Params.Value)); + if (!Compressed) + { + return; + } + + if (EnumHasAllFlags(Request.DownstreamPolicy, CachePolicy::StoreLocal)) + { + if (Request.IsRecordRequest) + { + m_CidStore.AddChunk(Params.Value, Params.RawHash); + } + else + { + m_CacheStore.Put(Namespace, + Key.Key.Bucket, + Key.Key.Hash, + {.Value = Params.Value, .RawSize = Params.RawSize, .RawHash = Params.RawHash}); + } + } + if (!EnumHasAllFlags(Request.DownstreamPolicy, CachePolicy::SkipData)) + { + Request.Value = std::move(Compressed); + } + } + Key.ChunkId = Params.RawHash; + Request.Exists = true; + Request.RawSize = Params.RawSize; + Request.RawSizeKnown = true; + Request.Source = Params.Source; + + m_CacheStats.UpstreamHitCount++; + }; + + m_UpstreamCache.GetCacheChunks(Namespace, UpstreamChunks, std::move(OnCacheChunksGetComplete)); + } +} + +CbPackage +HttpStructuredCacheService::WriteGetCacheChunksResponse(std::string_view Namespace, std::vector<cache::detail::ChunkRequest>& Requests) +{ + using namespace cache::detail; + + CbPackage RpcResponse; + CbObjectWriter Writer; + + Writer.BeginArray("Result"sv); + for (ChunkRequest& Request : Requests) + { + Writer.BeginObject(); + { + if (Request.Exists) + { + Writer.AddHash("RawHash"sv, Request.Key->ChunkId); + if (Request.Value && !EnumHasAllFlags(Request.DownstreamPolicy, CachePolicy::SkipData)) + { + RpcResponse.AddAttachment(CbAttachment(Request.Value, Request.Key->ChunkId)); + } + else + { + Writer.AddInteger("RawSize"sv, Request.RawSize); + } + + ZEN_DEBUG("GETCACHECHUNKS HIT - '{}/{}/{}/{}' {} '{}' ({}) in {}", + Namespace, + Request.Key->Key.Bucket, + Request.Key->Key.Hash, + Request.Key->ValueId, + NiceBytes(Request.RawSize), + Request.IsRecordRequest ? "Record"sv : "Value"sv, + Request.Source ? Request.Source->Url : "LOCAL"sv, + NiceLatencyNs(Request.ElapsedTimeUs * 1000)); + m_CacheStats.HitCount++; + } + else if (!EnumHasAnyFlags(Request.DownstreamPolicy, CachePolicy::Query)) + { + ZEN_DEBUG("GETCACHECHUNKS DISABLEDQUERY - '{}/{}/{}/{}' in {}", + Namespace, + Request.Key->Key.Bucket, + Request.Key->Key.Hash, + Request.Key->ValueId, + NiceLatencyNs(Request.ElapsedTimeUs * 1000)); + } + else + { + ZEN_DEBUG("GETCACHECHUNKS MISS - '{}/{}/{}/{}' in {}", + Namespace, + Request.Key->Key.Bucket, + Request.Key->Key.Hash, + Request.Key->ValueId, + NiceLatencyNs(Request.ElapsedTimeUs * 1000)); + m_CacheStats.MissCount++; + } + } + Writer.EndObject(); + } + Writer.EndArray(); + + RpcResponse.SetObject(Writer.Save()); + return RpcResponse; +} + +void +HttpStructuredCacheService::HandleStatsRequest(HttpServerRequest& Request) +{ + CbObjectWriter Cbo; + + EmitSnapshot("requests", m_HttpRequests, Cbo); + EmitSnapshot("upstream_gets", m_UpstreamGetRequestTiming, Cbo); + + const uint64_t HitCount = m_CacheStats.HitCount; + const uint64_t UpstreamHitCount = m_CacheStats.UpstreamHitCount; + const uint64_t MissCount = m_CacheStats.MissCount; + const uint64_t TotalCount = HitCount + MissCount; + + const CidStoreSize CidSize = m_CidStore.TotalSize(); + const GcStorageSize CacheSize = m_CacheStore.StorageSize(); + + Cbo.BeginObject("cache"); + { + Cbo.BeginObject("size"); + { + Cbo << "disk" << CacheSize.DiskSize; + Cbo << "memory" << CacheSize.MemorySize; + } + Cbo.EndObject(); + + Cbo << "upstream_ratio" << (HitCount > 0 ? (double(UpstreamHitCount) / double(HitCount)) : 0.0); + Cbo << "hits" << HitCount << "misses" << MissCount; + Cbo << "hit_ratio" << (TotalCount > 0 ? (double(HitCount) / double(TotalCount)) : 0.0); + Cbo << "upstream_hits" << m_CacheStats.UpstreamHitCount; + Cbo << "upstream_ratio" << (HitCount > 0 ? (double(UpstreamHitCount) / double(HitCount)) : 0.0); + } + Cbo.EndObject(); + + Cbo.BeginObject("upstream"); + { + m_UpstreamCache.GetStatus(Cbo); + } + Cbo.EndObject(); + + Cbo.BeginObject("cid"); + { + Cbo.BeginObject("size"); + { + Cbo << "tiny" << CidSize.TinySize; + Cbo << "small" << CidSize.SmallSize; + Cbo << "large" << CidSize.LargeSize; + Cbo << "total" << CidSize.TotalSize; + } + Cbo.EndObject(); + } + Cbo.EndObject(); + + Request.WriteResponse(HttpResponseCode::OK, Cbo.Save()); +} + +void +HttpStructuredCacheService::HandleStatusRequest(HttpServerRequest& Request) +{ + CbObjectWriter Cbo; + Cbo << "ok" << true; + Request.WriteResponse(HttpResponseCode::OK, Cbo.Save()); +} + +#if ZEN_WITH_TESTS + +TEST_CASE("z$service.parse.relative.Uri") +{ + HttpRequestData RootRequest; + CHECK(HttpRequestParseRelativeUri("", RootRequest)); + CHECK(!RootRequest.Namespace.has_value()); + CHECK(!RootRequest.Bucket.has_value()); + CHECK(!RootRequest.HashKey.has_value()); + CHECK(!RootRequest.ValueContentId.has_value()); + + RootRequest = {}; + CHECK(HttpRequestParseRelativeUri("/", RootRequest)); + CHECK(!RootRequest.Namespace.has_value()); + CHECK(!RootRequest.Bucket.has_value()); + CHECK(!RootRequest.HashKey.has_value()); + CHECK(!RootRequest.ValueContentId.has_value()); + + HttpRequestData LegacyBucketRequestBecomesNamespaceRequest; + CHECK(HttpRequestParseRelativeUri("test", LegacyBucketRequestBecomesNamespaceRequest)); + CHECK(LegacyBucketRequestBecomesNamespaceRequest.Namespace == "test"sv); + CHECK(!LegacyBucketRequestBecomesNamespaceRequest.Bucket.has_value()); + CHECK(!LegacyBucketRequestBecomesNamespaceRequest.HashKey.has_value()); + CHECK(!LegacyBucketRequestBecomesNamespaceRequest.ValueContentId.has_value()); + + HttpRequestData LegacyHashKeyRequest; + CHECK(HttpRequestParseRelativeUri("test/0123456789abcdef12340123456789abcdef1234", LegacyHashKeyRequest)); + CHECK(LegacyHashKeyRequest.Namespace == ZenCacheStore::DefaultNamespace); + CHECK(LegacyHashKeyRequest.Bucket == "test"sv); + CHECK(LegacyHashKeyRequest.HashKey == IoHash::FromHexString("0123456789abcdef12340123456789abcdef1234"sv)); + CHECK(!LegacyHashKeyRequest.ValueContentId.has_value()); + + HttpRequestData LegacyValueContentIdRequest; + CHECK(HttpRequestParseRelativeUri("test/0123456789abcdef12340123456789abcdef1234/56789abcdef12345678956789abcdef123456789", + LegacyValueContentIdRequest)); + CHECK(LegacyValueContentIdRequest.Namespace == ZenCacheStore::DefaultNamespace); + CHECK(LegacyValueContentIdRequest.Bucket == "test"sv); + CHECK(LegacyValueContentIdRequest.HashKey == IoHash::FromHexString("0123456789abcdef12340123456789abcdef1234"sv)); + CHECK(LegacyValueContentIdRequest.ValueContentId == IoHash::FromHexString("56789abcdef12345678956789abcdef123456789"sv)); + + HttpRequestData V2DefaultNamespaceRequest; + CHECK(HttpRequestParseRelativeUri("ue4.ddc", V2DefaultNamespaceRequest)); + CHECK(V2DefaultNamespaceRequest.Namespace == "ue4.ddc"); + CHECK(!V2DefaultNamespaceRequest.Bucket.has_value()); + CHECK(!V2DefaultNamespaceRequest.HashKey.has_value()); + CHECK(!V2DefaultNamespaceRequest.ValueContentId.has_value()); + + HttpRequestData V2NamespaceRequest; + CHECK(HttpRequestParseRelativeUri("nicenamespace", V2NamespaceRequest)); + CHECK(V2NamespaceRequest.Namespace == "nicenamespace"sv); + CHECK(!V2NamespaceRequest.Bucket.has_value()); + CHECK(!V2NamespaceRequest.HashKey.has_value()); + CHECK(!V2NamespaceRequest.ValueContentId.has_value()); + + HttpRequestData V2BucketRequestWithDefaultNamespace; + CHECK(HttpRequestParseRelativeUri("ue4.ddc/test", V2BucketRequestWithDefaultNamespace)); + CHECK(V2BucketRequestWithDefaultNamespace.Namespace == "ue4.ddc"); + CHECK(V2BucketRequestWithDefaultNamespace.Bucket == "test"sv); + CHECK(!V2BucketRequestWithDefaultNamespace.HashKey.has_value()); + CHECK(!V2BucketRequestWithDefaultNamespace.ValueContentId.has_value()); + + HttpRequestData V2BucketRequestWithNamespace; + CHECK(HttpRequestParseRelativeUri("nicenamespace/test", V2BucketRequestWithNamespace)); + CHECK(V2BucketRequestWithNamespace.Namespace == "nicenamespace"sv); + CHECK(V2BucketRequestWithNamespace.Bucket == "test"sv); + CHECK(!V2BucketRequestWithNamespace.HashKey.has_value()); + CHECK(!V2BucketRequestWithNamespace.ValueContentId.has_value()); + + HttpRequestData V2HashKeyRequest; + CHECK(HttpRequestParseRelativeUri("test/0123456789abcdef12340123456789abcdef1234", V2HashKeyRequest)); + CHECK(V2HashKeyRequest.Namespace == ZenCacheStore::DefaultNamespace); + CHECK(V2HashKeyRequest.Bucket == "test"); + CHECK(V2HashKeyRequest.HashKey == IoHash::FromHexString("0123456789abcdef12340123456789abcdef1234"sv)); + CHECK(!V2HashKeyRequest.ValueContentId.has_value()); + + HttpRequestData V2ValueContentIdRequest; + CHECK( + HttpRequestParseRelativeUri("nicenamespace/test/0123456789abcdef12340123456789abcdef1234/56789abcdef12345678956789abcdef123456789", + V2ValueContentIdRequest)); + CHECK(V2ValueContentIdRequest.Namespace == "nicenamespace"sv); + CHECK(V2ValueContentIdRequest.Bucket == "test"sv); + CHECK(V2ValueContentIdRequest.HashKey == IoHash::FromHexString("0123456789abcdef12340123456789abcdef1234"sv)); + CHECK(V2ValueContentIdRequest.ValueContentId == IoHash::FromHexString("56789abcdef12345678956789abcdef123456789"sv)); + + HttpRequestData Invalid; + CHECK(!HttpRequestParseRelativeUri("bad\2_namespace", Invalid)); + CHECK(!HttpRequestParseRelativeUri("nice/\2\1bucket", Invalid)); + CHECK(!HttpRequestParseRelativeUri("namespace/bucket/0123456789a", Invalid)); + CHECK(!HttpRequestParseRelativeUri("namespace/bucket/0123456789abcdef12340123456789abcdef1234/56789abcdef1234", Invalid)); + CHECK(!HttpRequestParseRelativeUri("namespace/bucket/pppppppp89abcdef12340123456789abcdef1234", Invalid)); + CHECK(!HttpRequestParseRelativeUri("namespace/bucket/0123456789abcdef12340123456789abcdef1234/56789abcd", Invalid)); + CHECK(!HttpRequestParseRelativeUri("namespace/bucket/0123456789abcdef12340123456789abcdef1234/ppppppppdef12345678956789abcdef123456789", + Invalid)); +} + +#endif + +void +z$service_forcelink() +{ +} + +} // namespace zen diff --git a/src/zenserver/cache/structuredcache.h b/src/zenserver/cache/structuredcache.h new file mode 100644 index 000000000..4e7b98ac9 --- /dev/null +++ b/src/zenserver/cache/structuredcache.h @@ -0,0 +1,187 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/stats.h> +#include <zenhttp/httpserver.h> + +#include "monitoring/httpstats.h" +#include "monitoring/httpstatus.h" + +#include <memory> +#include <vector> + +namespace spdlog { +class logger; +} + +namespace zen { + +struct CacheChunkRequest; +struct CacheKeyRequest; +class CidStore; +class CbObjectView; +struct PutRequestData; +class ScrubContext; +class UpstreamCache; +class ZenCacheStore; +enum class CachePolicy : uint32_t; +enum class RpcAcceptOptions : uint16_t; + +namespace cache { + class IRpcRequestReplayer; + class IRpcRequestRecorder; + namespace detail { + struct RecordBody; + struct ChunkRequest; + } // namespace detail +} // namespace cache + +/** + * Structured cache service. Imposes constraints on keys, supports blobs and + * structured values + * + * Keys are structured as: + * + * {BucketId}/{KeyHash} + * + * Where BucketId is a lower-case alphanumeric string, and KeyHash is a 40-character + * hexadecimal sequence. The hash value may be derived in any number of ways, it's + * up to the application to pick an approach. + * + * Values may be structured or unstructured. Structured values are encoded using Unreal + * Engine's compact binary encoding (see CbObject) + * + * Additionally, attachments may be addressed as: + * + * {BucketId}/{KeyHash}/{ValueHash} + * + * Where the two initial components are the same as for the main endpoint + * + * The storage strategy is as follows: + * + * - Structured values are stored in a dedicated backing store per bucket + * - Unstructured values and attachments are stored in the CAS pool + * + */ + +class HttpStructuredCacheService : public HttpService, public IHttpStatsProvider, public IHttpStatusProvider +{ +public: + HttpStructuredCacheService(ZenCacheStore& InCacheStore, + CidStore& InCidStore, + HttpStatsService& StatsService, + HttpStatusService& StatusService, + UpstreamCache& UpstreamCache); + ~HttpStructuredCacheService(); + + virtual const char* BaseUri() const override; + virtual void HandleRequest(HttpServerRequest& Request) override; + + void Flush(); + void Scrub(ScrubContext& Ctx); + +private: + struct CacheRef + { + std::string Namespace; + std::string BucketSegment; + IoHash HashKey; + IoHash ValueContentId; + }; + + struct CacheStats + { + std::atomic_uint64_t HitCount{}; + std::atomic_uint64_t UpstreamHitCount{}; + std::atomic_uint64_t MissCount{}; + }; + enum class PutResult + { + Success, + Fail, + Invalid, + }; + + void HandleCacheRecordRequest(HttpServerRequest& Request, const CacheRef& Ref, CachePolicy PolicyFromUrl); + void HandleGetCacheRecord(HttpServerRequest& Request, const CacheRef& Ref, CachePolicy PolicyFromUrl); + void HandlePutCacheRecord(HttpServerRequest& Request, const CacheRef& Ref, CachePolicy PolicyFromUrl); + void HandleCacheChunkRequest(HttpServerRequest& Request, const CacheRef& Ref, CachePolicy PolicyFromUrl); + void HandleGetCacheChunk(HttpServerRequest& Request, const CacheRef& Ref, CachePolicy PolicyFromUrl); + void HandlePutCacheChunk(HttpServerRequest& Request, const CacheRef& Ref, CachePolicy PolicyFromUrl); + void HandleRpcRequest(HttpServerRequest& Request); + void HandleDetailsRequest(HttpServerRequest& Request); + + CbPackage HandleRpcPutCacheRecords(const CbPackage& BatchRequest); + CbPackage HandleRpcGetCacheRecords(CbObjectView BatchRequest); + CbPackage HandleRpcPutCacheValues(const CbPackage& BatchRequest); + CbPackage HandleRpcGetCacheValues(CbObjectView BatchRequest); + CbPackage HandleRpcGetCacheChunks(CbObjectView BatchRequest); + CbPackage HandleRpcRequest(const ZenContentType ContentType, + IoBuffer&& Body, + uint32_t& OutAcceptMagic, + RpcAcceptOptions& OutAcceptFlags, + int& OutTargetProcessId); + + void HandleCacheRequest(HttpServerRequest& Request); + void HandleCacheNamespaceRequest(HttpServerRequest& Request, std::string_view Namespace); + void HandleCacheBucketRequest(HttpServerRequest& Request, std::string_view Namespace, std::string_view Bucket); + virtual void HandleStatsRequest(HttpServerRequest& Request) override; + virtual void HandleStatusRequest(HttpServerRequest& Request) override; + PutResult PutCacheRecord(PutRequestData& Request, const CbPackage* Package); + + /** HandleRpcGetCacheChunks Helper: Parse the Body object into RecordValue Requests and Value Requests. */ + bool ParseGetCacheChunksRequest(std::string& Namespace, + std::vector<CacheKeyRequest>& RecordKeys, + std::vector<cache::detail::RecordBody>& Records, + std::vector<CacheChunkRequest>& RequestKeys, + std::vector<cache::detail::ChunkRequest>& Requests, + std::vector<cache::detail::ChunkRequest*>& RecordRequests, + std::vector<cache::detail::ChunkRequest*>& ValueRequests, + CbObjectView RpcRequest); + /** HandleRpcGetCacheChunks Helper: Load records to get ContentId for RecordRequests, and load their payloads if they exist locally. */ + void GetLocalCacheRecords(std::string_view Namespace, + std::vector<CacheKeyRequest>& RecordKeys, + std::vector<cache::detail::RecordBody>& Records, + std::vector<cache::detail::ChunkRequest*>& RecordRequests, + std::vector<CacheChunkRequest*>& OutUpstreamChunks); + /** HandleRpcGetCacheChunks Helper: For ValueRequests, load their payloads if they exist locally. */ + void GetLocalCacheValues(std::string_view Namespace, + std::vector<cache::detail::ChunkRequest*>& ValueRequests, + std::vector<CacheChunkRequest*>& OutUpstreamChunks); + /** HandleRpcGetCacheChunks Helper: Load payloads from upstream that did not exist locally. */ + void GetUpstreamCacheChunks(std::string_view Namespace, + std::vector<CacheChunkRequest*>& UpstreamChunks, + std::vector<CacheChunkRequest>& RequestKeys, + std::vector<cache::detail::ChunkRequest>& Requests); + /** HandleRpcGetCacheChunks Helper: Send response message containing all chunk results. */ + CbPackage WriteGetCacheChunksResponse(std::string_view Namespace, std::vector<cache::detail::ChunkRequest>& Requests); + + spdlog::logger& Log() { return m_Log; } + spdlog::logger& m_Log; + ZenCacheStore& m_CacheStore; + HttpStatsService& m_StatsService; + HttpStatusService& m_StatusService; + CidStore& m_CidStore; + UpstreamCache& m_UpstreamCache; + uint64_t m_LastScrubTime = 0; + metrics::OperationTiming m_HttpRequests; + metrics::OperationTiming m_UpstreamGetRequestTiming; + CacheStats m_CacheStats; + + void ReplayRequestRecorder(cache::IRpcRequestReplayer& Replayer, uint32_t ThreadCount); + + std::unique_ptr<cache::IRpcRequestRecorder> m_RequestRecorder; +}; + +/** Recognize both kBinary and kCompressedBinary as kCompressedBinary for structured cache value keys. + * We need this until the content type is preserved for kCompressedBinary when passing to and from upstream servers. */ +inline bool +IsCompressedBinary(ZenContentType Type) +{ + return Type == ZenContentType::kBinary || Type == ZenContentType::kCompressedBinary; +} + +void z$service_forcelink(); + +} // namespace zen diff --git a/src/zenserver/cache/structuredcachestore.cpp b/src/zenserver/cache/structuredcachestore.cpp new file mode 100644 index 000000000..26e970073 --- /dev/null +++ b/src/zenserver/cache/structuredcachestore.cpp @@ -0,0 +1,3648 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "structuredcachestore.h" + +#include <zencore/except.h> + +#include <zencore/compactbinarybuilder.h> +#include <zencore/compactbinarypackage.h> +#include <zencore/compactbinaryvalidation.h> +#include <zencore/compress.h> +#include <zencore/except.h> +#include <zencore/filesystem.h> +#include <zencore/fmtutils.h> +#include <zencore/logging.h> +#include <zencore/scopeguard.h> +#include <zencore/string.h> +#include <zencore/thread.h> +#include <zencore/timer.h> +#include <zencore/trace.h> +#include <zenstore/cidstore.h> +#include <zenstore/scrubcontext.h> + +#include <xxhash.h> + +#include <limits> + +#if ZEN_PLATFORM_WINDOWS +# include <zencore/windows.h> +#endif + +ZEN_THIRD_PARTY_INCLUDES_START +#include <fmt/core.h> +#include <gsl/gsl-lite.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +#if ZEN_WITH_TESTS +# include <zencore/testing.h> +# include <zencore/testutils.h> +# include <zencore/workthreadpool.h> +# include <random> +#endif + +////////////////////////////////////////////////////////////////////////// + +namespace zen { + +namespace { + +#pragma pack(push) +#pragma pack(1) + + struct CacheBucketIndexHeader + { + static constexpr uint32_t ExpectedMagic = 0x75696478; // 'uidx'; + static constexpr uint32_t Version2 = 2; + static constexpr uint32_t CurrentVersion = Version2; + + uint32_t Magic = ExpectedMagic; + uint32_t Version = CurrentVersion; + uint64_t EntryCount = 0; + uint64_t LogPosition = 0; + uint32_t PayloadAlignment = 0; + uint32_t Checksum = 0; + + static uint32_t ComputeChecksum(const CacheBucketIndexHeader& Header) + { + return XXH32(&Header.Magic, sizeof(CacheBucketIndexHeader) - sizeof(uint32_t), 0xC0C0'BABA); + } + }; + + static_assert(sizeof(CacheBucketIndexHeader) == 32); + +#pragma pack(pop) + + const char* IndexExtension = ".uidx"; + const char* LogExtension = ".slog"; + + std::filesystem::path GetIndexPath(const std::filesystem::path& BucketDir, const std::string& BucketName) + { + return BucketDir / (BucketName + IndexExtension); + } + + std::filesystem::path GetTempIndexPath(const std::filesystem::path& BucketDir, const std::string& BucketName) + { + return BucketDir / (BucketName + ".tmp"); + } + + std::filesystem::path GetLogPath(const std::filesystem::path& BucketDir, const std::string& BucketName) + { + return BucketDir / (BucketName + LogExtension); + } + + bool ValidateEntry(const DiskIndexEntry& Entry, std::string& OutReason) + { + if (Entry.Key == IoHash::Zero) + { + OutReason = fmt::format("Invalid hash key {}", Entry.Key.ToHexString()); + return false; + } + if (Entry.Location.GetFlags() & + ~(DiskLocation::kStandaloneFile | DiskLocation::kStructured | DiskLocation::kTombStone | DiskLocation::kCompressed)) + { + OutReason = fmt::format("Invalid flags {} for entry {}", Entry.Location.GetFlags(), Entry.Key.ToHexString()); + return false; + } + if (Entry.Location.IsFlagSet(DiskLocation::kTombStone)) + { + return true; + } + if (Entry.Location.Reserved != 0) + { + OutReason = fmt::format("Invalid reserved field {} for entry {}", Entry.Location.Reserved, Entry.Key.ToHexString()); + return false; + } + uint64_t Size = Entry.Location.Size(); + if (Size == 0) + { + OutReason = fmt::format("Invalid size {} for entry {}", Size, Entry.Key.ToHexString()); + return false; + } + return true; + } + + bool MoveAndDeleteDirectory(const std::filesystem::path& Dir) + { + int DropIndex = 0; + do + { + if (!std::filesystem::exists(Dir)) + { + return false; + } + + std::string DroppedName = fmt::format("[dropped]{}({})", Dir.filename().string(), DropIndex); + std::filesystem::path DroppedBucketPath = Dir.parent_path() / DroppedName; + if (std::filesystem::exists(DroppedBucketPath)) + { + DropIndex++; + continue; + } + + std::error_code Ec; + std::filesystem::rename(Dir, DroppedBucketPath, Ec); + if (!Ec) + { + DeleteDirectories(DroppedBucketPath); + return true; + } + // TODO: Do we need to bail at some point? + zen::Sleep(100); + } while (true); + } + +} // namespace + +namespace fs = std::filesystem; + +static CbObject +LoadCompactBinaryObject(const fs::path& Path) +{ + FileContents Result = ReadFile(Path); + + if (!Result.ErrorCode) + { + IoBuffer Buffer = Result.Flatten(); + if (CbValidateError Error = ValidateCompactBinary(Buffer, CbValidateMode::All); Error == CbValidateError::None) + { + return LoadCompactBinaryObject(Buffer); + } + } + + return CbObject(); +} + +static void +SaveCompactBinaryObject(const fs::path& Path, const CbObject& Object) +{ + WriteFile(Path, Object.GetBuffer().AsIoBuffer()); +} + +ZenCacheNamespace::ZenCacheNamespace(GcManager& Gc, const std::filesystem::path& RootDir) +: GcStorage(Gc) +, GcContributor(Gc) +, m_RootDir(RootDir) +, m_DiskLayer(RootDir) +{ + ZEN_INFO("initializing structured cache at '{}'", RootDir); + CreateDirectories(RootDir); + + m_DiskLayer.DiscoverBuckets(); + +#if ZEN_USE_CACHE_TRACKER + m_AccessTracker.reset(new ZenCacheTracker(RootDir)); +#endif +} + +ZenCacheNamespace::~ZenCacheNamespace() +{ +} + +bool +ZenCacheNamespace::Get(std::string_view InBucket, const IoHash& HashKey, ZenCacheValue& OutValue) +{ + ZEN_TRACE_CPU("Z$::Get"); + + bool Ok = m_MemLayer.Get(InBucket, HashKey, OutValue); + +#if ZEN_USE_CACHE_TRACKER + auto _ = MakeGuard([&] { + if (!Ok) + return; + + m_AccessTracker->TrackAccess(InBucket, HashKey); + }); +#endif + + if (Ok) + { + ZEN_ASSERT(OutValue.Value.Size()); + + return true; + } + + Ok = m_DiskLayer.Get(InBucket, HashKey, OutValue); + + if (Ok) + { + ZEN_ASSERT(OutValue.Value.Size()); + + if (OutValue.Value.Size() <= m_DiskLayerSizeThreshold) + { + m_MemLayer.Put(InBucket, HashKey, OutValue); + } + } + + return Ok; +} + +void +ZenCacheNamespace::Put(std::string_view InBucket, const IoHash& HashKey, const ZenCacheValue& Value) +{ + ZEN_TRACE_CPU("Z$::Put"); + + // Store value and index + + ZEN_ASSERT(Value.Value.Size()); + + m_DiskLayer.Put(InBucket, HashKey, Value); + +#if ZEN_USE_REF_TRACKING + if (Value.Value.GetContentType() == ZenContentType::kCbObject) + { + if (ValidateCompactBinary(Value.Value, CbValidateMode::All) == CbValidateError::None) + { + CbObject Object{SharedBuffer(Value.Value)}; + + uint8_t TempBuffer[8 * sizeof(IoHash)]; + std::pmr::monotonic_buffer_resource Linear{TempBuffer, sizeof TempBuffer}; + std::pmr::polymorphic_allocator Allocator{&Linear}; + std::pmr::vector<IoHash> CidReferences{Allocator}; + + Object.IterateAttachments([&](CbFieldView Field) { CidReferences.push_back(Field.AsAttachment()); }); + + m_Gc.OnNewCidReferences(CidReferences); + } + } +#endif + + if (Value.Value.Size() <= m_DiskLayerSizeThreshold) + { + m_MemLayer.Put(InBucket, HashKey, Value); + } +} + +bool +ZenCacheNamespace::DropBucket(std::string_view Bucket) +{ + ZEN_INFO("dropping bucket '{}'", Bucket); + + // TODO: should ensure this is done atomically across all layers + + const bool MemDropped = m_MemLayer.DropBucket(Bucket); + const bool DiskDropped = m_DiskLayer.DropBucket(Bucket); + const bool AnyDropped = MemDropped || DiskDropped; + + ZEN_INFO("bucket '{}' was {}", Bucket, AnyDropped ? "dropped" : "not found"); + + return AnyDropped; +} + +bool +ZenCacheNamespace::Drop() +{ + m_MemLayer.Drop(); + return m_DiskLayer.Drop(); +} + +void +ZenCacheNamespace::Flush() +{ + m_DiskLayer.Flush(); +} + +void +ZenCacheNamespace::Scrub(ScrubContext& Ctx) +{ + if (m_LastScrubTime == Ctx.ScrubTimestamp()) + { + return; + } + + m_LastScrubTime = Ctx.ScrubTimestamp(); + + m_DiskLayer.Scrub(Ctx); + m_MemLayer.Scrub(Ctx); +} + +void +ZenCacheNamespace::GatherReferences(GcContext& GcCtx) +{ + Stopwatch Timer; + const auto Guard = + MakeGuard([&] { ZEN_DEBUG("cache gathered all references from '{}' in {}", m_RootDir, NiceTimeSpanMs(Timer.GetElapsedTimeMs())); }); + + access_tracking::AccessTimes AccessTimes; + m_MemLayer.GatherAccessTimes(AccessTimes); + + m_DiskLayer.UpdateAccessTimes(AccessTimes); + m_DiskLayer.GatherReferences(GcCtx); +} + +void +ZenCacheNamespace::CollectGarbage(GcContext& GcCtx) +{ + m_MemLayer.Reset(); + m_DiskLayer.CollectGarbage(GcCtx); +} + +GcStorageSize +ZenCacheNamespace::StorageSize() const +{ + return {.DiskSize = m_DiskLayer.TotalSize(), .MemorySize = m_MemLayer.TotalSize()}; +} + +ZenCacheNamespace::Info +ZenCacheNamespace::GetInfo() const +{ + ZenCacheNamespace::Info Info = {.Config = {.RootDir = m_RootDir, .DiskLayerThreshold = m_DiskLayerSizeThreshold}, + .DiskLayerInfo = m_DiskLayer.GetInfo(), + .MemoryLayerInfo = m_MemLayer.GetInfo()}; + std::unordered_set<std::string> BucketNames; + for (const std::string& BucketName : Info.DiskLayerInfo.BucketNames) + { + BucketNames.insert(BucketName); + } + for (const std::string& BucketName : Info.MemoryLayerInfo.BucketNames) + { + BucketNames.insert(BucketName); + } + Info.BucketNames.insert(Info.BucketNames.end(), BucketNames.begin(), BucketNames.end()); + return Info; +} + +std::optional<ZenCacheNamespace::BucketInfo> +ZenCacheNamespace::GetBucketInfo(std::string_view Bucket) const +{ + std::optional<ZenCacheDiskLayer::BucketInfo> DiskBucketInfo = m_DiskLayer.GetBucketInfo(Bucket); + if (!DiskBucketInfo.has_value()) + { + return {}; + } + ZenCacheNamespace::BucketInfo Info = {.DiskLayerInfo = *DiskBucketInfo, + .MemoryLayerInfo = m_MemLayer.GetBucketInfo(Bucket).value_or(ZenCacheMemoryLayer::BucketInfo{})}; + return Info; +} + +CacheValueDetails::NamespaceDetails +ZenCacheNamespace::GetValueDetails(const std::string_view BucketFilter, const std::string_view ValueFilter) const +{ + return m_DiskLayer.GetValueDetails(BucketFilter, ValueFilter); +} + +////////////////////////////////////////////////////////////////////////// + +ZenCacheMemoryLayer::ZenCacheMemoryLayer() +{ +} + +ZenCacheMemoryLayer::~ZenCacheMemoryLayer() +{ +} + +bool +ZenCacheMemoryLayer::Get(std::string_view InBucket, const IoHash& HashKey, ZenCacheValue& OutValue) +{ + RwLock::SharedLockScope _(m_Lock); + + auto It = m_Buckets.find(std::string(InBucket)); + + if (It == m_Buckets.end()) + { + return false; + } + + CacheBucket* Bucket = It->second.get(); + + _.ReleaseNow(); + + // There's a race here. Since the lock is released early to allow + // inserts, the bucket delete path could end up deleting the + // underlying data structure + + return Bucket->Get(HashKey, OutValue); +} + +void +ZenCacheMemoryLayer::Put(std::string_view InBucket, const IoHash& HashKey, const ZenCacheValue& Value) +{ + const auto BucketName = std::string(InBucket); + CacheBucket* Bucket = nullptr; + + { + RwLock::SharedLockScope _(m_Lock); + + if (auto It = m_Buckets.find(std::string(InBucket)); It != m_Buckets.end()) + { + Bucket = It->second.get(); + } + } + + if (Bucket == nullptr) + { + // New bucket + + RwLock::ExclusiveLockScope _(m_Lock); + + if (auto It = m_Buckets.find(std::string(InBucket)); It != m_Buckets.end()) + { + Bucket = It->second.get(); + } + else + { + auto InsertResult = m_Buckets.emplace(BucketName, std::make_unique<CacheBucket>()); + Bucket = InsertResult.first->second.get(); + } + } + + // Note that since the underlying IoBuffer is retained, the content type is also + + Bucket->Put(HashKey, Value); +} + +bool +ZenCacheMemoryLayer::DropBucket(std::string_view InBucket) +{ + RwLock::ExclusiveLockScope _(m_Lock); + + auto It = m_Buckets.find(std::string(InBucket)); + + if (It != m_Buckets.end()) + { + CacheBucket& Bucket = *It->second; + m_DroppedBuckets.push_back(std::move(It->second)); + m_Buckets.erase(It); + Bucket.Drop(); + return true; + } + return false; +} + +void +ZenCacheMemoryLayer::Drop() +{ + RwLock::ExclusiveLockScope _(m_Lock); + std::vector<std::unique_ptr<CacheBucket>> Buckets; + Buckets.reserve(m_Buckets.size()); + while (!m_Buckets.empty()) + { + const auto& It = m_Buckets.begin(); + CacheBucket& Bucket = *It->second; + m_DroppedBuckets.push_back(std::move(It->second)); + m_Buckets.erase(It->first); + Bucket.Drop(); + } +} + +void +ZenCacheMemoryLayer::Scrub(ScrubContext& Ctx) +{ + RwLock::SharedLockScope _(m_Lock); + + for (auto& Kv : m_Buckets) + { + Kv.second->Scrub(Ctx); + } +} + +void +ZenCacheMemoryLayer::GatherAccessTimes(zen::access_tracking::AccessTimes& AccessTimes) +{ + using namespace zen::access_tracking; + + RwLock::SharedLockScope _(m_Lock); + + for (auto& Kv : m_Buckets) + { + std::vector<KeyAccessTime>& Bucket = AccessTimes.Buckets[Kv.first]; + Kv.second->GatherAccessTimes(Bucket); + } +} + +void +ZenCacheMemoryLayer::Reset() +{ + RwLock::ExclusiveLockScope _(m_Lock); + m_Buckets.clear(); +} + +uint64_t +ZenCacheMemoryLayer::TotalSize() const +{ + uint64_t TotalSize{}; + RwLock::SharedLockScope _(m_Lock); + + for (auto& Kv : m_Buckets) + { + TotalSize += Kv.second->TotalSize(); + } + + return TotalSize; +} + +ZenCacheMemoryLayer::Info +ZenCacheMemoryLayer::GetInfo() const +{ + ZenCacheMemoryLayer::Info Info = {.Config = m_Configuration, .TotalSize = TotalSize()}; + + RwLock::SharedLockScope _(m_Lock); + Info.BucketNames.reserve(m_Buckets.size()); + for (auto& Kv : m_Buckets) + { + Info.BucketNames.push_back(Kv.first); + Info.EntryCount += Kv.second->EntryCount(); + } + return Info; +} + +std::optional<ZenCacheMemoryLayer::BucketInfo> +ZenCacheMemoryLayer::GetBucketInfo(std::string_view Bucket) const +{ + RwLock::SharedLockScope _(m_Lock); + + if (auto It = m_Buckets.find(std::string(Bucket)); It != m_Buckets.end()) + { + return ZenCacheMemoryLayer::BucketInfo{.EntryCount = It->second->EntryCount(), .TotalSize = It->second->TotalSize()}; + } + return {}; +} + +void +ZenCacheMemoryLayer::CacheBucket::Scrub(ScrubContext& Ctx) +{ + RwLock::SharedLockScope _(m_BucketLock); + + std::vector<IoHash> BadHashes; + + auto ValidateEntry = [](const IoHash& Hash, ZenContentType ContentType, IoBuffer Buffer) { + if (ContentType == ZenContentType::kCbObject) + { + CbValidateError Error = ValidateCompactBinary(Buffer, CbValidateMode::All); + return Error == CbValidateError::None; + } + if (ContentType == ZenContentType::kCompressedBinary) + { + IoHash RawHash; + uint64_t RawSize; + if (!CompressedBuffer::ValidateCompressedHeader(Buffer, RawHash, RawSize)) + { + return false; + } + if (Hash != RawHash) + { + return false; + } + } + return true; + }; + + for (auto& Kv : m_CacheMap) + { + const BucketPayload& Payload = m_Payloads[Kv.second]; + if (!ValidateEntry(Kv.first, Payload.Payload.GetContentType(), Payload.Payload)) + { + BadHashes.push_back(Kv.first); + } + } + + if (!BadHashes.empty()) + { + Ctx.ReportBadCidChunks(BadHashes); + } +} + +void +ZenCacheMemoryLayer::CacheBucket::GatherAccessTimes(std::vector<zen::access_tracking::KeyAccessTime>& AccessTimes) +{ + RwLock::SharedLockScope _(m_BucketLock); + std::transform(m_CacheMap.begin(), m_CacheMap.end(), std::back_inserter(AccessTimes), [this](const auto& Kv) { + return access_tracking::KeyAccessTime{.Key = Kv.first, .LastAccess = m_AccessTimes[Kv.second]}; + }); +} + +bool +ZenCacheMemoryLayer::CacheBucket::Get(const IoHash& HashKey, ZenCacheValue& OutValue) +{ + RwLock::SharedLockScope _(m_BucketLock); + + if (auto It = m_CacheMap.find(HashKey); It != m_CacheMap.end()) + { + uint32_t EntryIndex = It.value(); + ZEN_ASSERT_SLOW(EntryIndex < m_Payloads.size()); + ZEN_ASSERT_SLOW(m_AccessTimes.size() == m_Payloads.size()); + + const BucketPayload& Payload = m_Payloads[EntryIndex]; + OutValue = {.Value = Payload.Payload, .RawSize = Payload.RawSize, .RawHash = Payload.RawHash}; + m_AccessTimes[EntryIndex] = GcClock::TickCount(); + + return true; + } + + return false; +} + +void +ZenCacheMemoryLayer::CacheBucket::Put(const IoHash& HashKey, const ZenCacheValue& Value) +{ + size_t PayloadSize = Value.Value.GetSize(); + { + GcClock::Tick AccessTime = GcClock::TickCount(); + RwLock::ExclusiveLockScope _(m_BucketLock); + if (m_CacheMap.size() == std::numeric_limits<uint32_t>::max()) + { + // No more space in our memory cache! + return; + } + if (auto It = m_CacheMap.find(HashKey); It != m_CacheMap.end()) + { + uint32_t EntryIndex = It.value(); + ZEN_ASSERT_SLOW(EntryIndex < m_Payloads.size()); + + m_TotalSize.fetch_sub(PayloadSize, std::memory_order::relaxed); + BucketPayload& Payload = m_Payloads[EntryIndex]; + Payload.Payload = Value.Value; + Payload.RawHash = Value.RawHash; + Payload.RawSize = gsl::narrow<uint32_t>(Value.RawSize); + m_AccessTimes[EntryIndex] = AccessTime; + } + else + { + uint32_t EntryIndex = gsl::narrow<uint32_t>(m_Payloads.size()); + m_Payloads.emplace_back( + BucketPayload{.Payload = Value.Value, .RawSize = gsl::narrow<uint32_t>(Value.RawSize), .RawHash = Value.RawHash}); + m_AccessTimes.emplace_back(AccessTime); + m_CacheMap.insert_or_assign(HashKey, EntryIndex); + } + ZEN_ASSERT_SLOW(m_Payloads.size() == m_CacheMap.size()); + ZEN_ASSERT_SLOW(m_AccessTimes.size() == m_Payloads.size()); + } + + m_TotalSize.fetch_add(PayloadSize, std::memory_order::relaxed); +} + +void +ZenCacheMemoryLayer::CacheBucket::Drop() +{ + RwLock::ExclusiveLockScope _(m_BucketLock); + m_CacheMap.clear(); + m_AccessTimes.clear(); + m_Payloads.clear(); + m_TotalSize.store(0); +} + +uint64_t +ZenCacheMemoryLayer::CacheBucket::EntryCount() const +{ + RwLock::SharedLockScope _(m_BucketLock); + return static_cast<uint64_t>(m_CacheMap.size()); +} + +////////////////////////////////////////////////////////////////////////// + +ZenCacheDiskLayer::CacheBucket::CacheBucket(std::string BucketName) : m_BucketName(std::move(BucketName)), m_BucketId(Oid::Zero) +{ +} + +ZenCacheDiskLayer::CacheBucket::~CacheBucket() +{ +} + +bool +ZenCacheDiskLayer::CacheBucket::OpenOrCreate(std::filesystem::path BucketDir, bool AllowCreate) +{ + using namespace std::literals; + + m_BlocksBasePath = BucketDir / "blocks"; + m_BucketDir = BucketDir; + + CreateDirectories(m_BucketDir); + + std::filesystem::path ManifestPath{m_BucketDir / "zen_manifest"}; + + bool IsNew = false; + + CbObject Manifest = LoadCompactBinaryObject(ManifestPath); + + if (Manifest) + { + m_BucketId = Manifest["BucketId"sv].AsObjectId(); + if (m_BucketId == Oid::Zero) + { + return false; + } + } + else if (AllowCreate) + { + m_BucketId.Generate(); + + CbObjectWriter Writer; + Writer << "BucketId"sv << m_BucketId; + Manifest = Writer.Save(); + SaveCompactBinaryObject(ManifestPath, Manifest); + IsNew = true; + } + else + { + return false; + } + + OpenLog(IsNew); + + if (!IsNew) + { + Stopwatch Timer; + const auto _ = + MakeGuard([&] { ZEN_INFO("read store manifest '{}' in {}", ManifestPath, NiceTimeSpanMs(Timer.GetElapsedTimeMs())); }); + + for (CbFieldView Entry : Manifest["Timestamps"sv]) + { + const CbObjectView Obj = Entry.AsObjectView(); + const IoHash Key = Obj["Key"sv].AsHash(); + + if (auto It = m_Index.find(Key); It != m_Index.end()) + { + size_t EntryIndex = It.value(); + ZEN_ASSERT_SLOW(EntryIndex < m_AccessTimes.size()); + m_AccessTimes[EntryIndex] = Obj["LastAccess"sv].AsInt64(); + } + } + for (CbFieldView Entry : Manifest["RawInfo"sv]) + { + const CbObjectView Obj = Entry.AsObjectView(); + const IoHash Key = Obj["Key"sv].AsHash(); + if (auto It = m_Index.find(Key); It != m_Index.end()) + { + size_t EntryIndex = It.value(); + ZEN_ASSERT_SLOW(EntryIndex < m_Payloads.size()); + m_Payloads[EntryIndex].RawHash = Obj["RawHash"sv].AsHash(); + m_Payloads[EntryIndex].RawSize = Obj["RawSize"sv].AsUInt64(); + } + } + } + + return true; +} + +void +ZenCacheDiskLayer::CacheBucket::MakeIndexSnapshot() +{ + uint64_t LogCount = m_SlogFile.GetLogCount(); + if (m_LogFlushPosition == LogCount) + { + return; + } + + ZEN_DEBUG("write store snapshot for '{}'", m_BucketDir / m_BucketName); + uint64_t EntryCount = 0; + Stopwatch Timer; + const auto _ = MakeGuard([&] { + ZEN_INFO("wrote store snapshot for '{}' containing {} entries in {}", + m_BucketDir / m_BucketName, + EntryCount, + NiceTimeSpanMs(Timer.GetElapsedTimeMs())); + }); + + namespace fs = std::filesystem; + + fs::path IndexPath = GetIndexPath(m_BucketDir, m_BucketName); + fs::path STmpIndexPath = GetTempIndexPath(m_BucketDir, m_BucketName); + + // Move index away, we keep it if something goes wrong + if (fs::is_regular_file(STmpIndexPath)) + { + fs::remove(STmpIndexPath); + } + if (fs::is_regular_file(IndexPath)) + { + fs::rename(IndexPath, STmpIndexPath); + } + + try + { + // Write the current state of the location map to a new index state + std::vector<DiskIndexEntry> Entries; + + { + Entries.resize(m_Index.size()); + + uint64_t EntryIndex = 0; + for (auto& Entry : m_Index) + { + DiskIndexEntry& IndexEntry = Entries[EntryIndex++]; + IndexEntry.Key = Entry.first; + IndexEntry.Location = m_Payloads[Entry.second].Location; + } + } + + BasicFile ObjectIndexFile; + ObjectIndexFile.Open(IndexPath, BasicFile::Mode::kTruncate); + CacheBucketIndexHeader Header = {.EntryCount = Entries.size(), + .LogPosition = LogCount, + .PayloadAlignment = gsl::narrow<uint32_t>(m_PayloadAlignment)}; + + Header.Checksum = CacheBucketIndexHeader::ComputeChecksum(Header); + + ObjectIndexFile.Write(&Header, sizeof(CacheBucketIndexHeader), 0); + ObjectIndexFile.Write(Entries.data(), Entries.size() * sizeof(DiskIndexEntry), sizeof(CacheBucketIndexHeader)); + ObjectIndexFile.Flush(); + ObjectIndexFile.Close(); + EntryCount = Entries.size(); + m_LogFlushPosition = LogCount; + } + catch (std::exception& Err) + { + ZEN_ERROR("snapshot FAILED, reason: '{}'", Err.what()); + + // Restore any previous snapshot + + if (fs::is_regular_file(STmpIndexPath)) + { + fs::remove(IndexPath); + fs::rename(STmpIndexPath, IndexPath); + } + } + if (fs::is_regular_file(STmpIndexPath)) + { + fs::remove(STmpIndexPath); + } +} + +uint64_t +ZenCacheDiskLayer::CacheBucket::ReadIndexFile(const std::filesystem::path& IndexPath, uint32_t& OutVersion) +{ + if (std::filesystem::is_regular_file(IndexPath)) + { + BasicFile ObjectIndexFile; + ObjectIndexFile.Open(IndexPath, BasicFile::Mode::kRead); + uint64_t Size = ObjectIndexFile.FileSize(); + if (Size >= sizeof(CacheBucketIndexHeader)) + { + CacheBucketIndexHeader Header; + ObjectIndexFile.Read(&Header, sizeof(Header), 0); + if ((Header.Magic == CacheBucketIndexHeader::ExpectedMagic) && + (Header.Checksum == CacheBucketIndexHeader::ComputeChecksum(Header)) && (Header.PayloadAlignment > 0)) + { + switch (Header.Version) + { + case CacheBucketIndexHeader::Version2: + { + uint64_t ExpectedEntryCount = (Size - sizeof(sizeof(CacheBucketIndexHeader))) / sizeof(DiskIndexEntry); + if (Header.EntryCount > ExpectedEntryCount) + { + break; + } + size_t EntryCount = 0; + Stopwatch Timer; + const auto _ = MakeGuard([&] { + ZEN_INFO("read store '{}' index containing {} entries in {}", + IndexPath, + EntryCount, + NiceTimeSpanMs(Timer.GetElapsedTimeMs())); + }); + + m_PayloadAlignment = Header.PayloadAlignment; + + std::vector<DiskIndexEntry> Entries; + Entries.resize(Header.EntryCount); + ObjectIndexFile.Read(Entries.data(), + Header.EntryCount * sizeof(DiskIndexEntry), + sizeof(CacheBucketIndexHeader)); + + m_Payloads.reserve(Header.EntryCount); + m_AccessTimes.reserve(Header.EntryCount); + m_Index.reserve(Header.EntryCount); + + std::string InvalidEntryReason; + for (const DiskIndexEntry& Entry : Entries) + { + if (!ValidateEntry(Entry, InvalidEntryReason)) + { + ZEN_WARN("skipping invalid entry in '{}', reason: '{}'", IndexPath, InvalidEntryReason); + continue; + } + size_t EntryIndex = m_Payloads.size(); + m_Payloads.emplace_back(BucketPayload{.Location = Entry.Location, .RawSize = 0, .RawHash = IoHash::Zero}); + m_AccessTimes.emplace_back(GcClock::TickCount()); + m_Index.insert_or_assign(Entry.Key, EntryIndex); + EntryCount++; + } + OutVersion = CacheBucketIndexHeader::Version2; + return Header.LogPosition; + } + break; + default: + break; + } + } + } + ZEN_WARN("skipping invalid index file '{}'", IndexPath); + } + return 0; +} + +uint64_t +ZenCacheDiskLayer::CacheBucket::ReadLog(const std::filesystem::path& LogPath, uint64_t SkipEntryCount) +{ + if (std::filesystem::is_regular_file(LogPath)) + { + uint64_t LogEntryCount = 0; + Stopwatch Timer; + const auto _ = MakeGuard([&] { + ZEN_INFO("read store '{}' log containing {} entries in {}", LogPath, LogEntryCount, NiceTimeSpanMs(Timer.GetElapsedTimeMs())); + }); + TCasLogFile<DiskIndexEntry> CasLog; + CasLog.Open(LogPath, CasLogFile::Mode::kRead); + if (CasLog.Initialize()) + { + uint64_t EntryCount = CasLog.GetLogCount(); + if (EntryCount < SkipEntryCount) + { + ZEN_WARN("reading full log at '{}', reason: Log position from index snapshot is out of range", LogPath); + SkipEntryCount = 0; + } + LogEntryCount = EntryCount - SkipEntryCount; + m_Index.reserve(LogEntryCount); + uint64_t InvalidEntryCount = 0; + CasLog.Replay( + [&](const DiskIndexEntry& Record) { + std::string InvalidEntryReason; + if (Record.Location.Flags & DiskLocation::kTombStone) + { + m_Index.erase(Record.Key); + return; + } + if (!ValidateEntry(Record, InvalidEntryReason)) + { + ZEN_WARN("skipping invalid entry in '{}', reason: '{}'", LogPath, InvalidEntryReason); + ++InvalidEntryCount; + return; + } + size_t EntryIndex = m_Payloads.size(); + m_Payloads.emplace_back(BucketPayload{.Location = Record.Location, .RawSize = 0u, .RawHash = IoHash::Zero}); + m_AccessTimes.emplace_back(GcClock::TickCount()); + m_Index.insert_or_assign(Record.Key, EntryIndex); + }, + SkipEntryCount); + if (InvalidEntryCount) + { + ZEN_WARN("found {} invalid entries in '{}'", InvalidEntryCount, m_BucketDir / m_BucketName); + } + return LogEntryCount; + } + } + return 0; +}; + +void +ZenCacheDiskLayer::CacheBucket::OpenLog(const bool IsNew) +{ + m_TotalStandaloneSize = 0; + + m_Index.clear(); + m_Payloads.clear(); + m_AccessTimes.clear(); + + std::filesystem::path LogPath = GetLogPath(m_BucketDir, m_BucketName); + std::filesystem::path IndexPath = GetIndexPath(m_BucketDir, m_BucketName); + + if (IsNew) + { + fs::remove(LogPath); + fs::remove(IndexPath); + fs::remove_all(m_BlocksBasePath); + } + + uint64_t LogEntryCount = 0; + { + uint32_t IndexVersion = 0; + m_LogFlushPosition = ReadIndexFile(IndexPath, IndexVersion); + if (IndexVersion == 0 && std::filesystem::is_regular_file(IndexPath)) + { + ZEN_WARN("removing invalid index file at '{}'", IndexPath); + fs::remove(IndexPath); + } + + if (TCasLogFile<DiskIndexEntry>::IsValid(LogPath)) + { + LogEntryCount = ReadLog(LogPath, m_LogFlushPosition); + } + else + { + ZEN_WARN("removing invalid cas log at '{}'", LogPath); + fs::remove(LogPath); + } + } + + CreateDirectories(m_BucketDir); + + m_SlogFile.Open(LogPath, CasLogFile::Mode::kWrite); + + std::vector<BlockStoreLocation> KnownLocations; + KnownLocations.reserve(m_Index.size()); + for (const auto& Entry : m_Index) + { + size_t EntryIndex = Entry.second; + const BucketPayload& Payload = m_Payloads[EntryIndex]; + const DiskLocation& Location = Payload.Location; + + if (Location.IsFlagSet(DiskLocation::kStandaloneFile)) + { + m_TotalStandaloneSize.fetch_add(Location.Size(), std::memory_order::relaxed); + continue; + } + const BlockStoreLocation& BlockLocation = Location.GetBlockLocation(m_PayloadAlignment); + KnownLocations.push_back(BlockLocation); + } + + m_BlockStore.Initialize(m_BlocksBasePath, MaxBlockSize, BlockStoreDiskLocation::MaxBlockIndex + 1, KnownLocations); + if (IsNew || LogEntryCount > 0) + { + MakeIndexSnapshot(); + } + // TODO: should validate integrity of container files here +} + +void +ZenCacheDiskLayer::CacheBucket::BuildPath(PathBuilderBase& Path, const IoHash& HashKey) const +{ + char HexString[sizeof(HashKey.Hash) * 2]; + ToHexBytes(HashKey.Hash, sizeof HashKey.Hash, HexString); + + Path.Append(m_BucketDir); + Path.AppendSeparator(); + Path.Append(L"blob"); + Path.AppendSeparator(); + Path.AppendAsciiRange(HexString, HexString + 3); + Path.AppendSeparator(); + Path.AppendAsciiRange(HexString + 3, HexString + 5); + Path.AppendSeparator(); + Path.AppendAsciiRange(HexString + 5, HexString + sizeof(HexString)); +} + +IoBuffer +ZenCacheDiskLayer::CacheBucket::GetInlineCacheValue(const DiskLocation& Loc) const +{ + BlockStoreLocation Location = Loc.GetBlockLocation(m_PayloadAlignment); + + IoBuffer Value = m_BlockStore.TryGetChunk(Location); + if (Value) + { + Value.SetContentType(Loc.GetContentType()); + } + + return Value; +} + +IoBuffer +ZenCacheDiskLayer::CacheBucket::GetStandaloneCacheValue(const DiskLocation& Loc, const IoHash& HashKey) const +{ + ExtendablePathBuilder<256> DataFilePath; + BuildPath(DataFilePath, HashKey); + + RwLock::SharedLockScope ValueLock(LockForHash(HashKey)); + + if (IoBuffer Data = IoBufferBuilder::MakeFromFile(DataFilePath.ToPath())) + { + Data.SetContentType(Loc.GetContentType()); + + return Data; + } + + return {}; +} + +bool +ZenCacheDiskLayer::CacheBucket::Get(const IoHash& HashKey, ZenCacheValue& OutValue) +{ + RwLock::SharedLockScope _(m_IndexLock); + auto It = m_Index.find(HashKey); + if (It == m_Index.end()) + { + return false; + } + size_t EntryIndex = It.value(); + const BucketPayload& Payload = m_Payloads[EntryIndex]; + m_AccessTimes[EntryIndex] = GcClock::TickCount(); + DiskLocation Location = Payload.Location; + OutValue.RawSize = Payload.RawSize; + OutValue.RawHash = Payload.RawHash; + if (Location.IsFlagSet(DiskLocation::kStandaloneFile)) + { + // We don't need to hold the index lock when we read a standalone file + _.ReleaseNow(); + OutValue.Value = GetStandaloneCacheValue(Location, HashKey); + } + else + { + OutValue.Value = GetInlineCacheValue(Location); + } + _.ReleaseNow(); + + if (!Location.IsFlagSet(DiskLocation::kStructured)) + { + if (OutValue.RawHash == IoHash::Zero && OutValue.RawSize == 0 && OutValue.Value.GetSize() > 0) + { + if (Location.IsFlagSet(DiskLocation::kCompressed)) + { + (void)CompressedBuffer::FromCompressed(SharedBuffer(OutValue.Value), OutValue.RawHash, OutValue.RawSize); + } + else + { + OutValue.RawHash = IoHash::HashBuffer(OutValue.Value); + OutValue.RawSize = OutValue.Value.GetSize(); + } + RwLock::ExclusiveLockScope __(m_IndexLock); + if (auto WriteIt = m_Index.find(HashKey); WriteIt != m_Index.end()) + { + BucketPayload& WritePayload = m_Payloads[WriteIt.value()]; + WritePayload.RawHash = OutValue.RawHash; + WritePayload.RawSize = OutValue.RawSize; + + m_LogFlushPosition = 0; // Force resave of index on exit + } + } + } + + return (bool)OutValue.Value; +} + +void +ZenCacheDiskLayer::CacheBucket::Put(const IoHash& HashKey, const ZenCacheValue& Value) +{ + if (Value.Value.Size() >= m_LargeObjectThreshold) + { + return PutStandaloneCacheValue(HashKey, Value); + } + PutInlineCacheValue(HashKey, Value); +} + +bool +ZenCacheDiskLayer::CacheBucket::Drop() +{ + RwLock::ExclusiveLockScope _(m_IndexLock); + + std::vector<std::unique_ptr<RwLock::ExclusiveLockScope>> ShardLocks; + ShardLocks.reserve(256); + for (RwLock& Lock : m_ShardedLocks) + { + ShardLocks.push_back(std::make_unique<RwLock::ExclusiveLockScope>(Lock)); + } + m_BlockStore.Close(); + m_SlogFile.Close(); + + bool Deleted = MoveAndDeleteDirectory(m_BucketDir); + + m_Index.clear(); + m_Payloads.clear(); + m_AccessTimes.clear(); + return Deleted; +} + +void +ZenCacheDiskLayer::CacheBucket::Flush() +{ + m_BlockStore.Flush(); + + RwLock::SharedLockScope _(m_IndexLock); + m_SlogFile.Flush(); + MakeIndexSnapshot(); + SaveManifest(); +} + +void +ZenCacheDiskLayer::CacheBucket::SaveManifest() +{ + using namespace std::literals; + + CbObjectWriter Writer; + Writer << "BucketId"sv << m_BucketId; + + if (!m_Index.empty()) + { + Writer.BeginArray("Timestamps"sv); + for (auto& Kv : m_Index) + { + const IoHash& Key = Kv.first; + GcClock::Tick AccessTime = m_AccessTimes[Kv.second]; + + Writer.BeginObject(); + Writer << "Key"sv << Key; + Writer << "LastAccess"sv << AccessTime; + Writer.EndObject(); + } + Writer.EndArray(); + + Writer.BeginArray("RawInfo"sv); + { + for (auto& Kv : m_Index) + { + const IoHash& Key = Kv.first; + const BucketPayload& Payload = m_Payloads[Kv.second]; + if (Payload.RawHash != IoHash::Zero) + { + Writer.BeginObject(); + Writer << "Key"sv << Key; + Writer << "RawHash"sv << Payload.RawHash; + Writer << "RawSize"sv << Payload.RawSize; + Writer.EndObject(); + } + } + } + Writer.EndArray(); + } + + SaveCompactBinaryObject(m_BucketDir / "zen_manifest", Writer.Save()); +} + +void +ZenCacheDiskLayer::CacheBucket::Scrub(ScrubContext& Ctx) +{ + std::vector<IoHash> BadKeys; + uint64_t ChunkCount{0}, ChunkBytes{0}; + std::vector<BlockStoreLocation> ChunkLocations; + std::vector<IoHash> ChunkIndexToChunkHash; + + auto ValidateEntry = [](const IoHash& Hash, ZenContentType ContentType, IoBuffer Buffer) { + if (ContentType == ZenContentType::kCbObject) + { + CbValidateError Error = ValidateCompactBinary(Buffer, CbValidateMode::All); + return Error == CbValidateError::None; + } + if (ContentType == ZenContentType::kCompressedBinary) + { + IoHash RawHash; + uint64_t RawSize; + if (!CompressedBuffer::ValidateCompressedHeader(Buffer, RawHash, RawSize)) + { + return false; + } + if (RawHash != Hash) + { + return false; + } + } + return true; + }; + + RwLock::SharedLockScope _(m_IndexLock); + + const size_t BlockChunkInitialCount = m_Index.size() / 4; + ChunkLocations.reserve(BlockChunkInitialCount); + ChunkIndexToChunkHash.reserve(BlockChunkInitialCount); + + for (auto& Kv : m_Index) + { + const IoHash& HashKey = Kv.first; + const BucketPayload& Payload = m_Payloads[Kv.second]; + const DiskLocation& Loc = Payload.Location; + + if (Loc.IsFlagSet(DiskLocation::kStandaloneFile)) + { + ++ChunkCount; + ChunkBytes += Loc.Size(); + if (Loc.GetContentType() == ZenContentType::kBinary) + { + ExtendablePathBuilder<256> DataFilePath; + BuildPath(DataFilePath, HashKey); + + RwLock::SharedLockScope ValueLock(LockForHash(HashKey)); + + std::error_code Ec; + uintmax_t size = std::filesystem::file_size(DataFilePath.ToPath(), Ec); + if (Ec) + { + BadKeys.push_back(HashKey); + } + if (size != Loc.Size()) + { + BadKeys.push_back(HashKey); + } + continue; + } + IoBuffer Buffer = GetStandaloneCacheValue(Loc, HashKey); + if (!Buffer) + { + BadKeys.push_back(HashKey); + continue; + } + if (!ValidateEntry(HashKey, Loc.GetContentType(), Buffer)) + { + BadKeys.push_back(HashKey); + continue; + } + } + else + { + ChunkLocations.emplace_back(Loc.GetBlockLocation(m_PayloadAlignment)); + ChunkIndexToChunkHash.push_back(HashKey); + continue; + } + } + + const auto ValidateSmallChunk = [&](size_t ChunkIndex, const void* Data, uint64_t Size) { + ++ChunkCount; + ChunkBytes += Size; + const IoHash& Hash = ChunkIndexToChunkHash[ChunkIndex]; + if (!Data) + { + // ChunkLocation out of range of stored blocks + BadKeys.push_back(Hash); + return; + } + IoBuffer Buffer(IoBuffer::Wrap, Data, Size); + if (!Buffer) + { + BadKeys.push_back(Hash); + return; + } + const BucketPayload& Payload = m_Payloads[m_Index.at(Hash)]; + ZenContentType ContentType = Payload.Location.GetContentType(); + if (!ValidateEntry(Hash, ContentType, Buffer)) + { + BadKeys.push_back(Hash); + return; + } + }; + + const auto ValidateLargeChunk = [&](size_t ChunkIndex, BlockStoreFile& File, uint64_t Offset, uint64_t Size) { + ++ChunkCount; + ChunkBytes += Size; + const IoHash& Hash = ChunkIndexToChunkHash[ChunkIndex]; + // TODO: Add API to verify compressed buffer and possible structure data without having to memorymap the whole file + IoBuffer Buffer(IoBuffer::BorrowedFile, File.GetBasicFile().Handle(), Offset, Size); + if (!Buffer) + { + BadKeys.push_back(Hash); + return; + } + const BucketPayload& Payload = m_Payloads[m_Index.at(Hash)]; + ZenContentType ContentType = Payload.Location.GetContentType(); + if (!ValidateEntry(Hash, ContentType, Buffer)) + { + BadKeys.push_back(Hash); + return; + } + }; + + m_BlockStore.IterateChunks(ChunkLocations, ValidateSmallChunk, ValidateLargeChunk); + + _.ReleaseNow(); + + Ctx.ReportScrubbed(ChunkCount, ChunkBytes); + + if (!BadKeys.empty()) + { + ZEN_WARN("Scrubbing found {} bad chunks in '{}'", BadKeys.size(), m_BucketDir / m_BucketName); + + if (Ctx.RunRecovery()) + { + // Deal with bad chunks by removing them from our lookup map + + std::vector<DiskIndexEntry> LogEntries; + LogEntries.reserve(BadKeys.size()); + + { + RwLock::ExclusiveLockScope __(m_IndexLock); + for (const IoHash& BadKey : BadKeys) + { + // Log a tombstone and delete the in-memory index for the bad entry + const auto It = m_Index.find(BadKey); + const BucketPayload& Payload = m_Payloads[It->second]; + DiskLocation Location = Payload.Location; + Location.Flags |= DiskLocation::kTombStone; + LogEntries.push_back(DiskIndexEntry{.Key = BadKey, .Location = Location}); + m_Index.erase(BadKey); + } + } + for (const DiskIndexEntry& Entry : LogEntries) + { + if (Entry.Location.IsFlagSet(DiskLocation::kStandaloneFile)) + { + ExtendablePathBuilder<256> Path; + BuildPath(Path, Entry.Key); + fs::path FilePath = Path.ToPath(); + RwLock::ExclusiveLockScope ValueLock(LockForHash(Entry.Key)); + if (fs::is_regular_file(FilePath)) + { + ZEN_DEBUG("deleting bad standalone cache file '{}'", Path.ToUtf8()); + std::error_code Ec; + fs::remove(FilePath, Ec); // We don't care if we fail, we are no longer tracking this file... + } + m_TotalStandaloneSize.fetch_sub(Entry.Location.Size(), std::memory_order::relaxed); + } + } + m_SlogFile.Append(LogEntries); + + // Clean up m_AccessTimes and m_Payloads vectors + { + std::vector<BucketPayload> Payloads; + std::vector<AccessTime> AccessTimes; + IndexMap Index; + + { + RwLock::ExclusiveLockScope __(m_IndexLock); + size_t EntryCount = m_Index.size(); + Payloads.reserve(EntryCount); + AccessTimes.reserve(EntryCount); + Index.reserve(EntryCount); + for (auto It : m_Index) + { + size_t EntryIndex = Payloads.size(); + Payloads.push_back(m_Payloads[EntryIndex]); + AccessTimes.push_back(m_AccessTimes[EntryIndex]); + Index.insert({It.first, EntryIndex}); + } + m_Index.swap(Index); + m_Payloads.swap(Payloads); + m_AccessTimes.swap(AccessTimes); + } + } + } + } + + // Let whomever it concerns know about the bad chunks. This could + // be used to invalidate higher level data structures more efficiently + // than a full validation pass might be able to do + Ctx.ReportBadCidChunks(BadKeys); + + ZEN_INFO("cache bucket scrubbed: {} chunks ({})", ChunkCount, NiceBytes(ChunkBytes)); +} + +void +ZenCacheDiskLayer::CacheBucket::GatherReferences(GcContext& GcCtx) +{ + ZEN_TRACE_CPU("Z$::DiskLayer::CacheBucket::GatherReferences"); + + uint64_t WriteBlockTimeUs = 0; + uint64_t WriteBlockLongestTimeUs = 0; + uint64_t ReadBlockTimeUs = 0; + uint64_t ReadBlockLongestTimeUs = 0; + + Stopwatch TotalTimer; + const auto _ = MakeGuard([&] { + ZEN_DEBUG("gathered references from '{}' in {} write lock: {} ({}), read lock: {} ({})", + m_BucketDir / m_BucketName, + NiceTimeSpanMs(TotalTimer.GetElapsedTimeMs()), + NiceLatencyNs(WriteBlockTimeUs), + NiceLatencyNs(WriteBlockLongestTimeUs), + NiceLatencyNs(ReadBlockTimeUs), + NiceLatencyNs(ReadBlockLongestTimeUs)); + }); + + const GcClock::TimePoint ExpireTime = GcCtx.ExpireTime(); + + const GcClock::Tick ExpireTicks = ExpireTime.time_since_epoch().count(); + + IndexMap Index; + std::vector<AccessTime> AccessTimes; + std::vector<BucketPayload> Payloads; + { + RwLock::SharedLockScope __(m_IndexLock); + Stopwatch Timer; + const auto ___ = MakeGuard([&] { + uint64_t ElapsedUs = Timer.GetElapsedTimeUs(); + WriteBlockTimeUs += ElapsedUs; + WriteBlockLongestTimeUs = std::max(ElapsedUs, WriteBlockLongestTimeUs); + }); + Index = m_Index; + AccessTimes = m_AccessTimes; + Payloads = m_Payloads; + } + + std::vector<IoHash> ExpiredKeys; + ExpiredKeys.reserve(1024); + + std::vector<IoHash> Cids; + Cids.reserve(1024); + + for (const auto& Entry : Index) + { + const IoHash& Key = Entry.first; + GcClock::Tick AccessTime = AccessTimes[Entry.second]; + if (AccessTime < ExpireTicks) + { + ExpiredKeys.push_back(Key); + continue; + } + + const DiskLocation& Loc = Payloads[Entry.second].Location; + + if (Loc.IsFlagSet(DiskLocation::kStructured)) + { + if (Cids.size() > 1024) + { + GcCtx.AddRetainedCids(Cids); + Cids.clear(); + } + + IoBuffer Buffer; + { + RwLock::SharedLockScope __(m_IndexLock); + Stopwatch Timer; + const auto ___ = MakeGuard([&] { + uint64_t ElapsedUs = Timer.GetElapsedTimeUs(); + WriteBlockTimeUs += ElapsedUs; + WriteBlockLongestTimeUs = std::max(ElapsedUs, WriteBlockLongestTimeUs); + }); + if (Loc.IsFlagSet(DiskLocation::kStandaloneFile)) + { + // We don't need to hold the index lock when we read a standalone file + __.ReleaseNow(); + if (Buffer = GetStandaloneCacheValue(Loc, Key); !Buffer) + { + continue; + } + } + else if (Buffer = GetInlineCacheValue(Loc); !Buffer) + { + continue; + } + } + + ZEN_ASSERT(Buffer); + ZEN_ASSERT(Buffer.GetContentType() == ZenContentType::kCbObject); + CbObject Obj(SharedBuffer{Buffer}); + Obj.IterateAttachments([&Cids](CbFieldView Field) { Cids.push_back(Field.AsAttachment()); }); + } + } + + GcCtx.AddRetainedCids(Cids); + GcCtx.SetExpiredCacheKeys(m_BucketDir.string(), std::move(ExpiredKeys)); +} + +void +ZenCacheDiskLayer::CacheBucket::CollectGarbage(GcContext& GcCtx) +{ + ZEN_TRACE_CPU("Z$::DiskLayer::CacheBucket::CollectGarbage"); + + ZEN_DEBUG("collecting garbage from '{}'", m_BucketDir / m_BucketName); + + Stopwatch TotalTimer; + uint64_t WriteBlockTimeUs = 0; + uint64_t WriteBlockLongestTimeUs = 0; + uint64_t ReadBlockTimeUs = 0; + uint64_t ReadBlockLongestTimeUs = 0; + uint64_t TotalChunkCount = 0; + uint64_t DeletedSize = 0; + uint64_t OldTotalSize = TotalSize(); + + std::unordered_set<IoHash> DeletedChunks; + uint64_t MovedCount = 0; + + const auto _ = MakeGuard([&] { + ZEN_DEBUG( + "garbage collect from '{}' DONE after {}, write lock: {} ({}), read lock: {} ({}), collected {} bytes, deleted {} and moved " + "{} " + "of {} " + "entires ({}).", + m_BucketDir / m_BucketName, + NiceTimeSpanMs(TotalTimer.GetElapsedTimeMs()), + NiceLatencyNs(WriteBlockTimeUs), + NiceLatencyNs(WriteBlockLongestTimeUs), + NiceLatencyNs(ReadBlockTimeUs), + NiceLatencyNs(ReadBlockLongestTimeUs), + NiceBytes(DeletedSize), + DeletedChunks.size(), + MovedCount, + TotalChunkCount, + NiceBytes(OldTotalSize)); + RwLock::SharedLockScope _(m_IndexLock); + SaveManifest(); + }); + + m_SlogFile.Flush(); + + std::span<const IoHash> ExpiredCacheKeys = GcCtx.ExpiredCacheKeys(m_BucketDir.string()); + std::vector<IoHash> DeleteCacheKeys; + DeleteCacheKeys.reserve(ExpiredCacheKeys.size()); + GcCtx.FilterCids(ExpiredCacheKeys, [&](const IoHash& ChunkHash, bool Keep) { + if (Keep) + { + return; + } + DeleteCacheKeys.push_back(ChunkHash); + }); + if (DeleteCacheKeys.empty()) + { + ZEN_DEBUG("garbage collect SKIPPED, for '{}', no expired cache keys found", m_BucketDir / m_BucketName); + return; + } + + auto __ = MakeGuard([&]() { + if (!DeletedChunks.empty()) + { + // Clean up m_AccessTimes and m_Payloads vectors + std::vector<BucketPayload> Payloads; + std::vector<AccessTime> AccessTimes; + IndexMap Index; + + { + RwLock::ExclusiveLockScope _(m_IndexLock); + Stopwatch Timer; + const auto ___ = MakeGuard([&] { + uint64_t ElapsedUs = Timer.GetElapsedTimeUs(); + WriteBlockTimeUs += ElapsedUs; + WriteBlockLongestTimeUs = std::max(ElapsedUs, WriteBlockLongestTimeUs); + }); + size_t EntryCount = m_Index.size(); + Payloads.reserve(EntryCount); + AccessTimes.reserve(EntryCount); + Index.reserve(EntryCount); + for (auto It : m_Index) + { + size_t EntryIndex = Payloads.size(); + Payloads.push_back(m_Payloads[EntryIndex]); + AccessTimes.push_back(m_AccessTimes[EntryIndex]); + Index.insert({It.first, EntryIndex}); + } + m_Index.swap(Index); + m_Payloads.swap(Payloads); + m_AccessTimes.swap(AccessTimes); + } + GcCtx.AddDeletedCids(std::vector<IoHash>(DeletedChunks.begin(), DeletedChunks.end())); + } + }); + + std::vector<DiskIndexEntry> ExpiredStandaloneEntries; + IndexMap Index; + BlockStore::ReclaimSnapshotState BlockStoreState; + { + RwLock::SharedLockScope __(m_IndexLock); + Stopwatch Timer; + const auto ____ = MakeGuard([&] { + uint64_t ElapsedUs = Timer.GetElapsedTimeUs(); + WriteBlockTimeUs += ElapsedUs; + WriteBlockLongestTimeUs = std::max(ElapsedUs, WriteBlockLongestTimeUs); + }); + if (m_Index.empty()) + { + ZEN_DEBUG("garbage collect SKIPPED, for '{}', container is empty", m_BucketDir / m_BucketName); + return; + } + BlockStoreState = m_BlockStore.GetReclaimSnapshotState(); + + SaveManifest(); + Index = m_Index; + + for (const IoHash& Key : DeleteCacheKeys) + { + if (auto It = Index.find(Key); It != Index.end()) + { + const BucketPayload& Payload = m_Payloads[It->second]; + DiskIndexEntry Entry = {.Key = It->first, .Location = Payload.Location}; + if (Entry.Location.Flags & DiskLocation::kStandaloneFile) + { + Entry.Location.Flags |= DiskLocation::kTombStone; + ExpiredStandaloneEntries.push_back(Entry); + } + } + } + if (GcCtx.IsDeletionMode()) + { + for (const auto& Entry : ExpiredStandaloneEntries) + { + m_Index.erase(Entry.Key); + m_TotalStandaloneSize.fetch_sub(Entry.Location.Size(), std::memory_order::relaxed); + DeletedChunks.insert(Entry.Key); + } + m_SlogFile.Append(ExpiredStandaloneEntries); + } + } + + if (GcCtx.IsDeletionMode()) + { + std::error_code Ec; + ExtendablePathBuilder<256> Path; + + for (const auto& Entry : ExpiredStandaloneEntries) + { + const IoHash& Key = Entry.Key; + const DiskLocation& Loc = Entry.Location; + + Path.Reset(); + BuildPath(Path, Key); + fs::path FilePath = Path.ToPath(); + + { + RwLock::SharedLockScope __(m_IndexLock); + Stopwatch Timer; + const auto ____ = MakeGuard([&] { + uint64_t ElapsedUs = Timer.GetElapsedTimeUs(); + WriteBlockTimeUs += ElapsedUs; + WriteBlockLongestTimeUs = std::max(ElapsedUs, WriteBlockLongestTimeUs); + }); + if (m_Index.contains(Key)) + { + // Someone added it back, let the file on disk be + ZEN_DEBUG("skipping z$ delete standalone of file '{}' FAILED, it has been added back", Path.ToUtf8()); + continue; + } + __.ReleaseNow(); + + RwLock::ExclusiveLockScope ValueLock(LockForHash(Key)); + if (fs::is_regular_file(FilePath)) + { + ZEN_DEBUG("deleting standalone cache file '{}'", Path.ToUtf8()); + fs::remove(FilePath, Ec); + } + } + + if (Ec) + { + ZEN_WARN("delete expired z$ standalone file '{}' FAILED, reason: '{}'", Path.ToUtf8(), Ec.message()); + Ec.clear(); + DiskLocation RestoreLocation = Loc; + RestoreLocation.Flags &= ~DiskLocation::kTombStone; + + RwLock::ExclusiveLockScope __(m_IndexLock); + Stopwatch Timer; + const auto ___ = MakeGuard([&] { + uint64_t ElapsedUs = Timer.GetElapsedTimeUs(); + ReadBlockTimeUs += ElapsedUs; + ReadBlockLongestTimeUs = std::max(ElapsedUs, ReadBlockLongestTimeUs); + }); + if (m_Index.contains(Key)) + { + continue; + } + m_SlogFile.Append(DiskIndexEntry{.Key = Key, .Location = RestoreLocation}); + size_t EntryIndex = m_Payloads.size(); + m_Payloads.emplace_back(BucketPayload{.Location = RestoreLocation}); + m_AccessTimes.emplace_back(GcClock::TickCount()); + m_Index.insert({Key, EntryIndex}); + m_TotalStandaloneSize.fetch_add(RestoreLocation.Size(), std::memory_order::relaxed); + DeletedChunks.erase(Key); + continue; + } + DeletedSize += Entry.Location.Size(); + } + } + + TotalChunkCount = Index.size(); + + std::vector<IoHash> TotalChunkHashes; + TotalChunkHashes.reserve(TotalChunkCount); + for (const auto& Entry : Index) + { + const DiskLocation& Location = m_Payloads[Entry.second].Location; + + if (Location.Flags & DiskLocation::kStandaloneFile) + { + continue; + } + TotalChunkHashes.push_back(Entry.first); + } + + if (TotalChunkHashes.empty()) + { + return; + } + TotalChunkCount = TotalChunkHashes.size(); + + std::vector<BlockStoreLocation> ChunkLocations; + BlockStore::ChunkIndexArray KeepChunkIndexes; + std::vector<IoHash> ChunkIndexToChunkHash; + ChunkLocations.reserve(TotalChunkCount); + ChunkLocations.reserve(TotalChunkCount); + ChunkIndexToChunkHash.reserve(TotalChunkCount); + + GcCtx.FilterCids(TotalChunkHashes, [&](const IoHash& ChunkHash, bool Keep) { + auto KeyIt = Index.find(ChunkHash); + const DiskLocation& DiskLocation = m_Payloads[KeyIt->second].Location; + BlockStoreLocation Location = DiskLocation.GetBlockLocation(m_PayloadAlignment); + size_t ChunkIndex = ChunkLocations.size(); + ChunkLocations.push_back(Location); + ChunkIndexToChunkHash[ChunkIndex] = ChunkHash; + if (Keep) + { + KeepChunkIndexes.push_back(ChunkIndex); + } + }); + + size_t DeleteCount = TotalChunkCount - KeepChunkIndexes.size(); + + const bool PerformDelete = GcCtx.IsDeletionMode() && GcCtx.CollectSmallObjects(); + if (!PerformDelete) + { + m_BlockStore.ReclaimSpace(BlockStoreState, ChunkLocations, KeepChunkIndexes, m_PayloadAlignment, true); + uint64_t CurrentTotalSize = TotalSize(); + ZEN_DEBUG("garbage collect from '{}' DISABLED, found {} chunks of total {} {}", + m_BucketDir / m_BucketName, + DeleteCount, + TotalChunkCount, + NiceBytes(CurrentTotalSize)); + return; + } + + m_BlockStore.ReclaimSpace( + BlockStoreState, + ChunkLocations, + KeepChunkIndexes, + m_PayloadAlignment, + false, + [&](const BlockStore::MovedChunksArray& MovedChunks, const BlockStore::ChunkIndexArray& RemovedChunks) { + std::vector<DiskIndexEntry> LogEntries; + LogEntries.reserve(MovedChunks.size() + RemovedChunks.size()); + for (const auto& Entry : MovedChunks) + { + size_t ChunkIndex = Entry.first; + const BlockStoreLocation& NewLocation = Entry.second; + const IoHash& ChunkHash = ChunkIndexToChunkHash[ChunkIndex]; + const BucketPayload& OldPayload = m_Payloads[Index[ChunkHash]]; + const DiskLocation& OldDiskLocation = OldPayload.Location; + LogEntries.push_back( + {.Key = ChunkHash, .Location = DiskLocation(NewLocation, m_PayloadAlignment, OldDiskLocation.GetFlags())}); + } + for (const size_t ChunkIndex : RemovedChunks) + { + const IoHash& ChunkHash = ChunkIndexToChunkHash[ChunkIndex]; + const BucketPayload& OldPayload = m_Payloads[Index[ChunkHash]]; + const DiskLocation& OldDiskLocation = OldPayload.Location; + LogEntries.push_back({.Key = ChunkHash, + .Location = DiskLocation(OldDiskLocation.GetBlockLocation(m_PayloadAlignment), + m_PayloadAlignment, + OldDiskLocation.GetFlags() | DiskLocation::kTombStone)}); + DeletedChunks.insert(ChunkHash); + } + + m_SlogFile.Append(LogEntries); + m_SlogFile.Flush(); + { + RwLock::ExclusiveLockScope __(m_IndexLock); + Stopwatch Timer; + const auto ____ = MakeGuard([&] { + uint64_t ElapsedUs = Timer.GetElapsedTimeUs(); + ReadBlockTimeUs += ElapsedUs; + ReadBlockLongestTimeUs = std::max(ElapsedUs, ReadBlockLongestTimeUs); + }); + for (const DiskIndexEntry& Entry : LogEntries) + { + if (Entry.Location.GetFlags() & DiskLocation::kTombStone) + { + m_Index.erase(Entry.Key); + continue; + } + m_Payloads[m_Index[Entry.Key]].Location = Entry.Location; + } + } + }, + [&]() { return GcCtx.CollectSmallObjects(); }); +} + +void +ZenCacheDiskLayer::CacheBucket::UpdateAccessTimes(const std::vector<zen::access_tracking::KeyAccessTime>& AccessTimes) +{ + using namespace access_tracking; + + for (const KeyAccessTime& KeyTime : AccessTimes) + { + if (auto It = m_Index.find(KeyTime.Key); It != m_Index.end()) + { + size_t EntryIndex = It.value(); + ZEN_ASSERT_SLOW(EntryIndex < m_AccessTimes.size()); + m_AccessTimes[EntryIndex] = KeyTime.LastAccess; + } + } +} + +uint64_t +ZenCacheDiskLayer::CacheBucket::EntryCount() const +{ + RwLock::SharedLockScope _(m_IndexLock); + return static_cast<uint64_t>(m_Index.size()); +} + +CacheValueDetails::ValueDetails +ZenCacheDiskLayer::CacheBucket::GetValueDetails(const IoHash& Key, size_t Index) const +{ + std::vector<IoHash> Attachments; + const BucketPayload& Payload = m_Payloads[Index]; + if (Payload.Location.IsFlagSet(DiskLocation::kStructured)) + { + IoBuffer Value = Payload.Location.IsFlagSet(DiskLocation::kStandaloneFile) ? GetStandaloneCacheValue(Payload.Location, Key) + : GetInlineCacheValue(Payload.Location); + CbObject Obj(SharedBuffer{Value}); + Obj.IterateAttachments([&Attachments](CbFieldView Field) { Attachments.emplace_back(Field.AsAttachment()); }); + } + return CacheValueDetails::ValueDetails{.Size = Payload.Location.Size(), + .RawSize = Payload.RawSize, + .RawHash = Payload.RawHash, + .LastAccess = m_AccessTimes[Index], + .Attachments = std::move(Attachments), + .ContentType = Payload.Location.GetContentType()}; +} + +CacheValueDetails::BucketDetails +ZenCacheDiskLayer::CacheBucket::GetValueDetails(const std::string_view ValueFilter) const +{ + CacheValueDetails::BucketDetails Details; + RwLock::SharedLockScope _(m_IndexLock); + if (ValueFilter.empty()) + { + Details.Values.reserve(m_Index.size()); + for (const auto& It : m_Index) + { + Details.Values.insert_or_assign(It.first, GetValueDetails(It.first, It.second)); + } + } + else + { + IoHash Key = IoHash::FromHexString(ValueFilter); + if (auto It = m_Index.find(Key); It != m_Index.end()) + { + Details.Values.insert_or_assign(It->first, GetValueDetails(It->first, It->second)); + } + } + return Details; +} + +void +ZenCacheDiskLayer::CollectGarbage(GcContext& GcCtx) +{ + RwLock::SharedLockScope _(m_Lock); + + for (auto& Kv : m_Buckets) + { + CacheBucket& Bucket = *Kv.second; + Bucket.CollectGarbage(GcCtx); + } +} + +void +ZenCacheDiskLayer::UpdateAccessTimes(const zen::access_tracking::AccessTimes& AccessTimes) +{ + RwLock::SharedLockScope _(m_Lock); + + for (const auto& Kv : AccessTimes.Buckets) + { + if (auto It = m_Buckets.find(Kv.first); It != m_Buckets.end()) + { + CacheBucket& Bucket = *It->second; + Bucket.UpdateAccessTimes(Kv.second); + } + } +} + +void +ZenCacheDiskLayer::CacheBucket::PutStandaloneCacheValue(const IoHash& HashKey, const ZenCacheValue& Value) +{ + uint64_t NewFileSize = Value.Value.Size(); + + TemporaryFile DataFile; + + std::error_code Ec; + DataFile.CreateTemporary(m_BucketDir.c_str(), Ec); + if (Ec) + { + throw std::system_error(Ec, fmt::format("Failed to open temporary file for put in '{}'", m_BucketDir)); + } + + bool CleanUpTempFile = false; + auto __ = MakeGuard([&] { + if (CleanUpTempFile) + { + std::error_code Ec; + std::filesystem::remove(DataFile.GetPath(), Ec); + if (Ec) + { + ZEN_WARN("Failed to clean up temporary file '{}' for put in '{}', reason '{}'", + DataFile.GetPath(), + m_BucketDir, + Ec.message()); + } + } + }); + + DataFile.WriteAll(Value.Value, Ec); + if (Ec) + { + throw std::system_error(Ec, + fmt::format("Failed to write payload ({} bytes) to temporary file '{}' for put in '{}'", + NiceBytes(NewFileSize), + DataFile.GetPath().string(), + m_BucketDir)); + } + + ExtendablePathBuilder<256> DataFilePath; + BuildPath(DataFilePath, HashKey); + std::filesystem::path FsPath{DataFilePath.ToPath()}; + + RwLock::ExclusiveLockScope ValueLock(LockForHash(HashKey)); + + // We do a speculative remove of the file instead of probing with a exists call and check the error code instead + std::filesystem::remove(FsPath, Ec); + if (Ec) + { + if (Ec.value() != ENOENT) + { + ZEN_WARN("Failed to remove file '{}' for put in '{}', reason: '{}', retrying.", FsPath, m_BucketDir, Ec.message()); + Sleep(100); + Ec.clear(); + std::filesystem::remove(FsPath, Ec); + if (Ec && Ec.value() != ENOENT) + { + throw std::system_error(Ec, fmt::format("Failed to remove file '{}' for put in '{}'", FsPath, m_BucketDir)); + } + } + } + + DataFile.MoveTemporaryIntoPlace(FsPath, Ec); + if (Ec) + { + CreateDirectories(FsPath.parent_path()); + Ec.clear(); + + // Try again + DataFile.MoveTemporaryIntoPlace(FsPath, Ec); + if (Ec) + { + ZEN_WARN("Failed to finalize file '{}', moving from '{}' for put in '{}', reason: '{}', retrying.", + FsPath, + DataFile.GetPath(), + m_BucketDir, + Ec.message()); + Sleep(100); + Ec.clear(); + DataFile.MoveTemporaryIntoPlace(FsPath, Ec); + if (Ec) + { + throw std::system_error( + Ec, + fmt::format("Failed to finalize file '{}', moving from '{}' for put in '{}'", FsPath, DataFile.GetPath(), m_BucketDir)); + } + } + } + + // Once we have called MoveTemporaryIntoPlace automatic clean up the temp file + // will be disabled as the file handle has already been closed + CleanUpTempFile = false; + + uint8_t EntryFlags = DiskLocation::kStandaloneFile; + + if (Value.Value.GetContentType() == ZenContentType::kCbObject) + { + EntryFlags |= DiskLocation::kStructured; + } + else if (Value.Value.GetContentType() == ZenContentType::kCompressedBinary) + { + EntryFlags |= DiskLocation::kCompressed; + } + + DiskLocation Loc(NewFileSize, EntryFlags); + + RwLock::ExclusiveLockScope _(m_IndexLock); + if (auto It = m_Index.find(HashKey); It == m_Index.end()) + { + // Previously unknown object + size_t EntryIndex = m_Payloads.size(); + m_Payloads.emplace_back(BucketPayload{.Location = Loc, .RawSize = Value.RawSize, .RawHash = Value.RawHash}); + m_AccessTimes.emplace_back(GcClock::TickCount()); + m_Index.insert_or_assign(HashKey, EntryIndex); + } + else + { + // TODO: should check if write is idempotent and bail out if it is? + size_t EntryIndex = It.value(); + ZEN_ASSERT_SLOW(EntryIndex < m_AccessTimes.size()); + m_Payloads[EntryIndex] = BucketPayload{.Location = Loc, .RawSize = Value.RawSize, .RawHash = Value.RawHash}; + m_AccessTimes.emplace_back(GcClock::TickCount()); + m_TotalStandaloneSize.fetch_sub(Loc.Size(), std::memory_order::relaxed); + } + + m_SlogFile.Append({.Key = HashKey, .Location = Loc}); + m_TotalStandaloneSize.fetch_add(NewFileSize, std::memory_order::relaxed); +} + +void +ZenCacheDiskLayer::CacheBucket::PutInlineCacheValue(const IoHash& HashKey, const ZenCacheValue& Value) +{ + uint8_t EntryFlags = 0; + + if (Value.Value.GetContentType() == ZenContentType::kCbObject) + { + EntryFlags |= DiskLocation::kStructured; + } + else if (Value.Value.GetContentType() == ZenContentType::kCompressedBinary) + { + EntryFlags |= DiskLocation::kCompressed; + } + + m_BlockStore.WriteChunk(Value.Value.Data(), Value.Value.Size(), m_PayloadAlignment, [&](const BlockStoreLocation& BlockStoreLocation) { + DiskLocation Location(BlockStoreLocation, m_PayloadAlignment, EntryFlags); + m_SlogFile.Append({.Key = HashKey, .Location = Location}); + + RwLock::ExclusiveLockScope _(m_IndexLock); + if (auto It = m_Index.find(HashKey); It != m_Index.end()) + { + // TODO: should check if write is idempotent and bail out if it is? + // this would requiring comparing contents on disk unless we add a + // content hash to the index entry + size_t EntryIndex = It.value(); + ZEN_ASSERT_SLOW(EntryIndex < m_AccessTimes.size()); + m_Payloads[EntryIndex] = (BucketPayload{.Location = Location, .RawSize = Value.RawSize, .RawHash = Value.RawHash}); + m_AccessTimes[EntryIndex] = GcClock::TickCount(); + } + else + { + size_t EntryIndex = m_Payloads.size(); + m_Payloads.emplace_back(BucketPayload{.Location = Location, .RawSize = Value.RawSize, .RawHash = Value.RawHash}); + m_AccessTimes.emplace_back(GcClock::TickCount()); + m_Index.insert_or_assign(HashKey, EntryIndex); + } + }); +} + +////////////////////////////////////////////////////////////////////////// + +ZenCacheDiskLayer::ZenCacheDiskLayer(const std::filesystem::path& RootDir) : m_RootDir(RootDir) +{ +} + +ZenCacheDiskLayer::~ZenCacheDiskLayer() = default; + +bool +ZenCacheDiskLayer::Get(std::string_view InBucket, const IoHash& HashKey, ZenCacheValue& OutValue) +{ + const auto BucketName = std::string(InBucket); + CacheBucket* Bucket = nullptr; + + { + RwLock::SharedLockScope _(m_Lock); + + auto It = m_Buckets.find(BucketName); + + if (It != m_Buckets.end()) + { + Bucket = It->second.get(); + } + } + + if (Bucket == nullptr) + { + // Bucket needs to be opened/created + + RwLock::ExclusiveLockScope _(m_Lock); + + if (auto It = m_Buckets.find(BucketName); It != m_Buckets.end()) + { + Bucket = It->second.get(); + } + else + { + auto InsertResult = m_Buckets.emplace(BucketName, std::make_unique<CacheBucket>(BucketName)); + Bucket = InsertResult.first->second.get(); + + std::filesystem::path BucketPath = m_RootDir; + BucketPath /= BucketName; + + if (!Bucket->OpenOrCreate(BucketPath)) + { + m_Buckets.erase(InsertResult.first); + return false; + } + } + } + + ZEN_ASSERT(Bucket != nullptr); + return Bucket->Get(HashKey, OutValue); +} + +void +ZenCacheDiskLayer::Put(std::string_view InBucket, const IoHash& HashKey, const ZenCacheValue& Value) +{ + const auto BucketName = std::string(InBucket); + CacheBucket* Bucket = nullptr; + + { + RwLock::SharedLockScope _(m_Lock); + + auto It = m_Buckets.find(BucketName); + + if (It != m_Buckets.end()) + { + Bucket = It->second.get(); + } + } + + if (Bucket == nullptr) + { + // New bucket needs to be created + + RwLock::ExclusiveLockScope _(m_Lock); + + if (auto It = m_Buckets.find(BucketName); It != m_Buckets.end()) + { + Bucket = It->second.get(); + } + else + { + auto InsertResult = m_Buckets.emplace(BucketName, std::make_unique<CacheBucket>(BucketName)); + Bucket = InsertResult.first->second.get(); + + std::filesystem::path BucketPath = m_RootDir; + BucketPath /= BucketName; + + try + { + if (!Bucket->OpenOrCreate(BucketPath)) + { + ZEN_WARN("Found directory '{}' in our base directory '{}' but it is not a valid bucket", BucketName, m_RootDir); + m_Buckets.erase(InsertResult.first); + return; + } + } + catch (const std::exception& Err) + { + ZEN_ERROR("creating bucket '{}' in '{}' FAILED, reason: '{}'", BucketName, BucketPath, Err.what()); + return; + } + } + } + + ZEN_ASSERT(Bucket != nullptr); + + Bucket->Put(HashKey, Value); +} + +void +ZenCacheDiskLayer::DiscoverBuckets() +{ + DirectoryContent DirContent; + GetDirectoryContent(m_RootDir, DirectoryContent::IncludeDirsFlag, DirContent); + + // Initialize buckets + + RwLock::ExclusiveLockScope _(m_Lock); + + for (const std::filesystem::path& BucketPath : DirContent.Directories) + { + const std::string BucketName = PathToUtf8(BucketPath.stem()); + // New bucket needs to be created + if (auto It = m_Buckets.find(BucketName); It != m_Buckets.end()) + { + continue; + } + + auto InsertResult = m_Buckets.emplace(BucketName, std::make_unique<CacheBucket>(BucketName)); + CacheBucket& Bucket = *InsertResult.first->second; + + try + { + if (!Bucket.OpenOrCreate(BucketPath, /* AllowCreate */ false)) + { + ZEN_WARN("Found directory '{}' in our base directory '{}' but it is not a valid bucket", BucketName, m_RootDir); + + m_Buckets.erase(InsertResult.first); + continue; + } + } + catch (const std::exception& Err) + { + ZEN_ERROR("creating bucket '{}' in '{}' FAILED, reason: '{}'", BucketName, BucketPath, Err.what()); + return; + } + ZEN_INFO("Discovered bucket '{}'", BucketName); + } +} + +bool +ZenCacheDiskLayer::DropBucket(std::string_view InBucket) +{ + RwLock::ExclusiveLockScope _(m_Lock); + + auto It = m_Buckets.find(std::string(InBucket)); + + if (It != m_Buckets.end()) + { + CacheBucket& Bucket = *It->second; + m_DroppedBuckets.push_back(std::move(It->second)); + m_Buckets.erase(It); + + return Bucket.Drop(); + } + + // Make sure we remove the folder even if we don't know about the bucket + std::filesystem::path BucketPath = m_RootDir; + BucketPath /= std::string(InBucket); + return MoveAndDeleteDirectory(BucketPath); +} + +bool +ZenCacheDiskLayer::Drop() +{ + RwLock::ExclusiveLockScope _(m_Lock); + + std::vector<std::unique_ptr<CacheBucket>> Buckets; + Buckets.reserve(m_Buckets.size()); + while (!m_Buckets.empty()) + { + const auto& It = m_Buckets.begin(); + CacheBucket& Bucket = *It->second; + m_DroppedBuckets.push_back(std::move(It->second)); + m_Buckets.erase(It->first); + if (!Bucket.Drop()) + { + return false; + } + } + return MoveAndDeleteDirectory(m_RootDir); +} + +void +ZenCacheDiskLayer::Flush() +{ + std::vector<CacheBucket*> Buckets; + + { + RwLock::SharedLockScope _(m_Lock); + Buckets.reserve(m_Buckets.size()); + for (auto& Kv : m_Buckets) + { + CacheBucket* Bucket = Kv.second.get(); + Buckets.push_back(Bucket); + } + } + + for (auto& Bucket : Buckets) + { + Bucket->Flush(); + } +} + +void +ZenCacheDiskLayer::Scrub(ScrubContext& Ctx) +{ + RwLock::SharedLockScope _(m_Lock); + + for (auto& Kv : m_Buckets) + { + CacheBucket& Bucket = *Kv.second; + Bucket.Scrub(Ctx); + } +} + +void +ZenCacheDiskLayer::GatherReferences(GcContext& GcCtx) +{ + RwLock::SharedLockScope _(m_Lock); + + for (auto& Kv : m_Buckets) + { + CacheBucket& Bucket = *Kv.second; + Bucket.GatherReferences(GcCtx); + } +} + +uint64_t +ZenCacheDiskLayer::TotalSize() const +{ + uint64_t TotalSize{}; + RwLock::SharedLockScope _(m_Lock); + + for (auto& Kv : m_Buckets) + { + TotalSize += Kv.second->TotalSize(); + } + + return TotalSize; +} + +ZenCacheDiskLayer::Info +ZenCacheDiskLayer::GetInfo() const +{ + ZenCacheDiskLayer::Info Info = {.Config = {.RootDir = m_RootDir}, .TotalSize = TotalSize()}; + + RwLock::SharedLockScope _(m_Lock); + Info.BucketNames.reserve(m_Buckets.size()); + for (auto& Kv : m_Buckets) + { + Info.BucketNames.push_back(Kv.first); + Info.EntryCount += Kv.second->EntryCount(); + } + return Info; +} + +std::optional<ZenCacheDiskLayer::BucketInfo> +ZenCacheDiskLayer::GetBucketInfo(std::string_view Bucket) const +{ + RwLock::SharedLockScope _(m_Lock); + + if (auto It = m_Buckets.find(std::string(Bucket)); It != m_Buckets.end()) + { + return ZenCacheDiskLayer::BucketInfo{.EntryCount = It->second->EntryCount(), .TotalSize = It->second->TotalSize()}; + } + return {}; +} + +CacheValueDetails::NamespaceDetails +ZenCacheDiskLayer::GetValueDetails(const std::string_view BucketFilter, const std::string_view ValueFilter) const +{ + RwLock::SharedLockScope _(m_Lock); + CacheValueDetails::NamespaceDetails Details; + if (BucketFilter.empty()) + { + Details.Buckets.reserve(BucketFilter.empty() ? m_Buckets.size() : 1); + for (auto& Kv : m_Buckets) + { + Details.Buckets[Kv.first] = Kv.second->GetValueDetails(ValueFilter); + } + } + else if (auto It = m_Buckets.find(std::string(BucketFilter)); It != m_Buckets.end()) + { + Details.Buckets[It->first] = It->second->GetValueDetails(ValueFilter); + } + return Details; +} + +//////////////////////////// ZenCacheStore + +static constexpr std::string_view UE4DDCNamespaceName = "ue4.ddc"; + +ZenCacheStore::ZenCacheStore(GcManager& Gc, const Configuration& Configuration) : m_Gc(Gc), m_Configuration(Configuration) +{ + CreateDirectories(m_Configuration.BasePath); + + DirectoryContent DirContent; + GetDirectoryContent(m_Configuration.BasePath, DirectoryContent::IncludeDirsFlag, DirContent); + + std::vector<std::string> Namespaces; + for (const std::filesystem::path& DirPath : DirContent.Directories) + { + std::string DirName = PathToUtf8(DirPath.filename()); + if (DirName.starts_with(NamespaceDiskPrefix)) + { + Namespaces.push_back(DirName.substr(NamespaceDiskPrefix.length())); + continue; + } + } + + ZEN_INFO("Found {} namespaces in '{}'", Namespaces.size(), m_Configuration.BasePath); + + if (std::find(Namespaces.begin(), Namespaces.end(), UE4DDCNamespaceName) == Namespaces.end()) + { + // default (unspecified) and ue4-ddc namespace points to the same namespace instance + + std::filesystem::path DefaultNamespaceFolder = + m_Configuration.BasePath / fmt::format("{}{}", NamespaceDiskPrefix, UE4DDCNamespaceName); + CreateDirectories(DefaultNamespaceFolder); + Namespaces.push_back(std::string(UE4DDCNamespaceName)); + } + + for (const std::string& NamespaceName : Namespaces) + { + m_Namespaces[NamespaceName] = + std::make_unique<ZenCacheNamespace>(Gc, m_Configuration.BasePath / fmt::format("{}{}", NamespaceDiskPrefix, NamespaceName)); + } +} + +ZenCacheStore::~ZenCacheStore() +{ + m_Namespaces.clear(); +} + +bool +ZenCacheStore::Get(std::string_view Namespace, std::string_view Bucket, const IoHash& HashKey, ZenCacheValue& OutValue) +{ + if (ZenCacheNamespace* Store = GetNamespace(Namespace); Store) + { + return Store->Get(Bucket, HashKey, OutValue); + } + ZEN_WARN("request for unknown namespace '{}' in ZenCacheStore::Get, bucket '{}', key '{}'", Namespace, Bucket, HashKey.ToHexString()); + + return false; +} + +void +ZenCacheStore::Put(std::string_view Namespace, std::string_view Bucket, const IoHash& HashKey, const ZenCacheValue& Value) +{ + if (ZenCacheNamespace* Store = GetNamespace(Namespace); Store) + { + return Store->Put(Bucket, HashKey, Value); + } + ZEN_WARN("request for unknown namespace '{}' in ZenCacheStore::Put, bucket '{}', key '{}'", Namespace, Bucket, HashKey.ToHexString()); +} + +bool +ZenCacheStore::DropBucket(std::string_view Namespace, std::string_view Bucket) +{ + if (ZenCacheNamespace* Store = GetNamespace(Namespace); Store) + { + return Store->DropBucket(Bucket); + } + ZEN_WARN("request for unknown namespace '{}' in ZenCacheStore::DropBucket, bucket '{}'", Namespace, Bucket); + return false; +} + +bool +ZenCacheStore::DropNamespace(std::string_view InNamespace) +{ + RwLock::SharedLockScope _(m_NamespacesLock); + if (auto It = m_Namespaces.find(std::string(InNamespace)); It != m_Namespaces.end()) + { + ZenCacheNamespace& Namespace = *It->second; + m_DroppedNamespaces.push_back(std::move(It->second)); + m_Namespaces.erase(It); + return Namespace.Drop(); + } + ZEN_WARN("request for unknown namespace '{}' in ZenCacheStore::DropNamespace", InNamespace); + return false; +} + +void +ZenCacheStore::Flush() +{ + IterateNamespaces([&](std::string_view, ZenCacheNamespace& Store) { Store.Flush(); }); +} + +void +ZenCacheStore::Scrub(ScrubContext& Ctx) +{ + IterateNamespaces([&](std::string_view, ZenCacheNamespace& Store) { Store.Scrub(Ctx); }); +} + +CacheValueDetails +ZenCacheStore::GetValueDetails(const std::string_view NamespaceFilter, + const std::string_view BucketFilter, + const std::string_view ValueFilter) const +{ + CacheValueDetails Details; + if (NamespaceFilter.empty()) + { + IterateNamespaces([&](std::string_view Namespace, ZenCacheNamespace& Store) { + Details.Namespaces[std::string(Namespace)] = Store.GetValueDetails(BucketFilter, ValueFilter); + }); + } + else if (const ZenCacheNamespace* Store = FindNamespace(NamespaceFilter); Store != nullptr) + { + Details.Namespaces[std::string(NamespaceFilter)] = Store->GetValueDetails(BucketFilter, ValueFilter); + } + return Details; +} + +ZenCacheNamespace* +ZenCacheStore::GetNamespace(std::string_view Namespace) +{ + RwLock::SharedLockScope _(m_NamespacesLock); + if (auto It = m_Namespaces.find(std::string(Namespace)); It != m_Namespaces.end()) + { + return It->second.get(); + } + if (Namespace == DefaultNamespace) + { + if (auto It = m_Namespaces.find(std::string(UE4DDCNamespaceName)); It != m_Namespaces.end()) + { + return It->second.get(); + } + } + _.ReleaseNow(); + + if (!m_Configuration.AllowAutomaticCreationOfNamespaces) + { + return nullptr; + } + + RwLock::ExclusiveLockScope __(m_NamespacesLock); + if (auto It = m_Namespaces.find(std::string(Namespace)); It != m_Namespaces.end()) + { + return It->second.get(); + } + + auto NewNamespace = m_Namespaces.insert_or_assign( + std::string(Namespace), + std::make_unique<ZenCacheNamespace>(m_Gc, m_Configuration.BasePath / fmt::format("{}{}", NamespaceDiskPrefix, Namespace))); + return NewNamespace.first->second.get(); +} + +const ZenCacheNamespace* +ZenCacheStore::FindNamespace(std::string_view Namespace) const +{ + RwLock::SharedLockScope _(m_NamespacesLock); + if (auto It = m_Namespaces.find(std::string(Namespace)); It != m_Namespaces.end()) + { + return It->second.get(); + } + if (Namespace == DefaultNamespace) + { + if (auto It = m_Namespaces.find(std::string(UE4DDCNamespaceName)); It != m_Namespaces.end()) + { + return It->second.get(); + } + } + return nullptr; +} + +void +ZenCacheStore::IterateNamespaces(const std::function<void(std::string_view Namespace, ZenCacheNamespace& Store)>& Callback) const +{ + std::vector<std::pair<std::string, ZenCacheNamespace&>> Namespaces; + { + RwLock::SharedLockScope _(m_NamespacesLock); + Namespaces.reserve(m_Namespaces.size()); + for (const auto& Entry : m_Namespaces) + { + if (Entry.first == DefaultNamespace) + { + continue; + } + Namespaces.push_back({Entry.first, *Entry.second}); + } + } + for (auto& Entry : Namespaces) + { + Callback(Entry.first, Entry.second); + } +} + +GcStorageSize +ZenCacheStore::StorageSize() const +{ + GcStorageSize Size; + IterateNamespaces([&](std::string_view, ZenCacheNamespace& Store) { + GcStorageSize StoreSize = Store.StorageSize(); + Size.MemorySize += StoreSize.MemorySize; + Size.DiskSize += StoreSize.DiskSize; + }); + return Size; +} + +ZenCacheStore::Info +ZenCacheStore::GetInfo() const +{ + ZenCacheStore::Info Info = {.Config = m_Configuration, .StorageSize = StorageSize()}; + + IterateNamespaces([&Info](std::string_view NamespaceName, ZenCacheNamespace& Namespace) { + Info.NamespaceNames.push_back(std::string(NamespaceName)); + ZenCacheNamespace::Info NamespaceInfo = Namespace.GetInfo(); + Info.DiskEntryCount += NamespaceInfo.DiskLayerInfo.EntryCount; + Info.MemoryEntryCount += NamespaceInfo.MemoryLayerInfo.EntryCount; + }); + + return Info; +} + +std::optional<ZenCacheNamespace::Info> +ZenCacheStore::GetNamespaceInfo(std::string_view NamespaceName) +{ + if (const ZenCacheNamespace* Namespace = FindNamespace(NamespaceName); Namespace) + { + return Namespace->GetInfo(); + } + return {}; +} + +std::optional<ZenCacheNamespace::BucketInfo> +ZenCacheStore::GetBucketInfo(std::string_view NamespaceName, std::string_view BucketName) +{ + if (const ZenCacheNamespace* Namespace = FindNamespace(NamespaceName); Namespace) + { + return Namespace->GetBucketInfo(BucketName); + } + return {}; +} + +////////////////////////////////////////////////////////////////////////// + +#if ZEN_WITH_TESTS + +using namespace std::literals; + +namespace testutils { + IoHash CreateKey(size_t KeyValue) { return IoHash::HashBuffer(&KeyValue, sizeof(size_t)); } + + IoBuffer CreateBinaryCacheValue(uint64_t Size) + { + static std::random_device rd; + static std::mt19937 g(rd()); + + std::vector<uint8_t> Values; + Values.resize(Size); + for (size_t Idx = 0; Idx < Size; ++Idx) + { + Values[Idx] = static_cast<uint8_t>(Idx); + } + std::shuffle(Values.begin(), Values.end(), g); + + IoBuffer Buf(IoBuffer::Clone, Values.data(), Values.size()); + Buf.SetContentType(ZenContentType::kBinary); + return Buf; + }; + +} // namespace testutils + +TEST_CASE("z$.store") +{ + ScopedTemporaryDirectory TempDir; + + GcManager Gc; + + ZenCacheNamespace Zcs(Gc, TempDir.Path() / "cache"); + + const int kIterationCount = 100; + + for (int i = 0; i < kIterationCount; ++i) + { + const IoHash Key = IoHash::HashBuffer(&i, sizeof i); + + CbObjectWriter Cbo; + Cbo << "hey" << i; + CbObject Obj = Cbo.Save(); + + ZenCacheValue Value; + Value.Value = Obj.GetBuffer().AsIoBuffer(); + Value.Value.SetContentType(ZenContentType::kCbObject); + + Zcs.Put("test_bucket"sv, Key, Value); + } + + for (int i = 0; i < kIterationCount; ++i) + { + const IoHash Key = IoHash::HashBuffer(&i, sizeof i); + + ZenCacheValue Value; + Zcs.Get("test_bucket"sv, Key, /* out */ Value); + + REQUIRE(Value.Value); + CHECK(Value.Value.GetContentType() == ZenContentType::kCbObject); + CHECK_EQ(ValidateCompactBinary(Value.Value, CbValidateMode::All), CbValidateError::None); + CbObject Obj = LoadCompactBinaryObject(Value.Value); + CHECK_EQ(Obj["hey"].AsInt32(), i); + } +} + +TEST_CASE("z$.size") +{ + const auto CreateCacheValue = [](size_t Size) -> CbObject { + std::vector<uint8_t> Buf; + Buf.resize(Size); + + CbObjectWriter Writer; + Writer.AddBinary("Binary"sv, Buf.data(), Buf.size()); + return Writer.Save(); + }; + + SUBCASE("mem/disklayer") + { + const size_t Count = 16; + ScopedTemporaryDirectory TempDir; + + GcStorageSize CacheSize; + + { + GcManager Gc; + ZenCacheNamespace Zcs(Gc, TempDir.Path() / "cache"); + + CbObject CacheValue = CreateCacheValue(Zcs.DiskLayerThreshold() - 256); + + IoBuffer Buffer = CacheValue.GetBuffer().AsIoBuffer(); + Buffer.SetContentType(ZenContentType::kCbObject); + + for (size_t Key = 0; Key < Count; ++Key) + { + const size_t Bucket = Key % 4; + Zcs.Put(fmt::format("test_bucket-{}", Bucket), IoHash::HashBuffer(&Key, sizeof(uint32_t)), ZenCacheValue{.Value = Buffer}); + } + + CacheSize = Zcs.StorageSize(); + CHECK_LE(CacheValue.GetSize() * Count, CacheSize.DiskSize); + CHECK_LE(CacheValue.GetSize() * Count, CacheSize.MemorySize); + } + + { + GcManager Gc; + ZenCacheNamespace Zcs(Gc, TempDir.Path() / "cache"); + + const GcStorageSize SerializedSize = Zcs.StorageSize(); + CHECK_EQ(SerializedSize.MemorySize, 0); + CHECK_LE(SerializedSize.DiskSize, CacheSize.DiskSize); + + for (size_t Bucket = 0; Bucket < 4; ++Bucket) + { + Zcs.DropBucket(fmt::format("test_bucket-{}", Bucket)); + } + CHECK_EQ(0, Zcs.StorageSize().DiskSize); + } + } + + SUBCASE("disklayer") + { + const size_t Count = 16; + ScopedTemporaryDirectory TempDir; + + GcStorageSize CacheSize; + + { + GcManager Gc; + ZenCacheNamespace Zcs(Gc, TempDir.Path() / "cache"); + + CbObject CacheValue = CreateCacheValue(Zcs.DiskLayerThreshold() + 64); + + IoBuffer Buffer = CacheValue.GetBuffer().AsIoBuffer(); + Buffer.SetContentType(ZenContentType::kCbObject); + + for (size_t Key = 0; Key < Count; ++Key) + { + const size_t Bucket = Key % 4; + Zcs.Put(fmt::format("test_bucket-{}", Bucket), IoHash::HashBuffer(&Key, sizeof(uint32_t)), {.Value = Buffer}); + } + + CacheSize = Zcs.StorageSize(); + CHECK_LE(CacheValue.GetSize() * Count, CacheSize.DiskSize); + CHECK_EQ(0, CacheSize.MemorySize); + } + + { + GcManager Gc; + ZenCacheNamespace Zcs(Gc, TempDir.Path() / "cache"); + + const GcStorageSize SerializedSize = Zcs.StorageSize(); + CHECK_EQ(SerializedSize.MemorySize, 0); + CHECK_LE(SerializedSize.DiskSize, CacheSize.DiskSize); + + for (size_t Bucket = 0; Bucket < 4; ++Bucket) + { + Zcs.DropBucket(fmt::format("test_bucket-{}", Bucket)); + } + CHECK_EQ(0, Zcs.StorageSize().DiskSize); + } + } +} + +TEST_CASE("z$.gc") +{ + using namespace testutils; + + SUBCASE("gather references does NOT add references for expired cache entries") + { + ScopedTemporaryDirectory TempDir; + std::vector<IoHash> Cids{CreateKey(1), CreateKey(2), CreateKey(3)}; + + const auto CollectAndFilter = [](GcManager& Gc, + GcClock::TimePoint Time, + GcClock::Duration MaxDuration, + std::span<const IoHash> Cids, + std::vector<IoHash>& OutKeep) { + GcContext GcCtx(Time - MaxDuration); + Gc.CollectGarbage(GcCtx); + OutKeep.clear(); + GcCtx.FilterCids(Cids, [&OutKeep](const IoHash& Hash) { OutKeep.push_back(Hash); }); + }; + + { + GcManager Gc; + ZenCacheNamespace Zcs(Gc, TempDir.Path() / "cache"); + const auto Bucket = "teardrinker"sv; + + // Create a cache record + const IoHash Key = CreateKey(42); + CbObjectWriter Record; + Record << "Key"sv + << "SomeRecord"sv; + + for (size_t Idx = 0; auto& Cid : Cids) + { + Record.AddBinaryAttachment(fmt::format("attachment-{}", Idx++), Cid); + } + + IoBuffer Buffer = Record.Save().GetBuffer().AsIoBuffer(); + Buffer.SetContentType(ZenContentType::kCbObject); + + Zcs.Put(Bucket, Key, {.Value = Buffer}); + + std::vector<IoHash> Keep; + + // Collect garbage with 1 hour max cache duration + { + CollectAndFilter(Gc, GcClock::Now(), std::chrono::hours(1), Cids, Keep); + CHECK_EQ(Cids.size(), Keep.size()); + } + + // Move forward in time + { + CollectAndFilter(Gc, GcClock::Now() + std::chrono::hours(2), std::chrono::hours(1), Cids, Keep); + CHECK_EQ(0, Keep.size()); + } + } + + // Expect timestamps to be serialized + { + GcManager Gc; + ZenCacheNamespace Zcs(Gc, TempDir.Path() / "cache"); + std::vector<IoHash> Keep; + + // Collect garbage with 1 hour max cache duration + { + CollectAndFilter(Gc, GcClock::Now(), std::chrono::hours(1), Cids, Keep); + CHECK_EQ(3, Keep.size()); + } + + // Move forward in time + { + CollectAndFilter(Gc, GcClock::Now() + std::chrono::hours(2), std::chrono::hours(1), Cids, Keep); + CHECK_EQ(0, Keep.size()); + } + } + } + + SUBCASE("gc removes standalone values") + { + ScopedTemporaryDirectory TempDir; + GcManager Gc; + ZenCacheNamespace Zcs(Gc, TempDir.Path() / "cache"); + const auto Bucket = "fortysixandtwo"sv; + const GcClock::TimePoint CurrentTime = GcClock::Now(); + + std::vector<IoHash> Keys{CreateKey(1), CreateKey(2), CreateKey(3)}; + + for (const auto& Key : Keys) + { + IoBuffer Value = testutils::CreateBinaryCacheValue(128 << 10); + Zcs.Put(Bucket, Key, {.Value = Value}); + } + + { + GcContext GcCtx(CurrentTime - std::chrono::hours(46)); + + Gc.CollectGarbage(GcCtx); + + for (const auto& Key : Keys) + { + ZenCacheValue CacheValue; + const bool Exists = Zcs.Get(Bucket, Key, CacheValue); + CHECK(Exists); + } + } + + // Move forward in time and collect again + { + GcContext GcCtx(CurrentTime + std::chrono::minutes(2)); + Gc.CollectGarbage(GcCtx); + + for (const auto& Key : Keys) + { + ZenCacheValue CacheValue; + const bool Exists = Zcs.Get(Bucket, Key, CacheValue); + CHECK(!Exists); + } + + CHECK_EQ(0, Zcs.StorageSize().DiskSize); + } + } + + SUBCASE("gc removes small objects") + { + ScopedTemporaryDirectory TempDir; + GcManager Gc; + ZenCacheNamespace Zcs(Gc, TempDir.Path() / "cache"); + const auto Bucket = "rightintwo"sv; + + std::vector<IoHash> Keys{CreateKey(1), CreateKey(2), CreateKey(3)}; + + for (const auto& Key : Keys) + { + IoBuffer Value = testutils::CreateBinaryCacheValue(128); + Zcs.Put(Bucket, Key, {.Value = Value}); + } + + { + GcContext GcCtx(GcClock::Now() - std::chrono::hours(2)); + GcCtx.CollectSmallObjects(true); + + Gc.CollectGarbage(GcCtx); + + for (const auto& Key : Keys) + { + ZenCacheValue CacheValue; + const bool Exists = Zcs.Get(Bucket, Key, CacheValue); + CHECK(Exists); + } + } + + // Move forward in time and collect again + { + GcContext GcCtx(GcClock::Now() + std::chrono::minutes(2)); + GcCtx.CollectSmallObjects(true); + + Zcs.Flush(); + Gc.CollectGarbage(GcCtx); + + for (const auto& Key : Keys) + { + ZenCacheValue CacheValue; + const bool Exists = Zcs.Get(Bucket, Key, CacheValue); + CHECK(!Exists); + } + + CHECK_EQ(0, Zcs.StorageSize().DiskSize); + } + } +} + +TEST_CASE("z$.threadedinsert") // * doctest::skip(true)) +{ + // for (uint32_t i = 0; i < 100; ++i) + { + ScopedTemporaryDirectory TempDir; + + const uint64_t kChunkSize = 1048; + const int32_t kChunkCount = 8192; + + struct Chunk + { + std::string Bucket; + IoBuffer Buffer; + }; + std::unordered_map<IoHash, Chunk, IoHash::Hasher> Chunks; + Chunks.reserve(kChunkCount); + + const std::string Bucket1 = "rightinone"; + const std::string Bucket2 = "rightintwo"; + + for (int32_t Idx = 0; Idx < kChunkCount; ++Idx) + { + while (true) + { + IoBuffer Chunk = testutils::CreateBinaryCacheValue(kChunkSize); + IoHash Hash = HashBuffer(Chunk); + if (Chunks.contains(Hash)) + { + continue; + } + Chunks[Hash] = {.Bucket = Bucket1, .Buffer = Chunk}; + break; + } + while (true) + { + IoBuffer Chunk = testutils::CreateBinaryCacheValue(kChunkSize); + IoHash Hash = HashBuffer(Chunk); + if (Chunks.contains(Hash)) + { + continue; + } + Chunks[Hash] = {.Bucket = Bucket2, .Buffer = Chunk}; + break; + } + } + + CreateDirectories(TempDir.Path()); + + WorkerThreadPool ThreadPool(4); + GcManager Gc; + ZenCacheNamespace Zcs(Gc, TempDir.Path()); + + { + std::atomic<size_t> WorkCompleted = 0; + for (const auto& Chunk : Chunks) + { + ThreadPool.ScheduleWork([&Zcs, &WorkCompleted, &Chunk]() { + Zcs.Put(Chunk.second.Bucket, Chunk.first, {.Value = Chunk.second.Buffer}); + WorkCompleted.fetch_add(1); + }); + } + while (WorkCompleted < Chunks.size()) + { + Sleep(1); + } + } + + const uint64_t TotalSize = Zcs.StorageSize().DiskSize; + CHECK_LE(kChunkSize * Chunks.size(), TotalSize); + + { + std::atomic<size_t> WorkCompleted = 0; + for (const auto& Chunk : Chunks) + { + ThreadPool.ScheduleWork([&Zcs, &WorkCompleted, &Chunk]() { + std::string Bucket = Chunk.second.Bucket; + IoHash ChunkHash = Chunk.first; + ZenCacheValue CacheValue; + + CHECK(Zcs.Get(Bucket, ChunkHash, CacheValue)); + IoHash Hash = IoHash::HashBuffer(CacheValue.Value); + CHECK(ChunkHash == Hash); + WorkCompleted.fetch_add(1); + }); + } + while (WorkCompleted < Chunks.size()) + { + Sleep(1); + } + } + std::unordered_map<IoHash, std::string, IoHash::Hasher> GcChunkHashes; + GcChunkHashes.reserve(Chunks.size()); + for (const auto& Chunk : Chunks) + { + GcChunkHashes[Chunk.first] = Chunk.second.Bucket; + } + { + std::unordered_map<IoHash, Chunk, IoHash::Hasher> NewChunks; + + for (int32_t Idx = 0; Idx < kChunkCount; ++Idx) + { + { + IoBuffer Chunk = testutils::CreateBinaryCacheValue(kChunkSize); + IoHash Hash = HashBuffer(Chunk); + NewChunks[Hash] = {.Bucket = Bucket1, .Buffer = Chunk}; + } + { + IoBuffer Chunk = testutils::CreateBinaryCacheValue(kChunkSize); + IoHash Hash = HashBuffer(Chunk); + NewChunks[Hash] = {.Bucket = Bucket2, .Buffer = Chunk}; + } + } + + std::atomic<size_t> WorkCompleted = 0; + std::atomic_uint32_t AddedChunkCount = 0; + for (const auto& Chunk : NewChunks) + { + ThreadPool.ScheduleWork([&Zcs, &WorkCompleted, Chunk, &AddedChunkCount]() { + Zcs.Put(Chunk.second.Bucket, Chunk.first, {.Value = Chunk.second.Buffer}); + AddedChunkCount.fetch_add(1); + WorkCompleted.fetch_add(1); + }); + } + + for (const auto& Chunk : Chunks) + { + ThreadPool.ScheduleWork([&Zcs, &WorkCompleted, Chunk]() { + ZenCacheValue CacheValue; + if (Zcs.Get(Chunk.second.Bucket, Chunk.first, CacheValue)) + { + CHECK(Chunk.first == IoHash::HashBuffer(CacheValue.Value)); + } + WorkCompleted.fetch_add(1); + }); + } + while (AddedChunkCount.load() < NewChunks.size()) + { + // Need to be careful since we might GC blocks we don't know outside of RwLock::ExclusiveLockScope + for (const auto& Chunk : NewChunks) + { + ZenCacheValue CacheValue; + if (Zcs.Get(Chunk.second.Bucket, Chunk.first, CacheValue)) + { + GcChunkHashes[Chunk.first] = Chunk.second.Bucket; + } + } + std::vector<IoHash> KeepHashes; + KeepHashes.reserve(GcChunkHashes.size()); + for (const auto& Entry : GcChunkHashes) + { + KeepHashes.push_back(Entry.first); + } + size_t C = 0; + while (C < KeepHashes.size()) + { + if (C % 155 == 0) + { + if (C < KeepHashes.size() - 1) + { + KeepHashes[C] = KeepHashes[KeepHashes.size() - 1]; + KeepHashes.pop_back(); + } + if (C + 3 < KeepHashes.size() - 1) + { + KeepHashes[C + 3] = KeepHashes[KeepHashes.size() - 1]; + KeepHashes.pop_back(); + } + } + C++; + } + + GcContext GcCtx(GcClock::Now() - std::chrono::hours(24)); + GcCtx.CollectSmallObjects(true); + GcCtx.AddRetainedCids(KeepHashes); + Zcs.CollectGarbage(GcCtx); + const HashKeySet& Deleted = GcCtx.DeletedCids(); + Deleted.IterateHashes([&GcChunkHashes](const IoHash& ChunkHash) { GcChunkHashes.erase(ChunkHash); }); + } + + while (WorkCompleted < NewChunks.size() + Chunks.size()) + { + Sleep(1); + } + + { + // Need to be careful since we might GC blocks we don't know outside of RwLock::ExclusiveLockScope + for (const auto& Chunk : NewChunks) + { + ZenCacheValue CacheValue; + if (Zcs.Get(Chunk.second.Bucket, Chunk.first, CacheValue)) + { + GcChunkHashes[Chunk.first] = Chunk.second.Bucket; + } + } + std::vector<IoHash> KeepHashes; + KeepHashes.reserve(GcChunkHashes.size()); + for (const auto& Entry : GcChunkHashes) + { + KeepHashes.push_back(Entry.first); + } + size_t C = 0; + while (C < KeepHashes.size()) + { + if (C % 155 == 0) + { + if (C < KeepHashes.size() - 1) + { + KeepHashes[C] = KeepHashes[KeepHashes.size() - 1]; + KeepHashes.pop_back(); + } + if (C + 3 < KeepHashes.size() - 1) + { + KeepHashes[C + 3] = KeepHashes[KeepHashes.size() - 1]; + KeepHashes.pop_back(); + } + } + C++; + } + + GcContext GcCtx(GcClock::Now() - std::chrono::hours(24)); + GcCtx.CollectSmallObjects(true); + GcCtx.AddRetainedCids(KeepHashes); + Zcs.CollectGarbage(GcCtx); + const HashKeySet& Deleted = GcCtx.DeletedCids(); + Deleted.IterateHashes([&GcChunkHashes](const IoHash& ChunkHash) { GcChunkHashes.erase(ChunkHash); }); + } + } + { + std::atomic<size_t> WorkCompleted = 0; + for (const auto& Chunk : GcChunkHashes) + { + ThreadPool.ScheduleWork([&Zcs, &WorkCompleted, Chunk]() { + ZenCacheValue CacheValue; + CHECK(Zcs.Get(Chunk.second, Chunk.first, CacheValue)); + CHECK(Chunk.first == IoHash::HashBuffer(CacheValue.Value)); + WorkCompleted.fetch_add(1); + }); + } + while (WorkCompleted < GcChunkHashes.size()) + { + Sleep(1); + } + } + } +} + +TEST_CASE("z$.namespaces") +{ + using namespace testutils; + + const auto CreateCacheValue = [](size_t Size) -> CbObject { + std::vector<uint8_t> Buf; + Buf.resize(Size); + + CbObjectWriter Writer; + Writer.AddBinary("Binary"sv, Buf.data(), Buf.size()); + return Writer.Save(); + }; + + ScopedTemporaryDirectory TempDir; + CreateDirectories(TempDir.Path()); + + IoHash Key1; + IoHash Key2; + { + GcManager Gc; + ZenCacheStore Zcs(Gc, {.BasePath = TempDir.Path() / "cache", .AllowAutomaticCreationOfNamespaces = false}); + const auto Bucket = "teardrinker"sv; + const auto CustomNamespace = "mynamespace"sv; + + // Create a cache record + Key1 = CreateKey(42); + CbObject CacheValue = CreateCacheValue(4096); + + IoBuffer Buffer = CacheValue.GetBuffer().AsIoBuffer(); + Buffer.SetContentType(ZenContentType::kCbObject); + + ZenCacheValue PutValue = {.Value = Buffer}; + Zcs.Put(ZenCacheStore::DefaultNamespace, Bucket, Key1, PutValue); + + ZenCacheValue GetValue; + CHECK(Zcs.Get(ZenCacheStore::DefaultNamespace, Bucket, Key1, GetValue)); + CHECK(!Zcs.Get(CustomNamespace, Bucket, Key1, GetValue)); + + // This should just be dropped as we don't allow creating of namespaces on the fly + Zcs.Put(CustomNamespace, Bucket, Key1, PutValue); + CHECK(!Zcs.Get(CustomNamespace, Bucket, Key1, GetValue)); + } + + { + GcManager Gc; + ZenCacheStore Zcs(Gc, {.BasePath = TempDir.Path() / "cache", .AllowAutomaticCreationOfNamespaces = true}); + const auto Bucket = "teardrinker"sv; + const auto CustomNamespace = "mynamespace"sv; + + Key2 = CreateKey(43); + CbObject CacheValue2 = CreateCacheValue(4096); + + IoBuffer Buffer2 = CacheValue2.GetBuffer().AsIoBuffer(); + Buffer2.SetContentType(ZenContentType::kCbObject); + ZenCacheValue PutValue2 = {.Value = Buffer2}; + Zcs.Put(CustomNamespace, Bucket, Key2, PutValue2); + + ZenCacheValue GetValue; + CHECK(!Zcs.Get(ZenCacheStore::DefaultNamespace, Bucket, Key2, GetValue)); + CHECK(Zcs.Get(ZenCacheStore::DefaultNamespace, Bucket, Key1, GetValue)); + CHECK(!Zcs.Get(CustomNamespace, Bucket, Key1, GetValue)); + CHECK(Zcs.Get(CustomNamespace, Bucket, Key2, GetValue)); + } +} + +TEST_CASE("z$.drop.bucket") +{ + using namespace testutils; + + const auto CreateCacheValue = [](size_t Size) -> CbObject { + std::vector<uint8_t> Buf; + Buf.resize(Size); + + CbObjectWriter Writer; + Writer.AddBinary("Binary"sv, Buf.data(), Buf.size()); + return Writer.Save(); + }; + + ScopedTemporaryDirectory TempDir; + CreateDirectories(TempDir.Path()); + + IoHash Key1; + IoHash Key2; + + auto PutValue = + [&CreateCacheValue](ZenCacheStore& Zcs, std::string_view Namespace, std::string_view Bucket, size_t KeyIndex, size_t Size) { + // Create a cache record + IoHash Key = CreateKey(KeyIndex); + CbObject CacheValue = CreateCacheValue(Size); + + IoBuffer Buffer = CacheValue.GetBuffer().AsIoBuffer(); + Buffer.SetContentType(ZenContentType::kCbObject); + + ZenCacheValue PutValue = {.Value = Buffer}; + Zcs.Put(Namespace, Bucket, Key, PutValue); + return Key; + }; + auto GetValue = [](ZenCacheStore& Zcs, std::string_view Namespace, std::string_view Bucket, const IoHash& Key) { + ZenCacheValue GetValue; + Zcs.Get(Namespace, Bucket, Key, GetValue); + return GetValue; + }; + WorkerThreadPool Workers(1); + { + GcManager Gc; + ZenCacheStore Zcs(Gc, {.BasePath = TempDir.Path() / "cache", .AllowAutomaticCreationOfNamespaces = true}); + const auto Bucket = "teardrinker"sv; + const auto Namespace = "mynamespace"sv; + + Key1 = PutValue(Zcs, Namespace, Bucket, 42, 4096); + Key2 = PutValue(Zcs, Namespace, Bucket, 43, 2048); + + ZenCacheValue Value1 = GetValue(Zcs, Namespace, Bucket, Key1); + CHECK(Value1.Value); + + std::atomic_bool WorkComplete = false; + Workers.ScheduleWork([&]() { + zen::Sleep(100); + Value1.Value = IoBuffer{}; + WorkComplete = true; + }); + // On Windows, DropBucket() will be blocked as long as we hold a reference to a buffer in the bucket + // Our DropBucket execution blocks any incoming request from completing until we are done with the drop + CHECK(Zcs.DropBucket(Namespace, Bucket)); + while (!WorkComplete) + { + zen::Sleep(1); + } + + // Entire bucket should be dropped, but doing a request should will re-create the namespace but it must still be empty + Value1 = GetValue(Zcs, Namespace, Bucket, Key1); + CHECK(!Value1.Value); + ZenCacheValue Value2 = GetValue(Zcs, Namespace, Bucket, Key2); + CHECK(!Value2.Value); + } +} + +TEST_CASE("z$.drop.namespace") +{ + using namespace testutils; + + const auto CreateCacheValue = [](size_t Size) -> CbObject { + std::vector<uint8_t> Buf; + Buf.resize(Size); + + CbObjectWriter Writer; + Writer.AddBinary("Binary"sv, Buf.data(), Buf.size()); + return Writer.Save(); + }; + + ScopedTemporaryDirectory TempDir; + CreateDirectories(TempDir.Path()); + + auto PutValue = + [&CreateCacheValue](ZenCacheStore& Zcs, std::string_view Namespace, std::string_view Bucket, size_t KeyIndex, size_t Size) { + // Create a cache record + IoHash Key = CreateKey(KeyIndex); + CbObject CacheValue = CreateCacheValue(Size); + + IoBuffer Buffer = CacheValue.GetBuffer().AsIoBuffer(); + Buffer.SetContentType(ZenContentType::kCbObject); + + ZenCacheValue PutValue = {.Value = Buffer}; + Zcs.Put(Namespace, Bucket, Key, PutValue); + return Key; + }; + auto GetValue = [](ZenCacheStore& Zcs, std::string_view Namespace, std::string_view Bucket, const IoHash& Key) { + ZenCacheValue GetValue; + Zcs.Get(Namespace, Bucket, Key, GetValue); + return GetValue; + }; + WorkerThreadPool Workers(1); + { + GcManager Gc; + ZenCacheStore Zcs(Gc, {.BasePath = TempDir.Path() / "cache", .AllowAutomaticCreationOfNamespaces = true}); + const auto Bucket1 = "teardrinker1"sv; + const auto Bucket2 = "teardrinker2"sv; + const auto Namespace1 = "mynamespace1"sv; + const auto Namespace2 = "mynamespace2"sv; + + IoHash Key1 = PutValue(Zcs, Namespace1, Bucket1, 42, 4096); + IoHash Key2 = PutValue(Zcs, Namespace1, Bucket2, 43, 2048); + IoHash Key3 = PutValue(Zcs, Namespace2, Bucket1, 44, 4096); + IoHash Key4 = PutValue(Zcs, Namespace2, Bucket2, 45, 2048); + + ZenCacheValue Value1 = GetValue(Zcs, Namespace1, Bucket1, Key1); + CHECK(Value1.Value); + ZenCacheValue Value2 = GetValue(Zcs, Namespace1, Bucket2, Key2); + CHECK(Value2.Value); + ZenCacheValue Value3 = GetValue(Zcs, Namespace2, Bucket1, Key3); + CHECK(Value3.Value); + ZenCacheValue Value4 = GetValue(Zcs, Namespace2, Bucket2, Key4); + CHECK(Value4.Value); + + std::atomic_bool WorkComplete = false; + Workers.ScheduleWork([&]() { + zen::Sleep(100); + Value1.Value = IoBuffer{}; + Value2.Value = IoBuffer{}; + Value3.Value = IoBuffer{}; + Value4.Value = IoBuffer{}; + WorkComplete = true; + }); + // On Windows, DropBucket() will be blocked as long as we hold a reference to a buffer in the bucket + // Our DropBucket execution blocks any incoming request from completing until we are done with the drop + CHECK(Zcs.DropNamespace(Namespace1)); + while (!WorkComplete) + { + zen::Sleep(1); + } + + // Entire namespace should be dropped, but doing a request should will re-create the namespace but it must still be empty + Value1 = GetValue(Zcs, Namespace1, Bucket1, Key1); + CHECK(!Value1.Value); + Value2 = GetValue(Zcs, Namespace1, Bucket2, Key2); + CHECK(!Value2.Value); + Value3 = GetValue(Zcs, Namespace2, Bucket1, Key3); + CHECK(Value3.Value); + Value4 = GetValue(Zcs, Namespace2, Bucket2, Key4); + CHECK(Value4.Value); + } +} + +TEST_CASE("z$.blocked.disklayer.put") +{ + ScopedTemporaryDirectory TempDir; + + GcStorageSize CacheSize; + + const auto CreateCacheValue = [](size_t Size) -> CbObject { + std::vector<uint8_t> Buf; + Buf.resize(Size, Size & 0xff); + + CbObjectWriter Writer; + Writer.AddBinary("Binary"sv, Buf.data(), Buf.size()); + return Writer.Save(); + }; + + GcManager Gc; + ZenCacheNamespace Zcs(Gc, TempDir.Path() / "cache"); + + CbObject CacheValue = CreateCacheValue(64 * 1024 + 64); + + IoBuffer Buffer = CacheValue.GetBuffer().AsIoBuffer(); + Buffer.SetContentType(ZenContentType::kCbObject); + + size_t Key = Buffer.Size(); + IoHash HashKey = IoHash::HashBuffer(&Key, sizeof(uint32_t)); + Zcs.Put("test_bucket", HashKey, {.Value = Buffer}); + + ZenCacheValue BufferGet; + CHECK(Zcs.Get("test_bucket", HashKey, BufferGet)); + + CbObject CacheValue2 = CreateCacheValue(64 * 1024 + 64 + 1); + IoBuffer Buffer2 = CacheValue2.GetBuffer().AsIoBuffer(); + Buffer2.SetContentType(ZenContentType::kCbObject); + + // We should be able to overwrite even if the file is open for read + Zcs.Put("test_bucket", HashKey, {.Value = Buffer2}); + + MemoryView OldView = BufferGet.Value.GetView(); + + ZenCacheValue BufferGet2; + CHECK(Zcs.Get("test_bucket", HashKey, BufferGet2)); + MemoryView NewView = BufferGet2.Value.GetView(); + + // Make sure file openend for read before we wrote it still have old data + CHECK(OldView.GetSize() == Buffer.GetSize()); + CHECK(memcmp(OldView.GetData(), Buffer.GetData(), OldView.GetSize()) == 0); + + // Make sure we get the new data when reading after we write new data + CHECK(NewView.GetSize() == Buffer2.GetSize()); + CHECK(memcmp(NewView.GetData(), Buffer2.GetData(), NewView.GetSize()) == 0); +} + +TEST_CASE("z$.scrub") +{ + ScopedTemporaryDirectory TempDir; + + using namespace testutils; + + struct CacheRecord + { + IoBuffer Record; + std::vector<CompressedBuffer> Attachments; + }; + + auto CreateCacheRecord = [](bool Structured, std::string_view Bucket, const IoHash& Key, const std::vector<size_t>& AttachmentSizes) { + CacheRecord Result; + if (Structured) + { + Result.Attachments.resize(AttachmentSizes.size()); + CbObjectWriter Record; + Record.BeginObject("Key"sv); + { + Record << "Bucket"sv << Bucket; + Record << "Hash"sv << Key; + } + Record.EndObject(); + for (size_t Index = 0; Index < AttachmentSizes.size(); Index++) + { + IoBuffer AttachmentData = CreateBinaryCacheValue(AttachmentSizes[Index]); + CompressedBuffer CompressedAttachmentData = CompressedBuffer::Compress(SharedBuffer(AttachmentData)); + Record.AddBinaryAttachment(fmt::format("attachment-{}", Index), CompressedAttachmentData.DecodeRawHash()); + Result.Attachments[Index] = CompressedAttachmentData; + } + Result.Record = Record.Save().GetBuffer().AsIoBuffer(); + Result.Record.SetContentType(ZenContentType::kCbObject); + } + else + { + std::string RecordData = fmt::format("{}:{}", Bucket, Key.ToHexString()); + size_t TotalSize = RecordData.length() + 1; + for (size_t AttachmentSize : AttachmentSizes) + { + TotalSize += AttachmentSize; + } + Result.Record = IoBuffer(TotalSize); + char* DataPtr = (char*)Result.Record.MutableData(); + memcpy(DataPtr, RecordData.c_str(), RecordData.length() + 1); + DataPtr += RecordData.length() + 1; + for (size_t AttachmentSize : AttachmentSizes) + { + IoBuffer AttachmentData = CreateBinaryCacheValue(AttachmentSize); + memcpy(DataPtr, AttachmentData.GetData(), AttachmentData.GetSize()); + DataPtr += AttachmentData.GetSize(); + } + } + return Result; + }; + + GcManager Gc; + CidStore CidStore(Gc); + ZenCacheNamespace Zcs(Gc, TempDir.Path() / "cache"); + CidStoreConfiguration CidConfig = {.RootDirectory = TempDir.Path() / "cas", .TinyValueThreshold = 1024, .HugeValueThreshold = 4096}; + CidStore.Initialize(CidConfig); + + auto CreateRecords = + [&](bool IsStructured, std::string_view BucketName, const std::vector<IoHash>& Cids, const std::vector<size_t>& AttachmentSizes) { + for (const IoHash& Cid : Cids) + { + CacheRecord Record = CreateCacheRecord(IsStructured, BucketName, Cid, AttachmentSizes); + Zcs.Put("mybucket", Cid, {.Value = Record.Record}); + for (const CompressedBuffer& Attachment : Record.Attachments) + { + CidStore.AddChunk(Attachment.GetCompressed().Flatten().AsIoBuffer(), Attachment.DecodeRawHash()); + } + } + }; + + std::vector<size_t> AttachmentSizes = {16, 1000, 2000, 4000, 8000, 64000, 80000}; + + std::vector<IoHash> UnstructuredCids{CreateKey(4), CreateKey(5), CreateKey(6)}; + CreateRecords(false, "mybucket"sv, UnstructuredCids, AttachmentSizes); + + std::vector<IoHash> StructuredCids{CreateKey(1), CreateKey(2), CreateKey(3)}; + CreateRecords(true, "mybucket"sv, StructuredCids, AttachmentSizes); + + ScrubContext ScrubCtx; + Zcs.Scrub(ScrubCtx); + CidStore.Scrub(ScrubCtx); + CHECK(ScrubCtx.ScrubbedChunks() == (StructuredCids.size() + StructuredCids.size() * AttachmentSizes.size()) + UnstructuredCids.size()); + CHECK(ScrubCtx.BadCids().GetSize() == 0); +} + +#endif + +void +z$_forcelink() +{ +} + +} // namespace zen diff --git a/src/zenserver/cache/structuredcachestore.h b/src/zenserver/cache/structuredcachestore.h new file mode 100644 index 000000000..3fb4f035d --- /dev/null +++ b/src/zenserver/cache/structuredcachestore.h @@ -0,0 +1,535 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/compactbinary.h> +#include <zencore/iobuffer.h> +#include <zencore/iohash.h> +#include <zencore/thread.h> +#include <zencore/uid.h> +#include <zenstore/blockstore.h> +#include <zenstore/caslog.h> +#include <zenstore/gc.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <tsl/robin_map.h> +ZEN_THIRD_PARTY_INCLUDES_END + +#include <atomic> +#include <compare> +#include <filesystem> +#include <unordered_map> + +#define ZEN_USE_CACHE_TRACKER 0 + +namespace zen { + +class PathBuilderBase; +class GcManager; +class ZenCacheTracker; +class ScrubContext; + +/****************************************************************************** + + /$$$$$$$$ /$$$$$$ /$$ + |_____ $$ /$$__ $$ | $$ + /$$/ /$$$$$$ /$$$$$$$ | $$ \__/ /$$$$$$ /$$$$$$| $$$$$$$ /$$$$$$ + /$$/ /$$__ $| $$__ $$ | $$ |____ $$/$$_____| $$__ $$/$$__ $$ + /$$/ | $$$$$$$| $$ \ $$ | $$ /$$$$$$| $$ | $$ \ $| $$$$$$$$ + /$$/ | $$_____| $$ | $$ | $$ $$/$$__ $| $$ | $$ | $| $$_____/ + /$$$$$$$| $$$$$$| $$ | $$ | $$$$$$| $$$$$$| $$$$$$| $$ | $| $$$$$$$ + |________/\_______|__/ |__/ \______/ \_______/\_______|__/ |__/\_______/ + + Cache store for UE5. Restricts keys to "{bucket}/{hash}" pairs where the hash + is 40 (hex) chars in size. Values may be opaque blobs or structured objects + which can in turn contain references to other objects (or blobs). + +******************************************************************************/ + +namespace access_tracking { + + struct KeyAccessTime + { + IoHash Key; + GcClock::Tick LastAccess{}; + }; + + struct AccessTimes + { + std::unordered_map<std::string, std::vector<KeyAccessTime>> Buckets; + }; +}; // namespace access_tracking + +struct ZenCacheValue +{ + IoBuffer Value; + uint64_t RawSize = 0; + IoHash RawHash = IoHash::Zero; +}; + +struct CacheValueDetails +{ + struct ValueDetails + { + uint64_t Size; + uint64_t RawSize; + IoHash RawHash; + GcClock::Tick LastAccess{}; + std::vector<IoHash> Attachments; + ZenContentType ContentType; + }; + + struct BucketDetails + { + std::unordered_map<IoHash, ValueDetails, IoHash::Hasher> Values; + }; + + struct NamespaceDetails + { + std::unordered_map<std::string, BucketDetails> Buckets; + }; + + std::unordered_map<std::string, NamespaceDetails> Namespaces; +}; + +////////////////////////////////////////////////////////////////////////// + +#pragma pack(push) +#pragma pack(1) + +struct DiskLocation +{ + inline DiskLocation() = default; + + inline DiskLocation(uint64_t ValueSize, uint8_t Flags) : Flags(Flags | kStandaloneFile) { Location.StandaloneSize = ValueSize; } + + inline DiskLocation(const BlockStoreLocation& Location, uint64_t PayloadAlignment, uint8_t Flags) : Flags(Flags & ~kStandaloneFile) + { + this->Location.BlockLocation = BlockStoreDiskLocation(Location, PayloadAlignment); + } + + inline BlockStoreLocation GetBlockLocation(uint64_t PayloadAlignment) const + { + ZEN_ASSERT(!(Flags & kStandaloneFile)); + return Location.BlockLocation.Get(PayloadAlignment); + } + + inline uint64_t Size() const { return (Flags & kStandaloneFile) ? Location.StandaloneSize : Location.BlockLocation.GetSize(); } + inline uint8_t IsFlagSet(uint64_t Flag) const { return Flags & Flag; } + inline uint8_t GetFlags() const { return Flags; } + inline ZenContentType GetContentType() const + { + ZenContentType ContentType = ZenContentType::kBinary; + + if (IsFlagSet(kStructured)) + { + ContentType = ZenContentType::kCbObject; + } + + if (IsFlagSet(kCompressed)) + { + ContentType = ZenContentType::kCompressedBinary; + } + + return ContentType; + } + + union + { + BlockStoreDiskLocation BlockLocation; // 10 bytes + uint64_t StandaloneSize = 0; // 8 bytes + } Location; + + static const uint8_t kStandaloneFile = 0x80u; // Stored as a separate file + static const uint8_t kStructured = 0x40u; // Serialized as compact binary + static const uint8_t kTombStone = 0x20u; // Represents a deleted key/value + static const uint8_t kCompressed = 0x10u; // Stored in compressed buffer format + + uint8_t Flags = 0; + uint8_t Reserved = 0; +}; + +struct DiskIndexEntry +{ + IoHash Key; // 20 bytes + DiskLocation Location; // 12 bytes +}; + +#pragma pack(pop) + +static_assert(sizeof(DiskIndexEntry) == 32); + +// This store the access time as seconds since epoch internally in a 32-bit value giving is a range of 136 years since epoch +struct AccessTime +{ + explicit AccessTime(GcClock::Tick Tick) noexcept : SecondsSinceEpoch(ToSeconds(Tick)) {} + AccessTime& operator=(GcClock::Tick Tick) noexcept + { + SecondsSinceEpoch.store(ToSeconds(Tick), std::memory_order_relaxed); + return *this; + } + operator GcClock::Tick() const noexcept + { + return std::chrono::duration_cast<GcClock::Duration>(std::chrono::seconds(SecondsSinceEpoch.load(std::memory_order_relaxed))) + .count(); + } + + AccessTime(AccessTime&& Rhs) noexcept : SecondsSinceEpoch(Rhs.SecondsSinceEpoch.load(std::memory_order_relaxed)) {} + AccessTime(const AccessTime& Rhs) noexcept : SecondsSinceEpoch(Rhs.SecondsSinceEpoch.load(std::memory_order_relaxed)) {} + AccessTime& operator=(AccessTime&& Rhs) noexcept + { + SecondsSinceEpoch.store(Rhs.SecondsSinceEpoch.load(std::memory_order_relaxed), std::memory_order_relaxed); + return *this; + } + AccessTime& operator=(const AccessTime& Rhs) noexcept + { + SecondsSinceEpoch.store(Rhs.SecondsSinceEpoch.load(std::memory_order_relaxed), std::memory_order_relaxed); + return *this; + } + +private: + static uint32_t ToSeconds(GcClock::Tick Tick) + { + return gsl::narrow<uint32_t>(std::chrono::duration_cast<std::chrono::seconds>(GcClock::Duration(Tick)).count()); + } + std::atomic_uint32_t SecondsSinceEpoch; +}; + +/** In-memory cache storage + + Intended for small values which are frequently accessed + + This should have a better memory management policy to maintain reasonable + footprint. + */ +class ZenCacheMemoryLayer +{ +public: + struct Configuration + { + uint64_t TargetFootprintBytes = 16 * 1024 * 1024; + uint64_t ScavengeThreshold = 4 * 1024 * 1024; + }; + + struct BucketInfo + { + uint64_t EntryCount = 0; + uint64_t TotalSize = 0; + }; + + struct Info + { + Configuration Config; + std::vector<std::string> BucketNames; + uint64_t EntryCount = 0; + uint64_t TotalSize = 0; + }; + + ZenCacheMemoryLayer(); + ~ZenCacheMemoryLayer(); + + bool Get(std::string_view Bucket, const IoHash& HashKey, ZenCacheValue& OutValue); + void Put(std::string_view Bucket, const IoHash& HashKey, const ZenCacheValue& Value); + void Drop(); + bool DropBucket(std::string_view Bucket); + void Scrub(ScrubContext& Ctx); + void GatherAccessTimes(zen::access_tracking::AccessTimes& AccessTimes); + void Reset(); + uint64_t TotalSize() const; + + Info GetInfo() const; + std::optional<BucketInfo> GetBucketInfo(std::string_view Bucket) const; + + const Configuration& GetConfiguration() const { return m_Configuration; } + void SetConfiguration(const Configuration& NewConfig) { m_Configuration = NewConfig; } + +private: + struct CacheBucket + { +#pragma pack(push) +#pragma pack(1) + struct BucketPayload + { + IoBuffer Payload; // 8 + uint32_t RawSize; // 4 + IoHash RawHash; // 20 + }; +#pragma pack(pop) + static_assert(sizeof(BucketPayload) == 32u); + static_assert(sizeof(AccessTime) == 4u); + + mutable RwLock m_BucketLock; + std::vector<AccessTime> m_AccessTimes; + std::vector<BucketPayload> m_Payloads; + tsl::robin_map<IoHash, uint32_t> m_CacheMap; + + std::atomic_uint64_t m_TotalSize{}; + + bool Get(const IoHash& HashKey, ZenCacheValue& OutValue); + void Put(const IoHash& HashKey, const ZenCacheValue& Value); + void Drop(); + void Scrub(ScrubContext& Ctx); + void GatherAccessTimes(std::vector<zen::access_tracking::KeyAccessTime>& AccessTimes); + inline uint64_t TotalSize() const { return m_TotalSize; } + uint64_t EntryCount() const; + }; + + mutable RwLock m_Lock; + std::unordered_map<std::string, std::unique_ptr<CacheBucket>> m_Buckets; + std::vector<std::unique_ptr<CacheBucket>> m_DroppedBuckets; + Configuration m_Configuration; + + ZenCacheMemoryLayer(const ZenCacheMemoryLayer&) = delete; + ZenCacheMemoryLayer& operator=(const ZenCacheMemoryLayer&) = delete; +}; + +class ZenCacheDiskLayer +{ +public: + struct Configuration + { + std::filesystem::path RootDir; + }; + + struct BucketInfo + { + uint64_t EntryCount = 0; + uint64_t TotalSize = 0; + }; + + struct Info + { + Configuration Config; + std::vector<std::string> BucketNames; + uint64_t EntryCount = 0; + uint64_t TotalSize = 0; + }; + + explicit ZenCacheDiskLayer(const std::filesystem::path& RootDir); + ~ZenCacheDiskLayer(); + + bool Get(std::string_view Bucket, const IoHash& HashKey, ZenCacheValue& OutValue); + void Put(std::string_view Bucket, const IoHash& HashKey, const ZenCacheValue& Value); + bool Drop(); + bool DropBucket(std::string_view Bucket); + void Flush(); + void Scrub(ScrubContext& Ctx); + void GatherReferences(GcContext& GcCtx); + void CollectGarbage(GcContext& GcCtx); + void UpdateAccessTimes(const zen::access_tracking::AccessTimes& AccessTimes); + + void DiscoverBuckets(); + uint64_t TotalSize() const; + + Info GetInfo() const; + std::optional<BucketInfo> GetBucketInfo(std::string_view Bucket) const; + + CacheValueDetails::NamespaceDetails GetValueDetails(const std::string_view BucketFilter, const std::string_view ValueFilter) const; + +private: + /** A cache bucket manages a single directory containing + metadata and data for that bucket + */ + struct CacheBucket + { + CacheBucket(std::string BucketName); + ~CacheBucket(); + + bool OpenOrCreate(std::filesystem::path BucketDir, bool AllowCreate = true); + bool Get(const IoHash& HashKey, ZenCacheValue& OutValue); + void Put(const IoHash& HashKey, const ZenCacheValue& Value); + bool Drop(); + void Flush(); + void Scrub(ScrubContext& Ctx); + void GatherReferences(GcContext& GcCtx); + void CollectGarbage(GcContext& GcCtx); + void UpdateAccessTimes(const std::vector<zen::access_tracking::KeyAccessTime>& AccessTimes); + + inline uint64_t TotalSize() const { return m_TotalStandaloneSize.load(std::memory_order::relaxed) + m_BlockStore.TotalSize(); } + uint64_t EntryCount() const; + + CacheValueDetails::BucketDetails GetValueDetails(const std::string_view ValueFilter) const; + + private: + const uint64_t MaxBlockSize = 1ull << 30; + uint64_t m_PayloadAlignment = 1ull << 4; + + std::string m_BucketName; + std::filesystem::path m_BucketDir; + std::filesystem::path m_BlocksBasePath; + BlockStore m_BlockStore; + Oid m_BucketId; + uint64_t m_LargeObjectThreshold = 128 * 1024; + + // These files are used to manage storage of small objects for this bucket + + TCasLogFile<DiskIndexEntry> m_SlogFile; + uint64_t m_LogFlushPosition = 0; + +#pragma pack(push) +#pragma pack(1) + struct BucketPayload + { + DiskLocation Location; // 12 + uint64_t RawSize; // 8 + IoHash RawHash; // 20 + }; +#pragma pack(pop) + static_assert(sizeof(BucketPayload) == 40u); + static_assert(sizeof(AccessTime) == 4u); + + using IndexMap = tsl::robin_map<IoHash, size_t, IoHash::Hasher>; + + mutable RwLock m_IndexLock; + std::vector<AccessTime> m_AccessTimes; + std::vector<BucketPayload> m_Payloads; + IndexMap m_Index; + + std::atomic_uint64_t m_TotalStandaloneSize{}; + + void BuildPath(PathBuilderBase& Path, const IoHash& HashKey) const; + void PutStandaloneCacheValue(const IoHash& HashKey, const ZenCacheValue& Value); + IoBuffer GetStandaloneCacheValue(const DiskLocation& Loc, const IoHash& HashKey) const; + void PutInlineCacheValue(const IoHash& HashKey, const ZenCacheValue& Value); + IoBuffer GetInlineCacheValue(const DiskLocation& Loc) const; + void MakeIndexSnapshot(); + uint64_t ReadIndexFile(const std::filesystem::path& IndexPath, uint32_t& OutVersion); + uint64_t ReadLog(const std::filesystem::path& LogPath, uint64_t LogPosition); + void OpenLog(const bool IsNew); + void SaveManifest(); + CacheValueDetails::ValueDetails GetValueDetails(const IoHash& Key, size_t Index) const; + // These locks are here to avoid contention on file creation, therefore it's sufficient + // that we take the same lock for the same hash + // + // These locks are small and should really be spaced out so they don't share cache lines, + // but we don't currently access them at particularly high frequency so it should not be + // an issue in practice + + mutable RwLock m_ShardedLocks[256]; + inline RwLock& LockForHash(const IoHash& Hash) const { return m_ShardedLocks[Hash.Hash[19]]; } + }; + + std::filesystem::path m_RootDir; + mutable RwLock m_Lock; + std::unordered_map<std::string, std::unique_ptr<CacheBucket>> m_Buckets; // TODO: make this case insensitive + std::vector<std::unique_ptr<CacheBucket>> m_DroppedBuckets; + + ZenCacheDiskLayer(const ZenCacheDiskLayer&) = delete; + ZenCacheDiskLayer& operator=(const ZenCacheDiskLayer&) = delete; +}; + +class ZenCacheNamespace final : public RefCounted, public GcStorage, public GcContributor +{ +public: + struct Configuration + { + std::filesystem::path RootDir; + uint64_t DiskLayerThreshold = 0; + }; + struct BucketInfo + { + ZenCacheDiskLayer::BucketInfo DiskLayerInfo; + ZenCacheMemoryLayer::BucketInfo MemoryLayerInfo; + }; + struct Info + { + Configuration Config; + std::vector<std::string> BucketNames; + ZenCacheDiskLayer::Info DiskLayerInfo; + ZenCacheMemoryLayer::Info MemoryLayerInfo; + }; + + ZenCacheNamespace(GcManager& Gc, const std::filesystem::path& RootDir); + ~ZenCacheNamespace(); + + bool Get(std::string_view Bucket, const IoHash& HashKey, ZenCacheValue& OutValue); + void Put(std::string_view Bucket, const IoHash& HashKey, const ZenCacheValue& Value); + bool Drop(); + bool DropBucket(std::string_view Bucket); + void Flush(); + void Scrub(ScrubContext& Ctx); + uint64_t DiskLayerThreshold() const { return m_DiskLayerSizeThreshold; } + virtual void GatherReferences(GcContext& GcCtx) override; + virtual void CollectGarbage(GcContext& GcCtx) override; + virtual GcStorageSize StorageSize() const override; + Info GetInfo() const; + std::optional<BucketInfo> GetBucketInfo(std::string_view Bucket) const; + + CacheValueDetails::NamespaceDetails GetValueDetails(const std::string_view BucketFilter, const std::string_view ValueFilter) const; + +private: + std::filesystem::path m_RootDir; + ZenCacheMemoryLayer m_MemLayer; + ZenCacheDiskLayer m_DiskLayer; + uint64_t m_DiskLayerSizeThreshold = 1 * 1024; + uint64_t m_LastScrubTime = 0; + +#if ZEN_USE_CACHE_TRACKER + std::unique_ptr<ZenCacheTracker> m_AccessTracker; +#endif + + ZenCacheNamespace(const ZenCacheNamespace&) = delete; + ZenCacheNamespace& operator=(const ZenCacheNamespace&) = delete; +}; + +class ZenCacheStore final +{ +public: + static constexpr std::string_view DefaultNamespace = + "!default!"; // This is intentionally not a valid namespace name and will only be used for mapping when no namespace is given + static constexpr std::string_view NamespaceDiskPrefix = "ns_"; + + struct Configuration + { + std::filesystem::path BasePath; + bool AllowAutomaticCreationOfNamespaces = false; + }; + + struct Info + { + Configuration Config; + std::vector<std::string> NamespaceNames; + uint64_t DiskEntryCount = 0; + uint64_t MemoryEntryCount = 0; + GcStorageSize StorageSize; + }; + + ZenCacheStore(GcManager& Gc, const Configuration& Configuration); + ~ZenCacheStore(); + + bool Get(std::string_view Namespace, std::string_view Bucket, const IoHash& HashKey, ZenCacheValue& OutValue); + void Put(std::string_view Namespace, std::string_view Bucket, const IoHash& HashKey, const ZenCacheValue& Value); + bool DropBucket(std::string_view Namespace, std::string_view Bucket); + bool DropNamespace(std::string_view Namespace); + void Flush(); + void Scrub(ScrubContext& Ctx); + + CacheValueDetails GetValueDetails(const std::string_view NamespaceFilter, + const std::string_view BucketFilter, + const std::string_view ValueFilter) const; + + GcStorageSize StorageSize() const; + // const Configuration& GetConfiguration() const { return m_Configuration; } + + Info GetInfo() const; + std::optional<ZenCacheNamespace::Info> GetNamespaceInfo(std::string_view Namespace); + std::optional<ZenCacheNamespace::BucketInfo> GetBucketInfo(std::string_view Namespace, std::string_view Bucket); + +private: + const ZenCacheNamespace* FindNamespace(std::string_view Namespace) const; + ZenCacheNamespace* GetNamespace(std::string_view Namespace); + void IterateNamespaces(const std::function<void(std::string_view Namespace, ZenCacheNamespace& Store)>& Callback) const; + + typedef std::unordered_map<std::string, std::unique_ptr<ZenCacheNamespace>> NamespaceMap; + + mutable RwLock m_NamespacesLock; + NamespaceMap m_Namespaces; + std::vector<std::unique_ptr<ZenCacheNamespace>> m_DroppedNamespaces; + + GcManager& m_Gc; + Configuration m_Configuration; +}; + +void z$_forcelink(); + +} // namespace zen diff --git a/src/zenserver/cidstore.cpp b/src/zenserver/cidstore.cpp new file mode 100644 index 000000000..bce4f1dfb --- /dev/null +++ b/src/zenserver/cidstore.cpp @@ -0,0 +1,124 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "cidstore.h" + +#include <zencore/compress.h> +#include <zencore/fmtutils.h> +#include <zencore/logging.h> +#include <zenstore/cidstore.h> + +#include <gsl/gsl-lite.hpp> + +namespace zen { + +HttpCidService::HttpCidService(CidStore& Store) : m_CidStore(Store) +{ + m_Router.AddPattern("cid", "([0-9A-Fa-f]{40})"); + + m_Router.RegisterRoute( + "{cid}", + [this](HttpRouterRequest& Req) { + IoHash Hash = IoHash::FromHexString(Req.GetCapture(1)); + ZEN_DEBUG("CID request for {}", Hash); + + HttpServerRequest& ServerRequest = Req.ServerRequest(); + + switch (ServerRequest.RequestVerb()) + { + case HttpVerb::kGet: + case HttpVerb::kHead: + { + if (IoBuffer Value = m_CidStore.FindChunkByCid(Hash)) + { + return ServerRequest.WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, Value); + } + + return ServerRequest.WriteResponse(HttpResponseCode::NotFound); + } + break; + + case HttpVerb::kPut: + { + IoBuffer Payload = ServerRequest.ReadPayload(); + IoHash RawHash; + uint64_t RawSize; + if (!CompressedBuffer::ValidateCompressedHeader(Payload, RawHash, RawSize)) + { + return ServerRequest.WriteResponse(HttpResponseCode::UnsupportedMediaType); + } + + // URI hash must match content hash + if (RawHash != Hash) + { + return ServerRequest.WriteResponse(HttpResponseCode::BadRequest); + } + + m_CidStore.AddChunk(Payload, RawHash); + + return ServerRequest.WriteResponse(HttpResponseCode::OK); + } + break; + + default: + break; + } + }, + HttpVerb::kGet | HttpVerb::kPut | HttpVerb::kHead); +} + +const char* +HttpCidService::BaseUri() const +{ + return "/cid/"; +} + +void +HttpCidService::HandleRequest(zen::HttpServerRequest& Request) +{ + if (Request.RelativeUri().empty()) + { + // Root URI request + + switch (Request.RequestVerb()) + { + case HttpVerb::kPut: + case HttpVerb::kPost: + { + IoBuffer Payload = Request.ReadPayload(); + IoHash RawHash; + uint64_t RawSize; + if (!CompressedBuffer::ValidateCompressedHeader(Payload, RawHash, RawSize)) + { + return Request.WriteResponse(HttpResponseCode::UnsupportedMediaType); + } + + ZEN_DEBUG("CID POST request for {} ({} bytes)", RawHash, Payload.Size()); + + auto InsertResult = m_CidStore.AddChunk(Payload, RawHash); + + if (InsertResult.New) + { + return Request.WriteResponse(HttpResponseCode::Created); + } + else + { + return Request.WriteResponse(HttpResponseCode::OK); + } + } + break; + + case HttpVerb::kGet: + case HttpVerb::kHead: + break; + + default: + break; + } + } + else + { + m_Router.HandleRequest(Request); + } +} + +} // namespace zen diff --git a/src/zenserver/cidstore.h b/src/zenserver/cidstore.h new file mode 100644 index 000000000..8e7832b35 --- /dev/null +++ b/src/zenserver/cidstore.h @@ -0,0 +1,35 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenhttp/httpserver.h> + +namespace zen { + +/** + * Simple CID store HTTP endpoint + * + * Note that since this does not end up pinning any of the chunks it's only really useful for a small subset of use cases where you know a + * chunk exists in the underlying CID store. Thus it's mainly useful for internal use when communicating between Zen store instances + * + * Using this interface for adding CID chunks makes little sense except for testing purposes as garbage collection may reap anything you add + * before anything ever gets to access it + */ + +class CidStore; + +class HttpCidService : public HttpService +{ +public: + explicit HttpCidService(CidStore& Store); + ~HttpCidService() = default; + + virtual const char* BaseUri() const override; + virtual void HandleRequest(zen::HttpServerRequest& Request) override; + +private: + CidStore& m_CidStore; + HttpRequestRouter m_Router; +}; + +} // namespace zen diff --git a/src/zenserver/compute/function.cpp b/src/zenserver/compute/function.cpp new file mode 100644 index 000000000..493e2666e --- /dev/null +++ b/src/zenserver/compute/function.cpp @@ -0,0 +1,629 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "function.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include <upstream/jupiter.h> +# include <upstream/upstreamapply.h> +# include <upstream/upstreamcache.h> +# include <zencore/compactbinary.h> +# include <zencore/compactbinarybuilder.h> +# include <zencore/compactbinarypackage.h> +# include <zencore/compress.h> +# include <zencore/except.h> +# include <zencore/filesystem.h> +# include <zencore/fmtutils.h> +# include <zencore/iobuffer.h> +# include <zencore/iohash.h> +# include <zencore/scopeguard.h> +# include <zenstore/cidstore.h> + +# include <span> + +using namespace std::literals; + +namespace zen { + +HttpFunctionService::HttpFunctionService(CidStore& InCidStore, + const CloudCacheClientOptions& ComputeOptions, + const CloudCacheClientOptions& StorageOptions, + const UpstreamAuthConfig& ComputeAuthConfig, + const UpstreamAuthConfig& StorageAuthConfig, + AuthMgr& Mgr) +: m_Log(logging::Get("apply")) +, m_CidStore(InCidStore) +{ + m_UpstreamApply = UpstreamApply::Create({}, m_CidStore); + + InitializeThread = std::thread{[this, ComputeOptions, StorageOptions, ComputeAuthConfig, StorageAuthConfig, &Mgr] { + auto HordeUpstreamEndpoint = UpstreamApplyEndpoint::CreateHordeEndpoint(ComputeOptions, + ComputeAuthConfig, + StorageOptions, + StorageAuthConfig, + m_CidStore, + Mgr); + m_UpstreamApply->RegisterEndpoint(std::move(HordeUpstreamEndpoint)); + m_UpstreamApply->Initialize(); + }}; + + m_Router.AddPattern("job", "([[:digit:]]+)"); + m_Router.AddPattern("worker", "([[:xdigit:]]{40})"); + m_Router.AddPattern("action", "([[:xdigit:]]{40})"); + + m_Router.RegisterRoute( + "ready", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + return HttpReq.WriteResponse(m_UpstreamApply->IsHealthy() ? HttpResponseCode::OK : HttpResponseCode::ServiceUnavailable); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "workers/{worker}", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + const IoHash WorkerId = IoHash::FromHexString(Req.GetCapture(1)); + + switch (HttpReq.RequestVerb()) + { + case HttpVerb::kGet: + { + RwLock::SharedLockScope _(m_WorkerLock); + + if (auto It = m_WorkerMap.find(WorkerId); It == m_WorkerMap.end()) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + else + { + const WorkerDesc& Desc = It->second; + return HttpReq.WriteResponse(HttpResponseCode::OK, Desc.Descriptor); + } + } + break; + + case HttpVerb::kPost: + { + switch (HttpReq.RequestContentType()) + { + case HttpContentType::kCbObject: + { + CbObject FunctionSpec = HttpReq.ReadPayloadObject(); + + // Determine which pieces are missing and need to be transmitted to populate CAS + + HashKeySet ChunkSet; + + FunctionSpec.IterateAttachments([&](CbFieldView Field) { + const IoHash Hash = Field.AsHash(); + ChunkSet.AddHashToSet(Hash); + }); + + // Note that we store executables uncompressed to make it + // more straightforward and efficient to materialize them, hence + // the CAS lookup here instead of CID for the input payloads + + m_CidStore.FilterChunks(ChunkSet); + + if (ChunkSet.IsEmpty()) + { + RwLock::ExclusiveLockScope _(m_WorkerLock); + + m_WorkerMap.insert_or_assign(WorkerId, WorkerDesc{FunctionSpec}); + + ZEN_DEBUG("worker {}: all attachments already available", WorkerId); + + return HttpReq.WriteResponse(HttpResponseCode::NoContent); + } + else + { + CbObjectWriter ResponseWriter; + ResponseWriter.BeginArray("need"); + + ChunkSet.IterateHashes([&](const IoHash& Hash) { + ZEN_DEBUG("worker {}: need chunk {}", WorkerId, Hash); + + ResponseWriter.AddHash(Hash); + }); + + ResponseWriter.EndArray(); + + ZEN_DEBUG("worker {}: need {} attachments", WorkerId, ChunkSet.GetSize()); + + return HttpReq.WriteResponse(HttpResponseCode::NotFound, ResponseWriter.Save()); + } + } + break; + + case HttpContentType::kCbPackage: + { + CbPackage FunctionSpec = HttpReq.ReadPayloadPackage(); + + CbObject Obj = FunctionSpec.GetObject(); + + std::span<const CbAttachment> Attachments = FunctionSpec.GetAttachments(); + + int AttachmentCount = 0; + int NewAttachmentCount = 0; + uint64_t TotalAttachmentBytes = 0; + uint64_t TotalNewBytes = 0; + + for (const CbAttachment& Attachment : Attachments) + { + ZEN_ASSERT(Attachment.IsCompressedBinary()); + + const IoHash DataHash = Attachment.GetHash(); + CompressedBuffer Buffer = Attachment.AsCompressedBinary(); + + ZEN_UNUSED(DataHash); + TotalAttachmentBytes += Buffer.GetCompressedSize(); + ++AttachmentCount; + + const CidStore::InsertResult InsertResult = + m_CidStore.AddChunk(Buffer.GetCompressed().Flatten().AsIoBuffer(), DataHash); + + if (InsertResult.New) + { + TotalNewBytes += Buffer.GetCompressedSize(); + ++NewAttachmentCount; + } + } + + ZEN_DEBUG("worker {}: {} in {} attachments, {} in {} new attachments", + WorkerId, + zen::NiceBytes(TotalAttachmentBytes), + AttachmentCount, + zen::NiceBytes(TotalNewBytes), + NewAttachmentCount); + + RwLock::ExclusiveLockScope _(m_WorkerLock); + + m_WorkerMap.insert_or_assign(WorkerId, WorkerDesc{.Descriptor = Obj}); + + return HttpReq.WriteResponse(HttpResponseCode::NoContent); + } + break; + + default: + break; + } + } + break; + + default: + break; + } + }, + HttpVerb::kGet | HttpVerb::kPost); + + m_Router.RegisterRoute( + "jobs/{job}", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + switch (HttpReq.RequestVerb()) + { + case HttpVerb::kGet: + break; + + case HttpVerb::kPost: + break; + + default: + break; + } + }, + HttpVerb::kGet | HttpVerb::kPost); + + m_Router.RegisterRoute( + "jobs/{worker}/{action}", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const IoHash WorkerId = IoHash::FromHexString(Req.GetCapture(1)); + const IoHash ActionId = IoHash::FromHexString(Req.GetCapture(2)); + + switch (HttpReq.RequestVerb()) + { + case HttpVerb::kGet: + { + CbPackage Output; + HttpResponseCode ResponseCode = ExecActionUpstreamResult(WorkerId, ActionId, Output); + if (ResponseCode != HttpResponseCode::OK) + { + return HttpReq.WriteResponse(ResponseCode); + } + return HttpReq.WriteResponse(HttpResponseCode::OK, Output); + } + break; + } + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "simple/{worker}", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const IoHash WorkerId = IoHash::FromHexString(Req.GetCapture(1)); + + WorkerDesc Worker; + + { + RwLock::SharedLockScope _(m_WorkerLock); + + if (auto It = m_WorkerMap.find(WorkerId); It == m_WorkerMap.end()) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + else + { + Worker = It->second; + } + } + + switch (HttpReq.RequestVerb()) + { + case HttpVerb::kGet: + { + CbObject Output; + HttpResponseCode ResponseCode = ExecActionUpstreamResult(WorkerId, Output); + if (ResponseCode != HttpResponseCode::OK) + { + return HttpReq.WriteResponse(ResponseCode); + } + + { + RwLock::SharedLockScope _(m_WorkerLock); + m_WorkerMap.erase(WorkerId); + } + + return HttpReq.WriteResponse(HttpResponseCode::OK, Output); + } + break; + + case HttpVerb::kPost: + { + CbObject Output; + HttpResponseCode ResponseCode = ExecActionUpstream(Worker, Output); + if (ResponseCode != HttpResponseCode::OK) + { + return HttpReq.WriteResponse(ResponseCode); + } + return HttpReq.WriteResponse(HttpResponseCode::OK, Output); + } + break; + + default: + break; + } + }, + HttpVerb::kGet | HttpVerb::kPost); + + m_Router.RegisterRoute( + "jobs/{worker}", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const IoHash WorkerId = IoHash::FromHexString(Req.GetCapture(1)); + + WorkerDesc Worker; + + { + RwLock::SharedLockScope _(m_WorkerLock); + + if (auto It = m_WorkerMap.find(WorkerId); It == m_WorkerMap.end()) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + else + { + Worker = It->second; + } + } + + switch (HttpReq.RequestVerb()) + { + case HttpVerb::kGet: + // TODO: return status of all pending or executing jobs + break; + + case HttpVerb::kPost: + switch (HttpReq.RequestContentType()) + { + case HttpContentType::kCbObject: + { + // This operation takes the proposed job spec and identifies which + // chunks are not present on this server. This list is then returned in + // the "need" list in the response + + IoBuffer Payload = HttpReq.ReadPayload(); + CbObject RequestObject = LoadCompactBinaryObject(Payload); + + std::vector<IoHash> NeedList; + + RequestObject.IterateAttachments([&](CbFieldView Field) { + const IoHash FileHash = Field.AsHash(); + + if (!m_CidStore.ContainsChunk(FileHash)) + { + NeedList.push_back(FileHash); + } + }); + + if (NeedList.empty()) + { + // We already have everything + CbObject Output; + HttpResponseCode ResponseCode = ExecActionUpstream(Worker, RequestObject, Output); + + if (ResponseCode != HttpResponseCode::OK) + { + return HttpReq.WriteResponse(ResponseCode); + } + return HttpReq.WriteResponse(HttpResponseCode::OK, Output); + } + + CbObjectWriter Cbo; + Cbo.BeginArray("need"); + + for (const IoHash& Hash : NeedList) + { + Cbo << Hash; + } + + Cbo.EndArray(); + CbObject Response = Cbo.Save(); + + return HttpReq.WriteResponse(HttpResponseCode::NotFound, Response); + } + break; + + case HttpContentType::kCbPackage: + { + CbPackage Action = HttpReq.ReadPayloadPackage(); + CbObject ActionObj = Action.GetObject(); + + std::span<const CbAttachment> Attachments = Action.GetAttachments(); + + int AttachmentCount = 0; + int NewAttachmentCount = 0; + uint64_t TotalAttachmentBytes = 0; + uint64_t TotalNewBytes = 0; + + for (const CbAttachment& Attachment : Attachments) + { + ZEN_ASSERT(Attachment.IsCompressedBinary()); + + const IoHash DataHash = Attachment.GetHash(); + CompressedBuffer DataView = Attachment.AsCompressedBinary(); + + ZEN_UNUSED(DataHash); + + const uint64_t CompressedSize = DataView.GetCompressedSize(); + + TotalAttachmentBytes += CompressedSize; + ++AttachmentCount; + + const CidStore::InsertResult InsertResult = + m_CidStore.AddChunk(DataView.GetCompressed().Flatten().AsIoBuffer(), DataHash); + + if (InsertResult.New) + { + TotalNewBytes += CompressedSize; + ++NewAttachmentCount; + } + } + + ZEN_DEBUG("new action: {} in {} attachments. {} new ({} attachments)", + zen::NiceBytes(TotalAttachmentBytes), + AttachmentCount, + zen::NiceBytes(TotalNewBytes), + NewAttachmentCount); + + CbObject Output; + HttpResponseCode ResponseCode = ExecActionUpstream(Worker, ActionObj, Output); + + if (ResponseCode != HttpResponseCode::OK) + { + return HttpReq.WriteResponse(ResponseCode); + } + return HttpReq.WriteResponse(HttpResponseCode::OK, Output); + } + break; + + default: + break; + } + break; + + default: + break; + } + }, + HttpVerb::kPost); +} + +HttpFunctionService::~HttpFunctionService() +{ +} + +const char* +HttpFunctionService::BaseUri() const +{ + return "/apply/"; +} + +void +HttpFunctionService::HandleRequest(HttpServerRequest& Request) +{ + if (m_Router.HandleRequest(Request) == false) + { + ZEN_WARN("No route found for {0}", Request.RelativeUri()); + } +} + +HttpResponseCode +HttpFunctionService::ExecActionUpstream(const WorkerDesc& Worker, CbObject& Object) +{ + const IoHash WorkerId = Worker.Descriptor.GetHash(); + + ZEN_INFO("Action {} being processed...", WorkerId.ToHexString()); + + auto EnqueueResult = m_UpstreamApply->EnqueueUpstream({.WorkerDescriptor = Worker.Descriptor, .Type = UpstreamApplyType::Simple}); + if (!EnqueueResult.Success) + { + ZEN_ERROR("Error enqueuing upstream Action {}", WorkerId.ToHexString()); + return HttpResponseCode::InternalServerError; + } + + CbObjectWriter Writer; + Writer.AddHash("worker", WorkerId); + + Object = Writer.Save(); + return HttpResponseCode::OK; +} + +HttpResponseCode +HttpFunctionService::ExecActionUpstreamResult(const IoHash& WorkerId, CbObject& Object) +{ + const static IoHash Empty = CbObject().GetHash(); + auto Status = m_UpstreamApply->GetStatus(WorkerId, Empty); + if (!Status.Success) + { + return HttpResponseCode::NotFound; + } + + if (Status.Status.State != UpstreamApplyState::Complete) + { + return HttpResponseCode::Accepted; + } + + GetUpstreamApplyResult& Completed = Status.Status.Result; + + if (!Completed.Success) + { + ZEN_ERROR("Action {} failed:\n stdout: {}\n stderr: {}\n reason: {}\n errorcode: {}", + WorkerId.ToHexString(), + Completed.StdOut, + Completed.StdErr, + Completed.Error.Reason, + Completed.Error.ErrorCode); + + if (Completed.Error.ErrorCode == 0) + { + Completed.Error.ErrorCode = -1; + } + if (Completed.StdErr.empty() && !Completed.Error.Reason.empty()) + { + Completed.StdErr = Completed.Error.Reason; + } + } + else + { + ZEN_INFO("Action {} completed with {} files ExitCode={}", + WorkerId.ToHexString(), + Completed.OutputFiles.size(), + Completed.Error.ErrorCode); + } + + CbObjectWriter ResultObject; + + ResultObject.AddString("agent"sv, Completed.Agent); + ResultObject.AddString("detail"sv, Completed.Detail); + ResultObject.AddString("stdout"sv, Completed.StdOut); + ResultObject.AddString("stderr"sv, Completed.StdErr); + ResultObject.AddInteger("exitcode"sv, Completed.Error.ErrorCode); + ResultObject.BeginArray("stats"sv); + for (const auto& Timepoint : Completed.Timepoints) + { + ResultObject.BeginObject(); + ResultObject.AddString("name"sv, Timepoint.first); + ResultObject.AddDateTimeTicks("time"sv, Timepoint.second); + ResultObject.EndObject(); + } + ResultObject.EndArray(); + + ResultObject.BeginArray("files"sv); + for (const auto& File : Completed.OutputFiles) + { + ResultObject.BeginObject(); + ResultObject.AddString("name"sv, File.first.string()); + ResultObject.AddBinary("data"sv, Completed.FileData[File.second]); + ResultObject.EndObject(); + } + ResultObject.EndArray(); + + Object = ResultObject.Save(); + return HttpResponseCode::OK; +} + +HttpResponseCode +HttpFunctionService::ExecActionUpstream(const WorkerDesc& Worker, CbObject Action, CbObject& Object) +{ + const IoHash WorkerId = Worker.Descriptor.GetHash(); + const IoHash ActionId = Action.GetHash(); + + Action.MakeOwned(); + + ZEN_INFO("Action {}/{} being processed...", WorkerId.ToHexString(), ActionId.ToHexString()); + + auto EnqueueResult = m_UpstreamApply->EnqueueUpstream( + {.WorkerDescriptor = Worker.Descriptor, .Action = std::move(Action), .Type = UpstreamApplyType::Asset}); + + if (!EnqueueResult.Success) + { + ZEN_ERROR("Error enqueuing upstream Action {}/{}", WorkerId.ToHexString(), ActionId.ToHexString()); + return HttpResponseCode::InternalServerError; + } + + CbObjectWriter Writer; + Writer.AddHash("worker", WorkerId); + Writer.AddHash("action", ActionId); + + Object = Writer.Save(); + return HttpResponseCode::OK; +} + +HttpResponseCode +HttpFunctionService::ExecActionUpstreamResult(const IoHash& WorkerId, const IoHash& ActionId, CbPackage& Package) +{ + auto Status = m_UpstreamApply->GetStatus(WorkerId, ActionId); + if (!Status.Success) + { + return HttpResponseCode::NotFound; + } + + if (Status.Status.State != UpstreamApplyState::Complete) + { + return HttpResponseCode::Accepted; + } + + GetUpstreamApplyResult& Completed = Status.Status.Result; + if (!Completed.Success || Completed.Error.ErrorCode != 0) + { + ZEN_ERROR("Action {}/{} failed:\n stdout: {}\n stderr: {}\n reason: {}\n errorcode: {}", + WorkerId.ToHexString(), + ActionId.ToHexString(), + Completed.StdOut, + Completed.StdErr, + Completed.Error.Reason, + Completed.Error.ErrorCode); + + return HttpResponseCode::InternalServerError; + } + + ZEN_INFO("Action {}/{} completed with {} attachments ({} compressed, {} uncompressed)", + WorkerId.ToHexString(), + ActionId.ToHexString(), + Completed.OutputPackage.GetAttachments().size(), + NiceBytes(Completed.TotalAttachmentBytes), + NiceBytes(Completed.TotalRawAttachmentBytes)); + + Package = std::move(Completed.OutputPackage); + return HttpResponseCode::OK; +} + +} // namespace zen + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zenserver/compute/function.h b/src/zenserver/compute/function.h new file mode 100644 index 000000000..650cee757 --- /dev/null +++ b/src/zenserver/compute/function.h @@ -0,0 +1,73 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/zencore.h> + +#if !defined(ZEN_WITH_COMPUTE_SERVICES) +# define ZEN_WITH_COMPUTE_SERVICES 1 +#endif + +#if ZEN_WITH_COMPUTE_SERVICES + +# include <zencore/compactbinary.h> +# include <zencore/iohash.h> +# include <zencore/logging.h> +# include <zenhttp/httpserver.h> + +# include <filesystem> +# include <unordered_map> + +namespace zen { + +class CidStore; +class UpstreamApply; +class CloudCacheClient; +class AuthMgr; + +struct UpstreamAuthConfig; +struct CloudCacheClientOptions; + +/** + * Lambda style compute function service + */ +class HttpFunctionService : public HttpService +{ +public: + HttpFunctionService(CidStore& InCidStore, + const CloudCacheClientOptions& ComputeOptions, + const CloudCacheClientOptions& StorageOptions, + const UpstreamAuthConfig& ComputeAuthConfig, + const UpstreamAuthConfig& StorageAuthConfig, + AuthMgr& Mgr); + ~HttpFunctionService(); + + virtual const char* BaseUri() const override; + virtual void HandleRequest(HttpServerRequest& Request) override; + +private: + std::thread InitializeThread; + spdlog::logger& Log() { return m_Log; } + spdlog::logger& m_Log; + HttpRequestRouter m_Router; + CidStore& m_CidStore; + std::unique_ptr<UpstreamApply> m_UpstreamApply; + + struct WorkerDesc + { + CbObject Descriptor; + }; + + [[nodiscard]] HttpResponseCode ExecActionUpstream(const WorkerDesc& Worker, CbObject& Object); + [[nodiscard]] HttpResponseCode ExecActionUpstreamResult(const IoHash& WorkerId, CbObject& Object); + + [[nodiscard]] HttpResponseCode ExecActionUpstream(const WorkerDesc& Worker, CbObject Action, CbObject& Object); + [[nodiscard]] HttpResponseCode ExecActionUpstreamResult(const IoHash& WorkerId, const IoHash& ActionId, CbPackage& Package); + + RwLock m_WorkerLock; + std::unordered_map<IoHash, WorkerDesc> m_WorkerMap; +}; + +} // namespace zen + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zenserver/config.cpp b/src/zenserver/config.cpp new file mode 100644 index 000000000..cff93d67b --- /dev/null +++ b/src/zenserver/config.cpp @@ -0,0 +1,902 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "config.h" + +#include "diag/logging.h" + +#include <zencore/crypto.h> +#include <zencore/fmtutils.h> +#include <zencore/iobuffer.h> +#include <zencore/string.h> +#include <zenhttp/zenhttp.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <fmt/format.h> +#include <zencore/logging.h> +#include <cxxopts.hpp> +#include <sol/sol.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +#if ZEN_PLATFORM_WINDOWS +# include <conio.h> +#else +# include <pwd.h> +#endif + +#if ZEN_PLATFORM_WINDOWS + +// Used for getting My Documents for default data directory +# include <ShlObj.h> +# pragma comment(lib, "shell32.lib") + +std::filesystem::path +PickDefaultStateDirectory() +{ + // Pick sensible default + PWSTR programDataDir = nullptr; + HRESULT hRes = SHGetKnownFolderPath(FOLDERID_ProgramData, 0, NULL, &programDataDir); + + if (SUCCEEDED(hRes)) + { + std::filesystem::path finalPath(programDataDir); + finalPath /= L"Epic\\Zen\\Data"; + ::CoTaskMemFree(programDataDir); + + return finalPath; + } + + return L""; +} + +#else + +std::filesystem::path +PickDefaultStateDirectory() +{ + int UserId = getuid(); + const passwd* Passwd = getpwuid(UserId); + return std::filesystem::path(Passwd->pw_dir) / ".zen"; +} + +#endif + +void +ValidateOptions(ZenServerOptions& ServerOptions) +{ + if (ServerOptions.EncryptionKey.empty() == false) + { + const auto Key = zen::AesKey256Bit::FromString(ServerOptions.EncryptionKey); + + if (Key.IsValid() == false) + { + throw cxxopts::OptionParseException("Invalid AES encryption key"); + } + } + + if (ServerOptions.EncryptionIV.empty() == false) + { + const auto IV = zen::AesIV128Bit::FromString(ServerOptions.EncryptionIV); + + if (IV.IsValid() == false) + { + throw cxxopts::OptionParseException("Invalid AES initialization vector"); + } + } +} + +UpstreamCachePolicy +ParseUpstreamCachePolicy(std::string_view Options) +{ + if (Options == "readonly") + { + return UpstreamCachePolicy::Read; + } + else if (Options == "writeonly") + { + return UpstreamCachePolicy::Write; + } + else if (Options == "disabled") + { + return UpstreamCachePolicy::Disabled; + } + else + { + return UpstreamCachePolicy::ReadWrite; + } +} + +ZenObjectStoreConfig +ParseBucketConfigs(std::span<std::string> Buckets) +{ + using namespace std::literals; + + ZenObjectStoreConfig Cfg; + + // split bucket args in the form of "{BucketName};{LocalPath}" + for (std::string_view Bucket : Buckets) + { + ZenObjectStoreConfig::BucketConfig NewBucket; + + if (auto Idx = Bucket.find_first_of(";"); Idx != std::string_view::npos) + { + NewBucket.Name = Bucket.substr(0, Idx); + NewBucket.Directory = Bucket.substr(Idx + 1); + } + else + { + NewBucket.Name = Bucket; + } + + Cfg.Buckets.push_back(std::move(NewBucket)); + } + + return Cfg; +} + +void ParseConfigFile(const std::filesystem::path& Path, ZenServerOptions& ServerOptions); + +void +ParseCliOptions(int argc, char* argv[], ZenServerOptions& ServerOptions) +{ +#if ZEN_WITH_HTTPSYS + const char* DefaultHttp = "httpsys"; +#else + const char* DefaultHttp = "asio"; +#endif + + // Note to those adding future options; std::filesystem::path-type options + // must be read into a std::string first. As of cxxopts-3.0.0 it uses a >> + // stream operator to convert argv value into the options type. std::fs::path + // expects paths in streams to be quoted but argv paths are unquoted. By + // going into a std::string first, paths with whitespace parse correctly. + std::string DataDir; + std::string ContentDir; + std::string AbsLogFile; + std::string ConfigFile; + + cxxopts::Options options("zenserver", "Zen Server"); + options.add_options()("dedicated", + "Enable dedicated server mode", + cxxopts::value<bool>(ServerOptions.IsDedicated)->default_value("false")); + options.add_options()("d, debug", "Enable debugging", cxxopts::value<bool>(ServerOptions.IsDebug)->default_value("false")); + options.add_options()("help", "Show command line help"); + options.add_options()("t, test", "Enable test mode", cxxopts::value<bool>(ServerOptions.IsTest)->default_value("false")); + options.add_options()("log-id", "Specify id for adding context to log output", cxxopts::value<std::string>(ServerOptions.LogId)); + options.add_options()("data-dir", "Specify persistence root", cxxopts::value<std::string>(DataDir)); + options.add_options()("content-dir", "Frontend content directory", cxxopts::value<std::string>(ContentDir)); + options.add_options()("abslog", "Path to log file", cxxopts::value<std::string>(AbsLogFile)); + options.add_options()("config", "Path to Lua config file", cxxopts::value<std::string>(ConfigFile)); + options.add_options()("no-sentry", + "Disable Sentry crash handler", + cxxopts::value<bool>(ServerOptions.NoSentry)->default_value("false")); + + options.add_option("security", + "", + "encryption-aes-key", + "256 bit AES encryption key", + cxxopts::value<std::string>(ServerOptions.EncryptionKey), + ""); + + options.add_option("security", + "", + "encryption-aes-iv", + "128 bit AES encryption initialization vector", + cxxopts::value<std::string>(ServerOptions.EncryptionIV), + ""); + + std::string OpenIdProviderName; + options.add_option("security", + "", + "openid-provider-name", + "Open ID provider name", + cxxopts::value<std::string>(OpenIdProviderName), + "Default"); + + std::string OpenIdProviderUrl; + options.add_option("security", "", "openid-provider-url", "Open ID provider URL", cxxopts::value<std::string>(OpenIdProviderUrl), ""); + + std::string OpenIdClientId; + options.add_option("security", "", "openid-client-id", "Open ID client ID", cxxopts::value<std::string>(OpenIdClientId), ""); + + options + .add_option("lifetime", "", "owner-pid", "Specify owning process id", cxxopts::value<int>(ServerOptions.OwnerPid), "<identifier>"); + options.add_option("lifetime", + "", + "child-id", + "Specify id which can be used to signal parent", + cxxopts::value<std::string>(ServerOptions.ChildId), + "<identifier>"); + +#if ZEN_PLATFORM_WINDOWS + options.add_option("lifetime", + "", + "install", + "Install zenserver as a Windows service", + cxxopts::value<bool>(ServerOptions.InstallService), + ""); + options.add_option("lifetime", + "", + "uninstall", + "Uninstall zenserver as a Windows service", + cxxopts::value<bool>(ServerOptions.UninstallService), + ""); +#endif + + options.add_option("network", + "", + "http", + "Select HTTP server implementation (asio|httpsys|null)", + cxxopts::value<std::string>(ServerOptions.HttpServerClass)->default_value(DefaultHttp), + "<http class>"); + + options.add_option("network", + "p", + "port", + "Select HTTP port", + cxxopts::value<int>(ServerOptions.BasePort)->default_value("1337"), + "<port number>"); + + options.add_option("network", + "", + "websocket-port", + "Websocket server port", + cxxopts::value<int>(ServerOptions.WebSocketPort)->default_value("0"), + "<port number>"); + + options.add_option("network", + "", + "websocket-threads", + "Number of websocket I/O thread(s) (0 == hardware concurrency)", + cxxopts::value<int>(ServerOptions.WebSocketThreads)->default_value("0"), + ""); + +#if ZEN_WITH_TRACE + options.add_option("ue-trace", + "", + "tracehost", + "Hostname to send the trace to", + cxxopts::value<std::string>(ServerOptions.TraceHost)->default_value(""), + ""); + + options.add_option("ue-trace", + "", + "tracefile", + "Path to write a trace to", + cxxopts::value<std::string>(ServerOptions.TraceFile)->default_value(""), + ""); +#endif // ZEN_WITH_TRACE + + options.add_option("diagnostics", + "", + "crash", + "Simulate a crash", + cxxopts::value<bool>(ServerOptions.ShouldCrash)->default_value("false"), + ""); + + std::string UpstreamCachePolicyOptions; + options.add_option("cache", + "", + "upstream-cache-policy", + "", + cxxopts::value<std::string>(UpstreamCachePolicyOptions)->default_value(""), + "Upstream cache policy (readwrite|readonly|writeonly|disabled)"); + + options.add_option("cache", + "", + "upstream-jupiter-url", + "URL to a Jupiter instance", + cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.JupiterConfig.Url)->default_value(""), + ""); + + options.add_option("cache", + "", + "upstream-jupiter-oauth-url", + "URL to the OAuth provier", + cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.JupiterConfig.OAuthUrl)->default_value(""), + ""); + + options.add_option("cache", + "", + "upstream-jupiter-oauth-clientid", + "The OAuth client ID", + cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.JupiterConfig.OAuthClientId)->default_value(""), + ""); + + options.add_option("cache", + "", + "upstream-jupiter-oauth-clientsecret", + "The OAuth client secret", + cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.JupiterConfig.OAuthClientSecret)->default_value(""), + ""); + + options.add_option("cache", + "", + "upstream-jupiter-openid-provider", + "Name of a registered Open ID provider", + cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.JupiterConfig.OpenIdProvider)->default_value(""), + ""); + + options.add_option("cache", + "", + "upstream-jupiter-token", + "A static authentication token", + cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.JupiterConfig.AccessToken)->default_value(""), + ""); + + options.add_option("cache", + "", + "upstream-jupiter-namespace", + "The Common Blob Store API namespace", + cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.JupiterConfig.Namespace)->default_value(""), + ""); + + options.add_option("cache", + "", + "upstream-jupiter-namespace-ddc", + "The lecacy DDC namespace", + cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.JupiterConfig.DdcNamespace)->default_value(""), + ""); + + options.add_option("cache", + "", + "upstream-zen-url", + "URL to remote Zen server. Use a comma separated list to choose the one with the best latency.", + cxxopts::value<std::vector<std::string>>(ServerOptions.UpstreamCacheConfig.ZenConfig.Urls), + ""); + + options.add_option("cache", + "", + "upstream-zen-dns", + "DNS that resolves to one or more Zen server instance(s)", + cxxopts::value<std::vector<std::string>>(ServerOptions.UpstreamCacheConfig.ZenConfig.Dns), + ""); + + options.add_option("cache", + "", + "upstream-thread-count", + "Number of threads used for upstream procsssing", + cxxopts::value<int32_t>(ServerOptions.UpstreamCacheConfig.UpstreamThreadCount)->default_value("4"), + ""); + + options.add_option("cache", + "", + "upstream-connect-timeout-ms", + "Connect timeout in millisecond(s). Default 5000 ms.", + cxxopts::value<int32_t>(ServerOptions.UpstreamCacheConfig.ConnectTimeoutMilliseconds)->default_value("5000"), + ""); + + options.add_option("cache", + "", + "upstream-timeout-ms", + "Timeout in millisecond(s). Default 0 ms", + cxxopts::value<int32_t>(ServerOptions.UpstreamCacheConfig.TimeoutMilliseconds)->default_value("0"), + ""); + + options.add_option("compute", + "", + "upstream-horde-url", + "URL to a Horde instance.", + cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.HordeConfig.Url)->default_value(""), + ""); + + options.add_option("compute", + "", + "upstream-horde-oauth-url", + "URL to the OAuth provier", + cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.HordeConfig.OAuthUrl)->default_value(""), + ""); + + options.add_option("compute", + "", + "upstream-horde-oauth-clientid", + "The OAuth client ID", + cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.HordeConfig.OAuthClientId)->default_value(""), + ""); + + options.add_option("compute", + "", + "upstream-horde-oauth-clientsecret", + "The OAuth client secret", + cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.HordeConfig.OAuthClientSecret)->default_value(""), + ""); + + options.add_option("compute", + "", + "upstream-horde-openid-provider", + "Name of a registered Open ID provider", + cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.HordeConfig.OpenIdProvider)->default_value(""), + ""); + + options.add_option("compute", + "", + "upstream-horde-token", + "A static authentication token", + cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.HordeConfig.AccessToken)->default_value(""), + ""); + + options.add_option("compute", + "", + "upstream-horde-storage-url", + "URL to a Horde Storage instance.", + cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.HordeConfig.StorageUrl)->default_value(""), + ""); + + options.add_option("compute", + "", + "upstream-horde-storage-oauth-url", + "URL to the OAuth provier", + cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.HordeConfig.StorageOAuthUrl)->default_value(""), + ""); + + options.add_option("compute", + "", + "upstream-horde-storage-oauth-clientid", + "The OAuth client ID", + cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.HordeConfig.StorageOAuthClientId)->default_value(""), + ""); + + options.add_option( + "compute", + "", + "upstream-horde-storage-oauth-clientsecret", + "The OAuth client secret", + cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.HordeConfig.StorageOAuthClientSecret)->default_value(""), + ""); + + options.add_option("compute", + "", + "upstream-horde-storage-openid-provider", + "Name of a registered Open ID provider", + cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.HordeConfig.StorageOpenIdProvider)->default_value(""), + ""); + + options.add_option("compute", + "", + "upstream-horde-storage-token", + "A static authentication token", + cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.HordeConfig.StorageAccessToken)->default_value(""), + ""); + + options.add_option("compute", + "", + "upstream-horde-cluster", + "The Horde compute cluster id", + cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.HordeConfig.Cluster)->default_value(""), + ""); + + options.add_option("compute", + "", + "upstream-horde-namespace", + "The Jupiter namespace to use with Horde compute", + cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.HordeConfig.Namespace)->default_value(""), + ""); + + options.add_option("gc", + "", + "gc-enabled", + "Whether garbage collection is enabled or not.", + cxxopts::value<bool>(ServerOptions.GcConfig.Enabled)->default_value("true"), + ""); + + options.add_option("gc", + "", + "gc-small-objects", + "Whether garbage collection of small objects is enabled or not.", + cxxopts::value<bool>(ServerOptions.GcConfig.CollectSmallObjects)->default_value("true"), + ""); + + options.add_option("gc", + "", + "gc-interval-seconds", + "Garbage collection interval in seconds. Default set to 3600 (1 hour).", + cxxopts::value<int32_t>(ServerOptions.GcConfig.IntervalSeconds)->default_value("3600"), + ""); + + options.add_option("gc", + "", + "gc-cache-duration-seconds", + "Max duration in seconds before Z$ entries get evicted. Default set to 1209600 (2 weeks)", + cxxopts::value<int32_t>(ServerOptions.GcConfig.Cache.MaxDurationSeconds)->default_value("1209600"), + ""); + + options.add_option("gc", + "", + "disk-reserve-size", + "Size of gc disk reserve in bytes. Default set to 268435456 (256 Mb).", + cxxopts::value<uint64_t>(ServerOptions.GcConfig.DiskReserveSize)->default_value("268435456"), + ""); + + options.add_option("gc", + "", + "gc-monitor-interval-seconds", + "Garbage collection monitoring interval in seconds. Default set to 30 (30 seconds)", + cxxopts::value<int32_t>(ServerOptions.GcConfig.MonitorIntervalSeconds)->default_value("30"), + ""); + + options.add_option("gc", + "", + "gc-disksize-softlimit", + "Garbage collection disk usage soft limit. Default set to 0 (Off).", + cxxopts::value<uint64_t>(ServerOptions.GcConfig.Cache.DiskSizeSoftLimit)->default_value("0"), + ""); + + options.add_option("objectstore", + "", + "objectstore-enabled", + "Whether the object store is enabled or not.", + cxxopts::value<bool>(ServerOptions.ObjectStoreEnabled)->default_value("false"), + ""); + + std::vector<std::string> BucketConfigs; + options.add_option("objectstore", + "", + "objectstore-bucket", + "Object store bucket mappings.", + cxxopts::value<std::vector<std::string>>(BucketConfigs), + ""); + + try + { + auto result = options.parse(argc, argv); + + if (result.count("help")) + { + zen::logging::ConsoleLog().info("{}", options.help()); +#if ZEN_PLATFORM_WINDOWS + zen::logging::ConsoleLog().info("Press any key to exit!"); + _getch(); +#else + // Assume the user's in a terminal on all other platforms and that + // they'll use less/more/etc. if need be. +#endif + exit(0); + } + + auto MakeSafePath = [](const std::string& Path) { +#if ZEN_PLATFORM_WINDOWS + if (Path.empty()) + { + return Path; + } + + std::string FixedPath = Path; + std::replace(FixedPath.begin(), FixedPath.end(), '/', '\\'); + if (!FixedPath.starts_with("\\\\?\\")) + { + FixedPath.insert(0, "\\\\?\\"); + } + return FixedPath; +#else + return Path; +#endif + }; + + ServerOptions.DataDir = MakeSafePath(DataDir); + ServerOptions.ContentDir = MakeSafePath(ContentDir); + ServerOptions.AbsLogFile = MakeSafePath(AbsLogFile); + ServerOptions.ConfigFile = MakeSafePath(ConfigFile); + ServerOptions.UpstreamCacheConfig.CachePolicy = ParseUpstreamCachePolicy(UpstreamCachePolicyOptions); + + if (OpenIdProviderUrl.empty() == false) + { + if (OpenIdClientId.empty()) + { + throw cxxopts::OptionParseException("Invalid OpenID client ID"); + } + + ServerOptions.AuthConfig.OpenIdProviders.push_back( + {.Name = OpenIdProviderName, .Url = OpenIdProviderUrl, .ClientId = OpenIdClientId}); + } + + ServerOptions.ObjectStoreConfig = ParseBucketConfigs(BucketConfigs); + + if (!ServerOptions.ConfigFile.empty()) + { + ParseConfigFile(ServerOptions.ConfigFile, ServerOptions); + } + else + { + ParseConfigFile(ServerOptions.DataDir / "zen_cfg.lua", ServerOptions); + } + + ValidateOptions(ServerOptions); + } + catch (cxxopts::OptionParseException& e) + { + zen::logging::ConsoleLog().error("Error parsing zenserver arguments: {}\n\n{}", e.what(), options.help()); + + throw; + } + + if (ServerOptions.DataDir.empty()) + { + ServerOptions.DataDir = PickDefaultStateDirectory(); + } + + if (ServerOptions.AbsLogFile.empty()) + { + ServerOptions.AbsLogFile = ServerOptions.DataDir / "logs" / "zenserver.log"; + } +} + +void +ParseConfigFile(const std::filesystem::path& Path, ZenServerOptions& ServerOptions) +{ + zen::IoBuffer LuaScript = zen::IoBufferBuilder::MakeFromFile(Path); + + if (LuaScript) + { + sol::state lua; + + lua.open_libraries(sol::lib::base); + + lua.set_function("getenv", [&](const std::string env) -> sol::object { +#if ZEN_PLATFORM_WINDOWS + std::wstring EnvVarValue; + size_t RequiredSize = 0; + std::wstring EnvWide = zen::Utf8ToWide(env); + _wgetenv_s(&RequiredSize, nullptr, 0, EnvWide.c_str()); + + if (RequiredSize == 0) + return sol::make_object(lua, sol::lua_nil); + + EnvVarValue.resize(RequiredSize); + _wgetenv_s(&RequiredSize, EnvVarValue.data(), RequiredSize, EnvWide.c_str()); + return sol::make_object(lua, zen::WideToUtf8(EnvVarValue.c_str())); +#else + ZEN_UNUSED(env); + return sol::make_object(lua, sol::lua_nil); +#endif + }); + + try + { + sol::load_result config = lua.load(std::string_view((const char*)LuaScript.Data(), LuaScript.Size()), "zen_cfg"); + + if (!config.valid()) + { + sol::error err = config; + + std::string ErrorString = sol::to_string(config.status()); + + throw std::runtime_error(fmt::format("{} error: {}", ErrorString, err.what())); + } + + config(); + } + catch (std::exception& e) + { + throw std::runtime_error(fmt::format("failed to load config script ('{}'): {}", Path, e.what()).c_str()); + } + + if (sol::optional<sol::table> ServerConfig = lua["server"]) + { + if (ServerOptions.DataDir.empty()) + { + if (sol::optional<std::string> Opt = ServerConfig.value()["datadir"]) + { + ServerOptions.DataDir = Opt.value(); + } + } + + if (ServerOptions.ContentDir.empty()) + { + if (sol::optional<std::string> Opt = ServerConfig.value()["contentdir"]) + { + ServerOptions.ContentDir = Opt.value(); + } + } + + if (ServerOptions.AbsLogFile.empty()) + { + if (sol::optional<std::string> Opt = ServerConfig.value()["abslog"]) + { + ServerOptions.AbsLogFile = Opt.value(); + } + } + + ServerOptions.IsDebug = ServerConfig->get_or("debug", ServerOptions.IsDebug); + } + + if (sol::optional<sol::table> NetworkConfig = lua["network"]) + { + if (sol::optional<std::string> Opt = NetworkConfig.value()["httpserverclass"]) + { + ServerOptions.HttpServerClass = Opt.value(); + } + + ServerOptions.BasePort = NetworkConfig->get_or<int>("port", ServerOptions.BasePort); + } + + auto UpdateStringValueFromConfig = [](const sol::table& Table, std::string_view Key, std::string& OutValue) { + // Update the specified config value unless it has been set, i.e. from command line + if (auto MaybeValue = Table.get<sol::optional<std::string>>(Key); MaybeValue.has_value() && OutValue.empty()) + { + OutValue = MaybeValue.value(); + } + }; + + if (sol::optional<sol::table> StructuredCacheConfig = lua["cache"]) + { + ServerOptions.StructuredCacheEnabled = StructuredCacheConfig->get_or("enable", ServerOptions.StructuredCacheEnabled); + + if (auto UpstreamConfig = StructuredCacheConfig->get<sol::optional<sol::table>>("upstream")) + { + std::string Policy = UpstreamConfig->get_or("policy", std::string()); + ServerOptions.UpstreamCacheConfig.CachePolicy = ParseUpstreamCachePolicy(Policy); + ServerOptions.UpstreamCacheConfig.UpstreamThreadCount = + UpstreamConfig->get_or("upstreamthreadcount", ServerOptions.UpstreamCacheConfig.UpstreamThreadCount); + + if (auto JupiterConfig = UpstreamConfig->get<sol::optional<sol::table>>("jupiter")) + { + UpdateStringValueFromConfig(JupiterConfig.value(), + std::string_view("name"), + ServerOptions.UpstreamCacheConfig.JupiterConfig.Name); + UpdateStringValueFromConfig(JupiterConfig.value(), + std::string_view("url"), + ServerOptions.UpstreamCacheConfig.JupiterConfig.Url); + UpdateStringValueFromConfig(JupiterConfig.value(), + std::string_view("oauthprovider"), + ServerOptions.UpstreamCacheConfig.JupiterConfig.OAuthUrl); + UpdateStringValueFromConfig(JupiterConfig.value(), + std::string_view("oauthclientid"), + ServerOptions.UpstreamCacheConfig.JupiterConfig.OAuthClientId); + UpdateStringValueFromConfig(JupiterConfig.value(), + std::string_view("oauthclientsecret"), + ServerOptions.UpstreamCacheConfig.JupiterConfig.OAuthClientSecret); + UpdateStringValueFromConfig(JupiterConfig.value(), + std::string_view("openidprovider"), + ServerOptions.UpstreamCacheConfig.JupiterConfig.OpenIdProvider); + UpdateStringValueFromConfig(JupiterConfig.value(), + std::string_view("token"), + ServerOptions.UpstreamCacheConfig.JupiterConfig.AccessToken); + UpdateStringValueFromConfig(JupiterConfig.value(), + std::string_view("namespace"), + ServerOptions.UpstreamCacheConfig.JupiterConfig.Namespace); + UpdateStringValueFromConfig(JupiterConfig.value(), + std::string_view("ddcnamespace"), + ServerOptions.UpstreamCacheConfig.JupiterConfig.DdcNamespace); + }; + + if (auto ZenConfig = UpstreamConfig->get<sol::optional<sol::table>>("zen")) + { + ServerOptions.UpstreamCacheConfig.ZenConfig.Name = ZenConfig.value().get_or("name", std::string("Zen")); + + if (auto Url = ZenConfig.value().get<sol::optional<std::string>>("url")) + { + ServerOptions.UpstreamCacheConfig.ZenConfig.Urls.push_back(Url.value()); + } + else if (auto Urls = ZenConfig.value().get<sol::optional<sol::table>>("url")) + { + for (const auto& Kv : Urls.value()) + { + ServerOptions.UpstreamCacheConfig.ZenConfig.Urls.push_back(Kv.second.as<std::string>()); + } + } + + if (auto Dns = ZenConfig.value().get<sol::optional<std::string>>("dns")) + { + ServerOptions.UpstreamCacheConfig.ZenConfig.Dns.push_back(Dns.value()); + } + else if (auto DnsArray = ZenConfig.value().get<sol::optional<sol::table>>("dns")) + { + for (const auto& Kv : DnsArray.value()) + { + ServerOptions.UpstreamCacheConfig.ZenConfig.Dns.push_back(Kv.second.as<std::string>()); + } + } + } + } + } + + if (sol::optional<sol::table> ExecConfig = lua["exec"]) + { + ServerOptions.ExecServiceEnabled = ExecConfig->get_or("enable", ServerOptions.ExecServiceEnabled); + } + + if (sol::optional<sol::table> ComputeConfig = lua["compute"]) + { + ServerOptions.ComputeServiceEnabled = ComputeConfig->get_or("enable", ServerOptions.ComputeServiceEnabled); + + if (auto UpstreamConfig = ComputeConfig->get<sol::optional<sol::table>>("upstream")) + { + if (auto HordeConfig = UpstreamConfig->get<sol::optional<sol::table>>("horde")) + { + UpdateStringValueFromConfig(HordeConfig.value(), + std::string_view("name"), + ServerOptions.UpstreamCacheConfig.HordeConfig.Name); + UpdateStringValueFromConfig(HordeConfig.value(), + std::string_view("url"), + ServerOptions.UpstreamCacheConfig.HordeConfig.Url); + UpdateStringValueFromConfig(HordeConfig.value(), + std::string_view("oauthprovider"), + ServerOptions.UpstreamCacheConfig.HordeConfig.OAuthUrl); + UpdateStringValueFromConfig(HordeConfig.value(), + std::string_view("oauthclientid"), + ServerOptions.UpstreamCacheConfig.HordeConfig.OAuthClientId); + UpdateStringValueFromConfig(HordeConfig.value(), + std::string_view("oauthclientsecret"), + ServerOptions.UpstreamCacheConfig.HordeConfig.OAuthClientSecret); + UpdateStringValueFromConfig(HordeConfig.value(), + std::string_view("openidprovider"), + ServerOptions.UpstreamCacheConfig.HordeConfig.OpenIdProvider); + UpdateStringValueFromConfig(HordeConfig.value(), + std::string_view("token"), + ServerOptions.UpstreamCacheConfig.HordeConfig.AccessToken); + UpdateStringValueFromConfig(HordeConfig.value(), + std::string_view("cluster"), + ServerOptions.UpstreamCacheConfig.HordeConfig.Cluster); + UpdateStringValueFromConfig(HordeConfig.value(), + std::string_view("namespace"), + ServerOptions.UpstreamCacheConfig.HordeConfig.Namespace); + }; + + if (auto StorageConfig = UpstreamConfig->get<sol::optional<sol::table>>("storage")) + { + UpdateStringValueFromConfig(StorageConfig.value(), + std::string_view("url"), + ServerOptions.UpstreamCacheConfig.HordeConfig.StorageUrl); + UpdateStringValueFromConfig(StorageConfig.value(), + std::string_view("oauthprovider"), + ServerOptions.UpstreamCacheConfig.HordeConfig.StorageOAuthUrl); + UpdateStringValueFromConfig(StorageConfig.value(), + std::string_view("oauthclientid"), + ServerOptions.UpstreamCacheConfig.HordeConfig.StorageOAuthClientId); + UpdateStringValueFromConfig(StorageConfig.value(), + std::string_view("oauthclientsecret"), + ServerOptions.UpstreamCacheConfig.HordeConfig.StorageOAuthClientSecret); + UpdateStringValueFromConfig(StorageConfig.value(), + std::string_view("openidprovider"), + ServerOptions.UpstreamCacheConfig.HordeConfig.StorageOpenIdProvider); + UpdateStringValueFromConfig(StorageConfig.value(), + std::string_view("token"), + ServerOptions.UpstreamCacheConfig.HordeConfig.StorageAccessToken); + }; + } + } + + if (sol::optional<sol::table> GcConfig = lua["gc"]) + { + ServerOptions.GcConfig.MonitorIntervalSeconds = GcConfig.value().get_or("monitorintervalseconds", 30); + ServerOptions.GcConfig.IntervalSeconds = GcConfig.value().get_or("intervalseconds", 0); + ServerOptions.GcConfig.DiskReserveSize = GcConfig.value().get_or("diskreservesize", uint64_t(1u << 28)); + + if (sol::optional<sol::table> CacheGcConfig = GcConfig.value()["cache"]) + { + ServerOptions.GcConfig.Cache.MaxDurationSeconds = CacheGcConfig.value().get_or("maxdurationseconds", int32_t(0)); + ServerOptions.GcConfig.Cache.DiskSizeLimit = CacheGcConfig.value().get_or("disksizelimit", ~uint64_t(0)); + ServerOptions.GcConfig.Cache.MemorySizeLimit = CacheGcConfig.value().get_or("memorysizelimit", ~uint64_t(0)); + ServerOptions.GcConfig.Cache.DiskSizeSoftLimit = CacheGcConfig.value().get_or("disksizesoftlimit", 0); + } + + if (sol::optional<sol::table> CasGcConfig = GcConfig.value()["cas"]) + { + ServerOptions.GcConfig.Cas.LargeStrategySizeLimit = CasGcConfig.value().get_or("largestrategysizelimit", ~uint64_t(0)); + ServerOptions.GcConfig.Cas.SmallStrategySizeLimit = CasGcConfig.value().get_or("smallstrategysizelimit", ~uint64_t(0)); + ServerOptions.GcConfig.Cas.TinyStrategySizeLimit = CasGcConfig.value().get_or("tinystrategysizelimit", ~uint64_t(0)); + } + } + + if (sol::optional<sol::table> SecurityConfig = lua["security"]) + { + if (sol::optional<sol::table> OpenIdProviders = SecurityConfig.value()["openidproviders"]) + { + for (const auto& Kv : OpenIdProviders.value()) + { + if (sol::optional<sol::table> OpenIdProvider = Kv.second.as<sol::table>()) + { + std::string Name = OpenIdProvider.value().get_or("name", std::string("Default")); + std::string Url = OpenIdProvider.value().get_or("url", std::string()); + std::string ClientId = OpenIdProvider.value().get_or("clientid", std::string()); + + ServerOptions.AuthConfig.OpenIdProviders.push_back( + {.Name = std::move(Name), .Url = std::move(Url), .ClientId = std::move(ClientId)}); + } + } + } + + ServerOptions.EncryptionKey = SecurityConfig.value().get_or("encryptionaeskey", std::string()); + ServerOptions.EncryptionIV = SecurityConfig.value().get_or("encryptionaesiv", std::string()); + } + } +} diff --git a/src/zenserver/config.h b/src/zenserver/config.h new file mode 100644 index 000000000..8a5c6de4e --- /dev/null +++ b/src/zenserver/config.h @@ -0,0 +1,158 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/zencore.h> +#include <filesystem> +#include <string> +#include <vector> + +struct ZenUpstreamJupiterConfig +{ + std::string Name; + std::string Url; + std::string OAuthUrl; + std::string OAuthClientId; + std::string OAuthClientSecret; + std::string OpenIdProvider; + std::string AccessToken; + std::string Namespace; + std::string DdcNamespace; +}; + +struct ZenUpstreamHordeConfig +{ + std::string Name; + std::string Url; + std::string OAuthUrl; + std::string OAuthClientId; + std::string OAuthClientSecret; + std::string OpenIdProvider; + std::string AccessToken; + + std::string StorageUrl; + std::string StorageOAuthUrl; + std::string StorageOAuthClientId; + std::string StorageOAuthClientSecret; + std::string StorageOpenIdProvider; + std::string StorageAccessToken; + + std::string Cluster; + std::string Namespace; +}; + +struct ZenUpstreamZenConfig +{ + std::string Name; + std::vector<std::string> Urls; + std::vector<std::string> Dns; +}; + +enum class UpstreamCachePolicy : uint8_t +{ + Disabled = 0, + Read = 1 << 0, + Write = 1 << 1, + ReadWrite = Read | Write +}; + +struct ZenUpstreamCacheConfig +{ + ZenUpstreamJupiterConfig JupiterConfig; + ZenUpstreamHordeConfig HordeConfig; + ZenUpstreamZenConfig ZenConfig; + int32_t UpstreamThreadCount = 4; + int32_t ConnectTimeoutMilliseconds = 5000; + int32_t TimeoutMilliseconds = 0; + UpstreamCachePolicy CachePolicy = UpstreamCachePolicy::ReadWrite; +}; + +struct ZenCacheEvictionPolicy +{ + uint64_t DiskSizeLimit = ~uint64_t(0); + uint64_t MemorySizeLimit = 1024 * 1024 * 1024; + int32_t MaxDurationSeconds = 24 * 60 * 60; + uint64_t DiskSizeSoftLimit = 0; + bool Enabled = true; +}; + +struct ZenCasEvictionPolicy +{ + uint64_t LargeStrategySizeLimit = ~uint64_t(0); + uint64_t SmallStrategySizeLimit = ~uint64_t(0); + uint64_t TinyStrategySizeLimit = ~uint64_t(0); + bool Enabled = true; +}; + +struct ZenGcConfig +{ + ZenCasEvictionPolicy Cas; + ZenCacheEvictionPolicy Cache; + int32_t MonitorIntervalSeconds = 30; + int32_t IntervalSeconds = 0; + bool CollectSmallObjects = true; + bool Enabled = true; + uint64_t DiskReserveSize = 1ul << 28; +}; + +struct ZenOpenIdProviderConfig +{ + std::string Name; + std::string Url; + std::string ClientId; +}; + +struct ZenAuthConfig +{ + std::vector<ZenOpenIdProviderConfig> OpenIdProviders; +}; + +struct ZenObjectStoreConfig +{ + struct BucketConfig + { + std::string Name; + std::filesystem::path Directory; + }; + + std::vector<BucketConfig> Buckets; +}; + +struct ZenServerOptions +{ + ZenUpstreamCacheConfig UpstreamCacheConfig; + ZenGcConfig GcConfig; + ZenAuthConfig AuthConfig; + ZenObjectStoreConfig ObjectStoreConfig; + std::filesystem::path DataDir; // Root directory for state (used for testing) + std::filesystem::path ContentDir; // Root directory for serving frontend content (experimental) + std::filesystem::path AbsLogFile; // Absolute path to main log file + std::filesystem::path ConfigFile; // Path to Lua config file + std::string ChildId; // Id assigned by parent process (used for lifetime management) + std::string LogId; // Id for tagging log output + std::string HttpServerClass; // Choice of HTTP server implementation + std::string EncryptionKey; // 256 bit AES encryption key + std::string EncryptionIV; // 128 bit AES initialization vector + int BasePort = 1337; // Service listen port (used for both UDP and TCP) + int OwnerPid = 0; // Parent process id (zero for standalone) + int WebSocketPort = 0; // Web socket port (Zero = disabled) + int WebSocketThreads = 0; + bool InstallService = false; // Flag used to initiate service install (temporary) + bool UninstallService = false; // Flag used to initiate service uninstall (temporary) + bool IsDebug = false; + bool IsTest = false; + bool IsDedicated = false; // Indicates a dedicated/shared instance, with larger resource requirements + bool StructuredCacheEnabled = true; + bool ExecServiceEnabled = true; + bool ComputeServiceEnabled = true; + bool ShouldCrash = false; // Option for testing crash handling + bool IsFirstRun = false; + bool NoSentry = false; + bool ObjectStoreEnabled = false; +#if ZEN_WITH_TRACE + std::string TraceHost; // Host name or IP address to send trace data to + std::string TraceFile; // Path of a file to write a trace +#endif +}; + +void ParseCliOptions(int argc, char* argv[], ZenServerOptions& ServerOptions); diff --git a/src/zenserver/diag/diagsvcs.cpp b/src/zenserver/diag/diagsvcs.cpp new file mode 100644 index 000000000..29ad5c3dd --- /dev/null +++ b/src/zenserver/diag/diagsvcs.cpp @@ -0,0 +1,127 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "diagsvcs.h" + +#include <zencore/compactbinary.h> +#include <zencore/compactbinarybuilder.h> +#include <zencore/config.h> +#include <zencore/filesystem.h> +#include <zencore/logging.h> +#include <zencore/string.h> +#include <fstream> +#include <sstream> + +#include <json11.hpp> + +namespace zen { + +using namespace std::literals; + +bool +ReadFile(const std::string& Path, StringBuilderBase& Out) +{ + try + { + constexpr auto ReadSize = std::size_t{4096}; + auto FileStream = std::ifstream{Path}; + + std::string Buf(ReadSize, '\0'); + while (FileStream.read(&Buf[0], ReadSize)) + { + Out.Append(std::string_view(&Buf[0], FileStream.gcount())); + } + Out.Append(std::string_view(&Buf[0], FileStream.gcount())); + + return true; + } + catch (std::exception&) + { + Out.Reset(); + return false; + } +} + +HttpHealthService::HttpHealthService() +{ + m_Router.RegisterRoute( + "", + [](HttpRouterRequest& RoutedReq) { + HttpServerRequest& HttpReq = RoutedReq.ServerRequest(); + HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, u8"OK!"sv); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "info", + [this](HttpRouterRequest& RoutedReq) { + HttpServerRequest& HttpReq = RoutedReq.ServerRequest(); + + CbObjectWriter Writer; + Writer << "DataRoot"sv << m_HealthInfo.DataRoot.string(); + Writer << "AbsLogPath"sv << m_HealthInfo.AbsLogPath.string(); + Writer << "BuildVersion"sv << m_HealthInfo.BuildVersion; + Writer << "HttpServerClass"sv << m_HealthInfo.HttpServerClass; + + HttpReq.WriteResponse(HttpResponseCode::OK, Writer.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "log", + [this](HttpRouterRequest& RoutedReq) { + HttpServerRequest& HttpReq = RoutedReq.ServerRequest(); + + zen::Log().flush(); + + std::filesystem::path Path = + m_HealthInfo.AbsLogPath.empty() ? m_HealthInfo.DataRoot / "logs/zenserver.log" : m_HealthInfo.AbsLogPath; + + ExtendableStringBuilder<4096> Sb; + if (ReadFile(Path.string(), Sb) && Sb.Size() > 0) + { + HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Sb.ToView()); + } + else + { + HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + }, + HttpVerb::kGet); + m_Router.RegisterRoute( + "version", + [this](HttpRouterRequest& RoutedReq) { + HttpServerRequest& HttpReq = RoutedReq.ServerRequest(); + if (HttpReq.GetQueryParams().GetValue("detailed") == "true") + { + HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, ZEN_CFG_VERSION_BUILD_STRING_FULL); + } + else + { + HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, ZEN_CFG_VERSION); + } + }, + HttpVerb::kGet); +} + +void +HttpHealthService::SetHealthInfo(HealthServiceInfo&& Info) +{ + m_HealthInfo = std::move(Info); +} + +const char* +HttpHealthService::BaseUri() const +{ + return "/health/"; +} + +void +HttpHealthService::HandleRequest(HttpServerRequest& Request) +{ + if (!m_Router.HandleRequest(Request)) + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, u8"OK!"sv); + } +} + +} // namespace zen diff --git a/src/zenserver/diag/diagsvcs.h b/src/zenserver/diag/diagsvcs.h new file mode 100644 index 000000000..bd03f8023 --- /dev/null +++ b/src/zenserver/diag/diagsvcs.h @@ -0,0 +1,111 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/iobuffer.h> +#include <zenhttp/httpserver.h> + +#include <filesystem> + +////////////////////////////////////////////////////////////////////////// + +namespace zen { + +class HttpTestService : public HttpService +{ + uint32_t LogPoint = 0; + +public: + HttpTestService() {} + ~HttpTestService() = default; + + virtual const char* BaseUri() const override { return "/test/"; } + + virtual void HandleRequest(HttpServerRequest& Request) override + { + using namespace std::literals; + + auto Uri = Request.RelativeUri(); + + if (Uri == "hello"sv) + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, u8"hello world!"sv); + + // OutputLogMessageInternal(&LogPoint, 0, 0); + } + else if (Uri == "1K"sv) + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, m_1k); + } + else if (Uri == "1M"sv) + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, m_1m); + } + else if (Uri == "1M_1k"sv) + { + std::vector<IoBuffer> Buffers; + Buffers.reserve(1024); + + for (int i = 0; i < 1024; ++i) + { + Buffers.push_back(m_1k); + } + + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, Buffers); + } + else if (Uri == "1G"sv) + { + std::vector<IoBuffer> Buffers; + Buffers.reserve(1024); + + for (int i = 0; i < 1024; ++i) + { + Buffers.push_back(m_1m); + } + + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, Buffers); + } + else if (Uri == "1G_1k"sv) + { + std::vector<IoBuffer> Buffers; + Buffers.reserve(1024 * 1024); + + for (int i = 0; i < 1024 * 1024; ++i) + { + Buffers.push_back(m_1k); + } + + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, Buffers); + } + } + +private: + IoBuffer m_1m{1024 * 1024}; + IoBuffer m_1k{m_1m, 0u, 1024}; +}; + +struct HealthServiceInfo +{ + std::filesystem::path DataRoot; + std::filesystem::path AbsLogPath; + std::string HttpServerClass; + std::string BuildVersion; +}; + +class HttpHealthService : public HttpService +{ +public: + HttpHealthService(); + ~HttpHealthService() = default; + + void SetHealthInfo(HealthServiceInfo&& Info); + + virtual const char* BaseUri() const override; + virtual void HandleRequest(HttpServerRequest& Request) override final; + +private: + HttpRequestRouter m_Router; + HealthServiceInfo m_HealthInfo; +}; + +} // namespace zen diff --git a/src/zenserver/diag/formatters.h b/src/zenserver/diag/formatters.h new file mode 100644 index 000000000..759df58d3 --- /dev/null +++ b/src/zenserver/diag/formatters.h @@ -0,0 +1,71 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/compactbinary.h> +#include <zencore/compactbinaryvalidation.h> +#include <zencore/iobuffer.h> +#include <zencore/string.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <cpr/cpr.h> +#include <fmt/format.h> +ZEN_THIRD_PARTY_INCLUDES_END + +template<> +struct fmt::formatter<cpr::Response> +{ + constexpr auto parse(format_parse_context& Ctx) -> decltype(Ctx.begin()) { return Ctx.end(); } + + template<typename FormatContext> + auto format(const cpr::Response& Response, FormatContext& Ctx) -> decltype(Ctx.out()) + { + using namespace std::literals; + + if (Response.status_code == 200 || Response.status_code == 201) + { + return fmt::format_to(Ctx.out(), + "Url: {}, Status: {}, Bytes: {}/{} (Up/Down), Elapsed: {}s", + Response.url.str(), + Response.status_code, + Response.uploaded_bytes, + Response.downloaded_bytes, + Response.elapsed); + } + else + { + const auto It = Response.header.find("Content-Type"); + const std::string_view ContentType = It != Response.header.end() ? It->second : "<None>"sv; + + if (ContentType == "application/x-ue-cb"sv) + { + zen::IoBuffer Body(zen::IoBuffer::Wrap, Response.text.data(), Response.text.size()); + zen::CbObjectView Obj(Body.Data()); + zen::ExtendableStringBuilder<256> Sb; + std::string_view Json = Obj.ToJson(Sb).ToView(); + + return fmt::format_to(Ctx.out(), + "Url: {}, Status: {}, Bytes: {}/{} (Up/Down), Elapsed: {}s, Response: '{}', Reason: '{}'", + Response.url.str(), + Response.status_code, + Response.uploaded_bytes, + Response.downloaded_bytes, + Response.elapsed, + Json, + Response.reason); + } + else + { + return fmt::format_to(Ctx.out(), + "Url: {}, Status: {}, Bytes: {}/{} (Up/Down), Elapsed: {}s, Reponse: '{}', Reason: '{}'", + Response.url.str(), + Response.status_code, + Response.uploaded_bytes, + Response.downloaded_bytes, + Response.elapsed, + Response.text, + Response.reason); + } + } + } +}; diff --git a/src/zenserver/diag/logging.cpp b/src/zenserver/diag/logging.cpp new file mode 100644 index 000000000..24c7572f4 --- /dev/null +++ b/src/zenserver/diag/logging.cpp @@ -0,0 +1,467 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "logging.h" + +#include "config.h" + +ZEN_THIRD_PARTY_INCLUDES_START +#include <spdlog/async.h> +#include <spdlog/async_logger.h> +#include <spdlog/pattern_formatter.h> +#include <spdlog/sinks/ansicolor_sink.h> +#include <spdlog/sinks/basic_file_sink.h> +#include <spdlog/sinks/daily_file_sink.h> +#include <spdlog/sinks/msvc_sink.h> +#include <spdlog/sinks/rotating_file_sink.h> +#include <spdlog/sinks/stdout_color_sinks.h> +ZEN_THIRD_PARTY_INCLUDES_END + +#include <zencore/compactbinary.h> +#include <zencore/filesystem.h> +#include <zencore/string.h> + +#include <chrono> +#include <memory> + +// Custom logging -- test code, this should be tweaked + +namespace logging { + +using namespace spdlog; +using namespace spdlog::details; +using namespace std::literals; + +class full_formatter final : public spdlog::formatter +{ +public: + full_formatter(std::string_view LogId, std::chrono::time_point<std::chrono::system_clock> Epoch) : m_Epoch(Epoch), m_LogId(LogId) {} + + virtual std::unique_ptr<formatter> clone() const override { return std::make_unique<full_formatter>(m_LogId, m_Epoch); } + + static constexpr bool UseDate = false; + + virtual void format(const details::log_msg& msg, memory_buf_t& dest) override + { + using std::chrono::duration_cast; + using std::chrono::milliseconds; + using std::chrono::seconds; + + if constexpr (UseDate) + { + auto secs = std::chrono::duration_cast<seconds>(msg.time.time_since_epoch()); + if (secs != m_LastLogSecs) + { + m_CachedTm = os::localtime(log_clock::to_time_t(msg.time)); + m_LastLogSecs = secs; + } + } + + const auto& tm_time = m_CachedTm; + + // cache the date/time part for the next second. + auto duration = msg.time - m_Epoch; + auto secs = duration_cast<seconds>(duration); + + if (m_CacheTimestamp != secs || m_CachedDatetime.size() == 0) + { + m_CachedDatetime.clear(); + m_CachedDatetime.push_back('['); + + if constexpr (UseDate) + { + fmt_helper::append_int(tm_time.tm_year + 1900, m_CachedDatetime); + m_CachedDatetime.push_back('-'); + + fmt_helper::pad2(tm_time.tm_mon + 1, m_CachedDatetime); + m_CachedDatetime.push_back('-'); + + fmt_helper::pad2(tm_time.tm_mday, m_CachedDatetime); + m_CachedDatetime.push_back(' '); + + fmt_helper::pad2(tm_time.tm_hour, m_CachedDatetime); + m_CachedDatetime.push_back(':'); + + fmt_helper::pad2(tm_time.tm_min, m_CachedDatetime); + m_CachedDatetime.push_back(':'); + + fmt_helper::pad2(tm_time.tm_sec, m_CachedDatetime); + } + else + { + int Count = int(secs.count()); + + const int LogSecs = Count % 60; + Count /= 60; + + const int LogMins = Count % 60; + Count /= 60; + + const int LogHours = Count; + + fmt_helper::pad2(LogHours, m_CachedDatetime); + m_CachedDatetime.push_back(':'); + fmt_helper::pad2(LogMins, m_CachedDatetime); + m_CachedDatetime.push_back(':'); + fmt_helper::pad2(LogSecs, m_CachedDatetime); + } + + m_CachedDatetime.push_back('.'); + + m_CacheTimestamp = secs; + } + + dest.append(m_CachedDatetime.begin(), m_CachedDatetime.end()); + + auto millis = fmt_helper::time_fraction<milliseconds>(msg.time); + fmt_helper::pad3(static_cast<uint32_t>(millis.count()), dest); + dest.push_back(']'); + dest.push_back(' '); + + if (!m_LogId.empty()) + { + dest.push_back('['); + fmt_helper::append_string_view(m_LogId, dest); + dest.push_back(']'); + dest.push_back(' '); + } + + // append logger name if exists + if (msg.logger_name.size() > 0) + { + dest.push_back('['); + fmt_helper::append_string_view(msg.logger_name, dest); + dest.push_back(']'); + dest.push_back(' '); + } + + dest.push_back('['); + // wrap the level name with color + msg.color_range_start = dest.size(); + fmt_helper::append_string_view(level::to_string_view(msg.level), dest); + msg.color_range_end = dest.size(); + dest.push_back(']'); + dest.push_back(' '); + + // add source location if present + if (!msg.source.empty()) + { + dest.push_back('['); + const char* filename = details::short_filename_formatter<details::null_scoped_padder>::basename(msg.source.filename); + fmt_helper::append_string_view(filename, dest); + dest.push_back(':'); + fmt_helper::append_int(msg.source.line, dest); + dest.push_back(']'); + dest.push_back(' '); + } + + fmt_helper::append_string_view(msg.payload, dest); + fmt_helper::append_string_view("\n"sv, dest); + } + +private: + std::chrono::time_point<std::chrono::system_clock> m_Epoch; + std::tm m_CachedTm; + std::chrono::seconds m_LastLogSecs; + std::chrono::seconds m_CacheTimestamp{0}; + memory_buf_t m_CachedDatetime; + std::string m_LogId; +}; + +class json_formatter final : public spdlog::formatter +{ +public: + json_formatter(std::string_view LogId) : m_LogId(LogId) {} + + virtual std::unique_ptr<formatter> clone() const override { return std::make_unique<json_formatter>(m_LogId); } + + virtual void format(const details::log_msg& msg, memory_buf_t& dest) override + { + using std::chrono::duration_cast; + using std::chrono::milliseconds; + using std::chrono::seconds; + + auto secs = std::chrono::duration_cast<seconds>(msg.time.time_since_epoch()); + if (secs != m_LastLogSecs) + { + m_CachedTm = os::localtime(log_clock::to_time_t(msg.time)); + m_LastLogSecs = secs; + } + + const auto& tm_time = m_CachedTm; + + // cache the date/time part for the next second. + + if (m_CacheTimestamp != secs || m_CachedDatetime.size() == 0) + { + m_CachedDatetime.clear(); + + fmt_helper::append_int(tm_time.tm_year + 1900, m_CachedDatetime); + m_CachedDatetime.push_back('-'); + + fmt_helper::pad2(tm_time.tm_mon + 1, m_CachedDatetime); + m_CachedDatetime.push_back('-'); + + fmt_helper::pad2(tm_time.tm_mday, m_CachedDatetime); + m_CachedDatetime.push_back(' '); + + fmt_helper::pad2(tm_time.tm_hour, m_CachedDatetime); + m_CachedDatetime.push_back(':'); + + fmt_helper::pad2(tm_time.tm_min, m_CachedDatetime); + m_CachedDatetime.push_back(':'); + + fmt_helper::pad2(tm_time.tm_sec, m_CachedDatetime); + + m_CachedDatetime.push_back('.'); + + m_CacheTimestamp = secs; + } + dest.append("{"sv); + dest.append("\"time\": \""sv); + dest.append(m_CachedDatetime.begin(), m_CachedDatetime.end()); + auto millis = fmt_helper::time_fraction<milliseconds>(msg.time); + fmt_helper::pad3(static_cast<uint32_t>(millis.count()), dest); + dest.append("\", "sv); + + dest.append("\"status\": \""sv); + dest.append(level::to_string_view(msg.level)); + dest.append("\", "sv); + + dest.append("\"source\": \""sv); + dest.append("zenserver"sv); + dest.append("\", "sv); + + dest.append("\"service\": \""sv); + dest.append("zencache"sv); + dest.append("\", "sv); + + if (!m_LogId.empty()) + { + dest.append("\"id\": \""sv); + dest.append(m_LogId); + dest.append("\", "sv); + } + + if (msg.logger_name.size() > 0) + { + dest.append("\"logger.name\": \""sv); + dest.append(msg.logger_name); + dest.append("\", "sv); + } + + if (msg.thread_id != 0) + { + dest.append("\"logger.thread_name\": \""sv); + fmt_helper::pad_uint(msg.thread_id, 0, dest); + dest.append("\", "sv); + } + + if (!msg.source.empty()) + { + dest.append("\"file\": \""sv); + WriteEscapedString(dest, details::short_filename_formatter<details::null_scoped_padder>::basename(msg.source.filename)); + dest.append("\","sv); + + dest.append("\"line\": \""sv); + dest.append(fmt::format("{}", msg.source.line)); + dest.append("\","sv); + + dest.append("\"logger.method_name\": \""sv); + WriteEscapedString(dest, msg.source.funcname); + dest.append("\", "sv); + } + + dest.append("\"message\": \""sv); + WriteEscapedString(dest, msg.payload); + dest.append("\""sv); + + dest.append("}\n"sv); + } + +private: + static inline const std::unordered_map<char, std::string_view> SpecialCharacterMap{{'\b', "\\b"sv}, + {'\f', "\\f"sv}, + {'\n', "\\n"sv}, + {'\r', "\\r"sv}, + {'\t', "\\t"sv}, + {'"', "\\\""sv}, + {'\\', "\\\\"sv}}; + + static void WriteEscapedString(memory_buf_t& dest, const spdlog::string_view_t& payload) + { + const char* RangeStart = payload.begin(); + for (const char* It = RangeStart; It != payload.end(); ++It) + { + if (auto SpecialIt = SpecialCharacterMap.find(*It); SpecialIt != SpecialCharacterMap.end()) + { + if (RangeStart != It) + { + dest.append(RangeStart, It); + } + dest.append(SpecialIt->second); + RangeStart = It + 1; + } + } + if (RangeStart != payload.end()) + { + dest.append(RangeStart, payload.end()); + } + }; + + std::tm m_CachedTm{0, 0, 0, 0, 0, 0, 0, 0, 0}; + std::chrono::seconds m_LastLogSecs{0}; + std::chrono::seconds m_CacheTimestamp{0}; + memory_buf_t m_CachedDatetime; + std::string m_LogId; +}; + +bool +EnableVTMode() +{ +#if ZEN_PLATFORM_WINDOWS + // Set output mode to handle virtual terminal sequences + HANDLE hOut = GetStdHandle(STD_OUTPUT_HANDLE); + if (hOut == INVALID_HANDLE_VALUE) + { + return false; + } + + DWORD dwMode = 0; + if (!GetConsoleMode(hOut, &dwMode)) + { + return false; + } + + dwMode |= ENABLE_VIRTUAL_TERMINAL_PROCESSING; + if (!SetConsoleMode(hOut, dwMode)) + { + return false; + } +#endif + + return true; +} + +} // namespace logging + +void +InitializeLogging(const ZenServerOptions& GlobalOptions) +{ + zen::logging::InitializeLogging(); + logging::EnableVTMode(); + + bool IsAsync = true; + spdlog::level::level_enum LogLevel = spdlog::level::info; + + if (GlobalOptions.IsDebug) + { + LogLevel = spdlog::level::debug; + IsAsync = false; + } + + if (GlobalOptions.IsTest) + { + LogLevel = spdlog::level::trace; + IsAsync = false; + } + + if (IsAsync) + { + const int QueueSize = 8192; + const int ThreadCount = 1; + spdlog::init_thread_pool(QueueSize, ThreadCount); + + auto AsyncLogger = spdlog::create_async<spdlog::sinks::ansicolor_stdout_sink_mt>("main"); + zen::logging::SetDefault(AsyncLogger); + } + + // Sinks + + auto ConsoleSink = std::make_shared<spdlog::sinks::ansicolor_stdout_sink_mt>(); + + // spdlog can't create directories that starts with `\\?\` so we make sure the folder exists before creating the logger instance + zen::CreateDirectories(GlobalOptions.AbsLogFile.parent_path()); + +#if 0 + auto FileSink = std::make_shared<spdlog::sinks::daily_file_sink_mt>(zen::PathToUtf8(GlobalOptions.AbsLogFile), + 0, + 0, + /* truncate */ false, + uint16_t(/* max files */ 14)); +#else + auto FileSink = std::make_shared<spdlog::sinks::rotating_file_sink_mt>(zen::PathToUtf8(GlobalOptions.AbsLogFile), + /* max size */ 128 * 1024 * 1024, + /* max files */ 16, + /* rotate on open */ true); +#endif + + std::set_terminate([]() { ZEN_CRITICAL("Program exited abnormally via std::terminate()"); }); + + // Default + + auto& DefaultLogger = zen::logging::Default(); + auto& Sinks = DefaultLogger.sinks(); + + Sinks.clear(); + Sinks.push_back(ConsoleSink); + Sinks.push_back(FileSink); + +#if ZEN_PLATFORM_WINDOWS + if (zen::IsDebuggerPresent() && GlobalOptions.IsDebug) + { + auto DebugSink = std::make_shared<spdlog::sinks::msvc_sink_mt>(); + DebugSink->set_level(spdlog::level::debug); + Sinks.push_back(DebugSink); + } +#endif + + // HTTP server request logging + + std::filesystem::path HttpLogPath = GlobalOptions.DataDir / "logs" / "http.log"; + + // spdlog can't create directories that starts with `\\?\` so we make sure the folder exists before creating the logger instance + zen::CreateDirectories(HttpLogPath.parent_path()); + + auto HttpSink = std::make_shared<spdlog::sinks::rotating_file_sink_mt>(zen::PathToUtf8(HttpLogPath), + /* max size */ 128 * 1024 * 1024, + /* max files */ 16, + /* rotate on open */ true); + + auto HttpLogger = std::make_shared<spdlog::logger>("http_requests", HttpSink); + spdlog::register_logger(HttpLogger); + + // Jupiter - only log upstream HTTP traffic to file + + auto JupiterLogger = std::make_shared<spdlog::logger>("jupiter", FileSink); + spdlog::register_logger(JupiterLogger); + + // Zen - only log upstream HTTP traffic to file + + auto ZenClientLogger = std::make_shared<spdlog::logger>("zenclient", FileSink); + spdlog::register_logger(ZenClientLogger); + + // Configure all registered loggers according to settings + + spdlog::set_level(LogLevel); + spdlog::flush_on(spdlog::level::err); + spdlog::flush_every(std::chrono::seconds{2}); + spdlog::set_formatter(std::make_unique<logging::full_formatter>(GlobalOptions.LogId, std::chrono::system_clock::now())); + + if (GlobalOptions.AbsLogFile.extension() == ".json") + { + FileSink->set_formatter(std::make_unique<logging::json_formatter>(GlobalOptions.LogId)); + } + else + { + FileSink->set_pattern("[%C-%m-%d.%e %T] [%n] [%l] %v"); + } + DefaultLogger.info("log starting at {}", zen::DateTime::Now().ToIso8601()); +} + +void +ShutdownLogging() +{ + auto& DefaultLogger = zen::logging::Default(); + DefaultLogger.info("log ending at {}", zen::DateTime::Now().ToIso8601()); + zen::logging::ShutdownLogging(); +} diff --git a/src/zenserver/diag/logging.h b/src/zenserver/diag/logging.h new file mode 100644 index 000000000..8df49f842 --- /dev/null +++ b/src/zenserver/diag/logging.h @@ -0,0 +1,10 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/logging.h> +struct ZenServerOptions; + +void InitializeLogging(const ZenServerOptions& GlobalOptions); + +void ShutdownLogging(); diff --git a/src/zenserver/frontend/frontend.cpp b/src/zenserver/frontend/frontend.cpp new file mode 100644 index 000000000..149d97924 --- /dev/null +++ b/src/zenserver/frontend/frontend.cpp @@ -0,0 +1,128 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "frontend.h" + +#include <zencore/endian.h> +#include <zencore/filesystem.h> +#include <zencore/string.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#if ZEN_PLATFORM_WINDOWS +# include <Windows.h> +#endif +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen { + +//////////////////////////////////////////////////////////////////////////////// +HttpFrontendService::HttpFrontendService(std::filesystem::path Directory) : m_Directory(Directory) +{ + std::filesystem::path SelfPath = GetRunningExecutablePath(); + + // Locate a .zip file appended onto the end of this binary + IoBuffer SelfBuffer = IoBufferBuilder::MakeFromFile(SelfPath); + m_ZipFs = ZipFs(std::move(SelfBuffer)); + +#if ZEN_BUILD_DEBUG + if (!Directory.empty()) + { + return; + } + + std::error_code ErrorCode; + auto Path = SelfPath; + while (Path.has_parent_path()) + { + auto ParentPath = Path.parent_path(); + if (ParentPath == Path) + { + break; + } + if (std::filesystem::is_regular_file(ParentPath / "xmake.lua", ErrorCode)) + { + if (ErrorCode) + { + break; + } + + auto HtmlDir = ParentPath / "zenserver" / "frontend" / "html"; + if (std::filesystem::is_directory(HtmlDir, ErrorCode)) + { + m_Directory = HtmlDir; + } + break; + } + Path = ParentPath; + }; +#endif +} + +//////////////////////////////////////////////////////////////////////////////// +HttpFrontendService::~HttpFrontendService() +{ +} + +//////////////////////////////////////////////////////////////////////////////// +const char* +HttpFrontendService::BaseUri() const +{ + return "/dashboard"; // in order to use the root path we need to remove HttpAddUrlToUrlGroup in HttpSys.cpp +} + +//////////////////////////////////////////////////////////////////////////////// +void +HttpFrontendService::HandleRequest(zen::HttpServerRequest& Request) +{ + using namespace std::literals; + + std::string_view Uri = Request.RelativeUriWithExtension(); + for (; Uri[0] == '/'; Uri = Uri.substr(1)) + ; + if (Uri.empty()) + { + Uri = "index.html"sv; + } + + // Dismiss if the URI contains .. anywhere to prevent arbitrary file reads + if (Uri.find("..") != Uri.npos) + { + return Request.WriteResponse(HttpResponseCode::Forbidden); + } + + // Map the file extension to a MIME type. To keep things constrained, only a + // small subset of file extensions is allowed + + HttpContentType ContentType = HttpContentType::kUnknownContentType; + + if (const size_t DotIndex = Uri.rfind("."); DotIndex != Uri.npos) + { + const std::string_view DotExt = Uri.substr(DotIndex + 1); + + ContentType = ParseContentType(DotExt); + } + + if (ContentType == HttpContentType::kUnknownContentType) + { + return Request.WriteResponse(HttpResponseCode::Forbidden); + } + + // The given content directory overrides any zip-fs discovered in the binary + if (!m_Directory.empty()) + { + FileContents File = ReadFile(m_Directory / Uri); + + if (!File.ErrorCode) + { + return Request.WriteResponse(HttpResponseCode::OK, ContentType, File.Data[0]); + } + } + + if (IoBuffer FileBuffer = m_ZipFs.GetFile(Uri)) + { + return Request.WriteResponse(HttpResponseCode::OK, ContentType, FileBuffer); + } + + Request.WriteResponse(HttpResponseCode::NotFound, HttpContentType::kText, "Not found"sv); +} + +} // namespace zen diff --git a/src/zenserver/frontend/frontend.h b/src/zenserver/frontend/frontend.h new file mode 100644 index 000000000..6eac20620 --- /dev/null +++ b/src/zenserver/frontend/frontend.h @@ -0,0 +1,25 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenhttp/httpserver.h> +#include "zipfs.h" + +#include <filesystem> + +namespace zen { + +class HttpFrontendService final : public zen::HttpService +{ +public: + HttpFrontendService(std::filesystem::path Directory); + virtual ~HttpFrontendService(); + virtual const char* BaseUri() const override; + virtual void HandleRequest(zen::HttpServerRequest& Request) override; + +private: + ZipFs m_ZipFs; + std::filesystem::path m_Directory; +}; + +} // namespace zen diff --git a/src/zenserver/frontend/html/index.html b/src/zenserver/frontend/html/index.html new file mode 100644 index 000000000..252ee621e --- /dev/null +++ b/src/zenserver/frontend/html/index.html @@ -0,0 +1,59 @@ +<!DOCTYPE html> +<html> +<head> + <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" integrity="sha384-F3w7mX95PdgyTmZZMECAngseQB83DfGTowi0iMjiWaeVhAn4FJkqJByhZMI3AhiU" crossorigin="anonymous"> + <script src="https://cdn.jsdelivr.net/npm/[email protected]/dist/js/bootstrap.min.js" integrity="sha384-skAcpIdS7UcVUC05LJ9Dxay8AXcDYfBJqt1CJ85S/CFujBsIzCIv+l9liuYLaMQ/" crossorigin="anonymous"></script> + <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/font/bootstrap-icons.css"> + <style type="text/css"> + body { + background-color: #fafafa; + } + </style> + <script type="text/javascript"> + const getCacheStats = () => { + const opts = { headers: { "Accept": "application/json" } }; + fetch("/stats/z$", opts) + .then(response => { + if (!response.ok) { + throw Error(response.statusText); + } + return response.json(); + }) + .then(json => { + document.getElementById("status").innerHTML = "connected" + document.getElementById("stats").innerHTML = JSON.stringify(json, null, 4); + }) + .catch(error => { + document.getElementById("status").innerHTML = "disconnected" + document.getElementById("stats").innerHTML = "" + console.log(error); + }) + .finally(() => { + window.setTimeout(getCacheStats, 1000); + }); + }; + getCacheStats(); + </script> +</head> +<body> + <div class="container"> + <div class="row"> + <div class="text-center mt-5"> + <pre> +__________ _________ __ +\____ / ____ ____ / _____/_/ |_ ____ _______ ____ + / / _/ __ \ / \ \_____ \ \ __\ / _ \ \_ __ \_/ __ \ + / /_ \ ___/ | | \ / \ | | ( <_> ) | | \/\ ___/ +/_______ \ \___ >|___| //_______ / |__| \____/ |__| \___ > + \/ \/ \/ \/ \/ + </pre> + <pre id="status"/> + </div> + </div> + <div class="row"> + <pre class="mb-0">Z$:</pre> + <pre id="stats"></pre> + <div> + </div> +</body> +</html> diff --git a/src/zenserver/frontend/zipfs.cpp b/src/zenserver/frontend/zipfs.cpp new file mode 100644 index 000000000..f9c2bc8ff --- /dev/null +++ b/src/zenserver/frontend/zipfs.cpp @@ -0,0 +1,169 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "zipfs.h" + +namespace zen { + +////////////////////////////////////////////////////////////////////////// +namespace { + +#if ZEN_COMPILER_MSC +# pragma warning(push) +# pragma warning(disable : 4200) +#endif + + using ZipInt16 = uint16_t; + + struct ZipInt32 + { + operator uint32_t() const { return *(uint32_t*)Parts; } + uint16_t Parts[2]; + }; + + struct EocdRecord + { + enum : uint32_t + { + Magic = 0x0605'4b50, + }; + ZipInt32 Signature; + ZipInt16 ThisDiskIndex; + ZipInt16 CdStartDiskIndex; + ZipInt16 CdRecordThisDiskCount; + ZipInt16 CdRecordCount; + ZipInt32 CdSize; + ZipInt32 CdOffset; + ZipInt16 CommentSize; + char Comment[]; + }; + + struct CentralDirectoryRecord + { + enum : uint32_t + { + Magic = 0x0201'4b50, + }; + + ZipInt32 Signature; + ZipInt16 VersionMadeBy; + ZipInt16 VersionRequired; + ZipInt16 Flags; + ZipInt16 CompressionMethod; + ZipInt16 LastModTime; + ZipInt16 LastModDate; + ZipInt32 Crc32; + ZipInt32 CompressedSize; + ZipInt32 OriginalSize; + ZipInt16 FileNameLength; + ZipInt16 ExtraFieldLength; + ZipInt16 CommentLength; + ZipInt16 DiskIndex; + ZipInt16 InternalFileAttr; + ZipInt32 ExternalFileAttr; + ZipInt32 Offset; + char FileName[]; + }; + + struct LocalFileHeader + { + enum : uint32_t + { + Magic = 0x0403'4b50, + }; + + ZipInt32 Signature; + ZipInt16 VersionRequired; + ZipInt16 Flags; + ZipInt16 CompressionMethod; + ZipInt16 LastModTime; + ZipInt16 LastModDate; + ZipInt32 Crc32; + ZipInt32 CompressedSize; + ZipInt32 OriginalSize; + ZipInt16 FileNameLength; + ZipInt16 ExtraFieldLength; + char FileName[]; + }; + +#if ZEN_COMPILER_MSC +# pragma warning(pop) +#endif + +} // namespace + +////////////////////////////////////////////////////////////////////////// +ZipFs::ZipFs(IoBuffer&& Buffer) +{ + MemoryView View = Buffer.GetView(); + + uint8_t* Cursor = (uint8_t*)(View.GetData()) + View.GetSize(); + if (View.GetSize() < sizeof(EocdRecord)) + { + return; + } + + const auto* EocdCursor = (EocdRecord*)(Cursor - sizeof(EocdRecord)); + + // It is more correct to search backwards for EocdRecord::Magic as the + // comment can be of a variable length. But here we're not going to support + // zip files with comments. + if (EocdCursor->Signature != EocdRecord::Magic) + { + return; + } + + // Zip64 isn't supported either + if (EocdCursor->ThisDiskIndex == 0xffff) + { + return; + } + + Cursor = (uint8_t*)EocdCursor - uint32_t(EocdCursor->CdOffset) - uint32_t(EocdCursor->CdSize); + + const auto* CdCursor = (CentralDirectoryRecord*)(Cursor + EocdCursor->CdOffset); + for (int i = 0, n = EocdCursor->CdRecordCount; i < n; ++i) + { + const CentralDirectoryRecord& Cd = *CdCursor; + + bool Acceptable = true; + Acceptable &= (Cd.OriginalSize > 0); // has some content + Acceptable &= (Cd.CompressionMethod == 0); // is stored uncomrpessed + if (Acceptable) + { + const uint8_t* Lfh = Cursor + Cd.Offset; + if (uintptr_t(Lfh - Cursor) < View.GetSize()) + { + std::string_view FileName(Cd.FileName, Cd.FileNameLength); + m_Files.insert(std::make_pair(FileName, FileItem{Lfh, size_t(0)})); + } + } + + uint32_t ExtraBytes = Cd.FileNameLength + Cd.ExtraFieldLength + Cd.CommentLength; + CdCursor = (CentralDirectoryRecord*)(Cd.FileName + ExtraBytes); + } + + m_Buffer = std::move(Buffer); +} + +////////////////////////////////////////////////////////////////////////// +IoBuffer +ZipFs::GetFile(const std::string_view& FileName) const +{ + FileMap::iterator Iter = m_Files.find(FileName); + if (Iter == m_Files.end()) + { + return {}; + } + + FileItem& Item = Iter->second; + if (Item.GetSize() > 0) + { + return IoBuffer(IoBuffer::Wrap, Item.GetData(), Item.GetSize()); + } + + const auto* Lfh = (LocalFileHeader*)(Item.GetData()); + Item = MemoryView(Lfh->FileName + Lfh->FileNameLength + Lfh->ExtraFieldLength, Lfh->OriginalSize); + return IoBuffer(IoBuffer::Wrap, Item.GetData(), Item.GetSize()); +} + +} // namespace zen diff --git a/src/zenserver/frontend/zipfs.h b/src/zenserver/frontend/zipfs.h new file mode 100644 index 000000000..e1fa4457c --- /dev/null +++ b/src/zenserver/frontend/zipfs.h @@ -0,0 +1,26 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/iobuffer.h> + +#include <unordered_map> + +namespace zen { + +////////////////////////////////////////////////////////////////////////// +class ZipFs +{ +public: + ZipFs() = default; + ZipFs(IoBuffer&& Buffer); + IoBuffer GetFile(const std::string_view& FileName) const; + +private: + using FileItem = MemoryView; + using FileMap = std::unordered_map<std::string_view, FileItem>; + FileMap mutable m_Files; + IoBuffer m_Buffer; +}; + +} // namespace zen diff --git a/src/zenserver/monitoring/httpstats.cpp b/src/zenserver/monitoring/httpstats.cpp new file mode 100644 index 000000000..4d985f8c2 --- /dev/null +++ b/src/zenserver/monitoring/httpstats.cpp @@ -0,0 +1,62 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "httpstats.h" + +namespace zen { + +HttpStatsService::HttpStatsService() : m_Log(logging::Get("stats")) +{ +} + +HttpStatsService::~HttpStatsService() +{ +} + +const char* +HttpStatsService::BaseUri() const +{ + return "/stats/"; +} + +void +HttpStatsService::RegisterHandler(std::string_view Id, IHttpStatsProvider& Provider) +{ + RwLock::ExclusiveLockScope _(m_Lock); + m_Providers.insert_or_assign(std::string(Id), &Provider); +} + +void +HttpStatsService::UnregisterHandler(std::string_view Id, IHttpStatsProvider& Provider) +{ + ZEN_UNUSED(Provider); + + RwLock::ExclusiveLockScope _(m_Lock); + m_Providers.erase(std::string(Id)); +} + +void +HttpStatsService::HandleRequest(HttpServerRequest& Request) +{ + using namespace std::literals; + + std::string_view Key = Request.RelativeUri(); + + switch (Request.RequestVerb()) + { + case HttpVerb::kHead: + case HttpVerb::kGet: + { + RwLock::SharedLockScope _(m_Lock); + if (auto It = m_Providers.find(std::string{Key}); It != end(m_Providers)) + { + return It->second->HandleStatsRequest(Request); + } + } + + [[fallthrough]]; + default: + return; + } +} + +} // namespace zen diff --git a/src/zenserver/monitoring/httpstats.h b/src/zenserver/monitoring/httpstats.h new file mode 100644 index 000000000..732815a9a --- /dev/null +++ b/src/zenserver/monitoring/httpstats.h @@ -0,0 +1,38 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/logging.h> +#include <zenhttp/httpserver.h> + +#include <map> + +namespace zen { + +struct IHttpStatsProvider +{ + virtual void HandleStatsRequest(HttpServerRequest& Request) = 0; +}; + +class HttpStatsService : public HttpService +{ +public: + HttpStatsService(); + ~HttpStatsService(); + + virtual const char* BaseUri() const override; + virtual void HandleRequest(HttpServerRequest& Request) override; + void RegisterHandler(std::string_view Id, IHttpStatsProvider& Provider); + void UnregisterHandler(std::string_view Id, IHttpStatsProvider& Provider); + +private: + spdlog::logger& m_Log; + HttpRequestRouter m_Router; + + inline spdlog::logger& Log() { return m_Log; } + + RwLock m_Lock; + std::map<std::string, IHttpStatsProvider*> m_Providers; +}; + +} // namespace zen
\ No newline at end of file diff --git a/src/zenserver/monitoring/httpstatus.cpp b/src/zenserver/monitoring/httpstatus.cpp new file mode 100644 index 000000000..8b10601dd --- /dev/null +++ b/src/zenserver/monitoring/httpstatus.cpp @@ -0,0 +1,62 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "httpstatus.h" + +namespace zen { + +HttpStatusService::HttpStatusService() : m_Log(logging::Get("status")) +{ +} + +HttpStatusService::~HttpStatusService() +{ +} + +const char* +HttpStatusService::BaseUri() const +{ + return "/status/"; +} + +void +HttpStatusService::RegisterHandler(std::string_view Id, IHttpStatusProvider& Provider) +{ + RwLock::ExclusiveLockScope _(m_Lock); + m_Providers.insert_or_assign(std::string(Id), &Provider); +} + +void +HttpStatusService::UnregisterHandler(std::string_view Id, IHttpStatusProvider& Provider) +{ + ZEN_UNUSED(Provider); + + RwLock::ExclusiveLockScope _(m_Lock); + m_Providers.erase(std::string(Id)); +} + +void +HttpStatusService::HandleRequest(HttpServerRequest& Request) +{ + using namespace std::literals; + + std::string_view Key = Request.RelativeUri(); + + switch (Request.RequestVerb()) + { + case HttpVerb::kHead: + case HttpVerb::kGet: + { + RwLock::SharedLockScope _(m_Lock); + if (auto It = m_Providers.find(std::string{Key}); It != end(m_Providers)) + { + return It->second->HandleStatusRequest(Request); + } + } + + [[fallthrough]]; + default: + return; + } +} + +} // namespace zen diff --git a/src/zenserver/monitoring/httpstatus.h b/src/zenserver/monitoring/httpstatus.h new file mode 100644 index 000000000..b04e45324 --- /dev/null +++ b/src/zenserver/monitoring/httpstatus.h @@ -0,0 +1,38 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/logging.h> +#include <zenhttp/httpserver.h> + +#include <map> + +namespace zen { + +struct IHttpStatusProvider +{ + virtual void HandleStatusRequest(HttpServerRequest& Request) = 0; +}; + +class HttpStatusService : public HttpService +{ +public: + HttpStatusService(); + ~HttpStatusService(); + + virtual const char* BaseUri() const override; + virtual void HandleRequest(HttpServerRequest& Request) override; + void RegisterHandler(std::string_view Id, IHttpStatusProvider& Provider); + void UnregisterHandler(std::string_view Id, IHttpStatusProvider& Provider); + +private: + spdlog::logger& m_Log; + HttpRequestRouter m_Router; + + RwLock m_Lock; + std::map<std::string, IHttpStatusProvider*> m_Providers; + + inline spdlog::logger& Log() { return m_Log; } +}; + +} // namespace zen
\ No newline at end of file diff --git a/src/zenserver/objectstore/objectstore.cpp b/src/zenserver/objectstore/objectstore.cpp new file mode 100644 index 000000000..e5739418e --- /dev/null +++ b/src/zenserver/objectstore/objectstore.cpp @@ -0,0 +1,232 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <objectstore/objectstore.h> + +#include <zencore/filesystem.h> +#include <zencore/fmtutils.h> +#include <zencore/logging.h> +#include <zencore/string.h> +#include "zencore/compactbinarybuilder.h" +#include "zenhttp/httpcommon.h" +#include "zenhttp/httpserver.h" + +#include <thread> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <fmt/format.h> +#include <json11.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen { + +using namespace std::literals; + +ZEN_DEFINE_LOG_CATEGORY_STATIC(LogObj, "obj"sv); + +HttpObjectStoreService::HttpObjectStoreService(ObjectStoreConfig Cfg) : m_Cfg(std::move(Cfg)) +{ + Inititalize(); +} + +HttpObjectStoreService::~HttpObjectStoreService() +{ +} + +const char* +HttpObjectStoreService::BaseUri() const +{ + return "/obj/"; +} + +void +HttpObjectStoreService::HandleRequest(zen::HttpServerRequest& Request) +{ + if (m_Router.HandleRequest(Request) == false) + { + ZEN_LOG_WARN(LogObj, "No route found for {0}", Request.RelativeUri()); + return Request.WriteResponse(HttpResponseCode::NotFound, HttpContentType::kText, "Not found"sv); + } +} + +void +HttpObjectStoreService::Inititalize() +{ + ZEN_LOG_INFO(LogObj, "Initialzing Object Store in '{}'", m_Cfg.RootDirectory); + for (const auto& Bucket : m_Cfg.Buckets) + { + ZEN_LOG_INFO(LogObj, " - bucket '{}' -> '{}'", Bucket.Name, Bucket.Directory); + } + + m_Router.RegisterRoute( + "distributionpoints/{bucket}", + [this](zen::HttpRouterRequest& Request) { + const std::string BucketName = Request.GetCapture(1); + + StringBuilder<1024> Json; + { + CbObjectWriter Writer; + Writer.BeginArray("distributions"); + Writer << fmt::format("http://localhost:{}/obj/{}", m_Cfg.ServerPort, BucketName); + Writer.EndArray(); + Writer.Save().ToJson(Json); + } + + Request.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kJSON, Json.ToString()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "{bucket}/{path}", + [this](zen::HttpRouterRequest& Request) { GetBlob(Request); }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "{bucket}/{path}", + [this](zen::HttpRouterRequest& Request) { PutBlob(Request); }, + HttpVerb::kPost | HttpVerb::kPut); +} + +std::filesystem::path +HttpObjectStoreService::GetBucketDirectory(std::string_view BucketName) +{ + std::lock_guard _(BucketsMutex); + + if (const auto It = std::find_if(std::begin(m_Cfg.Buckets), + std::end(m_Cfg.Buckets), + [&BucketName](const auto& Bucket) -> bool { return Bucket.Name == BucketName; }); + It != std::end(m_Cfg.Buckets)) + { + return It->Directory; + } + + return std::filesystem::path(); +} + +void +HttpObjectStoreService::GetBlob(zen::HttpRouterRequest& Request) +{ + namespace fs = std::filesystem; + + const std::string& BucketName = Request.GetCapture(1); + const fs::path BucketDir = GetBucketDirectory(BucketName); + + if (BucketDir.empty()) + { + ZEN_LOG_DEBUG(LogObj, "GET - [FAILED], unknown bucket '{}'", BucketName); + return Request.ServerRequest().WriteResponse(HttpResponseCode::NotFound); + } + + const fs::path RelativeBucketPath = Request.GetCapture(2); + + if (RelativeBucketPath.is_absolute() || RelativeBucketPath.string().starts_with("..")) + { + ZEN_LOG_DEBUG(LogObj, "GET - from bucket '{}' [FAILED], invalid file path", BucketName); + return Request.ServerRequest().WriteResponse(HttpResponseCode::Forbidden); + } + + fs::path FilePath = BucketDir / RelativeBucketPath; + if (fs::exists(FilePath) == false) + { + ZEN_LOG_DEBUG(LogObj, "GET - '{}/{}' [FAILED], doesn't exist", BucketName, FilePath); + return Request.ServerRequest().WriteResponse(HttpResponseCode::NotFound); + } + + zen::HttpRanges Ranges; + if (Request.ServerRequest().TryGetRanges(Ranges); Ranges.size() > 1) + { + // Only a single range is supported + return Request.ServerRequest().WriteResponse(HttpResponseCode::BadRequest); + } + + FileContents File = ReadFile(FilePath); + if (File.ErrorCode) + { + ZEN_LOG_WARN(LogObj, + "GET - '{}/{}' [FAILED] ('{}': {})", + BucketName, + FilePath, + File.ErrorCode.category().name(), + File.ErrorCode.value()); + + return Request.ServerRequest().WriteResponse(HttpResponseCode::BadRequest); + } + + const IoBuffer& FileBuf = File.Data[0]; + + if (Ranges.empty()) + { + const uint64_t TotalServed = TotalBytesServed.fetch_add(FileBuf.Size()) + FileBuf.Size(); + + ZEN_LOG_DEBUG(LogObj, + "GET - '{}/{}' ({}) [OK] (Served: {})", + BucketName, + RelativeBucketPath, + NiceBytes(FileBuf.Size()), + NiceBytes(TotalServed)); + + Request.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, FileBuf); + } + else + { + const auto Range = Ranges[0]; + const uint64_t RangeSize = Range.End - Range.Start; + const uint64_t TotalServed = TotalBytesServed.fetch_add(RangeSize) + RangeSize; + + ZEN_LOG_DEBUG(LogObj, + "GET - '{}/{}' (Range: {}-{}) ({}/{}) [OK] (Served: {})", + BucketName, + RelativeBucketPath, + Range.Start, + Range.End, + NiceBytes(RangeSize), + NiceBytes(FileBuf.Size()), + NiceBytes(TotalServed)); + + MemoryView RangeView = FileBuf.GetView().Mid(Range.Start, RangeSize); + if (RangeView.GetSize() != RangeSize) + { + return Request.ServerRequest().WriteResponse(HttpResponseCode::BadRequest); + } + + IoBuffer RangeBuf = IoBuffer(IoBuffer::Wrap, RangeView.GetData(), RangeView.GetSize()); + Request.ServerRequest().WriteResponse(HttpResponseCode::PartialContent, HttpContentType::kBinary, RangeBuf); + } +} + +void +HttpObjectStoreService::PutBlob(zen::HttpRouterRequest& Request) +{ + namespace fs = std::filesystem; + + const std::string& BucketName = Request.GetCapture(1); + const fs::path BucketDir = GetBucketDirectory(BucketName); + + if (BucketDir.empty()) + { + ZEN_LOG_DEBUG(LogObj, "PUT - [FAILED], unknown bucket '{}'", BucketName); + return Request.ServerRequest().WriteResponse(HttpResponseCode::NotFound); + } + + const fs::path RelativeBucketPath = Request.GetCapture(2); + + if (RelativeBucketPath.is_absolute() || RelativeBucketPath.string().starts_with("..")) + { + ZEN_LOG_DEBUG(LogObj, "PUT - bucket '{}' [FAILED], invalid file path", BucketName); + return Request.ServerRequest().WriteResponse(HttpResponseCode::Forbidden); + } + + fs::path FilePath = BucketDir / RelativeBucketPath; + const IoBuffer FileBuf = Request.ServerRequest().ReadPayload(); + + if (FileBuf.Size() == 0) + { + ZEN_LOG_DEBUG(LogObj, "PUT - '{}/{}' [FAILED], empty file", BucketName, FilePath); + return Request.ServerRequest().WriteResponse(HttpResponseCode::BadRequest); + } + + WriteFile(FilePath, FileBuf); + ZEN_LOG_DEBUG(LogObj, "PUT - '{}/{}' [OK] ({})", BucketName, RelativeBucketPath, NiceBytes(FileBuf.Size())); + Request.ServerRequest().WriteResponse(HttpResponseCode::OK); +} + +} // namespace zen diff --git a/src/zenserver/objectstore/objectstore.h b/src/zenserver/objectstore/objectstore.h new file mode 100644 index 000000000..eaab57794 --- /dev/null +++ b/src/zenserver/objectstore/objectstore.h @@ -0,0 +1,48 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenhttp/httpserver.h> +#include <atomic> +#include <filesystem> +#include <mutex> + +namespace zen { + +class HttpRouterRequest; + +struct ObjectStoreConfig +{ + struct BucketConfig + { + std::string Name; + std::filesystem::path Directory; + }; + + std::filesystem::path RootDirectory; + std::vector<BucketConfig> Buckets; + uint16_t ServerPort{1337}; +}; + +class HttpObjectStoreService final : public zen::HttpService +{ +public: + HttpObjectStoreService(ObjectStoreConfig Cfg); + virtual ~HttpObjectStoreService(); + + virtual const char* BaseUri() const override; + virtual void HandleRequest(zen::HttpServerRequest& Request) override; + +private: + void Inititalize(); + std::filesystem::path GetBucketDirectory(std::string_view BucketName); + void GetBlob(zen::HttpRouterRequest& Request); + void PutBlob(zen::HttpRouterRequest& Request); + + ObjectStoreConfig m_Cfg; + std::mutex BucketsMutex; + HttpRequestRouter m_Router; + std::atomic_uint64_t TotalBytesServed{0}; +}; + +} // namespace zen diff --git a/src/zenserver/projectstore/fileremoteprojectstore.cpp b/src/zenserver/projectstore/fileremoteprojectstore.cpp new file mode 100644 index 000000000..d7a34a6c2 --- /dev/null +++ b/src/zenserver/projectstore/fileremoteprojectstore.cpp @@ -0,0 +1,235 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "fileremoteprojectstore.h" + +#include <zencore/compress.h> +#include <zencore/filesystem.h> +#include <zencore/fmtutils.h> +#include <zencore/logging.h> +#include <zencore/timer.h> + +namespace zen { + +using namespace std::literals; + +class LocalExportProjectStore : public RemoteProjectStore +{ +public: + LocalExportProjectStore(std::string_view Name, + const std::filesystem::path& FolderPath, + bool ForceDisableBlocks, + bool ForceEnableTempBlocks) + : m_Name(Name) + , m_OutputPath(FolderPath) + { + if (ForceDisableBlocks) + { + m_EnableBlocks = false; + } + if (ForceEnableTempBlocks) + { + m_UseTempBlocks = true; + } + } + + virtual RemoteStoreInfo GetInfo() const override + { + return {.CreateBlocks = m_EnableBlocks, + .UseTempBlockFiles = m_UseTempBlocks, + .Description = fmt::format("[file] {}"sv, m_OutputPath)}; + } + + virtual SaveResult SaveContainer(const IoBuffer& Payload) override + { + Stopwatch Timer; + SaveResult Result; + + { + CbObject ContainerObject = LoadCompactBinaryObject(Payload); + + ContainerObject.IterateAttachments([&](CbFieldView FieldView) { + IoHash AttachmentHash = FieldView.AsBinaryAttachment(); + std::filesystem::path AttachmentPath = GetAttachmentPath(AttachmentHash); + if (!std::filesystem::exists(AttachmentPath)) + { + Result.Needs.insert(AttachmentHash); + } + }); + } + + std::filesystem::path ContainerPath = m_OutputPath; + ContainerPath.append(m_Name); + + CreateDirectories(m_OutputPath); + BasicFile ContainerFile; + ContainerFile.Open(ContainerPath, BasicFile::Mode::kTruncate); + std::error_code Ec; + ContainerFile.WriteAll(Payload, Ec); + if (Ec) + { + Result.ErrorCode = gsl::narrow<int32_t>(HttpResponseCode::InternalServerError); + Result.Reason = Ec.message(); + } + Result.RawHash = IoHash::HashBuffer(Payload); + Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500; + return Result; + } + + virtual SaveAttachmentResult SaveAttachment(const CompositeBuffer& Payload, const IoHash& RawHash) override + { + Stopwatch Timer; + SaveAttachmentResult Result; + std::filesystem::path ChunkPath = GetAttachmentPath(RawHash); + if (!std::filesystem::exists(ChunkPath)) + { + try + { + CreateDirectories(ChunkPath.parent_path()); + + BasicFile ChunkFile; + ChunkFile.Open(ChunkPath, BasicFile::Mode::kTruncate); + size_t Offset = 0; + for (const SharedBuffer& Segment : Payload.GetSegments()) + { + ChunkFile.Write(Segment.GetView(), Offset); + Offset += Segment.GetSize(); + } + } + catch (std::exception& Ex) + { + Result.ErrorCode = gsl::narrow<int32_t>(HttpResponseCode::InternalServerError); + Result.Reason = Ex.what(); + } + } + Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500; + return Result; + } + + virtual SaveAttachmentsResult SaveAttachments(const std::vector<SharedBuffer>& Chunks) override + { + Stopwatch Timer; + + for (const SharedBuffer& Chunk : Chunks) + { + CompressedBuffer Compressed = CompressedBuffer::FromCompressedNoValidate(Chunk.AsIoBuffer()); + SaveAttachmentResult ChunkResult = SaveAttachment(Compressed.GetCompressed(), Compressed.DecodeRawHash()); + if (ChunkResult.ErrorCode) + { + ChunkResult.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500; + return SaveAttachmentsResult{ChunkResult}; + } + } + SaveAttachmentsResult Result; + Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500; + return Result; + } + + virtual Result FinalizeContainer(const IoHash&) override { return {}; } + + virtual LoadContainerResult LoadContainer() override + { + Stopwatch Timer; + LoadContainerResult Result; + std::filesystem::path ContainerPath = m_OutputPath; + ContainerPath.append(m_Name); + if (!std::filesystem::is_regular_file(ContainerPath)) + { + Result.ErrorCode = gsl::narrow<int>(HttpResponseCode::NotFound); + Result.Reason = fmt::format("The file {} does not exist"sv, ContainerPath.string()); + Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500; + return Result; + } + IoBuffer ContainerPayload; + { + BasicFile ContainerFile; + ContainerFile.Open(ContainerPath, BasicFile::Mode::kRead); + ContainerPayload = ContainerFile.ReadAll(); + } + Result.ContainerObject = LoadCompactBinaryObject(ContainerPayload); + if (!Result.ContainerObject) + { + Result.ErrorCode = gsl::narrow<int32_t>(HttpResponseCode::InternalServerError); + Result.Reason = fmt::format("The file {} is not formatted as a compact binary object"sv, ContainerPath.string()); + Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500; + return Result; + } + Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500; + return Result; + } + virtual LoadAttachmentResult LoadAttachment(const IoHash& RawHash) override + { + Stopwatch Timer; + LoadAttachmentResult Result; + std::filesystem::path ChunkPath = GetAttachmentPath(RawHash); + if (!std::filesystem::is_regular_file(ChunkPath)) + { + Result.ErrorCode = gsl::narrow<int>(HttpResponseCode::NotFound); + Result.Reason = fmt::format("The file {} does not exist"sv, ChunkPath.string()); + Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500; + return Result; + } + { + BasicFile ChunkFile; + ChunkFile.Open(ChunkPath, BasicFile::Mode::kRead); + Result.Bytes = ChunkFile.ReadAll(); + } + Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500; + return Result; + } + + virtual LoadAttachmentsResult LoadAttachments(const std::vector<IoHash>& RawHashes) override + { + Stopwatch Timer; + LoadAttachmentsResult Result; + for (const IoHash& Hash : RawHashes) + { + LoadAttachmentResult ChunkResult = LoadAttachment(Hash); + if (ChunkResult.ErrorCode) + { + ChunkResult.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500; + return LoadAttachmentsResult{ChunkResult}; + } + ZEN_DEBUG("Loaded attachment in {}", NiceTimeSpanMs(static_cast<uint64_t>(ChunkResult.ElapsedSeconds * 1000))); + Result.Chunks.emplace_back( + std::pair<IoHash, CompressedBuffer>{Hash, CompressedBuffer::FromCompressedNoValidate(std::move(ChunkResult.Bytes))}); + } + return Result; + } + +private: + std::filesystem::path GetAttachmentPath(const IoHash& RawHash) const + { + ExtendablePathBuilder<128> ShardedPath; + ShardedPath.Append(m_OutputPath.c_str()); + ExtendableStringBuilder<64> HashString; + RawHash.ToHexString(HashString); + const char* str = HashString.c_str(); + ShardedPath.AppendSeparator(); + ShardedPath.AppendAsciiRange(str, str + 3); + + ShardedPath.AppendSeparator(); + ShardedPath.AppendAsciiRange(str + 3, str + 5); + + ShardedPath.AppendSeparator(); + ShardedPath.AppendAsciiRange(str + 5, str + 40); + + return ShardedPath.ToPath(); + } + + const std::string m_Name; + const std::filesystem::path m_OutputPath; + bool m_EnableBlocks = true; + bool m_UseTempBlocks = false; +}; + +std::unique_ptr<RemoteProjectStore> +CreateFileRemoteStore(const FileRemoteStoreOptions& Options) +{ + std::unique_ptr<RemoteProjectStore> RemoteStore = std::make_unique<LocalExportProjectStore>(Options.Name, + std::filesystem::path(Options.FolderPath), + Options.ForceDisableBlocks, + Options.ForceEnableTempBlocks); + return RemoteStore; +} + +} // namespace zen diff --git a/src/zenserver/projectstore/fileremoteprojectstore.h b/src/zenserver/projectstore/fileremoteprojectstore.h new file mode 100644 index 000000000..68d1eb71e --- /dev/null +++ b/src/zenserver/projectstore/fileremoteprojectstore.h @@ -0,0 +1,19 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "remoteprojectstore.h" + +namespace zen { + +struct FileRemoteStoreOptions : RemoteStoreOptions +{ + std::filesystem::path FolderPath; + std::string Name; + bool ForceDisableBlocks; + bool ForceEnableTempBlocks; +}; + +std::unique_ptr<RemoteProjectStore> CreateFileRemoteStore(const FileRemoteStoreOptions& Options); + +} // namespace zen diff --git a/src/zenserver/projectstore/jupiterremoteprojectstore.cpp b/src/zenserver/projectstore/jupiterremoteprojectstore.cpp new file mode 100644 index 000000000..66cf3c4f8 --- /dev/null +++ b/src/zenserver/projectstore/jupiterremoteprojectstore.cpp @@ -0,0 +1,244 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "jupiterremoteprojectstore.h" + +#include <zencore/compress.h> +#include <zencore/fmtutils.h> + +#include <auth/authmgr.h> +#include <upstream/jupiter.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <cpr/cpr.h> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen { + +using namespace std::literals; + +class JupiterRemoteStore : public RemoteProjectStore +{ +public: + JupiterRemoteStore(Ref<CloudCacheClient>&& CloudClient, + std::string_view Namespace, + std::string_view Bucket, + const IoHash& Key, + bool ForceDisableBlocks, + bool ForceDisableTempBlocks) + : m_CloudClient(CloudClient) + , m_Namespace(Namespace) + , m_Bucket(Bucket) + , m_Key(Key) + { + if (ForceDisableBlocks) + { + m_EnableBlocks = false; + } + if (ForceDisableTempBlocks) + { + m_UseTempBlocks = false; + } + } + + virtual RemoteStoreInfo GetInfo() const override + { + return {.CreateBlocks = m_EnableBlocks, + .UseTempBlockFiles = m_UseTempBlocks, + .Description = fmt::format("[cloud] {} as {}/{}/{}"sv, m_CloudClient->ServiceUrl(), m_Namespace, m_Bucket, m_Key)}; + } + + virtual SaveResult SaveContainer(const IoBuffer& Payload) override + { + const int32_t MaxAttempts = 3; + PutRefResult Result; + { + CloudCacheSession Session(m_CloudClient.Get()); + for (int32_t Attempt = 0; Attempt < MaxAttempts && !Result.Success; Attempt++) + { + Result = Session.PutRef(m_Namespace, m_Bucket, m_Key, Payload, ZenContentType::kCbObject); + } + } + + return SaveResult{ConvertResult(Result), {Result.Needs.begin(), Result.Needs.end()} /*, {}*/, IoHash::HashBuffer(Payload)}; + } + + virtual SaveAttachmentResult SaveAttachment(const CompositeBuffer& Payload, const IoHash& RawHash) override + { + const int32_t MaxAttempts = 3; + CloudCacheResult Result; + { + CloudCacheSession Session(m_CloudClient.Get()); + for (int32_t Attempt = 0; Attempt < MaxAttempts && !Result.Success; Attempt++) + { + Result = Session.PutCompressedBlob(m_Namespace, RawHash, Payload); + } + } + + return SaveAttachmentResult{ConvertResult(Result)}; + } + + virtual SaveAttachmentsResult SaveAttachments(const std::vector<SharedBuffer>& Chunks) override + { + SaveAttachmentsResult Result; + for (const SharedBuffer& Chunk : Chunks) + { + CompressedBuffer Compressed = CompressedBuffer::FromCompressedNoValidate(Chunk.AsIoBuffer()); + SaveAttachmentResult ChunkResult = SaveAttachment(Compressed.GetCompressed(), Compressed.DecodeRawHash()); + if (ChunkResult.ErrorCode) + { + return SaveAttachmentsResult{ChunkResult}; + } + } + return Result; + } + + virtual Result FinalizeContainer(const IoHash& RawHash) override + { + const int32_t MaxAttempts = 3; + CloudCacheResult Result; + { + CloudCacheSession Session(m_CloudClient.Get()); + for (int32_t Attempt = 0; Attempt < MaxAttempts && !Result.Success; Attempt++) + { + Result = Session.FinalizeRef(m_Namespace, m_Bucket, m_Key, RawHash); + } + } + return ConvertResult(Result); + } + + virtual LoadContainerResult LoadContainer() override + { + const int32_t MaxAttempts = 3; + CloudCacheResult Result; + { + CloudCacheSession Session(m_CloudClient.Get()); + for (int32_t Attempt = 0; Attempt < MaxAttempts && !Result.Success; Attempt++) + { + Result = Session.GetRef(m_Namespace, m_Bucket, m_Key, ZenContentType::kCbObject); + } + } + + if (Result.ErrorCode || !Result.Success) + { + return LoadContainerResult{ConvertResult(Result)}; + } + + CbObject ContainerObject = LoadCompactBinaryObject(Result.Response); + if (!ContainerObject) + { + return LoadContainerResult{ + RemoteProjectStore::Result{ + .ErrorCode = gsl::narrow<int32_t>(HttpResponseCode::InternalServerError), + .ElapsedSeconds = Result.ElapsedSeconds, + .Reason = fmt::format("The ref {}/{}/{} is not formatted as a compact binary object"sv, m_Namespace, m_Bucket, m_Key)}, + std::move(ContainerObject)}; + } + + return LoadContainerResult{ConvertResult(Result), std::move(ContainerObject)}; + } + + virtual LoadAttachmentResult LoadAttachment(const IoHash& RawHash) override + { + const int32_t MaxAttempts = 3; + CloudCacheResult Result; + { + CloudCacheSession Session(m_CloudClient.Get()); + for (int32_t Attempt = 0; Attempt < MaxAttempts && !Result.Success; Attempt++) + { + Result = Session.GetCompressedBlob(m_Namespace, RawHash); + } + } + return LoadAttachmentResult{ConvertResult(Result), std::move(Result.Response)}; + } + + virtual LoadAttachmentsResult LoadAttachments(const std::vector<IoHash>& RawHashes) override + { + LoadAttachmentsResult Result; + for (const IoHash& Hash : RawHashes) + { + LoadAttachmentResult ChunkResult = LoadAttachment(Hash); + if (ChunkResult.ErrorCode) + { + return LoadAttachmentsResult{ChunkResult}; + } + ZEN_DEBUG("Loaded attachment in {}", NiceTimeSpanMs(static_cast<uint64_t>(ChunkResult.ElapsedSeconds * 1000))); + Result.Chunks.emplace_back( + std::pair<IoHash, CompressedBuffer>{Hash, CompressedBuffer::FromCompressedNoValidate(std::move(ChunkResult.Bytes))}); + } + return Result; + } + +private: + static Result ConvertResult(const CloudCacheResult& Response) + { + std::string Text; + int32_t ErrorCode = 0; + if (Response.ErrorCode != 0) + { + ErrorCode = Response.ErrorCode; + } + else if (!Response.Success) + { + ErrorCode = gsl::narrow<int32_t>(HttpResponseCode::InternalServerError); + if (Response.Response.GetContentType() == ZenContentType::kText) + { + Text = + std::string(reinterpret_cast<const std::string::value_type*>(Response.Response.GetData()), Response.Response.GetSize()); + } + } + return {.ErrorCode = ErrorCode, .ElapsedSeconds = Response.ElapsedSeconds, .Reason = Response.Reason, .Text = Text}; + } + + Ref<CloudCacheClient> m_CloudClient; + const std::string m_Namespace; + const std::string m_Bucket; + const IoHash m_Key; + bool m_EnableBlocks = true; + bool m_UseTempBlocks = true; +}; + +std::unique_ptr<RemoteProjectStore> +CreateJupiterRemoteStore(const JupiterRemoteStoreOptions& Options) +{ + std::string Url = Options.Url; + if (Url.find("://"sv) == std::string::npos) + { + // Assume https URL + Url = fmt::format("https://{}"sv, Url); + } + CloudCacheClientOptions ClientOptions{.Name = "Remote store"sv, + .ServiceUrl = Url, + .ConnectTimeout = std::chrono::milliseconds(2000), + .Timeout = std::chrono::milliseconds(60000)}; + // 1) Access token as parameter in request + // 2) Environment variable (different win vs linux/mac) + // 3) openid-provider (assumes oidctoken.exe -Zen true has been run with matching Options.OpenIdProvider + + std::unique_ptr<CloudCacheTokenProvider> TokenProvider; + if (!Options.AccessToken.empty()) + { + TokenProvider = CloudCacheTokenProvider::CreateFromCallback([AccessToken = Options.AccessToken]() { + return CloudCacheAccessToken{.Value = AccessToken, .ExpireTime = GcClock::TimePoint::max()}; + }); + } + else + { + TokenProvider = + CloudCacheTokenProvider::CreateFromCallback([&AuthManager = Options.AuthManager, OpenIdProvider = Options.OpenIdProvider]() { + AuthMgr::OpenIdAccessToken Token = AuthManager.GetOpenIdAccessToken(OpenIdProvider.empty() ? "Default" : OpenIdProvider); + return CloudCacheAccessToken{.Value = Token.AccessToken, .ExpireTime = Token.ExpireTime}; + }); + } + + Ref<CloudCacheClient> CloudClient(new CloudCacheClient(ClientOptions, std::move(TokenProvider))); + + std::unique_ptr<RemoteProjectStore> RemoteStore = std::make_unique<JupiterRemoteStore>(std::move(CloudClient), + Options.Namespace, + Options.Bucket, + Options.Key, + Options.ForceDisableBlocks, + Options.ForceDisableTempBlocks); + return RemoteStore; +} + +} // namespace zen diff --git a/src/zenserver/projectstore/jupiterremoteprojectstore.h b/src/zenserver/projectstore/jupiterremoteprojectstore.h new file mode 100644 index 000000000..31548af22 --- /dev/null +++ b/src/zenserver/projectstore/jupiterremoteprojectstore.h @@ -0,0 +1,26 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "remoteprojectstore.h" + +namespace zen { + +class AuthMgr; + +struct JupiterRemoteStoreOptions : RemoteStoreOptions +{ + std::string Url; + std::string Namespace; + std::string Bucket; + IoHash Key; + std::string OpenIdProvider; + std::string AccessToken; + AuthMgr& AuthManager; + bool ForceDisableBlocks; + bool ForceDisableTempBlocks; +}; + +std::unique_ptr<RemoteProjectStore> CreateJupiterRemoteStore(const JupiterRemoteStoreOptions& Options); + +} // namespace zen diff --git a/src/zenserver/projectstore/projectstore.cpp b/src/zenserver/projectstore/projectstore.cpp new file mode 100644 index 000000000..847a79a1d --- /dev/null +++ b/src/zenserver/projectstore/projectstore.cpp @@ -0,0 +1,4082 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "projectstore.h" + +#include <zencore/compactbinarybuilder.h> +#include <zencore/compactbinarypackage.h> +#include <zencore/compactbinaryvalidation.h> +#include <zencore/filesystem.h> +#include <zencore/fmtutils.h> +#include <zencore/logging.h> +#include <zencore/scopeguard.h> +#include <zencore/stream.h> +#include <zencore/timer.h> +#include <zencore/trace.h> +#include <zenhttp/httpshared.h> +#include <zenstore/caslog.h> +#include <zenstore/cidstore.h> +#include <zenstore/scrubcontext.h> +#include <zenutil/cache/rpcrecording.h> + +#include "fileremoteprojectstore.h" +#include "jupiterremoteprojectstore.h" +#include "remoteprojectstore.h" +#include "zenremoteprojectstore.h" + +ZEN_THIRD_PARTY_INCLUDES_START +#include <cpr/cpr.h> +#include <xxh3.h> +ZEN_THIRD_PARTY_INCLUDES_END + +#if ZEN_WITH_TESTS +# include <zencore/testing.h> +# include <zencore/testutils.h> +#endif // ZEN_WITH_TESTS + +namespace zen { + +namespace { + bool PrepareDirectoryDelete(const std::filesystem::path& Dir, std::filesystem::path& OutDeleteDir) + { + int DropIndex = 0; + do + { + if (!std::filesystem::exists(Dir)) + { + return true; + } + + std::string DroppedName = fmt::format("[dropped]{}({})", Dir.filename().string(), DropIndex); + std::filesystem::path DroppedBucketPath = Dir.parent_path() / DroppedName; + if (std::filesystem::exists(DroppedBucketPath)) + { + DropIndex++; + continue; + } + + std::error_code Ec; + std::filesystem::rename(Dir, DroppedBucketPath, Ec); + if (!Ec) + { + OutDeleteDir = DroppedBucketPath; + return true; + } + if (Ec && !std::filesystem::exists(DroppedBucketPath)) + { + // We can't move our folder, probably because it is busy, bail.. + return false; + } + Sleep(100); + } while (true); + } + + std::pair<std::unique_ptr<RemoteProjectStore>, std::string> CreateRemoteStore(CbObjectView Params, + AuthMgr& AuthManager, + size_t MaxBlockSize, + size_t MaxChunkEmbedSize) + { + using namespace std::literals; + + std::unique_ptr<RemoteProjectStore> RemoteStore; + + if (CbObjectView File = Params["file"sv].AsObjectView(); File) + { + std::filesystem::path FolderPath(File["path"sv].AsString()); + if (FolderPath.empty()) + { + return {nullptr, "Missing file path"}; + } + std::string_view Name(File["name"sv].AsString()); + if (Name.empty()) + { + return {nullptr, "Missing file name"}; + } + bool ForceDisableBlocks = File["disableblocks"sv].AsBool(false); + bool ForceEnableTempBlocks = File["enabletempblocks"sv].AsBool(false); + + FileRemoteStoreOptions Options = {RemoteStoreOptions{.MaxBlockSize = MaxBlockSize, .MaxChunkEmbedSize = MaxChunkEmbedSize}, + FolderPath, + std::string(Name), + ForceDisableBlocks, + ForceEnableTempBlocks}; + RemoteStore = CreateFileRemoteStore(Options); + } + + if (CbObjectView Cloud = Params["cloud"sv].AsObjectView(); Cloud) + { + std::string_view CloudServiceUrl = Cloud["url"sv].AsString(); + if (CloudServiceUrl.empty()) + { + return {nullptr, "Missing service url"}; + } + + std::string Url = cpr::util::urlDecode(std::string(CloudServiceUrl)); + std::string_view Namespace = Cloud["namespace"sv].AsString(); + if (Namespace.empty()) + { + return {nullptr, "Missing namespace"}; + } + std::string_view Bucket = Cloud["bucket"sv].AsString(); + if (Bucket.empty()) + { + return {nullptr, "Missing bucket"}; + } + std::string_view OpenIdProvider = Cloud["openid-provider"sv].AsString(); + std::string AccessToken = std::string(Cloud["access-token"sv].AsString()); + if (AccessToken.empty()) + { + std::string_view AccessTokenEnvVariable = Cloud["access-token-env"].AsString(); + if (!AccessTokenEnvVariable.empty()) + { + AccessToken = GetEnvVariable(AccessTokenEnvVariable); + } + } + std::string_view KeyParam = Cloud["key"sv].AsString(); + if (KeyParam.empty()) + { + return {nullptr, "Missing key"}; + } + if (KeyParam.length() != IoHash::StringLength) + { + return {nullptr, "Invalid key"}; + } + IoHash Key = IoHash::FromHexString(KeyParam); + if (Key == IoHash::Zero) + { + return {nullptr, "Invalid key string"}; + } + bool ForceDisableBlocks = Cloud["disableblocks"sv].AsBool(false); + bool ForceDisableTempBlocks = Cloud["disabletempblocks"sv].AsBool(false); + + JupiterRemoteStoreOptions Options = {RemoteStoreOptions{.MaxBlockSize = MaxBlockSize, .MaxChunkEmbedSize = MaxChunkEmbedSize}, + Url, + std::string(Namespace), + std::string(Bucket), + Key, + std::string(OpenIdProvider), + AccessToken, + AuthManager, + ForceDisableBlocks, + ForceDisableTempBlocks}; + RemoteStore = CreateJupiterRemoteStore(Options); + } + + if (CbObjectView Zen = Params["zen"sv].AsObjectView(); Zen) + { + std::string_view Url = Zen["url"sv].AsString(); + std::string_view Project = Zen["project"sv].AsString(); + if (Project.empty()) + { + return {nullptr, "Missing project"}; + } + std::string_view Oplog = Zen["oplog"sv].AsString(); + if (Oplog.empty()) + { + return {nullptr, "Missing oplog"}; + } + ZenRemoteStoreOptions Options = {RemoteStoreOptions{.MaxBlockSize = MaxBlockSize, .MaxChunkEmbedSize = MaxChunkEmbedSize}, + std::string(Url), + std::string(Project), + std::string(Oplog)}; + RemoteStore = CreateZenRemoteStore(Options); + } + + if (!RemoteStore) + { + return {nullptr, "Unknown remote store type"}; + } + + return {std::move(RemoteStore), ""}; + } + + std::pair<HttpResponseCode, std::string> ConvertResult(const RemoteProjectStore::Result& Result) + { + if (Result.ErrorCode == 0) + { + return {HttpResponseCode::OK, Result.Text}; + } + return {static_cast<HttpResponseCode>(Result.ErrorCode), + Result.Reason.empty() ? Result.Text + : Result.Text.empty() ? Result.Reason + : fmt::format("{}. Reason: '{}'", Result.Text, Result.Reason)}; + } + + void CSVHeader(bool Details, bool AttachmentDetails, StringBuilderBase& CSVWriter) + { + if (AttachmentDetails) + { + CSVWriter << "Project, Oplog, LSN, Key, Cid, Size"; + } + else if (Details) + { + CSVWriter << "Project, Oplog, LSN, Key, Size, AttachmentCount, AttachmentsSize"; + } + else + { + CSVWriter << "Project, Oplog, Key"; + } + } + + void CSVWriteOp(CidStore& CidStore, + std::string_view ProjectId, + std::string_view OplogId, + bool Details, + bool AttachmentDetails, + int LSN, + const Oid& Key, + CbObject Op, + StringBuilderBase& CSVWriter) + { + StringBuilder<32> KeyStringBuilder; + Key.ToString(KeyStringBuilder); + const std::string_view KeyString = KeyStringBuilder.ToView(); + + SharedBuffer Buffer = Op.GetBuffer(); + if (AttachmentDetails) + { + Op.IterateAttachments([&CidStore, &CSVWriter, &ProjectId, &OplogId, LSN, &KeyString](CbFieldView FieldView) { + const IoHash AttachmentHash = FieldView.AsAttachment(); + IoBuffer Attachment = CidStore.FindChunkByCid(AttachmentHash); + CSVWriter << "\r\n" + << ProjectId << ", " << OplogId << ", " << LSN << ", " << KeyString << ", " << AttachmentHash.ToHexString() + << ", " << gsl::narrow<uint64_t>(Attachment.GetSize()); + }); + } + else if (Details) + { + uint64_t AttachmentCount = 0; + size_t AttachmentsSize = 0; + Op.IterateAttachments([&CidStore, &AttachmentCount, &AttachmentsSize](CbFieldView FieldView) { + const IoHash AttachmentHash = FieldView.AsAttachment(); + AttachmentCount++; + IoBuffer Attachment = CidStore.FindChunkByCid(AttachmentHash); + AttachmentsSize += Attachment.GetSize(); + }); + CSVWriter << "\r\n" + << ProjectId << ", " << OplogId << ", " << LSN << ", " << KeyString << ", " << gsl::narrow<uint64_t>(Buffer.GetSize()) + << ", " << AttachmentCount << ", " << gsl::narrow<uint64_t>(AttachmentsSize); + } + else + { + CSVWriter << "\r\n" << ProjectId << ", " << OplogId << ", " << KeyString; + } + }; + + void CbWriteOp(CidStore& CidStore, + bool Details, + bool OpDetails, + bool AttachmentDetails, + int LSN, + const Oid& Key, + CbObject Op, + CbObjectWriter& CbWriter) + { + CbWriter.BeginObject(); + { + SharedBuffer Buffer = Op.GetBuffer(); + CbWriter.AddObjectId("key", Key); + if (Details) + { + CbWriter.AddInteger("lsn", LSN); + CbWriter.AddInteger("size", gsl::narrow<uint64_t>(Buffer.GetSize())); + } + if (AttachmentDetails) + { + CbWriter.BeginArray("attachments"); + Op.IterateAttachments([&CidStore, &CbWriter](CbFieldView FieldView) { + const IoHash AttachmentHash = FieldView.AsAttachment(); + CbWriter.BeginObject(); + { + IoBuffer Attachment = CidStore.FindChunkByCid(AttachmentHash); + CbWriter.AddString("cid", AttachmentHash.ToHexString()); + CbWriter.AddInteger("size", gsl::narrow<uint64_t>(Attachment.GetSize())); + } + CbWriter.EndObject(); + }); + CbWriter.EndArray(); + } + else if (Details) + { + uint64_t AttachmentCount = 0; + size_t AttachmentsSize = 0; + Op.IterateAttachments([&CidStore, &AttachmentCount, &AttachmentsSize](CbFieldView FieldView) { + const IoHash AttachmentHash = FieldView.AsAttachment(); + AttachmentCount++; + IoBuffer Attachment = CidStore.FindChunkByCid(AttachmentHash); + AttachmentsSize += Attachment.GetSize(); + }); + if (AttachmentCount > 0) + { + CbWriter.AddInteger("attachments", AttachmentCount); + CbWriter.AddInteger("attachmentssize", gsl::narrow<uint64_t>(AttachmentsSize)); + } + } + if (OpDetails) + { + CbWriter.BeginObject("op"); + for (const CbFieldView& Field : Op) + { + if (!Field.HasName()) + { + CbWriter.AddField(Field); + continue; + } + std::string_view FieldName = Field.GetName(); + CbWriter.AddField(FieldName, Field); + } + CbWriter.EndObject(); + } + } + CbWriter.EndObject(); + }; + + void CbWriteOplogOps(CidStore& CidStore, + ProjectStore::Oplog& Oplog, + bool Details, + bool OpDetails, + bool AttachmentDetails, + CbObjectWriter& Cbo) + { + Cbo.BeginArray("ops"); + { + Oplog.IterateOplogWithKey([&Cbo, &CidStore, Details, OpDetails, AttachmentDetails](int LSN, const Oid& Key, CbObject Op) { + CbWriteOp(CidStore, Details, OpDetails, AttachmentDetails, LSN, Key, Op, Cbo); + }); + } + Cbo.EndArray(); + } + + void CbWriteOplog(CidStore& CidStore, + ProjectStore::Oplog& Oplog, + bool Details, + bool OpDetails, + bool AttachmentDetails, + CbObjectWriter& Cbo) + { + Cbo.BeginObject(); + { + Cbo.AddString("name", Oplog.OplogId()); + CbWriteOplogOps(CidStore, Oplog, Details, OpDetails, AttachmentDetails, Cbo); + } + Cbo.EndObject(); + } + + void CbWriteOplogs(CidStore& CidStore, + ProjectStore::Project& Project, + std::vector<std::string> OpLogs, + bool Details, + bool OpDetails, + bool AttachmentDetails, + CbObjectWriter& Cbo) + { + Cbo.BeginArray("oplogs"); + { + for (const std::string& OpLogId : OpLogs) + { + ProjectStore::Oplog* Oplog = Project.OpenOplog(OpLogId); + if (Oplog != nullptr) + { + CbWriteOplog(CidStore, *Oplog, Details, OpDetails, AttachmentDetails, Cbo); + } + } + } + Cbo.EndArray(); + } + + void CbWriteProject(CidStore& CidStore, + ProjectStore::Project& Project, + std::vector<std::string> OpLogs, + bool Details, + bool OpDetails, + bool AttachmentDetails, + CbObjectWriter& Cbo) + { + Cbo.BeginObject(); + { + Cbo.AddString("name", Project.Identifier); + CbWriteOplogs(CidStore, Project, OpLogs, Details, OpDetails, AttachmentDetails, Cbo); + } + Cbo.EndObject(); + } + +} // namespace + +////////////////////////////////////////////////////////////////////////// + +Oid +OpKeyStringAsOId(std::string_view OpKey) +{ + using namespace std::literals; + + CbObjectWriter Writer; + Writer << "key"sv << OpKey; + + XXH3_128Stream KeyHasher; + Writer.Save()["key"sv].WriteToStream([&](const void* Data, size_t Size) { KeyHasher.Append(Data, Size); }); + XXH3_128 KeyHash = KeyHasher.GetHash(); + + Oid OpId; + memcpy(OpId.OidBits, &KeyHash, sizeof(OpId.OidBits)); + + return OpId; +} + +////////////////////////////////////////////////////////////////////////// + +struct ProjectStore::OplogStorage : public RefCounted +{ + OplogStorage(ProjectStore::Oplog* OwnerOplog, std::filesystem::path BasePath) : m_OwnerOplog(OwnerOplog), m_OplogStoragePath(BasePath) + { + } + + ~OplogStorage() + { + ZEN_INFO("closing oplog storage at {}", m_OplogStoragePath); + Flush(); + } + + [[nodiscard]] bool Exists() { return Exists(m_OplogStoragePath); } + [[nodiscard]] static bool Exists(std::filesystem::path BasePath) + { + return std::filesystem::exists(BasePath / "ops.zlog") && std::filesystem::exists(BasePath / "ops.zops"); + } + + static bool Delete(std::filesystem::path BasePath) { return DeleteDirectories(BasePath); } + + uint64_t OpBlobsSize() const + { + RwLock::SharedLockScope _(m_RwLock); + return m_NextOpsOffset; + } + + void Open(bool IsCreate) + { + using namespace std::literals; + + ZEN_INFO("initializing oplog storage at '{}'", m_OplogStoragePath); + + if (IsCreate) + { + DeleteDirectories(m_OplogStoragePath); + CreateDirectories(m_OplogStoragePath); + } + + m_Oplog.Open(m_OplogStoragePath / "ops.zlog"sv, IsCreate ? CasLogFile::Mode::kTruncate : CasLogFile::Mode::kWrite); + m_Oplog.Initialize(); + + m_OpBlobs.Open(m_OplogStoragePath / "ops.zops"sv, IsCreate ? BasicFile::Mode::kTruncate : BasicFile::Mode::kWrite); + + ZEN_ASSERT(IsPow2(m_OpsAlign)); + ZEN_ASSERT(!(m_NextOpsOffset & (m_OpsAlign - 1))); + } + + void ReplayLog(std::function<void(CbObject, const OplogEntry&)>&& Handler) + { + ZEN_TRACE_CPU("ProjectStore::OplogStorage::ReplayLog"); + + // This could use memory mapping or do something clever but for now it just reads the file sequentially + + ZEN_INFO("replaying log for '{}'", m_OplogStoragePath); + + Stopwatch Timer; + + uint64_t InvalidEntries = 0; + + IoBuffer OpBuffer; + m_Oplog.Replay( + [&](const OplogEntry& LogEntry) { + if (LogEntry.OpCoreSize == 0) + { + ++InvalidEntries; + + return; + } + + if (OpBuffer.GetSize() < LogEntry.OpCoreSize) + { + OpBuffer = IoBuffer(LogEntry.OpCoreSize); + } + + const uint64_t OpFileOffset = LogEntry.OpCoreOffset * m_OpsAlign; + + m_OpBlobs.Read((void*)OpBuffer.Data(), LogEntry.OpCoreSize, OpFileOffset); + + // Verify checksum, ignore op data if incorrect + const auto OpCoreHash = uint32_t(XXH3_64bits(OpBuffer.Data(), LogEntry.OpCoreSize) & 0xffffFFFF); + + if (OpCoreHash != LogEntry.OpCoreHash) + { + ZEN_WARN("skipping oplog entry with bad checksum!"); + return; + } + + CbObject Op(SharedBuffer::MakeView(OpBuffer.Data(), LogEntry.OpCoreSize)); + + m_NextOpsOffset = + Max(m_NextOpsOffset.load(std::memory_order_relaxed), RoundUp(OpFileOffset + LogEntry.OpCoreSize, m_OpsAlign)); + m_MaxLsn = Max(m_MaxLsn.load(std::memory_order_relaxed), LogEntry.OpLsn); + + Handler(Op, LogEntry); + }, + 0); + + if (InvalidEntries) + { + ZEN_WARN("ignored {} zero-sized oplog entries", InvalidEntries); + } + + ZEN_INFO("Oplog replay completed in {} - Max LSN# {}, Next offset: {}", + NiceTimeSpanMs(Timer.GetElapsedTimeMs()), + m_MaxLsn, + m_NextOpsOffset); + } + + void ReplayLog(const std::vector<OplogEntryAddress>& Entries, std::function<void(CbObject)>&& Handler) + { + for (const OplogEntryAddress& Entry : Entries) + { + CbObject Op = GetOp(Entry); + Handler(Op); + } + } + + CbObject GetOp(const OplogEntryAddress& Entry) + { + IoBuffer OpBuffer(Entry.Size); + + const uint64_t OpFileOffset = Entry.Offset * m_OpsAlign; + m_OpBlobs.Read((void*)OpBuffer.Data(), Entry.Size, OpFileOffset); + + return CbObject(SharedBuffer(std::move(OpBuffer))); + } + + OplogEntry AppendOp(SharedBuffer Buffer, uint32_t OpCoreHash, XXH3_128 KeyHash) + { + ZEN_TRACE_CPU("ProjectStore::OplogStorage::AppendOp"); + + using namespace std::literals; + + uint64_t WriteSize = Buffer.GetSize(); + + RwLock::ExclusiveLockScope Lock(m_RwLock); + const uint64_t WriteOffset = m_NextOpsOffset; + const uint32_t OpLsn = ++m_MaxLsn; + m_NextOpsOffset = RoundUp(WriteOffset + WriteSize, m_OpsAlign); + Lock.ReleaseNow(); + + ZEN_ASSERT(IsMultipleOf(WriteOffset, m_OpsAlign)); + + OplogEntry Entry = {.OpLsn = OpLsn, + .OpCoreOffset = gsl::narrow_cast<uint32_t>(WriteOffset / m_OpsAlign), + .OpCoreSize = uint32_t(Buffer.GetSize()), + .OpCoreHash = OpCoreHash, + .OpKeyHash = KeyHash}; + + m_Oplog.Append(Entry); + m_OpBlobs.Write(Buffer.GetData(), WriteSize, WriteOffset); + + return Entry; + } + + void Flush() + { + m_Oplog.Flush(); + m_OpBlobs.Flush(); + } + + spdlog::logger& Log() { return m_OwnerOplog->Log(); } + +private: + ProjectStore::Oplog* m_OwnerOplog; + std::filesystem::path m_OplogStoragePath; + mutable RwLock m_RwLock; + TCasLogFile<OplogEntry> m_Oplog; + BasicFile m_OpBlobs; + std::atomic<uint64_t> m_NextOpsOffset{0}; + uint64_t m_OpsAlign = 32; + std::atomic<uint32_t> m_MaxLsn{0}; +}; + +////////////////////////////////////////////////////////////////////////// + +ProjectStore::Oplog::Oplog(std::string_view Id, + Project* Project, + CidStore& Store, + std::filesystem::path BasePath, + const std::filesystem::path& MarkerPath) +: m_OuterProject(Project) +, m_CidStore(Store) +, m_BasePath(BasePath) +, m_MarkerPath(MarkerPath) +, m_OplogId(Id) +{ + using namespace std::literals; + + m_Storage = new OplogStorage(this, m_BasePath); + const bool StoreExists = m_Storage->Exists(); + m_Storage->Open(/* IsCreate */ !StoreExists); + + m_TempPath = m_BasePath / "temp"sv; + + CleanDirectory(m_TempPath); +} + +ProjectStore::Oplog::~Oplog() +{ + if (m_Storage) + { + Flush(); + } +} + +void +ProjectStore::Oplog::Flush() +{ + ZEN_ASSERT(m_Storage); + m_Storage->Flush(); +} + +void +ProjectStore::Oplog::Scrub(ScrubContext& Ctx) const +{ + ZEN_UNUSED(Ctx); +} + +void +ProjectStore::Oplog::GatherReferences(GcContext& GcCtx) +{ + RwLock::SharedLockScope _(m_OplogLock); + + std::vector<IoHash> Hashes; + Hashes.reserve(Max(m_ChunkMap.size(), m_MetaMap.size())); + + for (const auto& Kv : m_ChunkMap) + { + Hashes.push_back(Kv.second); + } + + GcCtx.AddRetainedCids(Hashes); + + Hashes.clear(); + + for (const auto& Kv : m_MetaMap) + { + Hashes.push_back(Kv.second); + } + + GcCtx.AddRetainedCids(Hashes); +} + +uint64_t +ProjectStore::Oplog::TotalSize() const +{ + RwLock::SharedLockScope _(m_OplogLock); + if (m_Storage) + { + return m_Storage->OpBlobsSize(); + } + return 0; +} + +bool +ProjectStore::Oplog::IsExpired() const +{ + if (m_MarkerPath.empty()) + { + return false; + } + return !std::filesystem::exists(m_MarkerPath); +} + +std::filesystem::path +ProjectStore::Oplog::PrepareForDelete(bool MoveFolder) +{ + RwLock::ExclusiveLockScope _(m_OplogLock); + m_ChunkMap.clear(); + m_MetaMap.clear(); + m_FileMap.clear(); + m_OpAddressMap.clear(); + m_LatestOpMap.clear(); + m_Storage = {}; + if (!MoveFolder) + { + return {}; + } + std::filesystem::path MovedDir; + if (PrepareDirectoryDelete(m_BasePath, MovedDir)) + { + return MovedDir; + } + return {}; +} + +bool +ProjectStore::Oplog::ExistsAt(std::filesystem::path BasePath) +{ + using namespace std::literals; + + std::filesystem::path StateFilePath = BasePath / "oplog.zcb"sv; + return std::filesystem::is_regular_file(StateFilePath); +} + +void +ProjectStore::Oplog::Read() +{ + using namespace std::literals; + + std::filesystem::path StateFilePath = m_BasePath / "oplog.zcb"sv; + if (std::filesystem::is_regular_file(StateFilePath)) + { + ZEN_INFO("reading config for oplog '{}' in project '{}' from {}", m_OplogId, m_OuterProject->Identifier, StateFilePath); + + BasicFile Blob; + Blob.Open(StateFilePath, BasicFile::Mode::kRead); + + IoBuffer Obj = Blob.ReadAll(); + CbValidateError ValidationError = ValidateCompactBinary(MemoryView(Obj.Data(), Obj.Size()), CbValidateMode::All); + + if (ValidationError != CbValidateError::None) + { + ZEN_ERROR("validation error {} hit for '{}'", int(ValidationError), StateFilePath); + return; + } + + CbObject Cfg = LoadCompactBinaryObject(Obj); + + m_MarkerPath = Cfg["gcpath"sv].AsString(); + } + else + { + ZEN_INFO("config for oplog '{}' in project '{}' not found at {}. Assuming legacy store", + m_OplogId, + m_OuterProject->Identifier, + StateFilePath); + } + ReplayLog(); +} + +void +ProjectStore::Oplog::Write() +{ + using namespace std::literals; + + BinaryWriter Mem; + + CbObjectWriter Cfg; + + Cfg << "gcpath"sv << PathToUtf8(m_MarkerPath); + + Cfg.Save(Mem); + + std::filesystem::path StateFilePath = m_BasePath / "oplog.zcb"sv; + + ZEN_INFO("persisting config for oplog '{}' in project '{}' to {}", m_OplogId, m_OuterProject->Identifier, StateFilePath); + + BasicFile Blob; + Blob.Open(StateFilePath, BasicFile::Mode::kTruncate); + Blob.Write(Mem.Data(), Mem.Size(), 0); + Blob.Flush(); +} + +void +ProjectStore::Oplog::ReplayLog() +{ + RwLock::ExclusiveLockScope OplogLock(m_OplogLock); + if (!m_Storage) + { + return; + } + m_Storage->ReplayLog( + [&](CbObject Op, const OplogEntry& OpEntry) { RegisterOplogEntry(OplogLock, GetMapping(Op), OpEntry, kUpdateReplay); }); +} + +IoBuffer +ProjectStore::Oplog::FindChunk(Oid ChunkId) +{ + RwLock::SharedLockScope OplogLock(m_OplogLock); + if (!m_Storage) + { + return IoBuffer{}; + } + + if (auto ChunkIt = m_ChunkMap.find(ChunkId); ChunkIt != m_ChunkMap.end()) + { + IoHash ChunkHash = ChunkIt->second; + OplogLock.ReleaseNow(); + + IoBuffer Chunk = m_CidStore.FindChunkByCid(ChunkHash); + Chunk.SetContentType(ZenContentType::kCompressedBinary); + + return Chunk; + } + + if (auto FileIt = m_FileMap.find(ChunkId); FileIt != m_FileMap.end()) + { + std::filesystem::path FilePath = m_OuterProject->RootDir / FileIt->second.ServerPath; + + OplogLock.ReleaseNow(); + + IoBuffer FileChunk = IoBufferBuilder::MakeFromFile(FilePath); + FileChunk.SetContentType(ZenContentType::kBinary); + + return FileChunk; + } + + if (auto MetaIt = m_MetaMap.find(ChunkId); MetaIt != m_MetaMap.end()) + { + IoHash ChunkHash = MetaIt->second; + OplogLock.ReleaseNow(); + + IoBuffer Chunk = m_CidStore.FindChunkByCid(ChunkHash); + Chunk.SetContentType(ZenContentType::kCompressedBinary); + + return Chunk; + } + + return {}; +} + +void +ProjectStore::Oplog::IterateFileMap( + std::function<void(const Oid&, const std::string_view& ServerPath, const std::string_view& ClientPath)>&& Fn) +{ + RwLock::SharedLockScope _(m_OplogLock); + if (!m_Storage) + { + return; + } + + for (const auto& Kv : m_FileMap) + { + Fn(Kv.first, Kv.second.ServerPath, Kv.second.ClientPath); + } +} + +void +ProjectStore::Oplog::IterateOplog(std::function<void(CbObject)>&& Handler) +{ + RwLock::SharedLockScope _(m_OplogLock); + if (!m_Storage) + { + return; + } + + std::vector<OplogEntryAddress> Entries; + Entries.reserve(m_LatestOpMap.size()); + + for (const auto& Kv : m_LatestOpMap) + { + const auto AddressEntry = m_OpAddressMap.find(Kv.second); + ZEN_ASSERT(AddressEntry != m_OpAddressMap.end()); + + Entries.push_back(AddressEntry->second); + } + + std::sort(Entries.begin(), Entries.end(), [](const OplogEntryAddress& Lhs, const OplogEntryAddress& Rhs) { + return Lhs.Offset < Rhs.Offset; + }); + + m_Storage->ReplayLog(Entries, [&](CbObject Op) { Handler(Op); }); +} + +void +ProjectStore::Oplog::IterateOplogWithKey(std::function<void(int, const Oid&, CbObject)>&& Handler) +{ + RwLock::SharedLockScope _(m_OplogLock); + if (!m_Storage) + { + return; + } + + std::vector<size_t> EntryIndexes; + std::vector<OplogEntryAddress> Entries; + std::vector<Oid> Keys; + std::vector<int> LSNs; + Entries.reserve(m_LatestOpMap.size()); + EntryIndexes.reserve(m_LatestOpMap.size()); + Keys.reserve(m_LatestOpMap.size()); + LSNs.reserve(m_LatestOpMap.size()); + + for (const auto& Kv : m_LatestOpMap) + { + const auto AddressEntry = m_OpAddressMap.find(Kv.second); + ZEN_ASSERT(AddressEntry != m_OpAddressMap.end()); + + Entries.push_back(AddressEntry->second); + Keys.push_back(Kv.first); + LSNs.push_back(Kv.second); + EntryIndexes.push_back(EntryIndexes.size()); + } + + std::sort(EntryIndexes.begin(), EntryIndexes.end(), [&Entries](const size_t& Lhs, const size_t& Rhs) { + const OplogEntryAddress& LhsEntry = Entries[Lhs]; + const OplogEntryAddress& RhsEntry = Entries[Rhs]; + return LhsEntry.Offset < RhsEntry.Offset; + }); + std::vector<OplogEntryAddress> SortedEntries; + SortedEntries.reserve(EntryIndexes.size()); + for (size_t Index : EntryIndexes) + { + SortedEntries.push_back(Entries[Index]); + } + + size_t EntryIndex = 0; + m_Storage->ReplayLog(SortedEntries, [&](CbObject Op) { + Handler(LSNs[EntryIndex], Keys[EntryIndex], Op); + EntryIndex++; + }); +} + +int +ProjectStore::Oplog::GetOpIndexByKey(const Oid& Key) +{ + RwLock::SharedLockScope _(m_OplogLock); + if (!m_Storage) + { + return {}; + } + if (const auto LatestOp = m_LatestOpMap.find(Key); LatestOp != m_LatestOpMap.end()) + { + return LatestOp->second; + } + return -1; +} + +std::optional<CbObject> +ProjectStore::Oplog::GetOpByKey(const Oid& Key) +{ + RwLock::SharedLockScope _(m_OplogLock); + if (!m_Storage) + { + return {}; + } + + if (const auto LatestOp = m_LatestOpMap.find(Key); LatestOp != m_LatestOpMap.end()) + { + const auto AddressEntry = m_OpAddressMap.find(LatestOp->second); + ZEN_ASSERT(AddressEntry != m_OpAddressMap.end()); + + return m_Storage->GetOp(AddressEntry->second); + } + + return {}; +} + +std::optional<CbObject> +ProjectStore::Oplog::GetOpByIndex(int Index) +{ + RwLock::SharedLockScope _(m_OplogLock); + if (!m_Storage) + { + return {}; + } + + if (const auto AddressEntryIt = m_OpAddressMap.find(Index); AddressEntryIt != m_OpAddressMap.end()) + { + return m_Storage->GetOp(AddressEntryIt->second); + } + + return {}; +} + +void +ProjectStore::Oplog::AddFileMapping(const RwLock::ExclusiveLockScope&, + Oid FileId, + IoHash Hash, + std::string_view ServerPath, + std::string_view ClientPath) +{ + if (Hash != IoHash::Zero) + { + m_ChunkMap.insert_or_assign(FileId, Hash); + } + + FileMapEntry Entry; + Entry.ServerPath = ServerPath; + Entry.ClientPath = ClientPath; + + m_FileMap[FileId] = std::move(Entry); + + if (Hash != IoHash::Zero) + { + m_ChunkMap.insert_or_assign(FileId, Hash); + } +} + +void +ProjectStore::Oplog::AddChunkMapping(const RwLock::ExclusiveLockScope&, Oid ChunkId, IoHash Hash) +{ + m_ChunkMap.insert_or_assign(ChunkId, Hash); +} + +void +ProjectStore::Oplog::AddMetaMapping(const RwLock::ExclusiveLockScope&, Oid ChunkId, IoHash Hash) +{ + m_MetaMap.insert_or_assign(ChunkId, Hash); +} + +ProjectStore::Oplog::OplogEntryMapping +ProjectStore::Oplog::GetMapping(CbObject Core) +{ + using namespace std::literals; + + OplogEntryMapping Result; + + // Update chunk id maps + CbObjectView PackageObj = Core["package"sv].AsObjectView(); + CbArrayView BulkDataArray = Core["bulkdata"sv].AsArrayView(); + CbArrayView PackageDataArray = Core["packagedata"sv].AsArrayView(); + Result.Chunks.reserve(PackageObj ? 1 : 0 + BulkDataArray.Num() + PackageDataArray.Num()); + + if (PackageObj) + { + Oid Id = PackageObj["id"sv].AsObjectId(); + IoHash Hash = PackageObj["data"sv].AsBinaryAttachment(); + Result.Chunks.emplace_back(OplogEntryMapping::Mapping{Id, Hash}); + ZEN_DEBUG("package data {} -> {}", Id, Hash); + } + + for (CbFieldView& Entry : PackageDataArray) + { + CbObjectView PackageDataObj = Entry.AsObjectView(); + Oid Id = PackageDataObj["id"sv].AsObjectId(); + IoHash Hash = PackageDataObj["data"sv].AsBinaryAttachment(); + Result.Chunks.emplace_back(OplogEntryMapping::Mapping{Id, Hash}); + ZEN_DEBUG("package {} -> {}", Id, Hash); + } + + for (CbFieldView& Entry : BulkDataArray) + { + CbObjectView BulkObj = Entry.AsObjectView(); + Oid Id = BulkObj["id"sv].AsObjectId(); + IoHash Hash = BulkObj["data"sv].AsBinaryAttachment(); + Result.Chunks.emplace_back(OplogEntryMapping::Mapping{Id, Hash}); + ZEN_DEBUG("bulkdata {} -> {}", Id, Hash); + } + + CbArrayView FilesArray = Core["files"sv].AsArrayView(); + Result.Files.reserve(FilesArray.Num()); + for (CbFieldView& Entry : FilesArray) + { + CbObjectView FileObj = Entry.AsObjectView(); + + std::string_view ServerPath = FileObj["serverpath"sv].AsString(); + std::string_view ClientPath = FileObj["clientpath"sv].AsString(); + if (ServerPath.empty() || ClientPath.empty()) + { + ZEN_WARN("invalid file"); + continue; + } + + Oid Id = FileObj["id"sv].AsObjectId(); + IoHash Hash = FileObj["data"sv].AsBinaryAttachment(); + Result.Files.emplace_back( + OplogEntryMapping::FileMapping{OplogEntryMapping::Mapping{Id, Hash}, std::string(ServerPath), std::string(ClientPath)}); + ZEN_DEBUG("file {} -> {}, ServerPath: {}, ClientPath: {}", Id, Hash, ServerPath, ClientPath); + } + + CbArrayView MetaArray = Core["meta"sv].AsArrayView(); + Result.Meta.reserve(MetaArray.Num()); + for (CbFieldView& Entry : MetaArray) + { + CbObjectView MetaObj = Entry.AsObjectView(); + Oid Id = MetaObj["id"sv].AsObjectId(); + IoHash Hash = MetaObj["data"sv].AsBinaryAttachment(); + Result.Meta.emplace_back(OplogEntryMapping::Mapping{Id, Hash}); + auto NameString = MetaObj["name"sv].AsString(); + ZEN_DEBUG("meta data ({}) {} -> {}", NameString, Id, Hash); + } + + return Result; +} + +uint32_t +ProjectStore::Oplog::RegisterOplogEntry(RwLock::ExclusiveLockScope& OplogLock, + const OplogEntryMapping& OpMapping, + const OplogEntry& OpEntry, + UpdateType TypeOfUpdate) +{ + ZEN_TRACE_CPU("ProjectStore::Oplog::RegisterOplogEntry"); + + ZEN_UNUSED(TypeOfUpdate); + + // For now we're assuming the update is all in-memory so we can hold an exclusive lock without causing + // too many problems. Longer term we'll probably want to ensure we can do concurrent updates however + + using namespace std::literals; + + // Update chunk id maps + for (const OplogEntryMapping::Mapping& Chunk : OpMapping.Chunks) + { + AddChunkMapping(OplogLock, Chunk.Id, Chunk.Hash); + } + + for (const OplogEntryMapping::FileMapping& File : OpMapping.Files) + { + AddFileMapping(OplogLock, File.Id, File.Hash, File.ServerPath, File.ClientPath); + } + + for (const OplogEntryMapping::Mapping& Meta : OpMapping.Meta) + { + AddMetaMapping(OplogLock, Meta.Id, Meta.Hash); + } + + m_OpAddressMap.emplace(OpEntry.OpLsn, OplogEntryAddress{.Offset = OpEntry.OpCoreOffset, .Size = OpEntry.OpCoreSize}); + m_LatestOpMap[OpEntry.OpKeyAsOId()] = OpEntry.OpLsn; + + return OpEntry.OpLsn; +} + +uint32_t +ProjectStore::Oplog::AppendNewOplogEntry(CbPackage OpPackage) +{ + ZEN_TRACE_CPU("ProjectStore::Oplog::AppendNewOplogEntry"); + + const CbObject& Core = OpPackage.GetObject(); + const uint32_t EntryId = AppendNewOplogEntry(Core); + if (EntryId == 0xffffffffu) + { + // The oplog has been deleted so just drop this + return EntryId; + } + + // Persist attachments after oplog entry so GC won't find attachments without references + + uint64_t AttachmentBytes = 0; + uint64_t NewAttachmentBytes = 0; + + auto Attachments = OpPackage.GetAttachments(); + + for (const auto& Attach : Attachments) + { + ZEN_ASSERT(Attach.IsCompressedBinary()); + + CompressedBuffer AttachmentData = Attach.AsCompressedBinary(); + const uint64_t AttachmentSize = AttachmentData.DecodeRawSize(); + CidStore::InsertResult InsertResult = m_CidStore.AddChunk(AttachmentData.GetCompressed().Flatten().AsIoBuffer(), Attach.GetHash()); + + if (InsertResult.New) + { + NewAttachmentBytes += AttachmentSize; + } + AttachmentBytes += AttachmentSize; + } + + ZEN_DEBUG("oplog entry #{} attachments: {} new, {} total", EntryId, NiceBytes(NewAttachmentBytes), NiceBytes(AttachmentBytes)); + + return EntryId; +} + +uint32_t +ProjectStore::Oplog::AppendNewOplogEntry(CbObject Core) +{ + ZEN_TRACE_CPU("ProjectStore::Oplog::AppendNewOplogEntry"); + + using namespace std::literals; + + OplogEntryMapping Mapping = GetMapping(Core); + + SharedBuffer Buffer = Core.GetBuffer(); + const uint64_t WriteSize = Buffer.GetSize(); + const auto OpCoreHash = uint32_t(XXH3_64bits(Buffer.GetData(), WriteSize) & 0xffffFFFF); + + ZEN_ASSERT(WriteSize != 0); + + XXH3_128Stream KeyHasher; + Core["key"sv].WriteToStream([&](const void* Data, size_t Size) { KeyHasher.Append(Data, Size); }); + XXH3_128 KeyHash = KeyHasher.GetHash(); + + RefPtr<OplogStorage> Storage; + { + RwLock::SharedLockScope _(m_OplogLock); + Storage = m_Storage; + } + if (!m_Storage) + { + return 0xffffffffu; + } + const OplogEntry OpEntry = m_Storage->AppendOp(Buffer, OpCoreHash, KeyHash); + + RwLock::ExclusiveLockScope OplogLock(m_OplogLock); + const uint32_t EntryId = RegisterOplogEntry(OplogLock, Mapping, OpEntry, kUpdateNewEntry); + + return EntryId; +} + +////////////////////////////////////////////////////////////////////////// + +ProjectStore::Project::Project(ProjectStore* PrjStore, CidStore& Store, std::filesystem::path BasePath) +: m_ProjectStore(PrjStore) +, m_CidStore(Store) +, m_OplogStoragePath(BasePath) +{ +} + +ProjectStore::Project::~Project() +{ +} + +bool +ProjectStore::Project::Exists(std::filesystem::path BasePath) +{ + return std::filesystem::exists(BasePath / "Project.zcb"); +} + +void +ProjectStore::Project::Read() +{ + using namespace std::literals; + + std::filesystem::path ProjectStateFilePath = m_OplogStoragePath / "Project.zcb"sv; + + ZEN_INFO("reading config for project '{}' from {}", Identifier, ProjectStateFilePath); + + BasicFile Blob; + Blob.Open(ProjectStateFilePath, BasicFile::Mode::kRead); + + IoBuffer Obj = Blob.ReadAll(); + CbValidateError ValidationError = ValidateCompactBinary(MemoryView(Obj.Data(), Obj.Size()), CbValidateMode::All); + + if (ValidationError == CbValidateError::None) + { + CbObject Cfg = LoadCompactBinaryObject(Obj); + + Identifier = Cfg["id"sv].AsString(); + RootDir = Cfg["root"sv].AsString(); + ProjectRootDir = Cfg["project"sv].AsString(); + EngineRootDir = Cfg["engine"sv].AsString(); + ProjectFilePath = Cfg["projectfile"sv].AsString(); + } + else + { + ZEN_ERROR("validation error {} hit for '{}'", int(ValidationError), ProjectStateFilePath); + } +} + +void +ProjectStore::Project::Write() +{ + using namespace std::literals; + + BinaryWriter Mem; + + CbObjectWriter Cfg; + Cfg << "id"sv << Identifier; + Cfg << "root"sv << PathToUtf8(RootDir); + Cfg << "project"sv << ProjectRootDir; + Cfg << "engine"sv << EngineRootDir; + Cfg << "projectfile"sv << ProjectFilePath; + + Cfg.Save(Mem); + + CreateDirectories(m_OplogStoragePath); + + std::filesystem::path ProjectStateFilePath = m_OplogStoragePath / "Project.zcb"sv; + + ZEN_INFO("persisting config for project '{}' to {}", Identifier, ProjectStateFilePath); + + BasicFile Blob; + Blob.Open(ProjectStateFilePath, BasicFile::Mode::kTruncate); + Blob.Write(Mem.Data(), Mem.Size(), 0); + Blob.Flush(); +} + +spdlog::logger& +ProjectStore::Project::Log() +{ + return m_ProjectStore->Log(); +} + +std::filesystem::path +ProjectStore::Project::BasePathForOplog(std::string_view OplogId) +{ + return m_OplogStoragePath / OplogId; +} + +ProjectStore::Oplog* +ProjectStore::Project::NewOplog(std::string_view OplogId, const std::filesystem::path& MarkerPath) +{ + RwLock::ExclusiveLockScope _(m_ProjectLock); + + std::filesystem::path OplogBasePath = BasePathForOplog(OplogId); + + try + { + Oplog* Log = m_Oplogs + .try_emplace(std::string{OplogId}, + std::make_unique<ProjectStore::Oplog>(OplogId, this, m_CidStore, OplogBasePath, MarkerPath)) + .first->second.get(); + + Log->Write(); + return Log; + } + catch (std::exception&) + { + // In case of failure we need to ensure there's no half constructed entry around + // + // (This is probably already ensured by the try_emplace implementation?) + + m_Oplogs.erase(std::string{OplogId}); + + return nullptr; + } +} + +ProjectStore::Oplog* +ProjectStore::Project::OpenOplog(std::string_view OplogId) +{ + { + RwLock::SharedLockScope _(m_ProjectLock); + + auto OplogIt = m_Oplogs.find(std::string(OplogId)); + + if (OplogIt != m_Oplogs.end()) + { + return OplogIt->second.get(); + } + } + + RwLock::ExclusiveLockScope _(m_ProjectLock); + + std::filesystem::path OplogBasePath = BasePathForOplog(OplogId); + + if (Oplog::ExistsAt(OplogBasePath)) + { + // Do open of existing oplog + + try + { + Oplog* Log = + m_Oplogs + .try_emplace(std::string{OplogId}, + std::make_unique<ProjectStore::Oplog>(OplogId, this, m_CidStore, OplogBasePath, std::filesystem::path{})) + .first->second.get(); + Log->Read(); + + return Log; + } + catch (std::exception& ex) + { + ZEN_WARN("failed to open oplog '{}' @ '{}': {}", OplogId, OplogBasePath, ex.what()); + + m_Oplogs.erase(std::string{OplogId}); + } + } + + return nullptr; +} + +void +ProjectStore::Project::DeleteOplog(std::string_view OplogId) +{ + std::filesystem::path DeletePath; + { + RwLock::ExclusiveLockScope _(m_ProjectLock); + + auto OplogIt = m_Oplogs.find(std::string(OplogId)); + + if (OplogIt != m_Oplogs.end()) + { + std::unique_ptr<Oplog>& Oplog = OplogIt->second; + DeletePath = Oplog->PrepareForDelete(true); + m_DeletedOplogs.emplace_back(std::move(Oplog)); + m_Oplogs.erase(OplogIt); + } + } + + // Erase content on disk + if (!DeletePath.empty()) + { + OplogStorage::Delete(DeletePath); + } +} + +std::vector<std::string> +ProjectStore::Project::ScanForOplogs() const +{ + DirectoryContent DirContent; + GetDirectoryContent(m_OplogStoragePath, DirectoryContent::IncludeDirsFlag, DirContent); + std::vector<std::string> Oplogs; + Oplogs.reserve(DirContent.Directories.size()); + for (const std::filesystem::path& DirPath : DirContent.Directories) + { + Oplogs.push_back(DirPath.filename().string()); + } + return Oplogs; +} + +void +ProjectStore::Project::IterateOplogs(std::function<void(const Oplog&)>&& Fn) const +{ + RwLock::SharedLockScope _(m_ProjectLock); + + for (auto& Kv : m_Oplogs) + { + Fn(*Kv.second); + } +} + +void +ProjectStore::Project::IterateOplogs(std::function<void(Oplog&)>&& Fn) +{ + RwLock::SharedLockScope _(m_ProjectLock); + + for (auto& Kv : m_Oplogs) + { + Fn(*Kv.second); + } +} + +void +ProjectStore::Project::Flush() +{ + // We only need to flush oplogs that we have already loaded + IterateOplogs([&](Oplog& Ops) { Ops.Flush(); }); +} + +void +ProjectStore::Project::Scrub(ScrubContext& Ctx) +{ + // Scrubbing needs to check all existing oplogs + std::vector<std::string> OpLogs = ScanForOplogs(); + for (const std::string& OpLogId : OpLogs) + { + OpenOplog(OpLogId); + } + IterateOplogs([&](const Oplog& Ops) { + if (!Ops.IsExpired()) + { + Ops.Scrub(Ctx); + } + }); +} + +void +ProjectStore::Project::GatherReferences(GcContext& GcCtx) +{ + ZEN_TRACE_CPU("ProjectStore::Project::GatherReferences"); + + Stopwatch Timer; + const auto Guard = MakeGuard([&] { + ZEN_DEBUG("gathered references from project store project {} in {}", Identifier, NiceTimeSpanMs(Timer.GetElapsedTimeMs())); + }); + + // GatherReferences needs to check all existing oplogs + std::vector<std::string> OpLogs = ScanForOplogs(); + for (const std::string& OpLogId : OpLogs) + { + OpenOplog(OpLogId); + } + IterateOplogs([&](Oplog& Ops) { + if (!Ops.IsExpired()) + { + Ops.GatherReferences(GcCtx); + } + }); +} + +uint64_t +ProjectStore::Project::TotalSize() const +{ + uint64_t Result = 0; + { + RwLock::SharedLockScope _(m_ProjectLock); + for (const auto& It : m_Oplogs) + { + Result += It.second->TotalSize(); + } + } + return Result; +} + +bool +ProjectStore::Project::PrepareForDelete(std::filesystem::path& OutDeletePath) +{ + RwLock::ExclusiveLockScope _(m_ProjectLock); + + for (auto& It : m_Oplogs) + { + // We don't care about the moved folder + It.second->PrepareForDelete(false); + m_DeletedOplogs.emplace_back(std::move(It.second)); + } + + m_Oplogs.clear(); + + bool Success = PrepareDirectoryDelete(m_OplogStoragePath, OutDeletePath); + if (!Success) + { + return false; + } + m_OplogStoragePath.clear(); + return true; +} + +bool +ProjectStore::Project::IsExpired() const +{ + if (ProjectFilePath.empty()) + { + return false; + } + return !std::filesystem::exists(ProjectFilePath); +} + +////////////////////////////////////////////////////////////////////////// + +ProjectStore::ProjectStore(CidStore& Store, std::filesystem::path BasePath, GcManager& Gc) +: GcStorage(Gc) +, GcContributor(Gc) +, m_Log(logging::Get("project")) +, m_CidStore(Store) +, m_ProjectBasePath(BasePath) +{ + ZEN_INFO("initializing project store at '{}'", BasePath); + // m_Log.set_level(spdlog::level::debug); +} + +ProjectStore::~ProjectStore() +{ + ZEN_INFO("closing project store ('{}')", m_ProjectBasePath); +} + +std::filesystem::path +ProjectStore::BasePathForProject(std::string_view ProjectId) +{ + return m_ProjectBasePath / ProjectId; +} + +void +ProjectStore::DiscoverProjects() +{ + if (!std::filesystem::exists(m_ProjectBasePath)) + { + return; + } + + DirectoryContent DirContent; + GetDirectoryContent(m_ProjectBasePath, DirectoryContent::IncludeDirsFlag, DirContent); + + for (const std::filesystem::path& DirPath : DirContent.Directories) + { + std::string DirName = PathToUtf8(DirPath.filename()); + OpenProject(DirName); + } +} + +void +ProjectStore::IterateProjects(std::function<void(Project& Prj)>&& Fn) +{ + RwLock::SharedLockScope _(m_ProjectsLock); + + for (auto& Kv : m_Projects) + { + Fn(*Kv.second.Get()); + } +} + +void +ProjectStore::Flush() +{ + std::vector<Ref<Project>> Projects; + { + RwLock::SharedLockScope _(m_ProjectsLock); + Projects.reserve(m_Projects.size()); + + for (auto& Kv : m_Projects) + { + Projects.push_back(Kv.second); + } + } + for (const Ref<Project>& Project : Projects) + { + Project->Flush(); + } +} + +void +ProjectStore::Scrub(ScrubContext& Ctx) +{ + DiscoverProjects(); + + std::vector<Ref<Project>> Projects; + { + RwLock::SharedLockScope _(m_ProjectsLock); + Projects.reserve(m_Projects.size()); + + for (auto& Kv : m_Projects) + { + if (Kv.second->IsExpired()) + { + continue; + } + Projects.push_back(Kv.second); + } + } + for (const Ref<Project>& Project : Projects) + { + Project->Scrub(Ctx); + } +} + +void +ProjectStore::GatherReferences(GcContext& GcCtx) +{ + ZEN_TRACE_CPU("ProjectStore::GatherReferences"); + + size_t ProjectCount = 0; + size_t ExpiredProjectCount = 0; + Stopwatch Timer; + const auto Guard = MakeGuard([&] { + ZEN_DEBUG("gathered references from '{}' in {}, found {} active projects and {} expired projects", + m_ProjectBasePath.string(), + NiceTimeSpanMs(Timer.GetElapsedTimeMs()), + ProjectCount, + ExpiredProjectCount); + }); + + DiscoverProjects(); + + std::vector<Ref<Project>> Projects; + { + RwLock::SharedLockScope _(m_ProjectsLock); + Projects.reserve(m_Projects.size()); + + for (auto& Kv : m_Projects) + { + if (Kv.second->IsExpired()) + { + ExpiredProjectCount++; + continue; + } + Projects.push_back(Kv.second); + } + } + ProjectCount = Projects.size(); + for (const Ref<Project>& Project : Projects) + { + Project->GatherReferences(GcCtx); + } +} + +void +ProjectStore::CollectGarbage(GcContext& GcCtx) +{ + ZEN_TRACE_CPU("ProjectStore::CollectGarbage"); + + size_t ProjectCount = 0; + size_t ExpiredProjectCount = 0; + + Stopwatch Timer; + const auto Guard = MakeGuard([&] { + ZEN_DEBUG("garbage collect from '{}' DONE after {}, found {} active projects and {} expired projects", + m_ProjectBasePath.string(), + NiceTimeSpanMs(Timer.GetElapsedTimeMs()), + ProjectCount, + ExpiredProjectCount); + }); + std::vector<Ref<Project>> ExpiredProjects; + std::vector<Ref<Project>> Projects; + + { + RwLock::SharedLockScope _(m_ProjectsLock); + for (auto& Kv : m_Projects) + { + if (Kv.second->IsExpired()) + { + ExpiredProjects.push_back(Kv.second); + ExpiredProjectCount++; + continue; + } + Projects.push_back(Kv.second); + ProjectCount++; + } + } + + if (!GcCtx.IsDeletionMode()) + { + ZEN_DEBUG("garbage collect DISABLED, for '{}' ", m_ProjectBasePath.string()); + return; + } + + for (const Ref<Project>& Project : Projects) + { + std::vector<std::string> ExpiredOplogs; + { + RwLock::ExclusiveLockScope _(m_ProjectsLock); + Project->IterateOplogs([&ExpiredOplogs](ProjectStore::Oplog& Oplog) { + if (Oplog.IsExpired()) + { + ExpiredOplogs.push_back(Oplog.OplogId()); + } + }); + } + for (const std::string& OplogId : ExpiredOplogs) + { + ZEN_DEBUG("ProjectStore::CollectGarbage garbage collected oplog '{}' in project '{}'. Removing storage on disk", + OplogId, + Project->Identifier); + Project->DeleteOplog(OplogId); + } + } + + if (ExpiredProjects.empty()) + { + ZEN_DEBUG("garbage collect for '{}', no expired projects found", m_ProjectBasePath.string()); + return; + } + + for (const Ref<Project>& Project : ExpiredProjects) + { + std::filesystem::path PathToRemove; + std::string ProjectId; + { + RwLock::ExclusiveLockScope _(m_ProjectsLock); + if (!Project->IsExpired()) + { + ZEN_DEBUG("ProjectStore::CollectGarbage skipped garbage collect of project '{}'. Project no longer expired.", ProjectId); + continue; + } + bool Success = Project->PrepareForDelete(PathToRemove); + if (!Success) + { + ZEN_DEBUG("ProjectStore::CollectGarbage skipped garbage collect of project '{}'. Project folder is locked.", ProjectId); + continue; + } + m_Projects.erase(Project->Identifier); + ProjectId = Project->Identifier; + } + + ZEN_DEBUG("ProjectStore::CollectGarbage garbage collected project '{}'. Removing storage on disk", ProjectId); + if (PathToRemove.empty()) + { + continue; + } + + DeleteDirectories(PathToRemove); + } +} + +GcStorageSize +ProjectStore::StorageSize() const +{ + GcStorageSize Result; + { + RwLock::SharedLockScope _(m_ProjectsLock); + for (auto& Kv : m_Projects) + { + const Ref<Project>& Project = Kv.second; + Result.DiskSize += Project->TotalSize(); + } + } + return Result; +} + +Ref<ProjectStore::Project> +ProjectStore::OpenProject(std::string_view ProjectId) +{ + { + RwLock::SharedLockScope _(m_ProjectsLock); + + auto ProjIt = m_Projects.find(std::string{ProjectId}); + + if (ProjIt != m_Projects.end()) + { + return ProjIt->second; + } + } + + RwLock::ExclusiveLockScope _(m_ProjectsLock); + + std::filesystem::path BasePath = BasePathForProject(ProjectId); + + if (Project::Exists(BasePath)) + { + try + { + ZEN_INFO("opening project {} @ {}", ProjectId, BasePath); + + Ref<Project>& Prj = + m_Projects + .try_emplace(std::string{ProjectId}, Ref<ProjectStore::Project>(new ProjectStore::Project(this, m_CidStore, BasePath))) + .first->second; + Prj->Identifier = ProjectId; + Prj->Read(); + return Prj; + } + catch (std::exception& e) + { + ZEN_WARN("failed to open {} @ {} ({})", ProjectId, BasePath, e.what()); + m_Projects.erase(std::string{ProjectId}); + } + } + + return {}; +} + +Ref<ProjectStore::Project> +ProjectStore::NewProject(std::filesystem::path BasePath, + std::string_view ProjectId, + std::string_view RootDir, + std::string_view EngineRootDir, + std::string_view ProjectRootDir, + std::string_view ProjectFilePath) +{ + RwLock::ExclusiveLockScope _(m_ProjectsLock); + + Ref<Project>& Prj = + m_Projects.try_emplace(std::string{ProjectId}, Ref<ProjectStore::Project>(new ProjectStore::Project(this, m_CidStore, BasePath))) + .first->second; + Prj->Identifier = ProjectId; + Prj->RootDir = RootDir; + Prj->EngineRootDir = EngineRootDir; + Prj->ProjectRootDir = ProjectRootDir; + Prj->ProjectFilePath = ProjectFilePath; + Prj->Write(); + + return Prj; +} + +bool +ProjectStore::DeleteProject(std::string_view ProjectId) +{ + ZEN_INFO("deleting project {}", ProjectId); + + RwLock::ExclusiveLockScope ProjectsLock(m_ProjectsLock); + + auto ProjIt = m_Projects.find(std::string{ProjectId}); + + if (ProjIt == m_Projects.end()) + { + return true; + } + + std::filesystem::path DeletePath; + bool Success = ProjIt->second->PrepareForDelete(DeletePath); + + if (!Success) + { + return false; + } + m_Projects.erase(ProjIt); + ProjectsLock.ReleaseNow(); + + if (!DeletePath.empty()) + { + DeleteDirectories(DeletePath); + } + return true; +} + +bool +ProjectStore::Exists(std::string_view ProjectId) +{ + return Project::Exists(BasePathForProject(ProjectId)); +} + +CbArray +ProjectStore::GetProjectsList() +{ + using namespace std::literals; + + DiscoverProjects(); + + CbWriter Response; + Response.BeginArray(); + + IterateProjects([&Response](ProjectStore::Project& Prj) { + Response.BeginObject(); + Response << "Id"sv << Prj.Identifier; + Response << "RootDir"sv << Prj.RootDir.string(); + Response << "ProjectRootDir"sv << Prj.ProjectRootDir; + Response << "EngineRootDir"sv << Prj.EngineRootDir; + Response << "ProjectFilePath"sv << Prj.ProjectFilePath; + Response.EndObject(); + }); + Response.EndArray(); + return Response.Save().AsArray(); +} + +std::pair<HttpResponseCode, std::string> +ProjectStore::GetProjectFiles(const std::string_view ProjectId, const std::string_view OplogId, bool FilterClient, CbObject& OutPayload) +{ + using namespace std::literals; + + Ref<ProjectStore::Project> Project = OpenProject(ProjectId); + if (!Project) + { + return {HttpResponseCode::NotFound, fmt::format("Project files request for unknown project '{}'", ProjectId)}; + } + + ProjectStore::Oplog* FoundLog = Project->OpenOplog(OplogId); + + if (!FoundLog) + { + return {HttpResponseCode::NotFound, fmt::format("Project files for unknown oplog '{}/{}'", ProjectId, OplogId)}; + } + + CbObjectWriter Response; + Response.BeginArray("files"sv); + + FoundLog->IterateFileMap([&](const Oid& Id, const std::string_view& ServerPath, const std::string_view& ClientPath) { + Response.BeginObject(); + Response << "id"sv << Id; + Response << "clientpath"sv << ClientPath; + if (!FilterClient) + { + Response << "serverpath"sv << ServerPath; + } + Response.EndObject(); + }); + + Response.EndArray(); + OutPayload = Response.Save(); + return {HttpResponseCode::OK, {}}; +} + +std::pair<HttpResponseCode, std::string> +ProjectStore::GetChunkInfo(const std::string_view ProjectId, + const std::string_view OplogId, + const std::string_view ChunkId, + CbObject& OutPayload) +{ + using namespace std::literals; + + Ref<ProjectStore::Project> Project = OpenProject(ProjectId); + if (!Project) + { + return {HttpResponseCode::NotFound, fmt::format("Chunk info request for unknown project '{}'", ProjectId)}; + } + + ProjectStore::Oplog* FoundLog = Project->OpenOplog(OplogId); + + if (!FoundLog) + { + return {HttpResponseCode::NotFound, fmt::format("Chunk info request for unknown oplog '{}/{}'", ProjectId, OplogId)}; + } + if (ChunkId.size() != 2 * sizeof(Oid::OidBits)) + { + return {HttpResponseCode::BadRequest, + fmt::format("Chunk info request for invalid chunk id '{}/{}'/'{}'", ProjectId, OplogId, ChunkId)}; + } + + const Oid Obj = Oid::FromHexString(ChunkId); + + IoBuffer Chunk = FoundLog->FindChunk(Obj); + if (!Chunk) + { + return {HttpResponseCode::NotFound, {}}; + } + + uint64_t ChunkSize = Chunk.GetSize(); + if (Chunk.GetContentType() == HttpContentType::kCompressedBinary) + { + IoHash RawHash; + uint64_t RawSize; + bool IsCompressed = CompressedBuffer::ValidateCompressedHeader(Chunk, RawHash, RawSize); + ZEN_ASSERT(IsCompressed); + ChunkSize = RawSize; + } + + CbObjectWriter Response; + Response << "size"sv << ChunkSize; + OutPayload = Response.Save(); + return {HttpResponseCode::OK, {}}; +} + +std::pair<HttpResponseCode, std::string> +ProjectStore::GetChunkRange(const std::string_view ProjectId, + const std::string_view OplogId, + const std::string_view ChunkId, + uint64_t Offset, + uint64_t Size, + ZenContentType AcceptType, + IoBuffer& OutChunk) +{ + bool IsOffset = Offset != 0 || Size != ~(0ull); + + Ref<ProjectStore::Project> Project = OpenProject(ProjectId); + if (!Project) + { + return {HttpResponseCode::NotFound, fmt::format("Chunk request for unknown project '{}'", ProjectId)}; + } + + ProjectStore::Oplog* FoundLog = Project->OpenOplog(OplogId); + + if (!FoundLog) + { + return {HttpResponseCode::NotFound, fmt::format("Chunk request for unknown oplog '{}/{}'", ProjectId, OplogId)}; + } + + if (ChunkId.size() != 2 * sizeof(Oid::OidBits)) + { + return {HttpResponseCode::BadRequest, fmt::format("Chunk request for invalid chunk id '{}/{}'/'{}'", ProjectId, OplogId, ChunkId)}; + } + + const Oid Obj = Oid::FromHexString(ChunkId); + + IoBuffer Chunk = FoundLog->FindChunk(Obj); + if (!Chunk) + { + return {HttpResponseCode::NotFound, {}}; + } + + OutChunk = Chunk; + HttpContentType ContentType = Chunk.GetContentType(); + + if (Chunk.GetContentType() == HttpContentType::kCompressedBinary) + { + IoHash RawHash; + uint64_t RawSize; + CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(std::move(Chunk)), RawHash, RawSize); + ZEN_ASSERT(!Compressed.IsNull()); + + if (IsOffset) + { + if ((Offset + Size) > RawSize) + { + Size = RawSize - Offset; + } + + if (AcceptType == HttpContentType::kBinary) + { + OutChunk = Compressed.Decompress(Offset, Size).AsIoBuffer(); + OutChunk.SetContentType(HttpContentType::kBinary); + } + else + { + // Value will be a range of compressed blocks that covers the requested range + // The client will have to compensate for any offsets that do not land on an even block size multiple + OutChunk = Compressed.CopyRange(Offset, Size).GetCompressed().Flatten().AsIoBuffer(); + OutChunk.SetContentType(HttpContentType::kCompressedBinary); + } + } + else + { + if (AcceptType == HttpContentType::kBinary) + { + OutChunk = Compressed.Decompress().AsIoBuffer(); + OutChunk.SetContentType(HttpContentType::kBinary); + } + else + { + OutChunk = Compressed.GetCompressed().Flatten().AsIoBuffer(); + OutChunk.SetContentType(HttpContentType::kCompressedBinary); + } + } + } + else if (IsOffset) + { + if ((Offset + Size) > Chunk.GetSize()) + { + Size = Chunk.GetSize() - Offset; + } + OutChunk = IoBuffer(std::move(Chunk), Offset, Size); + OutChunk.SetContentType(ContentType); + } + + return {HttpResponseCode::OK, {}}; +} + +std::pair<HttpResponseCode, std::string> +ProjectStore::GetChunk(const std::string_view ProjectId, + const std::string_view OplogId, + const std::string_view Cid, + ZenContentType AcceptType, + IoBuffer& OutChunk) +{ + Ref<ProjectStore::Project> Project = OpenProject(ProjectId); + if (!Project) + { + return {HttpResponseCode::NotFound, fmt::format("Chunk request for unknown project '{}'", ProjectId)}; + } + + ProjectStore::Oplog* FoundLog = Project->OpenOplog(OplogId); + + if (!FoundLog) + { + return {HttpResponseCode::NotFound, fmt::format("Chunk request for unknown oplog '{}/{}'", ProjectId, OplogId)}; + } + + if (Cid.length() != IoHash::StringLength) + { + return {HttpResponseCode::BadRequest, fmt::format("Chunk request for invalid chunk id '{}/{}'/'{}'", ProjectId, OplogId, Cid)}; + } + + const IoHash Hash = IoHash::FromHexString(Cid); + OutChunk = m_CidStore.FindChunkByCid(Hash); + + if (!OutChunk) + { + return {HttpResponseCode::NotFound, fmt::format("chunk - '{}' MISSING", Cid)}; + } + + if (AcceptType == ZenContentType::kUnknownContentType || AcceptType == ZenContentType::kBinary) + { + CompressedBuffer Compressed = CompressedBuffer::FromCompressedNoValidate(std::move(OutChunk)); + OutChunk = Compressed.Decompress().AsIoBuffer(); + OutChunk.SetContentType(ZenContentType::kBinary); + } + else + { + OutChunk.SetContentType(ZenContentType::kCompressedBinary); + } + return {HttpResponseCode::OK, {}}; +} + +std::pair<HttpResponseCode, std::string> +ProjectStore::PutChunk(const std::string_view ProjectId, + const std::string_view OplogId, + const std::string_view Cid, + ZenContentType ContentType, + IoBuffer&& Chunk) +{ + Ref<ProjectStore::Project> Project = OpenProject(ProjectId); + if (!Project) + { + return {HttpResponseCode::NotFound, fmt::format("Chunk put request for unknown project '{}'", ProjectId)}; + } + + ProjectStore::Oplog* FoundLog = Project->OpenOplog(OplogId); + + if (!FoundLog) + { + return {HttpResponseCode::NotFound, fmt::format("Chunk put request for unknown oplog '{}/{}'", ProjectId, OplogId)}; + } + + if (Cid.length() != IoHash::StringLength) + { + return {HttpResponseCode::BadRequest, fmt::format("Chunk put request for invalid chunk hash '{}'", Cid)}; + } + + const IoHash Hash = IoHash::FromHexString(Cid); + + if (ContentType != HttpContentType::kCompressedBinary) + { + return {HttpResponseCode::BadRequest, fmt::format("Chunk request for invalid content type for chunk '{}'", Cid)}; + } + IoHash RawHash; + uint64_t RawSize; + CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(Chunk), RawHash, RawSize); + if (RawHash != Hash) + { + return {HttpResponseCode::BadRequest, fmt::format("Chunk request for invalid payload format for chunk '{}'", Cid)}; + } + + CidStore::InsertResult Result = m_CidStore.AddChunk(Chunk, Hash); + return {Result.New ? HttpResponseCode::Created : HttpResponseCode::OK, {}}; +} + +std::pair<HttpResponseCode, std::string> +ProjectStore::WriteOplog(const std::string_view ProjectId, const std::string_view OplogId, IoBuffer&& Payload, CbObject& OutResponse) +{ + Ref<ProjectStore::Project> Project = OpenProject(ProjectId); + if (!Project) + { + return {HttpResponseCode::NotFound, fmt::format("Write oplog request for unknown project '{}'", ProjectId)}; + } + + ProjectStore::Oplog* Oplog = Project->OpenOplog(OplogId); + + if (!Oplog) + { + return {HttpResponseCode::NotFound, fmt::format("Write oplog request for unknown oplog '{}/{}'", ProjectId, OplogId)}; + } + + CbObject ContainerObject = LoadCompactBinaryObject(Payload); + if (!ContainerObject) + { + return {HttpResponseCode::BadRequest, "Invalid payload format"}; + } + + CidStore& ChunkStore = m_CidStore; + RwLock AttachmentsLock; + std::unordered_set<IoHash, IoHash::Hasher> Attachments; + + auto HasAttachment = [&ChunkStore](const IoHash& RawHash) { return ChunkStore.ContainsChunk(RawHash); }; + auto OnNeedBlock = [&AttachmentsLock, &Attachments](const IoHash& BlockHash, const std::vector<IoHash>&& ChunkHashes) { + RwLock::ExclusiveLockScope _(AttachmentsLock); + if (BlockHash != IoHash::Zero) + { + Attachments.insert(BlockHash); + } + else + { + Attachments.insert(ChunkHashes.begin(), ChunkHashes.end()); + } + }; + auto OnNeedAttachment = [&AttachmentsLock, &Attachments](const IoHash& RawHash) { + RwLock::ExclusiveLockScope _(AttachmentsLock); + Attachments.insert(RawHash); + }; + + RemoteProjectStore::Result RemoteResult = SaveOplogContainer(*Oplog, ContainerObject, HasAttachment, OnNeedBlock, OnNeedAttachment); + + if (RemoteResult.ErrorCode) + { + return ConvertResult(RemoteResult); + } + + CbObjectWriter Cbo; + Cbo.BeginArray("need"); + { + for (const IoHash& Hash : Attachments) + { + ZEN_DEBUG("Need attachment {}", Hash); + Cbo << Hash; + } + } + Cbo.EndArray(); // "need" + + OutResponse = Cbo.Save(); + return {HttpResponseCode::OK, {}}; +} + +std::pair<HttpResponseCode, std::string> +ProjectStore::ReadOplog(const std::string_view ProjectId, + const std::string_view OplogId, + const HttpServerRequest::QueryParams& Params, + CbObject& OutResponse) +{ + Ref<ProjectStore::Project> Project = OpenProject(ProjectId); + if (!Project) + { + return {HttpResponseCode::NotFound, fmt::format("Read oplog request for unknown project '{}'", ProjectId)}; + } + + ProjectStore::Oplog* Oplog = Project->OpenOplog(OplogId); + + if (!Oplog) + { + return {HttpResponseCode::NotFound, fmt::format("Read oplog request for unknown oplog '{}/{}'", ProjectId, OplogId)}; + } + + size_t MaxBlockSize = 128u * 1024u * 1024u; + if (auto Param = Params.GetValue("maxblocksize"); Param.empty() == false) + { + if (auto Value = ParseInt<size_t>(Param)) + { + MaxBlockSize = Value.value(); + } + } + size_t MaxChunkEmbedSize = 1024u * 1024u; + if (auto Param = Params.GetValue("maxchunkembedsize"); Param.empty() == false) + { + if (auto Value = ParseInt<size_t>(Param)) + { + MaxChunkEmbedSize = Value.value(); + } + } + + CidStore& ChunkStore = m_CidStore; + + RemoteProjectStore::LoadContainerResult ContainerResult = BuildContainer( + ChunkStore, + *Oplog, + MaxBlockSize, + MaxChunkEmbedSize, + false, + [](CompressedBuffer&&, const IoHash) {}, + [](const IoHash&) {}, + [](const std::unordered_set<IoHash, IoHash::Hasher>) {}); + + OutResponse = std::move(ContainerResult.ContainerObject); + return ConvertResult(ContainerResult); +} + +std::pair<HttpResponseCode, std::string> +ProjectStore::WriteBlock(const std::string_view ProjectId, const std::string_view OplogId, IoBuffer&& Payload) +{ + Ref<ProjectStore::Project> Project = OpenProject(ProjectId); + if (!Project) + { + return {HttpResponseCode::NotFound, fmt::format("Write block request for unknown project '{}'", ProjectId)}; + } + + ProjectStore::Oplog* Oplog = Project->OpenOplog(OplogId); + + if (!Oplog) + { + return {HttpResponseCode::NotFound, fmt::format("Write block request for unknown oplog '{}/{}'", ProjectId, OplogId)}; + } + + if (!IterateBlock(std::move(Payload), [this](CompressedBuffer&& Chunk, const IoHash& AttachmentRawHash) { + IoBuffer Compressed = Chunk.GetCompressed().Flatten().AsIoBuffer(); + m_CidStore.AddChunk(Compressed, AttachmentRawHash); + ZEN_DEBUG("Saved attachment {} from block, size {}", AttachmentRawHash, Compressed.GetSize()); + })) + { + return {HttpResponseCode::BadRequest, "Invalid chunk in block"}; + } + + return {HttpResponseCode::OK, {}}; +} + +void +ProjectStore::Rpc(HttpServerRequest& HttpReq, + const std::string_view ProjectId, + const std::string_view OplogId, + IoBuffer&& Payload, + AuthMgr& AuthManager) +{ + using namespace std::literals; + HttpContentType PayloadContentType = HttpReq.RequestContentType(); + CbPackage Package; + CbObject Cb; + switch (PayloadContentType) + { + case HttpContentType::kJSON: + case HttpContentType::kUnknownContentType: + case HttpContentType::kText: + { + std::string JsonText(reinterpret_cast<const char*>(Payload.GetData()), Payload.GetSize()); + Cb = LoadCompactBinaryFromJson(JsonText).AsObject(); + if (!Cb) + { + return HttpReq.WriteResponse(HttpResponseCode::BadRequest, + HttpContentType::kText, + "Content format not supported, expected JSON format"); + } + } + break; + case HttpContentType::kCbObject: + Cb = LoadCompactBinaryObject(Payload); + if (!Cb) + { + return HttpReq.WriteResponse(HttpResponseCode::BadRequest, + HttpContentType::kText, + "Content format not supported, expected compact binary format"); + } + break; + case HttpContentType::kCbPackage: + Package = ParsePackageMessage(Payload); + Cb = Package.GetObject(); + if (!Cb) + { + return HttpReq.WriteResponse(HttpResponseCode::BadRequest, + HttpContentType::kText, + "Content format not supported, expected package message format"); + } + break; + default: + return HttpReq.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "Invalid request content type"); + } + + Ref<ProjectStore::Project> Project = OpenProject(ProjectId); + if (!Project) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound, + HttpContentType::kText, + fmt::format("Rpc oplog request for unknown project '{}'", ProjectId)); + } + + ProjectStore::Oplog* Oplog = Project->OpenOplog(OplogId); + + if (!Oplog) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound, + HttpContentType::kText, + fmt::format("Rpc oplog request for unknown oplog '{}/{}'", ProjectId, OplogId)); + } + + std::string_view Method = Cb["method"sv].AsString(); + + if (Method == "import") + { + std::pair<HttpResponseCode, std::string> Result = Import(*Project.Get(), *Oplog, Cb["params"sv].AsObjectView(), AuthManager); + if (Result.second.empty()) + { + return HttpReq.WriteResponse(Result.first); + } + return HttpReq.WriteResponse(Result.first, HttpContentType::kText, Result.second); + } + else if (Method == "export") + { + std::pair<HttpResponseCode, std::string> Result = Export(*Project.Get(), *Oplog, Cb["params"sv].AsObjectView(), AuthManager); + if (Result.second.empty()) + { + return HttpReq.WriteResponse(Result.first); + } + return HttpReq.WriteResponse(Result.first, HttpContentType::kText, Result.second); + } + else if (Method == "getchunks") + { + CbPackage ResponsePackage; + { + CbArrayView ChunksArray = Cb["chunks"sv].AsArrayView(); + CbObjectWriter ResponseWriter; + ResponseWriter.BeginArray("chunks"sv); + for (CbFieldView FieldView : ChunksArray) + { + IoHash RawHash = FieldView.AsHash(); + IoBuffer ChunkBuffer = m_CidStore.FindChunkByCid(RawHash); + if (ChunkBuffer) + { + ResponseWriter.AddHash(RawHash); + ResponsePackage.AddAttachment( + CbAttachment(CompressedBuffer::FromCompressedNoValidate(std::move(ChunkBuffer)), RawHash)); + } + } + ResponseWriter.EndArray(); + ResponsePackage.SetObject(ResponseWriter.Save()); + } + CompositeBuffer RpcResponseBuffer = FormatPackageMessageBuffer(ResponsePackage, FormatFlags::kDefault); + return HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kCbPackage, RpcResponseBuffer); + } + else if (Method == "putchunks") + { + std::span<const CbAttachment> Attachments = Package.GetAttachments(); + for (const CbAttachment& Attachment : Attachments) + { + IoHash RawHash = Attachment.GetHash(); + CompressedBuffer Compressed = Attachment.AsCompressedBinary(); + m_CidStore.AddChunk(Compressed.GetCompressed().Flatten().AsIoBuffer(), RawHash, CidStore::InsertMode::kCopyOnly); + } + return HttpReq.WriteResponse(HttpResponseCode::OK); + } + return HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, fmt::format("Unknown rpc method '{}'", Method)); +} + +std::pair<HttpResponseCode, std::string> +ProjectStore::Export(ProjectStore::Project& Project, ProjectStore::Oplog& Oplog, CbObjectView&& Params, AuthMgr& AuthManager) +{ + using namespace std::literals; + + size_t MaxBlockSize = Params["maxblocksize"sv].AsUInt64(128u * 1024u * 1024u); + size_t MaxChunkEmbedSize = Params["maxchunkembedsize"sv].AsUInt64(1024u * 1024u); + bool Force = Params["force"sv].AsBool(false); + + std::pair<std::unique_ptr<RemoteProjectStore>, std::string> RemoteStoreResult = + CreateRemoteStore(Params, AuthManager, MaxBlockSize, MaxChunkEmbedSize); + + if (RemoteStoreResult.first == nullptr) + { + return {HttpResponseCode::BadRequest, RemoteStoreResult.second}; + } + std::unique_ptr<RemoteProjectStore> RemoteStore = std::move(RemoteStoreResult.first); + RemoteProjectStore::RemoteStoreInfo StoreInfo = RemoteStore->GetInfo(); + + ZEN_INFO("Saving oplog '{}/{}' to {}, maxblocksize {}, maxchunkembedsize {}", + Project.Identifier, + Oplog.OplogId(), + StoreInfo.Description, + NiceBytes(MaxBlockSize), + NiceBytes(MaxChunkEmbedSize)); + + RemoteProjectStore::Result Result = SaveOplog(m_CidStore, + *RemoteStore, + Oplog, + MaxBlockSize, + MaxChunkEmbedSize, + StoreInfo.CreateBlocks, + StoreInfo.UseTempBlockFiles, + Force); + + return ConvertResult(Result); +} + +std::pair<HttpResponseCode, std::string> +ProjectStore::Import(ProjectStore::Project& Project, ProjectStore::Oplog& Oplog, CbObjectView&& Params, AuthMgr& AuthManager) +{ + using namespace std::literals; + + size_t MaxBlockSize = Params["maxblocksize"sv].AsUInt64(128u * 1024u * 1024u); + size_t MaxChunkEmbedSize = Params["maxchunkembedsize"sv].AsUInt64(1024u * 1024u); + bool Force = Params["force"sv].AsBool(false); + + std::pair<std::unique_ptr<RemoteProjectStore>, std::string> RemoteStoreResult = + CreateRemoteStore(Params, AuthManager, MaxBlockSize, MaxChunkEmbedSize); + + if (RemoteStoreResult.first == nullptr) + { + return {HttpResponseCode::BadRequest, RemoteStoreResult.second}; + } + std::unique_ptr<RemoteProjectStore> RemoteStore = std::move(RemoteStoreResult.first); + RemoteProjectStore::RemoteStoreInfo StoreInfo = RemoteStore->GetInfo(); + + ZEN_INFO("Loading oplog '{}/{}' from {}", Project.Identifier, Oplog.OplogId(), StoreInfo.Description); + RemoteProjectStore::Result Result = LoadOplog(m_CidStore, *RemoteStore, Oplog, Force); + return ConvertResult(Result); +} + +////////////////////////////////////////////////////////////////////////// + +HttpProjectService::HttpProjectService(CidStore& Store, ProjectStore* Projects, HttpStatsService& StatsService, AuthMgr& AuthMgr) +: m_Log(logging::Get("project")) +, m_CidStore(Store) +, m_ProjectStore(Projects) +, m_StatsService(StatsService) +, m_AuthMgr(AuthMgr) +{ + using namespace std::literals; + + m_StatsService.RegisterHandler("prj", *this); + + m_Router.AddPattern("project", "([[:alnum:]_.]+)"); + m_Router.AddPattern("log", "([[:alnum:]_.]+)"); + m_Router.AddPattern("op", "([[:digit:]]+?)"); + m_Router.AddPattern("chunk", "([[:xdigit:]]{24})"); + m_Router.AddPattern("hash", "([[:xdigit:]]{40})"); + + m_Router.RegisterRoute( + "", + [this](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK, m_ProjectStore->GetProjectsList()); }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "list", + [this](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK, m_ProjectStore->GetProjectsList()); }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "{project}/oplog/{log}/batch", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const auto& ProjectId = Req.GetCapture(1); + const auto& OplogId = Req.GetCapture(2); + + Ref<ProjectStore::Project> Project = m_ProjectStore->OpenProject(ProjectId); + if (!Project) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + + ProjectStore::Oplog* FoundLog = Project->OpenOplog(OplogId); + + if (!FoundLog) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + + // Parse Request + + IoBuffer Payload = HttpReq.ReadPayload(); + BinaryReader Reader(Payload); + + struct RequestHeader + { + enum + { + kMagic = 0xAAAA'77AC + }; + uint32_t Magic; + uint32_t ChunkCount; + uint32_t Reserved1; + uint32_t Reserved2; + }; + + struct RequestChunkEntry + { + Oid ChunkId; + uint32_t CorrelationId; + uint64_t Offset; + uint64_t RequestBytes; + }; + + if (Payload.Size() <= sizeof(RequestHeader)) + { + HttpReq.WriteResponse(HttpResponseCode::BadRequest); + } + + RequestHeader RequestHdr; + Reader.Read(&RequestHdr, sizeof RequestHdr); + + if (RequestHdr.Magic != RequestHeader::kMagic) + { + HttpReq.WriteResponse(HttpResponseCode::BadRequest); + } + + std::vector<RequestChunkEntry> RequestedChunks; + RequestedChunks.resize(RequestHdr.ChunkCount); + Reader.Read(RequestedChunks.data(), sizeof(RequestChunkEntry) * RequestHdr.ChunkCount); + + // Make Response + + struct ResponseHeader + { + uint32_t Magic = 0xbada'b00f; + uint32_t ChunkCount; + uint32_t Reserved1 = 0; + uint32_t Reserved2 = 0; + }; + + struct ResponseChunkEntry + { + uint32_t CorrelationId; + uint32_t Flags = 0; + uint64_t ChunkSize; + }; + + std::vector<IoBuffer> OutBlobs; + OutBlobs.emplace_back(sizeof(ResponseHeader) + RequestHdr.ChunkCount * sizeof(ResponseChunkEntry)); + for (uint32_t ChunkIndex = 0; ChunkIndex < RequestHdr.ChunkCount; ++ChunkIndex) + { + const RequestChunkEntry& RequestedChunk = RequestedChunks[ChunkIndex]; + IoBuffer FoundChunk = FoundLog->FindChunk(RequestedChunk.ChunkId); + if (FoundChunk) + { + if (RequestedChunk.Offset > 0 || RequestedChunk.RequestBytes < uint64_t(-1)) + { + uint64_t Offset = RequestedChunk.Offset; + if (Offset > FoundChunk.Size()) + { + Offset = FoundChunk.Size(); + } + uint64_t Size = RequestedChunk.RequestBytes; + if ((Offset + Size) > FoundChunk.Size()) + { + Size = FoundChunk.Size() - Offset; + } + FoundChunk = IoBuffer(FoundChunk, Offset, Size); + } + } + OutBlobs.emplace_back(std::move(FoundChunk)); + } + uint8_t* ResponsePtr = reinterpret_cast<uint8_t*>(OutBlobs[0].MutableData()); + ResponseHeader ResponseHdr; + ResponseHdr.ChunkCount = RequestHdr.ChunkCount; + memcpy(ResponsePtr, &ResponseHdr, sizeof(ResponseHdr)); + ResponsePtr += sizeof(ResponseHdr); + for (uint32_t ChunkIndex = 0; ChunkIndex < RequestHdr.ChunkCount; ++ChunkIndex) + { + // const RequestChunkEntry& RequestedChunk = RequestedChunks[ChunkIndex]; + const IoBuffer& FoundChunk(OutBlobs[ChunkIndex + 1]); + ResponseChunkEntry ResponseChunk; + ResponseChunk.CorrelationId = ChunkIndex; + if (FoundChunk) + { + ResponseChunk.ChunkSize = FoundChunk.Size(); + } + else + { + ResponseChunk.ChunkSize = uint64_t(-1); + } + memcpy(ResponsePtr, &ResponseChunk, sizeof(ResponseChunk)); + ResponsePtr += sizeof(ResponseChunk); + } + return HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, OutBlobs); + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "{project}/oplog/{log}/files", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + // File manifest fetch, returns the client file list + + const auto& ProjectId = Req.GetCapture(1); + const auto& OplogId = Req.GetCapture(2); + + HttpServerRequest::QueryParams Params = HttpReq.GetQueryParams(); + + const bool FilterClient = Params.GetValue("filter"sv) == "client"sv; + + CbObject ResponsePayload; + std::pair<HttpResponseCode, std::string> Result = + m_ProjectStore->GetProjectFiles(ProjectId, OplogId, FilterClient, ResponsePayload); + if (Result.first == HttpResponseCode::OK) + { + return HttpReq.WriteResponse(HttpResponseCode::OK, ResponsePayload); + } + else + { + ZEN_DEBUG("Request {}: '{}' failed with {}. Reason: `{}`", + ToString(HttpReq.RequestVerb()), + HttpReq.QueryString(), + static_cast<int>(Result.first), + Result.second); + } + if (Result.second.empty()) + { + return HttpReq.WriteResponse(Result.first); + } + return HttpReq.WriteResponse(Result.first, HttpContentType::kText, Result.second); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "{project}/oplog/{log}/{chunk}/info", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + const auto& ProjectId = Req.GetCapture(1); + const auto& OplogId = Req.GetCapture(2); + const auto& ChunkId = Req.GetCapture(3); + + CbObject ResponsePayload; + std::pair<HttpResponseCode, std::string> Result = m_ProjectStore->GetChunkInfo(ProjectId, OplogId, ChunkId, ResponsePayload); + if (Result.first == HttpResponseCode::OK) + { + return HttpReq.WriteResponse(HttpResponseCode::OK, ResponsePayload); + } + else if (Result.first == HttpResponseCode::NotFound) + { + ZEN_DEBUG("chunk - '{}/{}/{}' MISSING", ProjectId, OplogId, ChunkId); + } + else + { + ZEN_DEBUG("Request {}: '{}' failed with {}. Reason: `{}`", + ToString(HttpReq.RequestVerb()), + HttpReq.QueryString(), + static_cast<int>(Result.first), + Result.second); + } + if (Result.second.empty()) + { + return HttpReq.WriteResponse(Result.first); + } + return HttpReq.WriteResponse(Result.first, HttpContentType::kText, Result.second); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "{project}/oplog/{log}/{chunk}", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + const auto& ProjectId = Req.GetCapture(1); + const auto& OplogId = Req.GetCapture(2); + const auto& ChunkId = Req.GetCapture(3); + + uint64_t Offset = 0; + uint64_t Size = ~(0ull); + + auto QueryParms = Req.ServerRequest().GetQueryParams(); + + if (auto OffsetParm = QueryParms.GetValue("offset"); OffsetParm.empty() == false) + { + if (auto OffsetVal = ParseInt<uint64_t>(OffsetParm)) + { + Offset = OffsetVal.value(); + } + else + { + return HttpReq.WriteResponse(HttpResponseCode::BadRequest); + } + } + + if (auto SizeParm = QueryParms.GetValue("size"); SizeParm.empty() == false) + { + if (auto SizeVal = ParseInt<uint64_t>(SizeParm)) + { + Size = SizeVal.value(); + } + else + { + return HttpReq.WriteResponse(HttpResponseCode::BadRequest); + } + } + + HttpContentType AcceptType = HttpReq.AcceptContentType(); + + IoBuffer Chunk; + std::pair<HttpResponseCode, std::string> Result = + m_ProjectStore->GetChunkRange(ProjectId, OplogId, ChunkId, Offset, Size, AcceptType, Chunk); + if (Result.first == HttpResponseCode::OK) + { + ZEN_DEBUG("chunk - '{}/{}/{}' '{}'", ProjectId, OplogId, ChunkId, ToString(Chunk.GetContentType())); + return HttpReq.WriteResponse(HttpResponseCode::OK, Chunk.GetContentType(), Chunk); + } + else if (Result.first == HttpResponseCode::NotFound) + { + ZEN_DEBUG("chunk - '{}/{}/{}' MISSING", ProjectId, OplogId, ChunkId); + } + else + { + ZEN_DEBUG("Request {}: '{}' failed with {}. Reason: `{}`", + ToString(HttpReq.RequestVerb()), + HttpReq.QueryString(), + static_cast<int>(Result.first), + Result.second); + } + if (Result.second.empty()) + { + return HttpReq.WriteResponse(Result.first); + } + return HttpReq.WriteResponse(Result.first, HttpContentType::kText, Result.second); + }, + HttpVerb::kGet | HttpVerb::kHead); + + m_Router.RegisterRoute( + "{project}/oplog/{log}/{hash}", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + const auto& ProjectId = Req.GetCapture(1); + const auto& OplogId = Req.GetCapture(2); + const auto& Cid = Req.GetCapture(3); + HttpContentType AcceptType = HttpReq.AcceptContentType(); + HttpContentType RequestType = HttpReq.RequestContentType(); + + switch (Req.ServerRequest().RequestVerb()) + { + case HttpVerb::kGet: + { + IoBuffer Value; + std::pair<HttpResponseCode, std::string> Result = + m_ProjectStore->GetChunk(ProjectId, OplogId, Cid, AcceptType, Value); + + if (Result.first == HttpResponseCode::OK) + { + return HttpReq.WriteResponse(HttpResponseCode::OK, Value.GetContentType(), Value); + } + else if (Result.first == HttpResponseCode::NotFound) + { + ZEN_DEBUG("chunk - '{}/{}/{}' MISSING", ProjectId, OplogId, Cid); + } + else + { + ZEN_DEBUG("Request {}: '{}' failed with {}. Reason: `{}`", + ToString(HttpReq.RequestVerb()), + HttpReq.QueryString(), + static_cast<int>(Result.first), + Result.second); + } + if (Result.second.empty()) + { + return HttpReq.WriteResponse(Result.first); + } + return HttpReq.WriteResponse(Result.first, HttpContentType::kText, Result.second); + } + case HttpVerb::kPost: + { + std::pair<HttpResponseCode, std::string> Result = + m_ProjectStore->PutChunk(ProjectId, OplogId, Cid, RequestType, HttpReq.ReadPayload()); + if (Result.first == HttpResponseCode::OK || Result.first == HttpResponseCode::Created) + { + return HttpReq.WriteResponse(Result.first); + } + else + { + ZEN_DEBUG("Request {}: '{}' failed with {}. Reason: `{}`", + ToString(HttpReq.RequestVerb()), + HttpReq.QueryString(), + static_cast<int>(Result.first), + Result.second); + } + if (Result.second.empty()) + { + return HttpReq.WriteResponse(Result.first); + } + return HttpReq.WriteResponse(Result.first, HttpContentType::kText, Result.second); + } + break; + } + }, + HttpVerb::kGet | HttpVerb::kPost); + + m_Router.RegisterRoute( + "{project}/oplog/{log}/prep", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + const auto& ProjectId = Req.GetCapture(1); + const auto& OplogId = Req.GetCapture(2); + + Ref<ProjectStore::Project> Project = m_ProjectStore->OpenProject(ProjectId); + if (!Project) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + + ProjectStore::Oplog* FoundLog = Project->OpenOplog(OplogId); + + if (!FoundLog) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + + // This operation takes a list of referenced hashes and decides which + // chunks are not present on this server. This list is then returned in + // the "need" list in the response + + IoBuffer Payload = HttpReq.ReadPayload(); + CbObject RequestObject = LoadCompactBinaryObject(Payload); + + std::vector<IoHash> NeedList; + + for (auto Entry : RequestObject["have"sv]) + { + const IoHash FileHash = Entry.AsHash(); + + if (!m_CidStore.ContainsChunk(FileHash)) + { + ZEN_DEBUG("prep - NEED: {}", FileHash); + + NeedList.push_back(FileHash); + } + } + + CbObjectWriter Cbo; + Cbo.BeginArray("need"); + + for (const IoHash& Hash : NeedList) + { + Cbo << Hash; + } + + Cbo.EndArray(); + CbObject Response = Cbo.Save(); + + return HttpReq.WriteResponse(HttpResponseCode::OK, Response); + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "{project}/oplog/{log}/new", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + const auto& ProjectId = Req.GetCapture(1); + const auto& OplogId = Req.GetCapture(2); + + HttpServerRequest::QueryParams Params = HttpReq.GetQueryParams(); + + bool IsUsingSalt = false; + IoHash SaltHash = IoHash::Zero; + + if (std::string_view SaltParam = Params.GetValue("salt"); SaltParam.empty() == false) + { + const uint32_t Salt = std::stoi(std::string(SaltParam)); + SaltHash = IoHash::HashBuffer(&Salt, sizeof Salt); + IsUsingSalt = true; + } + + Ref<ProjectStore::Project> Project = m_ProjectStore->OpenProject(ProjectId); + if (!Project) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + + ProjectStore::Oplog* FoundLog = Project->OpenOplog(OplogId); + + if (!FoundLog) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + + ProjectStore::Oplog& Oplog = *FoundLog; + + IoBuffer Payload = HttpReq.ReadPayload(); + + // This will attempt to open files which may not exist for the case where + // the prep step rejected the chunk. This should be fixed since there's + // a performance cost associated with any file system activity + + bool IsValid = true; + std::vector<IoHash> MissingChunks; + + CbPackage::AttachmentResolver Resolver = [&](const IoHash& Hash) -> SharedBuffer { + if (m_CidStore.ContainsChunk(Hash)) + { + // Return null attachment as we already have it, no point in reading it and storing it again + return {}; + } + + IoHash AttachmentId; + if (IsUsingSalt) + { + IoHash AttachmentSpec[]{SaltHash, Hash}; + AttachmentId = IoHash::HashBuffer(MakeMemoryView(AttachmentSpec)); + } + else + { + AttachmentId = Hash; + } + + std::filesystem::path AttachmentPath = Oplog.TempPath() / AttachmentId.ToHexString(); + if (IoBuffer Data = IoBufferBuilder::MakeFromTemporaryFile(AttachmentPath)) + { + return SharedBuffer(std::move(Data)); + } + else + { + IsValid = false; + MissingChunks.push_back(Hash); + + return {}; + } + }; + + CbPackage Package; + + if (!legacy::TryLoadCbPackage(Package, Payload, &UniqueBuffer::Alloc, &Resolver)) + { + std::filesystem::path BadPackagePath = + Oplog.TempPath() / "bad_packages"sv / fmt::format("session{}_request{}"sv, HttpReq.SessionId(), HttpReq.RequestId()); + + ZEN_WARN("Received malformed package! Saving payload to '{}'", BadPackagePath); + + WriteFile(BadPackagePath, Payload); + + return HttpReq.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "Invalid package"); + } + + if (!IsValid) + { + // TODO: emit diagnostics identifying missing chunks + + return HttpReq.WriteResponse(HttpResponseCode::NotFound, HttpContentType::kText, "Missing chunk reference"); + } + + CbObject Core = Package.GetObject(); + + if (!Core["key"sv]) + { + return HttpReq.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "No oplog entry key specified"); + } + + // Write core to oplog + + const uint32_t OpLsn = Oplog.AppendNewOplogEntry(Package); + + if (OpLsn == ProjectStore::Oplog::kInvalidOp) + { + return HttpReq.WriteResponse(HttpResponseCode::BadRequest); + } + + ZEN_DEBUG("'{}/{}' op #{} ({}) - '{}'", ProjectId, OplogId, OpLsn, NiceBytes(Payload.Size()), Core["key"sv].AsString()); + + HttpReq.WriteResponse(HttpResponseCode::Created); + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "{project}/oplog/{log}/{op}", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + const std::string& ProjectId = Req.GetCapture(1); + const std::string& OplogId = Req.GetCapture(2); + const std::string& OpIdString = Req.GetCapture(3); + + Ref<ProjectStore::Project> Project = m_ProjectStore->OpenProject(ProjectId); + if (!Project) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + + ProjectStore::Oplog* FoundLog = Project->OpenOplog(OplogId); + + if (!FoundLog) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + + ProjectStore::Oplog& Oplog = *FoundLog; + + if (const std::optional<int32_t> OpId = ParseInt<uint32_t>(OpIdString)) + { + if (std::optional<CbObject> MaybeOp = Oplog.GetOpByIndex(OpId.value())) + { + CbObject& Op = MaybeOp.value(); + if (Req.ServerRequest().AcceptContentType() == ZenContentType::kCbPackage) + { + CbPackage Package; + Package.SetObject(Op); + + Op.IterateAttachments([&](CbFieldView FieldView) { + const IoHash AttachmentHash = FieldView.AsAttachment(); + IoBuffer Payload = m_CidStore.FindChunkByCid(AttachmentHash); + + // We force this for now as content type is not consistently tracked (will + // be fixed in CidStore refactor) + Payload.SetContentType(ZenContentType::kCompressedBinary); + + if (Payload) + { + switch (Payload.GetContentType()) + { + case ZenContentType::kCbObject: + if (CbObject Object = LoadCompactBinaryObject(Payload)) + { + Package.AddAttachment(CbAttachment(Object)); + } + else + { + // Error - malformed object + + ZEN_WARN("malformed object returned for {}", AttachmentHash); + } + break; + + case ZenContentType::kCompressedBinary: + if (CompressedBuffer Compressed = CompressedBuffer::FromCompressedNoValidate(std::move(Payload))) + { + Package.AddAttachment(CbAttachment(Compressed, AttachmentHash)); + } + else + { + // Error - not compressed! + + ZEN_WARN("invalid compressed binary returned for {}", AttachmentHash); + } + break; + + default: + Package.AddAttachment(CbAttachment(SharedBuffer(Payload))); + break; + } + } + }); + + return HttpReq.WriteResponse(HttpResponseCode::Accepted, Package); + } + else + { + // Client cannot accept a package, so we only send the core object + return HttpReq.WriteResponse(HttpResponseCode::Accepted, Op); + } + } + } + + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "{project}/oplog/{log}", + [this](HttpRouterRequest& Req) { + const auto& ProjectId = Req.GetCapture(1); + const auto& OplogId = Req.GetCapture(2); + + Ref<ProjectStore::Project> Project = m_ProjectStore->OpenProject(ProjectId); + + if (!Project) + { + return Req.ServerRequest().WriteResponse(HttpResponseCode::NotFound, + HttpContentType::kText, + fmt::format("project {} not found", ProjectId)); + } + + switch (Req.ServerRequest().RequestVerb()) + { + case HttpVerb::kGet: + { + ProjectStore::Oplog* OplogIt = Project->OpenOplog(OplogId); + + if (!OplogIt) + { + return Req.ServerRequest().WriteResponse(HttpResponseCode::NotFound, + HttpContentType::kText, + fmt::format("oplog {} not found in project {}", OplogId, ProjectId)); + } + + ProjectStore::Oplog& Log = *OplogIt; + + CbObjectWriter Cb; + Cb << "id"sv << Log.OplogId() << "project"sv << Project->Identifier << "tempdir"sv << Log.TempPath().c_str() + << "markerpath"sv << Log.MarkerPath().c_str() << "totalsize"sv << Log.TotalSize() << "opcount" + << Log.OplogCount() << "expired"sv << Log.IsExpired(); + + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Cb.Save()); + } + break; + + case HttpVerb::kPost: + { + std::filesystem::path OplogMarkerPath; + if (CbObject Params = Req.ServerRequest().ReadPayloadObject()) + { + OplogMarkerPath = Params["gcpath"sv].AsString(); + } + + ProjectStore::Oplog* OplogIt = Project->OpenOplog(OplogId); + + if (!OplogIt) + { + if (!Project->NewOplog(OplogId, OplogMarkerPath)) + { + // TODO: indicate why the operation failed! + return Req.ServerRequest().WriteResponse(HttpResponseCode::InternalServerError); + } + + ZEN_INFO("established oplog '{}/{}', gc marker file at '{}'", ProjectId, OplogId, OplogMarkerPath); + + return Req.ServerRequest().WriteResponse(HttpResponseCode::Created); + } + + // I guess this should ultimately be used to execute RPCs but for now, it + // does absolutely nothing + + return Req.ServerRequest().WriteResponse(HttpResponseCode::BadRequest); + } + break; + + case HttpVerb::kDelete: + { + ZEN_INFO("deleting oplog '{}/{}'", ProjectId, OplogId); + + Project->DeleteOplog(OplogId); + + return Req.ServerRequest().WriteResponse(HttpResponseCode::OK); + } + break; + + default: + break; + } + }, + HttpVerb::kPost | HttpVerb::kGet | HttpVerb::kDelete); + + m_Router.RegisterRoute( + "{project}/oplog/{log}/entries", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + const auto& ProjectId = Req.GetCapture(1); + const auto& OplogId = Req.GetCapture(2); + + Ref<ProjectStore::Project> Project = m_ProjectStore->OpenProject(ProjectId); + if (!Project) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + + ProjectStore::Oplog* FoundLog = Project->OpenOplog(OplogId); + + if (!FoundLog) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + + CbObjectWriter Response; + + if (FoundLog->OplogCount() > 0) + { + HttpServerRequest::QueryParams Params = HttpReq.GetQueryParams(); + + if (auto OpKey = Params.GetValue("opkey"); !OpKey.empty()) + { + Oid OpKeyId = OpKeyStringAsOId(OpKey); + std::optional<CbObject> Op = FoundLog->GetOpByKey(OpKeyId); + + if (Op.has_value()) + { + Response << "entry"sv << Op.value(); + } + else + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + } + else + { + Response.BeginArray("entries"sv); + + FoundLog->IterateOplog([&Response](CbObject Op) { Response << Op; }); + + Response.EndArray(); + } + } + + return HttpReq.WriteResponse(HttpResponseCode::OK, Response.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "{project}", + [this](HttpRouterRequest& Req) { + const std::string ProjectId = Req.GetCapture(1); + + switch (Req.ServerRequest().RequestVerb()) + { + case HttpVerb::kPost: + { + IoBuffer Payload = Req.ServerRequest().ReadPayload(); + CbObject Params = LoadCompactBinaryObject(Payload); + std::string_view Id = Params["id"sv].AsString(); + std::string_view Root = Params["root"sv].AsString(); + std::string_view EngineRoot = Params["engine"sv].AsString(); + std::string_view ProjectRoot = Params["project"sv].AsString(); + std::string_view ProjectFilePath = Params["projectfile"sv].AsString(); + + const std::filesystem::path BasePath = m_ProjectStore->BasePath() / ProjectId; + m_ProjectStore->NewProject(BasePath, ProjectId, Root, EngineRoot, ProjectRoot, ProjectFilePath); + + ZEN_INFO("established project - {} (id: '{}', roots: '{}', '{}', '{}', '{}'{})", + ProjectId, + Id, + Root, + EngineRoot, + ProjectRoot, + ProjectFilePath, + ProjectFilePath.empty() ? ", project will not be GCd due to empty project file path" : ""); + + Req.ServerRequest().WriteResponse(HttpResponseCode::Created); + } + break; + + case HttpVerb::kGet: + { + Ref<ProjectStore::Project> Project = m_ProjectStore->OpenProject(ProjectId); + + if (!Project) + { + return Req.ServerRequest().WriteResponse(HttpResponseCode::NotFound, + HttpContentType::kText, + fmt::format("project {} not found", ProjectId)); + } + + std::vector<std::string> OpLogs = Project->ScanForOplogs(); + + CbObjectWriter Response; + Response << "id"sv << Project->Identifier; + Response << "root"sv << PathToUtf8(Project->RootDir); + Response << "engine"sv << PathToUtf8(Project->EngineRootDir); + Response << "project"sv << PathToUtf8(Project->ProjectRootDir); + Response << "projectfile"sv << PathToUtf8(Project->ProjectFilePath); + + Response.BeginArray("oplogs"sv); + for (const std::string& OplogId : OpLogs) + { + Response.BeginObject(); + Response << "id"sv << OplogId; + Response.EndObject(); + } + Response.EndArray(); // oplogs + + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Response.Save()); + } + break; + + case HttpVerb::kDelete: + { + Ref<ProjectStore::Project> Project = m_ProjectStore->OpenProject(ProjectId); + + if (!Project) + { + return Req.ServerRequest().WriteResponse(HttpResponseCode::NotFound, + HttpContentType::kText, + fmt::format("project {} not found", ProjectId)); + } + + ZEN_INFO("deleting project '{}'", ProjectId); + if (!m_ProjectStore->DeleteProject(ProjectId)) + { + return Req.ServerRequest().WriteResponse(HttpResponseCode::Locked, + HttpContentType::kText, + fmt::format("project {} is in use", ProjectId)); + } + + return Req.ServerRequest().WriteResponse(HttpResponseCode::NoContent); + } + break; + + default: + break; + } + }, + HttpVerb::kGet | HttpVerb::kPost | HttpVerb::kDelete); + + // Push a oplog container + m_Router.RegisterRoute( + "{project}/oplog/{log}/save", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const auto& ProjectId = Req.GetCapture(1); + const auto& OplogId = Req.GetCapture(2); + if (HttpReq.RequestContentType() != HttpContentType::kCbObject) + { + return HttpReq.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "Invalid content type"); + } + IoBuffer Payload = Req.ServerRequest().ReadPayload(); + + CbObject Response; + std::pair<HttpResponseCode, std::string> Result = m_ProjectStore->WriteOplog(ProjectId, OplogId, std::move(Payload), Response); + if (Result.first == HttpResponseCode::OK) + { + return HttpReq.WriteResponse(HttpResponseCode::OK, Response); + } + if (Result.second.empty()) + { + return HttpReq.WriteResponse(Result.first); + } + return HttpReq.WriteResponse(Result.first, HttpContentType::kText, Result.second); + }, + HttpVerb::kPost); + + // Pull a oplog container + m_Router.RegisterRoute( + "{project}/oplog/{log}/load", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const auto& ProjectId = Req.GetCapture(1); + const auto& OplogId = Req.GetCapture(2); + if (HttpReq.AcceptContentType() != HttpContentType::kCbObject) + { + return HttpReq.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "Invalid accept content type"); + } + IoBuffer Payload = Req.ServerRequest().ReadPayload(); + + CbObject Response; + std::pair<HttpResponseCode, std::string> Result = + m_ProjectStore->ReadOplog(ProjectId, OplogId, Req.ServerRequest().GetQueryParams(), Response); + if (Result.first == HttpResponseCode::OK) + { + return HttpReq.WriteResponse(HttpResponseCode::OK, Response); + } + if (Result.second.empty()) + { + return HttpReq.WriteResponse(Result.first); + } + return HttpReq.WriteResponse(Result.first, HttpContentType::kText, Result.second); + }, + HttpVerb::kGet); + + // Do an rpc style operation on project/oplog + m_Router.RegisterRoute( + "{project}/oplog/{log}/rpc", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + const auto& ProjectId = Req.GetCapture(1); + const auto& OplogId = Req.GetCapture(2); + IoBuffer Payload = Req.ServerRequest().ReadPayload(); + + m_ProjectStore->Rpc(HttpReq, ProjectId, OplogId, std::move(Payload), m_AuthMgr); + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "details\\$", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + HttpServerRequest::QueryParams Params = HttpReq.GetQueryParams(); + bool CSV = Params.GetValue("csv") == "true"; + bool Details = Params.GetValue("details") == "true"; + bool OpDetails = Params.GetValue("opdetails") == "true"; + bool AttachmentDetails = Params.GetValue("attachmentdetails") == "true"; + + if (CSV) + { + ExtendableStringBuilder<4096> CSVWriter; + CSVHeader(Details, AttachmentDetails, CSVWriter); + + m_ProjectStore->IterateProjects([&](ProjectStore::Project& Project) { + Project.IterateOplogs([&](ProjectStore::Oplog& Oplog) { + Oplog.IterateOplogWithKey( + [this, &Project, &Oplog, &CSVWriter, Details, AttachmentDetails](int LSN, const Oid& Key, CbObject Op) { + CSVWriteOp(m_CidStore, + Project.Identifier, + Oplog.OplogId(), + Details, + AttachmentDetails, + LSN, + Key, + Op, + CSVWriter); + }); + }); + }); + + HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, CSVWriter.ToView()); + } + else + { + CbObjectWriter Cbo; + Cbo.BeginArray("projects"); + { + m_ProjectStore->DiscoverProjects(); + + m_ProjectStore->IterateProjects([&](ProjectStore::Project& Project) { + std::vector<std::string> OpLogs = Project.ScanForOplogs(); + CbWriteProject(m_CidStore, Project, OpLogs, Details, OpDetails, AttachmentDetails, Cbo); + }); + } + Cbo.EndArray(); + HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + } + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "details\\$/{project}", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const auto& ProjectId = Req.GetCapture(1); + + HttpServerRequest::QueryParams Params = HttpReq.GetQueryParams(); + bool CSV = Params.GetValue("csv") == "true"; + bool Details = Params.GetValue("details") == "true"; + bool OpDetails = Params.GetValue("opdetails") == "true"; + bool AttachmentDetails = Params.GetValue("attachmentdetails") == "true"; + + Ref<ProjectStore::Project> FoundProject = m_ProjectStore->OpenProject(ProjectId); + if (!FoundProject) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + ProjectStore::Project& Project = *FoundProject.Get(); + if (CSV) + { + ExtendableStringBuilder<4096> CSVWriter; + CSVHeader(Details, AttachmentDetails, CSVWriter); + + FoundProject->IterateOplogs([&](ProjectStore::Oplog& Oplog) { + Oplog.IterateOplogWithKey([this, &Project, &Oplog, &CSVWriter, Details, AttachmentDetails](int LSN, + const Oid& Key, + CbObject Op) { + CSVWriteOp(m_CidStore, Project.Identifier, Oplog.OplogId(), Details, AttachmentDetails, LSN, Key, Op, CSVWriter); + }); + }); + HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, CSVWriter.ToView()); + } + else + { + CbObjectWriter Cbo; + std::vector<std::string> OpLogs = FoundProject->ScanForOplogs(); + Cbo.BeginArray("projects"); + { + CbWriteProject(m_CidStore, Project, OpLogs, Details, OpDetails, AttachmentDetails, Cbo); + } + Cbo.EndArray(); + HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + } + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "details\\$/{project}/{log}", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const auto& ProjectId = Req.GetCapture(1); + const auto& OplogId = Req.GetCapture(2); + + HttpServerRequest::QueryParams Params = HttpReq.GetQueryParams(); + bool CSV = Params.GetValue("csv") == "true"; + bool Details = Params.GetValue("details") == "true"; + bool OpDetails = Params.GetValue("opdetails") == "true"; + bool AttachmentDetails = Params.GetValue("attachmentdetails") == "true"; + + Ref<ProjectStore::Project> FoundProject = m_ProjectStore->OpenProject(ProjectId); + if (!FoundProject) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + ProjectStore::Oplog* FoundLog = FoundProject->OpenOplog(OplogId); + + if (!FoundLog) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + + ProjectStore::Project& Project = *FoundProject.Get(); + ProjectStore::Oplog& Oplog = *FoundLog; + if (CSV) + { + ExtendableStringBuilder<4096> CSVWriter; + CSVHeader(Details, AttachmentDetails, CSVWriter); + + Oplog.IterateOplogWithKey( + [this, &Project, &Oplog, &CSVWriter, Details, AttachmentDetails](int LSN, const Oid& Key, CbObject Op) { + CSVWriteOp(m_CidStore, Project.Identifier, Oplog.OplogId(), Details, AttachmentDetails, LSN, Key, Op, CSVWriter); + }); + HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, CSVWriter.ToView()); + } + else + { + CbObjectWriter Cbo; + Cbo.BeginArray("oplogs"); + { + CbWriteOplog(m_CidStore, Oplog, Details, OpDetails, AttachmentDetails, Cbo); + } + Cbo.EndArray(); + HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + } + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "details\\$/{project}/{log}/{chunk}", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const auto& ProjectId = Req.GetCapture(1); + const auto& OplogId = Req.GetCapture(2); + const auto& ChunkId = Req.GetCapture(3); + + HttpServerRequest::QueryParams Params = HttpReq.GetQueryParams(); + bool CSV = Params.GetValue("csv") == "true"; + bool Details = Params.GetValue("details") == "true"; + bool OpDetails = Params.GetValue("opdetails") == "true"; + bool AttachmentDetails = Params.GetValue("attachmentdetails") == "true"; + + Ref<ProjectStore::Project> FoundProject = m_ProjectStore->OpenProject(ProjectId); + if (!FoundProject) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + ProjectStore::Oplog* FoundLog = FoundProject->OpenOplog(OplogId); + + if (!FoundLog) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + + if (ChunkId.size() != 2 * sizeof(Oid::OidBits)) + { + return HttpReq.WriteResponse( + HttpResponseCode::BadRequest, + HttpContentType::kText, + fmt::format("Chunk info request for invalid chunk id '{}/{}'/'{}'", ProjectId, OplogId, ChunkId)); + } + + const Oid ObjId = Oid::FromHexString(ChunkId); + ProjectStore::Project& Project = *FoundProject.Get(); + ProjectStore::Oplog& Oplog = *FoundLog; + + int LSN = Oplog.GetOpIndexByKey(ObjId); + if (LSN == -1) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + std::optional<CbObject> Op = Oplog.GetOpByIndex(LSN); + if (!Op.has_value()) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + + if (CSV) + { + ExtendableStringBuilder<4096> CSVWriter; + CSVHeader(Details, AttachmentDetails, CSVWriter); + + CSVWriteOp(m_CidStore, Project.Identifier, Oplog.OplogId(), Details, AttachmentDetails, LSN, ObjId, Op.value(), CSVWriter); + HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, CSVWriter.ToView()); + } + else + { + CbObjectWriter Cbo; + Cbo.BeginArray("ops"); + { + CbWriteOp(m_CidStore, Details, OpDetails, AttachmentDetails, LSN, ObjId, Op.value(), Cbo); + } + Cbo.EndArray(); + HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + } + }, + HttpVerb::kGet); +} + +HttpProjectService::~HttpProjectService() +{ + m_StatsService.UnregisterHandler("prj", *this); +} + +const char* +HttpProjectService::BaseUri() const +{ + return "/prj/"; +} + +void +HttpProjectService::HandleRequest(HttpServerRequest& Request) +{ + if (m_Router.HandleRequest(Request) == false) + { + ZEN_WARN("No route found for {0}", Request.RelativeUri()); + } +} + +void +HttpProjectService::HandleStatsRequest(HttpServerRequest& HttpReq) +{ + const GcStorageSize StoreSize = m_ProjectStore->StorageSize(); + const CidStoreSize CidSize = m_CidStore.TotalSize(); + + CbObjectWriter Cbo; + Cbo.BeginObject("store"); + { + Cbo.BeginObject("size"); + { + Cbo << "disk" << StoreSize.DiskSize; + Cbo << "memory" << StoreSize.MemorySize; + } + Cbo.EndObject(); + } + Cbo.EndObject(); + + Cbo.BeginObject("cid"); + { + Cbo.BeginObject("size"); + { + Cbo << "tiny" << CidSize.TinySize; + Cbo << "small" << CidSize.SmallSize; + Cbo << "large" << CidSize.LargeSize; + Cbo << "total" << CidSize.TotalSize; + } + Cbo.EndObject(); + } + Cbo.EndObject(); + + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); +} + +////////////////////////////////////////////////////////////////////////// + +#if ZEN_WITH_TESTS + +namespace testutils { + using namespace std::literals; + + std::string OidAsString(const Oid& Id) + { + StringBuilder<25> OidStringBuilder; + Id.ToString(OidStringBuilder); + return OidStringBuilder.ToString(); + } + + CbPackage CreateOplogPackage(const Oid& Id, const std::span<const std::pair<Oid, CompressedBuffer>>& Attachments) + { + CbPackage Package; + CbObjectWriter Object; + Object << "key"sv << OidAsString(Id); + if (!Attachments.empty()) + { + Object.BeginArray("bulkdata"); + for (const auto& Attachment : Attachments) + { + CbAttachment Attach(Attachment.second, Attachment.second.DecodeRawHash()); + Object.BeginObject(); + Object << "id"sv << Attachment.first; + Object << "type"sv + << "Standard"sv; + Object << "data"sv << Attach; + Object.EndObject(); + + Package.AddAttachment(Attach); + } + Object.EndArray(); + } + Package.SetObject(Object.Save()); + return Package; + }; + + std::vector<std::pair<Oid, CompressedBuffer>> CreateAttachments(const std::span<const size_t>& Sizes) + { + std::vector<std::pair<Oid, CompressedBuffer>> Result; + Result.reserve(Sizes.size()); + for (size_t Size : Sizes) + { + std::vector<uint8_t> Data; + Data.resize(Size); + uint16_t* DataPtr = reinterpret_cast<uint16_t*>(Data.data()); + for (size_t Idx = 0; Idx < Size / 2; ++Idx) + { + DataPtr[Idx] = static_cast<uint16_t>(Idx % 0xffffu); + } + if (Size & 1) + { + Data[Size - 1] = static_cast<uint8_t>((Size - 1) & 0xff); + } + CompressedBuffer Compressed = CompressedBuffer::Compress(SharedBuffer::MakeView(Data.data(), Data.size())); + Result.emplace_back(std::pair<Oid, CompressedBuffer>(Oid::NewOid(), Compressed)); + } + return Result; + } + + uint64 GetCompressedOffset(const CompressedBuffer& Buffer, uint64 RawOffset) + { + if (RawOffset > 0) + { + uint64 BlockSize = 0; + OodleCompressor Compressor; + OodleCompressionLevel CompressionLevel; + if (!Buffer.TryGetCompressParameters(Compressor, CompressionLevel, BlockSize)) + { + return 0; + } + return BlockSize > 0 ? RawOffset % BlockSize : 0; + } + return 0; + } + +} // namespace testutils + +TEST_CASE("project.store.create") +{ + using namespace std::literals; + + ScopedTemporaryDirectory TempDir; + + GcManager Gc; + CidStore CidStore(Gc); + CidStoreConfiguration CidConfig = {.RootDirectory = TempDir.Path() / "cas", .TinyValueThreshold = 1024, .HugeValueThreshold = 4096}; + CidStore.Initialize(CidConfig); + + std::string_view ProjectName("proj1"sv); + std::filesystem::path BasePath = TempDir.Path() / "projectstore"; + ProjectStore ProjectStore(CidStore, BasePath, Gc); + std::filesystem::path RootDir = TempDir.Path() / "root"; + std::filesystem::path EngineRootDir = TempDir.Path() / "engine"; + std::filesystem::path ProjectRootDir = TempDir.Path() / "game"; + std::filesystem::path ProjectFilePath = TempDir.Path() / "game" / "game.uproject"; + + Ref<ProjectStore::Project> Project(ProjectStore.NewProject(BasePath / ProjectName, + ProjectName, + RootDir.string(), + EngineRootDir.string(), + ProjectRootDir.string(), + ProjectFilePath.string())); + CHECK(ProjectStore.DeleteProject(ProjectName)); + CHECK(!Project->Exists(BasePath)); +} + +TEST_CASE("project.store.lifetimes") +{ + using namespace std::literals; + + ScopedTemporaryDirectory TempDir; + + GcManager Gc; + CidStore CidStore(Gc); + CidStoreConfiguration CidConfig = {.RootDirectory = TempDir.Path() / "cas", .TinyValueThreshold = 1024, .HugeValueThreshold = 4096}; + CidStore.Initialize(CidConfig); + + std::filesystem::path BasePath = TempDir.Path() / "projectstore"; + ProjectStore ProjectStore(CidStore, BasePath, Gc); + std::filesystem::path RootDir = TempDir.Path() / "root"; + std::filesystem::path EngineRootDir = TempDir.Path() / "engine"; + std::filesystem::path ProjectRootDir = TempDir.Path() / "game"; + std::filesystem::path ProjectFilePath = TempDir.Path() / "game" / "game.uproject"; + + Ref<ProjectStore::Project> Project(ProjectStore.NewProject(BasePath / "proj1"sv, + "proj1"sv, + RootDir.string(), + EngineRootDir.string(), + ProjectRootDir.string(), + ProjectFilePath.string())); + ProjectStore::Oplog* Oplog = Project->NewOplog("oplog1", {}); + CHECK(Oplog != nullptr); + + std::filesystem::path DeletePath; + CHECK(Project->PrepareForDelete(DeletePath)); + CHECK(!DeletePath.empty()); + CHECK(Project->OpenOplog("oplog1") == nullptr); + // Oplog is now invalid, but pointer can still be accessed since we store old oplog pointers + CHECK(Oplog->OplogCount() == 0); + // Project is still valid since we have a Ref to it + CHECK(Project->Identifier == "proj1"sv); +} + +TEST_CASE("project.store.gc") +{ + using namespace std::literals; + using namespace testutils; + + ScopedTemporaryDirectory TempDir; + + GcManager Gc; + CidStore CidStore(Gc); + CidStoreConfiguration CidConfig = {.RootDirectory = TempDir.Path() / "cas", .TinyValueThreshold = 1024, .HugeValueThreshold = 4096}; + CidStore.Initialize(CidConfig); + + std::filesystem::path BasePath = TempDir.Path() / "projectstore"; + ProjectStore ProjectStore(CidStore, BasePath, Gc); + std::filesystem::path RootDir = TempDir.Path() / "root"; + std::filesystem::path EngineRootDir = TempDir.Path() / "engine"; + + std::filesystem::path Project1RootDir = TempDir.Path() / "game1"; + std::filesystem::path Project1FilePath = TempDir.Path() / "game1" / "game.uproject"; + { + CreateDirectories(Project1FilePath.parent_path()); + BasicFile ProjectFile; + ProjectFile.Open(Project1FilePath, BasicFile::Mode::kTruncate); + } + + std::filesystem::path Project2RootDir = TempDir.Path() / "game2"; + std::filesystem::path Project2FilePath = TempDir.Path() / "game2" / "game.uproject"; + { + CreateDirectories(Project2FilePath.parent_path()); + BasicFile ProjectFile; + ProjectFile.Open(Project2FilePath, BasicFile::Mode::kTruncate); + } + + { + Ref<ProjectStore::Project> Project1(ProjectStore.NewProject(BasePath / "proj1"sv, + "proj1"sv, + RootDir.string(), + EngineRootDir.string(), + Project1RootDir.string(), + Project1FilePath.string())); + ProjectStore::Oplog* Oplog = Project1->NewOplog("oplog1", {}); + CHECK(Oplog != nullptr); + + Oplog->AppendNewOplogEntry(CreateOplogPackage(Oid::NewOid(), {})); + Oplog->AppendNewOplogEntry(CreateOplogPackage(Oid::NewOid(), CreateAttachments(std::initializer_list<size_t>{77}))); + Oplog->AppendNewOplogEntry(CreateOplogPackage(Oid::NewOid(), CreateAttachments(std::initializer_list<size_t>{7123, 583, 690, 99}))); + Oplog->AppendNewOplogEntry(CreateOplogPackage(Oid::NewOid(), CreateAttachments(std::initializer_list<size_t>{55, 122}))); + } + + { + Ref<ProjectStore::Project> Project2(ProjectStore.NewProject(BasePath / "proj2"sv, + "proj2"sv, + RootDir.string(), + EngineRootDir.string(), + Project2RootDir.string(), + Project2FilePath.string())); + ProjectStore::Oplog* Oplog = Project2->NewOplog("oplog1", {}); + CHECK(Oplog != nullptr); + + Oplog->AppendNewOplogEntry(CreateOplogPackage(Oid::NewOid(), {})); + Oplog->AppendNewOplogEntry(CreateOplogPackage(Oid::NewOid(), CreateAttachments(std::initializer_list<size_t>{177}))); + Oplog->AppendNewOplogEntry(CreateOplogPackage(Oid::NewOid(), CreateAttachments(std::initializer_list<size_t>{9123, 383, 590, 96}))); + Oplog->AppendNewOplogEntry(CreateOplogPackage(Oid::NewOid(), CreateAttachments(std::initializer_list<size_t>{535, 221}))); + } + + { + GcContext GcCtx(GcClock::Now() - std::chrono::hours(24)); + ProjectStore.GatherReferences(GcCtx); + size_t RefCount = 0; + GcCtx.IterateCids([&RefCount](const IoHash&) { RefCount++; }); + CHECK(RefCount == 14); + ProjectStore.CollectGarbage(GcCtx); + CHECK(ProjectStore.OpenProject("proj1"sv)); + CHECK(ProjectStore.OpenProject("proj2"sv)); + } + + std::filesystem::remove(Project1FilePath); + + { + GcContext GcCtx(GcClock::Now() - std::chrono::hours(24)); + ProjectStore.GatherReferences(GcCtx); + size_t RefCount = 0; + GcCtx.IterateCids([&RefCount](const IoHash&) { RefCount++; }); + CHECK(RefCount == 7); + ProjectStore.CollectGarbage(GcCtx); + CHECK(!ProjectStore.OpenProject("proj1"sv)); + CHECK(ProjectStore.OpenProject("proj2"sv)); + } +} + +TEST_CASE("project.store.partial.read") +{ + using namespace std::literals; + using namespace testutils; + + ScopedTemporaryDirectory TempDir; + + GcManager Gc; + CidStore CidStore(Gc); + CidStoreConfiguration CidConfig = {.RootDirectory = TempDir.Path() / "cas"sv, .TinyValueThreshold = 1024, .HugeValueThreshold = 4096}; + CidStore.Initialize(CidConfig); + + std::filesystem::path BasePath = TempDir.Path() / "projectstore"sv; + ProjectStore ProjectStore(CidStore, BasePath, Gc); + std::filesystem::path RootDir = TempDir.Path() / "root"sv; + std::filesystem::path EngineRootDir = TempDir.Path() / "engine"sv; + + std::filesystem::path Project1RootDir = TempDir.Path() / "game1"sv; + std::filesystem::path Project1FilePath = TempDir.Path() / "game1"sv / "game.uproject"sv; + { + CreateDirectories(Project1FilePath.parent_path()); + BasicFile ProjectFile; + ProjectFile.Open(Project1FilePath, BasicFile::Mode::kTruncate); + } + + std::vector<Oid> OpIds; + OpIds.insert(OpIds.end(), {Oid::NewOid(), Oid::NewOid(), Oid::NewOid(), Oid::NewOid()}); + std::unordered_map<Oid, std::vector<std::pair<Oid, CompressedBuffer>>, Oid::Hasher> Attachments; + { + Ref<ProjectStore::Project> Project1(ProjectStore.NewProject(BasePath / "proj1"sv, + "proj1"sv, + RootDir.string(), + EngineRootDir.string(), + Project1RootDir.string(), + Project1FilePath.string())); + ProjectStore::Oplog* Oplog = Project1->NewOplog("oplog1"sv, {}); + CHECK(Oplog != nullptr); + Attachments[OpIds[0]] = {}; + Attachments[OpIds[1]] = CreateAttachments(std::initializer_list<size_t>{77}); + Attachments[OpIds[2]] = CreateAttachments(std::initializer_list<size_t>{7123, 9583, 690, 99}); + Attachments[OpIds[3]] = CreateAttachments(std::initializer_list<size_t>{55, 122}); + for (auto It : Attachments) + { + Oplog->AppendNewOplogEntry(CreateOplogPackage(It.first, It.second)); + } + } + { + IoBuffer Chunk; + CHECK(ProjectStore + .GetChunk("proj1"sv, + "oplog1"sv, + Attachments[OpIds[1]][0].second.DecodeRawHash().ToHexString(), + HttpContentType::kCompressedBinary, + Chunk) + .first == HttpResponseCode::OK); + IoHash RawHash; + uint64_t RawSize; + CompressedBuffer Attachment = CompressedBuffer::FromCompressed(SharedBuffer(Chunk), RawHash, RawSize); + CHECK(RawSize == Attachments[OpIds[1]][0].second.DecodeRawSize()); + } + + IoBuffer ChunkResult; + CHECK(ProjectStore + .GetChunkRange("proj1"sv, + "oplog1"sv, + OidAsString(Attachments[OpIds[2]][1].first), + 0, + ~0ull, + HttpContentType::kCompressedBinary, + ChunkResult) + .first == HttpResponseCode::OK); + CHECK(ChunkResult); + CHECK(CompressedBuffer::FromCompressedNoValidate(std::move(ChunkResult)).DecodeRawSize() == + Attachments[OpIds[2]][1].second.DecodeRawSize()); + + IoBuffer PartialChunkResult; + CHECK(ProjectStore + .GetChunkRange("proj1"sv, + "oplog1"sv, + OidAsString(Attachments[OpIds[2]][1].first), + 5, + 1773, + HttpContentType::kCompressedBinary, + PartialChunkResult) + .first == HttpResponseCode::OK); + CHECK(PartialChunkResult); + IoHash PartialRawHash; + uint64_t PartialRawSize; + CompressedBuffer PartialCompressedResult = + CompressedBuffer::FromCompressed(SharedBuffer(PartialChunkResult), PartialRawHash, PartialRawSize); + CHECK(PartialRawSize >= 1773); + + uint64_t RawOffsetInPartialCompressed = GetCompressedOffset(PartialCompressedResult, 5); + SharedBuffer PartialDecompressed = PartialCompressedResult.Decompress(RawOffsetInPartialCompressed); + SharedBuffer FullDecompressed = Attachments[OpIds[2]][1].second.Decompress(); + const uint8_t* FullDataPtr = &(reinterpret_cast<const uint8_t*>(FullDecompressed.GetView().GetData())[5]); + const uint8_t* PartialDataPtr = reinterpret_cast<const uint8_t*>(PartialDecompressed.GetView().GetData()); + CHECK(FullDataPtr[0] == PartialDataPtr[0]); +} + +TEST_CASE("project.store.block") +{ + using namespace std::literals; + using namespace testutils; + + std::vector<std::size_t> AttachmentSizes({7633, 6825, 5738, 8031, 7225, 566, 3656, 6006, 24, 3466, 1093, 4269, 2257, 3685, 3489, + 7194, 6151, 5482, 6217, 3511, 6738, 5061, 7537, 2759, 1916, 8210, 2235, 4024, 1582, 5251, + 491, 5464, 4607, 8135, 3767, 4045, 4415, 5007, 8876, 6761, 3359, 8526, 4097, 4855, 8225}); + + std::vector<std::pair<Oid, CompressedBuffer>> AttachmentsWithId = CreateAttachments(AttachmentSizes); + std::vector<SharedBuffer> Chunks; + Chunks.reserve(AttachmentSizes.size()); + for (const auto& It : AttachmentsWithId) + { + Chunks.push_back(It.second.GetCompressed().Flatten()); + } + CompressedBuffer Block = GenerateBlock(std::move(Chunks)); + IoBuffer BlockBuffer = Block.GetCompressed().Flatten().AsIoBuffer(); + CHECK(IterateBlock(std::move(BlockBuffer), [](CompressedBuffer&&, const IoHash&) {})); +} + +#endif + +void +prj_forcelink() +{ +} + +} // namespace zen diff --git a/src/zenserver/projectstore/projectstore.h b/src/zenserver/projectstore/projectstore.h new file mode 100644 index 000000000..e4f664b85 --- /dev/null +++ b/src/zenserver/projectstore/projectstore.h @@ -0,0 +1,372 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/uid.h> +#include <zencore/xxhash.h> +#include <zenhttp/httpserver.h> +#include <zenstore/gc.h> + +#include "monitoring/httpstats.h" + +ZEN_THIRD_PARTY_INCLUDES_START +#include <tsl/robin_map.h> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen { + +class CbPackage; +class CidStore; +class AuthMgr; +class ScrubContext; + +struct OplogEntry +{ + uint32_t OpLsn; + uint32_t OpCoreOffset; // note: Multiple of alignment! + uint32_t OpCoreSize; + uint32_t OpCoreHash; // Used as checksum + XXH3_128 OpKeyHash; // XXH128_canonical_t + + inline Oid OpKeyAsOId() const + { + Oid Id; + memcpy(Id.OidBits, &OpKeyHash, sizeof Id.OidBits); + return Id; + } +}; + +struct OplogEntryAddress +{ + uint64_t Offset; + uint64_t Size; +}; + +static_assert(IsPow2(sizeof(OplogEntry))); + +/** Project Store + + A project store consists of a number of Projects. + + Each project contains a number of oplogs (short for "operation log"). UE uses + one oplog per target platform to store the output of the cook process. + + An oplog consists of a sequence of "op" entries. Each entry is a structured object + containing references to attachments. Attachments are typically the serialized + package data split into separate chunks for bulk data, exports and header + information. + */ +class ProjectStore : public RefCounted, public GcStorage, public GcContributor +{ + struct OplogStorage; + +public: + ProjectStore(CidStore& Store, std::filesystem::path BasePath, GcManager& Gc); + ~ProjectStore(); + + struct Project; + + struct Oplog + { + Oplog(std::string_view Id, + Project* Project, + CidStore& Store, + std::filesystem::path BasePath, + const std::filesystem::path& MarkerPath); + ~Oplog(); + + [[nodiscard]] static bool ExistsAt(std::filesystem::path BasePath); + + void Read(); + void Write(); + + void IterateFileMap(std::function<void(const Oid&, const std::string_view& ServerPath, const std::string_view& ClientPath)>&& Fn); + void IterateOplog(std::function<void(CbObject)>&& Fn); + void IterateOplogWithKey(std::function<void(int, const Oid&, CbObject)>&& Fn); + std::optional<CbObject> GetOpByKey(const Oid& Key); + std::optional<CbObject> GetOpByIndex(int Index); + int GetOpIndexByKey(const Oid& Key); + + IoBuffer FindChunk(Oid ChunkId); + + inline static const uint32_t kInvalidOp = ~0u; + + /** Persist a new oplog entry + * + * Returns the oplog LSN assigned to the new entry, or kInvalidOp if the entry is rejected + */ + uint32_t AppendNewOplogEntry(CbPackage Op); + + uint32_t AppendNewOplogEntry(CbObject Core); + + enum UpdateType + { + kUpdateNewEntry, + kUpdateReplay + }; + + const std::string& OplogId() const { return m_OplogId; } + + const std::filesystem::path& TempPath() const { return m_TempPath; } + const std::filesystem::path& MarkerPath() const { return m_MarkerPath; } + + spdlog::logger& Log() { return m_OuterProject->Log(); } + void Flush(); + void Scrub(ScrubContext& Ctx) const; + void GatherReferences(GcContext& GcCtx); + uint64_t TotalSize() const; + + std::size_t OplogCount() const + { + RwLock::SharedLockScope _(m_OplogLock); + return m_LatestOpMap.size(); + } + + bool IsExpired() const; + std::filesystem::path PrepareForDelete(bool MoveFolder); + + private: + struct FileMapEntry + { + std::string ServerPath; + std::string ClientPath; + }; + + template<class V> + using OidMap = tsl::robin_map<Oid, V, Oid::Hasher>; + + Project* m_OuterProject = nullptr; + CidStore& m_CidStore; + std::filesystem::path m_BasePath; + std::filesystem::path m_MarkerPath; + std::filesystem::path m_TempPath; + + mutable RwLock m_OplogLock; + OidMap<IoHash> m_ChunkMap; // output data chunk id -> CAS address + OidMap<IoHash> m_MetaMap; // meta chunk id -> CAS address + OidMap<FileMapEntry> m_FileMap; // file id -> file map entry + int32_t m_ManifestVersion; // File system manifest version + tsl::robin_map<int, OplogEntryAddress> m_OpAddressMap; // Index LSN -> op data in ops blob file + OidMap<int> m_LatestOpMap; // op key -> latest op LSN for key + + RefPtr<OplogStorage> m_Storage; + std::string m_OplogId; + + /** Scan oplog and register each entry, thus updating the in-memory tracking tables + */ + void ReplayLog(); + + struct OplogEntryMapping + { + struct Mapping + { + Oid Id; + IoHash Hash; + }; + struct FileMapping : public Mapping + { + std::string ServerPath; + std::string ClientPath; + }; + std::vector<Mapping> Chunks; + std::vector<Mapping> Meta; + std::vector<FileMapping> Files; + }; + + OplogEntryMapping GetMapping(CbObject Core); + + /** Update tracking metadata for a new oplog entry + * + * This is used during replay (and gets called as part of new op append) + * + * Returns the oplog LSN assigned to the new entry, or kInvalidOp if the entry is rejected + */ + uint32_t RegisterOplogEntry(RwLock::ExclusiveLockScope& OplogLock, + const OplogEntryMapping& OpMapping, + const OplogEntry& OpEntry, + UpdateType TypeOfUpdate); + + void AddFileMapping(const RwLock::ExclusiveLockScope& OplogLock, + Oid FileId, + IoHash Hash, + std::string_view ServerPath, + std::string_view ClientPath); + void AddChunkMapping(const RwLock::ExclusiveLockScope& OplogLock, Oid ChunkId, IoHash Hash); + void AddMetaMapping(const RwLock::ExclusiveLockScope& OplogLock, Oid ChunkId, IoHash Hash); + }; + + struct Project : public RefCounted + { + std::string Identifier; + std::filesystem::path RootDir; + std::string EngineRootDir; + std::string ProjectRootDir; + std::string ProjectFilePath; + + Oplog* NewOplog(std::string_view OplogId, const std::filesystem::path& MarkerPath); + Oplog* OpenOplog(std::string_view OplogId); + void DeleteOplog(std::string_view OplogId); + void IterateOplogs(std::function<void(const Oplog&)>&& Fn) const; + void IterateOplogs(std::function<void(Oplog&)>&& Fn); + std::vector<std::string> ScanForOplogs() const; + bool IsExpired() const; + + Project(ProjectStore* PrjStore, CidStore& Store, std::filesystem::path BasePath); + virtual ~Project(); + + void Read(); + void Write(); + [[nodiscard]] static bool Exists(std::filesystem::path BasePath); + void Flush(); + void Scrub(ScrubContext& Ctx); + spdlog::logger& Log(); + void GatherReferences(GcContext& GcCtx); + uint64_t TotalSize() const; + bool PrepareForDelete(std::filesystem::path& OutDeletePath); + + private: + ProjectStore* m_ProjectStore; + CidStore& m_CidStore; + mutable RwLock m_ProjectLock; + std::map<std::string, std::unique_ptr<Oplog>> m_Oplogs; + std::vector<std::unique_ptr<Oplog>> m_DeletedOplogs; + std::filesystem::path m_OplogStoragePath; + + std::filesystem::path BasePathForOplog(std::string_view OplogId); + }; + + // Oplog* OpenProjectOplog(std::string_view ProjectId, std::string_view OplogId); + + Ref<Project> OpenProject(std::string_view ProjectId); + Ref<Project> NewProject(std::filesystem::path BasePath, + std::string_view ProjectId, + std::string_view RootDir, + std::string_view EngineRootDir, + std::string_view ProjectRootDir, + std::string_view ProjectFilePath); + bool DeleteProject(std::string_view ProjectId); + bool Exists(std::string_view ProjectId); + void Flush(); + void Scrub(ScrubContext& Ctx); + void DiscoverProjects(); + void IterateProjects(std::function<void(Project& Prj)>&& Fn); + + spdlog::logger& Log() { return m_Log; } + const std::filesystem::path& BasePath() const { return m_ProjectBasePath; } + + virtual void GatherReferences(GcContext& GcCtx) override; + virtual void CollectGarbage(GcContext& GcCtx) override; + virtual GcStorageSize StorageSize() const override; + + CbArray GetProjectsList(); + std::pair<HttpResponseCode, std::string> GetProjectFiles(const std::string_view ProjectId, + const std::string_view OplogId, + bool FilterClient, + CbObject& OutPayload); + std::pair<HttpResponseCode, std::string> GetChunkInfo(const std::string_view ProjectId, + const std::string_view OplogId, + const std::string_view ChunkId, + CbObject& OutPayload); + std::pair<HttpResponseCode, std::string> GetChunkRange(const std::string_view ProjectId, + const std::string_view OplogId, + const std::string_view ChunkId, + uint64_t Offset, + uint64_t Size, + ZenContentType AcceptType, + IoBuffer& OutChunk); + std::pair<HttpResponseCode, std::string> GetChunk(const std::string_view ProjectId, + const std::string_view OplogId, + const std::string_view Cid, + ZenContentType AcceptType, + IoBuffer& OutChunk); + + std::pair<HttpResponseCode, std::string> PutChunk(const std::string_view ProjectId, + const std::string_view OplogId, + const std::string_view Cid, + ZenContentType ContentType, + IoBuffer&& Chunk); + + std::pair<HttpResponseCode, std::string> WriteOplog(const std::string_view ProjectId, + const std::string_view OplogId, + IoBuffer&& Payload, + CbObject& OutResponse); + + std::pair<HttpResponseCode, std::string> ReadOplog(const std::string_view ProjectId, + const std::string_view OplogId, + const HttpServerRequest::QueryParams& Params, + CbObject& OutResponse); + + std::pair<HttpResponseCode, std::string> WriteBlock(const std::string_view ProjectId, + const std::string_view OplogId, + IoBuffer&& Payload); + + void Rpc(HttpServerRequest& HttpReq, + const std::string_view ProjectId, + const std::string_view OplogId, + IoBuffer&& Payload, + AuthMgr& AuthManager); + + std::pair<HttpResponseCode, std::string> Export(ProjectStore::Project& Project, + ProjectStore::Oplog& Oplog, + CbObjectView&& Params, + AuthMgr& AuthManager); + + std::pair<HttpResponseCode, std::string> Import(ProjectStore::Project& Project, + ProjectStore::Oplog& Oplog, + CbObjectView&& Params, + AuthMgr& AuthManager); + +private: + spdlog::logger& m_Log; + CidStore& m_CidStore; + std::filesystem::path m_ProjectBasePath; + mutable RwLock m_ProjectsLock; + std::map<std::string, Ref<Project>> m_Projects; + + std::filesystem::path BasePathForProject(std::string_view ProjectId); +}; + +////////////////////////////////////////////////////////////////////////// +// +// {project} a project identifier +// {target} a variation of the project, typically a build target +// {lsn} oplog entry sequence number +// +// /prj/{project} +// /prj/{project}/oplog/{target} +// /prj/{project}/oplog/{target}/{lsn} +// +// oplog entry +// +// id: {id} +// key: {} +// meta: {} +// data: [] +// refs: +// + +class HttpProjectService : public HttpService, public IHttpStatsProvider +{ +public: + HttpProjectService(CidStore& Store, ProjectStore* InProjectStore, HttpStatsService& StatsService, AuthMgr& AuthMgr); + ~HttpProjectService(); + + virtual const char* BaseUri() const override; + virtual void HandleRequest(HttpServerRequest& Request) override; + + virtual void HandleStatsRequest(HttpServerRequest& Request) override; + +private: + inline spdlog::logger& Log() { return m_Log; } + + spdlog::logger& m_Log; + CidStore& m_CidStore; + HttpRequestRouter m_Router; + Ref<ProjectStore> m_ProjectStore; + HttpStatsService& m_StatsService; + AuthMgr& m_AuthMgr; +}; + +void prj_forcelink(); + +} // namespace zen diff --git a/src/zenserver/projectstore/remoteprojectstore.cpp b/src/zenserver/projectstore/remoteprojectstore.cpp new file mode 100644 index 000000000..1e6ca51a1 --- /dev/null +++ b/src/zenserver/projectstore/remoteprojectstore.cpp @@ -0,0 +1,1036 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "remoteprojectstore.h" + +#include <zencore/compactbinarybuilder.h> +#include <zencore/compress.h> +#include <zencore/filesystem.h> +#include <zencore/fmtutils.h> +#include <zencore/logging.h> +#include <zencore/scopeguard.h> +#include <zencore/stream.h> +#include <zencore/timer.h> +#include <zencore/workthreadpool.h> +#include <zenstore/cidstore.h> + +namespace zen { + +/* + OplogContainer + Binary("ops") // Compressed CompactBinary object to hide attachment references, also makes the oplog smaller + { + CbArray("ops") + { + CbObject Op + (CbFieldType::BinaryAttachment Attachments[]) + (OpData) + } + } + CbArray("blocks") + CbObject + CbFieldType::BinaryAttachment "rawhash" // Optional, only if we are creating blocks (Jupiter/File) + CbArray("chunks") + CbFieldType::Hash // Chunk hashes + CbArray("chunks") // Optional, only if we are not creating blocks (Zen) + CbFieldType::BinaryAttachment // Chunk attachment hashes + + CompressedBinary ChunkBlock + { + VarUInt ChunkCount + VarUInt ChunkSizes[ChunkCount] + uint8_t[chunksize])[ChunkCount] + } +*/ + +////////////////////////////// AsyncRemoteResult + +struct AsyncRemoteResult +{ + void SetError(int32_t ErrorCode, const std::string& ErrorReason, const std::string ErrorText) + { + int32_t Expected = 0; + if (m_ErrorCode.compare_exchange_weak(Expected, ErrorCode ? ErrorCode : -1)) + { + m_ErrorReason = ErrorReason; + m_ErrorText = ErrorText; + } + } + bool IsError() const { return m_ErrorCode.load() != 0; } + int GetError() const { return m_ErrorCode.load(); }; + const std::string& GetErrorReason() const { return m_ErrorReason; }; + const std::string& GetErrorText() const { return m_ErrorText; }; + RemoteProjectStore::Result ConvertResult(double ElapsedSeconds = 0.0) const + { + return RemoteProjectStore::Result{m_ErrorCode, ElapsedSeconds, m_ErrorReason, m_ErrorText}; + } + +private: + std::atomic<int32_t> m_ErrorCode = 0; + std::string m_ErrorReason; + std::string m_ErrorText; +}; + +bool +IterateBlock(IoBuffer&& CompressedBlock, std::function<void(CompressedBuffer&& Chunk, const IoHash& AttachmentHash)> Visitor) +{ + IoBuffer BlockPayload = CompressedBuffer::FromCompressedNoValidate(std::move(CompressedBlock)).Decompress().AsIoBuffer(); + + MemoryView BlockView = BlockPayload.GetView(); + const uint8_t* ReadPtr = reinterpret_cast<const uint8_t*>(BlockView.GetData()); + uint32_t NumberSize; + uint64_t ChunkCount = ReadVarUInt(ReadPtr, NumberSize); + ReadPtr += NumberSize; + std::vector<uint64_t> ChunkSizes; + ChunkSizes.reserve(ChunkCount); + while (ChunkCount--) + { + ChunkSizes.push_back(ReadVarUInt(ReadPtr, NumberSize)); + ReadPtr += NumberSize; + } + ptrdiff_t TempBufferLength = std::distance(reinterpret_cast<const uint8_t*>(BlockView.GetData()), ReadPtr); + ZEN_ASSERT(TempBufferLength > 0); + for (uint64_t ChunkSize : ChunkSizes) + { + IoBuffer Chunk(IoBuffer::Wrap, ReadPtr, ChunkSize); + IoHash AttachmentRawHash; + uint64_t AttachmentRawSize; + CompressedBuffer CompressedChunk = CompressedBuffer::FromCompressed(SharedBuffer(Chunk), AttachmentRawHash, AttachmentRawSize); + + if (!CompressedChunk) + { + ZEN_ERROR("Invalid chunk in block"); + return false; + } + Visitor(std::move(CompressedChunk), AttachmentRawHash); + ReadPtr += ChunkSize; + ZEN_ASSERT(ReadPtr <= BlockView.GetDataEnd()); + } + return true; +}; + +CompressedBuffer +GenerateBlock(std::vector<SharedBuffer>&& Chunks) +{ + size_t ChunkCount = Chunks.size(); + SharedBuffer SizeBuffer; + { + IoBuffer TempBuffer(ChunkCount * 9); + MutableMemoryView View = TempBuffer.GetMutableView(); + uint8_t* BufferStartPtr = reinterpret_cast<uint8_t*>(View.GetData()); + uint8_t* BufferEndPtr = BufferStartPtr; + BufferEndPtr += WriteVarUInt(gsl::narrow<uint64_t>(ChunkCount), BufferEndPtr); + auto It = Chunks.begin(); + while (It != Chunks.end()) + { + BufferEndPtr += WriteVarUInt(gsl::narrow<uint64_t>(It->GetSize()), BufferEndPtr); + It++; + } + ZEN_ASSERT(BufferEndPtr <= View.GetDataEnd()); + ptrdiff_t TempBufferLength = std::distance(BufferStartPtr, BufferEndPtr); + SizeBuffer = SharedBuffer(IoBuffer(TempBuffer, 0, gsl::narrow<size_t>(TempBufferLength))); + } + CompositeBuffer AllBuffers(std::move(SizeBuffer), CompositeBuffer(std::move(Chunks))); + + CompressedBuffer CompressedBlock = + CompressedBuffer::Compress(std::move(AllBuffers), OodleCompressor::Mermaid, OodleCompressionLevel::None); + + return CompressedBlock; +} + +struct Block +{ + IoHash BlockHash; + std::vector<IoHash> ChunksInBlock; +}; + +void +CreateBlock(WorkerThreadPool& WorkerPool, + Latch& OpSectionsLatch, + std::vector<SharedBuffer>&& ChunksInBlock, + RwLock& SectionsLock, + std::vector<Block>& Blocks, + size_t BlockIndex, + const std::function<void(CompressedBuffer&&, const IoHash&)>& AsyncOnBlock, + AsyncRemoteResult& RemoteResult) +{ + OpSectionsLatch.AddCount(1); + WorkerPool.ScheduleWork( + [&Blocks, &SectionsLock, &OpSectionsLatch, BlockIndex, Chunks = std::move(ChunksInBlock), &AsyncOnBlock, &RemoteResult]() mutable { + auto _ = MakeGuard([&OpSectionsLatch] { OpSectionsLatch.CountDown(); }); + if (RemoteResult.IsError()) + { + return; + } + if (!Chunks.empty()) + { + CompressedBuffer CompressedBlock = GenerateBlock(std::move(Chunks)); // Move to callback and return IoHash + IoHash BlockHash = CompressedBlock.DecodeRawHash(); + AsyncOnBlock(std::move(CompressedBlock), BlockHash); + { + // We can share the lock as we are not resizing the vector and only touch BlockHash at our own index + RwLock::SharedLockScope __(SectionsLock); + Blocks[BlockIndex].BlockHash = BlockHash; + } + } + }); +} + +size_t +AddBlock(RwLock& BlocksLock, std::vector<Block>& Blocks) +{ + size_t BlockIndex; + { + RwLock::ExclusiveLockScope _(BlocksLock); + BlockIndex = Blocks.size(); + Blocks.resize(BlockIndex + 1); + } + return BlockIndex; +} + +CbObject +BuildContainer(CidStore& ChunkStore, + ProjectStore::Oplog& Oplog, + size_t MaxBlockSize, + size_t MaxChunkEmbedSize, + bool BuildBlocks, + WorkerThreadPool& WorkerPool, + const std::function<void(CompressedBuffer&&, const IoHash&)>& AsyncOnBlock, + const std::function<void(const IoHash&)>& OnLargeAttachment, + const std::function<void(const std::unordered_set<IoHash, IoHash::Hasher>)>& OnBlockChunks, + AsyncRemoteResult& RemoteResult) +{ + using namespace std::literals; + + std::unordered_set<IoHash, IoHash::Hasher> LargeChunkHashes; + CbObjectWriter SectionOpsWriter; + SectionOpsWriter.BeginArray("ops"sv); + + size_t OpCount = 0; + + CbObject OplogContainerObject; + { + RwLock BlocksLock; + std::vector<Block> Blocks; + CompressedBuffer OpsBuffer; + + Latch BlockCreateLatch(1); + + std::unordered_set<IoHash, IoHash::Hasher> BlockAttachmentHashes; + + size_t BlockSize = 0; + std::vector<SharedBuffer> ChunksInBlock; + + std::unordered_set<IoHash, IoHash::Hasher> Attachments; + Oplog.IterateOplog([&Attachments, &SectionOpsWriter, &OpCount](CbObject Op) { + Op.IterateAttachments([&](CbFieldView FieldView) { Attachments.insert(FieldView.AsAttachment()); }); + (SectionOpsWriter) << Op; + OpCount++; + }); + + for (const IoHash& AttachmentHash : Attachments) + { + IoBuffer Payload = ChunkStore.FindChunkByCid(AttachmentHash); + if (!Payload) + { + RemoteResult.SetError(gsl::narrow<int>(HttpResponseCode::NotFound), + fmt::format("Failed to find attachment {} for op", AttachmentHash), + {}); + ZEN_ERROR("Failed to build container ({}). Reason: '{}'", RemoteResult.GetError(), RemoteResult.GetErrorReason()); + return {}; + } + uint64_t PayloadSize = Payload.GetSize(); + if (PayloadSize > MaxChunkEmbedSize) + { + if (LargeChunkHashes.insert(AttachmentHash).second) + { + OnLargeAttachment(AttachmentHash); + } + continue; + } + + if (!BlockAttachmentHashes.insert(AttachmentHash).second) + { + continue; + } + + BlockSize += PayloadSize; + if (BuildBlocks) + { + ChunksInBlock.emplace_back(SharedBuffer(std::move(Payload))); + } + else + { + Payload = {}; + } + + if (BlockSize >= MaxBlockSize) + { + size_t BlockIndex = AddBlock(BlocksLock, Blocks); + if (BuildBlocks) + { + CreateBlock(WorkerPool, + BlockCreateLatch, + std::move(ChunksInBlock), + BlocksLock, + Blocks, + BlockIndex, + AsyncOnBlock, + RemoteResult); + } + else + { + OnBlockChunks(BlockAttachmentHashes); + } + { + // We can share the lock as we are not resizing the vector and only touch BlockHash at our own index + RwLock::SharedLockScope _(BlocksLock); + Blocks[BlockIndex].ChunksInBlock.insert(Blocks[BlockIndex].ChunksInBlock.end(), + BlockAttachmentHashes.begin(), + BlockAttachmentHashes.end()); + } + BlockAttachmentHashes.clear(); + ChunksInBlock.clear(); + BlockSize = 0; + } + } + if (BlockSize > 0) + { + size_t BlockIndex = AddBlock(BlocksLock, Blocks); + if (BuildBlocks) + { + CreateBlock(WorkerPool, + BlockCreateLatch, + std::move(ChunksInBlock), + BlocksLock, + Blocks, + BlockIndex, + AsyncOnBlock, + RemoteResult); + } + else + { + OnBlockChunks(BlockAttachmentHashes); + } + { + // We can share the lock as we are not resizing the vector and only touch BlockHash at our own index + RwLock::SharedLockScope _(BlocksLock); + Blocks[BlockIndex].ChunksInBlock.insert(Blocks[BlockIndex].ChunksInBlock.end(), + BlockAttachmentHashes.begin(), + BlockAttachmentHashes.end()); + } + BlockAttachmentHashes.clear(); + ChunksInBlock.clear(); + BlockSize = 0; + } + SectionOpsWriter.EndArray(); // "ops" + + CompressedBuffer CompressedOpsSection = CompressedBuffer::Compress(SectionOpsWriter.Save().GetBuffer()); + ZEN_DEBUG("Added oplog section {}, {}", CompressedOpsSection.DecodeRawHash(), NiceBytes(CompressedOpsSection.GetCompressedSize())); + + BlockCreateLatch.CountDown(); + while (!BlockCreateLatch.Wait(1000)) + { + ZEN_INFO("Creating blocks, {} remaining...", BlockCreateLatch.Remaining()); + } + + if (!RemoteResult.IsError()) + { + CbObjectWriter OplogContinerWriter; + RwLock::SharedLockScope _(BlocksLock); + OplogContinerWriter.AddBinary("ops"sv, CompressedOpsSection.GetCompressed().Flatten().AsIoBuffer()); + + OplogContinerWriter.BeginArray("blocks"sv); + { + for (const Block& B : Blocks) + { + ZEN_ASSERT(!B.ChunksInBlock.empty()); + if (BuildBlocks) + { + ZEN_ASSERT(B.BlockHash != IoHash::Zero); + + OplogContinerWriter.BeginObject(); + { + OplogContinerWriter.AddBinaryAttachment("rawhash"sv, B.BlockHash); + OplogContinerWriter.BeginArray("chunks"sv); + { + for (const IoHash& RawHash : B.ChunksInBlock) + { + OplogContinerWriter.AddHash(RawHash); + } + } + OplogContinerWriter.EndArray(); // "chunks" + } + OplogContinerWriter.EndObject(); + continue; + } + + ZEN_ASSERT(B.BlockHash == IoHash::Zero); + OplogContinerWriter.BeginObject(); + { + OplogContinerWriter.BeginArray("chunks"sv); + { + for (const IoHash& RawHash : B.ChunksInBlock) + { + OplogContinerWriter.AddBinaryAttachment(RawHash); + } + } + OplogContinerWriter.EndArray(); + } + OplogContinerWriter.EndObject(); + } + } + OplogContinerWriter.EndArray(); // "blocks"sv + + OplogContinerWriter.BeginArray("chunks"sv); + { + for (const IoHash& AttachmentHash : LargeChunkHashes) + { + OplogContinerWriter.AddBinaryAttachment(AttachmentHash); + } + } + OplogContinerWriter.EndArray(); // "chunks" + + OplogContainerObject = OplogContinerWriter.Save(); + } + } + return OplogContainerObject; +} + +RemoteProjectStore::LoadContainerResult +BuildContainer(CidStore& ChunkStore, + ProjectStore::Oplog& Oplog, + size_t MaxBlockSize, + size_t MaxChunkEmbedSize, + bool BuildBlocks, + const std::function<void(CompressedBuffer&&, const IoHash&)>& AsyncOnBlock, + const std::function<void(const IoHash&)>& OnLargeAttachment, + const std::function<void(const std::unordered_set<IoHash, IoHash::Hasher>)>& OnBlockChunks) +{ + // We are creating a worker thread pool here since we are uploading a lot of attachments in one go and we dont want to keep a + // WorkerThreadPool alive + size_t WorkerCount = Min(std::thread::hardware_concurrency(), 16u); + WorkerThreadPool WorkerPool(gsl::narrow<int>(WorkerCount)); + + AsyncRemoteResult RemoteResult; + CbObject ContainerObject = BuildContainer(ChunkStore, + Oplog, + MaxBlockSize, + MaxChunkEmbedSize, + BuildBlocks, + WorkerPool, + AsyncOnBlock, + OnLargeAttachment, + OnBlockChunks, + RemoteResult); + return RemoteProjectStore::LoadContainerResult{RemoteResult.ConvertResult(), ContainerObject}; +} + +RemoteProjectStore::Result +SaveOplog(CidStore& ChunkStore, + RemoteProjectStore& RemoteStore, + ProjectStore::Oplog& Oplog, + size_t MaxBlockSize, + size_t MaxChunkEmbedSize, + bool BuildBlocks, + bool UseTempBlocks, + bool ForceUpload) +{ + using namespace std::literals; + + Stopwatch Timer; + + // We are creating a worker thread pool here since we are uploading a lot of attachments in one go + // Doing upload is a rare and transient occation so we don't want to keep a WorkerThreadPool alive. + size_t WorkerCount = Min(std::thread::hardware_concurrency(), 16u); + WorkerThreadPool WorkerPool(gsl::narrow<int>(WorkerCount)); + + std::filesystem::path AttachmentTempPath; + if (UseTempBlocks) + { + AttachmentTempPath = Oplog.TempPath(); + AttachmentTempPath.append(".pending"); + CreateDirectories(AttachmentTempPath); + } + + AsyncRemoteResult RemoteResult; + RwLock AttachmentsLock; + std::unordered_set<IoHash, IoHash::Hasher> LargeAttachments; + std::unordered_map<IoHash, IoBuffer, IoHash::Hasher> CreatedBlocks; + + auto MakeTempBlock = [AttachmentTempPath, &RemoteResult, &AttachmentsLock, &CreatedBlocks](CompressedBuffer&& CompressedBlock, + const IoHash& BlockHash) { + std::filesystem::path BlockPath = AttachmentTempPath; + BlockPath.append(BlockHash.ToHexString()); + if (!std::filesystem::exists(BlockPath)) + { + IoBuffer BlockBuffer; + try + { + BasicFile BlockFile; + BlockFile.Open(BlockPath, BasicFile::Mode::kTruncateDelete); + uint64_t Offset = 0; + for (const SharedBuffer& Buffer : CompressedBlock.GetCompressed().GetSegments()) + { + BlockFile.Write(Buffer.GetView(), Offset); + Offset += Buffer.GetSize(); + } + void* FileHandle = BlockFile.Detach(); + BlockBuffer = IoBuffer(IoBuffer::File, FileHandle, 0, Offset); + } + catch (std::exception& Ex) + { + RemoteResult.SetError(gsl::narrow<int32_t>(HttpResponseCode::InternalServerError), + Ex.what(), + "Unable to create temp block file"); + return; + } + + BlockBuffer.MarkAsDeleteOnClose(); + { + RwLock::ExclusiveLockScope __(AttachmentsLock); + CreatedBlocks.insert({BlockHash, std::move(BlockBuffer)}); + } + ZEN_DEBUG("Saved temp block {}, {}", BlockHash, NiceBytes(CompressedBlock.GetCompressedSize())); + } + }; + + auto UploadBlock = [&RemoteStore, &RemoteResult](CompressedBuffer&& CompressedBlock, const IoHash& BlockHash) { + RemoteProjectStore::SaveAttachmentResult Result = RemoteStore.SaveAttachment(CompressedBlock.GetCompressed(), BlockHash); + if (Result.ErrorCode) + { + RemoteResult.SetError(Result.ErrorCode, Result.Reason, Result.Text); + ZEN_ERROR("Failed to save attachment ({}). Reason: '{}'", RemoteResult.GetErrorReason(), RemoteResult.GetError()); + return; + } + ZEN_DEBUG("Saved block {}, {}", BlockHash, NiceBytes(CompressedBlock.GetCompressedSize())); + }; + + std::vector<std::vector<IoHash>> BlockChunks; + auto OnBlockChunks = [&BlockChunks](const std::unordered_set<IoHash, IoHash::Hasher>& Chunks) { + BlockChunks.push_back({Chunks.begin(), Chunks.end()}); + ZEN_DEBUG("Found {} block chunks", Chunks.size()); + }; + + auto OnLargeAttachment = [&AttachmentsLock, &LargeAttachments](const IoHash& AttachmentHash) { + { + RwLock::ExclusiveLockScope _(AttachmentsLock); + LargeAttachments.insert(AttachmentHash); + } + ZEN_DEBUG("Found attachment {}", AttachmentHash); + }; + + std::function<void(CompressedBuffer&&, const IoHash&)> OnBlock; + if (UseTempBlocks) + { + OnBlock = MakeTempBlock; + } + else + { + OnBlock = UploadBlock; + } + + CbObject OplogContainerObject = BuildContainer(ChunkStore, + Oplog, + MaxBlockSize, + MaxChunkEmbedSize, + BuildBlocks, + WorkerPool, + OnBlock, + OnLargeAttachment, + OnBlockChunks, + RemoteResult); + + if (!RemoteResult.IsError()) + { + uint64_t ChunkCount = OplogContainerObject["chunks"sv].AsArrayView().Num(); + uint64_t BlockCount = OplogContainerObject["blocks"sv].AsArrayView().Num(); + ZEN_INFO("Saving oplog container with {} attachments and {} blocks...", ChunkCount, BlockCount); + RemoteProjectStore::SaveResult ContainerSaveResult = RemoteStore.SaveContainer(OplogContainerObject.GetBuffer().AsIoBuffer()); + if (ContainerSaveResult.ErrorCode) + { + RemoteResult.SetError(ContainerSaveResult.ErrorCode, ContainerSaveResult.Reason, "Failed to save oplog container"); + ZEN_ERROR("Failed to save oplog container ({}). Reason: '{}'", RemoteResult.GetErrorReason(), RemoteResult.GetError()); + } + ZEN_DEBUG("Saved container in {}", NiceTimeSpanMs(static_cast<uint64_t>(ContainerSaveResult.ElapsedSeconds * 1000))); + if (!ContainerSaveResult.Needs.empty()) + { + ZEN_INFO("Filtering needed attachments..."); + std::vector<IoHash> NeededLargeAttachments; + std::unordered_set<IoHash, IoHash::Hasher> NeededOtherAttachments; + NeededLargeAttachments.reserve(LargeAttachments.size()); + NeededOtherAttachments.reserve(CreatedBlocks.size()); + if (ForceUpload) + { + NeededLargeAttachments.insert(NeededLargeAttachments.end(), LargeAttachments.begin(), LargeAttachments.end()); + } + else + { + for (const IoHash& RawHash : ContainerSaveResult.Needs) + { + if (LargeAttachments.contains(RawHash)) + { + NeededLargeAttachments.push_back(RawHash); + continue; + } + NeededOtherAttachments.insert(RawHash); + } + } + + Latch SaveAttachmentsLatch(1); + if (!NeededLargeAttachments.empty()) + { + ZEN_INFO("Saving large attachments..."); + for (const IoHash& RawHash : NeededLargeAttachments) + { + if (RemoteResult.IsError()) + { + break; + } + SaveAttachmentsLatch.AddCount(1); + WorkerPool.ScheduleWork([&ChunkStore, &RemoteStore, &SaveAttachmentsLatch, &RemoteResult, RawHash, &CreatedBlocks]() { + auto _ = MakeGuard([&SaveAttachmentsLatch] { SaveAttachmentsLatch.CountDown(); }); + if (RemoteResult.IsError()) + { + return; + } + + IoBuffer Payload; + if (auto It = CreatedBlocks.find(RawHash); It != CreatedBlocks.end()) + { + Payload = std::move(It->second); + } + else + { + Payload = ChunkStore.FindChunkByCid(RawHash); + } + if (!Payload) + { + RemoteResult.SetError(gsl::narrow<int>(HttpResponseCode::NotFound), + fmt::format("Failed to find attachment {}", RawHash), + {}); + ZEN_ERROR("Failed to build container ({}). Reason: '{}'", + RemoteResult.GetErrorReason(), + RemoteResult.GetError()); + return; + } + + RemoteProjectStore::SaveAttachmentResult Result = + RemoteStore.SaveAttachment(CompositeBuffer(SharedBuffer(Payload)), RawHash); + if (Result.ErrorCode) + { + RemoteResult.SetError(Result.ErrorCode, Result.Reason, Result.Text); + ZEN_ERROR("Failed to save attachment '{}', {} ({}). Reason: '{}'", + RawHash, + NiceBytes(Payload.GetSize()), + RemoteResult.GetError(), + RemoteResult.GetErrorReason()); + return; + } + ZEN_DEBUG("Saved attachment {}, {} in {}", + RawHash, + NiceBytes(Payload.GetSize()), + NiceTimeSpanMs(static_cast<uint64_t>(Result.ElapsedSeconds * 1000))); + return; + }); + } + } + + if (!CreatedBlocks.empty()) + { + ZEN_INFO("Saving created block attachments..."); + for (auto& It : CreatedBlocks) + { + if (RemoteResult.IsError()) + { + break; + } + const IoHash& RawHash = It.first; + if (ForceUpload || NeededOtherAttachments.contains(RawHash)) + { + IoBuffer Payload = It.second; + ZEN_ASSERT(Payload); + SaveAttachmentsLatch.AddCount(1); + WorkerPool.ScheduleWork( + [&ChunkStore, &RemoteStore, &SaveAttachmentsLatch, &RemoteResult, Payload = std::move(Payload), RawHash]() { + auto _ = MakeGuard([&SaveAttachmentsLatch] { SaveAttachmentsLatch.CountDown(); }); + if (RemoteResult.IsError()) + { + return; + } + + RemoteProjectStore::SaveAttachmentResult Result = + RemoteStore.SaveAttachment(CompositeBuffer(SharedBuffer(Payload)), RawHash); + if (Result.ErrorCode) + { + RemoteResult.SetError(Result.ErrorCode, Result.Reason, Result.Text); + ZEN_ERROR("Failed to save attachment '{}', {} ({}). Reason: '{}'", + RawHash, + NiceBytes(Payload.GetSize()), + RemoteResult.GetError(), + RemoteResult.GetErrorReason()); + return; + } + + ZEN_DEBUG("Saved attachment {}, {} in {}", + RawHash, + NiceBytes(Payload.GetSize()), + NiceTimeSpanMs(static_cast<uint64_t>(Result.ElapsedSeconds * 1000))); + return; + }); + } + It.second = {}; + } + } + + if (!BlockChunks.empty()) + { + ZEN_INFO("Saving chunk block attachments..."); + for (const std::vector<IoHash>& Chunks : BlockChunks) + { + if (RemoteResult.IsError()) + { + break; + } + std::vector<IoHash> NeededChunks; + if (ForceUpload) + { + NeededChunks = Chunks; + } + else + { + NeededChunks.reserve(Chunks.size()); + for (const IoHash& Chunk : Chunks) + { + if (NeededOtherAttachments.contains(Chunk)) + { + NeededChunks.push_back(Chunk); + } + } + if (NeededChunks.empty()) + { + continue; + } + } + SaveAttachmentsLatch.AddCount(1); + WorkerPool.ScheduleWork([&RemoteStore, + &ChunkStore, + &SaveAttachmentsLatch, + &RemoteResult, + &Chunks, + NeededChunks = std::move(NeededChunks), + ForceUpload]() { + auto _ = MakeGuard([&SaveAttachmentsLatch] { SaveAttachmentsLatch.CountDown(); }); + std::vector<SharedBuffer> ChunkBuffers; + ChunkBuffers.reserve(NeededChunks.size()); + for (const IoHash& Chunk : NeededChunks) + { + IoBuffer ChunkPayload = ChunkStore.FindChunkByCid(Chunk); + if (!ChunkPayload) + { + RemoteResult.SetError(static_cast<int32_t>(HttpResponseCode::NotFound), + fmt::format("Missing chunk {}"sv, Chunk), + fmt::format("Unable to fetch attachment {} required by the oplog"sv, Chunk)); + ChunkBuffers.clear(); + break; + } + ChunkBuffers.emplace_back(SharedBuffer(std::move(ChunkPayload))); + } + RemoteProjectStore::SaveAttachmentsResult Result = RemoteStore.SaveAttachments(ChunkBuffers); + if (Result.ErrorCode) + { + RemoteResult.SetError(Result.ErrorCode, Result.Reason, Result.Text); + ZEN_ERROR("Failed to save attachments with {} chunks ({}). Reason: '{}'", + Chunks.size(), + RemoteResult.GetError(), + RemoteResult.GetErrorReason()); + return; + } + ZEN_DEBUG("Saved {} bulk attachments in {}", + Chunks.size(), + NiceTimeSpanMs(static_cast<uint64_t>(Result.ElapsedSeconds * 1000))); + }); + } + } + SaveAttachmentsLatch.CountDown(); + while (!SaveAttachmentsLatch.Wait(1000)) + { + ZEN_INFO("Saving attachments, {} remaining...", SaveAttachmentsLatch.Remaining()); + } + SaveAttachmentsLatch.Wait(); + } + + if (!RemoteResult.IsError()) + { + ZEN_INFO("Finalizing oplog container..."); + RemoteProjectStore::Result ContainerFinalizeResult = RemoteStore.FinalizeContainer(ContainerSaveResult.RawHash); + if (ContainerFinalizeResult.ErrorCode) + { + RemoteResult.SetError(ContainerFinalizeResult.ErrorCode, ContainerFinalizeResult.Reason, ContainerFinalizeResult.Text); + ZEN_ERROR("Failed to finalize oplog container {} ({}). Reason: '{}'", + ContainerSaveResult.RawHash, + RemoteResult.GetError(), + RemoteResult.GetErrorReason()); + } + ZEN_DEBUG("Finalized container in {}", NiceTimeSpanMs(static_cast<uint64_t>(ContainerFinalizeResult.ElapsedSeconds * 1000))); + } + } + + RemoteProjectStore::Result Result = RemoteResult.ConvertResult(); + Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500; + ZEN_INFO("Saved oplog {} in {}", + RemoteResult.GetError() == 0 ? "SUCCESS" : "FAILURE", + NiceTimeSpanMs(static_cast<uint64_t>(Result.ElapsedSeconds * 1000))); + return Result; +}; + +RemoteProjectStore::Result +SaveOplogContainer(ProjectStore::Oplog& Oplog, + const CbObject& ContainerObject, + const std::function<bool(const IoHash& RawHash)>& HasAttachment, + const std::function<void(const IoHash& BlockHash, std::vector<IoHash>&& Chunks)>& OnNeedBlock, + const std::function<void(const IoHash& RawHash)>& OnNeedAttachment) +{ + using namespace std::literals; + + Stopwatch Timer; + + CbArrayView LargeChunksArray = ContainerObject["chunks"sv].AsArrayView(); + for (CbFieldView LargeChunksField : LargeChunksArray) + { + IoHash AttachmentHash = LargeChunksField.AsBinaryAttachment(); + if (HasAttachment(AttachmentHash)) + { + continue; + } + OnNeedAttachment(AttachmentHash); + }; + + CbArrayView BlocksArray = ContainerObject["blocks"sv].AsArrayView(); + for (CbFieldView BlockField : BlocksArray) + { + CbObjectView BlockView = BlockField.AsObjectView(); + IoHash BlockHash = BlockView["rawhash"sv].AsBinaryAttachment(); + + CbArrayView ChunksArray = BlockView["chunks"sv].AsArrayView(); + if (BlockHash == IoHash::Zero) + { + std::vector<IoHash> NeededChunks; + NeededChunks.reserve(ChunksArray.GetSize()); + for (CbFieldView ChunkField : ChunksArray) + { + IoHash ChunkHash = ChunkField.AsBinaryAttachment(); + if (HasAttachment(ChunkHash)) + { + continue; + } + NeededChunks.emplace_back(ChunkHash); + } + + if (!NeededChunks.empty()) + { + OnNeedBlock(IoHash::Zero, std::move(NeededChunks)); + } + continue; + } + + for (CbFieldView ChunkField : ChunksArray) + { + IoHash ChunkHash = ChunkField.AsHash(); + if (HasAttachment(ChunkHash)) + { + continue; + } + + OnNeedBlock(BlockHash, {}); + break; + } + }; + + MemoryView OpsSection = ContainerObject["ops"sv].AsBinaryView(); + IoBuffer OpsBuffer(IoBuffer::Wrap, OpsSection.GetData(), OpsSection.GetSize()); + IoBuffer SectionPayload = CompressedBuffer::FromCompressedNoValidate(std::move(OpsBuffer)).Decompress().AsIoBuffer(); + + CbObject SectionObject = LoadCompactBinaryObject(SectionPayload); + if (!SectionObject) + { + ZEN_ERROR("Failed to save oplog container. Reason: '{}'", "Section has unexpected data type"); + return RemoteProjectStore::Result{gsl::narrow<int>(HttpResponseCode::BadRequest), + Timer.GetElapsedTimeMs() / 1000.500, + "Section has unexpected data type", + "Failed to save oplog container"}; + } + + CbArrayView OpsArray = SectionObject["ops"sv].AsArrayView(); + for (CbFieldView OpEntry : OpsArray) + { + CbObjectView Core = OpEntry.AsObjectView(); + BinaryWriter Writer; + Core.CopyTo(Writer); + MemoryView OpView = Writer.GetView(); + IoBuffer OpBuffer(IoBuffer::Wrap, OpView.GetData(), OpView.GetSize()); + CbObject Op(SharedBuffer(OpBuffer), CbFieldType::HasFieldType); + const uint32_t OpLsn = Oplog.AppendNewOplogEntry(Op); + if (OpLsn == ProjectStore::Oplog::kInvalidOp) + { + return RemoteProjectStore::Result{gsl::narrow<int>(HttpResponseCode::BadRequest), + Timer.GetElapsedTimeMs() / 1000.500, + "Failed saving op", + "Failed to save oplog container"}; + } + ZEN_DEBUG("oplog entry #{}", OpLsn); + } + return RemoteProjectStore::Result{.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500}; +} + +RemoteProjectStore::Result +LoadOplog(CidStore& ChunkStore, RemoteProjectStore& RemoteStore, ProjectStore::Oplog& Oplog, bool ForceDownload) +{ + using namespace std::literals; + + Stopwatch Timer; + + // We are creating a worker thread pool here since we are download a lot of attachments in one go and we dont want to keep a + // WorkerThreadPool alive + size_t WorkerCount = Min(std::thread::hardware_concurrency(), 16u); + WorkerThreadPool WorkerPool(gsl::narrow<int>(WorkerCount)); + + std::unordered_set<IoHash, IoHash::Hasher> Attachments; + std::vector<std::vector<IoHash>> ChunksInBlocks; + + RemoteProjectStore::LoadContainerResult LoadContainerResult = RemoteStore.LoadContainer(); + if (LoadContainerResult.ErrorCode) + { + ZEN_WARN("Failed to load oplog container, reason: '{}', error code: {}", LoadContainerResult.Reason, LoadContainerResult.ErrorCode); + return RemoteProjectStore::Result{.ErrorCode = LoadContainerResult.ErrorCode, + .ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500, + .Reason = LoadContainerResult.Reason, + .Text = LoadContainerResult.Text}; + } + ZEN_DEBUG("Loaded container in {}", NiceTimeSpanMs(static_cast<uint64_t>(LoadContainerResult.ElapsedSeconds * 1000))); + + AsyncRemoteResult RemoteResult; + Latch AttachmentsWorkLatch(1); + + auto HasAttachment = [&ChunkStore, ForceDownload](const IoHash& RawHash) { + return !ForceDownload && ChunkStore.ContainsChunk(RawHash); + }; + auto OnNeedBlock = [&RemoteStore, &ChunkStore, &WorkerPool, &ChunksInBlocks, &AttachmentsWorkLatch, &RemoteResult]( + const IoHash& BlockHash, + std::vector<IoHash>&& Chunks) { + if (BlockHash == IoHash::Zero) + { + AttachmentsWorkLatch.AddCount(1); + WorkerPool.ScheduleWork([&RemoteStore, &ChunkStore, &AttachmentsWorkLatch, &RemoteResult, Chunks = std::move(Chunks)]() { + auto _ = MakeGuard([&AttachmentsWorkLatch] { AttachmentsWorkLatch.CountDown(); }); + if (RemoteResult.IsError()) + { + return; + } + + RemoteProjectStore::LoadAttachmentsResult Result = RemoteStore.LoadAttachments(Chunks); + if (Result.ErrorCode) + { + RemoteResult.SetError(Result.ErrorCode, Result.Reason, Result.Text); + ZEN_ERROR("Failed to attachments with {} chunks ({}). Reason: '{}'", + Chunks.size(), + RemoteResult.GetError(), + RemoteResult.GetErrorReason()); + return; + } + ZEN_DEBUG("Loaded {} bulk attachments in {}", + Chunks.size(), + NiceTimeSpanMs(static_cast<uint64_t>(Result.ElapsedSeconds * 1000))); + for (const auto& It : Result.Chunks) + { + ChunkStore.AddChunk(It.second.GetCompressed().Flatten().AsIoBuffer(), It.first, CidStore::InsertMode::kCopyOnly); + } + }); + return; + } + AttachmentsWorkLatch.AddCount(1); + WorkerPool.ScheduleWork([&AttachmentsWorkLatch, &ChunkStore, &RemoteStore, BlockHash, &RemoteResult]() { + auto _ = MakeGuard([&AttachmentsWorkLatch] { AttachmentsWorkLatch.CountDown(); }); + if (RemoteResult.IsError()) + { + return; + } + RemoteProjectStore::LoadAttachmentResult BlockResult = RemoteStore.LoadAttachment(BlockHash); + if (BlockResult.ErrorCode) + { + RemoteResult.SetError(BlockResult.ErrorCode, BlockResult.Reason, BlockResult.Text); + ZEN_ERROR("Failed to load oplog container, missing attachment {} ({}). Reason: '{}'", + BlockHash, + RemoteResult.GetError(), + RemoteResult.GetErrorReason()); + return; + } + ZEN_DEBUG("Loaded block attachment in {}", NiceTimeSpanMs(static_cast<uint64_t>(BlockResult.ElapsedSeconds * 1000))); + + if (!IterateBlock(std::move(BlockResult.Bytes), [&ChunkStore](CompressedBuffer&& Chunk, const IoHash& AttachmentRawHash) { + ChunkStore.AddChunk(Chunk.GetCompressed().Flatten().AsIoBuffer(), AttachmentRawHash); + })) + { + RemoteResult.SetError(gsl::narrow<int32_t>(HttpResponseCode::InternalServerError), + fmt::format("Invalid format for block {}", BlockHash), + {}); + ZEN_ERROR("Failed to load oplog container, attachment {} has invalid format ({}). Reason: '{}'", + BlockHash, + RemoteResult.GetError(), + RemoteResult.GetErrorReason()); + return; + } + }); + }; + + auto OnNeedAttachment = + [&RemoteStore, &ChunkStore, &WorkerPool, &AttachmentsWorkLatch, &RemoteResult, &Attachments](const IoHash& RawHash) { + if (!Attachments.insert(RawHash).second) + { + return; + } + + AttachmentsWorkLatch.AddCount(1); + WorkerPool.ScheduleWork([&RemoteStore, &ChunkStore, &RemoteResult, &AttachmentsWorkLatch, RawHash]() { + auto _ = MakeGuard([&AttachmentsWorkLatch] { AttachmentsWorkLatch.CountDown(); }); + if (RemoteResult.IsError()) + { + return; + } + RemoteProjectStore::LoadAttachmentResult AttachmentResult = RemoteStore.LoadAttachment(RawHash); + if (AttachmentResult.ErrorCode) + { + RemoteResult.SetError(AttachmentResult.ErrorCode, AttachmentResult.Reason, AttachmentResult.Text); + ZEN_ERROR("Failed to download attachment {}, reason: '{}', error code: {}", + RawHash, + AttachmentResult.Reason, + AttachmentResult.ErrorCode); + return; + } + ZEN_DEBUG("Loaded attachment in {}", NiceTimeSpanMs(static_cast<uint64_t>(AttachmentResult.ElapsedSeconds * 1000))); + ChunkStore.AddChunk(AttachmentResult.Bytes, RawHash); + }); + }; + + RemoteProjectStore::Result Result = + SaveOplogContainer(Oplog, LoadContainerResult.ContainerObject, HasAttachment, OnNeedBlock, OnNeedAttachment); + + AttachmentsWorkLatch.CountDown(); + while (!AttachmentsWorkLatch.Wait(1000)) + { + ZEN_INFO("Loading attachments, {} remaining...", AttachmentsWorkLatch.Remaining()); + } + AttachmentsWorkLatch.Wait(); + if (Result.ErrorCode == 0) + { + Result = RemoteResult.ConvertResult(); + } + Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500; + + ZEN_INFO("Loaded oplog {} in {}", + RemoteResult.GetError() == 0 ? "SUCCESS" : "FAILURE", + NiceTimeSpanMs(static_cast<uint64_t>(Result.ElapsedSeconds * 1000.0))); + + return Result; +} + +} // namespace zen diff --git a/src/zenserver/projectstore/remoteprojectstore.h b/src/zenserver/projectstore/remoteprojectstore.h new file mode 100644 index 000000000..dcabaedd4 --- /dev/null +++ b/src/zenserver/projectstore/remoteprojectstore.h @@ -0,0 +1,111 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "projectstore.h" + +#include <unordered_set> + +namespace zen { + +class CidStore; +class WorkerThreadPool; + +class RemoteProjectStore +{ +public: + struct Result + { + int32_t ErrorCode{}; + double ElapsedSeconds{}; + std::string Reason; + std::string Text; + }; + + struct SaveResult : public Result + { + std::unordered_set<IoHash, IoHash::Hasher> Needs; + IoHash RawHash; + }; + + struct SaveAttachmentResult : public Result + { + }; + + struct SaveAttachmentsResult : public Result + { + }; + + struct LoadAttachmentResult : public Result + { + IoBuffer Bytes; + }; + + struct LoadContainerResult : public Result + { + CbObject ContainerObject; + }; + + struct LoadAttachmentsResult : public Result + { + std::vector<std::pair<IoHash, CompressedBuffer>> Chunks; + }; + + struct RemoteStoreInfo + { + bool CreateBlocks; + bool UseTempBlockFiles; + std::string Description; + }; + + virtual ~RemoteProjectStore() {} + + virtual RemoteStoreInfo GetInfo() const = 0; + + virtual SaveResult SaveContainer(const IoBuffer& Payload) = 0; + virtual SaveAttachmentResult SaveAttachment(const CompositeBuffer& Payload, const IoHash& RawHash) = 0; + virtual Result FinalizeContainer(const IoHash& RawHash) = 0; + virtual SaveAttachmentsResult SaveAttachments(const std::vector<SharedBuffer>& Payloads) = 0; + + virtual LoadContainerResult LoadContainer() = 0; + virtual LoadAttachmentResult LoadAttachment(const IoHash& RawHash) = 0; + virtual LoadAttachmentsResult LoadAttachments(const std::vector<IoHash>& RawHashes) = 0; +}; + +struct RemoteStoreOptions +{ + size_t MaxBlockSize = 128u * 1024u * 1024u; + size_t MaxChunkEmbedSize = 1024u * 1024u; +}; + +RemoteProjectStore::LoadContainerResult BuildContainer( + CidStore& ChunkStore, + ProjectStore::Oplog& Oplog, + size_t MaxBlockSize, + size_t MaxChunkEmbedSize, + bool BuildBlocks, + const std::function<void(CompressedBuffer&&, const IoHash&)>& AsyncOnBlock, + const std::function<void(const IoHash&)>& OnLargeAttachment, + const std::function<void(const std::unordered_set<IoHash, IoHash::Hasher>)>& OnBlockChunks); + +RemoteProjectStore::Result SaveOplogContainer(ProjectStore::Oplog& Oplog, + const CbObject& ContainerObject, + const std::function<bool(const IoHash& RawHash)>& HasAttachment, + const std::function<void(const IoHash& BlockHash, std::vector<IoHash>&& Chunks)>& OnNeedBlock, + const std::function<void(const IoHash& RawHash)>& OnNeedAttachment); + +RemoteProjectStore::Result SaveOplog(CidStore& ChunkStore, + RemoteProjectStore& RemoteStore, + ProjectStore::Oplog& Oplog, + size_t MaxBlockSize, + size_t MaxChunkEmbedSize, + bool BuildBlocks, + bool UseTempBlocks, + bool ForceUpload); + +RemoteProjectStore::Result LoadOplog(CidStore& ChunkStore, RemoteProjectStore& RemoteStore, ProjectStore::Oplog& Oplog, bool ForceDownload); + +CompressedBuffer GenerateBlock(std::vector<SharedBuffer>&& Chunks); +bool IterateBlock(IoBuffer&& CompressedBlock, std::function<void(CompressedBuffer&& Chunk, const IoHash& AttachmentHash)> Visitor); + +} // namespace zen diff --git a/src/zenserver/projectstore/zenremoteprojectstore.cpp b/src/zenserver/projectstore/zenremoteprojectstore.cpp new file mode 100644 index 000000000..6ff471ae5 --- /dev/null +++ b/src/zenserver/projectstore/zenremoteprojectstore.cpp @@ -0,0 +1,341 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "zenremoteprojectstore.h" + +#include <zencore/compactbinarybuilder.h> +#include <zencore/compactbinarypackage.h> +#include <zencore/compositebuffer.h> +#include <zencore/fmtutils.h> +#include <zencore/scopeguard.h> +#include <zencore/stream.h> +#include <zencore/timer.h> +#include <zenhttp/httpshared.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <cpr/cpr.h> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen { + +using namespace std::literals; + +class ZenRemoteStore : public RemoteProjectStore +{ +public: + ZenRemoteStore(std::string_view HostAddress, + std::string_view Project, + std::string_view Oplog, + size_t MaxBlockSize, + size_t MaxChunkEmbedSize) + : m_HostAddress(HostAddress) + , m_ProjectStoreUrl(fmt::format("{}/prj"sv, m_HostAddress)) + , m_Project(Project) + , m_Oplog(Oplog) + , m_MaxBlockSize(MaxBlockSize) + , m_MaxChunkEmbedSize(MaxChunkEmbedSize) + { + } + + virtual RemoteStoreInfo GetInfo() const override + { + return {.CreateBlocks = false, .UseTempBlockFiles = false, .Description = fmt::format("[zen] {}"sv, m_HostAddress)}; + } + + virtual SaveResult SaveContainer(const IoBuffer& Payload) override + { + Stopwatch Timer; + + std::unique_ptr<cpr::Session> Session(AllocateSession()); + auto _ = MakeGuard([this, &Session]() { ReleaseSession(std::move(Session)); }); + + std::string SaveRequest = fmt::format("{}/{}/oplog/{}/save"sv, m_ProjectStoreUrl, m_Project, m_Oplog); + Session->SetUrl({SaveRequest}); + Session->SetHeader({{"Content-Type", std::string(MapContentTypeToString(HttpContentType::kCbObject))}}); + MemoryView Data(Payload.GetView()); + Session->SetBody({reinterpret_cast<const char*>(Data.GetData()), Data.GetSize()}); + cpr::Response Response = Session->Post(); + SaveResult Result = SaveResult{ConvertResult(Response)}; + + if (Result.ErrorCode) + { + Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500; + return Result; + } + IoBuffer ResponsePayload(IoBuffer::Wrap, Response.text.data(), Response.text.size()); + CbObject ResponseObject = LoadCompactBinaryObject(ResponsePayload); + if (!ResponseObject) + { + Result.Reason = fmt::format("The response for {}/{}/{} is not formatted as a compact binary object"sv, + m_ProjectStoreUrl, + m_Project, + m_Oplog); + Result.ErrorCode = gsl::narrow<int32_t>(HttpResponseCode::InternalServerError); + Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500; + return Result; + } + CbArrayView NeedsArray = ResponseObject["need"sv].AsArrayView(); + for (CbFieldView FieldView : NeedsArray) + { + IoHash ChunkHash = FieldView.AsHash(); + Result.Needs.insert(ChunkHash); + } + + Result.RawHash = IoHash::HashBuffer(Payload); + Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500; + return Result; + } + + virtual SaveAttachmentResult SaveAttachment(const CompositeBuffer& Payload, const IoHash& RawHash) override + { + Stopwatch Timer; + + std::unique_ptr<cpr::Session> Session(AllocateSession()); + auto _ = MakeGuard([this, &Session]() { ReleaseSession(std::move(Session)); }); + + std::string SaveRequest = fmt::format("{}/{}/oplog/{}/{}"sv, m_ProjectStoreUrl, m_Project, m_Oplog, RawHash); + Session->SetUrl({SaveRequest}); + Session->SetHeader({{"Content-Type", std::string(MapContentTypeToString(HttpContentType::kCompressedBinary))}}); + uint64_t SizeLeft = Payload.GetSize(); + CompositeBuffer::Iterator BufferIt = Payload.GetIterator(0); + auto ReadCallback = [&Payload, &BufferIt, &SizeLeft](char* buffer, size_t& size, intptr_t) { + size = Min<size_t>(size, SizeLeft); + MutableMemoryView Data(buffer, size); + Payload.CopyTo(Data, BufferIt); + SizeLeft -= size; + return true; + }; + Session->SetReadCallback(cpr::ReadCallback(gsl::narrow<cpr::cpr_off_t>(SizeLeft), ReadCallback)); + cpr::Response Response = Session->Post(); + SaveAttachmentResult Result = SaveAttachmentResult{ConvertResult(Response)}; + Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500; + return Result; + } + + virtual SaveAttachmentsResult SaveAttachments(const std::vector<SharedBuffer>& Chunks) override + { + Stopwatch Timer; + + CbPackage RequestPackage; + { + CbObjectWriter RequestWriter; + RequestWriter.AddString("method"sv, "putchunks"sv); + RequestWriter.BeginArray("chunks"sv); + { + for (const SharedBuffer& Chunk : Chunks) + { + IoHash RawHash; + uint64_t RawSize; + CompressedBuffer Compressed = CompressedBuffer::FromCompressed(Chunk, RawHash, RawSize); + RequestWriter.AddHash(RawHash); + RequestPackage.AddAttachment(CbAttachment(Compressed, RawHash)); + } + } + RequestWriter.EndArray(); // "chunks" + RequestPackage.SetObject(RequestWriter.Save()); + } + CompositeBuffer Payload = FormatPackageMessageBuffer(RequestPackage, FormatFlags::kDefault); + + std::unique_ptr<cpr::Session> Session(AllocateSession()); + auto _ = MakeGuard([this, &Session]() { ReleaseSession(std::move(Session)); }); + std::string SaveRequest = fmt::format("{}/{}/oplog/{}/rpc"sv, m_ProjectStoreUrl, m_Project, m_Oplog); + Session->SetUrl({SaveRequest}); + Session->SetHeader({{"Content-Type", std::string(MapContentTypeToString(HttpContentType::kCbPackage))}}); + + uint64_t SizeLeft = Payload.GetSize(); + CompositeBuffer::Iterator BufferIt = Payload.GetIterator(0); + auto ReadCallback = [&Payload, &BufferIt, &SizeLeft](char* buffer, size_t& size, intptr_t) { + size = Min<size_t>(size, SizeLeft); + MutableMemoryView Data(buffer, size); + Payload.CopyTo(Data, BufferIt); + SizeLeft -= size; + return true; + }; + Session->SetReadCallback(cpr::ReadCallback(gsl::narrow<cpr::cpr_off_t>(SizeLeft), ReadCallback)); + cpr::Response Response = Session->Post(); + SaveAttachmentsResult Result = SaveAttachmentsResult{ConvertResult(Response)}; + Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500; + return Result; + } + + virtual LoadAttachmentsResult LoadAttachments(const std::vector<IoHash>& RawHashes) override + { + Stopwatch Timer; + + std::unique_ptr<cpr::Session> Session(AllocateSession()); + auto _ = MakeGuard([this, &Session]() { ReleaseSession(std::move(Session)); }); + std::string SaveRequest = fmt::format("{}/{}/oplog/{}/rpc"sv, m_ProjectStoreUrl, m_Project, m_Oplog); + + CbObject Request; + { + CbObjectWriter RequestWriter; + RequestWriter.AddString("method"sv, "getchunks"sv); + RequestWriter.BeginArray("chunks"sv); + { + for (const IoHash& RawHash : RawHashes) + { + RequestWriter.AddHash(RawHash); + } + } + RequestWriter.EndArray(); // "chunks" + Request = RequestWriter.Save(); + } + IoBuffer Payload = Request.GetBuffer().AsIoBuffer(); + Session->SetBody(cpr::Body{(const char*)Payload.GetData(), Payload.GetSize()}); + Session->SetUrl(SaveRequest); + Session->SetHeader({{"Content-Type", std::string(MapContentTypeToString(HttpContentType::kCbObject))}, + {"Accept", std::string(MapContentTypeToString(HttpContentType::kCbPackage))}}); + + cpr::Response Response = Session->Post(); + LoadAttachmentsResult Result = LoadAttachmentsResult{ConvertResult(Response)}; + if (!Result.ErrorCode) + { + CbPackage Package = ParsePackageMessage(IoBuffer(IoBuffer::Wrap, Response.text.data(), Response.text.size())); + std::span<const CbAttachment> Attachments = Package.GetAttachments(); + Result.Chunks.reserve(Attachments.size()); + for (const CbAttachment& Attachment : Attachments) + { + Result.Chunks.emplace_back( + std::pair<IoHash, CompressedBuffer>{Attachment.GetHash(), Attachment.AsCompressedBinary().MakeOwned()}); + } + } + Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500; + return Result; + }; + + virtual Result FinalizeContainer(const IoHash&) override + { + Stopwatch Timer; + + RwLock::ExclusiveLockScope _(SessionsLock); + Sessions.clear(); + return {.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500}; + } + + virtual LoadContainerResult LoadContainer() override + { + Stopwatch Timer; + + std::unique_ptr<cpr::Session> Session(AllocateSession()); + auto _ = MakeGuard([this, &Session]() { ReleaseSession(std::move(Session)); }); + std::string SaveRequest = fmt::format("{}/{}/oplog/{}/load"sv, m_ProjectStoreUrl, m_Project, m_Oplog); + Session->SetUrl(SaveRequest); + Session->SetHeader({{"Accept", std::string(MapContentTypeToString(HttpContentType::kCbObject))}}); + Session->SetParameters( + {{"maxblocksize", fmt::format("{}", m_MaxBlockSize)}, {"maxchunkembedsize", fmt::format("{}", m_MaxChunkEmbedSize)}}); + cpr::Response Response = Session->Get(); + + LoadContainerResult Result = LoadContainerResult{ConvertResult(Response)}; + if (!Result.ErrorCode) + { + Result.ContainerObject = LoadCompactBinaryObject(IoBuffer(IoBuffer::Clone, Response.text.data(), Response.text.size())); + if (!Result.ContainerObject) + { + Result.Reason = fmt::format("The response for {}/{}/{} is not formatted as a compact binary object"sv, + m_ProjectStoreUrl, + m_Project, + m_Oplog); + Result.ErrorCode = gsl::narrow<int32_t>(HttpResponseCode::InternalServerError); + Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500; + return Result; + } + } + Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500; + return Result; + } + + virtual LoadAttachmentResult LoadAttachment(const IoHash& RawHash) override + { + Stopwatch Timer; + + std::unique_ptr<cpr::Session> Session(AllocateSession()); + auto _ = MakeGuard([this, &Session]() { ReleaseSession(std::move(Session)); }); + + std::string LoadRequest = fmt::format("{}/{}/oplog/{}/{}"sv, m_ProjectStoreUrl, m_Project, m_Oplog, RawHash); + Session->SetUrl({LoadRequest}); + Session->SetHeader({{"Accept", std::string(MapContentTypeToString(HttpContentType::kCompressedBinary))}}); + cpr::Response Response = Session->Get(); + LoadAttachmentResult Result = LoadAttachmentResult{ConvertResult(Response)}; + if (!Result.ErrorCode) + { + Result.Bytes = IoBufferBuilder::MakeCloneFromMemory(Response.text.data(), Response.text.size()); + } + Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500; + return Result; + } + +private: + std::unique_ptr<cpr::Session> AllocateSession() + { + RwLock::ExclusiveLockScope _(SessionsLock); + if (Sessions.empty()) + { + Sessions.emplace_back(std::make_unique<cpr::Session>()); + } + std::unique_ptr<cpr::Session> Session = std::move(Sessions.back()); + Sessions.pop_back(); + return Session; + } + + void ReleaseSession(std::unique_ptr<cpr::Session>&& Session) + { + RwLock::ExclusiveLockScope _(SessionsLock); + Sessions.emplace_back(std::move(Session)); + } + + static Result ConvertResult(const cpr::Response& Response) + { + std::string Text; + std::string Reason = Response.reason; + int32_t ErrorCode = 0; + if (Response.error.code != cpr::ErrorCode::OK) + { + ErrorCode = static_cast<int32_t>(Response.error.code); + if (!Response.error.message.empty()) + { + Reason = Response.error.message; + } + } + else if (!IsHttpSuccessCode(Response.status_code)) + { + ErrorCode = static_cast<int32_t>(Response.status_code); + + if (auto It = Response.header.find("Content-Type"); It != Response.header.end()) + { + zen::HttpContentType ContentType = zen::ParseContentType(It->second); + if (ContentType == zen::HttpContentType::kText) + { + Text = Response.text; + } + } + + Reason = fmt::format("{}"sv, Response.status_code); + } + return {.ErrorCode = ErrorCode, .ElapsedSeconds = Response.elapsed, .Reason = Reason, .Text = Text}; + } + + RwLock SessionsLock; + std::vector<std::unique_ptr<cpr::Session>> Sessions; + + const std::string m_HostAddress; + const std::string m_ProjectStoreUrl; + const std::string m_Project; + const std::string m_Oplog; + const size_t m_MaxBlockSize; + const size_t m_MaxChunkEmbedSize; +}; + +std::unique_ptr<RemoteProjectStore> +CreateZenRemoteStore(const ZenRemoteStoreOptions& Options) +{ + std::string Url = Options.Url; + if (Url.find("://"sv) == std::string::npos) + { + // Assume https URL + Url = fmt::format("http://{}"sv, Url); + } + std::unique_ptr<RemoteProjectStore> RemoteStore = + std::make_unique<ZenRemoteStore>(Url, Options.ProjectId, Options.OplogId, Options.MaxBlockSize, Options.MaxChunkEmbedSize); + return RemoteStore; +} + +} // namespace zen diff --git a/src/zenserver/projectstore/zenremoteprojectstore.h b/src/zenserver/projectstore/zenremoteprojectstore.h new file mode 100644 index 000000000..ef9dcad8c --- /dev/null +++ b/src/zenserver/projectstore/zenremoteprojectstore.h @@ -0,0 +1,18 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "remoteprojectstore.h" + +namespace zen { + +struct ZenRemoteStoreOptions : RemoteStoreOptions +{ + std::string Url; + std::string ProjectId; + std::string OplogId; +}; + +std::unique_ptr<RemoteProjectStore> CreateZenRemoteStore(const ZenRemoteStoreOptions& Options); + +} // namespace zen diff --git a/src/zenserver/resource.h b/src/zenserver/resource.h new file mode 100644 index 000000000..f2e3b471b --- /dev/null +++ b/src/zenserver/resource.h @@ -0,0 +1,18 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +//{{NO_DEPENDENCIES}} +// Microsoft Visual C++ generated include file. +// Used by zenserver.rc +// +#define IDI_ICON1 101 + +// Next default values for new objects +// +#ifdef APSTUDIO_INVOKED +# ifndef APSTUDIO_READONLY_SYMBOLS +# define _APS_NEXT_RESOURCE_VALUE 102 +# define _APS_NEXT_COMMAND_VALUE 40001 +# define _APS_NEXT_CONTROL_VALUE 1001 +# define _APS_NEXT_SYMED_VALUE 101 +# endif +#endif diff --git a/src/zenserver/targetver.h b/src/zenserver/targetver.h new file mode 100644 index 000000000..d432d6993 --- /dev/null +++ b/src/zenserver/targetver.h @@ -0,0 +1,10 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +// Including SDKDDKVer.h defines the highest available Windows platform. + +// If you wish to build your application for a previous Windows platform, include WinSDKVer.h and +// set the _WIN32_WINNT macro to the platform you wish to support before including SDKDDKVer.h. + +#include <SDKDDKVer.h> diff --git a/src/zenserver/testing/httptest.cpp b/src/zenserver/testing/httptest.cpp new file mode 100644 index 000000000..349a95ab3 --- /dev/null +++ b/src/zenserver/testing/httptest.cpp @@ -0,0 +1,207 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "httptest.h" + +#include <zencore/compactbinarybuilder.h> +#include <zencore/compactbinarypackage.h> +#include <zencore/timer.h> + +namespace zen { + +using namespace std::literals; + +HttpTestingService::HttpTestingService() +{ + m_Router.RegisterRoute( + "hello", + [](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK); }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "hello_slow", + [](HttpRouterRequest& Req) { + Req.ServerRequest().WriteResponseAsync([](HttpServerRequest& Request) { + Stopwatch Timer; + Sleep(1000); + Request.WriteResponse(HttpResponseCode::OK, + HttpContentType::kText, + fmt::format("hello, took me {}", NiceTimeSpanMs(Timer.GetElapsedTimeMs()))); + }); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "hello_veryslow", + [](HttpRouterRequest& Req) { + Req.ServerRequest().WriteResponseAsync([](HttpServerRequest& Request) { + Stopwatch Timer; + Sleep(60000); + Request.WriteResponse(HttpResponseCode::OK, + HttpContentType::kText, + fmt::format("hello, took me {}", NiceTimeSpanMs(Timer.GetElapsedTimeMs()))); + }); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "hello_throw", + [](HttpRouterRequest& Req) { + Req.ServerRequest().WriteResponseAsync([](HttpServerRequest&) { throw std::runtime_error("intentional error"); }); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "hello_noresponse", + [](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponseAsync([](HttpServerRequest&) {}); }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "metrics", + [this](HttpRouterRequest& Req) { + metrics::OperationTiming::Scope _(m_TimingStats); + Req.ServerRequest().WriteResponse(HttpResponseCode::OK); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "get_metrics", + [this](HttpRouterRequest& Req) { + CbObjectWriter Cbo; + EmitSnapshot("requests", m_TimingStats, Cbo); + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Cbo.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "json", + [this](HttpRouterRequest& Req) { + CbObjectWriter Obj; + Obj.AddBool("ok", true); + Obj.AddInteger("counter", ++m_Counter); + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Obj.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "echo", + [](HttpRouterRequest& Req) { + IoBuffer Body = Req.ServerRequest().ReadPayload(); + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, Body); + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "package", + [](HttpRouterRequest& Req) { + CbPackage Pkg = Req.ServerRequest().ReadPayloadPackage(); + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Pkg); + }, + HttpVerb::kPost); +} + +HttpTestingService::~HttpTestingService() +{ +} + +const char* +HttpTestingService::BaseUri() const +{ + return "/testing/"; +} + +void +HttpTestingService::HandleRequest(HttpServerRequest& Request) +{ + m_Router.HandleRequest(Request); +} + +Ref<IHttpPackageHandler> +HttpTestingService::HandlePackageRequest(HttpServerRequest& HttpServiceRequest) +{ + RwLock::ExclusiveLockScope _(m_RwLock); + + const uint32_t RequestId = HttpServiceRequest.RequestId(); + + if (auto It = m_HandlerMap.find(RequestId); It != m_HandlerMap.end()) + { + Ref<HttpTestingService::PackageHandler> Handler = std::move(It->second); + + m_HandlerMap.erase(It); + + return Handler; + } + + auto InsertResult = m_HandlerMap.insert({RequestId, Ref<PackageHandler>()}); + + _.ReleaseNow(); + + return (InsertResult.first->second = Ref<PackageHandler>(new PackageHandler(*this, RequestId))); +} + +void +HttpTestingService::RegisterHandlers(WebSocketServer& Server) +{ + Server.RegisterRequestHandler("SayHello"sv, *this); +} + +bool +HttpTestingService::HandleRequest(const WebSocketMessage& RequestMsg) +{ + CbObjectView Request = RequestMsg.Body().GetObject(); + + std::string_view Method = Request["Method"].AsString(); + + if (Method != "SayHello"sv) + { + return false; + } + + CbObjectWriter Response; + Response.AddString("Result"sv, "Hello Friend!!"); + + WebSocketMessage ResponseMsg; + ResponseMsg.SetMessageType(WebSocketMessageType::kResponse); + ResponseMsg.SetCorrelationId(RequestMsg.CorrelationId()); + ResponseMsg.SetSocketId(RequestMsg.SocketId()); + ResponseMsg.SetBody(Response.Save()); + + SocketServer().SendResponse(std::move(ResponseMsg)); + + return true; +} + +////////////////////////////////////////////////////////////////////////// + +HttpTestingService::PackageHandler::PackageHandler(HttpTestingService& Svc, uint32_t RequestId) : m_Svc(Svc), m_RequestId(RequestId) +{ +} + +HttpTestingService::PackageHandler::~PackageHandler() +{ +} + +void +HttpTestingService::PackageHandler::FilterOffer(std::vector<IoHash>& OfferCids) +{ + ZEN_UNUSED(OfferCids); + // No-op + return; +} +void +HttpTestingService::PackageHandler::OnRequestBegin() +{ +} + +void +HttpTestingService::PackageHandler::OnRequestComplete() +{ +} + +IoBuffer +HttpTestingService::PackageHandler::CreateTarget(const IoHash& Cid, uint64_t StorageSize) +{ + ZEN_UNUSED(Cid); + return IoBuffer{StorageSize}; +} + +} // namespace zen diff --git a/src/zenserver/testing/httptest.h b/src/zenserver/testing/httptest.h new file mode 100644 index 000000000..57d2d63f3 --- /dev/null +++ b/src/zenserver/testing/httptest.h @@ -0,0 +1,55 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/logging.h> +#include <zencore/stats.h> +#include <zenhttp/httpserver.h> +#include <zenhttp/websocket.h> + +#include <atomic> + +namespace zen { + +/** + * Test service to facilitate testing the HTTP framework and client interactions + */ +class HttpTestingService : public HttpService, public WebSocketService +{ +public: + HttpTestingService(); + ~HttpTestingService(); + + virtual const char* BaseUri() const override; + virtual void HandleRequest(HttpServerRequest& Request) override; + virtual Ref<IHttpPackageHandler> HandlePackageRequest(HttpServerRequest& HttpServiceRequest) override; + + class PackageHandler : public IHttpPackageHandler + { + public: + PackageHandler(HttpTestingService& Svc, uint32_t RequestId); + ~PackageHandler(); + + virtual void FilterOffer(std::vector<IoHash>& OfferCids) override; + virtual void OnRequestBegin() override; + virtual IoBuffer CreateTarget(const IoHash& Cid, uint64_t StorageSize) override; + virtual void OnRequestComplete() override; + + private: + HttpTestingService& m_Svc; + uint32_t m_RequestId; + }; + +private: + virtual void RegisterHandlers(WebSocketServer& Server) override; + virtual bool HandleRequest(const WebSocketMessage& Request) override; + + HttpRequestRouter m_Router; + std::atomic<uint32_t> m_Counter{0}; + metrics::OperationTiming m_TimingStats; + + RwLock m_RwLock; + std::unordered_map<uint32_t, Ref<PackageHandler>> m_HandlerMap; +}; + +} // namespace zen diff --git a/src/zenserver/upstream/hordecompute.cpp b/src/zenserver/upstream/hordecompute.cpp new file mode 100644 index 000000000..64d9fff72 --- /dev/null +++ b/src/zenserver/upstream/hordecompute.cpp @@ -0,0 +1,1457 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "upstreamapply.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include "jupiter.h" + +# include <zencore/compactbinary.h> +# include <zencore/compactbinarybuilder.h> +# include <zencore/compactbinarypackage.h> +# include <zencore/compactbinaryvalidation.h> +# include <zencore/fmtutils.h> +# include <zencore/session.h> +# include <zencore/stream.h> +# include <zencore/thread.h> +# include <zencore/timer.h> +# include <zencore/workthreadpool.h> + +# include <zenstore/cidstore.h> + +# include <auth/authmgr.h> +# include <upstream/upstreamcache.h> + +# include "cache/structuredcachestore.h" +# include "diag/logging.h" + +# include <fmt/format.h> + +# include <algorithm> +# include <atomic> +# include <set> +# include <stack> + +namespace zen { + +using namespace std::literals; + +static const IoBuffer EmptyBuffer; +static const IoHash EmptyBufferId = IoHash::HashBuffer(EmptyBuffer); + +namespace detail { + + class HordeUpstreamApplyEndpoint final : public UpstreamApplyEndpoint + { + public: + HordeUpstreamApplyEndpoint(const CloudCacheClientOptions& ComputeOptions, + const UpstreamAuthConfig& ComputeAuthConfig, + const CloudCacheClientOptions& StorageOptions, + const UpstreamAuthConfig& StorageAuthConfig, + CidStore& CidStore, + AuthMgr& Mgr) + : m_Log(logging::Get("upstream-apply")) + , m_CidStore(CidStore) + , m_AuthMgr(Mgr) + { + m_DisplayName = fmt::format("{} - '{}'+'{}'", ComputeOptions.Name, ComputeOptions.ServiceUrl, StorageOptions.ServiceUrl); + m_ChannelId = fmt::format("zen-{}", zen::GetSessionIdString()); + + { + std::unique_ptr<CloudCacheTokenProvider> TokenProvider; + + if (ComputeAuthConfig.OAuthUrl.empty() == false) + { + TokenProvider = + CloudCacheTokenProvider::CreateFromOAuthClientCredentials({.Url = ComputeAuthConfig.OAuthUrl, + .ClientId = ComputeAuthConfig.OAuthClientId, + .ClientSecret = ComputeAuthConfig.OAuthClientSecret}); + } + else if (ComputeAuthConfig.OpenIdProvider.empty() == false) + { + TokenProvider = + CloudCacheTokenProvider::CreateFromCallback([this, ProviderName = std::string(ComputeAuthConfig.OpenIdProvider)]() { + AuthMgr::OpenIdAccessToken Token = m_AuthMgr.GetOpenIdAccessToken(ProviderName); + return CloudCacheAccessToken{.Value = Token.AccessToken, .ExpireTime = Token.ExpireTime}; + }); + } + else + { + CloudCacheAccessToken AccessToken{.Value = std::string(ComputeAuthConfig.AccessToken), + .ExpireTime = CloudCacheAccessToken::TimePoint::max()}; + TokenProvider = CloudCacheTokenProvider::CreateFromStaticToken(AccessToken); + } + + m_Client = new CloudCacheClient(ComputeOptions, std::move(TokenProvider)); + } + + { + std::unique_ptr<CloudCacheTokenProvider> TokenProvider; + + if (StorageAuthConfig.OAuthUrl.empty() == false) + { + TokenProvider = + CloudCacheTokenProvider::CreateFromOAuthClientCredentials({.Url = StorageAuthConfig.OAuthUrl, + .ClientId = StorageAuthConfig.OAuthClientId, + .ClientSecret = StorageAuthConfig.OAuthClientSecret}); + } + else if (StorageAuthConfig.OpenIdProvider.empty() == false) + { + TokenProvider = + CloudCacheTokenProvider::CreateFromCallback([this, ProviderName = std::string(StorageAuthConfig.OpenIdProvider)]() { + AuthMgr::OpenIdAccessToken Token = m_AuthMgr.GetOpenIdAccessToken(ProviderName); + return CloudCacheAccessToken{.Value = Token.AccessToken, .ExpireTime = Token.ExpireTime}; + }); + } + else + { + CloudCacheAccessToken AccessToken{.Value = std::string(StorageAuthConfig.AccessToken), + .ExpireTime = CloudCacheAccessToken::TimePoint::max()}; + TokenProvider = CloudCacheTokenProvider::CreateFromStaticToken(AccessToken); + } + + m_StorageClient = new CloudCacheClient(StorageOptions, std::move(TokenProvider)); + } + } + + virtual ~HordeUpstreamApplyEndpoint() = default; + + virtual UpstreamEndpointHealth Initialize() override { return CheckHealth(); } + + virtual bool IsHealthy() const override { return m_HealthOk.load(); } + + virtual UpstreamEndpointHealth CheckHealth() override + { + try + { + CloudCacheSession Session(m_Client); + CloudCacheResult Result = Session.Authenticate(); + + m_HealthOk = Result.ErrorCode == 0; + + return {.Reason = std::move(Result.Reason), .Ok = Result.Success}; + } + catch (std::exception& Err) + { + return {.Reason = Err.what(), .Ok = false}; + } + } + + virtual std::string_view DisplayName() const override { return m_DisplayName; } + + virtual PostUpstreamApplyResult PostApply(UpstreamApplyRecord ApplyRecord) override + { + PostUpstreamApplyResult ApplyResult{}; + ApplyResult.Timepoints.merge(ApplyRecord.Timepoints); + + try + { + UpstreamData UpstreamData; + if (!ProcessApplyKey(ApplyRecord, UpstreamData)) + { + return {.Error{.ErrorCode = -1, .Reason = "Failed to generate task data"}}; + } + + { + ApplyResult.Timepoints["zen-storage-build-ref"] = DateTime::NowTicks(); + + bool AlreadyQueued; + { + std::scoped_lock Lock(m_TaskMutex); + AlreadyQueued = m_PendingTasks.contains(UpstreamData.TaskId); + } + if (AlreadyQueued) + { + // Pending task is already queued, return success + ApplyResult.Success = true; + return ApplyResult; + } + m_PendingTasks[UpstreamData.TaskId] = std::move(ApplyRecord); + } + + CloudCacheSession ComputeSession(m_Client); + CloudCacheSession StorageSession(m_StorageClient); + + { + CloudCacheResult Result = BatchPutBlobsIfMissing(StorageSession, UpstreamData.Blobs, UpstreamData.CasIds); + ApplyResult.Bytes += Result.Bytes; + ApplyResult.ElapsedSeconds += Result.ElapsedSeconds; + ApplyResult.Timepoints["zen-storage-upload-blobs"] = DateTime::NowTicks(); + if (!Result.Success) + { + ApplyResult.Error = {.ErrorCode = Result.ErrorCode, + .Reason = !Result.Reason.empty() ? std::move(Result.Reason) : "Failed to upload blobs"}; + return ApplyResult; + } + UpstreamData.Blobs.clear(); + UpstreamData.CasIds.clear(); + } + + { + CloudCacheResult Result = BatchPutCompressedBlobsIfMissing(StorageSession, UpstreamData.Cids); + ApplyResult.Bytes += Result.Bytes; + ApplyResult.ElapsedSeconds += Result.ElapsedSeconds; + ApplyResult.Timepoints["zen-storage-upload-compressed-blobs"] = DateTime::NowTicks(); + if (!Result.Success) + { + ApplyResult.Error = { + .ErrorCode = Result.ErrorCode, + .Reason = !Result.Reason.empty() ? std::move(Result.Reason) : "Failed to upload compressed blobs"}; + return ApplyResult; + } + UpstreamData.Cids.clear(); + } + + { + CloudCacheResult Result = BatchPutObjectsIfMissing(StorageSession, UpstreamData.Objects); + ApplyResult.Bytes += Result.Bytes; + ApplyResult.ElapsedSeconds += Result.ElapsedSeconds; + ApplyResult.Timepoints["zen-storage-upload-objects"] = DateTime::NowTicks(); + if (!Result.Success) + { + ApplyResult.Error = {.ErrorCode = Result.ErrorCode, + .Reason = !Result.Reason.empty() ? std::move(Result.Reason) : "Failed to upload objects"}; + return ApplyResult; + } + } + + { + PutRefResult RefResult = StorageSession.PutRef(StorageSession.Client().DefaultBlobStoreNamespace(), + "requests"sv, + UpstreamData.TaskId, + UpstreamData.Objects[UpstreamData.TaskId].GetBuffer().AsIoBuffer(), + ZenContentType::kCbObject); + Log().debug("Put ref {} Need={} Bytes={} Duration={}s Result={}", + UpstreamData.TaskId, + RefResult.Needs.size(), + RefResult.Bytes, + RefResult.ElapsedSeconds, + RefResult.Success); + ApplyResult.Bytes += RefResult.Bytes; + ApplyResult.ElapsedSeconds += RefResult.ElapsedSeconds; + ApplyResult.Timepoints["zen-storage-put-ref"] = DateTime::NowTicks(); + + if (RefResult.Needs.size() > 0) + { + Log().error("Failed to add task ref {} due to {} missing blobs", UpstreamData.TaskId, RefResult.Needs.size()); + for (const auto& Hash : RefResult.Needs) + { + Log().debug("Task ref {} missing blob {}", UpstreamData.TaskId, Hash); + } + + ApplyResult.Error = {.ErrorCode = RefResult.ErrorCode, + .Reason = !RefResult.Reason.empty() ? std::move(RefResult.Reason) + : "Failed to add task ref due to missing blob"}; + return ApplyResult; + } + + if (!RefResult.Success) + { + ApplyResult.Error = {.ErrorCode = RefResult.ErrorCode, + .Reason = !RefResult.Reason.empty() ? std::move(RefResult.Reason) : "Failed to add task ref"}; + return ApplyResult; + } + UpstreamData.Objects.clear(); + } + + { + CbObjectWriter Writer; + Writer.AddString("c"sv, m_ChannelId); + Writer.AddObjectAttachment("r"sv, UpstreamData.RequirementsId); + Writer.BeginArray("t"sv); + Writer.AddObjectAttachment(UpstreamData.TaskId); + Writer.EndArray(); + CbObject TasksObject = Writer.Save(); + IoBuffer TasksData = TasksObject.GetBuffer().AsIoBuffer(); + + CloudCacheResult Result = ComputeSession.PostComputeTasks(TasksData); + Log().debug("Post compute task {} Bytes={} Duration={}s Result={}", + TasksObject.GetHash(), + Result.Bytes, + Result.ElapsedSeconds, + Result.Success); + ApplyResult.Bytes += Result.Bytes; + ApplyResult.ElapsedSeconds += Result.ElapsedSeconds; + ApplyResult.Timepoints["zen-horde-post-task"] = DateTime::NowTicks(); + if (!Result.Success) + { + { + std::scoped_lock Lock(m_TaskMutex); + m_PendingTasks.erase(UpstreamData.TaskId); + } + + ApplyResult.Error = {.ErrorCode = Result.ErrorCode, + .Reason = !Result.Reason.empty() ? std::move(Result.Reason) : "Failed to post compute task"}; + return ApplyResult; + } + } + + Log().info("Task posted {}", UpstreamData.TaskId); + ApplyResult.Success = true; + return ApplyResult; + } + catch (std::exception& Err) + { + m_HealthOk = false; + return {.Error{.ErrorCode = -1, .Reason = Err.what()}}; + } + } + + [[nodiscard]] CloudCacheResult BatchPutBlobsIfMissing(CloudCacheSession& Session, + const std::map<IoHash, IoBuffer>& Blobs, + const std::set<IoHash>& CasIds) + { + if (Blobs.size() == 0 && CasIds.size() == 0) + { + return {.Success = true}; + } + + int64_t Bytes{}; + double ElapsedSeconds{}; + + // Batch check for missing blobs + std::set<IoHash> Keys; + std::transform(Blobs.begin(), Blobs.end(), std::inserter(Keys, Keys.end()), [](const auto& It) { return It.first; }); + Keys.insert(CasIds.begin(), CasIds.end()); + + CloudCacheExistsResult ExistsResult = Session.BlobExists(Session.Client().DefaultBlobStoreNamespace(), Keys); + Log().debug("Queried {} missing blobs Need={} Duration={}s Result={}", + Keys.size(), + ExistsResult.Needs.size(), + ExistsResult.ElapsedSeconds, + ExistsResult.Success); + ElapsedSeconds += ExistsResult.ElapsedSeconds; + if (!ExistsResult.Success) + { + return {.Bytes = Bytes, + .ElapsedSeconds = ElapsedSeconds, + .ErrorCode = ExistsResult.ErrorCode ? ExistsResult.ErrorCode : -1, + .Reason = !ExistsResult.Reason.empty() ? std::move(ExistsResult.Reason) : "Failed to check if blobs exist"}; + } + + for (const auto& Hash : ExistsResult.Needs) + { + IoBuffer DataBuffer; + if (Blobs.contains(Hash)) + { + DataBuffer = Blobs.at(Hash); + } + else + { + DataBuffer = m_CidStore.FindChunkByCid(Hash); + if (!DataBuffer) + { + Log().warn("Put blob FAILED, input chunk '{}' missing", Hash); + return {.Bytes = Bytes, .ElapsedSeconds = ElapsedSeconds, .ErrorCode = -1, .Reason = "Failed to put blobs"}; + } + } + + CloudCacheResult Result = Session.PutBlob(Session.Client().DefaultBlobStoreNamespace(), Hash, DataBuffer); + Log().debug("Put blob {} Bytes={} Duration={}s Result={}", Hash, Result.Bytes, Result.ElapsedSeconds, Result.Success); + Bytes += Result.Bytes; + ElapsedSeconds += Result.ElapsedSeconds; + if (!Result.Success) + { + return {.Bytes = Bytes, + .ElapsedSeconds = ElapsedSeconds, + .ErrorCode = Result.ErrorCode ? Result.ErrorCode : -1, + .Reason = !Result.Reason.empty() ? std::move(Result.Reason) : "Failed to put blobs"}; + } + } + + return {.Bytes = Bytes, .ElapsedSeconds = ElapsedSeconds, .Success = true}; + } + + [[nodiscard]] CloudCacheResult BatchPutCompressedBlobsIfMissing(CloudCacheSession& Session, const std::set<IoHash>& Cids) + { + if (Cids.size() == 0) + { + return {.Success = true}; + } + + int64_t Bytes{}; + double ElapsedSeconds{}; + + // Batch check for missing compressed blobs + CloudCacheExistsResult ExistsResult = Session.CompressedBlobExists(Session.Client().DefaultBlobStoreNamespace(), Cids); + Log().debug("Queried {} missing compressed blobs Need={} Duration={}s Result={}", + Cids.size(), + ExistsResult.Needs.size(), + ExistsResult.ElapsedSeconds, + ExistsResult.Success); + ElapsedSeconds += ExistsResult.ElapsedSeconds; + if (!ExistsResult.Success) + { + return { + .Bytes = Bytes, + .ElapsedSeconds = ElapsedSeconds, + .ErrorCode = ExistsResult.ErrorCode ? ExistsResult.ErrorCode : -1, + .Reason = !ExistsResult.Reason.empty() ? std::move(ExistsResult.Reason) : "Failed to check if compressed blobs exist"}; + } + + for (const auto& Hash : ExistsResult.Needs) + { + IoBuffer DataBuffer = m_CidStore.FindChunkByCid(Hash); + if (!DataBuffer) + { + Log().warn("Put compressed blob FAILED, input CID chunk '{}' missing", Hash); + return {.Bytes = Bytes, .ElapsedSeconds = ElapsedSeconds, .ErrorCode = -1, .Reason = "Failed to put compressed blobs"}; + } + + CloudCacheResult Result = Session.PutCompressedBlob(Session.Client().DefaultBlobStoreNamespace(), Hash, DataBuffer); + Log().debug("Put compressed blob {} Bytes={} Duration={}s Result={}", + Hash, + Result.Bytes, + Result.ElapsedSeconds, + Result.Success); + Bytes += Result.Bytes; + ElapsedSeconds += Result.ElapsedSeconds; + if (!Result.Success) + { + return {.Bytes = Bytes, + .ElapsedSeconds = ElapsedSeconds, + .ErrorCode = Result.ErrorCode ? Result.ErrorCode : -1, + .Reason = !Result.Reason.empty() ? std::move(Result.Reason) : "Failed to put compressed blobs"}; + } + } + + return {.Bytes = Bytes, .ElapsedSeconds = ElapsedSeconds, .Success = true}; + } + + [[nodiscard]] CloudCacheResult BatchPutObjectsIfMissing(CloudCacheSession& Session, const std::map<IoHash, CbObject>& Objects) + { + if (Objects.size() == 0) + { + return {.Success = true}; + } + + int64_t Bytes{}; + double ElapsedSeconds{}; + + // Batch check for missing objects + std::set<IoHash> Keys; + std::transform(Objects.begin(), Objects.end(), std::inserter(Keys, Keys.end()), [](const auto& It) { return It.first; }); + + CloudCacheExistsResult ExistsResult = Session.ObjectExists(Session.Client().DefaultBlobStoreNamespace(), Keys); + Log().debug("Queried {} missing objects Need={} Duration={}s Result={}", + Keys.size(), + ExistsResult.Needs.size(), + ExistsResult.ElapsedSeconds, + ExistsResult.Success); + ElapsedSeconds += ExistsResult.ElapsedSeconds; + if (!ExistsResult.Success) + { + return {.Bytes = Bytes, + .ElapsedSeconds = ElapsedSeconds, + .ErrorCode = ExistsResult.ErrorCode ? ExistsResult.ErrorCode : -1, + .Reason = !ExistsResult.Reason.empty() ? std::move(ExistsResult.Reason) : "Failed to check if objects exist"}; + } + + for (const auto& Hash : ExistsResult.Needs) + { + CloudCacheResult Result = + Session.PutObject(Session.Client().DefaultBlobStoreNamespace(), Hash, Objects.at(Hash).GetBuffer().AsIoBuffer()); + Log().debug("Put object {} Bytes={} Duration={}s Result={}", Hash, Result.Bytes, Result.ElapsedSeconds, Result.Success); + Bytes += Result.Bytes; + ElapsedSeconds += Result.ElapsedSeconds; + if (!Result.Success) + { + return {.Bytes = Bytes, + .ElapsedSeconds = ElapsedSeconds, + .ErrorCode = Result.ErrorCode ? Result.ErrorCode : -1, + .Reason = !Result.Reason.empty() ? std::move(Result.Reason) : "Failed to put objects"}; + } + } + + return {.Bytes = Bytes, .ElapsedSeconds = ElapsedSeconds, .Success = true}; + } + + enum class ComputeTaskState : int32_t + { + Queued = 0, + Executing = 1, + Complete = 2, + }; + + enum class ComputeTaskOutcome : int32_t + { + Success = 0, + Failed = 1, + Cancelled = 2, + NoResult = 3, + Exipred = 4, + BlobNotFound = 5, + Exception = 6, + }; + + [[nodiscard]] static std::string_view ComputeTaskStateToString(const ComputeTaskState Outcome) + { + switch (Outcome) + { + case ComputeTaskState::Queued: + return "Queued"sv; + case ComputeTaskState::Executing: + return "Executing"sv; + case ComputeTaskState::Complete: + return "Complete"sv; + }; + return "Unknown"sv; + } + + [[nodiscard]] static std::string_view ComputeTaskOutcomeToString(const ComputeTaskOutcome Outcome) + { + switch (Outcome) + { + case ComputeTaskOutcome::Success: + return "Success"sv; + case ComputeTaskOutcome::Failed: + return "Failed"sv; + case ComputeTaskOutcome::Cancelled: + return "Cancelled"sv; + case ComputeTaskOutcome::NoResult: + return "NoResult"sv; + case ComputeTaskOutcome::Exipred: + return "Exipred"sv; + case ComputeTaskOutcome::BlobNotFound: + return "BlobNotFound"sv; + case ComputeTaskOutcome::Exception: + return "Exception"sv; + }; + return "Unknown"sv; + } + + virtual GetUpstreamApplyUpdatesResult GetUpdates(WorkerThreadPool& ThreadPool) override + { + int64_t Bytes{}; + double ElapsedSeconds{}; + + { + std::scoped_lock Lock(m_TaskMutex); + if (m_PendingTasks.empty()) + { + if (m_CompletedTasks.empty()) + { + // Nothing to do. + return {.Success = true}; + } + + UpstreamApplyCompleted CompletedTasks; + std::swap(CompletedTasks, m_CompletedTasks); + return {.Bytes = Bytes, .ElapsedSeconds = ElapsedSeconds, .Completed = std::move(CompletedTasks), .Success = true}; + } + } + + try + { + CloudCacheSession ComputeSession(m_Client); + + CloudCacheResult UpdatesResult = ComputeSession.GetComputeUpdates(m_ChannelId); + Log().debug("Get compute updates Bytes={} Duration={}s Result={}", + UpdatesResult.Bytes, + UpdatesResult.ElapsedSeconds, + UpdatesResult.Success); + Bytes += UpdatesResult.Bytes; + ElapsedSeconds += UpdatesResult.ElapsedSeconds; + if (!UpdatesResult.Success) + { + return {.Error{.ErrorCode = UpdatesResult.ErrorCode, .Reason = std::move(UpdatesResult.Reason)}, + .Bytes = Bytes, + .ElapsedSeconds = ElapsedSeconds}; + } + + if (!UpdatesResult.Success) + { + return {.Error{.ErrorCode = -1, .Reason = "Failed get task updates"}, .Bytes = Bytes, .ElapsedSeconds = ElapsedSeconds}; + } + + CbObject TaskStatus = LoadCompactBinaryObject(std::move(UpdatesResult.Response)); + + for (auto& It : TaskStatus["u"sv]) + { + CbObjectView Status = It.AsObjectView(); + IoHash TaskId = Status["h"sv].AsHash(); + const ComputeTaskState State = (ComputeTaskState)Status["s"sv].AsInt32(); + const ComputeTaskOutcome Outcome = (ComputeTaskOutcome)Status["o"sv].AsInt32(); + + Log().info("Task {} State={}", TaskId, ComputeTaskStateToString(State)); + + // Only completed tasks need to be processed + if (State != ComputeTaskState::Complete) + { + continue; + } + + IoHash WorkerId{}; + IoHash ActionId{}; + UpstreamApplyType ApplyType{}; + + { + std::scoped_lock Lock(m_TaskMutex); + auto TaskIt = m_PendingTasks.find(TaskId); + if (TaskIt != m_PendingTasks.end()) + { + WorkerId = TaskIt->second.WorkerDescriptor.GetHash(); + ActionId = TaskIt->second.Action.GetHash(); + ApplyType = TaskIt->second.Type; + m_PendingTasks.erase(TaskIt); + } + } + + if (WorkerId == IoHash::Zero) + { + Log().warn("Task {} missing from pending tasks", TaskId); + continue; + } + + std::map<std::string, uint64_t> Timepoints; + ProcessQueueTimings(Status["qs"sv].AsObjectView(), Timepoints); + ProcessExecuteTimings(Status["es"sv].AsObjectView(), Timepoints); + + if (Outcome != ComputeTaskOutcome::Success) + { + const std::string_view Detail = Status["d"sv].AsString(); + { + std::scoped_lock Lock(m_TaskMutex); + m_CompletedTasks[WorkerId][ActionId] = { + .Error{.ErrorCode = -1, .Reason = fmt::format("Task {} {}", ComputeTaskOutcomeToString(Outcome), Detail)}, + .Timepoints = std::move(Timepoints)}; + } + continue; + } + + Timepoints["zen-complete-queue-added"] = DateTime::NowTicks(); + ThreadPool.ScheduleWork([this, + ApplyType, + ResultHash = Status["r"sv].AsHash(), + Timepoints = std::move(Timepoints), + TaskId = std::move(TaskId), + WorkerId = std::move(WorkerId), + ActionId = std::move(ActionId)]() mutable { + Timepoints["zen-complete-queue-dispatched"] = DateTime::NowTicks(); + GetUpstreamApplyResult Result = ProcessTaskStatus(ApplyType, ResultHash); + Timepoints["zen-complete-queue-complete"] = DateTime::NowTicks(); + Result.Timepoints.merge(Timepoints); + + Log().debug("Task Processed {} Files={} Attachments={} ExitCode={}", + TaskId, + Result.OutputFiles.size(), + Result.OutputPackage.GetAttachments().size(), + Result.Error.ErrorCode); + { + std::scoped_lock Lock(m_TaskMutex); + m_CompletedTasks[WorkerId][ActionId] = std::move(Result); + } + }); + } + + { + std::scoped_lock Lock(m_TaskMutex); + if (m_CompletedTasks.empty()) + { + // Nothing to do. + return {.Bytes = Bytes, .ElapsedSeconds = ElapsedSeconds, .Success = true}; + } + UpstreamApplyCompleted CompletedTasks; + std::swap(CompletedTasks, m_CompletedTasks); + return {.Bytes = Bytes, .ElapsedSeconds = ElapsedSeconds, .Completed = std::move(CompletedTasks), .Success = true}; + } + } + catch (std::exception& Err) + { + m_HealthOk = false; + return { + .Error{.ErrorCode = -1, .Reason = Err.what()}, + .Bytes = Bytes, + .ElapsedSeconds = ElapsedSeconds, + }; + } + } + + virtual UpstreamApplyEndpointStats& Stats() override { return m_Stats; } + + private: + spdlog::logger& Log() { return m_Log; } + + spdlog::logger& m_Log; + CidStore& m_CidStore; + AuthMgr& m_AuthMgr; + std::string m_DisplayName; + RefPtr<CloudCacheClient> m_Client; + RefPtr<CloudCacheClient> m_StorageClient; + UpstreamApplyEndpointStats m_Stats; + std::atomic_bool m_HealthOk{false}; + std::string m_ChannelId; + + std::mutex m_TaskMutex; + std::unordered_map<IoHash, UpstreamApplyRecord> m_PendingTasks; + UpstreamApplyCompleted m_CompletedTasks; + + struct UpstreamData + { + std::map<IoHash, IoBuffer> Blobs; + std::map<IoHash, CbObject> Objects; + std::set<IoHash> CasIds; + std::set<IoHash> Cids; + IoHash TaskId; + IoHash RequirementsId; + }; + + struct UpstreamDirectory + { + std::filesystem::path Path; + std::map<std::string, UpstreamDirectory> Directories; + std::set<std::string> Files; + }; + + static void ProcessQueueTimings(CbObjectView QueueStats, std::map<std::string, uint64_t>& Timepoints) + { + uint64_t Ticks = QueueStats["t"sv].AsDateTimeTicks(); + if (Ticks == 0) + { + return; + } + + // Scope is an array of miliseconds after start time + // TODO: cleanup + Timepoints["horde-queue-added"] = Ticks; + int Index = 0; + for (auto& Item : QueueStats["s"sv].AsArrayView()) + { + Ticks += Item.AsInt32() * TimeSpan::TicksPerMillisecond; + switch (Index) + { + case 0: + Timepoints["horde-queue-dispatched"] = Ticks; + break; + case 1: + Timepoints["horde-queue-complete"] = Ticks; + break; + } + Index++; + } + } + + static void ProcessExecuteTimings(CbObjectView ExecutionStats, std::map<std::string, uint64_t>& Timepoints) + { + uint64_t Ticks = ExecutionStats["t"sv].AsDateTimeTicks(); + if (Ticks == 0) + { + return; + } + + // Scope is an array of miliseconds after start time + // TODO: cleanup + Timepoints["horde-execution-start"] = Ticks; + int Index = 0; + for (auto& Item : ExecutionStats["s"sv].AsArrayView()) + { + Ticks += Item.AsInt32() * TimeSpan::TicksPerMillisecond; + switch (Index) + { + case 0: + Timepoints["horde-execution-download-ref"] = Ticks; + break; + case 1: + Timepoints["horde-execution-download-input"] = Ticks; + break; + case 2: + Timepoints["horde-execution-execute"] = Ticks; + break; + case 3: + Timepoints["horde-execution-upload-log"] = Ticks; + break; + case 4: + Timepoints["horde-execution-upload-output"] = Ticks; + break; + case 5: + Timepoints["horde-execution-upload-ref"] = Ticks; + break; + } + Index++; + } + } + + [[nodiscard]] GetUpstreamApplyResult ProcessTaskStatus(const UpstreamApplyType ApplyType, const IoHash& ResultHash) + { + try + { + CloudCacheSession Session(m_StorageClient); + + GetUpstreamApplyResult ApplyResult{}; + + IoHash StdOutHash; + IoHash StdErrHash; + IoHash OutputHash; + + std::map<IoHash, IoBuffer> BinaryData; + + { + CloudCacheResult ObjectRefResult = + Session.GetRef(Session.Client().DefaultBlobStoreNamespace(), "responses"sv, ResultHash, ZenContentType::kCbObject); + Log().debug("Get ref {} Bytes={} Duration={}s Result={}", + ResultHash, + ObjectRefResult.Bytes, + ObjectRefResult.ElapsedSeconds, + ObjectRefResult.Success); + ApplyResult.Bytes += ObjectRefResult.Bytes; + ApplyResult.ElapsedSeconds += ObjectRefResult.ElapsedSeconds; + ApplyResult.Timepoints["zen-storage-get-ref"] = DateTime::NowTicks(); + + if (!ObjectRefResult.Success) + { + ApplyResult.Error.Reason = "Failed to get result object data"; + return ApplyResult; + } + + CbObject ResultObject = LoadCompactBinaryObject(ObjectRefResult.Response); + ApplyResult.Error.ErrorCode = ResultObject["e"sv].AsInt32(); + StdOutHash = ResultObject["so"sv].AsBinaryAttachment(); + StdErrHash = ResultObject["se"sv].AsBinaryAttachment(); + OutputHash = ResultObject["o"sv].AsObjectAttachment(); + } + + { + std::set<IoHash> NeededData; + if (OutputHash != IoHash::Zero) + { + GetObjectReferencesResult ObjectReferenceResult = + Session.GetObjectReferences(Session.Client().DefaultBlobStoreNamespace(), OutputHash); + Log().debug("Get object references {} References={} Bytes={} Duration={}s Result={}", + ResultHash, + ObjectReferenceResult.References.size(), + ObjectReferenceResult.Bytes, + ObjectReferenceResult.ElapsedSeconds, + ObjectReferenceResult.Success); + ApplyResult.Bytes += ObjectReferenceResult.Bytes; + ApplyResult.ElapsedSeconds += ObjectReferenceResult.ElapsedSeconds; + ApplyResult.Timepoints["zen-storage-get-object-references"] = DateTime::NowTicks(); + + if (!ObjectReferenceResult.Success) + { + ApplyResult.Error.Reason = "Failed to get result object references"; + return ApplyResult; + } + + NeededData = std::move(ObjectReferenceResult.References); + } + + NeededData.insert(OutputHash); + NeededData.insert(StdOutHash); + NeededData.insert(StdErrHash); + + for (const auto& Hash : NeededData) + { + if (Hash == IoHash::Zero) + { + continue; + } + CloudCacheResult BlobResult = Session.GetBlob(Session.Client().DefaultBlobStoreNamespace(), Hash); + Log().debug("Get blob {} Bytes={} Duration={}s Result={}", + Hash, + BlobResult.Bytes, + BlobResult.ElapsedSeconds, + BlobResult.Success); + ApplyResult.Bytes += BlobResult.Bytes; + ApplyResult.ElapsedSeconds += BlobResult.ElapsedSeconds; + if (!BlobResult.Success) + { + ApplyResult.Error.Reason = "Failed to get blob"; + return ApplyResult; + } + BinaryData[Hash] = std::move(BlobResult.Response); + } + ApplyResult.Timepoints["zen-storage-get-blobs"] = DateTime::NowTicks(); + } + + ApplyResult.StdOut = StdOutHash != IoHash::Zero + ? std::string((const char*)BinaryData[StdOutHash].GetData(), BinaryData[StdOutHash].GetSize()) + : ""; + ApplyResult.StdErr = StdErrHash != IoHash::Zero + ? std::string((const char*)BinaryData[StdErrHash].GetData(), BinaryData[StdErrHash].GetSize()) + : ""; + + if (OutputHash == IoHash::Zero) + { + ApplyResult.Error.Reason = "Task completed with no output object"; + return ApplyResult; + } + + CbObject OutputObject = LoadCompactBinaryObject(BinaryData[OutputHash]); + + switch (ApplyType) + { + case UpstreamApplyType::Simple: + { + ResolveMerkleTreeDirectory(""sv, OutputHash, BinaryData, ApplyResult.OutputFiles); + for (const auto& Pair : BinaryData) + { + ApplyResult.FileData[Pair.first] = std::move(BinaryData.at(Pair.first)); + } + + ApplyResult.Success = ApplyResult.Error.ErrorCode == 0; + return ApplyResult; + } + break; + case UpstreamApplyType::Asset: + { + if (ApplyResult.Error.ErrorCode != 0) + { + ApplyResult.Error.Reason = "Task completed with errors"; + return ApplyResult; + } + + // Get build.output + IoHash BuildOutputId; + IoBuffer BuildOutput; + for (auto& It : OutputObject["f"sv]) + { + const CbObjectView FileObject = It.AsObjectView(); + if (FileObject["n"sv].AsString() == "Build.output"sv) + { + BuildOutputId = FileObject["h"sv].AsBinaryAttachment(); + BuildOutput = BinaryData[BuildOutputId]; + break; + } + } + + if (BuildOutput.GetSize() == 0) + { + ApplyResult.Error.Reason = "Build.output file not found in task results"; + return ApplyResult; + } + + // Get Output directory node + IoBuffer OutputDirectoryTree; + for (auto& It : OutputObject["d"sv]) + { + const CbObjectView DirectoryObject = It.AsObjectView(); + if (DirectoryObject["n"sv].AsString() == "Outputs"sv) + { + OutputDirectoryTree = BinaryData[DirectoryObject["h"sv].AsObjectAttachment()]; + break; + } + } + + if (OutputDirectoryTree.GetSize() == 0) + { + ApplyResult.Error.Reason = "Outputs directory not found in task results"; + return ApplyResult; + } + + // load build.output as CbObject + + // Move Outputs from Horde to CbPackage + + std::unordered_map<IoHash, IoHash> CidToCompressedId; + CbPackage OutputPackage; + CbObject OutputDirectoryTreeObject = LoadCompactBinaryObject(OutputDirectoryTree); + + for (auto& It : OutputDirectoryTreeObject["f"sv]) + { + CbObjectView FileObject = It.AsObjectView(); + // Name is the uncompressed hash + IoHash DecompressedId = IoHash::FromHexString(FileObject["n"sv].AsString()); + // Hash is the compressed data hash, and how it is stored in Horde + IoHash CompressedId = FileObject["h"sv].AsBinaryAttachment(); + + if (!BinaryData.contains(CompressedId)) + { + Log().warn("Object attachment chunk not retrieved from Horde {}", CompressedId); + ApplyResult.Error.Reason = "Object attachment chunk not retrieved from Horde"; + return ApplyResult; + } + CidToCompressedId[DecompressedId] = CompressedId; + } + + // Iterate attachments, verify all chunks exist, and add to CbPackage + bool AnyErrors = false; + CbObject BuildOutputObject = LoadCompactBinaryObject(BuildOutput); + BuildOutputObject.IterateAttachments([&](CbFieldView Field) { + const IoHash DecompressedId = Field.AsHash(); + if (!CidToCompressedId.contains(DecompressedId)) + { + Log().warn("Attachment not found {}", DecompressedId); + AnyErrors = true; + return; + } + const IoHash& CompressedId = CidToCompressedId.at(DecompressedId); + + if (!BinaryData.contains(CompressedId)) + { + Log().warn("Missing output {} compressed {} uncompressed", CompressedId, DecompressedId); + AnyErrors = true; + return; + } + + IoHash RawHash; + uint64_t RawSize; + CompressedBuffer AttachmentBuffer = + CompressedBuffer::FromCompressed(SharedBuffer(BinaryData[CompressedId]), RawHash, RawSize); + + if (!AttachmentBuffer || RawHash != DecompressedId) + { + Log().warn( + "Invalid output encountered (not valid CompressedBuffer format) {} compressed {} uncompressed", + CompressedId, + DecompressedId); + AnyErrors = true; + return; + } + + ApplyResult.TotalAttachmentBytes += AttachmentBuffer.GetCompressedSize(); + ApplyResult.TotalRawAttachmentBytes += RawSize; + + CbAttachment Attachment(AttachmentBuffer, DecompressedId); + OutputPackage.AddAttachment(Attachment); + }); + + if (AnyErrors) + { + ApplyResult.Error.Reason = "Failed to get result object attachment data"; + return ApplyResult; + } + + OutputPackage.SetObject(BuildOutputObject); + ApplyResult.OutputPackage = std::move(OutputPackage); + + ApplyResult.Success = ApplyResult.Error.ErrorCode == 0; + return ApplyResult; + } + break; + } + + ApplyResult.Error.Reason = "Unknown apply type"; + return ApplyResult; + } + catch (std::exception& Err) + { + return {.Error{.ErrorCode = -1, .Reason = Err.what()}}; + } + } + + [[nodiscard]] bool ProcessApplyKey(const UpstreamApplyRecord& ApplyRecord, UpstreamData& Data) + { + std::string ExecutablePath; + std::string WorkingDirectory; + std::vector<std::string> Arguments; + std::map<std::string, std::string> Environment; + std::set<std::filesystem::path> InputFiles; + std::set<std::string> Outputs; + std::map<std::filesystem::path, IoHash> InputFileHashes; + + ExecutablePath = ApplyRecord.WorkerDescriptor["path"sv].AsString(); + if (ExecutablePath.empty()) + { + Log().warn("process apply upstream FAILED, '{}', path missing from worker descriptor", + ApplyRecord.WorkerDescriptor.GetHash()); + return false; + } + + WorkingDirectory = ApplyRecord.WorkerDescriptor["workdir"sv].AsString(); + + for (auto& It : ApplyRecord.WorkerDescriptor["executables"sv]) + { + CbObjectView FileEntry = It.AsObjectView(); + if (!ProcessFileEntry(FileEntry, InputFiles, InputFileHashes, Data.CasIds)) + { + return false; + } + } + + for (auto& It : ApplyRecord.WorkerDescriptor["files"sv]) + { + CbObjectView FileEntry = It.AsObjectView(); + if (!ProcessFileEntry(FileEntry, InputFiles, InputFileHashes, Data.CasIds)) + { + return false; + } + } + + for (auto& It : ApplyRecord.WorkerDescriptor["dirs"sv]) + { + std::string_view Directory = It.AsString(); + std::string DummyFile = fmt::format("{}/.zen_empty_file", Directory); + InputFiles.insert(DummyFile); + Data.Blobs[EmptyBufferId] = EmptyBuffer; + InputFileHashes[DummyFile] = EmptyBufferId; + } + + if (!WorkingDirectory.empty()) + { + std::string DummyFile = fmt::format("{}/.zen_empty_file", WorkingDirectory); + InputFiles.insert(DummyFile); + Data.Blobs[EmptyBufferId] = EmptyBuffer; + InputFileHashes[DummyFile] = EmptyBufferId; + } + + for (auto& It : ApplyRecord.WorkerDescriptor["environment"sv]) + { + std::string_view Env = It.AsString(); + auto Index = Env.find('='); + if (Index == std::string_view::npos) + { + Log().warn("process apply upstream FAILED, environment '{}' malformed", Env); + return false; + } + + Environment[std::string(Env.substr(0, Index))] = Env.substr(Index + 1); + } + + switch (ApplyRecord.Type) + { + case UpstreamApplyType::Simple: + { + for (auto& It : ApplyRecord.WorkerDescriptor["arguments"sv]) + { + Arguments.push_back(std::string(It.AsString())); + } + + for (auto& It : ApplyRecord.WorkerDescriptor["outputs"sv]) + { + Outputs.insert(std::string(It.AsString())); + } + } + break; + case UpstreamApplyType::Asset: + { + static const std::filesystem::path BuildActionPath = "Build.action"sv; + static const std::filesystem::path InputPath = "Inputs"sv; + const IoHash ActionId = ApplyRecord.Action.GetHash(); + + Arguments.push_back("-Build=build.action"); + Outputs.insert("Build.output"); + Outputs.insert("Outputs"); + + InputFiles.insert(BuildActionPath); + InputFileHashes[BuildActionPath] = ActionId; + Data.Blobs[ActionId] = IoBufferBuilder::MakeCloneFromMemory(ApplyRecord.Action.GetBuffer().GetData(), + ApplyRecord.Action.GetBuffer().GetSize()); + + bool AnyErrors = false; + ApplyRecord.Action.IterateAttachments([&](CbFieldView Field) { + const IoHash Cid = Field.AsHash(); + const std::filesystem::path FilePath = {InputPath / Cid.ToHexString()}; + + if (!m_CidStore.ContainsChunk(Cid)) + { + Log().warn("process apply upstream FAILED, input CID chunk '{}' missing", Cid); + AnyErrors = true; + return; + } + + if (InputFiles.contains(FilePath)) + { + return; + } + + InputFiles.insert(FilePath); + InputFileHashes[FilePath] = Cid; + Data.Cids.insert(Cid); + }); + + if (AnyErrors) + { + return false; + } + } + break; + } + + const UpstreamDirectory RootDirectory = BuildDirectoryTree(InputFiles); + + CbObject Sandbox = BuildMerkleTreeDirectory(RootDirectory, InputFileHashes, Data.Cids, Data.Objects); + const IoHash SandboxHash = Sandbox.GetHash(); + Data.Objects[SandboxHash] = std::move(Sandbox); + + { + std::string_view HostPlatform = ApplyRecord.WorkerDescriptor["host"sv].AsString(); + if (HostPlatform.empty()) + { + Log().warn("process apply upstream FAILED, 'host' platform not provided"); + return false; + } + + int32_t LogicalCores = ApplyRecord.WorkerDescriptor["cores"sv].AsInt32(); + int64_t Memory = ApplyRecord.WorkerDescriptor["memory"sv].AsInt64(); + bool Exclusive = ApplyRecord.WorkerDescriptor["exclusive"sv].AsBool(); + + std::string Condition = fmt::format("Platform == '{}'", HostPlatform); + if (HostPlatform == "Win64") + { + // TODO + // Condition += " && Pool == 'Win-RemoteExec'"; + } + + std::map<std::string_view, int64_t> Resources; + if (LogicalCores > 0) + { + Resources["LogicalCores"sv] = LogicalCores; + } + if (Memory > 0) + { + Resources["RAM"sv] = std::max(Memory / 1024LL / 1024LL / 1024LL, 1LL); + } + + CbObject Requirements = BuildRequirements(Condition, Resources, Exclusive); + const IoHash RequirementsId = Requirements.GetHash(); + Data.Objects[RequirementsId] = std::move(Requirements); + Data.RequirementsId = RequirementsId; + } + + CbObject Task = BuildTask(ExecutablePath, Arguments, Environment, WorkingDirectory, SandboxHash, Data.RequirementsId, Outputs); + + const IoHash TaskId = Task.GetHash(); + Data.Objects[TaskId] = std::move(Task); + Data.TaskId = TaskId; + + return true; + } + + [[nodiscard]] bool ProcessFileEntry(const CbObjectView& FileEntry, + std::set<std::filesystem::path>& InputFiles, + std::map<std::filesystem::path, IoHash>& InputFileHashes, + std::set<IoHash>& CasIds) + { + const std::filesystem::path FilePath = FileEntry["name"sv].AsString(); + const IoHash ChunkId = FileEntry["hash"sv].AsHash(); + const uint64_t Size = FileEntry["size"sv].AsUInt64(); + + if (!m_CidStore.ContainsChunk(ChunkId)) + { + Log().warn("process apply upstream FAILED, worker CAS chunk '{}' missing", ChunkId); + return false; + } + + if (InputFiles.contains(FilePath)) + { + Log().warn("process apply upstream FAILED, worker CAS chunk '{}' size: {} duplicate filename {}", ChunkId, Size, FilePath); + return false; + } + + InputFiles.insert(FilePath); + InputFileHashes[FilePath] = ChunkId; + CasIds.insert(ChunkId); + return true; + } + + [[nodiscard]] UpstreamDirectory BuildDirectoryTree(const std::set<std::filesystem::path>& InputFiles) + { + static const std::filesystem::path RootPath; + std::map<std::filesystem::path, UpstreamDirectory*> AllDirectories; + UpstreamDirectory RootDirectory = {.Path = RootPath}; + + AllDirectories[RootPath] = &RootDirectory; + + // Build tree from flat list + for (const auto& Path : InputFiles) + { + if (Path.has_parent_path()) + { + if (!AllDirectories.contains(Path.parent_path())) + { + std::stack<std::string> PathSplit; + { + std::filesystem::path ParentPath = Path.parent_path(); + PathSplit.push(ParentPath.filename().string()); + while (ParentPath.has_parent_path()) + { + ParentPath = ParentPath.parent_path(); + PathSplit.push(ParentPath.filename().string()); + } + } + UpstreamDirectory* ParentPtr = &RootDirectory; + while (!PathSplit.empty()) + { + if (!ParentPtr->Directories.contains(PathSplit.top())) + { + std::filesystem::path NewParentPath = {ParentPtr->Path / PathSplit.top()}; + ParentPtr->Directories[PathSplit.top()] = {.Path = NewParentPath}; + AllDirectories[NewParentPath] = &ParentPtr->Directories[PathSplit.top()]; + } + ParentPtr = &ParentPtr->Directories[PathSplit.top()]; + PathSplit.pop(); + } + } + + AllDirectories[Path.parent_path()]->Files.insert(Path.filename().string()); + } + else + { + RootDirectory.Files.insert(Path.filename().string()); + } + } + + return RootDirectory; + } + + [[nodiscard]] CbObject BuildMerkleTreeDirectory(const UpstreamDirectory& RootDirectory, + const std::map<std::filesystem::path, IoHash>& InputFileHashes, + const std::set<IoHash>& Cids, + std::map<IoHash, CbObject>& Objects) + { + CbObjectWriter DirectoryTreeWriter; + + if (!RootDirectory.Files.empty()) + { + DirectoryTreeWriter.BeginArray("f"sv); + for (const auto& File : RootDirectory.Files) + { + const std::filesystem::path FilePath = {RootDirectory.Path / File}; + const IoHash& FileHash = InputFileHashes.at(FilePath); + const bool Compressed = Cids.contains(FileHash); + DirectoryTreeWriter.BeginObject(); + DirectoryTreeWriter.AddString("n"sv, File); + DirectoryTreeWriter.AddBinaryAttachment("h"sv, FileHash); + DirectoryTreeWriter.AddBool("c"sv, Compressed); + DirectoryTreeWriter.EndObject(); + } + DirectoryTreeWriter.EndArray(); + } + + if (!RootDirectory.Directories.empty()) + { + DirectoryTreeWriter.BeginArray("d"sv); + for (const auto& Item : RootDirectory.Directories) + { + CbObject Directory = BuildMerkleTreeDirectory(Item.second, InputFileHashes, Cids, Objects); + const IoHash DirectoryHash = Directory.GetHash(); + Objects[DirectoryHash] = std::move(Directory); + + DirectoryTreeWriter.BeginObject(); + DirectoryTreeWriter.AddString("n"sv, Item.first); + DirectoryTreeWriter.AddObjectAttachment("h"sv, DirectoryHash); + DirectoryTreeWriter.EndObject(); + } + DirectoryTreeWriter.EndArray(); + } + + return DirectoryTreeWriter.Save(); + } + + void ResolveMerkleTreeDirectory(const std::filesystem::path& ParentDirectory, + const IoHash& DirectoryHash, + const std::map<IoHash, IoBuffer>& Objects, + std::map<std::filesystem::path, IoHash>& OutputFiles) + { + CbObject Directory = LoadCompactBinaryObject(Objects.at(DirectoryHash)); + + for (auto& It : Directory["f"sv]) + { + const CbObjectView FileObject = It.AsObjectView(); + const std::filesystem::path Path = ParentDirectory / FileObject["n"sv].AsString(); + + OutputFiles[Path] = FileObject["h"sv].AsBinaryAttachment(); + } + + for (auto& It : Directory["d"sv]) + { + const CbObjectView DirectoryObject = It.AsObjectView(); + + ResolveMerkleTreeDirectory(ParentDirectory / DirectoryObject["n"sv].AsString(), + DirectoryObject["h"sv].AsObjectAttachment(), + Objects, + OutputFiles); + } + } + + [[nodiscard]] CbObject BuildRequirements(const std::string_view Condition, + const std::map<std::string_view, int64_t>& Resources, + const bool Exclusive) + { + CbObjectWriter Writer; + Writer.AddString("c", Condition); + if (!Resources.empty()) + { + Writer.BeginArray("r"); + for (const auto& Resource : Resources) + { + Writer.BeginArray(); + Writer.AddString(Resource.first); + Writer.AddInteger(Resource.second); + Writer.EndArray(); + } + Writer.EndArray(); + } + Writer.AddBool("e", Exclusive); + return Writer.Save(); + } + + [[nodiscard]] CbObject BuildTask(const std::string_view Executable, + const std::vector<std::string>& Arguments, + const std::map<std::string, std::string>& Environment, + const std::string_view WorkingDirectory, + const IoHash& SandboxHash, + const IoHash& RequirementsId, + const std::set<std::string>& Outputs) + { + CbObjectWriter TaskWriter; + TaskWriter.AddString("e"sv, Executable); + + if (!Arguments.empty()) + { + TaskWriter.BeginArray("a"sv); + for (const auto& Argument : Arguments) + { + TaskWriter.AddString(Argument); + } + TaskWriter.EndArray(); + } + + if (!Environment.empty()) + { + TaskWriter.BeginArray("v"sv); + for (const auto& Env : Environment) + { + TaskWriter.BeginArray(); + TaskWriter.AddString(Env.first); + TaskWriter.AddString(Env.second); + TaskWriter.EndArray(); + } + TaskWriter.EndArray(); + } + + if (!WorkingDirectory.empty()) + { + TaskWriter.AddString("w"sv, WorkingDirectory); + } + + TaskWriter.AddObjectAttachment("s"sv, SandboxHash); + TaskWriter.AddObjectAttachment("r"sv, RequirementsId); + + // Outputs + if (!Outputs.empty()) + { + TaskWriter.BeginArray("o"sv); + for (const auto& Output : Outputs) + { + TaskWriter.AddString(Output); + } + TaskWriter.EndArray(); + } + + return TaskWriter.Save(); + } + }; +} // namespace detail + +////////////////////////////////////////////////////////////////////////// + +std::unique_ptr<UpstreamApplyEndpoint> +UpstreamApplyEndpoint::CreateHordeEndpoint(const CloudCacheClientOptions& ComputeOptions, + const UpstreamAuthConfig& ComputeAuthConfig, + const CloudCacheClientOptions& StorageOptions, + const UpstreamAuthConfig& StorageAuthConfig, + CidStore& CidStore, + AuthMgr& Mgr) +{ + return std::make_unique<detail::HordeUpstreamApplyEndpoint>(ComputeOptions, + ComputeAuthConfig, + StorageOptions, + StorageAuthConfig, + CidStore, + Mgr); +} + +} // namespace zen + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zenserver/upstream/jupiter.cpp b/src/zenserver/upstream/jupiter.cpp new file mode 100644 index 000000000..dbb185bec --- /dev/null +++ b/src/zenserver/upstream/jupiter.cpp @@ -0,0 +1,965 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "jupiter.h" + +#include "diag/formatters.h" +#include "diag/logging.h" + +#include <zencore/compactbinary.h> +#include <zencore/compositebuffer.h> +#include <zencore/iobuffer.h> +#include <zencore/iohash.h> +#include <zencore/string.h> +#include <zencore/thread.h> +#include <zencore/trace.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <cpr/cpr.h> +#include <fmt/format.h> +ZEN_THIRD_PARTY_INCLUDES_END + +#if ZEN_PLATFORM_WINDOWS +# pragma comment(lib, "Crypt32.lib") +# pragma comment(lib, "Wldap32.lib") +#endif + +#include <json11.hpp> + +using namespace std::literals; + +namespace zen { + +namespace detail { + struct CloudCacheSessionState + { + CloudCacheSessionState(CloudCacheClient& Client) : m_Client(Client) {} + + const CloudCacheAccessToken& GetAccessToken(bool RefreshToken) + { + if (RefreshToken) + { + m_AccessToken = m_Client.AcquireAccessToken(); + } + + return m_AccessToken; + } + + cpr::Session& GetSession() { return m_Session; } + + void Reset(std::chrono::milliseconds ConnectTimeout, std::chrono::milliseconds Timeout) + { + m_Session.SetBody({}); + m_Session.SetHeader({}); + m_Session.SetConnectTimeout(ConnectTimeout); + m_Session.SetTimeout(Timeout); + } + + private: + friend class zen::CloudCacheClient; + + CloudCacheClient& m_Client; + CloudCacheAccessToken m_AccessToken; + cpr::Session m_Session; + }; + +} // namespace detail + +CloudCacheSession::CloudCacheSession(CloudCacheClient* CacheClient) : m_Log(CacheClient->Logger()), m_CacheClient(CacheClient) +{ + m_SessionState = m_CacheClient->AllocSessionState(); +} + +CloudCacheSession::~CloudCacheSession() +{ + m_CacheClient->FreeSessionState(m_SessionState); +} + +CloudCacheResult +CloudCacheSession::Authenticate() +{ + const bool RefreshToken = true; + const CloudCacheAccessToken& AccessToken = GetAccessToken(RefreshToken); + + return {.Success = AccessToken.IsValid()}; +} + +CloudCacheResult +CloudCacheSession::GetRef(std::string_view Namespace, std::string_view BucketId, const IoHash& Key, ZenContentType RefType) +{ + const std::string ContentType = RefType == ZenContentType::kCbObject ? "application/x-ue-cb" : "application/octet-stream"; + + ExtendableStringBuilder<256> Uri; + Uri << m_CacheClient->ServiceUrl() << "/api/v1/refs/" << Namespace << "/" << BucketId << "/" << Key.ToHexString(); + + cpr::Session& Session = GetSession(); + const CloudCacheAccessToken& AccessToken = GetAccessToken(); + + Session.SetOption(cpr::Url{Uri.c_str()}); + Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, {"Accept", ContentType}}); + Session.SetOption(cpr::Body{}); + + cpr::Response Response = Session.Get(); + ZEN_DEBUG("GET {}", Response); + + if (Response.error) + { + return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message}; + } + else if (!VerifyAccessToken(Response.status_code)) + { + return {.ErrorCode = 401, .Reason = std::string("Invalid access token")}; + } + + const bool Success = Response.status_code == 200; + const IoBuffer Buffer = Success ? IoBufferBuilder::MakeCloneFromMemory(Response.text.data(), Response.text.size()) : IoBuffer(); + + return {.Response = Buffer, .Bytes = Response.downloaded_bytes, .ElapsedSeconds = Response.elapsed, .Success = Success}; +} + +CloudCacheResult +CloudCacheSession::GetBlob(std::string_view Namespace, const IoHash& Key) +{ + ExtendableStringBuilder<256> Uri; + Uri << m_CacheClient->ServiceUrl() << "/api/v1/blobs/" << Namespace << "/" << Key.ToHexString(); + + cpr::Session& Session = GetSession(); + const CloudCacheAccessToken& AccessToken = GetAccessToken(); + + Session.SetOption(cpr::Url{Uri.c_str()}); + Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, {"Accept", "application/octet-stream"}}); + Session.SetOption(cpr::Body{}); + + cpr::Response Response = Session.Get(); + ZEN_DEBUG("GET {}", Response); + + if (Response.error) + { + return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message}; + } + else if (!VerifyAccessToken(Response.status_code)) + { + return {.ErrorCode = 401, .Reason = std::string("Invalid access token")}; + } + + const bool Success = Response.status_code == 200; + const IoBuffer Buffer = + Success && Response.text.size() > 0 ? IoBufferBuilder::MakeCloneFromMemory(Response.text.data(), Response.text.size()) : IoBuffer(); + + return {.Response = Buffer, .Bytes = Response.downloaded_bytes, .ElapsedSeconds = Response.elapsed, .Success = Success}; +} + +CloudCacheResult +CloudCacheSession::GetCompressedBlob(std::string_view Namespace, const IoHash& Key) +{ + ZEN_TRACE_CPU("HordeClient::GetCompressedBlob"); + + ExtendableStringBuilder<256> Uri; + Uri << m_CacheClient->ServiceUrl() << "/api/v1/compressed-blobs/" << Namespace << "/" << Key.ToHexString(); + + cpr::Session& Session = GetSession(); + const CloudCacheAccessToken& AccessToken = GetAccessToken(); + + Session.SetOption(cpr::Url{Uri.c_str()}); + Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, {"Accept", "application/x-ue-comp"}}); + Session.SetOption(cpr::Body{}); + + cpr::Response Response = Session.Get(); + ZEN_DEBUG("GET {}", Response); + + if (Response.error) + { + return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message}; + } + else if (!VerifyAccessToken(Response.status_code)) + { + return {.ErrorCode = 401, .Reason = std::string("Invalid access token")}; + } + + const bool Success = Response.status_code == 200; + const IoBuffer Buffer = Success ? IoBufferBuilder::MakeCloneFromMemory(Response.text.data(), Response.text.size()) : IoBuffer(); + + return {.Response = Buffer, .Bytes = Response.downloaded_bytes, .ElapsedSeconds = Response.elapsed, .Success = Success}; +} + +CloudCacheResult +CloudCacheSession::GetInlineBlob(std::string_view Namespace, std::string_view BucketId, const IoHash& Key, IoHash& OutPayloadHash) +{ + ZEN_TRACE_CPU("HordeClient::GetInlineBlob"); + + ExtendableStringBuilder<256> Uri; + Uri << m_CacheClient->ServiceUrl() << "/api/v1/refs/" << Namespace << "/" << BucketId << "/" << Key.ToHexString(); + + cpr::Session& Session = GetSession(); + const CloudCacheAccessToken& AccessToken = GetAccessToken(); + + Session.SetOption(cpr::Url{Uri.c_str()}); + Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, {"Accept", "application/x-jupiter-inline"}}); + Session.SetOption(cpr::Body{}); + + cpr::Response Response = Session.Get(); + ZEN_DEBUG("GET {}", Response); + + if (Response.error) + { + return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message}; + } + else if (!VerifyAccessToken(Response.status_code)) + { + return {.ErrorCode = 401, .Reason = std::string("Invalid access token")}; + } + + const bool Success = Response.status_code == 200; + const IoBuffer Buffer = Success ? IoBufferBuilder::MakeCloneFromMemory(Response.text.data(), Response.text.size()) : IoBuffer(); + if (auto It = Response.header.find("X-Jupiter-InlinePayloadHash"); It != Response.header.end()) + { + const std::string& PayloadHashHeader = It->second; + if (PayloadHashHeader.length() == IoHash::StringLength) + { + OutPayloadHash = IoHash::FromHexString(PayloadHashHeader); + } + } + + return {.Response = Buffer, .Bytes = Response.downloaded_bytes, .ElapsedSeconds = Response.elapsed, .Success = Success}; +} + +CloudCacheResult +CloudCacheSession::GetObject(std::string_view Namespace, const IoHash& Key) +{ + ZEN_TRACE_CPU("HordeClient::GetObject"); + + ExtendableStringBuilder<256> Uri; + Uri << m_CacheClient->ServiceUrl() << "/api/v1/objects/" << Namespace << "/" << Key.ToHexString(); + + cpr::Session& Session = GetSession(); + const CloudCacheAccessToken& AccessToken = GetAccessToken(); + + Session.SetOption(cpr::Url{Uri.c_str()}); + Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, {"Accept", "application/x-ue-cb"}}); + Session.SetOption(cpr::Body{}); + + cpr::Response Response = Session.Get(); + ZEN_DEBUG("GET {}", Response); + + if (Response.error) + { + return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message}; + } + else if (!VerifyAccessToken(Response.status_code)) + { + return {.ErrorCode = 401, .Reason = std::string("Invalid access token")}; + } + + const bool Success = Response.status_code == 200; + const IoBuffer Buffer = Success ? IoBufferBuilder::MakeCloneFromMemory(Response.text.data(), Response.text.size()) : IoBuffer(); + + return {.Response = Buffer, .Bytes = Response.downloaded_bytes, .ElapsedSeconds = Response.elapsed, .Success = Success}; +} + +PutRefResult +CloudCacheSession::PutRef(std::string_view Namespace, std::string_view BucketId, const IoHash& Key, IoBuffer Ref, ZenContentType RefType) +{ + ZEN_TRACE_CPU("HordeClient::PutRef"); + + IoHash Hash = IoHash::HashBuffer(Ref.Data(), Ref.Size()); + + const std::string ContentType = RefType == ZenContentType::kCbObject ? "application/x-ue-cb" : "application/octet-stream"; + + ExtendableStringBuilder<256> Uri; + Uri << m_CacheClient->ServiceUrl() << "/api/v1/refs/" << Namespace << "/" << BucketId << "/" << Key.ToHexString(); + + cpr::Session& Session = GetSession(); + const CloudCacheAccessToken& AccessToken = GetAccessToken(); + + Session.SetOption(cpr::Url{Uri.c_str()}); + Session.SetOption( + cpr::Header{{"Authorization", AccessToken.Value}, {"X-Jupiter-IoHash", Hash.ToHexString()}, {"Content-Type", ContentType}}); + Session.SetBody(cpr::Body{(const char*)Ref.Data(), Ref.Size()}); + + cpr::Response Response = Session.Put(); + ZEN_DEBUG("PUT {}", Response); + + if (Response.error) + { + PutRefResult Result; + Result.ErrorCode = static_cast<int32_t>(Response.error.code); + Result.Reason = std::move(Response.error.message); + return Result; + } + else if (!VerifyAccessToken(Response.status_code)) + { + PutRefResult Result; + Result.ErrorCode = 401; + Result.Reason = "Invalid access token"sv; + return Result; + } + + PutRefResult Result; + Result.Success = (Response.status_code == 200 || Response.status_code == 201); + Result.Bytes = Response.uploaded_bytes; + Result.ElapsedSeconds = Response.elapsed; + + if (Result.Success) + { + std::string JsonError; + json11::Json Json = json11::Json::parse(Response.text, JsonError); + if (JsonError.empty()) + { + json11::Json::array Needs = Json["needs"].array_items(); + for (const auto& Need : Needs) + { + Result.Needs.emplace_back(IoHash::FromHexString(Need.string_value())); + } + } + } + + return Result; +} + +FinalizeRefResult +CloudCacheSession::FinalizeRef(std::string_view Namespace, std::string_view BucketId, const IoHash& Key, const IoHash& RefHash) +{ + ZEN_TRACE_CPU("HordeClient::FinalizeRef"); + + ExtendableStringBuilder<256> Uri; + Uri << m_CacheClient->ServiceUrl() << "/api/v1/refs/" << Namespace << "/" << BucketId << "/" << Key.ToHexString() << "/finalize/" + << RefHash.ToHexString(); + + cpr::Session& Session = GetSession(); + const CloudCacheAccessToken& AccessToken = GetAccessToken(); + + Session.SetOption(cpr::Url{Uri.c_str()}); + Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, + {"X-Jupiter-IoHash", RefHash.ToHexString()}, + {"Content-Type", "application/x-ue-cb"}}); + Session.SetBody(cpr::Body{}); + + cpr::Response Response = Session.Post(); + ZEN_DEBUG("POST {}", Response); + + if (Response.error) + { + FinalizeRefResult Result; + Result.ErrorCode = static_cast<int32_t>(Response.error.code); + Result.Reason = std::move(Response.error.message); + return Result; + } + else if (!VerifyAccessToken(Response.status_code)) + { + FinalizeRefResult Result; + Result.ErrorCode = 401; + Result.Reason = "Invalid access token"sv; + return Result; + } + + FinalizeRefResult Result; + Result.Success = (Response.status_code == 200 || Response.status_code == 201); + Result.Bytes = Response.uploaded_bytes; + Result.ElapsedSeconds = Response.elapsed; + + if (Result.Success) + { + std::string JsonError; + json11::Json Json = json11::Json::parse(Response.text, JsonError); + if (JsonError.empty()) + { + json11::Json::array Needs = Json["needs"].array_items(); + for (const auto& Need : Needs) + { + Result.Needs.emplace_back(IoHash::FromHexString(Need.string_value())); + } + } + } + + return Result; +} + +CloudCacheResult +CloudCacheSession::PutBlob(std::string_view Namespace, const IoHash& Key, IoBuffer Blob) +{ + ZEN_TRACE_CPU("HordeClient::PutBlob"); + + ExtendableStringBuilder<256> Uri; + Uri << m_CacheClient->ServiceUrl() << "/api/v1/blobs/" << Namespace << "/" << Key.ToHexString(); + + cpr::Session& Session = GetSession(); + const CloudCacheAccessToken& AccessToken = GetAccessToken(); + + Session.SetOption(cpr::Url{Uri.c_str()}); + Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, {"Content-Type", "application/octet-stream"}}); + Session.SetBody(cpr::Body{(const char*)Blob.Data(), Blob.Size()}); + + cpr::Response Response = Session.Put(); + ZEN_DEBUG("PUT {}", Response); + + if (Response.error) + { + return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message}; + } + else if (!VerifyAccessToken(Response.status_code)) + { + return {.ErrorCode = 401, .Reason = std::string("Invalid access token")}; + } + + return {.Bytes = Response.uploaded_bytes, + .ElapsedSeconds = Response.elapsed, + .Success = (Response.status_code == 200 || Response.status_code == 201)}; +} + +CloudCacheResult +CloudCacheSession::PutCompressedBlob(std::string_view Namespace, const IoHash& Key, IoBuffer Blob) +{ + ZEN_TRACE_CPU("HordeClient::PutCompressedBlob"); + + ExtendableStringBuilder<256> Uri; + Uri << m_CacheClient->ServiceUrl() << "/api/v1/compressed-blobs/" << Namespace << "/" << Key.ToHexString(); + + cpr::Session& Session = GetSession(); + const CloudCacheAccessToken& AccessToken = GetAccessToken(); + + Session.SetOption(cpr::Url{Uri.c_str()}); + Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, {"Content-Type", "application/x-ue-comp"}}); + Session.SetBody(cpr::Body{(const char*)Blob.Data(), Blob.Size()}); + + cpr::Response Response = Session.Put(); + ZEN_DEBUG("PUT {}", Response); + + if (Response.error) + { + return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message}; + } + else if (!VerifyAccessToken(Response.status_code)) + { + return {.ErrorCode = 401, .Reason = std::string("Invalid access token")}; + } + + return {.Bytes = Response.uploaded_bytes, + .ElapsedSeconds = Response.elapsed, + .Success = (Response.status_code == 200 || Response.status_code == 201)}; +} + +CloudCacheResult +CloudCacheSession::PutCompressedBlob(std::string_view Namespace, const IoHash& Key, const CompositeBuffer& Payload) +{ + ZEN_TRACE_CPU("HordeClient::PutCompressedBlob"); + + ExtendableStringBuilder<256> Uri; + Uri << m_CacheClient->ServiceUrl() << "/api/v1/compressed-blobs/" << Namespace << "/" << Key.ToHexString(); + + cpr::Session& Session = GetSession(); + const CloudCacheAccessToken& AccessToken = GetAccessToken(); + + Session.SetOption(cpr::Url{Uri.c_str()}); + Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, {"Content-Type", "application/x-ue-comp"}}); + uint64_t SizeLeft = Payload.GetSize(); + CompositeBuffer::Iterator BufferIt = Payload.GetIterator(0); + auto ReadCallback = [&Payload, &BufferIt, &SizeLeft](char* buffer, size_t& size, intptr_t) { + size = Min<size_t>(size, SizeLeft); + MutableMemoryView Data(buffer, size); + Payload.CopyTo(Data, BufferIt); + SizeLeft -= size; + return true; + }; + Session.SetReadCallback(cpr::ReadCallback(gsl::narrow<cpr::cpr_off_t>(SizeLeft), ReadCallback)); + + cpr::Response Response = Session.Put(); + ZEN_DEBUG("PUT {}", Response); + + if (Response.error) + { + return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message}; + } + else if (!VerifyAccessToken(Response.status_code)) + { + return {.ErrorCode = 401, .Reason = std::string("Invalid access token")}; + } + + return {.Bytes = Response.uploaded_bytes, + .ElapsedSeconds = Response.elapsed, + .Success = (Response.status_code == 200 || Response.status_code == 201)}; +} + +CloudCacheResult +CloudCacheSession::PutObject(std::string_view Namespace, const IoHash& Key, IoBuffer Object) +{ + ZEN_TRACE_CPU("HordeClient::PutObject"); + + ExtendableStringBuilder<256> Uri; + Uri << m_CacheClient->ServiceUrl() << "/api/v1/objects/" << Namespace << "/" << Key.ToHexString(); + + cpr::Session& Session = GetSession(); + const CloudCacheAccessToken& AccessToken = GetAccessToken(); + + Session.SetOption(cpr::Url{Uri.c_str()}); + Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, {"Content-Type", "application/x-ue-cb"}}); + Session.SetBody(cpr::Body{(const char*)Object.Data(), Object.Size()}); + + cpr::Response Response = Session.Put(); + ZEN_DEBUG("PUT {}", Response); + + if (Response.error) + { + return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message}; + } + else if (!VerifyAccessToken(Response.status_code)) + { + return {.ErrorCode = 401, .Reason = std::string("Invalid access token")}; + } + + return {.Bytes = Response.uploaded_bytes, + .ElapsedSeconds = Response.elapsed, + .Success = (Response.status_code == 200 || Response.status_code == 201)}; +} + +CloudCacheResult +CloudCacheSession::RefExists(std::string_view Namespace, std::string_view BucketId, const IoHash& Key) +{ + ZEN_TRACE_CPU("HordeClient::RefExists"); + + ExtendableStringBuilder<256> Uri; + Uri << m_CacheClient->ServiceUrl() << "/api/v1/refs/" << Namespace << "/" << BucketId << "/" << Key.ToHexString(); + + cpr::Session& Session = GetSession(); + const CloudCacheAccessToken& AccessToken = GetAccessToken(); + + Session.SetOption(cpr::Url{Uri.c_str()}); + Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}}); + Session.SetOption(cpr::Body{}); + + cpr::Response Response = Session.Head(); + ZEN_DEBUG("HEAD {}", Response); + + if (Response.error) + { + return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message}; + } + else if (!VerifyAccessToken(Response.status_code)) + { + return {.ErrorCode = 401, .Reason = std::string("Invalid access token")}; + } + + return {.ElapsedSeconds = Response.elapsed, .Success = Response.status_code == 200}; +} + +GetObjectReferencesResult +CloudCacheSession::GetObjectReferences(std::string_view Namespace, const IoHash& Key) +{ + ZEN_TRACE_CPU("HordeClient::GetObjectReferences"); + + ExtendableStringBuilder<256> Uri; + Uri << m_CacheClient->ServiceUrl() << "/api/v1/objects/" << Namespace << "/" << Key.ToHexString() << "/references"; + + cpr::Session& Session = GetSession(); + const CloudCacheAccessToken& AccessToken = GetAccessToken(); + + Session.SetOption(cpr::Url{Uri.c_str()}); + Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, {"Accept", "application/x-ue-cb"}}); + Session.SetOption(cpr::Body{}); + + cpr::Response Response = Session.Get(); + ZEN_DEBUG("GET {}", Response); + + if (Response.error) + { + return {CloudCacheResult{.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message}}; + } + else if (!VerifyAccessToken(Response.status_code)) + { + return {CloudCacheResult{.ErrorCode = 401, .Reason = std::string("Invalid access token")}}; + } + + GetObjectReferencesResult Result{ + CloudCacheResult{.Bytes = Response.downloaded_bytes, .ElapsedSeconds = Response.elapsed, .Success = Response.status_code == 200}}; + + if (Result.Success) + { + IoBuffer Buffer = IoBuffer(zen::IoBuffer::Wrap, Response.text.data(), Response.text.size()); + const CbObject ReferencesResponse = LoadCompactBinaryObject(Buffer); + for (auto& Item : ReferencesResponse["references"sv]) + { + Result.References.insert(Item.AsHash()); + } + } + + return Result; +} + +CloudCacheResult +CloudCacheSession::BlobExists(std::string_view Namespace, const IoHash& Key) +{ + return CacheTypeExists(Namespace, "blobs"sv, Key); +} + +CloudCacheResult +CloudCacheSession::CompressedBlobExists(std::string_view Namespace, const IoHash& Key) +{ + return CacheTypeExists(Namespace, "compressed-blobs"sv, Key); +} + +CloudCacheResult +CloudCacheSession::ObjectExists(std::string_view Namespace, const IoHash& Key) +{ + return CacheTypeExists(Namespace, "objects"sv, Key); +} + +CloudCacheExistsResult +CloudCacheSession::BlobExists(std::string_view Namespace, const std::set<IoHash>& Keys) +{ + return CacheTypeExists(Namespace, "blobs"sv, Keys); +} + +CloudCacheExistsResult +CloudCacheSession::CompressedBlobExists(std::string_view Namespace, const std::set<IoHash>& Keys) +{ + return CacheTypeExists(Namespace, "compressed-blobs"sv, Keys); +} + +CloudCacheExistsResult +CloudCacheSession::ObjectExists(std::string_view Namespace, const std::set<IoHash>& Keys) +{ + return CacheTypeExists(Namespace, "objects"sv, Keys); +} + +CloudCacheResult +CloudCacheSession::PostComputeTasks(IoBuffer TasksData) +{ + ZEN_TRACE_CPU("HordeClient::PostComputeTasks"); + + ExtendableStringBuilder<256> Uri; + Uri << m_CacheClient->ServiceUrl() << "/api/v1/compute/" << m_CacheClient->ComputeCluster(); + + cpr::Session& Session = GetSession(); + const CloudCacheAccessToken& AccessToken = GetAccessToken(); + + Session.SetOption(cpr::Url{Uri.c_str()}); + Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, {"Content-Type", "application/x-ue-cb"}}); + Session.SetBody(cpr::Body{(const char*)TasksData.Data(), TasksData.Size()}); + + cpr::Response Response = Session.Post(); + ZEN_DEBUG("POST {}", Response); + + if (Response.error) + { + return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message}; + } + else if (!VerifyAccessToken(Response.status_code)) + { + return {.ErrorCode = 401, .Reason = std::string("Invalid access token")}; + } + + return {.Bytes = Response.uploaded_bytes, .ElapsedSeconds = Response.elapsed, .Success = Response.status_code == 200}; +} + +CloudCacheResult +CloudCacheSession::GetComputeUpdates(std::string_view ChannelId, const uint32_t WaitSeconds) +{ + ZEN_TRACE_CPU("HordeClient::GetComputeUpdates"); + + ExtendableStringBuilder<256> Uri; + Uri << m_CacheClient->ServiceUrl() << "/api/v1/compute/" << m_CacheClient->ComputeCluster() << "/updates/" << ChannelId + << "?wait=" << WaitSeconds; + + cpr::Session& Session = GetSession(); + const CloudCacheAccessToken& AccessToken = GetAccessToken(); + + Session.SetOption(cpr::Url{Uri.c_str()}); + Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, {"Accept", "application/x-ue-cb"}}); + Session.SetOption(cpr::Body{}); + + cpr::Response Response = Session.Post(); + ZEN_DEBUG("POST {}", Response); + + if (Response.error) + { + return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message}; + } + else if (!VerifyAccessToken(Response.status_code)) + { + return {.ErrorCode = 401, .Reason = std::string("Invalid access token")}; + } + + const bool Success = Response.status_code == 200; + const IoBuffer Buffer = Success ? IoBufferBuilder::MakeCloneFromMemory(Response.text.data(), Response.text.size()) : IoBuffer(); + + return {.Response = Buffer, .Bytes = Response.downloaded_bytes, .ElapsedSeconds = Response.elapsed, .Success = Success}; +} + +std::vector<IoHash> +CloudCacheSession::Filter(std::string_view Namespace, std::string_view BucketId, const std::vector<IoHash>& ChunkHashes) +{ + ExtendableStringBuilder<256> Uri; + Uri << m_CacheClient->ServiceUrl(); + Uri << "/api/v1/s/" << Namespace; + + ZEN_UNUSED(BucketId, ChunkHashes); + + return {}; +} + +cpr::Session& +CloudCacheSession::GetSession() +{ + return m_SessionState->GetSession(); +} + +CloudCacheAccessToken +CloudCacheSession::GetAccessToken(bool RefreshToken) +{ + return m_SessionState->GetAccessToken(RefreshToken); +} + +bool +CloudCacheSession::VerifyAccessToken(long StatusCode) +{ + return StatusCode != 401; +} + +CloudCacheResult +CloudCacheSession::CacheTypeExists(std::string_view Namespace, std::string_view TypeId, const IoHash& Key) +{ + ZEN_TRACE_CPU("HordeClient::CacheTypeExists"); + + ExtendableStringBuilder<256> Uri; + Uri << m_CacheClient->ServiceUrl() << "/api/v1/" << TypeId << "/" << Namespace << "/" << Key.ToHexString(); + + cpr::Session& Session = GetSession(); + const CloudCacheAccessToken& AccessToken = GetAccessToken(); + + Session.SetOption(cpr::Url{Uri.c_str()}); + Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}}); + Session.SetOption(cpr::Body{}); + + cpr::Response Response = Session.Head(); + ZEN_DEBUG("HEAD {}", Response); + + if (Response.error) + { + return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message}; + } + else if (!VerifyAccessToken(Response.status_code)) + { + return {.ErrorCode = 401, .Reason = std::string("Invalid access token")}; + } + + return {.ElapsedSeconds = Response.elapsed, .Success = Response.status_code == 200}; +} + +CloudCacheExistsResult +CloudCacheSession::CacheTypeExists(std::string_view Namespace, std::string_view TypeId, const std::set<IoHash>& Keys) +{ + ZEN_TRACE_CPU("HordeClient::CacheTypeExists"); + + ExtendableStringBuilder<256> Body; + Body << "["; + for (const auto& Key : Keys) + { + Body << (Body.Size() != 1 ? ",\"" : "\"") << Key.ToHexString() << "\""; + } + Body << "]"; + + ExtendableStringBuilder<256> Uri; + Uri << m_CacheClient->ServiceUrl() << "/api/v1/" << TypeId << "/" << Namespace << "/exist"; + + cpr::Session& Session = GetSession(); + const CloudCacheAccessToken& AccessToken = GetAccessToken(); + + Session.SetOption(cpr::Url{Uri.c_str()}); + Session.SetOption( + cpr::Header{{"Authorization", AccessToken.Value}, {"Accept", "application/x-ue-cb"}, {"Content-Type", "application/json"}}); + Session.SetOption(cpr::Body(Body.ToString())); + + cpr::Response Response = Session.Post(); + ZEN_DEBUG("POST {}", Response); + + if (Response.error) + { + return {CloudCacheResult{.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message}}; + } + else if (!VerifyAccessToken(Response.status_code)) + { + return {CloudCacheResult{.ErrorCode = 401, .Reason = std::string("Invalid access token")}}; + } + + CloudCacheExistsResult Result{ + CloudCacheResult{.Bytes = Response.downloaded_bytes, .ElapsedSeconds = Response.elapsed, .Success = Response.status_code == 200}}; + + if (Result.Success) + { + IoBuffer Buffer = IoBuffer(zen::IoBuffer::Wrap, Response.text.data(), Response.text.size()); + const CbObject ExistsResponse = LoadCompactBinaryObject(Buffer); + for (auto& Item : ExistsResponse["needs"sv]) + { + Result.Needs.insert(Item.AsHash()); + } + } + + return Result; +} + +/** + * An access token provider that holds a token that will never change. + */ +class StaticTokenProvider final : public CloudCacheTokenProvider +{ +public: + StaticTokenProvider(CloudCacheAccessToken Token) : m_Token(std::move(Token)) {} + + virtual ~StaticTokenProvider() = default; + + virtual CloudCacheAccessToken AcquireAccessToken() final override { return m_Token; } + +private: + CloudCacheAccessToken m_Token; +}; + +std::unique_ptr<CloudCacheTokenProvider> +CloudCacheTokenProvider::CreateFromStaticToken(CloudCacheAccessToken Token) +{ + return std::make_unique<StaticTokenProvider>(std::move(Token)); +} + +class OAuthClientCredentialsTokenProvider final : public CloudCacheTokenProvider +{ +public: + OAuthClientCredentialsTokenProvider(const CloudCacheTokenProvider::OAuthClientCredentialsParams& Params) + { + m_Url = std::string(Params.Url); + m_ClientId = std::string(Params.ClientId); + m_ClientSecret = std::string(Params.ClientSecret); + } + + virtual ~OAuthClientCredentialsTokenProvider() = default; + + virtual CloudCacheAccessToken AcquireAccessToken() final override + { + using namespace std::chrono; + + std::string Body = + fmt::format("client_id={}&scope=cache_access&grant_type=client_credentials&client_secret={}", m_ClientId, m_ClientSecret); + + cpr::Response Response = + cpr::Post(cpr::Url{m_Url}, cpr::Header{{"Content-Type", "application/x-www-form-urlencoded"}}, cpr::Body{std::move(Body)}); + + if (Response.error || Response.status_code != 200) + { + return {}; + } + + std::string JsonError; + json11::Json Json = json11::Json::parse(Response.text, JsonError); + + if (JsonError.empty() == false) + { + return {}; + } + + std::string Token = Json["access_token"].string_value(); + int64_t ExpiresInSeconds = static_cast<int64_t>(Json["expires_in"].int_value()); + CloudCacheAccessToken::TimePoint ExpireTime = CloudCacheAccessToken::Clock::now() + seconds(ExpiresInSeconds); + + return {.Value = fmt::format("Bearer {}", Token), .ExpireTime = ExpireTime}; + } + +private: + std::string m_Url; + std::string m_ClientId; + std::string m_ClientSecret; +}; + +std::unique_ptr<CloudCacheTokenProvider> +CloudCacheTokenProvider::CreateFromOAuthClientCredentials(const OAuthClientCredentialsParams& Params) +{ + return std::make_unique<OAuthClientCredentialsTokenProvider>(Params); +} + +class CallbackTokenProvider final : public CloudCacheTokenProvider +{ +public: + CallbackTokenProvider(std::function<CloudCacheAccessToken()>&& Callback) : m_Callback(std::move(Callback)) {} + + virtual ~CallbackTokenProvider() = default; + + virtual CloudCacheAccessToken AcquireAccessToken() final override { return m_Callback(); } + +private: + std::function<CloudCacheAccessToken()> m_Callback; +}; + +std::unique_ptr<CloudCacheTokenProvider> +CloudCacheTokenProvider::CreateFromCallback(std::function<CloudCacheAccessToken()>&& Callback) +{ + return std::make_unique<CallbackTokenProvider>(std::move(Callback)); +} + +CloudCacheClient::CloudCacheClient(const CloudCacheClientOptions& Options, std::unique_ptr<CloudCacheTokenProvider> TokenProvider) +: m_Log(zen::logging::Get("jupiter")) +, m_ServiceUrl(Options.ServiceUrl) +, m_DefaultDdcNamespace(Options.DdcNamespace) +, m_DefaultBlobStoreNamespace(Options.BlobStoreNamespace) +, m_ComputeCluster(Options.ComputeCluster) +, m_ConnectTimeout(Options.ConnectTimeout) +, m_Timeout(Options.Timeout) +, m_TokenProvider(std::move(TokenProvider)) +{ + ZEN_ASSERT(m_TokenProvider.get() != nullptr); +} + +CloudCacheClient::~CloudCacheClient() +{ + RwLock::ExclusiveLockScope _(m_SessionStateLock); + + for (auto State : m_SessionStateCache) + { + delete State; + } +} + +CloudCacheAccessToken +CloudCacheClient::AcquireAccessToken() +{ + ZEN_TRACE_CPU("HordeClient::AcquireAccessToken"); + + return m_TokenProvider->AcquireAccessToken(); +} + +detail::CloudCacheSessionState* +CloudCacheClient::AllocSessionState() +{ + detail::CloudCacheSessionState* State = nullptr; + + bool IsTokenValid = false; + + { + RwLock::ExclusiveLockScope _(m_SessionStateLock); + + if (m_SessionStateCache.empty() == false) + { + State = m_SessionStateCache.front(); + IsTokenValid = State->m_AccessToken.IsValid(); + + m_SessionStateCache.pop_front(); + } + } + + if (State == nullptr) + { + State = new detail::CloudCacheSessionState(*this); + } + + State->Reset(m_ConnectTimeout, m_Timeout); + + if (IsTokenValid == false) + { + State->m_AccessToken = m_TokenProvider->AcquireAccessToken(); + } + + return State; +} + +void +CloudCacheClient::FreeSessionState(detail::CloudCacheSessionState* State) +{ + RwLock::ExclusiveLockScope _(m_SessionStateLock); + m_SessionStateCache.push_front(State); +} + +} // namespace zen diff --git a/src/zenserver/upstream/jupiter.h b/src/zenserver/upstream/jupiter.h new file mode 100644 index 000000000..99e5c530f --- /dev/null +++ b/src/zenserver/upstream/jupiter.h @@ -0,0 +1,217 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/iohash.h> +#include <zencore/logging.h> +#include <zencore/refcount.h> +#include <zencore/thread.h> +#include <zenhttp/httpserver.h> + +#include <atomic> +#include <chrono> +#include <list> +#include <memory> +#include <set> +#include <vector> + +struct ZenCacheValue; + +namespace cpr { +class Session; +} + +namespace zen { +namespace detail { + struct CloudCacheSessionState; +} + +class CbObjectView; +class CloudCacheClient; +class IoBuffer; +struct IoHash; + +/** + * Cached access token, for use with `Authorization:` header + */ +struct CloudCacheAccessToken +{ + using Clock = std::chrono::system_clock; + using TimePoint = Clock::time_point; + + static constexpr int64_t ExpireMarginInSeconds = 30; + + std::string Value; + TimePoint ExpireTime; + + bool IsValid() const + { + return Value.empty() == false && + ExpireMarginInSeconds < std::chrono::duration_cast<std::chrono::seconds>(ExpireTime - Clock::now()).count(); + } +}; + +struct CloudCacheResult +{ + IoBuffer Response; + int64_t Bytes{}; + double ElapsedSeconds{}; + int32_t ErrorCode{}; + std::string Reason; + bool Success = false; +}; + +struct PutRefResult : CloudCacheResult +{ + std::vector<IoHash> Needs; +}; + +struct FinalizeRefResult : CloudCacheResult +{ + std::vector<IoHash> Needs; +}; + +struct CloudCacheExistsResult : CloudCacheResult +{ + std::set<IoHash> Needs; +}; + +struct GetObjectReferencesResult : CloudCacheResult +{ + std::set<IoHash> References; +}; + +/** + * Context for performing Jupiter operations + * + * Maintains an HTTP connection so that subsequent operations don't need to go + * through the whole connection setup process + * + */ +class CloudCacheSession +{ +public: + CloudCacheSession(CloudCacheClient* CacheClient); + ~CloudCacheSession(); + + CloudCacheResult Authenticate(); + CloudCacheResult GetRef(std::string_view Namespace, std::string_view BucketId, const IoHash& Key, ZenContentType RefType); + CloudCacheResult GetBlob(std::string_view Namespace, const IoHash& Key); + CloudCacheResult GetCompressedBlob(std::string_view Namespace, const IoHash& Key); + CloudCacheResult GetObject(std::string_view Namespace, const IoHash& Key); + CloudCacheResult GetInlineBlob(std::string_view Namespace, std::string_view BucketId, const IoHash& Key, IoHash& OutPayloadHash); + + PutRefResult PutRef(std::string_view Namespace, std::string_view BucketId, const IoHash& Key, IoBuffer Ref, ZenContentType RefType); + CloudCacheResult PutBlob(std::string_view Namespace, const IoHash& Key, IoBuffer Blob); + CloudCacheResult PutCompressedBlob(std::string_view Namespace, const IoHash& Key, IoBuffer Blob); + CloudCacheResult PutCompressedBlob(std::string_view Namespace, const IoHash& Key, const CompositeBuffer& Blob); + CloudCacheResult PutObject(std::string_view Namespace, const IoHash& Key, IoBuffer Object); + + FinalizeRefResult FinalizeRef(std::string_view Namespace, std::string_view BucketId, const IoHash& Key, const IoHash& RefHah); + + CloudCacheResult RefExists(std::string_view Namespace, std::string_view BucketId, const IoHash& Key); + + GetObjectReferencesResult GetObjectReferences(std::string_view Namespace, const IoHash& Key); + + CloudCacheResult BlobExists(std::string_view Namespace, const IoHash& Key); + CloudCacheResult CompressedBlobExists(std::string_view Namespace, const IoHash& Key); + CloudCacheResult ObjectExists(std::string_view Namespace, const IoHash& Key); + + CloudCacheExistsResult BlobExists(std::string_view Namespace, const std::set<IoHash>& Keys); + CloudCacheExistsResult CompressedBlobExists(std::string_view Namespace, const std::set<IoHash>& Keys); + CloudCacheExistsResult ObjectExists(std::string_view Namespace, const std::set<IoHash>& Keys); + + CloudCacheResult PostComputeTasks(IoBuffer TasksData); + CloudCacheResult GetComputeUpdates(std::string_view ChannelId, const uint32_t WaitSeconds = 0); + + std::vector<IoHash> Filter(std::string_view Namespace, std::string_view BucketId, const std::vector<IoHash>& ChunkHashes); + + CloudCacheClient& Client() { return *m_CacheClient; }; + +private: + inline spdlog::logger& Log() { return m_Log; } + cpr::Session& GetSession(); + CloudCacheAccessToken GetAccessToken(bool RefreshToken = false); + bool VerifyAccessToken(long StatusCode); + + CloudCacheResult CacheTypeExists(std::string_view Namespace, std::string_view TypeId, const IoHash& Key); + + CloudCacheExistsResult CacheTypeExists(std::string_view Namespace, std::string_view TypeId, const std::set<IoHash>& Keys); + + spdlog::logger& m_Log; + RefPtr<CloudCacheClient> m_CacheClient; + detail::CloudCacheSessionState* m_SessionState; +}; + +/** + * Access token provider interface + */ +class CloudCacheTokenProvider +{ +public: + virtual ~CloudCacheTokenProvider() = default; + + virtual CloudCacheAccessToken AcquireAccessToken() = 0; + + static std::unique_ptr<CloudCacheTokenProvider> CreateFromStaticToken(CloudCacheAccessToken Token); + + struct OAuthClientCredentialsParams + { + std::string_view Url; + std::string_view ClientId; + std::string_view ClientSecret; + }; + + static std::unique_ptr<CloudCacheTokenProvider> CreateFromOAuthClientCredentials(const OAuthClientCredentialsParams& Params); + + static std::unique_ptr<CloudCacheTokenProvider> CreateFromCallback(std::function<CloudCacheAccessToken()>&& Callback); +}; + +struct CloudCacheClientOptions +{ + std::string_view Name; + std::string_view ServiceUrl; + std::string_view DdcNamespace; + std::string_view BlobStoreNamespace; + std::string_view ComputeCluster; + std::chrono::milliseconds ConnectTimeout{5000}; + std::chrono::milliseconds Timeout{}; +}; + +/** + * Jupiter upstream cache client + */ +class CloudCacheClient : public RefCounted +{ +public: + CloudCacheClient(const CloudCacheClientOptions& Options, std::unique_ptr<CloudCacheTokenProvider> TokenProvider); + ~CloudCacheClient(); + + CloudCacheAccessToken AcquireAccessToken(); + std::string_view DefaultDdcNamespace() const { return m_DefaultDdcNamespace; } + std::string_view DefaultBlobStoreNamespace() const { return m_DefaultBlobStoreNamespace; } + std::string_view ComputeCluster() const { return m_ComputeCluster; } + std::string_view ServiceUrl() const { return m_ServiceUrl; } + + spdlog::logger& Logger() { return m_Log; } + +private: + spdlog::logger& m_Log; + std::string m_ServiceUrl; + std::string m_DefaultDdcNamespace; + std::string m_DefaultBlobStoreNamespace; + std::string m_ComputeCluster; + std::chrono::milliseconds m_ConnectTimeout{}; + std::chrono::milliseconds m_Timeout{}; + std::unique_ptr<CloudCacheTokenProvider> m_TokenProvider; + + RwLock m_SessionStateLock; + std::list<detail::CloudCacheSessionState*> m_SessionStateCache; + + detail::CloudCacheSessionState* AllocSessionState(); + void FreeSessionState(detail::CloudCacheSessionState*); + + friend class CloudCacheSession; +}; + +} // namespace zen diff --git a/src/zenserver/upstream/upstream.h b/src/zenserver/upstream/upstream.h new file mode 100644 index 000000000..a57301206 --- /dev/null +++ b/src/zenserver/upstream/upstream.h @@ -0,0 +1,8 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <upstream/jupiter.h> +#include <upstream/upstreamcache.h> +#include <upstream/upstreamservice.h> +#include <upstream/zen.h> diff --git a/src/zenserver/upstream/upstreamapply.cpp b/src/zenserver/upstream/upstreamapply.cpp new file mode 100644 index 000000000..c719b225d --- /dev/null +++ b/src/zenserver/upstream/upstreamapply.cpp @@ -0,0 +1,459 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "upstreamapply.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include <zencore/compactbinary.h> +# include <zencore/compactbinarybuilder.h> +# include <zencore/fmtutils.h> +# include <zencore/stream.h> +# include <zencore/timer.h> +# include <zencore/workthreadpool.h> + +# include <zenstore/cidstore.h> + +# include "diag/logging.h" + +# include <fmt/format.h> + +# include <atomic> + +namespace zen { + +using namespace std::literals; + +struct UpstreamApplyStats +{ + static constexpr uint64_t MaxSampleCount = 1000ull; + + UpstreamApplyStats(bool Enabled) : m_Enabled(Enabled) {} + + void Add(UpstreamApplyEndpoint& Endpoint, const PostUpstreamApplyResult& Result) + { + UpstreamApplyEndpointStats& Stats = Endpoint.Stats(); + + if (Result.Error) + { + Stats.ErrorCount.Increment(1); + } + else if (Result.Success) + { + Stats.PostCount.Increment(1); + Stats.UpBytes.Increment(Result.Bytes / 1024 / 1024); + } + } + + void Add(UpstreamApplyEndpoint& Endpoint, const GetUpstreamApplyUpdatesResult& Result) + { + UpstreamApplyEndpointStats& Stats = Endpoint.Stats(); + + if (Result.Error) + { + Stats.ErrorCount.Increment(1); + } + else if (Result.Success) + { + Stats.UpdateCount.Increment(1); + Stats.DownBytes.Increment(Result.Bytes / 1024 / 1024); + if (!Result.Completed.empty()) + { + uint64_t Completed = 0; + for (auto& It : Result.Completed) + { + Completed += It.second.size(); + } + Stats.CompleteCount.Increment(Completed); + } + } + } + + bool m_Enabled; +}; + +////////////////////////////////////////////////////////////////////////// + +class UpstreamApplyImpl final : public UpstreamApply +{ +public: + UpstreamApplyImpl(const UpstreamApplyOptions& Options, CidStore& CidStore) + : m_Log(logging::Get("upstream-apply")) + , m_Options(Options) + , m_CidStore(CidStore) + , m_Stats(Options.StatsEnabled) + , m_UpstreamAsyncWorkPool(Options.UpstreamThreadCount) + , m_DownstreamAsyncWorkPool(Options.DownstreamThreadCount) + { + } + + virtual ~UpstreamApplyImpl() { Shutdown(); } + + virtual bool Initialize() override + { + for (auto& Endpoint : m_Endpoints) + { + const UpstreamEndpointHealth Health = Endpoint->Initialize(); + if (Health.Ok) + { + Log().info("initialize endpoint '{}' OK", Endpoint->DisplayName()); + } + else + { + Log().warn("initialize endpoint '{}' FAILED, reason '{}'", Endpoint->DisplayName(), Health.Reason); + } + } + + m_RunState.IsRunning = !m_Endpoints.empty(); + + if (m_RunState.IsRunning) + { + m_ShutdownEvent.Reset(); + + m_UpstreamUpdatesThread = std::thread(&UpstreamApplyImpl::ProcessUpstreamUpdates, this); + + m_EndpointMonitorThread = std::thread(&UpstreamApplyImpl::MonitorEndpoints, this); + } + + return m_RunState.IsRunning; + } + + virtual bool IsHealthy() const override + { + if (m_RunState.IsRunning) + { + for (const auto& Endpoint : m_Endpoints) + { + if (Endpoint->IsHealthy()) + { + return true; + } + } + } + + return false; + } + + virtual void RegisterEndpoint(std::unique_ptr<UpstreamApplyEndpoint> Endpoint) override + { + m_Endpoints.emplace_back(std::move(Endpoint)); + } + + virtual EnqueueResult EnqueueUpstream(UpstreamApplyRecord ApplyRecord) override + { + if (m_RunState.IsRunning) + { + const IoHash WorkerId = ApplyRecord.WorkerDescriptor.GetHash(); + const IoHash ActionId = ApplyRecord.Action.GetHash(); + const uint32_t TimeoutSeconds = ApplyRecord.WorkerDescriptor["timeout"sv].AsInt32(300); + + { + std::scoped_lock Lock(m_ApplyTasksMutex); + if (auto Status = FindStatus(WorkerId, ActionId); Status != nullptr) + { + // Already in progress + return {.ApplyId = ActionId, .Success = true}; + } + + std::chrono::steady_clock::time_point ExpireTime = + TimeoutSeconds > 0 ? std::chrono::steady_clock::now() + std::chrono::seconds(TimeoutSeconds) + : std::chrono::steady_clock::time_point::max(); + + m_ApplyTasks[WorkerId][ActionId] = {.State = UpstreamApplyState::Queued, .Result{}, .ExpireTime = std::move(ExpireTime)}; + } + + ApplyRecord.Timepoints["zen-queue-added"] = DateTime::NowTicks(); + m_UpstreamAsyncWorkPool.ScheduleWork( + [this, ApplyRecord = std::move(ApplyRecord)]() { ProcessApplyRecord(std::move(ApplyRecord)); }); + + return {.ApplyId = ActionId, .Success = true}; + } + + return {}; + } + + virtual StatusResult GetStatus(const IoHash& WorkerId, const IoHash& ActionId) override + { + if (m_RunState.IsRunning) + { + std::scoped_lock Lock(m_ApplyTasksMutex); + if (auto Status = FindStatus(WorkerId, ActionId); Status != nullptr) + { + return {.Status = *Status, .Success = true}; + } + } + + return {}; + } + + virtual void GetStatus(CbObjectWriter& Status) override + { + Status << "upstream_worker_threads" << m_Options.UpstreamThreadCount; + Status << "upstream_queue_count" << m_UpstreamAsyncWorkPool.PendingWork(); + Status << "downstream_worker_threads" << m_Options.DownstreamThreadCount; + Status << "downstream_queue_count" << m_DownstreamAsyncWorkPool.PendingWork(); + + Status.BeginArray("endpoints"); + for (const auto& Ep : m_Endpoints) + { + Status.BeginObject(); + Status << "name" << Ep->DisplayName(); + Status << "health" << (Ep->IsHealthy() ? "ok"sv : "inactive"sv); + + UpstreamApplyEndpointStats& Stats = Ep->Stats(); + const uint64_t PostCount = Stats.PostCount.Value(); + const uint64_t CompleteCount = Stats.CompleteCount.Value(); + // const uint64_t UpdateCount = Stats.UpdateCount; + const double CompleteRate = CompleteCount > 0 ? (double(PostCount) / double(CompleteCount)) : 0.0; + + Status << "post_count" << PostCount; + Status << "complete_count" << PostCount; + Status << "update_count" << Stats.UpdateCount.Value(); + + Status << "complete_ratio" << CompleteRate; + Status << "downloaded_mb" << Stats.DownBytes.Value(); + Status << "uploaded_mb" << Stats.UpBytes.Value(); + Status << "error_count" << Stats.ErrorCount.Value(); + + Status.EndObject(); + } + Status.EndArray(); + } + +private: + // The caller is responsible for locking if required + UpstreamApplyStatus* FindStatus(const IoHash& WorkerId, const IoHash& ActionId) + { + if (auto It = m_ApplyTasks.find(WorkerId); It != m_ApplyTasks.end()) + { + if (auto It2 = It->second.find(ActionId); It2 != It->second.end()) + { + return &It2->second; + } + } + return nullptr; + } + + void ProcessApplyRecord(UpstreamApplyRecord ApplyRecord) + { + const IoHash WorkerId = ApplyRecord.WorkerDescriptor.GetHash(); + const IoHash ActionId = ApplyRecord.Action.GetHash(); + try + { + for (auto& Endpoint : m_Endpoints) + { + if (Endpoint->IsHealthy()) + { + ApplyRecord.Timepoints["zen-queue-dispatched"] = DateTime::NowTicks(); + PostUpstreamApplyResult Result = Endpoint->PostApply(std::move(ApplyRecord)); + { + std::scoped_lock Lock(m_ApplyTasksMutex); + if (auto Status = FindStatus(WorkerId, ActionId); Status != nullptr) + { + Status->Timepoints.merge(Result.Timepoints); + + if (Result.Success) + { + Status->State = UpstreamApplyState::Executing; + } + else + { + Status->State = UpstreamApplyState::Complete; + Status->Result = {.Error = std::move(Result.Error), + .Bytes = Result.Bytes, + .ElapsedSeconds = Result.ElapsedSeconds}; + } + } + } + m_Stats.Add(*Endpoint, Result); + return; + } + } + + Log().warn("process upstream apply ({}/{}) FAILED 'No available endpoint'", WorkerId, ActionId); + + { + std::scoped_lock Lock(m_ApplyTasksMutex); + if (auto Status = FindStatus(WorkerId, ActionId); Status != nullptr) + { + Status->State = UpstreamApplyState::Complete; + Status->Result = {.Error{.ErrorCode = -1, .Reason = "No available endpoint"}}; + } + } + } + catch (std::exception& e) + { + std::scoped_lock Lock(m_ApplyTasksMutex); + if (auto Status = FindStatus(WorkerId, ActionId); Status != nullptr) + { + Status->State = UpstreamApplyState::Complete; + Status->Result = {.Error{.ErrorCode = -1, .Reason = e.what()}}; + } + Log().warn("process upstream apply ({}/{}) FAILED '{}'", WorkerId, ActionId, e.what()); + } + } + + void ProcessApplyUpdates() + { + for (auto& Endpoint : m_Endpoints) + { + if (Endpoint->IsHealthy()) + { + GetUpstreamApplyUpdatesResult Result = Endpoint->GetUpdates(m_DownstreamAsyncWorkPool); + m_Stats.Add(*Endpoint, Result); + + if (!Result.Success) + { + Log().warn("process upstream apply updates FAILED '{}'", Result.Error.Reason); + } + + if (!Result.Completed.empty()) + { + for (auto& It : Result.Completed) + { + for (auto& It2 : It.second) + { + std::scoped_lock Lock(m_ApplyTasksMutex); + if (auto Status = FindStatus(It.first, It2.first); Status != nullptr) + { + Status->State = UpstreamApplyState::Complete; + Status->Result = std::move(It2.second); + Status->Result.Timepoints.merge(Status->Timepoints); + Status->Result.Timepoints["zen-queue-complete"] = DateTime::NowTicks(); + Status->Timepoints.clear(); + } + } + } + } + } + } + } + + void ProcessUpstreamUpdates() + { + const auto& UpdateSleep = std::chrono::milliseconds(m_Options.UpdatesInterval); + while (!m_ShutdownEvent.Wait(uint32_t(UpdateSleep.count()))) + { + if (!m_RunState.IsRunning) + { + break; + } + + ProcessApplyUpdates(); + + // Remove any expired tasks, regardless of state + { + std::scoped_lock Lock(m_ApplyTasksMutex); + for (auto& WorkerIt : m_ApplyTasks) + { + const auto Count = std::erase_if(WorkerIt.second, [](const auto& Item) { + return Item.second.ExpireTime < std::chrono::steady_clock::now(); + }); + if (Count > 0) + { + Log().debug("Removed '{}' expired tasks", Count); + } + } + const auto Count = std::erase_if(m_ApplyTasks, [](const auto& Item) { return Item.second.empty(); }); + if (Count > 0) + { + Log().debug("Removed '{}' empty task lists", Count); + } + } + } + } + + void MonitorEndpoints() + { + for (;;) + { + { + std::unique_lock Lock(m_RunState.Mutex); + if (m_RunState.ExitSignal.wait_for(Lock, m_Options.HealthCheckInterval, [this]() { return !m_RunState.IsRunning.load(); })) + { + break; + } + } + + for (auto& Endpoint : m_Endpoints) + { + if (!Endpoint->IsHealthy()) + { + if (const UpstreamEndpointHealth Health = Endpoint->CheckHealth(); Health.Ok) + { + Log().warn("health check endpoint '{}' OK", Endpoint->DisplayName(), Health.Reason); + } + else + { + Log().warn("health check endpoint '{}' FAILED, reason '{}'", Endpoint->DisplayName(), Health.Reason); + } + } + } + } + } + + void Shutdown() + { + if (m_RunState.Stop()) + { + m_ShutdownEvent.Set(); + m_EndpointMonitorThread.join(); + m_UpstreamUpdatesThread.join(); + m_Endpoints.clear(); + } + } + + spdlog::logger& Log() { return m_Log; } + + struct RunState + { + std::mutex Mutex; + std::condition_variable ExitSignal; + std::atomic_bool IsRunning{false}; + + bool Stop() + { + bool Stopped = false; + { + std::scoped_lock Lock(Mutex); + Stopped = IsRunning.exchange(false); + } + if (Stopped) + { + ExitSignal.notify_all(); + } + return Stopped; + } + }; + + spdlog::logger& m_Log; + UpstreamApplyOptions m_Options; + CidStore& m_CidStore; + UpstreamApplyStats m_Stats; + UpstreamApplyTasks m_ApplyTasks; + std::mutex m_ApplyTasksMutex; + std::vector<std::unique_ptr<UpstreamApplyEndpoint>> m_Endpoints; + Event m_ShutdownEvent; + WorkerThreadPool m_UpstreamAsyncWorkPool; + WorkerThreadPool m_DownstreamAsyncWorkPool; + std::thread m_UpstreamUpdatesThread; + std::thread m_EndpointMonitorThread; + RunState m_RunState; +}; + +////////////////////////////////////////////////////////////////////////// + +bool +UpstreamApply::IsHealthy() const +{ + return false; +} + +std::unique_ptr<UpstreamApply> +UpstreamApply::Create(const UpstreamApplyOptions& Options, CidStore& CidStore) +{ + return std::make_unique<UpstreamApplyImpl>(Options, CidStore); +} + +} // namespace zen + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zenserver/upstream/upstreamapply.h b/src/zenserver/upstream/upstreamapply.h new file mode 100644 index 000000000..4a095be6c --- /dev/null +++ b/src/zenserver/upstream/upstreamapply.h @@ -0,0 +1,192 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#if ZEN_WITH_COMPUTE_SERVICES + +# include <zencore/compactbinarypackage.h> +# include <zencore/iobuffer.h> +# include <zencore/iohash.h> +# include <zencore/stats.h> +# include <zencore/zencore.h> + +# include <chrono> +# include <map> +# include <unordered_map> +# include <unordered_set> + +namespace zen { + +class AuthMgr; +class CbObjectWriter; +class CidStore; +class CloudCacheTokenProvider; +class WorkerThreadPool; +class ZenCacheNamespace; +struct CloudCacheClientOptions; +struct UpstreamAuthConfig; + +enum class UpstreamApplyState : int32_t +{ + Queued = 0, + Executing = 1, + Complete = 2, +}; + +enum class UpstreamApplyType +{ + Simple = 0, + Asset = 1, +}; + +struct UpstreamApplyRecord +{ + CbObject WorkerDescriptor; + CbObject Action; + UpstreamApplyType Type; + std::map<std::string, uint64_t> Timepoints{}; +}; + +struct UpstreamApplyOptions +{ + std::chrono::seconds HealthCheckInterval{5}; + std::chrono::seconds UpdatesInterval{5}; + uint32_t UpstreamThreadCount = 4; + uint32_t DownstreamThreadCount = 4; + bool StatsEnabled = false; +}; + +struct UpstreamApplyError +{ + int32_t ErrorCode{}; + std::string Reason{}; + + explicit operator bool() const { return ErrorCode != 0; } +}; + +struct PostUpstreamApplyResult +{ + UpstreamApplyError Error{}; + int64_t Bytes{}; + double ElapsedSeconds{}; + std::map<std::string, uint64_t> Timepoints{}; + bool Success = false; +}; + +struct GetUpstreamApplyResult +{ + // UpstreamApplyType::Simple + std::map<std::filesystem::path, IoHash> OutputFiles{}; + std::map<IoHash, IoBuffer> FileData{}; + + // UpstreamApplyType::Asset + CbPackage OutputPackage{}; + int64_t TotalAttachmentBytes{}; + int64_t TotalRawAttachmentBytes{}; + + UpstreamApplyError Error{}; + int64_t Bytes{}; + double ElapsedSeconds{}; + std::string StdOut{}; + std::string StdErr{}; + std::string Agent{}; + std::string Detail{}; + std::map<std::string, uint64_t> Timepoints{}; + bool Success = false; +}; + +using UpstreamApplyCompleted = std::unordered_map<IoHash, std::unordered_map<IoHash, GetUpstreamApplyResult>>; + +struct GetUpstreamApplyUpdatesResult +{ + UpstreamApplyError Error{}; + int64_t Bytes{}; + double ElapsedSeconds{}; + UpstreamApplyCompleted Completed{}; + bool Success = false; +}; + +struct UpstreamApplyStatus +{ + UpstreamApplyState State{}; + GetUpstreamApplyResult Result{}; + std::chrono::steady_clock::time_point ExpireTime{}; + std::map<std::string, uint64_t> Timepoints{}; +}; + +using UpstreamApplyTasks = std::unordered_map<IoHash, std::unordered_map<IoHash, UpstreamApplyStatus>>; + +struct UpstreamEndpointHealth +{ + std::string Reason; + bool Ok = false; +}; + +struct UpstreamApplyEndpointStats +{ + metrics::Counter PostCount; + metrics::Counter CompleteCount; + metrics::Counter UpdateCount; + metrics::Counter ErrorCount; + metrics::Counter UpBytes; + metrics::Counter DownBytes; +}; + +/** + * The upstream apply endpoint is responsible for handling remote execution. + */ +class UpstreamApplyEndpoint +{ +public: + virtual ~UpstreamApplyEndpoint() = default; + + virtual UpstreamEndpointHealth Initialize() = 0; + virtual bool IsHealthy() const = 0; + virtual UpstreamEndpointHealth CheckHealth() = 0; + virtual std::string_view DisplayName() const = 0; + virtual PostUpstreamApplyResult PostApply(UpstreamApplyRecord ApplyRecord) = 0; + virtual GetUpstreamApplyUpdatesResult GetUpdates(WorkerThreadPool& ThreadPool) = 0; + virtual UpstreamApplyEndpointStats& Stats() = 0; + + static std::unique_ptr<UpstreamApplyEndpoint> CreateHordeEndpoint(const CloudCacheClientOptions& ComputeOptions, + const UpstreamAuthConfig& ComputeAuthConfig, + const CloudCacheClientOptions& StorageOptions, + const UpstreamAuthConfig& StorageAuthConfig, + CidStore& CidStore, + AuthMgr& Mgr); +}; + +/** + * Manages one or more upstream compute endpoints. + */ +class UpstreamApply +{ +public: + virtual ~UpstreamApply() = default; + + virtual bool Initialize() = 0; + virtual bool IsHealthy() const = 0; + virtual void RegisterEndpoint(std::unique_ptr<UpstreamApplyEndpoint> Endpoint) = 0; + + struct EnqueueResult + { + IoHash ApplyId{}; + bool Success = false; + }; + + struct StatusResult + { + UpstreamApplyStatus Status{}; + bool Success = false; + }; + + virtual EnqueueResult EnqueueUpstream(UpstreamApplyRecord ApplyRecord) = 0; + virtual StatusResult GetStatus(const IoHash& WorkerId, const IoHash& ActionId) = 0; + virtual void GetStatus(CbObjectWriter& CbO) = 0; + + static std::unique_ptr<UpstreamApply> Create(const UpstreamApplyOptions& Options, CidStore& CidStore); +}; + +} // namespace zen + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zenserver/upstream/upstreamcache.cpp b/src/zenserver/upstream/upstreamcache.cpp new file mode 100644 index 000000000..e838b5fe2 --- /dev/null +++ b/src/zenserver/upstream/upstreamcache.cpp @@ -0,0 +1,2112 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "upstreamcache.h" +#include "jupiter.h" +#include "zen.h" + +#include <zencore/blockingqueue.h> +#include <zencore/compactbinary.h> +#include <zencore/compactbinarybuilder.h> +#include <zencore/compactbinarypackage.h> +#include <zencore/compactbinaryvalidation.h> +#include <zencore/fmtutils.h> +#include <zencore/stats.h> +#include <zencore/stream.h> +#include <zencore/timer.h> +#include <zencore/trace.h> + +#include <zenhttp/httpshared.h> + +#include <zenstore/cidstore.h> + +#include <auth/authmgr.h> +#include "cache/structuredcache.h" +#include "cache/structuredcachestore.h" +#include "diag/logging.h" + +#include <fmt/format.h> + +#include <algorithm> +#include <atomic> +#include <shared_mutex> +#include <thread> +#include <unordered_map> + +namespace zen { + +using namespace std::literals; + +namespace detail { + + class UpstreamStatus + { + public: + UpstreamEndpointState EndpointState() const { return static_cast<UpstreamEndpointState>(m_State.load(std::memory_order_relaxed)); } + + UpstreamEndpointStatus EndpointStatus() const + { + const UpstreamEndpointState State = EndpointState(); + { + std::unique_lock _(m_Mutex); + return {.Reason = m_ErrorText, .State = State}; + } + } + + void Set(UpstreamEndpointState NewState) + { + m_State.store(static_cast<uint32_t>(NewState), std::memory_order_relaxed); + { + std::unique_lock _(m_Mutex); + m_ErrorText.clear(); + } + } + + void Set(UpstreamEndpointState NewState, std::string ErrorText) + { + m_State.store(static_cast<uint32_t>(NewState), std::memory_order_relaxed); + { + std::unique_lock _(m_Mutex); + m_ErrorText = std::move(ErrorText); + } + } + + void SetFromErrorCode(int32_t ErrorCode, std::string_view ErrorText) + { + if (ErrorCode != 0) + { + Set(ErrorCode == 401 ? UpstreamEndpointState::kUnauthorized : UpstreamEndpointState::kError, std::string(ErrorText)); + } + } + + private: + mutable std::mutex m_Mutex; + std::string m_ErrorText; + std::atomic_uint32_t m_State; + }; + + class JupiterUpstreamEndpoint final : public UpstreamEndpoint + { + public: + JupiterUpstreamEndpoint(const CloudCacheClientOptions& Options, const UpstreamAuthConfig& AuthConfig, AuthMgr& Mgr) + : m_AuthMgr(Mgr) + , m_Log(zen::logging::Get("upstream")) + { + ZEN_ASSERT(!Options.Name.empty()); + m_Info.Name = Options.Name; + m_Info.Url = Options.ServiceUrl; + + std::unique_ptr<CloudCacheTokenProvider> TokenProvider; + + if (AuthConfig.OAuthUrl.empty() == false) + { + TokenProvider = CloudCacheTokenProvider::CreateFromOAuthClientCredentials( + {.Url = AuthConfig.OAuthUrl, .ClientId = AuthConfig.OAuthClientId, .ClientSecret = AuthConfig.OAuthClientSecret}); + } + else if (AuthConfig.OpenIdProvider.empty() == false) + { + TokenProvider = + CloudCacheTokenProvider::CreateFromCallback([this, ProviderName = std::string(AuthConfig.OpenIdProvider)]() { + AuthMgr::OpenIdAccessToken Token = m_AuthMgr.GetOpenIdAccessToken(ProviderName); + return CloudCacheAccessToken{.Value = Token.AccessToken, .ExpireTime = Token.ExpireTime}; + }); + } + else + { + CloudCacheAccessToken AccessToken{.Value = std::string(AuthConfig.AccessToken), + .ExpireTime = CloudCacheAccessToken::TimePoint::max()}; + + TokenProvider = CloudCacheTokenProvider::CreateFromStaticToken(AccessToken); + } + + m_Client = new CloudCacheClient(Options, std::move(TokenProvider)); + } + + virtual ~JupiterUpstreamEndpoint() = default; + + virtual const UpstreamEndpointInfo& GetEndpointInfo() const override { return m_Info; } + + virtual UpstreamEndpointStatus Initialize() override + { + try + { + if (m_Status.EndpointState() == UpstreamEndpointState::kOk) + { + return {.State = UpstreamEndpointState::kOk}; + } + + CloudCacheSession Session(m_Client); + const CloudCacheResult Result = Session.Authenticate(); + + if (Result.Success) + { + m_Status.Set(UpstreamEndpointState::kOk); + } + else if (Result.ErrorCode != 0) + { + m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason); + } + else + { + m_Status.Set(UpstreamEndpointState::kUnauthorized); + } + + return m_Status.EndpointStatus(); + } + catch (std::exception& Err) + { + m_Status.Set(UpstreamEndpointState::kError, Err.what()); + + return {.Reason = Err.what(), .State = GetState()}; + } + } + + std::string_view GetActualDdcNamespace(CloudCacheSession& Session, std::string_view Namespace) + { + if (Namespace == ZenCacheStore::DefaultNamespace) + { + return Session.Client().DefaultDdcNamespace(); + } + return Namespace; + } + + std::string_view GetActualBlobStoreNamespace(CloudCacheSession& Session, std::string_view Namespace) + { + if (Namespace == ZenCacheStore::DefaultNamespace) + { + return Session.Client().DefaultBlobStoreNamespace(); + } + return Namespace; + } + + virtual UpstreamEndpointState GetState() override { return m_Status.EndpointState(); } + + virtual UpstreamEndpointStatus GetStatus() override { return m_Status.EndpointStatus(); } + + virtual GetUpstreamCacheSingleResult GetCacheRecord(std::string_view Namespace, + const CacheKey& CacheKey, + ZenContentType Type) override + { + ZEN_TRACE_CPU("Upstream::Horde::GetSingleCacheRecord"); + + try + { + CloudCacheSession Session(m_Client); + CloudCacheResult Result; + + std::string_view BlobStoreNamespace = GetActualBlobStoreNamespace(Session, Namespace); + + if (Type == ZenContentType::kCompressedBinary) + { + Result = Session.GetRef(BlobStoreNamespace, CacheKey.Bucket, CacheKey.Hash, ZenContentType::kCbObject); + + if (Result.Success) + { + const CbValidateError ValidationResult = ValidateCompactBinary(Result.Response, CbValidateMode::All); + if (Result.Success = ValidationResult == CbValidateError::None; Result.Success) + { + CbObject CacheRecord = LoadCompactBinaryObject(Result.Response); + IoBuffer ContentBuffer; + int NumAttachments = 0; + + CacheRecord.IterateAttachments([&](CbFieldView AttachmentHash) { + CloudCacheResult AttachmentResult = Session.GetCompressedBlob(BlobStoreNamespace, AttachmentHash.AsHash()); + Result.Bytes += AttachmentResult.Bytes; + Result.ElapsedSeconds += AttachmentResult.ElapsedSeconds; + Result.ErrorCode = AttachmentResult.ErrorCode; + + IoHash RawHash; + uint64_t RawSize; + if (CompressedBuffer::ValidateCompressedHeader(AttachmentResult.Response, RawHash, RawSize)) + { + Result.Response = AttachmentResult.Response; + ++NumAttachments; + } + else + { + Result.Success = false; + } + }); + if (NumAttachments != 1) + { + Result.Success = false; + } + } + } + } + else + { + const ZenContentType AcceptType = Type == ZenContentType::kCbPackage ? ZenContentType::kCbObject : Type; + Result = Session.GetRef(BlobStoreNamespace, CacheKey.Bucket, CacheKey.Hash, AcceptType); + + if (Result.Success && Type == ZenContentType::kCbPackage) + { + CbPackage Package; + + const CbValidateError ValidationResult = ValidateCompactBinary(Result.Response, CbValidateMode::All); + if (Result.Success = ValidationResult == CbValidateError::None; Result.Success) + { + CbObject CacheRecord = LoadCompactBinaryObject(Result.Response); + + CacheRecord.IterateAttachments([&](CbFieldView AttachmentHash) { + CloudCacheResult AttachmentResult = Session.GetCompressedBlob(BlobStoreNamespace, AttachmentHash.AsHash()); + Result.Bytes += AttachmentResult.Bytes; + Result.ElapsedSeconds += AttachmentResult.ElapsedSeconds; + Result.ErrorCode = AttachmentResult.ErrorCode; + + IoHash RawHash; + uint64_t RawSize; + if (CompressedBuffer Chunk = + CompressedBuffer::FromCompressed(SharedBuffer(AttachmentResult.Response), RawHash, RawSize)) + { + Package.AddAttachment(CbAttachment(Chunk, AttachmentHash.AsHash())); + } + else + { + Result.Success = false; + } + }); + + Package.SetObject(CacheRecord); + } + + if (Result.Success) + { + BinaryWriter MemStream; + Package.Save(MemStream); + + Result.Response = IoBuffer(IoBuffer::Clone, MemStream.Data(), MemStream.Size()); + } + } + } + + m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason); + + if (Result.ErrorCode == 0) + { + return {.Status = {.Bytes = Result.Bytes, .ElapsedSeconds = Result.ElapsedSeconds, .Success = Result.Success}, + .Value = Result.Response, + .Source = &m_Info}; + } + else + { + return {.Status = {.Error{.ErrorCode = Result.ErrorCode, .Reason = std::move(Result.Reason)}}}; + } + } + catch (std::exception& Err) + { + m_Status.Set(UpstreamEndpointState::kError, Err.what()); + + return {.Status = {.Error{.ErrorCode = -1, .Reason = Err.what()}}}; + } + } + + virtual GetUpstreamCacheResult GetCacheRecords(std::string_view Namespace, + std::span<CacheKeyRequest*> Requests, + OnCacheRecordGetComplete&& OnComplete) override + { + ZEN_TRACE_CPU("Upstream::Horde::GetCacheRecords"); + + CloudCacheSession Session(m_Client); + GetUpstreamCacheResult Result; + + for (CacheKeyRequest* Request : Requests) + { + const CacheKey& CacheKey = Request->Key; + CbPackage Package; + CbObject Record; + + double ElapsedSeconds = 0.0; + if (!Result.Error) + { + std::string_view BlobStoreNamespace = GetActualBlobStoreNamespace(Session, Namespace); + CloudCacheResult RefResult = + Session.GetRef(BlobStoreNamespace, CacheKey.Bucket, CacheKey.Hash, ZenContentType::kCbObject); + AppendResult(RefResult, Result); + ElapsedSeconds = RefResult.ElapsedSeconds; + + m_Status.SetFromErrorCode(RefResult.ErrorCode, RefResult.Reason); + + if (RefResult.ErrorCode == 0) + { + const CbValidateError ValidationResult = ValidateCompactBinary(RefResult.Response, CbValidateMode::All); + if (ValidationResult == CbValidateError::None) + { + Record = LoadCompactBinaryObject(RefResult.Response); + Record.IterateAttachments([&](CbFieldView AttachmentHash) { + CloudCacheResult BlobResult = Session.GetCompressedBlob(BlobStoreNamespace, AttachmentHash.AsHash()); + AppendResult(BlobResult, Result); + + m_Status.SetFromErrorCode(BlobResult.ErrorCode, BlobResult.Reason); + + if (BlobResult.ErrorCode == 0) + { + IoHash RawHash; + uint64_t RawSize; + if (CompressedBuffer Chunk = + CompressedBuffer::FromCompressed(SharedBuffer(BlobResult.Response), RawHash, RawSize)) + { + if (RawHash == AttachmentHash.AsHash()) + { + Package.AddAttachment(CbAttachment(Chunk, RawHash)); + } + } + } + }); + } + } + } + + OnComplete( + {.Request = *Request, .Record = Record, .Package = Package, .ElapsedSeconds = ElapsedSeconds, .Source = &m_Info}); + } + + return Result; + } + + virtual GetUpstreamCacheSingleResult GetCacheChunk(std::string_view Namespace, + const CacheKey&, + const IoHash& ValueContentId) override + { + ZEN_TRACE_CPU("Upstream::Horde::GetSingleCacheChunk"); + + try + { + CloudCacheSession Session(m_Client); + std::string_view BlobStoreNamespace = GetActualBlobStoreNamespace(Session, Namespace); + const CloudCacheResult Result = Session.GetCompressedBlob(BlobStoreNamespace, ValueContentId); + + m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason); + + if (Result.ErrorCode == 0) + { + return {.Status = {.Bytes = Result.Bytes, .ElapsedSeconds = Result.ElapsedSeconds, .Success = Result.Success}, + .Value = Result.Response, + .Source = &m_Info}; + } + else + { + return {.Status = {.Error{.ErrorCode = Result.ErrorCode, .Reason = std::move(Result.Reason)}}}; + } + } + catch (std::exception& Err) + { + m_Status.Set(UpstreamEndpointState::kError, Err.what()); + + return {.Status = {.Error{.ErrorCode = -1, .Reason = Err.what()}}}; + } + } + + virtual GetUpstreamCacheResult GetCacheChunks(std::string_view Namespace, + std::span<CacheChunkRequest*> CacheChunkRequests, + OnCacheChunksGetComplete&& OnComplete) override final + { + ZEN_TRACE_CPU("Upstream::Horde::GetCacheChunks"); + + CloudCacheSession Session(m_Client); + GetUpstreamCacheResult Result; + + for (CacheChunkRequest* RequestPtr : CacheChunkRequests) + { + CacheChunkRequest& Request = *RequestPtr; + IoBuffer Payload; + IoHash RawHash = IoHash::Zero; + uint64_t RawSize = 0; + + double ElapsedSeconds = 0.0; + bool IsCompressed = false; + if (!Result.Error) + { + std::string_view BlobStoreNamespace = GetActualBlobStoreNamespace(Session, Namespace); + const CloudCacheResult BlobResult = + Request.ChunkId == IoHash::Zero + ? Session.GetInlineBlob(BlobStoreNamespace, Request.Key.Bucket, Request.Key.Hash, Request.ChunkId) + : Session.GetCompressedBlob(BlobStoreNamespace, Request.ChunkId); + ElapsedSeconds = BlobResult.ElapsedSeconds; + Payload = BlobResult.Response; + + AppendResult(BlobResult, Result); + + m_Status.SetFromErrorCode(BlobResult.ErrorCode, BlobResult.Reason); + if (Payload && IsCompressedBinary(Payload.GetContentType())) + { + IsCompressed = CompressedBuffer::ValidateCompressedHeader(Payload, RawHash, RawSize); + } + } + + if (IsCompressed) + { + OnComplete({.Request = Request, + .RawHash = RawHash, + .RawSize = RawSize, + .Value = Payload, + .ElapsedSeconds = ElapsedSeconds, + .Source = &m_Info}); + } + else + { + OnComplete({.Request = Request, .RawHash = IoHash::Zero, .RawSize = 0, .Value = IoBuffer()}); + } + } + + return Result; + } + + virtual GetUpstreamCacheResult GetCacheValues(std::string_view Namespace, + std::span<CacheValueRequest*> CacheValueRequests, + OnCacheValueGetComplete&& OnComplete) override final + { + ZEN_TRACE_CPU("Upstream::Horde::GetCacheValues"); + + CloudCacheSession Session(m_Client); + GetUpstreamCacheResult Result; + + for (CacheValueRequest* RequestPtr : CacheValueRequests) + { + CacheValueRequest& Request = *RequestPtr; + IoBuffer Payload; + IoHash RawHash = IoHash::Zero; + uint64_t RawSize = 0; + + double ElapsedSeconds = 0.0; + bool IsCompressed = false; + if (!Result.Error) + { + std::string_view BlobStoreNamespace = GetActualBlobStoreNamespace(Session, Namespace); + IoHash PayloadHash; + const CloudCacheResult BlobResult = + Session.GetInlineBlob(BlobStoreNamespace, Request.Key.Bucket, Request.Key.Hash, PayloadHash); + ElapsedSeconds = BlobResult.ElapsedSeconds; + Payload = BlobResult.Response; + + AppendResult(BlobResult, Result); + + m_Status.SetFromErrorCode(BlobResult.ErrorCode, BlobResult.Reason); + if (Payload) + { + if (IsCompressedBinary(Payload.GetContentType())) + { + IsCompressed = CompressedBuffer::ValidateCompressedHeader(Payload, RawHash, RawSize) && RawHash != PayloadHash; + } + else + { + CompressedBuffer Compressed = CompressedBuffer::Compress(SharedBuffer(Payload)); + RawHash = Compressed.DecodeRawHash(); + if (RawHash == PayloadHash) + { + IsCompressed = true; + } + else + { + ZEN_WARN("Horde request for inline payload of {}/{}/{} has hash {}, expected hash {} from header", + Namespace, + Request.Key.Bucket, + Request.Key.Hash.ToHexString(), + RawHash.ToHexString(), + PayloadHash.ToHexString()); + } + } + } + } + + if (IsCompressed) + { + OnComplete({.Request = Request, + .RawHash = RawHash, + .RawSize = RawSize, + .Value = Payload, + .ElapsedSeconds = ElapsedSeconds, + .Source = &m_Info}); + } + else + { + OnComplete({.Request = Request, .RawHash = IoHash::Zero, .RawSize = 0, .Value = IoBuffer()}); + } + } + + return Result; + } + + virtual PutUpstreamCacheResult PutCacheRecord(const UpstreamCacheRecord& CacheRecord, + IoBuffer RecordValue, + std::span<IoBuffer const> Values) override + { + ZEN_TRACE_CPU("Upstream::Horde::PutCacheRecord"); + + ZEN_ASSERT(CacheRecord.ValueContentIds.size() == Values.size()); + const int32_t MaxAttempts = 3; + + try + { + CloudCacheSession Session(m_Client); + + if (CacheRecord.Type == ZenContentType::kBinary) + { + CloudCacheResult Result; + for (uint32_t Attempt = 0; Attempt < MaxAttempts && !Result.Success; Attempt++) + { + std::string_view BlobStoreNamespace = GetActualBlobStoreNamespace(Session, CacheRecord.Namespace); + Result = Session.PutRef(BlobStoreNamespace, + CacheRecord.Key.Bucket, + CacheRecord.Key.Hash, + RecordValue, + ZenContentType::kBinary); + } + + m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason); + + return {.Reason = std::move(Result.Reason), + .Bytes = Result.Bytes, + .ElapsedSeconds = Result.ElapsedSeconds, + .Success = Result.Success}; + } + else if (CacheRecord.Type == ZenContentType::kCompressedBinary) + { + IoHash RawHash; + uint64_t RawSize; + if (!CompressedBuffer::ValidateCompressedHeader(RecordValue, RawHash, RawSize)) + { + return {.Reason = std::string("Invalid compressed value buffer"), .Success = false}; + } + + CbObjectWriter ReferencingObject; + ReferencingObject.AddBinaryAttachment("RawHash", RawHash); + ReferencingObject.AddInteger("RawSize", RawSize); + + return PerformStructuredPut( + Session, + CacheRecord.Namespace, + CacheRecord.Key, + ReferencingObject.Save().GetBuffer().AsIoBuffer(), + MaxAttempts, + [&](const IoHash& ValueContentId, IoBuffer& OutBuffer, std::string& OutReason) { + if (ValueContentId != RawHash) + { + OutReason = + fmt::format("Value '{}' MISMATCHED from compressed buffer raw hash {}", ValueContentId, RawHash); + return false; + } + + OutBuffer = RecordValue; + return true; + }); + } + else + { + return PerformStructuredPut( + Session, + CacheRecord.Namespace, + CacheRecord.Key, + RecordValue, + MaxAttempts, + [&](const IoHash& ValueContentId, IoBuffer& OutBuffer, std::string& OutReason) { + const auto It = + std::find(std::begin(CacheRecord.ValueContentIds), std::end(CacheRecord.ValueContentIds), ValueContentId); + + if (It == std::end(CacheRecord.ValueContentIds)) + { + OutReason = fmt::format("value '{}' MISSING from local cache", ValueContentId); + return false; + } + + const size_t Idx = std::distance(std::begin(CacheRecord.ValueContentIds), It); + + OutBuffer = Values[Idx]; + return true; + }); + } + } + catch (std::exception& Err) + { + m_Status.Set(UpstreamEndpointState::kError, Err.what()); + + return {.Reason = std::string(Err.what()), .Success = false}; + } + } + + virtual UpstreamEndpointStats& Stats() override { return m_Stats; } + + private: + static void AppendResult(const CloudCacheResult& Result, GetUpstreamCacheResult& Out) + { + Out.Success &= Result.Success; + Out.Bytes += Result.Bytes; + Out.ElapsedSeconds += Result.ElapsedSeconds; + + if (Result.ErrorCode) + { + Out.Error = {.ErrorCode = Result.ErrorCode, .Reason = std::move(Result.Reason)}; + } + }; + + PutUpstreamCacheResult PerformStructuredPut( + CloudCacheSession& Session, + std::string_view Namespace, + const CacheKey& Key, + IoBuffer ObjectBuffer, + const int32_t MaxAttempts, + std::function<bool(const IoHash& ValueContentId, IoBuffer& OutBuffer, std::string& OutReason)>&& BlobFetchFn) + { + int64_t TotalBytes = 0ull; + double TotalElapsedSeconds = 0.0; + + std::string_view BlobStoreNamespace = GetActualBlobStoreNamespace(Session, Namespace); + const auto PutBlobs = [&](std::span<IoHash> ValueContentIds, std::string& OutReason) -> bool { + for (const IoHash& ValueContentId : ValueContentIds) + { + IoBuffer BlobBuffer; + if (!BlobFetchFn(ValueContentId, BlobBuffer, OutReason)) + { + return false; + } + + CloudCacheResult BlobResult; + for (int32_t Attempt = 0; Attempt < MaxAttempts && !BlobResult.Success; Attempt++) + { + BlobResult = Session.PutCompressedBlob(BlobStoreNamespace, ValueContentId, BlobBuffer); + } + + m_Status.SetFromErrorCode(BlobResult.ErrorCode, BlobResult.Reason); + + if (!BlobResult.Success) + { + OutReason = fmt::format("upload value '{}' FAILED, reason '{}'", ValueContentId, BlobResult.Reason); + return false; + } + + TotalBytes += BlobResult.Bytes; + TotalElapsedSeconds += BlobResult.ElapsedSeconds; + } + + return true; + }; + + PutRefResult RefResult; + for (int32_t Attempt = 0; Attempt < MaxAttempts && !RefResult.Success; Attempt++) + { + RefResult = Session.PutRef(BlobStoreNamespace, Key.Bucket, Key.Hash, ObjectBuffer, ZenContentType::kCbObject); + } + + m_Status.SetFromErrorCode(RefResult.ErrorCode, RefResult.Reason); + + if (!RefResult.Success) + { + return {.Reason = fmt::format("upload cache record '{}/{}' FAILED, reason '{}'", Key.Bucket, Key.Hash, RefResult.Reason), + .Success = false}; + } + + TotalBytes += RefResult.Bytes; + TotalElapsedSeconds += RefResult.ElapsedSeconds; + + std::string Reason; + if (!PutBlobs(RefResult.Needs, Reason)) + { + return {.Reason = std::move(Reason), .Success = false}; + } + + const IoHash RefHash = IoHash::HashBuffer(ObjectBuffer); + FinalizeRefResult FinalizeResult = Session.FinalizeRef(BlobStoreNamespace, Key.Bucket, Key.Hash, RefHash); + + m_Status.SetFromErrorCode(FinalizeResult.ErrorCode, FinalizeResult.Reason); + + if (!FinalizeResult.Success) + { + return { + .Reason = fmt::format("finalize cache record '{}/{}' FAILED, reason '{}'", Key.Bucket, Key.Hash, FinalizeResult.Reason), + .Success = false}; + } + + if (!FinalizeResult.Needs.empty()) + { + if (!PutBlobs(FinalizeResult.Needs, Reason)) + { + return {.Reason = std::move(Reason), .Success = false}; + } + + FinalizeResult = Session.FinalizeRef(BlobStoreNamespace, Key.Bucket, Key.Hash, RefHash); + + m_Status.SetFromErrorCode(FinalizeResult.ErrorCode, FinalizeResult.Reason); + + if (!FinalizeResult.Success) + { + return {.Reason = fmt::format("finalize '{}/{}' FAILED, reason '{}'", Key.Bucket, Key.Hash, FinalizeResult.Reason), + .Success = false}; + } + + if (!FinalizeResult.Needs.empty()) + { + ExtendableStringBuilder<256> Sb; + for (const IoHash& MissingHash : FinalizeResult.Needs) + { + Sb << MissingHash.ToHexString() << ","; + } + + return { + .Reason = fmt::format("finalize '{}/{}' FAILED, still needs value(s) '{}'", Key.Bucket, Key.Hash, Sb.ToString()), + .Success = false}; + } + } + + TotalBytes += FinalizeResult.Bytes; + TotalElapsedSeconds += FinalizeResult.ElapsedSeconds; + + return {.Bytes = TotalBytes, .ElapsedSeconds = TotalElapsedSeconds, .Success = true}; + } + + spdlog::logger& Log() { return m_Log; } + + AuthMgr& m_AuthMgr; + spdlog::logger& m_Log; + UpstreamEndpointInfo m_Info; + UpstreamStatus m_Status; + UpstreamEndpointStats m_Stats; + RefPtr<CloudCacheClient> m_Client; + }; + + class ZenUpstreamEndpoint final : public UpstreamEndpoint + { + struct ZenEndpoint + { + std::string Url; + std::string Reason; + double Latency{}; + bool Ok = false; + + bool operator<(const ZenEndpoint& RHS) const { return Ok && RHS.Ok ? Latency < RHS.Latency : Ok; } + }; + + public: + ZenUpstreamEndpoint(const ZenStructuredCacheClientOptions& Options) + : m_Log(zen::logging::Get("upstream")) + , m_ConnectTimeout(Options.ConnectTimeout) + , m_Timeout(Options.Timeout) + { + ZEN_ASSERT(!Options.Name.empty()); + m_Info.Name = Options.Name; + + for (const auto& Url : Options.Urls) + { + m_Endpoints.push_back({.Url = Url}); + } + } + + ~ZenUpstreamEndpoint() = default; + + virtual const UpstreamEndpointInfo& GetEndpointInfo() const override { return m_Info; } + + virtual UpstreamEndpointStatus Initialize() override + { + try + { + if (m_Status.EndpointState() == UpstreamEndpointState::kOk) + { + return {.State = UpstreamEndpointState::kOk}; + } + + const ZenEndpoint& Ep = GetEndpoint(); + + if (m_Info.Url != Ep.Url) + { + ZEN_INFO("Setting Zen upstream URL to '{}'", Ep.Url); + m_Info.Url = Ep.Url; + } + + if (Ep.Ok) + { + RwLock::ExclusiveLockScope _(m_ClientLock); + m_Client = new ZenStructuredCacheClient({.Url = m_Info.Url, .ConnectTimeout = m_ConnectTimeout, .Timeout = m_Timeout}); + m_Status.Set(UpstreamEndpointState::kOk); + } + else + { + m_Status.Set(UpstreamEndpointState::kError, Ep.Reason); + } + + return m_Status.EndpointStatus(); + } + catch (std::exception& Err) + { + m_Status.Set(UpstreamEndpointState::kError, Err.what()); + + return {.Reason = Err.what(), .State = GetState()}; + } + } + + virtual UpstreamEndpointState GetState() override { return m_Status.EndpointState(); } + + virtual UpstreamEndpointStatus GetStatus() override { return m_Status.EndpointStatus(); } + + virtual GetUpstreamCacheSingleResult GetCacheRecord(std::string_view Namespace, + const CacheKey& CacheKey, + ZenContentType Type) override + { + ZEN_TRACE_CPU("Upstream::Zen::GetSingleCacheRecord"); + + try + { + ZenStructuredCacheSession Session(GetClientRef()); + const ZenCacheResult Result = Session.GetCacheRecord(Namespace, CacheKey.Bucket, CacheKey.Hash, Type); + + m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason); + + if (Result.ErrorCode == 0) + { + return {.Status = {.Bytes = Result.Bytes, .ElapsedSeconds = Result.ElapsedSeconds, .Success = Result.Success}, + .Value = Result.Response, + .Source = &m_Info}; + } + else + { + return {.Status = {.Error{.ErrorCode = Result.ErrorCode, .Reason = std::move(Result.Reason)}}}; + } + } + catch (std::exception& Err) + { + m_Status.Set(UpstreamEndpointState::kError, Err.what()); + + return {.Status = {.Error{.ErrorCode = -1, .Reason = Err.what()}}}; + } + } + + virtual GetUpstreamCacheResult GetCacheRecords(std::string_view Namespace, + std::span<CacheKeyRequest*> Requests, + OnCacheRecordGetComplete&& OnComplete) override + { + ZEN_TRACE_CPU("Upstream::Zen::GetCacheRecords"); + ZEN_ASSERT(Requests.size() > 0); + + CbObjectWriter BatchRequest; + BatchRequest << "Method"sv + << "GetCacheRecords"sv; + BatchRequest << "Accept"sv << kCbPkgMagic; + + BatchRequest.BeginObject("Params"sv); + { + CachePolicy DefaultPolicy = Requests[0]->Policy.GetRecordPolicy(); + BatchRequest << "DefaultPolicy"sv << WriteToString<128>(DefaultPolicy); + + BatchRequest << "Namespace"sv << Namespace; + + BatchRequest.BeginArray("Requests"sv); + for (CacheKeyRequest* Request : Requests) + { + BatchRequest.BeginObject(); + { + const CacheKey& Key = Request->Key; + BatchRequest.BeginObject("Key"sv); + { + BatchRequest << "Bucket"sv << Key.Bucket; + BatchRequest << "Hash"sv << Key.Hash; + } + BatchRequest.EndObject(); + if (!Request->Policy.IsUniform() || Request->Policy.GetRecordPolicy() != DefaultPolicy) + { + BatchRequest.SetName("Policy"sv); + Request->Policy.Save(BatchRequest); + } + } + BatchRequest.EndObject(); + } + BatchRequest.EndArray(); + } + BatchRequest.EndObject(); + + ZenCacheResult Result; + + { + ZenStructuredCacheSession Session(GetClientRef()); + Result = Session.InvokeRpc(BatchRequest.Save()); + } + + m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason); + + if (Result.Success) + { + CbPackage BatchResponse; + if (ParsePackageMessageWithLegacyFallback(Result.Response, BatchResponse)) + { + CbArrayView Results = BatchResponse.GetObject()["Result"sv].AsArrayView(); + if (Results.Num() != Requests.size()) + { + ZEN_WARN("Upstream::Zen::GetCacheRecords invalid number of Response results from Upstream."); + } + else + { + for (size_t Index = 0; CbFieldView Record : Results) + { + CacheKeyRequest* Request = Requests[Index++]; + OnComplete({.Request = *Request, + .Record = Record.AsObjectView(), + .Package = BatchResponse, + .ElapsedSeconds = Result.ElapsedSeconds, + .Source = &m_Info}); + } + + return {.Bytes = Result.Bytes, .ElapsedSeconds = Result.ElapsedSeconds, .Success = true}; + } + } + else + { + ZEN_WARN("Upstream::Zen::GetCacheRecords invalid Response from Upstream."); + } + } + + for (CacheKeyRequest* Request : Requests) + { + OnComplete({.Request = *Request, .Record = CbObjectView(), .Package = CbPackage()}); + } + + return {.Error{.ErrorCode = Result.ErrorCode, .Reason = std::move(Result.Reason)}}; + } + + virtual GetUpstreamCacheSingleResult GetCacheChunk(std::string_view Namespace, + const CacheKey& CacheKey, + const IoHash& ValueContentId) override + { + ZEN_TRACE_CPU("Upstream::Zen::GetCacheChunk"); + + try + { + ZenStructuredCacheSession Session(GetClientRef()); + const ZenCacheResult Result = Session.GetCacheChunk(Namespace, CacheKey.Bucket, CacheKey.Hash, ValueContentId); + + m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason); + + if (Result.ErrorCode == 0) + { + return {.Status = {.Bytes = Result.Bytes, .ElapsedSeconds = Result.ElapsedSeconds, .Success = Result.Success}, + .Value = Result.Response, + .Source = &m_Info}; + } + else + { + return {.Status = {.Error{.ErrorCode = Result.ErrorCode, .Reason = std::move(Result.Reason)}}}; + } + } + catch (std::exception& Err) + { + m_Status.Set(UpstreamEndpointState::kError, Err.what()); + + return {.Status = {.Error{.ErrorCode = -1, .Reason = Err.what()}}}; + } + } + + virtual GetUpstreamCacheResult GetCacheValues(std::string_view Namespace, + std::span<CacheValueRequest*> CacheValueRequests, + OnCacheValueGetComplete&& OnComplete) override final + { + ZEN_TRACE_CPU("Upstream::Zen::GetCacheValues"); + ZEN_ASSERT(!CacheValueRequests.empty()); + + CbObjectWriter BatchRequest; + BatchRequest << "Method"sv + << "GetCacheValues"sv; + BatchRequest << "Accept"sv << kCbPkgMagic; + + BatchRequest.BeginObject("Params"sv); + { + CachePolicy DefaultPolicy = CacheValueRequests[0]->Policy; + BatchRequest << "DefaultPolicy"sv << WriteToString<128>(DefaultPolicy).ToView(); + BatchRequest << "Namespace"sv << Namespace; + + BatchRequest.BeginArray("Requests"sv); + { + for (CacheValueRequest* RequestPtr : CacheValueRequests) + { + const CacheValueRequest& Request = *RequestPtr; + + BatchRequest.BeginObject(); + { + BatchRequest.BeginObject("Key"sv); + BatchRequest << "Bucket"sv << Request.Key.Bucket; + BatchRequest << "Hash"sv << Request.Key.Hash; + BatchRequest.EndObject(); + if (Request.Policy != DefaultPolicy) + { + BatchRequest << "Policy"sv << WriteToString<128>(Request.Policy).ToView(); + } + } + BatchRequest.EndObject(); + } + } + BatchRequest.EndArray(); + } + BatchRequest.EndObject(); + + ZenCacheResult Result; + + { + ZenStructuredCacheSession Session(GetClientRef()); + Result = Session.InvokeRpc(BatchRequest.Save()); + } + + m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason); + + if (Result.Success) + { + CbPackage BatchResponse; + if (ParsePackageMessageWithLegacyFallback(Result.Response, BatchResponse)) + { + CbArrayView Results = BatchResponse.GetObject()["Result"sv].AsArrayView(); + if (CacheValueRequests.size() != Results.Num()) + { + ZEN_WARN("Upstream::Zen::GetCacheValues invalid number of Response results from Upstream."); + } + else + { + for (size_t RequestIndex = 0; CbFieldView ChunkField : Results) + { + CacheValueRequest& Request = *CacheValueRequests[RequestIndex++]; + CbObjectView ChunkObject = ChunkField.AsObjectView(); + IoHash RawHash = ChunkObject["RawHash"sv].AsHash(); + IoBuffer Payload; + uint64_t RawSize = 0; + if (RawHash != IoHash::Zero) + { + bool Success = false; + const CbAttachment* Attachment = BatchResponse.FindAttachment(RawHash); + if (Attachment) + { + if (const CompressedBuffer& Compressed = Attachment->AsCompressedBinary()) + { + Payload = Compressed.GetCompressed().Flatten().AsIoBuffer(); + Payload.SetContentType(ZenContentType::kCompressedBinary); + RawSize = Compressed.DecodeRawSize(); + Success = true; + } + } + if (!Success) + { + CbFieldView RawSizeField = ChunkObject["RawSize"sv]; + RawSize = RawSizeField.AsUInt64(); + Success = !RawSizeField.HasError(); + } + if (!Success) + { + RawHash = IoHash::Zero; + } + } + OnComplete({.Request = Request, + .RawHash = RawHash, + .RawSize = RawSize, + .Value = std::move(Payload), + .ElapsedSeconds = Result.ElapsedSeconds, + .Source = &m_Info}); + } + + return {.Bytes = Result.Bytes, .ElapsedSeconds = Result.ElapsedSeconds, .Success = true}; + } + } + else + { + ZEN_WARN("Upstream::Zen::GetCacheValues invalid Response from Upstream."); + } + } + + for (CacheValueRequest* RequestPtr : CacheValueRequests) + { + OnComplete({.Request = *RequestPtr, .RawHash = IoHash::Zero, .RawSize = 0, .Value = IoBuffer()}); + } + + return {.Error{.ErrorCode = Result.ErrorCode, .Reason = std::move(Result.Reason)}}; + } + + virtual GetUpstreamCacheResult GetCacheChunks(std::string_view Namespace, + std::span<CacheChunkRequest*> CacheChunkRequests, + OnCacheChunksGetComplete&& OnComplete) override final + { + ZEN_TRACE_CPU("Upstream::Zen::GetCacheChunks"); + ZEN_ASSERT(!CacheChunkRequests.empty()); + + CbObjectWriter BatchRequest; + BatchRequest << "Method"sv + << "GetCacheChunks"sv; + BatchRequest << "Accept"sv << kCbPkgMagic; + + BatchRequest.BeginObject("Params"sv); + { + CachePolicy DefaultPolicy = CacheChunkRequests[0]->Policy; + BatchRequest << "DefaultPolicy"sv << WriteToString<128>(DefaultPolicy).ToView(); + BatchRequest << "Namespace"sv << Namespace; + + BatchRequest.BeginArray("ChunkRequests"sv); + { + for (CacheChunkRequest* RequestPtr : CacheChunkRequests) + { + const CacheChunkRequest& Request = *RequestPtr; + + BatchRequest.BeginObject(); + { + BatchRequest.BeginObject("Key"sv); + BatchRequest << "Bucket"sv << Request.Key.Bucket; + BatchRequest << "Hash"sv << Request.Key.Hash; + BatchRequest.EndObject(); + if (Request.ValueId) + { + BatchRequest.AddObjectId("ValueId"sv, Request.ValueId); + } + if (Request.ChunkId != Request.ChunkId.Zero) + { + BatchRequest << "ChunkId"sv << Request.ChunkId; + } + if (Request.RawOffset != 0) + { + BatchRequest << "RawOffset"sv << Request.RawOffset; + } + if (Request.RawSize != UINT64_MAX) + { + BatchRequest << "RawSize"sv << Request.RawSize; + } + if (Request.Policy != DefaultPolicy) + { + BatchRequest << "Policy"sv << WriteToString<128>(Request.Policy).ToView(); + } + } + BatchRequest.EndObject(); + } + } + BatchRequest.EndArray(); + } + BatchRequest.EndObject(); + + ZenCacheResult Result; + + { + ZenStructuredCacheSession Session(GetClientRef()); + Result = Session.InvokeRpc(BatchRequest.Save()); + } + + m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason); + + if (Result.Success) + { + CbPackage BatchResponse; + if (ParsePackageMessageWithLegacyFallback(Result.Response, BatchResponse)) + { + CbArrayView Results = BatchResponse.GetObject()["Result"sv].AsArrayView(); + if (CacheChunkRequests.size() != Results.Num()) + { + ZEN_WARN("Upstream::Zen::GetCacheChunks invalid number of Response results from Upstream."); + } + else + { + for (size_t RequestIndex = 0; CbFieldView ChunkField : Results) + { + CacheChunkRequest& Request = *CacheChunkRequests[RequestIndex++]; + CbObjectView ChunkObject = ChunkField.AsObjectView(); + IoHash RawHash = ChunkObject["RawHash"sv].AsHash(); + IoBuffer Payload; + uint64_t RawSize = 0; + if (RawHash != IoHash::Zero) + { + bool Success = false; + const CbAttachment* Attachment = BatchResponse.FindAttachment(RawHash); + if (Attachment) + { + if (const CompressedBuffer& Compressed = Attachment->AsCompressedBinary()) + { + Payload = Compressed.GetCompressed().Flatten().AsIoBuffer(); + Payload.SetContentType(ZenContentType::kCompressedBinary); + RawSize = Compressed.DecodeRawSize(); + Success = true; + } + } + if (!Success) + { + CbFieldView RawSizeField = ChunkObject["RawSize"sv]; + RawSize = RawSizeField.AsUInt64(); + Success = !RawSizeField.HasError(); + } + if (!Success) + { + RawHash = IoHash::Zero; + } + } + OnComplete({.Request = Request, + .RawHash = RawHash, + .RawSize = RawSize, + .Value = std::move(Payload), + .ElapsedSeconds = Result.ElapsedSeconds, + .Source = &m_Info}); + } + + return {.Bytes = Result.Bytes, .ElapsedSeconds = Result.ElapsedSeconds, .Success = true}; + } + } + else + { + ZEN_WARN("Upstream::Zen::GetCacheChunks invalid Response from Upstream."); + } + } + + for (CacheChunkRequest* RequestPtr : CacheChunkRequests) + { + OnComplete({.Request = *RequestPtr, .RawHash = IoHash::Zero, .RawSize = 0, .Value = IoBuffer()}); + } + + return {.Error{.ErrorCode = Result.ErrorCode, .Reason = std::move(Result.Reason)}}; + } + + virtual PutUpstreamCacheResult PutCacheRecord(const UpstreamCacheRecord& CacheRecord, + IoBuffer RecordValue, + std::span<IoBuffer const> Values) override + { + ZEN_TRACE_CPU("Upstream::Zen::PutCacheRecord"); + + ZEN_ASSERT(CacheRecord.ValueContentIds.size() == Values.size()); + const int32_t MaxAttempts = 3; + + try + { + ZenStructuredCacheSession Session(GetClientRef()); + ZenCacheResult Result; + int64_t TotalBytes = 0ull; + double TotalElapsedSeconds = 0.0; + + if (CacheRecord.Type == ZenContentType::kCbPackage) + { + CbPackage Package; + Package.SetObject(CbObject(SharedBuffer(RecordValue))); + + for (const IoBuffer& Value : Values) + { + IoHash RawHash; + uint64_t RawSize; + if (CompressedBuffer AttachmentBuffer = CompressedBuffer::FromCompressed(SharedBuffer(Value), RawHash, RawSize)) + { + Package.AddAttachment(CbAttachment(AttachmentBuffer, RawHash)); + } + else + { + return {.Reason = std::string("Invalid value buffer"), .Success = false}; + } + } + + BinaryWriter MemStream; + Package.Save(MemStream); + IoBuffer PackagePayload(IoBuffer::Wrap, MemStream.Data(), MemStream.Size()); + + for (uint32_t Attempt = 0; Attempt < MaxAttempts && !Result.Success; Attempt++) + { + Result = Session.PutCacheRecord(CacheRecord.Namespace, + CacheRecord.Key.Bucket, + CacheRecord.Key.Hash, + PackagePayload, + CacheRecord.Type); + } + + m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason); + + TotalBytes = Result.Bytes; + TotalElapsedSeconds = Result.ElapsedSeconds; + } + else if (CacheRecord.Type == ZenContentType::kCompressedBinary) + { + IoHash RawHash; + uint64_t RawSize; + CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(RecordValue), RawHash, RawSize); + if (!Compressed) + { + return {.Reason = std::string("Invalid value compressed buffer"), .Success = false}; + } + + CbPackage BatchPackage; + CbObjectWriter BatchWriter; + BatchWriter << "Method"sv + << "PutCacheValues"sv; + BatchWriter << "Accept"sv << kCbPkgMagic; + + BatchWriter.BeginObject("Params"sv); + { + // DefaultPolicy unspecified and expected to be Default + + BatchWriter << "Namespace"sv << CacheRecord.Namespace; + + BatchWriter.BeginArray("Requests"sv); + { + BatchWriter.BeginObject(); + { + const CacheKey& Key = CacheRecord.Key; + BatchWriter.BeginObject("Key"sv); + { + BatchWriter << "Bucket"sv << Key.Bucket; + BatchWriter << "Hash"sv << Key.Hash; + } + BatchWriter.EndObject(); + // Policy unspecified and expected to be Default + BatchWriter.AddBinaryAttachment("RawHash"sv, RawHash); + BatchPackage.AddAttachment(CbAttachment(Compressed, RawHash)); + } + BatchWriter.EndObject(); + } + BatchWriter.EndArray(); + } + BatchWriter.EndObject(); + BatchPackage.SetObject(BatchWriter.Save()); + + Result.Success = false; + for (uint32_t Attempt = 0; Attempt < MaxAttempts && !Result.Success; Attempt++) + { + Result = Session.InvokeRpc(BatchPackage); + } + + m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason); + + TotalBytes += Result.Bytes; + TotalElapsedSeconds += Result.ElapsedSeconds; + } + else + { + for (size_t Idx = 0, Count = Values.size(); Idx < Count; Idx++) + { + Result.Success = false; + for (uint32_t Attempt = 0; Attempt < MaxAttempts && !Result.Success; Attempt++) + { + Result = Session.PutCacheValue(CacheRecord.Namespace, + CacheRecord.Key.Bucket, + CacheRecord.Key.Hash, + CacheRecord.ValueContentIds[Idx], + Values[Idx]); + } + + m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason); + + TotalBytes += Result.Bytes; + TotalElapsedSeconds += Result.ElapsedSeconds; + + if (!Result.Success) + { + return {.Reason = "Failed to upload value", + .Bytes = TotalBytes, + .ElapsedSeconds = TotalElapsedSeconds, + .Success = false}; + } + } + + Result.Success = false; + for (uint32_t Attempt = 0; Attempt < MaxAttempts && !Result.Success; Attempt++) + { + Result = Session.PutCacheRecord(CacheRecord.Namespace, + CacheRecord.Key.Bucket, + CacheRecord.Key.Hash, + RecordValue, + CacheRecord.Type); + } + + m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason); + + TotalBytes += Result.Bytes; + TotalElapsedSeconds += Result.ElapsedSeconds; + } + + return {.Reason = std::move(Result.Reason), + .Bytes = TotalBytes, + .ElapsedSeconds = TotalElapsedSeconds, + .Success = Result.Success}; + } + catch (std::exception& Err) + { + m_Status.Set(UpstreamEndpointState::kError, Err.what()); + + return {.Reason = std::string(Err.what()), .Success = false}; + } + } + + virtual UpstreamEndpointStats& Stats() override { return m_Stats; } + + private: + Ref<ZenStructuredCacheClient> GetClientRef() + { + // m_Client can be modified at any time by a different thread. + // Make sure we safely bump the refcount inside a scope lock + RwLock::SharedLockScope _(m_ClientLock); + ZEN_ASSERT(m_Client); + Ref<ZenStructuredCacheClient> ClientRef(m_Client); + _.ReleaseNow(); + return ClientRef; + } + + const ZenEndpoint& GetEndpoint() + { + for (ZenEndpoint& Ep : m_Endpoints) + { + Ref<ZenStructuredCacheClient> Client( + new ZenStructuredCacheClient({.Url = Ep.Url, .ConnectTimeout = std::chrono::milliseconds(1000)})); + ZenStructuredCacheSession Session(std::move(Client)); + const int32_t SampleCount = 2; + + Ep.Ok = false; + Ep.Latency = {}; + + for (int32_t Sample = 0; Sample < SampleCount; ++Sample) + { + ZenCacheResult Result = Session.CheckHealth(); + Ep.Ok = Result.Success; + Ep.Reason = std::move(Result.Reason); + Ep.Latency += Result.ElapsedSeconds; + } + Ep.Latency /= double(SampleCount); + } + + std::sort(std::begin(m_Endpoints), std::end(m_Endpoints)); + + for (const auto& Ep : m_Endpoints) + { + ZEN_INFO("ping 'Zen' endpoint '{}' latency '{:.3}s' {}", Ep.Url, Ep.Latency, Ep.Ok ? "OK" : Ep.Reason); + } + + return m_Endpoints.front(); + } + + spdlog::logger& Log() { return m_Log; } + + spdlog::logger& m_Log; + UpstreamEndpointInfo m_Info; + UpstreamStatus m_Status; + UpstreamEndpointStats m_Stats; + std::vector<ZenEndpoint> m_Endpoints; + std::chrono::milliseconds m_ConnectTimeout; + std::chrono::milliseconds m_Timeout; + RwLock m_ClientLock; + RefPtr<ZenStructuredCacheClient> m_Client; + }; + +} // namespace detail + +////////////////////////////////////////////////////////////////////////// + +class UpstreamCacheImpl final : public UpstreamCache +{ +public: + UpstreamCacheImpl(const UpstreamCacheOptions& Options, ZenCacheStore& CacheStore, CidStore& CidStore) + : m_Log(logging::Get("upstream")) + , m_Options(Options) + , m_CacheStore(CacheStore) + , m_CidStore(CidStore) + { + } + + virtual ~UpstreamCacheImpl() { Shutdown(); } + + virtual void Initialize() override + { + for (uint32_t Idx = 0; Idx < m_Options.ThreadCount; Idx++) + { + m_UpstreamThreads.emplace_back(&UpstreamCacheImpl::ProcessUpstreamQueue, this); + } + + m_EndpointMonitorThread = std::thread(&UpstreamCacheImpl::MonitorEndpoints, this); + m_RunState.IsRunning = true; + } + + virtual void RegisterEndpoint(std::unique_ptr<UpstreamEndpoint> Endpoint) override + { + const UpstreamEndpointStatus Status = Endpoint->Initialize(); + const UpstreamEndpointInfo& Info = Endpoint->GetEndpointInfo(); + + if (Status.State == UpstreamEndpointState::kOk) + { + ZEN_INFO("register endpoint '{} - {}' {}", Info.Name, Info.Url, ToString(Status.State)); + } + else + { + ZEN_WARN("register endpoint '{} - {}' {}", Info.Name, Info.Url, ToString(Status.State)); + } + + // Register endpoint even if it fails, the health monitor thread will probe failing endpoint(s) + std::unique_lock<std::shared_mutex> _(m_EndpointsMutex); + m_Endpoints.emplace_back(std::move(Endpoint)); + } + + virtual void IterateEndpoints(std::function<bool(UpstreamEndpoint&)>&& Fn) override + { + std::shared_lock<std::shared_mutex> _(m_EndpointsMutex); + + for (auto& Ep : m_Endpoints) + { + if (!Fn(*Ep)) + { + break; + } + } + } + + virtual GetUpstreamCacheSingleResult GetCacheRecord(std::string_view Namespace, const CacheKey& CacheKey, ZenContentType Type) override + { + ZEN_TRACE_CPU("Upstream::GetCacheRecord"); + + std::shared_lock<std::shared_mutex> _(m_EndpointsMutex); + + if (m_Options.ReadUpstream) + { + for (auto& Endpoint : m_Endpoints) + { + if (Endpoint->GetState() != UpstreamEndpointState::kOk) + { + continue; + } + + UpstreamEndpointStats& Stats = Endpoint->Stats(); + metrics::OperationTiming::Scope Scope(Stats.CacheGetRequestTiming); + GetUpstreamCacheSingleResult Result = Endpoint->GetCacheRecord(Namespace, CacheKey, Type); + Scope.Stop(); + + Stats.CacheGetCount.Increment(1); + Stats.CacheGetTotalBytes.Increment(Result.Status.Bytes); + + if (Result.Status.Success) + { + Stats.CacheHitCount.Increment(1); + + return Result; + } + + if (Result.Status.Error) + { + Stats.CacheErrorCount.Increment(1); + + ZEN_WARN("get cache record FAILED, endpoint '{}', reason '{}', error code '{}'", + Endpoint->GetEndpointInfo().Url, + Result.Status.Error.Reason, + Result.Status.Error.ErrorCode); + } + } + } + + return {}; + } + + virtual void GetCacheRecords(std::string_view Namespace, + std::span<CacheKeyRequest*> Requests, + OnCacheRecordGetComplete&& OnComplete) override final + { + ZEN_TRACE_CPU("Upstream::GetCacheRecords"); + + std::shared_lock<std::shared_mutex> _(m_EndpointsMutex); + + std::vector<CacheKeyRequest*> RemainingKeys(Requests.begin(), Requests.end()); + + if (m_Options.ReadUpstream) + { + for (auto& Endpoint : m_Endpoints) + { + if (RemainingKeys.empty()) + { + break; + } + + if (Endpoint->GetState() != UpstreamEndpointState::kOk) + { + continue; + } + + UpstreamEndpointStats& Stats = Endpoint->Stats(); + std::vector<CacheKeyRequest*> Missing; + GetUpstreamCacheResult Result; + { + metrics::OperationTiming::Scope Scope(Stats.CacheGetRequestTiming); + + Result = Endpoint->GetCacheRecords(Namespace, RemainingKeys, [&](CacheRecordGetCompleteParams&& Params) { + if (Params.Record) + { + OnComplete(std::forward<CacheRecordGetCompleteParams>(Params)); + + Stats.CacheHitCount.Increment(1); + } + else + { + Missing.push_back(&Params.Request); + } + }); + } + + Stats.CacheGetCount.Increment(int64_t(RemainingKeys.size())); + Stats.CacheGetTotalBytes.Increment(Result.Bytes); + + if (Result.Error) + { + Stats.CacheErrorCount.Increment(1); + + ZEN_WARN("get cache record(s) (rpc) FAILED, endpoint '{}', reason '{}', error code '{}'", + Endpoint->GetEndpointInfo().Url, + Result.Error.Reason, + Result.Error.ErrorCode); + } + + RemainingKeys = std::move(Missing); + } + } + + const UpstreamEndpointInfo Info; + for (CacheKeyRequest* Request : RemainingKeys) + { + OnComplete({.Request = *Request, .Record = CbObjectView(), .Package = CbPackage()}); + } + } + + virtual void GetCacheChunks(std::string_view Namespace, + std::span<CacheChunkRequest*> CacheChunkRequests, + OnCacheChunksGetComplete&& OnComplete) override final + { + ZEN_TRACE_CPU("Upstream::GetCacheChunks"); + + std::shared_lock<std::shared_mutex> _(m_EndpointsMutex); + + std::vector<CacheChunkRequest*> RemainingKeys(CacheChunkRequests.begin(), CacheChunkRequests.end()); + + if (m_Options.ReadUpstream) + { + for (auto& Endpoint : m_Endpoints) + { + if (RemainingKeys.empty()) + { + break; + } + + if (Endpoint->GetState() != UpstreamEndpointState::kOk) + { + continue; + } + + UpstreamEndpointStats& Stats = Endpoint->Stats(); + std::vector<CacheChunkRequest*> Missing; + GetUpstreamCacheResult Result; + { + metrics::OperationTiming::Scope Scope(Endpoint->Stats().CacheGetRequestTiming); + + Result = Endpoint->GetCacheChunks(Namespace, RemainingKeys, [&](CacheChunkGetCompleteParams&& Params) { + if (Params.RawHash != Params.RawHash.Zero) + { + OnComplete(std::forward<CacheChunkGetCompleteParams>(Params)); + + Stats.CacheHitCount.Increment(1); + } + else + { + Missing.push_back(&Params.Request); + } + }); + } + + Stats.CacheGetCount.Increment(int64_t(RemainingKeys.size())); + Stats.CacheGetTotalBytes.Increment(Result.Bytes); + + if (Result.Error) + { + Stats.CacheErrorCount.Increment(1); + + ZEN_WARN("get cache chunks(s) (rpc) FAILED, endpoint '{}', reason '{}', error code '{}'", + Endpoint->GetEndpointInfo().Url, + Result.Error.Reason, + Result.Error.ErrorCode); + } + + RemainingKeys = std::move(Missing); + } + } + + const UpstreamEndpointInfo Info; + for (CacheChunkRequest* RequestPtr : RemainingKeys) + { + OnComplete({.Request = *RequestPtr, .RawHash = IoHash::Zero, .RawSize = 0, .Value = IoBuffer()}); + } + } + + virtual GetUpstreamCacheSingleResult GetCacheChunk(std::string_view Namespace, + const CacheKey& CacheKey, + const IoHash& ValueContentId) override + { + ZEN_TRACE_CPU("Upstream::GetCacheChunk"); + + if (m_Options.ReadUpstream) + { + for (auto& Endpoint : m_Endpoints) + { + if (Endpoint->GetState() != UpstreamEndpointState::kOk) + { + continue; + } + + UpstreamEndpointStats& Stats = Endpoint->Stats(); + metrics::OperationTiming::Scope Scope(Stats.CacheGetRequestTiming); + GetUpstreamCacheSingleResult Result = Endpoint->GetCacheChunk(Namespace, CacheKey, ValueContentId); + Scope.Stop(); + + Stats.CacheGetCount.Increment(1); + Stats.CacheGetTotalBytes.Increment(Result.Status.Bytes); + + if (Result.Status.Success) + { + Stats.CacheHitCount.Increment(1); + + return Result; + } + + if (Result.Status.Error) + { + Stats.CacheErrorCount.Increment(1); + + ZEN_WARN("get cache chunk FAILED, endpoint '{}', reason '{}', error code '{}'", + Endpoint->GetEndpointInfo().Url, + Result.Status.Error.Reason, + Result.Status.Error.ErrorCode); + } + } + } + + return {}; + } + + virtual void GetCacheValues(std::string_view Namespace, + std::span<CacheValueRequest*> CacheValueRequests, + OnCacheValueGetComplete&& OnComplete) override final + { + ZEN_TRACE_CPU("Upstream::GetCacheValues"); + + std::shared_lock<std::shared_mutex> _(m_EndpointsMutex); + + std::vector<CacheValueRequest*> RemainingKeys(CacheValueRequests.begin(), CacheValueRequests.end()); + + if (m_Options.ReadUpstream) + { + for (auto& Endpoint : m_Endpoints) + { + if (RemainingKeys.empty()) + { + break; + } + + if (Endpoint->GetState() != UpstreamEndpointState::kOk) + { + continue; + } + + UpstreamEndpointStats& Stats = Endpoint->Stats(); + std::vector<CacheValueRequest*> Missing; + GetUpstreamCacheResult Result; + { + metrics::OperationTiming::Scope Scope(Endpoint->Stats().CacheGetRequestTiming); + + Result = Endpoint->GetCacheValues(Namespace, RemainingKeys, [&](CacheValueGetCompleteParams&& Params) { + if (Params.RawHash != Params.RawHash.Zero) + { + OnComplete(std::forward<CacheValueGetCompleteParams>(Params)); + + Stats.CacheHitCount.Increment(1); + } + else + { + Missing.push_back(&Params.Request); + } + }); + } + + Stats.CacheGetCount.Increment(int64_t(RemainingKeys.size())); + Stats.CacheGetTotalBytes.Increment(Result.Bytes); + + if (Result.Error) + { + Stats.CacheErrorCount.Increment(1); + + ZEN_WARN("get cache values(s) (rpc) FAILED, endpoint '{}', reason '{}', error code '{}'", + Endpoint->GetEndpointInfo().Url, + Result.Error.Reason, + Result.Error.ErrorCode); + } + + RemainingKeys = std::move(Missing); + } + } + + const UpstreamEndpointInfo Info; + for (CacheValueRequest* RequestPtr : RemainingKeys) + { + OnComplete({.Request = *RequestPtr, .RawHash = IoHash::Zero, .RawSize = 0, .Value = IoBuffer()}); + } + } + + virtual void EnqueueUpstream(UpstreamCacheRecord CacheRecord) override + { + if (m_RunState.IsRunning && m_Options.WriteUpstream && m_Endpoints.size() > 0) + { + if (!m_UpstreamThreads.empty()) + { + m_UpstreamQueue.Enqueue(std::move(CacheRecord)); + } + else + { + ProcessCacheRecord(std::move(CacheRecord)); + } + } + } + + virtual void GetStatus(CbObjectWriter& Status) override + { + Status << "reading" << m_Options.ReadUpstream; + Status << "writing" << m_Options.WriteUpstream; + Status << "worker_threads" << m_Options.ThreadCount; + Status << "queue_count" << m_UpstreamQueue.Size(); + + Status.BeginArray("endpoints"); + for (const auto& Ep : m_Endpoints) + { + const UpstreamEndpointInfo& EpInfo = Ep->GetEndpointInfo(); + const UpstreamEndpointStatus EpStatus = Ep->GetStatus(); + UpstreamEndpointStats& EpStats = Ep->Stats(); + + Status.BeginObject(); + Status << "name" << EpInfo.Name; + Status << "url" << EpInfo.Url; + Status << "state" << ToString(EpStatus.State); + Status << "reason" << EpStatus.Reason; + + Status.BeginObject("cache"sv); + { + const int64_t GetCount = EpStats.CacheGetCount.Value(); + const int64_t HitCount = EpStats.CacheHitCount.Value(); + const int64_t ErrorCount = EpStats.CacheErrorCount.Value(); + const double HitRatio = GetCount > 0 ? double(HitCount) / double(GetCount) : 0.0; + const double ErrorRatio = GetCount > 0 ? double(ErrorCount) / double(GetCount) : 0.0; + + metrics::EmitSnapshot("get_requests"sv, EpStats.CacheGetRequestTiming, Status); + Status << "get_bytes" << EpStats.CacheGetTotalBytes.Value(); + Status << "get_count" << GetCount; + Status << "hit_count" << HitCount; + Status << "hit_ratio" << HitRatio; + Status << "error_count" << ErrorCount; + Status << "error_ratio" << ErrorRatio; + metrics::EmitSnapshot("put_requests"sv, EpStats.CachePutRequestTiming, Status); + Status << "put_bytes" << EpStats.CachePutTotalBytes.Value(); + } + Status.EndObject(); + + Status.EndObject(); + } + Status.EndArray(); + } + +private: + void ProcessCacheRecord(UpstreamCacheRecord CacheRecord) + { + ZEN_TRACE_CPU("Upstream::ProcessCacheRecord"); + + ZenCacheValue CacheValue; + std::vector<IoBuffer> Payloads; + + if (!m_CacheStore.Get(CacheRecord.Namespace, CacheRecord.Key.Bucket, CacheRecord.Key.Hash, CacheValue)) + { + ZEN_WARN("process upstream FAILED, '{}/{}/{}', cache record doesn't exist", + CacheRecord.Namespace, + CacheRecord.Key.Bucket, + CacheRecord.Key.Hash); + return; + } + + for (const IoHash& ValueContentId : CacheRecord.ValueContentIds) + { + if (IoBuffer Payload = m_CidStore.FindChunkByCid(ValueContentId)) + { + Payloads.push_back(Payload); + } + else + { + ZEN_WARN("process upstream FAILED, '{}/{}/{}/{}', ValueContentId doesn't exist in CAS", + CacheRecord.Namespace, + CacheRecord.Key.Bucket, + CacheRecord.Key.Hash, + ValueContentId); + return; + } + } + + std::shared_lock<std::shared_mutex> _(m_EndpointsMutex); + + for (auto& Endpoint : m_Endpoints) + { + if (Endpoint->GetState() != UpstreamEndpointState::kOk) + { + continue; + } + + UpstreamEndpointStats& Stats = Endpoint->Stats(); + PutUpstreamCacheResult Result; + { + metrics::OperationTiming::Scope Scope(Stats.CachePutRequestTiming); + Result = Endpoint->PutCacheRecord(CacheRecord, CacheValue.Value, std::span(Payloads)); + } + + Stats.CachePutTotalBytes.Increment(Result.Bytes); + + if (!Result.Success) + { + ZEN_WARN("upload cache record '{}/{}/{}' FAILED, endpoint '{}', reason '{}'", + CacheRecord.Namespace, + CacheRecord.Key.Bucket, + CacheRecord.Key.Hash, + Endpoint->GetEndpointInfo().Url, + Result.Reason); + } + } + } + + void ProcessUpstreamQueue() + { + for (;;) + { + UpstreamCacheRecord CacheRecord; + if (m_UpstreamQueue.WaitAndDequeue(CacheRecord)) + { + try + { + ProcessCacheRecord(std::move(CacheRecord)); + } + catch (std::exception& Err) + { + ZEN_ERROR("upload cache record '{}/{}/{}' FAILED, reason '{}'", + CacheRecord.Namespace, + CacheRecord.Key.Bucket, + CacheRecord.Key.Hash, + Err.what()); + } + } + + if (!m_RunState.IsRunning) + { + break; + } + } + } + + void MonitorEndpoints() + { + for (;;) + { + { + std::unique_lock lk(m_RunState.Mutex); + if (m_RunState.ExitSignal.wait_for(lk, m_Options.HealthCheckInterval, [this]() { return !m_RunState.IsRunning.load(); })) + { + break; + } + } + + try + { + std::vector<UpstreamEndpoint*> Endpoints; + + { + std::shared_lock<std::shared_mutex> _(m_EndpointsMutex); + + for (auto& Endpoint : m_Endpoints) + { + UpstreamEndpointState State = Endpoint->GetState(); + if (State == UpstreamEndpointState::kError) + { + Endpoints.push_back(Endpoint.get()); + ZEN_WARN("HEALTH - endpoint '{} - {}' is in error state '{}'", + Endpoint->GetEndpointInfo().Name, + Endpoint->GetEndpointInfo().Url, + Endpoint->GetStatus().Reason); + } + if (State == UpstreamEndpointState::kUnauthorized) + { + Endpoints.push_back(Endpoint.get()); + } + } + } + + for (auto& Endpoint : Endpoints) + { + const UpstreamEndpointInfo& Info = Endpoint->GetEndpointInfo(); + const UpstreamEndpointStatus Status = Endpoint->Initialize(); + + if (Status.State == UpstreamEndpointState::kOk) + { + ZEN_INFO("HEALTH - endpoint '{} - {}' Ok", Info.Name, Info.Url); + } + else + { + const std::string Reason = Status.Reason.empty() ? "" : fmt::format(", reason '{}'", Status.Reason); + ZEN_WARN("HEALTH - endpoint '{} - {}' {} {}", Info.Name, Info.Url, ToString(Status.State), Reason); + } + } + } + catch (std::exception& Err) + { + ZEN_ERROR("check endpoint(s) health FAILED, reason '{}'", Err.what()); + } + } + } + + void Shutdown() + { + if (m_RunState.Stop()) + { + m_UpstreamQueue.CompleteAdding(); + for (std::thread& Thread : m_UpstreamThreads) + { + Thread.join(); + } + + m_EndpointMonitorThread.join(); + m_UpstreamThreads.clear(); + m_Endpoints.clear(); + } + } + + spdlog::logger& Log() { return m_Log; } + + using UpstreamQueue = BlockingQueue<UpstreamCacheRecord>; + + struct RunState + { + std::mutex Mutex; + std::condition_variable ExitSignal; + std::atomic_bool IsRunning{false}; + + bool Stop() + { + bool Stopped = false; + { + std::lock_guard _(Mutex); + Stopped = IsRunning.exchange(false); + } + if (Stopped) + { + ExitSignal.notify_all(); + } + return Stopped; + } + }; + + spdlog::logger& m_Log; + UpstreamCacheOptions m_Options; + ZenCacheStore& m_CacheStore; + CidStore& m_CidStore; + UpstreamQueue m_UpstreamQueue; + std::shared_mutex m_EndpointsMutex; + std::vector<std::unique_ptr<UpstreamEndpoint>> m_Endpoints; + std::vector<std::thread> m_UpstreamThreads; + std::thread m_EndpointMonitorThread; + RunState m_RunState; +}; + +////////////////////////////////////////////////////////////////////////// + +std::unique_ptr<UpstreamEndpoint> +UpstreamEndpoint::CreateZenEndpoint(const ZenStructuredCacheClientOptions& Options) +{ + return std::make_unique<detail::ZenUpstreamEndpoint>(Options); +} + +std::unique_ptr<UpstreamEndpoint> +UpstreamEndpoint::CreateJupiterEndpoint(const CloudCacheClientOptions& Options, const UpstreamAuthConfig& AuthConfig, AuthMgr& Mgr) +{ + return std::make_unique<detail::JupiterUpstreamEndpoint>(Options, AuthConfig, Mgr); +} + +std::unique_ptr<UpstreamCache> +UpstreamCache::Create(const UpstreamCacheOptions& Options, ZenCacheStore& CacheStore, CidStore& CidStore) +{ + return std::make_unique<UpstreamCacheImpl>(Options, CacheStore, CidStore); +} + +} // namespace zen diff --git a/src/zenserver/upstream/upstreamcache.h b/src/zenserver/upstream/upstreamcache.h new file mode 100644 index 000000000..695c06b32 --- /dev/null +++ b/src/zenserver/upstream/upstreamcache.h @@ -0,0 +1,252 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/compactbinary.h> +#include <zencore/compress.h> +#include <zencore/iobuffer.h> +#include <zencore/iohash.h> +#include <zencore/stats.h> +#include <zencore/zencore.h> +#include <zenutil/cache/cache.h> + +#include <atomic> +#include <chrono> +#include <functional> +#include <memory> +#include <vector> + +namespace zen { + +class CbObjectView; +class AuthMgr; +class CbObjectView; +class CbPackage; +class CbObjectWriter; +class CidStore; +class ZenCacheStore; +struct CloudCacheClientOptions; +class CloudCacheTokenProvider; +struct ZenStructuredCacheClientOptions; + +struct UpstreamCacheRecord +{ + ZenContentType Type = ZenContentType::kBinary; + std::string Namespace; + CacheKey Key; + std::vector<IoHash> ValueContentIds; +}; + +struct UpstreamCacheOptions +{ + std::chrono::seconds HealthCheckInterval{5}; + uint32_t ThreadCount = 4; + bool ReadUpstream = true; + bool WriteUpstream = true; +}; + +struct UpstreamError +{ + int32_t ErrorCode{}; + std::string Reason{}; + + explicit operator bool() const { return ErrorCode != 0; } +}; + +struct UpstreamEndpointInfo +{ + std::string Name; + std::string Url; +}; + +struct GetUpstreamCacheResult +{ + UpstreamError Error{}; + int64_t Bytes{}; + double ElapsedSeconds{}; + bool Success = false; +}; + +struct GetUpstreamCacheSingleResult +{ + GetUpstreamCacheResult Status; + IoBuffer Value; + const UpstreamEndpointInfo* Source = nullptr; +}; + +struct PutUpstreamCacheResult +{ + std::string Reason; + int64_t Bytes{}; + double ElapsedSeconds{}; + bool Success = false; +}; + +struct CacheRecordGetCompleteParams +{ + CacheKeyRequest& Request; + const CbObjectView& Record; + const CbPackage& Package; + double ElapsedSeconds{}; + const UpstreamEndpointInfo* Source = nullptr; +}; + +using OnCacheRecordGetComplete = std::function<void(CacheRecordGetCompleteParams&&)>; + +struct CacheValueGetCompleteParams +{ + CacheValueRequest& Request; + IoHash RawHash; + uint64_t RawSize; + IoBuffer Value; + double ElapsedSeconds{}; + const UpstreamEndpointInfo* Source = nullptr; +}; + +using OnCacheValueGetComplete = std::function<void(CacheValueGetCompleteParams&&)>; + +struct CacheChunkGetCompleteParams +{ + CacheChunkRequest& Request; + IoHash RawHash; + uint64_t RawSize; + IoBuffer Value; + double ElapsedSeconds{}; + const UpstreamEndpointInfo* Source = nullptr; +}; + +using OnCacheChunksGetComplete = std::function<void(CacheChunkGetCompleteParams&&)>; + +struct UpstreamEndpointStats +{ + metrics::OperationTiming CacheGetRequestTiming; + metrics::OperationTiming CachePutRequestTiming; + metrics::Counter CacheGetTotalBytes; + metrics::Counter CachePutTotalBytes; + metrics::Counter CacheGetCount; + metrics::Counter CacheHitCount; + metrics::Counter CacheErrorCount; +}; + +enum class UpstreamEndpointState : uint32_t +{ + kDisabled, + kUnauthorized, + kError, + kOk +}; + +inline std::string_view +ToString(UpstreamEndpointState State) +{ + using namespace std::literals; + + switch (State) + { + case UpstreamEndpointState::kDisabled: + return "Disabled"sv; + case UpstreamEndpointState::kUnauthorized: + return "Unauthorized"sv; + case UpstreamEndpointState::kError: + return "Error"sv; + case UpstreamEndpointState::kOk: + return "Ok"sv; + default: + return "Unknown"sv; + } +} + +struct UpstreamAuthConfig +{ + std::string_view OAuthUrl; + std::string_view OAuthClientId; + std::string_view OAuthClientSecret; + std::string_view OpenIdProvider; + std::string_view AccessToken; +}; + +struct UpstreamEndpointStatus +{ + std::string Reason; + UpstreamEndpointState State; +}; + +/** + * The upstream endpoint is responsible for handling upload/downloading of cache records. + */ +class UpstreamEndpoint +{ +public: + virtual ~UpstreamEndpoint() = default; + + virtual UpstreamEndpointStatus Initialize() = 0; + + virtual const UpstreamEndpointInfo& GetEndpointInfo() const = 0; + + virtual UpstreamEndpointState GetState() = 0; + virtual UpstreamEndpointStatus GetStatus() = 0; + + virtual GetUpstreamCacheSingleResult GetCacheRecord(std::string_view Namespace, const CacheKey& CacheKey, ZenContentType Type) = 0; + virtual GetUpstreamCacheResult GetCacheRecords(std::string_view Namespace, + std::span<CacheKeyRequest*> Requests, + OnCacheRecordGetComplete&& OnComplete) = 0; + + virtual GetUpstreamCacheResult GetCacheValues(std::string_view Namespace, + std::span<CacheValueRequest*> CacheValueRequests, + OnCacheValueGetComplete&& OnComplete) = 0; + + virtual GetUpstreamCacheSingleResult GetCacheChunk(std::string_view Namespace, const CacheKey& CacheKey, const IoHash& PayloadId) = 0; + virtual GetUpstreamCacheResult GetCacheChunks(std::string_view Namespace, + std::span<CacheChunkRequest*> CacheChunkRequests, + OnCacheChunksGetComplete&& OnComplete) = 0; + + virtual PutUpstreamCacheResult PutCacheRecord(const UpstreamCacheRecord& CacheRecord, + IoBuffer RecordValue, + std::span<IoBuffer const> Payloads) = 0; + + virtual UpstreamEndpointStats& Stats() = 0; + + static std::unique_ptr<UpstreamEndpoint> CreateZenEndpoint(const ZenStructuredCacheClientOptions& Options); + + static std::unique_ptr<UpstreamEndpoint> CreateJupiterEndpoint(const CloudCacheClientOptions& Options, + const UpstreamAuthConfig& AuthConfig, + AuthMgr& Mgr); +}; + +/** + * Manages one or more upstream cache endpoints. + */ +class UpstreamCache +{ +public: + virtual ~UpstreamCache() = default; + + virtual void Initialize() = 0; + + virtual void RegisterEndpoint(std::unique_ptr<UpstreamEndpoint> Endpoint) = 0; + virtual void IterateEndpoints(std::function<bool(UpstreamEndpoint&)>&& Fn) = 0; + + virtual GetUpstreamCacheSingleResult GetCacheRecord(std::string_view Namespace, const CacheKey& CacheKey, ZenContentType Type) = 0; + virtual void GetCacheRecords(std::string_view Namespace, + std::span<CacheKeyRequest*> Requests, + OnCacheRecordGetComplete&& OnComplete) = 0; + + virtual void GetCacheValues(std::string_view Namespace, + std::span<CacheValueRequest*> CacheValueRequests, + OnCacheValueGetComplete&& OnComplete) = 0; + + virtual GetUpstreamCacheSingleResult GetCacheChunk(std::string_view Namespace, + const CacheKey& CacheKey, + const IoHash& ValueContentId) = 0; + virtual void GetCacheChunks(std::string_view Namespace, + std::span<CacheChunkRequest*> CacheChunkRequests, + OnCacheChunksGetComplete&& OnComplete) = 0; + + virtual void EnqueueUpstream(UpstreamCacheRecord CacheRecord) = 0; + + virtual void GetStatus(CbObjectWriter& CbO) = 0; + + static std::unique_ptr<UpstreamCache> Create(const UpstreamCacheOptions& Options, ZenCacheStore& CacheStore, CidStore& CidStore); +}; + +} // namespace zen diff --git a/src/zenserver/upstream/upstreamservice.cpp b/src/zenserver/upstream/upstreamservice.cpp new file mode 100644 index 000000000..6db1357c5 --- /dev/null +++ b/src/zenserver/upstream/upstreamservice.cpp @@ -0,0 +1,56 @@ +// Copyright Epic Games, Inc. All Rights Reserved. +#include <upstream/upstreamservice.h> + +#include <auth/authmgr.h> +#include <upstream/upstreamcache.h> + +#include <zencore/compactbinarybuilder.h> +#include <zencore/string.h> + +namespace zen { + +using namespace std::literals; + +HttpUpstreamService::HttpUpstreamService(UpstreamCache& Upstream, AuthMgr& Mgr) : m_Upstream(Upstream), m_AuthMgr(Mgr) +{ + m_Router.RegisterRoute( + "endpoints", + [this](HttpRouterRequest& Req) { + CbObjectWriter Writer; + Writer.BeginArray("Endpoints"sv); + m_Upstream.IterateEndpoints([&Writer](UpstreamEndpoint& Ep) { + UpstreamEndpointInfo Info = Ep.GetEndpointInfo(); + UpstreamEndpointStatus Status = Ep.GetStatus(); + + Writer.BeginObject(); + Writer << "Name"sv << Info.Name; + Writer << "Url"sv << Info.Url; + Writer << "State"sv << ToString(Status.State); + Writer << "Reason"sv << Status.Reason; + Writer.EndObject(); + + return true; + }); + Writer.EndArray(); + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Writer.Save()); + }, + HttpVerb::kGet); +} + +HttpUpstreamService::~HttpUpstreamService() +{ +} + +const char* +HttpUpstreamService::BaseUri() const +{ + return "/upstream/"; +} + +void +HttpUpstreamService::HandleRequest(zen::HttpServerRequest& Request) +{ + m_Router.HandleRequest(Request); +} + +} // namespace zen diff --git a/src/zenserver/upstream/upstreamservice.h b/src/zenserver/upstream/upstreamservice.h new file mode 100644 index 000000000..f1da03c8c --- /dev/null +++ b/src/zenserver/upstream/upstreamservice.h @@ -0,0 +1,27 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenhttp/httpserver.h> + +namespace zen { + +class AuthMgr; +class UpstreamCache; + +class HttpUpstreamService final : public zen::HttpService +{ +public: + HttpUpstreamService(UpstreamCache& Upstream, AuthMgr& Mgr); + virtual ~HttpUpstreamService(); + + virtual const char* BaseUri() const override; + virtual void HandleRequest(zen::HttpServerRequest& Request) override; + +private: + UpstreamCache& m_Upstream; + AuthMgr& m_AuthMgr; + HttpRequestRouter m_Router; +}; + +} // namespace zen diff --git a/src/zenserver/upstream/zen.cpp b/src/zenserver/upstream/zen.cpp new file mode 100644 index 000000000..9e1212834 --- /dev/null +++ b/src/zenserver/upstream/zen.cpp @@ -0,0 +1,326 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "zen.h" + +#include <zencore/compactbinarybuilder.h> +#include <zencore/compactbinarypackage.h> +#include <zencore/compactbinaryvalidation.h> +#include <zencore/fmtutils.h> +#include <zencore/session.h> +#include <zencore/stream.h> +#include <zenhttp/httpcommon.h> +#include <zenhttp/httpshared.h> + +#include "cache/structuredcachestore.h" +#include "diag/formatters.h" +#include "diag/logging.h" + +ZEN_THIRD_PARTY_INCLUDES_START +#include <cpr/cpr.h> +ZEN_THIRD_PARTY_INCLUDES_END + +#include <xxhash.h> +#include <gsl/gsl-lite.hpp> + +namespace zen { + +namespace detail { + struct ZenCacheSessionState + { + ZenCacheSessionState(ZenStructuredCacheClient& Client) : OwnerClient(Client) {} + ~ZenCacheSessionState() {} + + void Reset(std::chrono::milliseconds ConnectTimeout, std::chrono::milliseconds Timeout) + { + Session.SetBody({}); + Session.SetHeader({}); + Session.SetConnectTimeout(ConnectTimeout); + Session.SetTimeout(Timeout); + } + + cpr::Session& GetSession() { return Session; } + + private: + ZenStructuredCacheClient& OwnerClient; + cpr::Session Session; + }; + +} // namespace detail + +////////////////////////////////////////////////////////////////////////// + +ZenStructuredCacheClient::ZenStructuredCacheClient(const ZenStructuredCacheClientOptions& Options) +: m_Log(logging::Get(std::string_view("zenclient"))) +, m_ServiceUrl(Options.Url) +, m_ConnectTimeout(Options.ConnectTimeout) +, m_Timeout(Options.Timeout) +{ +} + +ZenStructuredCacheClient::~ZenStructuredCacheClient() +{ +} + +detail::ZenCacheSessionState* +ZenStructuredCacheClient::AllocSessionState() +{ + detail::ZenCacheSessionState* State = nullptr; + + if (RwLock::ExclusiveLockScope _(m_SessionStateLock); !m_SessionStateCache.empty()) + { + State = m_SessionStateCache.front(); + m_SessionStateCache.pop_front(); + } + + if (State == nullptr) + { + State = new detail::ZenCacheSessionState(*this); + } + + State->Reset(m_ConnectTimeout, m_Timeout); + + return State; +} + +void +ZenStructuredCacheClient::FreeSessionState(detail::ZenCacheSessionState* State) +{ + RwLock::ExclusiveLockScope _(m_SessionStateLock); + m_SessionStateCache.push_front(State); +} + +////////////////////////////////////////////////////////////////////////// + +using namespace std::literals; + +ZenStructuredCacheSession::ZenStructuredCacheSession(Ref<ZenStructuredCacheClient>&& OuterClient) +: m_Log(OuterClient->Log()) +, m_Client(std::move(OuterClient)) +{ + m_SessionState = m_Client->AllocSessionState(); +} + +ZenStructuredCacheSession::~ZenStructuredCacheSession() +{ + m_Client->FreeSessionState(m_SessionState); +} + +ZenCacheResult +ZenStructuredCacheSession::CheckHealth() +{ + ExtendableStringBuilder<256> Uri; + Uri << m_Client->ServiceUrl() << "/health/check"; + + cpr::Session& Session = m_SessionState->GetSession(); + Session.SetOption(cpr::Url{Uri.c_str()}); + cpr::Response Response = Session.Get(); + + if (Response.error) + { + return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = std::move(Response.error.message)}; + } + + return {.Bytes = Response.downloaded_bytes, .ElapsedSeconds = Response.elapsed, .Success = Response.status_code == 200}; +} + +ZenCacheResult +ZenStructuredCacheSession::GetCacheRecord(std::string_view Namespace, std::string_view BucketId, const IoHash& Key, ZenContentType Type) +{ + ExtendableStringBuilder<256> Uri; + Uri << m_Client->ServiceUrl() << "/z$/"; + if (Namespace != ZenCacheStore::DefaultNamespace) + { + Uri << Namespace << "/"; + } + Uri << BucketId << "/" << Key.ToHexString(); + + cpr::Session& Session = m_SessionState->GetSession(); + + Session.SetOption(cpr::Url{Uri.c_str()}); + Session.SetHeader(cpr::Header{{"Accept", std::string{MapContentTypeToString(Type)}}}); + cpr::Response Response = Session.Get(); + ZEN_DEBUG("GET {}", Response); + + if (Response.error) + { + return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = std::move(Response.error.message)}; + } + + const bool Success = Response.status_code == 200; + const IoBuffer Buffer = Success ? IoBufferBuilder::MakeCloneFromMemory(Response.text.data(), Response.text.size()) : IoBuffer(); + + return {.Response = Buffer, .Bytes = Response.downloaded_bytes, .ElapsedSeconds = Response.elapsed, .Success = Success}; +} + +ZenCacheResult +ZenStructuredCacheSession::GetCacheChunk(std::string_view Namespace, + std::string_view BucketId, + const IoHash& Key, + const IoHash& ValueContentId) +{ + ExtendableStringBuilder<256> Uri; + Uri << m_Client->ServiceUrl() << "/z$/"; + if (Namespace != ZenCacheStore::DefaultNamespace) + { + Uri << Namespace << "/"; + } + Uri << BucketId << "/" << Key.ToHexString() << "/" << ValueContentId.ToHexString(); + + cpr::Session& Session = m_SessionState->GetSession(); + + Session.SetOption(cpr::Url{Uri.c_str()}); + Session.SetHeader(cpr::Header{{"Accept", "application/x-ue-comp"}}); + + cpr::Response Response = Session.Get(); + ZEN_DEBUG("GET {}", Response); + + if (Response.error) + { + return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = std::move(Response.error.message)}; + } + + const bool Success = Response.status_code == 200; + const IoBuffer Buffer = Success ? IoBufferBuilder::MakeCloneFromMemory(Response.text.data(), Response.text.size()) : IoBuffer(); + + return {.Response = Buffer, + .Bytes = Response.downloaded_bytes, + .ElapsedSeconds = Response.elapsed, + .Reason = Response.reason, + .Success = Success}; +} + +ZenCacheResult +ZenStructuredCacheSession::PutCacheRecord(std::string_view Namespace, + std::string_view BucketId, + const IoHash& Key, + IoBuffer Value, + ZenContentType Type) +{ + ExtendableStringBuilder<256> Uri; + Uri << m_Client->ServiceUrl() << "/z$/"; + if (Namespace != ZenCacheStore::DefaultNamespace) + { + Uri << Namespace << "/"; + } + Uri << BucketId << "/" << Key.ToHexString(); + + cpr::Session& Session = m_SessionState->GetSession(); + + Session.SetOption(cpr::Url{Uri.c_str()}); + Session.SetHeader(cpr::Header{{"Content-Type", + Type == ZenContentType::kCbPackage ? "application/x-ue-cbpkg" + : Type == ZenContentType::kCbObject ? "application/x-ue-cb" + : "application/octet-stream"}}); + Session.SetBody(cpr::Body{static_cast<const char*>(Value.Data()), Value.Size()}); + + cpr::Response Response = Session.Put(); + ZEN_DEBUG("PUT {}", Response); + + if (Response.error) + { + return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = std::move(Response.error.message)}; + } + + const bool Success = Response.status_code == 200 || Response.status_code == 201; + return {.Bytes = Response.uploaded_bytes, .ElapsedSeconds = Response.elapsed, .Reason = Response.reason, .Success = Success}; +} + +ZenCacheResult +ZenStructuredCacheSession::PutCacheValue(std::string_view Namespace, + std::string_view BucketId, + const IoHash& Key, + const IoHash& ValueContentId, + IoBuffer Payload) +{ + ExtendableStringBuilder<256> Uri; + Uri << m_Client->ServiceUrl() << "/z$/"; + if (Namespace != ZenCacheStore::DefaultNamespace) + { + Uri << Namespace << "/"; + } + Uri << BucketId << "/" << Key.ToHexString() << "/" << ValueContentId.ToHexString(); + + cpr::Session& Session = m_SessionState->GetSession(); + + Session.SetOption(cpr::Url{Uri.c_str()}); + Session.SetHeader(cpr::Header{{"Content-Type", "application/x-ue-comp"}}); + Session.SetBody(cpr::Body{static_cast<const char*>(Payload.Data()), Payload.Size()}); + + cpr::Response Response = Session.Put(); + ZEN_DEBUG("PUT {}", Response); + + if (Response.error) + { + return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = std::move(Response.error.message)}; + } + + const bool Success = Response.status_code == 200 || Response.status_code == 201; + return {.Bytes = Response.uploaded_bytes, .ElapsedSeconds = Response.elapsed, .Reason = Response.reason, .Success = Success}; +} + +ZenCacheResult +ZenStructuredCacheSession::InvokeRpc(const CbObjectView& Request) +{ + ExtendableStringBuilder<256> Uri; + Uri << m_Client->ServiceUrl() << "/z$/$rpc"; + + BinaryWriter Body; + Request.CopyTo(Body); + + cpr::Session& Session = m_SessionState->GetSession(); + + Session.SetOption(cpr::Url{Uri.c_str()}); + Session.SetHeader(cpr::Header{{"Content-Type", "application/x-ue-cb"}, {"Accept", "application/x-ue-cbpkg"}}); + Session.SetBody(cpr::Body{reinterpret_cast<const char*>(Body.GetData()), Body.GetSize()}); + + cpr::Response Response = Session.Post(); + ZEN_DEBUG("POST {}", Response); + + if (Response.error) + { + return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = std::move(Response.error.message)}; + } + + const bool Success = Response.status_code == 200; + const IoBuffer Buffer = Success ? IoBufferBuilder::MakeCloneFromMemory(Response.text.data(), Response.text.size()) : IoBuffer(); + + return {.Response = std::move(Buffer), + .Bytes = Response.uploaded_bytes, + .ElapsedSeconds = Response.elapsed, + .Reason = Response.reason, + .Success = Success}; +} + +ZenCacheResult +ZenStructuredCacheSession::InvokeRpc(const CbPackage& Request) +{ + ExtendableStringBuilder<256> Uri; + Uri << m_Client->ServiceUrl() << "/z$/$rpc"; + + SharedBuffer Message = FormatPackageMessageBuffer(Request).Flatten(); + + cpr::Session& Session = m_SessionState->GetSession(); + + Session.SetOption(cpr::Url{Uri.c_str()}); + Session.SetHeader(cpr::Header{{"Content-Type", "application/x-ue-cbpkg"}, {"Accept", "application/x-ue-cbpkg"}}); + Session.SetBody(cpr::Body{reinterpret_cast<const char*>(Message.GetData()), Message.GetSize()}); + + cpr::Response Response = Session.Post(); + ZEN_DEBUG("POST {}", Response); + + if (Response.error) + { + return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = std::move(Response.error.message)}; + } + + const bool Success = Response.status_code == 200; + const IoBuffer Buffer = Success ? IoBufferBuilder::MakeCloneFromMemory(Response.text.data(), Response.text.size()) : IoBuffer(); + + return {.Response = std::move(Buffer), + .Bytes = Response.uploaded_bytes, + .ElapsedSeconds = Response.elapsed, + .Reason = Response.reason, + .Success = Success}; +} + +} // namespace zen diff --git a/src/zenserver/upstream/zen.h b/src/zenserver/upstream/zen.h new file mode 100644 index 000000000..bfba8fa98 --- /dev/null +++ b/src/zenserver/upstream/zen.h @@ -0,0 +1,125 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/iobuffer.h> +#include <zencore/iohash.h> +#include <zencore/logging.h> +#include <zencore/memory.h> +#include <zencore/thread.h> +#include <zencore/uid.h> +#include <zencore/zencore.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <tsl/robin_map.h> +#include <asio.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +#include <chrono> +#include <list> + +struct ZenCacheValue; + +namespace spdlog { +class logger; +} + +namespace zen { + +class CbObjectWriter; +class CbObjectView; +class CbPackage; +class ZenStructuredCacheClient; + +////////////////////////////////////////////////////////////////////////// + +namespace detail { + struct ZenCacheSessionState; +} + +struct ZenCacheResult +{ + IoBuffer Response; + int64_t Bytes = {}; + double ElapsedSeconds = {}; + int32_t ErrorCode = {}; + std::string Reason; + bool Success = false; +}; + +struct ZenStructuredCacheClientOptions +{ + std::string_view Name; + std::string_view Url; + std::span<std::string const> Urls; + std::chrono::milliseconds ConnectTimeout{}; + std::chrono::milliseconds Timeout{}; +}; + +/** Zen Structured Cache session + * + * This provides a context in which cache queries can be performed + * + * These are currently all synchronous. Will need to be made asynchronous + */ +class ZenStructuredCacheSession +{ +public: + ZenStructuredCacheSession(Ref<ZenStructuredCacheClient>&& OuterClient); + ~ZenStructuredCacheSession(); + + ZenCacheResult CheckHealth(); + ZenCacheResult GetCacheRecord(std::string_view Namespace, std::string_view BucketId, const IoHash& Key, ZenContentType Type); + ZenCacheResult GetCacheChunk(std::string_view Namespace, std::string_view BucketId, const IoHash& Key, const IoHash& ValueContentId); + ZenCacheResult PutCacheRecord(std::string_view Namespace, + std::string_view BucketId, + const IoHash& Key, + IoBuffer Value, + ZenContentType Type); + ZenCacheResult PutCacheValue(std::string_view Namespace, + std::string_view BucketId, + const IoHash& Key, + const IoHash& ValueContentId, + IoBuffer Payload); + ZenCacheResult InvokeRpc(const CbObjectView& Request); + ZenCacheResult InvokeRpc(const CbPackage& Package); + +private: + inline spdlog::logger& Log() { return m_Log; } + + spdlog::logger& m_Log; + Ref<ZenStructuredCacheClient> m_Client; + detail::ZenCacheSessionState* m_SessionState; +}; + +/** Zen Structured Cache client + * + * This represents an endpoint to query -- actual queries should be done via + * ZenStructuredCacheSession + */ +class ZenStructuredCacheClient : public RefCounted +{ +public: + ZenStructuredCacheClient(const ZenStructuredCacheClientOptions& Options); + ~ZenStructuredCacheClient(); + + std::string_view ServiceUrl() const { return m_ServiceUrl; } + + inline spdlog::logger& Log() { return m_Log; } + +private: + spdlog::logger& m_Log; + std::string m_ServiceUrl; + std::chrono::milliseconds m_ConnectTimeout; + std::chrono::milliseconds m_Timeout; + + RwLock m_SessionStateLock; + std::list<detail::ZenCacheSessionState*> m_SessionStateCache; + + detail::ZenCacheSessionState* AllocSessionState(); + void FreeSessionState(detail::ZenCacheSessionState*); + + friend class ZenStructuredCacheSession; +}; + +} // namespace zen diff --git a/src/zenserver/windows/service.cpp b/src/zenserver/windows/service.cpp new file mode 100644 index 000000000..89bacab0b --- /dev/null +++ b/src/zenserver/windows/service.cpp @@ -0,0 +1,646 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "service.h" + +#include <zencore/zencore.h> + +#if ZEN_PLATFORM_WINDOWS + +# include <zencore/except.h> +# include <zencore/zencore.h> + +# include <stdio.h> +# include <tchar.h> +# include <zencore/windows.h> + +# define SVCNAME L"Zen Store" + +SERVICE_STATUS gSvcStatus; +SERVICE_STATUS_HANDLE gSvcStatusHandle; +HANDLE ghSvcStopEvent = NULL; + +void SvcInstall(void); + +void ReportSvcStatus(DWORD, DWORD, DWORD); +void SvcReportEvent(LPTSTR); + +WindowsService::WindowsService() +{ +} + +WindowsService::~WindowsService() +{ +} + +// +// Purpose: +// Installs a service in the SCM database +// +// Parameters: +// None +// +// Return value: +// None +// +VOID +WindowsService::Install() +{ + SC_HANDLE schSCManager; + SC_HANDLE schService; + TCHAR szPath[MAX_PATH]; + + if (!GetModuleFileName(NULL, szPath, MAX_PATH)) + { + printf("Cannot install service (%d)\n", GetLastError()); + return; + } + + // Get a handle to the SCM database. + + schSCManager = OpenSCManager(NULL, // local computer + NULL, // ServicesActive database + SC_MANAGER_ALL_ACCESS); // full access rights + + if (NULL == schSCManager) + { + printf("OpenSCManager failed (%d)\n", GetLastError()); + return; + } + + // Create the service + + schService = CreateService(schSCManager, // SCM database + SVCNAME, // name of service + SVCNAME, // service name to display + SERVICE_ALL_ACCESS, // desired access + SERVICE_WIN32_OWN_PROCESS, // service type + SERVICE_DEMAND_START, // start type + SERVICE_ERROR_NORMAL, // error control type + szPath, // path to service's binary + NULL, // no load ordering group + NULL, // no tag identifier + NULL, // no dependencies + NULL, // LocalSystem account + NULL); // no password + + if (schService == NULL) + { + printf("CreateService failed (%d)\n", GetLastError()); + CloseServiceHandle(schSCManager); + return; + } + else + printf("Service installed successfully\n"); + + CloseServiceHandle(schService); + CloseServiceHandle(schSCManager); +} + +void +WindowsService::Delete() +{ + SC_HANDLE schSCManager; + SC_HANDLE schService; + + // Get a handle to the SCM database. + + schSCManager = OpenSCManager(NULL, // local computer + NULL, // ServicesActive database + SC_MANAGER_ALL_ACCESS); // full access rights + + if (NULL == schSCManager) + { + printf("OpenSCManager failed (%d)\n", GetLastError()); + return; + } + + // Get a handle to the service. + + schService = OpenService(schSCManager, // SCM database + SVCNAME, // name of service + DELETE); // need delete access + + if (schService == NULL) + { + printf("OpenService failed (%d)\n", GetLastError()); + CloseServiceHandle(schSCManager); + return; + } + + // Delete the service. + + if (!DeleteService(schService)) + { + printf("DeleteService failed (%d)\n", GetLastError()); + } + else + printf("Service deleted successfully\n"); + + CloseServiceHandle(schService); + CloseServiceHandle(schSCManager); +} + +WindowsService* gSvc; + +void WINAPI +CallMain(DWORD, LPSTR*) +{ + gSvc->SvcMain(); +} + +int +WindowsService::ServiceMain() +{ + gSvc = this; + + SERVICE_TABLE_ENTRY DispatchTable[] = {{(LPWSTR)SVCNAME, (LPSERVICE_MAIN_FUNCTION)&CallMain}, {NULL, NULL}}; + + // This call returns when the service has stopped. + // The process should simply terminate when the call returns. + + if (!StartServiceCtrlDispatcher(DispatchTable)) + { + const DWORD dwError = zen::GetLastError(); + + if (dwError == ERROR_FAILED_SERVICE_CONTROLLER_CONNECT) + { + // Not actually running as a service + gSvc = nullptr; + + zen::SetIsInteractiveSession(true); + + return Run(); + } + else + { + zen::ThrowSystemError(dwError, "StartServiceCtrlDispatcher failed"); + } + } + + zen::SetIsInteractiveSession(false); + + return 0; +} + +int +WindowsService::SvcMain() +{ + // Register the handler function for the service + + gSvcStatusHandle = RegisterServiceCtrlHandler(SVCNAME, SvcCtrlHandler); + + if (!gSvcStatusHandle) + { + SvcReportEvent((LPTSTR)TEXT("RegisterServiceCtrlHandler")); + + return 1; + } + + // These SERVICE_STATUS members remain as set here + + gSvcStatus.dwServiceType = SERVICE_WIN32_OWN_PROCESS; + gSvcStatus.dwServiceSpecificExitCode = 0; + + // Report initial status to the SCM + + ReportSvcStatus(SERVICE_START_PENDING, NO_ERROR, 3000); + + // Create an event. The control handler function, SvcCtrlHandler, + // signals this event when it receives the stop control code. + + ghSvcStopEvent = CreateEvent(NULL, // default security attributes + TRUE, // manual reset event + FALSE, // not signaled + NULL); // no name + + if (ghSvcStopEvent == NULL) + { + ReportSvcStatus(SERVICE_STOPPED, GetLastError(), 0); + + return 1; + } + + // Report running status when initialization is complete. + + ReportSvcStatus(SERVICE_RUNNING, NO_ERROR, 0); + + int ReturnCode = Run(); + + ReportSvcStatus(SERVICE_STOPPED, NO_ERROR, 0); + + return ReturnCode; +} + +// +// Purpose: +// Retrieves and displays the current service configuration. +// +// Parameters: +// None +// +// Return value: +// None +// +void +DoQuerySvc() +{ + SC_HANDLE schSCManager{}; + SC_HANDLE schService{}; + LPQUERY_SERVICE_CONFIG lpsc{}; + LPSERVICE_DESCRIPTION lpsd{}; + DWORD dwBytesNeeded{}, cbBufSize{}, dwError{}; + + // Get a handle to the SCM database. + + schSCManager = OpenSCManager(NULL, // local computer + NULL, // ServicesActive database + SC_MANAGER_ALL_ACCESS); // full access rights + + if (NULL == schSCManager) + { + printf("OpenSCManager failed (%d)\n", GetLastError()); + return; + } + + // Get a handle to the service. + + schService = OpenService(schSCManager, // SCM database + SVCNAME, // name of service + SERVICE_QUERY_CONFIG); // need query config access + + if (schService == NULL) + { + printf("OpenService failed (%d)\n", GetLastError()); + CloseServiceHandle(schSCManager); + return; + } + + // Get the configuration information. + + if (!QueryServiceConfig(schService, NULL, 0, &dwBytesNeeded)) + { + dwError = GetLastError(); + if (ERROR_INSUFFICIENT_BUFFER == dwError) + { + cbBufSize = dwBytesNeeded; + lpsc = (LPQUERY_SERVICE_CONFIG)LocalAlloc(LMEM_FIXED, cbBufSize); + } + else + { + printf("QueryServiceConfig failed (%d)", dwError); + goto cleanup; + } + } + + if (!QueryServiceConfig(schService, lpsc, cbBufSize, &dwBytesNeeded)) + { + printf("QueryServiceConfig failed (%d)", GetLastError()); + goto cleanup; + } + + if (!QueryServiceConfig2(schService, SERVICE_CONFIG_DESCRIPTION, NULL, 0, &dwBytesNeeded)) + { + dwError = GetLastError(); + if (ERROR_INSUFFICIENT_BUFFER == dwError) + { + cbBufSize = dwBytesNeeded; + lpsd = (LPSERVICE_DESCRIPTION)LocalAlloc(LMEM_FIXED, cbBufSize); + } + else + { + printf("QueryServiceConfig2 failed (%d)", dwError); + goto cleanup; + } + } + + if (!QueryServiceConfig2(schService, SERVICE_CONFIG_DESCRIPTION, (LPBYTE)lpsd, cbBufSize, &dwBytesNeeded)) + { + printf("QueryServiceConfig2 failed (%d)", GetLastError()); + goto cleanup; + } + + // Print the configuration information. + + _tprintf(TEXT("%s configuration: \n"), SVCNAME); + _tprintf(TEXT(" Type: 0x%x\n"), lpsc->dwServiceType); + _tprintf(TEXT(" Start Type: 0x%x\n"), lpsc->dwStartType); + _tprintf(TEXT(" Error Control: 0x%x\n"), lpsc->dwErrorControl); + _tprintf(TEXT(" Binary path: %s\n"), lpsc->lpBinaryPathName); + _tprintf(TEXT(" Account: %s\n"), lpsc->lpServiceStartName); + + if (lpsd->lpDescription != NULL && lstrcmp(lpsd->lpDescription, TEXT("")) != 0) + _tprintf(TEXT(" Description: %s\n"), lpsd->lpDescription); + if (lpsc->lpLoadOrderGroup != NULL && lstrcmp(lpsc->lpLoadOrderGroup, TEXT("")) != 0) + _tprintf(TEXT(" Load order group: %s\n"), lpsc->lpLoadOrderGroup); + if (lpsc->dwTagId != 0) + _tprintf(TEXT(" Tag ID: %d\n"), lpsc->dwTagId); + if (lpsc->lpDependencies != NULL && lstrcmp(lpsc->lpDependencies, TEXT("")) != 0) + _tprintf(TEXT(" Dependencies: %s\n"), lpsc->lpDependencies); + + LocalFree(lpsc); + LocalFree(lpsd); + +cleanup: + CloseServiceHandle(schService); + CloseServiceHandle(schSCManager); +} + +// +// Purpose: +// Disables the service. +// +// Parameters: +// None +// +// Return value: +// None +// +void +DoDisableSvc() +{ + SC_HANDLE schSCManager; + SC_HANDLE schService; + + // Get a handle to the SCM database. + + schSCManager = OpenSCManager(NULL, // local computer + NULL, // ServicesActive database + SC_MANAGER_ALL_ACCESS); // full access rights + + if (NULL == schSCManager) + { + printf("OpenSCManager failed (%d)\n", GetLastError()); + return; + } + + // Get a handle to the service. + + schService = OpenService(schSCManager, // SCM database + SVCNAME, // name of service + SERVICE_CHANGE_CONFIG); // need change config access + + if (schService == NULL) + { + printf("OpenService failed (%d)\n", GetLastError()); + CloseServiceHandle(schSCManager); + return; + } + + // Change the service start type. + + if (!ChangeServiceConfig(schService, // handle of service + SERVICE_NO_CHANGE, // service type: no change + SERVICE_DISABLED, // service start type + SERVICE_NO_CHANGE, // error control: no change + NULL, // binary path: no change + NULL, // load order group: no change + NULL, // tag ID: no change + NULL, // dependencies: no change + NULL, // account name: no change + NULL, // password: no change + NULL)) // display name: no change + { + printf("ChangeServiceConfig failed (%d)\n", GetLastError()); + } + else + printf("Service disabled successfully.\n"); + + CloseServiceHandle(schService); + CloseServiceHandle(schSCManager); +} + +// +// Purpose: +// Enables the service. +// +// Parameters: +// None +// +// Return value: +// None +// +VOID __stdcall DoEnableSvc() +{ + SC_HANDLE schSCManager; + SC_HANDLE schService; + + // Get a handle to the SCM database. + + schSCManager = OpenSCManager(NULL, // local computer + NULL, // ServicesActive database + SC_MANAGER_ALL_ACCESS); // full access rights + + if (NULL == schSCManager) + { + printf("OpenSCManager failed (%d)\n", GetLastError()); + return; + } + + // Get a handle to the service. + + schService = OpenService(schSCManager, // SCM database + SVCNAME, // name of service + SERVICE_CHANGE_CONFIG); // need change config access + + if (schService == NULL) + { + printf("OpenService failed (%d)\n", GetLastError()); + CloseServiceHandle(schSCManager); + return; + } + + // Change the service start type. + + if (!ChangeServiceConfig(schService, // handle of service + SERVICE_NO_CHANGE, // service type: no change + SERVICE_DEMAND_START, // service start type + SERVICE_NO_CHANGE, // error control: no change + NULL, // binary path: no change + NULL, // load order group: no change + NULL, // tag ID: no change + NULL, // dependencies: no change + NULL, // account name: no change + NULL, // password: no change + NULL)) // display name: no change + { + printf("ChangeServiceConfig failed (%d)\n", GetLastError()); + } + else + printf("Service enabled successfully.\n"); + + CloseServiceHandle(schService); + CloseServiceHandle(schSCManager); +} +// +// Purpose: +// Updates the service description to "This is a test description". +// +// Parameters: +// None +// +// Return value: +// None +// +void +DoUpdateSvcDesc() +{ + SC_HANDLE schSCManager; + SC_HANDLE schService; + SERVICE_DESCRIPTION sd; + TCHAR szDesc[] = TEXT("This is a test description"); + + // Get a handle to the SCM database. + + schSCManager = OpenSCManager(NULL, // local computer + NULL, // ServicesActive database + SC_MANAGER_ALL_ACCESS); // full access rights + + if (NULL == schSCManager) + { + printf("OpenSCManager failed (%d)\n", GetLastError()); + return; + } + + // Get a handle to the service. + + schService = OpenService(schSCManager, // SCM database + SVCNAME, // name of service + SERVICE_CHANGE_CONFIG); // need change config access + + if (schService == NULL) + { + printf("OpenService failed (%d)\n", GetLastError()); + CloseServiceHandle(schSCManager); + return; + } + + // Change the service description. + + sd.lpDescription = szDesc; + + if (!ChangeServiceConfig2(schService, // handle to service + SERVICE_CONFIG_DESCRIPTION, // change: description + &sd)) // new description + { + printf("ChangeServiceConfig2 failed\n"); + } + else + printf("Service description updated successfully.\n"); + + CloseServiceHandle(schService); + CloseServiceHandle(schSCManager); +} + +// +// Purpose: +// Sets the current service status and reports it to the SCM. +// +// Parameters: +// dwCurrentState - The current state (see SERVICE_STATUS) +// dwWin32ExitCode - The system error code +// dwWaitHint - Estimated time for pending operation, +// in milliseconds +// +// Return value: +// None +// +VOID +ReportSvcStatus(DWORD dwCurrentState, DWORD dwWin32ExitCode, DWORD dwWaitHint) +{ + static DWORD dwCheckPoint = 1; + + // Fill in the SERVICE_STATUS structure. + + gSvcStatus.dwCurrentState = dwCurrentState; + gSvcStatus.dwWin32ExitCode = dwWin32ExitCode; + gSvcStatus.dwWaitHint = dwWaitHint; + + if (dwCurrentState == SERVICE_START_PENDING) + gSvcStatus.dwControlsAccepted = 0; + else + gSvcStatus.dwControlsAccepted = SERVICE_ACCEPT_STOP; + + if ((dwCurrentState == SERVICE_RUNNING) || (dwCurrentState == SERVICE_STOPPED)) + gSvcStatus.dwCheckPoint = 0; + else + gSvcStatus.dwCheckPoint = dwCheckPoint++; + + // Report the status of the service to the SCM. + SetServiceStatus(gSvcStatusHandle, &gSvcStatus); +} + +void +WindowsService::SvcCtrlHandler(DWORD dwCtrl) +{ + // Handle the requested control code. + // + // Called by SCM whenever a control code is sent to the service + // using the ControlService function. + + switch (dwCtrl) + { + case SERVICE_CONTROL_STOP: + ReportSvcStatus(SERVICE_STOP_PENDING, NO_ERROR, 0); + + // Signal the service to stop. + + SetEvent(ghSvcStopEvent); + zen::RequestApplicationExit(0); + + ReportSvcStatus(gSvcStatus.dwCurrentState, NO_ERROR, 0); + return; + + case SERVICE_CONTROL_INTERROGATE: + break; + + default: + break; + } +} + +// +// Purpose: +// Logs messages to the event log +// +// Parameters: +// szFunction - name of function that failed +// +// Return value: +// None +// +// Remarks: +// The service must have an entry in the Application event log. +// +VOID +SvcReportEvent(LPTSTR szFunction) +{ + ZEN_UNUSED(szFunction); + + // HANDLE hEventSource; + // LPCTSTR lpszStrings[2]; + // TCHAR Buffer[80]; + + // hEventSource = RegisterEventSource(NULL, SVCNAME); + + // if (NULL != hEventSource) + //{ + // StringCchPrintf(Buffer, 80, TEXT("%s failed with %d"), szFunction, GetLastError()); + + // lpszStrings[0] = SVCNAME; + // lpszStrings[1] = Buffer; + + // ReportEvent(hEventSource, // event log handle + // EVENTLOG_ERROR_TYPE, // event type + // 0, // event category + // SVC_ERROR, // event identifier + // NULL, // no security identifier + // 2, // size of lpszStrings array + // 0, // no binary data + // lpszStrings, // array of strings + // NULL); // no binary data + + // DeregisterEventSource(hEventSource); + //} +} + +#endif // ZEN_PLATFORM_WINDOWS diff --git a/src/zenserver/windows/service.h b/src/zenserver/windows/service.h new file mode 100644 index 000000000..7c9610983 --- /dev/null +++ b/src/zenserver/windows/service.h @@ -0,0 +1,20 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +class WindowsService +{ +public: + WindowsService(); + ~WindowsService(); + + virtual int Run() = 0; + + int ServiceMain(); + + static void Install(); + static void Delete(); + + int SvcMain(); + static void __stdcall SvcCtrlHandler(unsigned long); +}; diff --git a/src/zenserver/xmake.lua b/src/zenserver/xmake.lua new file mode 100644 index 000000000..23bfb9535 --- /dev/null +++ b/src/zenserver/xmake.lua @@ -0,0 +1,60 @@ +-- Copyright Epic Games, Inc. All Rights Reserved. + +target("zenserver") + set_kind("binary") + add_deps("zencore", "zenhttp", "zenstore", "zenutil") + add_headerfiles("**.h") + add_files("**.cpp") + add_files("zenserver.cpp", {unity_ignored = true }) + add_includedirs(".") + set_symbols("debug") + + if is_mode("release") then + set_optimize("fastest") + end + + if is_plat("windows") then + add_ldflags("/subsystem:console,5.02") + add_ldflags("/MANIFEST:EMBED") + add_ldflags("/LTCG") + add_files("zenserver.rc") + add_cxxflags("/bigobj") + else + remove_files("windows/**") + end + + if is_plat("macosx") then + add_ldflags("-framework CoreFoundation") + add_ldflags("-framework CoreGraphics") + add_ldflags("-framework CoreText") + add_ldflags("-framework Foundation") + add_ldflags("-framework Security") + add_ldflags("-framework SystemConfiguration") + add_syslinks("bsm") + end + + add_options("compute") + add_options("exec") + + add_packages( + "vcpkg::asio", + "vcpkg::cxxopts", + "vcpkg::http-parser", + "vcpkg::json11", + "vcpkg::lua", + "vcpkg::mimalloc", + "vcpkg::rocksdb", + "vcpkg::sentry-native", + "vcpkg::sol2" + ) + + -- Only applicable to later versions of sentry-native + --[[ + if is_plat("linux") then + -- As sentry_native uses symbols from breakpad_client, the latter must + -- be specified after the former with GCC-like toolchains. xmake however + -- is unaware of this and simply globs files from vcpkg's output. The + -- line below forces breakpad_client to be to the right of sentry_native + add_syslinks("breakpad_client") + end + ]]-- diff --git a/src/zenserver/zenserver.cpp b/src/zenserver/zenserver.cpp new file mode 100644 index 000000000..635fd04e0 --- /dev/null +++ b/src/zenserver/zenserver.cpp @@ -0,0 +1,1261 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencore/compactbinarybuilder.h> +#include <zencore/compactbinaryvalidation.h> +#include <zencore/config.h> +#include <zencore/filesystem.h> +#include <zencore/fmtutils.h> +#include <zencore/iobuffer.h> +#include <zencore/logging.h> +#include <zencore/refcount.h> +#include <zencore/scopeguard.h> +#include <zencore/session.h> +#include <zencore/string.h> +#include <zencore/thread.h> +#include <zencore/timer.h> +#include <zencore/trace.h> +#include <zenhttp/httpserver.h> +#include <zenhttp/websocket.h> +#include <zenstore/cidstore.h> +#include <zenstore/scrubcontext.h> +#include <zenutil/basicfile.h> +#include <zenutil/zenserverprocess.h> + +#if ZEN_PLATFORM_WINDOWS +# include <zencore/windows.h> +#endif + +#if ZEN_USE_MIMALLOC +ZEN_THIRD_PARTY_INCLUDES_START +# include <mimalloc-new-delete.h> +# include <mimalloc.h> +ZEN_THIRD_PARTY_INCLUDES_END +#endif + +ZEN_THIRD_PARTY_INCLUDES_START +#include <fmt/format.h> +#include <asio.hpp> +#include <lua.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +#include <exception> +#include <list> +#include <optional> +#include <regex> +#include <set> +#include <unordered_map> + +////////////////////////////////////////////////////////////////////////// +// We don't have any doctest code in this file but this is needed to bring +// in some shared code into the executable + +#if ZEN_WITH_TESTS +# define ZEN_TEST_WITH_RUNNER 1 +# include <zencore/testing.h> +#endif + +////////////////////////////////////////////////////////////////////////// + +#include "config.h" +#include "diag/logging.h" + +#if ZEN_PLATFORM_WINDOWS +# include "windows/service.h" +#endif + +////////////////////////////////////////////////////////////////////////// +// Sentry +// + +#if !defined(ZEN_USE_SENTRY) +# if ZEN_PLATFORM_MAC && ZEN_ARCH_ARM64 +// vcpkg's sentry-native port does not support Arm on Mac. +# define ZEN_USE_SENTRY 0 +# else +# define ZEN_USE_SENTRY 1 +# endif +#endif + +#if ZEN_USE_SENTRY +# define SENTRY_BUILD_STATIC 1 +ZEN_THIRD_PARTY_INCLUDES_START +# include <sentry.h> +# include <spdlog/sinks/base_sink.h> +ZEN_THIRD_PARTY_INCLUDES_END + +// Sentry currently does not automatically add all required Windows +// libraries to the linker when consumed via vcpkg + +# if ZEN_PLATFORM_WINDOWS +# pragma comment(lib, "sentry.lib") +# pragma comment(lib, "dbghelp.lib") +# pragma comment(lib, "winhttp.lib") +# pragma comment(lib, "version.lib") +# endif +#endif + +////////////////////////////////////////////////////////////////////////// +// Services +// + +#include "admin/admin.h" +#include "auth/authmgr.h" +#include "auth/authservice.h" +#include "cache/structuredcache.h" +#include "cache/structuredcachestore.h" +#include "cidstore.h" +#include "compute/function.h" +#include "diag/diagsvcs.h" +#include "frontend/frontend.h" +#include "monitoring/httpstats.h" +#include "monitoring/httpstatus.h" +#include "objectstore/objectstore.h" +#include "projectstore/projectstore.h" +#include "testing/httptest.h" +#include "upstream/upstream.h" +#include "zenstore/gc.h" + +#define ZEN_APP_NAME "Zen store" + +namespace zen { + +using namespace std::literals; + +namespace utils { +#if ZEN_USE_SENTRY + class sentry_sink final : public spdlog::sinks::base_sink<spdlog::details::null_mutex> + { + public: + sentry_sink() {} + + protected: + static constexpr sentry_level_t MapToSentryLevel[spdlog::level::level_enum::n_levels] = {SENTRY_LEVEL_DEBUG, + SENTRY_LEVEL_DEBUG, + SENTRY_LEVEL_INFO, + SENTRY_LEVEL_WARNING, + SENTRY_LEVEL_ERROR, + SENTRY_LEVEL_FATAL, + SENTRY_LEVEL_DEBUG}; + + void sink_it_(const spdlog::details::log_msg& msg) override + { + std::string Message = fmt::format("{}\n{}({}) [{}]", msg.payload, msg.source.filename, msg.source.line, msg.source.funcname); + sentry_value_t event = sentry_value_new_message_event( + /* level */ MapToSentryLevel[msg.level], + /* logger */ nullptr, + /* message */ Message.c_str()); + sentry_event_value_add_stacktrace(event, NULL, 0); + sentry_capture_event(event); + } + void flush_() override {} + }; +#endif + + asio::error_code ResolveHostname(asio::io_context& Ctx, + std::string_view Host, + std::string_view DefaultPort, + std::vector<std::string>& OutEndpoints) + { + std::string_view Port = DefaultPort; + + if (const size_t Idx = Host.find(":"); Idx != std::string_view::npos) + { + Port = Host.substr(Idx + 1); + Host = Host.substr(0, Idx); + } + + asio::ip::tcp::resolver Resolver(Ctx); + + asio::error_code ErrorCode; + asio::ip::tcp::resolver::results_type Endpoints = Resolver.resolve(Host, Port, ErrorCode); + + if (!ErrorCode) + { + for (const asio::ip::tcp::endpoint Ep : Endpoints) + { + OutEndpoints.push_back(fmt::format("http://{}:{}", Ep.address().to_string(), Ep.port())); + } + } + + return ErrorCode; + } +} // namespace utils + +class ZenServer : public IHttpStatusProvider +{ +public: + int Initialize(const ZenServerOptions& ServerOptions, ZenServerState::ZenServerEntry* ServerEntry) + { + m_UseSentry = ServerOptions.NoSentry == false; + m_ServerEntry = ServerEntry; + m_DebugOptionForcedCrash = ServerOptions.ShouldCrash; + const int ParentPid = ServerOptions.OwnerPid; + + if (ParentPid) + { + zen::ProcessHandle OwnerProcess; + OwnerProcess.Initialize(ParentPid); + + if (!OwnerProcess.IsValid()) + { + ZEN_WARN("Unable to initialize process handle for specified parent pid #{}", ParentPid); + + // If the pid is not reachable should we just shut down immediately? the intended owner process + // could have been killed or somehow crashed already + } + else + { + ZEN_INFO("Using parent pid #{} to control process lifetime", ParentPid); + } + + m_ProcessMonitor.AddPid(ParentPid); + } + + // Initialize/check mutex based on base port + + std::string MutexName = fmt::format("zen_{}", ServerOptions.BasePort); + + if (zen::NamedMutex::Exists(MutexName) || ((m_ServerMutex.Create(MutexName) == false))) + { + throw std::runtime_error(fmt::format("Failed to create mutex '{}' - is another instance already running?", MutexName).c_str()); + } + + InitializeState(ServerOptions); + + m_HealthService.SetHealthInfo({.DataRoot = m_DataRoot, + .AbsLogPath = ServerOptions.AbsLogFile, + .HttpServerClass = std::string(ServerOptions.HttpServerClass), + .BuildVersion = std::string(ZEN_CFG_VERSION_BUILD_STRING_FULL)}); + + // Ok so now we're configured, let's kick things off + + m_Http = zen::CreateHttpServer(ServerOptions.HttpServerClass); + int EffectiveBasePort = m_Http->Initialize(ServerOptions.BasePort); + + if (ServerOptions.WebSocketPort != 0) + { + const uint32 ThreadCount = + ServerOptions.WebSocketThreads > 0 ? uint32_t(ServerOptions.WebSocketThreads) : std::thread::hardware_concurrency(); + + m_WebSocket = zen::WebSocketServer::Create( + {.Port = gsl::narrow<uint16_t>(ServerOptions.WebSocketPort), .ThreadCount = Max(ThreadCount, uint32_t(16))}); + } + + // Setup authentication manager + { + std::string EncryptionKey = ServerOptions.EncryptionKey; + + if (EncryptionKey.empty()) + { + EncryptionKey = "abcdefghijklmnopqrstuvxyz0123456"; + + ZEN_WARN("using default encryption key"); + } + + std::string EncryptionIV = ServerOptions.EncryptionIV; + + if (EncryptionIV.empty()) + { + EncryptionIV = "0123456789abcdef"; + + ZEN_WARN("using default encryption initialization vector"); + } + + m_AuthMgr = AuthMgr::Create({.RootDirectory = m_DataRoot / "auth", + .EncryptionKey = AesKey256Bit::FromString(EncryptionKey), + .EncryptionIV = AesIV128Bit::FromString(EncryptionIV)}); + + for (const ZenOpenIdProviderConfig& OpenIdProvider : ServerOptions.AuthConfig.OpenIdProviders) + { + m_AuthMgr->AddOpenIdProvider({.Name = OpenIdProvider.Name, .Url = OpenIdProvider.Url, .ClientId = OpenIdProvider.ClientId}); + } + } + + m_AuthService = std::make_unique<zen::HttpAuthService>(*m_AuthMgr); + m_Http->RegisterService(*m_AuthService); + + m_Http->RegisterService(m_HealthService); + m_Http->RegisterService(m_StatsService); + m_Http->RegisterService(m_StatusService); + m_StatusService.RegisterHandler("status", *this); + + // Initialize storage and services + + ZEN_INFO("initializing storage"); + + zen::CidStoreConfiguration Config; + Config.RootDirectory = m_DataRoot / "cas"; + + m_CidStore = std::make_unique<zen::CidStore>(m_GcManager); + m_CidStore->Initialize(Config); + m_CidService.reset(new zen::HttpCidService{*m_CidStore}); + + ZEN_INFO("instantiating project service"); + + m_ProjectStore = new zen::ProjectStore(*m_CidStore, m_DataRoot / "projects", m_GcManager); + m_HttpProjectService.reset(new zen::HttpProjectService{*m_CidStore, m_ProjectStore, m_StatsService, *m_AuthMgr}); + +#if ZEN_WITH_COMPUTE_SERVICES + if (ServerOptions.ComputeServiceEnabled) + { + InitializeCompute(ServerOptions); + } + else + { + ZEN_INFO("NOT instantiating compute services"); + } +#endif // ZEN_WITH_COMPUTE_SERVICES + + if (ServerOptions.StructuredCacheEnabled) + { + InitializeStructuredCache(ServerOptions); + } + else + { + ZEN_INFO("NOT instantiating structured cache service"); + } + + m_Http->RegisterService(m_TestService); // NOTE: this is intentionally not limited to test mode as it's useful for diagnostics + m_Http->RegisterService(m_TestingService); + m_Http->RegisterService(m_AdminService); + + if (m_WebSocket) + { + m_WebSocket->RegisterService(m_TestingService); + } + + if (m_HttpProjectService) + { + m_Http->RegisterService(*m_HttpProjectService); + } + + m_Http->RegisterService(*m_CidService); + +#if ZEN_WITH_COMPUTE_SERVICES + if (ServerOptions.ComputeServiceEnabled) + { + if (m_HttpFunctionService != nullptr) + { + m_Http->RegisterService(*m_HttpFunctionService); + } + } +#endif // ZEN_WITH_COMPUTE_SERVICES + + m_FrontendService = std::make_unique<HttpFrontendService>(m_ContentRoot); + + if (m_FrontendService) + { + m_Http->RegisterService(*m_FrontendService); + } + + if (ServerOptions.ObjectStoreEnabled) + { + ObjectStoreConfig ObjCfg; + ObjCfg.RootDirectory = m_DataRoot / "obj"; + ObjCfg.ServerPort = static_cast<uint16_t>(EffectiveBasePort); + + for (const auto& Bucket : ServerOptions.ObjectStoreConfig.Buckets) + { + ObjectStoreConfig::BucketConfig NewBucket{.Name = Bucket.Name}; + NewBucket.Directory = Bucket.Directory.empty() ? (ObjCfg.RootDirectory / Bucket.Name) : Bucket.Directory; + ObjCfg.Buckets.push_back(std::move(NewBucket)); + } + + m_ObjStoreService = std::make_unique<HttpObjectStoreService>(std::move(ObjCfg)); + m_Http->RegisterService(*m_ObjStoreService); + } + + ZEN_INFO("initializing GC, enabled '{}', interval {}s", ServerOptions.GcConfig.Enabled, ServerOptions.GcConfig.IntervalSeconds); + zen::GcSchedulerConfig GcConfig{.RootDirectory = m_DataRoot / "gc", + .MonitorInterval = std::chrono::seconds(ServerOptions.GcConfig.MonitorIntervalSeconds), + .Interval = std::chrono::seconds(ServerOptions.GcConfig.IntervalSeconds), + .MaxCacheDuration = std::chrono::seconds(ServerOptions.GcConfig.Cache.MaxDurationSeconds), + .CollectSmallObjects = ServerOptions.GcConfig.CollectSmallObjects, + .Enabled = ServerOptions.GcConfig.Enabled, + .DiskReserveSize = ServerOptions.GcConfig.DiskReserveSize, + .DiskSizeSoftLimit = ServerOptions.GcConfig.Cache.DiskSizeSoftLimit}; + m_GcScheduler.Initialize(GcConfig); + + return EffectiveBasePort; + } + + void InitializeState(const ZenServerOptions& ServerOptions); + void InitializeStructuredCache(const ZenServerOptions& ServerOptions); + void InitializeCompute(const ZenServerOptions& ServerOptions); + + void Run() + { + // This is disabled for now, awaiting better scheduling + // + // Scrub(); + + if (m_ProcessMonitor.IsActive()) + { + EnqueueTimer(); + } + + if (!m_TestMode) + { + ZEN_INFO("__________ _________ __ "); + ZEN_INFO("\\____ /____ ____ / _____// |_ ___________ ____ "); + ZEN_INFO(" / // __ \\ / \\ \\_____ \\\\ __\\/ _ \\_ __ \\_/ __ \\ "); + ZEN_INFO(" / /\\ ___/| | \\ / \\| | ( <_> ) | \\/\\ ___/ "); + ZEN_INFO("/_______ \\___ >___| / /_______ /|__| \\____/|__| \\___ >"); + ZEN_INFO(" \\/ \\/ \\/ \\/ \\/ "); + } + + ZEN_INFO(ZEN_APP_NAME " now running (pid: {})", zen::GetCurrentProcessId()); + +#if ZEN_USE_SENTRY + ZEN_INFO("sentry crash handler {}", m_UseSentry ? "ENABLED" : "DISABLED"); + if (m_UseSentry) + { + sentry_clear_modulecache(); + } +#endif + + if (m_DebugOptionForcedCrash) + { + ZEN_DEBUG_BREAK(); + } + + const bool IsInteractiveMode = zen::IsInteractiveSession() && !m_TestMode; + + SetNewState(kRunning); + + OnReady(); + + if (m_WebSocket) + { + m_WebSocket->Run(); + } + + m_Http->Run(IsInteractiveMode); + + SetNewState(kShuttingDown); + + ZEN_INFO(ZEN_APP_NAME " exiting"); + + m_IoContext.stop(); + if (m_IoRunner.joinable()) + { + m_IoRunner.join(); + } + + Flush(); + } + + void RequestExit(int ExitCode) + { + RequestApplicationExit(ExitCode); + m_Http->RequestExit(); + } + + void Cleanup() + { + ZEN_INFO(ZEN_APP_NAME " cleaning up"); + m_GcScheduler.Shutdown(); + } + + void SetDedicatedMode(bool State) { m_IsDedicatedMode = State; } + void SetTestMode(bool State) { m_TestMode = State; } + void SetDataRoot(std::filesystem::path Root) { m_DataRoot = Root; } + void SetContentRoot(std::filesystem::path Root) { m_ContentRoot = Root; } + + std::function<void()> m_IsReadyFunc; + void SetIsReadyFunc(std::function<void()>&& IsReadyFunc) { m_IsReadyFunc = std::move(IsReadyFunc); } + void OnReady(); + + void EnsureIoRunner() + { + if (!m_IoRunner.joinable()) + { + m_IoRunner = std::thread{[this] { m_IoContext.run(); }}; + } + } + + void EnqueueTimer() + { + m_PidCheckTimer.expires_after(std::chrono::seconds(1)); + m_PidCheckTimer.async_wait([this](const asio::error_code&) { CheckOwnerPid(); }); + + EnsureIoRunner(); + } + + void CheckOwnerPid() + { + // Pick up any new "owner" processes + + std::set<uint32_t> AddedPids; + + for (auto& PidEntry : m_ServerEntry->SponsorPids) + { + if (uint32_t ThisPid = PidEntry.load(std::memory_order_relaxed)) + { + if (PidEntry.compare_exchange_strong(ThisPid, 0)) + { + if (AddedPids.insert(ThisPid).second) + { + m_ProcessMonitor.AddPid(ThisPid); + + ZEN_INFO("added process with pid #{} as a sponsor process", ThisPid); + } + } + } + } + + if (m_ProcessMonitor.IsRunning()) + { + EnqueueTimer(); + } + else + { + ZEN_INFO(ZEN_APP_NAME " exiting since sponsor processes are all gone"); + + RequestExit(0); + } + } + + void Scrub() + { + Stopwatch Timer; + ZEN_INFO("Storage validation STARTING"); + + ScrubContext Ctx; + m_CidStore->Scrub(Ctx); + m_ProjectStore->Scrub(Ctx); + m_StructuredCacheService->Scrub(Ctx); + + const uint64_t ElapsedTimeMs = Timer.GetElapsedTimeMs(); + + ZEN_INFO("Storage validation DONE in {}, ({} in {} chunks - {})", + NiceTimeSpanMs(ElapsedTimeMs), + NiceBytes(Ctx.ScrubbedBytes()), + Ctx.ScrubbedChunks(), + NiceByteRate(Ctx.ScrubbedBytes(), ElapsedTimeMs)); + } + + void Flush() + { + if (m_CidStore) + m_CidStore->Flush(); + + if (m_StructuredCacheService) + m_StructuredCacheService->Flush(); + + if (m_ProjectStore) + m_ProjectStore->Flush(); + } + + virtual void HandleStatusRequest(HttpServerRequest& Request) override + { + CbObjectWriter Cbo; + Cbo << "ok" << true; + Cbo << "state" << ToString(m_CurrentState); + Request.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + } + +private: + ZenServerState::ZenServerEntry* m_ServerEntry = nullptr; + bool m_IsDedicatedMode = false; + bool m_TestMode = false; + CbObject m_RootManifest; + std::filesystem::path m_DataRoot; + std::filesystem::path m_ContentRoot; + std::thread m_IoRunner; + asio::io_context m_IoContext; + asio::steady_timer m_PidCheckTimer{m_IoContext}; + zen::ProcessMonitor m_ProcessMonitor; + zen::NamedMutex m_ServerMutex; + + enum ServerState + { + kInitializing, + kRunning, + kShuttingDown + } m_CurrentState = kInitializing; + + inline void SetNewState(ServerState NewState) { m_CurrentState = NewState; } + + std::string_view ToString(ServerState Value) + { + switch (Value) + { + case kInitializing: + return "initializing"sv; + case kRunning: + return "running"sv; + case kShuttingDown: + return "shutdown"sv; + default: + return "unknown"sv; + } + } + + zen::Ref<zen::HttpServer> m_Http; + std::unique_ptr<zen::WebSocketServer> m_WebSocket; + std::unique_ptr<zen::AuthMgr> m_AuthMgr; + std::unique_ptr<zen::HttpAuthService> m_AuthService; + zen::HttpStatusService m_StatusService; + zen::HttpStatsService m_StatsService; + zen::GcManager m_GcManager; + zen::GcScheduler m_GcScheduler{m_GcManager}; + std::unique_ptr<zen::CidStore> m_CidStore; + std::unique_ptr<zen::ZenCacheStore> m_CacheStore; + zen::HttpTestService m_TestService; + zen::HttpTestingService m_TestingService; + std::unique_ptr<zen::HttpCidService> m_CidService; + zen::RefPtr<zen::ProjectStore> m_ProjectStore; + std::unique_ptr<zen::HttpProjectService> m_HttpProjectService; + std::unique_ptr<zen::UpstreamCache> m_UpstreamCache; + std::unique_ptr<zen::HttpUpstreamService> m_UpstreamService; + std::unique_ptr<zen::HttpStructuredCacheService> m_StructuredCacheService; + zen::HttpAdminService m_AdminService{m_GcScheduler}; + zen::HttpHealthService m_HealthService; +#if ZEN_WITH_COMPUTE_SERVICES + std::unique_ptr<zen::HttpFunctionService> m_HttpFunctionService; +#endif // ZEN_WITH_COMPUTE_SERVICES + std::unique_ptr<zen::HttpFrontendService> m_FrontendService; + std::unique_ptr<zen::HttpObjectStoreService> m_ObjStoreService; + + bool m_DebugOptionForcedCrash = false; + bool m_UseSentry = false; +}; + +void +ZenServer::OnReady() +{ + m_ServerEntry->SignalReady(); + + if (m_IsReadyFunc) + { + m_IsReadyFunc(); + } +} + +void +ZenServer::InitializeState(const ZenServerOptions& ServerOptions) +{ + // Check root manifest to deal with schema versioning + + bool WipeState = false; + std::string WipeReason = "Unspecified"; + + bool UpdateManifest = false; + std::filesystem::path ManifestPath = m_DataRoot / "root_manifest"; + FileContents ManifestData = zen::ReadFile(ManifestPath); + + if (ManifestData.ErrorCode) + { + if (ServerOptions.IsFirstRun) + { + ZEN_INFO("Initializing state at '{}'", m_DataRoot); + + UpdateManifest = true; + } + else + { + WipeState = true; + WipeReason = fmt::format("No manifest present at '{}'", ManifestPath); + } + } + else + { + IoBuffer Manifest = ManifestData.Flatten(); + + if (CbValidateError ValidationResult = ValidateCompactBinary(Manifest, CbValidateMode::All); + ValidationResult != CbValidateError::None) + { + ZEN_WARN("Manifest validation failed: {}, state will be wiped", uint32_t(ValidationResult)); + + WipeState = true; + WipeReason = fmt::format("Validation of manifest at '{}' failed: {}", ManifestPath, uint32_t(ValidationResult)); + } + else + { + m_RootManifest = LoadCompactBinaryObject(Manifest); + + const int32_t ManifestVersion = m_RootManifest["schema_version"].AsInt32(0); + + if (ManifestVersion != ZEN_CFG_SCHEMA_VERSION) + { + WipeState = true; + WipeReason = fmt::format("Manifest schema version: {}, differs from required: {}", ManifestVersion, ZEN_CFG_SCHEMA_VERSION); + } + } + } + + // Release any open handles so we can overwrite the manifest + ManifestData = {}; + + // Handle any state wipe + + if (WipeState) + { + ZEN_WARN("Wiping state at '{}' - reason: '{}'", m_DataRoot, WipeReason); + + std::error_code Ec; + for (const std::filesystem::directory_entry& DirEntry : std::filesystem::directory_iterator{m_DataRoot, Ec}) + { + if (DirEntry.is_directory() && (DirEntry.path().filename() != "logs")) + { + ZEN_INFO("Deleting '{}'", DirEntry.path()); + + std::filesystem::remove_all(DirEntry.path(), Ec); + + if (Ec) + { + ZEN_WARN("Delete of '{}' returned error: '{}'", DirEntry.path(), Ec.message()); + } + } + } + + ZEN_INFO("Wiped all directories in data root"); + + UpdateManifest = true; + } + + if (UpdateManifest) + { + // Write new manifest + + const DateTime Now = DateTime::Now(); + + CbObjectWriter Cbo; + Cbo << "schema_version" << ZEN_CFG_SCHEMA_VERSION << "created" << Now << "updated" << Now << "state_id" << Oid::NewOid(); + + m_RootManifest = Cbo.Save(); + + WriteFile(ManifestPath, m_RootManifest.GetBuffer().AsIoBuffer()); + } +} + +void +ZenServer::InitializeStructuredCache(const ZenServerOptions& ServerOptions) +{ + using namespace std::literals; + + ZEN_INFO("instantiating structured cache service"); + m_CacheStore = std::make_unique<ZenCacheStore>( + m_GcManager, + ZenCacheStore::Configuration{.BasePath = m_DataRoot / "cache", .AllowAutomaticCreationOfNamespaces = true}); + + const ZenUpstreamCacheConfig& UpstreamConfig = ServerOptions.UpstreamCacheConfig; + + zen::UpstreamCacheOptions UpstreamOptions; + UpstreamOptions.ReadUpstream = (uint8_t(ServerOptions.UpstreamCacheConfig.CachePolicy) & uint8_t(UpstreamCachePolicy::Read)) != 0; + UpstreamOptions.WriteUpstream = (uint8_t(ServerOptions.UpstreamCacheConfig.CachePolicy) & uint8_t(UpstreamCachePolicy::Write)) != 0; + + if (UpstreamConfig.UpstreamThreadCount < 32) + { + UpstreamOptions.ThreadCount = static_cast<uint32_t>(UpstreamConfig.UpstreamThreadCount); + } + + m_UpstreamCache = zen::UpstreamCache::Create(UpstreamOptions, *m_CacheStore, *m_CidStore); + m_UpstreamService = std::make_unique<HttpUpstreamService>(*m_UpstreamCache, *m_AuthMgr); + m_UpstreamCache->Initialize(); + + if (ServerOptions.UpstreamCacheConfig.CachePolicy != UpstreamCachePolicy::Disabled) + { + // Zen upstream + { + std::vector<std::string> ZenUrls = UpstreamConfig.ZenConfig.Urls; + if (!UpstreamConfig.ZenConfig.Dns.empty()) + { + for (const std::string& Dns : UpstreamConfig.ZenConfig.Dns) + { + if (!Dns.empty()) + { + const asio::error_code Err = zen::utils::ResolveHostname(m_IoContext, Dns, "1337"sv, ZenUrls); + if (Err) + { + ZEN_ERROR("resolve FAILED, reason '{}'", Err.message()); + } + } + } + } + + std::erase_if(ZenUrls, [](const auto& Url) { return Url.empty(); }); + + if (!ZenUrls.empty()) + { + const auto ZenEndpointName = UpstreamConfig.ZenConfig.Name.empty() ? "Zen"sv : UpstreamConfig.ZenConfig.Name; + + std::unique_ptr<zen::UpstreamEndpoint> ZenEndpoint = zen::UpstreamEndpoint::CreateZenEndpoint( + {.Name = ZenEndpointName, + .Urls = ZenUrls, + .ConnectTimeout = std::chrono::milliseconds(UpstreamConfig.ConnectTimeoutMilliseconds), + .Timeout = std::chrono::milliseconds(UpstreamConfig.TimeoutMilliseconds)}); + + m_UpstreamCache->RegisterEndpoint(std::move(ZenEndpoint)); + } + } + + // Jupiter upstream + if (UpstreamConfig.JupiterConfig.Url.empty() == false) + { + std::string_view EndpointName = UpstreamConfig.JupiterConfig.Name.empty() ? "Jupiter"sv : UpstreamConfig.JupiterConfig.Name; + + auto Options = + zen::CloudCacheClientOptions{.Name = EndpointName, + .ServiceUrl = UpstreamConfig.JupiterConfig.Url, + .DdcNamespace = UpstreamConfig.JupiterConfig.DdcNamespace, + .BlobStoreNamespace = UpstreamConfig.JupiterConfig.Namespace, + .ConnectTimeout = std::chrono::milliseconds(UpstreamConfig.ConnectTimeoutMilliseconds), + .Timeout = std::chrono::milliseconds(UpstreamConfig.TimeoutMilliseconds)}; + + auto AuthConfig = zen::UpstreamAuthConfig{.OAuthUrl = UpstreamConfig.JupiterConfig.OAuthUrl, + .OAuthClientId = UpstreamConfig.JupiterConfig.OAuthClientId, + .OAuthClientSecret = UpstreamConfig.JupiterConfig.OAuthClientSecret, + .OpenIdProvider = UpstreamConfig.JupiterConfig.OpenIdProvider, + .AccessToken = UpstreamConfig.JupiterConfig.AccessToken}; + + std::unique_ptr<zen::UpstreamEndpoint> JupiterEndpoint = + zen::UpstreamEndpoint::CreateJupiterEndpoint(Options, AuthConfig, *m_AuthMgr); + + m_UpstreamCache->RegisterEndpoint(std::move(JupiterEndpoint)); + } + } + + m_StructuredCacheService = + std::make_unique<HttpStructuredCacheService>(*m_CacheStore, *m_CidStore, m_StatsService, m_StatusService, *m_UpstreamCache); + + m_Http->RegisterService(*m_StructuredCacheService); + m_Http->RegisterService(*m_UpstreamService); +} + +#if ZEN_WITH_COMPUTE_SERVICES +void +ZenServer::InitializeCompute(const ZenServerOptions& ServerOptions) +{ + ServerOptions; + const ZenUpstreamCacheConfig& UpstreamConfig = ServerOptions.UpstreamCacheConfig; + + // Horde compute upstream + if (UpstreamConfig.HordeConfig.Url.empty() == false && UpstreamConfig.HordeConfig.StorageUrl.empty() == false) + { + ZEN_INFO("instantiating compute service"); + + std::string_view EndpointName = UpstreamConfig.HordeConfig.Name.empty() ? "Horde"sv : UpstreamConfig.HordeConfig.Name; + + auto ComputeOptions = + zen::CloudCacheClientOptions{.Name = EndpointName, + .ServiceUrl = UpstreamConfig.HordeConfig.Url, + .ComputeCluster = UpstreamConfig.HordeConfig.Cluster, + .ConnectTimeout = std::chrono::milliseconds(UpstreamConfig.ConnectTimeoutMilliseconds), + .Timeout = std::chrono::milliseconds(UpstreamConfig.TimeoutMilliseconds)}; + + auto ComputeAuthConfig = zen::UpstreamAuthConfig{.OAuthUrl = UpstreamConfig.HordeConfig.OAuthUrl, + .OAuthClientId = UpstreamConfig.HordeConfig.OAuthClientId, + .OAuthClientSecret = UpstreamConfig.HordeConfig.OAuthClientSecret, + .OpenIdProvider = UpstreamConfig.HordeConfig.OpenIdProvider, + .AccessToken = UpstreamConfig.HordeConfig.AccessToken}; + + auto StorageOptions = + zen::CloudCacheClientOptions{.Name = EndpointName, + .ServiceUrl = UpstreamConfig.HordeConfig.StorageUrl, + .BlobStoreNamespace = UpstreamConfig.HordeConfig.Namespace, + .ConnectTimeout = std::chrono::milliseconds(UpstreamConfig.ConnectTimeoutMilliseconds), + .Timeout = std::chrono::milliseconds(UpstreamConfig.TimeoutMilliseconds)}; + + auto StorageAuthConfig = zen::UpstreamAuthConfig{.OAuthUrl = UpstreamConfig.HordeConfig.StorageOAuthUrl, + .OAuthClientId = UpstreamConfig.HordeConfig.StorageOAuthClientId, + .OAuthClientSecret = UpstreamConfig.HordeConfig.StorageOAuthClientSecret, + .OpenIdProvider = UpstreamConfig.HordeConfig.StorageOpenIdProvider, + .AccessToken = UpstreamConfig.HordeConfig.StorageAccessToken}; + + m_HttpFunctionService = std::make_unique<zen::HttpFunctionService>(*m_CidStore, + ComputeOptions, + StorageOptions, + ComputeAuthConfig, + StorageAuthConfig, + *m_AuthMgr); + } + else + { + ZEN_INFO("NOT instantiating compute service (missing Horde or Storage config)"); + } +} +#endif // ZEN_WITH_COMPUTE_SERVICES + +//////////////////////////////////////////////////////////////////////////////// + +class ZenEntryPoint +{ +public: + ZenEntryPoint(ZenServerOptions& ServerOptions); + ZenEntryPoint(const ZenEntryPoint&) = delete; + ZenEntryPoint& operator=(const ZenEntryPoint&) = delete; + int Run(); + +private: + ZenServerOptions& m_ServerOptions; + zen::LockFile m_LockFile; +}; + +ZenEntryPoint::ZenEntryPoint(ZenServerOptions& ServerOptions) : m_ServerOptions(ServerOptions) +{ +} + +#if ZEN_USE_SENTRY +static void +SentryLogFunction(sentry_level_t Level, const char* Message, va_list Args, [[maybe_unused]] void* Userdata) +{ + char LogMessageBuffer[160]; + std::string LogMessage; + const char* MessagePtr = LogMessageBuffer; + + int n = vsnprintf(LogMessageBuffer, sizeof LogMessageBuffer, Message, Args); + + if (n >= int(sizeof LogMessageBuffer)) + { + LogMessage.resize(n + 1); + + n = vsnprintf(LogMessage.data(), LogMessage.size(), Message, Args); + + MessagePtr = LogMessage.c_str(); + } + + switch (Level) + { + case SENTRY_LEVEL_DEBUG: + ConsoleLog().debug("sentry: {}", MessagePtr); + break; + + case SENTRY_LEVEL_INFO: + ConsoleLog().info("sentry: {}", MessagePtr); + break; + + case SENTRY_LEVEL_WARNING: + ConsoleLog().warn("sentry: {}", MessagePtr); + break; + + case SENTRY_LEVEL_ERROR: + ConsoleLog().error("sentry: {}", MessagePtr); + break; + + case SENTRY_LEVEL_FATAL: + ConsoleLog().critical("sentry: {}", MessagePtr); + break; + } +} +#endif + +int +ZenEntryPoint::Run() +{ +#if ZEN_USE_SENTRY + std::string SentryDatabasePath = PathToUtf8(m_ServerOptions.DataDir / ".sentry-native"); + int SentryErrorCode = 0; + if (m_ServerOptions.NoSentry == false) + { + sentry_options_t* SentryOptions = sentry_options_new(); + sentry_options_set_dsn(SentryOptions, "https://[email protected]/5919284"); + if (SentryDatabasePath.starts_with("\\\\?\\")) + { + SentryDatabasePath = SentryDatabasePath.substr(4); + } + sentry_options_set_database_path(SentryOptions, SentryDatabasePath.c_str()); + sentry_options_set_logger(SentryOptions, SentryLogFunction, this); + std::string SentryAttachmentPath = m_ServerOptions.AbsLogFile.string(); + if (SentryAttachmentPath.starts_with("\\\\?\\")) + { + SentryAttachmentPath = SentryAttachmentPath.substr(4); + } + sentry_options_add_attachment(SentryOptions, SentryAttachmentPath.c_str()); + sentry_options_set_release(SentryOptions, ZEN_CFG_VERSION); + // sentry_options_set_debug(SentryOptions, 1); + + SentryErrorCode = sentry_init(SentryOptions); + + auto SentrySink = spdlog::create<utils::sentry_sink>("sentry"); + zen::logging::SetErrorLog(std::move(SentrySink)); + } + + auto _ = zen::MakeGuard([] { + zen::logging::SetErrorLog(std::shared_ptr<spdlog::logger>()); + sentry_close(); + }); +#endif + + auto& ServerOptions = m_ServerOptions; + + try + { + // Mutual exclusion and synchronization + ZenServerState ServerState; + ServerState.Initialize(); + ServerState.Sweep(); + + ZenServerState::ZenServerEntry* Entry = ServerState.Lookup(ServerOptions.BasePort); + + if (Entry) + { + if (ServerOptions.OwnerPid) + { + ConsoleLog().info( + "Looks like there is already a process listening to this port {} (pid: {}), attaching owner pid {} to running instance", + ServerOptions.BasePort, + Entry->Pid, + ServerOptions.OwnerPid); + + Entry->AddSponsorProcess(ServerOptions.OwnerPid); + + std::exit(0); + } + else + { + ConsoleLog().warn("Exiting since there is already a process listening to port {} (pid: {})", + ServerOptions.BasePort, + Entry->Pid); + std::exit(1); + } + } + + std::error_code Ec; + + std::filesystem::path LockFilePath = ServerOptions.DataDir / ".lock"; + + bool IsReady = false; + + auto MakeLockData = [&] { + CbObjectWriter Cbo; + Cbo << "pid" << zen::GetCurrentProcessId() << "data" << PathToUtf8(ServerOptions.DataDir) << "port" << ServerOptions.BasePort + << "session_id" << GetSessionId() << "ready" << IsReady; + return Cbo.Save(); + }; + + m_LockFile.Create(LockFilePath, MakeLockData(), Ec); + + if (Ec) + { + ConsoleLog().warn("ERROR: Unable to grab lock at '{}' (error: '{}')", LockFilePath, Ec.message()); + + std::exit(99); + } + + InitializeLogging(ServerOptions); + +#if ZEN_USE_SENTRY + if (m_ServerOptions.NoSentry == false) + { + if (SentryErrorCode == 0) + { + ZEN_INFO("sentry initialized"); + } + else + { + ZEN_WARN("sentry_init returned failure! (error code: {})", SentryErrorCode); + } + } +#endif + + MaximizeOpenFileCount(); + + ZEN_INFO(ZEN_APP_NAME " - using lock file at '{}'", LockFilePath); + + ZEN_INFO(ZEN_APP_NAME " - starting on port {}, version '{}'", ServerOptions.BasePort, ZEN_CFG_VERSION_BUILD_STRING_FULL); + + Entry = ServerState.Register(ServerOptions.BasePort); + + if (ServerOptions.OwnerPid) + { + Entry->AddSponsorProcess(ServerOptions.OwnerPid); + } + + ZenServer Server; + Server.SetDataRoot(ServerOptions.DataDir); + Server.SetContentRoot(ServerOptions.ContentDir); + Server.SetTestMode(ServerOptions.IsTest); + Server.SetDedicatedMode(ServerOptions.IsDedicated); + + int EffectiveBasePort = Server.Initialize(ServerOptions, Entry); + + Entry->EffectiveListenPort = uint16_t(EffectiveBasePort); + if (EffectiveBasePort != ServerOptions.BasePort) + { + ZEN_INFO(ZEN_APP_NAME " - relocated to base port {}", EffectiveBasePort); + ServerOptions.BasePort = EffectiveBasePort; + } + + std::unique_ptr<std::thread> ShutdownThread; + std::unique_ptr<zen::NamedEvent> ShutdownEvent; + + zen::ExtendableStringBuilder<64> ShutdownEventName; + ShutdownEventName << "Zen_" << ServerOptions.BasePort << "_Shutdown"; + ShutdownEvent.reset(new zen::NamedEvent{ShutdownEventName}); + + // Monitor shutdown signals + + ShutdownThread.reset(new std::thread{[&] { + ZEN_INFO("shutdown monitor thread waiting for shutdown signal '{}'", ShutdownEventName); + if (ShutdownEvent->Wait()) + { + ZEN_INFO("shutdown signal received"); + Server.RequestExit(0); + } + else + { + ZEN_INFO("shutdown signal wait() failed"); + } + }}); + + // If we have a parent process, establish the mechanisms we need + // to be able to communicate readiness with the parent + + Server.SetIsReadyFunc([&] { + IsReady = true; + + m_LockFile.Update(MakeLockData(), Ec); + + if (!ServerOptions.ChildId.empty()) + { + zen::NamedEvent ParentEvent{ServerOptions.ChildId}; + ParentEvent.Set(); + } + }); + + Server.Run(); + Server.Cleanup(); + + ShutdownEvent->Set(); + ShutdownThread->join(); + } + catch (std::exception& e) + { + SPDLOG_CRITICAL("Caught exception in main: {}", e.what()); + } + + ShutdownLogging(); + + return 0; +} + +} // namespace zen + +//////////////////////////////////////////////////////////////////////////////// + +#if ZEN_PLATFORM_WINDOWS + +class ZenWindowsService : public WindowsService +{ +public: + ZenWindowsService(ZenServerOptions& ServerOptions) : m_EntryPoint(ServerOptions) {} + + ZenWindowsService(const ZenWindowsService&) = delete; + ZenWindowsService& operator=(const ZenWindowsService&) = delete; + + virtual int Run() override; + +private: + zen::ZenEntryPoint m_EntryPoint; +}; + +int +ZenWindowsService::Run() +{ + return m_EntryPoint.Run(); +} + +#endif // ZEN_PLATFORM_WINDOWS + +//////////////////////////////////////////////////////////////////////////////// + +#if ZEN_WITH_TESTS +int +test_main(int argc, char** argv) +{ + zen::zencore_forcelinktests(); + zen::zenhttp_forcelinktests(); + zen::zenstore_forcelinktests(); + zen::z$_forcelink(); + zen::z$service_forcelink(); + + zen::logging::InitializeLogging(); + spdlog::set_level(spdlog::level::debug); + + zen::MaximizeOpenFileCount(); + + return ZEN_RUN_TESTS(argc, argv); +} +#endif + +int +main(int argc, char* argv[]) +{ + using namespace zen; + +#if ZEN_USE_MIMALLOC + mi_version(); +#endif + +#if ZEN_WITH_TESTS + if (argc >= 2) + { + if (argv[1] == "test"sv) + { + return test_main(argc, argv); + } + } +#endif + + try + { + ZenServerOptions ServerOptions; + ParseCliOptions(argc, argv, ServerOptions); + + if (!std::filesystem::exists(ServerOptions.DataDir)) + { + ServerOptions.IsFirstRun = true; + std::filesystem::create_directories(ServerOptions.DataDir); + } + +#if ZEN_WITH_TRACE + if (ServerOptions.TraceHost.size()) + { + TraceInit(ServerOptions.TraceHost.c_str(), TraceType::Network); + } + else if (ServerOptions.TraceFile.size()) + { + TraceInit(ServerOptions.TraceFile.c_str(), TraceType::File); + } + else + { + TraceInit(nullptr, TraceType::None); + } +#endif // ZEN_WITH_TRACE + +#if ZEN_PLATFORM_WINDOWS + if (ServerOptions.InstallService) + { + WindowsService::Install(); + + std::exit(0); + } + + if (ServerOptions.UninstallService) + { + WindowsService::Delete(); + + std::exit(0); + } + + ZenWindowsService App(ServerOptions); + return App.ServiceMain(); +#else + if (ServerOptions.InstallService || ServerOptions.UninstallService) + { + throw std::runtime_error("Service mode is not supported on this platform"); + } + + ZenEntryPoint App(ServerOptions); + return App.Run(); +#endif // ZEN_PLATFORM_WINDOWS + } + catch (std::exception& Ex) + { + fprintf(stderr, "ERROR: Caught exception in main: '%s'", Ex.what()); + + return 1; + } +} diff --git a/src/zenserver/zenserver.rc b/src/zenserver/zenserver.rc new file mode 100644 index 000000000..6d31e2c6e --- /dev/null +++ b/src/zenserver/zenserver.rc @@ -0,0 +1,105 @@ +// Microsoft Visual C++ generated resource script. +// +#include "resource.h" + +#include "zencore/config.h" + +#define APSTUDIO_READONLY_SYMBOLS +///////////////////////////////////////////////////////////////////////////// +// +// Generated from the TEXTINCLUDE 2 resource. +// +#include "winres.h" + +///////////////////////////////////////////////////////////////////////////// +#undef APSTUDIO_READONLY_SYMBOLS + +///////////////////////////////////////////////////////////////////////////// +// English (United States) resources + +#if !defined(AFX_RESOURCE_DLL) || defined(AFX_TARG_ENU) +LANGUAGE LANG_ENGLISH, SUBLANG_ENGLISH_US +#pragma code_page(1252) + +///////////////////////////////////////////////////////////////////////////// +// +// Icon +// + +// Icon with lowest ID value placed first to ensure application icon +// remains consistent on all systems. +IDI_ICON1 ICON "..\\UnrealEngine.ico" + +#endif // English (United States) resources +///////////////////////////////////////////////////////////////////////////// + + +///////////////////////////////////////////////////////////////////////////// +// English (United Kingdom) resources + +#if !defined(AFX_RESOURCE_DLL) || defined(AFX_TARG_ENG) +LANGUAGE LANG_ENGLISH, SUBLANG_ENGLISH_UK +#pragma code_page(1252) + +#ifdef APSTUDIO_INVOKED +///////////////////////////////////////////////////////////////////////////// +// +// TEXTINCLUDE +// + +1 TEXTINCLUDE +BEGIN + "resource.h\0" +END + +2 TEXTINCLUDE +BEGIN + "#include ""winres.h""\r\n" + "\0" +END + +3 TEXTINCLUDE +BEGIN + "\r\n" + "\0" +END + +#endif // APSTUDIO_INVOKED + +#endif // English (United Kingdom) resources +///////////////////////////////////////////////////////////////////////////// + + + +#ifndef APSTUDIO_INVOKED +///////////////////////////////////////////////////////////////////////////// +// +// Generated from the TEXTINCLUDE 3 resource. +// + + +///////////////////////////////////////////////////////////////////////////// +#endif // not APSTUDIO_INVOKED + +VS_VERSION_INFO VERSIONINFO +FILEVERSION ZEN_CFG_VERSION_MAJOR,ZEN_CFG_VERSION_MINOR,ZEN_CFG_VERSION_ALTER,0 +PRODUCTVERSION ZEN_CFG_VERSION_MAJOR,ZEN_CFG_VERSION_MINOR,ZEN_CFG_VERSION_ALTER,0 +{ + BLOCK "StringFileInfo" + { + BLOCK "040904b0" + { + VALUE "CompanyName", "Epic Games Inc\0" + VALUE "FileDescription", "Local Storage Service for Unreal Engine\0" + VALUE "FileVersion", ZEN_CFG_VERSION "\0" + VALUE "LegalCopyright", "Copyright Epic Games Inc. All Rights Reserved\0" + VALUE "OriginalFilename", "zenserver.exe\0" + VALUE "ProductName", "Zen Storage Server\0" + VALUE "ProductVersion", ZEN_CFG_VERSION_BUILD_STRING_FULL "\0" + } + } + BLOCK "VarFileInfo" + { + VALUE "Translation", 0x409, 1200 + } +} |