aboutsummaryrefslogtreecommitdiff
path: root/src/zenserver
diff options
context:
space:
mode:
authorStefan Boberg <[email protected]>2023-05-02 10:01:47 +0200
committerGitHub <[email protected]>2023-05-02 10:01:47 +0200
commit075d17f8ada47e990fe94606c3d21df409223465 (patch)
treee50549b766a2f3c354798a54ff73404217b4c9af /src/zenserver
parentfix: bundle shouldn't append content zip to zen (diff)
downloadzen-075d17f8ada47e990fe94606c3d21df409223465.tar.xz
zen-075d17f8ada47e990fe94606c3d21df409223465.zip
moved source directories into `/src` (#264)
* moved source directories into `/src` * updated bundle.lua for new `src` path * moved some docs, icon * removed old test trees
Diffstat (limited to 'src/zenserver')
-rw-r--r--src/zenserver/admin/admin.cpp101
-rw-r--r--src/zenserver/admin/admin.h26
-rw-r--r--src/zenserver/auth/authmgr.cpp506
-rw-r--r--src/zenserver/auth/authmgr.h56
-rw-r--r--src/zenserver/auth/authservice.cpp91
-rw-r--r--src/zenserver/auth/authservice.h25
-rw-r--r--src/zenserver/auth/oidc.cpp127
-rw-r--r--src/zenserver/auth/oidc.h76
-rw-r--r--src/zenserver/cache/cachetracking.cpp376
-rw-r--r--src/zenserver/cache/cachetracking.h41
-rw-r--r--src/zenserver/cache/structuredcache.cpp3159
-rw-r--r--src/zenserver/cache/structuredcache.h187
-rw-r--r--src/zenserver/cache/structuredcachestore.cpp3648
-rw-r--r--src/zenserver/cache/structuredcachestore.h535
-rw-r--r--src/zenserver/cidstore.cpp124
-rw-r--r--src/zenserver/cidstore.h35
-rw-r--r--src/zenserver/compute/function.cpp629
-rw-r--r--src/zenserver/compute/function.h73
-rw-r--r--src/zenserver/config.cpp902
-rw-r--r--src/zenserver/config.h158
-rw-r--r--src/zenserver/diag/diagsvcs.cpp127
-rw-r--r--src/zenserver/diag/diagsvcs.h111
-rw-r--r--src/zenserver/diag/formatters.h71
-rw-r--r--src/zenserver/diag/logging.cpp467
-rw-r--r--src/zenserver/diag/logging.h10
-rw-r--r--src/zenserver/frontend/frontend.cpp128
-rw-r--r--src/zenserver/frontend/frontend.h25
-rw-r--r--src/zenserver/frontend/html/index.html59
-rw-r--r--src/zenserver/frontend/zipfs.cpp169
-rw-r--r--src/zenserver/frontend/zipfs.h26
-rw-r--r--src/zenserver/monitoring/httpstats.cpp62
-rw-r--r--src/zenserver/monitoring/httpstats.h38
-rw-r--r--src/zenserver/monitoring/httpstatus.cpp62
-rw-r--r--src/zenserver/monitoring/httpstatus.h38
-rw-r--r--src/zenserver/objectstore/objectstore.cpp232
-rw-r--r--src/zenserver/objectstore/objectstore.h48
-rw-r--r--src/zenserver/projectstore/fileremoteprojectstore.cpp235
-rw-r--r--src/zenserver/projectstore/fileremoteprojectstore.h19
-rw-r--r--src/zenserver/projectstore/jupiterremoteprojectstore.cpp244
-rw-r--r--src/zenserver/projectstore/jupiterremoteprojectstore.h26
-rw-r--r--src/zenserver/projectstore/projectstore.cpp4082
-rw-r--r--src/zenserver/projectstore/projectstore.h372
-rw-r--r--src/zenserver/projectstore/remoteprojectstore.cpp1036
-rw-r--r--src/zenserver/projectstore/remoteprojectstore.h111
-rw-r--r--src/zenserver/projectstore/zenremoteprojectstore.cpp341
-rw-r--r--src/zenserver/projectstore/zenremoteprojectstore.h18
-rw-r--r--src/zenserver/resource.h18
-rw-r--r--src/zenserver/targetver.h10
-rw-r--r--src/zenserver/testing/httptest.cpp207
-rw-r--r--src/zenserver/testing/httptest.h55
-rw-r--r--src/zenserver/upstream/hordecompute.cpp1457
-rw-r--r--src/zenserver/upstream/jupiter.cpp965
-rw-r--r--src/zenserver/upstream/jupiter.h217
-rw-r--r--src/zenserver/upstream/upstream.h8
-rw-r--r--src/zenserver/upstream/upstreamapply.cpp459
-rw-r--r--src/zenserver/upstream/upstreamapply.h192
-rw-r--r--src/zenserver/upstream/upstreamcache.cpp2112
-rw-r--r--src/zenserver/upstream/upstreamcache.h252
-rw-r--r--src/zenserver/upstream/upstreamservice.cpp56
-rw-r--r--src/zenserver/upstream/upstreamservice.h27
-rw-r--r--src/zenserver/upstream/zen.cpp326
-rw-r--r--src/zenserver/upstream/zen.h125
-rw-r--r--src/zenserver/windows/service.cpp646
-rw-r--r--src/zenserver/windows/service.h20
-rw-r--r--src/zenserver/xmake.lua60
-rw-r--r--src/zenserver/zenserver.cpp1261
-rw-r--r--src/zenserver/zenserver.rc105
67 files changed, 27610 insertions, 0 deletions
diff --git a/src/zenserver/admin/admin.cpp b/src/zenserver/admin/admin.cpp
new file mode 100644
index 000000000..7aa1b48d1
--- /dev/null
+++ b/src/zenserver/admin/admin.cpp
@@ -0,0 +1,101 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "admin.h"
+
+#include <zencore/compactbinarybuilder.h>
+#include <zencore/string.h>
+#include <zenstore/gc.h>
+
+#include <chrono>
+
+namespace zen {
+
+HttpAdminService::HttpAdminService(GcScheduler& Scheduler) : m_GcScheduler(Scheduler)
+{
+ using namespace std::literals;
+
+ m_Router.RegisterRoute(
+ "health",
+ [](HttpRouterRequest& Req) {
+ CbObjectWriter Obj;
+ Obj.AddBool("ok", true);
+ Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Obj.Save());
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "gc",
+ [this](HttpRouterRequest& Req) {
+ const GcSchedulerStatus Status = m_GcScheduler.Status();
+
+ CbObjectWriter Response;
+ Response << "Status"sv << (GcSchedulerStatus::kIdle == Status ? "Idle"sv : "Running"sv);
+ Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Response.Save());
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "gc",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+ const HttpServerRequest::QueryParams Params = HttpReq.GetQueryParams();
+ GcScheduler::TriggerParams GcParams;
+
+ if (auto Param = Params.GetValue("smallobjects"); Param.empty() == false)
+ {
+ GcParams.CollectSmallObjects = Param == "true"sv;
+ }
+
+ if (auto Param = Params.GetValue("maxcacheduration"); Param.empty() == false)
+ {
+ if (auto Value = ParseInt<uint64_t>(Param))
+ {
+ GcParams.MaxCacheDuration = std::chrono::seconds(Value.value());
+ }
+ }
+
+ if (auto Param = Params.GetValue("disksizesoftlimit"); Param.empty() == false)
+ {
+ if (auto Value = ParseInt<uint64_t>(Param))
+ {
+ GcParams.DiskSizeSoftLimit = Value.value();
+ }
+ }
+
+ const bool Started = m_GcScheduler.Trigger(GcParams);
+
+ CbObjectWriter Response;
+ Response << "Status"sv << (Started ? "Started"sv : "Running"sv);
+ HttpReq.WriteResponse(HttpResponseCode::OK, Response.Save());
+ },
+ HttpVerb::kPost);
+
+ m_Router.RegisterRoute(
+ "",
+ [](HttpRouterRequest& Req) {
+ CbObject Payload = Req.ServerRequest().ReadPayloadObject();
+
+ CbObjectWriter Obj;
+ Obj.AddBool("ok", true);
+ Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Obj.Save());
+ },
+ HttpVerb::kPost);
+}
+
+HttpAdminService::~HttpAdminService()
+{
+}
+
+const char*
+HttpAdminService::BaseUri() const
+{
+ return "/admin/";
+}
+
+void
+HttpAdminService::HandleRequest(zen::HttpServerRequest& Request)
+{
+ m_Router.HandleRequest(Request);
+}
+
+} // namespace zen
diff --git a/src/zenserver/admin/admin.h b/src/zenserver/admin/admin.h
new file mode 100644
index 000000000..9463ffbb3
--- /dev/null
+++ b/src/zenserver/admin/admin.h
@@ -0,0 +1,26 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/compactbinary.h>
+#include <zenhttp/httpserver.h>
+
+namespace zen {
+
+class GcScheduler;
+
+class HttpAdminService : public zen::HttpService
+{
+public:
+ HttpAdminService(GcScheduler& Scheduler);
+ ~HttpAdminService();
+
+ virtual const char* BaseUri() const override;
+ virtual void HandleRequest(zen::HttpServerRequest& Request) override;
+
+private:
+ HttpRequestRouter m_Router;
+ GcScheduler& m_GcScheduler;
+};
+
+} // namespace zen
diff --git a/src/zenserver/auth/authmgr.cpp b/src/zenserver/auth/authmgr.cpp
new file mode 100644
index 000000000..4cd6b3362
--- /dev/null
+++ b/src/zenserver/auth/authmgr.cpp
@@ -0,0 +1,506 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <auth/authmgr.h>
+#include <auth/oidc.h>
+
+#include <zencore/compactbinary.h>
+#include <zencore/compactbinarybuilder.h>
+#include <zencore/compactbinaryvalidation.h>
+#include <zencore/crypto.h>
+#include <zencore/filesystem.h>
+#include <zencore/logging.h>
+
+#include <condition_variable>
+#include <memory>
+#include <shared_mutex>
+#include <thread>
+#include <unordered_map>
+
+#include <fmt/format.h>
+
+namespace zen {
+
+using namespace std::literals;
+
+namespace details {
+ IoBuffer ReadEncryptedFile(std::filesystem::path Path,
+ const AesKey256Bit& Key,
+ const AesIV128Bit& IV,
+ std::optional<std::string>& Reason)
+ {
+ FileContents Result = ReadFile(Path);
+
+ if (Result.ErrorCode)
+ {
+ return IoBuffer();
+ }
+
+ IoBuffer EncryptedBuffer = Result.Flatten();
+
+ if (EncryptedBuffer.GetSize() == 0)
+ {
+ return IoBuffer();
+ }
+
+ std::vector<uint8_t> DecryptionBuffer;
+ DecryptionBuffer.resize(EncryptedBuffer.GetSize() + Aes::BlockSize);
+
+ MemoryView DecryptedView = Aes::Decrypt(Key, IV, EncryptedBuffer, MakeMutableMemoryView(DecryptionBuffer), Reason);
+
+ if (DecryptedView.IsEmpty())
+ {
+ return IoBuffer();
+ }
+
+ return IoBufferBuilder::MakeCloneFromMemory(DecryptedView);
+ }
+
+ void WriteEncryptedFile(std::filesystem::path Path,
+ IoBuffer FileData,
+ const AesKey256Bit& Key,
+ const AesIV128Bit& IV,
+ std::optional<std::string>& Reason)
+ {
+ if (FileData.GetSize() == 0)
+ {
+ return;
+ }
+
+ std::vector<uint8_t> EncryptionBuffer;
+ EncryptionBuffer.resize(FileData.GetSize() + Aes::BlockSize);
+
+ MemoryView EncryptedView = Aes::Encrypt(Key, IV, FileData, MakeMutableMemoryView(EncryptionBuffer), Reason);
+
+ if (EncryptedView.IsEmpty())
+ {
+ return;
+ }
+
+ WriteFile(Path, IoBuffer(IoBuffer::Wrap, EncryptedView.GetData(), EncryptedView.GetSize()));
+ }
+} // namespace details
+
+class AuthMgrImpl final : public AuthMgr
+{
+ using Clock = std::chrono::system_clock;
+ using TimePoint = Clock::time_point;
+ using Seconds = std::chrono::seconds;
+
+public:
+ AuthMgrImpl(const AuthConfig& Config) : m_Config(Config), m_Log(logging::Get("auth"))
+ {
+ LoadState();
+
+ m_BackgroundThread.Interval = Config.UpdateInterval;
+ m_BackgroundThread.Thread = std::thread(&AuthMgrImpl::BackgroundThreadEntry, this);
+ }
+
+ virtual ~AuthMgrImpl() { Shutdown(); }
+
+ virtual void AddOpenIdProvider(const AddOpenIdProviderParams& Params) final
+ {
+ if (OpenIdProviderExist(Params.Name))
+ {
+ ZEN_DEBUG("OpenID provider '{}' already exist", Params.Name);
+ return;
+ }
+
+ if (Params.Name.empty())
+ {
+ ZEN_WARN("add OpenID provider FAILED, reason 'invalid name'");
+ return;
+ }
+
+ std::unique_ptr<OidcClient> Client =
+ std::make_unique<OidcClient>(OidcClient::Options{.BaseUrl = Params.Url, .ClientId = Params.ClientId});
+
+ if (const auto InitResult = Client->Initialize(); InitResult.Ok == false)
+ {
+ ZEN_WARN("query OpenID provider FAILED, reason '{}'", InitResult.Reason);
+ return;
+ }
+
+ std::string NewProviderName = std::string(Params.Name);
+
+ OpenIdProvider* NewProvider = nullptr;
+
+ {
+ std::unique_lock _(m_ProviderMutex);
+
+ if (m_OpenIdProviders.contains(NewProviderName))
+ {
+ return;
+ }
+
+ auto InsertResult = m_OpenIdProviders.emplace(NewProviderName, std::make_unique<OpenIdProvider>());
+ NewProvider = InsertResult.first->second.get();
+ }
+
+ NewProvider->Name = std::string(Params.Name);
+ NewProvider->Url = std::string(Params.Url);
+ NewProvider->ClientId = std::string(Params.ClientId);
+ NewProvider->HttpClient = std::move(Client);
+
+ ZEN_INFO("added OpenID provider '{} - {}'", Params.Name, Params.Url);
+ }
+
+ virtual bool AddOpenIdToken(const AddOpenIdTokenParams& Params) final
+ {
+ if (Params.ProviderName.empty())
+ {
+ ZEN_WARN("trying add OpenID token with invalid provider name");
+ return false;
+ }
+
+ if (Params.RefreshToken.empty())
+ {
+ ZEN_WARN("add OpenID token FAILED, reason 'Token invalid'");
+ return false;
+ }
+
+ auto RefreshResult = RefreshOpenIdToken(Params.ProviderName, Params.RefreshToken);
+
+ if (RefreshResult.Ok == false)
+ {
+ ZEN_WARN("refresh OpenId token FAILED, reason '{}'", RefreshResult.Reason);
+ return false;
+ }
+
+ bool IsNew = false;
+
+ {
+ auto Token = OpenIdToken{.IdentityToken = RefreshResult.IdentityToken,
+ .RefreshToken = RefreshResult.RefreshToken,
+ .AccessToken = fmt::format("Bearer {}"sv, RefreshResult.AccessToken),
+ .ExpireTime = Clock::now() + Seconds(RefreshResult.ExpiresInSeconds)};
+
+ std::unique_lock _(m_TokenMutex);
+
+ const auto InsertResult = m_OpenIdTokens.insert_or_assign(std::string(Params.ProviderName), std::move(Token));
+
+ IsNew = InsertResult.second;
+ }
+
+ if (IsNew)
+ {
+ ZEN_INFO("added new OpenID token for provider '{}'", Params.ProviderName);
+ }
+ else
+ {
+ ZEN_INFO("updating OpenID token for provider '{}'", Params.ProviderName);
+ }
+
+ return true;
+ }
+
+ virtual OpenIdAccessToken GetOpenIdAccessToken(std::string_view ProviderName) final
+ {
+ std::unique_lock _(m_TokenMutex);
+
+ if (auto It = m_OpenIdTokens.find(std::string(ProviderName)); It != m_OpenIdTokens.end())
+ {
+ const OpenIdToken& Token = It->second;
+
+ return {.AccessToken = Token.AccessToken, .ExpireTime = Token.ExpireTime};
+ }
+
+ return {};
+ }
+
+private:
+ bool OpenIdProviderExist(std::string_view ProviderName)
+ {
+ std::unique_lock _(m_ProviderMutex);
+
+ return m_OpenIdProviders.contains(std::string(ProviderName));
+ }
+
+ OidcClient& GetOpenIdClient(std::string_view ProviderName)
+ {
+ std::unique_lock _(m_ProviderMutex);
+ return *m_OpenIdProviders[std::string(ProviderName)]->HttpClient.get();
+ }
+
+ OidcClient::RefreshTokenResult RefreshOpenIdToken(std::string_view ProviderName, std::string_view RefreshToken)
+ {
+ if (OpenIdProviderExist(ProviderName) == false)
+ {
+ return {.Reason = fmt::format("provider '{}' is missing", ProviderName)};
+ }
+
+ OidcClient& Client = GetOpenIdClient(ProviderName);
+
+ return Client.RefreshToken(RefreshToken);
+ }
+
+ void Shutdown()
+ {
+ BackgroundThread::Stop(m_BackgroundThread);
+ SaveState();
+ }
+
+ void LoadState()
+ {
+ try
+ {
+ std::optional<std::string> Reason;
+
+ IoBuffer Buffer =
+ details::ReadEncryptedFile(m_Config.RootDirectory / "authstate"sv, m_Config.EncryptionKey, m_Config.EncryptionIV, Reason);
+
+ if (!Buffer)
+ {
+ if (Reason)
+ {
+ ZEN_WARN("load auth state FAILED, reason '{}'", Reason.value());
+ }
+
+ return;
+ }
+
+ const CbValidateError ValidationError = ValidateCompactBinary(Buffer, CbValidateMode::All);
+
+ if (ValidationError != CbValidateError::None)
+ {
+ ZEN_WARN("load serialized state FAILED, reason 'Invalid compact binary'");
+ return;
+ }
+
+ if (CbObject AuthState = LoadCompactBinaryObject(Buffer))
+ {
+ for (CbFieldView ProviderView : AuthState["OpenIdProviders"sv])
+ {
+ CbObjectView ProviderObj = ProviderView.AsObjectView();
+
+ std::string_view ProviderName = ProviderObj["Name"].AsString();
+ std::string_view Url = ProviderObj["Url"].AsString();
+ std::string_view ClientId = ProviderObj["ClientId"].AsString();
+
+ AddOpenIdProvider({.Name = ProviderName, .Url = Url, .ClientId = ClientId});
+ }
+
+ for (CbFieldView TokenView : AuthState["OpenIdTokens"sv])
+ {
+ CbObjectView TokenObj = TokenView.AsObjectView();
+
+ std::string_view ProviderName = TokenObj["ProviderName"sv].AsString();
+ std::string_view RefreshToken = TokenObj["RefreshToken"sv].AsString();
+
+ const bool Ok = AddOpenIdToken({.ProviderName = ProviderName, .RefreshToken = RefreshToken});
+
+ if (!Ok)
+ {
+ ZEN_WARN("load serialized OpenId token for provider '{}' FAILED", ProviderName);
+ }
+ }
+ }
+ }
+ catch (std::exception& Err)
+ {
+ ZEN_ERROR("(de)serialize state FAILED, reason '{}'", Err.what());
+
+ {
+ std::unique_lock _(m_ProviderMutex);
+ m_OpenIdProviders.clear();
+ }
+
+ {
+ std::unique_lock _(m_TokenMutex);
+ m_OpenIdTokens.clear();
+ }
+ }
+ }
+
+ void SaveState()
+ {
+ try
+ {
+ CbObjectWriter AuthState;
+
+ {
+ std::unique_lock _(m_ProviderMutex);
+
+ if (m_OpenIdProviders.size() > 0)
+ {
+ AuthState.BeginArray("OpenIdProviders");
+ for (const auto& Kv : m_OpenIdProviders)
+ {
+ AuthState.BeginObject();
+ AuthState << "Name"sv << Kv.second->Name;
+ AuthState << "Url"sv << Kv.second->Url;
+ AuthState << "ClientId"sv << Kv.second->ClientId;
+ AuthState.EndObject();
+ }
+ AuthState.EndArray();
+ }
+ }
+
+ {
+ std::unique_lock _(m_TokenMutex);
+
+ AuthState.BeginArray("OpenIdTokens");
+ if (m_OpenIdTokens.size() > 0)
+ {
+ for (const auto& Kv : m_OpenIdTokens)
+ {
+ AuthState.BeginObject();
+ AuthState << "ProviderName"sv << Kv.first;
+ AuthState << "RefreshToken"sv << Kv.second.RefreshToken;
+ AuthState.EndObject();
+ }
+ }
+ AuthState.EndArray();
+ }
+
+ std::filesystem::create_directories(m_Config.RootDirectory);
+
+ std::optional<std::string> Reason;
+
+ details::WriteEncryptedFile(m_Config.RootDirectory / "authstate"sv,
+ AuthState.Save().GetBuffer().AsIoBuffer(),
+ m_Config.EncryptionKey,
+ m_Config.EncryptionIV,
+ Reason);
+
+ if (Reason)
+ {
+ ZEN_WARN("save auth state FAILED, reason '{}'", Reason.value());
+ }
+ }
+ catch (std::exception& Err)
+ {
+ ZEN_ERROR("serialize state FAILED, reason '{}'", Err.what());
+ }
+ }
+
+ void BackgroundThreadEntry()
+ {
+ for (;;)
+ {
+ std::cv_status SignalStatus = BackgroundThread::WaitForSignal(m_BackgroundThread);
+
+ if (m_BackgroundThread.Running.load() == false)
+ {
+ break;
+ }
+
+ if (SignalStatus != std::cv_status::timeout)
+ {
+ continue;
+ }
+
+ {
+ // Refresh Open ID token(s)
+
+ std::vector<OpenIdTokenMap::value_type> ExpiredTokens;
+
+ {
+ std::unique_lock _(m_TokenMutex);
+
+ for (const auto& Kv : m_OpenIdTokens)
+ {
+ const Seconds ExpiresIn = std::chrono::duration_cast<Seconds>(Kv.second.ExpireTime - Clock::now());
+ const bool Expired = ExpiresIn < Seconds(m_BackgroundThread.Interval * 2);
+
+ if (Expired)
+ {
+ ExpiredTokens.push_back(Kv);
+ }
+ }
+ }
+
+ ZEN_DEBUG("refreshing '{}' OpenID token(s)", ExpiredTokens.size());
+
+ for (const auto& Kv : ExpiredTokens)
+ {
+ OidcClient::RefreshTokenResult RefreshResult = RefreshOpenIdToken(Kv.first, Kv.second.RefreshToken);
+
+ if (RefreshResult.Ok)
+ {
+ ZEN_DEBUG("refresh access token from provider '{}' Ok", Kv.first);
+
+ auto Token = OpenIdToken{.IdentityToken = RefreshResult.IdentityToken,
+ .RefreshToken = RefreshResult.RefreshToken,
+ .AccessToken = fmt::format("Bearer {}"sv, RefreshResult.AccessToken),
+ .ExpireTime = Clock::now() + Seconds(RefreshResult.ExpiresInSeconds)};
+
+ {
+ std::unique_lock _(m_TokenMutex);
+ m_OpenIdTokens.insert_or_assign(Kv.first, std::move(Token));
+ }
+ }
+ else
+ {
+ ZEN_WARN("refresh access token from provider '{}' FAILED, reason '{}'", Kv.first, RefreshResult.Reason);
+ }
+ }
+ }
+ }
+ }
+
+ struct BackgroundThread
+ {
+ std::chrono::seconds Interval{10};
+ std::mutex Mutex;
+ std::condition_variable Signal;
+ std::atomic_bool Running{true};
+ std::thread Thread;
+
+ static void Stop(BackgroundThread& State)
+ {
+ if (State.Running.load())
+ {
+ State.Running.store(false);
+ State.Signal.notify_one();
+ }
+
+ if (State.Thread.joinable())
+ {
+ State.Thread.join();
+ }
+ }
+
+ static std::cv_status WaitForSignal(BackgroundThread& State)
+ {
+ std::unique_lock Lock(State.Mutex);
+ return State.Signal.wait_for(Lock, State.Interval);
+ }
+ };
+
+ struct OpenIdProvider
+ {
+ std::string Name;
+ std::string Url;
+ std::string ClientId;
+ std::unique_ptr<OidcClient> HttpClient;
+ };
+
+ struct OpenIdToken
+ {
+ std::string IdentityToken;
+ std::string RefreshToken;
+ std::string AccessToken;
+ TimePoint ExpireTime{};
+ };
+
+ using OpenIdProviderMap = std::unordered_map<std::string, std::unique_ptr<OpenIdProvider>>;
+ using OpenIdTokenMap = std::unordered_map<std::string, OpenIdToken>;
+
+ spdlog::logger& Log() { return m_Log; }
+
+ AuthConfig m_Config;
+ spdlog::logger& m_Log;
+ BackgroundThread m_BackgroundThread;
+ OpenIdProviderMap m_OpenIdProviders;
+ OpenIdTokenMap m_OpenIdTokens;
+ std::mutex m_ProviderMutex;
+ std::shared_mutex m_TokenMutex;
+};
+
+std::unique_ptr<AuthMgr>
+AuthMgr::Create(const AuthConfig& Config)
+{
+ return std::make_unique<AuthMgrImpl>(Config);
+}
+
+} // namespace zen
diff --git a/src/zenserver/auth/authmgr.h b/src/zenserver/auth/authmgr.h
new file mode 100644
index 000000000..054588ab9
--- /dev/null
+++ b/src/zenserver/auth/authmgr.h
@@ -0,0 +1,56 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/crypto.h>
+#include <zencore/iobuffer.h>
+#include <zencore/string.h>
+
+#include <chrono>
+#include <filesystem>
+#include <memory>
+
+namespace zen {
+
+struct AuthConfig
+{
+ std::filesystem::path RootDirectory;
+ std::chrono::seconds UpdateInterval{30};
+ AesKey256Bit EncryptionKey;
+ AesIV128Bit EncryptionIV;
+};
+
+class AuthMgr
+{
+public:
+ virtual ~AuthMgr() = default;
+
+ struct AddOpenIdProviderParams
+ {
+ std::string_view Name;
+ std::string_view Url;
+ std::string_view ClientId;
+ };
+
+ virtual void AddOpenIdProvider(const AddOpenIdProviderParams& Params) = 0;
+
+ struct AddOpenIdTokenParams
+ {
+ std::string_view ProviderName;
+ std::string_view RefreshToken;
+ };
+
+ virtual bool AddOpenIdToken(const AddOpenIdTokenParams& Params) = 0;
+
+ struct OpenIdAccessToken
+ {
+ std::string AccessToken;
+ std::chrono::system_clock::time_point ExpireTime{};
+ };
+
+ virtual OpenIdAccessToken GetOpenIdAccessToken(std::string_view ProviderName) = 0;
+
+ static std::unique_ptr<AuthMgr> Create(const AuthConfig& Config);
+};
+
+} // namespace zen
diff --git a/src/zenserver/auth/authservice.cpp b/src/zenserver/auth/authservice.cpp
new file mode 100644
index 000000000..1cc679540
--- /dev/null
+++ b/src/zenserver/auth/authservice.cpp
@@ -0,0 +1,91 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <auth/authservice.h>
+
+#include <auth/authmgr.h>
+
+#include <zencore/compactbinarybuilder.h>
+#include <zencore/string.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <json11.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+namespace zen {
+
+using namespace std::literals;
+
+HttpAuthService::HttpAuthService(AuthMgr& AuthMgr) : m_AuthMgr(AuthMgr)
+{
+ m_Router.RegisterRoute(
+ "oidc/refreshtoken",
+ [this](HttpRouterRequest& RouterRequest) {
+ HttpServerRequest& ServerRequest = RouterRequest.ServerRequest();
+
+ const HttpContentType ContentType = ServerRequest.RequestContentType();
+
+ if ((ContentType == HttpContentType::kUnknownContentType || ContentType == HttpContentType::kJSON) == false)
+ {
+ return ServerRequest.WriteResponse(HttpResponseCode::BadRequest);
+ }
+
+ const IoBuffer Body = ServerRequest.ReadPayload();
+
+ std::string JsonText(reinterpret_cast<const char*>(Body.GetData()), Body.GetSize());
+ std::string JsonError;
+ json11::Json TokenInfo = json11::Json::parse(JsonText, JsonError);
+
+ if (!JsonError.empty())
+ {
+ CbObjectWriter Response;
+ Response << "Result"sv << false;
+ Response << "Error"sv << JsonError;
+
+ return ServerRequest.WriteResponse(HttpResponseCode::BadRequest, Response.Save());
+ }
+
+ const std::string RefreshToken = TokenInfo["RefreshToken"].string_value();
+ std::string ProviderName = TokenInfo["ProviderName"].string_value();
+
+ if (ProviderName.empty())
+ {
+ ProviderName = "Default"sv;
+ }
+
+ const bool Ok =
+ m_AuthMgr.AddOpenIdToken(AuthMgr::AddOpenIdTokenParams{.ProviderName = ProviderName, .RefreshToken = RefreshToken});
+
+ if (Ok)
+ {
+ ServerRequest.WriteResponse(Ok ? HttpResponseCode::OK : HttpResponseCode::BadRequest);
+ }
+ else
+ {
+ CbObjectWriter Response;
+ Response << "Result"sv << false;
+ Response << "Error"sv
+ << "Invalid token"sv;
+
+ ServerRequest.WriteResponse(HttpResponseCode::BadRequest, Response.Save());
+ }
+ },
+ HttpVerb::kPost);
+}
+
+HttpAuthService::~HttpAuthService()
+{
+}
+
+const char*
+HttpAuthService::BaseUri() const
+{
+ return "/auth/";
+}
+
+void
+HttpAuthService::HandleRequest(zen::HttpServerRequest& Request)
+{
+ m_Router.HandleRequest(Request);
+}
+
+} // namespace zen
diff --git a/src/zenserver/auth/authservice.h b/src/zenserver/auth/authservice.h
new file mode 100644
index 000000000..64b86e21f
--- /dev/null
+++ b/src/zenserver/auth/authservice.h
@@ -0,0 +1,25 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zenhttp/httpserver.h>
+
+namespace zen {
+
+class AuthMgr;
+
+class HttpAuthService final : public zen::HttpService
+{
+public:
+ HttpAuthService(AuthMgr& AuthMgr);
+ virtual ~HttpAuthService();
+
+ virtual const char* BaseUri() const override;
+ virtual void HandleRequest(zen::HttpServerRequest& Request) override;
+
+private:
+ AuthMgr& m_AuthMgr;
+ HttpRequestRouter m_Router;
+};
+
+} // namespace zen
diff --git a/src/zenserver/auth/oidc.cpp b/src/zenserver/auth/oidc.cpp
new file mode 100644
index 000000000..d2265c22f
--- /dev/null
+++ b/src/zenserver/auth/oidc.cpp
@@ -0,0 +1,127 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <auth/oidc.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <cpr/cpr.h>
+#include <fmt/format.h>
+#include <json11.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+namespace zen {
+
+namespace details {
+
+ using StringArray = std::vector<std::string>;
+
+ StringArray ToStringArray(const json11::Json JsonArray)
+ {
+ StringArray Result;
+
+ const auto& Items = JsonArray.array_items();
+
+ for (const auto& Item : Items)
+ {
+ Result.push_back(Item.string_value());
+ }
+
+ return Result;
+ }
+
+} // namespace details
+
+using namespace std::literals;
+
+OidcClient::OidcClient(const OidcClient::Options& Options)
+{
+ m_BaseUrl = std::string(Options.BaseUrl);
+ m_ClientId = std::string(Options.ClientId);
+}
+
+OidcClient::InitResult
+OidcClient::Initialize()
+{
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_BaseUrl << "/.well-known/openid-configuration"sv;
+
+ cpr::Session Session;
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+
+ cpr::Response Response = Session.Get();
+
+ if (Response.error)
+ {
+ return {.Reason = std::move(Response.error.message)};
+ }
+
+ if (Response.status_code != 200)
+ {
+ return {.Reason = std::move(Response.reason)};
+ }
+
+ std::string JsonError;
+ json11::Json Json = json11::Json::parse(Response.text, JsonError);
+
+ if (JsonError.empty() == false)
+ {
+ return {.Reason = std::move(JsonError)};
+ }
+
+ m_Config = {.Issuer = Json["issuer"].string_value(),
+ .AuthorizationEndpoint = Json["authorization_endpoint"].string_value(),
+ .TokenEndpoint = Json["token_endpoint"].string_value(),
+ .UserInfoEndpoint = Json["userinfo_endpoint"].string_value(),
+ .RegistrationEndpoint = Json["registration_endpoint"].string_value(),
+ .JwksUri = Json["jwks_uri"].string_value(),
+ .SupportedResponseTypes = details::ToStringArray(Json["response_types_supported"]),
+ .SupportedResponseModes = details::ToStringArray(Json["response_modes_supported"]),
+ .SupportedGrantTypes = details::ToStringArray(Json["grant_types_supported"]),
+ .SupportedScopes = details::ToStringArray(Json["scopes_supported"]),
+ .SupportedTokenEndpointAuthMethods = details::ToStringArray(Json["token_endpoint_auth_methods_supported"]),
+ .SupportedClaims = details::ToStringArray(Json["claims_supported"])};
+
+ return {.Ok = true};
+}
+
+OidcClient::RefreshTokenResult
+OidcClient::RefreshToken(std::string_view RefreshToken)
+{
+ const std::string Body = fmt::format("grant_type=refresh_token&refresh_token={}&client_id={}", RefreshToken, m_ClientId);
+
+ cpr::Session Session;
+
+ Session.SetOption(cpr::Url{m_Config.TokenEndpoint.c_str()});
+ Session.SetOption(cpr::Header{{"Content-Type", "application/x-www-form-urlencoded"}});
+ Session.SetBody(cpr::Body{Body.data(), Body.size()});
+
+ cpr::Response Response = Session.Post();
+
+ if (Response.error)
+ {
+ return {.Reason = std::move(Response.error.message)};
+ }
+
+ if (Response.status_code != 200)
+ {
+ return {.Reason = fmt::format("{} ({})", Response.reason, Response.text)};
+ }
+
+ std::string JsonError;
+ json11::Json Json = json11::Json::parse(Response.text, JsonError);
+
+ if (JsonError.empty() == false)
+ {
+ return {.Reason = std::move(JsonError)};
+ }
+
+ return {.TokenType = Json["token_type"].string_value(),
+ .AccessToken = Json["access_token"].string_value(),
+ .RefreshToken = Json["refresh_token"].string_value(),
+ .IdentityToken = Json["id_token"].string_value(),
+ .Scope = Json["scope"].string_value(),
+ .ExpiresInSeconds = static_cast<int64_t>(Json["expires_in"].int_value()),
+ .Ok = true};
+}
+
+} // namespace zen
diff --git a/src/zenserver/auth/oidc.h b/src/zenserver/auth/oidc.h
new file mode 100644
index 000000000..f43ae3cd7
--- /dev/null
+++ b/src/zenserver/auth/oidc.h
@@ -0,0 +1,76 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/string.h>
+
+#include <vector>
+
+namespace zen {
+
+class OidcClient
+{
+public:
+ struct Options
+ {
+ std::string_view BaseUrl;
+ std::string_view ClientId;
+ };
+
+ OidcClient(const Options& Options);
+ ~OidcClient() = default;
+
+ OidcClient(const OidcClient&) = delete;
+ OidcClient& operator=(const OidcClient&) = delete;
+
+ struct Result
+ {
+ std::string Reason;
+ bool Ok = false;
+ };
+
+ using InitResult = Result;
+
+ InitResult Initialize();
+
+ struct RefreshTokenResult
+ {
+ std::string TokenType;
+ std::string AccessToken;
+ std::string RefreshToken;
+ std::string IdentityToken;
+ std::string Scope;
+ std::string Reason;
+ int64_t ExpiresInSeconds{};
+ bool Ok = false;
+ };
+
+ RefreshTokenResult RefreshToken(std::string_view RefreshToken);
+
+private:
+ using StringArray = std::vector<std::string>;
+
+ struct OpenIdConfiguration
+ {
+ std::string Issuer;
+ std::string AuthorizationEndpoint;
+ std::string TokenEndpoint;
+ std::string UserInfoEndpoint;
+ std::string RegistrationEndpoint;
+ std::string EndSessionEndpoint;
+ std::string DeviceAuthorizationEndpoint;
+ std::string JwksUri;
+ StringArray SupportedResponseTypes;
+ StringArray SupportedResponseModes;
+ StringArray SupportedGrantTypes;
+ StringArray SupportedScopes;
+ StringArray SupportedTokenEndpointAuthMethods;
+ StringArray SupportedClaims;
+ };
+
+ std::string m_BaseUrl;
+ std::string m_ClientId;
+ OpenIdConfiguration m_Config;
+};
+
+} // namespace zen
diff --git a/src/zenserver/cache/cachetracking.cpp b/src/zenserver/cache/cachetracking.cpp
new file mode 100644
index 000000000..9119e3122
--- /dev/null
+++ b/src/zenserver/cache/cachetracking.cpp
@@ -0,0 +1,376 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "cachetracking.h"
+
+#if ZEN_USE_CACHE_TRACKER
+
+# include <zencore/compactbinarybuilder.h>
+# include <zencore/compactbinaryvalue.h>
+# include <zencore/endian.h>
+# include <zencore/filesystem.h>
+# include <zencore/logging.h>
+# include <zencore/scopeguard.h>
+# include <zencore/string.h>
+
+# include <zencore/testing.h>
+# include <zencore/testutils.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+# pragma comment(lib, "Rpcrt4.lib") // RocksDB made me do this
+# include <fmt/format.h>
+# include <rocksdb/db.h>
+# include <tsl/robin_map.h>
+# include <tsl/robin_set.h>
+# include <gsl/gsl-lite.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+namespace zen {
+
+namespace rocksdb = ROCKSDB_NAMESPACE;
+
+static constinit auto Epoch = std::chrono::time_point<std::chrono::system_clock>{};
+
+static uint64_t
+GetCurrentCacheTimeStamp()
+{
+ auto Duration = std::chrono::system_clock::now() - Epoch;
+ uint64_t Millis = std::chrono::duration_cast<std::chrono::milliseconds>(Duration).count();
+
+ return Millis;
+}
+
+struct CacheAccessSnapshot
+{
+public:
+ void TrackAccess(std::string_view BucketSegment, const IoHash& HashKey)
+ {
+ BucketTracker* Tracker = GetBucket(std::string(BucketSegment));
+
+ Tracker->Track(HashKey);
+ }
+
+ bool SerializeSnapshot(CbObjectWriter& Cbo)
+ {
+ bool Serialized = false;
+ RwLock::ExclusiveLockScope _(m_Lock);
+
+ for (const auto& Kv : m_BucketMapping)
+ {
+ if (m_Buckets[Kv.second]->Size())
+ {
+ Cbo.BeginArray(Kv.first);
+ m_Buckets[Kv.second]->SerializeSnapshotAndClear(Cbo);
+ Cbo.EndArray();
+ Serialized = true;
+ }
+ }
+
+ return Serialized;
+ }
+
+private:
+ struct BucketTracker
+ {
+ mutable RwLock Lock;
+ tsl::robin_set<IoHash> AccessedKeys;
+
+ void Track(const IoHash& HashKey)
+ {
+ if (RwLock::SharedLockScope _(Lock); AccessedKeys.contains(HashKey))
+ {
+ return;
+ }
+
+ RwLock::ExclusiveLockScope _(Lock);
+
+ AccessedKeys.insert(HashKey);
+ }
+
+ void SerializeSnapshotAndClear(CbObjectWriter& Cbo)
+ {
+ RwLock::ExclusiveLockScope _(Lock);
+
+ for (const IoHash& Hash : AccessedKeys)
+ {
+ Cbo.AddHash(Hash);
+ }
+
+ AccessedKeys.clear();
+ }
+
+ size_t Size() const
+ {
+ RwLock::SharedLockScope _(Lock);
+ return AccessedKeys.size();
+ }
+ };
+
+ BucketTracker* GetBucket(const std::string& BucketName)
+ {
+ RwLock::SharedLockScope _(m_Lock);
+
+ if (auto It = m_BucketMapping.find(BucketName); It == m_BucketMapping.end())
+ {
+ _.ReleaseNow();
+
+ return AddNewBucket(BucketName);
+ }
+ else
+ {
+ return m_Buckets[It->second].get();
+ }
+ }
+
+ BucketTracker* AddNewBucket(const std::string& BucketName)
+ {
+ RwLock::ExclusiveLockScope _(m_Lock);
+
+ if (auto It = m_BucketMapping.find(BucketName); It == m_BucketMapping.end())
+ {
+ const uint32_t BucketIndex = gsl::narrow<uint32_t>(m_Buckets.size());
+ m_Buckets.emplace_back(std::make_unique<BucketTracker>());
+ m_BucketMapping[BucketName] = BucketIndex;
+
+ return m_Buckets[BucketIndex].get();
+ }
+ else
+ {
+ return m_Buckets[It->second].get();
+ }
+ }
+
+ RwLock m_Lock;
+ std::vector<std::unique_ptr<BucketTracker>> m_Buckets;
+ tsl::robin_map<std::string, uint32_t> m_BucketMapping;
+};
+
+struct ZenCacheTracker::Impl
+{
+ Impl(std::filesystem::path StateDirectory)
+ {
+ std::filesystem::path StatsDbPath{StateDirectory / ".zdb"};
+
+ std::string RocksdbPath = StatsDbPath.string();
+
+ ZEN_DEBUG("opening tracker db at '{}'", RocksdbPath);
+
+ rocksdb::DB* Db = nullptr;
+ rocksdb::DBOptions Options;
+ Options.create_if_missing = true;
+
+ std::vector<std::string> ExistingColumnFamilies;
+ rocksdb::Status Status = rocksdb::DB::ListColumnFamilies(Options, RocksdbPath, &ExistingColumnFamilies);
+
+ std::vector<rocksdb::ColumnFamilyDescriptor> ColumnDescriptors;
+
+ if (Status.IsPathNotFound())
+ {
+ ColumnDescriptors.emplace_back(rocksdb::ColumnFamilyDescriptor{rocksdb::kDefaultColumnFamilyName, {}});
+ }
+ else if (Status.ok())
+ {
+ for (const std::string& Column : ExistingColumnFamilies)
+ {
+ rocksdb::ColumnFamilyDescriptor ColumnFamily;
+ ColumnFamily.name = Column;
+ ColumnDescriptors.push_back(ColumnFamily);
+ }
+ }
+ else
+ {
+ throw std::runtime_error(fmt::format("column family iteration failed for '{}': '{}'", RocksdbPath, Status.getState()).c_str());
+ }
+
+ Status = rocksdb::DB::Open(Options, RocksdbPath, ColumnDescriptors, &m_RocksDbColumnHandles, &Db);
+
+ if (!Status.ok())
+ {
+ throw std::runtime_error(fmt::format("database open failed for '{}': '{}'", RocksdbPath, Status.getState()).c_str());
+ }
+
+ m_RocksDb.reset(Db);
+ }
+
+ ~Impl()
+ {
+ for (auto* Column : m_RocksDbColumnHandles)
+ {
+ delete Column;
+ }
+
+ m_RocksDbColumnHandles.clear();
+ }
+
+ struct KeyStruct
+ {
+ uint64_t TimestampLittleEndian;
+ };
+
+ void TrackAccess(std::string_view BucketSegment, const IoHash& HashKey) { m_CurrentSnapshot.TrackAccess(BucketSegment, HashKey); }
+
+ void SaveSnapshot()
+ {
+ CbObjectWriter Cbo;
+
+ if (m_CurrentSnapshot.SerializeSnapshot(Cbo))
+ {
+ IoBuffer SnapshotBuffer = Cbo.Save().GetBuffer().AsIoBuffer();
+
+ const KeyStruct Key{.TimestampLittleEndian = ToNetworkOrder(GetCurrentCacheTimeStamp())};
+ rocksdb::Slice KeySlice{(const char*)&Key, sizeof Key};
+ rocksdb::Slice ValueSlice{(char*)SnapshotBuffer.Data(), SnapshotBuffer.Size()};
+
+ rocksdb::WriteOptions Wo;
+ m_RocksDb->Put(Wo, KeySlice, ValueSlice);
+ }
+ }
+
+ void IterateSnapshots(std::function<void(uint64_t TimeStamp, CbObject Snapshot)>&& Callback)
+ {
+ rocksdb::ManagedSnapshot Snap(m_RocksDb.get());
+
+ rocksdb::ReadOptions Ro;
+ Ro.snapshot = Snap.snapshot();
+
+ std::unique_ptr<rocksdb::Iterator> It{m_RocksDb->NewIterator(Ro)};
+
+ const KeyStruct ZeroKey{.TimestampLittleEndian = 0};
+ rocksdb::Slice ZeroKeySlice{(const char*)&ZeroKey, sizeof ZeroKey};
+
+ It->Seek(ZeroKeySlice);
+
+ while (It->Valid())
+ {
+ rocksdb::Slice KeySlice = It->key();
+ rocksdb::Slice ValueSlice = It->value();
+
+ if (KeySlice.size() == sizeof(KeyStruct))
+ {
+ IoBuffer ValueBuffer(IoBuffer::Wrap, ValueSlice.data(), ValueSlice.size());
+
+ CbObject Value = LoadCompactBinaryObject(ValueBuffer);
+
+ uint64_t Key = FromNetworkOrder(*reinterpret_cast<const uint64_t*>(KeySlice.data()));
+
+ Callback(Key, Value);
+ }
+
+ It->Next();
+ }
+ }
+
+ std::unique_ptr<rocksdb::DB> m_RocksDb;
+ std::vector<rocksdb::ColumnFamilyHandle*> m_RocksDbColumnHandles;
+ CacheAccessSnapshot m_CurrentSnapshot;
+};
+
+ZenCacheTracker::ZenCacheTracker(std::filesystem::path StateDirectory) : m_Impl(new Impl(StateDirectory))
+{
+}
+
+ZenCacheTracker::~ZenCacheTracker()
+{
+ delete m_Impl;
+}
+
+void
+ZenCacheTracker::TrackAccess(std::string_view BucketSegment, const IoHash& HashKey)
+{
+ m_Impl->TrackAccess(BucketSegment, HashKey);
+}
+
+void
+ZenCacheTracker::SaveSnapshot()
+{
+ m_Impl->SaveSnapshot();
+}
+
+void
+ZenCacheTracker::IterateSnapshots(std::function<void(uint64_t TimeStamp, CbObject Snapshot)>&& Callback)
+{
+ m_Impl->IterateSnapshots(std::move(Callback));
+}
+
+# if ZEN_WITH_TESTS
+
+TEST_CASE("z$.tracker")
+{
+ using namespace std::literals;
+
+ const uint64_t t0 = GetCurrentCacheTimeStamp();
+
+ ScopedTemporaryDirectory TempDir;
+
+ ZenCacheTracker Zcs(TempDir.Path());
+
+ tsl::robin_set<IoHash> KeyHashes;
+
+ for (int i = 0; i < 10000; ++i)
+ {
+ IoHash KeyHash = IoHash::HashBuffer(&i, sizeof i);
+
+ KeyHashes.insert(KeyHash);
+
+ Zcs.TrackAccess("foo"sv, KeyHash);
+ }
+
+ for (int i = 0; i < 10000; ++i)
+ {
+ IoHash KeyHash = IoHash::HashBuffer(&i, sizeof i);
+
+ Zcs.TrackAccess("foo"sv, KeyHash);
+ }
+
+ Zcs.SaveSnapshot();
+
+ for (int n = 0; n < 10; ++n)
+ {
+ for (int i = 0; i < 1000; ++i)
+ {
+ const int Index = i + n * 1000;
+ IoHash KeyHash = IoHash::HashBuffer(&Index, sizeof Index);
+
+ Zcs.TrackAccess("foo"sv, KeyHash);
+ }
+
+ Zcs.SaveSnapshot();
+ }
+
+ Zcs.SaveSnapshot();
+
+ const uint64_t t1 = GetCurrentCacheTimeStamp();
+
+ int SnapshotCount = 0;
+
+ Zcs.IterateSnapshots([&](uint64_t TimeStamp, CbObject Snapshot) {
+ CHECK(TimeStamp >= t0);
+ CHECK(TimeStamp <= t1);
+
+ for (auto& Field : Snapshot)
+ {
+ CHECK_EQ(Field.GetName(), "foo"sv);
+
+ const CbArray& Array = Field.AsArray();
+
+ for (const auto& Element : Array)
+ {
+ CHECK(KeyHashes.contains(Element.GetValue().AsHash()));
+ }
+ }
+
+ ++SnapshotCount;
+ });
+
+ CHECK_EQ(SnapshotCount, 11);
+}
+
+# endif
+
+void
+cachetracker_forcelink()
+{
+}
+
+} // namespace zen
+
+#endif // ZEN_USE_CACHE_TRACKER
diff --git a/src/zenserver/cache/cachetracking.h b/src/zenserver/cache/cachetracking.h
new file mode 100644
index 000000000..fdfe1a4c7
--- /dev/null
+++ b/src/zenserver/cache/cachetracking.h
@@ -0,0 +1,41 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/iohash.h>
+
+#include <stdint.h>
+#include <filesystem>
+#include <functional>
+
+namespace zen {
+
+#define ZEN_USE_CACHE_TRACKER 0
+#if ZEN_USE_CACHE_TRACKER
+
+class CbObject;
+
+/**
+ */
+
+class ZenCacheTracker
+{
+public:
+ ZenCacheTracker(std::filesystem::path StateDirectory);
+ ~ZenCacheTracker();
+
+ void TrackAccess(std::string_view BucketSegment, const IoHash& HashKey);
+ void SaveSnapshot();
+ void IterateSnapshots(std::function<void(uint64_t TimeStamp, CbObject Snapshot)>&& Callback);
+
+private:
+ struct Impl;
+
+ Impl* m_Impl = nullptr;
+};
+
+void cachetracker_forcelink();
+
+#endif // ZEN_USE_CACHE_TRACKER
+
+} // namespace zen
diff --git a/src/zenserver/cache/structuredcache.cpp b/src/zenserver/cache/structuredcache.cpp
new file mode 100644
index 000000000..90e905bf6
--- /dev/null
+++ b/src/zenserver/cache/structuredcache.cpp
@@ -0,0 +1,3159 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "structuredcache.h"
+
+#include <zencore/compactbinary.h>
+#include <zencore/compactbinarybuilder.h>
+#include <zencore/compactbinarypackage.h>
+#include <zencore/compactbinaryvalidation.h>
+#include <zencore/compress.h>
+#include <zencore/enumflags.h>
+#include <zencore/fmtutils.h>
+#include <zencore/logging.h>
+#include <zencore/scopeguard.h>
+#include <zencore/stream.h>
+#include <zencore/timer.h>
+#include <zencore/trace.h>
+#include <zencore/workthreadpool.h>
+#include <zenhttp/httpserver.h>
+#include <zenhttp/httpshared.h>
+#include <zenutil/cache/cache.h>
+#include <zenutil/cache/rpcrecording.h>
+
+#include "monitoring/httpstats.h"
+#include "structuredcachestore.h"
+#include "upstream/jupiter.h"
+#include "upstream/upstreamcache.h"
+#include "upstream/zen.h"
+#include "zenstore/cidstore.h"
+#include "zenstore/scrubcontext.h"
+
+#include <algorithm>
+#include <atomic>
+#include <filesystem>
+#include <queue>
+#include <thread>
+
+#include <cpr/cpr.h>
+#include <gsl/gsl-lite.hpp>
+
+#if ZEN_WITH_TESTS
+# include <zencore/testing.h>
+# include <zencore/testutils.h>
+#endif
+
+namespace zen {
+
+using namespace std::literals;
+
+//////////////////////////////////////////////////////////////////////////
+
+CachePolicy
+ParseCachePolicy(const HttpServerRequest::QueryParams& QueryParams)
+{
+ std::string_view PolicyText = QueryParams.GetValue("Policy"sv);
+ return !PolicyText.empty() ? zen::ParseCachePolicy(PolicyText) : CachePolicy::Default;
+}
+
+CacheRecordPolicy
+LoadCacheRecordPolicy(CbObjectView Object, CachePolicy DefaultPolicy = CachePolicy::Default)
+{
+ OptionalCacheRecordPolicy Policy = CacheRecordPolicy::Load(Object);
+ return Policy ? std::move(Policy).Get() : CacheRecordPolicy(DefaultPolicy);
+}
+
+struct AttachmentCount
+{
+ uint32_t New = 0;
+ uint32_t Valid = 0;
+ uint32_t Invalid = 0;
+ uint32_t Total = 0;
+};
+
+struct PutRequestData
+{
+ std::string Namespace;
+ CacheKey Key;
+ CbObjectView RecordObject;
+ CacheRecordPolicy Policy;
+};
+
+namespace {
+ static constinit std::string_view HttpZCacheRPCPrefix = "$rpc"sv;
+ static constinit std::string_view HttpZCacheUtilStartRecording = "exec$/start-recording"sv;
+ static constinit std::string_view HttpZCacheUtilStopRecording = "exec$/stop-recording"sv;
+ static constinit std::string_view HttpZCacheUtilReplayRecording = "exec$/replay-recording"sv;
+ static constinit std::string_view HttpZCacheDetailsPrefix = "details$"sv;
+
+ struct HttpRequestData
+ {
+ std::optional<std::string> Namespace;
+ std::optional<std::string> Bucket;
+ std::optional<IoHash> HashKey;
+ std::optional<IoHash> ValueContentId;
+ };
+
+ constinit AsciiSet ValidNamespaceNameCharactersSet{"abcdefghijklmnopqrstuvwxyz0123456789-_.ABCDEFGHIJKLMNOPQRSTUVWXYZ"};
+ constinit AsciiSet ValidBucketNameCharactersSet{"abcdefghijklmnopqrstuvwxyz0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"};
+
+ std::optional<std::string> GetValidNamespaceName(std::string_view Name)
+ {
+ if (Name.empty())
+ {
+ ZEN_WARN("Namespace is invalid, empty namespace is not allowed");
+ return {};
+ }
+
+ if (Name.length() > 64)
+ {
+ ZEN_WARN("Namespace '{}' is invalid, length exceeds 64 characters", Name);
+ return {};
+ }
+
+ if (!AsciiSet::HasOnly(Name, ValidNamespaceNameCharactersSet))
+ {
+ ZEN_WARN("Namespace '{}' is invalid, invalid characters detected", Name);
+ return {};
+ }
+
+ return ToLower(Name);
+ }
+
+ std::optional<std::string> GetValidBucketName(std::string_view Name)
+ {
+ if (Name.empty())
+ {
+ ZEN_WARN("Bucket name is invalid, empty bucket name is not allowed");
+ return {};
+ }
+
+ if (!AsciiSet::HasOnly(Name, ValidBucketNameCharactersSet))
+ {
+ ZEN_WARN("Bucket name '{}' is invalid, invalid characters detected", Name);
+ return {};
+ }
+
+ return ToLower(Name);
+ }
+
+ std::optional<IoHash> GetValidIoHash(std::string_view Hash)
+ {
+ if (Hash.length() != IoHash::StringLength)
+ {
+ return {};
+ }
+
+ IoHash KeyHash;
+ if (!ParseHexBytes(Hash.data(), Hash.size(), KeyHash.Hash))
+ {
+ return {};
+ }
+ return KeyHash;
+ }
+
+ bool HttpRequestParseRelativeUri(std::string_view Key, HttpRequestData& Data)
+ {
+ std::vector<std::string_view> Tokens;
+ uint32_t TokenCount = ForEachStrTok(Key, '/', [&](const std::string_view& Token) {
+ Tokens.push_back(Token);
+ return true;
+ });
+
+ switch (TokenCount)
+ {
+ case 0:
+ return true;
+ case 1:
+ Data.Namespace = GetValidNamespaceName(Tokens[0]);
+ return Data.Namespace.has_value();
+ case 2:
+ {
+ std::optional<IoHash> PossibleHashKey = GetValidIoHash(Tokens[1]);
+ if (PossibleHashKey.has_value())
+ {
+ // Legacy bucket/key request
+ Data.Bucket = GetValidBucketName(Tokens[0]);
+ if (!Data.Bucket.has_value())
+ {
+ return false;
+ }
+ Data.HashKey = PossibleHashKey;
+ Data.Namespace = ZenCacheStore::DefaultNamespace;
+ return true;
+ }
+ Data.Namespace = GetValidNamespaceName(Tokens[0]);
+ if (!Data.Namespace.has_value())
+ {
+ return false;
+ }
+ Data.Bucket = GetValidBucketName(Tokens[1]);
+ if (!Data.Bucket.has_value())
+ {
+ return false;
+ }
+ return true;
+ }
+ case 3:
+ {
+ std::optional<IoHash> PossibleHashKey = GetValidIoHash(Tokens[1]);
+ if (PossibleHashKey.has_value())
+ {
+ // Legacy bucket/key/valueid request
+ Data.Bucket = GetValidBucketName(Tokens[0]);
+ if (!Data.Bucket.has_value())
+ {
+ return false;
+ }
+ Data.HashKey = PossibleHashKey;
+ Data.ValueContentId = GetValidIoHash(Tokens[2]);
+ if (!Data.ValueContentId.has_value())
+ {
+ return false;
+ }
+ Data.Namespace = ZenCacheStore::DefaultNamespace;
+ return true;
+ }
+ Data.Namespace = GetValidNamespaceName(Tokens[0]);
+ if (!Data.Namespace.has_value())
+ {
+ return false;
+ }
+ Data.Bucket = GetValidBucketName(Tokens[1]);
+ if (!Data.Bucket.has_value())
+ {
+ return false;
+ }
+ Data.HashKey = GetValidIoHash(Tokens[2]);
+ if (!Data.HashKey)
+ {
+ return false;
+ }
+ return true;
+ }
+ case 4:
+ {
+ Data.Namespace = GetValidNamespaceName(Tokens[0]);
+ if (!Data.Namespace.has_value())
+ {
+ return false;
+ }
+
+ Data.Bucket = GetValidBucketName(Tokens[1]);
+ if (!Data.Bucket.has_value())
+ {
+ return false;
+ }
+
+ Data.HashKey = GetValidIoHash(Tokens[2]);
+ if (!Data.HashKey.has_value())
+ {
+ return false;
+ }
+
+ Data.ValueContentId = GetValidIoHash(Tokens[3]);
+ if (!Data.ValueContentId.has_value())
+ {
+ return false;
+ }
+ return true;
+ }
+ default:
+ return false;
+ }
+ }
+
+ std::optional<std::string> GetRpcRequestNamespace(const CbObjectView Params)
+ {
+ CbFieldView NamespaceField = Params["Namespace"sv];
+ if (!NamespaceField)
+ {
+ return std::string(ZenCacheStore::DefaultNamespace);
+ }
+
+ if (NamespaceField.HasError())
+ {
+ return {};
+ }
+ if (!NamespaceField.IsString())
+ {
+ return {};
+ }
+ return GetValidNamespaceName(NamespaceField.AsString());
+ }
+
+ bool GetRpcRequestCacheKey(const CbObjectView& KeyView, CacheKey& Key)
+ {
+ CbFieldView BucketField = KeyView["Bucket"sv];
+ if (BucketField.HasError())
+ {
+ return false;
+ }
+ if (!BucketField.IsString())
+ {
+ return false;
+ }
+ std::optional<std::string> Bucket = GetValidBucketName(BucketField.AsString());
+ if (!Bucket.has_value())
+ {
+ return false;
+ }
+ CbFieldView HashField = KeyView["Hash"sv];
+ if (HashField.HasError())
+ {
+ return false;
+ }
+ if (!HashField.IsHash())
+ {
+ return false;
+ }
+ IoHash Hash = HashField.AsHash();
+ Key = CacheKey::Create(*Bucket, Hash);
+ return true;
+ }
+
+} // namespace
+
+//////////////////////////////////////////////////////////////////////////
+
+HttpStructuredCacheService::HttpStructuredCacheService(ZenCacheStore& InCacheStore,
+ CidStore& InCidStore,
+ HttpStatsService& StatsService,
+ HttpStatusService& StatusService,
+ UpstreamCache& UpstreamCache)
+: m_Log(logging::Get("cache"))
+, m_CacheStore(InCacheStore)
+, m_StatsService(StatsService)
+, m_StatusService(StatusService)
+, m_CidStore(InCidStore)
+, m_UpstreamCache(UpstreamCache)
+{
+ m_StatsService.RegisterHandler("z$", *this);
+ m_StatusService.RegisterHandler("z$", *this);
+}
+
+HttpStructuredCacheService::~HttpStructuredCacheService()
+{
+ ZEN_INFO("closing structured cache");
+ m_RequestRecorder.reset();
+
+ m_StatsService.UnregisterHandler("z$", *this);
+ m_StatusService.UnregisterHandler("z$", *this);
+}
+
+const char*
+HttpStructuredCacheService::BaseUri() const
+{
+ return "/z$/";
+}
+
+void
+HttpStructuredCacheService::Flush()
+{
+ m_CacheStore.Flush();
+}
+
+void
+HttpStructuredCacheService::Scrub(ScrubContext& Ctx)
+{
+ if (m_LastScrubTime == Ctx.ScrubTimestamp())
+ {
+ return;
+ }
+
+ m_LastScrubTime = Ctx.ScrubTimestamp();
+
+ m_CidStore.Scrub(Ctx);
+ m_CacheStore.Scrub(Ctx);
+}
+
+void
+HttpStructuredCacheService::HandleDetailsRequest(HttpServerRequest& Request)
+{
+ std::string_view Key = Request.RelativeUri();
+ std::vector<std::string> Tokens;
+ uint32_t TokenCount = ForEachStrTok(Key, '/', [&Tokens](std::string_view Token) {
+ Tokens.push_back(std::string(Token));
+ return true;
+ });
+ std::string FilterNamespace;
+ std::string FilterBucket;
+ std::string FilterValue;
+ switch (TokenCount)
+ {
+ case 1:
+ break;
+ case 2:
+ {
+ FilterNamespace = Tokens[1];
+ if (FilterNamespace.empty())
+ {
+ return Request.WriteResponse(HttpResponseCode::BadRequest); // invalid URL
+ }
+ }
+ break;
+ case 3:
+ {
+ FilterNamespace = Tokens[1];
+ if (FilterNamespace.empty())
+ {
+ return Request.WriteResponse(HttpResponseCode::BadRequest); // invalid URL
+ }
+ FilterBucket = Tokens[2];
+ if (FilterBucket.empty())
+ {
+ return Request.WriteResponse(HttpResponseCode::BadRequest); // invalid URL
+ }
+ }
+ break;
+ case 4:
+ {
+ FilterNamespace = Tokens[1];
+ if (FilterNamespace.empty())
+ {
+ return Request.WriteResponse(HttpResponseCode::BadRequest); // invalid URL
+ }
+ FilterBucket = Tokens[2];
+ if (FilterBucket.empty())
+ {
+ return Request.WriteResponse(HttpResponseCode::BadRequest); // invalid URL
+ }
+ FilterValue = Tokens[3];
+ if (FilterValue.empty())
+ {
+ return Request.WriteResponse(HttpResponseCode::BadRequest); // invalid URL
+ }
+ }
+ break;
+ default:
+ return Request.WriteResponse(HttpResponseCode::BadRequest); // invalid URL
+ }
+
+ HttpServerRequest::QueryParams Params = Request.GetQueryParams();
+ bool CSV = Params.GetValue("csv") == "true";
+ bool Details = Params.GetValue("details") == "true";
+ bool AttachmentDetails = Params.GetValue("attachmentdetails") == "true";
+
+ std::chrono::seconds NowSeconds = std::chrono::duration_cast<std::chrono::seconds>(GcClock::Now().time_since_epoch());
+ CacheValueDetails ValueDetails = m_CacheStore.GetValueDetails(FilterNamespace, FilterBucket, FilterValue);
+
+ if (CSV)
+ {
+ ExtendableStringBuilder<4096> CSVWriter;
+ if (AttachmentDetails)
+ {
+ CSVWriter << "Namespace, Bucket, Key, Cid, Size";
+ }
+ else if (Details)
+ {
+ CSVWriter << "Namespace, Bucket, Key, Size, RawSize, RawHash, ContentType, Age, AttachmentsCount, AttachmentsSize";
+ }
+ else
+ {
+ CSVWriter << "Namespace, Bucket, Key";
+ }
+ for (const auto& NamespaceIt : ValueDetails.Namespaces)
+ {
+ const std::string& Namespace = NamespaceIt.first;
+ for (const auto& BucketIt : NamespaceIt.second.Buckets)
+ {
+ const std::string& Bucket = BucketIt.first;
+ for (const auto& ValueIt : BucketIt.second.Values)
+ {
+ if (AttachmentDetails)
+ {
+ for (const IoHash& Hash : ValueIt.second.Attachments)
+ {
+ IoBuffer Payload = m_CidStore.FindChunkByCid(Hash);
+ CSVWriter << "\r\n"
+ << Namespace << "," << Bucket << "," << ValueIt.first.ToHexString() << ", " << Hash.ToHexString()
+ << ", " << gsl::narrow<uint64_t>(Payload.GetSize());
+ }
+ }
+ else if (Details)
+ {
+ std::chrono::seconds LastAccessedSeconds = std::chrono::duration_cast<std::chrono::seconds>(
+ GcClock::TimePointFromTick(ValueIt.second.LastAccess).time_since_epoch());
+ CSVWriter << "\r\n"
+ << Namespace << "," << Bucket << "," << ValueIt.first.ToHexString() << ", " << ValueIt.second.Size << ","
+ << ValueIt.second.RawSize << "," << ValueIt.second.RawHash.ToHexString() << ", "
+ << ToString(ValueIt.second.ContentType) << ", " << (NowSeconds.count() - LastAccessedSeconds.count())
+ << ", " << gsl::narrow<uint64_t>(ValueIt.second.Attachments.size());
+ size_t AttachmentsSize = 0;
+ for (const IoHash& Hash : ValueIt.second.Attachments)
+ {
+ IoBuffer Payload = m_CidStore.FindChunkByCid(Hash);
+ AttachmentsSize += Payload.GetSize();
+ }
+ CSVWriter << ", " << gsl::narrow<uint64_t>(AttachmentsSize);
+ }
+ else
+ {
+ CSVWriter << "\r\n" << Namespace << "," << Bucket << "," << ValueIt.first.ToHexString();
+ }
+ }
+ }
+ }
+ return Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, CSVWriter.ToView());
+ }
+ else
+ {
+ CbObjectWriter Cbo;
+ Cbo.BeginArray("namespaces");
+ {
+ for (const auto& NamespaceIt : ValueDetails.Namespaces)
+ {
+ const std::string& Namespace = NamespaceIt.first;
+ Cbo.BeginObject();
+ {
+ Cbo.AddString("name", Namespace);
+ Cbo.BeginArray("buckets");
+ {
+ for (const auto& BucketIt : NamespaceIt.second.Buckets)
+ {
+ const std::string& Bucket = BucketIt.first;
+ Cbo.BeginObject();
+ {
+ Cbo.AddString("name", Bucket);
+ Cbo.BeginArray("values");
+ {
+ for (const auto& ValueIt : BucketIt.second.Values)
+ {
+ std::chrono::seconds LastAccessedSeconds = std::chrono::duration_cast<std::chrono::seconds>(
+ GcClock::TimePointFromTick(ValueIt.second.LastAccess).time_since_epoch());
+ Cbo.BeginObject();
+ {
+ Cbo.AddHash("key", ValueIt.first);
+ if (Details)
+ {
+ Cbo.AddInteger("size", ValueIt.second.Size);
+ if (ValueIt.second.Size > 0 && ValueIt.second.RawSize != 0 &&
+ ValueIt.second.RawSize != ValueIt.second.Size)
+ {
+ Cbo.AddInteger("rawsize", ValueIt.second.RawSize);
+ Cbo.AddHash("rawhash", ValueIt.second.RawHash);
+ }
+ Cbo.AddString("contenttype", ToString(ValueIt.second.ContentType));
+ Cbo.AddInteger("age", NowSeconds.count() - LastAccessedSeconds.count());
+ if (ValueIt.second.Attachments.size() > 0)
+ {
+ if (AttachmentDetails)
+ {
+ Cbo.BeginArray("attachments");
+ {
+ for (const IoHash& Hash : ValueIt.second.Attachments)
+ {
+ Cbo.BeginObject();
+ Cbo.AddHash("cid", Hash);
+ IoBuffer Payload = m_CidStore.FindChunkByCid(Hash);
+ Cbo.AddInteger("size", gsl::narrow<uint64_t>(Payload.GetSize()));
+ Cbo.EndObject();
+ }
+ }
+ Cbo.EndArray();
+ }
+ else
+ {
+ Cbo.AddInteger("attachmentcount",
+ gsl::narrow<uint64_t>(ValueIt.second.Attachments.size()));
+ size_t AttachmentsSize = 0;
+ for (const IoHash& Hash : ValueIt.second.Attachments)
+ {
+ IoBuffer Payload = m_CidStore.FindChunkByCid(Hash);
+ AttachmentsSize += Payload.GetSize();
+ }
+ Cbo.AddInteger("attachmentssize", gsl::narrow<uint64_t>(AttachmentsSize));
+ }
+ }
+ }
+ }
+ Cbo.EndObject();
+ }
+ }
+ Cbo.EndArray();
+ }
+ Cbo.EndObject();
+ }
+ }
+ Cbo.EndArray();
+ }
+ Cbo.EndObject();
+ }
+ }
+ Cbo.EndArray();
+ Request.WriteResponse(HttpResponseCode::OK, Cbo.Save());
+ }
+}
+
+void
+HttpStructuredCacheService::HandleRequest(HttpServerRequest& Request)
+{
+ metrics::OperationTiming::Scope $(m_HttpRequests);
+
+ std::string_view Key = Request.RelativeUri();
+ if (Key == HttpZCacheRPCPrefix)
+ {
+ return HandleRpcRequest(Request);
+ }
+
+ if (Key == HttpZCacheUtilStartRecording)
+ {
+ m_RequestRecorder.reset();
+ HttpServerRequest::QueryParams Params = Request.GetQueryParams();
+ std::string RecordPath = cpr::util::urlDecode(std::string(Params.GetValue("path")));
+ m_RequestRecorder = cache::MakeDiskRequestRecorder(RecordPath);
+ Request.WriteResponse(HttpResponseCode::OK);
+ return;
+ }
+ if (Key == HttpZCacheUtilStopRecording)
+ {
+ m_RequestRecorder.reset();
+ Request.WriteResponse(HttpResponseCode::OK);
+ return;
+ }
+ if (Key == HttpZCacheUtilReplayRecording)
+ {
+ m_RequestRecorder.reset();
+ HttpServerRequest::QueryParams Params = Request.GetQueryParams();
+ std::string RecordPath = cpr::util::urlDecode(std::string(Params.GetValue("path")));
+ uint32_t ThreadCount = std::thread::hardware_concurrency();
+ if (auto Param = Params.GetValue("thread_count"); Param.empty() == false)
+ {
+ if (auto Value = ParseInt<uint64_t>(Param))
+ {
+ ThreadCount = gsl::narrow<uint32_t>(Value.value());
+ }
+ }
+ std::unique_ptr<cache::IRpcRequestReplayer> Replayer(cache::MakeDiskRequestReplayer(RecordPath, false));
+ ReplayRequestRecorder(*Replayer, ThreadCount < 1 ? 1 : ThreadCount);
+ Request.WriteResponse(HttpResponseCode::OK);
+ return;
+ }
+ if (Key.starts_with(HttpZCacheDetailsPrefix))
+ {
+ HandleDetailsRequest(Request);
+ return;
+ }
+
+ HttpRequestData RequestData;
+ if (!HttpRequestParseRelativeUri(Key, RequestData))
+ {
+ return Request.WriteResponse(HttpResponseCode::BadRequest); // invalid URL
+ }
+
+ if (RequestData.ValueContentId.has_value())
+ {
+ ZEN_ASSERT(RequestData.Namespace.has_value());
+ ZEN_ASSERT(RequestData.Bucket.has_value());
+ ZEN_ASSERT(RequestData.HashKey.has_value());
+ CacheRef Ref = {.Namespace = RequestData.Namespace.value(),
+ .BucketSegment = RequestData.Bucket.value(),
+ .HashKey = RequestData.HashKey.value(),
+ .ValueContentId = RequestData.ValueContentId.value()};
+ return HandleCacheChunkRequest(Request, Ref, ParseCachePolicy(Request.GetQueryParams()));
+ }
+
+ if (RequestData.HashKey.has_value())
+ {
+ ZEN_ASSERT(RequestData.Namespace.has_value());
+ ZEN_ASSERT(RequestData.Bucket.has_value());
+ CacheRef Ref = {.Namespace = RequestData.Namespace.value(),
+ .BucketSegment = RequestData.Bucket.value(),
+ .HashKey = RequestData.HashKey.value(),
+ .ValueContentId = IoHash::Zero};
+ return HandleCacheRecordRequest(Request, Ref, ParseCachePolicy(Request.GetQueryParams()));
+ }
+
+ if (RequestData.Bucket.has_value())
+ {
+ ZEN_ASSERT(RequestData.Namespace.has_value());
+ return HandleCacheBucketRequest(Request, RequestData.Namespace.value(), RequestData.Bucket.value());
+ }
+
+ if (RequestData.Namespace.has_value())
+ {
+ return HandleCacheNamespaceRequest(Request, RequestData.Namespace.value());
+ }
+ return HandleCacheRequest(Request);
+}
+
+void
+HttpStructuredCacheService::HandleCacheRequest(HttpServerRequest& Request)
+{
+ switch (Request.RequestVerb())
+ {
+ case HttpVerb::kHead:
+ case HttpVerb::kGet:
+ {
+ ZenCacheStore::Info Info = m_CacheStore.GetInfo();
+
+ CbObjectWriter ResponseWriter;
+
+ ResponseWriter.BeginObject("Configuration");
+ {
+ ExtendableStringBuilder<128> BasePathString;
+ BasePathString << Info.Config.BasePath.u8string();
+ ResponseWriter.AddString("BasePath"sv, BasePathString.ToView());
+ ResponseWriter.AddBool("AllowAutomaticCreationOfNamespaces", Info.Config.AllowAutomaticCreationOfNamespaces);
+ }
+ ResponseWriter.EndObject();
+
+ std::sort(begin(Info.NamespaceNames), end(Info.NamespaceNames), [](std::string_view L, std::string_view R) {
+ return L.compare(R) < 0;
+ });
+ ResponseWriter.BeginArray("Namespaces");
+ for (const std::string& NamespaceName : Info.NamespaceNames)
+ {
+ ResponseWriter.AddString(NamespaceName);
+ }
+ ResponseWriter.EndArray();
+ ResponseWriter.BeginObject("StorageSize");
+ {
+ ResponseWriter.AddInteger("DiskSize", Info.StorageSize.DiskSize);
+ ResponseWriter.AddInteger("MemorySize", Info.StorageSize.MemorySize);
+ }
+
+ ResponseWriter.EndObject();
+
+ ResponseWriter.AddInteger("DiskEntryCount", Info.DiskEntryCount);
+ ResponseWriter.AddInteger("MemoryEntryCount", Info.MemoryEntryCount);
+
+ return Request.WriteResponse(HttpResponseCode::OK, ResponseWriter.Save());
+ }
+ break;
+ }
+}
+
+void
+HttpStructuredCacheService::HandleCacheNamespaceRequest(HttpServerRequest& Request, std::string_view NamespaceName)
+{
+ switch (Request.RequestVerb())
+ {
+ case HttpVerb::kHead:
+ case HttpVerb::kGet:
+ {
+ std::optional<ZenCacheNamespace::Info> Info = m_CacheStore.GetNamespaceInfo(NamespaceName);
+ if (!Info.has_value())
+ {
+ return Request.WriteResponse(HttpResponseCode::NotFound);
+ }
+
+ CbObjectWriter ResponseWriter;
+
+ ResponseWriter.BeginObject("Configuration");
+ {
+ ExtendableStringBuilder<128> BasePathString;
+ BasePathString << Info->Config.RootDir.u8string();
+ ResponseWriter.AddString("RootDir"sv, BasePathString.ToView());
+ ResponseWriter.AddInteger("DiskLayerThreshold"sv, Info->Config.DiskLayerThreshold);
+ }
+ ResponseWriter.EndObject();
+
+ std::sort(begin(Info->BucketNames), end(Info->BucketNames), [](std::string_view L, std::string_view R) {
+ return L.compare(R) < 0;
+ });
+
+ ResponseWriter.BeginArray("Buckets"sv);
+ for (const std::string& BucketName : Info->BucketNames)
+ {
+ ResponseWriter.AddString(BucketName);
+ }
+ ResponseWriter.EndArray();
+
+ ResponseWriter.BeginObject("StorageSize"sv);
+ {
+ ResponseWriter.AddInteger("DiskSize"sv, Info->DiskLayerInfo.TotalSize);
+ ResponseWriter.AddInteger("MemorySize"sv, Info->MemoryLayerInfo.TotalSize);
+ }
+ ResponseWriter.EndObject();
+
+ ResponseWriter.AddInteger("DiskEntryCount", Info->DiskLayerInfo.EntryCount);
+ ResponseWriter.AddInteger("MemoryEntryCount", Info->MemoryLayerInfo.EntryCount);
+
+ return Request.WriteResponse(HttpResponseCode::OK, ResponseWriter.Save());
+ }
+ break;
+
+ case HttpVerb::kDelete:
+ // Drop namespace
+ {
+ if (m_CacheStore.DropNamespace(NamespaceName))
+ {
+ return Request.WriteResponse(HttpResponseCode::OK);
+ }
+ else
+ {
+ return Request.WriteResponse(HttpResponseCode::NotFound);
+ }
+ }
+ break;
+
+ default:
+ break;
+ }
+}
+
+void
+HttpStructuredCacheService::HandleCacheBucketRequest(HttpServerRequest& Request,
+ std::string_view NamespaceName,
+ std::string_view BucketName)
+{
+ switch (Request.RequestVerb())
+ {
+ case HttpVerb::kHead:
+ case HttpVerb::kGet:
+ {
+ std::optional<ZenCacheNamespace::BucketInfo> Info = m_CacheStore.GetBucketInfo(NamespaceName, BucketName);
+ if (!Info.has_value())
+ {
+ return Request.WriteResponse(HttpResponseCode::NotFound);
+ }
+
+ CbObjectWriter ResponseWriter;
+
+ ResponseWriter.BeginObject("StorageSize");
+ {
+ ResponseWriter.AddInteger("DiskSize", Info->DiskLayerInfo.TotalSize);
+ ResponseWriter.AddInteger("MemorySize", Info->MemoryLayerInfo.TotalSize);
+ }
+ ResponseWriter.EndObject();
+
+ ResponseWriter.AddInteger("DiskEntryCount", Info->DiskLayerInfo.EntryCount);
+ ResponseWriter.AddInteger("MemoryEntryCount", Info->MemoryLayerInfo.EntryCount);
+
+ return Request.WriteResponse(HttpResponseCode::OK, ResponseWriter.Save());
+ }
+ break;
+
+ case HttpVerb::kDelete:
+ // Drop bucket
+ {
+ if (m_CacheStore.DropBucket(NamespaceName, BucketName))
+ {
+ return Request.WriteResponse(HttpResponseCode::OK);
+ }
+ else
+ {
+ return Request.WriteResponse(HttpResponseCode::NotFound);
+ }
+ }
+ break;
+
+ default:
+ break;
+ }
+}
+
+void
+HttpStructuredCacheService::HandleCacheRecordRequest(HttpServerRequest& Request, const CacheRef& Ref, CachePolicy PolicyFromUrl)
+{
+ switch (Request.RequestVerb())
+ {
+ case HttpVerb::kHead:
+ case HttpVerb::kGet:
+ {
+ HandleGetCacheRecord(Request, Ref, PolicyFromUrl);
+ }
+ break;
+
+ case HttpVerb::kPut:
+ HandlePutCacheRecord(Request, Ref, PolicyFromUrl);
+ break;
+ default:
+ break;
+ }
+}
+
+void
+HttpStructuredCacheService::HandleGetCacheRecord(HttpServerRequest& Request, const CacheRef& Ref, CachePolicy PolicyFromUrl)
+{
+ const ZenContentType AcceptType = Request.AcceptContentType();
+ const bool SkipData = EnumHasAllFlags(PolicyFromUrl, CachePolicy::SkipData);
+ const bool PartialRecord = EnumHasAllFlags(PolicyFromUrl, CachePolicy::PartialRecord);
+
+ bool Success = false;
+ ZenCacheValue ClientResultValue;
+ if (!EnumHasAnyFlags(PolicyFromUrl, CachePolicy::Query))
+ {
+ return Request.WriteResponse(HttpResponseCode::OK);
+ }
+
+ Stopwatch Timer;
+
+ if (EnumHasAllFlags(PolicyFromUrl, CachePolicy::QueryLocal) &&
+ m_CacheStore.Get(Ref.Namespace, Ref.BucketSegment, Ref.HashKey, ClientResultValue))
+ {
+ Success = true;
+ ZenContentType ContentType = ClientResultValue.Value.GetContentType();
+
+ if (AcceptType == ZenContentType::kCbPackage)
+ {
+ if (ContentType == ZenContentType::kCbObject)
+ {
+ CbPackage Package;
+ uint32_t MissingCount = 0;
+
+ CbObjectView CacheRecord(ClientResultValue.Value.Data());
+ CacheRecord.IterateAttachments([this, &MissingCount, &Package, SkipData](CbFieldView AttachmentHash) {
+ if (SkipData)
+ {
+ if (!m_CidStore.ContainsChunk(AttachmentHash.AsHash()))
+ {
+ MissingCount++;
+ }
+ }
+ else
+ {
+ if (IoBuffer Chunk = m_CidStore.FindChunkByCid(AttachmentHash.AsHash()))
+ {
+ CompressedBuffer Compressed = CompressedBuffer::FromCompressedNoValidate(std::move(Chunk));
+ Package.AddAttachment(CbAttachment(Compressed, AttachmentHash.AsHash()));
+ }
+ else
+ {
+ MissingCount++;
+ }
+ }
+ });
+
+ Success = MissingCount == 0 || PartialRecord;
+
+ if (Success)
+ {
+ Package.SetObject(LoadCompactBinaryObject(ClientResultValue.Value));
+
+ BinaryWriter MemStream;
+ Package.Save(MemStream);
+
+ ClientResultValue.Value = IoBuffer(IoBuffer::Clone, MemStream.Data(), MemStream.Size());
+ ClientResultValue.Value.SetContentType(HttpContentType::kCbPackage);
+ }
+ }
+ else
+ {
+ Success = false;
+ }
+ }
+ else if (AcceptType != ClientResultValue.Value.GetContentType() && AcceptType != ZenContentType::kUnknownContentType &&
+ AcceptType != ZenContentType::kBinary)
+ {
+ Success = false;
+ }
+ }
+
+ if (Success)
+ {
+ ZEN_DEBUG("GETCACHERECORD HIT - '{}/{}/{}' {} '{}' (LOCAL) in {}",
+ Ref.Namespace,
+ Ref.BucketSegment,
+ Ref.HashKey,
+ NiceBytes(ClientResultValue.Value.Size()),
+ ToString(ClientResultValue.Value.GetContentType()),
+ NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000));
+
+ m_CacheStats.HitCount++;
+ if (SkipData && AcceptType != ZenContentType::kCbPackage && AcceptType != ZenContentType::kCbObject)
+ {
+ return Request.WriteResponse(HttpResponseCode::OK);
+ }
+ else
+ {
+ // kCbPackage handled SkipData when constructing the ClientResultValue, kcbObject ignores SkipData
+ return Request.WriteResponse(HttpResponseCode::OK, ClientResultValue.Value.GetContentType(), ClientResultValue.Value);
+ }
+ }
+ else if (!EnumHasAllFlags(PolicyFromUrl, CachePolicy::QueryRemote))
+ {
+ ZEN_DEBUG("GETCACHERECORD MISS - '{}/{}/{}' '{}' in {}",
+ Ref.Namespace,
+ Ref.BucketSegment,
+ Ref.HashKey,
+ ToString(AcceptType),
+ NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000));
+ m_CacheStats.MissCount++;
+ return Request.WriteResponse(HttpResponseCode::NotFound);
+ }
+
+ // Issue upstream query asynchronously in order to keep requests flowing without
+ // hogging I/O servicing threads with blocking work
+
+ uint64_t LocalElapsedTimeUs = Timer.GetElapsedTimeUs();
+
+ Request.WriteResponseAsync([this, AcceptType, PolicyFromUrl, Ref, LocalElapsedTimeUs](HttpServerRequest& AsyncRequest) {
+ Stopwatch Timer;
+ bool Success = false;
+ const bool PartialRecord = EnumHasAllFlags(PolicyFromUrl, CachePolicy::PartialRecord);
+ const bool QueryLocal = EnumHasAllFlags(PolicyFromUrl, CachePolicy::QueryLocal);
+ const bool StoreLocal = EnumHasAllFlags(PolicyFromUrl, CachePolicy::StoreLocal);
+ const bool SkipData = EnumHasAllFlags(PolicyFromUrl, CachePolicy::SkipData);
+ ZenCacheValue ClientResultValue;
+
+ metrics::OperationTiming::Scope $(m_UpstreamGetRequestTiming);
+
+ if (GetUpstreamCacheSingleResult UpstreamResult =
+ m_UpstreamCache.GetCacheRecord(Ref.Namespace, {Ref.BucketSegment, Ref.HashKey}, AcceptType);
+ UpstreamResult.Status.Success)
+ {
+ Success = true;
+
+ ClientResultValue.Value = UpstreamResult.Value;
+ ClientResultValue.Value.SetContentType(AcceptType);
+
+ if (AcceptType == ZenContentType::kBinary || AcceptType == ZenContentType::kCbObject)
+ {
+ if (AcceptType == ZenContentType::kCbObject)
+ {
+ const CbValidateError ValidationResult = ValidateCompactBinary(UpstreamResult.Value, CbValidateMode::All);
+ if (ValidationResult != CbValidateError::None)
+ {
+ Success = false;
+ ZEN_WARN("Get - '{}/{}/{}' '{}' FAILED, invalid compact binary object from upstream",
+ Ref.Namespace,
+ Ref.BucketSegment,
+ Ref.HashKey,
+ ToString(AcceptType));
+ }
+
+ // We do not do anything to the returned object for SkipData, only package attachments are cut when skipping data
+ }
+
+ if (Success && StoreLocal)
+ {
+ m_CacheStore.Put(Ref.Namespace, Ref.BucketSegment, Ref.HashKey, ClientResultValue);
+ }
+ }
+ else if (AcceptType == ZenContentType::kCbPackage)
+ {
+ CbPackage Package;
+ if (Package.TryLoad(ClientResultValue.Value))
+ {
+ CbObject CacheRecord = Package.GetObject();
+ AttachmentCount Count;
+ size_t NumAttachments = Package.GetAttachments().size();
+ std::vector<const CbAttachment*> AttachmentsToStoreLocally;
+ AttachmentsToStoreLocally.reserve(NumAttachments);
+
+ CacheRecord.IterateAttachments(
+ [this, &Package, &Ref, &AttachmentsToStoreLocally, &Count, QueryLocal, StoreLocal, SkipData](CbFieldView HashView) {
+ IoHash Hash = HashView.AsHash();
+ if (const CbAttachment* Attachment = Package.FindAttachment(Hash))
+ {
+ if (Attachment->IsCompressedBinary())
+ {
+ if (StoreLocal)
+ {
+ AttachmentsToStoreLocally.emplace_back(Attachment);
+ }
+ Count.Valid++;
+ }
+ else
+ {
+ ZEN_WARN("Uncompressed value '{}' from upstream cache record '{}/{}'",
+ Hash,
+ Ref.BucketSegment,
+ Ref.HashKey);
+ Count.Invalid++;
+ }
+ }
+ else if (QueryLocal)
+ {
+ if (SkipData)
+ {
+ if (m_CidStore.ContainsChunk(Hash))
+ {
+ Count.Valid++;
+ }
+ }
+ else if (IoBuffer Chunk = m_CidStore.FindChunkByCid(Hash))
+ {
+ CompressedBuffer Compressed = CompressedBuffer::FromCompressedNoValidate(std::move(Chunk));
+ if (Compressed)
+ {
+ Package.AddAttachment(CbAttachment(Compressed, Hash));
+ Count.Valid++;
+ }
+ else
+ {
+ ZEN_WARN("Uncompressed value '{}' stored in local cache '{}/{}'",
+ Hash,
+ Ref.BucketSegment,
+ Ref.HashKey);
+ Count.Invalid++;
+ }
+ }
+ }
+ Count.Total++;
+ });
+
+ if ((Count.Valid == Count.Total) || PartialRecord)
+ {
+ ZenCacheValue CacheValue;
+ CacheValue.Value = CacheRecord.GetBuffer().AsIoBuffer();
+ CacheValue.Value.SetContentType(ZenContentType::kCbObject);
+
+ if (StoreLocal)
+ {
+ m_CacheStore.Put(Ref.Namespace, Ref.BucketSegment, Ref.HashKey, CacheValue);
+ }
+
+ for (const CbAttachment* Attachment : AttachmentsToStoreLocally)
+ {
+ CompressedBuffer Chunk = Attachment->AsCompressedBinary();
+ CidStore::InsertResult InsertResult =
+ m_CidStore.AddChunk(Chunk.GetCompressed().Flatten().AsIoBuffer(), Attachment->GetHash());
+ if (InsertResult.New)
+ {
+ Count.New++;
+ }
+ }
+
+ BinaryWriter MemStream;
+ if (SkipData)
+ {
+ // Save a package containing only the object.
+ CbPackage(Package.GetObject()).Save(MemStream);
+ }
+ else
+ {
+ Package.Save(MemStream);
+ }
+
+ ClientResultValue.Value = IoBuffer(IoBuffer::Clone, MemStream.Data(), MemStream.Size());
+ ClientResultValue.Value.SetContentType(ZenContentType::kCbPackage);
+ }
+ else
+ {
+ Success = false;
+ ZEN_WARN("Get - '{}/{}' '{}' FAILED, attachments missing in upstream package",
+ Ref.BucketSegment,
+ Ref.HashKey,
+ ToString(AcceptType));
+ }
+ }
+ else
+ {
+ Success = false;
+ ZEN_WARN("Get - '{}/{}/{}' '{}' FAILED, invalid upstream package",
+ Ref.Namespace,
+ Ref.BucketSegment,
+ Ref.HashKey,
+ ToString(AcceptType));
+ }
+ }
+ }
+
+ if (Success)
+ {
+ ZEN_DEBUG("GETCACHERECORD HIT - '{}/{}/{}' {} '{}' (UPSTREAM) in {}",
+ Ref.Namespace,
+ Ref.BucketSegment,
+ Ref.HashKey,
+ NiceBytes(ClientResultValue.Value.Size()),
+ ToString(ClientResultValue.Value.GetContentType()),
+ NiceLatencyNs((LocalElapsedTimeUs + Timer.GetElapsedTimeUs()) * 1000));
+
+ m_CacheStats.HitCount++;
+ m_CacheStats.UpstreamHitCount++;
+
+ if (SkipData && AcceptType == ZenContentType::kBinary)
+ {
+ AsyncRequest.WriteResponse(HttpResponseCode::OK);
+ }
+ else
+ {
+ // Other methods modify ClientResultValue to a version that has skipped the data but keeps the Object and optionally
+ // metadata.
+ AsyncRequest.WriteResponse(HttpResponseCode::OK, ClientResultValue.Value.GetContentType(), ClientResultValue.Value);
+ }
+ }
+ else
+ {
+ ZEN_DEBUG("GETCACHERECORD MISS - '{}/{}/{}' '{}' in {}",
+ Ref.Namespace,
+ Ref.BucketSegment,
+ Ref.HashKey,
+ ToString(AcceptType),
+ NiceLatencyNs((LocalElapsedTimeUs + Timer.GetElapsedTimeUs()) * 1000));
+ m_CacheStats.MissCount++;
+ AsyncRequest.WriteResponse(HttpResponseCode::NotFound);
+ }
+ });
+}
+
+void
+HttpStructuredCacheService::HandlePutCacheRecord(HttpServerRequest& Request, const CacheRef& Ref, CachePolicy PolicyFromUrl)
+{
+ IoBuffer Body = Request.ReadPayload();
+
+ if (!Body || Body.Size() == 0)
+ {
+ return Request.WriteResponse(HttpResponseCode::BadRequest);
+ }
+
+ const HttpContentType ContentType = Request.RequestContentType();
+
+ Body.SetContentType(ContentType);
+
+ Stopwatch Timer;
+
+ if (ContentType == HttpContentType::kBinary || ContentType == HttpContentType::kCompressedBinary)
+ {
+ IoHash RawHash = IoHash::Zero;
+ uint64_t RawSize = Body.GetSize();
+ if (ContentType == HttpContentType::kCompressedBinary)
+ {
+ if (!CompressedBuffer::ValidateCompressedHeader(Body, RawHash, RawSize))
+ {
+ return Request.WriteResponse(HttpResponseCode::BadRequest,
+ HttpContentType::kText,
+ "Payload is not a valid compressed binary"sv);
+ }
+ }
+ else
+ {
+ RawHash = IoHash::HashBuffer(SharedBuffer(Body));
+ }
+ m_CacheStore.Put(Ref.Namespace, Ref.BucketSegment, Ref.HashKey, {.Value = Body, .RawSize = RawSize, .RawHash = RawHash});
+
+ if (EnumHasAllFlags(PolicyFromUrl, CachePolicy::StoreRemote))
+ {
+ m_UpstreamCache.EnqueueUpstream({.Type = ContentType, .Namespace = Ref.Namespace, .Key = {Ref.BucketSegment, Ref.HashKey}});
+ }
+
+ ZEN_DEBUG("PUTCACHERECORD - '{}/{}/{}' {} '{}' in {}",
+ Ref.Namespace,
+ Ref.BucketSegment,
+ Ref.HashKey,
+ NiceBytes(Body.Size()),
+ ToString(ContentType),
+ NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000));
+ Request.WriteResponse(HttpResponseCode::Created);
+ }
+ else if (ContentType == HttpContentType::kCbObject)
+ {
+ const CbValidateError ValidationResult = ValidateCompactBinary(MemoryView(Body.GetData(), Body.GetSize()), CbValidateMode::All);
+
+ if (ValidationResult != CbValidateError::None)
+ {
+ ZEN_WARN("PUTCACHERECORD - '{}/{}/{}' '{}' FAILED, invalid compact binary",
+ Ref.Namespace,
+ Ref.BucketSegment,
+ Ref.HashKey,
+ ToString(ContentType));
+ return Request.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "Compact binary validation failed"sv);
+ }
+
+ Body.SetContentType(ZenContentType::kCbObject);
+ m_CacheStore.Put(Ref.Namespace, Ref.BucketSegment, Ref.HashKey, {.Value = Body});
+
+ CbObjectView CacheRecord(Body.Data());
+ std::vector<IoHash> ValidAttachments;
+ int32_t TotalCount = 0;
+
+ CacheRecord.IterateAttachments([this, &TotalCount, &ValidAttachments](CbFieldView AttachmentHash) {
+ const IoHash Hash = AttachmentHash.AsHash();
+ if (m_CidStore.ContainsChunk(Hash))
+ {
+ ValidAttachments.emplace_back(Hash);
+ }
+ TotalCount++;
+ });
+
+ ZEN_DEBUG("PUTCACHERECORD - '{}/{}/{}' {} '{}' attachments '{}/{}' (valid/total) in {}",
+ Ref.Namespace,
+ Ref.BucketSegment,
+ Ref.HashKey,
+ NiceBytes(Body.Size()),
+ ToString(ContentType),
+ TotalCount,
+ ValidAttachments.size(),
+ NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000));
+
+ const bool IsPartialRecord = TotalCount != static_cast<int32_t>(ValidAttachments.size());
+
+ CachePolicy Policy = PolicyFromUrl;
+ if (EnumHasAllFlags(Policy, CachePolicy::StoreRemote) && !IsPartialRecord)
+ {
+ m_UpstreamCache.EnqueueUpstream({.Type = ZenContentType::kCbObject,
+ .Namespace = Ref.Namespace,
+ .Key = {Ref.BucketSegment, Ref.HashKey},
+ .ValueContentIds = std::move(ValidAttachments)});
+ }
+
+ Request.WriteResponse(HttpResponseCode::Created);
+ }
+ else if (ContentType == HttpContentType::kCbPackage)
+ {
+ CbPackage Package;
+
+ if (!Package.TryLoad(Body))
+ {
+ ZEN_WARN("PUTCACHERECORD - '{}/{}/{}' '{}' FAILED, invalid package",
+ Ref.Namespace,
+ Ref.BucketSegment,
+ Ref.HashKey,
+ ToString(ContentType));
+ return Request.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "Invalid package"sv);
+ }
+ CachePolicy Policy = PolicyFromUrl;
+
+ CbObject CacheRecord = Package.GetObject();
+
+ AttachmentCount Count;
+ size_t NumAttachments = Package.GetAttachments().size();
+ std::vector<IoHash> ValidAttachments;
+ std::vector<const CbAttachment*> AttachmentsToStoreLocally;
+ ValidAttachments.reserve(NumAttachments);
+ AttachmentsToStoreLocally.reserve(NumAttachments);
+
+ CacheRecord.IterateAttachments([this, &Ref, &Package, &AttachmentsToStoreLocally, &ValidAttachments, &Count](CbFieldView HashView) {
+ const IoHash Hash = HashView.AsHash();
+ if (const CbAttachment* Attachment = Package.FindAttachment(Hash))
+ {
+ if (Attachment->IsCompressedBinary())
+ {
+ AttachmentsToStoreLocally.emplace_back(Attachment);
+ ValidAttachments.emplace_back(Hash);
+ Count.Valid++;
+ }
+ else
+ {
+ ZEN_WARN("PUTCACHERECORD - '{}/{}/{}' '{}' FAILED, attachment '{}' is not compressed",
+ Ref.Namespace,
+ Ref.BucketSegment,
+ Ref.HashKey,
+ ToString(HttpContentType::kCbPackage),
+ Hash);
+ Count.Invalid++;
+ }
+ }
+ else if (m_CidStore.ContainsChunk(Hash))
+ {
+ ValidAttachments.emplace_back(Hash);
+ Count.Valid++;
+ }
+ Count.Total++;
+ });
+
+ if (Count.Invalid > 0)
+ {
+ return Request.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "Invalid attachment(s)"sv);
+ }
+
+ ZenCacheValue CacheValue;
+ CacheValue.Value = CacheRecord.GetBuffer().AsIoBuffer();
+ CacheValue.Value.SetContentType(ZenContentType::kCbObject);
+ m_CacheStore.Put(Ref.Namespace, Ref.BucketSegment, Ref.HashKey, CacheValue);
+
+ for (const CbAttachment* Attachment : AttachmentsToStoreLocally)
+ {
+ CompressedBuffer Chunk = Attachment->AsCompressedBinary();
+ CidStore::InsertResult InsertResult = m_CidStore.AddChunk(Chunk.GetCompressed().Flatten().AsIoBuffer(), Attachment->GetHash());
+ if (InsertResult.New)
+ {
+ Count.New++;
+ }
+ }
+
+ ZEN_DEBUG("PUTCACHERECORD - '{}/{}/{}' {} '{}', attachments '{}/{}/{}' (new/valid/total) in {}",
+ Ref.Namespace,
+ Ref.BucketSegment,
+ Ref.HashKey,
+ NiceBytes(Body.GetSize()),
+ ToString(ContentType),
+ Count.New,
+ Count.Valid,
+ Count.Total,
+ NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000));
+
+ const bool IsPartialRecord = Count.Valid != Count.Total;
+
+ if (EnumHasAllFlags(Policy, CachePolicy::StoreRemote) && !IsPartialRecord)
+ {
+ m_UpstreamCache.EnqueueUpstream({.Type = ZenContentType::kCbPackage,
+ .Namespace = Ref.Namespace,
+ .Key = {Ref.BucketSegment, Ref.HashKey},
+ .ValueContentIds = std::move(ValidAttachments)});
+ }
+
+ Request.WriteResponse(HttpResponseCode::Created);
+ }
+ else
+ {
+ return Request.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "Content-Type invalid"sv);
+ }
+}
+
+void
+HttpStructuredCacheService::HandleCacheChunkRequest(HttpServerRequest& Request, const CacheRef& Ref, CachePolicy PolicyFromUrl)
+{
+ switch (Request.RequestVerb())
+ {
+ case HttpVerb::kHead:
+ case HttpVerb::kGet:
+ HandleGetCacheChunk(Request, Ref, PolicyFromUrl);
+ break;
+ case HttpVerb::kPut:
+ HandlePutCacheChunk(Request, Ref, PolicyFromUrl);
+ break;
+ default:
+ break;
+ }
+}
+
+void
+HttpStructuredCacheService::HandleGetCacheChunk(HttpServerRequest& Request, const CacheRef& Ref, CachePolicy PolicyFromUrl)
+{
+ Stopwatch Timer;
+
+ IoBuffer Value = m_CidStore.FindChunkByCid(Ref.ValueContentId);
+ const UpstreamEndpointInfo* Source = nullptr;
+ CachePolicy Policy = PolicyFromUrl;
+ {
+ const bool QueryUpstream = !Value && EnumHasAllFlags(Policy, CachePolicy::QueryRemote);
+
+ if (QueryUpstream)
+ {
+ if (GetUpstreamCacheSingleResult UpstreamResult =
+ m_UpstreamCache.GetCacheChunk(Ref.Namespace, {Ref.BucketSegment, Ref.HashKey}, Ref.ValueContentId);
+ UpstreamResult.Status.Success)
+ {
+ IoHash RawHash;
+ uint64_t RawSize;
+ if (CompressedBuffer::ValidateCompressedHeader(UpstreamResult.Value, RawHash, RawSize))
+ {
+ if (RawHash == Ref.ValueContentId)
+ {
+ m_CidStore.AddChunk(UpstreamResult.Value, RawHash);
+ Source = UpstreamResult.Source;
+ }
+ else
+ {
+ ZEN_WARN("got missmatching upstream cache value");
+ }
+ }
+ else
+ {
+ ZEN_WARN("got uncompressed upstream cache value");
+ }
+ }
+ }
+ }
+
+ if (!Value)
+ {
+ ZEN_DEBUG("GETCACHECHUNK MISS - '{}/{}/{}/{}' '{}' in {}",
+ Ref.Namespace,
+ Ref.BucketSegment,
+ Ref.HashKey,
+ Ref.ValueContentId,
+ ToString(Request.AcceptContentType()),
+ NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000));
+ m_CacheStats.MissCount++;
+ return Request.WriteResponse(HttpResponseCode::NotFound);
+ }
+
+ ZEN_DEBUG("GETCACHECHUNK HIT - '{}/{}/{}/{}' {} '{}' ({}) in {}",
+ Ref.Namespace,
+ Ref.BucketSegment,
+ Ref.HashKey,
+ Ref.ValueContentId,
+ NiceBytes(Value.Size()),
+ ToString(Value.GetContentType()),
+ Source ? Source->Url : "LOCAL"sv,
+ NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000));
+
+ m_CacheStats.HitCount++;
+ if (Source)
+ {
+ m_CacheStats.UpstreamHitCount++;
+ }
+
+ if (EnumHasAllFlags(Policy, CachePolicy::SkipData))
+ {
+ Request.WriteResponse(HttpResponseCode::OK);
+ }
+ else
+ {
+ Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, Value);
+ }
+}
+
+void
+HttpStructuredCacheService::HandlePutCacheChunk(HttpServerRequest& Request, const CacheRef& Ref, CachePolicy PolicyFromUrl)
+{
+ // Note: Individual cacherecord values are not propagated upstream until a valid cache record has been stored
+ ZEN_UNUSED(PolicyFromUrl);
+
+ Stopwatch Timer;
+
+ IoBuffer Body = Request.ReadPayload();
+
+ if (!Body || Body.Size() == 0)
+ {
+ return Request.WriteResponse(HttpResponseCode::BadRequest);
+ }
+
+ Body.SetContentType(Request.RequestContentType());
+
+ IoHash RawHash;
+ uint64_t RawSize;
+ if (!CompressedBuffer::ValidateCompressedHeader(Body, RawHash, RawSize))
+ {
+ return Request.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "Attachments must be compressed"sv);
+ }
+
+ if (RawHash != Ref.ValueContentId)
+ {
+ return Request.WriteResponse(HttpResponseCode::BadRequest,
+ HttpContentType::kText,
+ "ValueContentId does not match attachment hash"sv);
+ }
+
+ CidStore::InsertResult Result = m_CidStore.AddChunk(Body, RawHash);
+
+ ZEN_DEBUG("PUTCACHECHUNK - '{}/{}/{}/{}' {} '{}' ({}) in {}",
+ Ref.Namespace,
+ Ref.BucketSegment,
+ Ref.HashKey,
+ Ref.ValueContentId,
+ NiceBytes(Body.Size()),
+ ToString(Body.GetContentType()),
+ Result.New ? "NEW" : "OLD",
+ NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000));
+
+ const HttpResponseCode ResponseCode = Result.New ? HttpResponseCode::Created : HttpResponseCode::OK;
+
+ Request.WriteResponse(ResponseCode);
+}
+
+CbPackage
+HttpStructuredCacheService::HandleRpcRequest(const ZenContentType ContentType,
+ IoBuffer&& Body,
+ uint32_t& OutAcceptMagic,
+ RpcAcceptOptions& OutAcceptFlags,
+ int& OutTargetProcessId)
+{
+ CbPackage Package;
+ CbObjectView Object;
+ CbObject ObjectBuffer;
+ if (ContentType == ZenContentType::kCbObject)
+ {
+ ObjectBuffer = LoadCompactBinaryObject(std::move(Body));
+ Object = ObjectBuffer;
+ }
+ else
+ {
+ Package = ParsePackageMessage(Body);
+ Object = Package.GetObject();
+ }
+ OutAcceptMagic = Object["Accept"sv].AsUInt32();
+ OutAcceptFlags = static_cast<RpcAcceptOptions>(Object["AcceptFlags"sv].AsUInt16(0u));
+ OutTargetProcessId = Object["Pid"sv].AsInt32(0);
+
+ const std::string_view Method = Object["Method"sv].AsString();
+
+ if (Method == "PutCacheRecords"sv)
+ {
+ return HandleRpcPutCacheRecords(Package);
+ }
+ else if (Method == "GetCacheRecords"sv)
+ {
+ return HandleRpcGetCacheRecords(Object);
+ }
+ else if (Method == "PutCacheValues"sv)
+ {
+ return HandleRpcPutCacheValues(Package);
+ }
+ else if (Method == "GetCacheValues"sv)
+ {
+ return HandleRpcGetCacheValues(Object);
+ }
+ else if (Method == "GetCacheChunks"sv)
+ {
+ return HandleRpcGetCacheChunks(Object);
+ }
+ return CbPackage{};
+}
+
+void
+HttpStructuredCacheService::ReplayRequestRecorder(cache::IRpcRequestReplayer& Replayer, uint32_t ThreadCount)
+{
+ WorkerThreadPool WorkerPool(ThreadCount);
+ uint64_t RequestCount = Replayer.GetRequestCount();
+ Stopwatch Timer;
+ auto _ = MakeGuard([&]() { ZEN_INFO("Replayed {} requests in {}", RequestCount, NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000)); });
+ Latch JobLatch(RequestCount);
+ ZEN_INFO("Replaying {} requests", RequestCount);
+ for (uint64_t RequestIndex = 0; RequestIndex < RequestCount; ++RequestIndex)
+ {
+ WorkerPool.ScheduleWork([this, &JobLatch, &Replayer, RequestIndex]() {
+ IoBuffer Body;
+ std::pair<ZenContentType, ZenContentType> ContentType = Replayer.GetRequest(RequestIndex, Body);
+ if (Body)
+ {
+ uint32_t AcceptMagic = 0;
+ RpcAcceptOptions AcceptFlags = RpcAcceptOptions::kNone;
+ int TargetPid = 0;
+ CbPackage RpcResult = HandleRpcRequest(ContentType.first, std::move(Body), AcceptMagic, AcceptFlags, TargetPid);
+ if (AcceptMagic == kCbPkgMagic)
+ {
+ FormatFlags Flags = FormatFlags::kDefault;
+ if (EnumHasAllFlags(AcceptFlags, RpcAcceptOptions::kAllowLocalReferences))
+ {
+ Flags |= FormatFlags::kAllowLocalReferences;
+ if (!EnumHasAnyFlags(AcceptFlags, RpcAcceptOptions::kAllowPartialLocalReferences))
+ {
+ Flags |= FormatFlags::kDenyPartialLocalReferences;
+ }
+ }
+ CompositeBuffer RpcResponseBuffer = FormatPackageMessageBuffer(RpcResult, Flags, TargetPid);
+ ZEN_ASSERT(RpcResponseBuffer.GetSize() > 0);
+ }
+ else
+ {
+ BinaryWriter MemStream;
+ RpcResult.Save(MemStream);
+ IoBuffer RpcResponseBuffer(IoBuffer::Wrap, MemStream.GetData(), MemStream.GetSize());
+ ZEN_ASSERT(RpcResponseBuffer.Size() > 0);
+ }
+ }
+ JobLatch.CountDown();
+ });
+ }
+ while (!JobLatch.Wait(10000))
+ {
+ ZEN_INFO("Replayed {} of {} requests, elapsed {}",
+ RequestCount - JobLatch.Remaining(),
+ RequestCount,
+ NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000));
+ }
+}
+
+void
+HttpStructuredCacheService::HandleRpcRequest(HttpServerRequest& Request)
+{
+ switch (Request.RequestVerb())
+ {
+ case HttpVerb::kPost:
+ {
+ const HttpContentType ContentType = Request.RequestContentType();
+ const HttpContentType AcceptType = Request.AcceptContentType();
+
+ if ((ContentType != HttpContentType::kCbObject && ContentType != HttpContentType::kCbPackage) ||
+ AcceptType != HttpContentType::kCbPackage)
+ {
+ return Request.WriteResponse(HttpResponseCode::BadRequest);
+ }
+
+ Request.WriteResponseAsync(
+ [this, Body = Request.ReadPayload(), ContentType, AcceptType](HttpServerRequest& AsyncRequest) mutable {
+ std::uint64_t RequestIndex =
+ m_RequestRecorder ? m_RequestRecorder->RecordRequest(ContentType, AcceptType, Body) : ~0ull;
+ uint32_t AcceptMagic = 0;
+ RpcAcceptOptions AcceptFlags = RpcAcceptOptions::kNone;
+ int TargetProcessId = 0;
+ CbPackage RpcResult = HandleRpcRequest(ContentType, std::move(Body), AcceptMagic, AcceptFlags, TargetProcessId);
+ if (RpcResult.IsNull())
+ {
+ AsyncRequest.WriteResponse(HttpResponseCode::BadRequest);
+ return;
+ }
+ if (AcceptMagic == kCbPkgMagic)
+ {
+ FormatFlags Flags = FormatFlags::kDefault;
+ if (EnumHasAllFlags(AcceptFlags, RpcAcceptOptions::kAllowLocalReferences))
+ {
+ Flags |= FormatFlags::kAllowLocalReferences;
+ if (!EnumHasAnyFlags(AcceptFlags, RpcAcceptOptions::kAllowPartialLocalReferences))
+ {
+ Flags |= FormatFlags::kDenyPartialLocalReferences;
+ }
+ }
+ CompositeBuffer RpcResponseBuffer = FormatPackageMessageBuffer(RpcResult, Flags, TargetProcessId);
+ if (RequestIndex != ~0ull)
+ {
+ ZEN_ASSERT(m_RequestRecorder);
+ m_RequestRecorder->RecordResponse(RequestIndex, HttpContentType::kCbPackage, RpcResponseBuffer);
+ }
+ AsyncRequest.WriteResponse(HttpResponseCode::OK, HttpContentType::kCbPackage, RpcResponseBuffer);
+ }
+ else
+ {
+ BinaryWriter MemStream;
+ RpcResult.Save(MemStream);
+
+ if (RequestIndex != ~0ull)
+ {
+ ZEN_ASSERT(m_RequestRecorder);
+ m_RequestRecorder->RecordResponse(RequestIndex,
+ HttpContentType::kCbPackage,
+ IoBuffer(IoBuffer::Wrap, MemStream.GetData(), MemStream.GetSize()));
+ }
+ AsyncRequest.WriteResponse(HttpResponseCode::OK,
+ HttpContentType::kCbPackage,
+ IoBuffer(IoBuffer::Wrap, MemStream.GetData(), MemStream.GetSize()));
+ }
+ });
+ }
+ break;
+ default:
+ Request.WriteResponse(HttpResponseCode::BadRequest);
+ break;
+ }
+}
+
+CbPackage
+HttpStructuredCacheService::HandleRpcPutCacheRecords(const CbPackage& BatchRequest)
+{
+ ZEN_TRACE_CPU("Z$::RpcPutCacheRecords");
+ CbObjectView BatchObject = BatchRequest.GetObject();
+ ZEN_ASSERT(BatchObject["Method"sv].AsString() == "PutCacheRecords"sv);
+
+ CbObjectView Params = BatchObject["Params"sv].AsObjectView();
+ CachePolicy DefaultPolicy;
+
+ std::string_view PolicyText = Params["DefaultPolicy"].AsString();
+ std::optional<std::string> Namespace = GetRpcRequestNamespace(Params);
+ if (!Namespace)
+ {
+ return CbPackage{};
+ }
+ DefaultPolicy = !PolicyText.empty() ? ParseCachePolicy(PolicyText) : CachePolicy::Default;
+ std::vector<bool> Results;
+ for (CbFieldView RequestField : Params["Requests"sv])
+ {
+ CbObjectView RequestObject = RequestField.AsObjectView();
+ CbObjectView RecordObject = RequestObject["Record"sv].AsObjectView();
+ CbObjectView KeyView = RecordObject["Key"sv].AsObjectView();
+
+ CacheKey Key;
+ if (!GetRpcRequestCacheKey(KeyView, Key))
+ {
+ return CbPackage{};
+ }
+ CacheRecordPolicy Policy = LoadCacheRecordPolicy(RequestObject["Policy"sv].AsObjectView(), DefaultPolicy);
+ PutRequestData PutRequest{*Namespace, std::move(Key), RecordObject, std::move(Policy)};
+
+ PutResult Result = PutCacheRecord(PutRequest, &BatchRequest);
+
+ if (Result == PutResult::Invalid)
+ {
+ return CbPackage{};
+ }
+ Results.push_back(Result == PutResult::Success);
+ }
+ if (Results.empty())
+ {
+ return CbPackage{};
+ }
+
+ CbObjectWriter ResponseObject;
+ ResponseObject.BeginArray("Result"sv);
+ for (bool Value : Results)
+ {
+ ResponseObject.AddBool(Value);
+ }
+ ResponseObject.EndArray();
+
+ CbPackage RpcResponse;
+ RpcResponse.SetObject(ResponseObject.Save());
+ return RpcResponse;
+}
+
+HttpStructuredCacheService::PutResult
+HttpStructuredCacheService::PutCacheRecord(PutRequestData& Request, const CbPackage* Package)
+{
+ CbObjectView Record = Request.RecordObject;
+ uint64_t RecordObjectSize = Record.GetSize();
+ uint64_t TransferredSize = RecordObjectSize;
+
+ AttachmentCount Count;
+ size_t NumAttachments = Package->GetAttachments().size();
+ std::vector<IoHash> ValidAttachments;
+ std::vector<const CbAttachment*> AttachmentsToStoreLocally;
+ ValidAttachments.reserve(NumAttachments);
+ AttachmentsToStoreLocally.reserve(NumAttachments);
+
+ Stopwatch Timer;
+
+ Request.RecordObject.IterateAttachments(
+ [this, &Request, Package, &AttachmentsToStoreLocally, &ValidAttachments, &Count, &TransferredSize](CbFieldView HashView) {
+ const IoHash ValueHash = HashView.AsHash();
+ if (const CbAttachment* Attachment = Package ? Package->FindAttachment(ValueHash) : nullptr)
+ {
+ if (Attachment->IsCompressedBinary())
+ {
+ AttachmentsToStoreLocally.emplace_back(Attachment);
+ ValidAttachments.emplace_back(ValueHash);
+ Count.Valid++;
+ }
+ else
+ {
+ ZEN_WARN("PUTCACEHRECORD - '{}/{}/{}' '{}' FAILED, attachment '{}' is not compressed",
+ Request.Namespace,
+ Request.Key.Bucket,
+ Request.Key.Hash,
+ ToString(HttpContentType::kCbPackage),
+ ValueHash);
+ Count.Invalid++;
+ }
+ }
+ else if (m_CidStore.ContainsChunk(ValueHash))
+ {
+ ValidAttachments.emplace_back(ValueHash);
+ Count.Valid++;
+ }
+ Count.Total++;
+ });
+
+ if (Count.Invalid > 0)
+ {
+ return PutResult::Invalid;
+ }
+
+ ZenCacheValue CacheValue;
+ CacheValue.Value = IoBuffer(Record.GetSize());
+ Record.CopyTo(MutableMemoryView(CacheValue.Value.MutableData(), CacheValue.Value.GetSize()));
+ CacheValue.Value.SetContentType(ZenContentType::kCbObject);
+ m_CacheStore.Put(Request.Namespace, Request.Key.Bucket, Request.Key.Hash, CacheValue);
+
+ for (const CbAttachment* Attachment : AttachmentsToStoreLocally)
+ {
+ CompressedBuffer Chunk = Attachment->AsCompressedBinary();
+ CidStore::InsertResult InsertResult = m_CidStore.AddChunk(Chunk.GetCompressed().Flatten().AsIoBuffer(), Attachment->GetHash());
+ if (InsertResult.New)
+ {
+ Count.New++;
+ }
+ TransferredSize += Chunk.GetCompressedSize();
+ }
+
+ ZEN_DEBUG("PUTCACEHRECORD - '{}/{}/{}' {}, attachments '{}/{}/{}' (new/valid/total) in {}",
+ Request.Namespace,
+ Request.Key.Bucket,
+ Request.Key.Hash,
+ NiceBytes(TransferredSize),
+ Count.New,
+ Count.Valid,
+ Count.Total,
+ NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000));
+
+ const bool IsPartialRecord = Count.Valid != Count.Total;
+
+ if (EnumHasAllFlags(Request.Policy.GetRecordPolicy(), CachePolicy::StoreRemote) && !IsPartialRecord)
+ {
+ m_UpstreamCache.EnqueueUpstream({.Type = ZenContentType::kCbPackage,
+ .Namespace = Request.Namespace,
+ .Key = Request.Key,
+ .ValueContentIds = std::move(ValidAttachments)});
+ }
+ return PutResult::Success;
+}
+
+CbPackage
+HttpStructuredCacheService::HandleRpcGetCacheRecords(CbObjectView RpcRequest)
+{
+ ZEN_TRACE_CPU("Z$::RpcGetCacheRecords");
+
+ ZEN_ASSERT(RpcRequest["Method"sv].AsString() == "GetCacheRecords"sv);
+
+ CbObjectView Params = RpcRequest["Params"sv].AsObjectView();
+
+ struct ValueRequestData
+ {
+ Oid ValueId;
+ IoHash ContentId;
+ CompressedBuffer Payload;
+ CachePolicy DownstreamPolicy;
+ bool Exists = false;
+ bool ReadFromUpstream = false;
+ };
+ struct RecordRequestData
+ {
+ CacheKeyRequest Upstream;
+ CbObjectView RecordObject;
+ IoBuffer RecordCacheValue;
+ CacheRecordPolicy DownstreamPolicy;
+ std::vector<ValueRequestData> Values;
+ bool Complete = false;
+ const UpstreamEndpointInfo* Source = nullptr;
+ uint64_t ElapsedTimeUs;
+ };
+
+ std::string_view PolicyText = Params["DefaultPolicy"sv].AsString();
+ CachePolicy DefaultPolicy = !PolicyText.empty() ? ParseCachePolicy(PolicyText) : CachePolicy::Default;
+ std::optional<std::string> Namespace = GetRpcRequestNamespace(Params);
+ if (!Namespace)
+ {
+ return CbPackage{};
+ }
+ std::vector<RecordRequestData> Requests;
+ std::vector<size_t> UpstreamIndexes;
+ CbArrayView RequestsArray = Params["Requests"sv].AsArrayView();
+ Requests.reserve(RequestsArray.Num());
+
+ auto ParseValues = [](RecordRequestData& Request) {
+ CbArrayView ValuesArray = Request.RecordObject["Values"sv].AsArrayView();
+ Request.Values.reserve(ValuesArray.Num());
+ for (CbFieldView ValueField : ValuesArray)
+ {
+ CbObjectView ValueObject = ValueField.AsObjectView();
+ Oid ValueId = ValueObject["Id"sv].AsObjectId();
+ CbFieldView RawHashField = ValueObject["RawHash"sv];
+ IoHash RawHash = RawHashField.AsBinaryAttachment();
+ if (ValueId && !RawHashField.HasError())
+ {
+ Request.Values.push_back({ValueId, RawHash});
+ Request.Values.back().DownstreamPolicy = Request.DownstreamPolicy.GetValuePolicy(ValueId);
+ }
+ }
+ };
+
+ for (CbFieldView RequestField : RequestsArray)
+ {
+ Stopwatch Timer;
+ RecordRequestData& Request = Requests.emplace_back();
+ CbObjectView RequestObject = RequestField.AsObjectView();
+ CbObjectView KeyObject = RequestObject["Key"sv].AsObjectView();
+
+ CacheKey& Key = Request.Upstream.Key;
+ if (!GetRpcRequestCacheKey(KeyObject, Key))
+ {
+ return CbPackage{};
+ }
+
+ Request.DownstreamPolicy = LoadCacheRecordPolicy(RequestObject["Policy"sv].AsObjectView(), DefaultPolicy);
+ const CacheRecordPolicy& Policy = Request.DownstreamPolicy;
+
+ ZenCacheValue CacheValue;
+ bool NeedUpstreamAttachment = false;
+ bool FoundLocalInvalid = false;
+ ZenCacheValue RecordCacheValue;
+
+ if (EnumHasAllFlags(Policy.GetRecordPolicy(), CachePolicy::QueryLocal) &&
+ m_CacheStore.Get(*Namespace, Key.Bucket, Key.Hash, RecordCacheValue))
+ {
+ Request.RecordCacheValue = std::move(RecordCacheValue.Value);
+ if (Request.RecordCacheValue.GetContentType() != ZenContentType::kCbObject)
+ {
+ FoundLocalInvalid = true;
+ }
+ else
+ {
+ Request.RecordObject = CbObjectView(Request.RecordCacheValue.GetData());
+ ParseValues(Request);
+
+ Request.Complete = true;
+ for (ValueRequestData& Value : Request.Values)
+ {
+ CachePolicy ValuePolicy = Value.DownstreamPolicy;
+ if (!EnumHasAllFlags(ValuePolicy, CachePolicy::QueryLocal))
+ {
+ // A value that is requested without the Query flag (such as None/Disable) counts as existing, because we
+ // didn't ask for it and thus the record is complete in its absence.
+ if (!EnumHasAllFlags(ValuePolicy, CachePolicy::QueryRemote))
+ {
+ Value.Exists = true;
+ }
+ else
+ {
+ NeedUpstreamAttachment = true;
+ Value.ReadFromUpstream = true;
+ Request.Complete = false;
+ }
+ }
+ else if (EnumHasAllFlags(ValuePolicy, CachePolicy::SkipData))
+ {
+ if (m_CidStore.ContainsChunk(Value.ContentId))
+ {
+ Value.Exists = true;
+ }
+ else
+ {
+ if (EnumHasAllFlags(ValuePolicy, CachePolicy::QueryRemote))
+ {
+ NeedUpstreamAttachment = true;
+ Value.ReadFromUpstream = true;
+ }
+ Request.Complete = false;
+ }
+ }
+ else
+ {
+ if (IoBuffer Chunk = m_CidStore.FindChunkByCid(Value.ContentId))
+ {
+ ZEN_ASSERT(Chunk.GetSize() > 0);
+ Value.Payload = CompressedBuffer::FromCompressedNoValidate(std::move(Chunk));
+ Value.Exists = true;
+ }
+ else
+ {
+ if (EnumHasAllFlags(ValuePolicy, CachePolicy::QueryRemote))
+ {
+ NeedUpstreamAttachment = true;
+ Value.ReadFromUpstream = true;
+ }
+ Request.Complete = false;
+ }
+ }
+ }
+ }
+ }
+ if (!Request.Complete)
+ {
+ bool NeedUpstreamRecord =
+ !Request.RecordObject && !FoundLocalInvalid && EnumHasAllFlags(Policy.GetRecordPolicy(), CachePolicy::QueryRemote);
+ if (NeedUpstreamRecord || NeedUpstreamAttachment)
+ {
+ UpstreamIndexes.push_back(Requests.size() - 1);
+ }
+ }
+ Request.ElapsedTimeUs = Timer.GetElapsedTimeUs();
+ }
+ if (Requests.empty())
+ {
+ return CbPackage{};
+ }
+
+ if (!UpstreamIndexes.empty())
+ {
+ std::vector<CacheKeyRequest*> UpstreamRequests;
+ UpstreamRequests.reserve(UpstreamIndexes.size());
+ for (size_t Index : UpstreamIndexes)
+ {
+ RecordRequestData& Request = Requests[Index];
+ UpstreamRequests.push_back(&Request.Upstream);
+
+ if (Request.Values.size())
+ {
+ // We will be returning the local object and know all the value Ids that exist in it
+ // Convert all their Downstream Values to upstream values, and add SkipData to any ones that we already have.
+ CachePolicy UpstreamBasePolicy = ConvertToUpstream(Request.DownstreamPolicy.GetBasePolicy()) | CachePolicy::SkipMeta;
+ CacheRecordPolicyBuilder Builder(UpstreamBasePolicy);
+ for (ValueRequestData& Value : Request.Values)
+ {
+ CachePolicy UpstreamPolicy = ConvertToUpstream(Value.DownstreamPolicy);
+ UpstreamPolicy |= !Value.ReadFromUpstream ? CachePolicy::SkipData : CachePolicy::None;
+ Builder.AddValuePolicy(Value.ValueId, UpstreamPolicy);
+ }
+ Request.Upstream.Policy = Builder.Build();
+ }
+ else
+ {
+ // We don't know which Values exist in the Record; ask the upstrem for all values that the client wants,
+ // and convert the CacheRecordPolicy to an upstream policy
+ Request.Upstream.Policy = Request.DownstreamPolicy.ConvertToUpstream();
+ }
+ }
+
+ const auto OnCacheRecordGetComplete = [this, Namespace, &ParseValues](CacheRecordGetCompleteParams&& Params) {
+ if (!Params.Record)
+ {
+ return;
+ }
+
+ RecordRequestData& Request =
+ *reinterpret_cast<RecordRequestData*>(reinterpret_cast<char*>(&Params.Request) - offsetof(RecordRequestData, Upstream));
+ Request.ElapsedTimeUs += static_cast<uint64_t>(Params.ElapsedSeconds * 1000000.0);
+ const CacheKey& Key = Request.Upstream.Key;
+ Stopwatch Timer;
+ auto TimeGuard = MakeGuard([&Timer, &Request]() { Request.ElapsedTimeUs += Timer.GetElapsedTimeUs(); });
+ if (!Request.RecordObject)
+ {
+ CbObject ObjectBuffer = CbObject::Clone(Params.Record);
+ Request.RecordCacheValue = ObjectBuffer.GetBuffer().AsIoBuffer();
+ Request.RecordCacheValue.SetContentType(ZenContentType::kCbObject);
+ Request.RecordObject = ObjectBuffer;
+ if (EnumHasAllFlags(Request.DownstreamPolicy.GetRecordPolicy(), CachePolicy::StoreLocal))
+ {
+ m_CacheStore.Put(*Namespace, Key.Bucket, Key.Hash, {.Value = {Request.RecordCacheValue}});
+ }
+ ParseValues(Request);
+ Request.Source = Params.Source;
+ }
+
+ Request.Complete = true;
+ for (ValueRequestData& Value : Request.Values)
+ {
+ if (Value.Exists)
+ {
+ continue;
+ }
+ CachePolicy ValuePolicy = Value.DownstreamPolicy;
+ if (!EnumHasAllFlags(ValuePolicy, CachePolicy::QueryRemote))
+ {
+ Request.Complete = false;
+ continue;
+ }
+ if (!EnumHasAllFlags(ValuePolicy, CachePolicy::SkipData) || EnumHasAllFlags(ValuePolicy, CachePolicy::StoreLocal))
+ {
+ if (const CbAttachment* Attachment = Params.Package.FindAttachment(Value.ContentId))
+ {
+ if (CompressedBuffer Compressed = Attachment->AsCompressedBinary())
+ {
+ Request.Source = Params.Source;
+ Value.Exists = true;
+ if (EnumHasAllFlags(ValuePolicy, CachePolicy::StoreLocal))
+ {
+ m_CidStore.AddChunk(Compressed.GetCompressed().Flatten().AsIoBuffer(), Attachment->GetHash());
+ }
+ if (!EnumHasAllFlags(ValuePolicy, CachePolicy::SkipData))
+ {
+ Value.Payload = Compressed;
+ }
+ }
+ else
+ {
+ ZEN_DEBUG("Uncompressed value '{}' from upstream cache record '{}/{}/{}'",
+ Value.ContentId,
+ *Namespace,
+ Key.Bucket,
+ Key.Hash);
+ }
+ }
+ if (!Value.Exists && !EnumHasAllFlags(ValuePolicy, CachePolicy::SkipData))
+ {
+ Request.Complete = false;
+ }
+ // Request.Complete does not need to be set to false for upstream SkipData attachments.
+ // In the PartialRecord==false case, the upstream will have failed the entire record if any SkipData attachment
+ // didn't exist and we will not get here. In the PartialRecord==true case, we do not need to inform the client of
+ // any missing SkipData attachments.
+ }
+ Request.ElapsedTimeUs += Timer.GetElapsedTimeUs();
+ }
+ };
+
+ m_UpstreamCache.GetCacheRecords(*Namespace, UpstreamRequests, std::move(OnCacheRecordGetComplete));
+ }
+
+ CbPackage ResponsePackage;
+ CbObjectWriter ResponseObject;
+
+ ResponseObject.BeginArray("Result"sv);
+ for (RecordRequestData& Request : Requests)
+ {
+ const CacheKey& Key = Request.Upstream.Key;
+ if (Request.Complete ||
+ (Request.RecordObject && EnumHasAllFlags(Request.DownstreamPolicy.GetRecordPolicy(), CachePolicy::PartialRecord)))
+ {
+ ResponseObject << Request.RecordObject;
+ for (ValueRequestData& Value : Request.Values)
+ {
+ if (!EnumHasAllFlags(Value.DownstreamPolicy, CachePolicy::SkipData) && Value.Payload)
+ {
+ ResponsePackage.AddAttachment(CbAttachment(Value.Payload, Value.ContentId));
+ }
+ }
+
+ ZEN_DEBUG("GETCACHERECORD HIT - '{}/{}/{}' {}{} ({}) in {}",
+ *Namespace,
+ Key.Bucket,
+ Key.Hash,
+ NiceBytes(Request.RecordCacheValue.Size()),
+ Request.Complete ? ""sv : " (PARTIAL)"sv,
+ Request.Source ? Request.Source->Url : "LOCAL"sv,
+ NiceLatencyNs(Request.ElapsedTimeUs * 1000));
+ m_CacheStats.HitCount++;
+ m_CacheStats.UpstreamHitCount += Request.Source ? 1 : 0;
+ }
+ else
+ {
+ ResponseObject.AddNull();
+
+ if (!EnumHasAnyFlags(Request.DownstreamPolicy.GetRecordPolicy(), CachePolicy::Query))
+ {
+ // If they requested no query, do not record this as a miss
+ ZEN_DEBUG("GETCACHERECORD DISABLEDQUERY - '{}/{}/{}' in {}",
+ *Namespace,
+ Key.Bucket,
+ Key.Hash,
+ NiceLatencyNs(Request.ElapsedTimeUs * 1000));
+ }
+ else
+ {
+ ZEN_DEBUG("GETCACHERECORD MISS - '{}/{}/{}'{} ({}) in {}",
+ *Namespace,
+ Key.Bucket,
+ Key.Hash,
+ Request.RecordObject ? ""sv : " (PARTIAL)"sv,
+ Request.Source ? Request.Source->Url : "LOCAL"sv,
+ NiceLatencyNs(Request.ElapsedTimeUs * 1000));
+ m_CacheStats.MissCount++;
+ }
+ }
+ }
+ ResponseObject.EndArray();
+ ResponsePackage.SetObject(ResponseObject.Save());
+ return ResponsePackage;
+}
+
+CbPackage
+HttpStructuredCacheService::HandleRpcPutCacheValues(const CbPackage& BatchRequest)
+{
+ CbObjectView BatchObject = BatchRequest.GetObject();
+ CbObjectView Params = BatchObject["Params"sv].AsObjectView();
+
+ std::string_view PolicyText = Params["DefaultPolicy"].AsString();
+ CachePolicy DefaultPolicy = !PolicyText.empty() ? ParseCachePolicy(PolicyText) : CachePolicy::Default;
+ std::optional<std::string> Namespace = GetRpcRequestNamespace(Params);
+ if (!Namespace)
+ {
+ return CbPackage{};
+ }
+ std::vector<bool> Results;
+ for (CbFieldView RequestField : Params["Requests"sv])
+ {
+ Stopwatch Timer;
+
+ CbObjectView RequestObject = RequestField.AsObjectView();
+ CbObjectView KeyView = RequestObject["Key"sv].AsObjectView();
+
+ CacheKey Key;
+ if (!GetRpcRequestCacheKey(KeyView, Key))
+ {
+ return CbPackage{};
+ }
+
+ PolicyText = RequestObject["Policy"sv].AsString();
+ CachePolicy Policy = !PolicyText.empty() ? ParseCachePolicy(PolicyText) : DefaultPolicy;
+ IoHash RawHash = RequestObject["RawHash"sv].AsBinaryAttachment();
+ uint64_t RawSize = RequestObject["RawSize"sv].AsUInt64();
+ bool Succeeded = false;
+ uint64_t TransferredSize = 0;
+
+ if (const CbAttachment* Attachment = BatchRequest.FindAttachment(RawHash))
+ {
+ if (Attachment->IsCompressedBinary())
+ {
+ CompressedBuffer Chunk = Attachment->AsCompressedBinary();
+ if (EnumHasAllFlags(Policy, CachePolicy::StoreRemote))
+ {
+ // TODO: Implement upstream puts of CacheValues with StoreLocal == false.
+ // Currently ProcessCacheRecord requires that the value exist in the local cache to put it upstream.
+ Policy |= CachePolicy::StoreLocal;
+ }
+
+ if (EnumHasAllFlags(Policy, CachePolicy::StoreLocal))
+ {
+ IoBuffer Value = Chunk.GetCompressed().Flatten().AsIoBuffer();
+ Value.SetContentType(ZenContentType::kCompressedBinary);
+ if (RawSize == 0)
+ {
+ RawSize = Chunk.DecodeRawSize();
+ }
+ m_CacheStore.Put(*Namespace, Key.Bucket, Key.Hash, {.Value = Value, .RawSize = RawSize, .RawHash = RawHash});
+ TransferredSize = Chunk.GetCompressedSize();
+ }
+ Succeeded = true;
+ }
+ else
+ {
+ ZEN_WARN("PUTCACHEVALUES - '{}/{}/{}/{}' FAILED, value is not compressed", *Namespace, Key.Bucket, Key.Hash, RawHash);
+ return CbPackage{};
+ }
+ }
+ else if (EnumHasAllFlags(Policy, CachePolicy::QueryLocal))
+ {
+ ZenCacheValue ExistingValue;
+ if (m_CacheStore.Get(*Namespace, Key.Bucket, Key.Hash, ExistingValue) &&
+ IsCompressedBinary(ExistingValue.Value.GetContentType()))
+ {
+ Succeeded = true;
+ }
+ }
+ // We do not search the Upstream. No data in a put means the caller is probing for whether they need to do a heavy put.
+ // If it doesn't exist locally they should do the heavy put rather than having us fetch it from upstream.
+
+ if (Succeeded && EnumHasAllFlags(Policy, CachePolicy::StoreRemote))
+ {
+ m_UpstreamCache.EnqueueUpstream({.Type = ZenContentType::kCompressedBinary, .Namespace = *Namespace, .Key = Key});
+ }
+ Results.push_back(Succeeded);
+ ZEN_DEBUG("PUTCACHEVALUES - '{}/{}/{}' {}, '{}' in {}",
+ *Namespace,
+ Key.Bucket,
+ Key.Hash,
+ NiceBytes(TransferredSize),
+ Succeeded ? "Added"sv : "Invalid",
+ NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000));
+ }
+ if (Results.empty())
+ {
+ return CbPackage{};
+ }
+
+ CbObjectWriter ResponseObject;
+ ResponseObject.BeginArray("Result"sv);
+ for (bool Value : Results)
+ {
+ ResponseObject.AddBool(Value);
+ }
+ ResponseObject.EndArray();
+
+ CbPackage RpcResponse;
+ RpcResponse.SetObject(ResponseObject.Save());
+
+ return RpcResponse;
+}
+
+CbPackage
+HttpStructuredCacheService::HandleRpcGetCacheValues(CbObjectView RpcRequest)
+{
+ ZEN_ASSERT(RpcRequest["Method"sv].AsString() == "GetCacheValues"sv);
+
+ CbObjectView Params = RpcRequest["Params"sv].AsObjectView();
+ std::string_view PolicyText = Params["DefaultPolicy"sv].AsString();
+ CachePolicy DefaultPolicy = !PolicyText.empty() ? ParseCachePolicy(PolicyText) : CachePolicy::Default;
+ std::optional<std::string> Namespace = GetRpcRequestNamespace(Params);
+ if (!Namespace)
+ {
+ return CbPackage{};
+ }
+
+ struct RequestData
+ {
+ CacheKey Key;
+ CachePolicy Policy;
+ IoHash RawHash = IoHash::Zero;
+ uint64_t RawSize = 0;
+ CompressedBuffer Result;
+ };
+ std::vector<RequestData> Requests;
+
+ std::vector<size_t> RemoteRequestIndexes;
+
+ for (CbFieldView RequestField : Params["Requests"sv])
+ {
+ Stopwatch Timer;
+
+ RequestData& Request = Requests.emplace_back();
+ CbObjectView RequestObject = RequestField.AsObjectView();
+ CbObjectView KeyObject = RequestObject["Key"sv].AsObjectView();
+
+ if (!GetRpcRequestCacheKey(KeyObject, Request.Key))
+ {
+ return CbPackage{};
+ }
+
+ PolicyText = RequestObject["Policy"sv].AsString();
+ Request.Policy = !PolicyText.empty() ? ParseCachePolicy(PolicyText) : DefaultPolicy;
+
+ CacheKey& Key = Request.Key;
+ CachePolicy Policy = Request.Policy;
+
+ ZenCacheValue CacheValue;
+ if (EnumHasAllFlags(Policy, CachePolicy::QueryLocal))
+ {
+ if (m_CacheStore.Get(*Namespace, Key.Bucket, Key.Hash, CacheValue) && IsCompressedBinary(CacheValue.Value.GetContentType()))
+ {
+ Request.RawHash = CacheValue.RawHash;
+ Request.RawSize = CacheValue.RawSize;
+ Request.Result = CompressedBuffer::FromCompressedNoValidate(std::move(CacheValue.Value));
+ }
+ }
+ if (Request.Result)
+ {
+ ZEN_DEBUG("GETCACHEVALUES HIT - '{}/{}/{}' {} ({}) in {}",
+ *Namespace,
+ Key.Bucket,
+ Key.Hash,
+ NiceBytes(Request.Result.GetCompressed().GetSize()),
+ "LOCAL"sv,
+ NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000));
+ m_CacheStats.HitCount++;
+ }
+ else if (EnumHasAllFlags(Policy, CachePolicy::QueryRemote))
+ {
+ RemoteRequestIndexes.push_back(Requests.size() - 1);
+ }
+ else if (!EnumHasAnyFlags(Policy, CachePolicy::Query))
+ {
+ // If they requested no query, do not record this as a miss
+ ZEN_DEBUG("GETCACHEVALUES DISABLEDQUERY - '{}/{}/{}'", *Namespace, Key.Bucket, Key.Hash);
+ }
+ else
+ {
+ ZEN_DEBUG("GETCACHEVALUES MISS - '{}/{}/{}' ({}) in {}",
+ *Namespace,
+ Key.Bucket,
+ Key.Hash,
+ "LOCAL"sv,
+ NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000));
+ m_CacheStats.MissCount++;
+ }
+ }
+
+ if (!RemoteRequestIndexes.empty())
+ {
+ std::vector<CacheValueRequest> RequestedRecordsData;
+ std::vector<CacheValueRequest*> CacheValueRequests;
+ RequestedRecordsData.reserve(RemoteRequestIndexes.size());
+ CacheValueRequests.reserve(RemoteRequestIndexes.size());
+ for (size_t Index : RemoteRequestIndexes)
+ {
+ RequestData& Request = Requests[Index];
+ RequestedRecordsData.push_back({.Key = {Request.Key.Bucket, Request.Key.Hash}, .Policy = ConvertToUpstream(Request.Policy)});
+ CacheValueRequests.push_back(&RequestedRecordsData.back());
+ }
+ Stopwatch Timer;
+ m_UpstreamCache.GetCacheValues(
+ *Namespace,
+ CacheValueRequests,
+ [this, Namespace, &RequestedRecordsData, &Requests, &RemoteRequestIndexes, &Timer](CacheValueGetCompleteParams&& Params) {
+ CacheValueRequest& ChunkRequest = Params.Request;
+ if (Params.RawHash != IoHash::Zero)
+ {
+ size_t RequestOffset = std::distance(RequestedRecordsData.data(), &ChunkRequest);
+ size_t RequestIndex = RemoteRequestIndexes[RequestOffset];
+ RequestData& Request = Requests[RequestIndex];
+ Request.RawHash = Params.RawHash;
+ Request.RawSize = Params.RawSize;
+ const bool HasData = IsCompressedBinary(Params.Value.GetContentType());
+ const bool SkipData = EnumHasAllFlags(Request.Policy, CachePolicy::SkipData);
+ const bool StoreData = EnumHasAllFlags(Request.Policy, CachePolicy::StoreLocal);
+ const bool IsHit = SkipData || HasData;
+ if (IsHit)
+ {
+ if (HasData && !SkipData)
+ {
+ Request.Result = CompressedBuffer::FromCompressedNoValidate(IoBuffer(Params.Value));
+ }
+
+ if (HasData && StoreData)
+ {
+ m_CacheStore.Put(*Namespace,
+ Request.Key.Bucket,
+ Request.Key.Hash,
+ ZenCacheValue{.Value = Params.Value, .RawSize = Request.RawSize, .RawHash = Request.RawHash});
+ }
+
+ ZEN_DEBUG("GETCACHEVALUES HIT - '{}/{}/{}' {} ({}) in {}",
+ *Namespace,
+ ChunkRequest.Key.Bucket,
+ ChunkRequest.Key.Hash,
+ NiceBytes(Request.Result.GetCompressed().GetSize()),
+ Params.Source ? Params.Source->Url : "UPSTREAM",
+ NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000));
+ m_CacheStats.HitCount++;
+ m_CacheStats.UpstreamHitCount++;
+ return;
+ }
+ }
+ ZEN_DEBUG("GETCACHEVALUES MISS - '{}/{}/{}' ({}) in {}",
+ *Namespace,
+ ChunkRequest.Key.Bucket,
+ ChunkRequest.Key.Hash,
+ Params.Source ? Params.Source->Url : "UPSTREAM",
+ NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000));
+ m_CacheStats.MissCount++;
+ });
+ }
+
+ if (Requests.empty())
+ {
+ return CbPackage{};
+ }
+
+ CbPackage RpcResponse;
+ CbObjectWriter ResponseObject;
+ ResponseObject.BeginArray("Result"sv);
+ for (const RequestData& Request : Requests)
+ {
+ ResponseObject.BeginObject();
+ {
+ const CompressedBuffer& Result = Request.Result;
+ if (Result)
+ {
+ ResponseObject.AddHash("RawHash"sv, Request.RawHash);
+ if (!EnumHasAllFlags(Request.Policy, CachePolicy::SkipData))
+ {
+ RpcResponse.AddAttachment(CbAttachment(Result, Request.RawHash));
+ }
+ else
+ {
+ ResponseObject.AddInteger("RawSize"sv, Request.RawSize);
+ }
+ }
+ else if (Request.RawHash != IoHash::Zero)
+ {
+ ResponseObject.AddHash("RawHash"sv, Request.RawHash);
+ ResponseObject.AddInteger("RawSize"sv, Request.RawSize);
+ }
+ }
+ ResponseObject.EndObject();
+ }
+ ResponseObject.EndArray();
+
+ RpcResponse.SetObject(ResponseObject.Save());
+ return RpcResponse;
+}
+
+namespace cache::detail {
+
+ struct RecordValue
+ {
+ Oid ValueId;
+ IoHash ContentId;
+ uint64_t RawSize;
+ };
+ struct RecordBody
+ {
+ IoBuffer CacheValue;
+ std::vector<RecordValue> Values;
+ const UpstreamEndpointInfo* Source = nullptr;
+ CachePolicy DownstreamPolicy;
+ bool Exists = false;
+ bool HasRequest = false;
+ bool ValuesRead = false;
+ };
+ struct ChunkRequest
+ {
+ CacheChunkRequest* Key = nullptr;
+ RecordBody* Record = nullptr;
+ CompressedBuffer Value;
+ const UpstreamEndpointInfo* Source = nullptr;
+ uint64_t RawSize = 0;
+ uint64_t RequestedSize = 0;
+ uint64_t RequestedOffset = 0;
+ CachePolicy DownstreamPolicy;
+ bool Exists = false;
+ bool RawSizeKnown = false;
+ bool IsRecordRequest = false;
+ uint64_t ElapsedTimeUs = 0;
+ };
+
+} // namespace cache::detail
+
+CbPackage
+HttpStructuredCacheService::HandleRpcGetCacheChunks(CbObjectView RpcRequest)
+{
+ using namespace cache::detail;
+
+ std::string Namespace;
+ std::vector<CacheKeyRequest> RecordKeys; // Data about a Record necessary to identify it to the upstream
+ std::vector<RecordBody> Records; // Scratch-space data about a Record when fulfilling RecordRequests
+ std::vector<CacheChunkRequest> RequestKeys; // Data about a ChunkRequest necessary to identify it to the upstream
+ std::vector<ChunkRequest> Requests; // Intermediate and result data about a ChunkRequest
+ std::vector<ChunkRequest*> RecordRequests; // The ChunkRequests that are requesting a subvalue from a Record Key
+ std::vector<ChunkRequest*> ValueRequests; // The ChunkRequests that are requesting a Value Key
+ std::vector<CacheChunkRequest*> UpstreamChunks; // ChunkRequests that we need to send to the upstream
+
+ // Parse requests from the CompactBinary body of the RpcRequest and divide it into RecordRequests and ValueRequests
+ if (!ParseGetCacheChunksRequest(Namespace, RecordKeys, Records, RequestKeys, Requests, RecordRequests, ValueRequests, RpcRequest))
+ {
+ return CbPackage{};
+ }
+
+ // For each Record request, load the Record if necessary to find the Chunk's ContentId, load its Payloads if we
+ // have it locally, and otherwise append a request for the payload to UpstreamChunks
+ GetLocalCacheRecords(Namespace, RecordKeys, Records, RecordRequests, UpstreamChunks);
+
+ // For each Value request, load the Value if we have it locally and otherwise append a request for the payload to UpstreamChunks
+ GetLocalCacheValues(Namespace, ValueRequests, UpstreamChunks);
+
+ // Call GetCacheChunks on the upstream for any payloads we do not have locally
+ GetUpstreamCacheChunks(Namespace, UpstreamChunks, RequestKeys, Requests);
+
+ // Send the payload and descriptive data about each chunk to the client
+ return WriteGetCacheChunksResponse(Namespace, Requests);
+}
+
+bool
+HttpStructuredCacheService::ParseGetCacheChunksRequest(std::string& Namespace,
+ std::vector<CacheKeyRequest>& RecordKeys,
+ std::vector<cache::detail::RecordBody>& Records,
+ std::vector<CacheChunkRequest>& RequestKeys,
+ std::vector<cache::detail::ChunkRequest>& Requests,
+ std::vector<cache::detail::ChunkRequest*>& RecordRequests,
+ std::vector<cache::detail::ChunkRequest*>& ValueRequests,
+ CbObjectView RpcRequest)
+{
+ using namespace cache::detail;
+
+ ZEN_ASSERT(RpcRequest["Method"sv].AsString() == "GetCacheChunks"sv);
+
+ CbObjectView Params = RpcRequest["Params"sv].AsObjectView();
+ std::string_view DefaultPolicyText = Params["DefaultPolicy"sv].AsString();
+ CachePolicy DefaultPolicy = !DefaultPolicyText.empty() ? ParseCachePolicy(DefaultPolicyText) : CachePolicy::Default;
+
+ std::optional<std::string> NamespaceText = GetRpcRequestNamespace(Params);
+ if (!NamespaceText)
+ {
+ ZEN_WARN("GetCacheChunks: Invalid namespace in ChunkRequest.");
+ return false;
+ }
+ Namespace = *NamespaceText;
+
+ CbArrayView ChunkRequestsArray = Params["ChunkRequests"sv].AsArrayView();
+ size_t NumRequests = static_cast<size_t>(ChunkRequestsArray.Num());
+
+ // Note that these reservations allow us to take pointers to the elements while populating them. If the reservation is removed,
+ // we will need to change the pointers to indexes to handle reallocations.
+ RecordKeys.reserve(NumRequests);
+ Records.reserve(NumRequests);
+ RequestKeys.reserve(NumRequests);
+ Requests.reserve(NumRequests);
+ RecordRequests.reserve(NumRequests);
+ ValueRequests.reserve(NumRequests);
+
+ CacheKeyRequest* PreviousRecordKey = nullptr;
+ RecordBody* PreviousRecord = nullptr;
+
+ for (CbFieldView RequestView : ChunkRequestsArray)
+ {
+ CbObjectView RequestObject = RequestView.AsObjectView();
+ CacheChunkRequest& RequestKey = RequestKeys.emplace_back();
+ ChunkRequest& Request = Requests.emplace_back();
+ CbObjectView KeyObject = RequestObject["Key"sv].AsObjectView();
+
+ Request.Key = &RequestKey;
+ if (!GetRpcRequestCacheKey(KeyObject, Request.Key->Key))
+ {
+ ZEN_WARN("GetCacheChunks: Invalid key in ChunkRequest.");
+ return false;
+ }
+
+ RequestKey.ChunkId = RequestObject["ChunkId"sv].AsHash();
+ RequestKey.ValueId = RequestObject["ValueId"sv].AsObjectId();
+ RequestKey.RawOffset = RequestObject["RawOffset"sv].AsUInt64();
+ RequestKey.RawSize = RequestObject["RawSize"sv].AsUInt64(UINT64_MAX);
+ Request.RequestedSize = RequestKey.RawSize;
+ Request.RequestedOffset = RequestKey.RawOffset;
+ std::string_view PolicyText = RequestObject["Policy"sv].AsString();
+ Request.DownstreamPolicy = !PolicyText.empty() ? ParseCachePolicy(PolicyText) : DefaultPolicy;
+ Request.IsRecordRequest = (bool)RequestKey.ValueId;
+
+ if (!Request.IsRecordRequest)
+ {
+ ValueRequests.push_back(&Request);
+ }
+ else
+ {
+ RecordRequests.push_back(&Request);
+ CacheKeyRequest* RecordKey = nullptr;
+ RecordBody* Record = nullptr;
+
+ if (!PreviousRecordKey || PreviousRecordKey->Key < RequestKey.Key)
+ {
+ RecordKey = &RecordKeys.emplace_back();
+ PreviousRecordKey = RecordKey;
+ Record = &Records.emplace_back();
+ PreviousRecord = Record;
+ RecordKey->Key = RequestKey.Key;
+ }
+ else if (RequestKey.Key == PreviousRecordKey->Key)
+ {
+ RecordKey = PreviousRecordKey;
+ Record = PreviousRecord;
+ }
+ else
+ {
+ ZEN_WARN("GetCacheChunks: Keys in ChunkRequest are not sorted: {}/{} came after {}/{}.",
+ RequestKey.Key.Bucket,
+ RequestKey.Key.Hash,
+ PreviousRecordKey->Key.Bucket,
+ PreviousRecordKey->Key.Hash);
+ return false;
+ }
+ Request.Record = Record;
+ if (RequestKey.ChunkId == RequestKey.ChunkId.Zero)
+ {
+ Record->DownstreamPolicy =
+ Record->HasRequest ? Union(Record->DownstreamPolicy, Request.DownstreamPolicy) : Request.DownstreamPolicy;
+ Record->HasRequest = true;
+ }
+ }
+ }
+ if (Requests.empty())
+ {
+ return false;
+ }
+ return true;
+}
+
+void
+HttpStructuredCacheService::GetLocalCacheRecords(std::string_view Namespace,
+ std::vector<CacheKeyRequest>& RecordKeys,
+ std::vector<cache::detail::RecordBody>& Records,
+ std::vector<cache::detail::ChunkRequest*>& RecordRequests,
+ std::vector<CacheChunkRequest*>& OutUpstreamChunks)
+{
+ using namespace cache::detail;
+
+ std::vector<CacheKeyRequest*> UpstreamRecordRequests;
+ for (size_t RecordIndex = 0; RecordIndex < Records.size(); ++RecordIndex)
+ {
+ Stopwatch Timer;
+ CacheKeyRequest& RecordKey = RecordKeys[RecordIndex];
+ RecordBody& Record = Records[RecordIndex];
+ if (Record.HasRequest)
+ {
+ Record.DownstreamPolicy |= CachePolicy::SkipData | CachePolicy::SkipMeta;
+
+ if (!Record.Exists && EnumHasAllFlags(Record.DownstreamPolicy, CachePolicy::QueryLocal))
+ {
+ ZenCacheValue CacheValue;
+ if (m_CacheStore.Get(Namespace, RecordKey.Key.Bucket, RecordKey.Key.Hash, CacheValue))
+ {
+ Record.Exists = true;
+ Record.CacheValue = std::move(CacheValue.Value);
+ }
+ }
+ if (!Record.Exists && EnumHasAllFlags(Record.DownstreamPolicy, CachePolicy::QueryRemote))
+ {
+ RecordKey.Policy = CacheRecordPolicy(ConvertToUpstream(Record.DownstreamPolicy));
+ UpstreamRecordRequests.push_back(&RecordKey);
+ }
+ RecordRequests[RecordIndex]->ElapsedTimeUs += Timer.GetElapsedTimeUs();
+ }
+ }
+
+ if (!UpstreamRecordRequests.empty())
+ {
+ const auto OnCacheRecordGetComplete =
+ [this, Namespace, &RecordKeys, &Records, &RecordRequests](CacheRecordGetCompleteParams&& Params) {
+ if (!Params.Record)
+ {
+ return;
+ }
+ CacheKeyRequest& RecordKey = Params.Request;
+ size_t RecordIndex = std::distance(RecordKeys.data(), &RecordKey);
+ RecordRequests[RecordIndex]->ElapsedTimeUs += static_cast<uint64_t>(Params.ElapsedSeconds * 1000000.0);
+ RecordBody& Record = Records[RecordIndex];
+
+ const CacheKey& Key = RecordKey.Key;
+ Record.Exists = true;
+ CbObject ObjectBuffer = CbObject::Clone(Params.Record);
+ Record.CacheValue = ObjectBuffer.GetBuffer().AsIoBuffer();
+ Record.CacheValue.SetContentType(ZenContentType::kCbObject);
+ Record.Source = Params.Source;
+
+ if (EnumHasAllFlags(Record.DownstreamPolicy, CachePolicy::StoreLocal))
+ {
+ m_CacheStore.Put(Namespace, Key.Bucket, Key.Hash, {.Value = Record.CacheValue});
+ }
+ };
+ m_UpstreamCache.GetCacheRecords(Namespace, UpstreamRecordRequests, std::move(OnCacheRecordGetComplete));
+ }
+
+ std::vector<CacheChunkRequest*> UpstreamPayloadRequests;
+ for (ChunkRequest* Request : RecordRequests)
+ {
+ Stopwatch Timer;
+ if (Request->Key->ChunkId == IoHash::Zero)
+ {
+ // Unreal uses a 12 byte ID to address cache record values. When the uncompressed hash (ChunkId)
+ // is missing, parse the cache record and try to find the raw hash from the ValueId.
+ RecordBody& Record = *Request->Record;
+ if (!Record.ValuesRead)
+ {
+ Record.ValuesRead = true;
+ if (Record.CacheValue && Record.CacheValue.GetContentType() == ZenContentType::kCbObject)
+ {
+ CbObjectView RecordObject = CbObjectView(Record.CacheValue.GetData());
+ CbArrayView ValuesArray = RecordObject["Values"sv].AsArrayView();
+ Record.Values.reserve(ValuesArray.Num());
+ for (CbFieldView ValueField : ValuesArray)
+ {
+ CbObjectView ValueObject = ValueField.AsObjectView();
+ Oid ValueId = ValueObject["Id"sv].AsObjectId();
+ CbFieldView RawHashField = ValueObject["RawHash"sv];
+ IoHash RawHash = RawHashField.AsBinaryAttachment();
+ if (ValueId && !RawHashField.HasError())
+ {
+ Record.Values.push_back({ValueId, RawHash, ValueObject["RawSize"sv].AsUInt64()});
+ }
+ }
+ }
+ }
+
+ for (const RecordValue& Value : Record.Values)
+ {
+ if (Value.ValueId == Request->Key->ValueId)
+ {
+ Request->Key->ChunkId = Value.ContentId;
+ Request->RawSize = Value.RawSize;
+ Request->RawSizeKnown = true;
+ break;
+ }
+ }
+ }
+
+ // Now load the ContentId from the local ContentIdStore or from the upstream
+ if (Request->Key->ChunkId != IoHash::Zero)
+ {
+ if (EnumHasAllFlags(Request->DownstreamPolicy, CachePolicy::QueryLocal))
+ {
+ if (EnumHasAllFlags(Request->DownstreamPolicy, CachePolicy::SkipData) && Request->RawSizeKnown)
+ {
+ if (m_CidStore.ContainsChunk(Request->Key->ChunkId))
+ {
+ Request->Exists = true;
+ }
+ }
+ else if (IoBuffer Payload = m_CidStore.FindChunkByCid(Request->Key->ChunkId))
+ {
+ if (!EnumHasAllFlags(Request->DownstreamPolicy, CachePolicy::SkipData))
+ {
+ Request->Value = CompressedBuffer::FromCompressedNoValidate(std::move(Payload));
+ if (Request->Value)
+ {
+ Request->Exists = true;
+ Request->RawSizeKnown = false;
+ }
+ }
+ else
+ {
+ IoHash _;
+ if (CompressedBuffer::ValidateCompressedHeader(Payload, _, Request->RawSize))
+ {
+ Request->Exists = true;
+ Request->RawSizeKnown = true;
+ }
+ }
+ }
+ }
+ if (!Request->Exists && EnumHasAllFlags(Request->DownstreamPolicy, CachePolicy::QueryRemote))
+ {
+ Request->Key->Policy = ConvertToUpstream(Request->DownstreamPolicy);
+ OutUpstreamChunks.push_back(Request->Key);
+ }
+ }
+ Request->ElapsedTimeUs += Timer.GetElapsedTimeUs();
+ }
+}
+
+void
+HttpStructuredCacheService::GetLocalCacheValues(std::string_view Namespace,
+ std::vector<cache::detail::ChunkRequest*>& ValueRequests,
+ std::vector<CacheChunkRequest*>& OutUpstreamChunks)
+{
+ using namespace cache::detail;
+
+ for (ChunkRequest* Request : ValueRequests)
+ {
+ Stopwatch Timer;
+ if (!Request->Exists && EnumHasAllFlags(Request->DownstreamPolicy, CachePolicy::QueryLocal))
+ {
+ ZenCacheValue CacheValue;
+ if (m_CacheStore.Get(Namespace, Request->Key->Key.Bucket, Request->Key->Key.Hash, CacheValue))
+ {
+ if (IsCompressedBinary(CacheValue.Value.GetContentType()))
+ {
+ Request->Key->ChunkId = CacheValue.RawHash;
+ Request->Exists = true;
+ Request->RawSize = CacheValue.RawSize;
+ Request->RawSizeKnown = true;
+ if (!EnumHasAllFlags(Request->DownstreamPolicy, CachePolicy::SkipData))
+ {
+ Request->Value = CompressedBuffer::FromCompressedNoValidate(std::move(CacheValue.Value));
+ }
+ }
+ }
+ }
+ if (!Request->Exists && EnumHasAllFlags(Request->DownstreamPolicy, CachePolicy::QueryRemote))
+ {
+ if (EnumHasAllFlags(Request->DownstreamPolicy, CachePolicy::StoreLocal))
+ {
+ // Convert the Offset,Size request into a request for the entire value; we will need it all to be able to store it locally
+ Request->Key->RawOffset = 0;
+ Request->Key->RawSize = UINT64_MAX;
+ }
+ OutUpstreamChunks.push_back(Request->Key);
+ }
+ Request->ElapsedTimeUs += Timer.GetElapsedTimeUs();
+ }
+}
+
+void
+HttpStructuredCacheService::GetUpstreamCacheChunks(std::string_view Namespace,
+ std::vector<CacheChunkRequest*>& UpstreamChunks,
+ std::vector<CacheChunkRequest>& RequestKeys,
+ std::vector<cache::detail::ChunkRequest>& Requests)
+{
+ using namespace cache::detail;
+
+ if (!UpstreamChunks.empty())
+ {
+ const auto OnCacheChunksGetComplete = [this, Namespace, &RequestKeys, &Requests](CacheChunkGetCompleteParams&& Params) {
+ if (Params.RawHash == Params.RawHash.Zero)
+ {
+ return;
+ }
+
+ CacheChunkRequest& Key = Params.Request;
+ size_t RequestIndex = std::distance(RequestKeys.data(), &Key);
+ ChunkRequest& Request = Requests[RequestIndex];
+ Request.ElapsedTimeUs += static_cast<uint64_t>(Params.ElapsedSeconds * 1000000.0);
+ if (EnumHasAllFlags(Request.DownstreamPolicy, CachePolicy::StoreLocal) ||
+ !EnumHasAllFlags(Request.DownstreamPolicy, CachePolicy::SkipData))
+ {
+ CompressedBuffer Compressed = CompressedBuffer::FromCompressedNoValidate(IoBuffer(Params.Value));
+ if (!Compressed)
+ {
+ return;
+ }
+
+ if (EnumHasAllFlags(Request.DownstreamPolicy, CachePolicy::StoreLocal))
+ {
+ if (Request.IsRecordRequest)
+ {
+ m_CidStore.AddChunk(Params.Value, Params.RawHash);
+ }
+ else
+ {
+ m_CacheStore.Put(Namespace,
+ Key.Key.Bucket,
+ Key.Key.Hash,
+ {.Value = Params.Value, .RawSize = Params.RawSize, .RawHash = Params.RawHash});
+ }
+ }
+ if (!EnumHasAllFlags(Request.DownstreamPolicy, CachePolicy::SkipData))
+ {
+ Request.Value = std::move(Compressed);
+ }
+ }
+ Key.ChunkId = Params.RawHash;
+ Request.Exists = true;
+ Request.RawSize = Params.RawSize;
+ Request.RawSizeKnown = true;
+ Request.Source = Params.Source;
+
+ m_CacheStats.UpstreamHitCount++;
+ };
+
+ m_UpstreamCache.GetCacheChunks(Namespace, UpstreamChunks, std::move(OnCacheChunksGetComplete));
+ }
+}
+
+CbPackage
+HttpStructuredCacheService::WriteGetCacheChunksResponse(std::string_view Namespace, std::vector<cache::detail::ChunkRequest>& Requests)
+{
+ using namespace cache::detail;
+
+ CbPackage RpcResponse;
+ CbObjectWriter Writer;
+
+ Writer.BeginArray("Result"sv);
+ for (ChunkRequest& Request : Requests)
+ {
+ Writer.BeginObject();
+ {
+ if (Request.Exists)
+ {
+ Writer.AddHash("RawHash"sv, Request.Key->ChunkId);
+ if (Request.Value && !EnumHasAllFlags(Request.DownstreamPolicy, CachePolicy::SkipData))
+ {
+ RpcResponse.AddAttachment(CbAttachment(Request.Value, Request.Key->ChunkId));
+ }
+ else
+ {
+ Writer.AddInteger("RawSize"sv, Request.RawSize);
+ }
+
+ ZEN_DEBUG("GETCACHECHUNKS HIT - '{}/{}/{}/{}' {} '{}' ({}) in {}",
+ Namespace,
+ Request.Key->Key.Bucket,
+ Request.Key->Key.Hash,
+ Request.Key->ValueId,
+ NiceBytes(Request.RawSize),
+ Request.IsRecordRequest ? "Record"sv : "Value"sv,
+ Request.Source ? Request.Source->Url : "LOCAL"sv,
+ NiceLatencyNs(Request.ElapsedTimeUs * 1000));
+ m_CacheStats.HitCount++;
+ }
+ else if (!EnumHasAnyFlags(Request.DownstreamPolicy, CachePolicy::Query))
+ {
+ ZEN_DEBUG("GETCACHECHUNKS DISABLEDQUERY - '{}/{}/{}/{}' in {}",
+ Namespace,
+ Request.Key->Key.Bucket,
+ Request.Key->Key.Hash,
+ Request.Key->ValueId,
+ NiceLatencyNs(Request.ElapsedTimeUs * 1000));
+ }
+ else
+ {
+ ZEN_DEBUG("GETCACHECHUNKS MISS - '{}/{}/{}/{}' in {}",
+ Namespace,
+ Request.Key->Key.Bucket,
+ Request.Key->Key.Hash,
+ Request.Key->ValueId,
+ NiceLatencyNs(Request.ElapsedTimeUs * 1000));
+ m_CacheStats.MissCount++;
+ }
+ }
+ Writer.EndObject();
+ }
+ Writer.EndArray();
+
+ RpcResponse.SetObject(Writer.Save());
+ return RpcResponse;
+}
+
+void
+HttpStructuredCacheService::HandleStatsRequest(HttpServerRequest& Request)
+{
+ CbObjectWriter Cbo;
+
+ EmitSnapshot("requests", m_HttpRequests, Cbo);
+ EmitSnapshot("upstream_gets", m_UpstreamGetRequestTiming, Cbo);
+
+ const uint64_t HitCount = m_CacheStats.HitCount;
+ const uint64_t UpstreamHitCount = m_CacheStats.UpstreamHitCount;
+ const uint64_t MissCount = m_CacheStats.MissCount;
+ const uint64_t TotalCount = HitCount + MissCount;
+
+ const CidStoreSize CidSize = m_CidStore.TotalSize();
+ const GcStorageSize CacheSize = m_CacheStore.StorageSize();
+
+ Cbo.BeginObject("cache");
+ {
+ Cbo.BeginObject("size");
+ {
+ Cbo << "disk" << CacheSize.DiskSize;
+ Cbo << "memory" << CacheSize.MemorySize;
+ }
+ Cbo.EndObject();
+
+ Cbo << "upstream_ratio" << (HitCount > 0 ? (double(UpstreamHitCount) / double(HitCount)) : 0.0);
+ Cbo << "hits" << HitCount << "misses" << MissCount;
+ Cbo << "hit_ratio" << (TotalCount > 0 ? (double(HitCount) / double(TotalCount)) : 0.0);
+ Cbo << "upstream_hits" << m_CacheStats.UpstreamHitCount;
+ Cbo << "upstream_ratio" << (HitCount > 0 ? (double(UpstreamHitCount) / double(HitCount)) : 0.0);
+ }
+ Cbo.EndObject();
+
+ Cbo.BeginObject("upstream");
+ {
+ m_UpstreamCache.GetStatus(Cbo);
+ }
+ Cbo.EndObject();
+
+ Cbo.BeginObject("cid");
+ {
+ Cbo.BeginObject("size");
+ {
+ Cbo << "tiny" << CidSize.TinySize;
+ Cbo << "small" << CidSize.SmallSize;
+ Cbo << "large" << CidSize.LargeSize;
+ Cbo << "total" << CidSize.TotalSize;
+ }
+ Cbo.EndObject();
+ }
+ Cbo.EndObject();
+
+ Request.WriteResponse(HttpResponseCode::OK, Cbo.Save());
+}
+
+void
+HttpStructuredCacheService::HandleStatusRequest(HttpServerRequest& Request)
+{
+ CbObjectWriter Cbo;
+ Cbo << "ok" << true;
+ Request.WriteResponse(HttpResponseCode::OK, Cbo.Save());
+}
+
+#if ZEN_WITH_TESTS
+
+TEST_CASE("z$service.parse.relative.Uri")
+{
+ HttpRequestData RootRequest;
+ CHECK(HttpRequestParseRelativeUri("", RootRequest));
+ CHECK(!RootRequest.Namespace.has_value());
+ CHECK(!RootRequest.Bucket.has_value());
+ CHECK(!RootRequest.HashKey.has_value());
+ CHECK(!RootRequest.ValueContentId.has_value());
+
+ RootRequest = {};
+ CHECK(HttpRequestParseRelativeUri("/", RootRequest));
+ CHECK(!RootRequest.Namespace.has_value());
+ CHECK(!RootRequest.Bucket.has_value());
+ CHECK(!RootRequest.HashKey.has_value());
+ CHECK(!RootRequest.ValueContentId.has_value());
+
+ HttpRequestData LegacyBucketRequestBecomesNamespaceRequest;
+ CHECK(HttpRequestParseRelativeUri("test", LegacyBucketRequestBecomesNamespaceRequest));
+ CHECK(LegacyBucketRequestBecomesNamespaceRequest.Namespace == "test"sv);
+ CHECK(!LegacyBucketRequestBecomesNamespaceRequest.Bucket.has_value());
+ CHECK(!LegacyBucketRequestBecomesNamespaceRequest.HashKey.has_value());
+ CHECK(!LegacyBucketRequestBecomesNamespaceRequest.ValueContentId.has_value());
+
+ HttpRequestData LegacyHashKeyRequest;
+ CHECK(HttpRequestParseRelativeUri("test/0123456789abcdef12340123456789abcdef1234", LegacyHashKeyRequest));
+ CHECK(LegacyHashKeyRequest.Namespace == ZenCacheStore::DefaultNamespace);
+ CHECK(LegacyHashKeyRequest.Bucket == "test"sv);
+ CHECK(LegacyHashKeyRequest.HashKey == IoHash::FromHexString("0123456789abcdef12340123456789abcdef1234"sv));
+ CHECK(!LegacyHashKeyRequest.ValueContentId.has_value());
+
+ HttpRequestData LegacyValueContentIdRequest;
+ CHECK(HttpRequestParseRelativeUri("test/0123456789abcdef12340123456789abcdef1234/56789abcdef12345678956789abcdef123456789",
+ LegacyValueContentIdRequest));
+ CHECK(LegacyValueContentIdRequest.Namespace == ZenCacheStore::DefaultNamespace);
+ CHECK(LegacyValueContentIdRequest.Bucket == "test"sv);
+ CHECK(LegacyValueContentIdRequest.HashKey == IoHash::FromHexString("0123456789abcdef12340123456789abcdef1234"sv));
+ CHECK(LegacyValueContentIdRequest.ValueContentId == IoHash::FromHexString("56789abcdef12345678956789abcdef123456789"sv));
+
+ HttpRequestData V2DefaultNamespaceRequest;
+ CHECK(HttpRequestParseRelativeUri("ue4.ddc", V2DefaultNamespaceRequest));
+ CHECK(V2DefaultNamespaceRequest.Namespace == "ue4.ddc");
+ CHECK(!V2DefaultNamespaceRequest.Bucket.has_value());
+ CHECK(!V2DefaultNamespaceRequest.HashKey.has_value());
+ CHECK(!V2DefaultNamespaceRequest.ValueContentId.has_value());
+
+ HttpRequestData V2NamespaceRequest;
+ CHECK(HttpRequestParseRelativeUri("nicenamespace", V2NamespaceRequest));
+ CHECK(V2NamespaceRequest.Namespace == "nicenamespace"sv);
+ CHECK(!V2NamespaceRequest.Bucket.has_value());
+ CHECK(!V2NamespaceRequest.HashKey.has_value());
+ CHECK(!V2NamespaceRequest.ValueContentId.has_value());
+
+ HttpRequestData V2BucketRequestWithDefaultNamespace;
+ CHECK(HttpRequestParseRelativeUri("ue4.ddc/test", V2BucketRequestWithDefaultNamespace));
+ CHECK(V2BucketRequestWithDefaultNamespace.Namespace == "ue4.ddc");
+ CHECK(V2BucketRequestWithDefaultNamespace.Bucket == "test"sv);
+ CHECK(!V2BucketRequestWithDefaultNamespace.HashKey.has_value());
+ CHECK(!V2BucketRequestWithDefaultNamespace.ValueContentId.has_value());
+
+ HttpRequestData V2BucketRequestWithNamespace;
+ CHECK(HttpRequestParseRelativeUri("nicenamespace/test", V2BucketRequestWithNamespace));
+ CHECK(V2BucketRequestWithNamespace.Namespace == "nicenamespace"sv);
+ CHECK(V2BucketRequestWithNamespace.Bucket == "test"sv);
+ CHECK(!V2BucketRequestWithNamespace.HashKey.has_value());
+ CHECK(!V2BucketRequestWithNamespace.ValueContentId.has_value());
+
+ HttpRequestData V2HashKeyRequest;
+ CHECK(HttpRequestParseRelativeUri("test/0123456789abcdef12340123456789abcdef1234", V2HashKeyRequest));
+ CHECK(V2HashKeyRequest.Namespace == ZenCacheStore::DefaultNamespace);
+ CHECK(V2HashKeyRequest.Bucket == "test");
+ CHECK(V2HashKeyRequest.HashKey == IoHash::FromHexString("0123456789abcdef12340123456789abcdef1234"sv));
+ CHECK(!V2HashKeyRequest.ValueContentId.has_value());
+
+ HttpRequestData V2ValueContentIdRequest;
+ CHECK(
+ HttpRequestParseRelativeUri("nicenamespace/test/0123456789abcdef12340123456789abcdef1234/56789abcdef12345678956789abcdef123456789",
+ V2ValueContentIdRequest));
+ CHECK(V2ValueContentIdRequest.Namespace == "nicenamespace"sv);
+ CHECK(V2ValueContentIdRequest.Bucket == "test"sv);
+ CHECK(V2ValueContentIdRequest.HashKey == IoHash::FromHexString("0123456789abcdef12340123456789abcdef1234"sv));
+ CHECK(V2ValueContentIdRequest.ValueContentId == IoHash::FromHexString("56789abcdef12345678956789abcdef123456789"sv));
+
+ HttpRequestData Invalid;
+ CHECK(!HttpRequestParseRelativeUri("bad\2_namespace", Invalid));
+ CHECK(!HttpRequestParseRelativeUri("nice/\2\1bucket", Invalid));
+ CHECK(!HttpRequestParseRelativeUri("namespace/bucket/0123456789a", Invalid));
+ CHECK(!HttpRequestParseRelativeUri("namespace/bucket/0123456789abcdef12340123456789abcdef1234/56789abcdef1234", Invalid));
+ CHECK(!HttpRequestParseRelativeUri("namespace/bucket/pppppppp89abcdef12340123456789abcdef1234", Invalid));
+ CHECK(!HttpRequestParseRelativeUri("namespace/bucket/0123456789abcdef12340123456789abcdef1234/56789abcd", Invalid));
+ CHECK(!HttpRequestParseRelativeUri("namespace/bucket/0123456789abcdef12340123456789abcdef1234/ppppppppdef12345678956789abcdef123456789",
+ Invalid));
+}
+
+#endif
+
+void
+z$service_forcelink()
+{
+}
+
+} // namespace zen
diff --git a/src/zenserver/cache/structuredcache.h b/src/zenserver/cache/structuredcache.h
new file mode 100644
index 000000000..4e7b98ac9
--- /dev/null
+++ b/src/zenserver/cache/structuredcache.h
@@ -0,0 +1,187 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/stats.h>
+#include <zenhttp/httpserver.h>
+
+#include "monitoring/httpstats.h"
+#include "monitoring/httpstatus.h"
+
+#include <memory>
+#include <vector>
+
+namespace spdlog {
+class logger;
+}
+
+namespace zen {
+
+struct CacheChunkRequest;
+struct CacheKeyRequest;
+class CidStore;
+class CbObjectView;
+struct PutRequestData;
+class ScrubContext;
+class UpstreamCache;
+class ZenCacheStore;
+enum class CachePolicy : uint32_t;
+enum class RpcAcceptOptions : uint16_t;
+
+namespace cache {
+ class IRpcRequestReplayer;
+ class IRpcRequestRecorder;
+ namespace detail {
+ struct RecordBody;
+ struct ChunkRequest;
+ } // namespace detail
+} // namespace cache
+
+/**
+ * Structured cache service. Imposes constraints on keys, supports blobs and
+ * structured values
+ *
+ * Keys are structured as:
+ *
+ * {BucketId}/{KeyHash}
+ *
+ * Where BucketId is a lower-case alphanumeric string, and KeyHash is a 40-character
+ * hexadecimal sequence. The hash value may be derived in any number of ways, it's
+ * up to the application to pick an approach.
+ *
+ * Values may be structured or unstructured. Structured values are encoded using Unreal
+ * Engine's compact binary encoding (see CbObject)
+ *
+ * Additionally, attachments may be addressed as:
+ *
+ * {BucketId}/{KeyHash}/{ValueHash}
+ *
+ * Where the two initial components are the same as for the main endpoint
+ *
+ * The storage strategy is as follows:
+ *
+ * - Structured values are stored in a dedicated backing store per bucket
+ * - Unstructured values and attachments are stored in the CAS pool
+ *
+ */
+
+class HttpStructuredCacheService : public HttpService, public IHttpStatsProvider, public IHttpStatusProvider
+{
+public:
+ HttpStructuredCacheService(ZenCacheStore& InCacheStore,
+ CidStore& InCidStore,
+ HttpStatsService& StatsService,
+ HttpStatusService& StatusService,
+ UpstreamCache& UpstreamCache);
+ ~HttpStructuredCacheService();
+
+ virtual const char* BaseUri() const override;
+ virtual void HandleRequest(HttpServerRequest& Request) override;
+
+ void Flush();
+ void Scrub(ScrubContext& Ctx);
+
+private:
+ struct CacheRef
+ {
+ std::string Namespace;
+ std::string BucketSegment;
+ IoHash HashKey;
+ IoHash ValueContentId;
+ };
+
+ struct CacheStats
+ {
+ std::atomic_uint64_t HitCount{};
+ std::atomic_uint64_t UpstreamHitCount{};
+ std::atomic_uint64_t MissCount{};
+ };
+ enum class PutResult
+ {
+ Success,
+ Fail,
+ Invalid,
+ };
+
+ void HandleCacheRecordRequest(HttpServerRequest& Request, const CacheRef& Ref, CachePolicy PolicyFromUrl);
+ void HandleGetCacheRecord(HttpServerRequest& Request, const CacheRef& Ref, CachePolicy PolicyFromUrl);
+ void HandlePutCacheRecord(HttpServerRequest& Request, const CacheRef& Ref, CachePolicy PolicyFromUrl);
+ void HandleCacheChunkRequest(HttpServerRequest& Request, const CacheRef& Ref, CachePolicy PolicyFromUrl);
+ void HandleGetCacheChunk(HttpServerRequest& Request, const CacheRef& Ref, CachePolicy PolicyFromUrl);
+ void HandlePutCacheChunk(HttpServerRequest& Request, const CacheRef& Ref, CachePolicy PolicyFromUrl);
+ void HandleRpcRequest(HttpServerRequest& Request);
+ void HandleDetailsRequest(HttpServerRequest& Request);
+
+ CbPackage HandleRpcPutCacheRecords(const CbPackage& BatchRequest);
+ CbPackage HandleRpcGetCacheRecords(CbObjectView BatchRequest);
+ CbPackage HandleRpcPutCacheValues(const CbPackage& BatchRequest);
+ CbPackage HandleRpcGetCacheValues(CbObjectView BatchRequest);
+ CbPackage HandleRpcGetCacheChunks(CbObjectView BatchRequest);
+ CbPackage HandleRpcRequest(const ZenContentType ContentType,
+ IoBuffer&& Body,
+ uint32_t& OutAcceptMagic,
+ RpcAcceptOptions& OutAcceptFlags,
+ int& OutTargetProcessId);
+
+ void HandleCacheRequest(HttpServerRequest& Request);
+ void HandleCacheNamespaceRequest(HttpServerRequest& Request, std::string_view Namespace);
+ void HandleCacheBucketRequest(HttpServerRequest& Request, std::string_view Namespace, std::string_view Bucket);
+ virtual void HandleStatsRequest(HttpServerRequest& Request) override;
+ virtual void HandleStatusRequest(HttpServerRequest& Request) override;
+ PutResult PutCacheRecord(PutRequestData& Request, const CbPackage* Package);
+
+ /** HandleRpcGetCacheChunks Helper: Parse the Body object into RecordValue Requests and Value Requests. */
+ bool ParseGetCacheChunksRequest(std::string& Namespace,
+ std::vector<CacheKeyRequest>& RecordKeys,
+ std::vector<cache::detail::RecordBody>& Records,
+ std::vector<CacheChunkRequest>& RequestKeys,
+ std::vector<cache::detail::ChunkRequest>& Requests,
+ std::vector<cache::detail::ChunkRequest*>& RecordRequests,
+ std::vector<cache::detail::ChunkRequest*>& ValueRequests,
+ CbObjectView RpcRequest);
+ /** HandleRpcGetCacheChunks Helper: Load records to get ContentId for RecordRequests, and load their payloads if they exist locally. */
+ void GetLocalCacheRecords(std::string_view Namespace,
+ std::vector<CacheKeyRequest>& RecordKeys,
+ std::vector<cache::detail::RecordBody>& Records,
+ std::vector<cache::detail::ChunkRequest*>& RecordRequests,
+ std::vector<CacheChunkRequest*>& OutUpstreamChunks);
+ /** HandleRpcGetCacheChunks Helper: For ValueRequests, load their payloads if they exist locally. */
+ void GetLocalCacheValues(std::string_view Namespace,
+ std::vector<cache::detail::ChunkRequest*>& ValueRequests,
+ std::vector<CacheChunkRequest*>& OutUpstreamChunks);
+ /** HandleRpcGetCacheChunks Helper: Load payloads from upstream that did not exist locally. */
+ void GetUpstreamCacheChunks(std::string_view Namespace,
+ std::vector<CacheChunkRequest*>& UpstreamChunks,
+ std::vector<CacheChunkRequest>& RequestKeys,
+ std::vector<cache::detail::ChunkRequest>& Requests);
+ /** HandleRpcGetCacheChunks Helper: Send response message containing all chunk results. */
+ CbPackage WriteGetCacheChunksResponse(std::string_view Namespace, std::vector<cache::detail::ChunkRequest>& Requests);
+
+ spdlog::logger& Log() { return m_Log; }
+ spdlog::logger& m_Log;
+ ZenCacheStore& m_CacheStore;
+ HttpStatsService& m_StatsService;
+ HttpStatusService& m_StatusService;
+ CidStore& m_CidStore;
+ UpstreamCache& m_UpstreamCache;
+ uint64_t m_LastScrubTime = 0;
+ metrics::OperationTiming m_HttpRequests;
+ metrics::OperationTiming m_UpstreamGetRequestTiming;
+ CacheStats m_CacheStats;
+
+ void ReplayRequestRecorder(cache::IRpcRequestReplayer& Replayer, uint32_t ThreadCount);
+
+ std::unique_ptr<cache::IRpcRequestRecorder> m_RequestRecorder;
+};
+
+/** Recognize both kBinary and kCompressedBinary as kCompressedBinary for structured cache value keys.
+ * We need this until the content type is preserved for kCompressedBinary when passing to and from upstream servers. */
+inline bool
+IsCompressedBinary(ZenContentType Type)
+{
+ return Type == ZenContentType::kBinary || Type == ZenContentType::kCompressedBinary;
+}
+
+void z$service_forcelink();
+
+} // namespace zen
diff --git a/src/zenserver/cache/structuredcachestore.cpp b/src/zenserver/cache/structuredcachestore.cpp
new file mode 100644
index 000000000..26e970073
--- /dev/null
+++ b/src/zenserver/cache/structuredcachestore.cpp
@@ -0,0 +1,3648 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "structuredcachestore.h"
+
+#include <zencore/except.h>
+
+#include <zencore/compactbinarybuilder.h>
+#include <zencore/compactbinarypackage.h>
+#include <zencore/compactbinaryvalidation.h>
+#include <zencore/compress.h>
+#include <zencore/except.h>
+#include <zencore/filesystem.h>
+#include <zencore/fmtutils.h>
+#include <zencore/logging.h>
+#include <zencore/scopeguard.h>
+#include <zencore/string.h>
+#include <zencore/thread.h>
+#include <zencore/timer.h>
+#include <zencore/trace.h>
+#include <zenstore/cidstore.h>
+#include <zenstore/scrubcontext.h>
+
+#include <xxhash.h>
+
+#include <limits>
+
+#if ZEN_PLATFORM_WINDOWS
+# include <zencore/windows.h>
+#endif
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <fmt/core.h>
+#include <gsl/gsl-lite.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+#if ZEN_WITH_TESTS
+# include <zencore/testing.h>
+# include <zencore/testutils.h>
+# include <zencore/workthreadpool.h>
+# include <random>
+#endif
+
+//////////////////////////////////////////////////////////////////////////
+
+namespace zen {
+
+namespace {
+
+#pragma pack(push)
+#pragma pack(1)
+
+ struct CacheBucketIndexHeader
+ {
+ static constexpr uint32_t ExpectedMagic = 0x75696478; // 'uidx';
+ static constexpr uint32_t Version2 = 2;
+ static constexpr uint32_t CurrentVersion = Version2;
+
+ uint32_t Magic = ExpectedMagic;
+ uint32_t Version = CurrentVersion;
+ uint64_t EntryCount = 0;
+ uint64_t LogPosition = 0;
+ uint32_t PayloadAlignment = 0;
+ uint32_t Checksum = 0;
+
+ static uint32_t ComputeChecksum(const CacheBucketIndexHeader& Header)
+ {
+ return XXH32(&Header.Magic, sizeof(CacheBucketIndexHeader) - sizeof(uint32_t), 0xC0C0'BABA);
+ }
+ };
+
+ static_assert(sizeof(CacheBucketIndexHeader) == 32);
+
+#pragma pack(pop)
+
+ const char* IndexExtension = ".uidx";
+ const char* LogExtension = ".slog";
+
+ std::filesystem::path GetIndexPath(const std::filesystem::path& BucketDir, const std::string& BucketName)
+ {
+ return BucketDir / (BucketName + IndexExtension);
+ }
+
+ std::filesystem::path GetTempIndexPath(const std::filesystem::path& BucketDir, const std::string& BucketName)
+ {
+ return BucketDir / (BucketName + ".tmp");
+ }
+
+ std::filesystem::path GetLogPath(const std::filesystem::path& BucketDir, const std::string& BucketName)
+ {
+ return BucketDir / (BucketName + LogExtension);
+ }
+
+ bool ValidateEntry(const DiskIndexEntry& Entry, std::string& OutReason)
+ {
+ if (Entry.Key == IoHash::Zero)
+ {
+ OutReason = fmt::format("Invalid hash key {}", Entry.Key.ToHexString());
+ return false;
+ }
+ if (Entry.Location.GetFlags() &
+ ~(DiskLocation::kStandaloneFile | DiskLocation::kStructured | DiskLocation::kTombStone | DiskLocation::kCompressed))
+ {
+ OutReason = fmt::format("Invalid flags {} for entry {}", Entry.Location.GetFlags(), Entry.Key.ToHexString());
+ return false;
+ }
+ if (Entry.Location.IsFlagSet(DiskLocation::kTombStone))
+ {
+ return true;
+ }
+ if (Entry.Location.Reserved != 0)
+ {
+ OutReason = fmt::format("Invalid reserved field {} for entry {}", Entry.Location.Reserved, Entry.Key.ToHexString());
+ return false;
+ }
+ uint64_t Size = Entry.Location.Size();
+ if (Size == 0)
+ {
+ OutReason = fmt::format("Invalid size {} for entry {}", Size, Entry.Key.ToHexString());
+ return false;
+ }
+ return true;
+ }
+
+ bool MoveAndDeleteDirectory(const std::filesystem::path& Dir)
+ {
+ int DropIndex = 0;
+ do
+ {
+ if (!std::filesystem::exists(Dir))
+ {
+ return false;
+ }
+
+ std::string DroppedName = fmt::format("[dropped]{}({})", Dir.filename().string(), DropIndex);
+ std::filesystem::path DroppedBucketPath = Dir.parent_path() / DroppedName;
+ if (std::filesystem::exists(DroppedBucketPath))
+ {
+ DropIndex++;
+ continue;
+ }
+
+ std::error_code Ec;
+ std::filesystem::rename(Dir, DroppedBucketPath, Ec);
+ if (!Ec)
+ {
+ DeleteDirectories(DroppedBucketPath);
+ return true;
+ }
+ // TODO: Do we need to bail at some point?
+ zen::Sleep(100);
+ } while (true);
+ }
+
+} // namespace
+
+namespace fs = std::filesystem;
+
+static CbObject
+LoadCompactBinaryObject(const fs::path& Path)
+{
+ FileContents Result = ReadFile(Path);
+
+ if (!Result.ErrorCode)
+ {
+ IoBuffer Buffer = Result.Flatten();
+ if (CbValidateError Error = ValidateCompactBinary(Buffer, CbValidateMode::All); Error == CbValidateError::None)
+ {
+ return LoadCompactBinaryObject(Buffer);
+ }
+ }
+
+ return CbObject();
+}
+
+static void
+SaveCompactBinaryObject(const fs::path& Path, const CbObject& Object)
+{
+ WriteFile(Path, Object.GetBuffer().AsIoBuffer());
+}
+
+ZenCacheNamespace::ZenCacheNamespace(GcManager& Gc, const std::filesystem::path& RootDir)
+: GcStorage(Gc)
+, GcContributor(Gc)
+, m_RootDir(RootDir)
+, m_DiskLayer(RootDir)
+{
+ ZEN_INFO("initializing structured cache at '{}'", RootDir);
+ CreateDirectories(RootDir);
+
+ m_DiskLayer.DiscoverBuckets();
+
+#if ZEN_USE_CACHE_TRACKER
+ m_AccessTracker.reset(new ZenCacheTracker(RootDir));
+#endif
+}
+
+ZenCacheNamespace::~ZenCacheNamespace()
+{
+}
+
+bool
+ZenCacheNamespace::Get(std::string_view InBucket, const IoHash& HashKey, ZenCacheValue& OutValue)
+{
+ ZEN_TRACE_CPU("Z$::Get");
+
+ bool Ok = m_MemLayer.Get(InBucket, HashKey, OutValue);
+
+#if ZEN_USE_CACHE_TRACKER
+ auto _ = MakeGuard([&] {
+ if (!Ok)
+ return;
+
+ m_AccessTracker->TrackAccess(InBucket, HashKey);
+ });
+#endif
+
+ if (Ok)
+ {
+ ZEN_ASSERT(OutValue.Value.Size());
+
+ return true;
+ }
+
+ Ok = m_DiskLayer.Get(InBucket, HashKey, OutValue);
+
+ if (Ok)
+ {
+ ZEN_ASSERT(OutValue.Value.Size());
+
+ if (OutValue.Value.Size() <= m_DiskLayerSizeThreshold)
+ {
+ m_MemLayer.Put(InBucket, HashKey, OutValue);
+ }
+ }
+
+ return Ok;
+}
+
+void
+ZenCacheNamespace::Put(std::string_view InBucket, const IoHash& HashKey, const ZenCacheValue& Value)
+{
+ ZEN_TRACE_CPU("Z$::Put");
+
+ // Store value and index
+
+ ZEN_ASSERT(Value.Value.Size());
+
+ m_DiskLayer.Put(InBucket, HashKey, Value);
+
+#if ZEN_USE_REF_TRACKING
+ if (Value.Value.GetContentType() == ZenContentType::kCbObject)
+ {
+ if (ValidateCompactBinary(Value.Value, CbValidateMode::All) == CbValidateError::None)
+ {
+ CbObject Object{SharedBuffer(Value.Value)};
+
+ uint8_t TempBuffer[8 * sizeof(IoHash)];
+ std::pmr::monotonic_buffer_resource Linear{TempBuffer, sizeof TempBuffer};
+ std::pmr::polymorphic_allocator Allocator{&Linear};
+ std::pmr::vector<IoHash> CidReferences{Allocator};
+
+ Object.IterateAttachments([&](CbFieldView Field) { CidReferences.push_back(Field.AsAttachment()); });
+
+ m_Gc.OnNewCidReferences(CidReferences);
+ }
+ }
+#endif
+
+ if (Value.Value.Size() <= m_DiskLayerSizeThreshold)
+ {
+ m_MemLayer.Put(InBucket, HashKey, Value);
+ }
+}
+
+bool
+ZenCacheNamespace::DropBucket(std::string_view Bucket)
+{
+ ZEN_INFO("dropping bucket '{}'", Bucket);
+
+ // TODO: should ensure this is done atomically across all layers
+
+ const bool MemDropped = m_MemLayer.DropBucket(Bucket);
+ const bool DiskDropped = m_DiskLayer.DropBucket(Bucket);
+ const bool AnyDropped = MemDropped || DiskDropped;
+
+ ZEN_INFO("bucket '{}' was {}", Bucket, AnyDropped ? "dropped" : "not found");
+
+ return AnyDropped;
+}
+
+bool
+ZenCacheNamespace::Drop()
+{
+ m_MemLayer.Drop();
+ return m_DiskLayer.Drop();
+}
+
+void
+ZenCacheNamespace::Flush()
+{
+ m_DiskLayer.Flush();
+}
+
+void
+ZenCacheNamespace::Scrub(ScrubContext& Ctx)
+{
+ if (m_LastScrubTime == Ctx.ScrubTimestamp())
+ {
+ return;
+ }
+
+ m_LastScrubTime = Ctx.ScrubTimestamp();
+
+ m_DiskLayer.Scrub(Ctx);
+ m_MemLayer.Scrub(Ctx);
+}
+
+void
+ZenCacheNamespace::GatherReferences(GcContext& GcCtx)
+{
+ Stopwatch Timer;
+ const auto Guard =
+ MakeGuard([&] { ZEN_DEBUG("cache gathered all references from '{}' in {}", m_RootDir, NiceTimeSpanMs(Timer.GetElapsedTimeMs())); });
+
+ access_tracking::AccessTimes AccessTimes;
+ m_MemLayer.GatherAccessTimes(AccessTimes);
+
+ m_DiskLayer.UpdateAccessTimes(AccessTimes);
+ m_DiskLayer.GatherReferences(GcCtx);
+}
+
+void
+ZenCacheNamespace::CollectGarbage(GcContext& GcCtx)
+{
+ m_MemLayer.Reset();
+ m_DiskLayer.CollectGarbage(GcCtx);
+}
+
+GcStorageSize
+ZenCacheNamespace::StorageSize() const
+{
+ return {.DiskSize = m_DiskLayer.TotalSize(), .MemorySize = m_MemLayer.TotalSize()};
+}
+
+ZenCacheNamespace::Info
+ZenCacheNamespace::GetInfo() const
+{
+ ZenCacheNamespace::Info Info = {.Config = {.RootDir = m_RootDir, .DiskLayerThreshold = m_DiskLayerSizeThreshold},
+ .DiskLayerInfo = m_DiskLayer.GetInfo(),
+ .MemoryLayerInfo = m_MemLayer.GetInfo()};
+ std::unordered_set<std::string> BucketNames;
+ for (const std::string& BucketName : Info.DiskLayerInfo.BucketNames)
+ {
+ BucketNames.insert(BucketName);
+ }
+ for (const std::string& BucketName : Info.MemoryLayerInfo.BucketNames)
+ {
+ BucketNames.insert(BucketName);
+ }
+ Info.BucketNames.insert(Info.BucketNames.end(), BucketNames.begin(), BucketNames.end());
+ return Info;
+}
+
+std::optional<ZenCacheNamespace::BucketInfo>
+ZenCacheNamespace::GetBucketInfo(std::string_view Bucket) const
+{
+ std::optional<ZenCacheDiskLayer::BucketInfo> DiskBucketInfo = m_DiskLayer.GetBucketInfo(Bucket);
+ if (!DiskBucketInfo.has_value())
+ {
+ return {};
+ }
+ ZenCacheNamespace::BucketInfo Info = {.DiskLayerInfo = *DiskBucketInfo,
+ .MemoryLayerInfo = m_MemLayer.GetBucketInfo(Bucket).value_or(ZenCacheMemoryLayer::BucketInfo{})};
+ return Info;
+}
+
+CacheValueDetails::NamespaceDetails
+ZenCacheNamespace::GetValueDetails(const std::string_view BucketFilter, const std::string_view ValueFilter) const
+{
+ return m_DiskLayer.GetValueDetails(BucketFilter, ValueFilter);
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+ZenCacheMemoryLayer::ZenCacheMemoryLayer()
+{
+}
+
+ZenCacheMemoryLayer::~ZenCacheMemoryLayer()
+{
+}
+
+bool
+ZenCacheMemoryLayer::Get(std::string_view InBucket, const IoHash& HashKey, ZenCacheValue& OutValue)
+{
+ RwLock::SharedLockScope _(m_Lock);
+
+ auto It = m_Buckets.find(std::string(InBucket));
+
+ if (It == m_Buckets.end())
+ {
+ return false;
+ }
+
+ CacheBucket* Bucket = It->second.get();
+
+ _.ReleaseNow();
+
+ // There's a race here. Since the lock is released early to allow
+ // inserts, the bucket delete path could end up deleting the
+ // underlying data structure
+
+ return Bucket->Get(HashKey, OutValue);
+}
+
+void
+ZenCacheMemoryLayer::Put(std::string_view InBucket, const IoHash& HashKey, const ZenCacheValue& Value)
+{
+ const auto BucketName = std::string(InBucket);
+ CacheBucket* Bucket = nullptr;
+
+ {
+ RwLock::SharedLockScope _(m_Lock);
+
+ if (auto It = m_Buckets.find(std::string(InBucket)); It != m_Buckets.end())
+ {
+ Bucket = It->second.get();
+ }
+ }
+
+ if (Bucket == nullptr)
+ {
+ // New bucket
+
+ RwLock::ExclusiveLockScope _(m_Lock);
+
+ if (auto It = m_Buckets.find(std::string(InBucket)); It != m_Buckets.end())
+ {
+ Bucket = It->second.get();
+ }
+ else
+ {
+ auto InsertResult = m_Buckets.emplace(BucketName, std::make_unique<CacheBucket>());
+ Bucket = InsertResult.first->second.get();
+ }
+ }
+
+ // Note that since the underlying IoBuffer is retained, the content type is also
+
+ Bucket->Put(HashKey, Value);
+}
+
+bool
+ZenCacheMemoryLayer::DropBucket(std::string_view InBucket)
+{
+ RwLock::ExclusiveLockScope _(m_Lock);
+
+ auto It = m_Buckets.find(std::string(InBucket));
+
+ if (It != m_Buckets.end())
+ {
+ CacheBucket& Bucket = *It->second;
+ m_DroppedBuckets.push_back(std::move(It->second));
+ m_Buckets.erase(It);
+ Bucket.Drop();
+ return true;
+ }
+ return false;
+}
+
+void
+ZenCacheMemoryLayer::Drop()
+{
+ RwLock::ExclusiveLockScope _(m_Lock);
+ std::vector<std::unique_ptr<CacheBucket>> Buckets;
+ Buckets.reserve(m_Buckets.size());
+ while (!m_Buckets.empty())
+ {
+ const auto& It = m_Buckets.begin();
+ CacheBucket& Bucket = *It->second;
+ m_DroppedBuckets.push_back(std::move(It->second));
+ m_Buckets.erase(It->first);
+ Bucket.Drop();
+ }
+}
+
+void
+ZenCacheMemoryLayer::Scrub(ScrubContext& Ctx)
+{
+ RwLock::SharedLockScope _(m_Lock);
+
+ for (auto& Kv : m_Buckets)
+ {
+ Kv.second->Scrub(Ctx);
+ }
+}
+
+void
+ZenCacheMemoryLayer::GatherAccessTimes(zen::access_tracking::AccessTimes& AccessTimes)
+{
+ using namespace zen::access_tracking;
+
+ RwLock::SharedLockScope _(m_Lock);
+
+ for (auto& Kv : m_Buckets)
+ {
+ std::vector<KeyAccessTime>& Bucket = AccessTimes.Buckets[Kv.first];
+ Kv.second->GatherAccessTimes(Bucket);
+ }
+}
+
+void
+ZenCacheMemoryLayer::Reset()
+{
+ RwLock::ExclusiveLockScope _(m_Lock);
+ m_Buckets.clear();
+}
+
+uint64_t
+ZenCacheMemoryLayer::TotalSize() const
+{
+ uint64_t TotalSize{};
+ RwLock::SharedLockScope _(m_Lock);
+
+ for (auto& Kv : m_Buckets)
+ {
+ TotalSize += Kv.second->TotalSize();
+ }
+
+ return TotalSize;
+}
+
+ZenCacheMemoryLayer::Info
+ZenCacheMemoryLayer::GetInfo() const
+{
+ ZenCacheMemoryLayer::Info Info = {.Config = m_Configuration, .TotalSize = TotalSize()};
+
+ RwLock::SharedLockScope _(m_Lock);
+ Info.BucketNames.reserve(m_Buckets.size());
+ for (auto& Kv : m_Buckets)
+ {
+ Info.BucketNames.push_back(Kv.first);
+ Info.EntryCount += Kv.second->EntryCount();
+ }
+ return Info;
+}
+
+std::optional<ZenCacheMemoryLayer::BucketInfo>
+ZenCacheMemoryLayer::GetBucketInfo(std::string_view Bucket) const
+{
+ RwLock::SharedLockScope _(m_Lock);
+
+ if (auto It = m_Buckets.find(std::string(Bucket)); It != m_Buckets.end())
+ {
+ return ZenCacheMemoryLayer::BucketInfo{.EntryCount = It->second->EntryCount(), .TotalSize = It->second->TotalSize()};
+ }
+ return {};
+}
+
+void
+ZenCacheMemoryLayer::CacheBucket::Scrub(ScrubContext& Ctx)
+{
+ RwLock::SharedLockScope _(m_BucketLock);
+
+ std::vector<IoHash> BadHashes;
+
+ auto ValidateEntry = [](const IoHash& Hash, ZenContentType ContentType, IoBuffer Buffer) {
+ if (ContentType == ZenContentType::kCbObject)
+ {
+ CbValidateError Error = ValidateCompactBinary(Buffer, CbValidateMode::All);
+ return Error == CbValidateError::None;
+ }
+ if (ContentType == ZenContentType::kCompressedBinary)
+ {
+ IoHash RawHash;
+ uint64_t RawSize;
+ if (!CompressedBuffer::ValidateCompressedHeader(Buffer, RawHash, RawSize))
+ {
+ return false;
+ }
+ if (Hash != RawHash)
+ {
+ return false;
+ }
+ }
+ return true;
+ };
+
+ for (auto& Kv : m_CacheMap)
+ {
+ const BucketPayload& Payload = m_Payloads[Kv.second];
+ if (!ValidateEntry(Kv.first, Payload.Payload.GetContentType(), Payload.Payload))
+ {
+ BadHashes.push_back(Kv.first);
+ }
+ }
+
+ if (!BadHashes.empty())
+ {
+ Ctx.ReportBadCidChunks(BadHashes);
+ }
+}
+
+void
+ZenCacheMemoryLayer::CacheBucket::GatherAccessTimes(std::vector<zen::access_tracking::KeyAccessTime>& AccessTimes)
+{
+ RwLock::SharedLockScope _(m_BucketLock);
+ std::transform(m_CacheMap.begin(), m_CacheMap.end(), std::back_inserter(AccessTimes), [this](const auto& Kv) {
+ return access_tracking::KeyAccessTime{.Key = Kv.first, .LastAccess = m_AccessTimes[Kv.second]};
+ });
+}
+
+bool
+ZenCacheMemoryLayer::CacheBucket::Get(const IoHash& HashKey, ZenCacheValue& OutValue)
+{
+ RwLock::SharedLockScope _(m_BucketLock);
+
+ if (auto It = m_CacheMap.find(HashKey); It != m_CacheMap.end())
+ {
+ uint32_t EntryIndex = It.value();
+ ZEN_ASSERT_SLOW(EntryIndex < m_Payloads.size());
+ ZEN_ASSERT_SLOW(m_AccessTimes.size() == m_Payloads.size());
+
+ const BucketPayload& Payload = m_Payloads[EntryIndex];
+ OutValue = {.Value = Payload.Payload, .RawSize = Payload.RawSize, .RawHash = Payload.RawHash};
+ m_AccessTimes[EntryIndex] = GcClock::TickCount();
+
+ return true;
+ }
+
+ return false;
+}
+
+void
+ZenCacheMemoryLayer::CacheBucket::Put(const IoHash& HashKey, const ZenCacheValue& Value)
+{
+ size_t PayloadSize = Value.Value.GetSize();
+ {
+ GcClock::Tick AccessTime = GcClock::TickCount();
+ RwLock::ExclusiveLockScope _(m_BucketLock);
+ if (m_CacheMap.size() == std::numeric_limits<uint32_t>::max())
+ {
+ // No more space in our memory cache!
+ return;
+ }
+ if (auto It = m_CacheMap.find(HashKey); It != m_CacheMap.end())
+ {
+ uint32_t EntryIndex = It.value();
+ ZEN_ASSERT_SLOW(EntryIndex < m_Payloads.size());
+
+ m_TotalSize.fetch_sub(PayloadSize, std::memory_order::relaxed);
+ BucketPayload& Payload = m_Payloads[EntryIndex];
+ Payload.Payload = Value.Value;
+ Payload.RawHash = Value.RawHash;
+ Payload.RawSize = gsl::narrow<uint32_t>(Value.RawSize);
+ m_AccessTimes[EntryIndex] = AccessTime;
+ }
+ else
+ {
+ uint32_t EntryIndex = gsl::narrow<uint32_t>(m_Payloads.size());
+ m_Payloads.emplace_back(
+ BucketPayload{.Payload = Value.Value, .RawSize = gsl::narrow<uint32_t>(Value.RawSize), .RawHash = Value.RawHash});
+ m_AccessTimes.emplace_back(AccessTime);
+ m_CacheMap.insert_or_assign(HashKey, EntryIndex);
+ }
+ ZEN_ASSERT_SLOW(m_Payloads.size() == m_CacheMap.size());
+ ZEN_ASSERT_SLOW(m_AccessTimes.size() == m_Payloads.size());
+ }
+
+ m_TotalSize.fetch_add(PayloadSize, std::memory_order::relaxed);
+}
+
+void
+ZenCacheMemoryLayer::CacheBucket::Drop()
+{
+ RwLock::ExclusiveLockScope _(m_BucketLock);
+ m_CacheMap.clear();
+ m_AccessTimes.clear();
+ m_Payloads.clear();
+ m_TotalSize.store(0);
+}
+
+uint64_t
+ZenCacheMemoryLayer::CacheBucket::EntryCount() const
+{
+ RwLock::SharedLockScope _(m_BucketLock);
+ return static_cast<uint64_t>(m_CacheMap.size());
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+ZenCacheDiskLayer::CacheBucket::CacheBucket(std::string BucketName) : m_BucketName(std::move(BucketName)), m_BucketId(Oid::Zero)
+{
+}
+
+ZenCacheDiskLayer::CacheBucket::~CacheBucket()
+{
+}
+
+bool
+ZenCacheDiskLayer::CacheBucket::OpenOrCreate(std::filesystem::path BucketDir, bool AllowCreate)
+{
+ using namespace std::literals;
+
+ m_BlocksBasePath = BucketDir / "blocks";
+ m_BucketDir = BucketDir;
+
+ CreateDirectories(m_BucketDir);
+
+ std::filesystem::path ManifestPath{m_BucketDir / "zen_manifest"};
+
+ bool IsNew = false;
+
+ CbObject Manifest = LoadCompactBinaryObject(ManifestPath);
+
+ if (Manifest)
+ {
+ m_BucketId = Manifest["BucketId"sv].AsObjectId();
+ if (m_BucketId == Oid::Zero)
+ {
+ return false;
+ }
+ }
+ else if (AllowCreate)
+ {
+ m_BucketId.Generate();
+
+ CbObjectWriter Writer;
+ Writer << "BucketId"sv << m_BucketId;
+ Manifest = Writer.Save();
+ SaveCompactBinaryObject(ManifestPath, Manifest);
+ IsNew = true;
+ }
+ else
+ {
+ return false;
+ }
+
+ OpenLog(IsNew);
+
+ if (!IsNew)
+ {
+ Stopwatch Timer;
+ const auto _ =
+ MakeGuard([&] { ZEN_INFO("read store manifest '{}' in {}", ManifestPath, NiceTimeSpanMs(Timer.GetElapsedTimeMs())); });
+
+ for (CbFieldView Entry : Manifest["Timestamps"sv])
+ {
+ const CbObjectView Obj = Entry.AsObjectView();
+ const IoHash Key = Obj["Key"sv].AsHash();
+
+ if (auto It = m_Index.find(Key); It != m_Index.end())
+ {
+ size_t EntryIndex = It.value();
+ ZEN_ASSERT_SLOW(EntryIndex < m_AccessTimes.size());
+ m_AccessTimes[EntryIndex] = Obj["LastAccess"sv].AsInt64();
+ }
+ }
+ for (CbFieldView Entry : Manifest["RawInfo"sv])
+ {
+ const CbObjectView Obj = Entry.AsObjectView();
+ const IoHash Key = Obj["Key"sv].AsHash();
+ if (auto It = m_Index.find(Key); It != m_Index.end())
+ {
+ size_t EntryIndex = It.value();
+ ZEN_ASSERT_SLOW(EntryIndex < m_Payloads.size());
+ m_Payloads[EntryIndex].RawHash = Obj["RawHash"sv].AsHash();
+ m_Payloads[EntryIndex].RawSize = Obj["RawSize"sv].AsUInt64();
+ }
+ }
+ }
+
+ return true;
+}
+
+void
+ZenCacheDiskLayer::CacheBucket::MakeIndexSnapshot()
+{
+ uint64_t LogCount = m_SlogFile.GetLogCount();
+ if (m_LogFlushPosition == LogCount)
+ {
+ return;
+ }
+
+ ZEN_DEBUG("write store snapshot for '{}'", m_BucketDir / m_BucketName);
+ uint64_t EntryCount = 0;
+ Stopwatch Timer;
+ const auto _ = MakeGuard([&] {
+ ZEN_INFO("wrote store snapshot for '{}' containing {} entries in {}",
+ m_BucketDir / m_BucketName,
+ EntryCount,
+ NiceTimeSpanMs(Timer.GetElapsedTimeMs()));
+ });
+
+ namespace fs = std::filesystem;
+
+ fs::path IndexPath = GetIndexPath(m_BucketDir, m_BucketName);
+ fs::path STmpIndexPath = GetTempIndexPath(m_BucketDir, m_BucketName);
+
+ // Move index away, we keep it if something goes wrong
+ if (fs::is_regular_file(STmpIndexPath))
+ {
+ fs::remove(STmpIndexPath);
+ }
+ if (fs::is_regular_file(IndexPath))
+ {
+ fs::rename(IndexPath, STmpIndexPath);
+ }
+
+ try
+ {
+ // Write the current state of the location map to a new index state
+ std::vector<DiskIndexEntry> Entries;
+
+ {
+ Entries.resize(m_Index.size());
+
+ uint64_t EntryIndex = 0;
+ for (auto& Entry : m_Index)
+ {
+ DiskIndexEntry& IndexEntry = Entries[EntryIndex++];
+ IndexEntry.Key = Entry.first;
+ IndexEntry.Location = m_Payloads[Entry.second].Location;
+ }
+ }
+
+ BasicFile ObjectIndexFile;
+ ObjectIndexFile.Open(IndexPath, BasicFile::Mode::kTruncate);
+ CacheBucketIndexHeader Header = {.EntryCount = Entries.size(),
+ .LogPosition = LogCount,
+ .PayloadAlignment = gsl::narrow<uint32_t>(m_PayloadAlignment)};
+
+ Header.Checksum = CacheBucketIndexHeader::ComputeChecksum(Header);
+
+ ObjectIndexFile.Write(&Header, sizeof(CacheBucketIndexHeader), 0);
+ ObjectIndexFile.Write(Entries.data(), Entries.size() * sizeof(DiskIndexEntry), sizeof(CacheBucketIndexHeader));
+ ObjectIndexFile.Flush();
+ ObjectIndexFile.Close();
+ EntryCount = Entries.size();
+ m_LogFlushPosition = LogCount;
+ }
+ catch (std::exception& Err)
+ {
+ ZEN_ERROR("snapshot FAILED, reason: '{}'", Err.what());
+
+ // Restore any previous snapshot
+
+ if (fs::is_regular_file(STmpIndexPath))
+ {
+ fs::remove(IndexPath);
+ fs::rename(STmpIndexPath, IndexPath);
+ }
+ }
+ if (fs::is_regular_file(STmpIndexPath))
+ {
+ fs::remove(STmpIndexPath);
+ }
+}
+
+uint64_t
+ZenCacheDiskLayer::CacheBucket::ReadIndexFile(const std::filesystem::path& IndexPath, uint32_t& OutVersion)
+{
+ if (std::filesystem::is_regular_file(IndexPath))
+ {
+ BasicFile ObjectIndexFile;
+ ObjectIndexFile.Open(IndexPath, BasicFile::Mode::kRead);
+ uint64_t Size = ObjectIndexFile.FileSize();
+ if (Size >= sizeof(CacheBucketIndexHeader))
+ {
+ CacheBucketIndexHeader Header;
+ ObjectIndexFile.Read(&Header, sizeof(Header), 0);
+ if ((Header.Magic == CacheBucketIndexHeader::ExpectedMagic) &&
+ (Header.Checksum == CacheBucketIndexHeader::ComputeChecksum(Header)) && (Header.PayloadAlignment > 0))
+ {
+ switch (Header.Version)
+ {
+ case CacheBucketIndexHeader::Version2:
+ {
+ uint64_t ExpectedEntryCount = (Size - sizeof(sizeof(CacheBucketIndexHeader))) / sizeof(DiskIndexEntry);
+ if (Header.EntryCount > ExpectedEntryCount)
+ {
+ break;
+ }
+ size_t EntryCount = 0;
+ Stopwatch Timer;
+ const auto _ = MakeGuard([&] {
+ ZEN_INFO("read store '{}' index containing {} entries in {}",
+ IndexPath,
+ EntryCount,
+ NiceTimeSpanMs(Timer.GetElapsedTimeMs()));
+ });
+
+ m_PayloadAlignment = Header.PayloadAlignment;
+
+ std::vector<DiskIndexEntry> Entries;
+ Entries.resize(Header.EntryCount);
+ ObjectIndexFile.Read(Entries.data(),
+ Header.EntryCount * sizeof(DiskIndexEntry),
+ sizeof(CacheBucketIndexHeader));
+
+ m_Payloads.reserve(Header.EntryCount);
+ m_AccessTimes.reserve(Header.EntryCount);
+ m_Index.reserve(Header.EntryCount);
+
+ std::string InvalidEntryReason;
+ for (const DiskIndexEntry& Entry : Entries)
+ {
+ if (!ValidateEntry(Entry, InvalidEntryReason))
+ {
+ ZEN_WARN("skipping invalid entry in '{}', reason: '{}'", IndexPath, InvalidEntryReason);
+ continue;
+ }
+ size_t EntryIndex = m_Payloads.size();
+ m_Payloads.emplace_back(BucketPayload{.Location = Entry.Location, .RawSize = 0, .RawHash = IoHash::Zero});
+ m_AccessTimes.emplace_back(GcClock::TickCount());
+ m_Index.insert_or_assign(Entry.Key, EntryIndex);
+ EntryCount++;
+ }
+ OutVersion = CacheBucketIndexHeader::Version2;
+ return Header.LogPosition;
+ }
+ break;
+ default:
+ break;
+ }
+ }
+ }
+ ZEN_WARN("skipping invalid index file '{}'", IndexPath);
+ }
+ return 0;
+}
+
+uint64_t
+ZenCacheDiskLayer::CacheBucket::ReadLog(const std::filesystem::path& LogPath, uint64_t SkipEntryCount)
+{
+ if (std::filesystem::is_regular_file(LogPath))
+ {
+ uint64_t LogEntryCount = 0;
+ Stopwatch Timer;
+ const auto _ = MakeGuard([&] {
+ ZEN_INFO("read store '{}' log containing {} entries in {}", LogPath, LogEntryCount, NiceTimeSpanMs(Timer.GetElapsedTimeMs()));
+ });
+ TCasLogFile<DiskIndexEntry> CasLog;
+ CasLog.Open(LogPath, CasLogFile::Mode::kRead);
+ if (CasLog.Initialize())
+ {
+ uint64_t EntryCount = CasLog.GetLogCount();
+ if (EntryCount < SkipEntryCount)
+ {
+ ZEN_WARN("reading full log at '{}', reason: Log position from index snapshot is out of range", LogPath);
+ SkipEntryCount = 0;
+ }
+ LogEntryCount = EntryCount - SkipEntryCount;
+ m_Index.reserve(LogEntryCount);
+ uint64_t InvalidEntryCount = 0;
+ CasLog.Replay(
+ [&](const DiskIndexEntry& Record) {
+ std::string InvalidEntryReason;
+ if (Record.Location.Flags & DiskLocation::kTombStone)
+ {
+ m_Index.erase(Record.Key);
+ return;
+ }
+ if (!ValidateEntry(Record, InvalidEntryReason))
+ {
+ ZEN_WARN("skipping invalid entry in '{}', reason: '{}'", LogPath, InvalidEntryReason);
+ ++InvalidEntryCount;
+ return;
+ }
+ size_t EntryIndex = m_Payloads.size();
+ m_Payloads.emplace_back(BucketPayload{.Location = Record.Location, .RawSize = 0u, .RawHash = IoHash::Zero});
+ m_AccessTimes.emplace_back(GcClock::TickCount());
+ m_Index.insert_or_assign(Record.Key, EntryIndex);
+ },
+ SkipEntryCount);
+ if (InvalidEntryCount)
+ {
+ ZEN_WARN("found {} invalid entries in '{}'", InvalidEntryCount, m_BucketDir / m_BucketName);
+ }
+ return LogEntryCount;
+ }
+ }
+ return 0;
+};
+
+void
+ZenCacheDiskLayer::CacheBucket::OpenLog(const bool IsNew)
+{
+ m_TotalStandaloneSize = 0;
+
+ m_Index.clear();
+ m_Payloads.clear();
+ m_AccessTimes.clear();
+
+ std::filesystem::path LogPath = GetLogPath(m_BucketDir, m_BucketName);
+ std::filesystem::path IndexPath = GetIndexPath(m_BucketDir, m_BucketName);
+
+ if (IsNew)
+ {
+ fs::remove(LogPath);
+ fs::remove(IndexPath);
+ fs::remove_all(m_BlocksBasePath);
+ }
+
+ uint64_t LogEntryCount = 0;
+ {
+ uint32_t IndexVersion = 0;
+ m_LogFlushPosition = ReadIndexFile(IndexPath, IndexVersion);
+ if (IndexVersion == 0 && std::filesystem::is_regular_file(IndexPath))
+ {
+ ZEN_WARN("removing invalid index file at '{}'", IndexPath);
+ fs::remove(IndexPath);
+ }
+
+ if (TCasLogFile<DiskIndexEntry>::IsValid(LogPath))
+ {
+ LogEntryCount = ReadLog(LogPath, m_LogFlushPosition);
+ }
+ else
+ {
+ ZEN_WARN("removing invalid cas log at '{}'", LogPath);
+ fs::remove(LogPath);
+ }
+ }
+
+ CreateDirectories(m_BucketDir);
+
+ m_SlogFile.Open(LogPath, CasLogFile::Mode::kWrite);
+
+ std::vector<BlockStoreLocation> KnownLocations;
+ KnownLocations.reserve(m_Index.size());
+ for (const auto& Entry : m_Index)
+ {
+ size_t EntryIndex = Entry.second;
+ const BucketPayload& Payload = m_Payloads[EntryIndex];
+ const DiskLocation& Location = Payload.Location;
+
+ if (Location.IsFlagSet(DiskLocation::kStandaloneFile))
+ {
+ m_TotalStandaloneSize.fetch_add(Location.Size(), std::memory_order::relaxed);
+ continue;
+ }
+ const BlockStoreLocation& BlockLocation = Location.GetBlockLocation(m_PayloadAlignment);
+ KnownLocations.push_back(BlockLocation);
+ }
+
+ m_BlockStore.Initialize(m_BlocksBasePath, MaxBlockSize, BlockStoreDiskLocation::MaxBlockIndex + 1, KnownLocations);
+ if (IsNew || LogEntryCount > 0)
+ {
+ MakeIndexSnapshot();
+ }
+ // TODO: should validate integrity of container files here
+}
+
+void
+ZenCacheDiskLayer::CacheBucket::BuildPath(PathBuilderBase& Path, const IoHash& HashKey) const
+{
+ char HexString[sizeof(HashKey.Hash) * 2];
+ ToHexBytes(HashKey.Hash, sizeof HashKey.Hash, HexString);
+
+ Path.Append(m_BucketDir);
+ Path.AppendSeparator();
+ Path.Append(L"blob");
+ Path.AppendSeparator();
+ Path.AppendAsciiRange(HexString, HexString + 3);
+ Path.AppendSeparator();
+ Path.AppendAsciiRange(HexString + 3, HexString + 5);
+ Path.AppendSeparator();
+ Path.AppendAsciiRange(HexString + 5, HexString + sizeof(HexString));
+}
+
+IoBuffer
+ZenCacheDiskLayer::CacheBucket::GetInlineCacheValue(const DiskLocation& Loc) const
+{
+ BlockStoreLocation Location = Loc.GetBlockLocation(m_PayloadAlignment);
+
+ IoBuffer Value = m_BlockStore.TryGetChunk(Location);
+ if (Value)
+ {
+ Value.SetContentType(Loc.GetContentType());
+ }
+
+ return Value;
+}
+
+IoBuffer
+ZenCacheDiskLayer::CacheBucket::GetStandaloneCacheValue(const DiskLocation& Loc, const IoHash& HashKey) const
+{
+ ExtendablePathBuilder<256> DataFilePath;
+ BuildPath(DataFilePath, HashKey);
+
+ RwLock::SharedLockScope ValueLock(LockForHash(HashKey));
+
+ if (IoBuffer Data = IoBufferBuilder::MakeFromFile(DataFilePath.ToPath()))
+ {
+ Data.SetContentType(Loc.GetContentType());
+
+ return Data;
+ }
+
+ return {};
+}
+
+bool
+ZenCacheDiskLayer::CacheBucket::Get(const IoHash& HashKey, ZenCacheValue& OutValue)
+{
+ RwLock::SharedLockScope _(m_IndexLock);
+ auto It = m_Index.find(HashKey);
+ if (It == m_Index.end())
+ {
+ return false;
+ }
+ size_t EntryIndex = It.value();
+ const BucketPayload& Payload = m_Payloads[EntryIndex];
+ m_AccessTimes[EntryIndex] = GcClock::TickCount();
+ DiskLocation Location = Payload.Location;
+ OutValue.RawSize = Payload.RawSize;
+ OutValue.RawHash = Payload.RawHash;
+ if (Location.IsFlagSet(DiskLocation::kStandaloneFile))
+ {
+ // We don't need to hold the index lock when we read a standalone file
+ _.ReleaseNow();
+ OutValue.Value = GetStandaloneCacheValue(Location, HashKey);
+ }
+ else
+ {
+ OutValue.Value = GetInlineCacheValue(Location);
+ }
+ _.ReleaseNow();
+
+ if (!Location.IsFlagSet(DiskLocation::kStructured))
+ {
+ if (OutValue.RawHash == IoHash::Zero && OutValue.RawSize == 0 && OutValue.Value.GetSize() > 0)
+ {
+ if (Location.IsFlagSet(DiskLocation::kCompressed))
+ {
+ (void)CompressedBuffer::FromCompressed(SharedBuffer(OutValue.Value), OutValue.RawHash, OutValue.RawSize);
+ }
+ else
+ {
+ OutValue.RawHash = IoHash::HashBuffer(OutValue.Value);
+ OutValue.RawSize = OutValue.Value.GetSize();
+ }
+ RwLock::ExclusiveLockScope __(m_IndexLock);
+ if (auto WriteIt = m_Index.find(HashKey); WriteIt != m_Index.end())
+ {
+ BucketPayload& WritePayload = m_Payloads[WriteIt.value()];
+ WritePayload.RawHash = OutValue.RawHash;
+ WritePayload.RawSize = OutValue.RawSize;
+
+ m_LogFlushPosition = 0; // Force resave of index on exit
+ }
+ }
+ }
+
+ return (bool)OutValue.Value;
+}
+
+void
+ZenCacheDiskLayer::CacheBucket::Put(const IoHash& HashKey, const ZenCacheValue& Value)
+{
+ if (Value.Value.Size() >= m_LargeObjectThreshold)
+ {
+ return PutStandaloneCacheValue(HashKey, Value);
+ }
+ PutInlineCacheValue(HashKey, Value);
+}
+
+bool
+ZenCacheDiskLayer::CacheBucket::Drop()
+{
+ RwLock::ExclusiveLockScope _(m_IndexLock);
+
+ std::vector<std::unique_ptr<RwLock::ExclusiveLockScope>> ShardLocks;
+ ShardLocks.reserve(256);
+ for (RwLock& Lock : m_ShardedLocks)
+ {
+ ShardLocks.push_back(std::make_unique<RwLock::ExclusiveLockScope>(Lock));
+ }
+ m_BlockStore.Close();
+ m_SlogFile.Close();
+
+ bool Deleted = MoveAndDeleteDirectory(m_BucketDir);
+
+ m_Index.clear();
+ m_Payloads.clear();
+ m_AccessTimes.clear();
+ return Deleted;
+}
+
+void
+ZenCacheDiskLayer::CacheBucket::Flush()
+{
+ m_BlockStore.Flush();
+
+ RwLock::SharedLockScope _(m_IndexLock);
+ m_SlogFile.Flush();
+ MakeIndexSnapshot();
+ SaveManifest();
+}
+
+void
+ZenCacheDiskLayer::CacheBucket::SaveManifest()
+{
+ using namespace std::literals;
+
+ CbObjectWriter Writer;
+ Writer << "BucketId"sv << m_BucketId;
+
+ if (!m_Index.empty())
+ {
+ Writer.BeginArray("Timestamps"sv);
+ for (auto& Kv : m_Index)
+ {
+ const IoHash& Key = Kv.first;
+ GcClock::Tick AccessTime = m_AccessTimes[Kv.second];
+
+ Writer.BeginObject();
+ Writer << "Key"sv << Key;
+ Writer << "LastAccess"sv << AccessTime;
+ Writer.EndObject();
+ }
+ Writer.EndArray();
+
+ Writer.BeginArray("RawInfo"sv);
+ {
+ for (auto& Kv : m_Index)
+ {
+ const IoHash& Key = Kv.first;
+ const BucketPayload& Payload = m_Payloads[Kv.second];
+ if (Payload.RawHash != IoHash::Zero)
+ {
+ Writer.BeginObject();
+ Writer << "Key"sv << Key;
+ Writer << "RawHash"sv << Payload.RawHash;
+ Writer << "RawSize"sv << Payload.RawSize;
+ Writer.EndObject();
+ }
+ }
+ }
+ Writer.EndArray();
+ }
+
+ SaveCompactBinaryObject(m_BucketDir / "zen_manifest", Writer.Save());
+}
+
+void
+ZenCacheDiskLayer::CacheBucket::Scrub(ScrubContext& Ctx)
+{
+ std::vector<IoHash> BadKeys;
+ uint64_t ChunkCount{0}, ChunkBytes{0};
+ std::vector<BlockStoreLocation> ChunkLocations;
+ std::vector<IoHash> ChunkIndexToChunkHash;
+
+ auto ValidateEntry = [](const IoHash& Hash, ZenContentType ContentType, IoBuffer Buffer) {
+ if (ContentType == ZenContentType::kCbObject)
+ {
+ CbValidateError Error = ValidateCompactBinary(Buffer, CbValidateMode::All);
+ return Error == CbValidateError::None;
+ }
+ if (ContentType == ZenContentType::kCompressedBinary)
+ {
+ IoHash RawHash;
+ uint64_t RawSize;
+ if (!CompressedBuffer::ValidateCompressedHeader(Buffer, RawHash, RawSize))
+ {
+ return false;
+ }
+ if (RawHash != Hash)
+ {
+ return false;
+ }
+ }
+ return true;
+ };
+
+ RwLock::SharedLockScope _(m_IndexLock);
+
+ const size_t BlockChunkInitialCount = m_Index.size() / 4;
+ ChunkLocations.reserve(BlockChunkInitialCount);
+ ChunkIndexToChunkHash.reserve(BlockChunkInitialCount);
+
+ for (auto& Kv : m_Index)
+ {
+ const IoHash& HashKey = Kv.first;
+ const BucketPayload& Payload = m_Payloads[Kv.second];
+ const DiskLocation& Loc = Payload.Location;
+
+ if (Loc.IsFlagSet(DiskLocation::kStandaloneFile))
+ {
+ ++ChunkCount;
+ ChunkBytes += Loc.Size();
+ if (Loc.GetContentType() == ZenContentType::kBinary)
+ {
+ ExtendablePathBuilder<256> DataFilePath;
+ BuildPath(DataFilePath, HashKey);
+
+ RwLock::SharedLockScope ValueLock(LockForHash(HashKey));
+
+ std::error_code Ec;
+ uintmax_t size = std::filesystem::file_size(DataFilePath.ToPath(), Ec);
+ if (Ec)
+ {
+ BadKeys.push_back(HashKey);
+ }
+ if (size != Loc.Size())
+ {
+ BadKeys.push_back(HashKey);
+ }
+ continue;
+ }
+ IoBuffer Buffer = GetStandaloneCacheValue(Loc, HashKey);
+ if (!Buffer)
+ {
+ BadKeys.push_back(HashKey);
+ continue;
+ }
+ if (!ValidateEntry(HashKey, Loc.GetContentType(), Buffer))
+ {
+ BadKeys.push_back(HashKey);
+ continue;
+ }
+ }
+ else
+ {
+ ChunkLocations.emplace_back(Loc.GetBlockLocation(m_PayloadAlignment));
+ ChunkIndexToChunkHash.push_back(HashKey);
+ continue;
+ }
+ }
+
+ const auto ValidateSmallChunk = [&](size_t ChunkIndex, const void* Data, uint64_t Size) {
+ ++ChunkCount;
+ ChunkBytes += Size;
+ const IoHash& Hash = ChunkIndexToChunkHash[ChunkIndex];
+ if (!Data)
+ {
+ // ChunkLocation out of range of stored blocks
+ BadKeys.push_back(Hash);
+ return;
+ }
+ IoBuffer Buffer(IoBuffer::Wrap, Data, Size);
+ if (!Buffer)
+ {
+ BadKeys.push_back(Hash);
+ return;
+ }
+ const BucketPayload& Payload = m_Payloads[m_Index.at(Hash)];
+ ZenContentType ContentType = Payload.Location.GetContentType();
+ if (!ValidateEntry(Hash, ContentType, Buffer))
+ {
+ BadKeys.push_back(Hash);
+ return;
+ }
+ };
+
+ const auto ValidateLargeChunk = [&](size_t ChunkIndex, BlockStoreFile& File, uint64_t Offset, uint64_t Size) {
+ ++ChunkCount;
+ ChunkBytes += Size;
+ const IoHash& Hash = ChunkIndexToChunkHash[ChunkIndex];
+ // TODO: Add API to verify compressed buffer and possible structure data without having to memorymap the whole file
+ IoBuffer Buffer(IoBuffer::BorrowedFile, File.GetBasicFile().Handle(), Offset, Size);
+ if (!Buffer)
+ {
+ BadKeys.push_back(Hash);
+ return;
+ }
+ const BucketPayload& Payload = m_Payloads[m_Index.at(Hash)];
+ ZenContentType ContentType = Payload.Location.GetContentType();
+ if (!ValidateEntry(Hash, ContentType, Buffer))
+ {
+ BadKeys.push_back(Hash);
+ return;
+ }
+ };
+
+ m_BlockStore.IterateChunks(ChunkLocations, ValidateSmallChunk, ValidateLargeChunk);
+
+ _.ReleaseNow();
+
+ Ctx.ReportScrubbed(ChunkCount, ChunkBytes);
+
+ if (!BadKeys.empty())
+ {
+ ZEN_WARN("Scrubbing found {} bad chunks in '{}'", BadKeys.size(), m_BucketDir / m_BucketName);
+
+ if (Ctx.RunRecovery())
+ {
+ // Deal with bad chunks by removing them from our lookup map
+
+ std::vector<DiskIndexEntry> LogEntries;
+ LogEntries.reserve(BadKeys.size());
+
+ {
+ RwLock::ExclusiveLockScope __(m_IndexLock);
+ for (const IoHash& BadKey : BadKeys)
+ {
+ // Log a tombstone and delete the in-memory index for the bad entry
+ const auto It = m_Index.find(BadKey);
+ const BucketPayload& Payload = m_Payloads[It->second];
+ DiskLocation Location = Payload.Location;
+ Location.Flags |= DiskLocation::kTombStone;
+ LogEntries.push_back(DiskIndexEntry{.Key = BadKey, .Location = Location});
+ m_Index.erase(BadKey);
+ }
+ }
+ for (const DiskIndexEntry& Entry : LogEntries)
+ {
+ if (Entry.Location.IsFlagSet(DiskLocation::kStandaloneFile))
+ {
+ ExtendablePathBuilder<256> Path;
+ BuildPath(Path, Entry.Key);
+ fs::path FilePath = Path.ToPath();
+ RwLock::ExclusiveLockScope ValueLock(LockForHash(Entry.Key));
+ if (fs::is_regular_file(FilePath))
+ {
+ ZEN_DEBUG("deleting bad standalone cache file '{}'", Path.ToUtf8());
+ std::error_code Ec;
+ fs::remove(FilePath, Ec); // We don't care if we fail, we are no longer tracking this file...
+ }
+ m_TotalStandaloneSize.fetch_sub(Entry.Location.Size(), std::memory_order::relaxed);
+ }
+ }
+ m_SlogFile.Append(LogEntries);
+
+ // Clean up m_AccessTimes and m_Payloads vectors
+ {
+ std::vector<BucketPayload> Payloads;
+ std::vector<AccessTime> AccessTimes;
+ IndexMap Index;
+
+ {
+ RwLock::ExclusiveLockScope __(m_IndexLock);
+ size_t EntryCount = m_Index.size();
+ Payloads.reserve(EntryCount);
+ AccessTimes.reserve(EntryCount);
+ Index.reserve(EntryCount);
+ for (auto It : m_Index)
+ {
+ size_t EntryIndex = Payloads.size();
+ Payloads.push_back(m_Payloads[EntryIndex]);
+ AccessTimes.push_back(m_AccessTimes[EntryIndex]);
+ Index.insert({It.first, EntryIndex});
+ }
+ m_Index.swap(Index);
+ m_Payloads.swap(Payloads);
+ m_AccessTimes.swap(AccessTimes);
+ }
+ }
+ }
+ }
+
+ // Let whomever it concerns know about the bad chunks. This could
+ // be used to invalidate higher level data structures more efficiently
+ // than a full validation pass might be able to do
+ Ctx.ReportBadCidChunks(BadKeys);
+
+ ZEN_INFO("cache bucket scrubbed: {} chunks ({})", ChunkCount, NiceBytes(ChunkBytes));
+}
+
+void
+ZenCacheDiskLayer::CacheBucket::GatherReferences(GcContext& GcCtx)
+{
+ ZEN_TRACE_CPU("Z$::DiskLayer::CacheBucket::GatherReferences");
+
+ uint64_t WriteBlockTimeUs = 0;
+ uint64_t WriteBlockLongestTimeUs = 0;
+ uint64_t ReadBlockTimeUs = 0;
+ uint64_t ReadBlockLongestTimeUs = 0;
+
+ Stopwatch TotalTimer;
+ const auto _ = MakeGuard([&] {
+ ZEN_DEBUG("gathered references from '{}' in {} write lock: {} ({}), read lock: {} ({})",
+ m_BucketDir / m_BucketName,
+ NiceTimeSpanMs(TotalTimer.GetElapsedTimeMs()),
+ NiceLatencyNs(WriteBlockTimeUs),
+ NiceLatencyNs(WriteBlockLongestTimeUs),
+ NiceLatencyNs(ReadBlockTimeUs),
+ NiceLatencyNs(ReadBlockLongestTimeUs));
+ });
+
+ const GcClock::TimePoint ExpireTime = GcCtx.ExpireTime();
+
+ const GcClock::Tick ExpireTicks = ExpireTime.time_since_epoch().count();
+
+ IndexMap Index;
+ std::vector<AccessTime> AccessTimes;
+ std::vector<BucketPayload> Payloads;
+ {
+ RwLock::SharedLockScope __(m_IndexLock);
+ Stopwatch Timer;
+ const auto ___ = MakeGuard([&] {
+ uint64_t ElapsedUs = Timer.GetElapsedTimeUs();
+ WriteBlockTimeUs += ElapsedUs;
+ WriteBlockLongestTimeUs = std::max(ElapsedUs, WriteBlockLongestTimeUs);
+ });
+ Index = m_Index;
+ AccessTimes = m_AccessTimes;
+ Payloads = m_Payloads;
+ }
+
+ std::vector<IoHash> ExpiredKeys;
+ ExpiredKeys.reserve(1024);
+
+ std::vector<IoHash> Cids;
+ Cids.reserve(1024);
+
+ for (const auto& Entry : Index)
+ {
+ const IoHash& Key = Entry.first;
+ GcClock::Tick AccessTime = AccessTimes[Entry.second];
+ if (AccessTime < ExpireTicks)
+ {
+ ExpiredKeys.push_back(Key);
+ continue;
+ }
+
+ const DiskLocation& Loc = Payloads[Entry.second].Location;
+
+ if (Loc.IsFlagSet(DiskLocation::kStructured))
+ {
+ if (Cids.size() > 1024)
+ {
+ GcCtx.AddRetainedCids(Cids);
+ Cids.clear();
+ }
+
+ IoBuffer Buffer;
+ {
+ RwLock::SharedLockScope __(m_IndexLock);
+ Stopwatch Timer;
+ const auto ___ = MakeGuard([&] {
+ uint64_t ElapsedUs = Timer.GetElapsedTimeUs();
+ WriteBlockTimeUs += ElapsedUs;
+ WriteBlockLongestTimeUs = std::max(ElapsedUs, WriteBlockLongestTimeUs);
+ });
+ if (Loc.IsFlagSet(DiskLocation::kStandaloneFile))
+ {
+ // We don't need to hold the index lock when we read a standalone file
+ __.ReleaseNow();
+ if (Buffer = GetStandaloneCacheValue(Loc, Key); !Buffer)
+ {
+ continue;
+ }
+ }
+ else if (Buffer = GetInlineCacheValue(Loc); !Buffer)
+ {
+ continue;
+ }
+ }
+
+ ZEN_ASSERT(Buffer);
+ ZEN_ASSERT(Buffer.GetContentType() == ZenContentType::kCbObject);
+ CbObject Obj(SharedBuffer{Buffer});
+ Obj.IterateAttachments([&Cids](CbFieldView Field) { Cids.push_back(Field.AsAttachment()); });
+ }
+ }
+
+ GcCtx.AddRetainedCids(Cids);
+ GcCtx.SetExpiredCacheKeys(m_BucketDir.string(), std::move(ExpiredKeys));
+}
+
+void
+ZenCacheDiskLayer::CacheBucket::CollectGarbage(GcContext& GcCtx)
+{
+ ZEN_TRACE_CPU("Z$::DiskLayer::CacheBucket::CollectGarbage");
+
+ ZEN_DEBUG("collecting garbage from '{}'", m_BucketDir / m_BucketName);
+
+ Stopwatch TotalTimer;
+ uint64_t WriteBlockTimeUs = 0;
+ uint64_t WriteBlockLongestTimeUs = 0;
+ uint64_t ReadBlockTimeUs = 0;
+ uint64_t ReadBlockLongestTimeUs = 0;
+ uint64_t TotalChunkCount = 0;
+ uint64_t DeletedSize = 0;
+ uint64_t OldTotalSize = TotalSize();
+
+ std::unordered_set<IoHash> DeletedChunks;
+ uint64_t MovedCount = 0;
+
+ const auto _ = MakeGuard([&] {
+ ZEN_DEBUG(
+ "garbage collect from '{}' DONE after {}, write lock: {} ({}), read lock: {} ({}), collected {} bytes, deleted {} and moved "
+ "{} "
+ "of {} "
+ "entires ({}).",
+ m_BucketDir / m_BucketName,
+ NiceTimeSpanMs(TotalTimer.GetElapsedTimeMs()),
+ NiceLatencyNs(WriteBlockTimeUs),
+ NiceLatencyNs(WriteBlockLongestTimeUs),
+ NiceLatencyNs(ReadBlockTimeUs),
+ NiceLatencyNs(ReadBlockLongestTimeUs),
+ NiceBytes(DeletedSize),
+ DeletedChunks.size(),
+ MovedCount,
+ TotalChunkCount,
+ NiceBytes(OldTotalSize));
+ RwLock::SharedLockScope _(m_IndexLock);
+ SaveManifest();
+ });
+
+ m_SlogFile.Flush();
+
+ std::span<const IoHash> ExpiredCacheKeys = GcCtx.ExpiredCacheKeys(m_BucketDir.string());
+ std::vector<IoHash> DeleteCacheKeys;
+ DeleteCacheKeys.reserve(ExpiredCacheKeys.size());
+ GcCtx.FilterCids(ExpiredCacheKeys, [&](const IoHash& ChunkHash, bool Keep) {
+ if (Keep)
+ {
+ return;
+ }
+ DeleteCacheKeys.push_back(ChunkHash);
+ });
+ if (DeleteCacheKeys.empty())
+ {
+ ZEN_DEBUG("garbage collect SKIPPED, for '{}', no expired cache keys found", m_BucketDir / m_BucketName);
+ return;
+ }
+
+ auto __ = MakeGuard([&]() {
+ if (!DeletedChunks.empty())
+ {
+ // Clean up m_AccessTimes and m_Payloads vectors
+ std::vector<BucketPayload> Payloads;
+ std::vector<AccessTime> AccessTimes;
+ IndexMap Index;
+
+ {
+ RwLock::ExclusiveLockScope _(m_IndexLock);
+ Stopwatch Timer;
+ const auto ___ = MakeGuard([&] {
+ uint64_t ElapsedUs = Timer.GetElapsedTimeUs();
+ WriteBlockTimeUs += ElapsedUs;
+ WriteBlockLongestTimeUs = std::max(ElapsedUs, WriteBlockLongestTimeUs);
+ });
+ size_t EntryCount = m_Index.size();
+ Payloads.reserve(EntryCount);
+ AccessTimes.reserve(EntryCount);
+ Index.reserve(EntryCount);
+ for (auto It : m_Index)
+ {
+ size_t EntryIndex = Payloads.size();
+ Payloads.push_back(m_Payloads[EntryIndex]);
+ AccessTimes.push_back(m_AccessTimes[EntryIndex]);
+ Index.insert({It.first, EntryIndex});
+ }
+ m_Index.swap(Index);
+ m_Payloads.swap(Payloads);
+ m_AccessTimes.swap(AccessTimes);
+ }
+ GcCtx.AddDeletedCids(std::vector<IoHash>(DeletedChunks.begin(), DeletedChunks.end()));
+ }
+ });
+
+ std::vector<DiskIndexEntry> ExpiredStandaloneEntries;
+ IndexMap Index;
+ BlockStore::ReclaimSnapshotState BlockStoreState;
+ {
+ RwLock::SharedLockScope __(m_IndexLock);
+ Stopwatch Timer;
+ const auto ____ = MakeGuard([&] {
+ uint64_t ElapsedUs = Timer.GetElapsedTimeUs();
+ WriteBlockTimeUs += ElapsedUs;
+ WriteBlockLongestTimeUs = std::max(ElapsedUs, WriteBlockLongestTimeUs);
+ });
+ if (m_Index.empty())
+ {
+ ZEN_DEBUG("garbage collect SKIPPED, for '{}', container is empty", m_BucketDir / m_BucketName);
+ return;
+ }
+ BlockStoreState = m_BlockStore.GetReclaimSnapshotState();
+
+ SaveManifest();
+ Index = m_Index;
+
+ for (const IoHash& Key : DeleteCacheKeys)
+ {
+ if (auto It = Index.find(Key); It != Index.end())
+ {
+ const BucketPayload& Payload = m_Payloads[It->second];
+ DiskIndexEntry Entry = {.Key = It->first, .Location = Payload.Location};
+ if (Entry.Location.Flags & DiskLocation::kStandaloneFile)
+ {
+ Entry.Location.Flags |= DiskLocation::kTombStone;
+ ExpiredStandaloneEntries.push_back(Entry);
+ }
+ }
+ }
+ if (GcCtx.IsDeletionMode())
+ {
+ for (const auto& Entry : ExpiredStandaloneEntries)
+ {
+ m_Index.erase(Entry.Key);
+ m_TotalStandaloneSize.fetch_sub(Entry.Location.Size(), std::memory_order::relaxed);
+ DeletedChunks.insert(Entry.Key);
+ }
+ m_SlogFile.Append(ExpiredStandaloneEntries);
+ }
+ }
+
+ if (GcCtx.IsDeletionMode())
+ {
+ std::error_code Ec;
+ ExtendablePathBuilder<256> Path;
+
+ for (const auto& Entry : ExpiredStandaloneEntries)
+ {
+ const IoHash& Key = Entry.Key;
+ const DiskLocation& Loc = Entry.Location;
+
+ Path.Reset();
+ BuildPath(Path, Key);
+ fs::path FilePath = Path.ToPath();
+
+ {
+ RwLock::SharedLockScope __(m_IndexLock);
+ Stopwatch Timer;
+ const auto ____ = MakeGuard([&] {
+ uint64_t ElapsedUs = Timer.GetElapsedTimeUs();
+ WriteBlockTimeUs += ElapsedUs;
+ WriteBlockLongestTimeUs = std::max(ElapsedUs, WriteBlockLongestTimeUs);
+ });
+ if (m_Index.contains(Key))
+ {
+ // Someone added it back, let the file on disk be
+ ZEN_DEBUG("skipping z$ delete standalone of file '{}' FAILED, it has been added back", Path.ToUtf8());
+ continue;
+ }
+ __.ReleaseNow();
+
+ RwLock::ExclusiveLockScope ValueLock(LockForHash(Key));
+ if (fs::is_regular_file(FilePath))
+ {
+ ZEN_DEBUG("deleting standalone cache file '{}'", Path.ToUtf8());
+ fs::remove(FilePath, Ec);
+ }
+ }
+
+ if (Ec)
+ {
+ ZEN_WARN("delete expired z$ standalone file '{}' FAILED, reason: '{}'", Path.ToUtf8(), Ec.message());
+ Ec.clear();
+ DiskLocation RestoreLocation = Loc;
+ RestoreLocation.Flags &= ~DiskLocation::kTombStone;
+
+ RwLock::ExclusiveLockScope __(m_IndexLock);
+ Stopwatch Timer;
+ const auto ___ = MakeGuard([&] {
+ uint64_t ElapsedUs = Timer.GetElapsedTimeUs();
+ ReadBlockTimeUs += ElapsedUs;
+ ReadBlockLongestTimeUs = std::max(ElapsedUs, ReadBlockLongestTimeUs);
+ });
+ if (m_Index.contains(Key))
+ {
+ continue;
+ }
+ m_SlogFile.Append(DiskIndexEntry{.Key = Key, .Location = RestoreLocation});
+ size_t EntryIndex = m_Payloads.size();
+ m_Payloads.emplace_back(BucketPayload{.Location = RestoreLocation});
+ m_AccessTimes.emplace_back(GcClock::TickCount());
+ m_Index.insert({Key, EntryIndex});
+ m_TotalStandaloneSize.fetch_add(RestoreLocation.Size(), std::memory_order::relaxed);
+ DeletedChunks.erase(Key);
+ continue;
+ }
+ DeletedSize += Entry.Location.Size();
+ }
+ }
+
+ TotalChunkCount = Index.size();
+
+ std::vector<IoHash> TotalChunkHashes;
+ TotalChunkHashes.reserve(TotalChunkCount);
+ for (const auto& Entry : Index)
+ {
+ const DiskLocation& Location = m_Payloads[Entry.second].Location;
+
+ if (Location.Flags & DiskLocation::kStandaloneFile)
+ {
+ continue;
+ }
+ TotalChunkHashes.push_back(Entry.first);
+ }
+
+ if (TotalChunkHashes.empty())
+ {
+ return;
+ }
+ TotalChunkCount = TotalChunkHashes.size();
+
+ std::vector<BlockStoreLocation> ChunkLocations;
+ BlockStore::ChunkIndexArray KeepChunkIndexes;
+ std::vector<IoHash> ChunkIndexToChunkHash;
+ ChunkLocations.reserve(TotalChunkCount);
+ ChunkLocations.reserve(TotalChunkCount);
+ ChunkIndexToChunkHash.reserve(TotalChunkCount);
+
+ GcCtx.FilterCids(TotalChunkHashes, [&](const IoHash& ChunkHash, bool Keep) {
+ auto KeyIt = Index.find(ChunkHash);
+ const DiskLocation& DiskLocation = m_Payloads[KeyIt->second].Location;
+ BlockStoreLocation Location = DiskLocation.GetBlockLocation(m_PayloadAlignment);
+ size_t ChunkIndex = ChunkLocations.size();
+ ChunkLocations.push_back(Location);
+ ChunkIndexToChunkHash[ChunkIndex] = ChunkHash;
+ if (Keep)
+ {
+ KeepChunkIndexes.push_back(ChunkIndex);
+ }
+ });
+
+ size_t DeleteCount = TotalChunkCount - KeepChunkIndexes.size();
+
+ const bool PerformDelete = GcCtx.IsDeletionMode() && GcCtx.CollectSmallObjects();
+ if (!PerformDelete)
+ {
+ m_BlockStore.ReclaimSpace(BlockStoreState, ChunkLocations, KeepChunkIndexes, m_PayloadAlignment, true);
+ uint64_t CurrentTotalSize = TotalSize();
+ ZEN_DEBUG("garbage collect from '{}' DISABLED, found {} chunks of total {} {}",
+ m_BucketDir / m_BucketName,
+ DeleteCount,
+ TotalChunkCount,
+ NiceBytes(CurrentTotalSize));
+ return;
+ }
+
+ m_BlockStore.ReclaimSpace(
+ BlockStoreState,
+ ChunkLocations,
+ KeepChunkIndexes,
+ m_PayloadAlignment,
+ false,
+ [&](const BlockStore::MovedChunksArray& MovedChunks, const BlockStore::ChunkIndexArray& RemovedChunks) {
+ std::vector<DiskIndexEntry> LogEntries;
+ LogEntries.reserve(MovedChunks.size() + RemovedChunks.size());
+ for (const auto& Entry : MovedChunks)
+ {
+ size_t ChunkIndex = Entry.first;
+ const BlockStoreLocation& NewLocation = Entry.second;
+ const IoHash& ChunkHash = ChunkIndexToChunkHash[ChunkIndex];
+ const BucketPayload& OldPayload = m_Payloads[Index[ChunkHash]];
+ const DiskLocation& OldDiskLocation = OldPayload.Location;
+ LogEntries.push_back(
+ {.Key = ChunkHash, .Location = DiskLocation(NewLocation, m_PayloadAlignment, OldDiskLocation.GetFlags())});
+ }
+ for (const size_t ChunkIndex : RemovedChunks)
+ {
+ const IoHash& ChunkHash = ChunkIndexToChunkHash[ChunkIndex];
+ const BucketPayload& OldPayload = m_Payloads[Index[ChunkHash]];
+ const DiskLocation& OldDiskLocation = OldPayload.Location;
+ LogEntries.push_back({.Key = ChunkHash,
+ .Location = DiskLocation(OldDiskLocation.GetBlockLocation(m_PayloadAlignment),
+ m_PayloadAlignment,
+ OldDiskLocation.GetFlags() | DiskLocation::kTombStone)});
+ DeletedChunks.insert(ChunkHash);
+ }
+
+ m_SlogFile.Append(LogEntries);
+ m_SlogFile.Flush();
+ {
+ RwLock::ExclusiveLockScope __(m_IndexLock);
+ Stopwatch Timer;
+ const auto ____ = MakeGuard([&] {
+ uint64_t ElapsedUs = Timer.GetElapsedTimeUs();
+ ReadBlockTimeUs += ElapsedUs;
+ ReadBlockLongestTimeUs = std::max(ElapsedUs, ReadBlockLongestTimeUs);
+ });
+ for (const DiskIndexEntry& Entry : LogEntries)
+ {
+ if (Entry.Location.GetFlags() & DiskLocation::kTombStone)
+ {
+ m_Index.erase(Entry.Key);
+ continue;
+ }
+ m_Payloads[m_Index[Entry.Key]].Location = Entry.Location;
+ }
+ }
+ },
+ [&]() { return GcCtx.CollectSmallObjects(); });
+}
+
+void
+ZenCacheDiskLayer::CacheBucket::UpdateAccessTimes(const std::vector<zen::access_tracking::KeyAccessTime>& AccessTimes)
+{
+ using namespace access_tracking;
+
+ for (const KeyAccessTime& KeyTime : AccessTimes)
+ {
+ if (auto It = m_Index.find(KeyTime.Key); It != m_Index.end())
+ {
+ size_t EntryIndex = It.value();
+ ZEN_ASSERT_SLOW(EntryIndex < m_AccessTimes.size());
+ m_AccessTimes[EntryIndex] = KeyTime.LastAccess;
+ }
+ }
+}
+
+uint64_t
+ZenCacheDiskLayer::CacheBucket::EntryCount() const
+{
+ RwLock::SharedLockScope _(m_IndexLock);
+ return static_cast<uint64_t>(m_Index.size());
+}
+
+CacheValueDetails::ValueDetails
+ZenCacheDiskLayer::CacheBucket::GetValueDetails(const IoHash& Key, size_t Index) const
+{
+ std::vector<IoHash> Attachments;
+ const BucketPayload& Payload = m_Payloads[Index];
+ if (Payload.Location.IsFlagSet(DiskLocation::kStructured))
+ {
+ IoBuffer Value = Payload.Location.IsFlagSet(DiskLocation::kStandaloneFile) ? GetStandaloneCacheValue(Payload.Location, Key)
+ : GetInlineCacheValue(Payload.Location);
+ CbObject Obj(SharedBuffer{Value});
+ Obj.IterateAttachments([&Attachments](CbFieldView Field) { Attachments.emplace_back(Field.AsAttachment()); });
+ }
+ return CacheValueDetails::ValueDetails{.Size = Payload.Location.Size(),
+ .RawSize = Payload.RawSize,
+ .RawHash = Payload.RawHash,
+ .LastAccess = m_AccessTimes[Index],
+ .Attachments = std::move(Attachments),
+ .ContentType = Payload.Location.GetContentType()};
+}
+
+CacheValueDetails::BucketDetails
+ZenCacheDiskLayer::CacheBucket::GetValueDetails(const std::string_view ValueFilter) const
+{
+ CacheValueDetails::BucketDetails Details;
+ RwLock::SharedLockScope _(m_IndexLock);
+ if (ValueFilter.empty())
+ {
+ Details.Values.reserve(m_Index.size());
+ for (const auto& It : m_Index)
+ {
+ Details.Values.insert_or_assign(It.first, GetValueDetails(It.first, It.second));
+ }
+ }
+ else
+ {
+ IoHash Key = IoHash::FromHexString(ValueFilter);
+ if (auto It = m_Index.find(Key); It != m_Index.end())
+ {
+ Details.Values.insert_or_assign(It->first, GetValueDetails(It->first, It->second));
+ }
+ }
+ return Details;
+}
+
+void
+ZenCacheDiskLayer::CollectGarbage(GcContext& GcCtx)
+{
+ RwLock::SharedLockScope _(m_Lock);
+
+ for (auto& Kv : m_Buckets)
+ {
+ CacheBucket& Bucket = *Kv.second;
+ Bucket.CollectGarbage(GcCtx);
+ }
+}
+
+void
+ZenCacheDiskLayer::UpdateAccessTimes(const zen::access_tracking::AccessTimes& AccessTimes)
+{
+ RwLock::SharedLockScope _(m_Lock);
+
+ for (const auto& Kv : AccessTimes.Buckets)
+ {
+ if (auto It = m_Buckets.find(Kv.first); It != m_Buckets.end())
+ {
+ CacheBucket& Bucket = *It->second;
+ Bucket.UpdateAccessTimes(Kv.second);
+ }
+ }
+}
+
+void
+ZenCacheDiskLayer::CacheBucket::PutStandaloneCacheValue(const IoHash& HashKey, const ZenCacheValue& Value)
+{
+ uint64_t NewFileSize = Value.Value.Size();
+
+ TemporaryFile DataFile;
+
+ std::error_code Ec;
+ DataFile.CreateTemporary(m_BucketDir.c_str(), Ec);
+ if (Ec)
+ {
+ throw std::system_error(Ec, fmt::format("Failed to open temporary file for put in '{}'", m_BucketDir));
+ }
+
+ bool CleanUpTempFile = false;
+ auto __ = MakeGuard([&] {
+ if (CleanUpTempFile)
+ {
+ std::error_code Ec;
+ std::filesystem::remove(DataFile.GetPath(), Ec);
+ if (Ec)
+ {
+ ZEN_WARN("Failed to clean up temporary file '{}' for put in '{}', reason '{}'",
+ DataFile.GetPath(),
+ m_BucketDir,
+ Ec.message());
+ }
+ }
+ });
+
+ DataFile.WriteAll(Value.Value, Ec);
+ if (Ec)
+ {
+ throw std::system_error(Ec,
+ fmt::format("Failed to write payload ({} bytes) to temporary file '{}' for put in '{}'",
+ NiceBytes(NewFileSize),
+ DataFile.GetPath().string(),
+ m_BucketDir));
+ }
+
+ ExtendablePathBuilder<256> DataFilePath;
+ BuildPath(DataFilePath, HashKey);
+ std::filesystem::path FsPath{DataFilePath.ToPath()};
+
+ RwLock::ExclusiveLockScope ValueLock(LockForHash(HashKey));
+
+ // We do a speculative remove of the file instead of probing with a exists call and check the error code instead
+ std::filesystem::remove(FsPath, Ec);
+ if (Ec)
+ {
+ if (Ec.value() != ENOENT)
+ {
+ ZEN_WARN("Failed to remove file '{}' for put in '{}', reason: '{}', retrying.", FsPath, m_BucketDir, Ec.message());
+ Sleep(100);
+ Ec.clear();
+ std::filesystem::remove(FsPath, Ec);
+ if (Ec && Ec.value() != ENOENT)
+ {
+ throw std::system_error(Ec, fmt::format("Failed to remove file '{}' for put in '{}'", FsPath, m_BucketDir));
+ }
+ }
+ }
+
+ DataFile.MoveTemporaryIntoPlace(FsPath, Ec);
+ if (Ec)
+ {
+ CreateDirectories(FsPath.parent_path());
+ Ec.clear();
+
+ // Try again
+ DataFile.MoveTemporaryIntoPlace(FsPath, Ec);
+ if (Ec)
+ {
+ ZEN_WARN("Failed to finalize file '{}', moving from '{}' for put in '{}', reason: '{}', retrying.",
+ FsPath,
+ DataFile.GetPath(),
+ m_BucketDir,
+ Ec.message());
+ Sleep(100);
+ Ec.clear();
+ DataFile.MoveTemporaryIntoPlace(FsPath, Ec);
+ if (Ec)
+ {
+ throw std::system_error(
+ Ec,
+ fmt::format("Failed to finalize file '{}', moving from '{}' for put in '{}'", FsPath, DataFile.GetPath(), m_BucketDir));
+ }
+ }
+ }
+
+ // Once we have called MoveTemporaryIntoPlace automatic clean up the temp file
+ // will be disabled as the file handle has already been closed
+ CleanUpTempFile = false;
+
+ uint8_t EntryFlags = DiskLocation::kStandaloneFile;
+
+ if (Value.Value.GetContentType() == ZenContentType::kCbObject)
+ {
+ EntryFlags |= DiskLocation::kStructured;
+ }
+ else if (Value.Value.GetContentType() == ZenContentType::kCompressedBinary)
+ {
+ EntryFlags |= DiskLocation::kCompressed;
+ }
+
+ DiskLocation Loc(NewFileSize, EntryFlags);
+
+ RwLock::ExclusiveLockScope _(m_IndexLock);
+ if (auto It = m_Index.find(HashKey); It == m_Index.end())
+ {
+ // Previously unknown object
+ size_t EntryIndex = m_Payloads.size();
+ m_Payloads.emplace_back(BucketPayload{.Location = Loc, .RawSize = Value.RawSize, .RawHash = Value.RawHash});
+ m_AccessTimes.emplace_back(GcClock::TickCount());
+ m_Index.insert_or_assign(HashKey, EntryIndex);
+ }
+ else
+ {
+ // TODO: should check if write is idempotent and bail out if it is?
+ size_t EntryIndex = It.value();
+ ZEN_ASSERT_SLOW(EntryIndex < m_AccessTimes.size());
+ m_Payloads[EntryIndex] = BucketPayload{.Location = Loc, .RawSize = Value.RawSize, .RawHash = Value.RawHash};
+ m_AccessTimes.emplace_back(GcClock::TickCount());
+ m_TotalStandaloneSize.fetch_sub(Loc.Size(), std::memory_order::relaxed);
+ }
+
+ m_SlogFile.Append({.Key = HashKey, .Location = Loc});
+ m_TotalStandaloneSize.fetch_add(NewFileSize, std::memory_order::relaxed);
+}
+
+void
+ZenCacheDiskLayer::CacheBucket::PutInlineCacheValue(const IoHash& HashKey, const ZenCacheValue& Value)
+{
+ uint8_t EntryFlags = 0;
+
+ if (Value.Value.GetContentType() == ZenContentType::kCbObject)
+ {
+ EntryFlags |= DiskLocation::kStructured;
+ }
+ else if (Value.Value.GetContentType() == ZenContentType::kCompressedBinary)
+ {
+ EntryFlags |= DiskLocation::kCompressed;
+ }
+
+ m_BlockStore.WriteChunk(Value.Value.Data(), Value.Value.Size(), m_PayloadAlignment, [&](const BlockStoreLocation& BlockStoreLocation) {
+ DiskLocation Location(BlockStoreLocation, m_PayloadAlignment, EntryFlags);
+ m_SlogFile.Append({.Key = HashKey, .Location = Location});
+
+ RwLock::ExclusiveLockScope _(m_IndexLock);
+ if (auto It = m_Index.find(HashKey); It != m_Index.end())
+ {
+ // TODO: should check if write is idempotent and bail out if it is?
+ // this would requiring comparing contents on disk unless we add a
+ // content hash to the index entry
+ size_t EntryIndex = It.value();
+ ZEN_ASSERT_SLOW(EntryIndex < m_AccessTimes.size());
+ m_Payloads[EntryIndex] = (BucketPayload{.Location = Location, .RawSize = Value.RawSize, .RawHash = Value.RawHash});
+ m_AccessTimes[EntryIndex] = GcClock::TickCount();
+ }
+ else
+ {
+ size_t EntryIndex = m_Payloads.size();
+ m_Payloads.emplace_back(BucketPayload{.Location = Location, .RawSize = Value.RawSize, .RawHash = Value.RawHash});
+ m_AccessTimes.emplace_back(GcClock::TickCount());
+ m_Index.insert_or_assign(HashKey, EntryIndex);
+ }
+ });
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+ZenCacheDiskLayer::ZenCacheDiskLayer(const std::filesystem::path& RootDir) : m_RootDir(RootDir)
+{
+}
+
+ZenCacheDiskLayer::~ZenCacheDiskLayer() = default;
+
+bool
+ZenCacheDiskLayer::Get(std::string_view InBucket, const IoHash& HashKey, ZenCacheValue& OutValue)
+{
+ const auto BucketName = std::string(InBucket);
+ CacheBucket* Bucket = nullptr;
+
+ {
+ RwLock::SharedLockScope _(m_Lock);
+
+ auto It = m_Buckets.find(BucketName);
+
+ if (It != m_Buckets.end())
+ {
+ Bucket = It->second.get();
+ }
+ }
+
+ if (Bucket == nullptr)
+ {
+ // Bucket needs to be opened/created
+
+ RwLock::ExclusiveLockScope _(m_Lock);
+
+ if (auto It = m_Buckets.find(BucketName); It != m_Buckets.end())
+ {
+ Bucket = It->second.get();
+ }
+ else
+ {
+ auto InsertResult = m_Buckets.emplace(BucketName, std::make_unique<CacheBucket>(BucketName));
+ Bucket = InsertResult.first->second.get();
+
+ std::filesystem::path BucketPath = m_RootDir;
+ BucketPath /= BucketName;
+
+ if (!Bucket->OpenOrCreate(BucketPath))
+ {
+ m_Buckets.erase(InsertResult.first);
+ return false;
+ }
+ }
+ }
+
+ ZEN_ASSERT(Bucket != nullptr);
+ return Bucket->Get(HashKey, OutValue);
+}
+
+void
+ZenCacheDiskLayer::Put(std::string_view InBucket, const IoHash& HashKey, const ZenCacheValue& Value)
+{
+ const auto BucketName = std::string(InBucket);
+ CacheBucket* Bucket = nullptr;
+
+ {
+ RwLock::SharedLockScope _(m_Lock);
+
+ auto It = m_Buckets.find(BucketName);
+
+ if (It != m_Buckets.end())
+ {
+ Bucket = It->second.get();
+ }
+ }
+
+ if (Bucket == nullptr)
+ {
+ // New bucket needs to be created
+
+ RwLock::ExclusiveLockScope _(m_Lock);
+
+ if (auto It = m_Buckets.find(BucketName); It != m_Buckets.end())
+ {
+ Bucket = It->second.get();
+ }
+ else
+ {
+ auto InsertResult = m_Buckets.emplace(BucketName, std::make_unique<CacheBucket>(BucketName));
+ Bucket = InsertResult.first->second.get();
+
+ std::filesystem::path BucketPath = m_RootDir;
+ BucketPath /= BucketName;
+
+ try
+ {
+ if (!Bucket->OpenOrCreate(BucketPath))
+ {
+ ZEN_WARN("Found directory '{}' in our base directory '{}' but it is not a valid bucket", BucketName, m_RootDir);
+ m_Buckets.erase(InsertResult.first);
+ return;
+ }
+ }
+ catch (const std::exception& Err)
+ {
+ ZEN_ERROR("creating bucket '{}' in '{}' FAILED, reason: '{}'", BucketName, BucketPath, Err.what());
+ return;
+ }
+ }
+ }
+
+ ZEN_ASSERT(Bucket != nullptr);
+
+ Bucket->Put(HashKey, Value);
+}
+
+void
+ZenCacheDiskLayer::DiscoverBuckets()
+{
+ DirectoryContent DirContent;
+ GetDirectoryContent(m_RootDir, DirectoryContent::IncludeDirsFlag, DirContent);
+
+ // Initialize buckets
+
+ RwLock::ExclusiveLockScope _(m_Lock);
+
+ for (const std::filesystem::path& BucketPath : DirContent.Directories)
+ {
+ const std::string BucketName = PathToUtf8(BucketPath.stem());
+ // New bucket needs to be created
+ if (auto It = m_Buckets.find(BucketName); It != m_Buckets.end())
+ {
+ continue;
+ }
+
+ auto InsertResult = m_Buckets.emplace(BucketName, std::make_unique<CacheBucket>(BucketName));
+ CacheBucket& Bucket = *InsertResult.first->second;
+
+ try
+ {
+ if (!Bucket.OpenOrCreate(BucketPath, /* AllowCreate */ false))
+ {
+ ZEN_WARN("Found directory '{}' in our base directory '{}' but it is not a valid bucket", BucketName, m_RootDir);
+
+ m_Buckets.erase(InsertResult.first);
+ continue;
+ }
+ }
+ catch (const std::exception& Err)
+ {
+ ZEN_ERROR("creating bucket '{}' in '{}' FAILED, reason: '{}'", BucketName, BucketPath, Err.what());
+ return;
+ }
+ ZEN_INFO("Discovered bucket '{}'", BucketName);
+ }
+}
+
+bool
+ZenCacheDiskLayer::DropBucket(std::string_view InBucket)
+{
+ RwLock::ExclusiveLockScope _(m_Lock);
+
+ auto It = m_Buckets.find(std::string(InBucket));
+
+ if (It != m_Buckets.end())
+ {
+ CacheBucket& Bucket = *It->second;
+ m_DroppedBuckets.push_back(std::move(It->second));
+ m_Buckets.erase(It);
+
+ return Bucket.Drop();
+ }
+
+ // Make sure we remove the folder even if we don't know about the bucket
+ std::filesystem::path BucketPath = m_RootDir;
+ BucketPath /= std::string(InBucket);
+ return MoveAndDeleteDirectory(BucketPath);
+}
+
+bool
+ZenCacheDiskLayer::Drop()
+{
+ RwLock::ExclusiveLockScope _(m_Lock);
+
+ std::vector<std::unique_ptr<CacheBucket>> Buckets;
+ Buckets.reserve(m_Buckets.size());
+ while (!m_Buckets.empty())
+ {
+ const auto& It = m_Buckets.begin();
+ CacheBucket& Bucket = *It->second;
+ m_DroppedBuckets.push_back(std::move(It->second));
+ m_Buckets.erase(It->first);
+ if (!Bucket.Drop())
+ {
+ return false;
+ }
+ }
+ return MoveAndDeleteDirectory(m_RootDir);
+}
+
+void
+ZenCacheDiskLayer::Flush()
+{
+ std::vector<CacheBucket*> Buckets;
+
+ {
+ RwLock::SharedLockScope _(m_Lock);
+ Buckets.reserve(m_Buckets.size());
+ for (auto& Kv : m_Buckets)
+ {
+ CacheBucket* Bucket = Kv.second.get();
+ Buckets.push_back(Bucket);
+ }
+ }
+
+ for (auto& Bucket : Buckets)
+ {
+ Bucket->Flush();
+ }
+}
+
+void
+ZenCacheDiskLayer::Scrub(ScrubContext& Ctx)
+{
+ RwLock::SharedLockScope _(m_Lock);
+
+ for (auto& Kv : m_Buckets)
+ {
+ CacheBucket& Bucket = *Kv.second;
+ Bucket.Scrub(Ctx);
+ }
+}
+
+void
+ZenCacheDiskLayer::GatherReferences(GcContext& GcCtx)
+{
+ RwLock::SharedLockScope _(m_Lock);
+
+ for (auto& Kv : m_Buckets)
+ {
+ CacheBucket& Bucket = *Kv.second;
+ Bucket.GatherReferences(GcCtx);
+ }
+}
+
+uint64_t
+ZenCacheDiskLayer::TotalSize() const
+{
+ uint64_t TotalSize{};
+ RwLock::SharedLockScope _(m_Lock);
+
+ for (auto& Kv : m_Buckets)
+ {
+ TotalSize += Kv.second->TotalSize();
+ }
+
+ return TotalSize;
+}
+
+ZenCacheDiskLayer::Info
+ZenCacheDiskLayer::GetInfo() const
+{
+ ZenCacheDiskLayer::Info Info = {.Config = {.RootDir = m_RootDir}, .TotalSize = TotalSize()};
+
+ RwLock::SharedLockScope _(m_Lock);
+ Info.BucketNames.reserve(m_Buckets.size());
+ for (auto& Kv : m_Buckets)
+ {
+ Info.BucketNames.push_back(Kv.first);
+ Info.EntryCount += Kv.second->EntryCount();
+ }
+ return Info;
+}
+
+std::optional<ZenCacheDiskLayer::BucketInfo>
+ZenCacheDiskLayer::GetBucketInfo(std::string_view Bucket) const
+{
+ RwLock::SharedLockScope _(m_Lock);
+
+ if (auto It = m_Buckets.find(std::string(Bucket)); It != m_Buckets.end())
+ {
+ return ZenCacheDiskLayer::BucketInfo{.EntryCount = It->second->EntryCount(), .TotalSize = It->second->TotalSize()};
+ }
+ return {};
+}
+
+CacheValueDetails::NamespaceDetails
+ZenCacheDiskLayer::GetValueDetails(const std::string_view BucketFilter, const std::string_view ValueFilter) const
+{
+ RwLock::SharedLockScope _(m_Lock);
+ CacheValueDetails::NamespaceDetails Details;
+ if (BucketFilter.empty())
+ {
+ Details.Buckets.reserve(BucketFilter.empty() ? m_Buckets.size() : 1);
+ for (auto& Kv : m_Buckets)
+ {
+ Details.Buckets[Kv.first] = Kv.second->GetValueDetails(ValueFilter);
+ }
+ }
+ else if (auto It = m_Buckets.find(std::string(BucketFilter)); It != m_Buckets.end())
+ {
+ Details.Buckets[It->first] = It->second->GetValueDetails(ValueFilter);
+ }
+ return Details;
+}
+
+//////////////////////////// ZenCacheStore
+
+static constexpr std::string_view UE4DDCNamespaceName = "ue4.ddc";
+
+ZenCacheStore::ZenCacheStore(GcManager& Gc, const Configuration& Configuration) : m_Gc(Gc), m_Configuration(Configuration)
+{
+ CreateDirectories(m_Configuration.BasePath);
+
+ DirectoryContent DirContent;
+ GetDirectoryContent(m_Configuration.BasePath, DirectoryContent::IncludeDirsFlag, DirContent);
+
+ std::vector<std::string> Namespaces;
+ for (const std::filesystem::path& DirPath : DirContent.Directories)
+ {
+ std::string DirName = PathToUtf8(DirPath.filename());
+ if (DirName.starts_with(NamespaceDiskPrefix))
+ {
+ Namespaces.push_back(DirName.substr(NamespaceDiskPrefix.length()));
+ continue;
+ }
+ }
+
+ ZEN_INFO("Found {} namespaces in '{}'", Namespaces.size(), m_Configuration.BasePath);
+
+ if (std::find(Namespaces.begin(), Namespaces.end(), UE4DDCNamespaceName) == Namespaces.end())
+ {
+ // default (unspecified) and ue4-ddc namespace points to the same namespace instance
+
+ std::filesystem::path DefaultNamespaceFolder =
+ m_Configuration.BasePath / fmt::format("{}{}", NamespaceDiskPrefix, UE4DDCNamespaceName);
+ CreateDirectories(DefaultNamespaceFolder);
+ Namespaces.push_back(std::string(UE4DDCNamespaceName));
+ }
+
+ for (const std::string& NamespaceName : Namespaces)
+ {
+ m_Namespaces[NamespaceName] =
+ std::make_unique<ZenCacheNamespace>(Gc, m_Configuration.BasePath / fmt::format("{}{}", NamespaceDiskPrefix, NamespaceName));
+ }
+}
+
+ZenCacheStore::~ZenCacheStore()
+{
+ m_Namespaces.clear();
+}
+
+bool
+ZenCacheStore::Get(std::string_view Namespace, std::string_view Bucket, const IoHash& HashKey, ZenCacheValue& OutValue)
+{
+ if (ZenCacheNamespace* Store = GetNamespace(Namespace); Store)
+ {
+ return Store->Get(Bucket, HashKey, OutValue);
+ }
+ ZEN_WARN("request for unknown namespace '{}' in ZenCacheStore::Get, bucket '{}', key '{}'", Namespace, Bucket, HashKey.ToHexString());
+
+ return false;
+}
+
+void
+ZenCacheStore::Put(std::string_view Namespace, std::string_view Bucket, const IoHash& HashKey, const ZenCacheValue& Value)
+{
+ if (ZenCacheNamespace* Store = GetNamespace(Namespace); Store)
+ {
+ return Store->Put(Bucket, HashKey, Value);
+ }
+ ZEN_WARN("request for unknown namespace '{}' in ZenCacheStore::Put, bucket '{}', key '{}'", Namespace, Bucket, HashKey.ToHexString());
+}
+
+bool
+ZenCacheStore::DropBucket(std::string_view Namespace, std::string_view Bucket)
+{
+ if (ZenCacheNamespace* Store = GetNamespace(Namespace); Store)
+ {
+ return Store->DropBucket(Bucket);
+ }
+ ZEN_WARN("request for unknown namespace '{}' in ZenCacheStore::DropBucket, bucket '{}'", Namespace, Bucket);
+ return false;
+}
+
+bool
+ZenCacheStore::DropNamespace(std::string_view InNamespace)
+{
+ RwLock::SharedLockScope _(m_NamespacesLock);
+ if (auto It = m_Namespaces.find(std::string(InNamespace)); It != m_Namespaces.end())
+ {
+ ZenCacheNamespace& Namespace = *It->second;
+ m_DroppedNamespaces.push_back(std::move(It->second));
+ m_Namespaces.erase(It);
+ return Namespace.Drop();
+ }
+ ZEN_WARN("request for unknown namespace '{}' in ZenCacheStore::DropNamespace", InNamespace);
+ return false;
+}
+
+void
+ZenCacheStore::Flush()
+{
+ IterateNamespaces([&](std::string_view, ZenCacheNamespace& Store) { Store.Flush(); });
+}
+
+void
+ZenCacheStore::Scrub(ScrubContext& Ctx)
+{
+ IterateNamespaces([&](std::string_view, ZenCacheNamespace& Store) { Store.Scrub(Ctx); });
+}
+
+CacheValueDetails
+ZenCacheStore::GetValueDetails(const std::string_view NamespaceFilter,
+ const std::string_view BucketFilter,
+ const std::string_view ValueFilter) const
+{
+ CacheValueDetails Details;
+ if (NamespaceFilter.empty())
+ {
+ IterateNamespaces([&](std::string_view Namespace, ZenCacheNamespace& Store) {
+ Details.Namespaces[std::string(Namespace)] = Store.GetValueDetails(BucketFilter, ValueFilter);
+ });
+ }
+ else if (const ZenCacheNamespace* Store = FindNamespace(NamespaceFilter); Store != nullptr)
+ {
+ Details.Namespaces[std::string(NamespaceFilter)] = Store->GetValueDetails(BucketFilter, ValueFilter);
+ }
+ return Details;
+}
+
+ZenCacheNamespace*
+ZenCacheStore::GetNamespace(std::string_view Namespace)
+{
+ RwLock::SharedLockScope _(m_NamespacesLock);
+ if (auto It = m_Namespaces.find(std::string(Namespace)); It != m_Namespaces.end())
+ {
+ return It->second.get();
+ }
+ if (Namespace == DefaultNamespace)
+ {
+ if (auto It = m_Namespaces.find(std::string(UE4DDCNamespaceName)); It != m_Namespaces.end())
+ {
+ return It->second.get();
+ }
+ }
+ _.ReleaseNow();
+
+ if (!m_Configuration.AllowAutomaticCreationOfNamespaces)
+ {
+ return nullptr;
+ }
+
+ RwLock::ExclusiveLockScope __(m_NamespacesLock);
+ if (auto It = m_Namespaces.find(std::string(Namespace)); It != m_Namespaces.end())
+ {
+ return It->second.get();
+ }
+
+ auto NewNamespace = m_Namespaces.insert_or_assign(
+ std::string(Namespace),
+ std::make_unique<ZenCacheNamespace>(m_Gc, m_Configuration.BasePath / fmt::format("{}{}", NamespaceDiskPrefix, Namespace)));
+ return NewNamespace.first->second.get();
+}
+
+const ZenCacheNamespace*
+ZenCacheStore::FindNamespace(std::string_view Namespace) const
+{
+ RwLock::SharedLockScope _(m_NamespacesLock);
+ if (auto It = m_Namespaces.find(std::string(Namespace)); It != m_Namespaces.end())
+ {
+ return It->second.get();
+ }
+ if (Namespace == DefaultNamespace)
+ {
+ if (auto It = m_Namespaces.find(std::string(UE4DDCNamespaceName)); It != m_Namespaces.end())
+ {
+ return It->second.get();
+ }
+ }
+ return nullptr;
+}
+
+void
+ZenCacheStore::IterateNamespaces(const std::function<void(std::string_view Namespace, ZenCacheNamespace& Store)>& Callback) const
+{
+ std::vector<std::pair<std::string, ZenCacheNamespace&>> Namespaces;
+ {
+ RwLock::SharedLockScope _(m_NamespacesLock);
+ Namespaces.reserve(m_Namespaces.size());
+ for (const auto& Entry : m_Namespaces)
+ {
+ if (Entry.first == DefaultNamespace)
+ {
+ continue;
+ }
+ Namespaces.push_back({Entry.first, *Entry.second});
+ }
+ }
+ for (auto& Entry : Namespaces)
+ {
+ Callback(Entry.first, Entry.second);
+ }
+}
+
+GcStorageSize
+ZenCacheStore::StorageSize() const
+{
+ GcStorageSize Size;
+ IterateNamespaces([&](std::string_view, ZenCacheNamespace& Store) {
+ GcStorageSize StoreSize = Store.StorageSize();
+ Size.MemorySize += StoreSize.MemorySize;
+ Size.DiskSize += StoreSize.DiskSize;
+ });
+ return Size;
+}
+
+ZenCacheStore::Info
+ZenCacheStore::GetInfo() const
+{
+ ZenCacheStore::Info Info = {.Config = m_Configuration, .StorageSize = StorageSize()};
+
+ IterateNamespaces([&Info](std::string_view NamespaceName, ZenCacheNamespace& Namespace) {
+ Info.NamespaceNames.push_back(std::string(NamespaceName));
+ ZenCacheNamespace::Info NamespaceInfo = Namespace.GetInfo();
+ Info.DiskEntryCount += NamespaceInfo.DiskLayerInfo.EntryCount;
+ Info.MemoryEntryCount += NamespaceInfo.MemoryLayerInfo.EntryCount;
+ });
+
+ return Info;
+}
+
+std::optional<ZenCacheNamespace::Info>
+ZenCacheStore::GetNamespaceInfo(std::string_view NamespaceName)
+{
+ if (const ZenCacheNamespace* Namespace = FindNamespace(NamespaceName); Namespace)
+ {
+ return Namespace->GetInfo();
+ }
+ return {};
+}
+
+std::optional<ZenCacheNamespace::BucketInfo>
+ZenCacheStore::GetBucketInfo(std::string_view NamespaceName, std::string_view BucketName)
+{
+ if (const ZenCacheNamespace* Namespace = FindNamespace(NamespaceName); Namespace)
+ {
+ return Namespace->GetBucketInfo(BucketName);
+ }
+ return {};
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+#if ZEN_WITH_TESTS
+
+using namespace std::literals;
+
+namespace testutils {
+ IoHash CreateKey(size_t KeyValue) { return IoHash::HashBuffer(&KeyValue, sizeof(size_t)); }
+
+ IoBuffer CreateBinaryCacheValue(uint64_t Size)
+ {
+ static std::random_device rd;
+ static std::mt19937 g(rd());
+
+ std::vector<uint8_t> Values;
+ Values.resize(Size);
+ for (size_t Idx = 0; Idx < Size; ++Idx)
+ {
+ Values[Idx] = static_cast<uint8_t>(Idx);
+ }
+ std::shuffle(Values.begin(), Values.end(), g);
+
+ IoBuffer Buf(IoBuffer::Clone, Values.data(), Values.size());
+ Buf.SetContentType(ZenContentType::kBinary);
+ return Buf;
+ };
+
+} // namespace testutils
+
+TEST_CASE("z$.store")
+{
+ ScopedTemporaryDirectory TempDir;
+
+ GcManager Gc;
+
+ ZenCacheNamespace Zcs(Gc, TempDir.Path() / "cache");
+
+ const int kIterationCount = 100;
+
+ for (int i = 0; i < kIterationCount; ++i)
+ {
+ const IoHash Key = IoHash::HashBuffer(&i, sizeof i);
+
+ CbObjectWriter Cbo;
+ Cbo << "hey" << i;
+ CbObject Obj = Cbo.Save();
+
+ ZenCacheValue Value;
+ Value.Value = Obj.GetBuffer().AsIoBuffer();
+ Value.Value.SetContentType(ZenContentType::kCbObject);
+
+ Zcs.Put("test_bucket"sv, Key, Value);
+ }
+
+ for (int i = 0; i < kIterationCount; ++i)
+ {
+ const IoHash Key = IoHash::HashBuffer(&i, sizeof i);
+
+ ZenCacheValue Value;
+ Zcs.Get("test_bucket"sv, Key, /* out */ Value);
+
+ REQUIRE(Value.Value);
+ CHECK(Value.Value.GetContentType() == ZenContentType::kCbObject);
+ CHECK_EQ(ValidateCompactBinary(Value.Value, CbValidateMode::All), CbValidateError::None);
+ CbObject Obj = LoadCompactBinaryObject(Value.Value);
+ CHECK_EQ(Obj["hey"].AsInt32(), i);
+ }
+}
+
+TEST_CASE("z$.size")
+{
+ const auto CreateCacheValue = [](size_t Size) -> CbObject {
+ std::vector<uint8_t> Buf;
+ Buf.resize(Size);
+
+ CbObjectWriter Writer;
+ Writer.AddBinary("Binary"sv, Buf.data(), Buf.size());
+ return Writer.Save();
+ };
+
+ SUBCASE("mem/disklayer")
+ {
+ const size_t Count = 16;
+ ScopedTemporaryDirectory TempDir;
+
+ GcStorageSize CacheSize;
+
+ {
+ GcManager Gc;
+ ZenCacheNamespace Zcs(Gc, TempDir.Path() / "cache");
+
+ CbObject CacheValue = CreateCacheValue(Zcs.DiskLayerThreshold() - 256);
+
+ IoBuffer Buffer = CacheValue.GetBuffer().AsIoBuffer();
+ Buffer.SetContentType(ZenContentType::kCbObject);
+
+ for (size_t Key = 0; Key < Count; ++Key)
+ {
+ const size_t Bucket = Key % 4;
+ Zcs.Put(fmt::format("test_bucket-{}", Bucket), IoHash::HashBuffer(&Key, sizeof(uint32_t)), ZenCacheValue{.Value = Buffer});
+ }
+
+ CacheSize = Zcs.StorageSize();
+ CHECK_LE(CacheValue.GetSize() * Count, CacheSize.DiskSize);
+ CHECK_LE(CacheValue.GetSize() * Count, CacheSize.MemorySize);
+ }
+
+ {
+ GcManager Gc;
+ ZenCacheNamespace Zcs(Gc, TempDir.Path() / "cache");
+
+ const GcStorageSize SerializedSize = Zcs.StorageSize();
+ CHECK_EQ(SerializedSize.MemorySize, 0);
+ CHECK_LE(SerializedSize.DiskSize, CacheSize.DiskSize);
+
+ for (size_t Bucket = 0; Bucket < 4; ++Bucket)
+ {
+ Zcs.DropBucket(fmt::format("test_bucket-{}", Bucket));
+ }
+ CHECK_EQ(0, Zcs.StorageSize().DiskSize);
+ }
+ }
+
+ SUBCASE("disklayer")
+ {
+ const size_t Count = 16;
+ ScopedTemporaryDirectory TempDir;
+
+ GcStorageSize CacheSize;
+
+ {
+ GcManager Gc;
+ ZenCacheNamespace Zcs(Gc, TempDir.Path() / "cache");
+
+ CbObject CacheValue = CreateCacheValue(Zcs.DiskLayerThreshold() + 64);
+
+ IoBuffer Buffer = CacheValue.GetBuffer().AsIoBuffer();
+ Buffer.SetContentType(ZenContentType::kCbObject);
+
+ for (size_t Key = 0; Key < Count; ++Key)
+ {
+ const size_t Bucket = Key % 4;
+ Zcs.Put(fmt::format("test_bucket-{}", Bucket), IoHash::HashBuffer(&Key, sizeof(uint32_t)), {.Value = Buffer});
+ }
+
+ CacheSize = Zcs.StorageSize();
+ CHECK_LE(CacheValue.GetSize() * Count, CacheSize.DiskSize);
+ CHECK_EQ(0, CacheSize.MemorySize);
+ }
+
+ {
+ GcManager Gc;
+ ZenCacheNamespace Zcs(Gc, TempDir.Path() / "cache");
+
+ const GcStorageSize SerializedSize = Zcs.StorageSize();
+ CHECK_EQ(SerializedSize.MemorySize, 0);
+ CHECK_LE(SerializedSize.DiskSize, CacheSize.DiskSize);
+
+ for (size_t Bucket = 0; Bucket < 4; ++Bucket)
+ {
+ Zcs.DropBucket(fmt::format("test_bucket-{}", Bucket));
+ }
+ CHECK_EQ(0, Zcs.StorageSize().DiskSize);
+ }
+ }
+}
+
+TEST_CASE("z$.gc")
+{
+ using namespace testutils;
+
+ SUBCASE("gather references does NOT add references for expired cache entries")
+ {
+ ScopedTemporaryDirectory TempDir;
+ std::vector<IoHash> Cids{CreateKey(1), CreateKey(2), CreateKey(3)};
+
+ const auto CollectAndFilter = [](GcManager& Gc,
+ GcClock::TimePoint Time,
+ GcClock::Duration MaxDuration,
+ std::span<const IoHash> Cids,
+ std::vector<IoHash>& OutKeep) {
+ GcContext GcCtx(Time - MaxDuration);
+ Gc.CollectGarbage(GcCtx);
+ OutKeep.clear();
+ GcCtx.FilterCids(Cids, [&OutKeep](const IoHash& Hash) { OutKeep.push_back(Hash); });
+ };
+
+ {
+ GcManager Gc;
+ ZenCacheNamespace Zcs(Gc, TempDir.Path() / "cache");
+ const auto Bucket = "teardrinker"sv;
+
+ // Create a cache record
+ const IoHash Key = CreateKey(42);
+ CbObjectWriter Record;
+ Record << "Key"sv
+ << "SomeRecord"sv;
+
+ for (size_t Idx = 0; auto& Cid : Cids)
+ {
+ Record.AddBinaryAttachment(fmt::format("attachment-{}", Idx++), Cid);
+ }
+
+ IoBuffer Buffer = Record.Save().GetBuffer().AsIoBuffer();
+ Buffer.SetContentType(ZenContentType::kCbObject);
+
+ Zcs.Put(Bucket, Key, {.Value = Buffer});
+
+ std::vector<IoHash> Keep;
+
+ // Collect garbage with 1 hour max cache duration
+ {
+ CollectAndFilter(Gc, GcClock::Now(), std::chrono::hours(1), Cids, Keep);
+ CHECK_EQ(Cids.size(), Keep.size());
+ }
+
+ // Move forward in time
+ {
+ CollectAndFilter(Gc, GcClock::Now() + std::chrono::hours(2), std::chrono::hours(1), Cids, Keep);
+ CHECK_EQ(0, Keep.size());
+ }
+ }
+
+ // Expect timestamps to be serialized
+ {
+ GcManager Gc;
+ ZenCacheNamespace Zcs(Gc, TempDir.Path() / "cache");
+ std::vector<IoHash> Keep;
+
+ // Collect garbage with 1 hour max cache duration
+ {
+ CollectAndFilter(Gc, GcClock::Now(), std::chrono::hours(1), Cids, Keep);
+ CHECK_EQ(3, Keep.size());
+ }
+
+ // Move forward in time
+ {
+ CollectAndFilter(Gc, GcClock::Now() + std::chrono::hours(2), std::chrono::hours(1), Cids, Keep);
+ CHECK_EQ(0, Keep.size());
+ }
+ }
+ }
+
+ SUBCASE("gc removes standalone values")
+ {
+ ScopedTemporaryDirectory TempDir;
+ GcManager Gc;
+ ZenCacheNamespace Zcs(Gc, TempDir.Path() / "cache");
+ const auto Bucket = "fortysixandtwo"sv;
+ const GcClock::TimePoint CurrentTime = GcClock::Now();
+
+ std::vector<IoHash> Keys{CreateKey(1), CreateKey(2), CreateKey(3)};
+
+ for (const auto& Key : Keys)
+ {
+ IoBuffer Value = testutils::CreateBinaryCacheValue(128 << 10);
+ Zcs.Put(Bucket, Key, {.Value = Value});
+ }
+
+ {
+ GcContext GcCtx(CurrentTime - std::chrono::hours(46));
+
+ Gc.CollectGarbage(GcCtx);
+
+ for (const auto& Key : Keys)
+ {
+ ZenCacheValue CacheValue;
+ const bool Exists = Zcs.Get(Bucket, Key, CacheValue);
+ CHECK(Exists);
+ }
+ }
+
+ // Move forward in time and collect again
+ {
+ GcContext GcCtx(CurrentTime + std::chrono::minutes(2));
+ Gc.CollectGarbage(GcCtx);
+
+ for (const auto& Key : Keys)
+ {
+ ZenCacheValue CacheValue;
+ const bool Exists = Zcs.Get(Bucket, Key, CacheValue);
+ CHECK(!Exists);
+ }
+
+ CHECK_EQ(0, Zcs.StorageSize().DiskSize);
+ }
+ }
+
+ SUBCASE("gc removes small objects")
+ {
+ ScopedTemporaryDirectory TempDir;
+ GcManager Gc;
+ ZenCacheNamespace Zcs(Gc, TempDir.Path() / "cache");
+ const auto Bucket = "rightintwo"sv;
+
+ std::vector<IoHash> Keys{CreateKey(1), CreateKey(2), CreateKey(3)};
+
+ for (const auto& Key : Keys)
+ {
+ IoBuffer Value = testutils::CreateBinaryCacheValue(128);
+ Zcs.Put(Bucket, Key, {.Value = Value});
+ }
+
+ {
+ GcContext GcCtx(GcClock::Now() - std::chrono::hours(2));
+ GcCtx.CollectSmallObjects(true);
+
+ Gc.CollectGarbage(GcCtx);
+
+ for (const auto& Key : Keys)
+ {
+ ZenCacheValue CacheValue;
+ const bool Exists = Zcs.Get(Bucket, Key, CacheValue);
+ CHECK(Exists);
+ }
+ }
+
+ // Move forward in time and collect again
+ {
+ GcContext GcCtx(GcClock::Now() + std::chrono::minutes(2));
+ GcCtx.CollectSmallObjects(true);
+
+ Zcs.Flush();
+ Gc.CollectGarbage(GcCtx);
+
+ for (const auto& Key : Keys)
+ {
+ ZenCacheValue CacheValue;
+ const bool Exists = Zcs.Get(Bucket, Key, CacheValue);
+ CHECK(!Exists);
+ }
+
+ CHECK_EQ(0, Zcs.StorageSize().DiskSize);
+ }
+ }
+}
+
+TEST_CASE("z$.threadedinsert") // * doctest::skip(true))
+{
+ // for (uint32_t i = 0; i < 100; ++i)
+ {
+ ScopedTemporaryDirectory TempDir;
+
+ const uint64_t kChunkSize = 1048;
+ const int32_t kChunkCount = 8192;
+
+ struct Chunk
+ {
+ std::string Bucket;
+ IoBuffer Buffer;
+ };
+ std::unordered_map<IoHash, Chunk, IoHash::Hasher> Chunks;
+ Chunks.reserve(kChunkCount);
+
+ const std::string Bucket1 = "rightinone";
+ const std::string Bucket2 = "rightintwo";
+
+ for (int32_t Idx = 0; Idx < kChunkCount; ++Idx)
+ {
+ while (true)
+ {
+ IoBuffer Chunk = testutils::CreateBinaryCacheValue(kChunkSize);
+ IoHash Hash = HashBuffer(Chunk);
+ if (Chunks.contains(Hash))
+ {
+ continue;
+ }
+ Chunks[Hash] = {.Bucket = Bucket1, .Buffer = Chunk};
+ break;
+ }
+ while (true)
+ {
+ IoBuffer Chunk = testutils::CreateBinaryCacheValue(kChunkSize);
+ IoHash Hash = HashBuffer(Chunk);
+ if (Chunks.contains(Hash))
+ {
+ continue;
+ }
+ Chunks[Hash] = {.Bucket = Bucket2, .Buffer = Chunk};
+ break;
+ }
+ }
+
+ CreateDirectories(TempDir.Path());
+
+ WorkerThreadPool ThreadPool(4);
+ GcManager Gc;
+ ZenCacheNamespace Zcs(Gc, TempDir.Path());
+
+ {
+ std::atomic<size_t> WorkCompleted = 0;
+ for (const auto& Chunk : Chunks)
+ {
+ ThreadPool.ScheduleWork([&Zcs, &WorkCompleted, &Chunk]() {
+ Zcs.Put(Chunk.second.Bucket, Chunk.first, {.Value = Chunk.second.Buffer});
+ WorkCompleted.fetch_add(1);
+ });
+ }
+ while (WorkCompleted < Chunks.size())
+ {
+ Sleep(1);
+ }
+ }
+
+ const uint64_t TotalSize = Zcs.StorageSize().DiskSize;
+ CHECK_LE(kChunkSize * Chunks.size(), TotalSize);
+
+ {
+ std::atomic<size_t> WorkCompleted = 0;
+ for (const auto& Chunk : Chunks)
+ {
+ ThreadPool.ScheduleWork([&Zcs, &WorkCompleted, &Chunk]() {
+ std::string Bucket = Chunk.second.Bucket;
+ IoHash ChunkHash = Chunk.first;
+ ZenCacheValue CacheValue;
+
+ CHECK(Zcs.Get(Bucket, ChunkHash, CacheValue));
+ IoHash Hash = IoHash::HashBuffer(CacheValue.Value);
+ CHECK(ChunkHash == Hash);
+ WorkCompleted.fetch_add(1);
+ });
+ }
+ while (WorkCompleted < Chunks.size())
+ {
+ Sleep(1);
+ }
+ }
+ std::unordered_map<IoHash, std::string, IoHash::Hasher> GcChunkHashes;
+ GcChunkHashes.reserve(Chunks.size());
+ for (const auto& Chunk : Chunks)
+ {
+ GcChunkHashes[Chunk.first] = Chunk.second.Bucket;
+ }
+ {
+ std::unordered_map<IoHash, Chunk, IoHash::Hasher> NewChunks;
+
+ for (int32_t Idx = 0; Idx < kChunkCount; ++Idx)
+ {
+ {
+ IoBuffer Chunk = testutils::CreateBinaryCacheValue(kChunkSize);
+ IoHash Hash = HashBuffer(Chunk);
+ NewChunks[Hash] = {.Bucket = Bucket1, .Buffer = Chunk};
+ }
+ {
+ IoBuffer Chunk = testutils::CreateBinaryCacheValue(kChunkSize);
+ IoHash Hash = HashBuffer(Chunk);
+ NewChunks[Hash] = {.Bucket = Bucket2, .Buffer = Chunk};
+ }
+ }
+
+ std::atomic<size_t> WorkCompleted = 0;
+ std::atomic_uint32_t AddedChunkCount = 0;
+ for (const auto& Chunk : NewChunks)
+ {
+ ThreadPool.ScheduleWork([&Zcs, &WorkCompleted, Chunk, &AddedChunkCount]() {
+ Zcs.Put(Chunk.second.Bucket, Chunk.first, {.Value = Chunk.second.Buffer});
+ AddedChunkCount.fetch_add(1);
+ WorkCompleted.fetch_add(1);
+ });
+ }
+
+ for (const auto& Chunk : Chunks)
+ {
+ ThreadPool.ScheduleWork([&Zcs, &WorkCompleted, Chunk]() {
+ ZenCacheValue CacheValue;
+ if (Zcs.Get(Chunk.second.Bucket, Chunk.first, CacheValue))
+ {
+ CHECK(Chunk.first == IoHash::HashBuffer(CacheValue.Value));
+ }
+ WorkCompleted.fetch_add(1);
+ });
+ }
+ while (AddedChunkCount.load() < NewChunks.size())
+ {
+ // Need to be careful since we might GC blocks we don't know outside of RwLock::ExclusiveLockScope
+ for (const auto& Chunk : NewChunks)
+ {
+ ZenCacheValue CacheValue;
+ if (Zcs.Get(Chunk.second.Bucket, Chunk.first, CacheValue))
+ {
+ GcChunkHashes[Chunk.first] = Chunk.second.Bucket;
+ }
+ }
+ std::vector<IoHash> KeepHashes;
+ KeepHashes.reserve(GcChunkHashes.size());
+ for (const auto& Entry : GcChunkHashes)
+ {
+ KeepHashes.push_back(Entry.first);
+ }
+ size_t C = 0;
+ while (C < KeepHashes.size())
+ {
+ if (C % 155 == 0)
+ {
+ if (C < KeepHashes.size() - 1)
+ {
+ KeepHashes[C] = KeepHashes[KeepHashes.size() - 1];
+ KeepHashes.pop_back();
+ }
+ if (C + 3 < KeepHashes.size() - 1)
+ {
+ KeepHashes[C + 3] = KeepHashes[KeepHashes.size() - 1];
+ KeepHashes.pop_back();
+ }
+ }
+ C++;
+ }
+
+ GcContext GcCtx(GcClock::Now() - std::chrono::hours(24));
+ GcCtx.CollectSmallObjects(true);
+ GcCtx.AddRetainedCids(KeepHashes);
+ Zcs.CollectGarbage(GcCtx);
+ const HashKeySet& Deleted = GcCtx.DeletedCids();
+ Deleted.IterateHashes([&GcChunkHashes](const IoHash& ChunkHash) { GcChunkHashes.erase(ChunkHash); });
+ }
+
+ while (WorkCompleted < NewChunks.size() + Chunks.size())
+ {
+ Sleep(1);
+ }
+
+ {
+ // Need to be careful since we might GC blocks we don't know outside of RwLock::ExclusiveLockScope
+ for (const auto& Chunk : NewChunks)
+ {
+ ZenCacheValue CacheValue;
+ if (Zcs.Get(Chunk.second.Bucket, Chunk.first, CacheValue))
+ {
+ GcChunkHashes[Chunk.first] = Chunk.second.Bucket;
+ }
+ }
+ std::vector<IoHash> KeepHashes;
+ KeepHashes.reserve(GcChunkHashes.size());
+ for (const auto& Entry : GcChunkHashes)
+ {
+ KeepHashes.push_back(Entry.first);
+ }
+ size_t C = 0;
+ while (C < KeepHashes.size())
+ {
+ if (C % 155 == 0)
+ {
+ if (C < KeepHashes.size() - 1)
+ {
+ KeepHashes[C] = KeepHashes[KeepHashes.size() - 1];
+ KeepHashes.pop_back();
+ }
+ if (C + 3 < KeepHashes.size() - 1)
+ {
+ KeepHashes[C + 3] = KeepHashes[KeepHashes.size() - 1];
+ KeepHashes.pop_back();
+ }
+ }
+ C++;
+ }
+
+ GcContext GcCtx(GcClock::Now() - std::chrono::hours(24));
+ GcCtx.CollectSmallObjects(true);
+ GcCtx.AddRetainedCids(KeepHashes);
+ Zcs.CollectGarbage(GcCtx);
+ const HashKeySet& Deleted = GcCtx.DeletedCids();
+ Deleted.IterateHashes([&GcChunkHashes](const IoHash& ChunkHash) { GcChunkHashes.erase(ChunkHash); });
+ }
+ }
+ {
+ std::atomic<size_t> WorkCompleted = 0;
+ for (const auto& Chunk : GcChunkHashes)
+ {
+ ThreadPool.ScheduleWork([&Zcs, &WorkCompleted, Chunk]() {
+ ZenCacheValue CacheValue;
+ CHECK(Zcs.Get(Chunk.second, Chunk.first, CacheValue));
+ CHECK(Chunk.first == IoHash::HashBuffer(CacheValue.Value));
+ WorkCompleted.fetch_add(1);
+ });
+ }
+ while (WorkCompleted < GcChunkHashes.size())
+ {
+ Sleep(1);
+ }
+ }
+ }
+}
+
+TEST_CASE("z$.namespaces")
+{
+ using namespace testutils;
+
+ const auto CreateCacheValue = [](size_t Size) -> CbObject {
+ std::vector<uint8_t> Buf;
+ Buf.resize(Size);
+
+ CbObjectWriter Writer;
+ Writer.AddBinary("Binary"sv, Buf.data(), Buf.size());
+ return Writer.Save();
+ };
+
+ ScopedTemporaryDirectory TempDir;
+ CreateDirectories(TempDir.Path());
+
+ IoHash Key1;
+ IoHash Key2;
+ {
+ GcManager Gc;
+ ZenCacheStore Zcs(Gc, {.BasePath = TempDir.Path() / "cache", .AllowAutomaticCreationOfNamespaces = false});
+ const auto Bucket = "teardrinker"sv;
+ const auto CustomNamespace = "mynamespace"sv;
+
+ // Create a cache record
+ Key1 = CreateKey(42);
+ CbObject CacheValue = CreateCacheValue(4096);
+
+ IoBuffer Buffer = CacheValue.GetBuffer().AsIoBuffer();
+ Buffer.SetContentType(ZenContentType::kCbObject);
+
+ ZenCacheValue PutValue = {.Value = Buffer};
+ Zcs.Put(ZenCacheStore::DefaultNamespace, Bucket, Key1, PutValue);
+
+ ZenCacheValue GetValue;
+ CHECK(Zcs.Get(ZenCacheStore::DefaultNamespace, Bucket, Key1, GetValue));
+ CHECK(!Zcs.Get(CustomNamespace, Bucket, Key1, GetValue));
+
+ // This should just be dropped as we don't allow creating of namespaces on the fly
+ Zcs.Put(CustomNamespace, Bucket, Key1, PutValue);
+ CHECK(!Zcs.Get(CustomNamespace, Bucket, Key1, GetValue));
+ }
+
+ {
+ GcManager Gc;
+ ZenCacheStore Zcs(Gc, {.BasePath = TempDir.Path() / "cache", .AllowAutomaticCreationOfNamespaces = true});
+ const auto Bucket = "teardrinker"sv;
+ const auto CustomNamespace = "mynamespace"sv;
+
+ Key2 = CreateKey(43);
+ CbObject CacheValue2 = CreateCacheValue(4096);
+
+ IoBuffer Buffer2 = CacheValue2.GetBuffer().AsIoBuffer();
+ Buffer2.SetContentType(ZenContentType::kCbObject);
+ ZenCacheValue PutValue2 = {.Value = Buffer2};
+ Zcs.Put(CustomNamespace, Bucket, Key2, PutValue2);
+
+ ZenCacheValue GetValue;
+ CHECK(!Zcs.Get(ZenCacheStore::DefaultNamespace, Bucket, Key2, GetValue));
+ CHECK(Zcs.Get(ZenCacheStore::DefaultNamespace, Bucket, Key1, GetValue));
+ CHECK(!Zcs.Get(CustomNamespace, Bucket, Key1, GetValue));
+ CHECK(Zcs.Get(CustomNamespace, Bucket, Key2, GetValue));
+ }
+}
+
+TEST_CASE("z$.drop.bucket")
+{
+ using namespace testutils;
+
+ const auto CreateCacheValue = [](size_t Size) -> CbObject {
+ std::vector<uint8_t> Buf;
+ Buf.resize(Size);
+
+ CbObjectWriter Writer;
+ Writer.AddBinary("Binary"sv, Buf.data(), Buf.size());
+ return Writer.Save();
+ };
+
+ ScopedTemporaryDirectory TempDir;
+ CreateDirectories(TempDir.Path());
+
+ IoHash Key1;
+ IoHash Key2;
+
+ auto PutValue =
+ [&CreateCacheValue](ZenCacheStore& Zcs, std::string_view Namespace, std::string_view Bucket, size_t KeyIndex, size_t Size) {
+ // Create a cache record
+ IoHash Key = CreateKey(KeyIndex);
+ CbObject CacheValue = CreateCacheValue(Size);
+
+ IoBuffer Buffer = CacheValue.GetBuffer().AsIoBuffer();
+ Buffer.SetContentType(ZenContentType::kCbObject);
+
+ ZenCacheValue PutValue = {.Value = Buffer};
+ Zcs.Put(Namespace, Bucket, Key, PutValue);
+ return Key;
+ };
+ auto GetValue = [](ZenCacheStore& Zcs, std::string_view Namespace, std::string_view Bucket, const IoHash& Key) {
+ ZenCacheValue GetValue;
+ Zcs.Get(Namespace, Bucket, Key, GetValue);
+ return GetValue;
+ };
+ WorkerThreadPool Workers(1);
+ {
+ GcManager Gc;
+ ZenCacheStore Zcs(Gc, {.BasePath = TempDir.Path() / "cache", .AllowAutomaticCreationOfNamespaces = true});
+ const auto Bucket = "teardrinker"sv;
+ const auto Namespace = "mynamespace"sv;
+
+ Key1 = PutValue(Zcs, Namespace, Bucket, 42, 4096);
+ Key2 = PutValue(Zcs, Namespace, Bucket, 43, 2048);
+
+ ZenCacheValue Value1 = GetValue(Zcs, Namespace, Bucket, Key1);
+ CHECK(Value1.Value);
+
+ std::atomic_bool WorkComplete = false;
+ Workers.ScheduleWork([&]() {
+ zen::Sleep(100);
+ Value1.Value = IoBuffer{};
+ WorkComplete = true;
+ });
+ // On Windows, DropBucket() will be blocked as long as we hold a reference to a buffer in the bucket
+ // Our DropBucket execution blocks any incoming request from completing until we are done with the drop
+ CHECK(Zcs.DropBucket(Namespace, Bucket));
+ while (!WorkComplete)
+ {
+ zen::Sleep(1);
+ }
+
+ // Entire bucket should be dropped, but doing a request should will re-create the namespace but it must still be empty
+ Value1 = GetValue(Zcs, Namespace, Bucket, Key1);
+ CHECK(!Value1.Value);
+ ZenCacheValue Value2 = GetValue(Zcs, Namespace, Bucket, Key2);
+ CHECK(!Value2.Value);
+ }
+}
+
+TEST_CASE("z$.drop.namespace")
+{
+ using namespace testutils;
+
+ const auto CreateCacheValue = [](size_t Size) -> CbObject {
+ std::vector<uint8_t> Buf;
+ Buf.resize(Size);
+
+ CbObjectWriter Writer;
+ Writer.AddBinary("Binary"sv, Buf.data(), Buf.size());
+ return Writer.Save();
+ };
+
+ ScopedTemporaryDirectory TempDir;
+ CreateDirectories(TempDir.Path());
+
+ auto PutValue =
+ [&CreateCacheValue](ZenCacheStore& Zcs, std::string_view Namespace, std::string_view Bucket, size_t KeyIndex, size_t Size) {
+ // Create a cache record
+ IoHash Key = CreateKey(KeyIndex);
+ CbObject CacheValue = CreateCacheValue(Size);
+
+ IoBuffer Buffer = CacheValue.GetBuffer().AsIoBuffer();
+ Buffer.SetContentType(ZenContentType::kCbObject);
+
+ ZenCacheValue PutValue = {.Value = Buffer};
+ Zcs.Put(Namespace, Bucket, Key, PutValue);
+ return Key;
+ };
+ auto GetValue = [](ZenCacheStore& Zcs, std::string_view Namespace, std::string_view Bucket, const IoHash& Key) {
+ ZenCacheValue GetValue;
+ Zcs.Get(Namespace, Bucket, Key, GetValue);
+ return GetValue;
+ };
+ WorkerThreadPool Workers(1);
+ {
+ GcManager Gc;
+ ZenCacheStore Zcs(Gc, {.BasePath = TempDir.Path() / "cache", .AllowAutomaticCreationOfNamespaces = true});
+ const auto Bucket1 = "teardrinker1"sv;
+ const auto Bucket2 = "teardrinker2"sv;
+ const auto Namespace1 = "mynamespace1"sv;
+ const auto Namespace2 = "mynamespace2"sv;
+
+ IoHash Key1 = PutValue(Zcs, Namespace1, Bucket1, 42, 4096);
+ IoHash Key2 = PutValue(Zcs, Namespace1, Bucket2, 43, 2048);
+ IoHash Key3 = PutValue(Zcs, Namespace2, Bucket1, 44, 4096);
+ IoHash Key4 = PutValue(Zcs, Namespace2, Bucket2, 45, 2048);
+
+ ZenCacheValue Value1 = GetValue(Zcs, Namespace1, Bucket1, Key1);
+ CHECK(Value1.Value);
+ ZenCacheValue Value2 = GetValue(Zcs, Namespace1, Bucket2, Key2);
+ CHECK(Value2.Value);
+ ZenCacheValue Value3 = GetValue(Zcs, Namespace2, Bucket1, Key3);
+ CHECK(Value3.Value);
+ ZenCacheValue Value4 = GetValue(Zcs, Namespace2, Bucket2, Key4);
+ CHECK(Value4.Value);
+
+ std::atomic_bool WorkComplete = false;
+ Workers.ScheduleWork([&]() {
+ zen::Sleep(100);
+ Value1.Value = IoBuffer{};
+ Value2.Value = IoBuffer{};
+ Value3.Value = IoBuffer{};
+ Value4.Value = IoBuffer{};
+ WorkComplete = true;
+ });
+ // On Windows, DropBucket() will be blocked as long as we hold a reference to a buffer in the bucket
+ // Our DropBucket execution blocks any incoming request from completing until we are done with the drop
+ CHECK(Zcs.DropNamespace(Namespace1));
+ while (!WorkComplete)
+ {
+ zen::Sleep(1);
+ }
+
+ // Entire namespace should be dropped, but doing a request should will re-create the namespace but it must still be empty
+ Value1 = GetValue(Zcs, Namespace1, Bucket1, Key1);
+ CHECK(!Value1.Value);
+ Value2 = GetValue(Zcs, Namespace1, Bucket2, Key2);
+ CHECK(!Value2.Value);
+ Value3 = GetValue(Zcs, Namespace2, Bucket1, Key3);
+ CHECK(Value3.Value);
+ Value4 = GetValue(Zcs, Namespace2, Bucket2, Key4);
+ CHECK(Value4.Value);
+ }
+}
+
+TEST_CASE("z$.blocked.disklayer.put")
+{
+ ScopedTemporaryDirectory TempDir;
+
+ GcStorageSize CacheSize;
+
+ const auto CreateCacheValue = [](size_t Size) -> CbObject {
+ std::vector<uint8_t> Buf;
+ Buf.resize(Size, Size & 0xff);
+
+ CbObjectWriter Writer;
+ Writer.AddBinary("Binary"sv, Buf.data(), Buf.size());
+ return Writer.Save();
+ };
+
+ GcManager Gc;
+ ZenCacheNamespace Zcs(Gc, TempDir.Path() / "cache");
+
+ CbObject CacheValue = CreateCacheValue(64 * 1024 + 64);
+
+ IoBuffer Buffer = CacheValue.GetBuffer().AsIoBuffer();
+ Buffer.SetContentType(ZenContentType::kCbObject);
+
+ size_t Key = Buffer.Size();
+ IoHash HashKey = IoHash::HashBuffer(&Key, sizeof(uint32_t));
+ Zcs.Put("test_bucket", HashKey, {.Value = Buffer});
+
+ ZenCacheValue BufferGet;
+ CHECK(Zcs.Get("test_bucket", HashKey, BufferGet));
+
+ CbObject CacheValue2 = CreateCacheValue(64 * 1024 + 64 + 1);
+ IoBuffer Buffer2 = CacheValue2.GetBuffer().AsIoBuffer();
+ Buffer2.SetContentType(ZenContentType::kCbObject);
+
+ // We should be able to overwrite even if the file is open for read
+ Zcs.Put("test_bucket", HashKey, {.Value = Buffer2});
+
+ MemoryView OldView = BufferGet.Value.GetView();
+
+ ZenCacheValue BufferGet2;
+ CHECK(Zcs.Get("test_bucket", HashKey, BufferGet2));
+ MemoryView NewView = BufferGet2.Value.GetView();
+
+ // Make sure file openend for read before we wrote it still have old data
+ CHECK(OldView.GetSize() == Buffer.GetSize());
+ CHECK(memcmp(OldView.GetData(), Buffer.GetData(), OldView.GetSize()) == 0);
+
+ // Make sure we get the new data when reading after we write new data
+ CHECK(NewView.GetSize() == Buffer2.GetSize());
+ CHECK(memcmp(NewView.GetData(), Buffer2.GetData(), NewView.GetSize()) == 0);
+}
+
+TEST_CASE("z$.scrub")
+{
+ ScopedTemporaryDirectory TempDir;
+
+ using namespace testutils;
+
+ struct CacheRecord
+ {
+ IoBuffer Record;
+ std::vector<CompressedBuffer> Attachments;
+ };
+
+ auto CreateCacheRecord = [](bool Structured, std::string_view Bucket, const IoHash& Key, const std::vector<size_t>& AttachmentSizes) {
+ CacheRecord Result;
+ if (Structured)
+ {
+ Result.Attachments.resize(AttachmentSizes.size());
+ CbObjectWriter Record;
+ Record.BeginObject("Key"sv);
+ {
+ Record << "Bucket"sv << Bucket;
+ Record << "Hash"sv << Key;
+ }
+ Record.EndObject();
+ for (size_t Index = 0; Index < AttachmentSizes.size(); Index++)
+ {
+ IoBuffer AttachmentData = CreateBinaryCacheValue(AttachmentSizes[Index]);
+ CompressedBuffer CompressedAttachmentData = CompressedBuffer::Compress(SharedBuffer(AttachmentData));
+ Record.AddBinaryAttachment(fmt::format("attachment-{}", Index), CompressedAttachmentData.DecodeRawHash());
+ Result.Attachments[Index] = CompressedAttachmentData;
+ }
+ Result.Record = Record.Save().GetBuffer().AsIoBuffer();
+ Result.Record.SetContentType(ZenContentType::kCbObject);
+ }
+ else
+ {
+ std::string RecordData = fmt::format("{}:{}", Bucket, Key.ToHexString());
+ size_t TotalSize = RecordData.length() + 1;
+ for (size_t AttachmentSize : AttachmentSizes)
+ {
+ TotalSize += AttachmentSize;
+ }
+ Result.Record = IoBuffer(TotalSize);
+ char* DataPtr = (char*)Result.Record.MutableData();
+ memcpy(DataPtr, RecordData.c_str(), RecordData.length() + 1);
+ DataPtr += RecordData.length() + 1;
+ for (size_t AttachmentSize : AttachmentSizes)
+ {
+ IoBuffer AttachmentData = CreateBinaryCacheValue(AttachmentSize);
+ memcpy(DataPtr, AttachmentData.GetData(), AttachmentData.GetSize());
+ DataPtr += AttachmentData.GetSize();
+ }
+ }
+ return Result;
+ };
+
+ GcManager Gc;
+ CidStore CidStore(Gc);
+ ZenCacheNamespace Zcs(Gc, TempDir.Path() / "cache");
+ CidStoreConfiguration CidConfig = {.RootDirectory = TempDir.Path() / "cas", .TinyValueThreshold = 1024, .HugeValueThreshold = 4096};
+ CidStore.Initialize(CidConfig);
+
+ auto CreateRecords =
+ [&](bool IsStructured, std::string_view BucketName, const std::vector<IoHash>& Cids, const std::vector<size_t>& AttachmentSizes) {
+ for (const IoHash& Cid : Cids)
+ {
+ CacheRecord Record = CreateCacheRecord(IsStructured, BucketName, Cid, AttachmentSizes);
+ Zcs.Put("mybucket", Cid, {.Value = Record.Record});
+ for (const CompressedBuffer& Attachment : Record.Attachments)
+ {
+ CidStore.AddChunk(Attachment.GetCompressed().Flatten().AsIoBuffer(), Attachment.DecodeRawHash());
+ }
+ }
+ };
+
+ std::vector<size_t> AttachmentSizes = {16, 1000, 2000, 4000, 8000, 64000, 80000};
+
+ std::vector<IoHash> UnstructuredCids{CreateKey(4), CreateKey(5), CreateKey(6)};
+ CreateRecords(false, "mybucket"sv, UnstructuredCids, AttachmentSizes);
+
+ std::vector<IoHash> StructuredCids{CreateKey(1), CreateKey(2), CreateKey(3)};
+ CreateRecords(true, "mybucket"sv, StructuredCids, AttachmentSizes);
+
+ ScrubContext ScrubCtx;
+ Zcs.Scrub(ScrubCtx);
+ CidStore.Scrub(ScrubCtx);
+ CHECK(ScrubCtx.ScrubbedChunks() == (StructuredCids.size() + StructuredCids.size() * AttachmentSizes.size()) + UnstructuredCids.size());
+ CHECK(ScrubCtx.BadCids().GetSize() == 0);
+}
+
+#endif
+
+void
+z$_forcelink()
+{
+}
+
+} // namespace zen
diff --git a/src/zenserver/cache/structuredcachestore.h b/src/zenserver/cache/structuredcachestore.h
new file mode 100644
index 000000000..3fb4f035d
--- /dev/null
+++ b/src/zenserver/cache/structuredcachestore.h
@@ -0,0 +1,535 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/compactbinary.h>
+#include <zencore/iobuffer.h>
+#include <zencore/iohash.h>
+#include <zencore/thread.h>
+#include <zencore/uid.h>
+#include <zenstore/blockstore.h>
+#include <zenstore/caslog.h>
+#include <zenstore/gc.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <tsl/robin_map.h>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+#include <atomic>
+#include <compare>
+#include <filesystem>
+#include <unordered_map>
+
+#define ZEN_USE_CACHE_TRACKER 0
+
+namespace zen {
+
+class PathBuilderBase;
+class GcManager;
+class ZenCacheTracker;
+class ScrubContext;
+
+/******************************************************************************
+
+ /$$$$$$$$ /$$$$$$ /$$
+ |_____ $$ /$$__ $$ | $$
+ /$$/ /$$$$$$ /$$$$$$$ | $$ \__/ /$$$$$$ /$$$$$$| $$$$$$$ /$$$$$$
+ /$$/ /$$__ $| $$__ $$ | $$ |____ $$/$$_____| $$__ $$/$$__ $$
+ /$$/ | $$$$$$$| $$ \ $$ | $$ /$$$$$$| $$ | $$ \ $| $$$$$$$$
+ /$$/ | $$_____| $$ | $$ | $$ $$/$$__ $| $$ | $$ | $| $$_____/
+ /$$$$$$$| $$$$$$| $$ | $$ | $$$$$$| $$$$$$| $$$$$$| $$ | $| $$$$$$$
+ |________/\_______|__/ |__/ \______/ \_______/\_______|__/ |__/\_______/
+
+ Cache store for UE5. Restricts keys to "{bucket}/{hash}" pairs where the hash
+ is 40 (hex) chars in size. Values may be opaque blobs or structured objects
+ which can in turn contain references to other objects (or blobs).
+
+******************************************************************************/
+
+namespace access_tracking {
+
+ struct KeyAccessTime
+ {
+ IoHash Key;
+ GcClock::Tick LastAccess{};
+ };
+
+ struct AccessTimes
+ {
+ std::unordered_map<std::string, std::vector<KeyAccessTime>> Buckets;
+ };
+}; // namespace access_tracking
+
+struct ZenCacheValue
+{
+ IoBuffer Value;
+ uint64_t RawSize = 0;
+ IoHash RawHash = IoHash::Zero;
+};
+
+struct CacheValueDetails
+{
+ struct ValueDetails
+ {
+ uint64_t Size;
+ uint64_t RawSize;
+ IoHash RawHash;
+ GcClock::Tick LastAccess{};
+ std::vector<IoHash> Attachments;
+ ZenContentType ContentType;
+ };
+
+ struct BucketDetails
+ {
+ std::unordered_map<IoHash, ValueDetails, IoHash::Hasher> Values;
+ };
+
+ struct NamespaceDetails
+ {
+ std::unordered_map<std::string, BucketDetails> Buckets;
+ };
+
+ std::unordered_map<std::string, NamespaceDetails> Namespaces;
+};
+
+//////////////////////////////////////////////////////////////////////////
+
+#pragma pack(push)
+#pragma pack(1)
+
+struct DiskLocation
+{
+ inline DiskLocation() = default;
+
+ inline DiskLocation(uint64_t ValueSize, uint8_t Flags) : Flags(Flags | kStandaloneFile) { Location.StandaloneSize = ValueSize; }
+
+ inline DiskLocation(const BlockStoreLocation& Location, uint64_t PayloadAlignment, uint8_t Flags) : Flags(Flags & ~kStandaloneFile)
+ {
+ this->Location.BlockLocation = BlockStoreDiskLocation(Location, PayloadAlignment);
+ }
+
+ inline BlockStoreLocation GetBlockLocation(uint64_t PayloadAlignment) const
+ {
+ ZEN_ASSERT(!(Flags & kStandaloneFile));
+ return Location.BlockLocation.Get(PayloadAlignment);
+ }
+
+ inline uint64_t Size() const { return (Flags & kStandaloneFile) ? Location.StandaloneSize : Location.BlockLocation.GetSize(); }
+ inline uint8_t IsFlagSet(uint64_t Flag) const { return Flags & Flag; }
+ inline uint8_t GetFlags() const { return Flags; }
+ inline ZenContentType GetContentType() const
+ {
+ ZenContentType ContentType = ZenContentType::kBinary;
+
+ if (IsFlagSet(kStructured))
+ {
+ ContentType = ZenContentType::kCbObject;
+ }
+
+ if (IsFlagSet(kCompressed))
+ {
+ ContentType = ZenContentType::kCompressedBinary;
+ }
+
+ return ContentType;
+ }
+
+ union
+ {
+ BlockStoreDiskLocation BlockLocation; // 10 bytes
+ uint64_t StandaloneSize = 0; // 8 bytes
+ } Location;
+
+ static const uint8_t kStandaloneFile = 0x80u; // Stored as a separate file
+ static const uint8_t kStructured = 0x40u; // Serialized as compact binary
+ static const uint8_t kTombStone = 0x20u; // Represents a deleted key/value
+ static const uint8_t kCompressed = 0x10u; // Stored in compressed buffer format
+
+ uint8_t Flags = 0;
+ uint8_t Reserved = 0;
+};
+
+struct DiskIndexEntry
+{
+ IoHash Key; // 20 bytes
+ DiskLocation Location; // 12 bytes
+};
+
+#pragma pack(pop)
+
+static_assert(sizeof(DiskIndexEntry) == 32);
+
+// This store the access time as seconds since epoch internally in a 32-bit value giving is a range of 136 years since epoch
+struct AccessTime
+{
+ explicit AccessTime(GcClock::Tick Tick) noexcept : SecondsSinceEpoch(ToSeconds(Tick)) {}
+ AccessTime& operator=(GcClock::Tick Tick) noexcept
+ {
+ SecondsSinceEpoch.store(ToSeconds(Tick), std::memory_order_relaxed);
+ return *this;
+ }
+ operator GcClock::Tick() const noexcept
+ {
+ return std::chrono::duration_cast<GcClock::Duration>(std::chrono::seconds(SecondsSinceEpoch.load(std::memory_order_relaxed)))
+ .count();
+ }
+
+ AccessTime(AccessTime&& Rhs) noexcept : SecondsSinceEpoch(Rhs.SecondsSinceEpoch.load(std::memory_order_relaxed)) {}
+ AccessTime(const AccessTime& Rhs) noexcept : SecondsSinceEpoch(Rhs.SecondsSinceEpoch.load(std::memory_order_relaxed)) {}
+ AccessTime& operator=(AccessTime&& Rhs) noexcept
+ {
+ SecondsSinceEpoch.store(Rhs.SecondsSinceEpoch.load(std::memory_order_relaxed), std::memory_order_relaxed);
+ return *this;
+ }
+ AccessTime& operator=(const AccessTime& Rhs) noexcept
+ {
+ SecondsSinceEpoch.store(Rhs.SecondsSinceEpoch.load(std::memory_order_relaxed), std::memory_order_relaxed);
+ return *this;
+ }
+
+private:
+ static uint32_t ToSeconds(GcClock::Tick Tick)
+ {
+ return gsl::narrow<uint32_t>(std::chrono::duration_cast<std::chrono::seconds>(GcClock::Duration(Tick)).count());
+ }
+ std::atomic_uint32_t SecondsSinceEpoch;
+};
+
+/** In-memory cache storage
+
+ Intended for small values which are frequently accessed
+
+ This should have a better memory management policy to maintain reasonable
+ footprint.
+ */
+class ZenCacheMemoryLayer
+{
+public:
+ struct Configuration
+ {
+ uint64_t TargetFootprintBytes = 16 * 1024 * 1024;
+ uint64_t ScavengeThreshold = 4 * 1024 * 1024;
+ };
+
+ struct BucketInfo
+ {
+ uint64_t EntryCount = 0;
+ uint64_t TotalSize = 0;
+ };
+
+ struct Info
+ {
+ Configuration Config;
+ std::vector<std::string> BucketNames;
+ uint64_t EntryCount = 0;
+ uint64_t TotalSize = 0;
+ };
+
+ ZenCacheMemoryLayer();
+ ~ZenCacheMemoryLayer();
+
+ bool Get(std::string_view Bucket, const IoHash& HashKey, ZenCacheValue& OutValue);
+ void Put(std::string_view Bucket, const IoHash& HashKey, const ZenCacheValue& Value);
+ void Drop();
+ bool DropBucket(std::string_view Bucket);
+ void Scrub(ScrubContext& Ctx);
+ void GatherAccessTimes(zen::access_tracking::AccessTimes& AccessTimes);
+ void Reset();
+ uint64_t TotalSize() const;
+
+ Info GetInfo() const;
+ std::optional<BucketInfo> GetBucketInfo(std::string_view Bucket) const;
+
+ const Configuration& GetConfiguration() const { return m_Configuration; }
+ void SetConfiguration(const Configuration& NewConfig) { m_Configuration = NewConfig; }
+
+private:
+ struct CacheBucket
+ {
+#pragma pack(push)
+#pragma pack(1)
+ struct BucketPayload
+ {
+ IoBuffer Payload; // 8
+ uint32_t RawSize; // 4
+ IoHash RawHash; // 20
+ };
+#pragma pack(pop)
+ static_assert(sizeof(BucketPayload) == 32u);
+ static_assert(sizeof(AccessTime) == 4u);
+
+ mutable RwLock m_BucketLock;
+ std::vector<AccessTime> m_AccessTimes;
+ std::vector<BucketPayload> m_Payloads;
+ tsl::robin_map<IoHash, uint32_t> m_CacheMap;
+
+ std::atomic_uint64_t m_TotalSize{};
+
+ bool Get(const IoHash& HashKey, ZenCacheValue& OutValue);
+ void Put(const IoHash& HashKey, const ZenCacheValue& Value);
+ void Drop();
+ void Scrub(ScrubContext& Ctx);
+ void GatherAccessTimes(std::vector<zen::access_tracking::KeyAccessTime>& AccessTimes);
+ inline uint64_t TotalSize() const { return m_TotalSize; }
+ uint64_t EntryCount() const;
+ };
+
+ mutable RwLock m_Lock;
+ std::unordered_map<std::string, std::unique_ptr<CacheBucket>> m_Buckets;
+ std::vector<std::unique_ptr<CacheBucket>> m_DroppedBuckets;
+ Configuration m_Configuration;
+
+ ZenCacheMemoryLayer(const ZenCacheMemoryLayer&) = delete;
+ ZenCacheMemoryLayer& operator=(const ZenCacheMemoryLayer&) = delete;
+};
+
+class ZenCacheDiskLayer
+{
+public:
+ struct Configuration
+ {
+ std::filesystem::path RootDir;
+ };
+
+ struct BucketInfo
+ {
+ uint64_t EntryCount = 0;
+ uint64_t TotalSize = 0;
+ };
+
+ struct Info
+ {
+ Configuration Config;
+ std::vector<std::string> BucketNames;
+ uint64_t EntryCount = 0;
+ uint64_t TotalSize = 0;
+ };
+
+ explicit ZenCacheDiskLayer(const std::filesystem::path& RootDir);
+ ~ZenCacheDiskLayer();
+
+ bool Get(std::string_view Bucket, const IoHash& HashKey, ZenCacheValue& OutValue);
+ void Put(std::string_view Bucket, const IoHash& HashKey, const ZenCacheValue& Value);
+ bool Drop();
+ bool DropBucket(std::string_view Bucket);
+ void Flush();
+ void Scrub(ScrubContext& Ctx);
+ void GatherReferences(GcContext& GcCtx);
+ void CollectGarbage(GcContext& GcCtx);
+ void UpdateAccessTimes(const zen::access_tracking::AccessTimes& AccessTimes);
+
+ void DiscoverBuckets();
+ uint64_t TotalSize() const;
+
+ Info GetInfo() const;
+ std::optional<BucketInfo> GetBucketInfo(std::string_view Bucket) const;
+
+ CacheValueDetails::NamespaceDetails GetValueDetails(const std::string_view BucketFilter, const std::string_view ValueFilter) const;
+
+private:
+ /** A cache bucket manages a single directory containing
+ metadata and data for that bucket
+ */
+ struct CacheBucket
+ {
+ CacheBucket(std::string BucketName);
+ ~CacheBucket();
+
+ bool OpenOrCreate(std::filesystem::path BucketDir, bool AllowCreate = true);
+ bool Get(const IoHash& HashKey, ZenCacheValue& OutValue);
+ void Put(const IoHash& HashKey, const ZenCacheValue& Value);
+ bool Drop();
+ void Flush();
+ void Scrub(ScrubContext& Ctx);
+ void GatherReferences(GcContext& GcCtx);
+ void CollectGarbage(GcContext& GcCtx);
+ void UpdateAccessTimes(const std::vector<zen::access_tracking::KeyAccessTime>& AccessTimes);
+
+ inline uint64_t TotalSize() const { return m_TotalStandaloneSize.load(std::memory_order::relaxed) + m_BlockStore.TotalSize(); }
+ uint64_t EntryCount() const;
+
+ CacheValueDetails::BucketDetails GetValueDetails(const std::string_view ValueFilter) const;
+
+ private:
+ const uint64_t MaxBlockSize = 1ull << 30;
+ uint64_t m_PayloadAlignment = 1ull << 4;
+
+ std::string m_BucketName;
+ std::filesystem::path m_BucketDir;
+ std::filesystem::path m_BlocksBasePath;
+ BlockStore m_BlockStore;
+ Oid m_BucketId;
+ uint64_t m_LargeObjectThreshold = 128 * 1024;
+
+ // These files are used to manage storage of small objects for this bucket
+
+ TCasLogFile<DiskIndexEntry> m_SlogFile;
+ uint64_t m_LogFlushPosition = 0;
+
+#pragma pack(push)
+#pragma pack(1)
+ struct BucketPayload
+ {
+ DiskLocation Location; // 12
+ uint64_t RawSize; // 8
+ IoHash RawHash; // 20
+ };
+#pragma pack(pop)
+ static_assert(sizeof(BucketPayload) == 40u);
+ static_assert(sizeof(AccessTime) == 4u);
+
+ using IndexMap = tsl::robin_map<IoHash, size_t, IoHash::Hasher>;
+
+ mutable RwLock m_IndexLock;
+ std::vector<AccessTime> m_AccessTimes;
+ std::vector<BucketPayload> m_Payloads;
+ IndexMap m_Index;
+
+ std::atomic_uint64_t m_TotalStandaloneSize{};
+
+ void BuildPath(PathBuilderBase& Path, const IoHash& HashKey) const;
+ void PutStandaloneCacheValue(const IoHash& HashKey, const ZenCacheValue& Value);
+ IoBuffer GetStandaloneCacheValue(const DiskLocation& Loc, const IoHash& HashKey) const;
+ void PutInlineCacheValue(const IoHash& HashKey, const ZenCacheValue& Value);
+ IoBuffer GetInlineCacheValue(const DiskLocation& Loc) const;
+ void MakeIndexSnapshot();
+ uint64_t ReadIndexFile(const std::filesystem::path& IndexPath, uint32_t& OutVersion);
+ uint64_t ReadLog(const std::filesystem::path& LogPath, uint64_t LogPosition);
+ void OpenLog(const bool IsNew);
+ void SaveManifest();
+ CacheValueDetails::ValueDetails GetValueDetails(const IoHash& Key, size_t Index) const;
+ // These locks are here to avoid contention on file creation, therefore it's sufficient
+ // that we take the same lock for the same hash
+ //
+ // These locks are small and should really be spaced out so they don't share cache lines,
+ // but we don't currently access them at particularly high frequency so it should not be
+ // an issue in practice
+
+ mutable RwLock m_ShardedLocks[256];
+ inline RwLock& LockForHash(const IoHash& Hash) const { return m_ShardedLocks[Hash.Hash[19]]; }
+ };
+
+ std::filesystem::path m_RootDir;
+ mutable RwLock m_Lock;
+ std::unordered_map<std::string, std::unique_ptr<CacheBucket>> m_Buckets; // TODO: make this case insensitive
+ std::vector<std::unique_ptr<CacheBucket>> m_DroppedBuckets;
+
+ ZenCacheDiskLayer(const ZenCacheDiskLayer&) = delete;
+ ZenCacheDiskLayer& operator=(const ZenCacheDiskLayer&) = delete;
+};
+
+class ZenCacheNamespace final : public RefCounted, public GcStorage, public GcContributor
+{
+public:
+ struct Configuration
+ {
+ std::filesystem::path RootDir;
+ uint64_t DiskLayerThreshold = 0;
+ };
+ struct BucketInfo
+ {
+ ZenCacheDiskLayer::BucketInfo DiskLayerInfo;
+ ZenCacheMemoryLayer::BucketInfo MemoryLayerInfo;
+ };
+ struct Info
+ {
+ Configuration Config;
+ std::vector<std::string> BucketNames;
+ ZenCacheDiskLayer::Info DiskLayerInfo;
+ ZenCacheMemoryLayer::Info MemoryLayerInfo;
+ };
+
+ ZenCacheNamespace(GcManager& Gc, const std::filesystem::path& RootDir);
+ ~ZenCacheNamespace();
+
+ bool Get(std::string_view Bucket, const IoHash& HashKey, ZenCacheValue& OutValue);
+ void Put(std::string_view Bucket, const IoHash& HashKey, const ZenCacheValue& Value);
+ bool Drop();
+ bool DropBucket(std::string_view Bucket);
+ void Flush();
+ void Scrub(ScrubContext& Ctx);
+ uint64_t DiskLayerThreshold() const { return m_DiskLayerSizeThreshold; }
+ virtual void GatherReferences(GcContext& GcCtx) override;
+ virtual void CollectGarbage(GcContext& GcCtx) override;
+ virtual GcStorageSize StorageSize() const override;
+ Info GetInfo() const;
+ std::optional<BucketInfo> GetBucketInfo(std::string_view Bucket) const;
+
+ CacheValueDetails::NamespaceDetails GetValueDetails(const std::string_view BucketFilter, const std::string_view ValueFilter) const;
+
+private:
+ std::filesystem::path m_RootDir;
+ ZenCacheMemoryLayer m_MemLayer;
+ ZenCacheDiskLayer m_DiskLayer;
+ uint64_t m_DiskLayerSizeThreshold = 1 * 1024;
+ uint64_t m_LastScrubTime = 0;
+
+#if ZEN_USE_CACHE_TRACKER
+ std::unique_ptr<ZenCacheTracker> m_AccessTracker;
+#endif
+
+ ZenCacheNamespace(const ZenCacheNamespace&) = delete;
+ ZenCacheNamespace& operator=(const ZenCacheNamespace&) = delete;
+};
+
+class ZenCacheStore final
+{
+public:
+ static constexpr std::string_view DefaultNamespace =
+ "!default!"; // This is intentionally not a valid namespace name and will only be used for mapping when no namespace is given
+ static constexpr std::string_view NamespaceDiskPrefix = "ns_";
+
+ struct Configuration
+ {
+ std::filesystem::path BasePath;
+ bool AllowAutomaticCreationOfNamespaces = false;
+ };
+
+ struct Info
+ {
+ Configuration Config;
+ std::vector<std::string> NamespaceNames;
+ uint64_t DiskEntryCount = 0;
+ uint64_t MemoryEntryCount = 0;
+ GcStorageSize StorageSize;
+ };
+
+ ZenCacheStore(GcManager& Gc, const Configuration& Configuration);
+ ~ZenCacheStore();
+
+ bool Get(std::string_view Namespace, std::string_view Bucket, const IoHash& HashKey, ZenCacheValue& OutValue);
+ void Put(std::string_view Namespace, std::string_view Bucket, const IoHash& HashKey, const ZenCacheValue& Value);
+ bool DropBucket(std::string_view Namespace, std::string_view Bucket);
+ bool DropNamespace(std::string_view Namespace);
+ void Flush();
+ void Scrub(ScrubContext& Ctx);
+
+ CacheValueDetails GetValueDetails(const std::string_view NamespaceFilter,
+ const std::string_view BucketFilter,
+ const std::string_view ValueFilter) const;
+
+ GcStorageSize StorageSize() const;
+ // const Configuration& GetConfiguration() const { return m_Configuration; }
+
+ Info GetInfo() const;
+ std::optional<ZenCacheNamespace::Info> GetNamespaceInfo(std::string_view Namespace);
+ std::optional<ZenCacheNamespace::BucketInfo> GetBucketInfo(std::string_view Namespace, std::string_view Bucket);
+
+private:
+ const ZenCacheNamespace* FindNamespace(std::string_view Namespace) const;
+ ZenCacheNamespace* GetNamespace(std::string_view Namespace);
+ void IterateNamespaces(const std::function<void(std::string_view Namespace, ZenCacheNamespace& Store)>& Callback) const;
+
+ typedef std::unordered_map<std::string, std::unique_ptr<ZenCacheNamespace>> NamespaceMap;
+
+ mutable RwLock m_NamespacesLock;
+ NamespaceMap m_Namespaces;
+ std::vector<std::unique_ptr<ZenCacheNamespace>> m_DroppedNamespaces;
+
+ GcManager& m_Gc;
+ Configuration m_Configuration;
+};
+
+void z$_forcelink();
+
+} // namespace zen
diff --git a/src/zenserver/cidstore.cpp b/src/zenserver/cidstore.cpp
new file mode 100644
index 000000000..bce4f1dfb
--- /dev/null
+++ b/src/zenserver/cidstore.cpp
@@ -0,0 +1,124 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "cidstore.h"
+
+#include <zencore/compress.h>
+#include <zencore/fmtutils.h>
+#include <zencore/logging.h>
+#include <zenstore/cidstore.h>
+
+#include <gsl/gsl-lite.hpp>
+
+namespace zen {
+
+HttpCidService::HttpCidService(CidStore& Store) : m_CidStore(Store)
+{
+ m_Router.AddPattern("cid", "([0-9A-Fa-f]{40})");
+
+ m_Router.RegisterRoute(
+ "{cid}",
+ [this](HttpRouterRequest& Req) {
+ IoHash Hash = IoHash::FromHexString(Req.GetCapture(1));
+ ZEN_DEBUG("CID request for {}", Hash);
+
+ HttpServerRequest& ServerRequest = Req.ServerRequest();
+
+ switch (ServerRequest.RequestVerb())
+ {
+ case HttpVerb::kGet:
+ case HttpVerb::kHead:
+ {
+ if (IoBuffer Value = m_CidStore.FindChunkByCid(Hash))
+ {
+ return ServerRequest.WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, Value);
+ }
+
+ return ServerRequest.WriteResponse(HttpResponseCode::NotFound);
+ }
+ break;
+
+ case HttpVerb::kPut:
+ {
+ IoBuffer Payload = ServerRequest.ReadPayload();
+ IoHash RawHash;
+ uint64_t RawSize;
+ if (!CompressedBuffer::ValidateCompressedHeader(Payload, RawHash, RawSize))
+ {
+ return ServerRequest.WriteResponse(HttpResponseCode::UnsupportedMediaType);
+ }
+
+ // URI hash must match content hash
+ if (RawHash != Hash)
+ {
+ return ServerRequest.WriteResponse(HttpResponseCode::BadRequest);
+ }
+
+ m_CidStore.AddChunk(Payload, RawHash);
+
+ return ServerRequest.WriteResponse(HttpResponseCode::OK);
+ }
+ break;
+
+ default:
+ break;
+ }
+ },
+ HttpVerb::kGet | HttpVerb::kPut | HttpVerb::kHead);
+}
+
+const char*
+HttpCidService::BaseUri() const
+{
+ return "/cid/";
+}
+
+void
+HttpCidService::HandleRequest(zen::HttpServerRequest& Request)
+{
+ if (Request.RelativeUri().empty())
+ {
+ // Root URI request
+
+ switch (Request.RequestVerb())
+ {
+ case HttpVerb::kPut:
+ case HttpVerb::kPost:
+ {
+ IoBuffer Payload = Request.ReadPayload();
+ IoHash RawHash;
+ uint64_t RawSize;
+ if (!CompressedBuffer::ValidateCompressedHeader(Payload, RawHash, RawSize))
+ {
+ return Request.WriteResponse(HttpResponseCode::UnsupportedMediaType);
+ }
+
+ ZEN_DEBUG("CID POST request for {} ({} bytes)", RawHash, Payload.Size());
+
+ auto InsertResult = m_CidStore.AddChunk(Payload, RawHash);
+
+ if (InsertResult.New)
+ {
+ return Request.WriteResponse(HttpResponseCode::Created);
+ }
+ else
+ {
+ return Request.WriteResponse(HttpResponseCode::OK);
+ }
+ }
+ break;
+
+ case HttpVerb::kGet:
+ case HttpVerb::kHead:
+ break;
+
+ default:
+ break;
+ }
+ }
+ else
+ {
+ m_Router.HandleRequest(Request);
+ }
+}
+
+} // namespace zen
diff --git a/src/zenserver/cidstore.h b/src/zenserver/cidstore.h
new file mode 100644
index 000000000..8e7832b35
--- /dev/null
+++ b/src/zenserver/cidstore.h
@@ -0,0 +1,35 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zenhttp/httpserver.h>
+
+namespace zen {
+
+/**
+ * Simple CID store HTTP endpoint
+ *
+ * Note that since this does not end up pinning any of the chunks it's only really useful for a small subset of use cases where you know a
+ * chunk exists in the underlying CID store. Thus it's mainly useful for internal use when communicating between Zen store instances
+ *
+ * Using this interface for adding CID chunks makes little sense except for testing purposes as garbage collection may reap anything you add
+ * before anything ever gets to access it
+ */
+
+class CidStore;
+
+class HttpCidService : public HttpService
+{
+public:
+ explicit HttpCidService(CidStore& Store);
+ ~HttpCidService() = default;
+
+ virtual const char* BaseUri() const override;
+ virtual void HandleRequest(zen::HttpServerRequest& Request) override;
+
+private:
+ CidStore& m_CidStore;
+ HttpRequestRouter m_Router;
+};
+
+} // namespace zen
diff --git a/src/zenserver/compute/function.cpp b/src/zenserver/compute/function.cpp
new file mode 100644
index 000000000..493e2666e
--- /dev/null
+++ b/src/zenserver/compute/function.cpp
@@ -0,0 +1,629 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "function.h"
+
+#if ZEN_WITH_COMPUTE_SERVICES
+
+# include <upstream/jupiter.h>
+# include <upstream/upstreamapply.h>
+# include <upstream/upstreamcache.h>
+# include <zencore/compactbinary.h>
+# include <zencore/compactbinarybuilder.h>
+# include <zencore/compactbinarypackage.h>
+# include <zencore/compress.h>
+# include <zencore/except.h>
+# include <zencore/filesystem.h>
+# include <zencore/fmtutils.h>
+# include <zencore/iobuffer.h>
+# include <zencore/iohash.h>
+# include <zencore/scopeguard.h>
+# include <zenstore/cidstore.h>
+
+# include <span>
+
+using namespace std::literals;
+
+namespace zen {
+
+HttpFunctionService::HttpFunctionService(CidStore& InCidStore,
+ const CloudCacheClientOptions& ComputeOptions,
+ const CloudCacheClientOptions& StorageOptions,
+ const UpstreamAuthConfig& ComputeAuthConfig,
+ const UpstreamAuthConfig& StorageAuthConfig,
+ AuthMgr& Mgr)
+: m_Log(logging::Get("apply"))
+, m_CidStore(InCidStore)
+{
+ m_UpstreamApply = UpstreamApply::Create({}, m_CidStore);
+
+ InitializeThread = std::thread{[this, ComputeOptions, StorageOptions, ComputeAuthConfig, StorageAuthConfig, &Mgr] {
+ auto HordeUpstreamEndpoint = UpstreamApplyEndpoint::CreateHordeEndpoint(ComputeOptions,
+ ComputeAuthConfig,
+ StorageOptions,
+ StorageAuthConfig,
+ m_CidStore,
+ Mgr);
+ m_UpstreamApply->RegisterEndpoint(std::move(HordeUpstreamEndpoint));
+ m_UpstreamApply->Initialize();
+ }};
+
+ m_Router.AddPattern("job", "([[:digit:]]+)");
+ m_Router.AddPattern("worker", "([[:xdigit:]]{40})");
+ m_Router.AddPattern("action", "([[:xdigit:]]{40})");
+
+ m_Router.RegisterRoute(
+ "ready",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+
+ return HttpReq.WriteResponse(m_UpstreamApply->IsHealthy() ? HttpResponseCode::OK : HttpResponseCode::ServiceUnavailable);
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "workers/{worker}",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+
+ const IoHash WorkerId = IoHash::FromHexString(Req.GetCapture(1));
+
+ switch (HttpReq.RequestVerb())
+ {
+ case HttpVerb::kGet:
+ {
+ RwLock::SharedLockScope _(m_WorkerLock);
+
+ if (auto It = m_WorkerMap.find(WorkerId); It == m_WorkerMap.end())
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound);
+ }
+ else
+ {
+ const WorkerDesc& Desc = It->second;
+ return HttpReq.WriteResponse(HttpResponseCode::OK, Desc.Descriptor);
+ }
+ }
+ break;
+
+ case HttpVerb::kPost:
+ {
+ switch (HttpReq.RequestContentType())
+ {
+ case HttpContentType::kCbObject:
+ {
+ CbObject FunctionSpec = HttpReq.ReadPayloadObject();
+
+ // Determine which pieces are missing and need to be transmitted to populate CAS
+
+ HashKeySet ChunkSet;
+
+ FunctionSpec.IterateAttachments([&](CbFieldView Field) {
+ const IoHash Hash = Field.AsHash();
+ ChunkSet.AddHashToSet(Hash);
+ });
+
+ // Note that we store executables uncompressed to make it
+ // more straightforward and efficient to materialize them, hence
+ // the CAS lookup here instead of CID for the input payloads
+
+ m_CidStore.FilterChunks(ChunkSet);
+
+ if (ChunkSet.IsEmpty())
+ {
+ RwLock::ExclusiveLockScope _(m_WorkerLock);
+
+ m_WorkerMap.insert_or_assign(WorkerId, WorkerDesc{FunctionSpec});
+
+ ZEN_DEBUG("worker {}: all attachments already available", WorkerId);
+
+ return HttpReq.WriteResponse(HttpResponseCode::NoContent);
+ }
+ else
+ {
+ CbObjectWriter ResponseWriter;
+ ResponseWriter.BeginArray("need");
+
+ ChunkSet.IterateHashes([&](const IoHash& Hash) {
+ ZEN_DEBUG("worker {}: need chunk {}", WorkerId, Hash);
+
+ ResponseWriter.AddHash(Hash);
+ });
+
+ ResponseWriter.EndArray();
+
+ ZEN_DEBUG("worker {}: need {} attachments", WorkerId, ChunkSet.GetSize());
+
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound, ResponseWriter.Save());
+ }
+ }
+ break;
+
+ case HttpContentType::kCbPackage:
+ {
+ CbPackage FunctionSpec = HttpReq.ReadPayloadPackage();
+
+ CbObject Obj = FunctionSpec.GetObject();
+
+ std::span<const CbAttachment> Attachments = FunctionSpec.GetAttachments();
+
+ int AttachmentCount = 0;
+ int NewAttachmentCount = 0;
+ uint64_t TotalAttachmentBytes = 0;
+ uint64_t TotalNewBytes = 0;
+
+ for (const CbAttachment& Attachment : Attachments)
+ {
+ ZEN_ASSERT(Attachment.IsCompressedBinary());
+
+ const IoHash DataHash = Attachment.GetHash();
+ CompressedBuffer Buffer = Attachment.AsCompressedBinary();
+
+ ZEN_UNUSED(DataHash);
+ TotalAttachmentBytes += Buffer.GetCompressedSize();
+ ++AttachmentCount;
+
+ const CidStore::InsertResult InsertResult =
+ m_CidStore.AddChunk(Buffer.GetCompressed().Flatten().AsIoBuffer(), DataHash);
+
+ if (InsertResult.New)
+ {
+ TotalNewBytes += Buffer.GetCompressedSize();
+ ++NewAttachmentCount;
+ }
+ }
+
+ ZEN_DEBUG("worker {}: {} in {} attachments, {} in {} new attachments",
+ WorkerId,
+ zen::NiceBytes(TotalAttachmentBytes),
+ AttachmentCount,
+ zen::NiceBytes(TotalNewBytes),
+ NewAttachmentCount);
+
+ RwLock::ExclusiveLockScope _(m_WorkerLock);
+
+ m_WorkerMap.insert_or_assign(WorkerId, WorkerDesc{.Descriptor = Obj});
+
+ return HttpReq.WriteResponse(HttpResponseCode::NoContent);
+ }
+ break;
+
+ default:
+ break;
+ }
+ }
+ break;
+
+ default:
+ break;
+ }
+ },
+ HttpVerb::kGet | HttpVerb::kPost);
+
+ m_Router.RegisterRoute(
+ "jobs/{job}",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+
+ switch (HttpReq.RequestVerb())
+ {
+ case HttpVerb::kGet:
+ break;
+
+ case HttpVerb::kPost:
+ break;
+
+ default:
+ break;
+ }
+ },
+ HttpVerb::kGet | HttpVerb::kPost);
+
+ m_Router.RegisterRoute(
+ "jobs/{worker}/{action}",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+ const IoHash WorkerId = IoHash::FromHexString(Req.GetCapture(1));
+ const IoHash ActionId = IoHash::FromHexString(Req.GetCapture(2));
+
+ switch (HttpReq.RequestVerb())
+ {
+ case HttpVerb::kGet:
+ {
+ CbPackage Output;
+ HttpResponseCode ResponseCode = ExecActionUpstreamResult(WorkerId, ActionId, Output);
+ if (ResponseCode != HttpResponseCode::OK)
+ {
+ return HttpReq.WriteResponse(ResponseCode);
+ }
+ return HttpReq.WriteResponse(HttpResponseCode::OK, Output);
+ }
+ break;
+ }
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "simple/{worker}",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+ const IoHash WorkerId = IoHash::FromHexString(Req.GetCapture(1));
+
+ WorkerDesc Worker;
+
+ {
+ RwLock::SharedLockScope _(m_WorkerLock);
+
+ if (auto It = m_WorkerMap.find(WorkerId); It == m_WorkerMap.end())
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound);
+ }
+ else
+ {
+ Worker = It->second;
+ }
+ }
+
+ switch (HttpReq.RequestVerb())
+ {
+ case HttpVerb::kGet:
+ {
+ CbObject Output;
+ HttpResponseCode ResponseCode = ExecActionUpstreamResult(WorkerId, Output);
+ if (ResponseCode != HttpResponseCode::OK)
+ {
+ return HttpReq.WriteResponse(ResponseCode);
+ }
+
+ {
+ RwLock::SharedLockScope _(m_WorkerLock);
+ m_WorkerMap.erase(WorkerId);
+ }
+
+ return HttpReq.WriteResponse(HttpResponseCode::OK, Output);
+ }
+ break;
+
+ case HttpVerb::kPost:
+ {
+ CbObject Output;
+ HttpResponseCode ResponseCode = ExecActionUpstream(Worker, Output);
+ if (ResponseCode != HttpResponseCode::OK)
+ {
+ return HttpReq.WriteResponse(ResponseCode);
+ }
+ return HttpReq.WriteResponse(HttpResponseCode::OK, Output);
+ }
+ break;
+
+ default:
+ break;
+ }
+ },
+ HttpVerb::kGet | HttpVerb::kPost);
+
+ m_Router.RegisterRoute(
+ "jobs/{worker}",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+ const IoHash WorkerId = IoHash::FromHexString(Req.GetCapture(1));
+
+ WorkerDesc Worker;
+
+ {
+ RwLock::SharedLockScope _(m_WorkerLock);
+
+ if (auto It = m_WorkerMap.find(WorkerId); It == m_WorkerMap.end())
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound);
+ }
+ else
+ {
+ Worker = It->second;
+ }
+ }
+
+ switch (HttpReq.RequestVerb())
+ {
+ case HttpVerb::kGet:
+ // TODO: return status of all pending or executing jobs
+ break;
+
+ case HttpVerb::kPost:
+ switch (HttpReq.RequestContentType())
+ {
+ case HttpContentType::kCbObject:
+ {
+ // This operation takes the proposed job spec and identifies which
+ // chunks are not present on this server. This list is then returned in
+ // the "need" list in the response
+
+ IoBuffer Payload = HttpReq.ReadPayload();
+ CbObject RequestObject = LoadCompactBinaryObject(Payload);
+
+ std::vector<IoHash> NeedList;
+
+ RequestObject.IterateAttachments([&](CbFieldView Field) {
+ const IoHash FileHash = Field.AsHash();
+
+ if (!m_CidStore.ContainsChunk(FileHash))
+ {
+ NeedList.push_back(FileHash);
+ }
+ });
+
+ if (NeedList.empty())
+ {
+ // We already have everything
+ CbObject Output;
+ HttpResponseCode ResponseCode = ExecActionUpstream(Worker, RequestObject, Output);
+
+ if (ResponseCode != HttpResponseCode::OK)
+ {
+ return HttpReq.WriteResponse(ResponseCode);
+ }
+ return HttpReq.WriteResponse(HttpResponseCode::OK, Output);
+ }
+
+ CbObjectWriter Cbo;
+ Cbo.BeginArray("need");
+
+ for (const IoHash& Hash : NeedList)
+ {
+ Cbo << Hash;
+ }
+
+ Cbo.EndArray();
+ CbObject Response = Cbo.Save();
+
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound, Response);
+ }
+ break;
+
+ case HttpContentType::kCbPackage:
+ {
+ CbPackage Action = HttpReq.ReadPayloadPackage();
+ CbObject ActionObj = Action.GetObject();
+
+ std::span<const CbAttachment> Attachments = Action.GetAttachments();
+
+ int AttachmentCount = 0;
+ int NewAttachmentCount = 0;
+ uint64_t TotalAttachmentBytes = 0;
+ uint64_t TotalNewBytes = 0;
+
+ for (const CbAttachment& Attachment : Attachments)
+ {
+ ZEN_ASSERT(Attachment.IsCompressedBinary());
+
+ const IoHash DataHash = Attachment.GetHash();
+ CompressedBuffer DataView = Attachment.AsCompressedBinary();
+
+ ZEN_UNUSED(DataHash);
+
+ const uint64_t CompressedSize = DataView.GetCompressedSize();
+
+ TotalAttachmentBytes += CompressedSize;
+ ++AttachmentCount;
+
+ const CidStore::InsertResult InsertResult =
+ m_CidStore.AddChunk(DataView.GetCompressed().Flatten().AsIoBuffer(), DataHash);
+
+ if (InsertResult.New)
+ {
+ TotalNewBytes += CompressedSize;
+ ++NewAttachmentCount;
+ }
+ }
+
+ ZEN_DEBUG("new action: {} in {} attachments. {} new ({} attachments)",
+ zen::NiceBytes(TotalAttachmentBytes),
+ AttachmentCount,
+ zen::NiceBytes(TotalNewBytes),
+ NewAttachmentCount);
+
+ CbObject Output;
+ HttpResponseCode ResponseCode = ExecActionUpstream(Worker, ActionObj, Output);
+
+ if (ResponseCode != HttpResponseCode::OK)
+ {
+ return HttpReq.WriteResponse(ResponseCode);
+ }
+ return HttpReq.WriteResponse(HttpResponseCode::OK, Output);
+ }
+ break;
+
+ default:
+ break;
+ }
+ break;
+
+ default:
+ break;
+ }
+ },
+ HttpVerb::kPost);
+}
+
+HttpFunctionService::~HttpFunctionService()
+{
+}
+
+const char*
+HttpFunctionService::BaseUri() const
+{
+ return "/apply/";
+}
+
+void
+HttpFunctionService::HandleRequest(HttpServerRequest& Request)
+{
+ if (m_Router.HandleRequest(Request) == false)
+ {
+ ZEN_WARN("No route found for {0}", Request.RelativeUri());
+ }
+}
+
+HttpResponseCode
+HttpFunctionService::ExecActionUpstream(const WorkerDesc& Worker, CbObject& Object)
+{
+ const IoHash WorkerId = Worker.Descriptor.GetHash();
+
+ ZEN_INFO("Action {} being processed...", WorkerId.ToHexString());
+
+ auto EnqueueResult = m_UpstreamApply->EnqueueUpstream({.WorkerDescriptor = Worker.Descriptor, .Type = UpstreamApplyType::Simple});
+ if (!EnqueueResult.Success)
+ {
+ ZEN_ERROR("Error enqueuing upstream Action {}", WorkerId.ToHexString());
+ return HttpResponseCode::InternalServerError;
+ }
+
+ CbObjectWriter Writer;
+ Writer.AddHash("worker", WorkerId);
+
+ Object = Writer.Save();
+ return HttpResponseCode::OK;
+}
+
+HttpResponseCode
+HttpFunctionService::ExecActionUpstreamResult(const IoHash& WorkerId, CbObject& Object)
+{
+ const static IoHash Empty = CbObject().GetHash();
+ auto Status = m_UpstreamApply->GetStatus(WorkerId, Empty);
+ if (!Status.Success)
+ {
+ return HttpResponseCode::NotFound;
+ }
+
+ if (Status.Status.State != UpstreamApplyState::Complete)
+ {
+ return HttpResponseCode::Accepted;
+ }
+
+ GetUpstreamApplyResult& Completed = Status.Status.Result;
+
+ if (!Completed.Success)
+ {
+ ZEN_ERROR("Action {} failed:\n stdout: {}\n stderr: {}\n reason: {}\n errorcode: {}",
+ WorkerId.ToHexString(),
+ Completed.StdOut,
+ Completed.StdErr,
+ Completed.Error.Reason,
+ Completed.Error.ErrorCode);
+
+ if (Completed.Error.ErrorCode == 0)
+ {
+ Completed.Error.ErrorCode = -1;
+ }
+ if (Completed.StdErr.empty() && !Completed.Error.Reason.empty())
+ {
+ Completed.StdErr = Completed.Error.Reason;
+ }
+ }
+ else
+ {
+ ZEN_INFO("Action {} completed with {} files ExitCode={}",
+ WorkerId.ToHexString(),
+ Completed.OutputFiles.size(),
+ Completed.Error.ErrorCode);
+ }
+
+ CbObjectWriter ResultObject;
+
+ ResultObject.AddString("agent"sv, Completed.Agent);
+ ResultObject.AddString("detail"sv, Completed.Detail);
+ ResultObject.AddString("stdout"sv, Completed.StdOut);
+ ResultObject.AddString("stderr"sv, Completed.StdErr);
+ ResultObject.AddInteger("exitcode"sv, Completed.Error.ErrorCode);
+ ResultObject.BeginArray("stats"sv);
+ for (const auto& Timepoint : Completed.Timepoints)
+ {
+ ResultObject.BeginObject();
+ ResultObject.AddString("name"sv, Timepoint.first);
+ ResultObject.AddDateTimeTicks("time"sv, Timepoint.second);
+ ResultObject.EndObject();
+ }
+ ResultObject.EndArray();
+
+ ResultObject.BeginArray("files"sv);
+ for (const auto& File : Completed.OutputFiles)
+ {
+ ResultObject.BeginObject();
+ ResultObject.AddString("name"sv, File.first.string());
+ ResultObject.AddBinary("data"sv, Completed.FileData[File.second]);
+ ResultObject.EndObject();
+ }
+ ResultObject.EndArray();
+
+ Object = ResultObject.Save();
+ return HttpResponseCode::OK;
+}
+
+HttpResponseCode
+HttpFunctionService::ExecActionUpstream(const WorkerDesc& Worker, CbObject Action, CbObject& Object)
+{
+ const IoHash WorkerId = Worker.Descriptor.GetHash();
+ const IoHash ActionId = Action.GetHash();
+
+ Action.MakeOwned();
+
+ ZEN_INFO("Action {}/{} being processed...", WorkerId.ToHexString(), ActionId.ToHexString());
+
+ auto EnqueueResult = m_UpstreamApply->EnqueueUpstream(
+ {.WorkerDescriptor = Worker.Descriptor, .Action = std::move(Action), .Type = UpstreamApplyType::Asset});
+
+ if (!EnqueueResult.Success)
+ {
+ ZEN_ERROR("Error enqueuing upstream Action {}/{}", WorkerId.ToHexString(), ActionId.ToHexString());
+ return HttpResponseCode::InternalServerError;
+ }
+
+ CbObjectWriter Writer;
+ Writer.AddHash("worker", WorkerId);
+ Writer.AddHash("action", ActionId);
+
+ Object = Writer.Save();
+ return HttpResponseCode::OK;
+}
+
+HttpResponseCode
+HttpFunctionService::ExecActionUpstreamResult(const IoHash& WorkerId, const IoHash& ActionId, CbPackage& Package)
+{
+ auto Status = m_UpstreamApply->GetStatus(WorkerId, ActionId);
+ if (!Status.Success)
+ {
+ return HttpResponseCode::NotFound;
+ }
+
+ if (Status.Status.State != UpstreamApplyState::Complete)
+ {
+ return HttpResponseCode::Accepted;
+ }
+
+ GetUpstreamApplyResult& Completed = Status.Status.Result;
+ if (!Completed.Success || Completed.Error.ErrorCode != 0)
+ {
+ ZEN_ERROR("Action {}/{} failed:\n stdout: {}\n stderr: {}\n reason: {}\n errorcode: {}",
+ WorkerId.ToHexString(),
+ ActionId.ToHexString(),
+ Completed.StdOut,
+ Completed.StdErr,
+ Completed.Error.Reason,
+ Completed.Error.ErrorCode);
+
+ return HttpResponseCode::InternalServerError;
+ }
+
+ ZEN_INFO("Action {}/{} completed with {} attachments ({} compressed, {} uncompressed)",
+ WorkerId.ToHexString(),
+ ActionId.ToHexString(),
+ Completed.OutputPackage.GetAttachments().size(),
+ NiceBytes(Completed.TotalAttachmentBytes),
+ NiceBytes(Completed.TotalRawAttachmentBytes));
+
+ Package = std::move(Completed.OutputPackage);
+ return HttpResponseCode::OK;
+}
+
+} // namespace zen
+
+#endif // ZEN_WITH_COMPUTE_SERVICES
diff --git a/src/zenserver/compute/function.h b/src/zenserver/compute/function.h
new file mode 100644
index 000000000..650cee757
--- /dev/null
+++ b/src/zenserver/compute/function.h
@@ -0,0 +1,73 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/zencore.h>
+
+#if !defined(ZEN_WITH_COMPUTE_SERVICES)
+# define ZEN_WITH_COMPUTE_SERVICES 1
+#endif
+
+#if ZEN_WITH_COMPUTE_SERVICES
+
+# include <zencore/compactbinary.h>
+# include <zencore/iohash.h>
+# include <zencore/logging.h>
+# include <zenhttp/httpserver.h>
+
+# include <filesystem>
+# include <unordered_map>
+
+namespace zen {
+
+class CidStore;
+class UpstreamApply;
+class CloudCacheClient;
+class AuthMgr;
+
+struct UpstreamAuthConfig;
+struct CloudCacheClientOptions;
+
+/**
+ * Lambda style compute function service
+ */
+class HttpFunctionService : public HttpService
+{
+public:
+ HttpFunctionService(CidStore& InCidStore,
+ const CloudCacheClientOptions& ComputeOptions,
+ const CloudCacheClientOptions& StorageOptions,
+ const UpstreamAuthConfig& ComputeAuthConfig,
+ const UpstreamAuthConfig& StorageAuthConfig,
+ AuthMgr& Mgr);
+ ~HttpFunctionService();
+
+ virtual const char* BaseUri() const override;
+ virtual void HandleRequest(HttpServerRequest& Request) override;
+
+private:
+ std::thread InitializeThread;
+ spdlog::logger& Log() { return m_Log; }
+ spdlog::logger& m_Log;
+ HttpRequestRouter m_Router;
+ CidStore& m_CidStore;
+ std::unique_ptr<UpstreamApply> m_UpstreamApply;
+
+ struct WorkerDesc
+ {
+ CbObject Descriptor;
+ };
+
+ [[nodiscard]] HttpResponseCode ExecActionUpstream(const WorkerDesc& Worker, CbObject& Object);
+ [[nodiscard]] HttpResponseCode ExecActionUpstreamResult(const IoHash& WorkerId, CbObject& Object);
+
+ [[nodiscard]] HttpResponseCode ExecActionUpstream(const WorkerDesc& Worker, CbObject Action, CbObject& Object);
+ [[nodiscard]] HttpResponseCode ExecActionUpstreamResult(const IoHash& WorkerId, const IoHash& ActionId, CbPackage& Package);
+
+ RwLock m_WorkerLock;
+ std::unordered_map<IoHash, WorkerDesc> m_WorkerMap;
+};
+
+} // namespace zen
+
+#endif // ZEN_WITH_COMPUTE_SERVICES
diff --git a/src/zenserver/config.cpp b/src/zenserver/config.cpp
new file mode 100644
index 000000000..cff93d67b
--- /dev/null
+++ b/src/zenserver/config.cpp
@@ -0,0 +1,902 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "config.h"
+
+#include "diag/logging.h"
+
+#include <zencore/crypto.h>
+#include <zencore/fmtutils.h>
+#include <zencore/iobuffer.h>
+#include <zencore/string.h>
+#include <zenhttp/zenhttp.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <fmt/format.h>
+#include <zencore/logging.h>
+#include <cxxopts.hpp>
+#include <sol/sol.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+#if ZEN_PLATFORM_WINDOWS
+# include <conio.h>
+#else
+# include <pwd.h>
+#endif
+
+#if ZEN_PLATFORM_WINDOWS
+
+// Used for getting My Documents for default data directory
+# include <ShlObj.h>
+# pragma comment(lib, "shell32.lib")
+
+std::filesystem::path
+PickDefaultStateDirectory()
+{
+ // Pick sensible default
+ PWSTR programDataDir = nullptr;
+ HRESULT hRes = SHGetKnownFolderPath(FOLDERID_ProgramData, 0, NULL, &programDataDir);
+
+ if (SUCCEEDED(hRes))
+ {
+ std::filesystem::path finalPath(programDataDir);
+ finalPath /= L"Epic\\Zen\\Data";
+ ::CoTaskMemFree(programDataDir);
+
+ return finalPath;
+ }
+
+ return L"";
+}
+
+#else
+
+std::filesystem::path
+PickDefaultStateDirectory()
+{
+ int UserId = getuid();
+ const passwd* Passwd = getpwuid(UserId);
+ return std::filesystem::path(Passwd->pw_dir) / ".zen";
+}
+
+#endif
+
+void
+ValidateOptions(ZenServerOptions& ServerOptions)
+{
+ if (ServerOptions.EncryptionKey.empty() == false)
+ {
+ const auto Key = zen::AesKey256Bit::FromString(ServerOptions.EncryptionKey);
+
+ if (Key.IsValid() == false)
+ {
+ throw cxxopts::OptionParseException("Invalid AES encryption key");
+ }
+ }
+
+ if (ServerOptions.EncryptionIV.empty() == false)
+ {
+ const auto IV = zen::AesIV128Bit::FromString(ServerOptions.EncryptionIV);
+
+ if (IV.IsValid() == false)
+ {
+ throw cxxopts::OptionParseException("Invalid AES initialization vector");
+ }
+ }
+}
+
+UpstreamCachePolicy
+ParseUpstreamCachePolicy(std::string_view Options)
+{
+ if (Options == "readonly")
+ {
+ return UpstreamCachePolicy::Read;
+ }
+ else if (Options == "writeonly")
+ {
+ return UpstreamCachePolicy::Write;
+ }
+ else if (Options == "disabled")
+ {
+ return UpstreamCachePolicy::Disabled;
+ }
+ else
+ {
+ return UpstreamCachePolicy::ReadWrite;
+ }
+}
+
+ZenObjectStoreConfig
+ParseBucketConfigs(std::span<std::string> Buckets)
+{
+ using namespace std::literals;
+
+ ZenObjectStoreConfig Cfg;
+
+ // split bucket args in the form of "{BucketName};{LocalPath}"
+ for (std::string_view Bucket : Buckets)
+ {
+ ZenObjectStoreConfig::BucketConfig NewBucket;
+
+ if (auto Idx = Bucket.find_first_of(";"); Idx != std::string_view::npos)
+ {
+ NewBucket.Name = Bucket.substr(0, Idx);
+ NewBucket.Directory = Bucket.substr(Idx + 1);
+ }
+ else
+ {
+ NewBucket.Name = Bucket;
+ }
+
+ Cfg.Buckets.push_back(std::move(NewBucket));
+ }
+
+ return Cfg;
+}
+
+void ParseConfigFile(const std::filesystem::path& Path, ZenServerOptions& ServerOptions);
+
+void
+ParseCliOptions(int argc, char* argv[], ZenServerOptions& ServerOptions)
+{
+#if ZEN_WITH_HTTPSYS
+ const char* DefaultHttp = "httpsys";
+#else
+ const char* DefaultHttp = "asio";
+#endif
+
+ // Note to those adding future options; std::filesystem::path-type options
+ // must be read into a std::string first. As of cxxopts-3.0.0 it uses a >>
+ // stream operator to convert argv value into the options type. std::fs::path
+ // expects paths in streams to be quoted but argv paths are unquoted. By
+ // going into a std::string first, paths with whitespace parse correctly.
+ std::string DataDir;
+ std::string ContentDir;
+ std::string AbsLogFile;
+ std::string ConfigFile;
+
+ cxxopts::Options options("zenserver", "Zen Server");
+ options.add_options()("dedicated",
+ "Enable dedicated server mode",
+ cxxopts::value<bool>(ServerOptions.IsDedicated)->default_value("false"));
+ options.add_options()("d, debug", "Enable debugging", cxxopts::value<bool>(ServerOptions.IsDebug)->default_value("false"));
+ options.add_options()("help", "Show command line help");
+ options.add_options()("t, test", "Enable test mode", cxxopts::value<bool>(ServerOptions.IsTest)->default_value("false"));
+ options.add_options()("log-id", "Specify id for adding context to log output", cxxopts::value<std::string>(ServerOptions.LogId));
+ options.add_options()("data-dir", "Specify persistence root", cxxopts::value<std::string>(DataDir));
+ options.add_options()("content-dir", "Frontend content directory", cxxopts::value<std::string>(ContentDir));
+ options.add_options()("abslog", "Path to log file", cxxopts::value<std::string>(AbsLogFile));
+ options.add_options()("config", "Path to Lua config file", cxxopts::value<std::string>(ConfigFile));
+ options.add_options()("no-sentry",
+ "Disable Sentry crash handler",
+ cxxopts::value<bool>(ServerOptions.NoSentry)->default_value("false"));
+
+ options.add_option("security",
+ "",
+ "encryption-aes-key",
+ "256 bit AES encryption key",
+ cxxopts::value<std::string>(ServerOptions.EncryptionKey),
+ "");
+
+ options.add_option("security",
+ "",
+ "encryption-aes-iv",
+ "128 bit AES encryption initialization vector",
+ cxxopts::value<std::string>(ServerOptions.EncryptionIV),
+ "");
+
+ std::string OpenIdProviderName;
+ options.add_option("security",
+ "",
+ "openid-provider-name",
+ "Open ID provider name",
+ cxxopts::value<std::string>(OpenIdProviderName),
+ "Default");
+
+ std::string OpenIdProviderUrl;
+ options.add_option("security", "", "openid-provider-url", "Open ID provider URL", cxxopts::value<std::string>(OpenIdProviderUrl), "");
+
+ std::string OpenIdClientId;
+ options.add_option("security", "", "openid-client-id", "Open ID client ID", cxxopts::value<std::string>(OpenIdClientId), "");
+
+ options
+ .add_option("lifetime", "", "owner-pid", "Specify owning process id", cxxopts::value<int>(ServerOptions.OwnerPid), "<identifier>");
+ options.add_option("lifetime",
+ "",
+ "child-id",
+ "Specify id which can be used to signal parent",
+ cxxopts::value<std::string>(ServerOptions.ChildId),
+ "<identifier>");
+
+#if ZEN_PLATFORM_WINDOWS
+ options.add_option("lifetime",
+ "",
+ "install",
+ "Install zenserver as a Windows service",
+ cxxopts::value<bool>(ServerOptions.InstallService),
+ "");
+ options.add_option("lifetime",
+ "",
+ "uninstall",
+ "Uninstall zenserver as a Windows service",
+ cxxopts::value<bool>(ServerOptions.UninstallService),
+ "");
+#endif
+
+ options.add_option("network",
+ "",
+ "http",
+ "Select HTTP server implementation (asio|httpsys|null)",
+ cxxopts::value<std::string>(ServerOptions.HttpServerClass)->default_value(DefaultHttp),
+ "<http class>");
+
+ options.add_option("network",
+ "p",
+ "port",
+ "Select HTTP port",
+ cxxopts::value<int>(ServerOptions.BasePort)->default_value("1337"),
+ "<port number>");
+
+ options.add_option("network",
+ "",
+ "websocket-port",
+ "Websocket server port",
+ cxxopts::value<int>(ServerOptions.WebSocketPort)->default_value("0"),
+ "<port number>");
+
+ options.add_option("network",
+ "",
+ "websocket-threads",
+ "Number of websocket I/O thread(s) (0 == hardware concurrency)",
+ cxxopts::value<int>(ServerOptions.WebSocketThreads)->default_value("0"),
+ "");
+
+#if ZEN_WITH_TRACE
+ options.add_option("ue-trace",
+ "",
+ "tracehost",
+ "Hostname to send the trace to",
+ cxxopts::value<std::string>(ServerOptions.TraceHost)->default_value(""),
+ "");
+
+ options.add_option("ue-trace",
+ "",
+ "tracefile",
+ "Path to write a trace to",
+ cxxopts::value<std::string>(ServerOptions.TraceFile)->default_value(""),
+ "");
+#endif // ZEN_WITH_TRACE
+
+ options.add_option("diagnostics",
+ "",
+ "crash",
+ "Simulate a crash",
+ cxxopts::value<bool>(ServerOptions.ShouldCrash)->default_value("false"),
+ "");
+
+ std::string UpstreamCachePolicyOptions;
+ options.add_option("cache",
+ "",
+ "upstream-cache-policy",
+ "",
+ cxxopts::value<std::string>(UpstreamCachePolicyOptions)->default_value(""),
+ "Upstream cache policy (readwrite|readonly|writeonly|disabled)");
+
+ options.add_option("cache",
+ "",
+ "upstream-jupiter-url",
+ "URL to a Jupiter instance",
+ cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.JupiterConfig.Url)->default_value(""),
+ "");
+
+ options.add_option("cache",
+ "",
+ "upstream-jupiter-oauth-url",
+ "URL to the OAuth provier",
+ cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.JupiterConfig.OAuthUrl)->default_value(""),
+ "");
+
+ options.add_option("cache",
+ "",
+ "upstream-jupiter-oauth-clientid",
+ "The OAuth client ID",
+ cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.JupiterConfig.OAuthClientId)->default_value(""),
+ "");
+
+ options.add_option("cache",
+ "",
+ "upstream-jupiter-oauth-clientsecret",
+ "The OAuth client secret",
+ cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.JupiterConfig.OAuthClientSecret)->default_value(""),
+ "");
+
+ options.add_option("cache",
+ "",
+ "upstream-jupiter-openid-provider",
+ "Name of a registered Open ID provider",
+ cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.JupiterConfig.OpenIdProvider)->default_value(""),
+ "");
+
+ options.add_option("cache",
+ "",
+ "upstream-jupiter-token",
+ "A static authentication token",
+ cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.JupiterConfig.AccessToken)->default_value(""),
+ "");
+
+ options.add_option("cache",
+ "",
+ "upstream-jupiter-namespace",
+ "The Common Blob Store API namespace",
+ cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.JupiterConfig.Namespace)->default_value(""),
+ "");
+
+ options.add_option("cache",
+ "",
+ "upstream-jupiter-namespace-ddc",
+ "The lecacy DDC namespace",
+ cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.JupiterConfig.DdcNamespace)->default_value(""),
+ "");
+
+ options.add_option("cache",
+ "",
+ "upstream-zen-url",
+ "URL to remote Zen server. Use a comma separated list to choose the one with the best latency.",
+ cxxopts::value<std::vector<std::string>>(ServerOptions.UpstreamCacheConfig.ZenConfig.Urls),
+ "");
+
+ options.add_option("cache",
+ "",
+ "upstream-zen-dns",
+ "DNS that resolves to one or more Zen server instance(s)",
+ cxxopts::value<std::vector<std::string>>(ServerOptions.UpstreamCacheConfig.ZenConfig.Dns),
+ "");
+
+ options.add_option("cache",
+ "",
+ "upstream-thread-count",
+ "Number of threads used for upstream procsssing",
+ cxxopts::value<int32_t>(ServerOptions.UpstreamCacheConfig.UpstreamThreadCount)->default_value("4"),
+ "");
+
+ options.add_option("cache",
+ "",
+ "upstream-connect-timeout-ms",
+ "Connect timeout in millisecond(s). Default 5000 ms.",
+ cxxopts::value<int32_t>(ServerOptions.UpstreamCacheConfig.ConnectTimeoutMilliseconds)->default_value("5000"),
+ "");
+
+ options.add_option("cache",
+ "",
+ "upstream-timeout-ms",
+ "Timeout in millisecond(s). Default 0 ms",
+ cxxopts::value<int32_t>(ServerOptions.UpstreamCacheConfig.TimeoutMilliseconds)->default_value("0"),
+ "");
+
+ options.add_option("compute",
+ "",
+ "upstream-horde-url",
+ "URL to a Horde instance.",
+ cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.HordeConfig.Url)->default_value(""),
+ "");
+
+ options.add_option("compute",
+ "",
+ "upstream-horde-oauth-url",
+ "URL to the OAuth provier",
+ cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.HordeConfig.OAuthUrl)->default_value(""),
+ "");
+
+ options.add_option("compute",
+ "",
+ "upstream-horde-oauth-clientid",
+ "The OAuth client ID",
+ cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.HordeConfig.OAuthClientId)->default_value(""),
+ "");
+
+ options.add_option("compute",
+ "",
+ "upstream-horde-oauth-clientsecret",
+ "The OAuth client secret",
+ cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.HordeConfig.OAuthClientSecret)->default_value(""),
+ "");
+
+ options.add_option("compute",
+ "",
+ "upstream-horde-openid-provider",
+ "Name of a registered Open ID provider",
+ cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.HordeConfig.OpenIdProvider)->default_value(""),
+ "");
+
+ options.add_option("compute",
+ "",
+ "upstream-horde-token",
+ "A static authentication token",
+ cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.HordeConfig.AccessToken)->default_value(""),
+ "");
+
+ options.add_option("compute",
+ "",
+ "upstream-horde-storage-url",
+ "URL to a Horde Storage instance.",
+ cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.HordeConfig.StorageUrl)->default_value(""),
+ "");
+
+ options.add_option("compute",
+ "",
+ "upstream-horde-storage-oauth-url",
+ "URL to the OAuth provier",
+ cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.HordeConfig.StorageOAuthUrl)->default_value(""),
+ "");
+
+ options.add_option("compute",
+ "",
+ "upstream-horde-storage-oauth-clientid",
+ "The OAuth client ID",
+ cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.HordeConfig.StorageOAuthClientId)->default_value(""),
+ "");
+
+ options.add_option(
+ "compute",
+ "",
+ "upstream-horde-storage-oauth-clientsecret",
+ "The OAuth client secret",
+ cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.HordeConfig.StorageOAuthClientSecret)->default_value(""),
+ "");
+
+ options.add_option("compute",
+ "",
+ "upstream-horde-storage-openid-provider",
+ "Name of a registered Open ID provider",
+ cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.HordeConfig.StorageOpenIdProvider)->default_value(""),
+ "");
+
+ options.add_option("compute",
+ "",
+ "upstream-horde-storage-token",
+ "A static authentication token",
+ cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.HordeConfig.StorageAccessToken)->default_value(""),
+ "");
+
+ options.add_option("compute",
+ "",
+ "upstream-horde-cluster",
+ "The Horde compute cluster id",
+ cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.HordeConfig.Cluster)->default_value(""),
+ "");
+
+ options.add_option("compute",
+ "",
+ "upstream-horde-namespace",
+ "The Jupiter namespace to use with Horde compute",
+ cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.HordeConfig.Namespace)->default_value(""),
+ "");
+
+ options.add_option("gc",
+ "",
+ "gc-enabled",
+ "Whether garbage collection is enabled or not.",
+ cxxopts::value<bool>(ServerOptions.GcConfig.Enabled)->default_value("true"),
+ "");
+
+ options.add_option("gc",
+ "",
+ "gc-small-objects",
+ "Whether garbage collection of small objects is enabled or not.",
+ cxxopts::value<bool>(ServerOptions.GcConfig.CollectSmallObjects)->default_value("true"),
+ "");
+
+ options.add_option("gc",
+ "",
+ "gc-interval-seconds",
+ "Garbage collection interval in seconds. Default set to 3600 (1 hour).",
+ cxxopts::value<int32_t>(ServerOptions.GcConfig.IntervalSeconds)->default_value("3600"),
+ "");
+
+ options.add_option("gc",
+ "",
+ "gc-cache-duration-seconds",
+ "Max duration in seconds before Z$ entries get evicted. Default set to 1209600 (2 weeks)",
+ cxxopts::value<int32_t>(ServerOptions.GcConfig.Cache.MaxDurationSeconds)->default_value("1209600"),
+ "");
+
+ options.add_option("gc",
+ "",
+ "disk-reserve-size",
+ "Size of gc disk reserve in bytes. Default set to 268435456 (256 Mb).",
+ cxxopts::value<uint64_t>(ServerOptions.GcConfig.DiskReserveSize)->default_value("268435456"),
+ "");
+
+ options.add_option("gc",
+ "",
+ "gc-monitor-interval-seconds",
+ "Garbage collection monitoring interval in seconds. Default set to 30 (30 seconds)",
+ cxxopts::value<int32_t>(ServerOptions.GcConfig.MonitorIntervalSeconds)->default_value("30"),
+ "");
+
+ options.add_option("gc",
+ "",
+ "gc-disksize-softlimit",
+ "Garbage collection disk usage soft limit. Default set to 0 (Off).",
+ cxxopts::value<uint64_t>(ServerOptions.GcConfig.Cache.DiskSizeSoftLimit)->default_value("0"),
+ "");
+
+ options.add_option("objectstore",
+ "",
+ "objectstore-enabled",
+ "Whether the object store is enabled or not.",
+ cxxopts::value<bool>(ServerOptions.ObjectStoreEnabled)->default_value("false"),
+ "");
+
+ std::vector<std::string> BucketConfigs;
+ options.add_option("objectstore",
+ "",
+ "objectstore-bucket",
+ "Object store bucket mappings.",
+ cxxopts::value<std::vector<std::string>>(BucketConfigs),
+ "");
+
+ try
+ {
+ auto result = options.parse(argc, argv);
+
+ if (result.count("help"))
+ {
+ zen::logging::ConsoleLog().info("{}", options.help());
+#if ZEN_PLATFORM_WINDOWS
+ zen::logging::ConsoleLog().info("Press any key to exit!");
+ _getch();
+#else
+ // Assume the user's in a terminal on all other platforms and that
+ // they'll use less/more/etc. if need be.
+#endif
+ exit(0);
+ }
+
+ auto MakeSafePath = [](const std::string& Path) {
+#if ZEN_PLATFORM_WINDOWS
+ if (Path.empty())
+ {
+ return Path;
+ }
+
+ std::string FixedPath = Path;
+ std::replace(FixedPath.begin(), FixedPath.end(), '/', '\\');
+ if (!FixedPath.starts_with("\\\\?\\"))
+ {
+ FixedPath.insert(0, "\\\\?\\");
+ }
+ return FixedPath;
+#else
+ return Path;
+#endif
+ };
+
+ ServerOptions.DataDir = MakeSafePath(DataDir);
+ ServerOptions.ContentDir = MakeSafePath(ContentDir);
+ ServerOptions.AbsLogFile = MakeSafePath(AbsLogFile);
+ ServerOptions.ConfigFile = MakeSafePath(ConfigFile);
+ ServerOptions.UpstreamCacheConfig.CachePolicy = ParseUpstreamCachePolicy(UpstreamCachePolicyOptions);
+
+ if (OpenIdProviderUrl.empty() == false)
+ {
+ if (OpenIdClientId.empty())
+ {
+ throw cxxopts::OptionParseException("Invalid OpenID client ID");
+ }
+
+ ServerOptions.AuthConfig.OpenIdProviders.push_back(
+ {.Name = OpenIdProviderName, .Url = OpenIdProviderUrl, .ClientId = OpenIdClientId});
+ }
+
+ ServerOptions.ObjectStoreConfig = ParseBucketConfigs(BucketConfigs);
+
+ if (!ServerOptions.ConfigFile.empty())
+ {
+ ParseConfigFile(ServerOptions.ConfigFile, ServerOptions);
+ }
+ else
+ {
+ ParseConfigFile(ServerOptions.DataDir / "zen_cfg.lua", ServerOptions);
+ }
+
+ ValidateOptions(ServerOptions);
+ }
+ catch (cxxopts::OptionParseException& e)
+ {
+ zen::logging::ConsoleLog().error("Error parsing zenserver arguments: {}\n\n{}", e.what(), options.help());
+
+ throw;
+ }
+
+ if (ServerOptions.DataDir.empty())
+ {
+ ServerOptions.DataDir = PickDefaultStateDirectory();
+ }
+
+ if (ServerOptions.AbsLogFile.empty())
+ {
+ ServerOptions.AbsLogFile = ServerOptions.DataDir / "logs" / "zenserver.log";
+ }
+}
+
+void
+ParseConfigFile(const std::filesystem::path& Path, ZenServerOptions& ServerOptions)
+{
+ zen::IoBuffer LuaScript = zen::IoBufferBuilder::MakeFromFile(Path);
+
+ if (LuaScript)
+ {
+ sol::state lua;
+
+ lua.open_libraries(sol::lib::base);
+
+ lua.set_function("getenv", [&](const std::string env) -> sol::object {
+#if ZEN_PLATFORM_WINDOWS
+ std::wstring EnvVarValue;
+ size_t RequiredSize = 0;
+ std::wstring EnvWide = zen::Utf8ToWide(env);
+ _wgetenv_s(&RequiredSize, nullptr, 0, EnvWide.c_str());
+
+ if (RequiredSize == 0)
+ return sol::make_object(lua, sol::lua_nil);
+
+ EnvVarValue.resize(RequiredSize);
+ _wgetenv_s(&RequiredSize, EnvVarValue.data(), RequiredSize, EnvWide.c_str());
+ return sol::make_object(lua, zen::WideToUtf8(EnvVarValue.c_str()));
+#else
+ ZEN_UNUSED(env);
+ return sol::make_object(lua, sol::lua_nil);
+#endif
+ });
+
+ try
+ {
+ sol::load_result config = lua.load(std::string_view((const char*)LuaScript.Data(), LuaScript.Size()), "zen_cfg");
+
+ if (!config.valid())
+ {
+ sol::error err = config;
+
+ std::string ErrorString = sol::to_string(config.status());
+
+ throw std::runtime_error(fmt::format("{} error: {}", ErrorString, err.what()));
+ }
+
+ config();
+ }
+ catch (std::exception& e)
+ {
+ throw std::runtime_error(fmt::format("failed to load config script ('{}'): {}", Path, e.what()).c_str());
+ }
+
+ if (sol::optional<sol::table> ServerConfig = lua["server"])
+ {
+ if (ServerOptions.DataDir.empty())
+ {
+ if (sol::optional<std::string> Opt = ServerConfig.value()["datadir"])
+ {
+ ServerOptions.DataDir = Opt.value();
+ }
+ }
+
+ if (ServerOptions.ContentDir.empty())
+ {
+ if (sol::optional<std::string> Opt = ServerConfig.value()["contentdir"])
+ {
+ ServerOptions.ContentDir = Opt.value();
+ }
+ }
+
+ if (ServerOptions.AbsLogFile.empty())
+ {
+ if (sol::optional<std::string> Opt = ServerConfig.value()["abslog"])
+ {
+ ServerOptions.AbsLogFile = Opt.value();
+ }
+ }
+
+ ServerOptions.IsDebug = ServerConfig->get_or("debug", ServerOptions.IsDebug);
+ }
+
+ if (sol::optional<sol::table> NetworkConfig = lua["network"])
+ {
+ if (sol::optional<std::string> Opt = NetworkConfig.value()["httpserverclass"])
+ {
+ ServerOptions.HttpServerClass = Opt.value();
+ }
+
+ ServerOptions.BasePort = NetworkConfig->get_or<int>("port", ServerOptions.BasePort);
+ }
+
+ auto UpdateStringValueFromConfig = [](const sol::table& Table, std::string_view Key, std::string& OutValue) {
+ // Update the specified config value unless it has been set, i.e. from command line
+ if (auto MaybeValue = Table.get<sol::optional<std::string>>(Key); MaybeValue.has_value() && OutValue.empty())
+ {
+ OutValue = MaybeValue.value();
+ }
+ };
+
+ if (sol::optional<sol::table> StructuredCacheConfig = lua["cache"])
+ {
+ ServerOptions.StructuredCacheEnabled = StructuredCacheConfig->get_or("enable", ServerOptions.StructuredCacheEnabled);
+
+ if (auto UpstreamConfig = StructuredCacheConfig->get<sol::optional<sol::table>>("upstream"))
+ {
+ std::string Policy = UpstreamConfig->get_or("policy", std::string());
+ ServerOptions.UpstreamCacheConfig.CachePolicy = ParseUpstreamCachePolicy(Policy);
+ ServerOptions.UpstreamCacheConfig.UpstreamThreadCount =
+ UpstreamConfig->get_or("upstreamthreadcount", ServerOptions.UpstreamCacheConfig.UpstreamThreadCount);
+
+ if (auto JupiterConfig = UpstreamConfig->get<sol::optional<sol::table>>("jupiter"))
+ {
+ UpdateStringValueFromConfig(JupiterConfig.value(),
+ std::string_view("name"),
+ ServerOptions.UpstreamCacheConfig.JupiterConfig.Name);
+ UpdateStringValueFromConfig(JupiterConfig.value(),
+ std::string_view("url"),
+ ServerOptions.UpstreamCacheConfig.JupiterConfig.Url);
+ UpdateStringValueFromConfig(JupiterConfig.value(),
+ std::string_view("oauthprovider"),
+ ServerOptions.UpstreamCacheConfig.JupiterConfig.OAuthUrl);
+ UpdateStringValueFromConfig(JupiterConfig.value(),
+ std::string_view("oauthclientid"),
+ ServerOptions.UpstreamCacheConfig.JupiterConfig.OAuthClientId);
+ UpdateStringValueFromConfig(JupiterConfig.value(),
+ std::string_view("oauthclientsecret"),
+ ServerOptions.UpstreamCacheConfig.JupiterConfig.OAuthClientSecret);
+ UpdateStringValueFromConfig(JupiterConfig.value(),
+ std::string_view("openidprovider"),
+ ServerOptions.UpstreamCacheConfig.JupiterConfig.OpenIdProvider);
+ UpdateStringValueFromConfig(JupiterConfig.value(),
+ std::string_view("token"),
+ ServerOptions.UpstreamCacheConfig.JupiterConfig.AccessToken);
+ UpdateStringValueFromConfig(JupiterConfig.value(),
+ std::string_view("namespace"),
+ ServerOptions.UpstreamCacheConfig.JupiterConfig.Namespace);
+ UpdateStringValueFromConfig(JupiterConfig.value(),
+ std::string_view("ddcnamespace"),
+ ServerOptions.UpstreamCacheConfig.JupiterConfig.DdcNamespace);
+ };
+
+ if (auto ZenConfig = UpstreamConfig->get<sol::optional<sol::table>>("zen"))
+ {
+ ServerOptions.UpstreamCacheConfig.ZenConfig.Name = ZenConfig.value().get_or("name", std::string("Zen"));
+
+ if (auto Url = ZenConfig.value().get<sol::optional<std::string>>("url"))
+ {
+ ServerOptions.UpstreamCacheConfig.ZenConfig.Urls.push_back(Url.value());
+ }
+ else if (auto Urls = ZenConfig.value().get<sol::optional<sol::table>>("url"))
+ {
+ for (const auto& Kv : Urls.value())
+ {
+ ServerOptions.UpstreamCacheConfig.ZenConfig.Urls.push_back(Kv.second.as<std::string>());
+ }
+ }
+
+ if (auto Dns = ZenConfig.value().get<sol::optional<std::string>>("dns"))
+ {
+ ServerOptions.UpstreamCacheConfig.ZenConfig.Dns.push_back(Dns.value());
+ }
+ else if (auto DnsArray = ZenConfig.value().get<sol::optional<sol::table>>("dns"))
+ {
+ for (const auto& Kv : DnsArray.value())
+ {
+ ServerOptions.UpstreamCacheConfig.ZenConfig.Dns.push_back(Kv.second.as<std::string>());
+ }
+ }
+ }
+ }
+ }
+
+ if (sol::optional<sol::table> ExecConfig = lua["exec"])
+ {
+ ServerOptions.ExecServiceEnabled = ExecConfig->get_or("enable", ServerOptions.ExecServiceEnabled);
+ }
+
+ if (sol::optional<sol::table> ComputeConfig = lua["compute"])
+ {
+ ServerOptions.ComputeServiceEnabled = ComputeConfig->get_or("enable", ServerOptions.ComputeServiceEnabled);
+
+ if (auto UpstreamConfig = ComputeConfig->get<sol::optional<sol::table>>("upstream"))
+ {
+ if (auto HordeConfig = UpstreamConfig->get<sol::optional<sol::table>>("horde"))
+ {
+ UpdateStringValueFromConfig(HordeConfig.value(),
+ std::string_view("name"),
+ ServerOptions.UpstreamCacheConfig.HordeConfig.Name);
+ UpdateStringValueFromConfig(HordeConfig.value(),
+ std::string_view("url"),
+ ServerOptions.UpstreamCacheConfig.HordeConfig.Url);
+ UpdateStringValueFromConfig(HordeConfig.value(),
+ std::string_view("oauthprovider"),
+ ServerOptions.UpstreamCacheConfig.HordeConfig.OAuthUrl);
+ UpdateStringValueFromConfig(HordeConfig.value(),
+ std::string_view("oauthclientid"),
+ ServerOptions.UpstreamCacheConfig.HordeConfig.OAuthClientId);
+ UpdateStringValueFromConfig(HordeConfig.value(),
+ std::string_view("oauthclientsecret"),
+ ServerOptions.UpstreamCacheConfig.HordeConfig.OAuthClientSecret);
+ UpdateStringValueFromConfig(HordeConfig.value(),
+ std::string_view("openidprovider"),
+ ServerOptions.UpstreamCacheConfig.HordeConfig.OpenIdProvider);
+ UpdateStringValueFromConfig(HordeConfig.value(),
+ std::string_view("token"),
+ ServerOptions.UpstreamCacheConfig.HordeConfig.AccessToken);
+ UpdateStringValueFromConfig(HordeConfig.value(),
+ std::string_view("cluster"),
+ ServerOptions.UpstreamCacheConfig.HordeConfig.Cluster);
+ UpdateStringValueFromConfig(HordeConfig.value(),
+ std::string_view("namespace"),
+ ServerOptions.UpstreamCacheConfig.HordeConfig.Namespace);
+ };
+
+ if (auto StorageConfig = UpstreamConfig->get<sol::optional<sol::table>>("storage"))
+ {
+ UpdateStringValueFromConfig(StorageConfig.value(),
+ std::string_view("url"),
+ ServerOptions.UpstreamCacheConfig.HordeConfig.StorageUrl);
+ UpdateStringValueFromConfig(StorageConfig.value(),
+ std::string_view("oauthprovider"),
+ ServerOptions.UpstreamCacheConfig.HordeConfig.StorageOAuthUrl);
+ UpdateStringValueFromConfig(StorageConfig.value(),
+ std::string_view("oauthclientid"),
+ ServerOptions.UpstreamCacheConfig.HordeConfig.StorageOAuthClientId);
+ UpdateStringValueFromConfig(StorageConfig.value(),
+ std::string_view("oauthclientsecret"),
+ ServerOptions.UpstreamCacheConfig.HordeConfig.StorageOAuthClientSecret);
+ UpdateStringValueFromConfig(StorageConfig.value(),
+ std::string_view("openidprovider"),
+ ServerOptions.UpstreamCacheConfig.HordeConfig.StorageOpenIdProvider);
+ UpdateStringValueFromConfig(StorageConfig.value(),
+ std::string_view("token"),
+ ServerOptions.UpstreamCacheConfig.HordeConfig.StorageAccessToken);
+ };
+ }
+ }
+
+ if (sol::optional<sol::table> GcConfig = lua["gc"])
+ {
+ ServerOptions.GcConfig.MonitorIntervalSeconds = GcConfig.value().get_or("monitorintervalseconds", 30);
+ ServerOptions.GcConfig.IntervalSeconds = GcConfig.value().get_or("intervalseconds", 0);
+ ServerOptions.GcConfig.DiskReserveSize = GcConfig.value().get_or("diskreservesize", uint64_t(1u << 28));
+
+ if (sol::optional<sol::table> CacheGcConfig = GcConfig.value()["cache"])
+ {
+ ServerOptions.GcConfig.Cache.MaxDurationSeconds = CacheGcConfig.value().get_or("maxdurationseconds", int32_t(0));
+ ServerOptions.GcConfig.Cache.DiskSizeLimit = CacheGcConfig.value().get_or("disksizelimit", ~uint64_t(0));
+ ServerOptions.GcConfig.Cache.MemorySizeLimit = CacheGcConfig.value().get_or("memorysizelimit", ~uint64_t(0));
+ ServerOptions.GcConfig.Cache.DiskSizeSoftLimit = CacheGcConfig.value().get_or("disksizesoftlimit", 0);
+ }
+
+ if (sol::optional<sol::table> CasGcConfig = GcConfig.value()["cas"])
+ {
+ ServerOptions.GcConfig.Cas.LargeStrategySizeLimit = CasGcConfig.value().get_or("largestrategysizelimit", ~uint64_t(0));
+ ServerOptions.GcConfig.Cas.SmallStrategySizeLimit = CasGcConfig.value().get_or("smallstrategysizelimit", ~uint64_t(0));
+ ServerOptions.GcConfig.Cas.TinyStrategySizeLimit = CasGcConfig.value().get_or("tinystrategysizelimit", ~uint64_t(0));
+ }
+ }
+
+ if (sol::optional<sol::table> SecurityConfig = lua["security"])
+ {
+ if (sol::optional<sol::table> OpenIdProviders = SecurityConfig.value()["openidproviders"])
+ {
+ for (const auto& Kv : OpenIdProviders.value())
+ {
+ if (sol::optional<sol::table> OpenIdProvider = Kv.second.as<sol::table>())
+ {
+ std::string Name = OpenIdProvider.value().get_or("name", std::string("Default"));
+ std::string Url = OpenIdProvider.value().get_or("url", std::string());
+ std::string ClientId = OpenIdProvider.value().get_or("clientid", std::string());
+
+ ServerOptions.AuthConfig.OpenIdProviders.push_back(
+ {.Name = std::move(Name), .Url = std::move(Url), .ClientId = std::move(ClientId)});
+ }
+ }
+ }
+
+ ServerOptions.EncryptionKey = SecurityConfig.value().get_or("encryptionaeskey", std::string());
+ ServerOptions.EncryptionIV = SecurityConfig.value().get_or("encryptionaesiv", std::string());
+ }
+ }
+}
diff --git a/src/zenserver/config.h b/src/zenserver/config.h
new file mode 100644
index 000000000..8a5c6de4e
--- /dev/null
+++ b/src/zenserver/config.h
@@ -0,0 +1,158 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/zencore.h>
+#include <filesystem>
+#include <string>
+#include <vector>
+
+struct ZenUpstreamJupiterConfig
+{
+ std::string Name;
+ std::string Url;
+ std::string OAuthUrl;
+ std::string OAuthClientId;
+ std::string OAuthClientSecret;
+ std::string OpenIdProvider;
+ std::string AccessToken;
+ std::string Namespace;
+ std::string DdcNamespace;
+};
+
+struct ZenUpstreamHordeConfig
+{
+ std::string Name;
+ std::string Url;
+ std::string OAuthUrl;
+ std::string OAuthClientId;
+ std::string OAuthClientSecret;
+ std::string OpenIdProvider;
+ std::string AccessToken;
+
+ std::string StorageUrl;
+ std::string StorageOAuthUrl;
+ std::string StorageOAuthClientId;
+ std::string StorageOAuthClientSecret;
+ std::string StorageOpenIdProvider;
+ std::string StorageAccessToken;
+
+ std::string Cluster;
+ std::string Namespace;
+};
+
+struct ZenUpstreamZenConfig
+{
+ std::string Name;
+ std::vector<std::string> Urls;
+ std::vector<std::string> Dns;
+};
+
+enum class UpstreamCachePolicy : uint8_t
+{
+ Disabled = 0,
+ Read = 1 << 0,
+ Write = 1 << 1,
+ ReadWrite = Read | Write
+};
+
+struct ZenUpstreamCacheConfig
+{
+ ZenUpstreamJupiterConfig JupiterConfig;
+ ZenUpstreamHordeConfig HordeConfig;
+ ZenUpstreamZenConfig ZenConfig;
+ int32_t UpstreamThreadCount = 4;
+ int32_t ConnectTimeoutMilliseconds = 5000;
+ int32_t TimeoutMilliseconds = 0;
+ UpstreamCachePolicy CachePolicy = UpstreamCachePolicy::ReadWrite;
+};
+
+struct ZenCacheEvictionPolicy
+{
+ uint64_t DiskSizeLimit = ~uint64_t(0);
+ uint64_t MemorySizeLimit = 1024 * 1024 * 1024;
+ int32_t MaxDurationSeconds = 24 * 60 * 60;
+ uint64_t DiskSizeSoftLimit = 0;
+ bool Enabled = true;
+};
+
+struct ZenCasEvictionPolicy
+{
+ uint64_t LargeStrategySizeLimit = ~uint64_t(0);
+ uint64_t SmallStrategySizeLimit = ~uint64_t(0);
+ uint64_t TinyStrategySizeLimit = ~uint64_t(0);
+ bool Enabled = true;
+};
+
+struct ZenGcConfig
+{
+ ZenCasEvictionPolicy Cas;
+ ZenCacheEvictionPolicy Cache;
+ int32_t MonitorIntervalSeconds = 30;
+ int32_t IntervalSeconds = 0;
+ bool CollectSmallObjects = true;
+ bool Enabled = true;
+ uint64_t DiskReserveSize = 1ul << 28;
+};
+
+struct ZenOpenIdProviderConfig
+{
+ std::string Name;
+ std::string Url;
+ std::string ClientId;
+};
+
+struct ZenAuthConfig
+{
+ std::vector<ZenOpenIdProviderConfig> OpenIdProviders;
+};
+
+struct ZenObjectStoreConfig
+{
+ struct BucketConfig
+ {
+ std::string Name;
+ std::filesystem::path Directory;
+ };
+
+ std::vector<BucketConfig> Buckets;
+};
+
+struct ZenServerOptions
+{
+ ZenUpstreamCacheConfig UpstreamCacheConfig;
+ ZenGcConfig GcConfig;
+ ZenAuthConfig AuthConfig;
+ ZenObjectStoreConfig ObjectStoreConfig;
+ std::filesystem::path DataDir; // Root directory for state (used for testing)
+ std::filesystem::path ContentDir; // Root directory for serving frontend content (experimental)
+ std::filesystem::path AbsLogFile; // Absolute path to main log file
+ std::filesystem::path ConfigFile; // Path to Lua config file
+ std::string ChildId; // Id assigned by parent process (used for lifetime management)
+ std::string LogId; // Id for tagging log output
+ std::string HttpServerClass; // Choice of HTTP server implementation
+ std::string EncryptionKey; // 256 bit AES encryption key
+ std::string EncryptionIV; // 128 bit AES initialization vector
+ int BasePort = 1337; // Service listen port (used for both UDP and TCP)
+ int OwnerPid = 0; // Parent process id (zero for standalone)
+ int WebSocketPort = 0; // Web socket port (Zero = disabled)
+ int WebSocketThreads = 0;
+ bool InstallService = false; // Flag used to initiate service install (temporary)
+ bool UninstallService = false; // Flag used to initiate service uninstall (temporary)
+ bool IsDebug = false;
+ bool IsTest = false;
+ bool IsDedicated = false; // Indicates a dedicated/shared instance, with larger resource requirements
+ bool StructuredCacheEnabled = true;
+ bool ExecServiceEnabled = true;
+ bool ComputeServiceEnabled = true;
+ bool ShouldCrash = false; // Option for testing crash handling
+ bool IsFirstRun = false;
+ bool NoSentry = false;
+ bool ObjectStoreEnabled = false;
+#if ZEN_WITH_TRACE
+ std::string TraceHost; // Host name or IP address to send trace data to
+ std::string TraceFile; // Path of a file to write a trace
+#endif
+};
+
+void ParseCliOptions(int argc, char* argv[], ZenServerOptions& ServerOptions);
diff --git a/src/zenserver/diag/diagsvcs.cpp b/src/zenserver/diag/diagsvcs.cpp
new file mode 100644
index 000000000..29ad5c3dd
--- /dev/null
+++ b/src/zenserver/diag/diagsvcs.cpp
@@ -0,0 +1,127 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "diagsvcs.h"
+
+#include <zencore/compactbinary.h>
+#include <zencore/compactbinarybuilder.h>
+#include <zencore/config.h>
+#include <zencore/filesystem.h>
+#include <zencore/logging.h>
+#include <zencore/string.h>
+#include <fstream>
+#include <sstream>
+
+#include <json11.hpp>
+
+namespace zen {
+
+using namespace std::literals;
+
+bool
+ReadFile(const std::string& Path, StringBuilderBase& Out)
+{
+ try
+ {
+ constexpr auto ReadSize = std::size_t{4096};
+ auto FileStream = std::ifstream{Path};
+
+ std::string Buf(ReadSize, '\0');
+ while (FileStream.read(&Buf[0], ReadSize))
+ {
+ Out.Append(std::string_view(&Buf[0], FileStream.gcount()));
+ }
+ Out.Append(std::string_view(&Buf[0], FileStream.gcount()));
+
+ return true;
+ }
+ catch (std::exception&)
+ {
+ Out.Reset();
+ return false;
+ }
+}
+
+HttpHealthService::HttpHealthService()
+{
+ m_Router.RegisterRoute(
+ "",
+ [](HttpRouterRequest& RoutedReq) {
+ HttpServerRequest& HttpReq = RoutedReq.ServerRequest();
+ HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, u8"OK!"sv);
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "info",
+ [this](HttpRouterRequest& RoutedReq) {
+ HttpServerRequest& HttpReq = RoutedReq.ServerRequest();
+
+ CbObjectWriter Writer;
+ Writer << "DataRoot"sv << m_HealthInfo.DataRoot.string();
+ Writer << "AbsLogPath"sv << m_HealthInfo.AbsLogPath.string();
+ Writer << "BuildVersion"sv << m_HealthInfo.BuildVersion;
+ Writer << "HttpServerClass"sv << m_HealthInfo.HttpServerClass;
+
+ HttpReq.WriteResponse(HttpResponseCode::OK, Writer.Save());
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "log",
+ [this](HttpRouterRequest& RoutedReq) {
+ HttpServerRequest& HttpReq = RoutedReq.ServerRequest();
+
+ zen::Log().flush();
+
+ std::filesystem::path Path =
+ m_HealthInfo.AbsLogPath.empty() ? m_HealthInfo.DataRoot / "logs/zenserver.log" : m_HealthInfo.AbsLogPath;
+
+ ExtendableStringBuilder<4096> Sb;
+ if (ReadFile(Path.string(), Sb) && Sb.Size() > 0)
+ {
+ HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Sb.ToView());
+ }
+ else
+ {
+ HttpReq.WriteResponse(HttpResponseCode::NotFound);
+ }
+ },
+ HttpVerb::kGet);
+ m_Router.RegisterRoute(
+ "version",
+ [this](HttpRouterRequest& RoutedReq) {
+ HttpServerRequest& HttpReq = RoutedReq.ServerRequest();
+ if (HttpReq.GetQueryParams().GetValue("detailed") == "true")
+ {
+ HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, ZEN_CFG_VERSION_BUILD_STRING_FULL);
+ }
+ else
+ {
+ HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, ZEN_CFG_VERSION);
+ }
+ },
+ HttpVerb::kGet);
+}
+
+void
+HttpHealthService::SetHealthInfo(HealthServiceInfo&& Info)
+{
+ m_HealthInfo = std::move(Info);
+}
+
+const char*
+HttpHealthService::BaseUri() const
+{
+ return "/health/";
+}
+
+void
+HttpHealthService::HandleRequest(HttpServerRequest& Request)
+{
+ if (!m_Router.HandleRequest(Request))
+ {
+ Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, u8"OK!"sv);
+ }
+}
+
+} // namespace zen
diff --git a/src/zenserver/diag/diagsvcs.h b/src/zenserver/diag/diagsvcs.h
new file mode 100644
index 000000000..bd03f8023
--- /dev/null
+++ b/src/zenserver/diag/diagsvcs.h
@@ -0,0 +1,111 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/iobuffer.h>
+#include <zenhttp/httpserver.h>
+
+#include <filesystem>
+
+//////////////////////////////////////////////////////////////////////////
+
+namespace zen {
+
+class HttpTestService : public HttpService
+{
+ uint32_t LogPoint = 0;
+
+public:
+ HttpTestService() {}
+ ~HttpTestService() = default;
+
+ virtual const char* BaseUri() const override { return "/test/"; }
+
+ virtual void HandleRequest(HttpServerRequest& Request) override
+ {
+ using namespace std::literals;
+
+ auto Uri = Request.RelativeUri();
+
+ if (Uri == "hello"sv)
+ {
+ Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, u8"hello world!"sv);
+
+ // OutputLogMessageInternal(&LogPoint, 0, 0);
+ }
+ else if (Uri == "1K"sv)
+ {
+ Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, m_1k);
+ }
+ else if (Uri == "1M"sv)
+ {
+ Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, m_1m);
+ }
+ else if (Uri == "1M_1k"sv)
+ {
+ std::vector<IoBuffer> Buffers;
+ Buffers.reserve(1024);
+
+ for (int i = 0; i < 1024; ++i)
+ {
+ Buffers.push_back(m_1k);
+ }
+
+ Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, Buffers);
+ }
+ else if (Uri == "1G"sv)
+ {
+ std::vector<IoBuffer> Buffers;
+ Buffers.reserve(1024);
+
+ for (int i = 0; i < 1024; ++i)
+ {
+ Buffers.push_back(m_1m);
+ }
+
+ Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, Buffers);
+ }
+ else if (Uri == "1G_1k"sv)
+ {
+ std::vector<IoBuffer> Buffers;
+ Buffers.reserve(1024 * 1024);
+
+ for (int i = 0; i < 1024 * 1024; ++i)
+ {
+ Buffers.push_back(m_1k);
+ }
+
+ Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, Buffers);
+ }
+ }
+
+private:
+ IoBuffer m_1m{1024 * 1024};
+ IoBuffer m_1k{m_1m, 0u, 1024};
+};
+
+struct HealthServiceInfo
+{
+ std::filesystem::path DataRoot;
+ std::filesystem::path AbsLogPath;
+ std::string HttpServerClass;
+ std::string BuildVersion;
+};
+
+class HttpHealthService : public HttpService
+{
+public:
+ HttpHealthService();
+ ~HttpHealthService() = default;
+
+ void SetHealthInfo(HealthServiceInfo&& Info);
+
+ virtual const char* BaseUri() const override;
+ virtual void HandleRequest(HttpServerRequest& Request) override final;
+
+private:
+ HttpRequestRouter m_Router;
+ HealthServiceInfo m_HealthInfo;
+};
+
+} // namespace zen
diff --git a/src/zenserver/diag/formatters.h b/src/zenserver/diag/formatters.h
new file mode 100644
index 000000000..759df58d3
--- /dev/null
+++ b/src/zenserver/diag/formatters.h
@@ -0,0 +1,71 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/compactbinary.h>
+#include <zencore/compactbinaryvalidation.h>
+#include <zencore/iobuffer.h>
+#include <zencore/string.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <cpr/cpr.h>
+#include <fmt/format.h>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+template<>
+struct fmt::formatter<cpr::Response>
+{
+ constexpr auto parse(format_parse_context& Ctx) -> decltype(Ctx.begin()) { return Ctx.end(); }
+
+ template<typename FormatContext>
+ auto format(const cpr::Response& Response, FormatContext& Ctx) -> decltype(Ctx.out())
+ {
+ using namespace std::literals;
+
+ if (Response.status_code == 200 || Response.status_code == 201)
+ {
+ return fmt::format_to(Ctx.out(),
+ "Url: {}, Status: {}, Bytes: {}/{} (Up/Down), Elapsed: {}s",
+ Response.url.str(),
+ Response.status_code,
+ Response.uploaded_bytes,
+ Response.downloaded_bytes,
+ Response.elapsed);
+ }
+ else
+ {
+ const auto It = Response.header.find("Content-Type");
+ const std::string_view ContentType = It != Response.header.end() ? It->second : "<None>"sv;
+
+ if (ContentType == "application/x-ue-cb"sv)
+ {
+ zen::IoBuffer Body(zen::IoBuffer::Wrap, Response.text.data(), Response.text.size());
+ zen::CbObjectView Obj(Body.Data());
+ zen::ExtendableStringBuilder<256> Sb;
+ std::string_view Json = Obj.ToJson(Sb).ToView();
+
+ return fmt::format_to(Ctx.out(),
+ "Url: {}, Status: {}, Bytes: {}/{} (Up/Down), Elapsed: {}s, Response: '{}', Reason: '{}'",
+ Response.url.str(),
+ Response.status_code,
+ Response.uploaded_bytes,
+ Response.downloaded_bytes,
+ Response.elapsed,
+ Json,
+ Response.reason);
+ }
+ else
+ {
+ return fmt::format_to(Ctx.out(),
+ "Url: {}, Status: {}, Bytes: {}/{} (Up/Down), Elapsed: {}s, Reponse: '{}', Reason: '{}'",
+ Response.url.str(),
+ Response.status_code,
+ Response.uploaded_bytes,
+ Response.downloaded_bytes,
+ Response.elapsed,
+ Response.text,
+ Response.reason);
+ }
+ }
+ }
+};
diff --git a/src/zenserver/diag/logging.cpp b/src/zenserver/diag/logging.cpp
new file mode 100644
index 000000000..24c7572f4
--- /dev/null
+++ b/src/zenserver/diag/logging.cpp
@@ -0,0 +1,467 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "logging.h"
+
+#include "config.h"
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <spdlog/async.h>
+#include <spdlog/async_logger.h>
+#include <spdlog/pattern_formatter.h>
+#include <spdlog/sinks/ansicolor_sink.h>
+#include <spdlog/sinks/basic_file_sink.h>
+#include <spdlog/sinks/daily_file_sink.h>
+#include <spdlog/sinks/msvc_sink.h>
+#include <spdlog/sinks/rotating_file_sink.h>
+#include <spdlog/sinks/stdout_color_sinks.h>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+#include <zencore/compactbinary.h>
+#include <zencore/filesystem.h>
+#include <zencore/string.h>
+
+#include <chrono>
+#include <memory>
+
+// Custom logging -- test code, this should be tweaked
+
+namespace logging {
+
+using namespace spdlog;
+using namespace spdlog::details;
+using namespace std::literals;
+
+class full_formatter final : public spdlog::formatter
+{
+public:
+ full_formatter(std::string_view LogId, std::chrono::time_point<std::chrono::system_clock> Epoch) : m_Epoch(Epoch), m_LogId(LogId) {}
+
+ virtual std::unique_ptr<formatter> clone() const override { return std::make_unique<full_formatter>(m_LogId, m_Epoch); }
+
+ static constexpr bool UseDate = false;
+
+ virtual void format(const details::log_msg& msg, memory_buf_t& dest) override
+ {
+ using std::chrono::duration_cast;
+ using std::chrono::milliseconds;
+ using std::chrono::seconds;
+
+ if constexpr (UseDate)
+ {
+ auto secs = std::chrono::duration_cast<seconds>(msg.time.time_since_epoch());
+ if (secs != m_LastLogSecs)
+ {
+ m_CachedTm = os::localtime(log_clock::to_time_t(msg.time));
+ m_LastLogSecs = secs;
+ }
+ }
+
+ const auto& tm_time = m_CachedTm;
+
+ // cache the date/time part for the next second.
+ auto duration = msg.time - m_Epoch;
+ auto secs = duration_cast<seconds>(duration);
+
+ if (m_CacheTimestamp != secs || m_CachedDatetime.size() == 0)
+ {
+ m_CachedDatetime.clear();
+ m_CachedDatetime.push_back('[');
+
+ if constexpr (UseDate)
+ {
+ fmt_helper::append_int(tm_time.tm_year + 1900, m_CachedDatetime);
+ m_CachedDatetime.push_back('-');
+
+ fmt_helper::pad2(tm_time.tm_mon + 1, m_CachedDatetime);
+ m_CachedDatetime.push_back('-');
+
+ fmt_helper::pad2(tm_time.tm_mday, m_CachedDatetime);
+ m_CachedDatetime.push_back(' ');
+
+ fmt_helper::pad2(tm_time.tm_hour, m_CachedDatetime);
+ m_CachedDatetime.push_back(':');
+
+ fmt_helper::pad2(tm_time.tm_min, m_CachedDatetime);
+ m_CachedDatetime.push_back(':');
+
+ fmt_helper::pad2(tm_time.tm_sec, m_CachedDatetime);
+ }
+ else
+ {
+ int Count = int(secs.count());
+
+ const int LogSecs = Count % 60;
+ Count /= 60;
+
+ const int LogMins = Count % 60;
+ Count /= 60;
+
+ const int LogHours = Count;
+
+ fmt_helper::pad2(LogHours, m_CachedDatetime);
+ m_CachedDatetime.push_back(':');
+ fmt_helper::pad2(LogMins, m_CachedDatetime);
+ m_CachedDatetime.push_back(':');
+ fmt_helper::pad2(LogSecs, m_CachedDatetime);
+ }
+
+ m_CachedDatetime.push_back('.');
+
+ m_CacheTimestamp = secs;
+ }
+
+ dest.append(m_CachedDatetime.begin(), m_CachedDatetime.end());
+
+ auto millis = fmt_helper::time_fraction<milliseconds>(msg.time);
+ fmt_helper::pad3(static_cast<uint32_t>(millis.count()), dest);
+ dest.push_back(']');
+ dest.push_back(' ');
+
+ if (!m_LogId.empty())
+ {
+ dest.push_back('[');
+ fmt_helper::append_string_view(m_LogId, dest);
+ dest.push_back(']');
+ dest.push_back(' ');
+ }
+
+ // append logger name if exists
+ if (msg.logger_name.size() > 0)
+ {
+ dest.push_back('[');
+ fmt_helper::append_string_view(msg.logger_name, dest);
+ dest.push_back(']');
+ dest.push_back(' ');
+ }
+
+ dest.push_back('[');
+ // wrap the level name with color
+ msg.color_range_start = dest.size();
+ fmt_helper::append_string_view(level::to_string_view(msg.level), dest);
+ msg.color_range_end = dest.size();
+ dest.push_back(']');
+ dest.push_back(' ');
+
+ // add source location if present
+ if (!msg.source.empty())
+ {
+ dest.push_back('[');
+ const char* filename = details::short_filename_formatter<details::null_scoped_padder>::basename(msg.source.filename);
+ fmt_helper::append_string_view(filename, dest);
+ dest.push_back(':');
+ fmt_helper::append_int(msg.source.line, dest);
+ dest.push_back(']');
+ dest.push_back(' ');
+ }
+
+ fmt_helper::append_string_view(msg.payload, dest);
+ fmt_helper::append_string_view("\n"sv, dest);
+ }
+
+private:
+ std::chrono::time_point<std::chrono::system_clock> m_Epoch;
+ std::tm m_CachedTm;
+ std::chrono::seconds m_LastLogSecs;
+ std::chrono::seconds m_CacheTimestamp{0};
+ memory_buf_t m_CachedDatetime;
+ std::string m_LogId;
+};
+
+class json_formatter final : public spdlog::formatter
+{
+public:
+ json_formatter(std::string_view LogId) : m_LogId(LogId) {}
+
+ virtual std::unique_ptr<formatter> clone() const override { return std::make_unique<json_formatter>(m_LogId); }
+
+ virtual void format(const details::log_msg& msg, memory_buf_t& dest) override
+ {
+ using std::chrono::duration_cast;
+ using std::chrono::milliseconds;
+ using std::chrono::seconds;
+
+ auto secs = std::chrono::duration_cast<seconds>(msg.time.time_since_epoch());
+ if (secs != m_LastLogSecs)
+ {
+ m_CachedTm = os::localtime(log_clock::to_time_t(msg.time));
+ m_LastLogSecs = secs;
+ }
+
+ const auto& tm_time = m_CachedTm;
+
+ // cache the date/time part for the next second.
+
+ if (m_CacheTimestamp != secs || m_CachedDatetime.size() == 0)
+ {
+ m_CachedDatetime.clear();
+
+ fmt_helper::append_int(tm_time.tm_year + 1900, m_CachedDatetime);
+ m_CachedDatetime.push_back('-');
+
+ fmt_helper::pad2(tm_time.tm_mon + 1, m_CachedDatetime);
+ m_CachedDatetime.push_back('-');
+
+ fmt_helper::pad2(tm_time.tm_mday, m_CachedDatetime);
+ m_CachedDatetime.push_back(' ');
+
+ fmt_helper::pad2(tm_time.tm_hour, m_CachedDatetime);
+ m_CachedDatetime.push_back(':');
+
+ fmt_helper::pad2(tm_time.tm_min, m_CachedDatetime);
+ m_CachedDatetime.push_back(':');
+
+ fmt_helper::pad2(tm_time.tm_sec, m_CachedDatetime);
+
+ m_CachedDatetime.push_back('.');
+
+ m_CacheTimestamp = secs;
+ }
+ dest.append("{"sv);
+ dest.append("\"time\": \""sv);
+ dest.append(m_CachedDatetime.begin(), m_CachedDatetime.end());
+ auto millis = fmt_helper::time_fraction<milliseconds>(msg.time);
+ fmt_helper::pad3(static_cast<uint32_t>(millis.count()), dest);
+ dest.append("\", "sv);
+
+ dest.append("\"status\": \""sv);
+ dest.append(level::to_string_view(msg.level));
+ dest.append("\", "sv);
+
+ dest.append("\"source\": \""sv);
+ dest.append("zenserver"sv);
+ dest.append("\", "sv);
+
+ dest.append("\"service\": \""sv);
+ dest.append("zencache"sv);
+ dest.append("\", "sv);
+
+ if (!m_LogId.empty())
+ {
+ dest.append("\"id\": \""sv);
+ dest.append(m_LogId);
+ dest.append("\", "sv);
+ }
+
+ if (msg.logger_name.size() > 0)
+ {
+ dest.append("\"logger.name\": \""sv);
+ dest.append(msg.logger_name);
+ dest.append("\", "sv);
+ }
+
+ if (msg.thread_id != 0)
+ {
+ dest.append("\"logger.thread_name\": \""sv);
+ fmt_helper::pad_uint(msg.thread_id, 0, dest);
+ dest.append("\", "sv);
+ }
+
+ if (!msg.source.empty())
+ {
+ dest.append("\"file\": \""sv);
+ WriteEscapedString(dest, details::short_filename_formatter<details::null_scoped_padder>::basename(msg.source.filename));
+ dest.append("\","sv);
+
+ dest.append("\"line\": \""sv);
+ dest.append(fmt::format("{}", msg.source.line));
+ dest.append("\","sv);
+
+ dest.append("\"logger.method_name\": \""sv);
+ WriteEscapedString(dest, msg.source.funcname);
+ dest.append("\", "sv);
+ }
+
+ dest.append("\"message\": \""sv);
+ WriteEscapedString(dest, msg.payload);
+ dest.append("\""sv);
+
+ dest.append("}\n"sv);
+ }
+
+private:
+ static inline const std::unordered_map<char, std::string_view> SpecialCharacterMap{{'\b', "\\b"sv},
+ {'\f', "\\f"sv},
+ {'\n', "\\n"sv},
+ {'\r', "\\r"sv},
+ {'\t', "\\t"sv},
+ {'"', "\\\""sv},
+ {'\\', "\\\\"sv}};
+
+ static void WriteEscapedString(memory_buf_t& dest, const spdlog::string_view_t& payload)
+ {
+ const char* RangeStart = payload.begin();
+ for (const char* It = RangeStart; It != payload.end(); ++It)
+ {
+ if (auto SpecialIt = SpecialCharacterMap.find(*It); SpecialIt != SpecialCharacterMap.end())
+ {
+ if (RangeStart != It)
+ {
+ dest.append(RangeStart, It);
+ }
+ dest.append(SpecialIt->second);
+ RangeStart = It + 1;
+ }
+ }
+ if (RangeStart != payload.end())
+ {
+ dest.append(RangeStart, payload.end());
+ }
+ };
+
+ std::tm m_CachedTm{0, 0, 0, 0, 0, 0, 0, 0, 0};
+ std::chrono::seconds m_LastLogSecs{0};
+ std::chrono::seconds m_CacheTimestamp{0};
+ memory_buf_t m_CachedDatetime;
+ std::string m_LogId;
+};
+
+bool
+EnableVTMode()
+{
+#if ZEN_PLATFORM_WINDOWS
+ // Set output mode to handle virtual terminal sequences
+ HANDLE hOut = GetStdHandle(STD_OUTPUT_HANDLE);
+ if (hOut == INVALID_HANDLE_VALUE)
+ {
+ return false;
+ }
+
+ DWORD dwMode = 0;
+ if (!GetConsoleMode(hOut, &dwMode))
+ {
+ return false;
+ }
+
+ dwMode |= ENABLE_VIRTUAL_TERMINAL_PROCESSING;
+ if (!SetConsoleMode(hOut, dwMode))
+ {
+ return false;
+ }
+#endif
+
+ return true;
+}
+
+} // namespace logging
+
+void
+InitializeLogging(const ZenServerOptions& GlobalOptions)
+{
+ zen::logging::InitializeLogging();
+ logging::EnableVTMode();
+
+ bool IsAsync = true;
+ spdlog::level::level_enum LogLevel = spdlog::level::info;
+
+ if (GlobalOptions.IsDebug)
+ {
+ LogLevel = spdlog::level::debug;
+ IsAsync = false;
+ }
+
+ if (GlobalOptions.IsTest)
+ {
+ LogLevel = spdlog::level::trace;
+ IsAsync = false;
+ }
+
+ if (IsAsync)
+ {
+ const int QueueSize = 8192;
+ const int ThreadCount = 1;
+ spdlog::init_thread_pool(QueueSize, ThreadCount);
+
+ auto AsyncLogger = spdlog::create_async<spdlog::sinks::ansicolor_stdout_sink_mt>("main");
+ zen::logging::SetDefault(AsyncLogger);
+ }
+
+ // Sinks
+
+ auto ConsoleSink = std::make_shared<spdlog::sinks::ansicolor_stdout_sink_mt>();
+
+ // spdlog can't create directories that starts with `\\?\` so we make sure the folder exists before creating the logger instance
+ zen::CreateDirectories(GlobalOptions.AbsLogFile.parent_path());
+
+#if 0
+ auto FileSink = std::make_shared<spdlog::sinks::daily_file_sink_mt>(zen::PathToUtf8(GlobalOptions.AbsLogFile),
+ 0,
+ 0,
+ /* truncate */ false,
+ uint16_t(/* max files */ 14));
+#else
+ auto FileSink = std::make_shared<spdlog::sinks::rotating_file_sink_mt>(zen::PathToUtf8(GlobalOptions.AbsLogFile),
+ /* max size */ 128 * 1024 * 1024,
+ /* max files */ 16,
+ /* rotate on open */ true);
+#endif
+
+ std::set_terminate([]() { ZEN_CRITICAL("Program exited abnormally via std::terminate()"); });
+
+ // Default
+
+ auto& DefaultLogger = zen::logging::Default();
+ auto& Sinks = DefaultLogger.sinks();
+
+ Sinks.clear();
+ Sinks.push_back(ConsoleSink);
+ Sinks.push_back(FileSink);
+
+#if ZEN_PLATFORM_WINDOWS
+ if (zen::IsDebuggerPresent() && GlobalOptions.IsDebug)
+ {
+ auto DebugSink = std::make_shared<spdlog::sinks::msvc_sink_mt>();
+ DebugSink->set_level(spdlog::level::debug);
+ Sinks.push_back(DebugSink);
+ }
+#endif
+
+ // HTTP server request logging
+
+ std::filesystem::path HttpLogPath = GlobalOptions.DataDir / "logs" / "http.log";
+
+ // spdlog can't create directories that starts with `\\?\` so we make sure the folder exists before creating the logger instance
+ zen::CreateDirectories(HttpLogPath.parent_path());
+
+ auto HttpSink = std::make_shared<spdlog::sinks::rotating_file_sink_mt>(zen::PathToUtf8(HttpLogPath),
+ /* max size */ 128 * 1024 * 1024,
+ /* max files */ 16,
+ /* rotate on open */ true);
+
+ auto HttpLogger = std::make_shared<spdlog::logger>("http_requests", HttpSink);
+ spdlog::register_logger(HttpLogger);
+
+ // Jupiter - only log upstream HTTP traffic to file
+
+ auto JupiterLogger = std::make_shared<spdlog::logger>("jupiter", FileSink);
+ spdlog::register_logger(JupiterLogger);
+
+ // Zen - only log upstream HTTP traffic to file
+
+ auto ZenClientLogger = std::make_shared<spdlog::logger>("zenclient", FileSink);
+ spdlog::register_logger(ZenClientLogger);
+
+ // Configure all registered loggers according to settings
+
+ spdlog::set_level(LogLevel);
+ spdlog::flush_on(spdlog::level::err);
+ spdlog::flush_every(std::chrono::seconds{2});
+ spdlog::set_formatter(std::make_unique<logging::full_formatter>(GlobalOptions.LogId, std::chrono::system_clock::now()));
+
+ if (GlobalOptions.AbsLogFile.extension() == ".json")
+ {
+ FileSink->set_formatter(std::make_unique<logging::json_formatter>(GlobalOptions.LogId));
+ }
+ else
+ {
+ FileSink->set_pattern("[%C-%m-%d.%e %T] [%n] [%l] %v");
+ }
+ DefaultLogger.info("log starting at {}", zen::DateTime::Now().ToIso8601());
+}
+
+void
+ShutdownLogging()
+{
+ auto& DefaultLogger = zen::logging::Default();
+ DefaultLogger.info("log ending at {}", zen::DateTime::Now().ToIso8601());
+ zen::logging::ShutdownLogging();
+}
diff --git a/src/zenserver/diag/logging.h b/src/zenserver/diag/logging.h
new file mode 100644
index 000000000..8df49f842
--- /dev/null
+++ b/src/zenserver/diag/logging.h
@@ -0,0 +1,10 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/logging.h>
+struct ZenServerOptions;
+
+void InitializeLogging(const ZenServerOptions& GlobalOptions);
+
+void ShutdownLogging();
diff --git a/src/zenserver/frontend/frontend.cpp b/src/zenserver/frontend/frontend.cpp
new file mode 100644
index 000000000..149d97924
--- /dev/null
+++ b/src/zenserver/frontend/frontend.cpp
@@ -0,0 +1,128 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "frontend.h"
+
+#include <zencore/endian.h>
+#include <zencore/filesystem.h>
+#include <zencore/string.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#if ZEN_PLATFORM_WINDOWS
+# include <Windows.h>
+#endif
+ZEN_THIRD_PARTY_INCLUDES_END
+
+namespace zen {
+
+////////////////////////////////////////////////////////////////////////////////
+HttpFrontendService::HttpFrontendService(std::filesystem::path Directory) : m_Directory(Directory)
+{
+ std::filesystem::path SelfPath = GetRunningExecutablePath();
+
+ // Locate a .zip file appended onto the end of this binary
+ IoBuffer SelfBuffer = IoBufferBuilder::MakeFromFile(SelfPath);
+ m_ZipFs = ZipFs(std::move(SelfBuffer));
+
+#if ZEN_BUILD_DEBUG
+ if (!Directory.empty())
+ {
+ return;
+ }
+
+ std::error_code ErrorCode;
+ auto Path = SelfPath;
+ while (Path.has_parent_path())
+ {
+ auto ParentPath = Path.parent_path();
+ if (ParentPath == Path)
+ {
+ break;
+ }
+ if (std::filesystem::is_regular_file(ParentPath / "xmake.lua", ErrorCode))
+ {
+ if (ErrorCode)
+ {
+ break;
+ }
+
+ auto HtmlDir = ParentPath / "zenserver" / "frontend" / "html";
+ if (std::filesystem::is_directory(HtmlDir, ErrorCode))
+ {
+ m_Directory = HtmlDir;
+ }
+ break;
+ }
+ Path = ParentPath;
+ };
+#endif
+}
+
+////////////////////////////////////////////////////////////////////////////////
+HttpFrontendService::~HttpFrontendService()
+{
+}
+
+////////////////////////////////////////////////////////////////////////////////
+const char*
+HttpFrontendService::BaseUri() const
+{
+ return "/dashboard"; // in order to use the root path we need to remove HttpAddUrlToUrlGroup in HttpSys.cpp
+}
+
+////////////////////////////////////////////////////////////////////////////////
+void
+HttpFrontendService::HandleRequest(zen::HttpServerRequest& Request)
+{
+ using namespace std::literals;
+
+ std::string_view Uri = Request.RelativeUriWithExtension();
+ for (; Uri[0] == '/'; Uri = Uri.substr(1))
+ ;
+ if (Uri.empty())
+ {
+ Uri = "index.html"sv;
+ }
+
+ // Dismiss if the URI contains .. anywhere to prevent arbitrary file reads
+ if (Uri.find("..") != Uri.npos)
+ {
+ return Request.WriteResponse(HttpResponseCode::Forbidden);
+ }
+
+ // Map the file extension to a MIME type. To keep things constrained, only a
+ // small subset of file extensions is allowed
+
+ HttpContentType ContentType = HttpContentType::kUnknownContentType;
+
+ if (const size_t DotIndex = Uri.rfind("."); DotIndex != Uri.npos)
+ {
+ const std::string_view DotExt = Uri.substr(DotIndex + 1);
+
+ ContentType = ParseContentType(DotExt);
+ }
+
+ if (ContentType == HttpContentType::kUnknownContentType)
+ {
+ return Request.WriteResponse(HttpResponseCode::Forbidden);
+ }
+
+ // The given content directory overrides any zip-fs discovered in the binary
+ if (!m_Directory.empty())
+ {
+ FileContents File = ReadFile(m_Directory / Uri);
+
+ if (!File.ErrorCode)
+ {
+ return Request.WriteResponse(HttpResponseCode::OK, ContentType, File.Data[0]);
+ }
+ }
+
+ if (IoBuffer FileBuffer = m_ZipFs.GetFile(Uri))
+ {
+ return Request.WriteResponse(HttpResponseCode::OK, ContentType, FileBuffer);
+ }
+
+ Request.WriteResponse(HttpResponseCode::NotFound, HttpContentType::kText, "Not found"sv);
+}
+
+} // namespace zen
diff --git a/src/zenserver/frontend/frontend.h b/src/zenserver/frontend/frontend.h
new file mode 100644
index 000000000..6eac20620
--- /dev/null
+++ b/src/zenserver/frontend/frontend.h
@@ -0,0 +1,25 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zenhttp/httpserver.h>
+#include "zipfs.h"
+
+#include <filesystem>
+
+namespace zen {
+
+class HttpFrontendService final : public zen::HttpService
+{
+public:
+ HttpFrontendService(std::filesystem::path Directory);
+ virtual ~HttpFrontendService();
+ virtual const char* BaseUri() const override;
+ virtual void HandleRequest(zen::HttpServerRequest& Request) override;
+
+private:
+ ZipFs m_ZipFs;
+ std::filesystem::path m_Directory;
+};
+
+} // namespace zen
diff --git a/src/zenserver/frontend/html/index.html b/src/zenserver/frontend/html/index.html
new file mode 100644
index 000000000..252ee621e
--- /dev/null
+++ b/src/zenserver/frontend/html/index.html
@@ -0,0 +1,59 @@
+<!DOCTYPE html>
+<html>
+<head>
+ <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" integrity="sha384-F3w7mX95PdgyTmZZMECAngseQB83DfGTowi0iMjiWaeVhAn4FJkqJByhZMI3AhiU" crossorigin="anonymous">
+ <script src="https://cdn.jsdelivr.net/npm/[email protected]/dist/js/bootstrap.min.js" integrity="sha384-skAcpIdS7UcVUC05LJ9Dxay8AXcDYfBJqt1CJ85S/CFujBsIzCIv+l9liuYLaMQ/" crossorigin="anonymous"></script>
+ <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/font/bootstrap-icons.css">
+ <style type="text/css">
+ body {
+ background-color: #fafafa;
+ }
+ </style>
+ <script type="text/javascript">
+ const getCacheStats = () => {
+ const opts = { headers: { "Accept": "application/json" } };
+ fetch("/stats/z$", opts)
+ .then(response => {
+ if (!response.ok) {
+ throw Error(response.statusText);
+ }
+ return response.json();
+ })
+ .then(json => {
+ document.getElementById("status").innerHTML = "connected"
+ document.getElementById("stats").innerHTML = JSON.stringify(json, null, 4);
+ })
+ .catch(error => {
+ document.getElementById("status").innerHTML = "disconnected"
+ document.getElementById("stats").innerHTML = ""
+ console.log(error);
+ })
+ .finally(() => {
+ window.setTimeout(getCacheStats, 1000);
+ });
+ };
+ getCacheStats();
+ </script>
+</head>
+<body>
+ <div class="container">
+ <div class="row">
+ <div class="text-center mt-5">
+ <pre>
+__________ _________ __
+\____ / ____ ____ / _____/_/ |_ ____ _______ ____
+ / / _/ __ \ / \ \_____ \ \ __\ / _ \ \_ __ \_/ __ \
+ / /_ \ ___/ | | \ / \ | | ( <_> ) | | \/\ ___/
+/_______ \ \___ >|___| //_______ / |__| \____/ |__| \___ >
+ \/ \/ \/ \/ \/
+ </pre>
+ <pre id="status"/>
+ </div>
+ </div>
+ <div class="row">
+ <pre class="mb-0">Z$:</pre>
+ <pre id="stats"></pre>
+ <div>
+ </div>
+</body>
+</html>
diff --git a/src/zenserver/frontend/zipfs.cpp b/src/zenserver/frontend/zipfs.cpp
new file mode 100644
index 000000000..f9c2bc8ff
--- /dev/null
+++ b/src/zenserver/frontend/zipfs.cpp
@@ -0,0 +1,169 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "zipfs.h"
+
+namespace zen {
+
+//////////////////////////////////////////////////////////////////////////
+namespace {
+
+#if ZEN_COMPILER_MSC
+# pragma warning(push)
+# pragma warning(disable : 4200)
+#endif
+
+ using ZipInt16 = uint16_t;
+
+ struct ZipInt32
+ {
+ operator uint32_t() const { return *(uint32_t*)Parts; }
+ uint16_t Parts[2];
+ };
+
+ struct EocdRecord
+ {
+ enum : uint32_t
+ {
+ Magic = 0x0605'4b50,
+ };
+ ZipInt32 Signature;
+ ZipInt16 ThisDiskIndex;
+ ZipInt16 CdStartDiskIndex;
+ ZipInt16 CdRecordThisDiskCount;
+ ZipInt16 CdRecordCount;
+ ZipInt32 CdSize;
+ ZipInt32 CdOffset;
+ ZipInt16 CommentSize;
+ char Comment[];
+ };
+
+ struct CentralDirectoryRecord
+ {
+ enum : uint32_t
+ {
+ Magic = 0x0201'4b50,
+ };
+
+ ZipInt32 Signature;
+ ZipInt16 VersionMadeBy;
+ ZipInt16 VersionRequired;
+ ZipInt16 Flags;
+ ZipInt16 CompressionMethod;
+ ZipInt16 LastModTime;
+ ZipInt16 LastModDate;
+ ZipInt32 Crc32;
+ ZipInt32 CompressedSize;
+ ZipInt32 OriginalSize;
+ ZipInt16 FileNameLength;
+ ZipInt16 ExtraFieldLength;
+ ZipInt16 CommentLength;
+ ZipInt16 DiskIndex;
+ ZipInt16 InternalFileAttr;
+ ZipInt32 ExternalFileAttr;
+ ZipInt32 Offset;
+ char FileName[];
+ };
+
+ struct LocalFileHeader
+ {
+ enum : uint32_t
+ {
+ Magic = 0x0403'4b50,
+ };
+
+ ZipInt32 Signature;
+ ZipInt16 VersionRequired;
+ ZipInt16 Flags;
+ ZipInt16 CompressionMethod;
+ ZipInt16 LastModTime;
+ ZipInt16 LastModDate;
+ ZipInt32 Crc32;
+ ZipInt32 CompressedSize;
+ ZipInt32 OriginalSize;
+ ZipInt16 FileNameLength;
+ ZipInt16 ExtraFieldLength;
+ char FileName[];
+ };
+
+#if ZEN_COMPILER_MSC
+# pragma warning(pop)
+#endif
+
+} // namespace
+
+//////////////////////////////////////////////////////////////////////////
+ZipFs::ZipFs(IoBuffer&& Buffer)
+{
+ MemoryView View = Buffer.GetView();
+
+ uint8_t* Cursor = (uint8_t*)(View.GetData()) + View.GetSize();
+ if (View.GetSize() < sizeof(EocdRecord))
+ {
+ return;
+ }
+
+ const auto* EocdCursor = (EocdRecord*)(Cursor - sizeof(EocdRecord));
+
+ // It is more correct to search backwards for EocdRecord::Magic as the
+ // comment can be of a variable length. But here we're not going to support
+ // zip files with comments.
+ if (EocdCursor->Signature != EocdRecord::Magic)
+ {
+ return;
+ }
+
+ // Zip64 isn't supported either
+ if (EocdCursor->ThisDiskIndex == 0xffff)
+ {
+ return;
+ }
+
+ Cursor = (uint8_t*)EocdCursor - uint32_t(EocdCursor->CdOffset) - uint32_t(EocdCursor->CdSize);
+
+ const auto* CdCursor = (CentralDirectoryRecord*)(Cursor + EocdCursor->CdOffset);
+ for (int i = 0, n = EocdCursor->CdRecordCount; i < n; ++i)
+ {
+ const CentralDirectoryRecord& Cd = *CdCursor;
+
+ bool Acceptable = true;
+ Acceptable &= (Cd.OriginalSize > 0); // has some content
+ Acceptable &= (Cd.CompressionMethod == 0); // is stored uncomrpessed
+ if (Acceptable)
+ {
+ const uint8_t* Lfh = Cursor + Cd.Offset;
+ if (uintptr_t(Lfh - Cursor) < View.GetSize())
+ {
+ std::string_view FileName(Cd.FileName, Cd.FileNameLength);
+ m_Files.insert(std::make_pair(FileName, FileItem{Lfh, size_t(0)}));
+ }
+ }
+
+ uint32_t ExtraBytes = Cd.FileNameLength + Cd.ExtraFieldLength + Cd.CommentLength;
+ CdCursor = (CentralDirectoryRecord*)(Cd.FileName + ExtraBytes);
+ }
+
+ m_Buffer = std::move(Buffer);
+}
+
+//////////////////////////////////////////////////////////////////////////
+IoBuffer
+ZipFs::GetFile(const std::string_view& FileName) const
+{
+ FileMap::iterator Iter = m_Files.find(FileName);
+ if (Iter == m_Files.end())
+ {
+ return {};
+ }
+
+ FileItem& Item = Iter->second;
+ if (Item.GetSize() > 0)
+ {
+ return IoBuffer(IoBuffer::Wrap, Item.GetData(), Item.GetSize());
+ }
+
+ const auto* Lfh = (LocalFileHeader*)(Item.GetData());
+ Item = MemoryView(Lfh->FileName + Lfh->FileNameLength + Lfh->ExtraFieldLength, Lfh->OriginalSize);
+ return IoBuffer(IoBuffer::Wrap, Item.GetData(), Item.GetSize());
+}
+
+} // namespace zen
diff --git a/src/zenserver/frontend/zipfs.h b/src/zenserver/frontend/zipfs.h
new file mode 100644
index 000000000..e1fa4457c
--- /dev/null
+++ b/src/zenserver/frontend/zipfs.h
@@ -0,0 +1,26 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/iobuffer.h>
+
+#include <unordered_map>
+
+namespace zen {
+
+//////////////////////////////////////////////////////////////////////////
+class ZipFs
+{
+public:
+ ZipFs() = default;
+ ZipFs(IoBuffer&& Buffer);
+ IoBuffer GetFile(const std::string_view& FileName) const;
+
+private:
+ using FileItem = MemoryView;
+ using FileMap = std::unordered_map<std::string_view, FileItem>;
+ FileMap mutable m_Files;
+ IoBuffer m_Buffer;
+};
+
+} // namespace zen
diff --git a/src/zenserver/monitoring/httpstats.cpp b/src/zenserver/monitoring/httpstats.cpp
new file mode 100644
index 000000000..4d985f8c2
--- /dev/null
+++ b/src/zenserver/monitoring/httpstats.cpp
@@ -0,0 +1,62 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "httpstats.h"
+
+namespace zen {
+
+HttpStatsService::HttpStatsService() : m_Log(logging::Get("stats"))
+{
+}
+
+HttpStatsService::~HttpStatsService()
+{
+}
+
+const char*
+HttpStatsService::BaseUri() const
+{
+ return "/stats/";
+}
+
+void
+HttpStatsService::RegisterHandler(std::string_view Id, IHttpStatsProvider& Provider)
+{
+ RwLock::ExclusiveLockScope _(m_Lock);
+ m_Providers.insert_or_assign(std::string(Id), &Provider);
+}
+
+void
+HttpStatsService::UnregisterHandler(std::string_view Id, IHttpStatsProvider& Provider)
+{
+ ZEN_UNUSED(Provider);
+
+ RwLock::ExclusiveLockScope _(m_Lock);
+ m_Providers.erase(std::string(Id));
+}
+
+void
+HttpStatsService::HandleRequest(HttpServerRequest& Request)
+{
+ using namespace std::literals;
+
+ std::string_view Key = Request.RelativeUri();
+
+ switch (Request.RequestVerb())
+ {
+ case HttpVerb::kHead:
+ case HttpVerb::kGet:
+ {
+ RwLock::SharedLockScope _(m_Lock);
+ if (auto It = m_Providers.find(std::string{Key}); It != end(m_Providers))
+ {
+ return It->second->HandleStatsRequest(Request);
+ }
+ }
+
+ [[fallthrough]];
+ default:
+ return;
+ }
+}
+
+} // namespace zen
diff --git a/src/zenserver/monitoring/httpstats.h b/src/zenserver/monitoring/httpstats.h
new file mode 100644
index 000000000..732815a9a
--- /dev/null
+++ b/src/zenserver/monitoring/httpstats.h
@@ -0,0 +1,38 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/logging.h>
+#include <zenhttp/httpserver.h>
+
+#include <map>
+
+namespace zen {
+
+struct IHttpStatsProvider
+{
+ virtual void HandleStatsRequest(HttpServerRequest& Request) = 0;
+};
+
+class HttpStatsService : public HttpService
+{
+public:
+ HttpStatsService();
+ ~HttpStatsService();
+
+ virtual const char* BaseUri() const override;
+ virtual void HandleRequest(HttpServerRequest& Request) override;
+ void RegisterHandler(std::string_view Id, IHttpStatsProvider& Provider);
+ void UnregisterHandler(std::string_view Id, IHttpStatsProvider& Provider);
+
+private:
+ spdlog::logger& m_Log;
+ HttpRequestRouter m_Router;
+
+ inline spdlog::logger& Log() { return m_Log; }
+
+ RwLock m_Lock;
+ std::map<std::string, IHttpStatsProvider*> m_Providers;
+};
+
+} // namespace zen \ No newline at end of file
diff --git a/src/zenserver/monitoring/httpstatus.cpp b/src/zenserver/monitoring/httpstatus.cpp
new file mode 100644
index 000000000..8b10601dd
--- /dev/null
+++ b/src/zenserver/monitoring/httpstatus.cpp
@@ -0,0 +1,62 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "httpstatus.h"
+
+namespace zen {
+
+HttpStatusService::HttpStatusService() : m_Log(logging::Get("status"))
+{
+}
+
+HttpStatusService::~HttpStatusService()
+{
+}
+
+const char*
+HttpStatusService::BaseUri() const
+{
+ return "/status/";
+}
+
+void
+HttpStatusService::RegisterHandler(std::string_view Id, IHttpStatusProvider& Provider)
+{
+ RwLock::ExclusiveLockScope _(m_Lock);
+ m_Providers.insert_or_assign(std::string(Id), &Provider);
+}
+
+void
+HttpStatusService::UnregisterHandler(std::string_view Id, IHttpStatusProvider& Provider)
+{
+ ZEN_UNUSED(Provider);
+
+ RwLock::ExclusiveLockScope _(m_Lock);
+ m_Providers.erase(std::string(Id));
+}
+
+void
+HttpStatusService::HandleRequest(HttpServerRequest& Request)
+{
+ using namespace std::literals;
+
+ std::string_view Key = Request.RelativeUri();
+
+ switch (Request.RequestVerb())
+ {
+ case HttpVerb::kHead:
+ case HttpVerb::kGet:
+ {
+ RwLock::SharedLockScope _(m_Lock);
+ if (auto It = m_Providers.find(std::string{Key}); It != end(m_Providers))
+ {
+ return It->second->HandleStatusRequest(Request);
+ }
+ }
+
+ [[fallthrough]];
+ default:
+ return;
+ }
+}
+
+} // namespace zen
diff --git a/src/zenserver/monitoring/httpstatus.h b/src/zenserver/monitoring/httpstatus.h
new file mode 100644
index 000000000..b04e45324
--- /dev/null
+++ b/src/zenserver/monitoring/httpstatus.h
@@ -0,0 +1,38 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/logging.h>
+#include <zenhttp/httpserver.h>
+
+#include <map>
+
+namespace zen {
+
+struct IHttpStatusProvider
+{
+ virtual void HandleStatusRequest(HttpServerRequest& Request) = 0;
+};
+
+class HttpStatusService : public HttpService
+{
+public:
+ HttpStatusService();
+ ~HttpStatusService();
+
+ virtual const char* BaseUri() const override;
+ virtual void HandleRequest(HttpServerRequest& Request) override;
+ void RegisterHandler(std::string_view Id, IHttpStatusProvider& Provider);
+ void UnregisterHandler(std::string_view Id, IHttpStatusProvider& Provider);
+
+private:
+ spdlog::logger& m_Log;
+ HttpRequestRouter m_Router;
+
+ RwLock m_Lock;
+ std::map<std::string, IHttpStatusProvider*> m_Providers;
+
+ inline spdlog::logger& Log() { return m_Log; }
+};
+
+} // namespace zen \ No newline at end of file
diff --git a/src/zenserver/objectstore/objectstore.cpp b/src/zenserver/objectstore/objectstore.cpp
new file mode 100644
index 000000000..e5739418e
--- /dev/null
+++ b/src/zenserver/objectstore/objectstore.cpp
@@ -0,0 +1,232 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <objectstore/objectstore.h>
+
+#include <zencore/filesystem.h>
+#include <zencore/fmtutils.h>
+#include <zencore/logging.h>
+#include <zencore/string.h>
+#include "zencore/compactbinarybuilder.h"
+#include "zenhttp/httpcommon.h"
+#include "zenhttp/httpserver.h"
+
+#include <thread>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <fmt/format.h>
+#include <json11.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+namespace zen {
+
+using namespace std::literals;
+
+ZEN_DEFINE_LOG_CATEGORY_STATIC(LogObj, "obj"sv);
+
+HttpObjectStoreService::HttpObjectStoreService(ObjectStoreConfig Cfg) : m_Cfg(std::move(Cfg))
+{
+ Inititalize();
+}
+
+HttpObjectStoreService::~HttpObjectStoreService()
+{
+}
+
+const char*
+HttpObjectStoreService::BaseUri() const
+{
+ return "/obj/";
+}
+
+void
+HttpObjectStoreService::HandleRequest(zen::HttpServerRequest& Request)
+{
+ if (m_Router.HandleRequest(Request) == false)
+ {
+ ZEN_LOG_WARN(LogObj, "No route found for {0}", Request.RelativeUri());
+ return Request.WriteResponse(HttpResponseCode::NotFound, HttpContentType::kText, "Not found"sv);
+ }
+}
+
+void
+HttpObjectStoreService::Inititalize()
+{
+ ZEN_LOG_INFO(LogObj, "Initialzing Object Store in '{}'", m_Cfg.RootDirectory);
+ for (const auto& Bucket : m_Cfg.Buckets)
+ {
+ ZEN_LOG_INFO(LogObj, " - bucket '{}' -> '{}'", Bucket.Name, Bucket.Directory);
+ }
+
+ m_Router.RegisterRoute(
+ "distributionpoints/{bucket}",
+ [this](zen::HttpRouterRequest& Request) {
+ const std::string BucketName = Request.GetCapture(1);
+
+ StringBuilder<1024> Json;
+ {
+ CbObjectWriter Writer;
+ Writer.BeginArray("distributions");
+ Writer << fmt::format("http://localhost:{}/obj/{}", m_Cfg.ServerPort, BucketName);
+ Writer.EndArray();
+ Writer.Save().ToJson(Json);
+ }
+
+ Request.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kJSON, Json.ToString());
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "{bucket}/{path}",
+ [this](zen::HttpRouterRequest& Request) { GetBlob(Request); },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "{bucket}/{path}",
+ [this](zen::HttpRouterRequest& Request) { PutBlob(Request); },
+ HttpVerb::kPost | HttpVerb::kPut);
+}
+
+std::filesystem::path
+HttpObjectStoreService::GetBucketDirectory(std::string_view BucketName)
+{
+ std::lock_guard _(BucketsMutex);
+
+ if (const auto It = std::find_if(std::begin(m_Cfg.Buckets),
+ std::end(m_Cfg.Buckets),
+ [&BucketName](const auto& Bucket) -> bool { return Bucket.Name == BucketName; });
+ It != std::end(m_Cfg.Buckets))
+ {
+ return It->Directory;
+ }
+
+ return std::filesystem::path();
+}
+
+void
+HttpObjectStoreService::GetBlob(zen::HttpRouterRequest& Request)
+{
+ namespace fs = std::filesystem;
+
+ const std::string& BucketName = Request.GetCapture(1);
+ const fs::path BucketDir = GetBucketDirectory(BucketName);
+
+ if (BucketDir.empty())
+ {
+ ZEN_LOG_DEBUG(LogObj, "GET - [FAILED], unknown bucket '{}'", BucketName);
+ return Request.ServerRequest().WriteResponse(HttpResponseCode::NotFound);
+ }
+
+ const fs::path RelativeBucketPath = Request.GetCapture(2);
+
+ if (RelativeBucketPath.is_absolute() || RelativeBucketPath.string().starts_with(".."))
+ {
+ ZEN_LOG_DEBUG(LogObj, "GET - from bucket '{}' [FAILED], invalid file path", BucketName);
+ return Request.ServerRequest().WriteResponse(HttpResponseCode::Forbidden);
+ }
+
+ fs::path FilePath = BucketDir / RelativeBucketPath;
+ if (fs::exists(FilePath) == false)
+ {
+ ZEN_LOG_DEBUG(LogObj, "GET - '{}/{}' [FAILED], doesn't exist", BucketName, FilePath);
+ return Request.ServerRequest().WriteResponse(HttpResponseCode::NotFound);
+ }
+
+ zen::HttpRanges Ranges;
+ if (Request.ServerRequest().TryGetRanges(Ranges); Ranges.size() > 1)
+ {
+ // Only a single range is supported
+ return Request.ServerRequest().WriteResponse(HttpResponseCode::BadRequest);
+ }
+
+ FileContents File = ReadFile(FilePath);
+ if (File.ErrorCode)
+ {
+ ZEN_LOG_WARN(LogObj,
+ "GET - '{}/{}' [FAILED] ('{}': {})",
+ BucketName,
+ FilePath,
+ File.ErrorCode.category().name(),
+ File.ErrorCode.value());
+
+ return Request.ServerRequest().WriteResponse(HttpResponseCode::BadRequest);
+ }
+
+ const IoBuffer& FileBuf = File.Data[0];
+
+ if (Ranges.empty())
+ {
+ const uint64_t TotalServed = TotalBytesServed.fetch_add(FileBuf.Size()) + FileBuf.Size();
+
+ ZEN_LOG_DEBUG(LogObj,
+ "GET - '{}/{}' ({}) [OK] (Served: {})",
+ BucketName,
+ RelativeBucketPath,
+ NiceBytes(FileBuf.Size()),
+ NiceBytes(TotalServed));
+
+ Request.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, FileBuf);
+ }
+ else
+ {
+ const auto Range = Ranges[0];
+ const uint64_t RangeSize = Range.End - Range.Start;
+ const uint64_t TotalServed = TotalBytesServed.fetch_add(RangeSize) + RangeSize;
+
+ ZEN_LOG_DEBUG(LogObj,
+ "GET - '{}/{}' (Range: {}-{}) ({}/{}) [OK] (Served: {})",
+ BucketName,
+ RelativeBucketPath,
+ Range.Start,
+ Range.End,
+ NiceBytes(RangeSize),
+ NiceBytes(FileBuf.Size()),
+ NiceBytes(TotalServed));
+
+ MemoryView RangeView = FileBuf.GetView().Mid(Range.Start, RangeSize);
+ if (RangeView.GetSize() != RangeSize)
+ {
+ return Request.ServerRequest().WriteResponse(HttpResponseCode::BadRequest);
+ }
+
+ IoBuffer RangeBuf = IoBuffer(IoBuffer::Wrap, RangeView.GetData(), RangeView.GetSize());
+ Request.ServerRequest().WriteResponse(HttpResponseCode::PartialContent, HttpContentType::kBinary, RangeBuf);
+ }
+}
+
+void
+HttpObjectStoreService::PutBlob(zen::HttpRouterRequest& Request)
+{
+ namespace fs = std::filesystem;
+
+ const std::string& BucketName = Request.GetCapture(1);
+ const fs::path BucketDir = GetBucketDirectory(BucketName);
+
+ if (BucketDir.empty())
+ {
+ ZEN_LOG_DEBUG(LogObj, "PUT - [FAILED], unknown bucket '{}'", BucketName);
+ return Request.ServerRequest().WriteResponse(HttpResponseCode::NotFound);
+ }
+
+ const fs::path RelativeBucketPath = Request.GetCapture(2);
+
+ if (RelativeBucketPath.is_absolute() || RelativeBucketPath.string().starts_with(".."))
+ {
+ ZEN_LOG_DEBUG(LogObj, "PUT - bucket '{}' [FAILED], invalid file path", BucketName);
+ return Request.ServerRequest().WriteResponse(HttpResponseCode::Forbidden);
+ }
+
+ fs::path FilePath = BucketDir / RelativeBucketPath;
+ const IoBuffer FileBuf = Request.ServerRequest().ReadPayload();
+
+ if (FileBuf.Size() == 0)
+ {
+ ZEN_LOG_DEBUG(LogObj, "PUT - '{}/{}' [FAILED], empty file", BucketName, FilePath);
+ return Request.ServerRequest().WriteResponse(HttpResponseCode::BadRequest);
+ }
+
+ WriteFile(FilePath, FileBuf);
+ ZEN_LOG_DEBUG(LogObj, "PUT - '{}/{}' [OK] ({})", BucketName, RelativeBucketPath, NiceBytes(FileBuf.Size()));
+ Request.ServerRequest().WriteResponse(HttpResponseCode::OK);
+}
+
+} // namespace zen
diff --git a/src/zenserver/objectstore/objectstore.h b/src/zenserver/objectstore/objectstore.h
new file mode 100644
index 000000000..eaab57794
--- /dev/null
+++ b/src/zenserver/objectstore/objectstore.h
@@ -0,0 +1,48 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zenhttp/httpserver.h>
+#include <atomic>
+#include <filesystem>
+#include <mutex>
+
+namespace zen {
+
+class HttpRouterRequest;
+
+struct ObjectStoreConfig
+{
+ struct BucketConfig
+ {
+ std::string Name;
+ std::filesystem::path Directory;
+ };
+
+ std::filesystem::path RootDirectory;
+ std::vector<BucketConfig> Buckets;
+ uint16_t ServerPort{1337};
+};
+
+class HttpObjectStoreService final : public zen::HttpService
+{
+public:
+ HttpObjectStoreService(ObjectStoreConfig Cfg);
+ virtual ~HttpObjectStoreService();
+
+ virtual const char* BaseUri() const override;
+ virtual void HandleRequest(zen::HttpServerRequest& Request) override;
+
+private:
+ void Inititalize();
+ std::filesystem::path GetBucketDirectory(std::string_view BucketName);
+ void GetBlob(zen::HttpRouterRequest& Request);
+ void PutBlob(zen::HttpRouterRequest& Request);
+
+ ObjectStoreConfig m_Cfg;
+ std::mutex BucketsMutex;
+ HttpRequestRouter m_Router;
+ std::atomic_uint64_t TotalBytesServed{0};
+};
+
+} // namespace zen
diff --git a/src/zenserver/projectstore/fileremoteprojectstore.cpp b/src/zenserver/projectstore/fileremoteprojectstore.cpp
new file mode 100644
index 000000000..d7a34a6c2
--- /dev/null
+++ b/src/zenserver/projectstore/fileremoteprojectstore.cpp
@@ -0,0 +1,235 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "fileremoteprojectstore.h"
+
+#include <zencore/compress.h>
+#include <zencore/filesystem.h>
+#include <zencore/fmtutils.h>
+#include <zencore/logging.h>
+#include <zencore/timer.h>
+
+namespace zen {
+
+using namespace std::literals;
+
+class LocalExportProjectStore : public RemoteProjectStore
+{
+public:
+ LocalExportProjectStore(std::string_view Name,
+ const std::filesystem::path& FolderPath,
+ bool ForceDisableBlocks,
+ bool ForceEnableTempBlocks)
+ : m_Name(Name)
+ , m_OutputPath(FolderPath)
+ {
+ if (ForceDisableBlocks)
+ {
+ m_EnableBlocks = false;
+ }
+ if (ForceEnableTempBlocks)
+ {
+ m_UseTempBlocks = true;
+ }
+ }
+
+ virtual RemoteStoreInfo GetInfo() const override
+ {
+ return {.CreateBlocks = m_EnableBlocks,
+ .UseTempBlockFiles = m_UseTempBlocks,
+ .Description = fmt::format("[file] {}"sv, m_OutputPath)};
+ }
+
+ virtual SaveResult SaveContainer(const IoBuffer& Payload) override
+ {
+ Stopwatch Timer;
+ SaveResult Result;
+
+ {
+ CbObject ContainerObject = LoadCompactBinaryObject(Payload);
+
+ ContainerObject.IterateAttachments([&](CbFieldView FieldView) {
+ IoHash AttachmentHash = FieldView.AsBinaryAttachment();
+ std::filesystem::path AttachmentPath = GetAttachmentPath(AttachmentHash);
+ if (!std::filesystem::exists(AttachmentPath))
+ {
+ Result.Needs.insert(AttachmentHash);
+ }
+ });
+ }
+
+ std::filesystem::path ContainerPath = m_OutputPath;
+ ContainerPath.append(m_Name);
+
+ CreateDirectories(m_OutputPath);
+ BasicFile ContainerFile;
+ ContainerFile.Open(ContainerPath, BasicFile::Mode::kTruncate);
+ std::error_code Ec;
+ ContainerFile.WriteAll(Payload, Ec);
+ if (Ec)
+ {
+ Result.ErrorCode = gsl::narrow<int32_t>(HttpResponseCode::InternalServerError);
+ Result.Reason = Ec.message();
+ }
+ Result.RawHash = IoHash::HashBuffer(Payload);
+ Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500;
+ return Result;
+ }
+
+ virtual SaveAttachmentResult SaveAttachment(const CompositeBuffer& Payload, const IoHash& RawHash) override
+ {
+ Stopwatch Timer;
+ SaveAttachmentResult Result;
+ std::filesystem::path ChunkPath = GetAttachmentPath(RawHash);
+ if (!std::filesystem::exists(ChunkPath))
+ {
+ try
+ {
+ CreateDirectories(ChunkPath.parent_path());
+
+ BasicFile ChunkFile;
+ ChunkFile.Open(ChunkPath, BasicFile::Mode::kTruncate);
+ size_t Offset = 0;
+ for (const SharedBuffer& Segment : Payload.GetSegments())
+ {
+ ChunkFile.Write(Segment.GetView(), Offset);
+ Offset += Segment.GetSize();
+ }
+ }
+ catch (std::exception& Ex)
+ {
+ Result.ErrorCode = gsl::narrow<int32_t>(HttpResponseCode::InternalServerError);
+ Result.Reason = Ex.what();
+ }
+ }
+ Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500;
+ return Result;
+ }
+
+ virtual SaveAttachmentsResult SaveAttachments(const std::vector<SharedBuffer>& Chunks) override
+ {
+ Stopwatch Timer;
+
+ for (const SharedBuffer& Chunk : Chunks)
+ {
+ CompressedBuffer Compressed = CompressedBuffer::FromCompressedNoValidate(Chunk.AsIoBuffer());
+ SaveAttachmentResult ChunkResult = SaveAttachment(Compressed.GetCompressed(), Compressed.DecodeRawHash());
+ if (ChunkResult.ErrorCode)
+ {
+ ChunkResult.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500;
+ return SaveAttachmentsResult{ChunkResult};
+ }
+ }
+ SaveAttachmentsResult Result;
+ Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500;
+ return Result;
+ }
+
+ virtual Result FinalizeContainer(const IoHash&) override { return {}; }
+
+ virtual LoadContainerResult LoadContainer() override
+ {
+ Stopwatch Timer;
+ LoadContainerResult Result;
+ std::filesystem::path ContainerPath = m_OutputPath;
+ ContainerPath.append(m_Name);
+ if (!std::filesystem::is_regular_file(ContainerPath))
+ {
+ Result.ErrorCode = gsl::narrow<int>(HttpResponseCode::NotFound);
+ Result.Reason = fmt::format("The file {} does not exist"sv, ContainerPath.string());
+ Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500;
+ return Result;
+ }
+ IoBuffer ContainerPayload;
+ {
+ BasicFile ContainerFile;
+ ContainerFile.Open(ContainerPath, BasicFile::Mode::kRead);
+ ContainerPayload = ContainerFile.ReadAll();
+ }
+ Result.ContainerObject = LoadCompactBinaryObject(ContainerPayload);
+ if (!Result.ContainerObject)
+ {
+ Result.ErrorCode = gsl::narrow<int32_t>(HttpResponseCode::InternalServerError);
+ Result.Reason = fmt::format("The file {} is not formatted as a compact binary object"sv, ContainerPath.string());
+ Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500;
+ return Result;
+ }
+ Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500;
+ return Result;
+ }
+ virtual LoadAttachmentResult LoadAttachment(const IoHash& RawHash) override
+ {
+ Stopwatch Timer;
+ LoadAttachmentResult Result;
+ std::filesystem::path ChunkPath = GetAttachmentPath(RawHash);
+ if (!std::filesystem::is_regular_file(ChunkPath))
+ {
+ Result.ErrorCode = gsl::narrow<int>(HttpResponseCode::NotFound);
+ Result.Reason = fmt::format("The file {} does not exist"sv, ChunkPath.string());
+ Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500;
+ return Result;
+ }
+ {
+ BasicFile ChunkFile;
+ ChunkFile.Open(ChunkPath, BasicFile::Mode::kRead);
+ Result.Bytes = ChunkFile.ReadAll();
+ }
+ Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500;
+ return Result;
+ }
+
+ virtual LoadAttachmentsResult LoadAttachments(const std::vector<IoHash>& RawHashes) override
+ {
+ Stopwatch Timer;
+ LoadAttachmentsResult Result;
+ for (const IoHash& Hash : RawHashes)
+ {
+ LoadAttachmentResult ChunkResult = LoadAttachment(Hash);
+ if (ChunkResult.ErrorCode)
+ {
+ ChunkResult.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500;
+ return LoadAttachmentsResult{ChunkResult};
+ }
+ ZEN_DEBUG("Loaded attachment in {}", NiceTimeSpanMs(static_cast<uint64_t>(ChunkResult.ElapsedSeconds * 1000)));
+ Result.Chunks.emplace_back(
+ std::pair<IoHash, CompressedBuffer>{Hash, CompressedBuffer::FromCompressedNoValidate(std::move(ChunkResult.Bytes))});
+ }
+ return Result;
+ }
+
+private:
+ std::filesystem::path GetAttachmentPath(const IoHash& RawHash) const
+ {
+ ExtendablePathBuilder<128> ShardedPath;
+ ShardedPath.Append(m_OutputPath.c_str());
+ ExtendableStringBuilder<64> HashString;
+ RawHash.ToHexString(HashString);
+ const char* str = HashString.c_str();
+ ShardedPath.AppendSeparator();
+ ShardedPath.AppendAsciiRange(str, str + 3);
+
+ ShardedPath.AppendSeparator();
+ ShardedPath.AppendAsciiRange(str + 3, str + 5);
+
+ ShardedPath.AppendSeparator();
+ ShardedPath.AppendAsciiRange(str + 5, str + 40);
+
+ return ShardedPath.ToPath();
+ }
+
+ const std::string m_Name;
+ const std::filesystem::path m_OutputPath;
+ bool m_EnableBlocks = true;
+ bool m_UseTempBlocks = false;
+};
+
+std::unique_ptr<RemoteProjectStore>
+CreateFileRemoteStore(const FileRemoteStoreOptions& Options)
+{
+ std::unique_ptr<RemoteProjectStore> RemoteStore = std::make_unique<LocalExportProjectStore>(Options.Name,
+ std::filesystem::path(Options.FolderPath),
+ Options.ForceDisableBlocks,
+ Options.ForceEnableTempBlocks);
+ return RemoteStore;
+}
+
+} // namespace zen
diff --git a/src/zenserver/projectstore/fileremoteprojectstore.h b/src/zenserver/projectstore/fileremoteprojectstore.h
new file mode 100644
index 000000000..68d1eb71e
--- /dev/null
+++ b/src/zenserver/projectstore/fileremoteprojectstore.h
@@ -0,0 +1,19 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "remoteprojectstore.h"
+
+namespace zen {
+
+struct FileRemoteStoreOptions : RemoteStoreOptions
+{
+ std::filesystem::path FolderPath;
+ std::string Name;
+ bool ForceDisableBlocks;
+ bool ForceEnableTempBlocks;
+};
+
+std::unique_ptr<RemoteProjectStore> CreateFileRemoteStore(const FileRemoteStoreOptions& Options);
+
+} // namespace zen
diff --git a/src/zenserver/projectstore/jupiterremoteprojectstore.cpp b/src/zenserver/projectstore/jupiterremoteprojectstore.cpp
new file mode 100644
index 000000000..66cf3c4f8
--- /dev/null
+++ b/src/zenserver/projectstore/jupiterremoteprojectstore.cpp
@@ -0,0 +1,244 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "jupiterremoteprojectstore.h"
+
+#include <zencore/compress.h>
+#include <zencore/fmtutils.h>
+
+#include <auth/authmgr.h>
+#include <upstream/jupiter.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <cpr/cpr.h>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+namespace zen {
+
+using namespace std::literals;
+
+class JupiterRemoteStore : public RemoteProjectStore
+{
+public:
+ JupiterRemoteStore(Ref<CloudCacheClient>&& CloudClient,
+ std::string_view Namespace,
+ std::string_view Bucket,
+ const IoHash& Key,
+ bool ForceDisableBlocks,
+ bool ForceDisableTempBlocks)
+ : m_CloudClient(CloudClient)
+ , m_Namespace(Namespace)
+ , m_Bucket(Bucket)
+ , m_Key(Key)
+ {
+ if (ForceDisableBlocks)
+ {
+ m_EnableBlocks = false;
+ }
+ if (ForceDisableTempBlocks)
+ {
+ m_UseTempBlocks = false;
+ }
+ }
+
+ virtual RemoteStoreInfo GetInfo() const override
+ {
+ return {.CreateBlocks = m_EnableBlocks,
+ .UseTempBlockFiles = m_UseTempBlocks,
+ .Description = fmt::format("[cloud] {} as {}/{}/{}"sv, m_CloudClient->ServiceUrl(), m_Namespace, m_Bucket, m_Key)};
+ }
+
+ virtual SaveResult SaveContainer(const IoBuffer& Payload) override
+ {
+ const int32_t MaxAttempts = 3;
+ PutRefResult Result;
+ {
+ CloudCacheSession Session(m_CloudClient.Get());
+ for (int32_t Attempt = 0; Attempt < MaxAttempts && !Result.Success; Attempt++)
+ {
+ Result = Session.PutRef(m_Namespace, m_Bucket, m_Key, Payload, ZenContentType::kCbObject);
+ }
+ }
+
+ return SaveResult{ConvertResult(Result), {Result.Needs.begin(), Result.Needs.end()} /*, {}*/, IoHash::HashBuffer(Payload)};
+ }
+
+ virtual SaveAttachmentResult SaveAttachment(const CompositeBuffer& Payload, const IoHash& RawHash) override
+ {
+ const int32_t MaxAttempts = 3;
+ CloudCacheResult Result;
+ {
+ CloudCacheSession Session(m_CloudClient.Get());
+ for (int32_t Attempt = 0; Attempt < MaxAttempts && !Result.Success; Attempt++)
+ {
+ Result = Session.PutCompressedBlob(m_Namespace, RawHash, Payload);
+ }
+ }
+
+ return SaveAttachmentResult{ConvertResult(Result)};
+ }
+
+ virtual SaveAttachmentsResult SaveAttachments(const std::vector<SharedBuffer>& Chunks) override
+ {
+ SaveAttachmentsResult Result;
+ for (const SharedBuffer& Chunk : Chunks)
+ {
+ CompressedBuffer Compressed = CompressedBuffer::FromCompressedNoValidate(Chunk.AsIoBuffer());
+ SaveAttachmentResult ChunkResult = SaveAttachment(Compressed.GetCompressed(), Compressed.DecodeRawHash());
+ if (ChunkResult.ErrorCode)
+ {
+ return SaveAttachmentsResult{ChunkResult};
+ }
+ }
+ return Result;
+ }
+
+ virtual Result FinalizeContainer(const IoHash& RawHash) override
+ {
+ const int32_t MaxAttempts = 3;
+ CloudCacheResult Result;
+ {
+ CloudCacheSession Session(m_CloudClient.Get());
+ for (int32_t Attempt = 0; Attempt < MaxAttempts && !Result.Success; Attempt++)
+ {
+ Result = Session.FinalizeRef(m_Namespace, m_Bucket, m_Key, RawHash);
+ }
+ }
+ return ConvertResult(Result);
+ }
+
+ virtual LoadContainerResult LoadContainer() override
+ {
+ const int32_t MaxAttempts = 3;
+ CloudCacheResult Result;
+ {
+ CloudCacheSession Session(m_CloudClient.Get());
+ for (int32_t Attempt = 0; Attempt < MaxAttempts && !Result.Success; Attempt++)
+ {
+ Result = Session.GetRef(m_Namespace, m_Bucket, m_Key, ZenContentType::kCbObject);
+ }
+ }
+
+ if (Result.ErrorCode || !Result.Success)
+ {
+ return LoadContainerResult{ConvertResult(Result)};
+ }
+
+ CbObject ContainerObject = LoadCompactBinaryObject(Result.Response);
+ if (!ContainerObject)
+ {
+ return LoadContainerResult{
+ RemoteProjectStore::Result{
+ .ErrorCode = gsl::narrow<int32_t>(HttpResponseCode::InternalServerError),
+ .ElapsedSeconds = Result.ElapsedSeconds,
+ .Reason = fmt::format("The ref {}/{}/{} is not formatted as a compact binary object"sv, m_Namespace, m_Bucket, m_Key)},
+ std::move(ContainerObject)};
+ }
+
+ return LoadContainerResult{ConvertResult(Result), std::move(ContainerObject)};
+ }
+
+ virtual LoadAttachmentResult LoadAttachment(const IoHash& RawHash) override
+ {
+ const int32_t MaxAttempts = 3;
+ CloudCacheResult Result;
+ {
+ CloudCacheSession Session(m_CloudClient.Get());
+ for (int32_t Attempt = 0; Attempt < MaxAttempts && !Result.Success; Attempt++)
+ {
+ Result = Session.GetCompressedBlob(m_Namespace, RawHash);
+ }
+ }
+ return LoadAttachmentResult{ConvertResult(Result), std::move(Result.Response)};
+ }
+
+ virtual LoadAttachmentsResult LoadAttachments(const std::vector<IoHash>& RawHashes) override
+ {
+ LoadAttachmentsResult Result;
+ for (const IoHash& Hash : RawHashes)
+ {
+ LoadAttachmentResult ChunkResult = LoadAttachment(Hash);
+ if (ChunkResult.ErrorCode)
+ {
+ return LoadAttachmentsResult{ChunkResult};
+ }
+ ZEN_DEBUG("Loaded attachment in {}", NiceTimeSpanMs(static_cast<uint64_t>(ChunkResult.ElapsedSeconds * 1000)));
+ Result.Chunks.emplace_back(
+ std::pair<IoHash, CompressedBuffer>{Hash, CompressedBuffer::FromCompressedNoValidate(std::move(ChunkResult.Bytes))});
+ }
+ return Result;
+ }
+
+private:
+ static Result ConvertResult(const CloudCacheResult& Response)
+ {
+ std::string Text;
+ int32_t ErrorCode = 0;
+ if (Response.ErrorCode != 0)
+ {
+ ErrorCode = Response.ErrorCode;
+ }
+ else if (!Response.Success)
+ {
+ ErrorCode = gsl::narrow<int32_t>(HttpResponseCode::InternalServerError);
+ if (Response.Response.GetContentType() == ZenContentType::kText)
+ {
+ Text =
+ std::string(reinterpret_cast<const std::string::value_type*>(Response.Response.GetData()), Response.Response.GetSize());
+ }
+ }
+ return {.ErrorCode = ErrorCode, .ElapsedSeconds = Response.ElapsedSeconds, .Reason = Response.Reason, .Text = Text};
+ }
+
+ Ref<CloudCacheClient> m_CloudClient;
+ const std::string m_Namespace;
+ const std::string m_Bucket;
+ const IoHash m_Key;
+ bool m_EnableBlocks = true;
+ bool m_UseTempBlocks = true;
+};
+
+std::unique_ptr<RemoteProjectStore>
+CreateJupiterRemoteStore(const JupiterRemoteStoreOptions& Options)
+{
+ std::string Url = Options.Url;
+ if (Url.find("://"sv) == std::string::npos)
+ {
+ // Assume https URL
+ Url = fmt::format("https://{}"sv, Url);
+ }
+ CloudCacheClientOptions ClientOptions{.Name = "Remote store"sv,
+ .ServiceUrl = Url,
+ .ConnectTimeout = std::chrono::milliseconds(2000),
+ .Timeout = std::chrono::milliseconds(60000)};
+ // 1) Access token as parameter in request
+ // 2) Environment variable (different win vs linux/mac)
+ // 3) openid-provider (assumes oidctoken.exe -Zen true has been run with matching Options.OpenIdProvider
+
+ std::unique_ptr<CloudCacheTokenProvider> TokenProvider;
+ if (!Options.AccessToken.empty())
+ {
+ TokenProvider = CloudCacheTokenProvider::CreateFromCallback([AccessToken = Options.AccessToken]() {
+ return CloudCacheAccessToken{.Value = AccessToken, .ExpireTime = GcClock::TimePoint::max()};
+ });
+ }
+ else
+ {
+ TokenProvider =
+ CloudCacheTokenProvider::CreateFromCallback([&AuthManager = Options.AuthManager, OpenIdProvider = Options.OpenIdProvider]() {
+ AuthMgr::OpenIdAccessToken Token = AuthManager.GetOpenIdAccessToken(OpenIdProvider.empty() ? "Default" : OpenIdProvider);
+ return CloudCacheAccessToken{.Value = Token.AccessToken, .ExpireTime = Token.ExpireTime};
+ });
+ }
+
+ Ref<CloudCacheClient> CloudClient(new CloudCacheClient(ClientOptions, std::move(TokenProvider)));
+
+ std::unique_ptr<RemoteProjectStore> RemoteStore = std::make_unique<JupiterRemoteStore>(std::move(CloudClient),
+ Options.Namespace,
+ Options.Bucket,
+ Options.Key,
+ Options.ForceDisableBlocks,
+ Options.ForceDisableTempBlocks);
+ return RemoteStore;
+}
+
+} // namespace zen
diff --git a/src/zenserver/projectstore/jupiterremoteprojectstore.h b/src/zenserver/projectstore/jupiterremoteprojectstore.h
new file mode 100644
index 000000000..31548af22
--- /dev/null
+++ b/src/zenserver/projectstore/jupiterremoteprojectstore.h
@@ -0,0 +1,26 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "remoteprojectstore.h"
+
+namespace zen {
+
+class AuthMgr;
+
+struct JupiterRemoteStoreOptions : RemoteStoreOptions
+{
+ std::string Url;
+ std::string Namespace;
+ std::string Bucket;
+ IoHash Key;
+ std::string OpenIdProvider;
+ std::string AccessToken;
+ AuthMgr& AuthManager;
+ bool ForceDisableBlocks;
+ bool ForceDisableTempBlocks;
+};
+
+std::unique_ptr<RemoteProjectStore> CreateJupiterRemoteStore(const JupiterRemoteStoreOptions& Options);
+
+} // namespace zen
diff --git a/src/zenserver/projectstore/projectstore.cpp b/src/zenserver/projectstore/projectstore.cpp
new file mode 100644
index 000000000..847a79a1d
--- /dev/null
+++ b/src/zenserver/projectstore/projectstore.cpp
@@ -0,0 +1,4082 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "projectstore.h"
+
+#include <zencore/compactbinarybuilder.h>
+#include <zencore/compactbinarypackage.h>
+#include <zencore/compactbinaryvalidation.h>
+#include <zencore/filesystem.h>
+#include <zencore/fmtutils.h>
+#include <zencore/logging.h>
+#include <zencore/scopeguard.h>
+#include <zencore/stream.h>
+#include <zencore/timer.h>
+#include <zencore/trace.h>
+#include <zenhttp/httpshared.h>
+#include <zenstore/caslog.h>
+#include <zenstore/cidstore.h>
+#include <zenstore/scrubcontext.h>
+#include <zenutil/cache/rpcrecording.h>
+
+#include "fileremoteprojectstore.h"
+#include "jupiterremoteprojectstore.h"
+#include "remoteprojectstore.h"
+#include "zenremoteprojectstore.h"
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <cpr/cpr.h>
+#include <xxh3.h>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+#if ZEN_WITH_TESTS
+# include <zencore/testing.h>
+# include <zencore/testutils.h>
+#endif // ZEN_WITH_TESTS
+
+namespace zen {
+
+namespace {
+ bool PrepareDirectoryDelete(const std::filesystem::path& Dir, std::filesystem::path& OutDeleteDir)
+ {
+ int DropIndex = 0;
+ do
+ {
+ if (!std::filesystem::exists(Dir))
+ {
+ return true;
+ }
+
+ std::string DroppedName = fmt::format("[dropped]{}({})", Dir.filename().string(), DropIndex);
+ std::filesystem::path DroppedBucketPath = Dir.parent_path() / DroppedName;
+ if (std::filesystem::exists(DroppedBucketPath))
+ {
+ DropIndex++;
+ continue;
+ }
+
+ std::error_code Ec;
+ std::filesystem::rename(Dir, DroppedBucketPath, Ec);
+ if (!Ec)
+ {
+ OutDeleteDir = DroppedBucketPath;
+ return true;
+ }
+ if (Ec && !std::filesystem::exists(DroppedBucketPath))
+ {
+ // We can't move our folder, probably because it is busy, bail..
+ return false;
+ }
+ Sleep(100);
+ } while (true);
+ }
+
+ std::pair<std::unique_ptr<RemoteProjectStore>, std::string> CreateRemoteStore(CbObjectView Params,
+ AuthMgr& AuthManager,
+ size_t MaxBlockSize,
+ size_t MaxChunkEmbedSize)
+ {
+ using namespace std::literals;
+
+ std::unique_ptr<RemoteProjectStore> RemoteStore;
+
+ if (CbObjectView File = Params["file"sv].AsObjectView(); File)
+ {
+ std::filesystem::path FolderPath(File["path"sv].AsString());
+ if (FolderPath.empty())
+ {
+ return {nullptr, "Missing file path"};
+ }
+ std::string_view Name(File["name"sv].AsString());
+ if (Name.empty())
+ {
+ return {nullptr, "Missing file name"};
+ }
+ bool ForceDisableBlocks = File["disableblocks"sv].AsBool(false);
+ bool ForceEnableTempBlocks = File["enabletempblocks"sv].AsBool(false);
+
+ FileRemoteStoreOptions Options = {RemoteStoreOptions{.MaxBlockSize = MaxBlockSize, .MaxChunkEmbedSize = MaxChunkEmbedSize},
+ FolderPath,
+ std::string(Name),
+ ForceDisableBlocks,
+ ForceEnableTempBlocks};
+ RemoteStore = CreateFileRemoteStore(Options);
+ }
+
+ if (CbObjectView Cloud = Params["cloud"sv].AsObjectView(); Cloud)
+ {
+ std::string_view CloudServiceUrl = Cloud["url"sv].AsString();
+ if (CloudServiceUrl.empty())
+ {
+ return {nullptr, "Missing service url"};
+ }
+
+ std::string Url = cpr::util::urlDecode(std::string(CloudServiceUrl));
+ std::string_view Namespace = Cloud["namespace"sv].AsString();
+ if (Namespace.empty())
+ {
+ return {nullptr, "Missing namespace"};
+ }
+ std::string_view Bucket = Cloud["bucket"sv].AsString();
+ if (Bucket.empty())
+ {
+ return {nullptr, "Missing bucket"};
+ }
+ std::string_view OpenIdProvider = Cloud["openid-provider"sv].AsString();
+ std::string AccessToken = std::string(Cloud["access-token"sv].AsString());
+ if (AccessToken.empty())
+ {
+ std::string_view AccessTokenEnvVariable = Cloud["access-token-env"].AsString();
+ if (!AccessTokenEnvVariable.empty())
+ {
+ AccessToken = GetEnvVariable(AccessTokenEnvVariable);
+ }
+ }
+ std::string_view KeyParam = Cloud["key"sv].AsString();
+ if (KeyParam.empty())
+ {
+ return {nullptr, "Missing key"};
+ }
+ if (KeyParam.length() != IoHash::StringLength)
+ {
+ return {nullptr, "Invalid key"};
+ }
+ IoHash Key = IoHash::FromHexString(KeyParam);
+ if (Key == IoHash::Zero)
+ {
+ return {nullptr, "Invalid key string"};
+ }
+ bool ForceDisableBlocks = Cloud["disableblocks"sv].AsBool(false);
+ bool ForceDisableTempBlocks = Cloud["disabletempblocks"sv].AsBool(false);
+
+ JupiterRemoteStoreOptions Options = {RemoteStoreOptions{.MaxBlockSize = MaxBlockSize, .MaxChunkEmbedSize = MaxChunkEmbedSize},
+ Url,
+ std::string(Namespace),
+ std::string(Bucket),
+ Key,
+ std::string(OpenIdProvider),
+ AccessToken,
+ AuthManager,
+ ForceDisableBlocks,
+ ForceDisableTempBlocks};
+ RemoteStore = CreateJupiterRemoteStore(Options);
+ }
+
+ if (CbObjectView Zen = Params["zen"sv].AsObjectView(); Zen)
+ {
+ std::string_view Url = Zen["url"sv].AsString();
+ std::string_view Project = Zen["project"sv].AsString();
+ if (Project.empty())
+ {
+ return {nullptr, "Missing project"};
+ }
+ std::string_view Oplog = Zen["oplog"sv].AsString();
+ if (Oplog.empty())
+ {
+ return {nullptr, "Missing oplog"};
+ }
+ ZenRemoteStoreOptions Options = {RemoteStoreOptions{.MaxBlockSize = MaxBlockSize, .MaxChunkEmbedSize = MaxChunkEmbedSize},
+ std::string(Url),
+ std::string(Project),
+ std::string(Oplog)};
+ RemoteStore = CreateZenRemoteStore(Options);
+ }
+
+ if (!RemoteStore)
+ {
+ return {nullptr, "Unknown remote store type"};
+ }
+
+ return {std::move(RemoteStore), ""};
+ }
+
+ std::pair<HttpResponseCode, std::string> ConvertResult(const RemoteProjectStore::Result& Result)
+ {
+ if (Result.ErrorCode == 0)
+ {
+ return {HttpResponseCode::OK, Result.Text};
+ }
+ return {static_cast<HttpResponseCode>(Result.ErrorCode),
+ Result.Reason.empty() ? Result.Text
+ : Result.Text.empty() ? Result.Reason
+ : fmt::format("{}. Reason: '{}'", Result.Text, Result.Reason)};
+ }
+
+ void CSVHeader(bool Details, bool AttachmentDetails, StringBuilderBase& CSVWriter)
+ {
+ if (AttachmentDetails)
+ {
+ CSVWriter << "Project, Oplog, LSN, Key, Cid, Size";
+ }
+ else if (Details)
+ {
+ CSVWriter << "Project, Oplog, LSN, Key, Size, AttachmentCount, AttachmentsSize";
+ }
+ else
+ {
+ CSVWriter << "Project, Oplog, Key";
+ }
+ }
+
+ void CSVWriteOp(CidStore& CidStore,
+ std::string_view ProjectId,
+ std::string_view OplogId,
+ bool Details,
+ bool AttachmentDetails,
+ int LSN,
+ const Oid& Key,
+ CbObject Op,
+ StringBuilderBase& CSVWriter)
+ {
+ StringBuilder<32> KeyStringBuilder;
+ Key.ToString(KeyStringBuilder);
+ const std::string_view KeyString = KeyStringBuilder.ToView();
+
+ SharedBuffer Buffer = Op.GetBuffer();
+ if (AttachmentDetails)
+ {
+ Op.IterateAttachments([&CidStore, &CSVWriter, &ProjectId, &OplogId, LSN, &KeyString](CbFieldView FieldView) {
+ const IoHash AttachmentHash = FieldView.AsAttachment();
+ IoBuffer Attachment = CidStore.FindChunkByCid(AttachmentHash);
+ CSVWriter << "\r\n"
+ << ProjectId << ", " << OplogId << ", " << LSN << ", " << KeyString << ", " << AttachmentHash.ToHexString()
+ << ", " << gsl::narrow<uint64_t>(Attachment.GetSize());
+ });
+ }
+ else if (Details)
+ {
+ uint64_t AttachmentCount = 0;
+ size_t AttachmentsSize = 0;
+ Op.IterateAttachments([&CidStore, &AttachmentCount, &AttachmentsSize](CbFieldView FieldView) {
+ const IoHash AttachmentHash = FieldView.AsAttachment();
+ AttachmentCount++;
+ IoBuffer Attachment = CidStore.FindChunkByCid(AttachmentHash);
+ AttachmentsSize += Attachment.GetSize();
+ });
+ CSVWriter << "\r\n"
+ << ProjectId << ", " << OplogId << ", " << LSN << ", " << KeyString << ", " << gsl::narrow<uint64_t>(Buffer.GetSize())
+ << ", " << AttachmentCount << ", " << gsl::narrow<uint64_t>(AttachmentsSize);
+ }
+ else
+ {
+ CSVWriter << "\r\n" << ProjectId << ", " << OplogId << ", " << KeyString;
+ }
+ };
+
+ void CbWriteOp(CidStore& CidStore,
+ bool Details,
+ bool OpDetails,
+ bool AttachmentDetails,
+ int LSN,
+ const Oid& Key,
+ CbObject Op,
+ CbObjectWriter& CbWriter)
+ {
+ CbWriter.BeginObject();
+ {
+ SharedBuffer Buffer = Op.GetBuffer();
+ CbWriter.AddObjectId("key", Key);
+ if (Details)
+ {
+ CbWriter.AddInteger("lsn", LSN);
+ CbWriter.AddInteger("size", gsl::narrow<uint64_t>(Buffer.GetSize()));
+ }
+ if (AttachmentDetails)
+ {
+ CbWriter.BeginArray("attachments");
+ Op.IterateAttachments([&CidStore, &CbWriter](CbFieldView FieldView) {
+ const IoHash AttachmentHash = FieldView.AsAttachment();
+ CbWriter.BeginObject();
+ {
+ IoBuffer Attachment = CidStore.FindChunkByCid(AttachmentHash);
+ CbWriter.AddString("cid", AttachmentHash.ToHexString());
+ CbWriter.AddInteger("size", gsl::narrow<uint64_t>(Attachment.GetSize()));
+ }
+ CbWriter.EndObject();
+ });
+ CbWriter.EndArray();
+ }
+ else if (Details)
+ {
+ uint64_t AttachmentCount = 0;
+ size_t AttachmentsSize = 0;
+ Op.IterateAttachments([&CidStore, &AttachmentCount, &AttachmentsSize](CbFieldView FieldView) {
+ const IoHash AttachmentHash = FieldView.AsAttachment();
+ AttachmentCount++;
+ IoBuffer Attachment = CidStore.FindChunkByCid(AttachmentHash);
+ AttachmentsSize += Attachment.GetSize();
+ });
+ if (AttachmentCount > 0)
+ {
+ CbWriter.AddInteger("attachments", AttachmentCount);
+ CbWriter.AddInteger("attachmentssize", gsl::narrow<uint64_t>(AttachmentsSize));
+ }
+ }
+ if (OpDetails)
+ {
+ CbWriter.BeginObject("op");
+ for (const CbFieldView& Field : Op)
+ {
+ if (!Field.HasName())
+ {
+ CbWriter.AddField(Field);
+ continue;
+ }
+ std::string_view FieldName = Field.GetName();
+ CbWriter.AddField(FieldName, Field);
+ }
+ CbWriter.EndObject();
+ }
+ }
+ CbWriter.EndObject();
+ };
+
+ void CbWriteOplogOps(CidStore& CidStore,
+ ProjectStore::Oplog& Oplog,
+ bool Details,
+ bool OpDetails,
+ bool AttachmentDetails,
+ CbObjectWriter& Cbo)
+ {
+ Cbo.BeginArray("ops");
+ {
+ Oplog.IterateOplogWithKey([&Cbo, &CidStore, Details, OpDetails, AttachmentDetails](int LSN, const Oid& Key, CbObject Op) {
+ CbWriteOp(CidStore, Details, OpDetails, AttachmentDetails, LSN, Key, Op, Cbo);
+ });
+ }
+ Cbo.EndArray();
+ }
+
+ void CbWriteOplog(CidStore& CidStore,
+ ProjectStore::Oplog& Oplog,
+ bool Details,
+ bool OpDetails,
+ bool AttachmentDetails,
+ CbObjectWriter& Cbo)
+ {
+ Cbo.BeginObject();
+ {
+ Cbo.AddString("name", Oplog.OplogId());
+ CbWriteOplogOps(CidStore, Oplog, Details, OpDetails, AttachmentDetails, Cbo);
+ }
+ Cbo.EndObject();
+ }
+
+ void CbWriteOplogs(CidStore& CidStore,
+ ProjectStore::Project& Project,
+ std::vector<std::string> OpLogs,
+ bool Details,
+ bool OpDetails,
+ bool AttachmentDetails,
+ CbObjectWriter& Cbo)
+ {
+ Cbo.BeginArray("oplogs");
+ {
+ for (const std::string& OpLogId : OpLogs)
+ {
+ ProjectStore::Oplog* Oplog = Project.OpenOplog(OpLogId);
+ if (Oplog != nullptr)
+ {
+ CbWriteOplog(CidStore, *Oplog, Details, OpDetails, AttachmentDetails, Cbo);
+ }
+ }
+ }
+ Cbo.EndArray();
+ }
+
+ void CbWriteProject(CidStore& CidStore,
+ ProjectStore::Project& Project,
+ std::vector<std::string> OpLogs,
+ bool Details,
+ bool OpDetails,
+ bool AttachmentDetails,
+ CbObjectWriter& Cbo)
+ {
+ Cbo.BeginObject();
+ {
+ Cbo.AddString("name", Project.Identifier);
+ CbWriteOplogs(CidStore, Project, OpLogs, Details, OpDetails, AttachmentDetails, Cbo);
+ }
+ Cbo.EndObject();
+ }
+
+} // namespace
+
+//////////////////////////////////////////////////////////////////////////
+
+Oid
+OpKeyStringAsOId(std::string_view OpKey)
+{
+ using namespace std::literals;
+
+ CbObjectWriter Writer;
+ Writer << "key"sv << OpKey;
+
+ XXH3_128Stream KeyHasher;
+ Writer.Save()["key"sv].WriteToStream([&](const void* Data, size_t Size) { KeyHasher.Append(Data, Size); });
+ XXH3_128 KeyHash = KeyHasher.GetHash();
+
+ Oid OpId;
+ memcpy(OpId.OidBits, &KeyHash, sizeof(OpId.OidBits));
+
+ return OpId;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+struct ProjectStore::OplogStorage : public RefCounted
+{
+ OplogStorage(ProjectStore::Oplog* OwnerOplog, std::filesystem::path BasePath) : m_OwnerOplog(OwnerOplog), m_OplogStoragePath(BasePath)
+ {
+ }
+
+ ~OplogStorage()
+ {
+ ZEN_INFO("closing oplog storage at {}", m_OplogStoragePath);
+ Flush();
+ }
+
+ [[nodiscard]] bool Exists() { return Exists(m_OplogStoragePath); }
+ [[nodiscard]] static bool Exists(std::filesystem::path BasePath)
+ {
+ return std::filesystem::exists(BasePath / "ops.zlog") && std::filesystem::exists(BasePath / "ops.zops");
+ }
+
+ static bool Delete(std::filesystem::path BasePath) { return DeleteDirectories(BasePath); }
+
+ uint64_t OpBlobsSize() const
+ {
+ RwLock::SharedLockScope _(m_RwLock);
+ return m_NextOpsOffset;
+ }
+
+ void Open(bool IsCreate)
+ {
+ using namespace std::literals;
+
+ ZEN_INFO("initializing oplog storage at '{}'", m_OplogStoragePath);
+
+ if (IsCreate)
+ {
+ DeleteDirectories(m_OplogStoragePath);
+ CreateDirectories(m_OplogStoragePath);
+ }
+
+ m_Oplog.Open(m_OplogStoragePath / "ops.zlog"sv, IsCreate ? CasLogFile::Mode::kTruncate : CasLogFile::Mode::kWrite);
+ m_Oplog.Initialize();
+
+ m_OpBlobs.Open(m_OplogStoragePath / "ops.zops"sv, IsCreate ? BasicFile::Mode::kTruncate : BasicFile::Mode::kWrite);
+
+ ZEN_ASSERT(IsPow2(m_OpsAlign));
+ ZEN_ASSERT(!(m_NextOpsOffset & (m_OpsAlign - 1)));
+ }
+
+ void ReplayLog(std::function<void(CbObject, const OplogEntry&)>&& Handler)
+ {
+ ZEN_TRACE_CPU("ProjectStore::OplogStorage::ReplayLog");
+
+ // This could use memory mapping or do something clever but for now it just reads the file sequentially
+
+ ZEN_INFO("replaying log for '{}'", m_OplogStoragePath);
+
+ Stopwatch Timer;
+
+ uint64_t InvalidEntries = 0;
+
+ IoBuffer OpBuffer;
+ m_Oplog.Replay(
+ [&](const OplogEntry& LogEntry) {
+ if (LogEntry.OpCoreSize == 0)
+ {
+ ++InvalidEntries;
+
+ return;
+ }
+
+ if (OpBuffer.GetSize() < LogEntry.OpCoreSize)
+ {
+ OpBuffer = IoBuffer(LogEntry.OpCoreSize);
+ }
+
+ const uint64_t OpFileOffset = LogEntry.OpCoreOffset * m_OpsAlign;
+
+ m_OpBlobs.Read((void*)OpBuffer.Data(), LogEntry.OpCoreSize, OpFileOffset);
+
+ // Verify checksum, ignore op data if incorrect
+ const auto OpCoreHash = uint32_t(XXH3_64bits(OpBuffer.Data(), LogEntry.OpCoreSize) & 0xffffFFFF);
+
+ if (OpCoreHash != LogEntry.OpCoreHash)
+ {
+ ZEN_WARN("skipping oplog entry with bad checksum!");
+ return;
+ }
+
+ CbObject Op(SharedBuffer::MakeView(OpBuffer.Data(), LogEntry.OpCoreSize));
+
+ m_NextOpsOffset =
+ Max(m_NextOpsOffset.load(std::memory_order_relaxed), RoundUp(OpFileOffset + LogEntry.OpCoreSize, m_OpsAlign));
+ m_MaxLsn = Max(m_MaxLsn.load(std::memory_order_relaxed), LogEntry.OpLsn);
+
+ Handler(Op, LogEntry);
+ },
+ 0);
+
+ if (InvalidEntries)
+ {
+ ZEN_WARN("ignored {} zero-sized oplog entries", InvalidEntries);
+ }
+
+ ZEN_INFO("Oplog replay completed in {} - Max LSN# {}, Next offset: {}",
+ NiceTimeSpanMs(Timer.GetElapsedTimeMs()),
+ m_MaxLsn,
+ m_NextOpsOffset);
+ }
+
+ void ReplayLog(const std::vector<OplogEntryAddress>& Entries, std::function<void(CbObject)>&& Handler)
+ {
+ for (const OplogEntryAddress& Entry : Entries)
+ {
+ CbObject Op = GetOp(Entry);
+ Handler(Op);
+ }
+ }
+
+ CbObject GetOp(const OplogEntryAddress& Entry)
+ {
+ IoBuffer OpBuffer(Entry.Size);
+
+ const uint64_t OpFileOffset = Entry.Offset * m_OpsAlign;
+ m_OpBlobs.Read((void*)OpBuffer.Data(), Entry.Size, OpFileOffset);
+
+ return CbObject(SharedBuffer(std::move(OpBuffer)));
+ }
+
+ OplogEntry AppendOp(SharedBuffer Buffer, uint32_t OpCoreHash, XXH3_128 KeyHash)
+ {
+ ZEN_TRACE_CPU("ProjectStore::OplogStorage::AppendOp");
+
+ using namespace std::literals;
+
+ uint64_t WriteSize = Buffer.GetSize();
+
+ RwLock::ExclusiveLockScope Lock(m_RwLock);
+ const uint64_t WriteOffset = m_NextOpsOffset;
+ const uint32_t OpLsn = ++m_MaxLsn;
+ m_NextOpsOffset = RoundUp(WriteOffset + WriteSize, m_OpsAlign);
+ Lock.ReleaseNow();
+
+ ZEN_ASSERT(IsMultipleOf(WriteOffset, m_OpsAlign));
+
+ OplogEntry Entry = {.OpLsn = OpLsn,
+ .OpCoreOffset = gsl::narrow_cast<uint32_t>(WriteOffset / m_OpsAlign),
+ .OpCoreSize = uint32_t(Buffer.GetSize()),
+ .OpCoreHash = OpCoreHash,
+ .OpKeyHash = KeyHash};
+
+ m_Oplog.Append(Entry);
+ m_OpBlobs.Write(Buffer.GetData(), WriteSize, WriteOffset);
+
+ return Entry;
+ }
+
+ void Flush()
+ {
+ m_Oplog.Flush();
+ m_OpBlobs.Flush();
+ }
+
+ spdlog::logger& Log() { return m_OwnerOplog->Log(); }
+
+private:
+ ProjectStore::Oplog* m_OwnerOplog;
+ std::filesystem::path m_OplogStoragePath;
+ mutable RwLock m_RwLock;
+ TCasLogFile<OplogEntry> m_Oplog;
+ BasicFile m_OpBlobs;
+ std::atomic<uint64_t> m_NextOpsOffset{0};
+ uint64_t m_OpsAlign = 32;
+ std::atomic<uint32_t> m_MaxLsn{0};
+};
+
+//////////////////////////////////////////////////////////////////////////
+
+ProjectStore::Oplog::Oplog(std::string_view Id,
+ Project* Project,
+ CidStore& Store,
+ std::filesystem::path BasePath,
+ const std::filesystem::path& MarkerPath)
+: m_OuterProject(Project)
+, m_CidStore(Store)
+, m_BasePath(BasePath)
+, m_MarkerPath(MarkerPath)
+, m_OplogId(Id)
+{
+ using namespace std::literals;
+
+ m_Storage = new OplogStorage(this, m_BasePath);
+ const bool StoreExists = m_Storage->Exists();
+ m_Storage->Open(/* IsCreate */ !StoreExists);
+
+ m_TempPath = m_BasePath / "temp"sv;
+
+ CleanDirectory(m_TempPath);
+}
+
+ProjectStore::Oplog::~Oplog()
+{
+ if (m_Storage)
+ {
+ Flush();
+ }
+}
+
+void
+ProjectStore::Oplog::Flush()
+{
+ ZEN_ASSERT(m_Storage);
+ m_Storage->Flush();
+}
+
+void
+ProjectStore::Oplog::Scrub(ScrubContext& Ctx) const
+{
+ ZEN_UNUSED(Ctx);
+}
+
+void
+ProjectStore::Oplog::GatherReferences(GcContext& GcCtx)
+{
+ RwLock::SharedLockScope _(m_OplogLock);
+
+ std::vector<IoHash> Hashes;
+ Hashes.reserve(Max(m_ChunkMap.size(), m_MetaMap.size()));
+
+ for (const auto& Kv : m_ChunkMap)
+ {
+ Hashes.push_back(Kv.second);
+ }
+
+ GcCtx.AddRetainedCids(Hashes);
+
+ Hashes.clear();
+
+ for (const auto& Kv : m_MetaMap)
+ {
+ Hashes.push_back(Kv.second);
+ }
+
+ GcCtx.AddRetainedCids(Hashes);
+}
+
+uint64_t
+ProjectStore::Oplog::TotalSize() const
+{
+ RwLock::SharedLockScope _(m_OplogLock);
+ if (m_Storage)
+ {
+ return m_Storage->OpBlobsSize();
+ }
+ return 0;
+}
+
+bool
+ProjectStore::Oplog::IsExpired() const
+{
+ if (m_MarkerPath.empty())
+ {
+ return false;
+ }
+ return !std::filesystem::exists(m_MarkerPath);
+}
+
+std::filesystem::path
+ProjectStore::Oplog::PrepareForDelete(bool MoveFolder)
+{
+ RwLock::ExclusiveLockScope _(m_OplogLock);
+ m_ChunkMap.clear();
+ m_MetaMap.clear();
+ m_FileMap.clear();
+ m_OpAddressMap.clear();
+ m_LatestOpMap.clear();
+ m_Storage = {};
+ if (!MoveFolder)
+ {
+ return {};
+ }
+ std::filesystem::path MovedDir;
+ if (PrepareDirectoryDelete(m_BasePath, MovedDir))
+ {
+ return MovedDir;
+ }
+ return {};
+}
+
+bool
+ProjectStore::Oplog::ExistsAt(std::filesystem::path BasePath)
+{
+ using namespace std::literals;
+
+ std::filesystem::path StateFilePath = BasePath / "oplog.zcb"sv;
+ return std::filesystem::is_regular_file(StateFilePath);
+}
+
+void
+ProjectStore::Oplog::Read()
+{
+ using namespace std::literals;
+
+ std::filesystem::path StateFilePath = m_BasePath / "oplog.zcb"sv;
+ if (std::filesystem::is_regular_file(StateFilePath))
+ {
+ ZEN_INFO("reading config for oplog '{}' in project '{}' from {}", m_OplogId, m_OuterProject->Identifier, StateFilePath);
+
+ BasicFile Blob;
+ Blob.Open(StateFilePath, BasicFile::Mode::kRead);
+
+ IoBuffer Obj = Blob.ReadAll();
+ CbValidateError ValidationError = ValidateCompactBinary(MemoryView(Obj.Data(), Obj.Size()), CbValidateMode::All);
+
+ if (ValidationError != CbValidateError::None)
+ {
+ ZEN_ERROR("validation error {} hit for '{}'", int(ValidationError), StateFilePath);
+ return;
+ }
+
+ CbObject Cfg = LoadCompactBinaryObject(Obj);
+
+ m_MarkerPath = Cfg["gcpath"sv].AsString();
+ }
+ else
+ {
+ ZEN_INFO("config for oplog '{}' in project '{}' not found at {}. Assuming legacy store",
+ m_OplogId,
+ m_OuterProject->Identifier,
+ StateFilePath);
+ }
+ ReplayLog();
+}
+
+void
+ProjectStore::Oplog::Write()
+{
+ using namespace std::literals;
+
+ BinaryWriter Mem;
+
+ CbObjectWriter Cfg;
+
+ Cfg << "gcpath"sv << PathToUtf8(m_MarkerPath);
+
+ Cfg.Save(Mem);
+
+ std::filesystem::path StateFilePath = m_BasePath / "oplog.zcb"sv;
+
+ ZEN_INFO("persisting config for oplog '{}' in project '{}' to {}", m_OplogId, m_OuterProject->Identifier, StateFilePath);
+
+ BasicFile Blob;
+ Blob.Open(StateFilePath, BasicFile::Mode::kTruncate);
+ Blob.Write(Mem.Data(), Mem.Size(), 0);
+ Blob.Flush();
+}
+
+void
+ProjectStore::Oplog::ReplayLog()
+{
+ RwLock::ExclusiveLockScope OplogLock(m_OplogLock);
+ if (!m_Storage)
+ {
+ return;
+ }
+ m_Storage->ReplayLog(
+ [&](CbObject Op, const OplogEntry& OpEntry) { RegisterOplogEntry(OplogLock, GetMapping(Op), OpEntry, kUpdateReplay); });
+}
+
+IoBuffer
+ProjectStore::Oplog::FindChunk(Oid ChunkId)
+{
+ RwLock::SharedLockScope OplogLock(m_OplogLock);
+ if (!m_Storage)
+ {
+ return IoBuffer{};
+ }
+
+ if (auto ChunkIt = m_ChunkMap.find(ChunkId); ChunkIt != m_ChunkMap.end())
+ {
+ IoHash ChunkHash = ChunkIt->second;
+ OplogLock.ReleaseNow();
+
+ IoBuffer Chunk = m_CidStore.FindChunkByCid(ChunkHash);
+ Chunk.SetContentType(ZenContentType::kCompressedBinary);
+
+ return Chunk;
+ }
+
+ if (auto FileIt = m_FileMap.find(ChunkId); FileIt != m_FileMap.end())
+ {
+ std::filesystem::path FilePath = m_OuterProject->RootDir / FileIt->second.ServerPath;
+
+ OplogLock.ReleaseNow();
+
+ IoBuffer FileChunk = IoBufferBuilder::MakeFromFile(FilePath);
+ FileChunk.SetContentType(ZenContentType::kBinary);
+
+ return FileChunk;
+ }
+
+ if (auto MetaIt = m_MetaMap.find(ChunkId); MetaIt != m_MetaMap.end())
+ {
+ IoHash ChunkHash = MetaIt->second;
+ OplogLock.ReleaseNow();
+
+ IoBuffer Chunk = m_CidStore.FindChunkByCid(ChunkHash);
+ Chunk.SetContentType(ZenContentType::kCompressedBinary);
+
+ return Chunk;
+ }
+
+ return {};
+}
+
+void
+ProjectStore::Oplog::IterateFileMap(
+ std::function<void(const Oid&, const std::string_view& ServerPath, const std::string_view& ClientPath)>&& Fn)
+{
+ RwLock::SharedLockScope _(m_OplogLock);
+ if (!m_Storage)
+ {
+ return;
+ }
+
+ for (const auto& Kv : m_FileMap)
+ {
+ Fn(Kv.first, Kv.second.ServerPath, Kv.second.ClientPath);
+ }
+}
+
+void
+ProjectStore::Oplog::IterateOplog(std::function<void(CbObject)>&& Handler)
+{
+ RwLock::SharedLockScope _(m_OplogLock);
+ if (!m_Storage)
+ {
+ return;
+ }
+
+ std::vector<OplogEntryAddress> Entries;
+ Entries.reserve(m_LatestOpMap.size());
+
+ for (const auto& Kv : m_LatestOpMap)
+ {
+ const auto AddressEntry = m_OpAddressMap.find(Kv.second);
+ ZEN_ASSERT(AddressEntry != m_OpAddressMap.end());
+
+ Entries.push_back(AddressEntry->second);
+ }
+
+ std::sort(Entries.begin(), Entries.end(), [](const OplogEntryAddress& Lhs, const OplogEntryAddress& Rhs) {
+ return Lhs.Offset < Rhs.Offset;
+ });
+
+ m_Storage->ReplayLog(Entries, [&](CbObject Op) { Handler(Op); });
+}
+
+void
+ProjectStore::Oplog::IterateOplogWithKey(std::function<void(int, const Oid&, CbObject)>&& Handler)
+{
+ RwLock::SharedLockScope _(m_OplogLock);
+ if (!m_Storage)
+ {
+ return;
+ }
+
+ std::vector<size_t> EntryIndexes;
+ std::vector<OplogEntryAddress> Entries;
+ std::vector<Oid> Keys;
+ std::vector<int> LSNs;
+ Entries.reserve(m_LatestOpMap.size());
+ EntryIndexes.reserve(m_LatestOpMap.size());
+ Keys.reserve(m_LatestOpMap.size());
+ LSNs.reserve(m_LatestOpMap.size());
+
+ for (const auto& Kv : m_LatestOpMap)
+ {
+ const auto AddressEntry = m_OpAddressMap.find(Kv.second);
+ ZEN_ASSERT(AddressEntry != m_OpAddressMap.end());
+
+ Entries.push_back(AddressEntry->second);
+ Keys.push_back(Kv.first);
+ LSNs.push_back(Kv.second);
+ EntryIndexes.push_back(EntryIndexes.size());
+ }
+
+ std::sort(EntryIndexes.begin(), EntryIndexes.end(), [&Entries](const size_t& Lhs, const size_t& Rhs) {
+ const OplogEntryAddress& LhsEntry = Entries[Lhs];
+ const OplogEntryAddress& RhsEntry = Entries[Rhs];
+ return LhsEntry.Offset < RhsEntry.Offset;
+ });
+ std::vector<OplogEntryAddress> SortedEntries;
+ SortedEntries.reserve(EntryIndexes.size());
+ for (size_t Index : EntryIndexes)
+ {
+ SortedEntries.push_back(Entries[Index]);
+ }
+
+ size_t EntryIndex = 0;
+ m_Storage->ReplayLog(SortedEntries, [&](CbObject Op) {
+ Handler(LSNs[EntryIndex], Keys[EntryIndex], Op);
+ EntryIndex++;
+ });
+}
+
+int
+ProjectStore::Oplog::GetOpIndexByKey(const Oid& Key)
+{
+ RwLock::SharedLockScope _(m_OplogLock);
+ if (!m_Storage)
+ {
+ return {};
+ }
+ if (const auto LatestOp = m_LatestOpMap.find(Key); LatestOp != m_LatestOpMap.end())
+ {
+ return LatestOp->second;
+ }
+ return -1;
+}
+
+std::optional<CbObject>
+ProjectStore::Oplog::GetOpByKey(const Oid& Key)
+{
+ RwLock::SharedLockScope _(m_OplogLock);
+ if (!m_Storage)
+ {
+ return {};
+ }
+
+ if (const auto LatestOp = m_LatestOpMap.find(Key); LatestOp != m_LatestOpMap.end())
+ {
+ const auto AddressEntry = m_OpAddressMap.find(LatestOp->second);
+ ZEN_ASSERT(AddressEntry != m_OpAddressMap.end());
+
+ return m_Storage->GetOp(AddressEntry->second);
+ }
+
+ return {};
+}
+
+std::optional<CbObject>
+ProjectStore::Oplog::GetOpByIndex(int Index)
+{
+ RwLock::SharedLockScope _(m_OplogLock);
+ if (!m_Storage)
+ {
+ return {};
+ }
+
+ if (const auto AddressEntryIt = m_OpAddressMap.find(Index); AddressEntryIt != m_OpAddressMap.end())
+ {
+ return m_Storage->GetOp(AddressEntryIt->second);
+ }
+
+ return {};
+}
+
+void
+ProjectStore::Oplog::AddFileMapping(const RwLock::ExclusiveLockScope&,
+ Oid FileId,
+ IoHash Hash,
+ std::string_view ServerPath,
+ std::string_view ClientPath)
+{
+ if (Hash != IoHash::Zero)
+ {
+ m_ChunkMap.insert_or_assign(FileId, Hash);
+ }
+
+ FileMapEntry Entry;
+ Entry.ServerPath = ServerPath;
+ Entry.ClientPath = ClientPath;
+
+ m_FileMap[FileId] = std::move(Entry);
+
+ if (Hash != IoHash::Zero)
+ {
+ m_ChunkMap.insert_or_assign(FileId, Hash);
+ }
+}
+
+void
+ProjectStore::Oplog::AddChunkMapping(const RwLock::ExclusiveLockScope&, Oid ChunkId, IoHash Hash)
+{
+ m_ChunkMap.insert_or_assign(ChunkId, Hash);
+}
+
+void
+ProjectStore::Oplog::AddMetaMapping(const RwLock::ExclusiveLockScope&, Oid ChunkId, IoHash Hash)
+{
+ m_MetaMap.insert_or_assign(ChunkId, Hash);
+}
+
+ProjectStore::Oplog::OplogEntryMapping
+ProjectStore::Oplog::GetMapping(CbObject Core)
+{
+ using namespace std::literals;
+
+ OplogEntryMapping Result;
+
+ // Update chunk id maps
+ CbObjectView PackageObj = Core["package"sv].AsObjectView();
+ CbArrayView BulkDataArray = Core["bulkdata"sv].AsArrayView();
+ CbArrayView PackageDataArray = Core["packagedata"sv].AsArrayView();
+ Result.Chunks.reserve(PackageObj ? 1 : 0 + BulkDataArray.Num() + PackageDataArray.Num());
+
+ if (PackageObj)
+ {
+ Oid Id = PackageObj["id"sv].AsObjectId();
+ IoHash Hash = PackageObj["data"sv].AsBinaryAttachment();
+ Result.Chunks.emplace_back(OplogEntryMapping::Mapping{Id, Hash});
+ ZEN_DEBUG("package data {} -> {}", Id, Hash);
+ }
+
+ for (CbFieldView& Entry : PackageDataArray)
+ {
+ CbObjectView PackageDataObj = Entry.AsObjectView();
+ Oid Id = PackageDataObj["id"sv].AsObjectId();
+ IoHash Hash = PackageDataObj["data"sv].AsBinaryAttachment();
+ Result.Chunks.emplace_back(OplogEntryMapping::Mapping{Id, Hash});
+ ZEN_DEBUG("package {} -> {}", Id, Hash);
+ }
+
+ for (CbFieldView& Entry : BulkDataArray)
+ {
+ CbObjectView BulkObj = Entry.AsObjectView();
+ Oid Id = BulkObj["id"sv].AsObjectId();
+ IoHash Hash = BulkObj["data"sv].AsBinaryAttachment();
+ Result.Chunks.emplace_back(OplogEntryMapping::Mapping{Id, Hash});
+ ZEN_DEBUG("bulkdata {} -> {}", Id, Hash);
+ }
+
+ CbArrayView FilesArray = Core["files"sv].AsArrayView();
+ Result.Files.reserve(FilesArray.Num());
+ for (CbFieldView& Entry : FilesArray)
+ {
+ CbObjectView FileObj = Entry.AsObjectView();
+
+ std::string_view ServerPath = FileObj["serverpath"sv].AsString();
+ std::string_view ClientPath = FileObj["clientpath"sv].AsString();
+ if (ServerPath.empty() || ClientPath.empty())
+ {
+ ZEN_WARN("invalid file");
+ continue;
+ }
+
+ Oid Id = FileObj["id"sv].AsObjectId();
+ IoHash Hash = FileObj["data"sv].AsBinaryAttachment();
+ Result.Files.emplace_back(
+ OplogEntryMapping::FileMapping{OplogEntryMapping::Mapping{Id, Hash}, std::string(ServerPath), std::string(ClientPath)});
+ ZEN_DEBUG("file {} -> {}, ServerPath: {}, ClientPath: {}", Id, Hash, ServerPath, ClientPath);
+ }
+
+ CbArrayView MetaArray = Core["meta"sv].AsArrayView();
+ Result.Meta.reserve(MetaArray.Num());
+ for (CbFieldView& Entry : MetaArray)
+ {
+ CbObjectView MetaObj = Entry.AsObjectView();
+ Oid Id = MetaObj["id"sv].AsObjectId();
+ IoHash Hash = MetaObj["data"sv].AsBinaryAttachment();
+ Result.Meta.emplace_back(OplogEntryMapping::Mapping{Id, Hash});
+ auto NameString = MetaObj["name"sv].AsString();
+ ZEN_DEBUG("meta data ({}) {} -> {}", NameString, Id, Hash);
+ }
+
+ return Result;
+}
+
+uint32_t
+ProjectStore::Oplog::RegisterOplogEntry(RwLock::ExclusiveLockScope& OplogLock,
+ const OplogEntryMapping& OpMapping,
+ const OplogEntry& OpEntry,
+ UpdateType TypeOfUpdate)
+{
+ ZEN_TRACE_CPU("ProjectStore::Oplog::RegisterOplogEntry");
+
+ ZEN_UNUSED(TypeOfUpdate);
+
+ // For now we're assuming the update is all in-memory so we can hold an exclusive lock without causing
+ // too many problems. Longer term we'll probably want to ensure we can do concurrent updates however
+
+ using namespace std::literals;
+
+ // Update chunk id maps
+ for (const OplogEntryMapping::Mapping& Chunk : OpMapping.Chunks)
+ {
+ AddChunkMapping(OplogLock, Chunk.Id, Chunk.Hash);
+ }
+
+ for (const OplogEntryMapping::FileMapping& File : OpMapping.Files)
+ {
+ AddFileMapping(OplogLock, File.Id, File.Hash, File.ServerPath, File.ClientPath);
+ }
+
+ for (const OplogEntryMapping::Mapping& Meta : OpMapping.Meta)
+ {
+ AddMetaMapping(OplogLock, Meta.Id, Meta.Hash);
+ }
+
+ m_OpAddressMap.emplace(OpEntry.OpLsn, OplogEntryAddress{.Offset = OpEntry.OpCoreOffset, .Size = OpEntry.OpCoreSize});
+ m_LatestOpMap[OpEntry.OpKeyAsOId()] = OpEntry.OpLsn;
+
+ return OpEntry.OpLsn;
+}
+
+uint32_t
+ProjectStore::Oplog::AppendNewOplogEntry(CbPackage OpPackage)
+{
+ ZEN_TRACE_CPU("ProjectStore::Oplog::AppendNewOplogEntry");
+
+ const CbObject& Core = OpPackage.GetObject();
+ const uint32_t EntryId = AppendNewOplogEntry(Core);
+ if (EntryId == 0xffffffffu)
+ {
+ // The oplog has been deleted so just drop this
+ return EntryId;
+ }
+
+ // Persist attachments after oplog entry so GC won't find attachments without references
+
+ uint64_t AttachmentBytes = 0;
+ uint64_t NewAttachmentBytes = 0;
+
+ auto Attachments = OpPackage.GetAttachments();
+
+ for (const auto& Attach : Attachments)
+ {
+ ZEN_ASSERT(Attach.IsCompressedBinary());
+
+ CompressedBuffer AttachmentData = Attach.AsCompressedBinary();
+ const uint64_t AttachmentSize = AttachmentData.DecodeRawSize();
+ CidStore::InsertResult InsertResult = m_CidStore.AddChunk(AttachmentData.GetCompressed().Flatten().AsIoBuffer(), Attach.GetHash());
+
+ if (InsertResult.New)
+ {
+ NewAttachmentBytes += AttachmentSize;
+ }
+ AttachmentBytes += AttachmentSize;
+ }
+
+ ZEN_DEBUG("oplog entry #{} attachments: {} new, {} total", EntryId, NiceBytes(NewAttachmentBytes), NiceBytes(AttachmentBytes));
+
+ return EntryId;
+}
+
+uint32_t
+ProjectStore::Oplog::AppendNewOplogEntry(CbObject Core)
+{
+ ZEN_TRACE_CPU("ProjectStore::Oplog::AppendNewOplogEntry");
+
+ using namespace std::literals;
+
+ OplogEntryMapping Mapping = GetMapping(Core);
+
+ SharedBuffer Buffer = Core.GetBuffer();
+ const uint64_t WriteSize = Buffer.GetSize();
+ const auto OpCoreHash = uint32_t(XXH3_64bits(Buffer.GetData(), WriteSize) & 0xffffFFFF);
+
+ ZEN_ASSERT(WriteSize != 0);
+
+ XXH3_128Stream KeyHasher;
+ Core["key"sv].WriteToStream([&](const void* Data, size_t Size) { KeyHasher.Append(Data, Size); });
+ XXH3_128 KeyHash = KeyHasher.GetHash();
+
+ RefPtr<OplogStorage> Storage;
+ {
+ RwLock::SharedLockScope _(m_OplogLock);
+ Storage = m_Storage;
+ }
+ if (!m_Storage)
+ {
+ return 0xffffffffu;
+ }
+ const OplogEntry OpEntry = m_Storage->AppendOp(Buffer, OpCoreHash, KeyHash);
+
+ RwLock::ExclusiveLockScope OplogLock(m_OplogLock);
+ const uint32_t EntryId = RegisterOplogEntry(OplogLock, Mapping, OpEntry, kUpdateNewEntry);
+
+ return EntryId;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+ProjectStore::Project::Project(ProjectStore* PrjStore, CidStore& Store, std::filesystem::path BasePath)
+: m_ProjectStore(PrjStore)
+, m_CidStore(Store)
+, m_OplogStoragePath(BasePath)
+{
+}
+
+ProjectStore::Project::~Project()
+{
+}
+
+bool
+ProjectStore::Project::Exists(std::filesystem::path BasePath)
+{
+ return std::filesystem::exists(BasePath / "Project.zcb");
+}
+
+void
+ProjectStore::Project::Read()
+{
+ using namespace std::literals;
+
+ std::filesystem::path ProjectStateFilePath = m_OplogStoragePath / "Project.zcb"sv;
+
+ ZEN_INFO("reading config for project '{}' from {}", Identifier, ProjectStateFilePath);
+
+ BasicFile Blob;
+ Blob.Open(ProjectStateFilePath, BasicFile::Mode::kRead);
+
+ IoBuffer Obj = Blob.ReadAll();
+ CbValidateError ValidationError = ValidateCompactBinary(MemoryView(Obj.Data(), Obj.Size()), CbValidateMode::All);
+
+ if (ValidationError == CbValidateError::None)
+ {
+ CbObject Cfg = LoadCompactBinaryObject(Obj);
+
+ Identifier = Cfg["id"sv].AsString();
+ RootDir = Cfg["root"sv].AsString();
+ ProjectRootDir = Cfg["project"sv].AsString();
+ EngineRootDir = Cfg["engine"sv].AsString();
+ ProjectFilePath = Cfg["projectfile"sv].AsString();
+ }
+ else
+ {
+ ZEN_ERROR("validation error {} hit for '{}'", int(ValidationError), ProjectStateFilePath);
+ }
+}
+
+void
+ProjectStore::Project::Write()
+{
+ using namespace std::literals;
+
+ BinaryWriter Mem;
+
+ CbObjectWriter Cfg;
+ Cfg << "id"sv << Identifier;
+ Cfg << "root"sv << PathToUtf8(RootDir);
+ Cfg << "project"sv << ProjectRootDir;
+ Cfg << "engine"sv << EngineRootDir;
+ Cfg << "projectfile"sv << ProjectFilePath;
+
+ Cfg.Save(Mem);
+
+ CreateDirectories(m_OplogStoragePath);
+
+ std::filesystem::path ProjectStateFilePath = m_OplogStoragePath / "Project.zcb"sv;
+
+ ZEN_INFO("persisting config for project '{}' to {}", Identifier, ProjectStateFilePath);
+
+ BasicFile Blob;
+ Blob.Open(ProjectStateFilePath, BasicFile::Mode::kTruncate);
+ Blob.Write(Mem.Data(), Mem.Size(), 0);
+ Blob.Flush();
+}
+
+spdlog::logger&
+ProjectStore::Project::Log()
+{
+ return m_ProjectStore->Log();
+}
+
+std::filesystem::path
+ProjectStore::Project::BasePathForOplog(std::string_view OplogId)
+{
+ return m_OplogStoragePath / OplogId;
+}
+
+ProjectStore::Oplog*
+ProjectStore::Project::NewOplog(std::string_view OplogId, const std::filesystem::path& MarkerPath)
+{
+ RwLock::ExclusiveLockScope _(m_ProjectLock);
+
+ std::filesystem::path OplogBasePath = BasePathForOplog(OplogId);
+
+ try
+ {
+ Oplog* Log = m_Oplogs
+ .try_emplace(std::string{OplogId},
+ std::make_unique<ProjectStore::Oplog>(OplogId, this, m_CidStore, OplogBasePath, MarkerPath))
+ .first->second.get();
+
+ Log->Write();
+ return Log;
+ }
+ catch (std::exception&)
+ {
+ // In case of failure we need to ensure there's no half constructed entry around
+ //
+ // (This is probably already ensured by the try_emplace implementation?)
+
+ m_Oplogs.erase(std::string{OplogId});
+
+ return nullptr;
+ }
+}
+
+ProjectStore::Oplog*
+ProjectStore::Project::OpenOplog(std::string_view OplogId)
+{
+ {
+ RwLock::SharedLockScope _(m_ProjectLock);
+
+ auto OplogIt = m_Oplogs.find(std::string(OplogId));
+
+ if (OplogIt != m_Oplogs.end())
+ {
+ return OplogIt->second.get();
+ }
+ }
+
+ RwLock::ExclusiveLockScope _(m_ProjectLock);
+
+ std::filesystem::path OplogBasePath = BasePathForOplog(OplogId);
+
+ if (Oplog::ExistsAt(OplogBasePath))
+ {
+ // Do open of existing oplog
+
+ try
+ {
+ Oplog* Log =
+ m_Oplogs
+ .try_emplace(std::string{OplogId},
+ std::make_unique<ProjectStore::Oplog>(OplogId, this, m_CidStore, OplogBasePath, std::filesystem::path{}))
+ .first->second.get();
+ Log->Read();
+
+ return Log;
+ }
+ catch (std::exception& ex)
+ {
+ ZEN_WARN("failed to open oplog '{}' @ '{}': {}", OplogId, OplogBasePath, ex.what());
+
+ m_Oplogs.erase(std::string{OplogId});
+ }
+ }
+
+ return nullptr;
+}
+
+void
+ProjectStore::Project::DeleteOplog(std::string_view OplogId)
+{
+ std::filesystem::path DeletePath;
+ {
+ RwLock::ExclusiveLockScope _(m_ProjectLock);
+
+ auto OplogIt = m_Oplogs.find(std::string(OplogId));
+
+ if (OplogIt != m_Oplogs.end())
+ {
+ std::unique_ptr<Oplog>& Oplog = OplogIt->second;
+ DeletePath = Oplog->PrepareForDelete(true);
+ m_DeletedOplogs.emplace_back(std::move(Oplog));
+ m_Oplogs.erase(OplogIt);
+ }
+ }
+
+ // Erase content on disk
+ if (!DeletePath.empty())
+ {
+ OplogStorage::Delete(DeletePath);
+ }
+}
+
+std::vector<std::string>
+ProjectStore::Project::ScanForOplogs() const
+{
+ DirectoryContent DirContent;
+ GetDirectoryContent(m_OplogStoragePath, DirectoryContent::IncludeDirsFlag, DirContent);
+ std::vector<std::string> Oplogs;
+ Oplogs.reserve(DirContent.Directories.size());
+ for (const std::filesystem::path& DirPath : DirContent.Directories)
+ {
+ Oplogs.push_back(DirPath.filename().string());
+ }
+ return Oplogs;
+}
+
+void
+ProjectStore::Project::IterateOplogs(std::function<void(const Oplog&)>&& Fn) const
+{
+ RwLock::SharedLockScope _(m_ProjectLock);
+
+ for (auto& Kv : m_Oplogs)
+ {
+ Fn(*Kv.second);
+ }
+}
+
+void
+ProjectStore::Project::IterateOplogs(std::function<void(Oplog&)>&& Fn)
+{
+ RwLock::SharedLockScope _(m_ProjectLock);
+
+ for (auto& Kv : m_Oplogs)
+ {
+ Fn(*Kv.second);
+ }
+}
+
+void
+ProjectStore::Project::Flush()
+{
+ // We only need to flush oplogs that we have already loaded
+ IterateOplogs([&](Oplog& Ops) { Ops.Flush(); });
+}
+
+void
+ProjectStore::Project::Scrub(ScrubContext& Ctx)
+{
+ // Scrubbing needs to check all existing oplogs
+ std::vector<std::string> OpLogs = ScanForOplogs();
+ for (const std::string& OpLogId : OpLogs)
+ {
+ OpenOplog(OpLogId);
+ }
+ IterateOplogs([&](const Oplog& Ops) {
+ if (!Ops.IsExpired())
+ {
+ Ops.Scrub(Ctx);
+ }
+ });
+}
+
+void
+ProjectStore::Project::GatherReferences(GcContext& GcCtx)
+{
+ ZEN_TRACE_CPU("ProjectStore::Project::GatherReferences");
+
+ Stopwatch Timer;
+ const auto Guard = MakeGuard([&] {
+ ZEN_DEBUG("gathered references from project store project {} in {}", Identifier, NiceTimeSpanMs(Timer.GetElapsedTimeMs()));
+ });
+
+ // GatherReferences needs to check all existing oplogs
+ std::vector<std::string> OpLogs = ScanForOplogs();
+ for (const std::string& OpLogId : OpLogs)
+ {
+ OpenOplog(OpLogId);
+ }
+ IterateOplogs([&](Oplog& Ops) {
+ if (!Ops.IsExpired())
+ {
+ Ops.GatherReferences(GcCtx);
+ }
+ });
+}
+
+uint64_t
+ProjectStore::Project::TotalSize() const
+{
+ uint64_t Result = 0;
+ {
+ RwLock::SharedLockScope _(m_ProjectLock);
+ for (const auto& It : m_Oplogs)
+ {
+ Result += It.second->TotalSize();
+ }
+ }
+ return Result;
+}
+
+bool
+ProjectStore::Project::PrepareForDelete(std::filesystem::path& OutDeletePath)
+{
+ RwLock::ExclusiveLockScope _(m_ProjectLock);
+
+ for (auto& It : m_Oplogs)
+ {
+ // We don't care about the moved folder
+ It.second->PrepareForDelete(false);
+ m_DeletedOplogs.emplace_back(std::move(It.second));
+ }
+
+ m_Oplogs.clear();
+
+ bool Success = PrepareDirectoryDelete(m_OplogStoragePath, OutDeletePath);
+ if (!Success)
+ {
+ return false;
+ }
+ m_OplogStoragePath.clear();
+ return true;
+}
+
+bool
+ProjectStore::Project::IsExpired() const
+{
+ if (ProjectFilePath.empty())
+ {
+ return false;
+ }
+ return !std::filesystem::exists(ProjectFilePath);
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+ProjectStore::ProjectStore(CidStore& Store, std::filesystem::path BasePath, GcManager& Gc)
+: GcStorage(Gc)
+, GcContributor(Gc)
+, m_Log(logging::Get("project"))
+, m_CidStore(Store)
+, m_ProjectBasePath(BasePath)
+{
+ ZEN_INFO("initializing project store at '{}'", BasePath);
+ // m_Log.set_level(spdlog::level::debug);
+}
+
+ProjectStore::~ProjectStore()
+{
+ ZEN_INFO("closing project store ('{}')", m_ProjectBasePath);
+}
+
+std::filesystem::path
+ProjectStore::BasePathForProject(std::string_view ProjectId)
+{
+ return m_ProjectBasePath / ProjectId;
+}
+
+void
+ProjectStore::DiscoverProjects()
+{
+ if (!std::filesystem::exists(m_ProjectBasePath))
+ {
+ return;
+ }
+
+ DirectoryContent DirContent;
+ GetDirectoryContent(m_ProjectBasePath, DirectoryContent::IncludeDirsFlag, DirContent);
+
+ for (const std::filesystem::path& DirPath : DirContent.Directories)
+ {
+ std::string DirName = PathToUtf8(DirPath.filename());
+ OpenProject(DirName);
+ }
+}
+
+void
+ProjectStore::IterateProjects(std::function<void(Project& Prj)>&& Fn)
+{
+ RwLock::SharedLockScope _(m_ProjectsLock);
+
+ for (auto& Kv : m_Projects)
+ {
+ Fn(*Kv.second.Get());
+ }
+}
+
+void
+ProjectStore::Flush()
+{
+ std::vector<Ref<Project>> Projects;
+ {
+ RwLock::SharedLockScope _(m_ProjectsLock);
+ Projects.reserve(m_Projects.size());
+
+ for (auto& Kv : m_Projects)
+ {
+ Projects.push_back(Kv.second);
+ }
+ }
+ for (const Ref<Project>& Project : Projects)
+ {
+ Project->Flush();
+ }
+}
+
+void
+ProjectStore::Scrub(ScrubContext& Ctx)
+{
+ DiscoverProjects();
+
+ std::vector<Ref<Project>> Projects;
+ {
+ RwLock::SharedLockScope _(m_ProjectsLock);
+ Projects.reserve(m_Projects.size());
+
+ for (auto& Kv : m_Projects)
+ {
+ if (Kv.second->IsExpired())
+ {
+ continue;
+ }
+ Projects.push_back(Kv.second);
+ }
+ }
+ for (const Ref<Project>& Project : Projects)
+ {
+ Project->Scrub(Ctx);
+ }
+}
+
+void
+ProjectStore::GatherReferences(GcContext& GcCtx)
+{
+ ZEN_TRACE_CPU("ProjectStore::GatherReferences");
+
+ size_t ProjectCount = 0;
+ size_t ExpiredProjectCount = 0;
+ Stopwatch Timer;
+ const auto Guard = MakeGuard([&] {
+ ZEN_DEBUG("gathered references from '{}' in {}, found {} active projects and {} expired projects",
+ m_ProjectBasePath.string(),
+ NiceTimeSpanMs(Timer.GetElapsedTimeMs()),
+ ProjectCount,
+ ExpiredProjectCount);
+ });
+
+ DiscoverProjects();
+
+ std::vector<Ref<Project>> Projects;
+ {
+ RwLock::SharedLockScope _(m_ProjectsLock);
+ Projects.reserve(m_Projects.size());
+
+ for (auto& Kv : m_Projects)
+ {
+ if (Kv.second->IsExpired())
+ {
+ ExpiredProjectCount++;
+ continue;
+ }
+ Projects.push_back(Kv.second);
+ }
+ }
+ ProjectCount = Projects.size();
+ for (const Ref<Project>& Project : Projects)
+ {
+ Project->GatherReferences(GcCtx);
+ }
+}
+
+void
+ProjectStore::CollectGarbage(GcContext& GcCtx)
+{
+ ZEN_TRACE_CPU("ProjectStore::CollectGarbage");
+
+ size_t ProjectCount = 0;
+ size_t ExpiredProjectCount = 0;
+
+ Stopwatch Timer;
+ const auto Guard = MakeGuard([&] {
+ ZEN_DEBUG("garbage collect from '{}' DONE after {}, found {} active projects and {} expired projects",
+ m_ProjectBasePath.string(),
+ NiceTimeSpanMs(Timer.GetElapsedTimeMs()),
+ ProjectCount,
+ ExpiredProjectCount);
+ });
+ std::vector<Ref<Project>> ExpiredProjects;
+ std::vector<Ref<Project>> Projects;
+
+ {
+ RwLock::SharedLockScope _(m_ProjectsLock);
+ for (auto& Kv : m_Projects)
+ {
+ if (Kv.second->IsExpired())
+ {
+ ExpiredProjects.push_back(Kv.second);
+ ExpiredProjectCount++;
+ continue;
+ }
+ Projects.push_back(Kv.second);
+ ProjectCount++;
+ }
+ }
+
+ if (!GcCtx.IsDeletionMode())
+ {
+ ZEN_DEBUG("garbage collect DISABLED, for '{}' ", m_ProjectBasePath.string());
+ return;
+ }
+
+ for (const Ref<Project>& Project : Projects)
+ {
+ std::vector<std::string> ExpiredOplogs;
+ {
+ RwLock::ExclusiveLockScope _(m_ProjectsLock);
+ Project->IterateOplogs([&ExpiredOplogs](ProjectStore::Oplog& Oplog) {
+ if (Oplog.IsExpired())
+ {
+ ExpiredOplogs.push_back(Oplog.OplogId());
+ }
+ });
+ }
+ for (const std::string& OplogId : ExpiredOplogs)
+ {
+ ZEN_DEBUG("ProjectStore::CollectGarbage garbage collected oplog '{}' in project '{}'. Removing storage on disk",
+ OplogId,
+ Project->Identifier);
+ Project->DeleteOplog(OplogId);
+ }
+ }
+
+ if (ExpiredProjects.empty())
+ {
+ ZEN_DEBUG("garbage collect for '{}', no expired projects found", m_ProjectBasePath.string());
+ return;
+ }
+
+ for (const Ref<Project>& Project : ExpiredProjects)
+ {
+ std::filesystem::path PathToRemove;
+ std::string ProjectId;
+ {
+ RwLock::ExclusiveLockScope _(m_ProjectsLock);
+ if (!Project->IsExpired())
+ {
+ ZEN_DEBUG("ProjectStore::CollectGarbage skipped garbage collect of project '{}'. Project no longer expired.", ProjectId);
+ continue;
+ }
+ bool Success = Project->PrepareForDelete(PathToRemove);
+ if (!Success)
+ {
+ ZEN_DEBUG("ProjectStore::CollectGarbage skipped garbage collect of project '{}'. Project folder is locked.", ProjectId);
+ continue;
+ }
+ m_Projects.erase(Project->Identifier);
+ ProjectId = Project->Identifier;
+ }
+
+ ZEN_DEBUG("ProjectStore::CollectGarbage garbage collected project '{}'. Removing storage on disk", ProjectId);
+ if (PathToRemove.empty())
+ {
+ continue;
+ }
+
+ DeleteDirectories(PathToRemove);
+ }
+}
+
+GcStorageSize
+ProjectStore::StorageSize() const
+{
+ GcStorageSize Result;
+ {
+ RwLock::SharedLockScope _(m_ProjectsLock);
+ for (auto& Kv : m_Projects)
+ {
+ const Ref<Project>& Project = Kv.second;
+ Result.DiskSize += Project->TotalSize();
+ }
+ }
+ return Result;
+}
+
+Ref<ProjectStore::Project>
+ProjectStore::OpenProject(std::string_view ProjectId)
+{
+ {
+ RwLock::SharedLockScope _(m_ProjectsLock);
+
+ auto ProjIt = m_Projects.find(std::string{ProjectId});
+
+ if (ProjIt != m_Projects.end())
+ {
+ return ProjIt->second;
+ }
+ }
+
+ RwLock::ExclusiveLockScope _(m_ProjectsLock);
+
+ std::filesystem::path BasePath = BasePathForProject(ProjectId);
+
+ if (Project::Exists(BasePath))
+ {
+ try
+ {
+ ZEN_INFO("opening project {} @ {}", ProjectId, BasePath);
+
+ Ref<Project>& Prj =
+ m_Projects
+ .try_emplace(std::string{ProjectId}, Ref<ProjectStore::Project>(new ProjectStore::Project(this, m_CidStore, BasePath)))
+ .first->second;
+ Prj->Identifier = ProjectId;
+ Prj->Read();
+ return Prj;
+ }
+ catch (std::exception& e)
+ {
+ ZEN_WARN("failed to open {} @ {} ({})", ProjectId, BasePath, e.what());
+ m_Projects.erase(std::string{ProjectId});
+ }
+ }
+
+ return {};
+}
+
+Ref<ProjectStore::Project>
+ProjectStore::NewProject(std::filesystem::path BasePath,
+ std::string_view ProjectId,
+ std::string_view RootDir,
+ std::string_view EngineRootDir,
+ std::string_view ProjectRootDir,
+ std::string_view ProjectFilePath)
+{
+ RwLock::ExclusiveLockScope _(m_ProjectsLock);
+
+ Ref<Project>& Prj =
+ m_Projects.try_emplace(std::string{ProjectId}, Ref<ProjectStore::Project>(new ProjectStore::Project(this, m_CidStore, BasePath)))
+ .first->second;
+ Prj->Identifier = ProjectId;
+ Prj->RootDir = RootDir;
+ Prj->EngineRootDir = EngineRootDir;
+ Prj->ProjectRootDir = ProjectRootDir;
+ Prj->ProjectFilePath = ProjectFilePath;
+ Prj->Write();
+
+ return Prj;
+}
+
+bool
+ProjectStore::DeleteProject(std::string_view ProjectId)
+{
+ ZEN_INFO("deleting project {}", ProjectId);
+
+ RwLock::ExclusiveLockScope ProjectsLock(m_ProjectsLock);
+
+ auto ProjIt = m_Projects.find(std::string{ProjectId});
+
+ if (ProjIt == m_Projects.end())
+ {
+ return true;
+ }
+
+ std::filesystem::path DeletePath;
+ bool Success = ProjIt->second->PrepareForDelete(DeletePath);
+
+ if (!Success)
+ {
+ return false;
+ }
+ m_Projects.erase(ProjIt);
+ ProjectsLock.ReleaseNow();
+
+ if (!DeletePath.empty())
+ {
+ DeleteDirectories(DeletePath);
+ }
+ return true;
+}
+
+bool
+ProjectStore::Exists(std::string_view ProjectId)
+{
+ return Project::Exists(BasePathForProject(ProjectId));
+}
+
+CbArray
+ProjectStore::GetProjectsList()
+{
+ using namespace std::literals;
+
+ DiscoverProjects();
+
+ CbWriter Response;
+ Response.BeginArray();
+
+ IterateProjects([&Response](ProjectStore::Project& Prj) {
+ Response.BeginObject();
+ Response << "Id"sv << Prj.Identifier;
+ Response << "RootDir"sv << Prj.RootDir.string();
+ Response << "ProjectRootDir"sv << Prj.ProjectRootDir;
+ Response << "EngineRootDir"sv << Prj.EngineRootDir;
+ Response << "ProjectFilePath"sv << Prj.ProjectFilePath;
+ Response.EndObject();
+ });
+ Response.EndArray();
+ return Response.Save().AsArray();
+}
+
+std::pair<HttpResponseCode, std::string>
+ProjectStore::GetProjectFiles(const std::string_view ProjectId, const std::string_view OplogId, bool FilterClient, CbObject& OutPayload)
+{
+ using namespace std::literals;
+
+ Ref<ProjectStore::Project> Project = OpenProject(ProjectId);
+ if (!Project)
+ {
+ return {HttpResponseCode::NotFound, fmt::format("Project files request for unknown project '{}'", ProjectId)};
+ }
+
+ ProjectStore::Oplog* FoundLog = Project->OpenOplog(OplogId);
+
+ if (!FoundLog)
+ {
+ return {HttpResponseCode::NotFound, fmt::format("Project files for unknown oplog '{}/{}'", ProjectId, OplogId)};
+ }
+
+ CbObjectWriter Response;
+ Response.BeginArray("files"sv);
+
+ FoundLog->IterateFileMap([&](const Oid& Id, const std::string_view& ServerPath, const std::string_view& ClientPath) {
+ Response.BeginObject();
+ Response << "id"sv << Id;
+ Response << "clientpath"sv << ClientPath;
+ if (!FilterClient)
+ {
+ Response << "serverpath"sv << ServerPath;
+ }
+ Response.EndObject();
+ });
+
+ Response.EndArray();
+ OutPayload = Response.Save();
+ return {HttpResponseCode::OK, {}};
+}
+
+std::pair<HttpResponseCode, std::string>
+ProjectStore::GetChunkInfo(const std::string_view ProjectId,
+ const std::string_view OplogId,
+ const std::string_view ChunkId,
+ CbObject& OutPayload)
+{
+ using namespace std::literals;
+
+ Ref<ProjectStore::Project> Project = OpenProject(ProjectId);
+ if (!Project)
+ {
+ return {HttpResponseCode::NotFound, fmt::format("Chunk info request for unknown project '{}'", ProjectId)};
+ }
+
+ ProjectStore::Oplog* FoundLog = Project->OpenOplog(OplogId);
+
+ if (!FoundLog)
+ {
+ return {HttpResponseCode::NotFound, fmt::format("Chunk info request for unknown oplog '{}/{}'", ProjectId, OplogId)};
+ }
+ if (ChunkId.size() != 2 * sizeof(Oid::OidBits))
+ {
+ return {HttpResponseCode::BadRequest,
+ fmt::format("Chunk info request for invalid chunk id '{}/{}'/'{}'", ProjectId, OplogId, ChunkId)};
+ }
+
+ const Oid Obj = Oid::FromHexString(ChunkId);
+
+ IoBuffer Chunk = FoundLog->FindChunk(Obj);
+ if (!Chunk)
+ {
+ return {HttpResponseCode::NotFound, {}};
+ }
+
+ uint64_t ChunkSize = Chunk.GetSize();
+ if (Chunk.GetContentType() == HttpContentType::kCompressedBinary)
+ {
+ IoHash RawHash;
+ uint64_t RawSize;
+ bool IsCompressed = CompressedBuffer::ValidateCompressedHeader(Chunk, RawHash, RawSize);
+ ZEN_ASSERT(IsCompressed);
+ ChunkSize = RawSize;
+ }
+
+ CbObjectWriter Response;
+ Response << "size"sv << ChunkSize;
+ OutPayload = Response.Save();
+ return {HttpResponseCode::OK, {}};
+}
+
+std::pair<HttpResponseCode, std::string>
+ProjectStore::GetChunkRange(const std::string_view ProjectId,
+ const std::string_view OplogId,
+ const std::string_view ChunkId,
+ uint64_t Offset,
+ uint64_t Size,
+ ZenContentType AcceptType,
+ IoBuffer& OutChunk)
+{
+ bool IsOffset = Offset != 0 || Size != ~(0ull);
+
+ Ref<ProjectStore::Project> Project = OpenProject(ProjectId);
+ if (!Project)
+ {
+ return {HttpResponseCode::NotFound, fmt::format("Chunk request for unknown project '{}'", ProjectId)};
+ }
+
+ ProjectStore::Oplog* FoundLog = Project->OpenOplog(OplogId);
+
+ if (!FoundLog)
+ {
+ return {HttpResponseCode::NotFound, fmt::format("Chunk request for unknown oplog '{}/{}'", ProjectId, OplogId)};
+ }
+
+ if (ChunkId.size() != 2 * sizeof(Oid::OidBits))
+ {
+ return {HttpResponseCode::BadRequest, fmt::format("Chunk request for invalid chunk id '{}/{}'/'{}'", ProjectId, OplogId, ChunkId)};
+ }
+
+ const Oid Obj = Oid::FromHexString(ChunkId);
+
+ IoBuffer Chunk = FoundLog->FindChunk(Obj);
+ if (!Chunk)
+ {
+ return {HttpResponseCode::NotFound, {}};
+ }
+
+ OutChunk = Chunk;
+ HttpContentType ContentType = Chunk.GetContentType();
+
+ if (Chunk.GetContentType() == HttpContentType::kCompressedBinary)
+ {
+ IoHash RawHash;
+ uint64_t RawSize;
+ CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(std::move(Chunk)), RawHash, RawSize);
+ ZEN_ASSERT(!Compressed.IsNull());
+
+ if (IsOffset)
+ {
+ if ((Offset + Size) > RawSize)
+ {
+ Size = RawSize - Offset;
+ }
+
+ if (AcceptType == HttpContentType::kBinary)
+ {
+ OutChunk = Compressed.Decompress(Offset, Size).AsIoBuffer();
+ OutChunk.SetContentType(HttpContentType::kBinary);
+ }
+ else
+ {
+ // Value will be a range of compressed blocks that covers the requested range
+ // The client will have to compensate for any offsets that do not land on an even block size multiple
+ OutChunk = Compressed.CopyRange(Offset, Size).GetCompressed().Flatten().AsIoBuffer();
+ OutChunk.SetContentType(HttpContentType::kCompressedBinary);
+ }
+ }
+ else
+ {
+ if (AcceptType == HttpContentType::kBinary)
+ {
+ OutChunk = Compressed.Decompress().AsIoBuffer();
+ OutChunk.SetContentType(HttpContentType::kBinary);
+ }
+ else
+ {
+ OutChunk = Compressed.GetCompressed().Flatten().AsIoBuffer();
+ OutChunk.SetContentType(HttpContentType::kCompressedBinary);
+ }
+ }
+ }
+ else if (IsOffset)
+ {
+ if ((Offset + Size) > Chunk.GetSize())
+ {
+ Size = Chunk.GetSize() - Offset;
+ }
+ OutChunk = IoBuffer(std::move(Chunk), Offset, Size);
+ OutChunk.SetContentType(ContentType);
+ }
+
+ return {HttpResponseCode::OK, {}};
+}
+
+std::pair<HttpResponseCode, std::string>
+ProjectStore::GetChunk(const std::string_view ProjectId,
+ const std::string_view OplogId,
+ const std::string_view Cid,
+ ZenContentType AcceptType,
+ IoBuffer& OutChunk)
+{
+ Ref<ProjectStore::Project> Project = OpenProject(ProjectId);
+ if (!Project)
+ {
+ return {HttpResponseCode::NotFound, fmt::format("Chunk request for unknown project '{}'", ProjectId)};
+ }
+
+ ProjectStore::Oplog* FoundLog = Project->OpenOplog(OplogId);
+
+ if (!FoundLog)
+ {
+ return {HttpResponseCode::NotFound, fmt::format("Chunk request for unknown oplog '{}/{}'", ProjectId, OplogId)};
+ }
+
+ if (Cid.length() != IoHash::StringLength)
+ {
+ return {HttpResponseCode::BadRequest, fmt::format("Chunk request for invalid chunk id '{}/{}'/'{}'", ProjectId, OplogId, Cid)};
+ }
+
+ const IoHash Hash = IoHash::FromHexString(Cid);
+ OutChunk = m_CidStore.FindChunkByCid(Hash);
+
+ if (!OutChunk)
+ {
+ return {HttpResponseCode::NotFound, fmt::format("chunk - '{}' MISSING", Cid)};
+ }
+
+ if (AcceptType == ZenContentType::kUnknownContentType || AcceptType == ZenContentType::kBinary)
+ {
+ CompressedBuffer Compressed = CompressedBuffer::FromCompressedNoValidate(std::move(OutChunk));
+ OutChunk = Compressed.Decompress().AsIoBuffer();
+ OutChunk.SetContentType(ZenContentType::kBinary);
+ }
+ else
+ {
+ OutChunk.SetContentType(ZenContentType::kCompressedBinary);
+ }
+ return {HttpResponseCode::OK, {}};
+}
+
+std::pair<HttpResponseCode, std::string>
+ProjectStore::PutChunk(const std::string_view ProjectId,
+ const std::string_view OplogId,
+ const std::string_view Cid,
+ ZenContentType ContentType,
+ IoBuffer&& Chunk)
+{
+ Ref<ProjectStore::Project> Project = OpenProject(ProjectId);
+ if (!Project)
+ {
+ return {HttpResponseCode::NotFound, fmt::format("Chunk put request for unknown project '{}'", ProjectId)};
+ }
+
+ ProjectStore::Oplog* FoundLog = Project->OpenOplog(OplogId);
+
+ if (!FoundLog)
+ {
+ return {HttpResponseCode::NotFound, fmt::format("Chunk put request for unknown oplog '{}/{}'", ProjectId, OplogId)};
+ }
+
+ if (Cid.length() != IoHash::StringLength)
+ {
+ return {HttpResponseCode::BadRequest, fmt::format("Chunk put request for invalid chunk hash '{}'", Cid)};
+ }
+
+ const IoHash Hash = IoHash::FromHexString(Cid);
+
+ if (ContentType != HttpContentType::kCompressedBinary)
+ {
+ return {HttpResponseCode::BadRequest, fmt::format("Chunk request for invalid content type for chunk '{}'", Cid)};
+ }
+ IoHash RawHash;
+ uint64_t RawSize;
+ CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(Chunk), RawHash, RawSize);
+ if (RawHash != Hash)
+ {
+ return {HttpResponseCode::BadRequest, fmt::format("Chunk request for invalid payload format for chunk '{}'", Cid)};
+ }
+
+ CidStore::InsertResult Result = m_CidStore.AddChunk(Chunk, Hash);
+ return {Result.New ? HttpResponseCode::Created : HttpResponseCode::OK, {}};
+}
+
+std::pair<HttpResponseCode, std::string>
+ProjectStore::WriteOplog(const std::string_view ProjectId, const std::string_view OplogId, IoBuffer&& Payload, CbObject& OutResponse)
+{
+ Ref<ProjectStore::Project> Project = OpenProject(ProjectId);
+ if (!Project)
+ {
+ return {HttpResponseCode::NotFound, fmt::format("Write oplog request for unknown project '{}'", ProjectId)};
+ }
+
+ ProjectStore::Oplog* Oplog = Project->OpenOplog(OplogId);
+
+ if (!Oplog)
+ {
+ return {HttpResponseCode::NotFound, fmt::format("Write oplog request for unknown oplog '{}/{}'", ProjectId, OplogId)};
+ }
+
+ CbObject ContainerObject = LoadCompactBinaryObject(Payload);
+ if (!ContainerObject)
+ {
+ return {HttpResponseCode::BadRequest, "Invalid payload format"};
+ }
+
+ CidStore& ChunkStore = m_CidStore;
+ RwLock AttachmentsLock;
+ std::unordered_set<IoHash, IoHash::Hasher> Attachments;
+
+ auto HasAttachment = [&ChunkStore](const IoHash& RawHash) { return ChunkStore.ContainsChunk(RawHash); };
+ auto OnNeedBlock = [&AttachmentsLock, &Attachments](const IoHash& BlockHash, const std::vector<IoHash>&& ChunkHashes) {
+ RwLock::ExclusiveLockScope _(AttachmentsLock);
+ if (BlockHash != IoHash::Zero)
+ {
+ Attachments.insert(BlockHash);
+ }
+ else
+ {
+ Attachments.insert(ChunkHashes.begin(), ChunkHashes.end());
+ }
+ };
+ auto OnNeedAttachment = [&AttachmentsLock, &Attachments](const IoHash& RawHash) {
+ RwLock::ExclusiveLockScope _(AttachmentsLock);
+ Attachments.insert(RawHash);
+ };
+
+ RemoteProjectStore::Result RemoteResult = SaveOplogContainer(*Oplog, ContainerObject, HasAttachment, OnNeedBlock, OnNeedAttachment);
+
+ if (RemoteResult.ErrorCode)
+ {
+ return ConvertResult(RemoteResult);
+ }
+
+ CbObjectWriter Cbo;
+ Cbo.BeginArray("need");
+ {
+ for (const IoHash& Hash : Attachments)
+ {
+ ZEN_DEBUG("Need attachment {}", Hash);
+ Cbo << Hash;
+ }
+ }
+ Cbo.EndArray(); // "need"
+
+ OutResponse = Cbo.Save();
+ return {HttpResponseCode::OK, {}};
+}
+
+std::pair<HttpResponseCode, std::string>
+ProjectStore::ReadOplog(const std::string_view ProjectId,
+ const std::string_view OplogId,
+ const HttpServerRequest::QueryParams& Params,
+ CbObject& OutResponse)
+{
+ Ref<ProjectStore::Project> Project = OpenProject(ProjectId);
+ if (!Project)
+ {
+ return {HttpResponseCode::NotFound, fmt::format("Read oplog request for unknown project '{}'", ProjectId)};
+ }
+
+ ProjectStore::Oplog* Oplog = Project->OpenOplog(OplogId);
+
+ if (!Oplog)
+ {
+ return {HttpResponseCode::NotFound, fmt::format("Read oplog request for unknown oplog '{}/{}'", ProjectId, OplogId)};
+ }
+
+ size_t MaxBlockSize = 128u * 1024u * 1024u;
+ if (auto Param = Params.GetValue("maxblocksize"); Param.empty() == false)
+ {
+ if (auto Value = ParseInt<size_t>(Param))
+ {
+ MaxBlockSize = Value.value();
+ }
+ }
+ size_t MaxChunkEmbedSize = 1024u * 1024u;
+ if (auto Param = Params.GetValue("maxchunkembedsize"); Param.empty() == false)
+ {
+ if (auto Value = ParseInt<size_t>(Param))
+ {
+ MaxChunkEmbedSize = Value.value();
+ }
+ }
+
+ CidStore& ChunkStore = m_CidStore;
+
+ RemoteProjectStore::LoadContainerResult ContainerResult = BuildContainer(
+ ChunkStore,
+ *Oplog,
+ MaxBlockSize,
+ MaxChunkEmbedSize,
+ false,
+ [](CompressedBuffer&&, const IoHash) {},
+ [](const IoHash&) {},
+ [](const std::unordered_set<IoHash, IoHash::Hasher>) {});
+
+ OutResponse = std::move(ContainerResult.ContainerObject);
+ return ConvertResult(ContainerResult);
+}
+
+std::pair<HttpResponseCode, std::string>
+ProjectStore::WriteBlock(const std::string_view ProjectId, const std::string_view OplogId, IoBuffer&& Payload)
+{
+ Ref<ProjectStore::Project> Project = OpenProject(ProjectId);
+ if (!Project)
+ {
+ return {HttpResponseCode::NotFound, fmt::format("Write block request for unknown project '{}'", ProjectId)};
+ }
+
+ ProjectStore::Oplog* Oplog = Project->OpenOplog(OplogId);
+
+ if (!Oplog)
+ {
+ return {HttpResponseCode::NotFound, fmt::format("Write block request for unknown oplog '{}/{}'", ProjectId, OplogId)};
+ }
+
+ if (!IterateBlock(std::move(Payload), [this](CompressedBuffer&& Chunk, const IoHash& AttachmentRawHash) {
+ IoBuffer Compressed = Chunk.GetCompressed().Flatten().AsIoBuffer();
+ m_CidStore.AddChunk(Compressed, AttachmentRawHash);
+ ZEN_DEBUG("Saved attachment {} from block, size {}", AttachmentRawHash, Compressed.GetSize());
+ }))
+ {
+ return {HttpResponseCode::BadRequest, "Invalid chunk in block"};
+ }
+
+ return {HttpResponseCode::OK, {}};
+}
+
+void
+ProjectStore::Rpc(HttpServerRequest& HttpReq,
+ const std::string_view ProjectId,
+ const std::string_view OplogId,
+ IoBuffer&& Payload,
+ AuthMgr& AuthManager)
+{
+ using namespace std::literals;
+ HttpContentType PayloadContentType = HttpReq.RequestContentType();
+ CbPackage Package;
+ CbObject Cb;
+ switch (PayloadContentType)
+ {
+ case HttpContentType::kJSON:
+ case HttpContentType::kUnknownContentType:
+ case HttpContentType::kText:
+ {
+ std::string JsonText(reinterpret_cast<const char*>(Payload.GetData()), Payload.GetSize());
+ Cb = LoadCompactBinaryFromJson(JsonText).AsObject();
+ if (!Cb)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::BadRequest,
+ HttpContentType::kText,
+ "Content format not supported, expected JSON format");
+ }
+ }
+ break;
+ case HttpContentType::kCbObject:
+ Cb = LoadCompactBinaryObject(Payload);
+ if (!Cb)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::BadRequest,
+ HttpContentType::kText,
+ "Content format not supported, expected compact binary format");
+ }
+ break;
+ case HttpContentType::kCbPackage:
+ Package = ParsePackageMessage(Payload);
+ Cb = Package.GetObject();
+ if (!Cb)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::BadRequest,
+ HttpContentType::kText,
+ "Content format not supported, expected package message format");
+ }
+ break;
+ default:
+ return HttpReq.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "Invalid request content type");
+ }
+
+ Ref<ProjectStore::Project> Project = OpenProject(ProjectId);
+ if (!Project)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound,
+ HttpContentType::kText,
+ fmt::format("Rpc oplog request for unknown project '{}'", ProjectId));
+ }
+
+ ProjectStore::Oplog* Oplog = Project->OpenOplog(OplogId);
+
+ if (!Oplog)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound,
+ HttpContentType::kText,
+ fmt::format("Rpc oplog request for unknown oplog '{}/{}'", ProjectId, OplogId));
+ }
+
+ std::string_view Method = Cb["method"sv].AsString();
+
+ if (Method == "import")
+ {
+ std::pair<HttpResponseCode, std::string> Result = Import(*Project.Get(), *Oplog, Cb["params"sv].AsObjectView(), AuthManager);
+ if (Result.second.empty())
+ {
+ return HttpReq.WriteResponse(Result.first);
+ }
+ return HttpReq.WriteResponse(Result.first, HttpContentType::kText, Result.second);
+ }
+ else if (Method == "export")
+ {
+ std::pair<HttpResponseCode, std::string> Result = Export(*Project.Get(), *Oplog, Cb["params"sv].AsObjectView(), AuthManager);
+ if (Result.second.empty())
+ {
+ return HttpReq.WriteResponse(Result.first);
+ }
+ return HttpReq.WriteResponse(Result.first, HttpContentType::kText, Result.second);
+ }
+ else if (Method == "getchunks")
+ {
+ CbPackage ResponsePackage;
+ {
+ CbArrayView ChunksArray = Cb["chunks"sv].AsArrayView();
+ CbObjectWriter ResponseWriter;
+ ResponseWriter.BeginArray("chunks"sv);
+ for (CbFieldView FieldView : ChunksArray)
+ {
+ IoHash RawHash = FieldView.AsHash();
+ IoBuffer ChunkBuffer = m_CidStore.FindChunkByCid(RawHash);
+ if (ChunkBuffer)
+ {
+ ResponseWriter.AddHash(RawHash);
+ ResponsePackage.AddAttachment(
+ CbAttachment(CompressedBuffer::FromCompressedNoValidate(std::move(ChunkBuffer)), RawHash));
+ }
+ }
+ ResponseWriter.EndArray();
+ ResponsePackage.SetObject(ResponseWriter.Save());
+ }
+ CompositeBuffer RpcResponseBuffer = FormatPackageMessageBuffer(ResponsePackage, FormatFlags::kDefault);
+ return HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kCbPackage, RpcResponseBuffer);
+ }
+ else if (Method == "putchunks")
+ {
+ std::span<const CbAttachment> Attachments = Package.GetAttachments();
+ for (const CbAttachment& Attachment : Attachments)
+ {
+ IoHash RawHash = Attachment.GetHash();
+ CompressedBuffer Compressed = Attachment.AsCompressedBinary();
+ m_CidStore.AddChunk(Compressed.GetCompressed().Flatten().AsIoBuffer(), RawHash, CidStore::InsertMode::kCopyOnly);
+ }
+ return HttpReq.WriteResponse(HttpResponseCode::OK);
+ }
+ return HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, fmt::format("Unknown rpc method '{}'", Method));
+}
+
+std::pair<HttpResponseCode, std::string>
+ProjectStore::Export(ProjectStore::Project& Project, ProjectStore::Oplog& Oplog, CbObjectView&& Params, AuthMgr& AuthManager)
+{
+ using namespace std::literals;
+
+ size_t MaxBlockSize = Params["maxblocksize"sv].AsUInt64(128u * 1024u * 1024u);
+ size_t MaxChunkEmbedSize = Params["maxchunkembedsize"sv].AsUInt64(1024u * 1024u);
+ bool Force = Params["force"sv].AsBool(false);
+
+ std::pair<std::unique_ptr<RemoteProjectStore>, std::string> RemoteStoreResult =
+ CreateRemoteStore(Params, AuthManager, MaxBlockSize, MaxChunkEmbedSize);
+
+ if (RemoteStoreResult.first == nullptr)
+ {
+ return {HttpResponseCode::BadRequest, RemoteStoreResult.second};
+ }
+ std::unique_ptr<RemoteProjectStore> RemoteStore = std::move(RemoteStoreResult.first);
+ RemoteProjectStore::RemoteStoreInfo StoreInfo = RemoteStore->GetInfo();
+
+ ZEN_INFO("Saving oplog '{}/{}' to {}, maxblocksize {}, maxchunkembedsize {}",
+ Project.Identifier,
+ Oplog.OplogId(),
+ StoreInfo.Description,
+ NiceBytes(MaxBlockSize),
+ NiceBytes(MaxChunkEmbedSize));
+
+ RemoteProjectStore::Result Result = SaveOplog(m_CidStore,
+ *RemoteStore,
+ Oplog,
+ MaxBlockSize,
+ MaxChunkEmbedSize,
+ StoreInfo.CreateBlocks,
+ StoreInfo.UseTempBlockFiles,
+ Force);
+
+ return ConvertResult(Result);
+}
+
+std::pair<HttpResponseCode, std::string>
+ProjectStore::Import(ProjectStore::Project& Project, ProjectStore::Oplog& Oplog, CbObjectView&& Params, AuthMgr& AuthManager)
+{
+ using namespace std::literals;
+
+ size_t MaxBlockSize = Params["maxblocksize"sv].AsUInt64(128u * 1024u * 1024u);
+ size_t MaxChunkEmbedSize = Params["maxchunkembedsize"sv].AsUInt64(1024u * 1024u);
+ bool Force = Params["force"sv].AsBool(false);
+
+ std::pair<std::unique_ptr<RemoteProjectStore>, std::string> RemoteStoreResult =
+ CreateRemoteStore(Params, AuthManager, MaxBlockSize, MaxChunkEmbedSize);
+
+ if (RemoteStoreResult.first == nullptr)
+ {
+ return {HttpResponseCode::BadRequest, RemoteStoreResult.second};
+ }
+ std::unique_ptr<RemoteProjectStore> RemoteStore = std::move(RemoteStoreResult.first);
+ RemoteProjectStore::RemoteStoreInfo StoreInfo = RemoteStore->GetInfo();
+
+ ZEN_INFO("Loading oplog '{}/{}' from {}", Project.Identifier, Oplog.OplogId(), StoreInfo.Description);
+ RemoteProjectStore::Result Result = LoadOplog(m_CidStore, *RemoteStore, Oplog, Force);
+ return ConvertResult(Result);
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+HttpProjectService::HttpProjectService(CidStore& Store, ProjectStore* Projects, HttpStatsService& StatsService, AuthMgr& AuthMgr)
+: m_Log(logging::Get("project"))
+, m_CidStore(Store)
+, m_ProjectStore(Projects)
+, m_StatsService(StatsService)
+, m_AuthMgr(AuthMgr)
+{
+ using namespace std::literals;
+
+ m_StatsService.RegisterHandler("prj", *this);
+
+ m_Router.AddPattern("project", "([[:alnum:]_.]+)");
+ m_Router.AddPattern("log", "([[:alnum:]_.]+)");
+ m_Router.AddPattern("op", "([[:digit:]]+?)");
+ m_Router.AddPattern("chunk", "([[:xdigit:]]{24})");
+ m_Router.AddPattern("hash", "([[:xdigit:]]{40})");
+
+ m_Router.RegisterRoute(
+ "",
+ [this](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK, m_ProjectStore->GetProjectsList()); },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "list",
+ [this](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK, m_ProjectStore->GetProjectsList()); },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "{project}/oplog/{log}/batch",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+ const auto& ProjectId = Req.GetCapture(1);
+ const auto& OplogId = Req.GetCapture(2);
+
+ Ref<ProjectStore::Project> Project = m_ProjectStore->OpenProject(ProjectId);
+ if (!Project)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound);
+ }
+
+ ProjectStore::Oplog* FoundLog = Project->OpenOplog(OplogId);
+
+ if (!FoundLog)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound);
+ }
+
+ // Parse Request
+
+ IoBuffer Payload = HttpReq.ReadPayload();
+ BinaryReader Reader(Payload);
+
+ struct RequestHeader
+ {
+ enum
+ {
+ kMagic = 0xAAAA'77AC
+ };
+ uint32_t Magic;
+ uint32_t ChunkCount;
+ uint32_t Reserved1;
+ uint32_t Reserved2;
+ };
+
+ struct RequestChunkEntry
+ {
+ Oid ChunkId;
+ uint32_t CorrelationId;
+ uint64_t Offset;
+ uint64_t RequestBytes;
+ };
+
+ if (Payload.Size() <= sizeof(RequestHeader))
+ {
+ HttpReq.WriteResponse(HttpResponseCode::BadRequest);
+ }
+
+ RequestHeader RequestHdr;
+ Reader.Read(&RequestHdr, sizeof RequestHdr);
+
+ if (RequestHdr.Magic != RequestHeader::kMagic)
+ {
+ HttpReq.WriteResponse(HttpResponseCode::BadRequest);
+ }
+
+ std::vector<RequestChunkEntry> RequestedChunks;
+ RequestedChunks.resize(RequestHdr.ChunkCount);
+ Reader.Read(RequestedChunks.data(), sizeof(RequestChunkEntry) * RequestHdr.ChunkCount);
+
+ // Make Response
+
+ struct ResponseHeader
+ {
+ uint32_t Magic = 0xbada'b00f;
+ uint32_t ChunkCount;
+ uint32_t Reserved1 = 0;
+ uint32_t Reserved2 = 0;
+ };
+
+ struct ResponseChunkEntry
+ {
+ uint32_t CorrelationId;
+ uint32_t Flags = 0;
+ uint64_t ChunkSize;
+ };
+
+ std::vector<IoBuffer> OutBlobs;
+ OutBlobs.emplace_back(sizeof(ResponseHeader) + RequestHdr.ChunkCount * sizeof(ResponseChunkEntry));
+ for (uint32_t ChunkIndex = 0; ChunkIndex < RequestHdr.ChunkCount; ++ChunkIndex)
+ {
+ const RequestChunkEntry& RequestedChunk = RequestedChunks[ChunkIndex];
+ IoBuffer FoundChunk = FoundLog->FindChunk(RequestedChunk.ChunkId);
+ if (FoundChunk)
+ {
+ if (RequestedChunk.Offset > 0 || RequestedChunk.RequestBytes < uint64_t(-1))
+ {
+ uint64_t Offset = RequestedChunk.Offset;
+ if (Offset > FoundChunk.Size())
+ {
+ Offset = FoundChunk.Size();
+ }
+ uint64_t Size = RequestedChunk.RequestBytes;
+ if ((Offset + Size) > FoundChunk.Size())
+ {
+ Size = FoundChunk.Size() - Offset;
+ }
+ FoundChunk = IoBuffer(FoundChunk, Offset, Size);
+ }
+ }
+ OutBlobs.emplace_back(std::move(FoundChunk));
+ }
+ uint8_t* ResponsePtr = reinterpret_cast<uint8_t*>(OutBlobs[0].MutableData());
+ ResponseHeader ResponseHdr;
+ ResponseHdr.ChunkCount = RequestHdr.ChunkCount;
+ memcpy(ResponsePtr, &ResponseHdr, sizeof(ResponseHdr));
+ ResponsePtr += sizeof(ResponseHdr);
+ for (uint32_t ChunkIndex = 0; ChunkIndex < RequestHdr.ChunkCount; ++ChunkIndex)
+ {
+ // const RequestChunkEntry& RequestedChunk = RequestedChunks[ChunkIndex];
+ const IoBuffer& FoundChunk(OutBlobs[ChunkIndex + 1]);
+ ResponseChunkEntry ResponseChunk;
+ ResponseChunk.CorrelationId = ChunkIndex;
+ if (FoundChunk)
+ {
+ ResponseChunk.ChunkSize = FoundChunk.Size();
+ }
+ else
+ {
+ ResponseChunk.ChunkSize = uint64_t(-1);
+ }
+ memcpy(ResponsePtr, &ResponseChunk, sizeof(ResponseChunk));
+ ResponsePtr += sizeof(ResponseChunk);
+ }
+ return HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, OutBlobs);
+ },
+ HttpVerb::kPost);
+
+ m_Router.RegisterRoute(
+ "{project}/oplog/{log}/files",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+
+ // File manifest fetch, returns the client file list
+
+ const auto& ProjectId = Req.GetCapture(1);
+ const auto& OplogId = Req.GetCapture(2);
+
+ HttpServerRequest::QueryParams Params = HttpReq.GetQueryParams();
+
+ const bool FilterClient = Params.GetValue("filter"sv) == "client"sv;
+
+ CbObject ResponsePayload;
+ std::pair<HttpResponseCode, std::string> Result =
+ m_ProjectStore->GetProjectFiles(ProjectId, OplogId, FilterClient, ResponsePayload);
+ if (Result.first == HttpResponseCode::OK)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::OK, ResponsePayload);
+ }
+ else
+ {
+ ZEN_DEBUG("Request {}: '{}' failed with {}. Reason: `{}`",
+ ToString(HttpReq.RequestVerb()),
+ HttpReq.QueryString(),
+ static_cast<int>(Result.first),
+ Result.second);
+ }
+ if (Result.second.empty())
+ {
+ return HttpReq.WriteResponse(Result.first);
+ }
+ return HttpReq.WriteResponse(Result.first, HttpContentType::kText, Result.second);
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "{project}/oplog/{log}/{chunk}/info",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+
+ const auto& ProjectId = Req.GetCapture(1);
+ const auto& OplogId = Req.GetCapture(2);
+ const auto& ChunkId = Req.GetCapture(3);
+
+ CbObject ResponsePayload;
+ std::pair<HttpResponseCode, std::string> Result = m_ProjectStore->GetChunkInfo(ProjectId, OplogId, ChunkId, ResponsePayload);
+ if (Result.first == HttpResponseCode::OK)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::OK, ResponsePayload);
+ }
+ else if (Result.first == HttpResponseCode::NotFound)
+ {
+ ZEN_DEBUG("chunk - '{}/{}/{}' MISSING", ProjectId, OplogId, ChunkId);
+ }
+ else
+ {
+ ZEN_DEBUG("Request {}: '{}' failed with {}. Reason: `{}`",
+ ToString(HttpReq.RequestVerb()),
+ HttpReq.QueryString(),
+ static_cast<int>(Result.first),
+ Result.second);
+ }
+ if (Result.second.empty())
+ {
+ return HttpReq.WriteResponse(Result.first);
+ }
+ return HttpReq.WriteResponse(Result.first, HttpContentType::kText, Result.second);
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "{project}/oplog/{log}/{chunk}",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+
+ const auto& ProjectId = Req.GetCapture(1);
+ const auto& OplogId = Req.GetCapture(2);
+ const auto& ChunkId = Req.GetCapture(3);
+
+ uint64_t Offset = 0;
+ uint64_t Size = ~(0ull);
+
+ auto QueryParms = Req.ServerRequest().GetQueryParams();
+
+ if (auto OffsetParm = QueryParms.GetValue("offset"); OffsetParm.empty() == false)
+ {
+ if (auto OffsetVal = ParseInt<uint64_t>(OffsetParm))
+ {
+ Offset = OffsetVal.value();
+ }
+ else
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::BadRequest);
+ }
+ }
+
+ if (auto SizeParm = QueryParms.GetValue("size"); SizeParm.empty() == false)
+ {
+ if (auto SizeVal = ParseInt<uint64_t>(SizeParm))
+ {
+ Size = SizeVal.value();
+ }
+ else
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::BadRequest);
+ }
+ }
+
+ HttpContentType AcceptType = HttpReq.AcceptContentType();
+
+ IoBuffer Chunk;
+ std::pair<HttpResponseCode, std::string> Result =
+ m_ProjectStore->GetChunkRange(ProjectId, OplogId, ChunkId, Offset, Size, AcceptType, Chunk);
+ if (Result.first == HttpResponseCode::OK)
+ {
+ ZEN_DEBUG("chunk - '{}/{}/{}' '{}'", ProjectId, OplogId, ChunkId, ToString(Chunk.GetContentType()));
+ return HttpReq.WriteResponse(HttpResponseCode::OK, Chunk.GetContentType(), Chunk);
+ }
+ else if (Result.first == HttpResponseCode::NotFound)
+ {
+ ZEN_DEBUG("chunk - '{}/{}/{}' MISSING", ProjectId, OplogId, ChunkId);
+ }
+ else
+ {
+ ZEN_DEBUG("Request {}: '{}' failed with {}. Reason: `{}`",
+ ToString(HttpReq.RequestVerb()),
+ HttpReq.QueryString(),
+ static_cast<int>(Result.first),
+ Result.second);
+ }
+ if (Result.second.empty())
+ {
+ return HttpReq.WriteResponse(Result.first);
+ }
+ return HttpReq.WriteResponse(Result.first, HttpContentType::kText, Result.second);
+ },
+ HttpVerb::kGet | HttpVerb::kHead);
+
+ m_Router.RegisterRoute(
+ "{project}/oplog/{log}/{hash}",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+
+ const auto& ProjectId = Req.GetCapture(1);
+ const auto& OplogId = Req.GetCapture(2);
+ const auto& Cid = Req.GetCapture(3);
+ HttpContentType AcceptType = HttpReq.AcceptContentType();
+ HttpContentType RequestType = HttpReq.RequestContentType();
+
+ switch (Req.ServerRequest().RequestVerb())
+ {
+ case HttpVerb::kGet:
+ {
+ IoBuffer Value;
+ std::pair<HttpResponseCode, std::string> Result =
+ m_ProjectStore->GetChunk(ProjectId, OplogId, Cid, AcceptType, Value);
+
+ if (Result.first == HttpResponseCode::OK)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::OK, Value.GetContentType(), Value);
+ }
+ else if (Result.first == HttpResponseCode::NotFound)
+ {
+ ZEN_DEBUG("chunk - '{}/{}/{}' MISSING", ProjectId, OplogId, Cid);
+ }
+ else
+ {
+ ZEN_DEBUG("Request {}: '{}' failed with {}. Reason: `{}`",
+ ToString(HttpReq.RequestVerb()),
+ HttpReq.QueryString(),
+ static_cast<int>(Result.first),
+ Result.second);
+ }
+ if (Result.second.empty())
+ {
+ return HttpReq.WriteResponse(Result.first);
+ }
+ return HttpReq.WriteResponse(Result.first, HttpContentType::kText, Result.second);
+ }
+ case HttpVerb::kPost:
+ {
+ std::pair<HttpResponseCode, std::string> Result =
+ m_ProjectStore->PutChunk(ProjectId, OplogId, Cid, RequestType, HttpReq.ReadPayload());
+ if (Result.first == HttpResponseCode::OK || Result.first == HttpResponseCode::Created)
+ {
+ return HttpReq.WriteResponse(Result.first);
+ }
+ else
+ {
+ ZEN_DEBUG("Request {}: '{}' failed with {}. Reason: `{}`",
+ ToString(HttpReq.RequestVerb()),
+ HttpReq.QueryString(),
+ static_cast<int>(Result.first),
+ Result.second);
+ }
+ if (Result.second.empty())
+ {
+ return HttpReq.WriteResponse(Result.first);
+ }
+ return HttpReq.WriteResponse(Result.first, HttpContentType::kText, Result.second);
+ }
+ break;
+ }
+ },
+ HttpVerb::kGet | HttpVerb::kPost);
+
+ m_Router.RegisterRoute(
+ "{project}/oplog/{log}/prep",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+
+ const auto& ProjectId = Req.GetCapture(1);
+ const auto& OplogId = Req.GetCapture(2);
+
+ Ref<ProjectStore::Project> Project = m_ProjectStore->OpenProject(ProjectId);
+ if (!Project)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound);
+ }
+
+ ProjectStore::Oplog* FoundLog = Project->OpenOplog(OplogId);
+
+ if (!FoundLog)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound);
+ }
+
+ // This operation takes a list of referenced hashes and decides which
+ // chunks are not present on this server. This list is then returned in
+ // the "need" list in the response
+
+ IoBuffer Payload = HttpReq.ReadPayload();
+ CbObject RequestObject = LoadCompactBinaryObject(Payload);
+
+ std::vector<IoHash> NeedList;
+
+ for (auto Entry : RequestObject["have"sv])
+ {
+ const IoHash FileHash = Entry.AsHash();
+
+ if (!m_CidStore.ContainsChunk(FileHash))
+ {
+ ZEN_DEBUG("prep - NEED: {}", FileHash);
+
+ NeedList.push_back(FileHash);
+ }
+ }
+
+ CbObjectWriter Cbo;
+ Cbo.BeginArray("need");
+
+ for (const IoHash& Hash : NeedList)
+ {
+ Cbo << Hash;
+ }
+
+ Cbo.EndArray();
+ CbObject Response = Cbo.Save();
+
+ return HttpReq.WriteResponse(HttpResponseCode::OK, Response);
+ },
+ HttpVerb::kPost);
+
+ m_Router.RegisterRoute(
+ "{project}/oplog/{log}/new",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+
+ const auto& ProjectId = Req.GetCapture(1);
+ const auto& OplogId = Req.GetCapture(2);
+
+ HttpServerRequest::QueryParams Params = HttpReq.GetQueryParams();
+
+ bool IsUsingSalt = false;
+ IoHash SaltHash = IoHash::Zero;
+
+ if (std::string_view SaltParam = Params.GetValue("salt"); SaltParam.empty() == false)
+ {
+ const uint32_t Salt = std::stoi(std::string(SaltParam));
+ SaltHash = IoHash::HashBuffer(&Salt, sizeof Salt);
+ IsUsingSalt = true;
+ }
+
+ Ref<ProjectStore::Project> Project = m_ProjectStore->OpenProject(ProjectId);
+ if (!Project)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound);
+ }
+
+ ProjectStore::Oplog* FoundLog = Project->OpenOplog(OplogId);
+
+ if (!FoundLog)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound);
+ }
+
+ ProjectStore::Oplog& Oplog = *FoundLog;
+
+ IoBuffer Payload = HttpReq.ReadPayload();
+
+ // This will attempt to open files which may not exist for the case where
+ // the prep step rejected the chunk. This should be fixed since there's
+ // a performance cost associated with any file system activity
+
+ bool IsValid = true;
+ std::vector<IoHash> MissingChunks;
+
+ CbPackage::AttachmentResolver Resolver = [&](const IoHash& Hash) -> SharedBuffer {
+ if (m_CidStore.ContainsChunk(Hash))
+ {
+ // Return null attachment as we already have it, no point in reading it and storing it again
+ return {};
+ }
+
+ IoHash AttachmentId;
+ if (IsUsingSalt)
+ {
+ IoHash AttachmentSpec[]{SaltHash, Hash};
+ AttachmentId = IoHash::HashBuffer(MakeMemoryView(AttachmentSpec));
+ }
+ else
+ {
+ AttachmentId = Hash;
+ }
+
+ std::filesystem::path AttachmentPath = Oplog.TempPath() / AttachmentId.ToHexString();
+ if (IoBuffer Data = IoBufferBuilder::MakeFromTemporaryFile(AttachmentPath))
+ {
+ return SharedBuffer(std::move(Data));
+ }
+ else
+ {
+ IsValid = false;
+ MissingChunks.push_back(Hash);
+
+ return {};
+ }
+ };
+
+ CbPackage Package;
+
+ if (!legacy::TryLoadCbPackage(Package, Payload, &UniqueBuffer::Alloc, &Resolver))
+ {
+ std::filesystem::path BadPackagePath =
+ Oplog.TempPath() / "bad_packages"sv / fmt::format("session{}_request{}"sv, HttpReq.SessionId(), HttpReq.RequestId());
+
+ ZEN_WARN("Received malformed package! Saving payload to '{}'", BadPackagePath);
+
+ WriteFile(BadPackagePath, Payload);
+
+ return HttpReq.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "Invalid package");
+ }
+
+ if (!IsValid)
+ {
+ // TODO: emit diagnostics identifying missing chunks
+
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound, HttpContentType::kText, "Missing chunk reference");
+ }
+
+ CbObject Core = Package.GetObject();
+
+ if (!Core["key"sv])
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "No oplog entry key specified");
+ }
+
+ // Write core to oplog
+
+ const uint32_t OpLsn = Oplog.AppendNewOplogEntry(Package);
+
+ if (OpLsn == ProjectStore::Oplog::kInvalidOp)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::BadRequest);
+ }
+
+ ZEN_DEBUG("'{}/{}' op #{} ({}) - '{}'", ProjectId, OplogId, OpLsn, NiceBytes(Payload.Size()), Core["key"sv].AsString());
+
+ HttpReq.WriteResponse(HttpResponseCode::Created);
+ },
+ HttpVerb::kPost);
+
+ m_Router.RegisterRoute(
+ "{project}/oplog/{log}/{op}",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+
+ const std::string& ProjectId = Req.GetCapture(1);
+ const std::string& OplogId = Req.GetCapture(2);
+ const std::string& OpIdString = Req.GetCapture(3);
+
+ Ref<ProjectStore::Project> Project = m_ProjectStore->OpenProject(ProjectId);
+ if (!Project)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound);
+ }
+
+ ProjectStore::Oplog* FoundLog = Project->OpenOplog(OplogId);
+
+ if (!FoundLog)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound);
+ }
+
+ ProjectStore::Oplog& Oplog = *FoundLog;
+
+ if (const std::optional<int32_t> OpId = ParseInt<uint32_t>(OpIdString))
+ {
+ if (std::optional<CbObject> MaybeOp = Oplog.GetOpByIndex(OpId.value()))
+ {
+ CbObject& Op = MaybeOp.value();
+ if (Req.ServerRequest().AcceptContentType() == ZenContentType::kCbPackage)
+ {
+ CbPackage Package;
+ Package.SetObject(Op);
+
+ Op.IterateAttachments([&](CbFieldView FieldView) {
+ const IoHash AttachmentHash = FieldView.AsAttachment();
+ IoBuffer Payload = m_CidStore.FindChunkByCid(AttachmentHash);
+
+ // We force this for now as content type is not consistently tracked (will
+ // be fixed in CidStore refactor)
+ Payload.SetContentType(ZenContentType::kCompressedBinary);
+
+ if (Payload)
+ {
+ switch (Payload.GetContentType())
+ {
+ case ZenContentType::kCbObject:
+ if (CbObject Object = LoadCompactBinaryObject(Payload))
+ {
+ Package.AddAttachment(CbAttachment(Object));
+ }
+ else
+ {
+ // Error - malformed object
+
+ ZEN_WARN("malformed object returned for {}", AttachmentHash);
+ }
+ break;
+
+ case ZenContentType::kCompressedBinary:
+ if (CompressedBuffer Compressed = CompressedBuffer::FromCompressedNoValidate(std::move(Payload)))
+ {
+ Package.AddAttachment(CbAttachment(Compressed, AttachmentHash));
+ }
+ else
+ {
+ // Error - not compressed!
+
+ ZEN_WARN("invalid compressed binary returned for {}", AttachmentHash);
+ }
+ break;
+
+ default:
+ Package.AddAttachment(CbAttachment(SharedBuffer(Payload)));
+ break;
+ }
+ }
+ });
+
+ return HttpReq.WriteResponse(HttpResponseCode::Accepted, Package);
+ }
+ else
+ {
+ // Client cannot accept a package, so we only send the core object
+ return HttpReq.WriteResponse(HttpResponseCode::Accepted, Op);
+ }
+ }
+ }
+
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound);
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "{project}/oplog/{log}",
+ [this](HttpRouterRequest& Req) {
+ const auto& ProjectId = Req.GetCapture(1);
+ const auto& OplogId = Req.GetCapture(2);
+
+ Ref<ProjectStore::Project> Project = m_ProjectStore->OpenProject(ProjectId);
+
+ if (!Project)
+ {
+ return Req.ServerRequest().WriteResponse(HttpResponseCode::NotFound,
+ HttpContentType::kText,
+ fmt::format("project {} not found", ProjectId));
+ }
+
+ switch (Req.ServerRequest().RequestVerb())
+ {
+ case HttpVerb::kGet:
+ {
+ ProjectStore::Oplog* OplogIt = Project->OpenOplog(OplogId);
+
+ if (!OplogIt)
+ {
+ return Req.ServerRequest().WriteResponse(HttpResponseCode::NotFound,
+ HttpContentType::kText,
+ fmt::format("oplog {} not found in project {}", OplogId, ProjectId));
+ }
+
+ ProjectStore::Oplog& Log = *OplogIt;
+
+ CbObjectWriter Cb;
+ Cb << "id"sv << Log.OplogId() << "project"sv << Project->Identifier << "tempdir"sv << Log.TempPath().c_str()
+ << "markerpath"sv << Log.MarkerPath().c_str() << "totalsize"sv << Log.TotalSize() << "opcount"
+ << Log.OplogCount() << "expired"sv << Log.IsExpired();
+
+ Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Cb.Save());
+ }
+ break;
+
+ case HttpVerb::kPost:
+ {
+ std::filesystem::path OplogMarkerPath;
+ if (CbObject Params = Req.ServerRequest().ReadPayloadObject())
+ {
+ OplogMarkerPath = Params["gcpath"sv].AsString();
+ }
+
+ ProjectStore::Oplog* OplogIt = Project->OpenOplog(OplogId);
+
+ if (!OplogIt)
+ {
+ if (!Project->NewOplog(OplogId, OplogMarkerPath))
+ {
+ // TODO: indicate why the operation failed!
+ return Req.ServerRequest().WriteResponse(HttpResponseCode::InternalServerError);
+ }
+
+ ZEN_INFO("established oplog '{}/{}', gc marker file at '{}'", ProjectId, OplogId, OplogMarkerPath);
+
+ return Req.ServerRequest().WriteResponse(HttpResponseCode::Created);
+ }
+
+ // I guess this should ultimately be used to execute RPCs but for now, it
+ // does absolutely nothing
+
+ return Req.ServerRequest().WriteResponse(HttpResponseCode::BadRequest);
+ }
+ break;
+
+ case HttpVerb::kDelete:
+ {
+ ZEN_INFO("deleting oplog '{}/{}'", ProjectId, OplogId);
+
+ Project->DeleteOplog(OplogId);
+
+ return Req.ServerRequest().WriteResponse(HttpResponseCode::OK);
+ }
+ break;
+
+ default:
+ break;
+ }
+ },
+ HttpVerb::kPost | HttpVerb::kGet | HttpVerb::kDelete);
+
+ m_Router.RegisterRoute(
+ "{project}/oplog/{log}/entries",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+
+ const auto& ProjectId = Req.GetCapture(1);
+ const auto& OplogId = Req.GetCapture(2);
+
+ Ref<ProjectStore::Project> Project = m_ProjectStore->OpenProject(ProjectId);
+ if (!Project)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound);
+ }
+
+ ProjectStore::Oplog* FoundLog = Project->OpenOplog(OplogId);
+
+ if (!FoundLog)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound);
+ }
+
+ CbObjectWriter Response;
+
+ if (FoundLog->OplogCount() > 0)
+ {
+ HttpServerRequest::QueryParams Params = HttpReq.GetQueryParams();
+
+ if (auto OpKey = Params.GetValue("opkey"); !OpKey.empty())
+ {
+ Oid OpKeyId = OpKeyStringAsOId(OpKey);
+ std::optional<CbObject> Op = FoundLog->GetOpByKey(OpKeyId);
+
+ if (Op.has_value())
+ {
+ Response << "entry"sv << Op.value();
+ }
+ else
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound);
+ }
+ }
+ else
+ {
+ Response.BeginArray("entries"sv);
+
+ FoundLog->IterateOplog([&Response](CbObject Op) { Response << Op; });
+
+ Response.EndArray();
+ }
+ }
+
+ return HttpReq.WriteResponse(HttpResponseCode::OK, Response.Save());
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "{project}",
+ [this](HttpRouterRequest& Req) {
+ const std::string ProjectId = Req.GetCapture(1);
+
+ switch (Req.ServerRequest().RequestVerb())
+ {
+ case HttpVerb::kPost:
+ {
+ IoBuffer Payload = Req.ServerRequest().ReadPayload();
+ CbObject Params = LoadCompactBinaryObject(Payload);
+ std::string_view Id = Params["id"sv].AsString();
+ std::string_view Root = Params["root"sv].AsString();
+ std::string_view EngineRoot = Params["engine"sv].AsString();
+ std::string_view ProjectRoot = Params["project"sv].AsString();
+ std::string_view ProjectFilePath = Params["projectfile"sv].AsString();
+
+ const std::filesystem::path BasePath = m_ProjectStore->BasePath() / ProjectId;
+ m_ProjectStore->NewProject(BasePath, ProjectId, Root, EngineRoot, ProjectRoot, ProjectFilePath);
+
+ ZEN_INFO("established project - {} (id: '{}', roots: '{}', '{}', '{}', '{}'{})",
+ ProjectId,
+ Id,
+ Root,
+ EngineRoot,
+ ProjectRoot,
+ ProjectFilePath,
+ ProjectFilePath.empty() ? ", project will not be GCd due to empty project file path" : "");
+
+ Req.ServerRequest().WriteResponse(HttpResponseCode::Created);
+ }
+ break;
+
+ case HttpVerb::kGet:
+ {
+ Ref<ProjectStore::Project> Project = m_ProjectStore->OpenProject(ProjectId);
+
+ if (!Project)
+ {
+ return Req.ServerRequest().WriteResponse(HttpResponseCode::NotFound,
+ HttpContentType::kText,
+ fmt::format("project {} not found", ProjectId));
+ }
+
+ std::vector<std::string> OpLogs = Project->ScanForOplogs();
+
+ CbObjectWriter Response;
+ Response << "id"sv << Project->Identifier;
+ Response << "root"sv << PathToUtf8(Project->RootDir);
+ Response << "engine"sv << PathToUtf8(Project->EngineRootDir);
+ Response << "project"sv << PathToUtf8(Project->ProjectRootDir);
+ Response << "projectfile"sv << PathToUtf8(Project->ProjectFilePath);
+
+ Response.BeginArray("oplogs"sv);
+ for (const std::string& OplogId : OpLogs)
+ {
+ Response.BeginObject();
+ Response << "id"sv << OplogId;
+ Response.EndObject();
+ }
+ Response.EndArray(); // oplogs
+
+ Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Response.Save());
+ }
+ break;
+
+ case HttpVerb::kDelete:
+ {
+ Ref<ProjectStore::Project> Project = m_ProjectStore->OpenProject(ProjectId);
+
+ if (!Project)
+ {
+ return Req.ServerRequest().WriteResponse(HttpResponseCode::NotFound,
+ HttpContentType::kText,
+ fmt::format("project {} not found", ProjectId));
+ }
+
+ ZEN_INFO("deleting project '{}'", ProjectId);
+ if (!m_ProjectStore->DeleteProject(ProjectId))
+ {
+ return Req.ServerRequest().WriteResponse(HttpResponseCode::Locked,
+ HttpContentType::kText,
+ fmt::format("project {} is in use", ProjectId));
+ }
+
+ return Req.ServerRequest().WriteResponse(HttpResponseCode::NoContent);
+ }
+ break;
+
+ default:
+ break;
+ }
+ },
+ HttpVerb::kGet | HttpVerb::kPost | HttpVerb::kDelete);
+
+ // Push a oplog container
+ m_Router.RegisterRoute(
+ "{project}/oplog/{log}/save",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+ const auto& ProjectId = Req.GetCapture(1);
+ const auto& OplogId = Req.GetCapture(2);
+ if (HttpReq.RequestContentType() != HttpContentType::kCbObject)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "Invalid content type");
+ }
+ IoBuffer Payload = Req.ServerRequest().ReadPayload();
+
+ CbObject Response;
+ std::pair<HttpResponseCode, std::string> Result = m_ProjectStore->WriteOplog(ProjectId, OplogId, std::move(Payload), Response);
+ if (Result.first == HttpResponseCode::OK)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::OK, Response);
+ }
+ if (Result.second.empty())
+ {
+ return HttpReq.WriteResponse(Result.first);
+ }
+ return HttpReq.WriteResponse(Result.first, HttpContentType::kText, Result.second);
+ },
+ HttpVerb::kPost);
+
+ // Pull a oplog container
+ m_Router.RegisterRoute(
+ "{project}/oplog/{log}/load",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+ const auto& ProjectId = Req.GetCapture(1);
+ const auto& OplogId = Req.GetCapture(2);
+ if (HttpReq.AcceptContentType() != HttpContentType::kCbObject)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "Invalid accept content type");
+ }
+ IoBuffer Payload = Req.ServerRequest().ReadPayload();
+
+ CbObject Response;
+ std::pair<HttpResponseCode, std::string> Result =
+ m_ProjectStore->ReadOplog(ProjectId, OplogId, Req.ServerRequest().GetQueryParams(), Response);
+ if (Result.first == HttpResponseCode::OK)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::OK, Response);
+ }
+ if (Result.second.empty())
+ {
+ return HttpReq.WriteResponse(Result.first);
+ }
+ return HttpReq.WriteResponse(Result.first, HttpContentType::kText, Result.second);
+ },
+ HttpVerb::kGet);
+
+ // Do an rpc style operation on project/oplog
+ m_Router.RegisterRoute(
+ "{project}/oplog/{log}/rpc",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+
+ const auto& ProjectId = Req.GetCapture(1);
+ const auto& OplogId = Req.GetCapture(2);
+ IoBuffer Payload = Req.ServerRequest().ReadPayload();
+
+ m_ProjectStore->Rpc(HttpReq, ProjectId, OplogId, std::move(Payload), m_AuthMgr);
+ },
+ HttpVerb::kPost);
+
+ m_Router.RegisterRoute(
+ "details\\$",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+
+ HttpServerRequest::QueryParams Params = HttpReq.GetQueryParams();
+ bool CSV = Params.GetValue("csv") == "true";
+ bool Details = Params.GetValue("details") == "true";
+ bool OpDetails = Params.GetValue("opdetails") == "true";
+ bool AttachmentDetails = Params.GetValue("attachmentdetails") == "true";
+
+ if (CSV)
+ {
+ ExtendableStringBuilder<4096> CSVWriter;
+ CSVHeader(Details, AttachmentDetails, CSVWriter);
+
+ m_ProjectStore->IterateProjects([&](ProjectStore::Project& Project) {
+ Project.IterateOplogs([&](ProjectStore::Oplog& Oplog) {
+ Oplog.IterateOplogWithKey(
+ [this, &Project, &Oplog, &CSVWriter, Details, AttachmentDetails](int LSN, const Oid& Key, CbObject Op) {
+ CSVWriteOp(m_CidStore,
+ Project.Identifier,
+ Oplog.OplogId(),
+ Details,
+ AttachmentDetails,
+ LSN,
+ Key,
+ Op,
+ CSVWriter);
+ });
+ });
+ });
+
+ HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, CSVWriter.ToView());
+ }
+ else
+ {
+ CbObjectWriter Cbo;
+ Cbo.BeginArray("projects");
+ {
+ m_ProjectStore->DiscoverProjects();
+
+ m_ProjectStore->IterateProjects([&](ProjectStore::Project& Project) {
+ std::vector<std::string> OpLogs = Project.ScanForOplogs();
+ CbWriteProject(m_CidStore, Project, OpLogs, Details, OpDetails, AttachmentDetails, Cbo);
+ });
+ }
+ Cbo.EndArray();
+ HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save());
+ }
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "details\\$/{project}",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+ const auto& ProjectId = Req.GetCapture(1);
+
+ HttpServerRequest::QueryParams Params = HttpReq.GetQueryParams();
+ bool CSV = Params.GetValue("csv") == "true";
+ bool Details = Params.GetValue("details") == "true";
+ bool OpDetails = Params.GetValue("opdetails") == "true";
+ bool AttachmentDetails = Params.GetValue("attachmentdetails") == "true";
+
+ Ref<ProjectStore::Project> FoundProject = m_ProjectStore->OpenProject(ProjectId);
+ if (!FoundProject)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound);
+ }
+ ProjectStore::Project& Project = *FoundProject.Get();
+ if (CSV)
+ {
+ ExtendableStringBuilder<4096> CSVWriter;
+ CSVHeader(Details, AttachmentDetails, CSVWriter);
+
+ FoundProject->IterateOplogs([&](ProjectStore::Oplog& Oplog) {
+ Oplog.IterateOplogWithKey([this, &Project, &Oplog, &CSVWriter, Details, AttachmentDetails](int LSN,
+ const Oid& Key,
+ CbObject Op) {
+ CSVWriteOp(m_CidStore, Project.Identifier, Oplog.OplogId(), Details, AttachmentDetails, LSN, Key, Op, CSVWriter);
+ });
+ });
+ HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, CSVWriter.ToView());
+ }
+ else
+ {
+ CbObjectWriter Cbo;
+ std::vector<std::string> OpLogs = FoundProject->ScanForOplogs();
+ Cbo.BeginArray("projects");
+ {
+ CbWriteProject(m_CidStore, Project, OpLogs, Details, OpDetails, AttachmentDetails, Cbo);
+ }
+ Cbo.EndArray();
+ HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save());
+ }
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "details\\$/{project}/{log}",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+ const auto& ProjectId = Req.GetCapture(1);
+ const auto& OplogId = Req.GetCapture(2);
+
+ HttpServerRequest::QueryParams Params = HttpReq.GetQueryParams();
+ bool CSV = Params.GetValue("csv") == "true";
+ bool Details = Params.GetValue("details") == "true";
+ bool OpDetails = Params.GetValue("opdetails") == "true";
+ bool AttachmentDetails = Params.GetValue("attachmentdetails") == "true";
+
+ Ref<ProjectStore::Project> FoundProject = m_ProjectStore->OpenProject(ProjectId);
+ if (!FoundProject)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound);
+ }
+ ProjectStore::Oplog* FoundLog = FoundProject->OpenOplog(OplogId);
+
+ if (!FoundLog)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound);
+ }
+
+ ProjectStore::Project& Project = *FoundProject.Get();
+ ProjectStore::Oplog& Oplog = *FoundLog;
+ if (CSV)
+ {
+ ExtendableStringBuilder<4096> CSVWriter;
+ CSVHeader(Details, AttachmentDetails, CSVWriter);
+
+ Oplog.IterateOplogWithKey(
+ [this, &Project, &Oplog, &CSVWriter, Details, AttachmentDetails](int LSN, const Oid& Key, CbObject Op) {
+ CSVWriteOp(m_CidStore, Project.Identifier, Oplog.OplogId(), Details, AttachmentDetails, LSN, Key, Op, CSVWriter);
+ });
+ HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, CSVWriter.ToView());
+ }
+ else
+ {
+ CbObjectWriter Cbo;
+ Cbo.BeginArray("oplogs");
+ {
+ CbWriteOplog(m_CidStore, Oplog, Details, OpDetails, AttachmentDetails, Cbo);
+ }
+ Cbo.EndArray();
+ HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save());
+ }
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "details\\$/{project}/{log}/{chunk}",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+ const auto& ProjectId = Req.GetCapture(1);
+ const auto& OplogId = Req.GetCapture(2);
+ const auto& ChunkId = Req.GetCapture(3);
+
+ HttpServerRequest::QueryParams Params = HttpReq.GetQueryParams();
+ bool CSV = Params.GetValue("csv") == "true";
+ bool Details = Params.GetValue("details") == "true";
+ bool OpDetails = Params.GetValue("opdetails") == "true";
+ bool AttachmentDetails = Params.GetValue("attachmentdetails") == "true";
+
+ Ref<ProjectStore::Project> FoundProject = m_ProjectStore->OpenProject(ProjectId);
+ if (!FoundProject)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound);
+ }
+ ProjectStore::Oplog* FoundLog = FoundProject->OpenOplog(OplogId);
+
+ if (!FoundLog)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound);
+ }
+
+ if (ChunkId.size() != 2 * sizeof(Oid::OidBits))
+ {
+ return HttpReq.WriteResponse(
+ HttpResponseCode::BadRequest,
+ HttpContentType::kText,
+ fmt::format("Chunk info request for invalid chunk id '{}/{}'/'{}'", ProjectId, OplogId, ChunkId));
+ }
+
+ const Oid ObjId = Oid::FromHexString(ChunkId);
+ ProjectStore::Project& Project = *FoundProject.Get();
+ ProjectStore::Oplog& Oplog = *FoundLog;
+
+ int LSN = Oplog.GetOpIndexByKey(ObjId);
+ if (LSN == -1)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound);
+ }
+ std::optional<CbObject> Op = Oplog.GetOpByIndex(LSN);
+ if (!Op.has_value())
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound);
+ }
+
+ if (CSV)
+ {
+ ExtendableStringBuilder<4096> CSVWriter;
+ CSVHeader(Details, AttachmentDetails, CSVWriter);
+
+ CSVWriteOp(m_CidStore, Project.Identifier, Oplog.OplogId(), Details, AttachmentDetails, LSN, ObjId, Op.value(), CSVWriter);
+ HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, CSVWriter.ToView());
+ }
+ else
+ {
+ CbObjectWriter Cbo;
+ Cbo.BeginArray("ops");
+ {
+ CbWriteOp(m_CidStore, Details, OpDetails, AttachmentDetails, LSN, ObjId, Op.value(), Cbo);
+ }
+ Cbo.EndArray();
+ HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save());
+ }
+ },
+ HttpVerb::kGet);
+}
+
+HttpProjectService::~HttpProjectService()
+{
+ m_StatsService.UnregisterHandler("prj", *this);
+}
+
+const char*
+HttpProjectService::BaseUri() const
+{
+ return "/prj/";
+}
+
+void
+HttpProjectService::HandleRequest(HttpServerRequest& Request)
+{
+ if (m_Router.HandleRequest(Request) == false)
+ {
+ ZEN_WARN("No route found for {0}", Request.RelativeUri());
+ }
+}
+
+void
+HttpProjectService::HandleStatsRequest(HttpServerRequest& HttpReq)
+{
+ const GcStorageSize StoreSize = m_ProjectStore->StorageSize();
+ const CidStoreSize CidSize = m_CidStore.TotalSize();
+
+ CbObjectWriter Cbo;
+ Cbo.BeginObject("store");
+ {
+ Cbo.BeginObject("size");
+ {
+ Cbo << "disk" << StoreSize.DiskSize;
+ Cbo << "memory" << StoreSize.MemorySize;
+ }
+ Cbo.EndObject();
+ }
+ Cbo.EndObject();
+
+ Cbo.BeginObject("cid");
+ {
+ Cbo.BeginObject("size");
+ {
+ Cbo << "tiny" << CidSize.TinySize;
+ Cbo << "small" << CidSize.SmallSize;
+ Cbo << "large" << CidSize.LargeSize;
+ Cbo << "total" << CidSize.TotalSize;
+ }
+ Cbo.EndObject();
+ }
+ Cbo.EndObject();
+
+ return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save());
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+#if ZEN_WITH_TESTS
+
+namespace testutils {
+ using namespace std::literals;
+
+ std::string OidAsString(const Oid& Id)
+ {
+ StringBuilder<25> OidStringBuilder;
+ Id.ToString(OidStringBuilder);
+ return OidStringBuilder.ToString();
+ }
+
+ CbPackage CreateOplogPackage(const Oid& Id, const std::span<const std::pair<Oid, CompressedBuffer>>& Attachments)
+ {
+ CbPackage Package;
+ CbObjectWriter Object;
+ Object << "key"sv << OidAsString(Id);
+ if (!Attachments.empty())
+ {
+ Object.BeginArray("bulkdata");
+ for (const auto& Attachment : Attachments)
+ {
+ CbAttachment Attach(Attachment.second, Attachment.second.DecodeRawHash());
+ Object.BeginObject();
+ Object << "id"sv << Attachment.first;
+ Object << "type"sv
+ << "Standard"sv;
+ Object << "data"sv << Attach;
+ Object.EndObject();
+
+ Package.AddAttachment(Attach);
+ }
+ Object.EndArray();
+ }
+ Package.SetObject(Object.Save());
+ return Package;
+ };
+
+ std::vector<std::pair<Oid, CompressedBuffer>> CreateAttachments(const std::span<const size_t>& Sizes)
+ {
+ std::vector<std::pair<Oid, CompressedBuffer>> Result;
+ Result.reserve(Sizes.size());
+ for (size_t Size : Sizes)
+ {
+ std::vector<uint8_t> Data;
+ Data.resize(Size);
+ uint16_t* DataPtr = reinterpret_cast<uint16_t*>(Data.data());
+ for (size_t Idx = 0; Idx < Size / 2; ++Idx)
+ {
+ DataPtr[Idx] = static_cast<uint16_t>(Idx % 0xffffu);
+ }
+ if (Size & 1)
+ {
+ Data[Size - 1] = static_cast<uint8_t>((Size - 1) & 0xff);
+ }
+ CompressedBuffer Compressed = CompressedBuffer::Compress(SharedBuffer::MakeView(Data.data(), Data.size()));
+ Result.emplace_back(std::pair<Oid, CompressedBuffer>(Oid::NewOid(), Compressed));
+ }
+ return Result;
+ }
+
+ uint64 GetCompressedOffset(const CompressedBuffer& Buffer, uint64 RawOffset)
+ {
+ if (RawOffset > 0)
+ {
+ uint64 BlockSize = 0;
+ OodleCompressor Compressor;
+ OodleCompressionLevel CompressionLevel;
+ if (!Buffer.TryGetCompressParameters(Compressor, CompressionLevel, BlockSize))
+ {
+ return 0;
+ }
+ return BlockSize > 0 ? RawOffset % BlockSize : 0;
+ }
+ return 0;
+ }
+
+} // namespace testutils
+
+TEST_CASE("project.store.create")
+{
+ using namespace std::literals;
+
+ ScopedTemporaryDirectory TempDir;
+
+ GcManager Gc;
+ CidStore CidStore(Gc);
+ CidStoreConfiguration CidConfig = {.RootDirectory = TempDir.Path() / "cas", .TinyValueThreshold = 1024, .HugeValueThreshold = 4096};
+ CidStore.Initialize(CidConfig);
+
+ std::string_view ProjectName("proj1"sv);
+ std::filesystem::path BasePath = TempDir.Path() / "projectstore";
+ ProjectStore ProjectStore(CidStore, BasePath, Gc);
+ std::filesystem::path RootDir = TempDir.Path() / "root";
+ std::filesystem::path EngineRootDir = TempDir.Path() / "engine";
+ std::filesystem::path ProjectRootDir = TempDir.Path() / "game";
+ std::filesystem::path ProjectFilePath = TempDir.Path() / "game" / "game.uproject";
+
+ Ref<ProjectStore::Project> Project(ProjectStore.NewProject(BasePath / ProjectName,
+ ProjectName,
+ RootDir.string(),
+ EngineRootDir.string(),
+ ProjectRootDir.string(),
+ ProjectFilePath.string()));
+ CHECK(ProjectStore.DeleteProject(ProjectName));
+ CHECK(!Project->Exists(BasePath));
+}
+
+TEST_CASE("project.store.lifetimes")
+{
+ using namespace std::literals;
+
+ ScopedTemporaryDirectory TempDir;
+
+ GcManager Gc;
+ CidStore CidStore(Gc);
+ CidStoreConfiguration CidConfig = {.RootDirectory = TempDir.Path() / "cas", .TinyValueThreshold = 1024, .HugeValueThreshold = 4096};
+ CidStore.Initialize(CidConfig);
+
+ std::filesystem::path BasePath = TempDir.Path() / "projectstore";
+ ProjectStore ProjectStore(CidStore, BasePath, Gc);
+ std::filesystem::path RootDir = TempDir.Path() / "root";
+ std::filesystem::path EngineRootDir = TempDir.Path() / "engine";
+ std::filesystem::path ProjectRootDir = TempDir.Path() / "game";
+ std::filesystem::path ProjectFilePath = TempDir.Path() / "game" / "game.uproject";
+
+ Ref<ProjectStore::Project> Project(ProjectStore.NewProject(BasePath / "proj1"sv,
+ "proj1"sv,
+ RootDir.string(),
+ EngineRootDir.string(),
+ ProjectRootDir.string(),
+ ProjectFilePath.string()));
+ ProjectStore::Oplog* Oplog = Project->NewOplog("oplog1", {});
+ CHECK(Oplog != nullptr);
+
+ std::filesystem::path DeletePath;
+ CHECK(Project->PrepareForDelete(DeletePath));
+ CHECK(!DeletePath.empty());
+ CHECK(Project->OpenOplog("oplog1") == nullptr);
+ // Oplog is now invalid, but pointer can still be accessed since we store old oplog pointers
+ CHECK(Oplog->OplogCount() == 0);
+ // Project is still valid since we have a Ref to it
+ CHECK(Project->Identifier == "proj1"sv);
+}
+
+TEST_CASE("project.store.gc")
+{
+ using namespace std::literals;
+ using namespace testutils;
+
+ ScopedTemporaryDirectory TempDir;
+
+ GcManager Gc;
+ CidStore CidStore(Gc);
+ CidStoreConfiguration CidConfig = {.RootDirectory = TempDir.Path() / "cas", .TinyValueThreshold = 1024, .HugeValueThreshold = 4096};
+ CidStore.Initialize(CidConfig);
+
+ std::filesystem::path BasePath = TempDir.Path() / "projectstore";
+ ProjectStore ProjectStore(CidStore, BasePath, Gc);
+ std::filesystem::path RootDir = TempDir.Path() / "root";
+ std::filesystem::path EngineRootDir = TempDir.Path() / "engine";
+
+ std::filesystem::path Project1RootDir = TempDir.Path() / "game1";
+ std::filesystem::path Project1FilePath = TempDir.Path() / "game1" / "game.uproject";
+ {
+ CreateDirectories(Project1FilePath.parent_path());
+ BasicFile ProjectFile;
+ ProjectFile.Open(Project1FilePath, BasicFile::Mode::kTruncate);
+ }
+
+ std::filesystem::path Project2RootDir = TempDir.Path() / "game2";
+ std::filesystem::path Project2FilePath = TempDir.Path() / "game2" / "game.uproject";
+ {
+ CreateDirectories(Project2FilePath.parent_path());
+ BasicFile ProjectFile;
+ ProjectFile.Open(Project2FilePath, BasicFile::Mode::kTruncate);
+ }
+
+ {
+ Ref<ProjectStore::Project> Project1(ProjectStore.NewProject(BasePath / "proj1"sv,
+ "proj1"sv,
+ RootDir.string(),
+ EngineRootDir.string(),
+ Project1RootDir.string(),
+ Project1FilePath.string()));
+ ProjectStore::Oplog* Oplog = Project1->NewOplog("oplog1", {});
+ CHECK(Oplog != nullptr);
+
+ Oplog->AppendNewOplogEntry(CreateOplogPackage(Oid::NewOid(), {}));
+ Oplog->AppendNewOplogEntry(CreateOplogPackage(Oid::NewOid(), CreateAttachments(std::initializer_list<size_t>{77})));
+ Oplog->AppendNewOplogEntry(CreateOplogPackage(Oid::NewOid(), CreateAttachments(std::initializer_list<size_t>{7123, 583, 690, 99})));
+ Oplog->AppendNewOplogEntry(CreateOplogPackage(Oid::NewOid(), CreateAttachments(std::initializer_list<size_t>{55, 122})));
+ }
+
+ {
+ Ref<ProjectStore::Project> Project2(ProjectStore.NewProject(BasePath / "proj2"sv,
+ "proj2"sv,
+ RootDir.string(),
+ EngineRootDir.string(),
+ Project2RootDir.string(),
+ Project2FilePath.string()));
+ ProjectStore::Oplog* Oplog = Project2->NewOplog("oplog1", {});
+ CHECK(Oplog != nullptr);
+
+ Oplog->AppendNewOplogEntry(CreateOplogPackage(Oid::NewOid(), {}));
+ Oplog->AppendNewOplogEntry(CreateOplogPackage(Oid::NewOid(), CreateAttachments(std::initializer_list<size_t>{177})));
+ Oplog->AppendNewOplogEntry(CreateOplogPackage(Oid::NewOid(), CreateAttachments(std::initializer_list<size_t>{9123, 383, 590, 96})));
+ Oplog->AppendNewOplogEntry(CreateOplogPackage(Oid::NewOid(), CreateAttachments(std::initializer_list<size_t>{535, 221})));
+ }
+
+ {
+ GcContext GcCtx(GcClock::Now() - std::chrono::hours(24));
+ ProjectStore.GatherReferences(GcCtx);
+ size_t RefCount = 0;
+ GcCtx.IterateCids([&RefCount](const IoHash&) { RefCount++; });
+ CHECK(RefCount == 14);
+ ProjectStore.CollectGarbage(GcCtx);
+ CHECK(ProjectStore.OpenProject("proj1"sv));
+ CHECK(ProjectStore.OpenProject("proj2"sv));
+ }
+
+ std::filesystem::remove(Project1FilePath);
+
+ {
+ GcContext GcCtx(GcClock::Now() - std::chrono::hours(24));
+ ProjectStore.GatherReferences(GcCtx);
+ size_t RefCount = 0;
+ GcCtx.IterateCids([&RefCount](const IoHash&) { RefCount++; });
+ CHECK(RefCount == 7);
+ ProjectStore.CollectGarbage(GcCtx);
+ CHECK(!ProjectStore.OpenProject("proj1"sv));
+ CHECK(ProjectStore.OpenProject("proj2"sv));
+ }
+}
+
+TEST_CASE("project.store.partial.read")
+{
+ using namespace std::literals;
+ using namespace testutils;
+
+ ScopedTemporaryDirectory TempDir;
+
+ GcManager Gc;
+ CidStore CidStore(Gc);
+ CidStoreConfiguration CidConfig = {.RootDirectory = TempDir.Path() / "cas"sv, .TinyValueThreshold = 1024, .HugeValueThreshold = 4096};
+ CidStore.Initialize(CidConfig);
+
+ std::filesystem::path BasePath = TempDir.Path() / "projectstore"sv;
+ ProjectStore ProjectStore(CidStore, BasePath, Gc);
+ std::filesystem::path RootDir = TempDir.Path() / "root"sv;
+ std::filesystem::path EngineRootDir = TempDir.Path() / "engine"sv;
+
+ std::filesystem::path Project1RootDir = TempDir.Path() / "game1"sv;
+ std::filesystem::path Project1FilePath = TempDir.Path() / "game1"sv / "game.uproject"sv;
+ {
+ CreateDirectories(Project1FilePath.parent_path());
+ BasicFile ProjectFile;
+ ProjectFile.Open(Project1FilePath, BasicFile::Mode::kTruncate);
+ }
+
+ std::vector<Oid> OpIds;
+ OpIds.insert(OpIds.end(), {Oid::NewOid(), Oid::NewOid(), Oid::NewOid(), Oid::NewOid()});
+ std::unordered_map<Oid, std::vector<std::pair<Oid, CompressedBuffer>>, Oid::Hasher> Attachments;
+ {
+ Ref<ProjectStore::Project> Project1(ProjectStore.NewProject(BasePath / "proj1"sv,
+ "proj1"sv,
+ RootDir.string(),
+ EngineRootDir.string(),
+ Project1RootDir.string(),
+ Project1FilePath.string()));
+ ProjectStore::Oplog* Oplog = Project1->NewOplog("oplog1"sv, {});
+ CHECK(Oplog != nullptr);
+ Attachments[OpIds[0]] = {};
+ Attachments[OpIds[1]] = CreateAttachments(std::initializer_list<size_t>{77});
+ Attachments[OpIds[2]] = CreateAttachments(std::initializer_list<size_t>{7123, 9583, 690, 99});
+ Attachments[OpIds[3]] = CreateAttachments(std::initializer_list<size_t>{55, 122});
+ for (auto It : Attachments)
+ {
+ Oplog->AppendNewOplogEntry(CreateOplogPackage(It.first, It.second));
+ }
+ }
+ {
+ IoBuffer Chunk;
+ CHECK(ProjectStore
+ .GetChunk("proj1"sv,
+ "oplog1"sv,
+ Attachments[OpIds[1]][0].second.DecodeRawHash().ToHexString(),
+ HttpContentType::kCompressedBinary,
+ Chunk)
+ .first == HttpResponseCode::OK);
+ IoHash RawHash;
+ uint64_t RawSize;
+ CompressedBuffer Attachment = CompressedBuffer::FromCompressed(SharedBuffer(Chunk), RawHash, RawSize);
+ CHECK(RawSize == Attachments[OpIds[1]][0].second.DecodeRawSize());
+ }
+
+ IoBuffer ChunkResult;
+ CHECK(ProjectStore
+ .GetChunkRange("proj1"sv,
+ "oplog1"sv,
+ OidAsString(Attachments[OpIds[2]][1].first),
+ 0,
+ ~0ull,
+ HttpContentType::kCompressedBinary,
+ ChunkResult)
+ .first == HttpResponseCode::OK);
+ CHECK(ChunkResult);
+ CHECK(CompressedBuffer::FromCompressedNoValidate(std::move(ChunkResult)).DecodeRawSize() ==
+ Attachments[OpIds[2]][1].second.DecodeRawSize());
+
+ IoBuffer PartialChunkResult;
+ CHECK(ProjectStore
+ .GetChunkRange("proj1"sv,
+ "oplog1"sv,
+ OidAsString(Attachments[OpIds[2]][1].first),
+ 5,
+ 1773,
+ HttpContentType::kCompressedBinary,
+ PartialChunkResult)
+ .first == HttpResponseCode::OK);
+ CHECK(PartialChunkResult);
+ IoHash PartialRawHash;
+ uint64_t PartialRawSize;
+ CompressedBuffer PartialCompressedResult =
+ CompressedBuffer::FromCompressed(SharedBuffer(PartialChunkResult), PartialRawHash, PartialRawSize);
+ CHECK(PartialRawSize >= 1773);
+
+ uint64_t RawOffsetInPartialCompressed = GetCompressedOffset(PartialCompressedResult, 5);
+ SharedBuffer PartialDecompressed = PartialCompressedResult.Decompress(RawOffsetInPartialCompressed);
+ SharedBuffer FullDecompressed = Attachments[OpIds[2]][1].second.Decompress();
+ const uint8_t* FullDataPtr = &(reinterpret_cast<const uint8_t*>(FullDecompressed.GetView().GetData())[5]);
+ const uint8_t* PartialDataPtr = reinterpret_cast<const uint8_t*>(PartialDecompressed.GetView().GetData());
+ CHECK(FullDataPtr[0] == PartialDataPtr[0]);
+}
+
+TEST_CASE("project.store.block")
+{
+ using namespace std::literals;
+ using namespace testutils;
+
+ std::vector<std::size_t> AttachmentSizes({7633, 6825, 5738, 8031, 7225, 566, 3656, 6006, 24, 3466, 1093, 4269, 2257, 3685, 3489,
+ 7194, 6151, 5482, 6217, 3511, 6738, 5061, 7537, 2759, 1916, 8210, 2235, 4024, 1582, 5251,
+ 491, 5464, 4607, 8135, 3767, 4045, 4415, 5007, 8876, 6761, 3359, 8526, 4097, 4855, 8225});
+
+ std::vector<std::pair<Oid, CompressedBuffer>> AttachmentsWithId = CreateAttachments(AttachmentSizes);
+ std::vector<SharedBuffer> Chunks;
+ Chunks.reserve(AttachmentSizes.size());
+ for (const auto& It : AttachmentsWithId)
+ {
+ Chunks.push_back(It.second.GetCompressed().Flatten());
+ }
+ CompressedBuffer Block = GenerateBlock(std::move(Chunks));
+ IoBuffer BlockBuffer = Block.GetCompressed().Flatten().AsIoBuffer();
+ CHECK(IterateBlock(std::move(BlockBuffer), [](CompressedBuffer&&, const IoHash&) {}));
+}
+
+#endif
+
+void
+prj_forcelink()
+{
+}
+
+} // namespace zen
diff --git a/src/zenserver/projectstore/projectstore.h b/src/zenserver/projectstore/projectstore.h
new file mode 100644
index 000000000..e4f664b85
--- /dev/null
+++ b/src/zenserver/projectstore/projectstore.h
@@ -0,0 +1,372 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/uid.h>
+#include <zencore/xxhash.h>
+#include <zenhttp/httpserver.h>
+#include <zenstore/gc.h>
+
+#include "monitoring/httpstats.h"
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <tsl/robin_map.h>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+namespace zen {
+
+class CbPackage;
+class CidStore;
+class AuthMgr;
+class ScrubContext;
+
+struct OplogEntry
+{
+ uint32_t OpLsn;
+ uint32_t OpCoreOffset; // note: Multiple of alignment!
+ uint32_t OpCoreSize;
+ uint32_t OpCoreHash; // Used as checksum
+ XXH3_128 OpKeyHash; // XXH128_canonical_t
+
+ inline Oid OpKeyAsOId() const
+ {
+ Oid Id;
+ memcpy(Id.OidBits, &OpKeyHash, sizeof Id.OidBits);
+ return Id;
+ }
+};
+
+struct OplogEntryAddress
+{
+ uint64_t Offset;
+ uint64_t Size;
+};
+
+static_assert(IsPow2(sizeof(OplogEntry)));
+
+/** Project Store
+
+ A project store consists of a number of Projects.
+
+ Each project contains a number of oplogs (short for "operation log"). UE uses
+ one oplog per target platform to store the output of the cook process.
+
+ An oplog consists of a sequence of "op" entries. Each entry is a structured object
+ containing references to attachments. Attachments are typically the serialized
+ package data split into separate chunks for bulk data, exports and header
+ information.
+ */
+class ProjectStore : public RefCounted, public GcStorage, public GcContributor
+{
+ struct OplogStorage;
+
+public:
+ ProjectStore(CidStore& Store, std::filesystem::path BasePath, GcManager& Gc);
+ ~ProjectStore();
+
+ struct Project;
+
+ struct Oplog
+ {
+ Oplog(std::string_view Id,
+ Project* Project,
+ CidStore& Store,
+ std::filesystem::path BasePath,
+ const std::filesystem::path& MarkerPath);
+ ~Oplog();
+
+ [[nodiscard]] static bool ExistsAt(std::filesystem::path BasePath);
+
+ void Read();
+ void Write();
+
+ void IterateFileMap(std::function<void(const Oid&, const std::string_view& ServerPath, const std::string_view& ClientPath)>&& Fn);
+ void IterateOplog(std::function<void(CbObject)>&& Fn);
+ void IterateOplogWithKey(std::function<void(int, const Oid&, CbObject)>&& Fn);
+ std::optional<CbObject> GetOpByKey(const Oid& Key);
+ std::optional<CbObject> GetOpByIndex(int Index);
+ int GetOpIndexByKey(const Oid& Key);
+
+ IoBuffer FindChunk(Oid ChunkId);
+
+ inline static const uint32_t kInvalidOp = ~0u;
+
+ /** Persist a new oplog entry
+ *
+ * Returns the oplog LSN assigned to the new entry, or kInvalidOp if the entry is rejected
+ */
+ uint32_t AppendNewOplogEntry(CbPackage Op);
+
+ uint32_t AppendNewOplogEntry(CbObject Core);
+
+ enum UpdateType
+ {
+ kUpdateNewEntry,
+ kUpdateReplay
+ };
+
+ const std::string& OplogId() const { return m_OplogId; }
+
+ const std::filesystem::path& TempPath() const { return m_TempPath; }
+ const std::filesystem::path& MarkerPath() const { return m_MarkerPath; }
+
+ spdlog::logger& Log() { return m_OuterProject->Log(); }
+ void Flush();
+ void Scrub(ScrubContext& Ctx) const;
+ void GatherReferences(GcContext& GcCtx);
+ uint64_t TotalSize() const;
+
+ std::size_t OplogCount() const
+ {
+ RwLock::SharedLockScope _(m_OplogLock);
+ return m_LatestOpMap.size();
+ }
+
+ bool IsExpired() const;
+ std::filesystem::path PrepareForDelete(bool MoveFolder);
+
+ private:
+ struct FileMapEntry
+ {
+ std::string ServerPath;
+ std::string ClientPath;
+ };
+
+ template<class V>
+ using OidMap = tsl::robin_map<Oid, V, Oid::Hasher>;
+
+ Project* m_OuterProject = nullptr;
+ CidStore& m_CidStore;
+ std::filesystem::path m_BasePath;
+ std::filesystem::path m_MarkerPath;
+ std::filesystem::path m_TempPath;
+
+ mutable RwLock m_OplogLock;
+ OidMap<IoHash> m_ChunkMap; // output data chunk id -> CAS address
+ OidMap<IoHash> m_MetaMap; // meta chunk id -> CAS address
+ OidMap<FileMapEntry> m_FileMap; // file id -> file map entry
+ int32_t m_ManifestVersion; // File system manifest version
+ tsl::robin_map<int, OplogEntryAddress> m_OpAddressMap; // Index LSN -> op data in ops blob file
+ OidMap<int> m_LatestOpMap; // op key -> latest op LSN for key
+
+ RefPtr<OplogStorage> m_Storage;
+ std::string m_OplogId;
+
+ /** Scan oplog and register each entry, thus updating the in-memory tracking tables
+ */
+ void ReplayLog();
+
+ struct OplogEntryMapping
+ {
+ struct Mapping
+ {
+ Oid Id;
+ IoHash Hash;
+ };
+ struct FileMapping : public Mapping
+ {
+ std::string ServerPath;
+ std::string ClientPath;
+ };
+ std::vector<Mapping> Chunks;
+ std::vector<Mapping> Meta;
+ std::vector<FileMapping> Files;
+ };
+
+ OplogEntryMapping GetMapping(CbObject Core);
+
+ /** Update tracking metadata for a new oplog entry
+ *
+ * This is used during replay (and gets called as part of new op append)
+ *
+ * Returns the oplog LSN assigned to the new entry, or kInvalidOp if the entry is rejected
+ */
+ uint32_t RegisterOplogEntry(RwLock::ExclusiveLockScope& OplogLock,
+ const OplogEntryMapping& OpMapping,
+ const OplogEntry& OpEntry,
+ UpdateType TypeOfUpdate);
+
+ void AddFileMapping(const RwLock::ExclusiveLockScope& OplogLock,
+ Oid FileId,
+ IoHash Hash,
+ std::string_view ServerPath,
+ std::string_view ClientPath);
+ void AddChunkMapping(const RwLock::ExclusiveLockScope& OplogLock, Oid ChunkId, IoHash Hash);
+ void AddMetaMapping(const RwLock::ExclusiveLockScope& OplogLock, Oid ChunkId, IoHash Hash);
+ };
+
+ struct Project : public RefCounted
+ {
+ std::string Identifier;
+ std::filesystem::path RootDir;
+ std::string EngineRootDir;
+ std::string ProjectRootDir;
+ std::string ProjectFilePath;
+
+ Oplog* NewOplog(std::string_view OplogId, const std::filesystem::path& MarkerPath);
+ Oplog* OpenOplog(std::string_view OplogId);
+ void DeleteOplog(std::string_view OplogId);
+ void IterateOplogs(std::function<void(const Oplog&)>&& Fn) const;
+ void IterateOplogs(std::function<void(Oplog&)>&& Fn);
+ std::vector<std::string> ScanForOplogs() const;
+ bool IsExpired() const;
+
+ Project(ProjectStore* PrjStore, CidStore& Store, std::filesystem::path BasePath);
+ virtual ~Project();
+
+ void Read();
+ void Write();
+ [[nodiscard]] static bool Exists(std::filesystem::path BasePath);
+ void Flush();
+ void Scrub(ScrubContext& Ctx);
+ spdlog::logger& Log();
+ void GatherReferences(GcContext& GcCtx);
+ uint64_t TotalSize() const;
+ bool PrepareForDelete(std::filesystem::path& OutDeletePath);
+
+ private:
+ ProjectStore* m_ProjectStore;
+ CidStore& m_CidStore;
+ mutable RwLock m_ProjectLock;
+ std::map<std::string, std::unique_ptr<Oplog>> m_Oplogs;
+ std::vector<std::unique_ptr<Oplog>> m_DeletedOplogs;
+ std::filesystem::path m_OplogStoragePath;
+
+ std::filesystem::path BasePathForOplog(std::string_view OplogId);
+ };
+
+ // Oplog* OpenProjectOplog(std::string_view ProjectId, std::string_view OplogId);
+
+ Ref<Project> OpenProject(std::string_view ProjectId);
+ Ref<Project> NewProject(std::filesystem::path BasePath,
+ std::string_view ProjectId,
+ std::string_view RootDir,
+ std::string_view EngineRootDir,
+ std::string_view ProjectRootDir,
+ std::string_view ProjectFilePath);
+ bool DeleteProject(std::string_view ProjectId);
+ bool Exists(std::string_view ProjectId);
+ void Flush();
+ void Scrub(ScrubContext& Ctx);
+ void DiscoverProjects();
+ void IterateProjects(std::function<void(Project& Prj)>&& Fn);
+
+ spdlog::logger& Log() { return m_Log; }
+ const std::filesystem::path& BasePath() const { return m_ProjectBasePath; }
+
+ virtual void GatherReferences(GcContext& GcCtx) override;
+ virtual void CollectGarbage(GcContext& GcCtx) override;
+ virtual GcStorageSize StorageSize() const override;
+
+ CbArray GetProjectsList();
+ std::pair<HttpResponseCode, std::string> GetProjectFiles(const std::string_view ProjectId,
+ const std::string_view OplogId,
+ bool FilterClient,
+ CbObject& OutPayload);
+ std::pair<HttpResponseCode, std::string> GetChunkInfo(const std::string_view ProjectId,
+ const std::string_view OplogId,
+ const std::string_view ChunkId,
+ CbObject& OutPayload);
+ std::pair<HttpResponseCode, std::string> GetChunkRange(const std::string_view ProjectId,
+ const std::string_view OplogId,
+ const std::string_view ChunkId,
+ uint64_t Offset,
+ uint64_t Size,
+ ZenContentType AcceptType,
+ IoBuffer& OutChunk);
+ std::pair<HttpResponseCode, std::string> GetChunk(const std::string_view ProjectId,
+ const std::string_view OplogId,
+ const std::string_view Cid,
+ ZenContentType AcceptType,
+ IoBuffer& OutChunk);
+
+ std::pair<HttpResponseCode, std::string> PutChunk(const std::string_view ProjectId,
+ const std::string_view OplogId,
+ const std::string_view Cid,
+ ZenContentType ContentType,
+ IoBuffer&& Chunk);
+
+ std::pair<HttpResponseCode, std::string> WriteOplog(const std::string_view ProjectId,
+ const std::string_view OplogId,
+ IoBuffer&& Payload,
+ CbObject& OutResponse);
+
+ std::pair<HttpResponseCode, std::string> ReadOplog(const std::string_view ProjectId,
+ const std::string_view OplogId,
+ const HttpServerRequest::QueryParams& Params,
+ CbObject& OutResponse);
+
+ std::pair<HttpResponseCode, std::string> WriteBlock(const std::string_view ProjectId,
+ const std::string_view OplogId,
+ IoBuffer&& Payload);
+
+ void Rpc(HttpServerRequest& HttpReq,
+ const std::string_view ProjectId,
+ const std::string_view OplogId,
+ IoBuffer&& Payload,
+ AuthMgr& AuthManager);
+
+ std::pair<HttpResponseCode, std::string> Export(ProjectStore::Project& Project,
+ ProjectStore::Oplog& Oplog,
+ CbObjectView&& Params,
+ AuthMgr& AuthManager);
+
+ std::pair<HttpResponseCode, std::string> Import(ProjectStore::Project& Project,
+ ProjectStore::Oplog& Oplog,
+ CbObjectView&& Params,
+ AuthMgr& AuthManager);
+
+private:
+ spdlog::logger& m_Log;
+ CidStore& m_CidStore;
+ std::filesystem::path m_ProjectBasePath;
+ mutable RwLock m_ProjectsLock;
+ std::map<std::string, Ref<Project>> m_Projects;
+
+ std::filesystem::path BasePathForProject(std::string_view ProjectId);
+};
+
+//////////////////////////////////////////////////////////////////////////
+//
+// {project} a project identifier
+// {target} a variation of the project, typically a build target
+// {lsn} oplog entry sequence number
+//
+// /prj/{project}
+// /prj/{project}/oplog/{target}
+// /prj/{project}/oplog/{target}/{lsn}
+//
+// oplog entry
+//
+// id: {id}
+// key: {}
+// meta: {}
+// data: []
+// refs:
+//
+
+class HttpProjectService : public HttpService, public IHttpStatsProvider
+{
+public:
+ HttpProjectService(CidStore& Store, ProjectStore* InProjectStore, HttpStatsService& StatsService, AuthMgr& AuthMgr);
+ ~HttpProjectService();
+
+ virtual const char* BaseUri() const override;
+ virtual void HandleRequest(HttpServerRequest& Request) override;
+
+ virtual void HandleStatsRequest(HttpServerRequest& Request) override;
+
+private:
+ inline spdlog::logger& Log() { return m_Log; }
+
+ spdlog::logger& m_Log;
+ CidStore& m_CidStore;
+ HttpRequestRouter m_Router;
+ Ref<ProjectStore> m_ProjectStore;
+ HttpStatsService& m_StatsService;
+ AuthMgr& m_AuthMgr;
+};
+
+void prj_forcelink();
+
+} // namespace zen
diff --git a/src/zenserver/projectstore/remoteprojectstore.cpp b/src/zenserver/projectstore/remoteprojectstore.cpp
new file mode 100644
index 000000000..1e6ca51a1
--- /dev/null
+++ b/src/zenserver/projectstore/remoteprojectstore.cpp
@@ -0,0 +1,1036 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "remoteprojectstore.h"
+
+#include <zencore/compactbinarybuilder.h>
+#include <zencore/compress.h>
+#include <zencore/filesystem.h>
+#include <zencore/fmtutils.h>
+#include <zencore/logging.h>
+#include <zencore/scopeguard.h>
+#include <zencore/stream.h>
+#include <zencore/timer.h>
+#include <zencore/workthreadpool.h>
+#include <zenstore/cidstore.h>
+
+namespace zen {
+
+/*
+ OplogContainer
+ Binary("ops") // Compressed CompactBinary object to hide attachment references, also makes the oplog smaller
+ {
+ CbArray("ops")
+ {
+ CbObject Op
+ (CbFieldType::BinaryAttachment Attachments[])
+ (OpData)
+ }
+ }
+ CbArray("blocks")
+ CbObject
+ CbFieldType::BinaryAttachment "rawhash" // Optional, only if we are creating blocks (Jupiter/File)
+ CbArray("chunks")
+ CbFieldType::Hash // Chunk hashes
+ CbArray("chunks") // Optional, only if we are not creating blocks (Zen)
+ CbFieldType::BinaryAttachment // Chunk attachment hashes
+
+ CompressedBinary ChunkBlock
+ {
+ VarUInt ChunkCount
+ VarUInt ChunkSizes[ChunkCount]
+ uint8_t[chunksize])[ChunkCount]
+ }
+*/
+
+////////////////////////////// AsyncRemoteResult
+
+struct AsyncRemoteResult
+{
+ void SetError(int32_t ErrorCode, const std::string& ErrorReason, const std::string ErrorText)
+ {
+ int32_t Expected = 0;
+ if (m_ErrorCode.compare_exchange_weak(Expected, ErrorCode ? ErrorCode : -1))
+ {
+ m_ErrorReason = ErrorReason;
+ m_ErrorText = ErrorText;
+ }
+ }
+ bool IsError() const { return m_ErrorCode.load() != 0; }
+ int GetError() const { return m_ErrorCode.load(); };
+ const std::string& GetErrorReason() const { return m_ErrorReason; };
+ const std::string& GetErrorText() const { return m_ErrorText; };
+ RemoteProjectStore::Result ConvertResult(double ElapsedSeconds = 0.0) const
+ {
+ return RemoteProjectStore::Result{m_ErrorCode, ElapsedSeconds, m_ErrorReason, m_ErrorText};
+ }
+
+private:
+ std::atomic<int32_t> m_ErrorCode = 0;
+ std::string m_ErrorReason;
+ std::string m_ErrorText;
+};
+
+bool
+IterateBlock(IoBuffer&& CompressedBlock, std::function<void(CompressedBuffer&& Chunk, const IoHash& AttachmentHash)> Visitor)
+{
+ IoBuffer BlockPayload = CompressedBuffer::FromCompressedNoValidate(std::move(CompressedBlock)).Decompress().AsIoBuffer();
+
+ MemoryView BlockView = BlockPayload.GetView();
+ const uint8_t* ReadPtr = reinterpret_cast<const uint8_t*>(BlockView.GetData());
+ uint32_t NumberSize;
+ uint64_t ChunkCount = ReadVarUInt(ReadPtr, NumberSize);
+ ReadPtr += NumberSize;
+ std::vector<uint64_t> ChunkSizes;
+ ChunkSizes.reserve(ChunkCount);
+ while (ChunkCount--)
+ {
+ ChunkSizes.push_back(ReadVarUInt(ReadPtr, NumberSize));
+ ReadPtr += NumberSize;
+ }
+ ptrdiff_t TempBufferLength = std::distance(reinterpret_cast<const uint8_t*>(BlockView.GetData()), ReadPtr);
+ ZEN_ASSERT(TempBufferLength > 0);
+ for (uint64_t ChunkSize : ChunkSizes)
+ {
+ IoBuffer Chunk(IoBuffer::Wrap, ReadPtr, ChunkSize);
+ IoHash AttachmentRawHash;
+ uint64_t AttachmentRawSize;
+ CompressedBuffer CompressedChunk = CompressedBuffer::FromCompressed(SharedBuffer(Chunk), AttachmentRawHash, AttachmentRawSize);
+
+ if (!CompressedChunk)
+ {
+ ZEN_ERROR("Invalid chunk in block");
+ return false;
+ }
+ Visitor(std::move(CompressedChunk), AttachmentRawHash);
+ ReadPtr += ChunkSize;
+ ZEN_ASSERT(ReadPtr <= BlockView.GetDataEnd());
+ }
+ return true;
+};
+
+CompressedBuffer
+GenerateBlock(std::vector<SharedBuffer>&& Chunks)
+{
+ size_t ChunkCount = Chunks.size();
+ SharedBuffer SizeBuffer;
+ {
+ IoBuffer TempBuffer(ChunkCount * 9);
+ MutableMemoryView View = TempBuffer.GetMutableView();
+ uint8_t* BufferStartPtr = reinterpret_cast<uint8_t*>(View.GetData());
+ uint8_t* BufferEndPtr = BufferStartPtr;
+ BufferEndPtr += WriteVarUInt(gsl::narrow<uint64_t>(ChunkCount), BufferEndPtr);
+ auto It = Chunks.begin();
+ while (It != Chunks.end())
+ {
+ BufferEndPtr += WriteVarUInt(gsl::narrow<uint64_t>(It->GetSize()), BufferEndPtr);
+ It++;
+ }
+ ZEN_ASSERT(BufferEndPtr <= View.GetDataEnd());
+ ptrdiff_t TempBufferLength = std::distance(BufferStartPtr, BufferEndPtr);
+ SizeBuffer = SharedBuffer(IoBuffer(TempBuffer, 0, gsl::narrow<size_t>(TempBufferLength)));
+ }
+ CompositeBuffer AllBuffers(std::move(SizeBuffer), CompositeBuffer(std::move(Chunks)));
+
+ CompressedBuffer CompressedBlock =
+ CompressedBuffer::Compress(std::move(AllBuffers), OodleCompressor::Mermaid, OodleCompressionLevel::None);
+
+ return CompressedBlock;
+}
+
+struct Block
+{
+ IoHash BlockHash;
+ std::vector<IoHash> ChunksInBlock;
+};
+
+void
+CreateBlock(WorkerThreadPool& WorkerPool,
+ Latch& OpSectionsLatch,
+ std::vector<SharedBuffer>&& ChunksInBlock,
+ RwLock& SectionsLock,
+ std::vector<Block>& Blocks,
+ size_t BlockIndex,
+ const std::function<void(CompressedBuffer&&, const IoHash&)>& AsyncOnBlock,
+ AsyncRemoteResult& RemoteResult)
+{
+ OpSectionsLatch.AddCount(1);
+ WorkerPool.ScheduleWork(
+ [&Blocks, &SectionsLock, &OpSectionsLatch, BlockIndex, Chunks = std::move(ChunksInBlock), &AsyncOnBlock, &RemoteResult]() mutable {
+ auto _ = MakeGuard([&OpSectionsLatch] { OpSectionsLatch.CountDown(); });
+ if (RemoteResult.IsError())
+ {
+ return;
+ }
+ if (!Chunks.empty())
+ {
+ CompressedBuffer CompressedBlock = GenerateBlock(std::move(Chunks)); // Move to callback and return IoHash
+ IoHash BlockHash = CompressedBlock.DecodeRawHash();
+ AsyncOnBlock(std::move(CompressedBlock), BlockHash);
+ {
+ // We can share the lock as we are not resizing the vector and only touch BlockHash at our own index
+ RwLock::SharedLockScope __(SectionsLock);
+ Blocks[BlockIndex].BlockHash = BlockHash;
+ }
+ }
+ });
+}
+
+size_t
+AddBlock(RwLock& BlocksLock, std::vector<Block>& Blocks)
+{
+ size_t BlockIndex;
+ {
+ RwLock::ExclusiveLockScope _(BlocksLock);
+ BlockIndex = Blocks.size();
+ Blocks.resize(BlockIndex + 1);
+ }
+ return BlockIndex;
+}
+
+CbObject
+BuildContainer(CidStore& ChunkStore,
+ ProjectStore::Oplog& Oplog,
+ size_t MaxBlockSize,
+ size_t MaxChunkEmbedSize,
+ bool BuildBlocks,
+ WorkerThreadPool& WorkerPool,
+ const std::function<void(CompressedBuffer&&, const IoHash&)>& AsyncOnBlock,
+ const std::function<void(const IoHash&)>& OnLargeAttachment,
+ const std::function<void(const std::unordered_set<IoHash, IoHash::Hasher>)>& OnBlockChunks,
+ AsyncRemoteResult& RemoteResult)
+{
+ using namespace std::literals;
+
+ std::unordered_set<IoHash, IoHash::Hasher> LargeChunkHashes;
+ CbObjectWriter SectionOpsWriter;
+ SectionOpsWriter.BeginArray("ops"sv);
+
+ size_t OpCount = 0;
+
+ CbObject OplogContainerObject;
+ {
+ RwLock BlocksLock;
+ std::vector<Block> Blocks;
+ CompressedBuffer OpsBuffer;
+
+ Latch BlockCreateLatch(1);
+
+ std::unordered_set<IoHash, IoHash::Hasher> BlockAttachmentHashes;
+
+ size_t BlockSize = 0;
+ std::vector<SharedBuffer> ChunksInBlock;
+
+ std::unordered_set<IoHash, IoHash::Hasher> Attachments;
+ Oplog.IterateOplog([&Attachments, &SectionOpsWriter, &OpCount](CbObject Op) {
+ Op.IterateAttachments([&](CbFieldView FieldView) { Attachments.insert(FieldView.AsAttachment()); });
+ (SectionOpsWriter) << Op;
+ OpCount++;
+ });
+
+ for (const IoHash& AttachmentHash : Attachments)
+ {
+ IoBuffer Payload = ChunkStore.FindChunkByCid(AttachmentHash);
+ if (!Payload)
+ {
+ RemoteResult.SetError(gsl::narrow<int>(HttpResponseCode::NotFound),
+ fmt::format("Failed to find attachment {} for op", AttachmentHash),
+ {});
+ ZEN_ERROR("Failed to build container ({}). Reason: '{}'", RemoteResult.GetError(), RemoteResult.GetErrorReason());
+ return {};
+ }
+ uint64_t PayloadSize = Payload.GetSize();
+ if (PayloadSize > MaxChunkEmbedSize)
+ {
+ if (LargeChunkHashes.insert(AttachmentHash).second)
+ {
+ OnLargeAttachment(AttachmentHash);
+ }
+ continue;
+ }
+
+ if (!BlockAttachmentHashes.insert(AttachmentHash).second)
+ {
+ continue;
+ }
+
+ BlockSize += PayloadSize;
+ if (BuildBlocks)
+ {
+ ChunksInBlock.emplace_back(SharedBuffer(std::move(Payload)));
+ }
+ else
+ {
+ Payload = {};
+ }
+
+ if (BlockSize >= MaxBlockSize)
+ {
+ size_t BlockIndex = AddBlock(BlocksLock, Blocks);
+ if (BuildBlocks)
+ {
+ CreateBlock(WorkerPool,
+ BlockCreateLatch,
+ std::move(ChunksInBlock),
+ BlocksLock,
+ Blocks,
+ BlockIndex,
+ AsyncOnBlock,
+ RemoteResult);
+ }
+ else
+ {
+ OnBlockChunks(BlockAttachmentHashes);
+ }
+ {
+ // We can share the lock as we are not resizing the vector and only touch BlockHash at our own index
+ RwLock::SharedLockScope _(BlocksLock);
+ Blocks[BlockIndex].ChunksInBlock.insert(Blocks[BlockIndex].ChunksInBlock.end(),
+ BlockAttachmentHashes.begin(),
+ BlockAttachmentHashes.end());
+ }
+ BlockAttachmentHashes.clear();
+ ChunksInBlock.clear();
+ BlockSize = 0;
+ }
+ }
+ if (BlockSize > 0)
+ {
+ size_t BlockIndex = AddBlock(BlocksLock, Blocks);
+ if (BuildBlocks)
+ {
+ CreateBlock(WorkerPool,
+ BlockCreateLatch,
+ std::move(ChunksInBlock),
+ BlocksLock,
+ Blocks,
+ BlockIndex,
+ AsyncOnBlock,
+ RemoteResult);
+ }
+ else
+ {
+ OnBlockChunks(BlockAttachmentHashes);
+ }
+ {
+ // We can share the lock as we are not resizing the vector and only touch BlockHash at our own index
+ RwLock::SharedLockScope _(BlocksLock);
+ Blocks[BlockIndex].ChunksInBlock.insert(Blocks[BlockIndex].ChunksInBlock.end(),
+ BlockAttachmentHashes.begin(),
+ BlockAttachmentHashes.end());
+ }
+ BlockAttachmentHashes.clear();
+ ChunksInBlock.clear();
+ BlockSize = 0;
+ }
+ SectionOpsWriter.EndArray(); // "ops"
+
+ CompressedBuffer CompressedOpsSection = CompressedBuffer::Compress(SectionOpsWriter.Save().GetBuffer());
+ ZEN_DEBUG("Added oplog section {}, {}", CompressedOpsSection.DecodeRawHash(), NiceBytes(CompressedOpsSection.GetCompressedSize()));
+
+ BlockCreateLatch.CountDown();
+ while (!BlockCreateLatch.Wait(1000))
+ {
+ ZEN_INFO("Creating blocks, {} remaining...", BlockCreateLatch.Remaining());
+ }
+
+ if (!RemoteResult.IsError())
+ {
+ CbObjectWriter OplogContinerWriter;
+ RwLock::SharedLockScope _(BlocksLock);
+ OplogContinerWriter.AddBinary("ops"sv, CompressedOpsSection.GetCompressed().Flatten().AsIoBuffer());
+
+ OplogContinerWriter.BeginArray("blocks"sv);
+ {
+ for (const Block& B : Blocks)
+ {
+ ZEN_ASSERT(!B.ChunksInBlock.empty());
+ if (BuildBlocks)
+ {
+ ZEN_ASSERT(B.BlockHash != IoHash::Zero);
+
+ OplogContinerWriter.BeginObject();
+ {
+ OplogContinerWriter.AddBinaryAttachment("rawhash"sv, B.BlockHash);
+ OplogContinerWriter.BeginArray("chunks"sv);
+ {
+ for (const IoHash& RawHash : B.ChunksInBlock)
+ {
+ OplogContinerWriter.AddHash(RawHash);
+ }
+ }
+ OplogContinerWriter.EndArray(); // "chunks"
+ }
+ OplogContinerWriter.EndObject();
+ continue;
+ }
+
+ ZEN_ASSERT(B.BlockHash == IoHash::Zero);
+ OplogContinerWriter.BeginObject();
+ {
+ OplogContinerWriter.BeginArray("chunks"sv);
+ {
+ for (const IoHash& RawHash : B.ChunksInBlock)
+ {
+ OplogContinerWriter.AddBinaryAttachment(RawHash);
+ }
+ }
+ OplogContinerWriter.EndArray();
+ }
+ OplogContinerWriter.EndObject();
+ }
+ }
+ OplogContinerWriter.EndArray(); // "blocks"sv
+
+ OplogContinerWriter.BeginArray("chunks"sv);
+ {
+ for (const IoHash& AttachmentHash : LargeChunkHashes)
+ {
+ OplogContinerWriter.AddBinaryAttachment(AttachmentHash);
+ }
+ }
+ OplogContinerWriter.EndArray(); // "chunks"
+
+ OplogContainerObject = OplogContinerWriter.Save();
+ }
+ }
+ return OplogContainerObject;
+}
+
+RemoteProjectStore::LoadContainerResult
+BuildContainer(CidStore& ChunkStore,
+ ProjectStore::Oplog& Oplog,
+ size_t MaxBlockSize,
+ size_t MaxChunkEmbedSize,
+ bool BuildBlocks,
+ const std::function<void(CompressedBuffer&&, const IoHash&)>& AsyncOnBlock,
+ const std::function<void(const IoHash&)>& OnLargeAttachment,
+ const std::function<void(const std::unordered_set<IoHash, IoHash::Hasher>)>& OnBlockChunks)
+{
+ // We are creating a worker thread pool here since we are uploading a lot of attachments in one go and we dont want to keep a
+ // WorkerThreadPool alive
+ size_t WorkerCount = Min(std::thread::hardware_concurrency(), 16u);
+ WorkerThreadPool WorkerPool(gsl::narrow<int>(WorkerCount));
+
+ AsyncRemoteResult RemoteResult;
+ CbObject ContainerObject = BuildContainer(ChunkStore,
+ Oplog,
+ MaxBlockSize,
+ MaxChunkEmbedSize,
+ BuildBlocks,
+ WorkerPool,
+ AsyncOnBlock,
+ OnLargeAttachment,
+ OnBlockChunks,
+ RemoteResult);
+ return RemoteProjectStore::LoadContainerResult{RemoteResult.ConvertResult(), ContainerObject};
+}
+
+RemoteProjectStore::Result
+SaveOplog(CidStore& ChunkStore,
+ RemoteProjectStore& RemoteStore,
+ ProjectStore::Oplog& Oplog,
+ size_t MaxBlockSize,
+ size_t MaxChunkEmbedSize,
+ bool BuildBlocks,
+ bool UseTempBlocks,
+ bool ForceUpload)
+{
+ using namespace std::literals;
+
+ Stopwatch Timer;
+
+ // We are creating a worker thread pool here since we are uploading a lot of attachments in one go
+ // Doing upload is a rare and transient occation so we don't want to keep a WorkerThreadPool alive.
+ size_t WorkerCount = Min(std::thread::hardware_concurrency(), 16u);
+ WorkerThreadPool WorkerPool(gsl::narrow<int>(WorkerCount));
+
+ std::filesystem::path AttachmentTempPath;
+ if (UseTempBlocks)
+ {
+ AttachmentTempPath = Oplog.TempPath();
+ AttachmentTempPath.append(".pending");
+ CreateDirectories(AttachmentTempPath);
+ }
+
+ AsyncRemoteResult RemoteResult;
+ RwLock AttachmentsLock;
+ std::unordered_set<IoHash, IoHash::Hasher> LargeAttachments;
+ std::unordered_map<IoHash, IoBuffer, IoHash::Hasher> CreatedBlocks;
+
+ auto MakeTempBlock = [AttachmentTempPath, &RemoteResult, &AttachmentsLock, &CreatedBlocks](CompressedBuffer&& CompressedBlock,
+ const IoHash& BlockHash) {
+ std::filesystem::path BlockPath = AttachmentTempPath;
+ BlockPath.append(BlockHash.ToHexString());
+ if (!std::filesystem::exists(BlockPath))
+ {
+ IoBuffer BlockBuffer;
+ try
+ {
+ BasicFile BlockFile;
+ BlockFile.Open(BlockPath, BasicFile::Mode::kTruncateDelete);
+ uint64_t Offset = 0;
+ for (const SharedBuffer& Buffer : CompressedBlock.GetCompressed().GetSegments())
+ {
+ BlockFile.Write(Buffer.GetView(), Offset);
+ Offset += Buffer.GetSize();
+ }
+ void* FileHandle = BlockFile.Detach();
+ BlockBuffer = IoBuffer(IoBuffer::File, FileHandle, 0, Offset);
+ }
+ catch (std::exception& Ex)
+ {
+ RemoteResult.SetError(gsl::narrow<int32_t>(HttpResponseCode::InternalServerError),
+ Ex.what(),
+ "Unable to create temp block file");
+ return;
+ }
+
+ BlockBuffer.MarkAsDeleteOnClose();
+ {
+ RwLock::ExclusiveLockScope __(AttachmentsLock);
+ CreatedBlocks.insert({BlockHash, std::move(BlockBuffer)});
+ }
+ ZEN_DEBUG("Saved temp block {}, {}", BlockHash, NiceBytes(CompressedBlock.GetCompressedSize()));
+ }
+ };
+
+ auto UploadBlock = [&RemoteStore, &RemoteResult](CompressedBuffer&& CompressedBlock, const IoHash& BlockHash) {
+ RemoteProjectStore::SaveAttachmentResult Result = RemoteStore.SaveAttachment(CompressedBlock.GetCompressed(), BlockHash);
+ if (Result.ErrorCode)
+ {
+ RemoteResult.SetError(Result.ErrorCode, Result.Reason, Result.Text);
+ ZEN_ERROR("Failed to save attachment ({}). Reason: '{}'", RemoteResult.GetErrorReason(), RemoteResult.GetError());
+ return;
+ }
+ ZEN_DEBUG("Saved block {}, {}", BlockHash, NiceBytes(CompressedBlock.GetCompressedSize()));
+ };
+
+ std::vector<std::vector<IoHash>> BlockChunks;
+ auto OnBlockChunks = [&BlockChunks](const std::unordered_set<IoHash, IoHash::Hasher>& Chunks) {
+ BlockChunks.push_back({Chunks.begin(), Chunks.end()});
+ ZEN_DEBUG("Found {} block chunks", Chunks.size());
+ };
+
+ auto OnLargeAttachment = [&AttachmentsLock, &LargeAttachments](const IoHash& AttachmentHash) {
+ {
+ RwLock::ExclusiveLockScope _(AttachmentsLock);
+ LargeAttachments.insert(AttachmentHash);
+ }
+ ZEN_DEBUG("Found attachment {}", AttachmentHash);
+ };
+
+ std::function<void(CompressedBuffer&&, const IoHash&)> OnBlock;
+ if (UseTempBlocks)
+ {
+ OnBlock = MakeTempBlock;
+ }
+ else
+ {
+ OnBlock = UploadBlock;
+ }
+
+ CbObject OplogContainerObject = BuildContainer(ChunkStore,
+ Oplog,
+ MaxBlockSize,
+ MaxChunkEmbedSize,
+ BuildBlocks,
+ WorkerPool,
+ OnBlock,
+ OnLargeAttachment,
+ OnBlockChunks,
+ RemoteResult);
+
+ if (!RemoteResult.IsError())
+ {
+ uint64_t ChunkCount = OplogContainerObject["chunks"sv].AsArrayView().Num();
+ uint64_t BlockCount = OplogContainerObject["blocks"sv].AsArrayView().Num();
+ ZEN_INFO("Saving oplog container with {} attachments and {} blocks...", ChunkCount, BlockCount);
+ RemoteProjectStore::SaveResult ContainerSaveResult = RemoteStore.SaveContainer(OplogContainerObject.GetBuffer().AsIoBuffer());
+ if (ContainerSaveResult.ErrorCode)
+ {
+ RemoteResult.SetError(ContainerSaveResult.ErrorCode, ContainerSaveResult.Reason, "Failed to save oplog container");
+ ZEN_ERROR("Failed to save oplog container ({}). Reason: '{}'", RemoteResult.GetErrorReason(), RemoteResult.GetError());
+ }
+ ZEN_DEBUG("Saved container in {}", NiceTimeSpanMs(static_cast<uint64_t>(ContainerSaveResult.ElapsedSeconds * 1000)));
+ if (!ContainerSaveResult.Needs.empty())
+ {
+ ZEN_INFO("Filtering needed attachments...");
+ std::vector<IoHash> NeededLargeAttachments;
+ std::unordered_set<IoHash, IoHash::Hasher> NeededOtherAttachments;
+ NeededLargeAttachments.reserve(LargeAttachments.size());
+ NeededOtherAttachments.reserve(CreatedBlocks.size());
+ if (ForceUpload)
+ {
+ NeededLargeAttachments.insert(NeededLargeAttachments.end(), LargeAttachments.begin(), LargeAttachments.end());
+ }
+ else
+ {
+ for (const IoHash& RawHash : ContainerSaveResult.Needs)
+ {
+ if (LargeAttachments.contains(RawHash))
+ {
+ NeededLargeAttachments.push_back(RawHash);
+ continue;
+ }
+ NeededOtherAttachments.insert(RawHash);
+ }
+ }
+
+ Latch SaveAttachmentsLatch(1);
+ if (!NeededLargeAttachments.empty())
+ {
+ ZEN_INFO("Saving large attachments...");
+ for (const IoHash& RawHash : NeededLargeAttachments)
+ {
+ if (RemoteResult.IsError())
+ {
+ break;
+ }
+ SaveAttachmentsLatch.AddCount(1);
+ WorkerPool.ScheduleWork([&ChunkStore, &RemoteStore, &SaveAttachmentsLatch, &RemoteResult, RawHash, &CreatedBlocks]() {
+ auto _ = MakeGuard([&SaveAttachmentsLatch] { SaveAttachmentsLatch.CountDown(); });
+ if (RemoteResult.IsError())
+ {
+ return;
+ }
+
+ IoBuffer Payload;
+ if (auto It = CreatedBlocks.find(RawHash); It != CreatedBlocks.end())
+ {
+ Payload = std::move(It->second);
+ }
+ else
+ {
+ Payload = ChunkStore.FindChunkByCid(RawHash);
+ }
+ if (!Payload)
+ {
+ RemoteResult.SetError(gsl::narrow<int>(HttpResponseCode::NotFound),
+ fmt::format("Failed to find attachment {}", RawHash),
+ {});
+ ZEN_ERROR("Failed to build container ({}). Reason: '{}'",
+ RemoteResult.GetErrorReason(),
+ RemoteResult.GetError());
+ return;
+ }
+
+ RemoteProjectStore::SaveAttachmentResult Result =
+ RemoteStore.SaveAttachment(CompositeBuffer(SharedBuffer(Payload)), RawHash);
+ if (Result.ErrorCode)
+ {
+ RemoteResult.SetError(Result.ErrorCode, Result.Reason, Result.Text);
+ ZEN_ERROR("Failed to save attachment '{}', {} ({}). Reason: '{}'",
+ RawHash,
+ NiceBytes(Payload.GetSize()),
+ RemoteResult.GetError(),
+ RemoteResult.GetErrorReason());
+ return;
+ }
+ ZEN_DEBUG("Saved attachment {}, {} in {}",
+ RawHash,
+ NiceBytes(Payload.GetSize()),
+ NiceTimeSpanMs(static_cast<uint64_t>(Result.ElapsedSeconds * 1000)));
+ return;
+ });
+ }
+ }
+
+ if (!CreatedBlocks.empty())
+ {
+ ZEN_INFO("Saving created block attachments...");
+ for (auto& It : CreatedBlocks)
+ {
+ if (RemoteResult.IsError())
+ {
+ break;
+ }
+ const IoHash& RawHash = It.first;
+ if (ForceUpload || NeededOtherAttachments.contains(RawHash))
+ {
+ IoBuffer Payload = It.second;
+ ZEN_ASSERT(Payload);
+ SaveAttachmentsLatch.AddCount(1);
+ WorkerPool.ScheduleWork(
+ [&ChunkStore, &RemoteStore, &SaveAttachmentsLatch, &RemoteResult, Payload = std::move(Payload), RawHash]() {
+ auto _ = MakeGuard([&SaveAttachmentsLatch] { SaveAttachmentsLatch.CountDown(); });
+ if (RemoteResult.IsError())
+ {
+ return;
+ }
+
+ RemoteProjectStore::SaveAttachmentResult Result =
+ RemoteStore.SaveAttachment(CompositeBuffer(SharedBuffer(Payload)), RawHash);
+ if (Result.ErrorCode)
+ {
+ RemoteResult.SetError(Result.ErrorCode, Result.Reason, Result.Text);
+ ZEN_ERROR("Failed to save attachment '{}', {} ({}). Reason: '{}'",
+ RawHash,
+ NiceBytes(Payload.GetSize()),
+ RemoteResult.GetError(),
+ RemoteResult.GetErrorReason());
+ return;
+ }
+
+ ZEN_DEBUG("Saved attachment {}, {} in {}",
+ RawHash,
+ NiceBytes(Payload.GetSize()),
+ NiceTimeSpanMs(static_cast<uint64_t>(Result.ElapsedSeconds * 1000)));
+ return;
+ });
+ }
+ It.second = {};
+ }
+ }
+
+ if (!BlockChunks.empty())
+ {
+ ZEN_INFO("Saving chunk block attachments...");
+ for (const std::vector<IoHash>& Chunks : BlockChunks)
+ {
+ if (RemoteResult.IsError())
+ {
+ break;
+ }
+ std::vector<IoHash> NeededChunks;
+ if (ForceUpload)
+ {
+ NeededChunks = Chunks;
+ }
+ else
+ {
+ NeededChunks.reserve(Chunks.size());
+ for (const IoHash& Chunk : Chunks)
+ {
+ if (NeededOtherAttachments.contains(Chunk))
+ {
+ NeededChunks.push_back(Chunk);
+ }
+ }
+ if (NeededChunks.empty())
+ {
+ continue;
+ }
+ }
+ SaveAttachmentsLatch.AddCount(1);
+ WorkerPool.ScheduleWork([&RemoteStore,
+ &ChunkStore,
+ &SaveAttachmentsLatch,
+ &RemoteResult,
+ &Chunks,
+ NeededChunks = std::move(NeededChunks),
+ ForceUpload]() {
+ auto _ = MakeGuard([&SaveAttachmentsLatch] { SaveAttachmentsLatch.CountDown(); });
+ std::vector<SharedBuffer> ChunkBuffers;
+ ChunkBuffers.reserve(NeededChunks.size());
+ for (const IoHash& Chunk : NeededChunks)
+ {
+ IoBuffer ChunkPayload = ChunkStore.FindChunkByCid(Chunk);
+ if (!ChunkPayload)
+ {
+ RemoteResult.SetError(static_cast<int32_t>(HttpResponseCode::NotFound),
+ fmt::format("Missing chunk {}"sv, Chunk),
+ fmt::format("Unable to fetch attachment {} required by the oplog"sv, Chunk));
+ ChunkBuffers.clear();
+ break;
+ }
+ ChunkBuffers.emplace_back(SharedBuffer(std::move(ChunkPayload)));
+ }
+ RemoteProjectStore::SaveAttachmentsResult Result = RemoteStore.SaveAttachments(ChunkBuffers);
+ if (Result.ErrorCode)
+ {
+ RemoteResult.SetError(Result.ErrorCode, Result.Reason, Result.Text);
+ ZEN_ERROR("Failed to save attachments with {} chunks ({}). Reason: '{}'",
+ Chunks.size(),
+ RemoteResult.GetError(),
+ RemoteResult.GetErrorReason());
+ return;
+ }
+ ZEN_DEBUG("Saved {} bulk attachments in {}",
+ Chunks.size(),
+ NiceTimeSpanMs(static_cast<uint64_t>(Result.ElapsedSeconds * 1000)));
+ });
+ }
+ }
+ SaveAttachmentsLatch.CountDown();
+ while (!SaveAttachmentsLatch.Wait(1000))
+ {
+ ZEN_INFO("Saving attachments, {} remaining...", SaveAttachmentsLatch.Remaining());
+ }
+ SaveAttachmentsLatch.Wait();
+ }
+
+ if (!RemoteResult.IsError())
+ {
+ ZEN_INFO("Finalizing oplog container...");
+ RemoteProjectStore::Result ContainerFinalizeResult = RemoteStore.FinalizeContainer(ContainerSaveResult.RawHash);
+ if (ContainerFinalizeResult.ErrorCode)
+ {
+ RemoteResult.SetError(ContainerFinalizeResult.ErrorCode, ContainerFinalizeResult.Reason, ContainerFinalizeResult.Text);
+ ZEN_ERROR("Failed to finalize oplog container {} ({}). Reason: '{}'",
+ ContainerSaveResult.RawHash,
+ RemoteResult.GetError(),
+ RemoteResult.GetErrorReason());
+ }
+ ZEN_DEBUG("Finalized container in {}", NiceTimeSpanMs(static_cast<uint64_t>(ContainerFinalizeResult.ElapsedSeconds * 1000)));
+ }
+ }
+
+ RemoteProjectStore::Result Result = RemoteResult.ConvertResult();
+ Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500;
+ ZEN_INFO("Saved oplog {} in {}",
+ RemoteResult.GetError() == 0 ? "SUCCESS" : "FAILURE",
+ NiceTimeSpanMs(static_cast<uint64_t>(Result.ElapsedSeconds * 1000)));
+ return Result;
+};
+
+RemoteProjectStore::Result
+SaveOplogContainer(ProjectStore::Oplog& Oplog,
+ const CbObject& ContainerObject,
+ const std::function<bool(const IoHash& RawHash)>& HasAttachment,
+ const std::function<void(const IoHash& BlockHash, std::vector<IoHash>&& Chunks)>& OnNeedBlock,
+ const std::function<void(const IoHash& RawHash)>& OnNeedAttachment)
+{
+ using namespace std::literals;
+
+ Stopwatch Timer;
+
+ CbArrayView LargeChunksArray = ContainerObject["chunks"sv].AsArrayView();
+ for (CbFieldView LargeChunksField : LargeChunksArray)
+ {
+ IoHash AttachmentHash = LargeChunksField.AsBinaryAttachment();
+ if (HasAttachment(AttachmentHash))
+ {
+ continue;
+ }
+ OnNeedAttachment(AttachmentHash);
+ };
+
+ CbArrayView BlocksArray = ContainerObject["blocks"sv].AsArrayView();
+ for (CbFieldView BlockField : BlocksArray)
+ {
+ CbObjectView BlockView = BlockField.AsObjectView();
+ IoHash BlockHash = BlockView["rawhash"sv].AsBinaryAttachment();
+
+ CbArrayView ChunksArray = BlockView["chunks"sv].AsArrayView();
+ if (BlockHash == IoHash::Zero)
+ {
+ std::vector<IoHash> NeededChunks;
+ NeededChunks.reserve(ChunksArray.GetSize());
+ for (CbFieldView ChunkField : ChunksArray)
+ {
+ IoHash ChunkHash = ChunkField.AsBinaryAttachment();
+ if (HasAttachment(ChunkHash))
+ {
+ continue;
+ }
+ NeededChunks.emplace_back(ChunkHash);
+ }
+
+ if (!NeededChunks.empty())
+ {
+ OnNeedBlock(IoHash::Zero, std::move(NeededChunks));
+ }
+ continue;
+ }
+
+ for (CbFieldView ChunkField : ChunksArray)
+ {
+ IoHash ChunkHash = ChunkField.AsHash();
+ if (HasAttachment(ChunkHash))
+ {
+ continue;
+ }
+
+ OnNeedBlock(BlockHash, {});
+ break;
+ }
+ };
+
+ MemoryView OpsSection = ContainerObject["ops"sv].AsBinaryView();
+ IoBuffer OpsBuffer(IoBuffer::Wrap, OpsSection.GetData(), OpsSection.GetSize());
+ IoBuffer SectionPayload = CompressedBuffer::FromCompressedNoValidate(std::move(OpsBuffer)).Decompress().AsIoBuffer();
+
+ CbObject SectionObject = LoadCompactBinaryObject(SectionPayload);
+ if (!SectionObject)
+ {
+ ZEN_ERROR("Failed to save oplog container. Reason: '{}'", "Section has unexpected data type");
+ return RemoteProjectStore::Result{gsl::narrow<int>(HttpResponseCode::BadRequest),
+ Timer.GetElapsedTimeMs() / 1000.500,
+ "Section has unexpected data type",
+ "Failed to save oplog container"};
+ }
+
+ CbArrayView OpsArray = SectionObject["ops"sv].AsArrayView();
+ for (CbFieldView OpEntry : OpsArray)
+ {
+ CbObjectView Core = OpEntry.AsObjectView();
+ BinaryWriter Writer;
+ Core.CopyTo(Writer);
+ MemoryView OpView = Writer.GetView();
+ IoBuffer OpBuffer(IoBuffer::Wrap, OpView.GetData(), OpView.GetSize());
+ CbObject Op(SharedBuffer(OpBuffer), CbFieldType::HasFieldType);
+ const uint32_t OpLsn = Oplog.AppendNewOplogEntry(Op);
+ if (OpLsn == ProjectStore::Oplog::kInvalidOp)
+ {
+ return RemoteProjectStore::Result{gsl::narrow<int>(HttpResponseCode::BadRequest),
+ Timer.GetElapsedTimeMs() / 1000.500,
+ "Failed saving op",
+ "Failed to save oplog container"};
+ }
+ ZEN_DEBUG("oplog entry #{}", OpLsn);
+ }
+ return RemoteProjectStore::Result{.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500};
+}
+
+RemoteProjectStore::Result
+LoadOplog(CidStore& ChunkStore, RemoteProjectStore& RemoteStore, ProjectStore::Oplog& Oplog, bool ForceDownload)
+{
+ using namespace std::literals;
+
+ Stopwatch Timer;
+
+ // We are creating a worker thread pool here since we are download a lot of attachments in one go and we dont want to keep a
+ // WorkerThreadPool alive
+ size_t WorkerCount = Min(std::thread::hardware_concurrency(), 16u);
+ WorkerThreadPool WorkerPool(gsl::narrow<int>(WorkerCount));
+
+ std::unordered_set<IoHash, IoHash::Hasher> Attachments;
+ std::vector<std::vector<IoHash>> ChunksInBlocks;
+
+ RemoteProjectStore::LoadContainerResult LoadContainerResult = RemoteStore.LoadContainer();
+ if (LoadContainerResult.ErrorCode)
+ {
+ ZEN_WARN("Failed to load oplog container, reason: '{}', error code: {}", LoadContainerResult.Reason, LoadContainerResult.ErrorCode);
+ return RemoteProjectStore::Result{.ErrorCode = LoadContainerResult.ErrorCode,
+ .ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500,
+ .Reason = LoadContainerResult.Reason,
+ .Text = LoadContainerResult.Text};
+ }
+ ZEN_DEBUG("Loaded container in {}", NiceTimeSpanMs(static_cast<uint64_t>(LoadContainerResult.ElapsedSeconds * 1000)));
+
+ AsyncRemoteResult RemoteResult;
+ Latch AttachmentsWorkLatch(1);
+
+ auto HasAttachment = [&ChunkStore, ForceDownload](const IoHash& RawHash) {
+ return !ForceDownload && ChunkStore.ContainsChunk(RawHash);
+ };
+ auto OnNeedBlock = [&RemoteStore, &ChunkStore, &WorkerPool, &ChunksInBlocks, &AttachmentsWorkLatch, &RemoteResult](
+ const IoHash& BlockHash,
+ std::vector<IoHash>&& Chunks) {
+ if (BlockHash == IoHash::Zero)
+ {
+ AttachmentsWorkLatch.AddCount(1);
+ WorkerPool.ScheduleWork([&RemoteStore, &ChunkStore, &AttachmentsWorkLatch, &RemoteResult, Chunks = std::move(Chunks)]() {
+ auto _ = MakeGuard([&AttachmentsWorkLatch] { AttachmentsWorkLatch.CountDown(); });
+ if (RemoteResult.IsError())
+ {
+ return;
+ }
+
+ RemoteProjectStore::LoadAttachmentsResult Result = RemoteStore.LoadAttachments(Chunks);
+ if (Result.ErrorCode)
+ {
+ RemoteResult.SetError(Result.ErrorCode, Result.Reason, Result.Text);
+ ZEN_ERROR("Failed to attachments with {} chunks ({}). Reason: '{}'",
+ Chunks.size(),
+ RemoteResult.GetError(),
+ RemoteResult.GetErrorReason());
+ return;
+ }
+ ZEN_DEBUG("Loaded {} bulk attachments in {}",
+ Chunks.size(),
+ NiceTimeSpanMs(static_cast<uint64_t>(Result.ElapsedSeconds * 1000)));
+ for (const auto& It : Result.Chunks)
+ {
+ ChunkStore.AddChunk(It.second.GetCompressed().Flatten().AsIoBuffer(), It.first, CidStore::InsertMode::kCopyOnly);
+ }
+ });
+ return;
+ }
+ AttachmentsWorkLatch.AddCount(1);
+ WorkerPool.ScheduleWork([&AttachmentsWorkLatch, &ChunkStore, &RemoteStore, BlockHash, &RemoteResult]() {
+ auto _ = MakeGuard([&AttachmentsWorkLatch] { AttachmentsWorkLatch.CountDown(); });
+ if (RemoteResult.IsError())
+ {
+ return;
+ }
+ RemoteProjectStore::LoadAttachmentResult BlockResult = RemoteStore.LoadAttachment(BlockHash);
+ if (BlockResult.ErrorCode)
+ {
+ RemoteResult.SetError(BlockResult.ErrorCode, BlockResult.Reason, BlockResult.Text);
+ ZEN_ERROR("Failed to load oplog container, missing attachment {} ({}). Reason: '{}'",
+ BlockHash,
+ RemoteResult.GetError(),
+ RemoteResult.GetErrorReason());
+ return;
+ }
+ ZEN_DEBUG("Loaded block attachment in {}", NiceTimeSpanMs(static_cast<uint64_t>(BlockResult.ElapsedSeconds * 1000)));
+
+ if (!IterateBlock(std::move(BlockResult.Bytes), [&ChunkStore](CompressedBuffer&& Chunk, const IoHash& AttachmentRawHash) {
+ ChunkStore.AddChunk(Chunk.GetCompressed().Flatten().AsIoBuffer(), AttachmentRawHash);
+ }))
+ {
+ RemoteResult.SetError(gsl::narrow<int32_t>(HttpResponseCode::InternalServerError),
+ fmt::format("Invalid format for block {}", BlockHash),
+ {});
+ ZEN_ERROR("Failed to load oplog container, attachment {} has invalid format ({}). Reason: '{}'",
+ BlockHash,
+ RemoteResult.GetError(),
+ RemoteResult.GetErrorReason());
+ return;
+ }
+ });
+ };
+
+ auto OnNeedAttachment =
+ [&RemoteStore, &ChunkStore, &WorkerPool, &AttachmentsWorkLatch, &RemoteResult, &Attachments](const IoHash& RawHash) {
+ if (!Attachments.insert(RawHash).second)
+ {
+ return;
+ }
+
+ AttachmentsWorkLatch.AddCount(1);
+ WorkerPool.ScheduleWork([&RemoteStore, &ChunkStore, &RemoteResult, &AttachmentsWorkLatch, RawHash]() {
+ auto _ = MakeGuard([&AttachmentsWorkLatch] { AttachmentsWorkLatch.CountDown(); });
+ if (RemoteResult.IsError())
+ {
+ return;
+ }
+ RemoteProjectStore::LoadAttachmentResult AttachmentResult = RemoteStore.LoadAttachment(RawHash);
+ if (AttachmentResult.ErrorCode)
+ {
+ RemoteResult.SetError(AttachmentResult.ErrorCode, AttachmentResult.Reason, AttachmentResult.Text);
+ ZEN_ERROR("Failed to download attachment {}, reason: '{}', error code: {}",
+ RawHash,
+ AttachmentResult.Reason,
+ AttachmentResult.ErrorCode);
+ return;
+ }
+ ZEN_DEBUG("Loaded attachment in {}", NiceTimeSpanMs(static_cast<uint64_t>(AttachmentResult.ElapsedSeconds * 1000)));
+ ChunkStore.AddChunk(AttachmentResult.Bytes, RawHash);
+ });
+ };
+
+ RemoteProjectStore::Result Result =
+ SaveOplogContainer(Oplog, LoadContainerResult.ContainerObject, HasAttachment, OnNeedBlock, OnNeedAttachment);
+
+ AttachmentsWorkLatch.CountDown();
+ while (!AttachmentsWorkLatch.Wait(1000))
+ {
+ ZEN_INFO("Loading attachments, {} remaining...", AttachmentsWorkLatch.Remaining());
+ }
+ AttachmentsWorkLatch.Wait();
+ if (Result.ErrorCode == 0)
+ {
+ Result = RemoteResult.ConvertResult();
+ }
+ Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500;
+
+ ZEN_INFO("Loaded oplog {} in {}",
+ RemoteResult.GetError() == 0 ? "SUCCESS" : "FAILURE",
+ NiceTimeSpanMs(static_cast<uint64_t>(Result.ElapsedSeconds * 1000.0)));
+
+ return Result;
+}
+
+} // namespace zen
diff --git a/src/zenserver/projectstore/remoteprojectstore.h b/src/zenserver/projectstore/remoteprojectstore.h
new file mode 100644
index 000000000..dcabaedd4
--- /dev/null
+++ b/src/zenserver/projectstore/remoteprojectstore.h
@@ -0,0 +1,111 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "projectstore.h"
+
+#include <unordered_set>
+
+namespace zen {
+
+class CidStore;
+class WorkerThreadPool;
+
+class RemoteProjectStore
+{
+public:
+ struct Result
+ {
+ int32_t ErrorCode{};
+ double ElapsedSeconds{};
+ std::string Reason;
+ std::string Text;
+ };
+
+ struct SaveResult : public Result
+ {
+ std::unordered_set<IoHash, IoHash::Hasher> Needs;
+ IoHash RawHash;
+ };
+
+ struct SaveAttachmentResult : public Result
+ {
+ };
+
+ struct SaveAttachmentsResult : public Result
+ {
+ };
+
+ struct LoadAttachmentResult : public Result
+ {
+ IoBuffer Bytes;
+ };
+
+ struct LoadContainerResult : public Result
+ {
+ CbObject ContainerObject;
+ };
+
+ struct LoadAttachmentsResult : public Result
+ {
+ std::vector<std::pair<IoHash, CompressedBuffer>> Chunks;
+ };
+
+ struct RemoteStoreInfo
+ {
+ bool CreateBlocks;
+ bool UseTempBlockFiles;
+ std::string Description;
+ };
+
+ virtual ~RemoteProjectStore() {}
+
+ virtual RemoteStoreInfo GetInfo() const = 0;
+
+ virtual SaveResult SaveContainer(const IoBuffer& Payload) = 0;
+ virtual SaveAttachmentResult SaveAttachment(const CompositeBuffer& Payload, const IoHash& RawHash) = 0;
+ virtual Result FinalizeContainer(const IoHash& RawHash) = 0;
+ virtual SaveAttachmentsResult SaveAttachments(const std::vector<SharedBuffer>& Payloads) = 0;
+
+ virtual LoadContainerResult LoadContainer() = 0;
+ virtual LoadAttachmentResult LoadAttachment(const IoHash& RawHash) = 0;
+ virtual LoadAttachmentsResult LoadAttachments(const std::vector<IoHash>& RawHashes) = 0;
+};
+
+struct RemoteStoreOptions
+{
+ size_t MaxBlockSize = 128u * 1024u * 1024u;
+ size_t MaxChunkEmbedSize = 1024u * 1024u;
+};
+
+RemoteProjectStore::LoadContainerResult BuildContainer(
+ CidStore& ChunkStore,
+ ProjectStore::Oplog& Oplog,
+ size_t MaxBlockSize,
+ size_t MaxChunkEmbedSize,
+ bool BuildBlocks,
+ const std::function<void(CompressedBuffer&&, const IoHash&)>& AsyncOnBlock,
+ const std::function<void(const IoHash&)>& OnLargeAttachment,
+ const std::function<void(const std::unordered_set<IoHash, IoHash::Hasher>)>& OnBlockChunks);
+
+RemoteProjectStore::Result SaveOplogContainer(ProjectStore::Oplog& Oplog,
+ const CbObject& ContainerObject,
+ const std::function<bool(const IoHash& RawHash)>& HasAttachment,
+ const std::function<void(const IoHash& BlockHash, std::vector<IoHash>&& Chunks)>& OnNeedBlock,
+ const std::function<void(const IoHash& RawHash)>& OnNeedAttachment);
+
+RemoteProjectStore::Result SaveOplog(CidStore& ChunkStore,
+ RemoteProjectStore& RemoteStore,
+ ProjectStore::Oplog& Oplog,
+ size_t MaxBlockSize,
+ size_t MaxChunkEmbedSize,
+ bool BuildBlocks,
+ bool UseTempBlocks,
+ bool ForceUpload);
+
+RemoteProjectStore::Result LoadOplog(CidStore& ChunkStore, RemoteProjectStore& RemoteStore, ProjectStore::Oplog& Oplog, bool ForceDownload);
+
+CompressedBuffer GenerateBlock(std::vector<SharedBuffer>&& Chunks);
+bool IterateBlock(IoBuffer&& CompressedBlock, std::function<void(CompressedBuffer&& Chunk, const IoHash& AttachmentHash)> Visitor);
+
+} // namespace zen
diff --git a/src/zenserver/projectstore/zenremoteprojectstore.cpp b/src/zenserver/projectstore/zenremoteprojectstore.cpp
new file mode 100644
index 000000000..6ff471ae5
--- /dev/null
+++ b/src/zenserver/projectstore/zenremoteprojectstore.cpp
@@ -0,0 +1,341 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "zenremoteprojectstore.h"
+
+#include <zencore/compactbinarybuilder.h>
+#include <zencore/compactbinarypackage.h>
+#include <zencore/compositebuffer.h>
+#include <zencore/fmtutils.h>
+#include <zencore/scopeguard.h>
+#include <zencore/stream.h>
+#include <zencore/timer.h>
+#include <zenhttp/httpshared.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <cpr/cpr.h>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+namespace zen {
+
+using namespace std::literals;
+
+class ZenRemoteStore : public RemoteProjectStore
+{
+public:
+ ZenRemoteStore(std::string_view HostAddress,
+ std::string_view Project,
+ std::string_view Oplog,
+ size_t MaxBlockSize,
+ size_t MaxChunkEmbedSize)
+ : m_HostAddress(HostAddress)
+ , m_ProjectStoreUrl(fmt::format("{}/prj"sv, m_HostAddress))
+ , m_Project(Project)
+ , m_Oplog(Oplog)
+ , m_MaxBlockSize(MaxBlockSize)
+ , m_MaxChunkEmbedSize(MaxChunkEmbedSize)
+ {
+ }
+
+ virtual RemoteStoreInfo GetInfo() const override
+ {
+ return {.CreateBlocks = false, .UseTempBlockFiles = false, .Description = fmt::format("[zen] {}"sv, m_HostAddress)};
+ }
+
+ virtual SaveResult SaveContainer(const IoBuffer& Payload) override
+ {
+ Stopwatch Timer;
+
+ std::unique_ptr<cpr::Session> Session(AllocateSession());
+ auto _ = MakeGuard([this, &Session]() { ReleaseSession(std::move(Session)); });
+
+ std::string SaveRequest = fmt::format("{}/{}/oplog/{}/save"sv, m_ProjectStoreUrl, m_Project, m_Oplog);
+ Session->SetUrl({SaveRequest});
+ Session->SetHeader({{"Content-Type", std::string(MapContentTypeToString(HttpContentType::kCbObject))}});
+ MemoryView Data(Payload.GetView());
+ Session->SetBody({reinterpret_cast<const char*>(Data.GetData()), Data.GetSize()});
+ cpr::Response Response = Session->Post();
+ SaveResult Result = SaveResult{ConvertResult(Response)};
+
+ if (Result.ErrorCode)
+ {
+ Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500;
+ return Result;
+ }
+ IoBuffer ResponsePayload(IoBuffer::Wrap, Response.text.data(), Response.text.size());
+ CbObject ResponseObject = LoadCompactBinaryObject(ResponsePayload);
+ if (!ResponseObject)
+ {
+ Result.Reason = fmt::format("The response for {}/{}/{} is not formatted as a compact binary object"sv,
+ m_ProjectStoreUrl,
+ m_Project,
+ m_Oplog);
+ Result.ErrorCode = gsl::narrow<int32_t>(HttpResponseCode::InternalServerError);
+ Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500;
+ return Result;
+ }
+ CbArrayView NeedsArray = ResponseObject["need"sv].AsArrayView();
+ for (CbFieldView FieldView : NeedsArray)
+ {
+ IoHash ChunkHash = FieldView.AsHash();
+ Result.Needs.insert(ChunkHash);
+ }
+
+ Result.RawHash = IoHash::HashBuffer(Payload);
+ Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500;
+ return Result;
+ }
+
+ virtual SaveAttachmentResult SaveAttachment(const CompositeBuffer& Payload, const IoHash& RawHash) override
+ {
+ Stopwatch Timer;
+
+ std::unique_ptr<cpr::Session> Session(AllocateSession());
+ auto _ = MakeGuard([this, &Session]() { ReleaseSession(std::move(Session)); });
+
+ std::string SaveRequest = fmt::format("{}/{}/oplog/{}/{}"sv, m_ProjectStoreUrl, m_Project, m_Oplog, RawHash);
+ Session->SetUrl({SaveRequest});
+ Session->SetHeader({{"Content-Type", std::string(MapContentTypeToString(HttpContentType::kCompressedBinary))}});
+ uint64_t SizeLeft = Payload.GetSize();
+ CompositeBuffer::Iterator BufferIt = Payload.GetIterator(0);
+ auto ReadCallback = [&Payload, &BufferIt, &SizeLeft](char* buffer, size_t& size, intptr_t) {
+ size = Min<size_t>(size, SizeLeft);
+ MutableMemoryView Data(buffer, size);
+ Payload.CopyTo(Data, BufferIt);
+ SizeLeft -= size;
+ return true;
+ };
+ Session->SetReadCallback(cpr::ReadCallback(gsl::narrow<cpr::cpr_off_t>(SizeLeft), ReadCallback));
+ cpr::Response Response = Session->Post();
+ SaveAttachmentResult Result = SaveAttachmentResult{ConvertResult(Response)};
+ Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500;
+ return Result;
+ }
+
+ virtual SaveAttachmentsResult SaveAttachments(const std::vector<SharedBuffer>& Chunks) override
+ {
+ Stopwatch Timer;
+
+ CbPackage RequestPackage;
+ {
+ CbObjectWriter RequestWriter;
+ RequestWriter.AddString("method"sv, "putchunks"sv);
+ RequestWriter.BeginArray("chunks"sv);
+ {
+ for (const SharedBuffer& Chunk : Chunks)
+ {
+ IoHash RawHash;
+ uint64_t RawSize;
+ CompressedBuffer Compressed = CompressedBuffer::FromCompressed(Chunk, RawHash, RawSize);
+ RequestWriter.AddHash(RawHash);
+ RequestPackage.AddAttachment(CbAttachment(Compressed, RawHash));
+ }
+ }
+ RequestWriter.EndArray(); // "chunks"
+ RequestPackage.SetObject(RequestWriter.Save());
+ }
+ CompositeBuffer Payload = FormatPackageMessageBuffer(RequestPackage, FormatFlags::kDefault);
+
+ std::unique_ptr<cpr::Session> Session(AllocateSession());
+ auto _ = MakeGuard([this, &Session]() { ReleaseSession(std::move(Session)); });
+ std::string SaveRequest = fmt::format("{}/{}/oplog/{}/rpc"sv, m_ProjectStoreUrl, m_Project, m_Oplog);
+ Session->SetUrl({SaveRequest});
+ Session->SetHeader({{"Content-Type", std::string(MapContentTypeToString(HttpContentType::kCbPackage))}});
+
+ uint64_t SizeLeft = Payload.GetSize();
+ CompositeBuffer::Iterator BufferIt = Payload.GetIterator(0);
+ auto ReadCallback = [&Payload, &BufferIt, &SizeLeft](char* buffer, size_t& size, intptr_t) {
+ size = Min<size_t>(size, SizeLeft);
+ MutableMemoryView Data(buffer, size);
+ Payload.CopyTo(Data, BufferIt);
+ SizeLeft -= size;
+ return true;
+ };
+ Session->SetReadCallback(cpr::ReadCallback(gsl::narrow<cpr::cpr_off_t>(SizeLeft), ReadCallback));
+ cpr::Response Response = Session->Post();
+ SaveAttachmentsResult Result = SaveAttachmentsResult{ConvertResult(Response)};
+ Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500;
+ return Result;
+ }
+
+ virtual LoadAttachmentsResult LoadAttachments(const std::vector<IoHash>& RawHashes) override
+ {
+ Stopwatch Timer;
+
+ std::unique_ptr<cpr::Session> Session(AllocateSession());
+ auto _ = MakeGuard([this, &Session]() { ReleaseSession(std::move(Session)); });
+ std::string SaveRequest = fmt::format("{}/{}/oplog/{}/rpc"sv, m_ProjectStoreUrl, m_Project, m_Oplog);
+
+ CbObject Request;
+ {
+ CbObjectWriter RequestWriter;
+ RequestWriter.AddString("method"sv, "getchunks"sv);
+ RequestWriter.BeginArray("chunks"sv);
+ {
+ for (const IoHash& RawHash : RawHashes)
+ {
+ RequestWriter.AddHash(RawHash);
+ }
+ }
+ RequestWriter.EndArray(); // "chunks"
+ Request = RequestWriter.Save();
+ }
+ IoBuffer Payload = Request.GetBuffer().AsIoBuffer();
+ Session->SetBody(cpr::Body{(const char*)Payload.GetData(), Payload.GetSize()});
+ Session->SetUrl(SaveRequest);
+ Session->SetHeader({{"Content-Type", std::string(MapContentTypeToString(HttpContentType::kCbObject))},
+ {"Accept", std::string(MapContentTypeToString(HttpContentType::kCbPackage))}});
+
+ cpr::Response Response = Session->Post();
+ LoadAttachmentsResult Result = LoadAttachmentsResult{ConvertResult(Response)};
+ if (!Result.ErrorCode)
+ {
+ CbPackage Package = ParsePackageMessage(IoBuffer(IoBuffer::Wrap, Response.text.data(), Response.text.size()));
+ std::span<const CbAttachment> Attachments = Package.GetAttachments();
+ Result.Chunks.reserve(Attachments.size());
+ for (const CbAttachment& Attachment : Attachments)
+ {
+ Result.Chunks.emplace_back(
+ std::pair<IoHash, CompressedBuffer>{Attachment.GetHash(), Attachment.AsCompressedBinary().MakeOwned()});
+ }
+ }
+ Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500;
+ return Result;
+ };
+
+ virtual Result FinalizeContainer(const IoHash&) override
+ {
+ Stopwatch Timer;
+
+ RwLock::ExclusiveLockScope _(SessionsLock);
+ Sessions.clear();
+ return {.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500};
+ }
+
+ virtual LoadContainerResult LoadContainer() override
+ {
+ Stopwatch Timer;
+
+ std::unique_ptr<cpr::Session> Session(AllocateSession());
+ auto _ = MakeGuard([this, &Session]() { ReleaseSession(std::move(Session)); });
+ std::string SaveRequest = fmt::format("{}/{}/oplog/{}/load"sv, m_ProjectStoreUrl, m_Project, m_Oplog);
+ Session->SetUrl(SaveRequest);
+ Session->SetHeader({{"Accept", std::string(MapContentTypeToString(HttpContentType::kCbObject))}});
+ Session->SetParameters(
+ {{"maxblocksize", fmt::format("{}", m_MaxBlockSize)}, {"maxchunkembedsize", fmt::format("{}", m_MaxChunkEmbedSize)}});
+ cpr::Response Response = Session->Get();
+
+ LoadContainerResult Result = LoadContainerResult{ConvertResult(Response)};
+ if (!Result.ErrorCode)
+ {
+ Result.ContainerObject = LoadCompactBinaryObject(IoBuffer(IoBuffer::Clone, Response.text.data(), Response.text.size()));
+ if (!Result.ContainerObject)
+ {
+ Result.Reason = fmt::format("The response for {}/{}/{} is not formatted as a compact binary object"sv,
+ m_ProjectStoreUrl,
+ m_Project,
+ m_Oplog);
+ Result.ErrorCode = gsl::narrow<int32_t>(HttpResponseCode::InternalServerError);
+ Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500;
+ return Result;
+ }
+ }
+ Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500;
+ return Result;
+ }
+
+ virtual LoadAttachmentResult LoadAttachment(const IoHash& RawHash) override
+ {
+ Stopwatch Timer;
+
+ std::unique_ptr<cpr::Session> Session(AllocateSession());
+ auto _ = MakeGuard([this, &Session]() { ReleaseSession(std::move(Session)); });
+
+ std::string LoadRequest = fmt::format("{}/{}/oplog/{}/{}"sv, m_ProjectStoreUrl, m_Project, m_Oplog, RawHash);
+ Session->SetUrl({LoadRequest});
+ Session->SetHeader({{"Accept", std::string(MapContentTypeToString(HttpContentType::kCompressedBinary))}});
+ cpr::Response Response = Session->Get();
+ LoadAttachmentResult Result = LoadAttachmentResult{ConvertResult(Response)};
+ if (!Result.ErrorCode)
+ {
+ Result.Bytes = IoBufferBuilder::MakeCloneFromMemory(Response.text.data(), Response.text.size());
+ }
+ Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500;
+ return Result;
+ }
+
+private:
+ std::unique_ptr<cpr::Session> AllocateSession()
+ {
+ RwLock::ExclusiveLockScope _(SessionsLock);
+ if (Sessions.empty())
+ {
+ Sessions.emplace_back(std::make_unique<cpr::Session>());
+ }
+ std::unique_ptr<cpr::Session> Session = std::move(Sessions.back());
+ Sessions.pop_back();
+ return Session;
+ }
+
+ void ReleaseSession(std::unique_ptr<cpr::Session>&& Session)
+ {
+ RwLock::ExclusiveLockScope _(SessionsLock);
+ Sessions.emplace_back(std::move(Session));
+ }
+
+ static Result ConvertResult(const cpr::Response& Response)
+ {
+ std::string Text;
+ std::string Reason = Response.reason;
+ int32_t ErrorCode = 0;
+ if (Response.error.code != cpr::ErrorCode::OK)
+ {
+ ErrorCode = static_cast<int32_t>(Response.error.code);
+ if (!Response.error.message.empty())
+ {
+ Reason = Response.error.message;
+ }
+ }
+ else if (!IsHttpSuccessCode(Response.status_code))
+ {
+ ErrorCode = static_cast<int32_t>(Response.status_code);
+
+ if (auto It = Response.header.find("Content-Type"); It != Response.header.end())
+ {
+ zen::HttpContentType ContentType = zen::ParseContentType(It->second);
+ if (ContentType == zen::HttpContentType::kText)
+ {
+ Text = Response.text;
+ }
+ }
+
+ Reason = fmt::format("{}"sv, Response.status_code);
+ }
+ return {.ErrorCode = ErrorCode, .ElapsedSeconds = Response.elapsed, .Reason = Reason, .Text = Text};
+ }
+
+ RwLock SessionsLock;
+ std::vector<std::unique_ptr<cpr::Session>> Sessions;
+
+ const std::string m_HostAddress;
+ const std::string m_ProjectStoreUrl;
+ const std::string m_Project;
+ const std::string m_Oplog;
+ const size_t m_MaxBlockSize;
+ const size_t m_MaxChunkEmbedSize;
+};
+
+std::unique_ptr<RemoteProjectStore>
+CreateZenRemoteStore(const ZenRemoteStoreOptions& Options)
+{
+ std::string Url = Options.Url;
+ if (Url.find("://"sv) == std::string::npos)
+ {
+ // Assume https URL
+ Url = fmt::format("http://{}"sv, Url);
+ }
+ std::unique_ptr<RemoteProjectStore> RemoteStore =
+ std::make_unique<ZenRemoteStore>(Url, Options.ProjectId, Options.OplogId, Options.MaxBlockSize, Options.MaxChunkEmbedSize);
+ return RemoteStore;
+}
+
+} // namespace zen
diff --git a/src/zenserver/projectstore/zenremoteprojectstore.h b/src/zenserver/projectstore/zenremoteprojectstore.h
new file mode 100644
index 000000000..ef9dcad8c
--- /dev/null
+++ b/src/zenserver/projectstore/zenremoteprojectstore.h
@@ -0,0 +1,18 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "remoteprojectstore.h"
+
+namespace zen {
+
+struct ZenRemoteStoreOptions : RemoteStoreOptions
+{
+ std::string Url;
+ std::string ProjectId;
+ std::string OplogId;
+};
+
+std::unique_ptr<RemoteProjectStore> CreateZenRemoteStore(const ZenRemoteStoreOptions& Options);
+
+} // namespace zen
diff --git a/src/zenserver/resource.h b/src/zenserver/resource.h
new file mode 100644
index 000000000..f2e3b471b
--- /dev/null
+++ b/src/zenserver/resource.h
@@ -0,0 +1,18 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+//{{NO_DEPENDENCIES}}
+// Microsoft Visual C++ generated include file.
+// Used by zenserver.rc
+//
+#define IDI_ICON1 101
+
+// Next default values for new objects
+//
+#ifdef APSTUDIO_INVOKED
+# ifndef APSTUDIO_READONLY_SYMBOLS
+# define _APS_NEXT_RESOURCE_VALUE 102
+# define _APS_NEXT_COMMAND_VALUE 40001
+# define _APS_NEXT_CONTROL_VALUE 1001
+# define _APS_NEXT_SYMED_VALUE 101
+# endif
+#endif
diff --git a/src/zenserver/targetver.h b/src/zenserver/targetver.h
new file mode 100644
index 000000000..d432d6993
--- /dev/null
+++ b/src/zenserver/targetver.h
@@ -0,0 +1,10 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+// Including SDKDDKVer.h defines the highest available Windows platform.
+
+// If you wish to build your application for a previous Windows platform, include WinSDKVer.h and
+// set the _WIN32_WINNT macro to the platform you wish to support before including SDKDDKVer.h.
+
+#include <SDKDDKVer.h>
diff --git a/src/zenserver/testing/httptest.cpp b/src/zenserver/testing/httptest.cpp
new file mode 100644
index 000000000..349a95ab3
--- /dev/null
+++ b/src/zenserver/testing/httptest.cpp
@@ -0,0 +1,207 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "httptest.h"
+
+#include <zencore/compactbinarybuilder.h>
+#include <zencore/compactbinarypackage.h>
+#include <zencore/timer.h>
+
+namespace zen {
+
+using namespace std::literals;
+
+HttpTestingService::HttpTestingService()
+{
+ m_Router.RegisterRoute(
+ "hello",
+ [](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK); },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "hello_slow",
+ [](HttpRouterRequest& Req) {
+ Req.ServerRequest().WriteResponseAsync([](HttpServerRequest& Request) {
+ Stopwatch Timer;
+ Sleep(1000);
+ Request.WriteResponse(HttpResponseCode::OK,
+ HttpContentType::kText,
+ fmt::format("hello, took me {}", NiceTimeSpanMs(Timer.GetElapsedTimeMs())));
+ });
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "hello_veryslow",
+ [](HttpRouterRequest& Req) {
+ Req.ServerRequest().WriteResponseAsync([](HttpServerRequest& Request) {
+ Stopwatch Timer;
+ Sleep(60000);
+ Request.WriteResponse(HttpResponseCode::OK,
+ HttpContentType::kText,
+ fmt::format("hello, took me {}", NiceTimeSpanMs(Timer.GetElapsedTimeMs())));
+ });
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "hello_throw",
+ [](HttpRouterRequest& Req) {
+ Req.ServerRequest().WriteResponseAsync([](HttpServerRequest&) { throw std::runtime_error("intentional error"); });
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "hello_noresponse",
+ [](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponseAsync([](HttpServerRequest&) {}); },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "metrics",
+ [this](HttpRouterRequest& Req) {
+ metrics::OperationTiming::Scope _(m_TimingStats);
+ Req.ServerRequest().WriteResponse(HttpResponseCode::OK);
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "get_metrics",
+ [this](HttpRouterRequest& Req) {
+ CbObjectWriter Cbo;
+ EmitSnapshot("requests", m_TimingStats, Cbo);
+ Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Cbo.Save());
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "json",
+ [this](HttpRouterRequest& Req) {
+ CbObjectWriter Obj;
+ Obj.AddBool("ok", true);
+ Obj.AddInteger("counter", ++m_Counter);
+ Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Obj.Save());
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "echo",
+ [](HttpRouterRequest& Req) {
+ IoBuffer Body = Req.ServerRequest().ReadPayload();
+ Req.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, Body);
+ },
+ HttpVerb::kPost);
+
+ m_Router.RegisterRoute(
+ "package",
+ [](HttpRouterRequest& Req) {
+ CbPackage Pkg = Req.ServerRequest().ReadPayloadPackage();
+ Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Pkg);
+ },
+ HttpVerb::kPost);
+}
+
+HttpTestingService::~HttpTestingService()
+{
+}
+
+const char*
+HttpTestingService::BaseUri() const
+{
+ return "/testing/";
+}
+
+void
+HttpTestingService::HandleRequest(HttpServerRequest& Request)
+{
+ m_Router.HandleRequest(Request);
+}
+
+Ref<IHttpPackageHandler>
+HttpTestingService::HandlePackageRequest(HttpServerRequest& HttpServiceRequest)
+{
+ RwLock::ExclusiveLockScope _(m_RwLock);
+
+ const uint32_t RequestId = HttpServiceRequest.RequestId();
+
+ if (auto It = m_HandlerMap.find(RequestId); It != m_HandlerMap.end())
+ {
+ Ref<HttpTestingService::PackageHandler> Handler = std::move(It->second);
+
+ m_HandlerMap.erase(It);
+
+ return Handler;
+ }
+
+ auto InsertResult = m_HandlerMap.insert({RequestId, Ref<PackageHandler>()});
+
+ _.ReleaseNow();
+
+ return (InsertResult.first->second = Ref<PackageHandler>(new PackageHandler(*this, RequestId)));
+}
+
+void
+HttpTestingService::RegisterHandlers(WebSocketServer& Server)
+{
+ Server.RegisterRequestHandler("SayHello"sv, *this);
+}
+
+bool
+HttpTestingService::HandleRequest(const WebSocketMessage& RequestMsg)
+{
+ CbObjectView Request = RequestMsg.Body().GetObject();
+
+ std::string_view Method = Request["Method"].AsString();
+
+ if (Method != "SayHello"sv)
+ {
+ return false;
+ }
+
+ CbObjectWriter Response;
+ Response.AddString("Result"sv, "Hello Friend!!");
+
+ WebSocketMessage ResponseMsg;
+ ResponseMsg.SetMessageType(WebSocketMessageType::kResponse);
+ ResponseMsg.SetCorrelationId(RequestMsg.CorrelationId());
+ ResponseMsg.SetSocketId(RequestMsg.SocketId());
+ ResponseMsg.SetBody(Response.Save());
+
+ SocketServer().SendResponse(std::move(ResponseMsg));
+
+ return true;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+HttpTestingService::PackageHandler::PackageHandler(HttpTestingService& Svc, uint32_t RequestId) : m_Svc(Svc), m_RequestId(RequestId)
+{
+}
+
+HttpTestingService::PackageHandler::~PackageHandler()
+{
+}
+
+void
+HttpTestingService::PackageHandler::FilterOffer(std::vector<IoHash>& OfferCids)
+{
+ ZEN_UNUSED(OfferCids);
+ // No-op
+ return;
+}
+void
+HttpTestingService::PackageHandler::OnRequestBegin()
+{
+}
+
+void
+HttpTestingService::PackageHandler::OnRequestComplete()
+{
+}
+
+IoBuffer
+HttpTestingService::PackageHandler::CreateTarget(const IoHash& Cid, uint64_t StorageSize)
+{
+ ZEN_UNUSED(Cid);
+ return IoBuffer{StorageSize};
+}
+
+} // namespace zen
diff --git a/src/zenserver/testing/httptest.h b/src/zenserver/testing/httptest.h
new file mode 100644
index 000000000..57d2d63f3
--- /dev/null
+++ b/src/zenserver/testing/httptest.h
@@ -0,0 +1,55 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/logging.h>
+#include <zencore/stats.h>
+#include <zenhttp/httpserver.h>
+#include <zenhttp/websocket.h>
+
+#include <atomic>
+
+namespace zen {
+
+/**
+ * Test service to facilitate testing the HTTP framework and client interactions
+ */
+class HttpTestingService : public HttpService, public WebSocketService
+{
+public:
+ HttpTestingService();
+ ~HttpTestingService();
+
+ virtual const char* BaseUri() const override;
+ virtual void HandleRequest(HttpServerRequest& Request) override;
+ virtual Ref<IHttpPackageHandler> HandlePackageRequest(HttpServerRequest& HttpServiceRequest) override;
+
+ class PackageHandler : public IHttpPackageHandler
+ {
+ public:
+ PackageHandler(HttpTestingService& Svc, uint32_t RequestId);
+ ~PackageHandler();
+
+ virtual void FilterOffer(std::vector<IoHash>& OfferCids) override;
+ virtual void OnRequestBegin() override;
+ virtual IoBuffer CreateTarget(const IoHash& Cid, uint64_t StorageSize) override;
+ virtual void OnRequestComplete() override;
+
+ private:
+ HttpTestingService& m_Svc;
+ uint32_t m_RequestId;
+ };
+
+private:
+ virtual void RegisterHandlers(WebSocketServer& Server) override;
+ virtual bool HandleRequest(const WebSocketMessage& Request) override;
+
+ HttpRequestRouter m_Router;
+ std::atomic<uint32_t> m_Counter{0};
+ metrics::OperationTiming m_TimingStats;
+
+ RwLock m_RwLock;
+ std::unordered_map<uint32_t, Ref<PackageHandler>> m_HandlerMap;
+};
+
+} // namespace zen
diff --git a/src/zenserver/upstream/hordecompute.cpp b/src/zenserver/upstream/hordecompute.cpp
new file mode 100644
index 000000000..64d9fff72
--- /dev/null
+++ b/src/zenserver/upstream/hordecompute.cpp
@@ -0,0 +1,1457 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "upstreamapply.h"
+
+#if ZEN_WITH_COMPUTE_SERVICES
+
+# include "jupiter.h"
+
+# include <zencore/compactbinary.h>
+# include <zencore/compactbinarybuilder.h>
+# include <zencore/compactbinarypackage.h>
+# include <zencore/compactbinaryvalidation.h>
+# include <zencore/fmtutils.h>
+# include <zencore/session.h>
+# include <zencore/stream.h>
+# include <zencore/thread.h>
+# include <zencore/timer.h>
+# include <zencore/workthreadpool.h>
+
+# include <zenstore/cidstore.h>
+
+# include <auth/authmgr.h>
+# include <upstream/upstreamcache.h>
+
+# include "cache/structuredcachestore.h"
+# include "diag/logging.h"
+
+# include <fmt/format.h>
+
+# include <algorithm>
+# include <atomic>
+# include <set>
+# include <stack>
+
+namespace zen {
+
+using namespace std::literals;
+
+static const IoBuffer EmptyBuffer;
+static const IoHash EmptyBufferId = IoHash::HashBuffer(EmptyBuffer);
+
+namespace detail {
+
+ class HordeUpstreamApplyEndpoint final : public UpstreamApplyEndpoint
+ {
+ public:
+ HordeUpstreamApplyEndpoint(const CloudCacheClientOptions& ComputeOptions,
+ const UpstreamAuthConfig& ComputeAuthConfig,
+ const CloudCacheClientOptions& StorageOptions,
+ const UpstreamAuthConfig& StorageAuthConfig,
+ CidStore& CidStore,
+ AuthMgr& Mgr)
+ : m_Log(logging::Get("upstream-apply"))
+ , m_CidStore(CidStore)
+ , m_AuthMgr(Mgr)
+ {
+ m_DisplayName = fmt::format("{} - '{}'+'{}'", ComputeOptions.Name, ComputeOptions.ServiceUrl, StorageOptions.ServiceUrl);
+ m_ChannelId = fmt::format("zen-{}", zen::GetSessionIdString());
+
+ {
+ std::unique_ptr<CloudCacheTokenProvider> TokenProvider;
+
+ if (ComputeAuthConfig.OAuthUrl.empty() == false)
+ {
+ TokenProvider =
+ CloudCacheTokenProvider::CreateFromOAuthClientCredentials({.Url = ComputeAuthConfig.OAuthUrl,
+ .ClientId = ComputeAuthConfig.OAuthClientId,
+ .ClientSecret = ComputeAuthConfig.OAuthClientSecret});
+ }
+ else if (ComputeAuthConfig.OpenIdProvider.empty() == false)
+ {
+ TokenProvider =
+ CloudCacheTokenProvider::CreateFromCallback([this, ProviderName = std::string(ComputeAuthConfig.OpenIdProvider)]() {
+ AuthMgr::OpenIdAccessToken Token = m_AuthMgr.GetOpenIdAccessToken(ProviderName);
+ return CloudCacheAccessToken{.Value = Token.AccessToken, .ExpireTime = Token.ExpireTime};
+ });
+ }
+ else
+ {
+ CloudCacheAccessToken AccessToken{.Value = std::string(ComputeAuthConfig.AccessToken),
+ .ExpireTime = CloudCacheAccessToken::TimePoint::max()};
+ TokenProvider = CloudCacheTokenProvider::CreateFromStaticToken(AccessToken);
+ }
+
+ m_Client = new CloudCacheClient(ComputeOptions, std::move(TokenProvider));
+ }
+
+ {
+ std::unique_ptr<CloudCacheTokenProvider> TokenProvider;
+
+ if (StorageAuthConfig.OAuthUrl.empty() == false)
+ {
+ TokenProvider =
+ CloudCacheTokenProvider::CreateFromOAuthClientCredentials({.Url = StorageAuthConfig.OAuthUrl,
+ .ClientId = StorageAuthConfig.OAuthClientId,
+ .ClientSecret = StorageAuthConfig.OAuthClientSecret});
+ }
+ else if (StorageAuthConfig.OpenIdProvider.empty() == false)
+ {
+ TokenProvider =
+ CloudCacheTokenProvider::CreateFromCallback([this, ProviderName = std::string(StorageAuthConfig.OpenIdProvider)]() {
+ AuthMgr::OpenIdAccessToken Token = m_AuthMgr.GetOpenIdAccessToken(ProviderName);
+ return CloudCacheAccessToken{.Value = Token.AccessToken, .ExpireTime = Token.ExpireTime};
+ });
+ }
+ else
+ {
+ CloudCacheAccessToken AccessToken{.Value = std::string(StorageAuthConfig.AccessToken),
+ .ExpireTime = CloudCacheAccessToken::TimePoint::max()};
+ TokenProvider = CloudCacheTokenProvider::CreateFromStaticToken(AccessToken);
+ }
+
+ m_StorageClient = new CloudCacheClient(StorageOptions, std::move(TokenProvider));
+ }
+ }
+
+ virtual ~HordeUpstreamApplyEndpoint() = default;
+
+ virtual UpstreamEndpointHealth Initialize() override { return CheckHealth(); }
+
+ virtual bool IsHealthy() const override { return m_HealthOk.load(); }
+
+ virtual UpstreamEndpointHealth CheckHealth() override
+ {
+ try
+ {
+ CloudCacheSession Session(m_Client);
+ CloudCacheResult Result = Session.Authenticate();
+
+ m_HealthOk = Result.ErrorCode == 0;
+
+ return {.Reason = std::move(Result.Reason), .Ok = Result.Success};
+ }
+ catch (std::exception& Err)
+ {
+ return {.Reason = Err.what(), .Ok = false};
+ }
+ }
+
+ virtual std::string_view DisplayName() const override { return m_DisplayName; }
+
+ virtual PostUpstreamApplyResult PostApply(UpstreamApplyRecord ApplyRecord) override
+ {
+ PostUpstreamApplyResult ApplyResult{};
+ ApplyResult.Timepoints.merge(ApplyRecord.Timepoints);
+
+ try
+ {
+ UpstreamData UpstreamData;
+ if (!ProcessApplyKey(ApplyRecord, UpstreamData))
+ {
+ return {.Error{.ErrorCode = -1, .Reason = "Failed to generate task data"}};
+ }
+
+ {
+ ApplyResult.Timepoints["zen-storage-build-ref"] = DateTime::NowTicks();
+
+ bool AlreadyQueued;
+ {
+ std::scoped_lock Lock(m_TaskMutex);
+ AlreadyQueued = m_PendingTasks.contains(UpstreamData.TaskId);
+ }
+ if (AlreadyQueued)
+ {
+ // Pending task is already queued, return success
+ ApplyResult.Success = true;
+ return ApplyResult;
+ }
+ m_PendingTasks[UpstreamData.TaskId] = std::move(ApplyRecord);
+ }
+
+ CloudCacheSession ComputeSession(m_Client);
+ CloudCacheSession StorageSession(m_StorageClient);
+
+ {
+ CloudCacheResult Result = BatchPutBlobsIfMissing(StorageSession, UpstreamData.Blobs, UpstreamData.CasIds);
+ ApplyResult.Bytes += Result.Bytes;
+ ApplyResult.ElapsedSeconds += Result.ElapsedSeconds;
+ ApplyResult.Timepoints["zen-storage-upload-blobs"] = DateTime::NowTicks();
+ if (!Result.Success)
+ {
+ ApplyResult.Error = {.ErrorCode = Result.ErrorCode,
+ .Reason = !Result.Reason.empty() ? std::move(Result.Reason) : "Failed to upload blobs"};
+ return ApplyResult;
+ }
+ UpstreamData.Blobs.clear();
+ UpstreamData.CasIds.clear();
+ }
+
+ {
+ CloudCacheResult Result = BatchPutCompressedBlobsIfMissing(StorageSession, UpstreamData.Cids);
+ ApplyResult.Bytes += Result.Bytes;
+ ApplyResult.ElapsedSeconds += Result.ElapsedSeconds;
+ ApplyResult.Timepoints["zen-storage-upload-compressed-blobs"] = DateTime::NowTicks();
+ if (!Result.Success)
+ {
+ ApplyResult.Error = {
+ .ErrorCode = Result.ErrorCode,
+ .Reason = !Result.Reason.empty() ? std::move(Result.Reason) : "Failed to upload compressed blobs"};
+ return ApplyResult;
+ }
+ UpstreamData.Cids.clear();
+ }
+
+ {
+ CloudCacheResult Result = BatchPutObjectsIfMissing(StorageSession, UpstreamData.Objects);
+ ApplyResult.Bytes += Result.Bytes;
+ ApplyResult.ElapsedSeconds += Result.ElapsedSeconds;
+ ApplyResult.Timepoints["zen-storage-upload-objects"] = DateTime::NowTicks();
+ if (!Result.Success)
+ {
+ ApplyResult.Error = {.ErrorCode = Result.ErrorCode,
+ .Reason = !Result.Reason.empty() ? std::move(Result.Reason) : "Failed to upload objects"};
+ return ApplyResult;
+ }
+ }
+
+ {
+ PutRefResult RefResult = StorageSession.PutRef(StorageSession.Client().DefaultBlobStoreNamespace(),
+ "requests"sv,
+ UpstreamData.TaskId,
+ UpstreamData.Objects[UpstreamData.TaskId].GetBuffer().AsIoBuffer(),
+ ZenContentType::kCbObject);
+ Log().debug("Put ref {} Need={} Bytes={} Duration={}s Result={}",
+ UpstreamData.TaskId,
+ RefResult.Needs.size(),
+ RefResult.Bytes,
+ RefResult.ElapsedSeconds,
+ RefResult.Success);
+ ApplyResult.Bytes += RefResult.Bytes;
+ ApplyResult.ElapsedSeconds += RefResult.ElapsedSeconds;
+ ApplyResult.Timepoints["zen-storage-put-ref"] = DateTime::NowTicks();
+
+ if (RefResult.Needs.size() > 0)
+ {
+ Log().error("Failed to add task ref {} due to {} missing blobs", UpstreamData.TaskId, RefResult.Needs.size());
+ for (const auto& Hash : RefResult.Needs)
+ {
+ Log().debug("Task ref {} missing blob {}", UpstreamData.TaskId, Hash);
+ }
+
+ ApplyResult.Error = {.ErrorCode = RefResult.ErrorCode,
+ .Reason = !RefResult.Reason.empty() ? std::move(RefResult.Reason)
+ : "Failed to add task ref due to missing blob"};
+ return ApplyResult;
+ }
+
+ if (!RefResult.Success)
+ {
+ ApplyResult.Error = {.ErrorCode = RefResult.ErrorCode,
+ .Reason = !RefResult.Reason.empty() ? std::move(RefResult.Reason) : "Failed to add task ref"};
+ return ApplyResult;
+ }
+ UpstreamData.Objects.clear();
+ }
+
+ {
+ CbObjectWriter Writer;
+ Writer.AddString("c"sv, m_ChannelId);
+ Writer.AddObjectAttachment("r"sv, UpstreamData.RequirementsId);
+ Writer.BeginArray("t"sv);
+ Writer.AddObjectAttachment(UpstreamData.TaskId);
+ Writer.EndArray();
+ CbObject TasksObject = Writer.Save();
+ IoBuffer TasksData = TasksObject.GetBuffer().AsIoBuffer();
+
+ CloudCacheResult Result = ComputeSession.PostComputeTasks(TasksData);
+ Log().debug("Post compute task {} Bytes={} Duration={}s Result={}",
+ TasksObject.GetHash(),
+ Result.Bytes,
+ Result.ElapsedSeconds,
+ Result.Success);
+ ApplyResult.Bytes += Result.Bytes;
+ ApplyResult.ElapsedSeconds += Result.ElapsedSeconds;
+ ApplyResult.Timepoints["zen-horde-post-task"] = DateTime::NowTicks();
+ if (!Result.Success)
+ {
+ {
+ std::scoped_lock Lock(m_TaskMutex);
+ m_PendingTasks.erase(UpstreamData.TaskId);
+ }
+
+ ApplyResult.Error = {.ErrorCode = Result.ErrorCode,
+ .Reason = !Result.Reason.empty() ? std::move(Result.Reason) : "Failed to post compute task"};
+ return ApplyResult;
+ }
+ }
+
+ Log().info("Task posted {}", UpstreamData.TaskId);
+ ApplyResult.Success = true;
+ return ApplyResult;
+ }
+ catch (std::exception& Err)
+ {
+ m_HealthOk = false;
+ return {.Error{.ErrorCode = -1, .Reason = Err.what()}};
+ }
+ }
+
+ [[nodiscard]] CloudCacheResult BatchPutBlobsIfMissing(CloudCacheSession& Session,
+ const std::map<IoHash, IoBuffer>& Blobs,
+ const std::set<IoHash>& CasIds)
+ {
+ if (Blobs.size() == 0 && CasIds.size() == 0)
+ {
+ return {.Success = true};
+ }
+
+ int64_t Bytes{};
+ double ElapsedSeconds{};
+
+ // Batch check for missing blobs
+ std::set<IoHash> Keys;
+ std::transform(Blobs.begin(), Blobs.end(), std::inserter(Keys, Keys.end()), [](const auto& It) { return It.first; });
+ Keys.insert(CasIds.begin(), CasIds.end());
+
+ CloudCacheExistsResult ExistsResult = Session.BlobExists(Session.Client().DefaultBlobStoreNamespace(), Keys);
+ Log().debug("Queried {} missing blobs Need={} Duration={}s Result={}",
+ Keys.size(),
+ ExistsResult.Needs.size(),
+ ExistsResult.ElapsedSeconds,
+ ExistsResult.Success);
+ ElapsedSeconds += ExistsResult.ElapsedSeconds;
+ if (!ExistsResult.Success)
+ {
+ return {.Bytes = Bytes,
+ .ElapsedSeconds = ElapsedSeconds,
+ .ErrorCode = ExistsResult.ErrorCode ? ExistsResult.ErrorCode : -1,
+ .Reason = !ExistsResult.Reason.empty() ? std::move(ExistsResult.Reason) : "Failed to check if blobs exist"};
+ }
+
+ for (const auto& Hash : ExistsResult.Needs)
+ {
+ IoBuffer DataBuffer;
+ if (Blobs.contains(Hash))
+ {
+ DataBuffer = Blobs.at(Hash);
+ }
+ else
+ {
+ DataBuffer = m_CidStore.FindChunkByCid(Hash);
+ if (!DataBuffer)
+ {
+ Log().warn("Put blob FAILED, input chunk '{}' missing", Hash);
+ return {.Bytes = Bytes, .ElapsedSeconds = ElapsedSeconds, .ErrorCode = -1, .Reason = "Failed to put blobs"};
+ }
+ }
+
+ CloudCacheResult Result = Session.PutBlob(Session.Client().DefaultBlobStoreNamespace(), Hash, DataBuffer);
+ Log().debug("Put blob {} Bytes={} Duration={}s Result={}", Hash, Result.Bytes, Result.ElapsedSeconds, Result.Success);
+ Bytes += Result.Bytes;
+ ElapsedSeconds += Result.ElapsedSeconds;
+ if (!Result.Success)
+ {
+ return {.Bytes = Bytes,
+ .ElapsedSeconds = ElapsedSeconds,
+ .ErrorCode = Result.ErrorCode ? Result.ErrorCode : -1,
+ .Reason = !Result.Reason.empty() ? std::move(Result.Reason) : "Failed to put blobs"};
+ }
+ }
+
+ return {.Bytes = Bytes, .ElapsedSeconds = ElapsedSeconds, .Success = true};
+ }
+
+ [[nodiscard]] CloudCacheResult BatchPutCompressedBlobsIfMissing(CloudCacheSession& Session, const std::set<IoHash>& Cids)
+ {
+ if (Cids.size() == 0)
+ {
+ return {.Success = true};
+ }
+
+ int64_t Bytes{};
+ double ElapsedSeconds{};
+
+ // Batch check for missing compressed blobs
+ CloudCacheExistsResult ExistsResult = Session.CompressedBlobExists(Session.Client().DefaultBlobStoreNamespace(), Cids);
+ Log().debug("Queried {} missing compressed blobs Need={} Duration={}s Result={}",
+ Cids.size(),
+ ExistsResult.Needs.size(),
+ ExistsResult.ElapsedSeconds,
+ ExistsResult.Success);
+ ElapsedSeconds += ExistsResult.ElapsedSeconds;
+ if (!ExistsResult.Success)
+ {
+ return {
+ .Bytes = Bytes,
+ .ElapsedSeconds = ElapsedSeconds,
+ .ErrorCode = ExistsResult.ErrorCode ? ExistsResult.ErrorCode : -1,
+ .Reason = !ExistsResult.Reason.empty() ? std::move(ExistsResult.Reason) : "Failed to check if compressed blobs exist"};
+ }
+
+ for (const auto& Hash : ExistsResult.Needs)
+ {
+ IoBuffer DataBuffer = m_CidStore.FindChunkByCid(Hash);
+ if (!DataBuffer)
+ {
+ Log().warn("Put compressed blob FAILED, input CID chunk '{}' missing", Hash);
+ return {.Bytes = Bytes, .ElapsedSeconds = ElapsedSeconds, .ErrorCode = -1, .Reason = "Failed to put compressed blobs"};
+ }
+
+ CloudCacheResult Result = Session.PutCompressedBlob(Session.Client().DefaultBlobStoreNamespace(), Hash, DataBuffer);
+ Log().debug("Put compressed blob {} Bytes={} Duration={}s Result={}",
+ Hash,
+ Result.Bytes,
+ Result.ElapsedSeconds,
+ Result.Success);
+ Bytes += Result.Bytes;
+ ElapsedSeconds += Result.ElapsedSeconds;
+ if (!Result.Success)
+ {
+ return {.Bytes = Bytes,
+ .ElapsedSeconds = ElapsedSeconds,
+ .ErrorCode = Result.ErrorCode ? Result.ErrorCode : -1,
+ .Reason = !Result.Reason.empty() ? std::move(Result.Reason) : "Failed to put compressed blobs"};
+ }
+ }
+
+ return {.Bytes = Bytes, .ElapsedSeconds = ElapsedSeconds, .Success = true};
+ }
+
+ [[nodiscard]] CloudCacheResult BatchPutObjectsIfMissing(CloudCacheSession& Session, const std::map<IoHash, CbObject>& Objects)
+ {
+ if (Objects.size() == 0)
+ {
+ return {.Success = true};
+ }
+
+ int64_t Bytes{};
+ double ElapsedSeconds{};
+
+ // Batch check for missing objects
+ std::set<IoHash> Keys;
+ std::transform(Objects.begin(), Objects.end(), std::inserter(Keys, Keys.end()), [](const auto& It) { return It.first; });
+
+ CloudCacheExistsResult ExistsResult = Session.ObjectExists(Session.Client().DefaultBlobStoreNamespace(), Keys);
+ Log().debug("Queried {} missing objects Need={} Duration={}s Result={}",
+ Keys.size(),
+ ExistsResult.Needs.size(),
+ ExistsResult.ElapsedSeconds,
+ ExistsResult.Success);
+ ElapsedSeconds += ExistsResult.ElapsedSeconds;
+ if (!ExistsResult.Success)
+ {
+ return {.Bytes = Bytes,
+ .ElapsedSeconds = ElapsedSeconds,
+ .ErrorCode = ExistsResult.ErrorCode ? ExistsResult.ErrorCode : -1,
+ .Reason = !ExistsResult.Reason.empty() ? std::move(ExistsResult.Reason) : "Failed to check if objects exist"};
+ }
+
+ for (const auto& Hash : ExistsResult.Needs)
+ {
+ CloudCacheResult Result =
+ Session.PutObject(Session.Client().DefaultBlobStoreNamespace(), Hash, Objects.at(Hash).GetBuffer().AsIoBuffer());
+ Log().debug("Put object {} Bytes={} Duration={}s Result={}", Hash, Result.Bytes, Result.ElapsedSeconds, Result.Success);
+ Bytes += Result.Bytes;
+ ElapsedSeconds += Result.ElapsedSeconds;
+ if (!Result.Success)
+ {
+ return {.Bytes = Bytes,
+ .ElapsedSeconds = ElapsedSeconds,
+ .ErrorCode = Result.ErrorCode ? Result.ErrorCode : -1,
+ .Reason = !Result.Reason.empty() ? std::move(Result.Reason) : "Failed to put objects"};
+ }
+ }
+
+ return {.Bytes = Bytes, .ElapsedSeconds = ElapsedSeconds, .Success = true};
+ }
+
+ enum class ComputeTaskState : int32_t
+ {
+ Queued = 0,
+ Executing = 1,
+ Complete = 2,
+ };
+
+ enum class ComputeTaskOutcome : int32_t
+ {
+ Success = 0,
+ Failed = 1,
+ Cancelled = 2,
+ NoResult = 3,
+ Exipred = 4,
+ BlobNotFound = 5,
+ Exception = 6,
+ };
+
+ [[nodiscard]] static std::string_view ComputeTaskStateToString(const ComputeTaskState Outcome)
+ {
+ switch (Outcome)
+ {
+ case ComputeTaskState::Queued:
+ return "Queued"sv;
+ case ComputeTaskState::Executing:
+ return "Executing"sv;
+ case ComputeTaskState::Complete:
+ return "Complete"sv;
+ };
+ return "Unknown"sv;
+ }
+
+ [[nodiscard]] static std::string_view ComputeTaskOutcomeToString(const ComputeTaskOutcome Outcome)
+ {
+ switch (Outcome)
+ {
+ case ComputeTaskOutcome::Success:
+ return "Success"sv;
+ case ComputeTaskOutcome::Failed:
+ return "Failed"sv;
+ case ComputeTaskOutcome::Cancelled:
+ return "Cancelled"sv;
+ case ComputeTaskOutcome::NoResult:
+ return "NoResult"sv;
+ case ComputeTaskOutcome::Exipred:
+ return "Exipred"sv;
+ case ComputeTaskOutcome::BlobNotFound:
+ return "BlobNotFound"sv;
+ case ComputeTaskOutcome::Exception:
+ return "Exception"sv;
+ };
+ return "Unknown"sv;
+ }
+
+ virtual GetUpstreamApplyUpdatesResult GetUpdates(WorkerThreadPool& ThreadPool) override
+ {
+ int64_t Bytes{};
+ double ElapsedSeconds{};
+
+ {
+ std::scoped_lock Lock(m_TaskMutex);
+ if (m_PendingTasks.empty())
+ {
+ if (m_CompletedTasks.empty())
+ {
+ // Nothing to do.
+ return {.Success = true};
+ }
+
+ UpstreamApplyCompleted CompletedTasks;
+ std::swap(CompletedTasks, m_CompletedTasks);
+ return {.Bytes = Bytes, .ElapsedSeconds = ElapsedSeconds, .Completed = std::move(CompletedTasks), .Success = true};
+ }
+ }
+
+ try
+ {
+ CloudCacheSession ComputeSession(m_Client);
+
+ CloudCacheResult UpdatesResult = ComputeSession.GetComputeUpdates(m_ChannelId);
+ Log().debug("Get compute updates Bytes={} Duration={}s Result={}",
+ UpdatesResult.Bytes,
+ UpdatesResult.ElapsedSeconds,
+ UpdatesResult.Success);
+ Bytes += UpdatesResult.Bytes;
+ ElapsedSeconds += UpdatesResult.ElapsedSeconds;
+ if (!UpdatesResult.Success)
+ {
+ return {.Error{.ErrorCode = UpdatesResult.ErrorCode, .Reason = std::move(UpdatesResult.Reason)},
+ .Bytes = Bytes,
+ .ElapsedSeconds = ElapsedSeconds};
+ }
+
+ if (!UpdatesResult.Success)
+ {
+ return {.Error{.ErrorCode = -1, .Reason = "Failed get task updates"}, .Bytes = Bytes, .ElapsedSeconds = ElapsedSeconds};
+ }
+
+ CbObject TaskStatus = LoadCompactBinaryObject(std::move(UpdatesResult.Response));
+
+ for (auto& It : TaskStatus["u"sv])
+ {
+ CbObjectView Status = It.AsObjectView();
+ IoHash TaskId = Status["h"sv].AsHash();
+ const ComputeTaskState State = (ComputeTaskState)Status["s"sv].AsInt32();
+ const ComputeTaskOutcome Outcome = (ComputeTaskOutcome)Status["o"sv].AsInt32();
+
+ Log().info("Task {} State={}", TaskId, ComputeTaskStateToString(State));
+
+ // Only completed tasks need to be processed
+ if (State != ComputeTaskState::Complete)
+ {
+ continue;
+ }
+
+ IoHash WorkerId{};
+ IoHash ActionId{};
+ UpstreamApplyType ApplyType{};
+
+ {
+ std::scoped_lock Lock(m_TaskMutex);
+ auto TaskIt = m_PendingTasks.find(TaskId);
+ if (TaskIt != m_PendingTasks.end())
+ {
+ WorkerId = TaskIt->second.WorkerDescriptor.GetHash();
+ ActionId = TaskIt->second.Action.GetHash();
+ ApplyType = TaskIt->second.Type;
+ m_PendingTasks.erase(TaskIt);
+ }
+ }
+
+ if (WorkerId == IoHash::Zero)
+ {
+ Log().warn("Task {} missing from pending tasks", TaskId);
+ continue;
+ }
+
+ std::map<std::string, uint64_t> Timepoints;
+ ProcessQueueTimings(Status["qs"sv].AsObjectView(), Timepoints);
+ ProcessExecuteTimings(Status["es"sv].AsObjectView(), Timepoints);
+
+ if (Outcome != ComputeTaskOutcome::Success)
+ {
+ const std::string_view Detail = Status["d"sv].AsString();
+ {
+ std::scoped_lock Lock(m_TaskMutex);
+ m_CompletedTasks[WorkerId][ActionId] = {
+ .Error{.ErrorCode = -1, .Reason = fmt::format("Task {} {}", ComputeTaskOutcomeToString(Outcome), Detail)},
+ .Timepoints = std::move(Timepoints)};
+ }
+ continue;
+ }
+
+ Timepoints["zen-complete-queue-added"] = DateTime::NowTicks();
+ ThreadPool.ScheduleWork([this,
+ ApplyType,
+ ResultHash = Status["r"sv].AsHash(),
+ Timepoints = std::move(Timepoints),
+ TaskId = std::move(TaskId),
+ WorkerId = std::move(WorkerId),
+ ActionId = std::move(ActionId)]() mutable {
+ Timepoints["zen-complete-queue-dispatched"] = DateTime::NowTicks();
+ GetUpstreamApplyResult Result = ProcessTaskStatus(ApplyType, ResultHash);
+ Timepoints["zen-complete-queue-complete"] = DateTime::NowTicks();
+ Result.Timepoints.merge(Timepoints);
+
+ Log().debug("Task Processed {} Files={} Attachments={} ExitCode={}",
+ TaskId,
+ Result.OutputFiles.size(),
+ Result.OutputPackage.GetAttachments().size(),
+ Result.Error.ErrorCode);
+ {
+ std::scoped_lock Lock(m_TaskMutex);
+ m_CompletedTasks[WorkerId][ActionId] = std::move(Result);
+ }
+ });
+ }
+
+ {
+ std::scoped_lock Lock(m_TaskMutex);
+ if (m_CompletedTasks.empty())
+ {
+ // Nothing to do.
+ return {.Bytes = Bytes, .ElapsedSeconds = ElapsedSeconds, .Success = true};
+ }
+ UpstreamApplyCompleted CompletedTasks;
+ std::swap(CompletedTasks, m_CompletedTasks);
+ return {.Bytes = Bytes, .ElapsedSeconds = ElapsedSeconds, .Completed = std::move(CompletedTasks), .Success = true};
+ }
+ }
+ catch (std::exception& Err)
+ {
+ m_HealthOk = false;
+ return {
+ .Error{.ErrorCode = -1, .Reason = Err.what()},
+ .Bytes = Bytes,
+ .ElapsedSeconds = ElapsedSeconds,
+ };
+ }
+ }
+
+ virtual UpstreamApplyEndpointStats& Stats() override { return m_Stats; }
+
+ private:
+ spdlog::logger& Log() { return m_Log; }
+
+ spdlog::logger& m_Log;
+ CidStore& m_CidStore;
+ AuthMgr& m_AuthMgr;
+ std::string m_DisplayName;
+ RefPtr<CloudCacheClient> m_Client;
+ RefPtr<CloudCacheClient> m_StorageClient;
+ UpstreamApplyEndpointStats m_Stats;
+ std::atomic_bool m_HealthOk{false};
+ std::string m_ChannelId;
+
+ std::mutex m_TaskMutex;
+ std::unordered_map<IoHash, UpstreamApplyRecord> m_PendingTasks;
+ UpstreamApplyCompleted m_CompletedTasks;
+
+ struct UpstreamData
+ {
+ std::map<IoHash, IoBuffer> Blobs;
+ std::map<IoHash, CbObject> Objects;
+ std::set<IoHash> CasIds;
+ std::set<IoHash> Cids;
+ IoHash TaskId;
+ IoHash RequirementsId;
+ };
+
+ struct UpstreamDirectory
+ {
+ std::filesystem::path Path;
+ std::map<std::string, UpstreamDirectory> Directories;
+ std::set<std::string> Files;
+ };
+
+ static void ProcessQueueTimings(CbObjectView QueueStats, std::map<std::string, uint64_t>& Timepoints)
+ {
+ uint64_t Ticks = QueueStats["t"sv].AsDateTimeTicks();
+ if (Ticks == 0)
+ {
+ return;
+ }
+
+ // Scope is an array of miliseconds after start time
+ // TODO: cleanup
+ Timepoints["horde-queue-added"] = Ticks;
+ int Index = 0;
+ for (auto& Item : QueueStats["s"sv].AsArrayView())
+ {
+ Ticks += Item.AsInt32() * TimeSpan::TicksPerMillisecond;
+ switch (Index)
+ {
+ case 0:
+ Timepoints["horde-queue-dispatched"] = Ticks;
+ break;
+ case 1:
+ Timepoints["horde-queue-complete"] = Ticks;
+ break;
+ }
+ Index++;
+ }
+ }
+
+ static void ProcessExecuteTimings(CbObjectView ExecutionStats, std::map<std::string, uint64_t>& Timepoints)
+ {
+ uint64_t Ticks = ExecutionStats["t"sv].AsDateTimeTicks();
+ if (Ticks == 0)
+ {
+ return;
+ }
+
+ // Scope is an array of miliseconds after start time
+ // TODO: cleanup
+ Timepoints["horde-execution-start"] = Ticks;
+ int Index = 0;
+ for (auto& Item : ExecutionStats["s"sv].AsArrayView())
+ {
+ Ticks += Item.AsInt32() * TimeSpan::TicksPerMillisecond;
+ switch (Index)
+ {
+ case 0:
+ Timepoints["horde-execution-download-ref"] = Ticks;
+ break;
+ case 1:
+ Timepoints["horde-execution-download-input"] = Ticks;
+ break;
+ case 2:
+ Timepoints["horde-execution-execute"] = Ticks;
+ break;
+ case 3:
+ Timepoints["horde-execution-upload-log"] = Ticks;
+ break;
+ case 4:
+ Timepoints["horde-execution-upload-output"] = Ticks;
+ break;
+ case 5:
+ Timepoints["horde-execution-upload-ref"] = Ticks;
+ break;
+ }
+ Index++;
+ }
+ }
+
+ [[nodiscard]] GetUpstreamApplyResult ProcessTaskStatus(const UpstreamApplyType ApplyType, const IoHash& ResultHash)
+ {
+ try
+ {
+ CloudCacheSession Session(m_StorageClient);
+
+ GetUpstreamApplyResult ApplyResult{};
+
+ IoHash StdOutHash;
+ IoHash StdErrHash;
+ IoHash OutputHash;
+
+ std::map<IoHash, IoBuffer> BinaryData;
+
+ {
+ CloudCacheResult ObjectRefResult =
+ Session.GetRef(Session.Client().DefaultBlobStoreNamespace(), "responses"sv, ResultHash, ZenContentType::kCbObject);
+ Log().debug("Get ref {} Bytes={} Duration={}s Result={}",
+ ResultHash,
+ ObjectRefResult.Bytes,
+ ObjectRefResult.ElapsedSeconds,
+ ObjectRefResult.Success);
+ ApplyResult.Bytes += ObjectRefResult.Bytes;
+ ApplyResult.ElapsedSeconds += ObjectRefResult.ElapsedSeconds;
+ ApplyResult.Timepoints["zen-storage-get-ref"] = DateTime::NowTicks();
+
+ if (!ObjectRefResult.Success)
+ {
+ ApplyResult.Error.Reason = "Failed to get result object data";
+ return ApplyResult;
+ }
+
+ CbObject ResultObject = LoadCompactBinaryObject(ObjectRefResult.Response);
+ ApplyResult.Error.ErrorCode = ResultObject["e"sv].AsInt32();
+ StdOutHash = ResultObject["so"sv].AsBinaryAttachment();
+ StdErrHash = ResultObject["se"sv].AsBinaryAttachment();
+ OutputHash = ResultObject["o"sv].AsObjectAttachment();
+ }
+
+ {
+ std::set<IoHash> NeededData;
+ if (OutputHash != IoHash::Zero)
+ {
+ GetObjectReferencesResult ObjectReferenceResult =
+ Session.GetObjectReferences(Session.Client().DefaultBlobStoreNamespace(), OutputHash);
+ Log().debug("Get object references {} References={} Bytes={} Duration={}s Result={}",
+ ResultHash,
+ ObjectReferenceResult.References.size(),
+ ObjectReferenceResult.Bytes,
+ ObjectReferenceResult.ElapsedSeconds,
+ ObjectReferenceResult.Success);
+ ApplyResult.Bytes += ObjectReferenceResult.Bytes;
+ ApplyResult.ElapsedSeconds += ObjectReferenceResult.ElapsedSeconds;
+ ApplyResult.Timepoints["zen-storage-get-object-references"] = DateTime::NowTicks();
+
+ if (!ObjectReferenceResult.Success)
+ {
+ ApplyResult.Error.Reason = "Failed to get result object references";
+ return ApplyResult;
+ }
+
+ NeededData = std::move(ObjectReferenceResult.References);
+ }
+
+ NeededData.insert(OutputHash);
+ NeededData.insert(StdOutHash);
+ NeededData.insert(StdErrHash);
+
+ for (const auto& Hash : NeededData)
+ {
+ if (Hash == IoHash::Zero)
+ {
+ continue;
+ }
+ CloudCacheResult BlobResult = Session.GetBlob(Session.Client().DefaultBlobStoreNamespace(), Hash);
+ Log().debug("Get blob {} Bytes={} Duration={}s Result={}",
+ Hash,
+ BlobResult.Bytes,
+ BlobResult.ElapsedSeconds,
+ BlobResult.Success);
+ ApplyResult.Bytes += BlobResult.Bytes;
+ ApplyResult.ElapsedSeconds += BlobResult.ElapsedSeconds;
+ if (!BlobResult.Success)
+ {
+ ApplyResult.Error.Reason = "Failed to get blob";
+ return ApplyResult;
+ }
+ BinaryData[Hash] = std::move(BlobResult.Response);
+ }
+ ApplyResult.Timepoints["zen-storage-get-blobs"] = DateTime::NowTicks();
+ }
+
+ ApplyResult.StdOut = StdOutHash != IoHash::Zero
+ ? std::string((const char*)BinaryData[StdOutHash].GetData(), BinaryData[StdOutHash].GetSize())
+ : "";
+ ApplyResult.StdErr = StdErrHash != IoHash::Zero
+ ? std::string((const char*)BinaryData[StdErrHash].GetData(), BinaryData[StdErrHash].GetSize())
+ : "";
+
+ if (OutputHash == IoHash::Zero)
+ {
+ ApplyResult.Error.Reason = "Task completed with no output object";
+ return ApplyResult;
+ }
+
+ CbObject OutputObject = LoadCompactBinaryObject(BinaryData[OutputHash]);
+
+ switch (ApplyType)
+ {
+ case UpstreamApplyType::Simple:
+ {
+ ResolveMerkleTreeDirectory(""sv, OutputHash, BinaryData, ApplyResult.OutputFiles);
+ for (const auto& Pair : BinaryData)
+ {
+ ApplyResult.FileData[Pair.first] = std::move(BinaryData.at(Pair.first));
+ }
+
+ ApplyResult.Success = ApplyResult.Error.ErrorCode == 0;
+ return ApplyResult;
+ }
+ break;
+ case UpstreamApplyType::Asset:
+ {
+ if (ApplyResult.Error.ErrorCode != 0)
+ {
+ ApplyResult.Error.Reason = "Task completed with errors";
+ return ApplyResult;
+ }
+
+ // Get build.output
+ IoHash BuildOutputId;
+ IoBuffer BuildOutput;
+ for (auto& It : OutputObject["f"sv])
+ {
+ const CbObjectView FileObject = It.AsObjectView();
+ if (FileObject["n"sv].AsString() == "Build.output"sv)
+ {
+ BuildOutputId = FileObject["h"sv].AsBinaryAttachment();
+ BuildOutput = BinaryData[BuildOutputId];
+ break;
+ }
+ }
+
+ if (BuildOutput.GetSize() == 0)
+ {
+ ApplyResult.Error.Reason = "Build.output file not found in task results";
+ return ApplyResult;
+ }
+
+ // Get Output directory node
+ IoBuffer OutputDirectoryTree;
+ for (auto& It : OutputObject["d"sv])
+ {
+ const CbObjectView DirectoryObject = It.AsObjectView();
+ if (DirectoryObject["n"sv].AsString() == "Outputs"sv)
+ {
+ OutputDirectoryTree = BinaryData[DirectoryObject["h"sv].AsObjectAttachment()];
+ break;
+ }
+ }
+
+ if (OutputDirectoryTree.GetSize() == 0)
+ {
+ ApplyResult.Error.Reason = "Outputs directory not found in task results";
+ return ApplyResult;
+ }
+
+ // load build.output as CbObject
+
+ // Move Outputs from Horde to CbPackage
+
+ std::unordered_map<IoHash, IoHash> CidToCompressedId;
+ CbPackage OutputPackage;
+ CbObject OutputDirectoryTreeObject = LoadCompactBinaryObject(OutputDirectoryTree);
+
+ for (auto& It : OutputDirectoryTreeObject["f"sv])
+ {
+ CbObjectView FileObject = It.AsObjectView();
+ // Name is the uncompressed hash
+ IoHash DecompressedId = IoHash::FromHexString(FileObject["n"sv].AsString());
+ // Hash is the compressed data hash, and how it is stored in Horde
+ IoHash CompressedId = FileObject["h"sv].AsBinaryAttachment();
+
+ if (!BinaryData.contains(CompressedId))
+ {
+ Log().warn("Object attachment chunk not retrieved from Horde {}", CompressedId);
+ ApplyResult.Error.Reason = "Object attachment chunk not retrieved from Horde";
+ return ApplyResult;
+ }
+ CidToCompressedId[DecompressedId] = CompressedId;
+ }
+
+ // Iterate attachments, verify all chunks exist, and add to CbPackage
+ bool AnyErrors = false;
+ CbObject BuildOutputObject = LoadCompactBinaryObject(BuildOutput);
+ BuildOutputObject.IterateAttachments([&](CbFieldView Field) {
+ const IoHash DecompressedId = Field.AsHash();
+ if (!CidToCompressedId.contains(DecompressedId))
+ {
+ Log().warn("Attachment not found {}", DecompressedId);
+ AnyErrors = true;
+ return;
+ }
+ const IoHash& CompressedId = CidToCompressedId.at(DecompressedId);
+
+ if (!BinaryData.contains(CompressedId))
+ {
+ Log().warn("Missing output {} compressed {} uncompressed", CompressedId, DecompressedId);
+ AnyErrors = true;
+ return;
+ }
+
+ IoHash RawHash;
+ uint64_t RawSize;
+ CompressedBuffer AttachmentBuffer =
+ CompressedBuffer::FromCompressed(SharedBuffer(BinaryData[CompressedId]), RawHash, RawSize);
+
+ if (!AttachmentBuffer || RawHash != DecompressedId)
+ {
+ Log().warn(
+ "Invalid output encountered (not valid CompressedBuffer format) {} compressed {} uncompressed",
+ CompressedId,
+ DecompressedId);
+ AnyErrors = true;
+ return;
+ }
+
+ ApplyResult.TotalAttachmentBytes += AttachmentBuffer.GetCompressedSize();
+ ApplyResult.TotalRawAttachmentBytes += RawSize;
+
+ CbAttachment Attachment(AttachmentBuffer, DecompressedId);
+ OutputPackage.AddAttachment(Attachment);
+ });
+
+ if (AnyErrors)
+ {
+ ApplyResult.Error.Reason = "Failed to get result object attachment data";
+ return ApplyResult;
+ }
+
+ OutputPackage.SetObject(BuildOutputObject);
+ ApplyResult.OutputPackage = std::move(OutputPackage);
+
+ ApplyResult.Success = ApplyResult.Error.ErrorCode == 0;
+ return ApplyResult;
+ }
+ break;
+ }
+
+ ApplyResult.Error.Reason = "Unknown apply type";
+ return ApplyResult;
+ }
+ catch (std::exception& Err)
+ {
+ return {.Error{.ErrorCode = -1, .Reason = Err.what()}};
+ }
+ }
+
+ [[nodiscard]] bool ProcessApplyKey(const UpstreamApplyRecord& ApplyRecord, UpstreamData& Data)
+ {
+ std::string ExecutablePath;
+ std::string WorkingDirectory;
+ std::vector<std::string> Arguments;
+ std::map<std::string, std::string> Environment;
+ std::set<std::filesystem::path> InputFiles;
+ std::set<std::string> Outputs;
+ std::map<std::filesystem::path, IoHash> InputFileHashes;
+
+ ExecutablePath = ApplyRecord.WorkerDescriptor["path"sv].AsString();
+ if (ExecutablePath.empty())
+ {
+ Log().warn("process apply upstream FAILED, '{}', path missing from worker descriptor",
+ ApplyRecord.WorkerDescriptor.GetHash());
+ return false;
+ }
+
+ WorkingDirectory = ApplyRecord.WorkerDescriptor["workdir"sv].AsString();
+
+ for (auto& It : ApplyRecord.WorkerDescriptor["executables"sv])
+ {
+ CbObjectView FileEntry = It.AsObjectView();
+ if (!ProcessFileEntry(FileEntry, InputFiles, InputFileHashes, Data.CasIds))
+ {
+ return false;
+ }
+ }
+
+ for (auto& It : ApplyRecord.WorkerDescriptor["files"sv])
+ {
+ CbObjectView FileEntry = It.AsObjectView();
+ if (!ProcessFileEntry(FileEntry, InputFiles, InputFileHashes, Data.CasIds))
+ {
+ return false;
+ }
+ }
+
+ for (auto& It : ApplyRecord.WorkerDescriptor["dirs"sv])
+ {
+ std::string_view Directory = It.AsString();
+ std::string DummyFile = fmt::format("{}/.zen_empty_file", Directory);
+ InputFiles.insert(DummyFile);
+ Data.Blobs[EmptyBufferId] = EmptyBuffer;
+ InputFileHashes[DummyFile] = EmptyBufferId;
+ }
+
+ if (!WorkingDirectory.empty())
+ {
+ std::string DummyFile = fmt::format("{}/.zen_empty_file", WorkingDirectory);
+ InputFiles.insert(DummyFile);
+ Data.Blobs[EmptyBufferId] = EmptyBuffer;
+ InputFileHashes[DummyFile] = EmptyBufferId;
+ }
+
+ for (auto& It : ApplyRecord.WorkerDescriptor["environment"sv])
+ {
+ std::string_view Env = It.AsString();
+ auto Index = Env.find('=');
+ if (Index == std::string_view::npos)
+ {
+ Log().warn("process apply upstream FAILED, environment '{}' malformed", Env);
+ return false;
+ }
+
+ Environment[std::string(Env.substr(0, Index))] = Env.substr(Index + 1);
+ }
+
+ switch (ApplyRecord.Type)
+ {
+ case UpstreamApplyType::Simple:
+ {
+ for (auto& It : ApplyRecord.WorkerDescriptor["arguments"sv])
+ {
+ Arguments.push_back(std::string(It.AsString()));
+ }
+
+ for (auto& It : ApplyRecord.WorkerDescriptor["outputs"sv])
+ {
+ Outputs.insert(std::string(It.AsString()));
+ }
+ }
+ break;
+ case UpstreamApplyType::Asset:
+ {
+ static const std::filesystem::path BuildActionPath = "Build.action"sv;
+ static const std::filesystem::path InputPath = "Inputs"sv;
+ const IoHash ActionId = ApplyRecord.Action.GetHash();
+
+ Arguments.push_back("-Build=build.action");
+ Outputs.insert("Build.output");
+ Outputs.insert("Outputs");
+
+ InputFiles.insert(BuildActionPath);
+ InputFileHashes[BuildActionPath] = ActionId;
+ Data.Blobs[ActionId] = IoBufferBuilder::MakeCloneFromMemory(ApplyRecord.Action.GetBuffer().GetData(),
+ ApplyRecord.Action.GetBuffer().GetSize());
+
+ bool AnyErrors = false;
+ ApplyRecord.Action.IterateAttachments([&](CbFieldView Field) {
+ const IoHash Cid = Field.AsHash();
+ const std::filesystem::path FilePath = {InputPath / Cid.ToHexString()};
+
+ if (!m_CidStore.ContainsChunk(Cid))
+ {
+ Log().warn("process apply upstream FAILED, input CID chunk '{}' missing", Cid);
+ AnyErrors = true;
+ return;
+ }
+
+ if (InputFiles.contains(FilePath))
+ {
+ return;
+ }
+
+ InputFiles.insert(FilePath);
+ InputFileHashes[FilePath] = Cid;
+ Data.Cids.insert(Cid);
+ });
+
+ if (AnyErrors)
+ {
+ return false;
+ }
+ }
+ break;
+ }
+
+ const UpstreamDirectory RootDirectory = BuildDirectoryTree(InputFiles);
+
+ CbObject Sandbox = BuildMerkleTreeDirectory(RootDirectory, InputFileHashes, Data.Cids, Data.Objects);
+ const IoHash SandboxHash = Sandbox.GetHash();
+ Data.Objects[SandboxHash] = std::move(Sandbox);
+
+ {
+ std::string_view HostPlatform = ApplyRecord.WorkerDescriptor["host"sv].AsString();
+ if (HostPlatform.empty())
+ {
+ Log().warn("process apply upstream FAILED, 'host' platform not provided");
+ return false;
+ }
+
+ int32_t LogicalCores = ApplyRecord.WorkerDescriptor["cores"sv].AsInt32();
+ int64_t Memory = ApplyRecord.WorkerDescriptor["memory"sv].AsInt64();
+ bool Exclusive = ApplyRecord.WorkerDescriptor["exclusive"sv].AsBool();
+
+ std::string Condition = fmt::format("Platform == '{}'", HostPlatform);
+ if (HostPlatform == "Win64")
+ {
+ // TODO
+ // Condition += " && Pool == 'Win-RemoteExec'";
+ }
+
+ std::map<std::string_view, int64_t> Resources;
+ if (LogicalCores > 0)
+ {
+ Resources["LogicalCores"sv] = LogicalCores;
+ }
+ if (Memory > 0)
+ {
+ Resources["RAM"sv] = std::max(Memory / 1024LL / 1024LL / 1024LL, 1LL);
+ }
+
+ CbObject Requirements = BuildRequirements(Condition, Resources, Exclusive);
+ const IoHash RequirementsId = Requirements.GetHash();
+ Data.Objects[RequirementsId] = std::move(Requirements);
+ Data.RequirementsId = RequirementsId;
+ }
+
+ CbObject Task = BuildTask(ExecutablePath, Arguments, Environment, WorkingDirectory, SandboxHash, Data.RequirementsId, Outputs);
+
+ const IoHash TaskId = Task.GetHash();
+ Data.Objects[TaskId] = std::move(Task);
+ Data.TaskId = TaskId;
+
+ return true;
+ }
+
+ [[nodiscard]] bool ProcessFileEntry(const CbObjectView& FileEntry,
+ std::set<std::filesystem::path>& InputFiles,
+ std::map<std::filesystem::path, IoHash>& InputFileHashes,
+ std::set<IoHash>& CasIds)
+ {
+ const std::filesystem::path FilePath = FileEntry["name"sv].AsString();
+ const IoHash ChunkId = FileEntry["hash"sv].AsHash();
+ const uint64_t Size = FileEntry["size"sv].AsUInt64();
+
+ if (!m_CidStore.ContainsChunk(ChunkId))
+ {
+ Log().warn("process apply upstream FAILED, worker CAS chunk '{}' missing", ChunkId);
+ return false;
+ }
+
+ if (InputFiles.contains(FilePath))
+ {
+ Log().warn("process apply upstream FAILED, worker CAS chunk '{}' size: {} duplicate filename {}", ChunkId, Size, FilePath);
+ return false;
+ }
+
+ InputFiles.insert(FilePath);
+ InputFileHashes[FilePath] = ChunkId;
+ CasIds.insert(ChunkId);
+ return true;
+ }
+
+ [[nodiscard]] UpstreamDirectory BuildDirectoryTree(const std::set<std::filesystem::path>& InputFiles)
+ {
+ static const std::filesystem::path RootPath;
+ std::map<std::filesystem::path, UpstreamDirectory*> AllDirectories;
+ UpstreamDirectory RootDirectory = {.Path = RootPath};
+
+ AllDirectories[RootPath] = &RootDirectory;
+
+ // Build tree from flat list
+ for (const auto& Path : InputFiles)
+ {
+ if (Path.has_parent_path())
+ {
+ if (!AllDirectories.contains(Path.parent_path()))
+ {
+ std::stack<std::string> PathSplit;
+ {
+ std::filesystem::path ParentPath = Path.parent_path();
+ PathSplit.push(ParentPath.filename().string());
+ while (ParentPath.has_parent_path())
+ {
+ ParentPath = ParentPath.parent_path();
+ PathSplit.push(ParentPath.filename().string());
+ }
+ }
+ UpstreamDirectory* ParentPtr = &RootDirectory;
+ while (!PathSplit.empty())
+ {
+ if (!ParentPtr->Directories.contains(PathSplit.top()))
+ {
+ std::filesystem::path NewParentPath = {ParentPtr->Path / PathSplit.top()};
+ ParentPtr->Directories[PathSplit.top()] = {.Path = NewParentPath};
+ AllDirectories[NewParentPath] = &ParentPtr->Directories[PathSplit.top()];
+ }
+ ParentPtr = &ParentPtr->Directories[PathSplit.top()];
+ PathSplit.pop();
+ }
+ }
+
+ AllDirectories[Path.parent_path()]->Files.insert(Path.filename().string());
+ }
+ else
+ {
+ RootDirectory.Files.insert(Path.filename().string());
+ }
+ }
+
+ return RootDirectory;
+ }
+
+ [[nodiscard]] CbObject BuildMerkleTreeDirectory(const UpstreamDirectory& RootDirectory,
+ const std::map<std::filesystem::path, IoHash>& InputFileHashes,
+ const std::set<IoHash>& Cids,
+ std::map<IoHash, CbObject>& Objects)
+ {
+ CbObjectWriter DirectoryTreeWriter;
+
+ if (!RootDirectory.Files.empty())
+ {
+ DirectoryTreeWriter.BeginArray("f"sv);
+ for (const auto& File : RootDirectory.Files)
+ {
+ const std::filesystem::path FilePath = {RootDirectory.Path / File};
+ const IoHash& FileHash = InputFileHashes.at(FilePath);
+ const bool Compressed = Cids.contains(FileHash);
+ DirectoryTreeWriter.BeginObject();
+ DirectoryTreeWriter.AddString("n"sv, File);
+ DirectoryTreeWriter.AddBinaryAttachment("h"sv, FileHash);
+ DirectoryTreeWriter.AddBool("c"sv, Compressed);
+ DirectoryTreeWriter.EndObject();
+ }
+ DirectoryTreeWriter.EndArray();
+ }
+
+ if (!RootDirectory.Directories.empty())
+ {
+ DirectoryTreeWriter.BeginArray("d"sv);
+ for (const auto& Item : RootDirectory.Directories)
+ {
+ CbObject Directory = BuildMerkleTreeDirectory(Item.second, InputFileHashes, Cids, Objects);
+ const IoHash DirectoryHash = Directory.GetHash();
+ Objects[DirectoryHash] = std::move(Directory);
+
+ DirectoryTreeWriter.BeginObject();
+ DirectoryTreeWriter.AddString("n"sv, Item.first);
+ DirectoryTreeWriter.AddObjectAttachment("h"sv, DirectoryHash);
+ DirectoryTreeWriter.EndObject();
+ }
+ DirectoryTreeWriter.EndArray();
+ }
+
+ return DirectoryTreeWriter.Save();
+ }
+
+ void ResolveMerkleTreeDirectory(const std::filesystem::path& ParentDirectory,
+ const IoHash& DirectoryHash,
+ const std::map<IoHash, IoBuffer>& Objects,
+ std::map<std::filesystem::path, IoHash>& OutputFiles)
+ {
+ CbObject Directory = LoadCompactBinaryObject(Objects.at(DirectoryHash));
+
+ for (auto& It : Directory["f"sv])
+ {
+ const CbObjectView FileObject = It.AsObjectView();
+ const std::filesystem::path Path = ParentDirectory / FileObject["n"sv].AsString();
+
+ OutputFiles[Path] = FileObject["h"sv].AsBinaryAttachment();
+ }
+
+ for (auto& It : Directory["d"sv])
+ {
+ const CbObjectView DirectoryObject = It.AsObjectView();
+
+ ResolveMerkleTreeDirectory(ParentDirectory / DirectoryObject["n"sv].AsString(),
+ DirectoryObject["h"sv].AsObjectAttachment(),
+ Objects,
+ OutputFiles);
+ }
+ }
+
+ [[nodiscard]] CbObject BuildRequirements(const std::string_view Condition,
+ const std::map<std::string_view, int64_t>& Resources,
+ const bool Exclusive)
+ {
+ CbObjectWriter Writer;
+ Writer.AddString("c", Condition);
+ if (!Resources.empty())
+ {
+ Writer.BeginArray("r");
+ for (const auto& Resource : Resources)
+ {
+ Writer.BeginArray();
+ Writer.AddString(Resource.first);
+ Writer.AddInteger(Resource.second);
+ Writer.EndArray();
+ }
+ Writer.EndArray();
+ }
+ Writer.AddBool("e", Exclusive);
+ return Writer.Save();
+ }
+
+ [[nodiscard]] CbObject BuildTask(const std::string_view Executable,
+ const std::vector<std::string>& Arguments,
+ const std::map<std::string, std::string>& Environment,
+ const std::string_view WorkingDirectory,
+ const IoHash& SandboxHash,
+ const IoHash& RequirementsId,
+ const std::set<std::string>& Outputs)
+ {
+ CbObjectWriter TaskWriter;
+ TaskWriter.AddString("e"sv, Executable);
+
+ if (!Arguments.empty())
+ {
+ TaskWriter.BeginArray("a"sv);
+ for (const auto& Argument : Arguments)
+ {
+ TaskWriter.AddString(Argument);
+ }
+ TaskWriter.EndArray();
+ }
+
+ if (!Environment.empty())
+ {
+ TaskWriter.BeginArray("v"sv);
+ for (const auto& Env : Environment)
+ {
+ TaskWriter.BeginArray();
+ TaskWriter.AddString(Env.first);
+ TaskWriter.AddString(Env.second);
+ TaskWriter.EndArray();
+ }
+ TaskWriter.EndArray();
+ }
+
+ if (!WorkingDirectory.empty())
+ {
+ TaskWriter.AddString("w"sv, WorkingDirectory);
+ }
+
+ TaskWriter.AddObjectAttachment("s"sv, SandboxHash);
+ TaskWriter.AddObjectAttachment("r"sv, RequirementsId);
+
+ // Outputs
+ if (!Outputs.empty())
+ {
+ TaskWriter.BeginArray("o"sv);
+ for (const auto& Output : Outputs)
+ {
+ TaskWriter.AddString(Output);
+ }
+ TaskWriter.EndArray();
+ }
+
+ return TaskWriter.Save();
+ }
+ };
+} // namespace detail
+
+//////////////////////////////////////////////////////////////////////////
+
+std::unique_ptr<UpstreamApplyEndpoint>
+UpstreamApplyEndpoint::CreateHordeEndpoint(const CloudCacheClientOptions& ComputeOptions,
+ const UpstreamAuthConfig& ComputeAuthConfig,
+ const CloudCacheClientOptions& StorageOptions,
+ const UpstreamAuthConfig& StorageAuthConfig,
+ CidStore& CidStore,
+ AuthMgr& Mgr)
+{
+ return std::make_unique<detail::HordeUpstreamApplyEndpoint>(ComputeOptions,
+ ComputeAuthConfig,
+ StorageOptions,
+ StorageAuthConfig,
+ CidStore,
+ Mgr);
+}
+
+} // namespace zen
+
+#endif // ZEN_WITH_COMPUTE_SERVICES
diff --git a/src/zenserver/upstream/jupiter.cpp b/src/zenserver/upstream/jupiter.cpp
new file mode 100644
index 000000000..dbb185bec
--- /dev/null
+++ b/src/zenserver/upstream/jupiter.cpp
@@ -0,0 +1,965 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "jupiter.h"
+
+#include "diag/formatters.h"
+#include "diag/logging.h"
+
+#include <zencore/compactbinary.h>
+#include <zencore/compositebuffer.h>
+#include <zencore/iobuffer.h>
+#include <zencore/iohash.h>
+#include <zencore/string.h>
+#include <zencore/thread.h>
+#include <zencore/trace.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <cpr/cpr.h>
+#include <fmt/format.h>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+#if ZEN_PLATFORM_WINDOWS
+# pragma comment(lib, "Crypt32.lib")
+# pragma comment(lib, "Wldap32.lib")
+#endif
+
+#include <json11.hpp>
+
+using namespace std::literals;
+
+namespace zen {
+
+namespace detail {
+ struct CloudCacheSessionState
+ {
+ CloudCacheSessionState(CloudCacheClient& Client) : m_Client(Client) {}
+
+ const CloudCacheAccessToken& GetAccessToken(bool RefreshToken)
+ {
+ if (RefreshToken)
+ {
+ m_AccessToken = m_Client.AcquireAccessToken();
+ }
+
+ return m_AccessToken;
+ }
+
+ cpr::Session& GetSession() { return m_Session; }
+
+ void Reset(std::chrono::milliseconds ConnectTimeout, std::chrono::milliseconds Timeout)
+ {
+ m_Session.SetBody({});
+ m_Session.SetHeader({});
+ m_Session.SetConnectTimeout(ConnectTimeout);
+ m_Session.SetTimeout(Timeout);
+ }
+
+ private:
+ friend class zen::CloudCacheClient;
+
+ CloudCacheClient& m_Client;
+ CloudCacheAccessToken m_AccessToken;
+ cpr::Session m_Session;
+ };
+
+} // namespace detail
+
+CloudCacheSession::CloudCacheSession(CloudCacheClient* CacheClient) : m_Log(CacheClient->Logger()), m_CacheClient(CacheClient)
+{
+ m_SessionState = m_CacheClient->AllocSessionState();
+}
+
+CloudCacheSession::~CloudCacheSession()
+{
+ m_CacheClient->FreeSessionState(m_SessionState);
+}
+
+CloudCacheResult
+CloudCacheSession::Authenticate()
+{
+ const bool RefreshToken = true;
+ const CloudCacheAccessToken& AccessToken = GetAccessToken(RefreshToken);
+
+ return {.Success = AccessToken.IsValid()};
+}
+
+CloudCacheResult
+CloudCacheSession::GetRef(std::string_view Namespace, std::string_view BucketId, const IoHash& Key, ZenContentType RefType)
+{
+ const std::string ContentType = RefType == ZenContentType::kCbObject ? "application/x-ue-cb" : "application/octet-stream";
+
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_CacheClient->ServiceUrl() << "/api/v1/refs/" << Namespace << "/" << BucketId << "/" << Key.ToHexString();
+
+ cpr::Session& Session = GetSession();
+ const CloudCacheAccessToken& AccessToken = GetAccessToken();
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+ Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, {"Accept", ContentType}});
+ Session.SetOption(cpr::Body{});
+
+ cpr::Response Response = Session.Get();
+ ZEN_DEBUG("GET {}", Response);
+
+ if (Response.error)
+ {
+ return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message};
+ }
+ else if (!VerifyAccessToken(Response.status_code))
+ {
+ return {.ErrorCode = 401, .Reason = std::string("Invalid access token")};
+ }
+
+ const bool Success = Response.status_code == 200;
+ const IoBuffer Buffer = Success ? IoBufferBuilder::MakeCloneFromMemory(Response.text.data(), Response.text.size()) : IoBuffer();
+
+ return {.Response = Buffer, .Bytes = Response.downloaded_bytes, .ElapsedSeconds = Response.elapsed, .Success = Success};
+}
+
+CloudCacheResult
+CloudCacheSession::GetBlob(std::string_view Namespace, const IoHash& Key)
+{
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_CacheClient->ServiceUrl() << "/api/v1/blobs/" << Namespace << "/" << Key.ToHexString();
+
+ cpr::Session& Session = GetSession();
+ const CloudCacheAccessToken& AccessToken = GetAccessToken();
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+ Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, {"Accept", "application/octet-stream"}});
+ Session.SetOption(cpr::Body{});
+
+ cpr::Response Response = Session.Get();
+ ZEN_DEBUG("GET {}", Response);
+
+ if (Response.error)
+ {
+ return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message};
+ }
+ else if (!VerifyAccessToken(Response.status_code))
+ {
+ return {.ErrorCode = 401, .Reason = std::string("Invalid access token")};
+ }
+
+ const bool Success = Response.status_code == 200;
+ const IoBuffer Buffer =
+ Success && Response.text.size() > 0 ? IoBufferBuilder::MakeCloneFromMemory(Response.text.data(), Response.text.size()) : IoBuffer();
+
+ return {.Response = Buffer, .Bytes = Response.downloaded_bytes, .ElapsedSeconds = Response.elapsed, .Success = Success};
+}
+
+CloudCacheResult
+CloudCacheSession::GetCompressedBlob(std::string_view Namespace, const IoHash& Key)
+{
+ ZEN_TRACE_CPU("HordeClient::GetCompressedBlob");
+
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_CacheClient->ServiceUrl() << "/api/v1/compressed-blobs/" << Namespace << "/" << Key.ToHexString();
+
+ cpr::Session& Session = GetSession();
+ const CloudCacheAccessToken& AccessToken = GetAccessToken();
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+ Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, {"Accept", "application/x-ue-comp"}});
+ Session.SetOption(cpr::Body{});
+
+ cpr::Response Response = Session.Get();
+ ZEN_DEBUG("GET {}", Response);
+
+ if (Response.error)
+ {
+ return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message};
+ }
+ else if (!VerifyAccessToken(Response.status_code))
+ {
+ return {.ErrorCode = 401, .Reason = std::string("Invalid access token")};
+ }
+
+ const bool Success = Response.status_code == 200;
+ const IoBuffer Buffer = Success ? IoBufferBuilder::MakeCloneFromMemory(Response.text.data(), Response.text.size()) : IoBuffer();
+
+ return {.Response = Buffer, .Bytes = Response.downloaded_bytes, .ElapsedSeconds = Response.elapsed, .Success = Success};
+}
+
+CloudCacheResult
+CloudCacheSession::GetInlineBlob(std::string_view Namespace, std::string_view BucketId, const IoHash& Key, IoHash& OutPayloadHash)
+{
+ ZEN_TRACE_CPU("HordeClient::GetInlineBlob");
+
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_CacheClient->ServiceUrl() << "/api/v1/refs/" << Namespace << "/" << BucketId << "/" << Key.ToHexString();
+
+ cpr::Session& Session = GetSession();
+ const CloudCacheAccessToken& AccessToken = GetAccessToken();
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+ Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, {"Accept", "application/x-jupiter-inline"}});
+ Session.SetOption(cpr::Body{});
+
+ cpr::Response Response = Session.Get();
+ ZEN_DEBUG("GET {}", Response);
+
+ if (Response.error)
+ {
+ return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message};
+ }
+ else if (!VerifyAccessToken(Response.status_code))
+ {
+ return {.ErrorCode = 401, .Reason = std::string("Invalid access token")};
+ }
+
+ const bool Success = Response.status_code == 200;
+ const IoBuffer Buffer = Success ? IoBufferBuilder::MakeCloneFromMemory(Response.text.data(), Response.text.size()) : IoBuffer();
+ if (auto It = Response.header.find("X-Jupiter-InlinePayloadHash"); It != Response.header.end())
+ {
+ const std::string& PayloadHashHeader = It->second;
+ if (PayloadHashHeader.length() == IoHash::StringLength)
+ {
+ OutPayloadHash = IoHash::FromHexString(PayloadHashHeader);
+ }
+ }
+
+ return {.Response = Buffer, .Bytes = Response.downloaded_bytes, .ElapsedSeconds = Response.elapsed, .Success = Success};
+}
+
+CloudCacheResult
+CloudCacheSession::GetObject(std::string_view Namespace, const IoHash& Key)
+{
+ ZEN_TRACE_CPU("HordeClient::GetObject");
+
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_CacheClient->ServiceUrl() << "/api/v1/objects/" << Namespace << "/" << Key.ToHexString();
+
+ cpr::Session& Session = GetSession();
+ const CloudCacheAccessToken& AccessToken = GetAccessToken();
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+ Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, {"Accept", "application/x-ue-cb"}});
+ Session.SetOption(cpr::Body{});
+
+ cpr::Response Response = Session.Get();
+ ZEN_DEBUG("GET {}", Response);
+
+ if (Response.error)
+ {
+ return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message};
+ }
+ else if (!VerifyAccessToken(Response.status_code))
+ {
+ return {.ErrorCode = 401, .Reason = std::string("Invalid access token")};
+ }
+
+ const bool Success = Response.status_code == 200;
+ const IoBuffer Buffer = Success ? IoBufferBuilder::MakeCloneFromMemory(Response.text.data(), Response.text.size()) : IoBuffer();
+
+ return {.Response = Buffer, .Bytes = Response.downloaded_bytes, .ElapsedSeconds = Response.elapsed, .Success = Success};
+}
+
+PutRefResult
+CloudCacheSession::PutRef(std::string_view Namespace, std::string_view BucketId, const IoHash& Key, IoBuffer Ref, ZenContentType RefType)
+{
+ ZEN_TRACE_CPU("HordeClient::PutRef");
+
+ IoHash Hash = IoHash::HashBuffer(Ref.Data(), Ref.Size());
+
+ const std::string ContentType = RefType == ZenContentType::kCbObject ? "application/x-ue-cb" : "application/octet-stream";
+
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_CacheClient->ServiceUrl() << "/api/v1/refs/" << Namespace << "/" << BucketId << "/" << Key.ToHexString();
+
+ cpr::Session& Session = GetSession();
+ const CloudCacheAccessToken& AccessToken = GetAccessToken();
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+ Session.SetOption(
+ cpr::Header{{"Authorization", AccessToken.Value}, {"X-Jupiter-IoHash", Hash.ToHexString()}, {"Content-Type", ContentType}});
+ Session.SetBody(cpr::Body{(const char*)Ref.Data(), Ref.Size()});
+
+ cpr::Response Response = Session.Put();
+ ZEN_DEBUG("PUT {}", Response);
+
+ if (Response.error)
+ {
+ PutRefResult Result;
+ Result.ErrorCode = static_cast<int32_t>(Response.error.code);
+ Result.Reason = std::move(Response.error.message);
+ return Result;
+ }
+ else if (!VerifyAccessToken(Response.status_code))
+ {
+ PutRefResult Result;
+ Result.ErrorCode = 401;
+ Result.Reason = "Invalid access token"sv;
+ return Result;
+ }
+
+ PutRefResult Result;
+ Result.Success = (Response.status_code == 200 || Response.status_code == 201);
+ Result.Bytes = Response.uploaded_bytes;
+ Result.ElapsedSeconds = Response.elapsed;
+
+ if (Result.Success)
+ {
+ std::string JsonError;
+ json11::Json Json = json11::Json::parse(Response.text, JsonError);
+ if (JsonError.empty())
+ {
+ json11::Json::array Needs = Json["needs"].array_items();
+ for (const auto& Need : Needs)
+ {
+ Result.Needs.emplace_back(IoHash::FromHexString(Need.string_value()));
+ }
+ }
+ }
+
+ return Result;
+}
+
+FinalizeRefResult
+CloudCacheSession::FinalizeRef(std::string_view Namespace, std::string_view BucketId, const IoHash& Key, const IoHash& RefHash)
+{
+ ZEN_TRACE_CPU("HordeClient::FinalizeRef");
+
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_CacheClient->ServiceUrl() << "/api/v1/refs/" << Namespace << "/" << BucketId << "/" << Key.ToHexString() << "/finalize/"
+ << RefHash.ToHexString();
+
+ cpr::Session& Session = GetSession();
+ const CloudCacheAccessToken& AccessToken = GetAccessToken();
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+ Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value},
+ {"X-Jupiter-IoHash", RefHash.ToHexString()},
+ {"Content-Type", "application/x-ue-cb"}});
+ Session.SetBody(cpr::Body{});
+
+ cpr::Response Response = Session.Post();
+ ZEN_DEBUG("POST {}", Response);
+
+ if (Response.error)
+ {
+ FinalizeRefResult Result;
+ Result.ErrorCode = static_cast<int32_t>(Response.error.code);
+ Result.Reason = std::move(Response.error.message);
+ return Result;
+ }
+ else if (!VerifyAccessToken(Response.status_code))
+ {
+ FinalizeRefResult Result;
+ Result.ErrorCode = 401;
+ Result.Reason = "Invalid access token"sv;
+ return Result;
+ }
+
+ FinalizeRefResult Result;
+ Result.Success = (Response.status_code == 200 || Response.status_code == 201);
+ Result.Bytes = Response.uploaded_bytes;
+ Result.ElapsedSeconds = Response.elapsed;
+
+ if (Result.Success)
+ {
+ std::string JsonError;
+ json11::Json Json = json11::Json::parse(Response.text, JsonError);
+ if (JsonError.empty())
+ {
+ json11::Json::array Needs = Json["needs"].array_items();
+ for (const auto& Need : Needs)
+ {
+ Result.Needs.emplace_back(IoHash::FromHexString(Need.string_value()));
+ }
+ }
+ }
+
+ return Result;
+}
+
+CloudCacheResult
+CloudCacheSession::PutBlob(std::string_view Namespace, const IoHash& Key, IoBuffer Blob)
+{
+ ZEN_TRACE_CPU("HordeClient::PutBlob");
+
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_CacheClient->ServiceUrl() << "/api/v1/blobs/" << Namespace << "/" << Key.ToHexString();
+
+ cpr::Session& Session = GetSession();
+ const CloudCacheAccessToken& AccessToken = GetAccessToken();
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+ Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, {"Content-Type", "application/octet-stream"}});
+ Session.SetBody(cpr::Body{(const char*)Blob.Data(), Blob.Size()});
+
+ cpr::Response Response = Session.Put();
+ ZEN_DEBUG("PUT {}", Response);
+
+ if (Response.error)
+ {
+ return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message};
+ }
+ else if (!VerifyAccessToken(Response.status_code))
+ {
+ return {.ErrorCode = 401, .Reason = std::string("Invalid access token")};
+ }
+
+ return {.Bytes = Response.uploaded_bytes,
+ .ElapsedSeconds = Response.elapsed,
+ .Success = (Response.status_code == 200 || Response.status_code == 201)};
+}
+
+CloudCacheResult
+CloudCacheSession::PutCompressedBlob(std::string_view Namespace, const IoHash& Key, IoBuffer Blob)
+{
+ ZEN_TRACE_CPU("HordeClient::PutCompressedBlob");
+
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_CacheClient->ServiceUrl() << "/api/v1/compressed-blobs/" << Namespace << "/" << Key.ToHexString();
+
+ cpr::Session& Session = GetSession();
+ const CloudCacheAccessToken& AccessToken = GetAccessToken();
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+ Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, {"Content-Type", "application/x-ue-comp"}});
+ Session.SetBody(cpr::Body{(const char*)Blob.Data(), Blob.Size()});
+
+ cpr::Response Response = Session.Put();
+ ZEN_DEBUG("PUT {}", Response);
+
+ if (Response.error)
+ {
+ return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message};
+ }
+ else if (!VerifyAccessToken(Response.status_code))
+ {
+ return {.ErrorCode = 401, .Reason = std::string("Invalid access token")};
+ }
+
+ return {.Bytes = Response.uploaded_bytes,
+ .ElapsedSeconds = Response.elapsed,
+ .Success = (Response.status_code == 200 || Response.status_code == 201)};
+}
+
+CloudCacheResult
+CloudCacheSession::PutCompressedBlob(std::string_view Namespace, const IoHash& Key, const CompositeBuffer& Payload)
+{
+ ZEN_TRACE_CPU("HordeClient::PutCompressedBlob");
+
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_CacheClient->ServiceUrl() << "/api/v1/compressed-blobs/" << Namespace << "/" << Key.ToHexString();
+
+ cpr::Session& Session = GetSession();
+ const CloudCacheAccessToken& AccessToken = GetAccessToken();
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+ Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, {"Content-Type", "application/x-ue-comp"}});
+ uint64_t SizeLeft = Payload.GetSize();
+ CompositeBuffer::Iterator BufferIt = Payload.GetIterator(0);
+ auto ReadCallback = [&Payload, &BufferIt, &SizeLeft](char* buffer, size_t& size, intptr_t) {
+ size = Min<size_t>(size, SizeLeft);
+ MutableMemoryView Data(buffer, size);
+ Payload.CopyTo(Data, BufferIt);
+ SizeLeft -= size;
+ return true;
+ };
+ Session.SetReadCallback(cpr::ReadCallback(gsl::narrow<cpr::cpr_off_t>(SizeLeft), ReadCallback));
+
+ cpr::Response Response = Session.Put();
+ ZEN_DEBUG("PUT {}", Response);
+
+ if (Response.error)
+ {
+ return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message};
+ }
+ else if (!VerifyAccessToken(Response.status_code))
+ {
+ return {.ErrorCode = 401, .Reason = std::string("Invalid access token")};
+ }
+
+ return {.Bytes = Response.uploaded_bytes,
+ .ElapsedSeconds = Response.elapsed,
+ .Success = (Response.status_code == 200 || Response.status_code == 201)};
+}
+
+CloudCacheResult
+CloudCacheSession::PutObject(std::string_view Namespace, const IoHash& Key, IoBuffer Object)
+{
+ ZEN_TRACE_CPU("HordeClient::PutObject");
+
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_CacheClient->ServiceUrl() << "/api/v1/objects/" << Namespace << "/" << Key.ToHexString();
+
+ cpr::Session& Session = GetSession();
+ const CloudCacheAccessToken& AccessToken = GetAccessToken();
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+ Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, {"Content-Type", "application/x-ue-cb"}});
+ Session.SetBody(cpr::Body{(const char*)Object.Data(), Object.Size()});
+
+ cpr::Response Response = Session.Put();
+ ZEN_DEBUG("PUT {}", Response);
+
+ if (Response.error)
+ {
+ return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message};
+ }
+ else if (!VerifyAccessToken(Response.status_code))
+ {
+ return {.ErrorCode = 401, .Reason = std::string("Invalid access token")};
+ }
+
+ return {.Bytes = Response.uploaded_bytes,
+ .ElapsedSeconds = Response.elapsed,
+ .Success = (Response.status_code == 200 || Response.status_code == 201)};
+}
+
+CloudCacheResult
+CloudCacheSession::RefExists(std::string_view Namespace, std::string_view BucketId, const IoHash& Key)
+{
+ ZEN_TRACE_CPU("HordeClient::RefExists");
+
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_CacheClient->ServiceUrl() << "/api/v1/refs/" << Namespace << "/" << BucketId << "/" << Key.ToHexString();
+
+ cpr::Session& Session = GetSession();
+ const CloudCacheAccessToken& AccessToken = GetAccessToken();
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+ Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}});
+ Session.SetOption(cpr::Body{});
+
+ cpr::Response Response = Session.Head();
+ ZEN_DEBUG("HEAD {}", Response);
+
+ if (Response.error)
+ {
+ return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message};
+ }
+ else if (!VerifyAccessToken(Response.status_code))
+ {
+ return {.ErrorCode = 401, .Reason = std::string("Invalid access token")};
+ }
+
+ return {.ElapsedSeconds = Response.elapsed, .Success = Response.status_code == 200};
+}
+
+GetObjectReferencesResult
+CloudCacheSession::GetObjectReferences(std::string_view Namespace, const IoHash& Key)
+{
+ ZEN_TRACE_CPU("HordeClient::GetObjectReferences");
+
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_CacheClient->ServiceUrl() << "/api/v1/objects/" << Namespace << "/" << Key.ToHexString() << "/references";
+
+ cpr::Session& Session = GetSession();
+ const CloudCacheAccessToken& AccessToken = GetAccessToken();
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+ Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, {"Accept", "application/x-ue-cb"}});
+ Session.SetOption(cpr::Body{});
+
+ cpr::Response Response = Session.Get();
+ ZEN_DEBUG("GET {}", Response);
+
+ if (Response.error)
+ {
+ return {CloudCacheResult{.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message}};
+ }
+ else if (!VerifyAccessToken(Response.status_code))
+ {
+ return {CloudCacheResult{.ErrorCode = 401, .Reason = std::string("Invalid access token")}};
+ }
+
+ GetObjectReferencesResult Result{
+ CloudCacheResult{.Bytes = Response.downloaded_bytes, .ElapsedSeconds = Response.elapsed, .Success = Response.status_code == 200}};
+
+ if (Result.Success)
+ {
+ IoBuffer Buffer = IoBuffer(zen::IoBuffer::Wrap, Response.text.data(), Response.text.size());
+ const CbObject ReferencesResponse = LoadCompactBinaryObject(Buffer);
+ for (auto& Item : ReferencesResponse["references"sv])
+ {
+ Result.References.insert(Item.AsHash());
+ }
+ }
+
+ return Result;
+}
+
+CloudCacheResult
+CloudCacheSession::BlobExists(std::string_view Namespace, const IoHash& Key)
+{
+ return CacheTypeExists(Namespace, "blobs"sv, Key);
+}
+
+CloudCacheResult
+CloudCacheSession::CompressedBlobExists(std::string_view Namespace, const IoHash& Key)
+{
+ return CacheTypeExists(Namespace, "compressed-blobs"sv, Key);
+}
+
+CloudCacheResult
+CloudCacheSession::ObjectExists(std::string_view Namespace, const IoHash& Key)
+{
+ return CacheTypeExists(Namespace, "objects"sv, Key);
+}
+
+CloudCacheExistsResult
+CloudCacheSession::BlobExists(std::string_view Namespace, const std::set<IoHash>& Keys)
+{
+ return CacheTypeExists(Namespace, "blobs"sv, Keys);
+}
+
+CloudCacheExistsResult
+CloudCacheSession::CompressedBlobExists(std::string_view Namespace, const std::set<IoHash>& Keys)
+{
+ return CacheTypeExists(Namespace, "compressed-blobs"sv, Keys);
+}
+
+CloudCacheExistsResult
+CloudCacheSession::ObjectExists(std::string_view Namespace, const std::set<IoHash>& Keys)
+{
+ return CacheTypeExists(Namespace, "objects"sv, Keys);
+}
+
+CloudCacheResult
+CloudCacheSession::PostComputeTasks(IoBuffer TasksData)
+{
+ ZEN_TRACE_CPU("HordeClient::PostComputeTasks");
+
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_CacheClient->ServiceUrl() << "/api/v1/compute/" << m_CacheClient->ComputeCluster();
+
+ cpr::Session& Session = GetSession();
+ const CloudCacheAccessToken& AccessToken = GetAccessToken();
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+ Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, {"Content-Type", "application/x-ue-cb"}});
+ Session.SetBody(cpr::Body{(const char*)TasksData.Data(), TasksData.Size()});
+
+ cpr::Response Response = Session.Post();
+ ZEN_DEBUG("POST {}", Response);
+
+ if (Response.error)
+ {
+ return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message};
+ }
+ else if (!VerifyAccessToken(Response.status_code))
+ {
+ return {.ErrorCode = 401, .Reason = std::string("Invalid access token")};
+ }
+
+ return {.Bytes = Response.uploaded_bytes, .ElapsedSeconds = Response.elapsed, .Success = Response.status_code == 200};
+}
+
+CloudCacheResult
+CloudCacheSession::GetComputeUpdates(std::string_view ChannelId, const uint32_t WaitSeconds)
+{
+ ZEN_TRACE_CPU("HordeClient::GetComputeUpdates");
+
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_CacheClient->ServiceUrl() << "/api/v1/compute/" << m_CacheClient->ComputeCluster() << "/updates/" << ChannelId
+ << "?wait=" << WaitSeconds;
+
+ cpr::Session& Session = GetSession();
+ const CloudCacheAccessToken& AccessToken = GetAccessToken();
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+ Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, {"Accept", "application/x-ue-cb"}});
+ Session.SetOption(cpr::Body{});
+
+ cpr::Response Response = Session.Post();
+ ZEN_DEBUG("POST {}", Response);
+
+ if (Response.error)
+ {
+ return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message};
+ }
+ else if (!VerifyAccessToken(Response.status_code))
+ {
+ return {.ErrorCode = 401, .Reason = std::string("Invalid access token")};
+ }
+
+ const bool Success = Response.status_code == 200;
+ const IoBuffer Buffer = Success ? IoBufferBuilder::MakeCloneFromMemory(Response.text.data(), Response.text.size()) : IoBuffer();
+
+ return {.Response = Buffer, .Bytes = Response.downloaded_bytes, .ElapsedSeconds = Response.elapsed, .Success = Success};
+}
+
+std::vector<IoHash>
+CloudCacheSession::Filter(std::string_view Namespace, std::string_view BucketId, const std::vector<IoHash>& ChunkHashes)
+{
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_CacheClient->ServiceUrl();
+ Uri << "/api/v1/s/" << Namespace;
+
+ ZEN_UNUSED(BucketId, ChunkHashes);
+
+ return {};
+}
+
+cpr::Session&
+CloudCacheSession::GetSession()
+{
+ return m_SessionState->GetSession();
+}
+
+CloudCacheAccessToken
+CloudCacheSession::GetAccessToken(bool RefreshToken)
+{
+ return m_SessionState->GetAccessToken(RefreshToken);
+}
+
+bool
+CloudCacheSession::VerifyAccessToken(long StatusCode)
+{
+ return StatusCode != 401;
+}
+
+CloudCacheResult
+CloudCacheSession::CacheTypeExists(std::string_view Namespace, std::string_view TypeId, const IoHash& Key)
+{
+ ZEN_TRACE_CPU("HordeClient::CacheTypeExists");
+
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_CacheClient->ServiceUrl() << "/api/v1/" << TypeId << "/" << Namespace << "/" << Key.ToHexString();
+
+ cpr::Session& Session = GetSession();
+ const CloudCacheAccessToken& AccessToken = GetAccessToken();
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+ Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}});
+ Session.SetOption(cpr::Body{});
+
+ cpr::Response Response = Session.Head();
+ ZEN_DEBUG("HEAD {}", Response);
+
+ if (Response.error)
+ {
+ return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message};
+ }
+ else if (!VerifyAccessToken(Response.status_code))
+ {
+ return {.ErrorCode = 401, .Reason = std::string("Invalid access token")};
+ }
+
+ return {.ElapsedSeconds = Response.elapsed, .Success = Response.status_code == 200};
+}
+
+CloudCacheExistsResult
+CloudCacheSession::CacheTypeExists(std::string_view Namespace, std::string_view TypeId, const std::set<IoHash>& Keys)
+{
+ ZEN_TRACE_CPU("HordeClient::CacheTypeExists");
+
+ ExtendableStringBuilder<256> Body;
+ Body << "[";
+ for (const auto& Key : Keys)
+ {
+ Body << (Body.Size() != 1 ? ",\"" : "\"") << Key.ToHexString() << "\"";
+ }
+ Body << "]";
+
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_CacheClient->ServiceUrl() << "/api/v1/" << TypeId << "/" << Namespace << "/exist";
+
+ cpr::Session& Session = GetSession();
+ const CloudCacheAccessToken& AccessToken = GetAccessToken();
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+ Session.SetOption(
+ cpr::Header{{"Authorization", AccessToken.Value}, {"Accept", "application/x-ue-cb"}, {"Content-Type", "application/json"}});
+ Session.SetOption(cpr::Body(Body.ToString()));
+
+ cpr::Response Response = Session.Post();
+ ZEN_DEBUG("POST {}", Response);
+
+ if (Response.error)
+ {
+ return {CloudCacheResult{.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message}};
+ }
+ else if (!VerifyAccessToken(Response.status_code))
+ {
+ return {CloudCacheResult{.ErrorCode = 401, .Reason = std::string("Invalid access token")}};
+ }
+
+ CloudCacheExistsResult Result{
+ CloudCacheResult{.Bytes = Response.downloaded_bytes, .ElapsedSeconds = Response.elapsed, .Success = Response.status_code == 200}};
+
+ if (Result.Success)
+ {
+ IoBuffer Buffer = IoBuffer(zen::IoBuffer::Wrap, Response.text.data(), Response.text.size());
+ const CbObject ExistsResponse = LoadCompactBinaryObject(Buffer);
+ for (auto& Item : ExistsResponse["needs"sv])
+ {
+ Result.Needs.insert(Item.AsHash());
+ }
+ }
+
+ return Result;
+}
+
+/**
+ * An access token provider that holds a token that will never change.
+ */
+class StaticTokenProvider final : public CloudCacheTokenProvider
+{
+public:
+ StaticTokenProvider(CloudCacheAccessToken Token) : m_Token(std::move(Token)) {}
+
+ virtual ~StaticTokenProvider() = default;
+
+ virtual CloudCacheAccessToken AcquireAccessToken() final override { return m_Token; }
+
+private:
+ CloudCacheAccessToken m_Token;
+};
+
+std::unique_ptr<CloudCacheTokenProvider>
+CloudCacheTokenProvider::CreateFromStaticToken(CloudCacheAccessToken Token)
+{
+ return std::make_unique<StaticTokenProvider>(std::move(Token));
+}
+
+class OAuthClientCredentialsTokenProvider final : public CloudCacheTokenProvider
+{
+public:
+ OAuthClientCredentialsTokenProvider(const CloudCacheTokenProvider::OAuthClientCredentialsParams& Params)
+ {
+ m_Url = std::string(Params.Url);
+ m_ClientId = std::string(Params.ClientId);
+ m_ClientSecret = std::string(Params.ClientSecret);
+ }
+
+ virtual ~OAuthClientCredentialsTokenProvider() = default;
+
+ virtual CloudCacheAccessToken AcquireAccessToken() final override
+ {
+ using namespace std::chrono;
+
+ std::string Body =
+ fmt::format("client_id={}&scope=cache_access&grant_type=client_credentials&client_secret={}", m_ClientId, m_ClientSecret);
+
+ cpr::Response Response =
+ cpr::Post(cpr::Url{m_Url}, cpr::Header{{"Content-Type", "application/x-www-form-urlencoded"}}, cpr::Body{std::move(Body)});
+
+ if (Response.error || Response.status_code != 200)
+ {
+ return {};
+ }
+
+ std::string JsonError;
+ json11::Json Json = json11::Json::parse(Response.text, JsonError);
+
+ if (JsonError.empty() == false)
+ {
+ return {};
+ }
+
+ std::string Token = Json["access_token"].string_value();
+ int64_t ExpiresInSeconds = static_cast<int64_t>(Json["expires_in"].int_value());
+ CloudCacheAccessToken::TimePoint ExpireTime = CloudCacheAccessToken::Clock::now() + seconds(ExpiresInSeconds);
+
+ return {.Value = fmt::format("Bearer {}", Token), .ExpireTime = ExpireTime};
+ }
+
+private:
+ std::string m_Url;
+ std::string m_ClientId;
+ std::string m_ClientSecret;
+};
+
+std::unique_ptr<CloudCacheTokenProvider>
+CloudCacheTokenProvider::CreateFromOAuthClientCredentials(const OAuthClientCredentialsParams& Params)
+{
+ return std::make_unique<OAuthClientCredentialsTokenProvider>(Params);
+}
+
+class CallbackTokenProvider final : public CloudCacheTokenProvider
+{
+public:
+ CallbackTokenProvider(std::function<CloudCacheAccessToken()>&& Callback) : m_Callback(std::move(Callback)) {}
+
+ virtual ~CallbackTokenProvider() = default;
+
+ virtual CloudCacheAccessToken AcquireAccessToken() final override { return m_Callback(); }
+
+private:
+ std::function<CloudCacheAccessToken()> m_Callback;
+};
+
+std::unique_ptr<CloudCacheTokenProvider>
+CloudCacheTokenProvider::CreateFromCallback(std::function<CloudCacheAccessToken()>&& Callback)
+{
+ return std::make_unique<CallbackTokenProvider>(std::move(Callback));
+}
+
+CloudCacheClient::CloudCacheClient(const CloudCacheClientOptions& Options, std::unique_ptr<CloudCacheTokenProvider> TokenProvider)
+: m_Log(zen::logging::Get("jupiter"))
+, m_ServiceUrl(Options.ServiceUrl)
+, m_DefaultDdcNamespace(Options.DdcNamespace)
+, m_DefaultBlobStoreNamespace(Options.BlobStoreNamespace)
+, m_ComputeCluster(Options.ComputeCluster)
+, m_ConnectTimeout(Options.ConnectTimeout)
+, m_Timeout(Options.Timeout)
+, m_TokenProvider(std::move(TokenProvider))
+{
+ ZEN_ASSERT(m_TokenProvider.get() != nullptr);
+}
+
+CloudCacheClient::~CloudCacheClient()
+{
+ RwLock::ExclusiveLockScope _(m_SessionStateLock);
+
+ for (auto State : m_SessionStateCache)
+ {
+ delete State;
+ }
+}
+
+CloudCacheAccessToken
+CloudCacheClient::AcquireAccessToken()
+{
+ ZEN_TRACE_CPU("HordeClient::AcquireAccessToken");
+
+ return m_TokenProvider->AcquireAccessToken();
+}
+
+detail::CloudCacheSessionState*
+CloudCacheClient::AllocSessionState()
+{
+ detail::CloudCacheSessionState* State = nullptr;
+
+ bool IsTokenValid = false;
+
+ {
+ RwLock::ExclusiveLockScope _(m_SessionStateLock);
+
+ if (m_SessionStateCache.empty() == false)
+ {
+ State = m_SessionStateCache.front();
+ IsTokenValid = State->m_AccessToken.IsValid();
+
+ m_SessionStateCache.pop_front();
+ }
+ }
+
+ if (State == nullptr)
+ {
+ State = new detail::CloudCacheSessionState(*this);
+ }
+
+ State->Reset(m_ConnectTimeout, m_Timeout);
+
+ if (IsTokenValid == false)
+ {
+ State->m_AccessToken = m_TokenProvider->AcquireAccessToken();
+ }
+
+ return State;
+}
+
+void
+CloudCacheClient::FreeSessionState(detail::CloudCacheSessionState* State)
+{
+ RwLock::ExclusiveLockScope _(m_SessionStateLock);
+ m_SessionStateCache.push_front(State);
+}
+
+} // namespace zen
diff --git a/src/zenserver/upstream/jupiter.h b/src/zenserver/upstream/jupiter.h
new file mode 100644
index 000000000..99e5c530f
--- /dev/null
+++ b/src/zenserver/upstream/jupiter.h
@@ -0,0 +1,217 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/iohash.h>
+#include <zencore/logging.h>
+#include <zencore/refcount.h>
+#include <zencore/thread.h>
+#include <zenhttp/httpserver.h>
+
+#include <atomic>
+#include <chrono>
+#include <list>
+#include <memory>
+#include <set>
+#include <vector>
+
+struct ZenCacheValue;
+
+namespace cpr {
+class Session;
+}
+
+namespace zen {
+namespace detail {
+ struct CloudCacheSessionState;
+}
+
+class CbObjectView;
+class CloudCacheClient;
+class IoBuffer;
+struct IoHash;
+
+/**
+ * Cached access token, for use with `Authorization:` header
+ */
+struct CloudCacheAccessToken
+{
+ using Clock = std::chrono::system_clock;
+ using TimePoint = Clock::time_point;
+
+ static constexpr int64_t ExpireMarginInSeconds = 30;
+
+ std::string Value;
+ TimePoint ExpireTime;
+
+ bool IsValid() const
+ {
+ return Value.empty() == false &&
+ ExpireMarginInSeconds < std::chrono::duration_cast<std::chrono::seconds>(ExpireTime - Clock::now()).count();
+ }
+};
+
+struct CloudCacheResult
+{
+ IoBuffer Response;
+ int64_t Bytes{};
+ double ElapsedSeconds{};
+ int32_t ErrorCode{};
+ std::string Reason;
+ bool Success = false;
+};
+
+struct PutRefResult : CloudCacheResult
+{
+ std::vector<IoHash> Needs;
+};
+
+struct FinalizeRefResult : CloudCacheResult
+{
+ std::vector<IoHash> Needs;
+};
+
+struct CloudCacheExistsResult : CloudCacheResult
+{
+ std::set<IoHash> Needs;
+};
+
+struct GetObjectReferencesResult : CloudCacheResult
+{
+ std::set<IoHash> References;
+};
+
+/**
+ * Context for performing Jupiter operations
+ *
+ * Maintains an HTTP connection so that subsequent operations don't need to go
+ * through the whole connection setup process
+ *
+ */
+class CloudCacheSession
+{
+public:
+ CloudCacheSession(CloudCacheClient* CacheClient);
+ ~CloudCacheSession();
+
+ CloudCacheResult Authenticate();
+ CloudCacheResult GetRef(std::string_view Namespace, std::string_view BucketId, const IoHash& Key, ZenContentType RefType);
+ CloudCacheResult GetBlob(std::string_view Namespace, const IoHash& Key);
+ CloudCacheResult GetCompressedBlob(std::string_view Namespace, const IoHash& Key);
+ CloudCacheResult GetObject(std::string_view Namespace, const IoHash& Key);
+ CloudCacheResult GetInlineBlob(std::string_view Namespace, std::string_view BucketId, const IoHash& Key, IoHash& OutPayloadHash);
+
+ PutRefResult PutRef(std::string_view Namespace, std::string_view BucketId, const IoHash& Key, IoBuffer Ref, ZenContentType RefType);
+ CloudCacheResult PutBlob(std::string_view Namespace, const IoHash& Key, IoBuffer Blob);
+ CloudCacheResult PutCompressedBlob(std::string_view Namespace, const IoHash& Key, IoBuffer Blob);
+ CloudCacheResult PutCompressedBlob(std::string_view Namespace, const IoHash& Key, const CompositeBuffer& Blob);
+ CloudCacheResult PutObject(std::string_view Namespace, const IoHash& Key, IoBuffer Object);
+
+ FinalizeRefResult FinalizeRef(std::string_view Namespace, std::string_view BucketId, const IoHash& Key, const IoHash& RefHah);
+
+ CloudCacheResult RefExists(std::string_view Namespace, std::string_view BucketId, const IoHash& Key);
+
+ GetObjectReferencesResult GetObjectReferences(std::string_view Namespace, const IoHash& Key);
+
+ CloudCacheResult BlobExists(std::string_view Namespace, const IoHash& Key);
+ CloudCacheResult CompressedBlobExists(std::string_view Namespace, const IoHash& Key);
+ CloudCacheResult ObjectExists(std::string_view Namespace, const IoHash& Key);
+
+ CloudCacheExistsResult BlobExists(std::string_view Namespace, const std::set<IoHash>& Keys);
+ CloudCacheExistsResult CompressedBlobExists(std::string_view Namespace, const std::set<IoHash>& Keys);
+ CloudCacheExistsResult ObjectExists(std::string_view Namespace, const std::set<IoHash>& Keys);
+
+ CloudCacheResult PostComputeTasks(IoBuffer TasksData);
+ CloudCacheResult GetComputeUpdates(std::string_view ChannelId, const uint32_t WaitSeconds = 0);
+
+ std::vector<IoHash> Filter(std::string_view Namespace, std::string_view BucketId, const std::vector<IoHash>& ChunkHashes);
+
+ CloudCacheClient& Client() { return *m_CacheClient; };
+
+private:
+ inline spdlog::logger& Log() { return m_Log; }
+ cpr::Session& GetSession();
+ CloudCacheAccessToken GetAccessToken(bool RefreshToken = false);
+ bool VerifyAccessToken(long StatusCode);
+
+ CloudCacheResult CacheTypeExists(std::string_view Namespace, std::string_view TypeId, const IoHash& Key);
+
+ CloudCacheExistsResult CacheTypeExists(std::string_view Namespace, std::string_view TypeId, const std::set<IoHash>& Keys);
+
+ spdlog::logger& m_Log;
+ RefPtr<CloudCacheClient> m_CacheClient;
+ detail::CloudCacheSessionState* m_SessionState;
+};
+
+/**
+ * Access token provider interface
+ */
+class CloudCacheTokenProvider
+{
+public:
+ virtual ~CloudCacheTokenProvider() = default;
+
+ virtual CloudCacheAccessToken AcquireAccessToken() = 0;
+
+ static std::unique_ptr<CloudCacheTokenProvider> CreateFromStaticToken(CloudCacheAccessToken Token);
+
+ struct OAuthClientCredentialsParams
+ {
+ std::string_view Url;
+ std::string_view ClientId;
+ std::string_view ClientSecret;
+ };
+
+ static std::unique_ptr<CloudCacheTokenProvider> CreateFromOAuthClientCredentials(const OAuthClientCredentialsParams& Params);
+
+ static std::unique_ptr<CloudCacheTokenProvider> CreateFromCallback(std::function<CloudCacheAccessToken()>&& Callback);
+};
+
+struct CloudCacheClientOptions
+{
+ std::string_view Name;
+ std::string_view ServiceUrl;
+ std::string_view DdcNamespace;
+ std::string_view BlobStoreNamespace;
+ std::string_view ComputeCluster;
+ std::chrono::milliseconds ConnectTimeout{5000};
+ std::chrono::milliseconds Timeout{};
+};
+
+/**
+ * Jupiter upstream cache client
+ */
+class CloudCacheClient : public RefCounted
+{
+public:
+ CloudCacheClient(const CloudCacheClientOptions& Options, std::unique_ptr<CloudCacheTokenProvider> TokenProvider);
+ ~CloudCacheClient();
+
+ CloudCacheAccessToken AcquireAccessToken();
+ std::string_view DefaultDdcNamespace() const { return m_DefaultDdcNamespace; }
+ std::string_view DefaultBlobStoreNamespace() const { return m_DefaultBlobStoreNamespace; }
+ std::string_view ComputeCluster() const { return m_ComputeCluster; }
+ std::string_view ServiceUrl() const { return m_ServiceUrl; }
+
+ spdlog::logger& Logger() { return m_Log; }
+
+private:
+ spdlog::logger& m_Log;
+ std::string m_ServiceUrl;
+ std::string m_DefaultDdcNamespace;
+ std::string m_DefaultBlobStoreNamespace;
+ std::string m_ComputeCluster;
+ std::chrono::milliseconds m_ConnectTimeout{};
+ std::chrono::milliseconds m_Timeout{};
+ std::unique_ptr<CloudCacheTokenProvider> m_TokenProvider;
+
+ RwLock m_SessionStateLock;
+ std::list<detail::CloudCacheSessionState*> m_SessionStateCache;
+
+ detail::CloudCacheSessionState* AllocSessionState();
+ void FreeSessionState(detail::CloudCacheSessionState*);
+
+ friend class CloudCacheSession;
+};
+
+} // namespace zen
diff --git a/src/zenserver/upstream/upstream.h b/src/zenserver/upstream/upstream.h
new file mode 100644
index 000000000..a57301206
--- /dev/null
+++ b/src/zenserver/upstream/upstream.h
@@ -0,0 +1,8 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <upstream/jupiter.h>
+#include <upstream/upstreamcache.h>
+#include <upstream/upstreamservice.h>
+#include <upstream/zen.h>
diff --git a/src/zenserver/upstream/upstreamapply.cpp b/src/zenserver/upstream/upstreamapply.cpp
new file mode 100644
index 000000000..c719b225d
--- /dev/null
+++ b/src/zenserver/upstream/upstreamapply.cpp
@@ -0,0 +1,459 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "upstreamapply.h"
+
+#if ZEN_WITH_COMPUTE_SERVICES
+
+# include <zencore/compactbinary.h>
+# include <zencore/compactbinarybuilder.h>
+# include <zencore/fmtutils.h>
+# include <zencore/stream.h>
+# include <zencore/timer.h>
+# include <zencore/workthreadpool.h>
+
+# include <zenstore/cidstore.h>
+
+# include "diag/logging.h"
+
+# include <fmt/format.h>
+
+# include <atomic>
+
+namespace zen {
+
+using namespace std::literals;
+
+struct UpstreamApplyStats
+{
+ static constexpr uint64_t MaxSampleCount = 1000ull;
+
+ UpstreamApplyStats(bool Enabled) : m_Enabled(Enabled) {}
+
+ void Add(UpstreamApplyEndpoint& Endpoint, const PostUpstreamApplyResult& Result)
+ {
+ UpstreamApplyEndpointStats& Stats = Endpoint.Stats();
+
+ if (Result.Error)
+ {
+ Stats.ErrorCount.Increment(1);
+ }
+ else if (Result.Success)
+ {
+ Stats.PostCount.Increment(1);
+ Stats.UpBytes.Increment(Result.Bytes / 1024 / 1024);
+ }
+ }
+
+ void Add(UpstreamApplyEndpoint& Endpoint, const GetUpstreamApplyUpdatesResult& Result)
+ {
+ UpstreamApplyEndpointStats& Stats = Endpoint.Stats();
+
+ if (Result.Error)
+ {
+ Stats.ErrorCount.Increment(1);
+ }
+ else if (Result.Success)
+ {
+ Stats.UpdateCount.Increment(1);
+ Stats.DownBytes.Increment(Result.Bytes / 1024 / 1024);
+ if (!Result.Completed.empty())
+ {
+ uint64_t Completed = 0;
+ for (auto& It : Result.Completed)
+ {
+ Completed += It.second.size();
+ }
+ Stats.CompleteCount.Increment(Completed);
+ }
+ }
+ }
+
+ bool m_Enabled;
+};
+
+//////////////////////////////////////////////////////////////////////////
+
+class UpstreamApplyImpl final : public UpstreamApply
+{
+public:
+ UpstreamApplyImpl(const UpstreamApplyOptions& Options, CidStore& CidStore)
+ : m_Log(logging::Get("upstream-apply"))
+ , m_Options(Options)
+ , m_CidStore(CidStore)
+ , m_Stats(Options.StatsEnabled)
+ , m_UpstreamAsyncWorkPool(Options.UpstreamThreadCount)
+ , m_DownstreamAsyncWorkPool(Options.DownstreamThreadCount)
+ {
+ }
+
+ virtual ~UpstreamApplyImpl() { Shutdown(); }
+
+ virtual bool Initialize() override
+ {
+ for (auto& Endpoint : m_Endpoints)
+ {
+ const UpstreamEndpointHealth Health = Endpoint->Initialize();
+ if (Health.Ok)
+ {
+ Log().info("initialize endpoint '{}' OK", Endpoint->DisplayName());
+ }
+ else
+ {
+ Log().warn("initialize endpoint '{}' FAILED, reason '{}'", Endpoint->DisplayName(), Health.Reason);
+ }
+ }
+
+ m_RunState.IsRunning = !m_Endpoints.empty();
+
+ if (m_RunState.IsRunning)
+ {
+ m_ShutdownEvent.Reset();
+
+ m_UpstreamUpdatesThread = std::thread(&UpstreamApplyImpl::ProcessUpstreamUpdates, this);
+
+ m_EndpointMonitorThread = std::thread(&UpstreamApplyImpl::MonitorEndpoints, this);
+ }
+
+ return m_RunState.IsRunning;
+ }
+
+ virtual bool IsHealthy() const override
+ {
+ if (m_RunState.IsRunning)
+ {
+ for (const auto& Endpoint : m_Endpoints)
+ {
+ if (Endpoint->IsHealthy())
+ {
+ return true;
+ }
+ }
+ }
+
+ return false;
+ }
+
+ virtual void RegisterEndpoint(std::unique_ptr<UpstreamApplyEndpoint> Endpoint) override
+ {
+ m_Endpoints.emplace_back(std::move(Endpoint));
+ }
+
+ virtual EnqueueResult EnqueueUpstream(UpstreamApplyRecord ApplyRecord) override
+ {
+ if (m_RunState.IsRunning)
+ {
+ const IoHash WorkerId = ApplyRecord.WorkerDescriptor.GetHash();
+ const IoHash ActionId = ApplyRecord.Action.GetHash();
+ const uint32_t TimeoutSeconds = ApplyRecord.WorkerDescriptor["timeout"sv].AsInt32(300);
+
+ {
+ std::scoped_lock Lock(m_ApplyTasksMutex);
+ if (auto Status = FindStatus(WorkerId, ActionId); Status != nullptr)
+ {
+ // Already in progress
+ return {.ApplyId = ActionId, .Success = true};
+ }
+
+ std::chrono::steady_clock::time_point ExpireTime =
+ TimeoutSeconds > 0 ? std::chrono::steady_clock::now() + std::chrono::seconds(TimeoutSeconds)
+ : std::chrono::steady_clock::time_point::max();
+
+ m_ApplyTasks[WorkerId][ActionId] = {.State = UpstreamApplyState::Queued, .Result{}, .ExpireTime = std::move(ExpireTime)};
+ }
+
+ ApplyRecord.Timepoints["zen-queue-added"] = DateTime::NowTicks();
+ m_UpstreamAsyncWorkPool.ScheduleWork(
+ [this, ApplyRecord = std::move(ApplyRecord)]() { ProcessApplyRecord(std::move(ApplyRecord)); });
+
+ return {.ApplyId = ActionId, .Success = true};
+ }
+
+ return {};
+ }
+
+ virtual StatusResult GetStatus(const IoHash& WorkerId, const IoHash& ActionId) override
+ {
+ if (m_RunState.IsRunning)
+ {
+ std::scoped_lock Lock(m_ApplyTasksMutex);
+ if (auto Status = FindStatus(WorkerId, ActionId); Status != nullptr)
+ {
+ return {.Status = *Status, .Success = true};
+ }
+ }
+
+ return {};
+ }
+
+ virtual void GetStatus(CbObjectWriter& Status) override
+ {
+ Status << "upstream_worker_threads" << m_Options.UpstreamThreadCount;
+ Status << "upstream_queue_count" << m_UpstreamAsyncWorkPool.PendingWork();
+ Status << "downstream_worker_threads" << m_Options.DownstreamThreadCount;
+ Status << "downstream_queue_count" << m_DownstreamAsyncWorkPool.PendingWork();
+
+ Status.BeginArray("endpoints");
+ for (const auto& Ep : m_Endpoints)
+ {
+ Status.BeginObject();
+ Status << "name" << Ep->DisplayName();
+ Status << "health" << (Ep->IsHealthy() ? "ok"sv : "inactive"sv);
+
+ UpstreamApplyEndpointStats& Stats = Ep->Stats();
+ const uint64_t PostCount = Stats.PostCount.Value();
+ const uint64_t CompleteCount = Stats.CompleteCount.Value();
+ // const uint64_t UpdateCount = Stats.UpdateCount;
+ const double CompleteRate = CompleteCount > 0 ? (double(PostCount) / double(CompleteCount)) : 0.0;
+
+ Status << "post_count" << PostCount;
+ Status << "complete_count" << PostCount;
+ Status << "update_count" << Stats.UpdateCount.Value();
+
+ Status << "complete_ratio" << CompleteRate;
+ Status << "downloaded_mb" << Stats.DownBytes.Value();
+ Status << "uploaded_mb" << Stats.UpBytes.Value();
+ Status << "error_count" << Stats.ErrorCount.Value();
+
+ Status.EndObject();
+ }
+ Status.EndArray();
+ }
+
+private:
+ // The caller is responsible for locking if required
+ UpstreamApplyStatus* FindStatus(const IoHash& WorkerId, const IoHash& ActionId)
+ {
+ if (auto It = m_ApplyTasks.find(WorkerId); It != m_ApplyTasks.end())
+ {
+ if (auto It2 = It->second.find(ActionId); It2 != It->second.end())
+ {
+ return &It2->second;
+ }
+ }
+ return nullptr;
+ }
+
+ void ProcessApplyRecord(UpstreamApplyRecord ApplyRecord)
+ {
+ const IoHash WorkerId = ApplyRecord.WorkerDescriptor.GetHash();
+ const IoHash ActionId = ApplyRecord.Action.GetHash();
+ try
+ {
+ for (auto& Endpoint : m_Endpoints)
+ {
+ if (Endpoint->IsHealthy())
+ {
+ ApplyRecord.Timepoints["zen-queue-dispatched"] = DateTime::NowTicks();
+ PostUpstreamApplyResult Result = Endpoint->PostApply(std::move(ApplyRecord));
+ {
+ std::scoped_lock Lock(m_ApplyTasksMutex);
+ if (auto Status = FindStatus(WorkerId, ActionId); Status != nullptr)
+ {
+ Status->Timepoints.merge(Result.Timepoints);
+
+ if (Result.Success)
+ {
+ Status->State = UpstreamApplyState::Executing;
+ }
+ else
+ {
+ Status->State = UpstreamApplyState::Complete;
+ Status->Result = {.Error = std::move(Result.Error),
+ .Bytes = Result.Bytes,
+ .ElapsedSeconds = Result.ElapsedSeconds};
+ }
+ }
+ }
+ m_Stats.Add(*Endpoint, Result);
+ return;
+ }
+ }
+
+ Log().warn("process upstream apply ({}/{}) FAILED 'No available endpoint'", WorkerId, ActionId);
+
+ {
+ std::scoped_lock Lock(m_ApplyTasksMutex);
+ if (auto Status = FindStatus(WorkerId, ActionId); Status != nullptr)
+ {
+ Status->State = UpstreamApplyState::Complete;
+ Status->Result = {.Error{.ErrorCode = -1, .Reason = "No available endpoint"}};
+ }
+ }
+ }
+ catch (std::exception& e)
+ {
+ std::scoped_lock Lock(m_ApplyTasksMutex);
+ if (auto Status = FindStatus(WorkerId, ActionId); Status != nullptr)
+ {
+ Status->State = UpstreamApplyState::Complete;
+ Status->Result = {.Error{.ErrorCode = -1, .Reason = e.what()}};
+ }
+ Log().warn("process upstream apply ({}/{}) FAILED '{}'", WorkerId, ActionId, e.what());
+ }
+ }
+
+ void ProcessApplyUpdates()
+ {
+ for (auto& Endpoint : m_Endpoints)
+ {
+ if (Endpoint->IsHealthy())
+ {
+ GetUpstreamApplyUpdatesResult Result = Endpoint->GetUpdates(m_DownstreamAsyncWorkPool);
+ m_Stats.Add(*Endpoint, Result);
+
+ if (!Result.Success)
+ {
+ Log().warn("process upstream apply updates FAILED '{}'", Result.Error.Reason);
+ }
+
+ if (!Result.Completed.empty())
+ {
+ for (auto& It : Result.Completed)
+ {
+ for (auto& It2 : It.second)
+ {
+ std::scoped_lock Lock(m_ApplyTasksMutex);
+ if (auto Status = FindStatus(It.first, It2.first); Status != nullptr)
+ {
+ Status->State = UpstreamApplyState::Complete;
+ Status->Result = std::move(It2.second);
+ Status->Result.Timepoints.merge(Status->Timepoints);
+ Status->Result.Timepoints["zen-queue-complete"] = DateTime::NowTicks();
+ Status->Timepoints.clear();
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+
+ void ProcessUpstreamUpdates()
+ {
+ const auto& UpdateSleep = std::chrono::milliseconds(m_Options.UpdatesInterval);
+ while (!m_ShutdownEvent.Wait(uint32_t(UpdateSleep.count())))
+ {
+ if (!m_RunState.IsRunning)
+ {
+ break;
+ }
+
+ ProcessApplyUpdates();
+
+ // Remove any expired tasks, regardless of state
+ {
+ std::scoped_lock Lock(m_ApplyTasksMutex);
+ for (auto& WorkerIt : m_ApplyTasks)
+ {
+ const auto Count = std::erase_if(WorkerIt.second, [](const auto& Item) {
+ return Item.second.ExpireTime < std::chrono::steady_clock::now();
+ });
+ if (Count > 0)
+ {
+ Log().debug("Removed '{}' expired tasks", Count);
+ }
+ }
+ const auto Count = std::erase_if(m_ApplyTasks, [](const auto& Item) { return Item.second.empty(); });
+ if (Count > 0)
+ {
+ Log().debug("Removed '{}' empty task lists", Count);
+ }
+ }
+ }
+ }
+
+ void MonitorEndpoints()
+ {
+ for (;;)
+ {
+ {
+ std::unique_lock Lock(m_RunState.Mutex);
+ if (m_RunState.ExitSignal.wait_for(Lock, m_Options.HealthCheckInterval, [this]() { return !m_RunState.IsRunning.load(); }))
+ {
+ break;
+ }
+ }
+
+ for (auto& Endpoint : m_Endpoints)
+ {
+ if (!Endpoint->IsHealthy())
+ {
+ if (const UpstreamEndpointHealth Health = Endpoint->CheckHealth(); Health.Ok)
+ {
+ Log().warn("health check endpoint '{}' OK", Endpoint->DisplayName(), Health.Reason);
+ }
+ else
+ {
+ Log().warn("health check endpoint '{}' FAILED, reason '{}'", Endpoint->DisplayName(), Health.Reason);
+ }
+ }
+ }
+ }
+ }
+
+ void Shutdown()
+ {
+ if (m_RunState.Stop())
+ {
+ m_ShutdownEvent.Set();
+ m_EndpointMonitorThread.join();
+ m_UpstreamUpdatesThread.join();
+ m_Endpoints.clear();
+ }
+ }
+
+ spdlog::logger& Log() { return m_Log; }
+
+ struct RunState
+ {
+ std::mutex Mutex;
+ std::condition_variable ExitSignal;
+ std::atomic_bool IsRunning{false};
+
+ bool Stop()
+ {
+ bool Stopped = false;
+ {
+ std::scoped_lock Lock(Mutex);
+ Stopped = IsRunning.exchange(false);
+ }
+ if (Stopped)
+ {
+ ExitSignal.notify_all();
+ }
+ return Stopped;
+ }
+ };
+
+ spdlog::logger& m_Log;
+ UpstreamApplyOptions m_Options;
+ CidStore& m_CidStore;
+ UpstreamApplyStats m_Stats;
+ UpstreamApplyTasks m_ApplyTasks;
+ std::mutex m_ApplyTasksMutex;
+ std::vector<std::unique_ptr<UpstreamApplyEndpoint>> m_Endpoints;
+ Event m_ShutdownEvent;
+ WorkerThreadPool m_UpstreamAsyncWorkPool;
+ WorkerThreadPool m_DownstreamAsyncWorkPool;
+ std::thread m_UpstreamUpdatesThread;
+ std::thread m_EndpointMonitorThread;
+ RunState m_RunState;
+};
+
+//////////////////////////////////////////////////////////////////////////
+
+bool
+UpstreamApply::IsHealthy() const
+{
+ return false;
+}
+
+std::unique_ptr<UpstreamApply>
+UpstreamApply::Create(const UpstreamApplyOptions& Options, CidStore& CidStore)
+{
+ return std::make_unique<UpstreamApplyImpl>(Options, CidStore);
+}
+
+} // namespace zen
+
+#endif // ZEN_WITH_COMPUTE_SERVICES
diff --git a/src/zenserver/upstream/upstreamapply.h b/src/zenserver/upstream/upstreamapply.h
new file mode 100644
index 000000000..4a095be6c
--- /dev/null
+++ b/src/zenserver/upstream/upstreamapply.h
@@ -0,0 +1,192 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#if ZEN_WITH_COMPUTE_SERVICES
+
+# include <zencore/compactbinarypackage.h>
+# include <zencore/iobuffer.h>
+# include <zencore/iohash.h>
+# include <zencore/stats.h>
+# include <zencore/zencore.h>
+
+# include <chrono>
+# include <map>
+# include <unordered_map>
+# include <unordered_set>
+
+namespace zen {
+
+class AuthMgr;
+class CbObjectWriter;
+class CidStore;
+class CloudCacheTokenProvider;
+class WorkerThreadPool;
+class ZenCacheNamespace;
+struct CloudCacheClientOptions;
+struct UpstreamAuthConfig;
+
+enum class UpstreamApplyState : int32_t
+{
+ Queued = 0,
+ Executing = 1,
+ Complete = 2,
+};
+
+enum class UpstreamApplyType
+{
+ Simple = 0,
+ Asset = 1,
+};
+
+struct UpstreamApplyRecord
+{
+ CbObject WorkerDescriptor;
+ CbObject Action;
+ UpstreamApplyType Type;
+ std::map<std::string, uint64_t> Timepoints{};
+};
+
+struct UpstreamApplyOptions
+{
+ std::chrono::seconds HealthCheckInterval{5};
+ std::chrono::seconds UpdatesInterval{5};
+ uint32_t UpstreamThreadCount = 4;
+ uint32_t DownstreamThreadCount = 4;
+ bool StatsEnabled = false;
+};
+
+struct UpstreamApplyError
+{
+ int32_t ErrorCode{};
+ std::string Reason{};
+
+ explicit operator bool() const { return ErrorCode != 0; }
+};
+
+struct PostUpstreamApplyResult
+{
+ UpstreamApplyError Error{};
+ int64_t Bytes{};
+ double ElapsedSeconds{};
+ std::map<std::string, uint64_t> Timepoints{};
+ bool Success = false;
+};
+
+struct GetUpstreamApplyResult
+{
+ // UpstreamApplyType::Simple
+ std::map<std::filesystem::path, IoHash> OutputFiles{};
+ std::map<IoHash, IoBuffer> FileData{};
+
+ // UpstreamApplyType::Asset
+ CbPackage OutputPackage{};
+ int64_t TotalAttachmentBytes{};
+ int64_t TotalRawAttachmentBytes{};
+
+ UpstreamApplyError Error{};
+ int64_t Bytes{};
+ double ElapsedSeconds{};
+ std::string StdOut{};
+ std::string StdErr{};
+ std::string Agent{};
+ std::string Detail{};
+ std::map<std::string, uint64_t> Timepoints{};
+ bool Success = false;
+};
+
+using UpstreamApplyCompleted = std::unordered_map<IoHash, std::unordered_map<IoHash, GetUpstreamApplyResult>>;
+
+struct GetUpstreamApplyUpdatesResult
+{
+ UpstreamApplyError Error{};
+ int64_t Bytes{};
+ double ElapsedSeconds{};
+ UpstreamApplyCompleted Completed{};
+ bool Success = false;
+};
+
+struct UpstreamApplyStatus
+{
+ UpstreamApplyState State{};
+ GetUpstreamApplyResult Result{};
+ std::chrono::steady_clock::time_point ExpireTime{};
+ std::map<std::string, uint64_t> Timepoints{};
+};
+
+using UpstreamApplyTasks = std::unordered_map<IoHash, std::unordered_map<IoHash, UpstreamApplyStatus>>;
+
+struct UpstreamEndpointHealth
+{
+ std::string Reason;
+ bool Ok = false;
+};
+
+struct UpstreamApplyEndpointStats
+{
+ metrics::Counter PostCount;
+ metrics::Counter CompleteCount;
+ metrics::Counter UpdateCount;
+ metrics::Counter ErrorCount;
+ metrics::Counter UpBytes;
+ metrics::Counter DownBytes;
+};
+
+/**
+ * The upstream apply endpoint is responsible for handling remote execution.
+ */
+class UpstreamApplyEndpoint
+{
+public:
+ virtual ~UpstreamApplyEndpoint() = default;
+
+ virtual UpstreamEndpointHealth Initialize() = 0;
+ virtual bool IsHealthy() const = 0;
+ virtual UpstreamEndpointHealth CheckHealth() = 0;
+ virtual std::string_view DisplayName() const = 0;
+ virtual PostUpstreamApplyResult PostApply(UpstreamApplyRecord ApplyRecord) = 0;
+ virtual GetUpstreamApplyUpdatesResult GetUpdates(WorkerThreadPool& ThreadPool) = 0;
+ virtual UpstreamApplyEndpointStats& Stats() = 0;
+
+ static std::unique_ptr<UpstreamApplyEndpoint> CreateHordeEndpoint(const CloudCacheClientOptions& ComputeOptions,
+ const UpstreamAuthConfig& ComputeAuthConfig,
+ const CloudCacheClientOptions& StorageOptions,
+ const UpstreamAuthConfig& StorageAuthConfig,
+ CidStore& CidStore,
+ AuthMgr& Mgr);
+};
+
+/**
+ * Manages one or more upstream compute endpoints.
+ */
+class UpstreamApply
+{
+public:
+ virtual ~UpstreamApply() = default;
+
+ virtual bool Initialize() = 0;
+ virtual bool IsHealthy() const = 0;
+ virtual void RegisterEndpoint(std::unique_ptr<UpstreamApplyEndpoint> Endpoint) = 0;
+
+ struct EnqueueResult
+ {
+ IoHash ApplyId{};
+ bool Success = false;
+ };
+
+ struct StatusResult
+ {
+ UpstreamApplyStatus Status{};
+ bool Success = false;
+ };
+
+ virtual EnqueueResult EnqueueUpstream(UpstreamApplyRecord ApplyRecord) = 0;
+ virtual StatusResult GetStatus(const IoHash& WorkerId, const IoHash& ActionId) = 0;
+ virtual void GetStatus(CbObjectWriter& CbO) = 0;
+
+ static std::unique_ptr<UpstreamApply> Create(const UpstreamApplyOptions& Options, CidStore& CidStore);
+};
+
+} // namespace zen
+
+#endif // ZEN_WITH_COMPUTE_SERVICES
diff --git a/src/zenserver/upstream/upstreamcache.cpp b/src/zenserver/upstream/upstreamcache.cpp
new file mode 100644
index 000000000..e838b5fe2
--- /dev/null
+++ b/src/zenserver/upstream/upstreamcache.cpp
@@ -0,0 +1,2112 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "upstreamcache.h"
+#include "jupiter.h"
+#include "zen.h"
+
+#include <zencore/blockingqueue.h>
+#include <zencore/compactbinary.h>
+#include <zencore/compactbinarybuilder.h>
+#include <zencore/compactbinarypackage.h>
+#include <zencore/compactbinaryvalidation.h>
+#include <zencore/fmtutils.h>
+#include <zencore/stats.h>
+#include <zencore/stream.h>
+#include <zencore/timer.h>
+#include <zencore/trace.h>
+
+#include <zenhttp/httpshared.h>
+
+#include <zenstore/cidstore.h>
+
+#include <auth/authmgr.h>
+#include "cache/structuredcache.h"
+#include "cache/structuredcachestore.h"
+#include "diag/logging.h"
+
+#include <fmt/format.h>
+
+#include <algorithm>
+#include <atomic>
+#include <shared_mutex>
+#include <thread>
+#include <unordered_map>
+
+namespace zen {
+
+using namespace std::literals;
+
+namespace detail {
+
+ class UpstreamStatus
+ {
+ public:
+ UpstreamEndpointState EndpointState() const { return static_cast<UpstreamEndpointState>(m_State.load(std::memory_order_relaxed)); }
+
+ UpstreamEndpointStatus EndpointStatus() const
+ {
+ const UpstreamEndpointState State = EndpointState();
+ {
+ std::unique_lock _(m_Mutex);
+ return {.Reason = m_ErrorText, .State = State};
+ }
+ }
+
+ void Set(UpstreamEndpointState NewState)
+ {
+ m_State.store(static_cast<uint32_t>(NewState), std::memory_order_relaxed);
+ {
+ std::unique_lock _(m_Mutex);
+ m_ErrorText.clear();
+ }
+ }
+
+ void Set(UpstreamEndpointState NewState, std::string ErrorText)
+ {
+ m_State.store(static_cast<uint32_t>(NewState), std::memory_order_relaxed);
+ {
+ std::unique_lock _(m_Mutex);
+ m_ErrorText = std::move(ErrorText);
+ }
+ }
+
+ void SetFromErrorCode(int32_t ErrorCode, std::string_view ErrorText)
+ {
+ if (ErrorCode != 0)
+ {
+ Set(ErrorCode == 401 ? UpstreamEndpointState::kUnauthorized : UpstreamEndpointState::kError, std::string(ErrorText));
+ }
+ }
+
+ private:
+ mutable std::mutex m_Mutex;
+ std::string m_ErrorText;
+ std::atomic_uint32_t m_State;
+ };
+
+ class JupiterUpstreamEndpoint final : public UpstreamEndpoint
+ {
+ public:
+ JupiterUpstreamEndpoint(const CloudCacheClientOptions& Options, const UpstreamAuthConfig& AuthConfig, AuthMgr& Mgr)
+ : m_AuthMgr(Mgr)
+ , m_Log(zen::logging::Get("upstream"))
+ {
+ ZEN_ASSERT(!Options.Name.empty());
+ m_Info.Name = Options.Name;
+ m_Info.Url = Options.ServiceUrl;
+
+ std::unique_ptr<CloudCacheTokenProvider> TokenProvider;
+
+ if (AuthConfig.OAuthUrl.empty() == false)
+ {
+ TokenProvider = CloudCacheTokenProvider::CreateFromOAuthClientCredentials(
+ {.Url = AuthConfig.OAuthUrl, .ClientId = AuthConfig.OAuthClientId, .ClientSecret = AuthConfig.OAuthClientSecret});
+ }
+ else if (AuthConfig.OpenIdProvider.empty() == false)
+ {
+ TokenProvider =
+ CloudCacheTokenProvider::CreateFromCallback([this, ProviderName = std::string(AuthConfig.OpenIdProvider)]() {
+ AuthMgr::OpenIdAccessToken Token = m_AuthMgr.GetOpenIdAccessToken(ProviderName);
+ return CloudCacheAccessToken{.Value = Token.AccessToken, .ExpireTime = Token.ExpireTime};
+ });
+ }
+ else
+ {
+ CloudCacheAccessToken AccessToken{.Value = std::string(AuthConfig.AccessToken),
+ .ExpireTime = CloudCacheAccessToken::TimePoint::max()};
+
+ TokenProvider = CloudCacheTokenProvider::CreateFromStaticToken(AccessToken);
+ }
+
+ m_Client = new CloudCacheClient(Options, std::move(TokenProvider));
+ }
+
+ virtual ~JupiterUpstreamEndpoint() = default;
+
+ virtual const UpstreamEndpointInfo& GetEndpointInfo() const override { return m_Info; }
+
+ virtual UpstreamEndpointStatus Initialize() override
+ {
+ try
+ {
+ if (m_Status.EndpointState() == UpstreamEndpointState::kOk)
+ {
+ return {.State = UpstreamEndpointState::kOk};
+ }
+
+ CloudCacheSession Session(m_Client);
+ const CloudCacheResult Result = Session.Authenticate();
+
+ if (Result.Success)
+ {
+ m_Status.Set(UpstreamEndpointState::kOk);
+ }
+ else if (Result.ErrorCode != 0)
+ {
+ m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason);
+ }
+ else
+ {
+ m_Status.Set(UpstreamEndpointState::kUnauthorized);
+ }
+
+ return m_Status.EndpointStatus();
+ }
+ catch (std::exception& Err)
+ {
+ m_Status.Set(UpstreamEndpointState::kError, Err.what());
+
+ return {.Reason = Err.what(), .State = GetState()};
+ }
+ }
+
+ std::string_view GetActualDdcNamespace(CloudCacheSession& Session, std::string_view Namespace)
+ {
+ if (Namespace == ZenCacheStore::DefaultNamespace)
+ {
+ return Session.Client().DefaultDdcNamespace();
+ }
+ return Namespace;
+ }
+
+ std::string_view GetActualBlobStoreNamespace(CloudCacheSession& Session, std::string_view Namespace)
+ {
+ if (Namespace == ZenCacheStore::DefaultNamespace)
+ {
+ return Session.Client().DefaultBlobStoreNamespace();
+ }
+ return Namespace;
+ }
+
+ virtual UpstreamEndpointState GetState() override { return m_Status.EndpointState(); }
+
+ virtual UpstreamEndpointStatus GetStatus() override { return m_Status.EndpointStatus(); }
+
+ virtual GetUpstreamCacheSingleResult GetCacheRecord(std::string_view Namespace,
+ const CacheKey& CacheKey,
+ ZenContentType Type) override
+ {
+ ZEN_TRACE_CPU("Upstream::Horde::GetSingleCacheRecord");
+
+ try
+ {
+ CloudCacheSession Session(m_Client);
+ CloudCacheResult Result;
+
+ std::string_view BlobStoreNamespace = GetActualBlobStoreNamespace(Session, Namespace);
+
+ if (Type == ZenContentType::kCompressedBinary)
+ {
+ Result = Session.GetRef(BlobStoreNamespace, CacheKey.Bucket, CacheKey.Hash, ZenContentType::kCbObject);
+
+ if (Result.Success)
+ {
+ const CbValidateError ValidationResult = ValidateCompactBinary(Result.Response, CbValidateMode::All);
+ if (Result.Success = ValidationResult == CbValidateError::None; Result.Success)
+ {
+ CbObject CacheRecord = LoadCompactBinaryObject(Result.Response);
+ IoBuffer ContentBuffer;
+ int NumAttachments = 0;
+
+ CacheRecord.IterateAttachments([&](CbFieldView AttachmentHash) {
+ CloudCacheResult AttachmentResult = Session.GetCompressedBlob(BlobStoreNamespace, AttachmentHash.AsHash());
+ Result.Bytes += AttachmentResult.Bytes;
+ Result.ElapsedSeconds += AttachmentResult.ElapsedSeconds;
+ Result.ErrorCode = AttachmentResult.ErrorCode;
+
+ IoHash RawHash;
+ uint64_t RawSize;
+ if (CompressedBuffer::ValidateCompressedHeader(AttachmentResult.Response, RawHash, RawSize))
+ {
+ Result.Response = AttachmentResult.Response;
+ ++NumAttachments;
+ }
+ else
+ {
+ Result.Success = false;
+ }
+ });
+ if (NumAttachments != 1)
+ {
+ Result.Success = false;
+ }
+ }
+ }
+ }
+ else
+ {
+ const ZenContentType AcceptType = Type == ZenContentType::kCbPackage ? ZenContentType::kCbObject : Type;
+ Result = Session.GetRef(BlobStoreNamespace, CacheKey.Bucket, CacheKey.Hash, AcceptType);
+
+ if (Result.Success && Type == ZenContentType::kCbPackage)
+ {
+ CbPackage Package;
+
+ const CbValidateError ValidationResult = ValidateCompactBinary(Result.Response, CbValidateMode::All);
+ if (Result.Success = ValidationResult == CbValidateError::None; Result.Success)
+ {
+ CbObject CacheRecord = LoadCompactBinaryObject(Result.Response);
+
+ CacheRecord.IterateAttachments([&](CbFieldView AttachmentHash) {
+ CloudCacheResult AttachmentResult = Session.GetCompressedBlob(BlobStoreNamespace, AttachmentHash.AsHash());
+ Result.Bytes += AttachmentResult.Bytes;
+ Result.ElapsedSeconds += AttachmentResult.ElapsedSeconds;
+ Result.ErrorCode = AttachmentResult.ErrorCode;
+
+ IoHash RawHash;
+ uint64_t RawSize;
+ if (CompressedBuffer Chunk =
+ CompressedBuffer::FromCompressed(SharedBuffer(AttachmentResult.Response), RawHash, RawSize))
+ {
+ Package.AddAttachment(CbAttachment(Chunk, AttachmentHash.AsHash()));
+ }
+ else
+ {
+ Result.Success = false;
+ }
+ });
+
+ Package.SetObject(CacheRecord);
+ }
+
+ if (Result.Success)
+ {
+ BinaryWriter MemStream;
+ Package.Save(MemStream);
+
+ Result.Response = IoBuffer(IoBuffer::Clone, MemStream.Data(), MemStream.Size());
+ }
+ }
+ }
+
+ m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason);
+
+ if (Result.ErrorCode == 0)
+ {
+ return {.Status = {.Bytes = Result.Bytes, .ElapsedSeconds = Result.ElapsedSeconds, .Success = Result.Success},
+ .Value = Result.Response,
+ .Source = &m_Info};
+ }
+ else
+ {
+ return {.Status = {.Error{.ErrorCode = Result.ErrorCode, .Reason = std::move(Result.Reason)}}};
+ }
+ }
+ catch (std::exception& Err)
+ {
+ m_Status.Set(UpstreamEndpointState::kError, Err.what());
+
+ return {.Status = {.Error{.ErrorCode = -1, .Reason = Err.what()}}};
+ }
+ }
+
+ virtual GetUpstreamCacheResult GetCacheRecords(std::string_view Namespace,
+ std::span<CacheKeyRequest*> Requests,
+ OnCacheRecordGetComplete&& OnComplete) override
+ {
+ ZEN_TRACE_CPU("Upstream::Horde::GetCacheRecords");
+
+ CloudCacheSession Session(m_Client);
+ GetUpstreamCacheResult Result;
+
+ for (CacheKeyRequest* Request : Requests)
+ {
+ const CacheKey& CacheKey = Request->Key;
+ CbPackage Package;
+ CbObject Record;
+
+ double ElapsedSeconds = 0.0;
+ if (!Result.Error)
+ {
+ std::string_view BlobStoreNamespace = GetActualBlobStoreNamespace(Session, Namespace);
+ CloudCacheResult RefResult =
+ Session.GetRef(BlobStoreNamespace, CacheKey.Bucket, CacheKey.Hash, ZenContentType::kCbObject);
+ AppendResult(RefResult, Result);
+ ElapsedSeconds = RefResult.ElapsedSeconds;
+
+ m_Status.SetFromErrorCode(RefResult.ErrorCode, RefResult.Reason);
+
+ if (RefResult.ErrorCode == 0)
+ {
+ const CbValidateError ValidationResult = ValidateCompactBinary(RefResult.Response, CbValidateMode::All);
+ if (ValidationResult == CbValidateError::None)
+ {
+ Record = LoadCompactBinaryObject(RefResult.Response);
+ Record.IterateAttachments([&](CbFieldView AttachmentHash) {
+ CloudCacheResult BlobResult = Session.GetCompressedBlob(BlobStoreNamespace, AttachmentHash.AsHash());
+ AppendResult(BlobResult, Result);
+
+ m_Status.SetFromErrorCode(BlobResult.ErrorCode, BlobResult.Reason);
+
+ if (BlobResult.ErrorCode == 0)
+ {
+ IoHash RawHash;
+ uint64_t RawSize;
+ if (CompressedBuffer Chunk =
+ CompressedBuffer::FromCompressed(SharedBuffer(BlobResult.Response), RawHash, RawSize))
+ {
+ if (RawHash == AttachmentHash.AsHash())
+ {
+ Package.AddAttachment(CbAttachment(Chunk, RawHash));
+ }
+ }
+ }
+ });
+ }
+ }
+ }
+
+ OnComplete(
+ {.Request = *Request, .Record = Record, .Package = Package, .ElapsedSeconds = ElapsedSeconds, .Source = &m_Info});
+ }
+
+ return Result;
+ }
+
+ virtual GetUpstreamCacheSingleResult GetCacheChunk(std::string_view Namespace,
+ const CacheKey&,
+ const IoHash& ValueContentId) override
+ {
+ ZEN_TRACE_CPU("Upstream::Horde::GetSingleCacheChunk");
+
+ try
+ {
+ CloudCacheSession Session(m_Client);
+ std::string_view BlobStoreNamespace = GetActualBlobStoreNamespace(Session, Namespace);
+ const CloudCacheResult Result = Session.GetCompressedBlob(BlobStoreNamespace, ValueContentId);
+
+ m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason);
+
+ if (Result.ErrorCode == 0)
+ {
+ return {.Status = {.Bytes = Result.Bytes, .ElapsedSeconds = Result.ElapsedSeconds, .Success = Result.Success},
+ .Value = Result.Response,
+ .Source = &m_Info};
+ }
+ else
+ {
+ return {.Status = {.Error{.ErrorCode = Result.ErrorCode, .Reason = std::move(Result.Reason)}}};
+ }
+ }
+ catch (std::exception& Err)
+ {
+ m_Status.Set(UpstreamEndpointState::kError, Err.what());
+
+ return {.Status = {.Error{.ErrorCode = -1, .Reason = Err.what()}}};
+ }
+ }
+
+ virtual GetUpstreamCacheResult GetCacheChunks(std::string_view Namespace,
+ std::span<CacheChunkRequest*> CacheChunkRequests,
+ OnCacheChunksGetComplete&& OnComplete) override final
+ {
+ ZEN_TRACE_CPU("Upstream::Horde::GetCacheChunks");
+
+ CloudCacheSession Session(m_Client);
+ GetUpstreamCacheResult Result;
+
+ for (CacheChunkRequest* RequestPtr : CacheChunkRequests)
+ {
+ CacheChunkRequest& Request = *RequestPtr;
+ IoBuffer Payload;
+ IoHash RawHash = IoHash::Zero;
+ uint64_t RawSize = 0;
+
+ double ElapsedSeconds = 0.0;
+ bool IsCompressed = false;
+ if (!Result.Error)
+ {
+ std::string_view BlobStoreNamespace = GetActualBlobStoreNamespace(Session, Namespace);
+ const CloudCacheResult BlobResult =
+ Request.ChunkId == IoHash::Zero
+ ? Session.GetInlineBlob(BlobStoreNamespace, Request.Key.Bucket, Request.Key.Hash, Request.ChunkId)
+ : Session.GetCompressedBlob(BlobStoreNamespace, Request.ChunkId);
+ ElapsedSeconds = BlobResult.ElapsedSeconds;
+ Payload = BlobResult.Response;
+
+ AppendResult(BlobResult, Result);
+
+ m_Status.SetFromErrorCode(BlobResult.ErrorCode, BlobResult.Reason);
+ if (Payload && IsCompressedBinary(Payload.GetContentType()))
+ {
+ IsCompressed = CompressedBuffer::ValidateCompressedHeader(Payload, RawHash, RawSize);
+ }
+ }
+
+ if (IsCompressed)
+ {
+ OnComplete({.Request = Request,
+ .RawHash = RawHash,
+ .RawSize = RawSize,
+ .Value = Payload,
+ .ElapsedSeconds = ElapsedSeconds,
+ .Source = &m_Info});
+ }
+ else
+ {
+ OnComplete({.Request = Request, .RawHash = IoHash::Zero, .RawSize = 0, .Value = IoBuffer()});
+ }
+ }
+
+ return Result;
+ }
+
+ virtual GetUpstreamCacheResult GetCacheValues(std::string_view Namespace,
+ std::span<CacheValueRequest*> CacheValueRequests,
+ OnCacheValueGetComplete&& OnComplete) override final
+ {
+ ZEN_TRACE_CPU("Upstream::Horde::GetCacheValues");
+
+ CloudCacheSession Session(m_Client);
+ GetUpstreamCacheResult Result;
+
+ for (CacheValueRequest* RequestPtr : CacheValueRequests)
+ {
+ CacheValueRequest& Request = *RequestPtr;
+ IoBuffer Payload;
+ IoHash RawHash = IoHash::Zero;
+ uint64_t RawSize = 0;
+
+ double ElapsedSeconds = 0.0;
+ bool IsCompressed = false;
+ if (!Result.Error)
+ {
+ std::string_view BlobStoreNamespace = GetActualBlobStoreNamespace(Session, Namespace);
+ IoHash PayloadHash;
+ const CloudCacheResult BlobResult =
+ Session.GetInlineBlob(BlobStoreNamespace, Request.Key.Bucket, Request.Key.Hash, PayloadHash);
+ ElapsedSeconds = BlobResult.ElapsedSeconds;
+ Payload = BlobResult.Response;
+
+ AppendResult(BlobResult, Result);
+
+ m_Status.SetFromErrorCode(BlobResult.ErrorCode, BlobResult.Reason);
+ if (Payload)
+ {
+ if (IsCompressedBinary(Payload.GetContentType()))
+ {
+ IsCompressed = CompressedBuffer::ValidateCompressedHeader(Payload, RawHash, RawSize) && RawHash != PayloadHash;
+ }
+ else
+ {
+ CompressedBuffer Compressed = CompressedBuffer::Compress(SharedBuffer(Payload));
+ RawHash = Compressed.DecodeRawHash();
+ if (RawHash == PayloadHash)
+ {
+ IsCompressed = true;
+ }
+ else
+ {
+ ZEN_WARN("Horde request for inline payload of {}/{}/{} has hash {}, expected hash {} from header",
+ Namespace,
+ Request.Key.Bucket,
+ Request.Key.Hash.ToHexString(),
+ RawHash.ToHexString(),
+ PayloadHash.ToHexString());
+ }
+ }
+ }
+ }
+
+ if (IsCompressed)
+ {
+ OnComplete({.Request = Request,
+ .RawHash = RawHash,
+ .RawSize = RawSize,
+ .Value = Payload,
+ .ElapsedSeconds = ElapsedSeconds,
+ .Source = &m_Info});
+ }
+ else
+ {
+ OnComplete({.Request = Request, .RawHash = IoHash::Zero, .RawSize = 0, .Value = IoBuffer()});
+ }
+ }
+
+ return Result;
+ }
+
+ virtual PutUpstreamCacheResult PutCacheRecord(const UpstreamCacheRecord& CacheRecord,
+ IoBuffer RecordValue,
+ std::span<IoBuffer const> Values) override
+ {
+ ZEN_TRACE_CPU("Upstream::Horde::PutCacheRecord");
+
+ ZEN_ASSERT(CacheRecord.ValueContentIds.size() == Values.size());
+ const int32_t MaxAttempts = 3;
+
+ try
+ {
+ CloudCacheSession Session(m_Client);
+
+ if (CacheRecord.Type == ZenContentType::kBinary)
+ {
+ CloudCacheResult Result;
+ for (uint32_t Attempt = 0; Attempt < MaxAttempts && !Result.Success; Attempt++)
+ {
+ std::string_view BlobStoreNamespace = GetActualBlobStoreNamespace(Session, CacheRecord.Namespace);
+ Result = Session.PutRef(BlobStoreNamespace,
+ CacheRecord.Key.Bucket,
+ CacheRecord.Key.Hash,
+ RecordValue,
+ ZenContentType::kBinary);
+ }
+
+ m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason);
+
+ return {.Reason = std::move(Result.Reason),
+ .Bytes = Result.Bytes,
+ .ElapsedSeconds = Result.ElapsedSeconds,
+ .Success = Result.Success};
+ }
+ else if (CacheRecord.Type == ZenContentType::kCompressedBinary)
+ {
+ IoHash RawHash;
+ uint64_t RawSize;
+ if (!CompressedBuffer::ValidateCompressedHeader(RecordValue, RawHash, RawSize))
+ {
+ return {.Reason = std::string("Invalid compressed value buffer"), .Success = false};
+ }
+
+ CbObjectWriter ReferencingObject;
+ ReferencingObject.AddBinaryAttachment("RawHash", RawHash);
+ ReferencingObject.AddInteger("RawSize", RawSize);
+
+ return PerformStructuredPut(
+ Session,
+ CacheRecord.Namespace,
+ CacheRecord.Key,
+ ReferencingObject.Save().GetBuffer().AsIoBuffer(),
+ MaxAttempts,
+ [&](const IoHash& ValueContentId, IoBuffer& OutBuffer, std::string& OutReason) {
+ if (ValueContentId != RawHash)
+ {
+ OutReason =
+ fmt::format("Value '{}' MISMATCHED from compressed buffer raw hash {}", ValueContentId, RawHash);
+ return false;
+ }
+
+ OutBuffer = RecordValue;
+ return true;
+ });
+ }
+ else
+ {
+ return PerformStructuredPut(
+ Session,
+ CacheRecord.Namespace,
+ CacheRecord.Key,
+ RecordValue,
+ MaxAttempts,
+ [&](const IoHash& ValueContentId, IoBuffer& OutBuffer, std::string& OutReason) {
+ const auto It =
+ std::find(std::begin(CacheRecord.ValueContentIds), std::end(CacheRecord.ValueContentIds), ValueContentId);
+
+ if (It == std::end(CacheRecord.ValueContentIds))
+ {
+ OutReason = fmt::format("value '{}' MISSING from local cache", ValueContentId);
+ return false;
+ }
+
+ const size_t Idx = std::distance(std::begin(CacheRecord.ValueContentIds), It);
+
+ OutBuffer = Values[Idx];
+ return true;
+ });
+ }
+ }
+ catch (std::exception& Err)
+ {
+ m_Status.Set(UpstreamEndpointState::kError, Err.what());
+
+ return {.Reason = std::string(Err.what()), .Success = false};
+ }
+ }
+
+ virtual UpstreamEndpointStats& Stats() override { return m_Stats; }
+
+ private:
+ static void AppendResult(const CloudCacheResult& Result, GetUpstreamCacheResult& Out)
+ {
+ Out.Success &= Result.Success;
+ Out.Bytes += Result.Bytes;
+ Out.ElapsedSeconds += Result.ElapsedSeconds;
+
+ if (Result.ErrorCode)
+ {
+ Out.Error = {.ErrorCode = Result.ErrorCode, .Reason = std::move(Result.Reason)};
+ }
+ };
+
+ PutUpstreamCacheResult PerformStructuredPut(
+ CloudCacheSession& Session,
+ std::string_view Namespace,
+ const CacheKey& Key,
+ IoBuffer ObjectBuffer,
+ const int32_t MaxAttempts,
+ std::function<bool(const IoHash& ValueContentId, IoBuffer& OutBuffer, std::string& OutReason)>&& BlobFetchFn)
+ {
+ int64_t TotalBytes = 0ull;
+ double TotalElapsedSeconds = 0.0;
+
+ std::string_view BlobStoreNamespace = GetActualBlobStoreNamespace(Session, Namespace);
+ const auto PutBlobs = [&](std::span<IoHash> ValueContentIds, std::string& OutReason) -> bool {
+ for (const IoHash& ValueContentId : ValueContentIds)
+ {
+ IoBuffer BlobBuffer;
+ if (!BlobFetchFn(ValueContentId, BlobBuffer, OutReason))
+ {
+ return false;
+ }
+
+ CloudCacheResult BlobResult;
+ for (int32_t Attempt = 0; Attempt < MaxAttempts && !BlobResult.Success; Attempt++)
+ {
+ BlobResult = Session.PutCompressedBlob(BlobStoreNamespace, ValueContentId, BlobBuffer);
+ }
+
+ m_Status.SetFromErrorCode(BlobResult.ErrorCode, BlobResult.Reason);
+
+ if (!BlobResult.Success)
+ {
+ OutReason = fmt::format("upload value '{}' FAILED, reason '{}'", ValueContentId, BlobResult.Reason);
+ return false;
+ }
+
+ TotalBytes += BlobResult.Bytes;
+ TotalElapsedSeconds += BlobResult.ElapsedSeconds;
+ }
+
+ return true;
+ };
+
+ PutRefResult RefResult;
+ for (int32_t Attempt = 0; Attempt < MaxAttempts && !RefResult.Success; Attempt++)
+ {
+ RefResult = Session.PutRef(BlobStoreNamespace, Key.Bucket, Key.Hash, ObjectBuffer, ZenContentType::kCbObject);
+ }
+
+ m_Status.SetFromErrorCode(RefResult.ErrorCode, RefResult.Reason);
+
+ if (!RefResult.Success)
+ {
+ return {.Reason = fmt::format("upload cache record '{}/{}' FAILED, reason '{}'", Key.Bucket, Key.Hash, RefResult.Reason),
+ .Success = false};
+ }
+
+ TotalBytes += RefResult.Bytes;
+ TotalElapsedSeconds += RefResult.ElapsedSeconds;
+
+ std::string Reason;
+ if (!PutBlobs(RefResult.Needs, Reason))
+ {
+ return {.Reason = std::move(Reason), .Success = false};
+ }
+
+ const IoHash RefHash = IoHash::HashBuffer(ObjectBuffer);
+ FinalizeRefResult FinalizeResult = Session.FinalizeRef(BlobStoreNamespace, Key.Bucket, Key.Hash, RefHash);
+
+ m_Status.SetFromErrorCode(FinalizeResult.ErrorCode, FinalizeResult.Reason);
+
+ if (!FinalizeResult.Success)
+ {
+ return {
+ .Reason = fmt::format("finalize cache record '{}/{}' FAILED, reason '{}'", Key.Bucket, Key.Hash, FinalizeResult.Reason),
+ .Success = false};
+ }
+
+ if (!FinalizeResult.Needs.empty())
+ {
+ if (!PutBlobs(FinalizeResult.Needs, Reason))
+ {
+ return {.Reason = std::move(Reason), .Success = false};
+ }
+
+ FinalizeResult = Session.FinalizeRef(BlobStoreNamespace, Key.Bucket, Key.Hash, RefHash);
+
+ m_Status.SetFromErrorCode(FinalizeResult.ErrorCode, FinalizeResult.Reason);
+
+ if (!FinalizeResult.Success)
+ {
+ return {.Reason = fmt::format("finalize '{}/{}' FAILED, reason '{}'", Key.Bucket, Key.Hash, FinalizeResult.Reason),
+ .Success = false};
+ }
+
+ if (!FinalizeResult.Needs.empty())
+ {
+ ExtendableStringBuilder<256> Sb;
+ for (const IoHash& MissingHash : FinalizeResult.Needs)
+ {
+ Sb << MissingHash.ToHexString() << ",";
+ }
+
+ return {
+ .Reason = fmt::format("finalize '{}/{}' FAILED, still needs value(s) '{}'", Key.Bucket, Key.Hash, Sb.ToString()),
+ .Success = false};
+ }
+ }
+
+ TotalBytes += FinalizeResult.Bytes;
+ TotalElapsedSeconds += FinalizeResult.ElapsedSeconds;
+
+ return {.Bytes = TotalBytes, .ElapsedSeconds = TotalElapsedSeconds, .Success = true};
+ }
+
+ spdlog::logger& Log() { return m_Log; }
+
+ AuthMgr& m_AuthMgr;
+ spdlog::logger& m_Log;
+ UpstreamEndpointInfo m_Info;
+ UpstreamStatus m_Status;
+ UpstreamEndpointStats m_Stats;
+ RefPtr<CloudCacheClient> m_Client;
+ };
+
+ class ZenUpstreamEndpoint final : public UpstreamEndpoint
+ {
+ struct ZenEndpoint
+ {
+ std::string Url;
+ std::string Reason;
+ double Latency{};
+ bool Ok = false;
+
+ bool operator<(const ZenEndpoint& RHS) const { return Ok && RHS.Ok ? Latency < RHS.Latency : Ok; }
+ };
+
+ public:
+ ZenUpstreamEndpoint(const ZenStructuredCacheClientOptions& Options)
+ : m_Log(zen::logging::Get("upstream"))
+ , m_ConnectTimeout(Options.ConnectTimeout)
+ , m_Timeout(Options.Timeout)
+ {
+ ZEN_ASSERT(!Options.Name.empty());
+ m_Info.Name = Options.Name;
+
+ for (const auto& Url : Options.Urls)
+ {
+ m_Endpoints.push_back({.Url = Url});
+ }
+ }
+
+ ~ZenUpstreamEndpoint() = default;
+
+ virtual const UpstreamEndpointInfo& GetEndpointInfo() const override { return m_Info; }
+
+ virtual UpstreamEndpointStatus Initialize() override
+ {
+ try
+ {
+ if (m_Status.EndpointState() == UpstreamEndpointState::kOk)
+ {
+ return {.State = UpstreamEndpointState::kOk};
+ }
+
+ const ZenEndpoint& Ep = GetEndpoint();
+
+ if (m_Info.Url != Ep.Url)
+ {
+ ZEN_INFO("Setting Zen upstream URL to '{}'", Ep.Url);
+ m_Info.Url = Ep.Url;
+ }
+
+ if (Ep.Ok)
+ {
+ RwLock::ExclusiveLockScope _(m_ClientLock);
+ m_Client = new ZenStructuredCacheClient({.Url = m_Info.Url, .ConnectTimeout = m_ConnectTimeout, .Timeout = m_Timeout});
+ m_Status.Set(UpstreamEndpointState::kOk);
+ }
+ else
+ {
+ m_Status.Set(UpstreamEndpointState::kError, Ep.Reason);
+ }
+
+ return m_Status.EndpointStatus();
+ }
+ catch (std::exception& Err)
+ {
+ m_Status.Set(UpstreamEndpointState::kError, Err.what());
+
+ return {.Reason = Err.what(), .State = GetState()};
+ }
+ }
+
+ virtual UpstreamEndpointState GetState() override { return m_Status.EndpointState(); }
+
+ virtual UpstreamEndpointStatus GetStatus() override { return m_Status.EndpointStatus(); }
+
+ virtual GetUpstreamCacheSingleResult GetCacheRecord(std::string_view Namespace,
+ const CacheKey& CacheKey,
+ ZenContentType Type) override
+ {
+ ZEN_TRACE_CPU("Upstream::Zen::GetSingleCacheRecord");
+
+ try
+ {
+ ZenStructuredCacheSession Session(GetClientRef());
+ const ZenCacheResult Result = Session.GetCacheRecord(Namespace, CacheKey.Bucket, CacheKey.Hash, Type);
+
+ m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason);
+
+ if (Result.ErrorCode == 0)
+ {
+ return {.Status = {.Bytes = Result.Bytes, .ElapsedSeconds = Result.ElapsedSeconds, .Success = Result.Success},
+ .Value = Result.Response,
+ .Source = &m_Info};
+ }
+ else
+ {
+ return {.Status = {.Error{.ErrorCode = Result.ErrorCode, .Reason = std::move(Result.Reason)}}};
+ }
+ }
+ catch (std::exception& Err)
+ {
+ m_Status.Set(UpstreamEndpointState::kError, Err.what());
+
+ return {.Status = {.Error{.ErrorCode = -1, .Reason = Err.what()}}};
+ }
+ }
+
+ virtual GetUpstreamCacheResult GetCacheRecords(std::string_view Namespace,
+ std::span<CacheKeyRequest*> Requests,
+ OnCacheRecordGetComplete&& OnComplete) override
+ {
+ ZEN_TRACE_CPU("Upstream::Zen::GetCacheRecords");
+ ZEN_ASSERT(Requests.size() > 0);
+
+ CbObjectWriter BatchRequest;
+ BatchRequest << "Method"sv
+ << "GetCacheRecords"sv;
+ BatchRequest << "Accept"sv << kCbPkgMagic;
+
+ BatchRequest.BeginObject("Params"sv);
+ {
+ CachePolicy DefaultPolicy = Requests[0]->Policy.GetRecordPolicy();
+ BatchRequest << "DefaultPolicy"sv << WriteToString<128>(DefaultPolicy);
+
+ BatchRequest << "Namespace"sv << Namespace;
+
+ BatchRequest.BeginArray("Requests"sv);
+ for (CacheKeyRequest* Request : Requests)
+ {
+ BatchRequest.BeginObject();
+ {
+ const CacheKey& Key = Request->Key;
+ BatchRequest.BeginObject("Key"sv);
+ {
+ BatchRequest << "Bucket"sv << Key.Bucket;
+ BatchRequest << "Hash"sv << Key.Hash;
+ }
+ BatchRequest.EndObject();
+ if (!Request->Policy.IsUniform() || Request->Policy.GetRecordPolicy() != DefaultPolicy)
+ {
+ BatchRequest.SetName("Policy"sv);
+ Request->Policy.Save(BatchRequest);
+ }
+ }
+ BatchRequest.EndObject();
+ }
+ BatchRequest.EndArray();
+ }
+ BatchRequest.EndObject();
+
+ ZenCacheResult Result;
+
+ {
+ ZenStructuredCacheSession Session(GetClientRef());
+ Result = Session.InvokeRpc(BatchRequest.Save());
+ }
+
+ m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason);
+
+ if (Result.Success)
+ {
+ CbPackage BatchResponse;
+ if (ParsePackageMessageWithLegacyFallback(Result.Response, BatchResponse))
+ {
+ CbArrayView Results = BatchResponse.GetObject()["Result"sv].AsArrayView();
+ if (Results.Num() != Requests.size())
+ {
+ ZEN_WARN("Upstream::Zen::GetCacheRecords invalid number of Response results from Upstream.");
+ }
+ else
+ {
+ for (size_t Index = 0; CbFieldView Record : Results)
+ {
+ CacheKeyRequest* Request = Requests[Index++];
+ OnComplete({.Request = *Request,
+ .Record = Record.AsObjectView(),
+ .Package = BatchResponse,
+ .ElapsedSeconds = Result.ElapsedSeconds,
+ .Source = &m_Info});
+ }
+
+ return {.Bytes = Result.Bytes, .ElapsedSeconds = Result.ElapsedSeconds, .Success = true};
+ }
+ }
+ else
+ {
+ ZEN_WARN("Upstream::Zen::GetCacheRecords invalid Response from Upstream.");
+ }
+ }
+
+ for (CacheKeyRequest* Request : Requests)
+ {
+ OnComplete({.Request = *Request, .Record = CbObjectView(), .Package = CbPackage()});
+ }
+
+ return {.Error{.ErrorCode = Result.ErrorCode, .Reason = std::move(Result.Reason)}};
+ }
+
+ virtual GetUpstreamCacheSingleResult GetCacheChunk(std::string_view Namespace,
+ const CacheKey& CacheKey,
+ const IoHash& ValueContentId) override
+ {
+ ZEN_TRACE_CPU("Upstream::Zen::GetCacheChunk");
+
+ try
+ {
+ ZenStructuredCacheSession Session(GetClientRef());
+ const ZenCacheResult Result = Session.GetCacheChunk(Namespace, CacheKey.Bucket, CacheKey.Hash, ValueContentId);
+
+ m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason);
+
+ if (Result.ErrorCode == 0)
+ {
+ return {.Status = {.Bytes = Result.Bytes, .ElapsedSeconds = Result.ElapsedSeconds, .Success = Result.Success},
+ .Value = Result.Response,
+ .Source = &m_Info};
+ }
+ else
+ {
+ return {.Status = {.Error{.ErrorCode = Result.ErrorCode, .Reason = std::move(Result.Reason)}}};
+ }
+ }
+ catch (std::exception& Err)
+ {
+ m_Status.Set(UpstreamEndpointState::kError, Err.what());
+
+ return {.Status = {.Error{.ErrorCode = -1, .Reason = Err.what()}}};
+ }
+ }
+
+ virtual GetUpstreamCacheResult GetCacheValues(std::string_view Namespace,
+ std::span<CacheValueRequest*> CacheValueRequests,
+ OnCacheValueGetComplete&& OnComplete) override final
+ {
+ ZEN_TRACE_CPU("Upstream::Zen::GetCacheValues");
+ ZEN_ASSERT(!CacheValueRequests.empty());
+
+ CbObjectWriter BatchRequest;
+ BatchRequest << "Method"sv
+ << "GetCacheValues"sv;
+ BatchRequest << "Accept"sv << kCbPkgMagic;
+
+ BatchRequest.BeginObject("Params"sv);
+ {
+ CachePolicy DefaultPolicy = CacheValueRequests[0]->Policy;
+ BatchRequest << "DefaultPolicy"sv << WriteToString<128>(DefaultPolicy).ToView();
+ BatchRequest << "Namespace"sv << Namespace;
+
+ BatchRequest.BeginArray("Requests"sv);
+ {
+ for (CacheValueRequest* RequestPtr : CacheValueRequests)
+ {
+ const CacheValueRequest& Request = *RequestPtr;
+
+ BatchRequest.BeginObject();
+ {
+ BatchRequest.BeginObject("Key"sv);
+ BatchRequest << "Bucket"sv << Request.Key.Bucket;
+ BatchRequest << "Hash"sv << Request.Key.Hash;
+ BatchRequest.EndObject();
+ if (Request.Policy != DefaultPolicy)
+ {
+ BatchRequest << "Policy"sv << WriteToString<128>(Request.Policy).ToView();
+ }
+ }
+ BatchRequest.EndObject();
+ }
+ }
+ BatchRequest.EndArray();
+ }
+ BatchRequest.EndObject();
+
+ ZenCacheResult Result;
+
+ {
+ ZenStructuredCacheSession Session(GetClientRef());
+ Result = Session.InvokeRpc(BatchRequest.Save());
+ }
+
+ m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason);
+
+ if (Result.Success)
+ {
+ CbPackage BatchResponse;
+ if (ParsePackageMessageWithLegacyFallback(Result.Response, BatchResponse))
+ {
+ CbArrayView Results = BatchResponse.GetObject()["Result"sv].AsArrayView();
+ if (CacheValueRequests.size() != Results.Num())
+ {
+ ZEN_WARN("Upstream::Zen::GetCacheValues invalid number of Response results from Upstream.");
+ }
+ else
+ {
+ for (size_t RequestIndex = 0; CbFieldView ChunkField : Results)
+ {
+ CacheValueRequest& Request = *CacheValueRequests[RequestIndex++];
+ CbObjectView ChunkObject = ChunkField.AsObjectView();
+ IoHash RawHash = ChunkObject["RawHash"sv].AsHash();
+ IoBuffer Payload;
+ uint64_t RawSize = 0;
+ if (RawHash != IoHash::Zero)
+ {
+ bool Success = false;
+ const CbAttachment* Attachment = BatchResponse.FindAttachment(RawHash);
+ if (Attachment)
+ {
+ if (const CompressedBuffer& Compressed = Attachment->AsCompressedBinary())
+ {
+ Payload = Compressed.GetCompressed().Flatten().AsIoBuffer();
+ Payload.SetContentType(ZenContentType::kCompressedBinary);
+ RawSize = Compressed.DecodeRawSize();
+ Success = true;
+ }
+ }
+ if (!Success)
+ {
+ CbFieldView RawSizeField = ChunkObject["RawSize"sv];
+ RawSize = RawSizeField.AsUInt64();
+ Success = !RawSizeField.HasError();
+ }
+ if (!Success)
+ {
+ RawHash = IoHash::Zero;
+ }
+ }
+ OnComplete({.Request = Request,
+ .RawHash = RawHash,
+ .RawSize = RawSize,
+ .Value = std::move(Payload),
+ .ElapsedSeconds = Result.ElapsedSeconds,
+ .Source = &m_Info});
+ }
+
+ return {.Bytes = Result.Bytes, .ElapsedSeconds = Result.ElapsedSeconds, .Success = true};
+ }
+ }
+ else
+ {
+ ZEN_WARN("Upstream::Zen::GetCacheValues invalid Response from Upstream.");
+ }
+ }
+
+ for (CacheValueRequest* RequestPtr : CacheValueRequests)
+ {
+ OnComplete({.Request = *RequestPtr, .RawHash = IoHash::Zero, .RawSize = 0, .Value = IoBuffer()});
+ }
+
+ return {.Error{.ErrorCode = Result.ErrorCode, .Reason = std::move(Result.Reason)}};
+ }
+
+ virtual GetUpstreamCacheResult GetCacheChunks(std::string_view Namespace,
+ std::span<CacheChunkRequest*> CacheChunkRequests,
+ OnCacheChunksGetComplete&& OnComplete) override final
+ {
+ ZEN_TRACE_CPU("Upstream::Zen::GetCacheChunks");
+ ZEN_ASSERT(!CacheChunkRequests.empty());
+
+ CbObjectWriter BatchRequest;
+ BatchRequest << "Method"sv
+ << "GetCacheChunks"sv;
+ BatchRequest << "Accept"sv << kCbPkgMagic;
+
+ BatchRequest.BeginObject("Params"sv);
+ {
+ CachePolicy DefaultPolicy = CacheChunkRequests[0]->Policy;
+ BatchRequest << "DefaultPolicy"sv << WriteToString<128>(DefaultPolicy).ToView();
+ BatchRequest << "Namespace"sv << Namespace;
+
+ BatchRequest.BeginArray("ChunkRequests"sv);
+ {
+ for (CacheChunkRequest* RequestPtr : CacheChunkRequests)
+ {
+ const CacheChunkRequest& Request = *RequestPtr;
+
+ BatchRequest.BeginObject();
+ {
+ BatchRequest.BeginObject("Key"sv);
+ BatchRequest << "Bucket"sv << Request.Key.Bucket;
+ BatchRequest << "Hash"sv << Request.Key.Hash;
+ BatchRequest.EndObject();
+ if (Request.ValueId)
+ {
+ BatchRequest.AddObjectId("ValueId"sv, Request.ValueId);
+ }
+ if (Request.ChunkId != Request.ChunkId.Zero)
+ {
+ BatchRequest << "ChunkId"sv << Request.ChunkId;
+ }
+ if (Request.RawOffset != 0)
+ {
+ BatchRequest << "RawOffset"sv << Request.RawOffset;
+ }
+ if (Request.RawSize != UINT64_MAX)
+ {
+ BatchRequest << "RawSize"sv << Request.RawSize;
+ }
+ if (Request.Policy != DefaultPolicy)
+ {
+ BatchRequest << "Policy"sv << WriteToString<128>(Request.Policy).ToView();
+ }
+ }
+ BatchRequest.EndObject();
+ }
+ }
+ BatchRequest.EndArray();
+ }
+ BatchRequest.EndObject();
+
+ ZenCacheResult Result;
+
+ {
+ ZenStructuredCacheSession Session(GetClientRef());
+ Result = Session.InvokeRpc(BatchRequest.Save());
+ }
+
+ m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason);
+
+ if (Result.Success)
+ {
+ CbPackage BatchResponse;
+ if (ParsePackageMessageWithLegacyFallback(Result.Response, BatchResponse))
+ {
+ CbArrayView Results = BatchResponse.GetObject()["Result"sv].AsArrayView();
+ if (CacheChunkRequests.size() != Results.Num())
+ {
+ ZEN_WARN("Upstream::Zen::GetCacheChunks invalid number of Response results from Upstream.");
+ }
+ else
+ {
+ for (size_t RequestIndex = 0; CbFieldView ChunkField : Results)
+ {
+ CacheChunkRequest& Request = *CacheChunkRequests[RequestIndex++];
+ CbObjectView ChunkObject = ChunkField.AsObjectView();
+ IoHash RawHash = ChunkObject["RawHash"sv].AsHash();
+ IoBuffer Payload;
+ uint64_t RawSize = 0;
+ if (RawHash != IoHash::Zero)
+ {
+ bool Success = false;
+ const CbAttachment* Attachment = BatchResponse.FindAttachment(RawHash);
+ if (Attachment)
+ {
+ if (const CompressedBuffer& Compressed = Attachment->AsCompressedBinary())
+ {
+ Payload = Compressed.GetCompressed().Flatten().AsIoBuffer();
+ Payload.SetContentType(ZenContentType::kCompressedBinary);
+ RawSize = Compressed.DecodeRawSize();
+ Success = true;
+ }
+ }
+ if (!Success)
+ {
+ CbFieldView RawSizeField = ChunkObject["RawSize"sv];
+ RawSize = RawSizeField.AsUInt64();
+ Success = !RawSizeField.HasError();
+ }
+ if (!Success)
+ {
+ RawHash = IoHash::Zero;
+ }
+ }
+ OnComplete({.Request = Request,
+ .RawHash = RawHash,
+ .RawSize = RawSize,
+ .Value = std::move(Payload),
+ .ElapsedSeconds = Result.ElapsedSeconds,
+ .Source = &m_Info});
+ }
+
+ return {.Bytes = Result.Bytes, .ElapsedSeconds = Result.ElapsedSeconds, .Success = true};
+ }
+ }
+ else
+ {
+ ZEN_WARN("Upstream::Zen::GetCacheChunks invalid Response from Upstream.");
+ }
+ }
+
+ for (CacheChunkRequest* RequestPtr : CacheChunkRequests)
+ {
+ OnComplete({.Request = *RequestPtr, .RawHash = IoHash::Zero, .RawSize = 0, .Value = IoBuffer()});
+ }
+
+ return {.Error{.ErrorCode = Result.ErrorCode, .Reason = std::move(Result.Reason)}};
+ }
+
+ virtual PutUpstreamCacheResult PutCacheRecord(const UpstreamCacheRecord& CacheRecord,
+ IoBuffer RecordValue,
+ std::span<IoBuffer const> Values) override
+ {
+ ZEN_TRACE_CPU("Upstream::Zen::PutCacheRecord");
+
+ ZEN_ASSERT(CacheRecord.ValueContentIds.size() == Values.size());
+ const int32_t MaxAttempts = 3;
+
+ try
+ {
+ ZenStructuredCacheSession Session(GetClientRef());
+ ZenCacheResult Result;
+ int64_t TotalBytes = 0ull;
+ double TotalElapsedSeconds = 0.0;
+
+ if (CacheRecord.Type == ZenContentType::kCbPackage)
+ {
+ CbPackage Package;
+ Package.SetObject(CbObject(SharedBuffer(RecordValue)));
+
+ for (const IoBuffer& Value : Values)
+ {
+ IoHash RawHash;
+ uint64_t RawSize;
+ if (CompressedBuffer AttachmentBuffer = CompressedBuffer::FromCompressed(SharedBuffer(Value), RawHash, RawSize))
+ {
+ Package.AddAttachment(CbAttachment(AttachmentBuffer, RawHash));
+ }
+ else
+ {
+ return {.Reason = std::string("Invalid value buffer"), .Success = false};
+ }
+ }
+
+ BinaryWriter MemStream;
+ Package.Save(MemStream);
+ IoBuffer PackagePayload(IoBuffer::Wrap, MemStream.Data(), MemStream.Size());
+
+ for (uint32_t Attempt = 0; Attempt < MaxAttempts && !Result.Success; Attempt++)
+ {
+ Result = Session.PutCacheRecord(CacheRecord.Namespace,
+ CacheRecord.Key.Bucket,
+ CacheRecord.Key.Hash,
+ PackagePayload,
+ CacheRecord.Type);
+ }
+
+ m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason);
+
+ TotalBytes = Result.Bytes;
+ TotalElapsedSeconds = Result.ElapsedSeconds;
+ }
+ else if (CacheRecord.Type == ZenContentType::kCompressedBinary)
+ {
+ IoHash RawHash;
+ uint64_t RawSize;
+ CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(RecordValue), RawHash, RawSize);
+ if (!Compressed)
+ {
+ return {.Reason = std::string("Invalid value compressed buffer"), .Success = false};
+ }
+
+ CbPackage BatchPackage;
+ CbObjectWriter BatchWriter;
+ BatchWriter << "Method"sv
+ << "PutCacheValues"sv;
+ BatchWriter << "Accept"sv << kCbPkgMagic;
+
+ BatchWriter.BeginObject("Params"sv);
+ {
+ // DefaultPolicy unspecified and expected to be Default
+
+ BatchWriter << "Namespace"sv << CacheRecord.Namespace;
+
+ BatchWriter.BeginArray("Requests"sv);
+ {
+ BatchWriter.BeginObject();
+ {
+ const CacheKey& Key = CacheRecord.Key;
+ BatchWriter.BeginObject("Key"sv);
+ {
+ BatchWriter << "Bucket"sv << Key.Bucket;
+ BatchWriter << "Hash"sv << Key.Hash;
+ }
+ BatchWriter.EndObject();
+ // Policy unspecified and expected to be Default
+ BatchWriter.AddBinaryAttachment("RawHash"sv, RawHash);
+ BatchPackage.AddAttachment(CbAttachment(Compressed, RawHash));
+ }
+ BatchWriter.EndObject();
+ }
+ BatchWriter.EndArray();
+ }
+ BatchWriter.EndObject();
+ BatchPackage.SetObject(BatchWriter.Save());
+
+ Result.Success = false;
+ for (uint32_t Attempt = 0; Attempt < MaxAttempts && !Result.Success; Attempt++)
+ {
+ Result = Session.InvokeRpc(BatchPackage);
+ }
+
+ m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason);
+
+ TotalBytes += Result.Bytes;
+ TotalElapsedSeconds += Result.ElapsedSeconds;
+ }
+ else
+ {
+ for (size_t Idx = 0, Count = Values.size(); Idx < Count; Idx++)
+ {
+ Result.Success = false;
+ for (uint32_t Attempt = 0; Attempt < MaxAttempts && !Result.Success; Attempt++)
+ {
+ Result = Session.PutCacheValue(CacheRecord.Namespace,
+ CacheRecord.Key.Bucket,
+ CacheRecord.Key.Hash,
+ CacheRecord.ValueContentIds[Idx],
+ Values[Idx]);
+ }
+
+ m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason);
+
+ TotalBytes += Result.Bytes;
+ TotalElapsedSeconds += Result.ElapsedSeconds;
+
+ if (!Result.Success)
+ {
+ return {.Reason = "Failed to upload value",
+ .Bytes = TotalBytes,
+ .ElapsedSeconds = TotalElapsedSeconds,
+ .Success = false};
+ }
+ }
+
+ Result.Success = false;
+ for (uint32_t Attempt = 0; Attempt < MaxAttempts && !Result.Success; Attempt++)
+ {
+ Result = Session.PutCacheRecord(CacheRecord.Namespace,
+ CacheRecord.Key.Bucket,
+ CacheRecord.Key.Hash,
+ RecordValue,
+ CacheRecord.Type);
+ }
+
+ m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason);
+
+ TotalBytes += Result.Bytes;
+ TotalElapsedSeconds += Result.ElapsedSeconds;
+ }
+
+ return {.Reason = std::move(Result.Reason),
+ .Bytes = TotalBytes,
+ .ElapsedSeconds = TotalElapsedSeconds,
+ .Success = Result.Success};
+ }
+ catch (std::exception& Err)
+ {
+ m_Status.Set(UpstreamEndpointState::kError, Err.what());
+
+ return {.Reason = std::string(Err.what()), .Success = false};
+ }
+ }
+
+ virtual UpstreamEndpointStats& Stats() override { return m_Stats; }
+
+ private:
+ Ref<ZenStructuredCacheClient> GetClientRef()
+ {
+ // m_Client can be modified at any time by a different thread.
+ // Make sure we safely bump the refcount inside a scope lock
+ RwLock::SharedLockScope _(m_ClientLock);
+ ZEN_ASSERT(m_Client);
+ Ref<ZenStructuredCacheClient> ClientRef(m_Client);
+ _.ReleaseNow();
+ return ClientRef;
+ }
+
+ const ZenEndpoint& GetEndpoint()
+ {
+ for (ZenEndpoint& Ep : m_Endpoints)
+ {
+ Ref<ZenStructuredCacheClient> Client(
+ new ZenStructuredCacheClient({.Url = Ep.Url, .ConnectTimeout = std::chrono::milliseconds(1000)}));
+ ZenStructuredCacheSession Session(std::move(Client));
+ const int32_t SampleCount = 2;
+
+ Ep.Ok = false;
+ Ep.Latency = {};
+
+ for (int32_t Sample = 0; Sample < SampleCount; ++Sample)
+ {
+ ZenCacheResult Result = Session.CheckHealth();
+ Ep.Ok = Result.Success;
+ Ep.Reason = std::move(Result.Reason);
+ Ep.Latency += Result.ElapsedSeconds;
+ }
+ Ep.Latency /= double(SampleCount);
+ }
+
+ std::sort(std::begin(m_Endpoints), std::end(m_Endpoints));
+
+ for (const auto& Ep : m_Endpoints)
+ {
+ ZEN_INFO("ping 'Zen' endpoint '{}' latency '{:.3}s' {}", Ep.Url, Ep.Latency, Ep.Ok ? "OK" : Ep.Reason);
+ }
+
+ return m_Endpoints.front();
+ }
+
+ spdlog::logger& Log() { return m_Log; }
+
+ spdlog::logger& m_Log;
+ UpstreamEndpointInfo m_Info;
+ UpstreamStatus m_Status;
+ UpstreamEndpointStats m_Stats;
+ std::vector<ZenEndpoint> m_Endpoints;
+ std::chrono::milliseconds m_ConnectTimeout;
+ std::chrono::milliseconds m_Timeout;
+ RwLock m_ClientLock;
+ RefPtr<ZenStructuredCacheClient> m_Client;
+ };
+
+} // namespace detail
+
+//////////////////////////////////////////////////////////////////////////
+
+class UpstreamCacheImpl final : public UpstreamCache
+{
+public:
+ UpstreamCacheImpl(const UpstreamCacheOptions& Options, ZenCacheStore& CacheStore, CidStore& CidStore)
+ : m_Log(logging::Get("upstream"))
+ , m_Options(Options)
+ , m_CacheStore(CacheStore)
+ , m_CidStore(CidStore)
+ {
+ }
+
+ virtual ~UpstreamCacheImpl() { Shutdown(); }
+
+ virtual void Initialize() override
+ {
+ for (uint32_t Idx = 0; Idx < m_Options.ThreadCount; Idx++)
+ {
+ m_UpstreamThreads.emplace_back(&UpstreamCacheImpl::ProcessUpstreamQueue, this);
+ }
+
+ m_EndpointMonitorThread = std::thread(&UpstreamCacheImpl::MonitorEndpoints, this);
+ m_RunState.IsRunning = true;
+ }
+
+ virtual void RegisterEndpoint(std::unique_ptr<UpstreamEndpoint> Endpoint) override
+ {
+ const UpstreamEndpointStatus Status = Endpoint->Initialize();
+ const UpstreamEndpointInfo& Info = Endpoint->GetEndpointInfo();
+
+ if (Status.State == UpstreamEndpointState::kOk)
+ {
+ ZEN_INFO("register endpoint '{} - {}' {}", Info.Name, Info.Url, ToString(Status.State));
+ }
+ else
+ {
+ ZEN_WARN("register endpoint '{} - {}' {}", Info.Name, Info.Url, ToString(Status.State));
+ }
+
+ // Register endpoint even if it fails, the health monitor thread will probe failing endpoint(s)
+ std::unique_lock<std::shared_mutex> _(m_EndpointsMutex);
+ m_Endpoints.emplace_back(std::move(Endpoint));
+ }
+
+ virtual void IterateEndpoints(std::function<bool(UpstreamEndpoint&)>&& Fn) override
+ {
+ std::shared_lock<std::shared_mutex> _(m_EndpointsMutex);
+
+ for (auto& Ep : m_Endpoints)
+ {
+ if (!Fn(*Ep))
+ {
+ break;
+ }
+ }
+ }
+
+ virtual GetUpstreamCacheSingleResult GetCacheRecord(std::string_view Namespace, const CacheKey& CacheKey, ZenContentType Type) override
+ {
+ ZEN_TRACE_CPU("Upstream::GetCacheRecord");
+
+ std::shared_lock<std::shared_mutex> _(m_EndpointsMutex);
+
+ if (m_Options.ReadUpstream)
+ {
+ for (auto& Endpoint : m_Endpoints)
+ {
+ if (Endpoint->GetState() != UpstreamEndpointState::kOk)
+ {
+ continue;
+ }
+
+ UpstreamEndpointStats& Stats = Endpoint->Stats();
+ metrics::OperationTiming::Scope Scope(Stats.CacheGetRequestTiming);
+ GetUpstreamCacheSingleResult Result = Endpoint->GetCacheRecord(Namespace, CacheKey, Type);
+ Scope.Stop();
+
+ Stats.CacheGetCount.Increment(1);
+ Stats.CacheGetTotalBytes.Increment(Result.Status.Bytes);
+
+ if (Result.Status.Success)
+ {
+ Stats.CacheHitCount.Increment(1);
+
+ return Result;
+ }
+
+ if (Result.Status.Error)
+ {
+ Stats.CacheErrorCount.Increment(1);
+
+ ZEN_WARN("get cache record FAILED, endpoint '{}', reason '{}', error code '{}'",
+ Endpoint->GetEndpointInfo().Url,
+ Result.Status.Error.Reason,
+ Result.Status.Error.ErrorCode);
+ }
+ }
+ }
+
+ return {};
+ }
+
+ virtual void GetCacheRecords(std::string_view Namespace,
+ std::span<CacheKeyRequest*> Requests,
+ OnCacheRecordGetComplete&& OnComplete) override final
+ {
+ ZEN_TRACE_CPU("Upstream::GetCacheRecords");
+
+ std::shared_lock<std::shared_mutex> _(m_EndpointsMutex);
+
+ std::vector<CacheKeyRequest*> RemainingKeys(Requests.begin(), Requests.end());
+
+ if (m_Options.ReadUpstream)
+ {
+ for (auto& Endpoint : m_Endpoints)
+ {
+ if (RemainingKeys.empty())
+ {
+ break;
+ }
+
+ if (Endpoint->GetState() != UpstreamEndpointState::kOk)
+ {
+ continue;
+ }
+
+ UpstreamEndpointStats& Stats = Endpoint->Stats();
+ std::vector<CacheKeyRequest*> Missing;
+ GetUpstreamCacheResult Result;
+ {
+ metrics::OperationTiming::Scope Scope(Stats.CacheGetRequestTiming);
+
+ Result = Endpoint->GetCacheRecords(Namespace, RemainingKeys, [&](CacheRecordGetCompleteParams&& Params) {
+ if (Params.Record)
+ {
+ OnComplete(std::forward<CacheRecordGetCompleteParams>(Params));
+
+ Stats.CacheHitCount.Increment(1);
+ }
+ else
+ {
+ Missing.push_back(&Params.Request);
+ }
+ });
+ }
+
+ Stats.CacheGetCount.Increment(int64_t(RemainingKeys.size()));
+ Stats.CacheGetTotalBytes.Increment(Result.Bytes);
+
+ if (Result.Error)
+ {
+ Stats.CacheErrorCount.Increment(1);
+
+ ZEN_WARN("get cache record(s) (rpc) FAILED, endpoint '{}', reason '{}', error code '{}'",
+ Endpoint->GetEndpointInfo().Url,
+ Result.Error.Reason,
+ Result.Error.ErrorCode);
+ }
+
+ RemainingKeys = std::move(Missing);
+ }
+ }
+
+ const UpstreamEndpointInfo Info;
+ for (CacheKeyRequest* Request : RemainingKeys)
+ {
+ OnComplete({.Request = *Request, .Record = CbObjectView(), .Package = CbPackage()});
+ }
+ }
+
+ virtual void GetCacheChunks(std::string_view Namespace,
+ std::span<CacheChunkRequest*> CacheChunkRequests,
+ OnCacheChunksGetComplete&& OnComplete) override final
+ {
+ ZEN_TRACE_CPU("Upstream::GetCacheChunks");
+
+ std::shared_lock<std::shared_mutex> _(m_EndpointsMutex);
+
+ std::vector<CacheChunkRequest*> RemainingKeys(CacheChunkRequests.begin(), CacheChunkRequests.end());
+
+ if (m_Options.ReadUpstream)
+ {
+ for (auto& Endpoint : m_Endpoints)
+ {
+ if (RemainingKeys.empty())
+ {
+ break;
+ }
+
+ if (Endpoint->GetState() != UpstreamEndpointState::kOk)
+ {
+ continue;
+ }
+
+ UpstreamEndpointStats& Stats = Endpoint->Stats();
+ std::vector<CacheChunkRequest*> Missing;
+ GetUpstreamCacheResult Result;
+ {
+ metrics::OperationTiming::Scope Scope(Endpoint->Stats().CacheGetRequestTiming);
+
+ Result = Endpoint->GetCacheChunks(Namespace, RemainingKeys, [&](CacheChunkGetCompleteParams&& Params) {
+ if (Params.RawHash != Params.RawHash.Zero)
+ {
+ OnComplete(std::forward<CacheChunkGetCompleteParams>(Params));
+
+ Stats.CacheHitCount.Increment(1);
+ }
+ else
+ {
+ Missing.push_back(&Params.Request);
+ }
+ });
+ }
+
+ Stats.CacheGetCount.Increment(int64_t(RemainingKeys.size()));
+ Stats.CacheGetTotalBytes.Increment(Result.Bytes);
+
+ if (Result.Error)
+ {
+ Stats.CacheErrorCount.Increment(1);
+
+ ZEN_WARN("get cache chunks(s) (rpc) FAILED, endpoint '{}', reason '{}', error code '{}'",
+ Endpoint->GetEndpointInfo().Url,
+ Result.Error.Reason,
+ Result.Error.ErrorCode);
+ }
+
+ RemainingKeys = std::move(Missing);
+ }
+ }
+
+ const UpstreamEndpointInfo Info;
+ for (CacheChunkRequest* RequestPtr : RemainingKeys)
+ {
+ OnComplete({.Request = *RequestPtr, .RawHash = IoHash::Zero, .RawSize = 0, .Value = IoBuffer()});
+ }
+ }
+
+ virtual GetUpstreamCacheSingleResult GetCacheChunk(std::string_view Namespace,
+ const CacheKey& CacheKey,
+ const IoHash& ValueContentId) override
+ {
+ ZEN_TRACE_CPU("Upstream::GetCacheChunk");
+
+ if (m_Options.ReadUpstream)
+ {
+ for (auto& Endpoint : m_Endpoints)
+ {
+ if (Endpoint->GetState() != UpstreamEndpointState::kOk)
+ {
+ continue;
+ }
+
+ UpstreamEndpointStats& Stats = Endpoint->Stats();
+ metrics::OperationTiming::Scope Scope(Stats.CacheGetRequestTiming);
+ GetUpstreamCacheSingleResult Result = Endpoint->GetCacheChunk(Namespace, CacheKey, ValueContentId);
+ Scope.Stop();
+
+ Stats.CacheGetCount.Increment(1);
+ Stats.CacheGetTotalBytes.Increment(Result.Status.Bytes);
+
+ if (Result.Status.Success)
+ {
+ Stats.CacheHitCount.Increment(1);
+
+ return Result;
+ }
+
+ if (Result.Status.Error)
+ {
+ Stats.CacheErrorCount.Increment(1);
+
+ ZEN_WARN("get cache chunk FAILED, endpoint '{}', reason '{}', error code '{}'",
+ Endpoint->GetEndpointInfo().Url,
+ Result.Status.Error.Reason,
+ Result.Status.Error.ErrorCode);
+ }
+ }
+ }
+
+ return {};
+ }
+
+ virtual void GetCacheValues(std::string_view Namespace,
+ std::span<CacheValueRequest*> CacheValueRequests,
+ OnCacheValueGetComplete&& OnComplete) override final
+ {
+ ZEN_TRACE_CPU("Upstream::GetCacheValues");
+
+ std::shared_lock<std::shared_mutex> _(m_EndpointsMutex);
+
+ std::vector<CacheValueRequest*> RemainingKeys(CacheValueRequests.begin(), CacheValueRequests.end());
+
+ if (m_Options.ReadUpstream)
+ {
+ for (auto& Endpoint : m_Endpoints)
+ {
+ if (RemainingKeys.empty())
+ {
+ break;
+ }
+
+ if (Endpoint->GetState() != UpstreamEndpointState::kOk)
+ {
+ continue;
+ }
+
+ UpstreamEndpointStats& Stats = Endpoint->Stats();
+ std::vector<CacheValueRequest*> Missing;
+ GetUpstreamCacheResult Result;
+ {
+ metrics::OperationTiming::Scope Scope(Endpoint->Stats().CacheGetRequestTiming);
+
+ Result = Endpoint->GetCacheValues(Namespace, RemainingKeys, [&](CacheValueGetCompleteParams&& Params) {
+ if (Params.RawHash != Params.RawHash.Zero)
+ {
+ OnComplete(std::forward<CacheValueGetCompleteParams>(Params));
+
+ Stats.CacheHitCount.Increment(1);
+ }
+ else
+ {
+ Missing.push_back(&Params.Request);
+ }
+ });
+ }
+
+ Stats.CacheGetCount.Increment(int64_t(RemainingKeys.size()));
+ Stats.CacheGetTotalBytes.Increment(Result.Bytes);
+
+ if (Result.Error)
+ {
+ Stats.CacheErrorCount.Increment(1);
+
+ ZEN_WARN("get cache values(s) (rpc) FAILED, endpoint '{}', reason '{}', error code '{}'",
+ Endpoint->GetEndpointInfo().Url,
+ Result.Error.Reason,
+ Result.Error.ErrorCode);
+ }
+
+ RemainingKeys = std::move(Missing);
+ }
+ }
+
+ const UpstreamEndpointInfo Info;
+ for (CacheValueRequest* RequestPtr : RemainingKeys)
+ {
+ OnComplete({.Request = *RequestPtr, .RawHash = IoHash::Zero, .RawSize = 0, .Value = IoBuffer()});
+ }
+ }
+
+ virtual void EnqueueUpstream(UpstreamCacheRecord CacheRecord) override
+ {
+ if (m_RunState.IsRunning && m_Options.WriteUpstream && m_Endpoints.size() > 0)
+ {
+ if (!m_UpstreamThreads.empty())
+ {
+ m_UpstreamQueue.Enqueue(std::move(CacheRecord));
+ }
+ else
+ {
+ ProcessCacheRecord(std::move(CacheRecord));
+ }
+ }
+ }
+
+ virtual void GetStatus(CbObjectWriter& Status) override
+ {
+ Status << "reading" << m_Options.ReadUpstream;
+ Status << "writing" << m_Options.WriteUpstream;
+ Status << "worker_threads" << m_Options.ThreadCount;
+ Status << "queue_count" << m_UpstreamQueue.Size();
+
+ Status.BeginArray("endpoints");
+ for (const auto& Ep : m_Endpoints)
+ {
+ const UpstreamEndpointInfo& EpInfo = Ep->GetEndpointInfo();
+ const UpstreamEndpointStatus EpStatus = Ep->GetStatus();
+ UpstreamEndpointStats& EpStats = Ep->Stats();
+
+ Status.BeginObject();
+ Status << "name" << EpInfo.Name;
+ Status << "url" << EpInfo.Url;
+ Status << "state" << ToString(EpStatus.State);
+ Status << "reason" << EpStatus.Reason;
+
+ Status.BeginObject("cache"sv);
+ {
+ const int64_t GetCount = EpStats.CacheGetCount.Value();
+ const int64_t HitCount = EpStats.CacheHitCount.Value();
+ const int64_t ErrorCount = EpStats.CacheErrorCount.Value();
+ const double HitRatio = GetCount > 0 ? double(HitCount) / double(GetCount) : 0.0;
+ const double ErrorRatio = GetCount > 0 ? double(ErrorCount) / double(GetCount) : 0.0;
+
+ metrics::EmitSnapshot("get_requests"sv, EpStats.CacheGetRequestTiming, Status);
+ Status << "get_bytes" << EpStats.CacheGetTotalBytes.Value();
+ Status << "get_count" << GetCount;
+ Status << "hit_count" << HitCount;
+ Status << "hit_ratio" << HitRatio;
+ Status << "error_count" << ErrorCount;
+ Status << "error_ratio" << ErrorRatio;
+ metrics::EmitSnapshot("put_requests"sv, EpStats.CachePutRequestTiming, Status);
+ Status << "put_bytes" << EpStats.CachePutTotalBytes.Value();
+ }
+ Status.EndObject();
+
+ Status.EndObject();
+ }
+ Status.EndArray();
+ }
+
+private:
+ void ProcessCacheRecord(UpstreamCacheRecord CacheRecord)
+ {
+ ZEN_TRACE_CPU("Upstream::ProcessCacheRecord");
+
+ ZenCacheValue CacheValue;
+ std::vector<IoBuffer> Payloads;
+
+ if (!m_CacheStore.Get(CacheRecord.Namespace, CacheRecord.Key.Bucket, CacheRecord.Key.Hash, CacheValue))
+ {
+ ZEN_WARN("process upstream FAILED, '{}/{}/{}', cache record doesn't exist",
+ CacheRecord.Namespace,
+ CacheRecord.Key.Bucket,
+ CacheRecord.Key.Hash);
+ return;
+ }
+
+ for (const IoHash& ValueContentId : CacheRecord.ValueContentIds)
+ {
+ if (IoBuffer Payload = m_CidStore.FindChunkByCid(ValueContentId))
+ {
+ Payloads.push_back(Payload);
+ }
+ else
+ {
+ ZEN_WARN("process upstream FAILED, '{}/{}/{}/{}', ValueContentId doesn't exist in CAS",
+ CacheRecord.Namespace,
+ CacheRecord.Key.Bucket,
+ CacheRecord.Key.Hash,
+ ValueContentId);
+ return;
+ }
+ }
+
+ std::shared_lock<std::shared_mutex> _(m_EndpointsMutex);
+
+ for (auto& Endpoint : m_Endpoints)
+ {
+ if (Endpoint->GetState() != UpstreamEndpointState::kOk)
+ {
+ continue;
+ }
+
+ UpstreamEndpointStats& Stats = Endpoint->Stats();
+ PutUpstreamCacheResult Result;
+ {
+ metrics::OperationTiming::Scope Scope(Stats.CachePutRequestTiming);
+ Result = Endpoint->PutCacheRecord(CacheRecord, CacheValue.Value, std::span(Payloads));
+ }
+
+ Stats.CachePutTotalBytes.Increment(Result.Bytes);
+
+ if (!Result.Success)
+ {
+ ZEN_WARN("upload cache record '{}/{}/{}' FAILED, endpoint '{}', reason '{}'",
+ CacheRecord.Namespace,
+ CacheRecord.Key.Bucket,
+ CacheRecord.Key.Hash,
+ Endpoint->GetEndpointInfo().Url,
+ Result.Reason);
+ }
+ }
+ }
+
+ void ProcessUpstreamQueue()
+ {
+ for (;;)
+ {
+ UpstreamCacheRecord CacheRecord;
+ if (m_UpstreamQueue.WaitAndDequeue(CacheRecord))
+ {
+ try
+ {
+ ProcessCacheRecord(std::move(CacheRecord));
+ }
+ catch (std::exception& Err)
+ {
+ ZEN_ERROR("upload cache record '{}/{}/{}' FAILED, reason '{}'",
+ CacheRecord.Namespace,
+ CacheRecord.Key.Bucket,
+ CacheRecord.Key.Hash,
+ Err.what());
+ }
+ }
+
+ if (!m_RunState.IsRunning)
+ {
+ break;
+ }
+ }
+ }
+
+ void MonitorEndpoints()
+ {
+ for (;;)
+ {
+ {
+ std::unique_lock lk(m_RunState.Mutex);
+ if (m_RunState.ExitSignal.wait_for(lk, m_Options.HealthCheckInterval, [this]() { return !m_RunState.IsRunning.load(); }))
+ {
+ break;
+ }
+ }
+
+ try
+ {
+ std::vector<UpstreamEndpoint*> Endpoints;
+
+ {
+ std::shared_lock<std::shared_mutex> _(m_EndpointsMutex);
+
+ for (auto& Endpoint : m_Endpoints)
+ {
+ UpstreamEndpointState State = Endpoint->GetState();
+ if (State == UpstreamEndpointState::kError)
+ {
+ Endpoints.push_back(Endpoint.get());
+ ZEN_WARN("HEALTH - endpoint '{} - {}' is in error state '{}'",
+ Endpoint->GetEndpointInfo().Name,
+ Endpoint->GetEndpointInfo().Url,
+ Endpoint->GetStatus().Reason);
+ }
+ if (State == UpstreamEndpointState::kUnauthorized)
+ {
+ Endpoints.push_back(Endpoint.get());
+ }
+ }
+ }
+
+ for (auto& Endpoint : Endpoints)
+ {
+ const UpstreamEndpointInfo& Info = Endpoint->GetEndpointInfo();
+ const UpstreamEndpointStatus Status = Endpoint->Initialize();
+
+ if (Status.State == UpstreamEndpointState::kOk)
+ {
+ ZEN_INFO("HEALTH - endpoint '{} - {}' Ok", Info.Name, Info.Url);
+ }
+ else
+ {
+ const std::string Reason = Status.Reason.empty() ? "" : fmt::format(", reason '{}'", Status.Reason);
+ ZEN_WARN("HEALTH - endpoint '{} - {}' {} {}", Info.Name, Info.Url, ToString(Status.State), Reason);
+ }
+ }
+ }
+ catch (std::exception& Err)
+ {
+ ZEN_ERROR("check endpoint(s) health FAILED, reason '{}'", Err.what());
+ }
+ }
+ }
+
+ void Shutdown()
+ {
+ if (m_RunState.Stop())
+ {
+ m_UpstreamQueue.CompleteAdding();
+ for (std::thread& Thread : m_UpstreamThreads)
+ {
+ Thread.join();
+ }
+
+ m_EndpointMonitorThread.join();
+ m_UpstreamThreads.clear();
+ m_Endpoints.clear();
+ }
+ }
+
+ spdlog::logger& Log() { return m_Log; }
+
+ using UpstreamQueue = BlockingQueue<UpstreamCacheRecord>;
+
+ struct RunState
+ {
+ std::mutex Mutex;
+ std::condition_variable ExitSignal;
+ std::atomic_bool IsRunning{false};
+
+ bool Stop()
+ {
+ bool Stopped = false;
+ {
+ std::lock_guard _(Mutex);
+ Stopped = IsRunning.exchange(false);
+ }
+ if (Stopped)
+ {
+ ExitSignal.notify_all();
+ }
+ return Stopped;
+ }
+ };
+
+ spdlog::logger& m_Log;
+ UpstreamCacheOptions m_Options;
+ ZenCacheStore& m_CacheStore;
+ CidStore& m_CidStore;
+ UpstreamQueue m_UpstreamQueue;
+ std::shared_mutex m_EndpointsMutex;
+ std::vector<std::unique_ptr<UpstreamEndpoint>> m_Endpoints;
+ std::vector<std::thread> m_UpstreamThreads;
+ std::thread m_EndpointMonitorThread;
+ RunState m_RunState;
+};
+
+//////////////////////////////////////////////////////////////////////////
+
+std::unique_ptr<UpstreamEndpoint>
+UpstreamEndpoint::CreateZenEndpoint(const ZenStructuredCacheClientOptions& Options)
+{
+ return std::make_unique<detail::ZenUpstreamEndpoint>(Options);
+}
+
+std::unique_ptr<UpstreamEndpoint>
+UpstreamEndpoint::CreateJupiterEndpoint(const CloudCacheClientOptions& Options, const UpstreamAuthConfig& AuthConfig, AuthMgr& Mgr)
+{
+ return std::make_unique<detail::JupiterUpstreamEndpoint>(Options, AuthConfig, Mgr);
+}
+
+std::unique_ptr<UpstreamCache>
+UpstreamCache::Create(const UpstreamCacheOptions& Options, ZenCacheStore& CacheStore, CidStore& CidStore)
+{
+ return std::make_unique<UpstreamCacheImpl>(Options, CacheStore, CidStore);
+}
+
+} // namespace zen
diff --git a/src/zenserver/upstream/upstreamcache.h b/src/zenserver/upstream/upstreamcache.h
new file mode 100644
index 000000000..695c06b32
--- /dev/null
+++ b/src/zenserver/upstream/upstreamcache.h
@@ -0,0 +1,252 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/compactbinary.h>
+#include <zencore/compress.h>
+#include <zencore/iobuffer.h>
+#include <zencore/iohash.h>
+#include <zencore/stats.h>
+#include <zencore/zencore.h>
+#include <zenutil/cache/cache.h>
+
+#include <atomic>
+#include <chrono>
+#include <functional>
+#include <memory>
+#include <vector>
+
+namespace zen {
+
+class CbObjectView;
+class AuthMgr;
+class CbObjectView;
+class CbPackage;
+class CbObjectWriter;
+class CidStore;
+class ZenCacheStore;
+struct CloudCacheClientOptions;
+class CloudCacheTokenProvider;
+struct ZenStructuredCacheClientOptions;
+
+struct UpstreamCacheRecord
+{
+ ZenContentType Type = ZenContentType::kBinary;
+ std::string Namespace;
+ CacheKey Key;
+ std::vector<IoHash> ValueContentIds;
+};
+
+struct UpstreamCacheOptions
+{
+ std::chrono::seconds HealthCheckInterval{5};
+ uint32_t ThreadCount = 4;
+ bool ReadUpstream = true;
+ bool WriteUpstream = true;
+};
+
+struct UpstreamError
+{
+ int32_t ErrorCode{};
+ std::string Reason{};
+
+ explicit operator bool() const { return ErrorCode != 0; }
+};
+
+struct UpstreamEndpointInfo
+{
+ std::string Name;
+ std::string Url;
+};
+
+struct GetUpstreamCacheResult
+{
+ UpstreamError Error{};
+ int64_t Bytes{};
+ double ElapsedSeconds{};
+ bool Success = false;
+};
+
+struct GetUpstreamCacheSingleResult
+{
+ GetUpstreamCacheResult Status;
+ IoBuffer Value;
+ const UpstreamEndpointInfo* Source = nullptr;
+};
+
+struct PutUpstreamCacheResult
+{
+ std::string Reason;
+ int64_t Bytes{};
+ double ElapsedSeconds{};
+ bool Success = false;
+};
+
+struct CacheRecordGetCompleteParams
+{
+ CacheKeyRequest& Request;
+ const CbObjectView& Record;
+ const CbPackage& Package;
+ double ElapsedSeconds{};
+ const UpstreamEndpointInfo* Source = nullptr;
+};
+
+using OnCacheRecordGetComplete = std::function<void(CacheRecordGetCompleteParams&&)>;
+
+struct CacheValueGetCompleteParams
+{
+ CacheValueRequest& Request;
+ IoHash RawHash;
+ uint64_t RawSize;
+ IoBuffer Value;
+ double ElapsedSeconds{};
+ const UpstreamEndpointInfo* Source = nullptr;
+};
+
+using OnCacheValueGetComplete = std::function<void(CacheValueGetCompleteParams&&)>;
+
+struct CacheChunkGetCompleteParams
+{
+ CacheChunkRequest& Request;
+ IoHash RawHash;
+ uint64_t RawSize;
+ IoBuffer Value;
+ double ElapsedSeconds{};
+ const UpstreamEndpointInfo* Source = nullptr;
+};
+
+using OnCacheChunksGetComplete = std::function<void(CacheChunkGetCompleteParams&&)>;
+
+struct UpstreamEndpointStats
+{
+ metrics::OperationTiming CacheGetRequestTiming;
+ metrics::OperationTiming CachePutRequestTiming;
+ metrics::Counter CacheGetTotalBytes;
+ metrics::Counter CachePutTotalBytes;
+ metrics::Counter CacheGetCount;
+ metrics::Counter CacheHitCount;
+ metrics::Counter CacheErrorCount;
+};
+
+enum class UpstreamEndpointState : uint32_t
+{
+ kDisabled,
+ kUnauthorized,
+ kError,
+ kOk
+};
+
+inline std::string_view
+ToString(UpstreamEndpointState State)
+{
+ using namespace std::literals;
+
+ switch (State)
+ {
+ case UpstreamEndpointState::kDisabled:
+ return "Disabled"sv;
+ case UpstreamEndpointState::kUnauthorized:
+ return "Unauthorized"sv;
+ case UpstreamEndpointState::kError:
+ return "Error"sv;
+ case UpstreamEndpointState::kOk:
+ return "Ok"sv;
+ default:
+ return "Unknown"sv;
+ }
+}
+
+struct UpstreamAuthConfig
+{
+ std::string_view OAuthUrl;
+ std::string_view OAuthClientId;
+ std::string_view OAuthClientSecret;
+ std::string_view OpenIdProvider;
+ std::string_view AccessToken;
+};
+
+struct UpstreamEndpointStatus
+{
+ std::string Reason;
+ UpstreamEndpointState State;
+};
+
+/**
+ * The upstream endpoint is responsible for handling upload/downloading of cache records.
+ */
+class UpstreamEndpoint
+{
+public:
+ virtual ~UpstreamEndpoint() = default;
+
+ virtual UpstreamEndpointStatus Initialize() = 0;
+
+ virtual const UpstreamEndpointInfo& GetEndpointInfo() const = 0;
+
+ virtual UpstreamEndpointState GetState() = 0;
+ virtual UpstreamEndpointStatus GetStatus() = 0;
+
+ virtual GetUpstreamCacheSingleResult GetCacheRecord(std::string_view Namespace, const CacheKey& CacheKey, ZenContentType Type) = 0;
+ virtual GetUpstreamCacheResult GetCacheRecords(std::string_view Namespace,
+ std::span<CacheKeyRequest*> Requests,
+ OnCacheRecordGetComplete&& OnComplete) = 0;
+
+ virtual GetUpstreamCacheResult GetCacheValues(std::string_view Namespace,
+ std::span<CacheValueRequest*> CacheValueRequests,
+ OnCacheValueGetComplete&& OnComplete) = 0;
+
+ virtual GetUpstreamCacheSingleResult GetCacheChunk(std::string_view Namespace, const CacheKey& CacheKey, const IoHash& PayloadId) = 0;
+ virtual GetUpstreamCacheResult GetCacheChunks(std::string_view Namespace,
+ std::span<CacheChunkRequest*> CacheChunkRequests,
+ OnCacheChunksGetComplete&& OnComplete) = 0;
+
+ virtual PutUpstreamCacheResult PutCacheRecord(const UpstreamCacheRecord& CacheRecord,
+ IoBuffer RecordValue,
+ std::span<IoBuffer const> Payloads) = 0;
+
+ virtual UpstreamEndpointStats& Stats() = 0;
+
+ static std::unique_ptr<UpstreamEndpoint> CreateZenEndpoint(const ZenStructuredCacheClientOptions& Options);
+
+ static std::unique_ptr<UpstreamEndpoint> CreateJupiterEndpoint(const CloudCacheClientOptions& Options,
+ const UpstreamAuthConfig& AuthConfig,
+ AuthMgr& Mgr);
+};
+
+/**
+ * Manages one or more upstream cache endpoints.
+ */
+class UpstreamCache
+{
+public:
+ virtual ~UpstreamCache() = default;
+
+ virtual void Initialize() = 0;
+
+ virtual void RegisterEndpoint(std::unique_ptr<UpstreamEndpoint> Endpoint) = 0;
+ virtual void IterateEndpoints(std::function<bool(UpstreamEndpoint&)>&& Fn) = 0;
+
+ virtual GetUpstreamCacheSingleResult GetCacheRecord(std::string_view Namespace, const CacheKey& CacheKey, ZenContentType Type) = 0;
+ virtual void GetCacheRecords(std::string_view Namespace,
+ std::span<CacheKeyRequest*> Requests,
+ OnCacheRecordGetComplete&& OnComplete) = 0;
+
+ virtual void GetCacheValues(std::string_view Namespace,
+ std::span<CacheValueRequest*> CacheValueRequests,
+ OnCacheValueGetComplete&& OnComplete) = 0;
+
+ virtual GetUpstreamCacheSingleResult GetCacheChunk(std::string_view Namespace,
+ const CacheKey& CacheKey,
+ const IoHash& ValueContentId) = 0;
+ virtual void GetCacheChunks(std::string_view Namespace,
+ std::span<CacheChunkRequest*> CacheChunkRequests,
+ OnCacheChunksGetComplete&& OnComplete) = 0;
+
+ virtual void EnqueueUpstream(UpstreamCacheRecord CacheRecord) = 0;
+
+ virtual void GetStatus(CbObjectWriter& CbO) = 0;
+
+ static std::unique_ptr<UpstreamCache> Create(const UpstreamCacheOptions& Options, ZenCacheStore& CacheStore, CidStore& CidStore);
+};
+
+} // namespace zen
diff --git a/src/zenserver/upstream/upstreamservice.cpp b/src/zenserver/upstream/upstreamservice.cpp
new file mode 100644
index 000000000..6db1357c5
--- /dev/null
+++ b/src/zenserver/upstream/upstreamservice.cpp
@@ -0,0 +1,56 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+#include <upstream/upstreamservice.h>
+
+#include <auth/authmgr.h>
+#include <upstream/upstreamcache.h>
+
+#include <zencore/compactbinarybuilder.h>
+#include <zencore/string.h>
+
+namespace zen {
+
+using namespace std::literals;
+
+HttpUpstreamService::HttpUpstreamService(UpstreamCache& Upstream, AuthMgr& Mgr) : m_Upstream(Upstream), m_AuthMgr(Mgr)
+{
+ m_Router.RegisterRoute(
+ "endpoints",
+ [this](HttpRouterRequest& Req) {
+ CbObjectWriter Writer;
+ Writer.BeginArray("Endpoints"sv);
+ m_Upstream.IterateEndpoints([&Writer](UpstreamEndpoint& Ep) {
+ UpstreamEndpointInfo Info = Ep.GetEndpointInfo();
+ UpstreamEndpointStatus Status = Ep.GetStatus();
+
+ Writer.BeginObject();
+ Writer << "Name"sv << Info.Name;
+ Writer << "Url"sv << Info.Url;
+ Writer << "State"sv << ToString(Status.State);
+ Writer << "Reason"sv << Status.Reason;
+ Writer.EndObject();
+
+ return true;
+ });
+ Writer.EndArray();
+ Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Writer.Save());
+ },
+ HttpVerb::kGet);
+}
+
+HttpUpstreamService::~HttpUpstreamService()
+{
+}
+
+const char*
+HttpUpstreamService::BaseUri() const
+{
+ return "/upstream/";
+}
+
+void
+HttpUpstreamService::HandleRequest(zen::HttpServerRequest& Request)
+{
+ m_Router.HandleRequest(Request);
+}
+
+} // namespace zen
diff --git a/src/zenserver/upstream/upstreamservice.h b/src/zenserver/upstream/upstreamservice.h
new file mode 100644
index 000000000..f1da03c8c
--- /dev/null
+++ b/src/zenserver/upstream/upstreamservice.h
@@ -0,0 +1,27 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zenhttp/httpserver.h>
+
+namespace zen {
+
+class AuthMgr;
+class UpstreamCache;
+
+class HttpUpstreamService final : public zen::HttpService
+{
+public:
+ HttpUpstreamService(UpstreamCache& Upstream, AuthMgr& Mgr);
+ virtual ~HttpUpstreamService();
+
+ virtual const char* BaseUri() const override;
+ virtual void HandleRequest(zen::HttpServerRequest& Request) override;
+
+private:
+ UpstreamCache& m_Upstream;
+ AuthMgr& m_AuthMgr;
+ HttpRequestRouter m_Router;
+};
+
+} // namespace zen
diff --git a/src/zenserver/upstream/zen.cpp b/src/zenserver/upstream/zen.cpp
new file mode 100644
index 000000000..9e1212834
--- /dev/null
+++ b/src/zenserver/upstream/zen.cpp
@@ -0,0 +1,326 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "zen.h"
+
+#include <zencore/compactbinarybuilder.h>
+#include <zencore/compactbinarypackage.h>
+#include <zencore/compactbinaryvalidation.h>
+#include <zencore/fmtutils.h>
+#include <zencore/session.h>
+#include <zencore/stream.h>
+#include <zenhttp/httpcommon.h>
+#include <zenhttp/httpshared.h>
+
+#include "cache/structuredcachestore.h"
+#include "diag/formatters.h"
+#include "diag/logging.h"
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <cpr/cpr.h>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+#include <xxhash.h>
+#include <gsl/gsl-lite.hpp>
+
+namespace zen {
+
+namespace detail {
+ struct ZenCacheSessionState
+ {
+ ZenCacheSessionState(ZenStructuredCacheClient& Client) : OwnerClient(Client) {}
+ ~ZenCacheSessionState() {}
+
+ void Reset(std::chrono::milliseconds ConnectTimeout, std::chrono::milliseconds Timeout)
+ {
+ Session.SetBody({});
+ Session.SetHeader({});
+ Session.SetConnectTimeout(ConnectTimeout);
+ Session.SetTimeout(Timeout);
+ }
+
+ cpr::Session& GetSession() { return Session; }
+
+ private:
+ ZenStructuredCacheClient& OwnerClient;
+ cpr::Session Session;
+ };
+
+} // namespace detail
+
+//////////////////////////////////////////////////////////////////////////
+
+ZenStructuredCacheClient::ZenStructuredCacheClient(const ZenStructuredCacheClientOptions& Options)
+: m_Log(logging::Get(std::string_view("zenclient")))
+, m_ServiceUrl(Options.Url)
+, m_ConnectTimeout(Options.ConnectTimeout)
+, m_Timeout(Options.Timeout)
+{
+}
+
+ZenStructuredCacheClient::~ZenStructuredCacheClient()
+{
+}
+
+detail::ZenCacheSessionState*
+ZenStructuredCacheClient::AllocSessionState()
+{
+ detail::ZenCacheSessionState* State = nullptr;
+
+ if (RwLock::ExclusiveLockScope _(m_SessionStateLock); !m_SessionStateCache.empty())
+ {
+ State = m_SessionStateCache.front();
+ m_SessionStateCache.pop_front();
+ }
+
+ if (State == nullptr)
+ {
+ State = new detail::ZenCacheSessionState(*this);
+ }
+
+ State->Reset(m_ConnectTimeout, m_Timeout);
+
+ return State;
+}
+
+void
+ZenStructuredCacheClient::FreeSessionState(detail::ZenCacheSessionState* State)
+{
+ RwLock::ExclusiveLockScope _(m_SessionStateLock);
+ m_SessionStateCache.push_front(State);
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+using namespace std::literals;
+
+ZenStructuredCacheSession::ZenStructuredCacheSession(Ref<ZenStructuredCacheClient>&& OuterClient)
+: m_Log(OuterClient->Log())
+, m_Client(std::move(OuterClient))
+{
+ m_SessionState = m_Client->AllocSessionState();
+}
+
+ZenStructuredCacheSession::~ZenStructuredCacheSession()
+{
+ m_Client->FreeSessionState(m_SessionState);
+}
+
+ZenCacheResult
+ZenStructuredCacheSession::CheckHealth()
+{
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_Client->ServiceUrl() << "/health/check";
+
+ cpr::Session& Session = m_SessionState->GetSession();
+ Session.SetOption(cpr::Url{Uri.c_str()});
+ cpr::Response Response = Session.Get();
+
+ if (Response.error)
+ {
+ return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = std::move(Response.error.message)};
+ }
+
+ return {.Bytes = Response.downloaded_bytes, .ElapsedSeconds = Response.elapsed, .Success = Response.status_code == 200};
+}
+
+ZenCacheResult
+ZenStructuredCacheSession::GetCacheRecord(std::string_view Namespace, std::string_view BucketId, const IoHash& Key, ZenContentType Type)
+{
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_Client->ServiceUrl() << "/z$/";
+ if (Namespace != ZenCacheStore::DefaultNamespace)
+ {
+ Uri << Namespace << "/";
+ }
+ Uri << BucketId << "/" << Key.ToHexString();
+
+ cpr::Session& Session = m_SessionState->GetSession();
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+ Session.SetHeader(cpr::Header{{"Accept", std::string{MapContentTypeToString(Type)}}});
+ cpr::Response Response = Session.Get();
+ ZEN_DEBUG("GET {}", Response);
+
+ if (Response.error)
+ {
+ return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = std::move(Response.error.message)};
+ }
+
+ const bool Success = Response.status_code == 200;
+ const IoBuffer Buffer = Success ? IoBufferBuilder::MakeCloneFromMemory(Response.text.data(), Response.text.size()) : IoBuffer();
+
+ return {.Response = Buffer, .Bytes = Response.downloaded_bytes, .ElapsedSeconds = Response.elapsed, .Success = Success};
+}
+
+ZenCacheResult
+ZenStructuredCacheSession::GetCacheChunk(std::string_view Namespace,
+ std::string_view BucketId,
+ const IoHash& Key,
+ const IoHash& ValueContentId)
+{
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_Client->ServiceUrl() << "/z$/";
+ if (Namespace != ZenCacheStore::DefaultNamespace)
+ {
+ Uri << Namespace << "/";
+ }
+ Uri << BucketId << "/" << Key.ToHexString() << "/" << ValueContentId.ToHexString();
+
+ cpr::Session& Session = m_SessionState->GetSession();
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+ Session.SetHeader(cpr::Header{{"Accept", "application/x-ue-comp"}});
+
+ cpr::Response Response = Session.Get();
+ ZEN_DEBUG("GET {}", Response);
+
+ if (Response.error)
+ {
+ return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = std::move(Response.error.message)};
+ }
+
+ const bool Success = Response.status_code == 200;
+ const IoBuffer Buffer = Success ? IoBufferBuilder::MakeCloneFromMemory(Response.text.data(), Response.text.size()) : IoBuffer();
+
+ return {.Response = Buffer,
+ .Bytes = Response.downloaded_bytes,
+ .ElapsedSeconds = Response.elapsed,
+ .Reason = Response.reason,
+ .Success = Success};
+}
+
+ZenCacheResult
+ZenStructuredCacheSession::PutCacheRecord(std::string_view Namespace,
+ std::string_view BucketId,
+ const IoHash& Key,
+ IoBuffer Value,
+ ZenContentType Type)
+{
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_Client->ServiceUrl() << "/z$/";
+ if (Namespace != ZenCacheStore::DefaultNamespace)
+ {
+ Uri << Namespace << "/";
+ }
+ Uri << BucketId << "/" << Key.ToHexString();
+
+ cpr::Session& Session = m_SessionState->GetSession();
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+ Session.SetHeader(cpr::Header{{"Content-Type",
+ Type == ZenContentType::kCbPackage ? "application/x-ue-cbpkg"
+ : Type == ZenContentType::kCbObject ? "application/x-ue-cb"
+ : "application/octet-stream"}});
+ Session.SetBody(cpr::Body{static_cast<const char*>(Value.Data()), Value.Size()});
+
+ cpr::Response Response = Session.Put();
+ ZEN_DEBUG("PUT {}", Response);
+
+ if (Response.error)
+ {
+ return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = std::move(Response.error.message)};
+ }
+
+ const bool Success = Response.status_code == 200 || Response.status_code == 201;
+ return {.Bytes = Response.uploaded_bytes, .ElapsedSeconds = Response.elapsed, .Reason = Response.reason, .Success = Success};
+}
+
+ZenCacheResult
+ZenStructuredCacheSession::PutCacheValue(std::string_view Namespace,
+ std::string_view BucketId,
+ const IoHash& Key,
+ const IoHash& ValueContentId,
+ IoBuffer Payload)
+{
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_Client->ServiceUrl() << "/z$/";
+ if (Namespace != ZenCacheStore::DefaultNamespace)
+ {
+ Uri << Namespace << "/";
+ }
+ Uri << BucketId << "/" << Key.ToHexString() << "/" << ValueContentId.ToHexString();
+
+ cpr::Session& Session = m_SessionState->GetSession();
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+ Session.SetHeader(cpr::Header{{"Content-Type", "application/x-ue-comp"}});
+ Session.SetBody(cpr::Body{static_cast<const char*>(Payload.Data()), Payload.Size()});
+
+ cpr::Response Response = Session.Put();
+ ZEN_DEBUG("PUT {}", Response);
+
+ if (Response.error)
+ {
+ return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = std::move(Response.error.message)};
+ }
+
+ const bool Success = Response.status_code == 200 || Response.status_code == 201;
+ return {.Bytes = Response.uploaded_bytes, .ElapsedSeconds = Response.elapsed, .Reason = Response.reason, .Success = Success};
+}
+
+ZenCacheResult
+ZenStructuredCacheSession::InvokeRpc(const CbObjectView& Request)
+{
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_Client->ServiceUrl() << "/z$/$rpc";
+
+ BinaryWriter Body;
+ Request.CopyTo(Body);
+
+ cpr::Session& Session = m_SessionState->GetSession();
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+ Session.SetHeader(cpr::Header{{"Content-Type", "application/x-ue-cb"}, {"Accept", "application/x-ue-cbpkg"}});
+ Session.SetBody(cpr::Body{reinterpret_cast<const char*>(Body.GetData()), Body.GetSize()});
+
+ cpr::Response Response = Session.Post();
+ ZEN_DEBUG("POST {}", Response);
+
+ if (Response.error)
+ {
+ return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = std::move(Response.error.message)};
+ }
+
+ const bool Success = Response.status_code == 200;
+ const IoBuffer Buffer = Success ? IoBufferBuilder::MakeCloneFromMemory(Response.text.data(), Response.text.size()) : IoBuffer();
+
+ return {.Response = std::move(Buffer),
+ .Bytes = Response.uploaded_bytes,
+ .ElapsedSeconds = Response.elapsed,
+ .Reason = Response.reason,
+ .Success = Success};
+}
+
+ZenCacheResult
+ZenStructuredCacheSession::InvokeRpc(const CbPackage& Request)
+{
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_Client->ServiceUrl() << "/z$/$rpc";
+
+ SharedBuffer Message = FormatPackageMessageBuffer(Request).Flatten();
+
+ cpr::Session& Session = m_SessionState->GetSession();
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+ Session.SetHeader(cpr::Header{{"Content-Type", "application/x-ue-cbpkg"}, {"Accept", "application/x-ue-cbpkg"}});
+ Session.SetBody(cpr::Body{reinterpret_cast<const char*>(Message.GetData()), Message.GetSize()});
+
+ cpr::Response Response = Session.Post();
+ ZEN_DEBUG("POST {}", Response);
+
+ if (Response.error)
+ {
+ return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = std::move(Response.error.message)};
+ }
+
+ const bool Success = Response.status_code == 200;
+ const IoBuffer Buffer = Success ? IoBufferBuilder::MakeCloneFromMemory(Response.text.data(), Response.text.size()) : IoBuffer();
+
+ return {.Response = std::move(Buffer),
+ .Bytes = Response.uploaded_bytes,
+ .ElapsedSeconds = Response.elapsed,
+ .Reason = Response.reason,
+ .Success = Success};
+}
+
+} // namespace zen
diff --git a/src/zenserver/upstream/zen.h b/src/zenserver/upstream/zen.h
new file mode 100644
index 000000000..bfba8fa98
--- /dev/null
+++ b/src/zenserver/upstream/zen.h
@@ -0,0 +1,125 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/iobuffer.h>
+#include <zencore/iohash.h>
+#include <zencore/logging.h>
+#include <zencore/memory.h>
+#include <zencore/thread.h>
+#include <zencore/uid.h>
+#include <zencore/zencore.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <tsl/robin_map.h>
+#include <asio.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+#include <chrono>
+#include <list>
+
+struct ZenCacheValue;
+
+namespace spdlog {
+class logger;
+}
+
+namespace zen {
+
+class CbObjectWriter;
+class CbObjectView;
+class CbPackage;
+class ZenStructuredCacheClient;
+
+//////////////////////////////////////////////////////////////////////////
+
+namespace detail {
+ struct ZenCacheSessionState;
+}
+
+struct ZenCacheResult
+{
+ IoBuffer Response;
+ int64_t Bytes = {};
+ double ElapsedSeconds = {};
+ int32_t ErrorCode = {};
+ std::string Reason;
+ bool Success = false;
+};
+
+struct ZenStructuredCacheClientOptions
+{
+ std::string_view Name;
+ std::string_view Url;
+ std::span<std::string const> Urls;
+ std::chrono::milliseconds ConnectTimeout{};
+ std::chrono::milliseconds Timeout{};
+};
+
+/** Zen Structured Cache session
+ *
+ * This provides a context in which cache queries can be performed
+ *
+ * These are currently all synchronous. Will need to be made asynchronous
+ */
+class ZenStructuredCacheSession
+{
+public:
+ ZenStructuredCacheSession(Ref<ZenStructuredCacheClient>&& OuterClient);
+ ~ZenStructuredCacheSession();
+
+ ZenCacheResult CheckHealth();
+ ZenCacheResult GetCacheRecord(std::string_view Namespace, std::string_view BucketId, const IoHash& Key, ZenContentType Type);
+ ZenCacheResult GetCacheChunk(std::string_view Namespace, std::string_view BucketId, const IoHash& Key, const IoHash& ValueContentId);
+ ZenCacheResult PutCacheRecord(std::string_view Namespace,
+ std::string_view BucketId,
+ const IoHash& Key,
+ IoBuffer Value,
+ ZenContentType Type);
+ ZenCacheResult PutCacheValue(std::string_view Namespace,
+ std::string_view BucketId,
+ const IoHash& Key,
+ const IoHash& ValueContentId,
+ IoBuffer Payload);
+ ZenCacheResult InvokeRpc(const CbObjectView& Request);
+ ZenCacheResult InvokeRpc(const CbPackage& Package);
+
+private:
+ inline spdlog::logger& Log() { return m_Log; }
+
+ spdlog::logger& m_Log;
+ Ref<ZenStructuredCacheClient> m_Client;
+ detail::ZenCacheSessionState* m_SessionState;
+};
+
+/** Zen Structured Cache client
+ *
+ * This represents an endpoint to query -- actual queries should be done via
+ * ZenStructuredCacheSession
+ */
+class ZenStructuredCacheClient : public RefCounted
+{
+public:
+ ZenStructuredCacheClient(const ZenStructuredCacheClientOptions& Options);
+ ~ZenStructuredCacheClient();
+
+ std::string_view ServiceUrl() const { return m_ServiceUrl; }
+
+ inline spdlog::logger& Log() { return m_Log; }
+
+private:
+ spdlog::logger& m_Log;
+ std::string m_ServiceUrl;
+ std::chrono::milliseconds m_ConnectTimeout;
+ std::chrono::milliseconds m_Timeout;
+
+ RwLock m_SessionStateLock;
+ std::list<detail::ZenCacheSessionState*> m_SessionStateCache;
+
+ detail::ZenCacheSessionState* AllocSessionState();
+ void FreeSessionState(detail::ZenCacheSessionState*);
+
+ friend class ZenStructuredCacheSession;
+};
+
+} // namespace zen
diff --git a/src/zenserver/windows/service.cpp b/src/zenserver/windows/service.cpp
new file mode 100644
index 000000000..89bacab0b
--- /dev/null
+++ b/src/zenserver/windows/service.cpp
@@ -0,0 +1,646 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "service.h"
+
+#include <zencore/zencore.h>
+
+#if ZEN_PLATFORM_WINDOWS
+
+# include <zencore/except.h>
+# include <zencore/zencore.h>
+
+# include <stdio.h>
+# include <tchar.h>
+# include <zencore/windows.h>
+
+# define SVCNAME L"Zen Store"
+
+SERVICE_STATUS gSvcStatus;
+SERVICE_STATUS_HANDLE gSvcStatusHandle;
+HANDLE ghSvcStopEvent = NULL;
+
+void SvcInstall(void);
+
+void ReportSvcStatus(DWORD, DWORD, DWORD);
+void SvcReportEvent(LPTSTR);
+
+WindowsService::WindowsService()
+{
+}
+
+WindowsService::~WindowsService()
+{
+}
+
+//
+// Purpose:
+// Installs a service in the SCM database
+//
+// Parameters:
+// None
+//
+// Return value:
+// None
+//
+VOID
+WindowsService::Install()
+{
+ SC_HANDLE schSCManager;
+ SC_HANDLE schService;
+ TCHAR szPath[MAX_PATH];
+
+ if (!GetModuleFileName(NULL, szPath, MAX_PATH))
+ {
+ printf("Cannot install service (%d)\n", GetLastError());
+ return;
+ }
+
+ // Get a handle to the SCM database.
+
+ schSCManager = OpenSCManager(NULL, // local computer
+ NULL, // ServicesActive database
+ SC_MANAGER_ALL_ACCESS); // full access rights
+
+ if (NULL == schSCManager)
+ {
+ printf("OpenSCManager failed (%d)\n", GetLastError());
+ return;
+ }
+
+ // Create the service
+
+ schService = CreateService(schSCManager, // SCM database
+ SVCNAME, // name of service
+ SVCNAME, // service name to display
+ SERVICE_ALL_ACCESS, // desired access
+ SERVICE_WIN32_OWN_PROCESS, // service type
+ SERVICE_DEMAND_START, // start type
+ SERVICE_ERROR_NORMAL, // error control type
+ szPath, // path to service's binary
+ NULL, // no load ordering group
+ NULL, // no tag identifier
+ NULL, // no dependencies
+ NULL, // LocalSystem account
+ NULL); // no password
+
+ if (schService == NULL)
+ {
+ printf("CreateService failed (%d)\n", GetLastError());
+ CloseServiceHandle(schSCManager);
+ return;
+ }
+ else
+ printf("Service installed successfully\n");
+
+ CloseServiceHandle(schService);
+ CloseServiceHandle(schSCManager);
+}
+
+void
+WindowsService::Delete()
+{
+ SC_HANDLE schSCManager;
+ SC_HANDLE schService;
+
+ // Get a handle to the SCM database.
+
+ schSCManager = OpenSCManager(NULL, // local computer
+ NULL, // ServicesActive database
+ SC_MANAGER_ALL_ACCESS); // full access rights
+
+ if (NULL == schSCManager)
+ {
+ printf("OpenSCManager failed (%d)\n", GetLastError());
+ return;
+ }
+
+ // Get a handle to the service.
+
+ schService = OpenService(schSCManager, // SCM database
+ SVCNAME, // name of service
+ DELETE); // need delete access
+
+ if (schService == NULL)
+ {
+ printf("OpenService failed (%d)\n", GetLastError());
+ CloseServiceHandle(schSCManager);
+ return;
+ }
+
+ // Delete the service.
+
+ if (!DeleteService(schService))
+ {
+ printf("DeleteService failed (%d)\n", GetLastError());
+ }
+ else
+ printf("Service deleted successfully\n");
+
+ CloseServiceHandle(schService);
+ CloseServiceHandle(schSCManager);
+}
+
+WindowsService* gSvc;
+
+void WINAPI
+CallMain(DWORD, LPSTR*)
+{
+ gSvc->SvcMain();
+}
+
+int
+WindowsService::ServiceMain()
+{
+ gSvc = this;
+
+ SERVICE_TABLE_ENTRY DispatchTable[] = {{(LPWSTR)SVCNAME, (LPSERVICE_MAIN_FUNCTION)&CallMain}, {NULL, NULL}};
+
+ // This call returns when the service has stopped.
+ // The process should simply terminate when the call returns.
+
+ if (!StartServiceCtrlDispatcher(DispatchTable))
+ {
+ const DWORD dwError = zen::GetLastError();
+
+ if (dwError == ERROR_FAILED_SERVICE_CONTROLLER_CONNECT)
+ {
+ // Not actually running as a service
+ gSvc = nullptr;
+
+ zen::SetIsInteractiveSession(true);
+
+ return Run();
+ }
+ else
+ {
+ zen::ThrowSystemError(dwError, "StartServiceCtrlDispatcher failed");
+ }
+ }
+
+ zen::SetIsInteractiveSession(false);
+
+ return 0;
+}
+
+int
+WindowsService::SvcMain()
+{
+ // Register the handler function for the service
+
+ gSvcStatusHandle = RegisterServiceCtrlHandler(SVCNAME, SvcCtrlHandler);
+
+ if (!gSvcStatusHandle)
+ {
+ SvcReportEvent((LPTSTR)TEXT("RegisterServiceCtrlHandler"));
+
+ return 1;
+ }
+
+ // These SERVICE_STATUS members remain as set here
+
+ gSvcStatus.dwServiceType = SERVICE_WIN32_OWN_PROCESS;
+ gSvcStatus.dwServiceSpecificExitCode = 0;
+
+ // Report initial status to the SCM
+
+ ReportSvcStatus(SERVICE_START_PENDING, NO_ERROR, 3000);
+
+ // Create an event. The control handler function, SvcCtrlHandler,
+ // signals this event when it receives the stop control code.
+
+ ghSvcStopEvent = CreateEvent(NULL, // default security attributes
+ TRUE, // manual reset event
+ FALSE, // not signaled
+ NULL); // no name
+
+ if (ghSvcStopEvent == NULL)
+ {
+ ReportSvcStatus(SERVICE_STOPPED, GetLastError(), 0);
+
+ return 1;
+ }
+
+ // Report running status when initialization is complete.
+
+ ReportSvcStatus(SERVICE_RUNNING, NO_ERROR, 0);
+
+ int ReturnCode = Run();
+
+ ReportSvcStatus(SERVICE_STOPPED, NO_ERROR, 0);
+
+ return ReturnCode;
+}
+
+//
+// Purpose:
+// Retrieves and displays the current service configuration.
+//
+// Parameters:
+// None
+//
+// Return value:
+// None
+//
+void
+DoQuerySvc()
+{
+ SC_HANDLE schSCManager{};
+ SC_HANDLE schService{};
+ LPQUERY_SERVICE_CONFIG lpsc{};
+ LPSERVICE_DESCRIPTION lpsd{};
+ DWORD dwBytesNeeded{}, cbBufSize{}, dwError{};
+
+ // Get a handle to the SCM database.
+
+ schSCManager = OpenSCManager(NULL, // local computer
+ NULL, // ServicesActive database
+ SC_MANAGER_ALL_ACCESS); // full access rights
+
+ if (NULL == schSCManager)
+ {
+ printf("OpenSCManager failed (%d)\n", GetLastError());
+ return;
+ }
+
+ // Get a handle to the service.
+
+ schService = OpenService(schSCManager, // SCM database
+ SVCNAME, // name of service
+ SERVICE_QUERY_CONFIG); // need query config access
+
+ if (schService == NULL)
+ {
+ printf("OpenService failed (%d)\n", GetLastError());
+ CloseServiceHandle(schSCManager);
+ return;
+ }
+
+ // Get the configuration information.
+
+ if (!QueryServiceConfig(schService, NULL, 0, &dwBytesNeeded))
+ {
+ dwError = GetLastError();
+ if (ERROR_INSUFFICIENT_BUFFER == dwError)
+ {
+ cbBufSize = dwBytesNeeded;
+ lpsc = (LPQUERY_SERVICE_CONFIG)LocalAlloc(LMEM_FIXED, cbBufSize);
+ }
+ else
+ {
+ printf("QueryServiceConfig failed (%d)", dwError);
+ goto cleanup;
+ }
+ }
+
+ if (!QueryServiceConfig(schService, lpsc, cbBufSize, &dwBytesNeeded))
+ {
+ printf("QueryServiceConfig failed (%d)", GetLastError());
+ goto cleanup;
+ }
+
+ if (!QueryServiceConfig2(schService, SERVICE_CONFIG_DESCRIPTION, NULL, 0, &dwBytesNeeded))
+ {
+ dwError = GetLastError();
+ if (ERROR_INSUFFICIENT_BUFFER == dwError)
+ {
+ cbBufSize = dwBytesNeeded;
+ lpsd = (LPSERVICE_DESCRIPTION)LocalAlloc(LMEM_FIXED, cbBufSize);
+ }
+ else
+ {
+ printf("QueryServiceConfig2 failed (%d)", dwError);
+ goto cleanup;
+ }
+ }
+
+ if (!QueryServiceConfig2(schService, SERVICE_CONFIG_DESCRIPTION, (LPBYTE)lpsd, cbBufSize, &dwBytesNeeded))
+ {
+ printf("QueryServiceConfig2 failed (%d)", GetLastError());
+ goto cleanup;
+ }
+
+ // Print the configuration information.
+
+ _tprintf(TEXT("%s configuration: \n"), SVCNAME);
+ _tprintf(TEXT(" Type: 0x%x\n"), lpsc->dwServiceType);
+ _tprintf(TEXT(" Start Type: 0x%x\n"), lpsc->dwStartType);
+ _tprintf(TEXT(" Error Control: 0x%x\n"), lpsc->dwErrorControl);
+ _tprintf(TEXT(" Binary path: %s\n"), lpsc->lpBinaryPathName);
+ _tprintf(TEXT(" Account: %s\n"), lpsc->lpServiceStartName);
+
+ if (lpsd->lpDescription != NULL && lstrcmp(lpsd->lpDescription, TEXT("")) != 0)
+ _tprintf(TEXT(" Description: %s\n"), lpsd->lpDescription);
+ if (lpsc->lpLoadOrderGroup != NULL && lstrcmp(lpsc->lpLoadOrderGroup, TEXT("")) != 0)
+ _tprintf(TEXT(" Load order group: %s\n"), lpsc->lpLoadOrderGroup);
+ if (lpsc->dwTagId != 0)
+ _tprintf(TEXT(" Tag ID: %d\n"), lpsc->dwTagId);
+ if (lpsc->lpDependencies != NULL && lstrcmp(lpsc->lpDependencies, TEXT("")) != 0)
+ _tprintf(TEXT(" Dependencies: %s\n"), lpsc->lpDependencies);
+
+ LocalFree(lpsc);
+ LocalFree(lpsd);
+
+cleanup:
+ CloseServiceHandle(schService);
+ CloseServiceHandle(schSCManager);
+}
+
+//
+// Purpose:
+// Disables the service.
+//
+// Parameters:
+// None
+//
+// Return value:
+// None
+//
+void
+DoDisableSvc()
+{
+ SC_HANDLE schSCManager;
+ SC_HANDLE schService;
+
+ // Get a handle to the SCM database.
+
+ schSCManager = OpenSCManager(NULL, // local computer
+ NULL, // ServicesActive database
+ SC_MANAGER_ALL_ACCESS); // full access rights
+
+ if (NULL == schSCManager)
+ {
+ printf("OpenSCManager failed (%d)\n", GetLastError());
+ return;
+ }
+
+ // Get a handle to the service.
+
+ schService = OpenService(schSCManager, // SCM database
+ SVCNAME, // name of service
+ SERVICE_CHANGE_CONFIG); // need change config access
+
+ if (schService == NULL)
+ {
+ printf("OpenService failed (%d)\n", GetLastError());
+ CloseServiceHandle(schSCManager);
+ return;
+ }
+
+ // Change the service start type.
+
+ if (!ChangeServiceConfig(schService, // handle of service
+ SERVICE_NO_CHANGE, // service type: no change
+ SERVICE_DISABLED, // service start type
+ SERVICE_NO_CHANGE, // error control: no change
+ NULL, // binary path: no change
+ NULL, // load order group: no change
+ NULL, // tag ID: no change
+ NULL, // dependencies: no change
+ NULL, // account name: no change
+ NULL, // password: no change
+ NULL)) // display name: no change
+ {
+ printf("ChangeServiceConfig failed (%d)\n", GetLastError());
+ }
+ else
+ printf("Service disabled successfully.\n");
+
+ CloseServiceHandle(schService);
+ CloseServiceHandle(schSCManager);
+}
+
+//
+// Purpose:
+// Enables the service.
+//
+// Parameters:
+// None
+//
+// Return value:
+// None
+//
+VOID __stdcall DoEnableSvc()
+{
+ SC_HANDLE schSCManager;
+ SC_HANDLE schService;
+
+ // Get a handle to the SCM database.
+
+ schSCManager = OpenSCManager(NULL, // local computer
+ NULL, // ServicesActive database
+ SC_MANAGER_ALL_ACCESS); // full access rights
+
+ if (NULL == schSCManager)
+ {
+ printf("OpenSCManager failed (%d)\n", GetLastError());
+ return;
+ }
+
+ // Get a handle to the service.
+
+ schService = OpenService(schSCManager, // SCM database
+ SVCNAME, // name of service
+ SERVICE_CHANGE_CONFIG); // need change config access
+
+ if (schService == NULL)
+ {
+ printf("OpenService failed (%d)\n", GetLastError());
+ CloseServiceHandle(schSCManager);
+ return;
+ }
+
+ // Change the service start type.
+
+ if (!ChangeServiceConfig(schService, // handle of service
+ SERVICE_NO_CHANGE, // service type: no change
+ SERVICE_DEMAND_START, // service start type
+ SERVICE_NO_CHANGE, // error control: no change
+ NULL, // binary path: no change
+ NULL, // load order group: no change
+ NULL, // tag ID: no change
+ NULL, // dependencies: no change
+ NULL, // account name: no change
+ NULL, // password: no change
+ NULL)) // display name: no change
+ {
+ printf("ChangeServiceConfig failed (%d)\n", GetLastError());
+ }
+ else
+ printf("Service enabled successfully.\n");
+
+ CloseServiceHandle(schService);
+ CloseServiceHandle(schSCManager);
+}
+//
+// Purpose:
+// Updates the service description to "This is a test description".
+//
+// Parameters:
+// None
+//
+// Return value:
+// None
+//
+void
+DoUpdateSvcDesc()
+{
+ SC_HANDLE schSCManager;
+ SC_HANDLE schService;
+ SERVICE_DESCRIPTION sd;
+ TCHAR szDesc[] = TEXT("This is a test description");
+
+ // Get a handle to the SCM database.
+
+ schSCManager = OpenSCManager(NULL, // local computer
+ NULL, // ServicesActive database
+ SC_MANAGER_ALL_ACCESS); // full access rights
+
+ if (NULL == schSCManager)
+ {
+ printf("OpenSCManager failed (%d)\n", GetLastError());
+ return;
+ }
+
+ // Get a handle to the service.
+
+ schService = OpenService(schSCManager, // SCM database
+ SVCNAME, // name of service
+ SERVICE_CHANGE_CONFIG); // need change config access
+
+ if (schService == NULL)
+ {
+ printf("OpenService failed (%d)\n", GetLastError());
+ CloseServiceHandle(schSCManager);
+ return;
+ }
+
+ // Change the service description.
+
+ sd.lpDescription = szDesc;
+
+ if (!ChangeServiceConfig2(schService, // handle to service
+ SERVICE_CONFIG_DESCRIPTION, // change: description
+ &sd)) // new description
+ {
+ printf("ChangeServiceConfig2 failed\n");
+ }
+ else
+ printf("Service description updated successfully.\n");
+
+ CloseServiceHandle(schService);
+ CloseServiceHandle(schSCManager);
+}
+
+//
+// Purpose:
+// Sets the current service status and reports it to the SCM.
+//
+// Parameters:
+// dwCurrentState - The current state (see SERVICE_STATUS)
+// dwWin32ExitCode - The system error code
+// dwWaitHint - Estimated time for pending operation,
+// in milliseconds
+//
+// Return value:
+// None
+//
+VOID
+ReportSvcStatus(DWORD dwCurrentState, DWORD dwWin32ExitCode, DWORD dwWaitHint)
+{
+ static DWORD dwCheckPoint = 1;
+
+ // Fill in the SERVICE_STATUS structure.
+
+ gSvcStatus.dwCurrentState = dwCurrentState;
+ gSvcStatus.dwWin32ExitCode = dwWin32ExitCode;
+ gSvcStatus.dwWaitHint = dwWaitHint;
+
+ if (dwCurrentState == SERVICE_START_PENDING)
+ gSvcStatus.dwControlsAccepted = 0;
+ else
+ gSvcStatus.dwControlsAccepted = SERVICE_ACCEPT_STOP;
+
+ if ((dwCurrentState == SERVICE_RUNNING) || (dwCurrentState == SERVICE_STOPPED))
+ gSvcStatus.dwCheckPoint = 0;
+ else
+ gSvcStatus.dwCheckPoint = dwCheckPoint++;
+
+ // Report the status of the service to the SCM.
+ SetServiceStatus(gSvcStatusHandle, &gSvcStatus);
+}
+
+void
+WindowsService::SvcCtrlHandler(DWORD dwCtrl)
+{
+ // Handle the requested control code.
+ //
+ // Called by SCM whenever a control code is sent to the service
+ // using the ControlService function.
+
+ switch (dwCtrl)
+ {
+ case SERVICE_CONTROL_STOP:
+ ReportSvcStatus(SERVICE_STOP_PENDING, NO_ERROR, 0);
+
+ // Signal the service to stop.
+
+ SetEvent(ghSvcStopEvent);
+ zen::RequestApplicationExit(0);
+
+ ReportSvcStatus(gSvcStatus.dwCurrentState, NO_ERROR, 0);
+ return;
+
+ case SERVICE_CONTROL_INTERROGATE:
+ break;
+
+ default:
+ break;
+ }
+}
+
+//
+// Purpose:
+// Logs messages to the event log
+//
+// Parameters:
+// szFunction - name of function that failed
+//
+// Return value:
+// None
+//
+// Remarks:
+// The service must have an entry in the Application event log.
+//
+VOID
+SvcReportEvent(LPTSTR szFunction)
+{
+ ZEN_UNUSED(szFunction);
+
+ // HANDLE hEventSource;
+ // LPCTSTR lpszStrings[2];
+ // TCHAR Buffer[80];
+
+ // hEventSource = RegisterEventSource(NULL, SVCNAME);
+
+ // if (NULL != hEventSource)
+ //{
+ // StringCchPrintf(Buffer, 80, TEXT("%s failed with %d"), szFunction, GetLastError());
+
+ // lpszStrings[0] = SVCNAME;
+ // lpszStrings[1] = Buffer;
+
+ // ReportEvent(hEventSource, // event log handle
+ // EVENTLOG_ERROR_TYPE, // event type
+ // 0, // event category
+ // SVC_ERROR, // event identifier
+ // NULL, // no security identifier
+ // 2, // size of lpszStrings array
+ // 0, // no binary data
+ // lpszStrings, // array of strings
+ // NULL); // no binary data
+
+ // DeregisterEventSource(hEventSource);
+ //}
+}
+
+#endif // ZEN_PLATFORM_WINDOWS
diff --git a/src/zenserver/windows/service.h b/src/zenserver/windows/service.h
new file mode 100644
index 000000000..7c9610983
--- /dev/null
+++ b/src/zenserver/windows/service.h
@@ -0,0 +1,20 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+class WindowsService
+{
+public:
+ WindowsService();
+ ~WindowsService();
+
+ virtual int Run() = 0;
+
+ int ServiceMain();
+
+ static void Install();
+ static void Delete();
+
+ int SvcMain();
+ static void __stdcall SvcCtrlHandler(unsigned long);
+};
diff --git a/src/zenserver/xmake.lua b/src/zenserver/xmake.lua
new file mode 100644
index 000000000..23bfb9535
--- /dev/null
+++ b/src/zenserver/xmake.lua
@@ -0,0 +1,60 @@
+-- Copyright Epic Games, Inc. All Rights Reserved.
+
+target("zenserver")
+ set_kind("binary")
+ add_deps("zencore", "zenhttp", "zenstore", "zenutil")
+ add_headerfiles("**.h")
+ add_files("**.cpp")
+ add_files("zenserver.cpp", {unity_ignored = true })
+ add_includedirs(".")
+ set_symbols("debug")
+
+ if is_mode("release") then
+ set_optimize("fastest")
+ end
+
+ if is_plat("windows") then
+ add_ldflags("/subsystem:console,5.02")
+ add_ldflags("/MANIFEST:EMBED")
+ add_ldflags("/LTCG")
+ add_files("zenserver.rc")
+ add_cxxflags("/bigobj")
+ else
+ remove_files("windows/**")
+ end
+
+ if is_plat("macosx") then
+ add_ldflags("-framework CoreFoundation")
+ add_ldflags("-framework CoreGraphics")
+ add_ldflags("-framework CoreText")
+ add_ldflags("-framework Foundation")
+ add_ldflags("-framework Security")
+ add_ldflags("-framework SystemConfiguration")
+ add_syslinks("bsm")
+ end
+
+ add_options("compute")
+ add_options("exec")
+
+ add_packages(
+ "vcpkg::asio",
+ "vcpkg::cxxopts",
+ "vcpkg::http-parser",
+ "vcpkg::json11",
+ "vcpkg::lua",
+ "vcpkg::mimalloc",
+ "vcpkg::rocksdb",
+ "vcpkg::sentry-native",
+ "vcpkg::sol2"
+ )
+
+ -- Only applicable to later versions of sentry-native
+ --[[
+ if is_plat("linux") then
+ -- As sentry_native uses symbols from breakpad_client, the latter must
+ -- be specified after the former with GCC-like toolchains. xmake however
+ -- is unaware of this and simply globs files from vcpkg's output. The
+ -- line below forces breakpad_client to be to the right of sentry_native
+ add_syslinks("breakpad_client")
+ end
+ ]]--
diff --git a/src/zenserver/zenserver.cpp b/src/zenserver/zenserver.cpp
new file mode 100644
index 000000000..635fd04e0
--- /dev/null
+++ b/src/zenserver/zenserver.cpp
@@ -0,0 +1,1261 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zencore/compactbinarybuilder.h>
+#include <zencore/compactbinaryvalidation.h>
+#include <zencore/config.h>
+#include <zencore/filesystem.h>
+#include <zencore/fmtutils.h>
+#include <zencore/iobuffer.h>
+#include <zencore/logging.h>
+#include <zencore/refcount.h>
+#include <zencore/scopeguard.h>
+#include <zencore/session.h>
+#include <zencore/string.h>
+#include <zencore/thread.h>
+#include <zencore/timer.h>
+#include <zencore/trace.h>
+#include <zenhttp/httpserver.h>
+#include <zenhttp/websocket.h>
+#include <zenstore/cidstore.h>
+#include <zenstore/scrubcontext.h>
+#include <zenutil/basicfile.h>
+#include <zenutil/zenserverprocess.h>
+
+#if ZEN_PLATFORM_WINDOWS
+# include <zencore/windows.h>
+#endif
+
+#if ZEN_USE_MIMALLOC
+ZEN_THIRD_PARTY_INCLUDES_START
+# include <mimalloc-new-delete.h>
+# include <mimalloc.h>
+ZEN_THIRD_PARTY_INCLUDES_END
+#endif
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <fmt/format.h>
+#include <asio.hpp>
+#include <lua.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+#include <exception>
+#include <list>
+#include <optional>
+#include <regex>
+#include <set>
+#include <unordered_map>
+
+//////////////////////////////////////////////////////////////////////////
+// We don't have any doctest code in this file but this is needed to bring
+// in some shared code into the executable
+
+#if ZEN_WITH_TESTS
+# define ZEN_TEST_WITH_RUNNER 1
+# include <zencore/testing.h>
+#endif
+
+//////////////////////////////////////////////////////////////////////////
+
+#include "config.h"
+#include "diag/logging.h"
+
+#if ZEN_PLATFORM_WINDOWS
+# include "windows/service.h"
+#endif
+
+//////////////////////////////////////////////////////////////////////////
+// Sentry
+//
+
+#if !defined(ZEN_USE_SENTRY)
+# if ZEN_PLATFORM_MAC && ZEN_ARCH_ARM64
+// vcpkg's sentry-native port does not support Arm on Mac.
+# define ZEN_USE_SENTRY 0
+# else
+# define ZEN_USE_SENTRY 1
+# endif
+#endif
+
+#if ZEN_USE_SENTRY
+# define SENTRY_BUILD_STATIC 1
+ZEN_THIRD_PARTY_INCLUDES_START
+# include <sentry.h>
+# include <spdlog/sinks/base_sink.h>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+// Sentry currently does not automatically add all required Windows
+// libraries to the linker when consumed via vcpkg
+
+# if ZEN_PLATFORM_WINDOWS
+# pragma comment(lib, "sentry.lib")
+# pragma comment(lib, "dbghelp.lib")
+# pragma comment(lib, "winhttp.lib")
+# pragma comment(lib, "version.lib")
+# endif
+#endif
+
+//////////////////////////////////////////////////////////////////////////
+// Services
+//
+
+#include "admin/admin.h"
+#include "auth/authmgr.h"
+#include "auth/authservice.h"
+#include "cache/structuredcache.h"
+#include "cache/structuredcachestore.h"
+#include "cidstore.h"
+#include "compute/function.h"
+#include "diag/diagsvcs.h"
+#include "frontend/frontend.h"
+#include "monitoring/httpstats.h"
+#include "monitoring/httpstatus.h"
+#include "objectstore/objectstore.h"
+#include "projectstore/projectstore.h"
+#include "testing/httptest.h"
+#include "upstream/upstream.h"
+#include "zenstore/gc.h"
+
+#define ZEN_APP_NAME "Zen store"
+
+namespace zen {
+
+using namespace std::literals;
+
+namespace utils {
+#if ZEN_USE_SENTRY
+ class sentry_sink final : public spdlog::sinks::base_sink<spdlog::details::null_mutex>
+ {
+ public:
+ sentry_sink() {}
+
+ protected:
+ static constexpr sentry_level_t MapToSentryLevel[spdlog::level::level_enum::n_levels] = {SENTRY_LEVEL_DEBUG,
+ SENTRY_LEVEL_DEBUG,
+ SENTRY_LEVEL_INFO,
+ SENTRY_LEVEL_WARNING,
+ SENTRY_LEVEL_ERROR,
+ SENTRY_LEVEL_FATAL,
+ SENTRY_LEVEL_DEBUG};
+
+ void sink_it_(const spdlog::details::log_msg& msg) override
+ {
+ std::string Message = fmt::format("{}\n{}({}) [{}]", msg.payload, msg.source.filename, msg.source.line, msg.source.funcname);
+ sentry_value_t event = sentry_value_new_message_event(
+ /* level */ MapToSentryLevel[msg.level],
+ /* logger */ nullptr,
+ /* message */ Message.c_str());
+ sentry_event_value_add_stacktrace(event, NULL, 0);
+ sentry_capture_event(event);
+ }
+ void flush_() override {}
+ };
+#endif
+
+ asio::error_code ResolveHostname(asio::io_context& Ctx,
+ std::string_view Host,
+ std::string_view DefaultPort,
+ std::vector<std::string>& OutEndpoints)
+ {
+ std::string_view Port = DefaultPort;
+
+ if (const size_t Idx = Host.find(":"); Idx != std::string_view::npos)
+ {
+ Port = Host.substr(Idx + 1);
+ Host = Host.substr(0, Idx);
+ }
+
+ asio::ip::tcp::resolver Resolver(Ctx);
+
+ asio::error_code ErrorCode;
+ asio::ip::tcp::resolver::results_type Endpoints = Resolver.resolve(Host, Port, ErrorCode);
+
+ if (!ErrorCode)
+ {
+ for (const asio::ip::tcp::endpoint Ep : Endpoints)
+ {
+ OutEndpoints.push_back(fmt::format("http://{}:{}", Ep.address().to_string(), Ep.port()));
+ }
+ }
+
+ return ErrorCode;
+ }
+} // namespace utils
+
+class ZenServer : public IHttpStatusProvider
+{
+public:
+ int Initialize(const ZenServerOptions& ServerOptions, ZenServerState::ZenServerEntry* ServerEntry)
+ {
+ m_UseSentry = ServerOptions.NoSentry == false;
+ m_ServerEntry = ServerEntry;
+ m_DebugOptionForcedCrash = ServerOptions.ShouldCrash;
+ const int ParentPid = ServerOptions.OwnerPid;
+
+ if (ParentPid)
+ {
+ zen::ProcessHandle OwnerProcess;
+ OwnerProcess.Initialize(ParentPid);
+
+ if (!OwnerProcess.IsValid())
+ {
+ ZEN_WARN("Unable to initialize process handle for specified parent pid #{}", ParentPid);
+
+ // If the pid is not reachable should we just shut down immediately? the intended owner process
+ // could have been killed or somehow crashed already
+ }
+ else
+ {
+ ZEN_INFO("Using parent pid #{} to control process lifetime", ParentPid);
+ }
+
+ m_ProcessMonitor.AddPid(ParentPid);
+ }
+
+ // Initialize/check mutex based on base port
+
+ std::string MutexName = fmt::format("zen_{}", ServerOptions.BasePort);
+
+ if (zen::NamedMutex::Exists(MutexName) || ((m_ServerMutex.Create(MutexName) == false)))
+ {
+ throw std::runtime_error(fmt::format("Failed to create mutex '{}' - is another instance already running?", MutexName).c_str());
+ }
+
+ InitializeState(ServerOptions);
+
+ m_HealthService.SetHealthInfo({.DataRoot = m_DataRoot,
+ .AbsLogPath = ServerOptions.AbsLogFile,
+ .HttpServerClass = std::string(ServerOptions.HttpServerClass),
+ .BuildVersion = std::string(ZEN_CFG_VERSION_BUILD_STRING_FULL)});
+
+ // Ok so now we're configured, let's kick things off
+
+ m_Http = zen::CreateHttpServer(ServerOptions.HttpServerClass);
+ int EffectiveBasePort = m_Http->Initialize(ServerOptions.BasePort);
+
+ if (ServerOptions.WebSocketPort != 0)
+ {
+ const uint32 ThreadCount =
+ ServerOptions.WebSocketThreads > 0 ? uint32_t(ServerOptions.WebSocketThreads) : std::thread::hardware_concurrency();
+
+ m_WebSocket = zen::WebSocketServer::Create(
+ {.Port = gsl::narrow<uint16_t>(ServerOptions.WebSocketPort), .ThreadCount = Max(ThreadCount, uint32_t(16))});
+ }
+
+ // Setup authentication manager
+ {
+ std::string EncryptionKey = ServerOptions.EncryptionKey;
+
+ if (EncryptionKey.empty())
+ {
+ EncryptionKey = "abcdefghijklmnopqrstuvxyz0123456";
+
+ ZEN_WARN("using default encryption key");
+ }
+
+ std::string EncryptionIV = ServerOptions.EncryptionIV;
+
+ if (EncryptionIV.empty())
+ {
+ EncryptionIV = "0123456789abcdef";
+
+ ZEN_WARN("using default encryption initialization vector");
+ }
+
+ m_AuthMgr = AuthMgr::Create({.RootDirectory = m_DataRoot / "auth",
+ .EncryptionKey = AesKey256Bit::FromString(EncryptionKey),
+ .EncryptionIV = AesIV128Bit::FromString(EncryptionIV)});
+
+ for (const ZenOpenIdProviderConfig& OpenIdProvider : ServerOptions.AuthConfig.OpenIdProviders)
+ {
+ m_AuthMgr->AddOpenIdProvider({.Name = OpenIdProvider.Name, .Url = OpenIdProvider.Url, .ClientId = OpenIdProvider.ClientId});
+ }
+ }
+
+ m_AuthService = std::make_unique<zen::HttpAuthService>(*m_AuthMgr);
+ m_Http->RegisterService(*m_AuthService);
+
+ m_Http->RegisterService(m_HealthService);
+ m_Http->RegisterService(m_StatsService);
+ m_Http->RegisterService(m_StatusService);
+ m_StatusService.RegisterHandler("status", *this);
+
+ // Initialize storage and services
+
+ ZEN_INFO("initializing storage");
+
+ zen::CidStoreConfiguration Config;
+ Config.RootDirectory = m_DataRoot / "cas";
+
+ m_CidStore = std::make_unique<zen::CidStore>(m_GcManager);
+ m_CidStore->Initialize(Config);
+ m_CidService.reset(new zen::HttpCidService{*m_CidStore});
+
+ ZEN_INFO("instantiating project service");
+
+ m_ProjectStore = new zen::ProjectStore(*m_CidStore, m_DataRoot / "projects", m_GcManager);
+ m_HttpProjectService.reset(new zen::HttpProjectService{*m_CidStore, m_ProjectStore, m_StatsService, *m_AuthMgr});
+
+#if ZEN_WITH_COMPUTE_SERVICES
+ if (ServerOptions.ComputeServiceEnabled)
+ {
+ InitializeCompute(ServerOptions);
+ }
+ else
+ {
+ ZEN_INFO("NOT instantiating compute services");
+ }
+#endif // ZEN_WITH_COMPUTE_SERVICES
+
+ if (ServerOptions.StructuredCacheEnabled)
+ {
+ InitializeStructuredCache(ServerOptions);
+ }
+ else
+ {
+ ZEN_INFO("NOT instantiating structured cache service");
+ }
+
+ m_Http->RegisterService(m_TestService); // NOTE: this is intentionally not limited to test mode as it's useful for diagnostics
+ m_Http->RegisterService(m_TestingService);
+ m_Http->RegisterService(m_AdminService);
+
+ if (m_WebSocket)
+ {
+ m_WebSocket->RegisterService(m_TestingService);
+ }
+
+ if (m_HttpProjectService)
+ {
+ m_Http->RegisterService(*m_HttpProjectService);
+ }
+
+ m_Http->RegisterService(*m_CidService);
+
+#if ZEN_WITH_COMPUTE_SERVICES
+ if (ServerOptions.ComputeServiceEnabled)
+ {
+ if (m_HttpFunctionService != nullptr)
+ {
+ m_Http->RegisterService(*m_HttpFunctionService);
+ }
+ }
+#endif // ZEN_WITH_COMPUTE_SERVICES
+
+ m_FrontendService = std::make_unique<HttpFrontendService>(m_ContentRoot);
+
+ if (m_FrontendService)
+ {
+ m_Http->RegisterService(*m_FrontendService);
+ }
+
+ if (ServerOptions.ObjectStoreEnabled)
+ {
+ ObjectStoreConfig ObjCfg;
+ ObjCfg.RootDirectory = m_DataRoot / "obj";
+ ObjCfg.ServerPort = static_cast<uint16_t>(EffectiveBasePort);
+
+ for (const auto& Bucket : ServerOptions.ObjectStoreConfig.Buckets)
+ {
+ ObjectStoreConfig::BucketConfig NewBucket{.Name = Bucket.Name};
+ NewBucket.Directory = Bucket.Directory.empty() ? (ObjCfg.RootDirectory / Bucket.Name) : Bucket.Directory;
+ ObjCfg.Buckets.push_back(std::move(NewBucket));
+ }
+
+ m_ObjStoreService = std::make_unique<HttpObjectStoreService>(std::move(ObjCfg));
+ m_Http->RegisterService(*m_ObjStoreService);
+ }
+
+ ZEN_INFO("initializing GC, enabled '{}', interval {}s", ServerOptions.GcConfig.Enabled, ServerOptions.GcConfig.IntervalSeconds);
+ zen::GcSchedulerConfig GcConfig{.RootDirectory = m_DataRoot / "gc",
+ .MonitorInterval = std::chrono::seconds(ServerOptions.GcConfig.MonitorIntervalSeconds),
+ .Interval = std::chrono::seconds(ServerOptions.GcConfig.IntervalSeconds),
+ .MaxCacheDuration = std::chrono::seconds(ServerOptions.GcConfig.Cache.MaxDurationSeconds),
+ .CollectSmallObjects = ServerOptions.GcConfig.CollectSmallObjects,
+ .Enabled = ServerOptions.GcConfig.Enabled,
+ .DiskReserveSize = ServerOptions.GcConfig.DiskReserveSize,
+ .DiskSizeSoftLimit = ServerOptions.GcConfig.Cache.DiskSizeSoftLimit};
+ m_GcScheduler.Initialize(GcConfig);
+
+ return EffectiveBasePort;
+ }
+
+ void InitializeState(const ZenServerOptions& ServerOptions);
+ void InitializeStructuredCache(const ZenServerOptions& ServerOptions);
+ void InitializeCompute(const ZenServerOptions& ServerOptions);
+
+ void Run()
+ {
+ // This is disabled for now, awaiting better scheduling
+ //
+ // Scrub();
+
+ if (m_ProcessMonitor.IsActive())
+ {
+ EnqueueTimer();
+ }
+
+ if (!m_TestMode)
+ {
+ ZEN_INFO("__________ _________ __ ");
+ ZEN_INFO("\\____ /____ ____ / _____// |_ ___________ ____ ");
+ ZEN_INFO(" / // __ \\ / \\ \\_____ \\\\ __\\/ _ \\_ __ \\_/ __ \\ ");
+ ZEN_INFO(" / /\\ ___/| | \\ / \\| | ( <_> ) | \\/\\ ___/ ");
+ ZEN_INFO("/_______ \\___ >___| / /_______ /|__| \\____/|__| \\___ >");
+ ZEN_INFO(" \\/ \\/ \\/ \\/ \\/ ");
+ }
+
+ ZEN_INFO(ZEN_APP_NAME " now running (pid: {})", zen::GetCurrentProcessId());
+
+#if ZEN_USE_SENTRY
+ ZEN_INFO("sentry crash handler {}", m_UseSentry ? "ENABLED" : "DISABLED");
+ if (m_UseSentry)
+ {
+ sentry_clear_modulecache();
+ }
+#endif
+
+ if (m_DebugOptionForcedCrash)
+ {
+ ZEN_DEBUG_BREAK();
+ }
+
+ const bool IsInteractiveMode = zen::IsInteractiveSession() && !m_TestMode;
+
+ SetNewState(kRunning);
+
+ OnReady();
+
+ if (m_WebSocket)
+ {
+ m_WebSocket->Run();
+ }
+
+ m_Http->Run(IsInteractiveMode);
+
+ SetNewState(kShuttingDown);
+
+ ZEN_INFO(ZEN_APP_NAME " exiting");
+
+ m_IoContext.stop();
+ if (m_IoRunner.joinable())
+ {
+ m_IoRunner.join();
+ }
+
+ Flush();
+ }
+
+ void RequestExit(int ExitCode)
+ {
+ RequestApplicationExit(ExitCode);
+ m_Http->RequestExit();
+ }
+
+ void Cleanup()
+ {
+ ZEN_INFO(ZEN_APP_NAME " cleaning up");
+ m_GcScheduler.Shutdown();
+ }
+
+ void SetDedicatedMode(bool State) { m_IsDedicatedMode = State; }
+ void SetTestMode(bool State) { m_TestMode = State; }
+ void SetDataRoot(std::filesystem::path Root) { m_DataRoot = Root; }
+ void SetContentRoot(std::filesystem::path Root) { m_ContentRoot = Root; }
+
+ std::function<void()> m_IsReadyFunc;
+ void SetIsReadyFunc(std::function<void()>&& IsReadyFunc) { m_IsReadyFunc = std::move(IsReadyFunc); }
+ void OnReady();
+
+ void EnsureIoRunner()
+ {
+ if (!m_IoRunner.joinable())
+ {
+ m_IoRunner = std::thread{[this] { m_IoContext.run(); }};
+ }
+ }
+
+ void EnqueueTimer()
+ {
+ m_PidCheckTimer.expires_after(std::chrono::seconds(1));
+ m_PidCheckTimer.async_wait([this](const asio::error_code&) { CheckOwnerPid(); });
+
+ EnsureIoRunner();
+ }
+
+ void CheckOwnerPid()
+ {
+ // Pick up any new "owner" processes
+
+ std::set<uint32_t> AddedPids;
+
+ for (auto& PidEntry : m_ServerEntry->SponsorPids)
+ {
+ if (uint32_t ThisPid = PidEntry.load(std::memory_order_relaxed))
+ {
+ if (PidEntry.compare_exchange_strong(ThisPid, 0))
+ {
+ if (AddedPids.insert(ThisPid).second)
+ {
+ m_ProcessMonitor.AddPid(ThisPid);
+
+ ZEN_INFO("added process with pid #{} as a sponsor process", ThisPid);
+ }
+ }
+ }
+ }
+
+ if (m_ProcessMonitor.IsRunning())
+ {
+ EnqueueTimer();
+ }
+ else
+ {
+ ZEN_INFO(ZEN_APP_NAME " exiting since sponsor processes are all gone");
+
+ RequestExit(0);
+ }
+ }
+
+ void Scrub()
+ {
+ Stopwatch Timer;
+ ZEN_INFO("Storage validation STARTING");
+
+ ScrubContext Ctx;
+ m_CidStore->Scrub(Ctx);
+ m_ProjectStore->Scrub(Ctx);
+ m_StructuredCacheService->Scrub(Ctx);
+
+ const uint64_t ElapsedTimeMs = Timer.GetElapsedTimeMs();
+
+ ZEN_INFO("Storage validation DONE in {}, ({} in {} chunks - {})",
+ NiceTimeSpanMs(ElapsedTimeMs),
+ NiceBytes(Ctx.ScrubbedBytes()),
+ Ctx.ScrubbedChunks(),
+ NiceByteRate(Ctx.ScrubbedBytes(), ElapsedTimeMs));
+ }
+
+ void Flush()
+ {
+ if (m_CidStore)
+ m_CidStore->Flush();
+
+ if (m_StructuredCacheService)
+ m_StructuredCacheService->Flush();
+
+ if (m_ProjectStore)
+ m_ProjectStore->Flush();
+ }
+
+ virtual void HandleStatusRequest(HttpServerRequest& Request) override
+ {
+ CbObjectWriter Cbo;
+ Cbo << "ok" << true;
+ Cbo << "state" << ToString(m_CurrentState);
+ Request.WriteResponse(HttpResponseCode::OK, Cbo.Save());
+ }
+
+private:
+ ZenServerState::ZenServerEntry* m_ServerEntry = nullptr;
+ bool m_IsDedicatedMode = false;
+ bool m_TestMode = false;
+ CbObject m_RootManifest;
+ std::filesystem::path m_DataRoot;
+ std::filesystem::path m_ContentRoot;
+ std::thread m_IoRunner;
+ asio::io_context m_IoContext;
+ asio::steady_timer m_PidCheckTimer{m_IoContext};
+ zen::ProcessMonitor m_ProcessMonitor;
+ zen::NamedMutex m_ServerMutex;
+
+ enum ServerState
+ {
+ kInitializing,
+ kRunning,
+ kShuttingDown
+ } m_CurrentState = kInitializing;
+
+ inline void SetNewState(ServerState NewState) { m_CurrentState = NewState; }
+
+ std::string_view ToString(ServerState Value)
+ {
+ switch (Value)
+ {
+ case kInitializing:
+ return "initializing"sv;
+ case kRunning:
+ return "running"sv;
+ case kShuttingDown:
+ return "shutdown"sv;
+ default:
+ return "unknown"sv;
+ }
+ }
+
+ zen::Ref<zen::HttpServer> m_Http;
+ std::unique_ptr<zen::WebSocketServer> m_WebSocket;
+ std::unique_ptr<zen::AuthMgr> m_AuthMgr;
+ std::unique_ptr<zen::HttpAuthService> m_AuthService;
+ zen::HttpStatusService m_StatusService;
+ zen::HttpStatsService m_StatsService;
+ zen::GcManager m_GcManager;
+ zen::GcScheduler m_GcScheduler{m_GcManager};
+ std::unique_ptr<zen::CidStore> m_CidStore;
+ std::unique_ptr<zen::ZenCacheStore> m_CacheStore;
+ zen::HttpTestService m_TestService;
+ zen::HttpTestingService m_TestingService;
+ std::unique_ptr<zen::HttpCidService> m_CidService;
+ zen::RefPtr<zen::ProjectStore> m_ProjectStore;
+ std::unique_ptr<zen::HttpProjectService> m_HttpProjectService;
+ std::unique_ptr<zen::UpstreamCache> m_UpstreamCache;
+ std::unique_ptr<zen::HttpUpstreamService> m_UpstreamService;
+ std::unique_ptr<zen::HttpStructuredCacheService> m_StructuredCacheService;
+ zen::HttpAdminService m_AdminService{m_GcScheduler};
+ zen::HttpHealthService m_HealthService;
+#if ZEN_WITH_COMPUTE_SERVICES
+ std::unique_ptr<zen::HttpFunctionService> m_HttpFunctionService;
+#endif // ZEN_WITH_COMPUTE_SERVICES
+ std::unique_ptr<zen::HttpFrontendService> m_FrontendService;
+ std::unique_ptr<zen::HttpObjectStoreService> m_ObjStoreService;
+
+ bool m_DebugOptionForcedCrash = false;
+ bool m_UseSentry = false;
+};
+
+void
+ZenServer::OnReady()
+{
+ m_ServerEntry->SignalReady();
+
+ if (m_IsReadyFunc)
+ {
+ m_IsReadyFunc();
+ }
+}
+
+void
+ZenServer::InitializeState(const ZenServerOptions& ServerOptions)
+{
+ // Check root manifest to deal with schema versioning
+
+ bool WipeState = false;
+ std::string WipeReason = "Unspecified";
+
+ bool UpdateManifest = false;
+ std::filesystem::path ManifestPath = m_DataRoot / "root_manifest";
+ FileContents ManifestData = zen::ReadFile(ManifestPath);
+
+ if (ManifestData.ErrorCode)
+ {
+ if (ServerOptions.IsFirstRun)
+ {
+ ZEN_INFO("Initializing state at '{}'", m_DataRoot);
+
+ UpdateManifest = true;
+ }
+ else
+ {
+ WipeState = true;
+ WipeReason = fmt::format("No manifest present at '{}'", ManifestPath);
+ }
+ }
+ else
+ {
+ IoBuffer Manifest = ManifestData.Flatten();
+
+ if (CbValidateError ValidationResult = ValidateCompactBinary(Manifest, CbValidateMode::All);
+ ValidationResult != CbValidateError::None)
+ {
+ ZEN_WARN("Manifest validation failed: {}, state will be wiped", uint32_t(ValidationResult));
+
+ WipeState = true;
+ WipeReason = fmt::format("Validation of manifest at '{}' failed: {}", ManifestPath, uint32_t(ValidationResult));
+ }
+ else
+ {
+ m_RootManifest = LoadCompactBinaryObject(Manifest);
+
+ const int32_t ManifestVersion = m_RootManifest["schema_version"].AsInt32(0);
+
+ if (ManifestVersion != ZEN_CFG_SCHEMA_VERSION)
+ {
+ WipeState = true;
+ WipeReason = fmt::format("Manifest schema version: {}, differs from required: {}", ManifestVersion, ZEN_CFG_SCHEMA_VERSION);
+ }
+ }
+ }
+
+ // Release any open handles so we can overwrite the manifest
+ ManifestData = {};
+
+ // Handle any state wipe
+
+ if (WipeState)
+ {
+ ZEN_WARN("Wiping state at '{}' - reason: '{}'", m_DataRoot, WipeReason);
+
+ std::error_code Ec;
+ for (const std::filesystem::directory_entry& DirEntry : std::filesystem::directory_iterator{m_DataRoot, Ec})
+ {
+ if (DirEntry.is_directory() && (DirEntry.path().filename() != "logs"))
+ {
+ ZEN_INFO("Deleting '{}'", DirEntry.path());
+
+ std::filesystem::remove_all(DirEntry.path(), Ec);
+
+ if (Ec)
+ {
+ ZEN_WARN("Delete of '{}' returned error: '{}'", DirEntry.path(), Ec.message());
+ }
+ }
+ }
+
+ ZEN_INFO("Wiped all directories in data root");
+
+ UpdateManifest = true;
+ }
+
+ if (UpdateManifest)
+ {
+ // Write new manifest
+
+ const DateTime Now = DateTime::Now();
+
+ CbObjectWriter Cbo;
+ Cbo << "schema_version" << ZEN_CFG_SCHEMA_VERSION << "created" << Now << "updated" << Now << "state_id" << Oid::NewOid();
+
+ m_RootManifest = Cbo.Save();
+
+ WriteFile(ManifestPath, m_RootManifest.GetBuffer().AsIoBuffer());
+ }
+}
+
+void
+ZenServer::InitializeStructuredCache(const ZenServerOptions& ServerOptions)
+{
+ using namespace std::literals;
+
+ ZEN_INFO("instantiating structured cache service");
+ m_CacheStore = std::make_unique<ZenCacheStore>(
+ m_GcManager,
+ ZenCacheStore::Configuration{.BasePath = m_DataRoot / "cache", .AllowAutomaticCreationOfNamespaces = true});
+
+ const ZenUpstreamCacheConfig& UpstreamConfig = ServerOptions.UpstreamCacheConfig;
+
+ zen::UpstreamCacheOptions UpstreamOptions;
+ UpstreamOptions.ReadUpstream = (uint8_t(ServerOptions.UpstreamCacheConfig.CachePolicy) & uint8_t(UpstreamCachePolicy::Read)) != 0;
+ UpstreamOptions.WriteUpstream = (uint8_t(ServerOptions.UpstreamCacheConfig.CachePolicy) & uint8_t(UpstreamCachePolicy::Write)) != 0;
+
+ if (UpstreamConfig.UpstreamThreadCount < 32)
+ {
+ UpstreamOptions.ThreadCount = static_cast<uint32_t>(UpstreamConfig.UpstreamThreadCount);
+ }
+
+ m_UpstreamCache = zen::UpstreamCache::Create(UpstreamOptions, *m_CacheStore, *m_CidStore);
+ m_UpstreamService = std::make_unique<HttpUpstreamService>(*m_UpstreamCache, *m_AuthMgr);
+ m_UpstreamCache->Initialize();
+
+ if (ServerOptions.UpstreamCacheConfig.CachePolicy != UpstreamCachePolicy::Disabled)
+ {
+ // Zen upstream
+ {
+ std::vector<std::string> ZenUrls = UpstreamConfig.ZenConfig.Urls;
+ if (!UpstreamConfig.ZenConfig.Dns.empty())
+ {
+ for (const std::string& Dns : UpstreamConfig.ZenConfig.Dns)
+ {
+ if (!Dns.empty())
+ {
+ const asio::error_code Err = zen::utils::ResolveHostname(m_IoContext, Dns, "1337"sv, ZenUrls);
+ if (Err)
+ {
+ ZEN_ERROR("resolve FAILED, reason '{}'", Err.message());
+ }
+ }
+ }
+ }
+
+ std::erase_if(ZenUrls, [](const auto& Url) { return Url.empty(); });
+
+ if (!ZenUrls.empty())
+ {
+ const auto ZenEndpointName = UpstreamConfig.ZenConfig.Name.empty() ? "Zen"sv : UpstreamConfig.ZenConfig.Name;
+
+ std::unique_ptr<zen::UpstreamEndpoint> ZenEndpoint = zen::UpstreamEndpoint::CreateZenEndpoint(
+ {.Name = ZenEndpointName,
+ .Urls = ZenUrls,
+ .ConnectTimeout = std::chrono::milliseconds(UpstreamConfig.ConnectTimeoutMilliseconds),
+ .Timeout = std::chrono::milliseconds(UpstreamConfig.TimeoutMilliseconds)});
+
+ m_UpstreamCache->RegisterEndpoint(std::move(ZenEndpoint));
+ }
+ }
+
+ // Jupiter upstream
+ if (UpstreamConfig.JupiterConfig.Url.empty() == false)
+ {
+ std::string_view EndpointName = UpstreamConfig.JupiterConfig.Name.empty() ? "Jupiter"sv : UpstreamConfig.JupiterConfig.Name;
+
+ auto Options =
+ zen::CloudCacheClientOptions{.Name = EndpointName,
+ .ServiceUrl = UpstreamConfig.JupiterConfig.Url,
+ .DdcNamespace = UpstreamConfig.JupiterConfig.DdcNamespace,
+ .BlobStoreNamespace = UpstreamConfig.JupiterConfig.Namespace,
+ .ConnectTimeout = std::chrono::milliseconds(UpstreamConfig.ConnectTimeoutMilliseconds),
+ .Timeout = std::chrono::milliseconds(UpstreamConfig.TimeoutMilliseconds)};
+
+ auto AuthConfig = zen::UpstreamAuthConfig{.OAuthUrl = UpstreamConfig.JupiterConfig.OAuthUrl,
+ .OAuthClientId = UpstreamConfig.JupiterConfig.OAuthClientId,
+ .OAuthClientSecret = UpstreamConfig.JupiterConfig.OAuthClientSecret,
+ .OpenIdProvider = UpstreamConfig.JupiterConfig.OpenIdProvider,
+ .AccessToken = UpstreamConfig.JupiterConfig.AccessToken};
+
+ std::unique_ptr<zen::UpstreamEndpoint> JupiterEndpoint =
+ zen::UpstreamEndpoint::CreateJupiterEndpoint(Options, AuthConfig, *m_AuthMgr);
+
+ m_UpstreamCache->RegisterEndpoint(std::move(JupiterEndpoint));
+ }
+ }
+
+ m_StructuredCacheService =
+ std::make_unique<HttpStructuredCacheService>(*m_CacheStore, *m_CidStore, m_StatsService, m_StatusService, *m_UpstreamCache);
+
+ m_Http->RegisterService(*m_StructuredCacheService);
+ m_Http->RegisterService(*m_UpstreamService);
+}
+
+#if ZEN_WITH_COMPUTE_SERVICES
+void
+ZenServer::InitializeCompute(const ZenServerOptions& ServerOptions)
+{
+ ServerOptions;
+ const ZenUpstreamCacheConfig& UpstreamConfig = ServerOptions.UpstreamCacheConfig;
+
+ // Horde compute upstream
+ if (UpstreamConfig.HordeConfig.Url.empty() == false && UpstreamConfig.HordeConfig.StorageUrl.empty() == false)
+ {
+ ZEN_INFO("instantiating compute service");
+
+ std::string_view EndpointName = UpstreamConfig.HordeConfig.Name.empty() ? "Horde"sv : UpstreamConfig.HordeConfig.Name;
+
+ auto ComputeOptions =
+ zen::CloudCacheClientOptions{.Name = EndpointName,
+ .ServiceUrl = UpstreamConfig.HordeConfig.Url,
+ .ComputeCluster = UpstreamConfig.HordeConfig.Cluster,
+ .ConnectTimeout = std::chrono::milliseconds(UpstreamConfig.ConnectTimeoutMilliseconds),
+ .Timeout = std::chrono::milliseconds(UpstreamConfig.TimeoutMilliseconds)};
+
+ auto ComputeAuthConfig = zen::UpstreamAuthConfig{.OAuthUrl = UpstreamConfig.HordeConfig.OAuthUrl,
+ .OAuthClientId = UpstreamConfig.HordeConfig.OAuthClientId,
+ .OAuthClientSecret = UpstreamConfig.HordeConfig.OAuthClientSecret,
+ .OpenIdProvider = UpstreamConfig.HordeConfig.OpenIdProvider,
+ .AccessToken = UpstreamConfig.HordeConfig.AccessToken};
+
+ auto StorageOptions =
+ zen::CloudCacheClientOptions{.Name = EndpointName,
+ .ServiceUrl = UpstreamConfig.HordeConfig.StorageUrl,
+ .BlobStoreNamespace = UpstreamConfig.HordeConfig.Namespace,
+ .ConnectTimeout = std::chrono::milliseconds(UpstreamConfig.ConnectTimeoutMilliseconds),
+ .Timeout = std::chrono::milliseconds(UpstreamConfig.TimeoutMilliseconds)};
+
+ auto StorageAuthConfig = zen::UpstreamAuthConfig{.OAuthUrl = UpstreamConfig.HordeConfig.StorageOAuthUrl,
+ .OAuthClientId = UpstreamConfig.HordeConfig.StorageOAuthClientId,
+ .OAuthClientSecret = UpstreamConfig.HordeConfig.StorageOAuthClientSecret,
+ .OpenIdProvider = UpstreamConfig.HordeConfig.StorageOpenIdProvider,
+ .AccessToken = UpstreamConfig.HordeConfig.StorageAccessToken};
+
+ m_HttpFunctionService = std::make_unique<zen::HttpFunctionService>(*m_CidStore,
+ ComputeOptions,
+ StorageOptions,
+ ComputeAuthConfig,
+ StorageAuthConfig,
+ *m_AuthMgr);
+ }
+ else
+ {
+ ZEN_INFO("NOT instantiating compute service (missing Horde or Storage config)");
+ }
+}
+#endif // ZEN_WITH_COMPUTE_SERVICES
+
+////////////////////////////////////////////////////////////////////////////////
+
+class ZenEntryPoint
+{
+public:
+ ZenEntryPoint(ZenServerOptions& ServerOptions);
+ ZenEntryPoint(const ZenEntryPoint&) = delete;
+ ZenEntryPoint& operator=(const ZenEntryPoint&) = delete;
+ int Run();
+
+private:
+ ZenServerOptions& m_ServerOptions;
+ zen::LockFile m_LockFile;
+};
+
+ZenEntryPoint::ZenEntryPoint(ZenServerOptions& ServerOptions) : m_ServerOptions(ServerOptions)
+{
+}
+
+#if ZEN_USE_SENTRY
+static void
+SentryLogFunction(sentry_level_t Level, const char* Message, va_list Args, [[maybe_unused]] void* Userdata)
+{
+ char LogMessageBuffer[160];
+ std::string LogMessage;
+ const char* MessagePtr = LogMessageBuffer;
+
+ int n = vsnprintf(LogMessageBuffer, sizeof LogMessageBuffer, Message, Args);
+
+ if (n >= int(sizeof LogMessageBuffer))
+ {
+ LogMessage.resize(n + 1);
+
+ n = vsnprintf(LogMessage.data(), LogMessage.size(), Message, Args);
+
+ MessagePtr = LogMessage.c_str();
+ }
+
+ switch (Level)
+ {
+ case SENTRY_LEVEL_DEBUG:
+ ConsoleLog().debug("sentry: {}", MessagePtr);
+ break;
+
+ case SENTRY_LEVEL_INFO:
+ ConsoleLog().info("sentry: {}", MessagePtr);
+ break;
+
+ case SENTRY_LEVEL_WARNING:
+ ConsoleLog().warn("sentry: {}", MessagePtr);
+ break;
+
+ case SENTRY_LEVEL_ERROR:
+ ConsoleLog().error("sentry: {}", MessagePtr);
+ break;
+
+ case SENTRY_LEVEL_FATAL:
+ ConsoleLog().critical("sentry: {}", MessagePtr);
+ break;
+ }
+}
+#endif
+
+int
+ZenEntryPoint::Run()
+{
+#if ZEN_USE_SENTRY
+ std::string SentryDatabasePath = PathToUtf8(m_ServerOptions.DataDir / ".sentry-native");
+ int SentryErrorCode = 0;
+ if (m_ServerOptions.NoSentry == false)
+ {
+ sentry_options_t* SentryOptions = sentry_options_new();
+ sentry_options_set_dsn(SentryOptions, "https://[email protected]/5919284");
+ if (SentryDatabasePath.starts_with("\\\\?\\"))
+ {
+ SentryDatabasePath = SentryDatabasePath.substr(4);
+ }
+ sentry_options_set_database_path(SentryOptions, SentryDatabasePath.c_str());
+ sentry_options_set_logger(SentryOptions, SentryLogFunction, this);
+ std::string SentryAttachmentPath = m_ServerOptions.AbsLogFile.string();
+ if (SentryAttachmentPath.starts_with("\\\\?\\"))
+ {
+ SentryAttachmentPath = SentryAttachmentPath.substr(4);
+ }
+ sentry_options_add_attachment(SentryOptions, SentryAttachmentPath.c_str());
+ sentry_options_set_release(SentryOptions, ZEN_CFG_VERSION);
+ // sentry_options_set_debug(SentryOptions, 1);
+
+ SentryErrorCode = sentry_init(SentryOptions);
+
+ auto SentrySink = spdlog::create<utils::sentry_sink>("sentry");
+ zen::logging::SetErrorLog(std::move(SentrySink));
+ }
+
+ auto _ = zen::MakeGuard([] {
+ zen::logging::SetErrorLog(std::shared_ptr<spdlog::logger>());
+ sentry_close();
+ });
+#endif
+
+ auto& ServerOptions = m_ServerOptions;
+
+ try
+ {
+ // Mutual exclusion and synchronization
+ ZenServerState ServerState;
+ ServerState.Initialize();
+ ServerState.Sweep();
+
+ ZenServerState::ZenServerEntry* Entry = ServerState.Lookup(ServerOptions.BasePort);
+
+ if (Entry)
+ {
+ if (ServerOptions.OwnerPid)
+ {
+ ConsoleLog().info(
+ "Looks like there is already a process listening to this port {} (pid: {}), attaching owner pid {} to running instance",
+ ServerOptions.BasePort,
+ Entry->Pid,
+ ServerOptions.OwnerPid);
+
+ Entry->AddSponsorProcess(ServerOptions.OwnerPid);
+
+ std::exit(0);
+ }
+ else
+ {
+ ConsoleLog().warn("Exiting since there is already a process listening to port {} (pid: {})",
+ ServerOptions.BasePort,
+ Entry->Pid);
+ std::exit(1);
+ }
+ }
+
+ std::error_code Ec;
+
+ std::filesystem::path LockFilePath = ServerOptions.DataDir / ".lock";
+
+ bool IsReady = false;
+
+ auto MakeLockData = [&] {
+ CbObjectWriter Cbo;
+ Cbo << "pid" << zen::GetCurrentProcessId() << "data" << PathToUtf8(ServerOptions.DataDir) << "port" << ServerOptions.BasePort
+ << "session_id" << GetSessionId() << "ready" << IsReady;
+ return Cbo.Save();
+ };
+
+ m_LockFile.Create(LockFilePath, MakeLockData(), Ec);
+
+ if (Ec)
+ {
+ ConsoleLog().warn("ERROR: Unable to grab lock at '{}' (error: '{}')", LockFilePath, Ec.message());
+
+ std::exit(99);
+ }
+
+ InitializeLogging(ServerOptions);
+
+#if ZEN_USE_SENTRY
+ if (m_ServerOptions.NoSentry == false)
+ {
+ if (SentryErrorCode == 0)
+ {
+ ZEN_INFO("sentry initialized");
+ }
+ else
+ {
+ ZEN_WARN("sentry_init returned failure! (error code: {})", SentryErrorCode);
+ }
+ }
+#endif
+
+ MaximizeOpenFileCount();
+
+ ZEN_INFO(ZEN_APP_NAME " - using lock file at '{}'", LockFilePath);
+
+ ZEN_INFO(ZEN_APP_NAME " - starting on port {}, version '{}'", ServerOptions.BasePort, ZEN_CFG_VERSION_BUILD_STRING_FULL);
+
+ Entry = ServerState.Register(ServerOptions.BasePort);
+
+ if (ServerOptions.OwnerPid)
+ {
+ Entry->AddSponsorProcess(ServerOptions.OwnerPid);
+ }
+
+ ZenServer Server;
+ Server.SetDataRoot(ServerOptions.DataDir);
+ Server.SetContentRoot(ServerOptions.ContentDir);
+ Server.SetTestMode(ServerOptions.IsTest);
+ Server.SetDedicatedMode(ServerOptions.IsDedicated);
+
+ int EffectiveBasePort = Server.Initialize(ServerOptions, Entry);
+
+ Entry->EffectiveListenPort = uint16_t(EffectiveBasePort);
+ if (EffectiveBasePort != ServerOptions.BasePort)
+ {
+ ZEN_INFO(ZEN_APP_NAME " - relocated to base port {}", EffectiveBasePort);
+ ServerOptions.BasePort = EffectiveBasePort;
+ }
+
+ std::unique_ptr<std::thread> ShutdownThread;
+ std::unique_ptr<zen::NamedEvent> ShutdownEvent;
+
+ zen::ExtendableStringBuilder<64> ShutdownEventName;
+ ShutdownEventName << "Zen_" << ServerOptions.BasePort << "_Shutdown";
+ ShutdownEvent.reset(new zen::NamedEvent{ShutdownEventName});
+
+ // Monitor shutdown signals
+
+ ShutdownThread.reset(new std::thread{[&] {
+ ZEN_INFO("shutdown monitor thread waiting for shutdown signal '{}'", ShutdownEventName);
+ if (ShutdownEvent->Wait())
+ {
+ ZEN_INFO("shutdown signal received");
+ Server.RequestExit(0);
+ }
+ else
+ {
+ ZEN_INFO("shutdown signal wait() failed");
+ }
+ }});
+
+ // If we have a parent process, establish the mechanisms we need
+ // to be able to communicate readiness with the parent
+
+ Server.SetIsReadyFunc([&] {
+ IsReady = true;
+
+ m_LockFile.Update(MakeLockData(), Ec);
+
+ if (!ServerOptions.ChildId.empty())
+ {
+ zen::NamedEvent ParentEvent{ServerOptions.ChildId};
+ ParentEvent.Set();
+ }
+ });
+
+ Server.Run();
+ Server.Cleanup();
+
+ ShutdownEvent->Set();
+ ShutdownThread->join();
+ }
+ catch (std::exception& e)
+ {
+ SPDLOG_CRITICAL("Caught exception in main: {}", e.what());
+ }
+
+ ShutdownLogging();
+
+ return 0;
+}
+
+} // namespace zen
+
+////////////////////////////////////////////////////////////////////////////////
+
+#if ZEN_PLATFORM_WINDOWS
+
+class ZenWindowsService : public WindowsService
+{
+public:
+ ZenWindowsService(ZenServerOptions& ServerOptions) : m_EntryPoint(ServerOptions) {}
+
+ ZenWindowsService(const ZenWindowsService&) = delete;
+ ZenWindowsService& operator=(const ZenWindowsService&) = delete;
+
+ virtual int Run() override;
+
+private:
+ zen::ZenEntryPoint m_EntryPoint;
+};
+
+int
+ZenWindowsService::Run()
+{
+ return m_EntryPoint.Run();
+}
+
+#endif // ZEN_PLATFORM_WINDOWS
+
+////////////////////////////////////////////////////////////////////////////////
+
+#if ZEN_WITH_TESTS
+int
+test_main(int argc, char** argv)
+{
+ zen::zencore_forcelinktests();
+ zen::zenhttp_forcelinktests();
+ zen::zenstore_forcelinktests();
+ zen::z$_forcelink();
+ zen::z$service_forcelink();
+
+ zen::logging::InitializeLogging();
+ spdlog::set_level(spdlog::level::debug);
+
+ zen::MaximizeOpenFileCount();
+
+ return ZEN_RUN_TESTS(argc, argv);
+}
+#endif
+
+int
+main(int argc, char* argv[])
+{
+ using namespace zen;
+
+#if ZEN_USE_MIMALLOC
+ mi_version();
+#endif
+
+#if ZEN_WITH_TESTS
+ if (argc >= 2)
+ {
+ if (argv[1] == "test"sv)
+ {
+ return test_main(argc, argv);
+ }
+ }
+#endif
+
+ try
+ {
+ ZenServerOptions ServerOptions;
+ ParseCliOptions(argc, argv, ServerOptions);
+
+ if (!std::filesystem::exists(ServerOptions.DataDir))
+ {
+ ServerOptions.IsFirstRun = true;
+ std::filesystem::create_directories(ServerOptions.DataDir);
+ }
+
+#if ZEN_WITH_TRACE
+ if (ServerOptions.TraceHost.size())
+ {
+ TraceInit(ServerOptions.TraceHost.c_str(), TraceType::Network);
+ }
+ else if (ServerOptions.TraceFile.size())
+ {
+ TraceInit(ServerOptions.TraceFile.c_str(), TraceType::File);
+ }
+ else
+ {
+ TraceInit(nullptr, TraceType::None);
+ }
+#endif // ZEN_WITH_TRACE
+
+#if ZEN_PLATFORM_WINDOWS
+ if (ServerOptions.InstallService)
+ {
+ WindowsService::Install();
+
+ std::exit(0);
+ }
+
+ if (ServerOptions.UninstallService)
+ {
+ WindowsService::Delete();
+
+ std::exit(0);
+ }
+
+ ZenWindowsService App(ServerOptions);
+ return App.ServiceMain();
+#else
+ if (ServerOptions.InstallService || ServerOptions.UninstallService)
+ {
+ throw std::runtime_error("Service mode is not supported on this platform");
+ }
+
+ ZenEntryPoint App(ServerOptions);
+ return App.Run();
+#endif // ZEN_PLATFORM_WINDOWS
+ }
+ catch (std::exception& Ex)
+ {
+ fprintf(stderr, "ERROR: Caught exception in main: '%s'", Ex.what());
+
+ return 1;
+ }
+}
diff --git a/src/zenserver/zenserver.rc b/src/zenserver/zenserver.rc
new file mode 100644
index 000000000..6d31e2c6e
--- /dev/null
+++ b/src/zenserver/zenserver.rc
@@ -0,0 +1,105 @@
+// Microsoft Visual C++ generated resource script.
+//
+#include "resource.h"
+
+#include "zencore/config.h"
+
+#define APSTUDIO_READONLY_SYMBOLS
+/////////////////////////////////////////////////////////////////////////////
+//
+// Generated from the TEXTINCLUDE 2 resource.
+//
+#include "winres.h"
+
+/////////////////////////////////////////////////////////////////////////////
+#undef APSTUDIO_READONLY_SYMBOLS
+
+/////////////////////////////////////////////////////////////////////////////
+// English (United States) resources
+
+#if !defined(AFX_RESOURCE_DLL) || defined(AFX_TARG_ENU)
+LANGUAGE LANG_ENGLISH, SUBLANG_ENGLISH_US
+#pragma code_page(1252)
+
+/////////////////////////////////////////////////////////////////////////////
+//
+// Icon
+//
+
+// Icon with lowest ID value placed first to ensure application icon
+// remains consistent on all systems.
+IDI_ICON1 ICON "..\\UnrealEngine.ico"
+
+#endif // English (United States) resources
+/////////////////////////////////////////////////////////////////////////////
+
+
+/////////////////////////////////////////////////////////////////////////////
+// English (United Kingdom) resources
+
+#if !defined(AFX_RESOURCE_DLL) || defined(AFX_TARG_ENG)
+LANGUAGE LANG_ENGLISH, SUBLANG_ENGLISH_UK
+#pragma code_page(1252)
+
+#ifdef APSTUDIO_INVOKED
+/////////////////////////////////////////////////////////////////////////////
+//
+// TEXTINCLUDE
+//
+
+1 TEXTINCLUDE
+BEGIN
+ "resource.h\0"
+END
+
+2 TEXTINCLUDE
+BEGIN
+ "#include ""winres.h""\r\n"
+ "\0"
+END
+
+3 TEXTINCLUDE
+BEGIN
+ "\r\n"
+ "\0"
+END
+
+#endif // APSTUDIO_INVOKED
+
+#endif // English (United Kingdom) resources
+/////////////////////////////////////////////////////////////////////////////
+
+
+
+#ifndef APSTUDIO_INVOKED
+/////////////////////////////////////////////////////////////////////////////
+//
+// Generated from the TEXTINCLUDE 3 resource.
+//
+
+
+/////////////////////////////////////////////////////////////////////////////
+#endif // not APSTUDIO_INVOKED
+
+VS_VERSION_INFO VERSIONINFO
+FILEVERSION ZEN_CFG_VERSION_MAJOR,ZEN_CFG_VERSION_MINOR,ZEN_CFG_VERSION_ALTER,0
+PRODUCTVERSION ZEN_CFG_VERSION_MAJOR,ZEN_CFG_VERSION_MINOR,ZEN_CFG_VERSION_ALTER,0
+{
+ BLOCK "StringFileInfo"
+ {
+ BLOCK "040904b0"
+ {
+ VALUE "CompanyName", "Epic Games Inc\0"
+ VALUE "FileDescription", "Local Storage Service for Unreal Engine\0"
+ VALUE "FileVersion", ZEN_CFG_VERSION "\0"
+ VALUE "LegalCopyright", "Copyright Epic Games Inc. All Rights Reserved\0"
+ VALUE "OriginalFilename", "zenserver.exe\0"
+ VALUE "ProductName", "Zen Storage Server\0"
+ VALUE "ProductVersion", ZEN_CFG_VERSION_BUILD_STRING_FULL "\0"
+ }
+ }
+ BLOCK "VarFileInfo"
+ {
+ VALUE "Translation", 0x409, 1200
+ }
+}