aboutsummaryrefslogtreecommitdiff
path: root/src/zenserver/upstream
diff options
context:
space:
mode:
authorStefan Boberg <[email protected]>2023-05-02 10:01:47 +0200
committerGitHub <[email protected]>2023-05-02 10:01:47 +0200
commit075d17f8ada47e990fe94606c3d21df409223465 (patch)
treee50549b766a2f3c354798a54ff73404217b4c9af /src/zenserver/upstream
parentfix: bundle shouldn't append content zip to zen (diff)
downloadzen-075d17f8ada47e990fe94606c3d21df409223465.tar.xz
zen-075d17f8ada47e990fe94606c3d21df409223465.zip
moved source directories into `/src` (#264)
* moved source directories into `/src` * updated bundle.lua for new `src` path * moved some docs, icon * removed old test trees
Diffstat (limited to 'src/zenserver/upstream')
-rw-r--r--src/zenserver/upstream/hordecompute.cpp1457
-rw-r--r--src/zenserver/upstream/jupiter.cpp965
-rw-r--r--src/zenserver/upstream/jupiter.h217
-rw-r--r--src/zenserver/upstream/upstream.h8
-rw-r--r--src/zenserver/upstream/upstreamapply.cpp459
-rw-r--r--src/zenserver/upstream/upstreamapply.h192
-rw-r--r--src/zenserver/upstream/upstreamcache.cpp2112
-rw-r--r--src/zenserver/upstream/upstreamcache.h252
-rw-r--r--src/zenserver/upstream/upstreamservice.cpp56
-rw-r--r--src/zenserver/upstream/upstreamservice.h27
-rw-r--r--src/zenserver/upstream/zen.cpp326
-rw-r--r--src/zenserver/upstream/zen.h125
12 files changed, 6196 insertions, 0 deletions
diff --git a/src/zenserver/upstream/hordecompute.cpp b/src/zenserver/upstream/hordecompute.cpp
new file mode 100644
index 000000000..64d9fff72
--- /dev/null
+++ b/src/zenserver/upstream/hordecompute.cpp
@@ -0,0 +1,1457 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "upstreamapply.h"
+
+#if ZEN_WITH_COMPUTE_SERVICES
+
+# include "jupiter.h"
+
+# include <zencore/compactbinary.h>
+# include <zencore/compactbinarybuilder.h>
+# include <zencore/compactbinarypackage.h>
+# include <zencore/compactbinaryvalidation.h>
+# include <zencore/fmtutils.h>
+# include <zencore/session.h>
+# include <zencore/stream.h>
+# include <zencore/thread.h>
+# include <zencore/timer.h>
+# include <zencore/workthreadpool.h>
+
+# include <zenstore/cidstore.h>
+
+# include <auth/authmgr.h>
+# include <upstream/upstreamcache.h>
+
+# include "cache/structuredcachestore.h"
+# include "diag/logging.h"
+
+# include <fmt/format.h>
+
+# include <algorithm>
+# include <atomic>
+# include <set>
+# include <stack>
+
+namespace zen {
+
+using namespace std::literals;
+
+static const IoBuffer EmptyBuffer;
+static const IoHash EmptyBufferId = IoHash::HashBuffer(EmptyBuffer);
+
+namespace detail {
+
+ class HordeUpstreamApplyEndpoint final : public UpstreamApplyEndpoint
+ {
+ public:
+ HordeUpstreamApplyEndpoint(const CloudCacheClientOptions& ComputeOptions,
+ const UpstreamAuthConfig& ComputeAuthConfig,
+ const CloudCacheClientOptions& StorageOptions,
+ const UpstreamAuthConfig& StorageAuthConfig,
+ CidStore& CidStore,
+ AuthMgr& Mgr)
+ : m_Log(logging::Get("upstream-apply"))
+ , m_CidStore(CidStore)
+ , m_AuthMgr(Mgr)
+ {
+ m_DisplayName = fmt::format("{} - '{}'+'{}'", ComputeOptions.Name, ComputeOptions.ServiceUrl, StorageOptions.ServiceUrl);
+ m_ChannelId = fmt::format("zen-{}", zen::GetSessionIdString());
+
+ {
+ std::unique_ptr<CloudCacheTokenProvider> TokenProvider;
+
+ if (ComputeAuthConfig.OAuthUrl.empty() == false)
+ {
+ TokenProvider =
+ CloudCacheTokenProvider::CreateFromOAuthClientCredentials({.Url = ComputeAuthConfig.OAuthUrl,
+ .ClientId = ComputeAuthConfig.OAuthClientId,
+ .ClientSecret = ComputeAuthConfig.OAuthClientSecret});
+ }
+ else if (ComputeAuthConfig.OpenIdProvider.empty() == false)
+ {
+ TokenProvider =
+ CloudCacheTokenProvider::CreateFromCallback([this, ProviderName = std::string(ComputeAuthConfig.OpenIdProvider)]() {
+ AuthMgr::OpenIdAccessToken Token = m_AuthMgr.GetOpenIdAccessToken(ProviderName);
+ return CloudCacheAccessToken{.Value = Token.AccessToken, .ExpireTime = Token.ExpireTime};
+ });
+ }
+ else
+ {
+ CloudCacheAccessToken AccessToken{.Value = std::string(ComputeAuthConfig.AccessToken),
+ .ExpireTime = CloudCacheAccessToken::TimePoint::max()};
+ TokenProvider = CloudCacheTokenProvider::CreateFromStaticToken(AccessToken);
+ }
+
+ m_Client = new CloudCacheClient(ComputeOptions, std::move(TokenProvider));
+ }
+
+ {
+ std::unique_ptr<CloudCacheTokenProvider> TokenProvider;
+
+ if (StorageAuthConfig.OAuthUrl.empty() == false)
+ {
+ TokenProvider =
+ CloudCacheTokenProvider::CreateFromOAuthClientCredentials({.Url = StorageAuthConfig.OAuthUrl,
+ .ClientId = StorageAuthConfig.OAuthClientId,
+ .ClientSecret = StorageAuthConfig.OAuthClientSecret});
+ }
+ else if (StorageAuthConfig.OpenIdProvider.empty() == false)
+ {
+ TokenProvider =
+ CloudCacheTokenProvider::CreateFromCallback([this, ProviderName = std::string(StorageAuthConfig.OpenIdProvider)]() {
+ AuthMgr::OpenIdAccessToken Token = m_AuthMgr.GetOpenIdAccessToken(ProviderName);
+ return CloudCacheAccessToken{.Value = Token.AccessToken, .ExpireTime = Token.ExpireTime};
+ });
+ }
+ else
+ {
+ CloudCacheAccessToken AccessToken{.Value = std::string(StorageAuthConfig.AccessToken),
+ .ExpireTime = CloudCacheAccessToken::TimePoint::max()};
+ TokenProvider = CloudCacheTokenProvider::CreateFromStaticToken(AccessToken);
+ }
+
+ m_StorageClient = new CloudCacheClient(StorageOptions, std::move(TokenProvider));
+ }
+ }
+
+ virtual ~HordeUpstreamApplyEndpoint() = default;
+
+ virtual UpstreamEndpointHealth Initialize() override { return CheckHealth(); }
+
+ virtual bool IsHealthy() const override { return m_HealthOk.load(); }
+
+ virtual UpstreamEndpointHealth CheckHealth() override
+ {
+ try
+ {
+ CloudCacheSession Session(m_Client);
+ CloudCacheResult Result = Session.Authenticate();
+
+ m_HealthOk = Result.ErrorCode == 0;
+
+ return {.Reason = std::move(Result.Reason), .Ok = Result.Success};
+ }
+ catch (std::exception& Err)
+ {
+ return {.Reason = Err.what(), .Ok = false};
+ }
+ }
+
+ virtual std::string_view DisplayName() const override { return m_DisplayName; }
+
+ virtual PostUpstreamApplyResult PostApply(UpstreamApplyRecord ApplyRecord) override
+ {
+ PostUpstreamApplyResult ApplyResult{};
+ ApplyResult.Timepoints.merge(ApplyRecord.Timepoints);
+
+ try
+ {
+ UpstreamData UpstreamData;
+ if (!ProcessApplyKey(ApplyRecord, UpstreamData))
+ {
+ return {.Error{.ErrorCode = -1, .Reason = "Failed to generate task data"}};
+ }
+
+ {
+ ApplyResult.Timepoints["zen-storage-build-ref"] = DateTime::NowTicks();
+
+ bool AlreadyQueued;
+ {
+ std::scoped_lock Lock(m_TaskMutex);
+ AlreadyQueued = m_PendingTasks.contains(UpstreamData.TaskId);
+ }
+ if (AlreadyQueued)
+ {
+ // Pending task is already queued, return success
+ ApplyResult.Success = true;
+ return ApplyResult;
+ }
+ m_PendingTasks[UpstreamData.TaskId] = std::move(ApplyRecord);
+ }
+
+ CloudCacheSession ComputeSession(m_Client);
+ CloudCacheSession StorageSession(m_StorageClient);
+
+ {
+ CloudCacheResult Result = BatchPutBlobsIfMissing(StorageSession, UpstreamData.Blobs, UpstreamData.CasIds);
+ ApplyResult.Bytes += Result.Bytes;
+ ApplyResult.ElapsedSeconds += Result.ElapsedSeconds;
+ ApplyResult.Timepoints["zen-storage-upload-blobs"] = DateTime::NowTicks();
+ if (!Result.Success)
+ {
+ ApplyResult.Error = {.ErrorCode = Result.ErrorCode,
+ .Reason = !Result.Reason.empty() ? std::move(Result.Reason) : "Failed to upload blobs"};
+ return ApplyResult;
+ }
+ UpstreamData.Blobs.clear();
+ UpstreamData.CasIds.clear();
+ }
+
+ {
+ CloudCacheResult Result = BatchPutCompressedBlobsIfMissing(StorageSession, UpstreamData.Cids);
+ ApplyResult.Bytes += Result.Bytes;
+ ApplyResult.ElapsedSeconds += Result.ElapsedSeconds;
+ ApplyResult.Timepoints["zen-storage-upload-compressed-blobs"] = DateTime::NowTicks();
+ if (!Result.Success)
+ {
+ ApplyResult.Error = {
+ .ErrorCode = Result.ErrorCode,
+ .Reason = !Result.Reason.empty() ? std::move(Result.Reason) : "Failed to upload compressed blobs"};
+ return ApplyResult;
+ }
+ UpstreamData.Cids.clear();
+ }
+
+ {
+ CloudCacheResult Result = BatchPutObjectsIfMissing(StorageSession, UpstreamData.Objects);
+ ApplyResult.Bytes += Result.Bytes;
+ ApplyResult.ElapsedSeconds += Result.ElapsedSeconds;
+ ApplyResult.Timepoints["zen-storage-upload-objects"] = DateTime::NowTicks();
+ if (!Result.Success)
+ {
+ ApplyResult.Error = {.ErrorCode = Result.ErrorCode,
+ .Reason = !Result.Reason.empty() ? std::move(Result.Reason) : "Failed to upload objects"};
+ return ApplyResult;
+ }
+ }
+
+ {
+ PutRefResult RefResult = StorageSession.PutRef(StorageSession.Client().DefaultBlobStoreNamespace(),
+ "requests"sv,
+ UpstreamData.TaskId,
+ UpstreamData.Objects[UpstreamData.TaskId].GetBuffer().AsIoBuffer(),
+ ZenContentType::kCbObject);
+ Log().debug("Put ref {} Need={} Bytes={} Duration={}s Result={}",
+ UpstreamData.TaskId,
+ RefResult.Needs.size(),
+ RefResult.Bytes,
+ RefResult.ElapsedSeconds,
+ RefResult.Success);
+ ApplyResult.Bytes += RefResult.Bytes;
+ ApplyResult.ElapsedSeconds += RefResult.ElapsedSeconds;
+ ApplyResult.Timepoints["zen-storage-put-ref"] = DateTime::NowTicks();
+
+ if (RefResult.Needs.size() > 0)
+ {
+ Log().error("Failed to add task ref {} due to {} missing blobs", UpstreamData.TaskId, RefResult.Needs.size());
+ for (const auto& Hash : RefResult.Needs)
+ {
+ Log().debug("Task ref {} missing blob {}", UpstreamData.TaskId, Hash);
+ }
+
+ ApplyResult.Error = {.ErrorCode = RefResult.ErrorCode,
+ .Reason = !RefResult.Reason.empty() ? std::move(RefResult.Reason)
+ : "Failed to add task ref due to missing blob"};
+ return ApplyResult;
+ }
+
+ if (!RefResult.Success)
+ {
+ ApplyResult.Error = {.ErrorCode = RefResult.ErrorCode,
+ .Reason = !RefResult.Reason.empty() ? std::move(RefResult.Reason) : "Failed to add task ref"};
+ return ApplyResult;
+ }
+ UpstreamData.Objects.clear();
+ }
+
+ {
+ CbObjectWriter Writer;
+ Writer.AddString("c"sv, m_ChannelId);
+ Writer.AddObjectAttachment("r"sv, UpstreamData.RequirementsId);
+ Writer.BeginArray("t"sv);
+ Writer.AddObjectAttachment(UpstreamData.TaskId);
+ Writer.EndArray();
+ CbObject TasksObject = Writer.Save();
+ IoBuffer TasksData = TasksObject.GetBuffer().AsIoBuffer();
+
+ CloudCacheResult Result = ComputeSession.PostComputeTasks(TasksData);
+ Log().debug("Post compute task {} Bytes={} Duration={}s Result={}",
+ TasksObject.GetHash(),
+ Result.Bytes,
+ Result.ElapsedSeconds,
+ Result.Success);
+ ApplyResult.Bytes += Result.Bytes;
+ ApplyResult.ElapsedSeconds += Result.ElapsedSeconds;
+ ApplyResult.Timepoints["zen-horde-post-task"] = DateTime::NowTicks();
+ if (!Result.Success)
+ {
+ {
+ std::scoped_lock Lock(m_TaskMutex);
+ m_PendingTasks.erase(UpstreamData.TaskId);
+ }
+
+ ApplyResult.Error = {.ErrorCode = Result.ErrorCode,
+ .Reason = !Result.Reason.empty() ? std::move(Result.Reason) : "Failed to post compute task"};
+ return ApplyResult;
+ }
+ }
+
+ Log().info("Task posted {}", UpstreamData.TaskId);
+ ApplyResult.Success = true;
+ return ApplyResult;
+ }
+ catch (std::exception& Err)
+ {
+ m_HealthOk = false;
+ return {.Error{.ErrorCode = -1, .Reason = Err.what()}};
+ }
+ }
+
+ [[nodiscard]] CloudCacheResult BatchPutBlobsIfMissing(CloudCacheSession& Session,
+ const std::map<IoHash, IoBuffer>& Blobs,
+ const std::set<IoHash>& CasIds)
+ {
+ if (Blobs.size() == 0 && CasIds.size() == 0)
+ {
+ return {.Success = true};
+ }
+
+ int64_t Bytes{};
+ double ElapsedSeconds{};
+
+ // Batch check for missing blobs
+ std::set<IoHash> Keys;
+ std::transform(Blobs.begin(), Blobs.end(), std::inserter(Keys, Keys.end()), [](const auto& It) { return It.first; });
+ Keys.insert(CasIds.begin(), CasIds.end());
+
+ CloudCacheExistsResult ExistsResult = Session.BlobExists(Session.Client().DefaultBlobStoreNamespace(), Keys);
+ Log().debug("Queried {} missing blobs Need={} Duration={}s Result={}",
+ Keys.size(),
+ ExistsResult.Needs.size(),
+ ExistsResult.ElapsedSeconds,
+ ExistsResult.Success);
+ ElapsedSeconds += ExistsResult.ElapsedSeconds;
+ if (!ExistsResult.Success)
+ {
+ return {.Bytes = Bytes,
+ .ElapsedSeconds = ElapsedSeconds,
+ .ErrorCode = ExistsResult.ErrorCode ? ExistsResult.ErrorCode : -1,
+ .Reason = !ExistsResult.Reason.empty() ? std::move(ExistsResult.Reason) : "Failed to check if blobs exist"};
+ }
+
+ for (const auto& Hash : ExistsResult.Needs)
+ {
+ IoBuffer DataBuffer;
+ if (Blobs.contains(Hash))
+ {
+ DataBuffer = Blobs.at(Hash);
+ }
+ else
+ {
+ DataBuffer = m_CidStore.FindChunkByCid(Hash);
+ if (!DataBuffer)
+ {
+ Log().warn("Put blob FAILED, input chunk '{}' missing", Hash);
+ return {.Bytes = Bytes, .ElapsedSeconds = ElapsedSeconds, .ErrorCode = -1, .Reason = "Failed to put blobs"};
+ }
+ }
+
+ CloudCacheResult Result = Session.PutBlob(Session.Client().DefaultBlobStoreNamespace(), Hash, DataBuffer);
+ Log().debug("Put blob {} Bytes={} Duration={}s Result={}", Hash, Result.Bytes, Result.ElapsedSeconds, Result.Success);
+ Bytes += Result.Bytes;
+ ElapsedSeconds += Result.ElapsedSeconds;
+ if (!Result.Success)
+ {
+ return {.Bytes = Bytes,
+ .ElapsedSeconds = ElapsedSeconds,
+ .ErrorCode = Result.ErrorCode ? Result.ErrorCode : -1,
+ .Reason = !Result.Reason.empty() ? std::move(Result.Reason) : "Failed to put blobs"};
+ }
+ }
+
+ return {.Bytes = Bytes, .ElapsedSeconds = ElapsedSeconds, .Success = true};
+ }
+
+ [[nodiscard]] CloudCacheResult BatchPutCompressedBlobsIfMissing(CloudCacheSession& Session, const std::set<IoHash>& Cids)
+ {
+ if (Cids.size() == 0)
+ {
+ return {.Success = true};
+ }
+
+ int64_t Bytes{};
+ double ElapsedSeconds{};
+
+ // Batch check for missing compressed blobs
+ CloudCacheExistsResult ExistsResult = Session.CompressedBlobExists(Session.Client().DefaultBlobStoreNamespace(), Cids);
+ Log().debug("Queried {} missing compressed blobs Need={} Duration={}s Result={}",
+ Cids.size(),
+ ExistsResult.Needs.size(),
+ ExistsResult.ElapsedSeconds,
+ ExistsResult.Success);
+ ElapsedSeconds += ExistsResult.ElapsedSeconds;
+ if (!ExistsResult.Success)
+ {
+ return {
+ .Bytes = Bytes,
+ .ElapsedSeconds = ElapsedSeconds,
+ .ErrorCode = ExistsResult.ErrorCode ? ExistsResult.ErrorCode : -1,
+ .Reason = !ExistsResult.Reason.empty() ? std::move(ExistsResult.Reason) : "Failed to check if compressed blobs exist"};
+ }
+
+ for (const auto& Hash : ExistsResult.Needs)
+ {
+ IoBuffer DataBuffer = m_CidStore.FindChunkByCid(Hash);
+ if (!DataBuffer)
+ {
+ Log().warn("Put compressed blob FAILED, input CID chunk '{}' missing", Hash);
+ return {.Bytes = Bytes, .ElapsedSeconds = ElapsedSeconds, .ErrorCode = -1, .Reason = "Failed to put compressed blobs"};
+ }
+
+ CloudCacheResult Result = Session.PutCompressedBlob(Session.Client().DefaultBlobStoreNamespace(), Hash, DataBuffer);
+ Log().debug("Put compressed blob {} Bytes={} Duration={}s Result={}",
+ Hash,
+ Result.Bytes,
+ Result.ElapsedSeconds,
+ Result.Success);
+ Bytes += Result.Bytes;
+ ElapsedSeconds += Result.ElapsedSeconds;
+ if (!Result.Success)
+ {
+ return {.Bytes = Bytes,
+ .ElapsedSeconds = ElapsedSeconds,
+ .ErrorCode = Result.ErrorCode ? Result.ErrorCode : -1,
+ .Reason = !Result.Reason.empty() ? std::move(Result.Reason) : "Failed to put compressed blobs"};
+ }
+ }
+
+ return {.Bytes = Bytes, .ElapsedSeconds = ElapsedSeconds, .Success = true};
+ }
+
+ [[nodiscard]] CloudCacheResult BatchPutObjectsIfMissing(CloudCacheSession& Session, const std::map<IoHash, CbObject>& Objects)
+ {
+ if (Objects.size() == 0)
+ {
+ return {.Success = true};
+ }
+
+ int64_t Bytes{};
+ double ElapsedSeconds{};
+
+ // Batch check for missing objects
+ std::set<IoHash> Keys;
+ std::transform(Objects.begin(), Objects.end(), std::inserter(Keys, Keys.end()), [](const auto& It) { return It.first; });
+
+ CloudCacheExistsResult ExistsResult = Session.ObjectExists(Session.Client().DefaultBlobStoreNamespace(), Keys);
+ Log().debug("Queried {} missing objects Need={} Duration={}s Result={}",
+ Keys.size(),
+ ExistsResult.Needs.size(),
+ ExistsResult.ElapsedSeconds,
+ ExistsResult.Success);
+ ElapsedSeconds += ExistsResult.ElapsedSeconds;
+ if (!ExistsResult.Success)
+ {
+ return {.Bytes = Bytes,
+ .ElapsedSeconds = ElapsedSeconds,
+ .ErrorCode = ExistsResult.ErrorCode ? ExistsResult.ErrorCode : -1,
+ .Reason = !ExistsResult.Reason.empty() ? std::move(ExistsResult.Reason) : "Failed to check if objects exist"};
+ }
+
+ for (const auto& Hash : ExistsResult.Needs)
+ {
+ CloudCacheResult Result =
+ Session.PutObject(Session.Client().DefaultBlobStoreNamespace(), Hash, Objects.at(Hash).GetBuffer().AsIoBuffer());
+ Log().debug("Put object {} Bytes={} Duration={}s Result={}", Hash, Result.Bytes, Result.ElapsedSeconds, Result.Success);
+ Bytes += Result.Bytes;
+ ElapsedSeconds += Result.ElapsedSeconds;
+ if (!Result.Success)
+ {
+ return {.Bytes = Bytes,
+ .ElapsedSeconds = ElapsedSeconds,
+ .ErrorCode = Result.ErrorCode ? Result.ErrorCode : -1,
+ .Reason = !Result.Reason.empty() ? std::move(Result.Reason) : "Failed to put objects"};
+ }
+ }
+
+ return {.Bytes = Bytes, .ElapsedSeconds = ElapsedSeconds, .Success = true};
+ }
+
+ enum class ComputeTaskState : int32_t
+ {
+ Queued = 0,
+ Executing = 1,
+ Complete = 2,
+ };
+
+ enum class ComputeTaskOutcome : int32_t
+ {
+ Success = 0,
+ Failed = 1,
+ Cancelled = 2,
+ NoResult = 3,
+ Exipred = 4,
+ BlobNotFound = 5,
+ Exception = 6,
+ };
+
+ [[nodiscard]] static std::string_view ComputeTaskStateToString(const ComputeTaskState Outcome)
+ {
+ switch (Outcome)
+ {
+ case ComputeTaskState::Queued:
+ return "Queued"sv;
+ case ComputeTaskState::Executing:
+ return "Executing"sv;
+ case ComputeTaskState::Complete:
+ return "Complete"sv;
+ };
+ return "Unknown"sv;
+ }
+
+ [[nodiscard]] static std::string_view ComputeTaskOutcomeToString(const ComputeTaskOutcome Outcome)
+ {
+ switch (Outcome)
+ {
+ case ComputeTaskOutcome::Success:
+ return "Success"sv;
+ case ComputeTaskOutcome::Failed:
+ return "Failed"sv;
+ case ComputeTaskOutcome::Cancelled:
+ return "Cancelled"sv;
+ case ComputeTaskOutcome::NoResult:
+ return "NoResult"sv;
+ case ComputeTaskOutcome::Exipred:
+ return "Exipred"sv;
+ case ComputeTaskOutcome::BlobNotFound:
+ return "BlobNotFound"sv;
+ case ComputeTaskOutcome::Exception:
+ return "Exception"sv;
+ };
+ return "Unknown"sv;
+ }
+
+ virtual GetUpstreamApplyUpdatesResult GetUpdates(WorkerThreadPool& ThreadPool) override
+ {
+ int64_t Bytes{};
+ double ElapsedSeconds{};
+
+ {
+ std::scoped_lock Lock(m_TaskMutex);
+ if (m_PendingTasks.empty())
+ {
+ if (m_CompletedTasks.empty())
+ {
+ // Nothing to do.
+ return {.Success = true};
+ }
+
+ UpstreamApplyCompleted CompletedTasks;
+ std::swap(CompletedTasks, m_CompletedTasks);
+ return {.Bytes = Bytes, .ElapsedSeconds = ElapsedSeconds, .Completed = std::move(CompletedTasks), .Success = true};
+ }
+ }
+
+ try
+ {
+ CloudCacheSession ComputeSession(m_Client);
+
+ CloudCacheResult UpdatesResult = ComputeSession.GetComputeUpdates(m_ChannelId);
+ Log().debug("Get compute updates Bytes={} Duration={}s Result={}",
+ UpdatesResult.Bytes,
+ UpdatesResult.ElapsedSeconds,
+ UpdatesResult.Success);
+ Bytes += UpdatesResult.Bytes;
+ ElapsedSeconds += UpdatesResult.ElapsedSeconds;
+ if (!UpdatesResult.Success)
+ {
+ return {.Error{.ErrorCode = UpdatesResult.ErrorCode, .Reason = std::move(UpdatesResult.Reason)},
+ .Bytes = Bytes,
+ .ElapsedSeconds = ElapsedSeconds};
+ }
+
+ if (!UpdatesResult.Success)
+ {
+ return {.Error{.ErrorCode = -1, .Reason = "Failed get task updates"}, .Bytes = Bytes, .ElapsedSeconds = ElapsedSeconds};
+ }
+
+ CbObject TaskStatus = LoadCompactBinaryObject(std::move(UpdatesResult.Response));
+
+ for (auto& It : TaskStatus["u"sv])
+ {
+ CbObjectView Status = It.AsObjectView();
+ IoHash TaskId = Status["h"sv].AsHash();
+ const ComputeTaskState State = (ComputeTaskState)Status["s"sv].AsInt32();
+ const ComputeTaskOutcome Outcome = (ComputeTaskOutcome)Status["o"sv].AsInt32();
+
+ Log().info("Task {} State={}", TaskId, ComputeTaskStateToString(State));
+
+ // Only completed tasks need to be processed
+ if (State != ComputeTaskState::Complete)
+ {
+ continue;
+ }
+
+ IoHash WorkerId{};
+ IoHash ActionId{};
+ UpstreamApplyType ApplyType{};
+
+ {
+ std::scoped_lock Lock(m_TaskMutex);
+ auto TaskIt = m_PendingTasks.find(TaskId);
+ if (TaskIt != m_PendingTasks.end())
+ {
+ WorkerId = TaskIt->second.WorkerDescriptor.GetHash();
+ ActionId = TaskIt->second.Action.GetHash();
+ ApplyType = TaskIt->second.Type;
+ m_PendingTasks.erase(TaskIt);
+ }
+ }
+
+ if (WorkerId == IoHash::Zero)
+ {
+ Log().warn("Task {} missing from pending tasks", TaskId);
+ continue;
+ }
+
+ std::map<std::string, uint64_t> Timepoints;
+ ProcessQueueTimings(Status["qs"sv].AsObjectView(), Timepoints);
+ ProcessExecuteTimings(Status["es"sv].AsObjectView(), Timepoints);
+
+ if (Outcome != ComputeTaskOutcome::Success)
+ {
+ const std::string_view Detail = Status["d"sv].AsString();
+ {
+ std::scoped_lock Lock(m_TaskMutex);
+ m_CompletedTasks[WorkerId][ActionId] = {
+ .Error{.ErrorCode = -1, .Reason = fmt::format("Task {} {}", ComputeTaskOutcomeToString(Outcome), Detail)},
+ .Timepoints = std::move(Timepoints)};
+ }
+ continue;
+ }
+
+ Timepoints["zen-complete-queue-added"] = DateTime::NowTicks();
+ ThreadPool.ScheduleWork([this,
+ ApplyType,
+ ResultHash = Status["r"sv].AsHash(),
+ Timepoints = std::move(Timepoints),
+ TaskId = std::move(TaskId),
+ WorkerId = std::move(WorkerId),
+ ActionId = std::move(ActionId)]() mutable {
+ Timepoints["zen-complete-queue-dispatched"] = DateTime::NowTicks();
+ GetUpstreamApplyResult Result = ProcessTaskStatus(ApplyType, ResultHash);
+ Timepoints["zen-complete-queue-complete"] = DateTime::NowTicks();
+ Result.Timepoints.merge(Timepoints);
+
+ Log().debug("Task Processed {} Files={} Attachments={} ExitCode={}",
+ TaskId,
+ Result.OutputFiles.size(),
+ Result.OutputPackage.GetAttachments().size(),
+ Result.Error.ErrorCode);
+ {
+ std::scoped_lock Lock(m_TaskMutex);
+ m_CompletedTasks[WorkerId][ActionId] = std::move(Result);
+ }
+ });
+ }
+
+ {
+ std::scoped_lock Lock(m_TaskMutex);
+ if (m_CompletedTasks.empty())
+ {
+ // Nothing to do.
+ return {.Bytes = Bytes, .ElapsedSeconds = ElapsedSeconds, .Success = true};
+ }
+ UpstreamApplyCompleted CompletedTasks;
+ std::swap(CompletedTasks, m_CompletedTasks);
+ return {.Bytes = Bytes, .ElapsedSeconds = ElapsedSeconds, .Completed = std::move(CompletedTasks), .Success = true};
+ }
+ }
+ catch (std::exception& Err)
+ {
+ m_HealthOk = false;
+ return {
+ .Error{.ErrorCode = -1, .Reason = Err.what()},
+ .Bytes = Bytes,
+ .ElapsedSeconds = ElapsedSeconds,
+ };
+ }
+ }
+
+ virtual UpstreamApplyEndpointStats& Stats() override { return m_Stats; }
+
+ private:
+ spdlog::logger& Log() { return m_Log; }
+
+ spdlog::logger& m_Log;
+ CidStore& m_CidStore;
+ AuthMgr& m_AuthMgr;
+ std::string m_DisplayName;
+ RefPtr<CloudCacheClient> m_Client;
+ RefPtr<CloudCacheClient> m_StorageClient;
+ UpstreamApplyEndpointStats m_Stats;
+ std::atomic_bool m_HealthOk{false};
+ std::string m_ChannelId;
+
+ std::mutex m_TaskMutex;
+ std::unordered_map<IoHash, UpstreamApplyRecord> m_PendingTasks;
+ UpstreamApplyCompleted m_CompletedTasks;
+
+ struct UpstreamData
+ {
+ std::map<IoHash, IoBuffer> Blobs;
+ std::map<IoHash, CbObject> Objects;
+ std::set<IoHash> CasIds;
+ std::set<IoHash> Cids;
+ IoHash TaskId;
+ IoHash RequirementsId;
+ };
+
+ struct UpstreamDirectory
+ {
+ std::filesystem::path Path;
+ std::map<std::string, UpstreamDirectory> Directories;
+ std::set<std::string> Files;
+ };
+
+ static void ProcessQueueTimings(CbObjectView QueueStats, std::map<std::string, uint64_t>& Timepoints)
+ {
+ uint64_t Ticks = QueueStats["t"sv].AsDateTimeTicks();
+ if (Ticks == 0)
+ {
+ return;
+ }
+
+ // Scope is an array of miliseconds after start time
+ // TODO: cleanup
+ Timepoints["horde-queue-added"] = Ticks;
+ int Index = 0;
+ for (auto& Item : QueueStats["s"sv].AsArrayView())
+ {
+ Ticks += Item.AsInt32() * TimeSpan::TicksPerMillisecond;
+ switch (Index)
+ {
+ case 0:
+ Timepoints["horde-queue-dispatched"] = Ticks;
+ break;
+ case 1:
+ Timepoints["horde-queue-complete"] = Ticks;
+ break;
+ }
+ Index++;
+ }
+ }
+
+ static void ProcessExecuteTimings(CbObjectView ExecutionStats, std::map<std::string, uint64_t>& Timepoints)
+ {
+ uint64_t Ticks = ExecutionStats["t"sv].AsDateTimeTicks();
+ if (Ticks == 0)
+ {
+ return;
+ }
+
+ // Scope is an array of miliseconds after start time
+ // TODO: cleanup
+ Timepoints["horde-execution-start"] = Ticks;
+ int Index = 0;
+ for (auto& Item : ExecutionStats["s"sv].AsArrayView())
+ {
+ Ticks += Item.AsInt32() * TimeSpan::TicksPerMillisecond;
+ switch (Index)
+ {
+ case 0:
+ Timepoints["horde-execution-download-ref"] = Ticks;
+ break;
+ case 1:
+ Timepoints["horde-execution-download-input"] = Ticks;
+ break;
+ case 2:
+ Timepoints["horde-execution-execute"] = Ticks;
+ break;
+ case 3:
+ Timepoints["horde-execution-upload-log"] = Ticks;
+ break;
+ case 4:
+ Timepoints["horde-execution-upload-output"] = Ticks;
+ break;
+ case 5:
+ Timepoints["horde-execution-upload-ref"] = Ticks;
+ break;
+ }
+ Index++;
+ }
+ }
+
+ [[nodiscard]] GetUpstreamApplyResult ProcessTaskStatus(const UpstreamApplyType ApplyType, const IoHash& ResultHash)
+ {
+ try
+ {
+ CloudCacheSession Session(m_StorageClient);
+
+ GetUpstreamApplyResult ApplyResult{};
+
+ IoHash StdOutHash;
+ IoHash StdErrHash;
+ IoHash OutputHash;
+
+ std::map<IoHash, IoBuffer> BinaryData;
+
+ {
+ CloudCacheResult ObjectRefResult =
+ Session.GetRef(Session.Client().DefaultBlobStoreNamespace(), "responses"sv, ResultHash, ZenContentType::kCbObject);
+ Log().debug("Get ref {} Bytes={} Duration={}s Result={}",
+ ResultHash,
+ ObjectRefResult.Bytes,
+ ObjectRefResult.ElapsedSeconds,
+ ObjectRefResult.Success);
+ ApplyResult.Bytes += ObjectRefResult.Bytes;
+ ApplyResult.ElapsedSeconds += ObjectRefResult.ElapsedSeconds;
+ ApplyResult.Timepoints["zen-storage-get-ref"] = DateTime::NowTicks();
+
+ if (!ObjectRefResult.Success)
+ {
+ ApplyResult.Error.Reason = "Failed to get result object data";
+ return ApplyResult;
+ }
+
+ CbObject ResultObject = LoadCompactBinaryObject(ObjectRefResult.Response);
+ ApplyResult.Error.ErrorCode = ResultObject["e"sv].AsInt32();
+ StdOutHash = ResultObject["so"sv].AsBinaryAttachment();
+ StdErrHash = ResultObject["se"sv].AsBinaryAttachment();
+ OutputHash = ResultObject["o"sv].AsObjectAttachment();
+ }
+
+ {
+ std::set<IoHash> NeededData;
+ if (OutputHash != IoHash::Zero)
+ {
+ GetObjectReferencesResult ObjectReferenceResult =
+ Session.GetObjectReferences(Session.Client().DefaultBlobStoreNamespace(), OutputHash);
+ Log().debug("Get object references {} References={} Bytes={} Duration={}s Result={}",
+ ResultHash,
+ ObjectReferenceResult.References.size(),
+ ObjectReferenceResult.Bytes,
+ ObjectReferenceResult.ElapsedSeconds,
+ ObjectReferenceResult.Success);
+ ApplyResult.Bytes += ObjectReferenceResult.Bytes;
+ ApplyResult.ElapsedSeconds += ObjectReferenceResult.ElapsedSeconds;
+ ApplyResult.Timepoints["zen-storage-get-object-references"] = DateTime::NowTicks();
+
+ if (!ObjectReferenceResult.Success)
+ {
+ ApplyResult.Error.Reason = "Failed to get result object references";
+ return ApplyResult;
+ }
+
+ NeededData = std::move(ObjectReferenceResult.References);
+ }
+
+ NeededData.insert(OutputHash);
+ NeededData.insert(StdOutHash);
+ NeededData.insert(StdErrHash);
+
+ for (const auto& Hash : NeededData)
+ {
+ if (Hash == IoHash::Zero)
+ {
+ continue;
+ }
+ CloudCacheResult BlobResult = Session.GetBlob(Session.Client().DefaultBlobStoreNamespace(), Hash);
+ Log().debug("Get blob {} Bytes={} Duration={}s Result={}",
+ Hash,
+ BlobResult.Bytes,
+ BlobResult.ElapsedSeconds,
+ BlobResult.Success);
+ ApplyResult.Bytes += BlobResult.Bytes;
+ ApplyResult.ElapsedSeconds += BlobResult.ElapsedSeconds;
+ if (!BlobResult.Success)
+ {
+ ApplyResult.Error.Reason = "Failed to get blob";
+ return ApplyResult;
+ }
+ BinaryData[Hash] = std::move(BlobResult.Response);
+ }
+ ApplyResult.Timepoints["zen-storage-get-blobs"] = DateTime::NowTicks();
+ }
+
+ ApplyResult.StdOut = StdOutHash != IoHash::Zero
+ ? std::string((const char*)BinaryData[StdOutHash].GetData(), BinaryData[StdOutHash].GetSize())
+ : "";
+ ApplyResult.StdErr = StdErrHash != IoHash::Zero
+ ? std::string((const char*)BinaryData[StdErrHash].GetData(), BinaryData[StdErrHash].GetSize())
+ : "";
+
+ if (OutputHash == IoHash::Zero)
+ {
+ ApplyResult.Error.Reason = "Task completed with no output object";
+ return ApplyResult;
+ }
+
+ CbObject OutputObject = LoadCompactBinaryObject(BinaryData[OutputHash]);
+
+ switch (ApplyType)
+ {
+ case UpstreamApplyType::Simple:
+ {
+ ResolveMerkleTreeDirectory(""sv, OutputHash, BinaryData, ApplyResult.OutputFiles);
+ for (const auto& Pair : BinaryData)
+ {
+ ApplyResult.FileData[Pair.first] = std::move(BinaryData.at(Pair.first));
+ }
+
+ ApplyResult.Success = ApplyResult.Error.ErrorCode == 0;
+ return ApplyResult;
+ }
+ break;
+ case UpstreamApplyType::Asset:
+ {
+ if (ApplyResult.Error.ErrorCode != 0)
+ {
+ ApplyResult.Error.Reason = "Task completed with errors";
+ return ApplyResult;
+ }
+
+ // Get build.output
+ IoHash BuildOutputId;
+ IoBuffer BuildOutput;
+ for (auto& It : OutputObject["f"sv])
+ {
+ const CbObjectView FileObject = It.AsObjectView();
+ if (FileObject["n"sv].AsString() == "Build.output"sv)
+ {
+ BuildOutputId = FileObject["h"sv].AsBinaryAttachment();
+ BuildOutput = BinaryData[BuildOutputId];
+ break;
+ }
+ }
+
+ if (BuildOutput.GetSize() == 0)
+ {
+ ApplyResult.Error.Reason = "Build.output file not found in task results";
+ return ApplyResult;
+ }
+
+ // Get Output directory node
+ IoBuffer OutputDirectoryTree;
+ for (auto& It : OutputObject["d"sv])
+ {
+ const CbObjectView DirectoryObject = It.AsObjectView();
+ if (DirectoryObject["n"sv].AsString() == "Outputs"sv)
+ {
+ OutputDirectoryTree = BinaryData[DirectoryObject["h"sv].AsObjectAttachment()];
+ break;
+ }
+ }
+
+ if (OutputDirectoryTree.GetSize() == 0)
+ {
+ ApplyResult.Error.Reason = "Outputs directory not found in task results";
+ return ApplyResult;
+ }
+
+ // load build.output as CbObject
+
+ // Move Outputs from Horde to CbPackage
+
+ std::unordered_map<IoHash, IoHash> CidToCompressedId;
+ CbPackage OutputPackage;
+ CbObject OutputDirectoryTreeObject = LoadCompactBinaryObject(OutputDirectoryTree);
+
+ for (auto& It : OutputDirectoryTreeObject["f"sv])
+ {
+ CbObjectView FileObject = It.AsObjectView();
+ // Name is the uncompressed hash
+ IoHash DecompressedId = IoHash::FromHexString(FileObject["n"sv].AsString());
+ // Hash is the compressed data hash, and how it is stored in Horde
+ IoHash CompressedId = FileObject["h"sv].AsBinaryAttachment();
+
+ if (!BinaryData.contains(CompressedId))
+ {
+ Log().warn("Object attachment chunk not retrieved from Horde {}", CompressedId);
+ ApplyResult.Error.Reason = "Object attachment chunk not retrieved from Horde";
+ return ApplyResult;
+ }
+ CidToCompressedId[DecompressedId] = CompressedId;
+ }
+
+ // Iterate attachments, verify all chunks exist, and add to CbPackage
+ bool AnyErrors = false;
+ CbObject BuildOutputObject = LoadCompactBinaryObject(BuildOutput);
+ BuildOutputObject.IterateAttachments([&](CbFieldView Field) {
+ const IoHash DecompressedId = Field.AsHash();
+ if (!CidToCompressedId.contains(DecompressedId))
+ {
+ Log().warn("Attachment not found {}", DecompressedId);
+ AnyErrors = true;
+ return;
+ }
+ const IoHash& CompressedId = CidToCompressedId.at(DecompressedId);
+
+ if (!BinaryData.contains(CompressedId))
+ {
+ Log().warn("Missing output {} compressed {} uncompressed", CompressedId, DecompressedId);
+ AnyErrors = true;
+ return;
+ }
+
+ IoHash RawHash;
+ uint64_t RawSize;
+ CompressedBuffer AttachmentBuffer =
+ CompressedBuffer::FromCompressed(SharedBuffer(BinaryData[CompressedId]), RawHash, RawSize);
+
+ if (!AttachmentBuffer || RawHash != DecompressedId)
+ {
+ Log().warn(
+ "Invalid output encountered (not valid CompressedBuffer format) {} compressed {} uncompressed",
+ CompressedId,
+ DecompressedId);
+ AnyErrors = true;
+ return;
+ }
+
+ ApplyResult.TotalAttachmentBytes += AttachmentBuffer.GetCompressedSize();
+ ApplyResult.TotalRawAttachmentBytes += RawSize;
+
+ CbAttachment Attachment(AttachmentBuffer, DecompressedId);
+ OutputPackage.AddAttachment(Attachment);
+ });
+
+ if (AnyErrors)
+ {
+ ApplyResult.Error.Reason = "Failed to get result object attachment data";
+ return ApplyResult;
+ }
+
+ OutputPackage.SetObject(BuildOutputObject);
+ ApplyResult.OutputPackage = std::move(OutputPackage);
+
+ ApplyResult.Success = ApplyResult.Error.ErrorCode == 0;
+ return ApplyResult;
+ }
+ break;
+ }
+
+ ApplyResult.Error.Reason = "Unknown apply type";
+ return ApplyResult;
+ }
+ catch (std::exception& Err)
+ {
+ return {.Error{.ErrorCode = -1, .Reason = Err.what()}};
+ }
+ }
+
+ [[nodiscard]] bool ProcessApplyKey(const UpstreamApplyRecord& ApplyRecord, UpstreamData& Data)
+ {
+ std::string ExecutablePath;
+ std::string WorkingDirectory;
+ std::vector<std::string> Arguments;
+ std::map<std::string, std::string> Environment;
+ std::set<std::filesystem::path> InputFiles;
+ std::set<std::string> Outputs;
+ std::map<std::filesystem::path, IoHash> InputFileHashes;
+
+ ExecutablePath = ApplyRecord.WorkerDescriptor["path"sv].AsString();
+ if (ExecutablePath.empty())
+ {
+ Log().warn("process apply upstream FAILED, '{}', path missing from worker descriptor",
+ ApplyRecord.WorkerDescriptor.GetHash());
+ return false;
+ }
+
+ WorkingDirectory = ApplyRecord.WorkerDescriptor["workdir"sv].AsString();
+
+ for (auto& It : ApplyRecord.WorkerDescriptor["executables"sv])
+ {
+ CbObjectView FileEntry = It.AsObjectView();
+ if (!ProcessFileEntry(FileEntry, InputFiles, InputFileHashes, Data.CasIds))
+ {
+ return false;
+ }
+ }
+
+ for (auto& It : ApplyRecord.WorkerDescriptor["files"sv])
+ {
+ CbObjectView FileEntry = It.AsObjectView();
+ if (!ProcessFileEntry(FileEntry, InputFiles, InputFileHashes, Data.CasIds))
+ {
+ return false;
+ }
+ }
+
+ for (auto& It : ApplyRecord.WorkerDescriptor["dirs"sv])
+ {
+ std::string_view Directory = It.AsString();
+ std::string DummyFile = fmt::format("{}/.zen_empty_file", Directory);
+ InputFiles.insert(DummyFile);
+ Data.Blobs[EmptyBufferId] = EmptyBuffer;
+ InputFileHashes[DummyFile] = EmptyBufferId;
+ }
+
+ if (!WorkingDirectory.empty())
+ {
+ std::string DummyFile = fmt::format("{}/.zen_empty_file", WorkingDirectory);
+ InputFiles.insert(DummyFile);
+ Data.Blobs[EmptyBufferId] = EmptyBuffer;
+ InputFileHashes[DummyFile] = EmptyBufferId;
+ }
+
+ for (auto& It : ApplyRecord.WorkerDescriptor["environment"sv])
+ {
+ std::string_view Env = It.AsString();
+ auto Index = Env.find('=');
+ if (Index == std::string_view::npos)
+ {
+ Log().warn("process apply upstream FAILED, environment '{}' malformed", Env);
+ return false;
+ }
+
+ Environment[std::string(Env.substr(0, Index))] = Env.substr(Index + 1);
+ }
+
+ switch (ApplyRecord.Type)
+ {
+ case UpstreamApplyType::Simple:
+ {
+ for (auto& It : ApplyRecord.WorkerDescriptor["arguments"sv])
+ {
+ Arguments.push_back(std::string(It.AsString()));
+ }
+
+ for (auto& It : ApplyRecord.WorkerDescriptor["outputs"sv])
+ {
+ Outputs.insert(std::string(It.AsString()));
+ }
+ }
+ break;
+ case UpstreamApplyType::Asset:
+ {
+ static const std::filesystem::path BuildActionPath = "Build.action"sv;
+ static const std::filesystem::path InputPath = "Inputs"sv;
+ const IoHash ActionId = ApplyRecord.Action.GetHash();
+
+ Arguments.push_back("-Build=build.action");
+ Outputs.insert("Build.output");
+ Outputs.insert("Outputs");
+
+ InputFiles.insert(BuildActionPath);
+ InputFileHashes[BuildActionPath] = ActionId;
+ Data.Blobs[ActionId] = IoBufferBuilder::MakeCloneFromMemory(ApplyRecord.Action.GetBuffer().GetData(),
+ ApplyRecord.Action.GetBuffer().GetSize());
+
+ bool AnyErrors = false;
+ ApplyRecord.Action.IterateAttachments([&](CbFieldView Field) {
+ const IoHash Cid = Field.AsHash();
+ const std::filesystem::path FilePath = {InputPath / Cid.ToHexString()};
+
+ if (!m_CidStore.ContainsChunk(Cid))
+ {
+ Log().warn("process apply upstream FAILED, input CID chunk '{}' missing", Cid);
+ AnyErrors = true;
+ return;
+ }
+
+ if (InputFiles.contains(FilePath))
+ {
+ return;
+ }
+
+ InputFiles.insert(FilePath);
+ InputFileHashes[FilePath] = Cid;
+ Data.Cids.insert(Cid);
+ });
+
+ if (AnyErrors)
+ {
+ return false;
+ }
+ }
+ break;
+ }
+
+ const UpstreamDirectory RootDirectory = BuildDirectoryTree(InputFiles);
+
+ CbObject Sandbox = BuildMerkleTreeDirectory(RootDirectory, InputFileHashes, Data.Cids, Data.Objects);
+ const IoHash SandboxHash = Sandbox.GetHash();
+ Data.Objects[SandboxHash] = std::move(Sandbox);
+
+ {
+ std::string_view HostPlatform = ApplyRecord.WorkerDescriptor["host"sv].AsString();
+ if (HostPlatform.empty())
+ {
+ Log().warn("process apply upstream FAILED, 'host' platform not provided");
+ return false;
+ }
+
+ int32_t LogicalCores = ApplyRecord.WorkerDescriptor["cores"sv].AsInt32();
+ int64_t Memory = ApplyRecord.WorkerDescriptor["memory"sv].AsInt64();
+ bool Exclusive = ApplyRecord.WorkerDescriptor["exclusive"sv].AsBool();
+
+ std::string Condition = fmt::format("Platform == '{}'", HostPlatform);
+ if (HostPlatform == "Win64")
+ {
+ // TODO
+ // Condition += " && Pool == 'Win-RemoteExec'";
+ }
+
+ std::map<std::string_view, int64_t> Resources;
+ if (LogicalCores > 0)
+ {
+ Resources["LogicalCores"sv] = LogicalCores;
+ }
+ if (Memory > 0)
+ {
+ Resources["RAM"sv] = std::max(Memory / 1024LL / 1024LL / 1024LL, 1LL);
+ }
+
+ CbObject Requirements = BuildRequirements(Condition, Resources, Exclusive);
+ const IoHash RequirementsId = Requirements.GetHash();
+ Data.Objects[RequirementsId] = std::move(Requirements);
+ Data.RequirementsId = RequirementsId;
+ }
+
+ CbObject Task = BuildTask(ExecutablePath, Arguments, Environment, WorkingDirectory, SandboxHash, Data.RequirementsId, Outputs);
+
+ const IoHash TaskId = Task.GetHash();
+ Data.Objects[TaskId] = std::move(Task);
+ Data.TaskId = TaskId;
+
+ return true;
+ }
+
+ [[nodiscard]] bool ProcessFileEntry(const CbObjectView& FileEntry,
+ std::set<std::filesystem::path>& InputFiles,
+ std::map<std::filesystem::path, IoHash>& InputFileHashes,
+ std::set<IoHash>& CasIds)
+ {
+ const std::filesystem::path FilePath = FileEntry["name"sv].AsString();
+ const IoHash ChunkId = FileEntry["hash"sv].AsHash();
+ const uint64_t Size = FileEntry["size"sv].AsUInt64();
+
+ if (!m_CidStore.ContainsChunk(ChunkId))
+ {
+ Log().warn("process apply upstream FAILED, worker CAS chunk '{}' missing", ChunkId);
+ return false;
+ }
+
+ if (InputFiles.contains(FilePath))
+ {
+ Log().warn("process apply upstream FAILED, worker CAS chunk '{}' size: {} duplicate filename {}", ChunkId, Size, FilePath);
+ return false;
+ }
+
+ InputFiles.insert(FilePath);
+ InputFileHashes[FilePath] = ChunkId;
+ CasIds.insert(ChunkId);
+ return true;
+ }
+
+ [[nodiscard]] UpstreamDirectory BuildDirectoryTree(const std::set<std::filesystem::path>& InputFiles)
+ {
+ static const std::filesystem::path RootPath;
+ std::map<std::filesystem::path, UpstreamDirectory*> AllDirectories;
+ UpstreamDirectory RootDirectory = {.Path = RootPath};
+
+ AllDirectories[RootPath] = &RootDirectory;
+
+ // Build tree from flat list
+ for (const auto& Path : InputFiles)
+ {
+ if (Path.has_parent_path())
+ {
+ if (!AllDirectories.contains(Path.parent_path()))
+ {
+ std::stack<std::string> PathSplit;
+ {
+ std::filesystem::path ParentPath = Path.parent_path();
+ PathSplit.push(ParentPath.filename().string());
+ while (ParentPath.has_parent_path())
+ {
+ ParentPath = ParentPath.parent_path();
+ PathSplit.push(ParentPath.filename().string());
+ }
+ }
+ UpstreamDirectory* ParentPtr = &RootDirectory;
+ while (!PathSplit.empty())
+ {
+ if (!ParentPtr->Directories.contains(PathSplit.top()))
+ {
+ std::filesystem::path NewParentPath = {ParentPtr->Path / PathSplit.top()};
+ ParentPtr->Directories[PathSplit.top()] = {.Path = NewParentPath};
+ AllDirectories[NewParentPath] = &ParentPtr->Directories[PathSplit.top()];
+ }
+ ParentPtr = &ParentPtr->Directories[PathSplit.top()];
+ PathSplit.pop();
+ }
+ }
+
+ AllDirectories[Path.parent_path()]->Files.insert(Path.filename().string());
+ }
+ else
+ {
+ RootDirectory.Files.insert(Path.filename().string());
+ }
+ }
+
+ return RootDirectory;
+ }
+
+ [[nodiscard]] CbObject BuildMerkleTreeDirectory(const UpstreamDirectory& RootDirectory,
+ const std::map<std::filesystem::path, IoHash>& InputFileHashes,
+ const std::set<IoHash>& Cids,
+ std::map<IoHash, CbObject>& Objects)
+ {
+ CbObjectWriter DirectoryTreeWriter;
+
+ if (!RootDirectory.Files.empty())
+ {
+ DirectoryTreeWriter.BeginArray("f"sv);
+ for (const auto& File : RootDirectory.Files)
+ {
+ const std::filesystem::path FilePath = {RootDirectory.Path / File};
+ const IoHash& FileHash = InputFileHashes.at(FilePath);
+ const bool Compressed = Cids.contains(FileHash);
+ DirectoryTreeWriter.BeginObject();
+ DirectoryTreeWriter.AddString("n"sv, File);
+ DirectoryTreeWriter.AddBinaryAttachment("h"sv, FileHash);
+ DirectoryTreeWriter.AddBool("c"sv, Compressed);
+ DirectoryTreeWriter.EndObject();
+ }
+ DirectoryTreeWriter.EndArray();
+ }
+
+ if (!RootDirectory.Directories.empty())
+ {
+ DirectoryTreeWriter.BeginArray("d"sv);
+ for (const auto& Item : RootDirectory.Directories)
+ {
+ CbObject Directory = BuildMerkleTreeDirectory(Item.second, InputFileHashes, Cids, Objects);
+ const IoHash DirectoryHash = Directory.GetHash();
+ Objects[DirectoryHash] = std::move(Directory);
+
+ DirectoryTreeWriter.BeginObject();
+ DirectoryTreeWriter.AddString("n"sv, Item.first);
+ DirectoryTreeWriter.AddObjectAttachment("h"sv, DirectoryHash);
+ DirectoryTreeWriter.EndObject();
+ }
+ DirectoryTreeWriter.EndArray();
+ }
+
+ return DirectoryTreeWriter.Save();
+ }
+
+ void ResolveMerkleTreeDirectory(const std::filesystem::path& ParentDirectory,
+ const IoHash& DirectoryHash,
+ const std::map<IoHash, IoBuffer>& Objects,
+ std::map<std::filesystem::path, IoHash>& OutputFiles)
+ {
+ CbObject Directory = LoadCompactBinaryObject(Objects.at(DirectoryHash));
+
+ for (auto& It : Directory["f"sv])
+ {
+ const CbObjectView FileObject = It.AsObjectView();
+ const std::filesystem::path Path = ParentDirectory / FileObject["n"sv].AsString();
+
+ OutputFiles[Path] = FileObject["h"sv].AsBinaryAttachment();
+ }
+
+ for (auto& It : Directory["d"sv])
+ {
+ const CbObjectView DirectoryObject = It.AsObjectView();
+
+ ResolveMerkleTreeDirectory(ParentDirectory / DirectoryObject["n"sv].AsString(),
+ DirectoryObject["h"sv].AsObjectAttachment(),
+ Objects,
+ OutputFiles);
+ }
+ }
+
+ [[nodiscard]] CbObject BuildRequirements(const std::string_view Condition,
+ const std::map<std::string_view, int64_t>& Resources,
+ const bool Exclusive)
+ {
+ CbObjectWriter Writer;
+ Writer.AddString("c", Condition);
+ if (!Resources.empty())
+ {
+ Writer.BeginArray("r");
+ for (const auto& Resource : Resources)
+ {
+ Writer.BeginArray();
+ Writer.AddString(Resource.first);
+ Writer.AddInteger(Resource.second);
+ Writer.EndArray();
+ }
+ Writer.EndArray();
+ }
+ Writer.AddBool("e", Exclusive);
+ return Writer.Save();
+ }
+
+ [[nodiscard]] CbObject BuildTask(const std::string_view Executable,
+ const std::vector<std::string>& Arguments,
+ const std::map<std::string, std::string>& Environment,
+ const std::string_view WorkingDirectory,
+ const IoHash& SandboxHash,
+ const IoHash& RequirementsId,
+ const std::set<std::string>& Outputs)
+ {
+ CbObjectWriter TaskWriter;
+ TaskWriter.AddString("e"sv, Executable);
+
+ if (!Arguments.empty())
+ {
+ TaskWriter.BeginArray("a"sv);
+ for (const auto& Argument : Arguments)
+ {
+ TaskWriter.AddString(Argument);
+ }
+ TaskWriter.EndArray();
+ }
+
+ if (!Environment.empty())
+ {
+ TaskWriter.BeginArray("v"sv);
+ for (const auto& Env : Environment)
+ {
+ TaskWriter.BeginArray();
+ TaskWriter.AddString(Env.first);
+ TaskWriter.AddString(Env.second);
+ TaskWriter.EndArray();
+ }
+ TaskWriter.EndArray();
+ }
+
+ if (!WorkingDirectory.empty())
+ {
+ TaskWriter.AddString("w"sv, WorkingDirectory);
+ }
+
+ TaskWriter.AddObjectAttachment("s"sv, SandboxHash);
+ TaskWriter.AddObjectAttachment("r"sv, RequirementsId);
+
+ // Outputs
+ if (!Outputs.empty())
+ {
+ TaskWriter.BeginArray("o"sv);
+ for (const auto& Output : Outputs)
+ {
+ TaskWriter.AddString(Output);
+ }
+ TaskWriter.EndArray();
+ }
+
+ return TaskWriter.Save();
+ }
+ };
+} // namespace detail
+
+//////////////////////////////////////////////////////////////////////////
+
+std::unique_ptr<UpstreamApplyEndpoint>
+UpstreamApplyEndpoint::CreateHordeEndpoint(const CloudCacheClientOptions& ComputeOptions,
+ const UpstreamAuthConfig& ComputeAuthConfig,
+ const CloudCacheClientOptions& StorageOptions,
+ const UpstreamAuthConfig& StorageAuthConfig,
+ CidStore& CidStore,
+ AuthMgr& Mgr)
+{
+ return std::make_unique<detail::HordeUpstreamApplyEndpoint>(ComputeOptions,
+ ComputeAuthConfig,
+ StorageOptions,
+ StorageAuthConfig,
+ CidStore,
+ Mgr);
+}
+
+} // namespace zen
+
+#endif // ZEN_WITH_COMPUTE_SERVICES
diff --git a/src/zenserver/upstream/jupiter.cpp b/src/zenserver/upstream/jupiter.cpp
new file mode 100644
index 000000000..dbb185bec
--- /dev/null
+++ b/src/zenserver/upstream/jupiter.cpp
@@ -0,0 +1,965 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "jupiter.h"
+
+#include "diag/formatters.h"
+#include "diag/logging.h"
+
+#include <zencore/compactbinary.h>
+#include <zencore/compositebuffer.h>
+#include <zencore/iobuffer.h>
+#include <zencore/iohash.h>
+#include <zencore/string.h>
+#include <zencore/thread.h>
+#include <zencore/trace.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <cpr/cpr.h>
+#include <fmt/format.h>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+#if ZEN_PLATFORM_WINDOWS
+# pragma comment(lib, "Crypt32.lib")
+# pragma comment(lib, "Wldap32.lib")
+#endif
+
+#include <json11.hpp>
+
+using namespace std::literals;
+
+namespace zen {
+
+namespace detail {
+ struct CloudCacheSessionState
+ {
+ CloudCacheSessionState(CloudCacheClient& Client) : m_Client(Client) {}
+
+ const CloudCacheAccessToken& GetAccessToken(bool RefreshToken)
+ {
+ if (RefreshToken)
+ {
+ m_AccessToken = m_Client.AcquireAccessToken();
+ }
+
+ return m_AccessToken;
+ }
+
+ cpr::Session& GetSession() { return m_Session; }
+
+ void Reset(std::chrono::milliseconds ConnectTimeout, std::chrono::milliseconds Timeout)
+ {
+ m_Session.SetBody({});
+ m_Session.SetHeader({});
+ m_Session.SetConnectTimeout(ConnectTimeout);
+ m_Session.SetTimeout(Timeout);
+ }
+
+ private:
+ friend class zen::CloudCacheClient;
+
+ CloudCacheClient& m_Client;
+ CloudCacheAccessToken m_AccessToken;
+ cpr::Session m_Session;
+ };
+
+} // namespace detail
+
+CloudCacheSession::CloudCacheSession(CloudCacheClient* CacheClient) : m_Log(CacheClient->Logger()), m_CacheClient(CacheClient)
+{
+ m_SessionState = m_CacheClient->AllocSessionState();
+}
+
+CloudCacheSession::~CloudCacheSession()
+{
+ m_CacheClient->FreeSessionState(m_SessionState);
+}
+
+CloudCacheResult
+CloudCacheSession::Authenticate()
+{
+ const bool RefreshToken = true;
+ const CloudCacheAccessToken& AccessToken = GetAccessToken(RefreshToken);
+
+ return {.Success = AccessToken.IsValid()};
+}
+
+CloudCacheResult
+CloudCacheSession::GetRef(std::string_view Namespace, std::string_view BucketId, const IoHash& Key, ZenContentType RefType)
+{
+ const std::string ContentType = RefType == ZenContentType::kCbObject ? "application/x-ue-cb" : "application/octet-stream";
+
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_CacheClient->ServiceUrl() << "/api/v1/refs/" << Namespace << "/" << BucketId << "/" << Key.ToHexString();
+
+ cpr::Session& Session = GetSession();
+ const CloudCacheAccessToken& AccessToken = GetAccessToken();
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+ Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, {"Accept", ContentType}});
+ Session.SetOption(cpr::Body{});
+
+ cpr::Response Response = Session.Get();
+ ZEN_DEBUG("GET {}", Response);
+
+ if (Response.error)
+ {
+ return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message};
+ }
+ else if (!VerifyAccessToken(Response.status_code))
+ {
+ return {.ErrorCode = 401, .Reason = std::string("Invalid access token")};
+ }
+
+ const bool Success = Response.status_code == 200;
+ const IoBuffer Buffer = Success ? IoBufferBuilder::MakeCloneFromMemory(Response.text.data(), Response.text.size()) : IoBuffer();
+
+ return {.Response = Buffer, .Bytes = Response.downloaded_bytes, .ElapsedSeconds = Response.elapsed, .Success = Success};
+}
+
+CloudCacheResult
+CloudCacheSession::GetBlob(std::string_view Namespace, const IoHash& Key)
+{
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_CacheClient->ServiceUrl() << "/api/v1/blobs/" << Namespace << "/" << Key.ToHexString();
+
+ cpr::Session& Session = GetSession();
+ const CloudCacheAccessToken& AccessToken = GetAccessToken();
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+ Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, {"Accept", "application/octet-stream"}});
+ Session.SetOption(cpr::Body{});
+
+ cpr::Response Response = Session.Get();
+ ZEN_DEBUG("GET {}", Response);
+
+ if (Response.error)
+ {
+ return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message};
+ }
+ else if (!VerifyAccessToken(Response.status_code))
+ {
+ return {.ErrorCode = 401, .Reason = std::string("Invalid access token")};
+ }
+
+ const bool Success = Response.status_code == 200;
+ const IoBuffer Buffer =
+ Success && Response.text.size() > 0 ? IoBufferBuilder::MakeCloneFromMemory(Response.text.data(), Response.text.size()) : IoBuffer();
+
+ return {.Response = Buffer, .Bytes = Response.downloaded_bytes, .ElapsedSeconds = Response.elapsed, .Success = Success};
+}
+
+CloudCacheResult
+CloudCacheSession::GetCompressedBlob(std::string_view Namespace, const IoHash& Key)
+{
+ ZEN_TRACE_CPU("HordeClient::GetCompressedBlob");
+
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_CacheClient->ServiceUrl() << "/api/v1/compressed-blobs/" << Namespace << "/" << Key.ToHexString();
+
+ cpr::Session& Session = GetSession();
+ const CloudCacheAccessToken& AccessToken = GetAccessToken();
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+ Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, {"Accept", "application/x-ue-comp"}});
+ Session.SetOption(cpr::Body{});
+
+ cpr::Response Response = Session.Get();
+ ZEN_DEBUG("GET {}", Response);
+
+ if (Response.error)
+ {
+ return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message};
+ }
+ else if (!VerifyAccessToken(Response.status_code))
+ {
+ return {.ErrorCode = 401, .Reason = std::string("Invalid access token")};
+ }
+
+ const bool Success = Response.status_code == 200;
+ const IoBuffer Buffer = Success ? IoBufferBuilder::MakeCloneFromMemory(Response.text.data(), Response.text.size()) : IoBuffer();
+
+ return {.Response = Buffer, .Bytes = Response.downloaded_bytes, .ElapsedSeconds = Response.elapsed, .Success = Success};
+}
+
+CloudCacheResult
+CloudCacheSession::GetInlineBlob(std::string_view Namespace, std::string_view BucketId, const IoHash& Key, IoHash& OutPayloadHash)
+{
+ ZEN_TRACE_CPU("HordeClient::GetInlineBlob");
+
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_CacheClient->ServiceUrl() << "/api/v1/refs/" << Namespace << "/" << BucketId << "/" << Key.ToHexString();
+
+ cpr::Session& Session = GetSession();
+ const CloudCacheAccessToken& AccessToken = GetAccessToken();
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+ Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, {"Accept", "application/x-jupiter-inline"}});
+ Session.SetOption(cpr::Body{});
+
+ cpr::Response Response = Session.Get();
+ ZEN_DEBUG("GET {}", Response);
+
+ if (Response.error)
+ {
+ return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message};
+ }
+ else if (!VerifyAccessToken(Response.status_code))
+ {
+ return {.ErrorCode = 401, .Reason = std::string("Invalid access token")};
+ }
+
+ const bool Success = Response.status_code == 200;
+ const IoBuffer Buffer = Success ? IoBufferBuilder::MakeCloneFromMemory(Response.text.data(), Response.text.size()) : IoBuffer();
+ if (auto It = Response.header.find("X-Jupiter-InlinePayloadHash"); It != Response.header.end())
+ {
+ const std::string& PayloadHashHeader = It->second;
+ if (PayloadHashHeader.length() == IoHash::StringLength)
+ {
+ OutPayloadHash = IoHash::FromHexString(PayloadHashHeader);
+ }
+ }
+
+ return {.Response = Buffer, .Bytes = Response.downloaded_bytes, .ElapsedSeconds = Response.elapsed, .Success = Success};
+}
+
+CloudCacheResult
+CloudCacheSession::GetObject(std::string_view Namespace, const IoHash& Key)
+{
+ ZEN_TRACE_CPU("HordeClient::GetObject");
+
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_CacheClient->ServiceUrl() << "/api/v1/objects/" << Namespace << "/" << Key.ToHexString();
+
+ cpr::Session& Session = GetSession();
+ const CloudCacheAccessToken& AccessToken = GetAccessToken();
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+ Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, {"Accept", "application/x-ue-cb"}});
+ Session.SetOption(cpr::Body{});
+
+ cpr::Response Response = Session.Get();
+ ZEN_DEBUG("GET {}", Response);
+
+ if (Response.error)
+ {
+ return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message};
+ }
+ else if (!VerifyAccessToken(Response.status_code))
+ {
+ return {.ErrorCode = 401, .Reason = std::string("Invalid access token")};
+ }
+
+ const bool Success = Response.status_code == 200;
+ const IoBuffer Buffer = Success ? IoBufferBuilder::MakeCloneFromMemory(Response.text.data(), Response.text.size()) : IoBuffer();
+
+ return {.Response = Buffer, .Bytes = Response.downloaded_bytes, .ElapsedSeconds = Response.elapsed, .Success = Success};
+}
+
+PutRefResult
+CloudCacheSession::PutRef(std::string_view Namespace, std::string_view BucketId, const IoHash& Key, IoBuffer Ref, ZenContentType RefType)
+{
+ ZEN_TRACE_CPU("HordeClient::PutRef");
+
+ IoHash Hash = IoHash::HashBuffer(Ref.Data(), Ref.Size());
+
+ const std::string ContentType = RefType == ZenContentType::kCbObject ? "application/x-ue-cb" : "application/octet-stream";
+
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_CacheClient->ServiceUrl() << "/api/v1/refs/" << Namespace << "/" << BucketId << "/" << Key.ToHexString();
+
+ cpr::Session& Session = GetSession();
+ const CloudCacheAccessToken& AccessToken = GetAccessToken();
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+ Session.SetOption(
+ cpr::Header{{"Authorization", AccessToken.Value}, {"X-Jupiter-IoHash", Hash.ToHexString()}, {"Content-Type", ContentType}});
+ Session.SetBody(cpr::Body{(const char*)Ref.Data(), Ref.Size()});
+
+ cpr::Response Response = Session.Put();
+ ZEN_DEBUG("PUT {}", Response);
+
+ if (Response.error)
+ {
+ PutRefResult Result;
+ Result.ErrorCode = static_cast<int32_t>(Response.error.code);
+ Result.Reason = std::move(Response.error.message);
+ return Result;
+ }
+ else if (!VerifyAccessToken(Response.status_code))
+ {
+ PutRefResult Result;
+ Result.ErrorCode = 401;
+ Result.Reason = "Invalid access token"sv;
+ return Result;
+ }
+
+ PutRefResult Result;
+ Result.Success = (Response.status_code == 200 || Response.status_code == 201);
+ Result.Bytes = Response.uploaded_bytes;
+ Result.ElapsedSeconds = Response.elapsed;
+
+ if (Result.Success)
+ {
+ std::string JsonError;
+ json11::Json Json = json11::Json::parse(Response.text, JsonError);
+ if (JsonError.empty())
+ {
+ json11::Json::array Needs = Json["needs"].array_items();
+ for (const auto& Need : Needs)
+ {
+ Result.Needs.emplace_back(IoHash::FromHexString(Need.string_value()));
+ }
+ }
+ }
+
+ return Result;
+}
+
+FinalizeRefResult
+CloudCacheSession::FinalizeRef(std::string_view Namespace, std::string_view BucketId, const IoHash& Key, const IoHash& RefHash)
+{
+ ZEN_TRACE_CPU("HordeClient::FinalizeRef");
+
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_CacheClient->ServiceUrl() << "/api/v1/refs/" << Namespace << "/" << BucketId << "/" << Key.ToHexString() << "/finalize/"
+ << RefHash.ToHexString();
+
+ cpr::Session& Session = GetSession();
+ const CloudCacheAccessToken& AccessToken = GetAccessToken();
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+ Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value},
+ {"X-Jupiter-IoHash", RefHash.ToHexString()},
+ {"Content-Type", "application/x-ue-cb"}});
+ Session.SetBody(cpr::Body{});
+
+ cpr::Response Response = Session.Post();
+ ZEN_DEBUG("POST {}", Response);
+
+ if (Response.error)
+ {
+ FinalizeRefResult Result;
+ Result.ErrorCode = static_cast<int32_t>(Response.error.code);
+ Result.Reason = std::move(Response.error.message);
+ return Result;
+ }
+ else if (!VerifyAccessToken(Response.status_code))
+ {
+ FinalizeRefResult Result;
+ Result.ErrorCode = 401;
+ Result.Reason = "Invalid access token"sv;
+ return Result;
+ }
+
+ FinalizeRefResult Result;
+ Result.Success = (Response.status_code == 200 || Response.status_code == 201);
+ Result.Bytes = Response.uploaded_bytes;
+ Result.ElapsedSeconds = Response.elapsed;
+
+ if (Result.Success)
+ {
+ std::string JsonError;
+ json11::Json Json = json11::Json::parse(Response.text, JsonError);
+ if (JsonError.empty())
+ {
+ json11::Json::array Needs = Json["needs"].array_items();
+ for (const auto& Need : Needs)
+ {
+ Result.Needs.emplace_back(IoHash::FromHexString(Need.string_value()));
+ }
+ }
+ }
+
+ return Result;
+}
+
+CloudCacheResult
+CloudCacheSession::PutBlob(std::string_view Namespace, const IoHash& Key, IoBuffer Blob)
+{
+ ZEN_TRACE_CPU("HordeClient::PutBlob");
+
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_CacheClient->ServiceUrl() << "/api/v1/blobs/" << Namespace << "/" << Key.ToHexString();
+
+ cpr::Session& Session = GetSession();
+ const CloudCacheAccessToken& AccessToken = GetAccessToken();
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+ Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, {"Content-Type", "application/octet-stream"}});
+ Session.SetBody(cpr::Body{(const char*)Blob.Data(), Blob.Size()});
+
+ cpr::Response Response = Session.Put();
+ ZEN_DEBUG("PUT {}", Response);
+
+ if (Response.error)
+ {
+ return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message};
+ }
+ else if (!VerifyAccessToken(Response.status_code))
+ {
+ return {.ErrorCode = 401, .Reason = std::string("Invalid access token")};
+ }
+
+ return {.Bytes = Response.uploaded_bytes,
+ .ElapsedSeconds = Response.elapsed,
+ .Success = (Response.status_code == 200 || Response.status_code == 201)};
+}
+
+CloudCacheResult
+CloudCacheSession::PutCompressedBlob(std::string_view Namespace, const IoHash& Key, IoBuffer Blob)
+{
+ ZEN_TRACE_CPU("HordeClient::PutCompressedBlob");
+
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_CacheClient->ServiceUrl() << "/api/v1/compressed-blobs/" << Namespace << "/" << Key.ToHexString();
+
+ cpr::Session& Session = GetSession();
+ const CloudCacheAccessToken& AccessToken = GetAccessToken();
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+ Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, {"Content-Type", "application/x-ue-comp"}});
+ Session.SetBody(cpr::Body{(const char*)Blob.Data(), Blob.Size()});
+
+ cpr::Response Response = Session.Put();
+ ZEN_DEBUG("PUT {}", Response);
+
+ if (Response.error)
+ {
+ return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message};
+ }
+ else if (!VerifyAccessToken(Response.status_code))
+ {
+ return {.ErrorCode = 401, .Reason = std::string("Invalid access token")};
+ }
+
+ return {.Bytes = Response.uploaded_bytes,
+ .ElapsedSeconds = Response.elapsed,
+ .Success = (Response.status_code == 200 || Response.status_code == 201)};
+}
+
+CloudCacheResult
+CloudCacheSession::PutCompressedBlob(std::string_view Namespace, const IoHash& Key, const CompositeBuffer& Payload)
+{
+ ZEN_TRACE_CPU("HordeClient::PutCompressedBlob");
+
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_CacheClient->ServiceUrl() << "/api/v1/compressed-blobs/" << Namespace << "/" << Key.ToHexString();
+
+ cpr::Session& Session = GetSession();
+ const CloudCacheAccessToken& AccessToken = GetAccessToken();
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+ Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, {"Content-Type", "application/x-ue-comp"}});
+ uint64_t SizeLeft = Payload.GetSize();
+ CompositeBuffer::Iterator BufferIt = Payload.GetIterator(0);
+ auto ReadCallback = [&Payload, &BufferIt, &SizeLeft](char* buffer, size_t& size, intptr_t) {
+ size = Min<size_t>(size, SizeLeft);
+ MutableMemoryView Data(buffer, size);
+ Payload.CopyTo(Data, BufferIt);
+ SizeLeft -= size;
+ return true;
+ };
+ Session.SetReadCallback(cpr::ReadCallback(gsl::narrow<cpr::cpr_off_t>(SizeLeft), ReadCallback));
+
+ cpr::Response Response = Session.Put();
+ ZEN_DEBUG("PUT {}", Response);
+
+ if (Response.error)
+ {
+ return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message};
+ }
+ else if (!VerifyAccessToken(Response.status_code))
+ {
+ return {.ErrorCode = 401, .Reason = std::string("Invalid access token")};
+ }
+
+ return {.Bytes = Response.uploaded_bytes,
+ .ElapsedSeconds = Response.elapsed,
+ .Success = (Response.status_code == 200 || Response.status_code == 201)};
+}
+
+CloudCacheResult
+CloudCacheSession::PutObject(std::string_view Namespace, const IoHash& Key, IoBuffer Object)
+{
+ ZEN_TRACE_CPU("HordeClient::PutObject");
+
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_CacheClient->ServiceUrl() << "/api/v1/objects/" << Namespace << "/" << Key.ToHexString();
+
+ cpr::Session& Session = GetSession();
+ const CloudCacheAccessToken& AccessToken = GetAccessToken();
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+ Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, {"Content-Type", "application/x-ue-cb"}});
+ Session.SetBody(cpr::Body{(const char*)Object.Data(), Object.Size()});
+
+ cpr::Response Response = Session.Put();
+ ZEN_DEBUG("PUT {}", Response);
+
+ if (Response.error)
+ {
+ return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message};
+ }
+ else if (!VerifyAccessToken(Response.status_code))
+ {
+ return {.ErrorCode = 401, .Reason = std::string("Invalid access token")};
+ }
+
+ return {.Bytes = Response.uploaded_bytes,
+ .ElapsedSeconds = Response.elapsed,
+ .Success = (Response.status_code == 200 || Response.status_code == 201)};
+}
+
+CloudCacheResult
+CloudCacheSession::RefExists(std::string_view Namespace, std::string_view BucketId, const IoHash& Key)
+{
+ ZEN_TRACE_CPU("HordeClient::RefExists");
+
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_CacheClient->ServiceUrl() << "/api/v1/refs/" << Namespace << "/" << BucketId << "/" << Key.ToHexString();
+
+ cpr::Session& Session = GetSession();
+ const CloudCacheAccessToken& AccessToken = GetAccessToken();
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+ Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}});
+ Session.SetOption(cpr::Body{});
+
+ cpr::Response Response = Session.Head();
+ ZEN_DEBUG("HEAD {}", Response);
+
+ if (Response.error)
+ {
+ return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message};
+ }
+ else if (!VerifyAccessToken(Response.status_code))
+ {
+ return {.ErrorCode = 401, .Reason = std::string("Invalid access token")};
+ }
+
+ return {.ElapsedSeconds = Response.elapsed, .Success = Response.status_code == 200};
+}
+
+GetObjectReferencesResult
+CloudCacheSession::GetObjectReferences(std::string_view Namespace, const IoHash& Key)
+{
+ ZEN_TRACE_CPU("HordeClient::GetObjectReferences");
+
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_CacheClient->ServiceUrl() << "/api/v1/objects/" << Namespace << "/" << Key.ToHexString() << "/references";
+
+ cpr::Session& Session = GetSession();
+ const CloudCacheAccessToken& AccessToken = GetAccessToken();
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+ Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, {"Accept", "application/x-ue-cb"}});
+ Session.SetOption(cpr::Body{});
+
+ cpr::Response Response = Session.Get();
+ ZEN_DEBUG("GET {}", Response);
+
+ if (Response.error)
+ {
+ return {CloudCacheResult{.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message}};
+ }
+ else if (!VerifyAccessToken(Response.status_code))
+ {
+ return {CloudCacheResult{.ErrorCode = 401, .Reason = std::string("Invalid access token")}};
+ }
+
+ GetObjectReferencesResult Result{
+ CloudCacheResult{.Bytes = Response.downloaded_bytes, .ElapsedSeconds = Response.elapsed, .Success = Response.status_code == 200}};
+
+ if (Result.Success)
+ {
+ IoBuffer Buffer = IoBuffer(zen::IoBuffer::Wrap, Response.text.data(), Response.text.size());
+ const CbObject ReferencesResponse = LoadCompactBinaryObject(Buffer);
+ for (auto& Item : ReferencesResponse["references"sv])
+ {
+ Result.References.insert(Item.AsHash());
+ }
+ }
+
+ return Result;
+}
+
+CloudCacheResult
+CloudCacheSession::BlobExists(std::string_view Namespace, const IoHash& Key)
+{
+ return CacheTypeExists(Namespace, "blobs"sv, Key);
+}
+
+CloudCacheResult
+CloudCacheSession::CompressedBlobExists(std::string_view Namespace, const IoHash& Key)
+{
+ return CacheTypeExists(Namespace, "compressed-blobs"sv, Key);
+}
+
+CloudCacheResult
+CloudCacheSession::ObjectExists(std::string_view Namespace, const IoHash& Key)
+{
+ return CacheTypeExists(Namespace, "objects"sv, Key);
+}
+
+CloudCacheExistsResult
+CloudCacheSession::BlobExists(std::string_view Namespace, const std::set<IoHash>& Keys)
+{
+ return CacheTypeExists(Namespace, "blobs"sv, Keys);
+}
+
+CloudCacheExistsResult
+CloudCacheSession::CompressedBlobExists(std::string_view Namespace, const std::set<IoHash>& Keys)
+{
+ return CacheTypeExists(Namespace, "compressed-blobs"sv, Keys);
+}
+
+CloudCacheExistsResult
+CloudCacheSession::ObjectExists(std::string_view Namespace, const std::set<IoHash>& Keys)
+{
+ return CacheTypeExists(Namespace, "objects"sv, Keys);
+}
+
+CloudCacheResult
+CloudCacheSession::PostComputeTasks(IoBuffer TasksData)
+{
+ ZEN_TRACE_CPU("HordeClient::PostComputeTasks");
+
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_CacheClient->ServiceUrl() << "/api/v1/compute/" << m_CacheClient->ComputeCluster();
+
+ cpr::Session& Session = GetSession();
+ const CloudCacheAccessToken& AccessToken = GetAccessToken();
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+ Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, {"Content-Type", "application/x-ue-cb"}});
+ Session.SetBody(cpr::Body{(const char*)TasksData.Data(), TasksData.Size()});
+
+ cpr::Response Response = Session.Post();
+ ZEN_DEBUG("POST {}", Response);
+
+ if (Response.error)
+ {
+ return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message};
+ }
+ else if (!VerifyAccessToken(Response.status_code))
+ {
+ return {.ErrorCode = 401, .Reason = std::string("Invalid access token")};
+ }
+
+ return {.Bytes = Response.uploaded_bytes, .ElapsedSeconds = Response.elapsed, .Success = Response.status_code == 200};
+}
+
+CloudCacheResult
+CloudCacheSession::GetComputeUpdates(std::string_view ChannelId, const uint32_t WaitSeconds)
+{
+ ZEN_TRACE_CPU("HordeClient::GetComputeUpdates");
+
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_CacheClient->ServiceUrl() << "/api/v1/compute/" << m_CacheClient->ComputeCluster() << "/updates/" << ChannelId
+ << "?wait=" << WaitSeconds;
+
+ cpr::Session& Session = GetSession();
+ const CloudCacheAccessToken& AccessToken = GetAccessToken();
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+ Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, {"Accept", "application/x-ue-cb"}});
+ Session.SetOption(cpr::Body{});
+
+ cpr::Response Response = Session.Post();
+ ZEN_DEBUG("POST {}", Response);
+
+ if (Response.error)
+ {
+ return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message};
+ }
+ else if (!VerifyAccessToken(Response.status_code))
+ {
+ return {.ErrorCode = 401, .Reason = std::string("Invalid access token")};
+ }
+
+ const bool Success = Response.status_code == 200;
+ const IoBuffer Buffer = Success ? IoBufferBuilder::MakeCloneFromMemory(Response.text.data(), Response.text.size()) : IoBuffer();
+
+ return {.Response = Buffer, .Bytes = Response.downloaded_bytes, .ElapsedSeconds = Response.elapsed, .Success = Success};
+}
+
+std::vector<IoHash>
+CloudCacheSession::Filter(std::string_view Namespace, std::string_view BucketId, const std::vector<IoHash>& ChunkHashes)
+{
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_CacheClient->ServiceUrl();
+ Uri << "/api/v1/s/" << Namespace;
+
+ ZEN_UNUSED(BucketId, ChunkHashes);
+
+ return {};
+}
+
+cpr::Session&
+CloudCacheSession::GetSession()
+{
+ return m_SessionState->GetSession();
+}
+
+CloudCacheAccessToken
+CloudCacheSession::GetAccessToken(bool RefreshToken)
+{
+ return m_SessionState->GetAccessToken(RefreshToken);
+}
+
+bool
+CloudCacheSession::VerifyAccessToken(long StatusCode)
+{
+ return StatusCode != 401;
+}
+
+CloudCacheResult
+CloudCacheSession::CacheTypeExists(std::string_view Namespace, std::string_view TypeId, const IoHash& Key)
+{
+ ZEN_TRACE_CPU("HordeClient::CacheTypeExists");
+
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_CacheClient->ServiceUrl() << "/api/v1/" << TypeId << "/" << Namespace << "/" << Key.ToHexString();
+
+ cpr::Session& Session = GetSession();
+ const CloudCacheAccessToken& AccessToken = GetAccessToken();
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+ Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}});
+ Session.SetOption(cpr::Body{});
+
+ cpr::Response Response = Session.Head();
+ ZEN_DEBUG("HEAD {}", Response);
+
+ if (Response.error)
+ {
+ return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message};
+ }
+ else if (!VerifyAccessToken(Response.status_code))
+ {
+ return {.ErrorCode = 401, .Reason = std::string("Invalid access token")};
+ }
+
+ return {.ElapsedSeconds = Response.elapsed, .Success = Response.status_code == 200};
+}
+
+CloudCacheExistsResult
+CloudCacheSession::CacheTypeExists(std::string_view Namespace, std::string_view TypeId, const std::set<IoHash>& Keys)
+{
+ ZEN_TRACE_CPU("HordeClient::CacheTypeExists");
+
+ ExtendableStringBuilder<256> Body;
+ Body << "[";
+ for (const auto& Key : Keys)
+ {
+ Body << (Body.Size() != 1 ? ",\"" : "\"") << Key.ToHexString() << "\"";
+ }
+ Body << "]";
+
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_CacheClient->ServiceUrl() << "/api/v1/" << TypeId << "/" << Namespace << "/exist";
+
+ cpr::Session& Session = GetSession();
+ const CloudCacheAccessToken& AccessToken = GetAccessToken();
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+ Session.SetOption(
+ cpr::Header{{"Authorization", AccessToken.Value}, {"Accept", "application/x-ue-cb"}, {"Content-Type", "application/json"}});
+ Session.SetOption(cpr::Body(Body.ToString()));
+
+ cpr::Response Response = Session.Post();
+ ZEN_DEBUG("POST {}", Response);
+
+ if (Response.error)
+ {
+ return {CloudCacheResult{.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message}};
+ }
+ else if (!VerifyAccessToken(Response.status_code))
+ {
+ return {CloudCacheResult{.ErrorCode = 401, .Reason = std::string("Invalid access token")}};
+ }
+
+ CloudCacheExistsResult Result{
+ CloudCacheResult{.Bytes = Response.downloaded_bytes, .ElapsedSeconds = Response.elapsed, .Success = Response.status_code == 200}};
+
+ if (Result.Success)
+ {
+ IoBuffer Buffer = IoBuffer(zen::IoBuffer::Wrap, Response.text.data(), Response.text.size());
+ const CbObject ExistsResponse = LoadCompactBinaryObject(Buffer);
+ for (auto& Item : ExistsResponse["needs"sv])
+ {
+ Result.Needs.insert(Item.AsHash());
+ }
+ }
+
+ return Result;
+}
+
+/**
+ * An access token provider that holds a token that will never change.
+ */
+class StaticTokenProvider final : public CloudCacheTokenProvider
+{
+public:
+ StaticTokenProvider(CloudCacheAccessToken Token) : m_Token(std::move(Token)) {}
+
+ virtual ~StaticTokenProvider() = default;
+
+ virtual CloudCacheAccessToken AcquireAccessToken() final override { return m_Token; }
+
+private:
+ CloudCacheAccessToken m_Token;
+};
+
+std::unique_ptr<CloudCacheTokenProvider>
+CloudCacheTokenProvider::CreateFromStaticToken(CloudCacheAccessToken Token)
+{
+ return std::make_unique<StaticTokenProvider>(std::move(Token));
+}
+
+class OAuthClientCredentialsTokenProvider final : public CloudCacheTokenProvider
+{
+public:
+ OAuthClientCredentialsTokenProvider(const CloudCacheTokenProvider::OAuthClientCredentialsParams& Params)
+ {
+ m_Url = std::string(Params.Url);
+ m_ClientId = std::string(Params.ClientId);
+ m_ClientSecret = std::string(Params.ClientSecret);
+ }
+
+ virtual ~OAuthClientCredentialsTokenProvider() = default;
+
+ virtual CloudCacheAccessToken AcquireAccessToken() final override
+ {
+ using namespace std::chrono;
+
+ std::string Body =
+ fmt::format("client_id={}&scope=cache_access&grant_type=client_credentials&client_secret={}", m_ClientId, m_ClientSecret);
+
+ cpr::Response Response =
+ cpr::Post(cpr::Url{m_Url}, cpr::Header{{"Content-Type", "application/x-www-form-urlencoded"}}, cpr::Body{std::move(Body)});
+
+ if (Response.error || Response.status_code != 200)
+ {
+ return {};
+ }
+
+ std::string JsonError;
+ json11::Json Json = json11::Json::parse(Response.text, JsonError);
+
+ if (JsonError.empty() == false)
+ {
+ return {};
+ }
+
+ std::string Token = Json["access_token"].string_value();
+ int64_t ExpiresInSeconds = static_cast<int64_t>(Json["expires_in"].int_value());
+ CloudCacheAccessToken::TimePoint ExpireTime = CloudCacheAccessToken::Clock::now() + seconds(ExpiresInSeconds);
+
+ return {.Value = fmt::format("Bearer {}", Token), .ExpireTime = ExpireTime};
+ }
+
+private:
+ std::string m_Url;
+ std::string m_ClientId;
+ std::string m_ClientSecret;
+};
+
+std::unique_ptr<CloudCacheTokenProvider>
+CloudCacheTokenProvider::CreateFromOAuthClientCredentials(const OAuthClientCredentialsParams& Params)
+{
+ return std::make_unique<OAuthClientCredentialsTokenProvider>(Params);
+}
+
+class CallbackTokenProvider final : public CloudCacheTokenProvider
+{
+public:
+ CallbackTokenProvider(std::function<CloudCacheAccessToken()>&& Callback) : m_Callback(std::move(Callback)) {}
+
+ virtual ~CallbackTokenProvider() = default;
+
+ virtual CloudCacheAccessToken AcquireAccessToken() final override { return m_Callback(); }
+
+private:
+ std::function<CloudCacheAccessToken()> m_Callback;
+};
+
+std::unique_ptr<CloudCacheTokenProvider>
+CloudCacheTokenProvider::CreateFromCallback(std::function<CloudCacheAccessToken()>&& Callback)
+{
+ return std::make_unique<CallbackTokenProvider>(std::move(Callback));
+}
+
+CloudCacheClient::CloudCacheClient(const CloudCacheClientOptions& Options, std::unique_ptr<CloudCacheTokenProvider> TokenProvider)
+: m_Log(zen::logging::Get("jupiter"))
+, m_ServiceUrl(Options.ServiceUrl)
+, m_DefaultDdcNamespace(Options.DdcNamespace)
+, m_DefaultBlobStoreNamespace(Options.BlobStoreNamespace)
+, m_ComputeCluster(Options.ComputeCluster)
+, m_ConnectTimeout(Options.ConnectTimeout)
+, m_Timeout(Options.Timeout)
+, m_TokenProvider(std::move(TokenProvider))
+{
+ ZEN_ASSERT(m_TokenProvider.get() != nullptr);
+}
+
+CloudCacheClient::~CloudCacheClient()
+{
+ RwLock::ExclusiveLockScope _(m_SessionStateLock);
+
+ for (auto State : m_SessionStateCache)
+ {
+ delete State;
+ }
+}
+
+CloudCacheAccessToken
+CloudCacheClient::AcquireAccessToken()
+{
+ ZEN_TRACE_CPU("HordeClient::AcquireAccessToken");
+
+ return m_TokenProvider->AcquireAccessToken();
+}
+
+detail::CloudCacheSessionState*
+CloudCacheClient::AllocSessionState()
+{
+ detail::CloudCacheSessionState* State = nullptr;
+
+ bool IsTokenValid = false;
+
+ {
+ RwLock::ExclusiveLockScope _(m_SessionStateLock);
+
+ if (m_SessionStateCache.empty() == false)
+ {
+ State = m_SessionStateCache.front();
+ IsTokenValid = State->m_AccessToken.IsValid();
+
+ m_SessionStateCache.pop_front();
+ }
+ }
+
+ if (State == nullptr)
+ {
+ State = new detail::CloudCacheSessionState(*this);
+ }
+
+ State->Reset(m_ConnectTimeout, m_Timeout);
+
+ if (IsTokenValid == false)
+ {
+ State->m_AccessToken = m_TokenProvider->AcquireAccessToken();
+ }
+
+ return State;
+}
+
+void
+CloudCacheClient::FreeSessionState(detail::CloudCacheSessionState* State)
+{
+ RwLock::ExclusiveLockScope _(m_SessionStateLock);
+ m_SessionStateCache.push_front(State);
+}
+
+} // namespace zen
diff --git a/src/zenserver/upstream/jupiter.h b/src/zenserver/upstream/jupiter.h
new file mode 100644
index 000000000..99e5c530f
--- /dev/null
+++ b/src/zenserver/upstream/jupiter.h
@@ -0,0 +1,217 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/iohash.h>
+#include <zencore/logging.h>
+#include <zencore/refcount.h>
+#include <zencore/thread.h>
+#include <zenhttp/httpserver.h>
+
+#include <atomic>
+#include <chrono>
+#include <list>
+#include <memory>
+#include <set>
+#include <vector>
+
+struct ZenCacheValue;
+
+namespace cpr {
+class Session;
+}
+
+namespace zen {
+namespace detail {
+ struct CloudCacheSessionState;
+}
+
+class CbObjectView;
+class CloudCacheClient;
+class IoBuffer;
+struct IoHash;
+
+/**
+ * Cached access token, for use with `Authorization:` header
+ */
+struct CloudCacheAccessToken
+{
+ using Clock = std::chrono::system_clock;
+ using TimePoint = Clock::time_point;
+
+ static constexpr int64_t ExpireMarginInSeconds = 30;
+
+ std::string Value;
+ TimePoint ExpireTime;
+
+ bool IsValid() const
+ {
+ return Value.empty() == false &&
+ ExpireMarginInSeconds < std::chrono::duration_cast<std::chrono::seconds>(ExpireTime - Clock::now()).count();
+ }
+};
+
+struct CloudCacheResult
+{
+ IoBuffer Response;
+ int64_t Bytes{};
+ double ElapsedSeconds{};
+ int32_t ErrorCode{};
+ std::string Reason;
+ bool Success = false;
+};
+
+struct PutRefResult : CloudCacheResult
+{
+ std::vector<IoHash> Needs;
+};
+
+struct FinalizeRefResult : CloudCacheResult
+{
+ std::vector<IoHash> Needs;
+};
+
+struct CloudCacheExistsResult : CloudCacheResult
+{
+ std::set<IoHash> Needs;
+};
+
+struct GetObjectReferencesResult : CloudCacheResult
+{
+ std::set<IoHash> References;
+};
+
+/**
+ * Context for performing Jupiter operations
+ *
+ * Maintains an HTTP connection so that subsequent operations don't need to go
+ * through the whole connection setup process
+ *
+ */
+class CloudCacheSession
+{
+public:
+ CloudCacheSession(CloudCacheClient* CacheClient);
+ ~CloudCacheSession();
+
+ CloudCacheResult Authenticate();
+ CloudCacheResult GetRef(std::string_view Namespace, std::string_view BucketId, const IoHash& Key, ZenContentType RefType);
+ CloudCacheResult GetBlob(std::string_view Namespace, const IoHash& Key);
+ CloudCacheResult GetCompressedBlob(std::string_view Namespace, const IoHash& Key);
+ CloudCacheResult GetObject(std::string_view Namespace, const IoHash& Key);
+ CloudCacheResult GetInlineBlob(std::string_view Namespace, std::string_view BucketId, const IoHash& Key, IoHash& OutPayloadHash);
+
+ PutRefResult PutRef(std::string_view Namespace, std::string_view BucketId, const IoHash& Key, IoBuffer Ref, ZenContentType RefType);
+ CloudCacheResult PutBlob(std::string_view Namespace, const IoHash& Key, IoBuffer Blob);
+ CloudCacheResult PutCompressedBlob(std::string_view Namespace, const IoHash& Key, IoBuffer Blob);
+ CloudCacheResult PutCompressedBlob(std::string_view Namespace, const IoHash& Key, const CompositeBuffer& Blob);
+ CloudCacheResult PutObject(std::string_view Namespace, const IoHash& Key, IoBuffer Object);
+
+ FinalizeRefResult FinalizeRef(std::string_view Namespace, std::string_view BucketId, const IoHash& Key, const IoHash& RefHah);
+
+ CloudCacheResult RefExists(std::string_view Namespace, std::string_view BucketId, const IoHash& Key);
+
+ GetObjectReferencesResult GetObjectReferences(std::string_view Namespace, const IoHash& Key);
+
+ CloudCacheResult BlobExists(std::string_view Namespace, const IoHash& Key);
+ CloudCacheResult CompressedBlobExists(std::string_view Namespace, const IoHash& Key);
+ CloudCacheResult ObjectExists(std::string_view Namespace, const IoHash& Key);
+
+ CloudCacheExistsResult BlobExists(std::string_view Namespace, const std::set<IoHash>& Keys);
+ CloudCacheExistsResult CompressedBlobExists(std::string_view Namespace, const std::set<IoHash>& Keys);
+ CloudCacheExistsResult ObjectExists(std::string_view Namespace, const std::set<IoHash>& Keys);
+
+ CloudCacheResult PostComputeTasks(IoBuffer TasksData);
+ CloudCacheResult GetComputeUpdates(std::string_view ChannelId, const uint32_t WaitSeconds = 0);
+
+ std::vector<IoHash> Filter(std::string_view Namespace, std::string_view BucketId, const std::vector<IoHash>& ChunkHashes);
+
+ CloudCacheClient& Client() { return *m_CacheClient; };
+
+private:
+ inline spdlog::logger& Log() { return m_Log; }
+ cpr::Session& GetSession();
+ CloudCacheAccessToken GetAccessToken(bool RefreshToken = false);
+ bool VerifyAccessToken(long StatusCode);
+
+ CloudCacheResult CacheTypeExists(std::string_view Namespace, std::string_view TypeId, const IoHash& Key);
+
+ CloudCacheExistsResult CacheTypeExists(std::string_view Namespace, std::string_view TypeId, const std::set<IoHash>& Keys);
+
+ spdlog::logger& m_Log;
+ RefPtr<CloudCacheClient> m_CacheClient;
+ detail::CloudCacheSessionState* m_SessionState;
+};
+
+/**
+ * Access token provider interface
+ */
+class CloudCacheTokenProvider
+{
+public:
+ virtual ~CloudCacheTokenProvider() = default;
+
+ virtual CloudCacheAccessToken AcquireAccessToken() = 0;
+
+ static std::unique_ptr<CloudCacheTokenProvider> CreateFromStaticToken(CloudCacheAccessToken Token);
+
+ struct OAuthClientCredentialsParams
+ {
+ std::string_view Url;
+ std::string_view ClientId;
+ std::string_view ClientSecret;
+ };
+
+ static std::unique_ptr<CloudCacheTokenProvider> CreateFromOAuthClientCredentials(const OAuthClientCredentialsParams& Params);
+
+ static std::unique_ptr<CloudCacheTokenProvider> CreateFromCallback(std::function<CloudCacheAccessToken()>&& Callback);
+};
+
+struct CloudCacheClientOptions
+{
+ std::string_view Name;
+ std::string_view ServiceUrl;
+ std::string_view DdcNamespace;
+ std::string_view BlobStoreNamespace;
+ std::string_view ComputeCluster;
+ std::chrono::milliseconds ConnectTimeout{5000};
+ std::chrono::milliseconds Timeout{};
+};
+
+/**
+ * Jupiter upstream cache client
+ */
+class CloudCacheClient : public RefCounted
+{
+public:
+ CloudCacheClient(const CloudCacheClientOptions& Options, std::unique_ptr<CloudCacheTokenProvider> TokenProvider);
+ ~CloudCacheClient();
+
+ CloudCacheAccessToken AcquireAccessToken();
+ std::string_view DefaultDdcNamespace() const { return m_DefaultDdcNamespace; }
+ std::string_view DefaultBlobStoreNamespace() const { return m_DefaultBlobStoreNamespace; }
+ std::string_view ComputeCluster() const { return m_ComputeCluster; }
+ std::string_view ServiceUrl() const { return m_ServiceUrl; }
+
+ spdlog::logger& Logger() { return m_Log; }
+
+private:
+ spdlog::logger& m_Log;
+ std::string m_ServiceUrl;
+ std::string m_DefaultDdcNamespace;
+ std::string m_DefaultBlobStoreNamespace;
+ std::string m_ComputeCluster;
+ std::chrono::milliseconds m_ConnectTimeout{};
+ std::chrono::milliseconds m_Timeout{};
+ std::unique_ptr<CloudCacheTokenProvider> m_TokenProvider;
+
+ RwLock m_SessionStateLock;
+ std::list<detail::CloudCacheSessionState*> m_SessionStateCache;
+
+ detail::CloudCacheSessionState* AllocSessionState();
+ void FreeSessionState(detail::CloudCacheSessionState*);
+
+ friend class CloudCacheSession;
+};
+
+} // namespace zen
diff --git a/src/zenserver/upstream/upstream.h b/src/zenserver/upstream/upstream.h
new file mode 100644
index 000000000..a57301206
--- /dev/null
+++ b/src/zenserver/upstream/upstream.h
@@ -0,0 +1,8 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <upstream/jupiter.h>
+#include <upstream/upstreamcache.h>
+#include <upstream/upstreamservice.h>
+#include <upstream/zen.h>
diff --git a/src/zenserver/upstream/upstreamapply.cpp b/src/zenserver/upstream/upstreamapply.cpp
new file mode 100644
index 000000000..c719b225d
--- /dev/null
+++ b/src/zenserver/upstream/upstreamapply.cpp
@@ -0,0 +1,459 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "upstreamapply.h"
+
+#if ZEN_WITH_COMPUTE_SERVICES
+
+# include <zencore/compactbinary.h>
+# include <zencore/compactbinarybuilder.h>
+# include <zencore/fmtutils.h>
+# include <zencore/stream.h>
+# include <zencore/timer.h>
+# include <zencore/workthreadpool.h>
+
+# include <zenstore/cidstore.h>
+
+# include "diag/logging.h"
+
+# include <fmt/format.h>
+
+# include <atomic>
+
+namespace zen {
+
+using namespace std::literals;
+
+struct UpstreamApplyStats
+{
+ static constexpr uint64_t MaxSampleCount = 1000ull;
+
+ UpstreamApplyStats(bool Enabled) : m_Enabled(Enabled) {}
+
+ void Add(UpstreamApplyEndpoint& Endpoint, const PostUpstreamApplyResult& Result)
+ {
+ UpstreamApplyEndpointStats& Stats = Endpoint.Stats();
+
+ if (Result.Error)
+ {
+ Stats.ErrorCount.Increment(1);
+ }
+ else if (Result.Success)
+ {
+ Stats.PostCount.Increment(1);
+ Stats.UpBytes.Increment(Result.Bytes / 1024 / 1024);
+ }
+ }
+
+ void Add(UpstreamApplyEndpoint& Endpoint, const GetUpstreamApplyUpdatesResult& Result)
+ {
+ UpstreamApplyEndpointStats& Stats = Endpoint.Stats();
+
+ if (Result.Error)
+ {
+ Stats.ErrorCount.Increment(1);
+ }
+ else if (Result.Success)
+ {
+ Stats.UpdateCount.Increment(1);
+ Stats.DownBytes.Increment(Result.Bytes / 1024 / 1024);
+ if (!Result.Completed.empty())
+ {
+ uint64_t Completed = 0;
+ for (auto& It : Result.Completed)
+ {
+ Completed += It.second.size();
+ }
+ Stats.CompleteCount.Increment(Completed);
+ }
+ }
+ }
+
+ bool m_Enabled;
+};
+
+//////////////////////////////////////////////////////////////////////////
+
+class UpstreamApplyImpl final : public UpstreamApply
+{
+public:
+ UpstreamApplyImpl(const UpstreamApplyOptions& Options, CidStore& CidStore)
+ : m_Log(logging::Get("upstream-apply"))
+ , m_Options(Options)
+ , m_CidStore(CidStore)
+ , m_Stats(Options.StatsEnabled)
+ , m_UpstreamAsyncWorkPool(Options.UpstreamThreadCount)
+ , m_DownstreamAsyncWorkPool(Options.DownstreamThreadCount)
+ {
+ }
+
+ virtual ~UpstreamApplyImpl() { Shutdown(); }
+
+ virtual bool Initialize() override
+ {
+ for (auto& Endpoint : m_Endpoints)
+ {
+ const UpstreamEndpointHealth Health = Endpoint->Initialize();
+ if (Health.Ok)
+ {
+ Log().info("initialize endpoint '{}' OK", Endpoint->DisplayName());
+ }
+ else
+ {
+ Log().warn("initialize endpoint '{}' FAILED, reason '{}'", Endpoint->DisplayName(), Health.Reason);
+ }
+ }
+
+ m_RunState.IsRunning = !m_Endpoints.empty();
+
+ if (m_RunState.IsRunning)
+ {
+ m_ShutdownEvent.Reset();
+
+ m_UpstreamUpdatesThread = std::thread(&UpstreamApplyImpl::ProcessUpstreamUpdates, this);
+
+ m_EndpointMonitorThread = std::thread(&UpstreamApplyImpl::MonitorEndpoints, this);
+ }
+
+ return m_RunState.IsRunning;
+ }
+
+ virtual bool IsHealthy() const override
+ {
+ if (m_RunState.IsRunning)
+ {
+ for (const auto& Endpoint : m_Endpoints)
+ {
+ if (Endpoint->IsHealthy())
+ {
+ return true;
+ }
+ }
+ }
+
+ return false;
+ }
+
+ virtual void RegisterEndpoint(std::unique_ptr<UpstreamApplyEndpoint> Endpoint) override
+ {
+ m_Endpoints.emplace_back(std::move(Endpoint));
+ }
+
+ virtual EnqueueResult EnqueueUpstream(UpstreamApplyRecord ApplyRecord) override
+ {
+ if (m_RunState.IsRunning)
+ {
+ const IoHash WorkerId = ApplyRecord.WorkerDescriptor.GetHash();
+ const IoHash ActionId = ApplyRecord.Action.GetHash();
+ const uint32_t TimeoutSeconds = ApplyRecord.WorkerDescriptor["timeout"sv].AsInt32(300);
+
+ {
+ std::scoped_lock Lock(m_ApplyTasksMutex);
+ if (auto Status = FindStatus(WorkerId, ActionId); Status != nullptr)
+ {
+ // Already in progress
+ return {.ApplyId = ActionId, .Success = true};
+ }
+
+ std::chrono::steady_clock::time_point ExpireTime =
+ TimeoutSeconds > 0 ? std::chrono::steady_clock::now() + std::chrono::seconds(TimeoutSeconds)
+ : std::chrono::steady_clock::time_point::max();
+
+ m_ApplyTasks[WorkerId][ActionId] = {.State = UpstreamApplyState::Queued, .Result{}, .ExpireTime = std::move(ExpireTime)};
+ }
+
+ ApplyRecord.Timepoints["zen-queue-added"] = DateTime::NowTicks();
+ m_UpstreamAsyncWorkPool.ScheduleWork(
+ [this, ApplyRecord = std::move(ApplyRecord)]() { ProcessApplyRecord(std::move(ApplyRecord)); });
+
+ return {.ApplyId = ActionId, .Success = true};
+ }
+
+ return {};
+ }
+
+ virtual StatusResult GetStatus(const IoHash& WorkerId, const IoHash& ActionId) override
+ {
+ if (m_RunState.IsRunning)
+ {
+ std::scoped_lock Lock(m_ApplyTasksMutex);
+ if (auto Status = FindStatus(WorkerId, ActionId); Status != nullptr)
+ {
+ return {.Status = *Status, .Success = true};
+ }
+ }
+
+ return {};
+ }
+
+ virtual void GetStatus(CbObjectWriter& Status) override
+ {
+ Status << "upstream_worker_threads" << m_Options.UpstreamThreadCount;
+ Status << "upstream_queue_count" << m_UpstreamAsyncWorkPool.PendingWork();
+ Status << "downstream_worker_threads" << m_Options.DownstreamThreadCount;
+ Status << "downstream_queue_count" << m_DownstreamAsyncWorkPool.PendingWork();
+
+ Status.BeginArray("endpoints");
+ for (const auto& Ep : m_Endpoints)
+ {
+ Status.BeginObject();
+ Status << "name" << Ep->DisplayName();
+ Status << "health" << (Ep->IsHealthy() ? "ok"sv : "inactive"sv);
+
+ UpstreamApplyEndpointStats& Stats = Ep->Stats();
+ const uint64_t PostCount = Stats.PostCount.Value();
+ const uint64_t CompleteCount = Stats.CompleteCount.Value();
+ // const uint64_t UpdateCount = Stats.UpdateCount;
+ const double CompleteRate = CompleteCount > 0 ? (double(PostCount) / double(CompleteCount)) : 0.0;
+
+ Status << "post_count" << PostCount;
+ Status << "complete_count" << PostCount;
+ Status << "update_count" << Stats.UpdateCount.Value();
+
+ Status << "complete_ratio" << CompleteRate;
+ Status << "downloaded_mb" << Stats.DownBytes.Value();
+ Status << "uploaded_mb" << Stats.UpBytes.Value();
+ Status << "error_count" << Stats.ErrorCount.Value();
+
+ Status.EndObject();
+ }
+ Status.EndArray();
+ }
+
+private:
+ // The caller is responsible for locking if required
+ UpstreamApplyStatus* FindStatus(const IoHash& WorkerId, const IoHash& ActionId)
+ {
+ if (auto It = m_ApplyTasks.find(WorkerId); It != m_ApplyTasks.end())
+ {
+ if (auto It2 = It->second.find(ActionId); It2 != It->second.end())
+ {
+ return &It2->second;
+ }
+ }
+ return nullptr;
+ }
+
+ void ProcessApplyRecord(UpstreamApplyRecord ApplyRecord)
+ {
+ const IoHash WorkerId = ApplyRecord.WorkerDescriptor.GetHash();
+ const IoHash ActionId = ApplyRecord.Action.GetHash();
+ try
+ {
+ for (auto& Endpoint : m_Endpoints)
+ {
+ if (Endpoint->IsHealthy())
+ {
+ ApplyRecord.Timepoints["zen-queue-dispatched"] = DateTime::NowTicks();
+ PostUpstreamApplyResult Result = Endpoint->PostApply(std::move(ApplyRecord));
+ {
+ std::scoped_lock Lock(m_ApplyTasksMutex);
+ if (auto Status = FindStatus(WorkerId, ActionId); Status != nullptr)
+ {
+ Status->Timepoints.merge(Result.Timepoints);
+
+ if (Result.Success)
+ {
+ Status->State = UpstreamApplyState::Executing;
+ }
+ else
+ {
+ Status->State = UpstreamApplyState::Complete;
+ Status->Result = {.Error = std::move(Result.Error),
+ .Bytes = Result.Bytes,
+ .ElapsedSeconds = Result.ElapsedSeconds};
+ }
+ }
+ }
+ m_Stats.Add(*Endpoint, Result);
+ return;
+ }
+ }
+
+ Log().warn("process upstream apply ({}/{}) FAILED 'No available endpoint'", WorkerId, ActionId);
+
+ {
+ std::scoped_lock Lock(m_ApplyTasksMutex);
+ if (auto Status = FindStatus(WorkerId, ActionId); Status != nullptr)
+ {
+ Status->State = UpstreamApplyState::Complete;
+ Status->Result = {.Error{.ErrorCode = -1, .Reason = "No available endpoint"}};
+ }
+ }
+ }
+ catch (std::exception& e)
+ {
+ std::scoped_lock Lock(m_ApplyTasksMutex);
+ if (auto Status = FindStatus(WorkerId, ActionId); Status != nullptr)
+ {
+ Status->State = UpstreamApplyState::Complete;
+ Status->Result = {.Error{.ErrorCode = -1, .Reason = e.what()}};
+ }
+ Log().warn("process upstream apply ({}/{}) FAILED '{}'", WorkerId, ActionId, e.what());
+ }
+ }
+
+ void ProcessApplyUpdates()
+ {
+ for (auto& Endpoint : m_Endpoints)
+ {
+ if (Endpoint->IsHealthy())
+ {
+ GetUpstreamApplyUpdatesResult Result = Endpoint->GetUpdates(m_DownstreamAsyncWorkPool);
+ m_Stats.Add(*Endpoint, Result);
+
+ if (!Result.Success)
+ {
+ Log().warn("process upstream apply updates FAILED '{}'", Result.Error.Reason);
+ }
+
+ if (!Result.Completed.empty())
+ {
+ for (auto& It : Result.Completed)
+ {
+ for (auto& It2 : It.second)
+ {
+ std::scoped_lock Lock(m_ApplyTasksMutex);
+ if (auto Status = FindStatus(It.first, It2.first); Status != nullptr)
+ {
+ Status->State = UpstreamApplyState::Complete;
+ Status->Result = std::move(It2.second);
+ Status->Result.Timepoints.merge(Status->Timepoints);
+ Status->Result.Timepoints["zen-queue-complete"] = DateTime::NowTicks();
+ Status->Timepoints.clear();
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+
+ void ProcessUpstreamUpdates()
+ {
+ const auto& UpdateSleep = std::chrono::milliseconds(m_Options.UpdatesInterval);
+ while (!m_ShutdownEvent.Wait(uint32_t(UpdateSleep.count())))
+ {
+ if (!m_RunState.IsRunning)
+ {
+ break;
+ }
+
+ ProcessApplyUpdates();
+
+ // Remove any expired tasks, regardless of state
+ {
+ std::scoped_lock Lock(m_ApplyTasksMutex);
+ for (auto& WorkerIt : m_ApplyTasks)
+ {
+ const auto Count = std::erase_if(WorkerIt.second, [](const auto& Item) {
+ return Item.second.ExpireTime < std::chrono::steady_clock::now();
+ });
+ if (Count > 0)
+ {
+ Log().debug("Removed '{}' expired tasks", Count);
+ }
+ }
+ const auto Count = std::erase_if(m_ApplyTasks, [](const auto& Item) { return Item.second.empty(); });
+ if (Count > 0)
+ {
+ Log().debug("Removed '{}' empty task lists", Count);
+ }
+ }
+ }
+ }
+
+ void MonitorEndpoints()
+ {
+ for (;;)
+ {
+ {
+ std::unique_lock Lock(m_RunState.Mutex);
+ if (m_RunState.ExitSignal.wait_for(Lock, m_Options.HealthCheckInterval, [this]() { return !m_RunState.IsRunning.load(); }))
+ {
+ break;
+ }
+ }
+
+ for (auto& Endpoint : m_Endpoints)
+ {
+ if (!Endpoint->IsHealthy())
+ {
+ if (const UpstreamEndpointHealth Health = Endpoint->CheckHealth(); Health.Ok)
+ {
+ Log().warn("health check endpoint '{}' OK", Endpoint->DisplayName(), Health.Reason);
+ }
+ else
+ {
+ Log().warn("health check endpoint '{}' FAILED, reason '{}'", Endpoint->DisplayName(), Health.Reason);
+ }
+ }
+ }
+ }
+ }
+
+ void Shutdown()
+ {
+ if (m_RunState.Stop())
+ {
+ m_ShutdownEvent.Set();
+ m_EndpointMonitorThread.join();
+ m_UpstreamUpdatesThread.join();
+ m_Endpoints.clear();
+ }
+ }
+
+ spdlog::logger& Log() { return m_Log; }
+
+ struct RunState
+ {
+ std::mutex Mutex;
+ std::condition_variable ExitSignal;
+ std::atomic_bool IsRunning{false};
+
+ bool Stop()
+ {
+ bool Stopped = false;
+ {
+ std::scoped_lock Lock(Mutex);
+ Stopped = IsRunning.exchange(false);
+ }
+ if (Stopped)
+ {
+ ExitSignal.notify_all();
+ }
+ return Stopped;
+ }
+ };
+
+ spdlog::logger& m_Log;
+ UpstreamApplyOptions m_Options;
+ CidStore& m_CidStore;
+ UpstreamApplyStats m_Stats;
+ UpstreamApplyTasks m_ApplyTasks;
+ std::mutex m_ApplyTasksMutex;
+ std::vector<std::unique_ptr<UpstreamApplyEndpoint>> m_Endpoints;
+ Event m_ShutdownEvent;
+ WorkerThreadPool m_UpstreamAsyncWorkPool;
+ WorkerThreadPool m_DownstreamAsyncWorkPool;
+ std::thread m_UpstreamUpdatesThread;
+ std::thread m_EndpointMonitorThread;
+ RunState m_RunState;
+};
+
+//////////////////////////////////////////////////////////////////////////
+
+bool
+UpstreamApply::IsHealthy() const
+{
+ return false;
+}
+
+std::unique_ptr<UpstreamApply>
+UpstreamApply::Create(const UpstreamApplyOptions& Options, CidStore& CidStore)
+{
+ return std::make_unique<UpstreamApplyImpl>(Options, CidStore);
+}
+
+} // namespace zen
+
+#endif // ZEN_WITH_COMPUTE_SERVICES
diff --git a/src/zenserver/upstream/upstreamapply.h b/src/zenserver/upstream/upstreamapply.h
new file mode 100644
index 000000000..4a095be6c
--- /dev/null
+++ b/src/zenserver/upstream/upstreamapply.h
@@ -0,0 +1,192 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#if ZEN_WITH_COMPUTE_SERVICES
+
+# include <zencore/compactbinarypackage.h>
+# include <zencore/iobuffer.h>
+# include <zencore/iohash.h>
+# include <zencore/stats.h>
+# include <zencore/zencore.h>
+
+# include <chrono>
+# include <map>
+# include <unordered_map>
+# include <unordered_set>
+
+namespace zen {
+
+class AuthMgr;
+class CbObjectWriter;
+class CidStore;
+class CloudCacheTokenProvider;
+class WorkerThreadPool;
+class ZenCacheNamespace;
+struct CloudCacheClientOptions;
+struct UpstreamAuthConfig;
+
+enum class UpstreamApplyState : int32_t
+{
+ Queued = 0,
+ Executing = 1,
+ Complete = 2,
+};
+
+enum class UpstreamApplyType
+{
+ Simple = 0,
+ Asset = 1,
+};
+
+struct UpstreamApplyRecord
+{
+ CbObject WorkerDescriptor;
+ CbObject Action;
+ UpstreamApplyType Type;
+ std::map<std::string, uint64_t> Timepoints{};
+};
+
+struct UpstreamApplyOptions
+{
+ std::chrono::seconds HealthCheckInterval{5};
+ std::chrono::seconds UpdatesInterval{5};
+ uint32_t UpstreamThreadCount = 4;
+ uint32_t DownstreamThreadCount = 4;
+ bool StatsEnabled = false;
+};
+
+struct UpstreamApplyError
+{
+ int32_t ErrorCode{};
+ std::string Reason{};
+
+ explicit operator bool() const { return ErrorCode != 0; }
+};
+
+struct PostUpstreamApplyResult
+{
+ UpstreamApplyError Error{};
+ int64_t Bytes{};
+ double ElapsedSeconds{};
+ std::map<std::string, uint64_t> Timepoints{};
+ bool Success = false;
+};
+
+struct GetUpstreamApplyResult
+{
+ // UpstreamApplyType::Simple
+ std::map<std::filesystem::path, IoHash> OutputFiles{};
+ std::map<IoHash, IoBuffer> FileData{};
+
+ // UpstreamApplyType::Asset
+ CbPackage OutputPackage{};
+ int64_t TotalAttachmentBytes{};
+ int64_t TotalRawAttachmentBytes{};
+
+ UpstreamApplyError Error{};
+ int64_t Bytes{};
+ double ElapsedSeconds{};
+ std::string StdOut{};
+ std::string StdErr{};
+ std::string Agent{};
+ std::string Detail{};
+ std::map<std::string, uint64_t> Timepoints{};
+ bool Success = false;
+};
+
+using UpstreamApplyCompleted = std::unordered_map<IoHash, std::unordered_map<IoHash, GetUpstreamApplyResult>>;
+
+struct GetUpstreamApplyUpdatesResult
+{
+ UpstreamApplyError Error{};
+ int64_t Bytes{};
+ double ElapsedSeconds{};
+ UpstreamApplyCompleted Completed{};
+ bool Success = false;
+};
+
+struct UpstreamApplyStatus
+{
+ UpstreamApplyState State{};
+ GetUpstreamApplyResult Result{};
+ std::chrono::steady_clock::time_point ExpireTime{};
+ std::map<std::string, uint64_t> Timepoints{};
+};
+
+using UpstreamApplyTasks = std::unordered_map<IoHash, std::unordered_map<IoHash, UpstreamApplyStatus>>;
+
+struct UpstreamEndpointHealth
+{
+ std::string Reason;
+ bool Ok = false;
+};
+
+struct UpstreamApplyEndpointStats
+{
+ metrics::Counter PostCount;
+ metrics::Counter CompleteCount;
+ metrics::Counter UpdateCount;
+ metrics::Counter ErrorCount;
+ metrics::Counter UpBytes;
+ metrics::Counter DownBytes;
+};
+
+/**
+ * The upstream apply endpoint is responsible for handling remote execution.
+ */
+class UpstreamApplyEndpoint
+{
+public:
+ virtual ~UpstreamApplyEndpoint() = default;
+
+ virtual UpstreamEndpointHealth Initialize() = 0;
+ virtual bool IsHealthy() const = 0;
+ virtual UpstreamEndpointHealth CheckHealth() = 0;
+ virtual std::string_view DisplayName() const = 0;
+ virtual PostUpstreamApplyResult PostApply(UpstreamApplyRecord ApplyRecord) = 0;
+ virtual GetUpstreamApplyUpdatesResult GetUpdates(WorkerThreadPool& ThreadPool) = 0;
+ virtual UpstreamApplyEndpointStats& Stats() = 0;
+
+ static std::unique_ptr<UpstreamApplyEndpoint> CreateHordeEndpoint(const CloudCacheClientOptions& ComputeOptions,
+ const UpstreamAuthConfig& ComputeAuthConfig,
+ const CloudCacheClientOptions& StorageOptions,
+ const UpstreamAuthConfig& StorageAuthConfig,
+ CidStore& CidStore,
+ AuthMgr& Mgr);
+};
+
+/**
+ * Manages one or more upstream compute endpoints.
+ */
+class UpstreamApply
+{
+public:
+ virtual ~UpstreamApply() = default;
+
+ virtual bool Initialize() = 0;
+ virtual bool IsHealthy() const = 0;
+ virtual void RegisterEndpoint(std::unique_ptr<UpstreamApplyEndpoint> Endpoint) = 0;
+
+ struct EnqueueResult
+ {
+ IoHash ApplyId{};
+ bool Success = false;
+ };
+
+ struct StatusResult
+ {
+ UpstreamApplyStatus Status{};
+ bool Success = false;
+ };
+
+ virtual EnqueueResult EnqueueUpstream(UpstreamApplyRecord ApplyRecord) = 0;
+ virtual StatusResult GetStatus(const IoHash& WorkerId, const IoHash& ActionId) = 0;
+ virtual void GetStatus(CbObjectWriter& CbO) = 0;
+
+ static std::unique_ptr<UpstreamApply> Create(const UpstreamApplyOptions& Options, CidStore& CidStore);
+};
+
+} // namespace zen
+
+#endif // ZEN_WITH_COMPUTE_SERVICES
diff --git a/src/zenserver/upstream/upstreamcache.cpp b/src/zenserver/upstream/upstreamcache.cpp
new file mode 100644
index 000000000..e838b5fe2
--- /dev/null
+++ b/src/zenserver/upstream/upstreamcache.cpp
@@ -0,0 +1,2112 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "upstreamcache.h"
+#include "jupiter.h"
+#include "zen.h"
+
+#include <zencore/blockingqueue.h>
+#include <zencore/compactbinary.h>
+#include <zencore/compactbinarybuilder.h>
+#include <zencore/compactbinarypackage.h>
+#include <zencore/compactbinaryvalidation.h>
+#include <zencore/fmtutils.h>
+#include <zencore/stats.h>
+#include <zencore/stream.h>
+#include <zencore/timer.h>
+#include <zencore/trace.h>
+
+#include <zenhttp/httpshared.h>
+
+#include <zenstore/cidstore.h>
+
+#include <auth/authmgr.h>
+#include "cache/structuredcache.h"
+#include "cache/structuredcachestore.h"
+#include "diag/logging.h"
+
+#include <fmt/format.h>
+
+#include <algorithm>
+#include <atomic>
+#include <shared_mutex>
+#include <thread>
+#include <unordered_map>
+
+namespace zen {
+
+using namespace std::literals;
+
+namespace detail {
+
+ class UpstreamStatus
+ {
+ public:
+ UpstreamEndpointState EndpointState() const { return static_cast<UpstreamEndpointState>(m_State.load(std::memory_order_relaxed)); }
+
+ UpstreamEndpointStatus EndpointStatus() const
+ {
+ const UpstreamEndpointState State = EndpointState();
+ {
+ std::unique_lock _(m_Mutex);
+ return {.Reason = m_ErrorText, .State = State};
+ }
+ }
+
+ void Set(UpstreamEndpointState NewState)
+ {
+ m_State.store(static_cast<uint32_t>(NewState), std::memory_order_relaxed);
+ {
+ std::unique_lock _(m_Mutex);
+ m_ErrorText.clear();
+ }
+ }
+
+ void Set(UpstreamEndpointState NewState, std::string ErrorText)
+ {
+ m_State.store(static_cast<uint32_t>(NewState), std::memory_order_relaxed);
+ {
+ std::unique_lock _(m_Mutex);
+ m_ErrorText = std::move(ErrorText);
+ }
+ }
+
+ void SetFromErrorCode(int32_t ErrorCode, std::string_view ErrorText)
+ {
+ if (ErrorCode != 0)
+ {
+ Set(ErrorCode == 401 ? UpstreamEndpointState::kUnauthorized : UpstreamEndpointState::kError, std::string(ErrorText));
+ }
+ }
+
+ private:
+ mutable std::mutex m_Mutex;
+ std::string m_ErrorText;
+ std::atomic_uint32_t m_State;
+ };
+
+ class JupiterUpstreamEndpoint final : public UpstreamEndpoint
+ {
+ public:
+ JupiterUpstreamEndpoint(const CloudCacheClientOptions& Options, const UpstreamAuthConfig& AuthConfig, AuthMgr& Mgr)
+ : m_AuthMgr(Mgr)
+ , m_Log(zen::logging::Get("upstream"))
+ {
+ ZEN_ASSERT(!Options.Name.empty());
+ m_Info.Name = Options.Name;
+ m_Info.Url = Options.ServiceUrl;
+
+ std::unique_ptr<CloudCacheTokenProvider> TokenProvider;
+
+ if (AuthConfig.OAuthUrl.empty() == false)
+ {
+ TokenProvider = CloudCacheTokenProvider::CreateFromOAuthClientCredentials(
+ {.Url = AuthConfig.OAuthUrl, .ClientId = AuthConfig.OAuthClientId, .ClientSecret = AuthConfig.OAuthClientSecret});
+ }
+ else if (AuthConfig.OpenIdProvider.empty() == false)
+ {
+ TokenProvider =
+ CloudCacheTokenProvider::CreateFromCallback([this, ProviderName = std::string(AuthConfig.OpenIdProvider)]() {
+ AuthMgr::OpenIdAccessToken Token = m_AuthMgr.GetOpenIdAccessToken(ProviderName);
+ return CloudCacheAccessToken{.Value = Token.AccessToken, .ExpireTime = Token.ExpireTime};
+ });
+ }
+ else
+ {
+ CloudCacheAccessToken AccessToken{.Value = std::string(AuthConfig.AccessToken),
+ .ExpireTime = CloudCacheAccessToken::TimePoint::max()};
+
+ TokenProvider = CloudCacheTokenProvider::CreateFromStaticToken(AccessToken);
+ }
+
+ m_Client = new CloudCacheClient(Options, std::move(TokenProvider));
+ }
+
+ virtual ~JupiterUpstreamEndpoint() = default;
+
+ virtual const UpstreamEndpointInfo& GetEndpointInfo() const override { return m_Info; }
+
+ virtual UpstreamEndpointStatus Initialize() override
+ {
+ try
+ {
+ if (m_Status.EndpointState() == UpstreamEndpointState::kOk)
+ {
+ return {.State = UpstreamEndpointState::kOk};
+ }
+
+ CloudCacheSession Session(m_Client);
+ const CloudCacheResult Result = Session.Authenticate();
+
+ if (Result.Success)
+ {
+ m_Status.Set(UpstreamEndpointState::kOk);
+ }
+ else if (Result.ErrorCode != 0)
+ {
+ m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason);
+ }
+ else
+ {
+ m_Status.Set(UpstreamEndpointState::kUnauthorized);
+ }
+
+ return m_Status.EndpointStatus();
+ }
+ catch (std::exception& Err)
+ {
+ m_Status.Set(UpstreamEndpointState::kError, Err.what());
+
+ return {.Reason = Err.what(), .State = GetState()};
+ }
+ }
+
+ std::string_view GetActualDdcNamespace(CloudCacheSession& Session, std::string_view Namespace)
+ {
+ if (Namespace == ZenCacheStore::DefaultNamespace)
+ {
+ return Session.Client().DefaultDdcNamespace();
+ }
+ return Namespace;
+ }
+
+ std::string_view GetActualBlobStoreNamespace(CloudCacheSession& Session, std::string_view Namespace)
+ {
+ if (Namespace == ZenCacheStore::DefaultNamespace)
+ {
+ return Session.Client().DefaultBlobStoreNamespace();
+ }
+ return Namespace;
+ }
+
+ virtual UpstreamEndpointState GetState() override { return m_Status.EndpointState(); }
+
+ virtual UpstreamEndpointStatus GetStatus() override { return m_Status.EndpointStatus(); }
+
+ virtual GetUpstreamCacheSingleResult GetCacheRecord(std::string_view Namespace,
+ const CacheKey& CacheKey,
+ ZenContentType Type) override
+ {
+ ZEN_TRACE_CPU("Upstream::Horde::GetSingleCacheRecord");
+
+ try
+ {
+ CloudCacheSession Session(m_Client);
+ CloudCacheResult Result;
+
+ std::string_view BlobStoreNamespace = GetActualBlobStoreNamespace(Session, Namespace);
+
+ if (Type == ZenContentType::kCompressedBinary)
+ {
+ Result = Session.GetRef(BlobStoreNamespace, CacheKey.Bucket, CacheKey.Hash, ZenContentType::kCbObject);
+
+ if (Result.Success)
+ {
+ const CbValidateError ValidationResult = ValidateCompactBinary(Result.Response, CbValidateMode::All);
+ if (Result.Success = ValidationResult == CbValidateError::None; Result.Success)
+ {
+ CbObject CacheRecord = LoadCompactBinaryObject(Result.Response);
+ IoBuffer ContentBuffer;
+ int NumAttachments = 0;
+
+ CacheRecord.IterateAttachments([&](CbFieldView AttachmentHash) {
+ CloudCacheResult AttachmentResult = Session.GetCompressedBlob(BlobStoreNamespace, AttachmentHash.AsHash());
+ Result.Bytes += AttachmentResult.Bytes;
+ Result.ElapsedSeconds += AttachmentResult.ElapsedSeconds;
+ Result.ErrorCode = AttachmentResult.ErrorCode;
+
+ IoHash RawHash;
+ uint64_t RawSize;
+ if (CompressedBuffer::ValidateCompressedHeader(AttachmentResult.Response, RawHash, RawSize))
+ {
+ Result.Response = AttachmentResult.Response;
+ ++NumAttachments;
+ }
+ else
+ {
+ Result.Success = false;
+ }
+ });
+ if (NumAttachments != 1)
+ {
+ Result.Success = false;
+ }
+ }
+ }
+ }
+ else
+ {
+ const ZenContentType AcceptType = Type == ZenContentType::kCbPackage ? ZenContentType::kCbObject : Type;
+ Result = Session.GetRef(BlobStoreNamespace, CacheKey.Bucket, CacheKey.Hash, AcceptType);
+
+ if (Result.Success && Type == ZenContentType::kCbPackage)
+ {
+ CbPackage Package;
+
+ const CbValidateError ValidationResult = ValidateCompactBinary(Result.Response, CbValidateMode::All);
+ if (Result.Success = ValidationResult == CbValidateError::None; Result.Success)
+ {
+ CbObject CacheRecord = LoadCompactBinaryObject(Result.Response);
+
+ CacheRecord.IterateAttachments([&](CbFieldView AttachmentHash) {
+ CloudCacheResult AttachmentResult = Session.GetCompressedBlob(BlobStoreNamespace, AttachmentHash.AsHash());
+ Result.Bytes += AttachmentResult.Bytes;
+ Result.ElapsedSeconds += AttachmentResult.ElapsedSeconds;
+ Result.ErrorCode = AttachmentResult.ErrorCode;
+
+ IoHash RawHash;
+ uint64_t RawSize;
+ if (CompressedBuffer Chunk =
+ CompressedBuffer::FromCompressed(SharedBuffer(AttachmentResult.Response), RawHash, RawSize))
+ {
+ Package.AddAttachment(CbAttachment(Chunk, AttachmentHash.AsHash()));
+ }
+ else
+ {
+ Result.Success = false;
+ }
+ });
+
+ Package.SetObject(CacheRecord);
+ }
+
+ if (Result.Success)
+ {
+ BinaryWriter MemStream;
+ Package.Save(MemStream);
+
+ Result.Response = IoBuffer(IoBuffer::Clone, MemStream.Data(), MemStream.Size());
+ }
+ }
+ }
+
+ m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason);
+
+ if (Result.ErrorCode == 0)
+ {
+ return {.Status = {.Bytes = Result.Bytes, .ElapsedSeconds = Result.ElapsedSeconds, .Success = Result.Success},
+ .Value = Result.Response,
+ .Source = &m_Info};
+ }
+ else
+ {
+ return {.Status = {.Error{.ErrorCode = Result.ErrorCode, .Reason = std::move(Result.Reason)}}};
+ }
+ }
+ catch (std::exception& Err)
+ {
+ m_Status.Set(UpstreamEndpointState::kError, Err.what());
+
+ return {.Status = {.Error{.ErrorCode = -1, .Reason = Err.what()}}};
+ }
+ }
+
+ virtual GetUpstreamCacheResult GetCacheRecords(std::string_view Namespace,
+ std::span<CacheKeyRequest*> Requests,
+ OnCacheRecordGetComplete&& OnComplete) override
+ {
+ ZEN_TRACE_CPU("Upstream::Horde::GetCacheRecords");
+
+ CloudCacheSession Session(m_Client);
+ GetUpstreamCacheResult Result;
+
+ for (CacheKeyRequest* Request : Requests)
+ {
+ const CacheKey& CacheKey = Request->Key;
+ CbPackage Package;
+ CbObject Record;
+
+ double ElapsedSeconds = 0.0;
+ if (!Result.Error)
+ {
+ std::string_view BlobStoreNamespace = GetActualBlobStoreNamespace(Session, Namespace);
+ CloudCacheResult RefResult =
+ Session.GetRef(BlobStoreNamespace, CacheKey.Bucket, CacheKey.Hash, ZenContentType::kCbObject);
+ AppendResult(RefResult, Result);
+ ElapsedSeconds = RefResult.ElapsedSeconds;
+
+ m_Status.SetFromErrorCode(RefResult.ErrorCode, RefResult.Reason);
+
+ if (RefResult.ErrorCode == 0)
+ {
+ const CbValidateError ValidationResult = ValidateCompactBinary(RefResult.Response, CbValidateMode::All);
+ if (ValidationResult == CbValidateError::None)
+ {
+ Record = LoadCompactBinaryObject(RefResult.Response);
+ Record.IterateAttachments([&](CbFieldView AttachmentHash) {
+ CloudCacheResult BlobResult = Session.GetCompressedBlob(BlobStoreNamespace, AttachmentHash.AsHash());
+ AppendResult(BlobResult, Result);
+
+ m_Status.SetFromErrorCode(BlobResult.ErrorCode, BlobResult.Reason);
+
+ if (BlobResult.ErrorCode == 0)
+ {
+ IoHash RawHash;
+ uint64_t RawSize;
+ if (CompressedBuffer Chunk =
+ CompressedBuffer::FromCompressed(SharedBuffer(BlobResult.Response), RawHash, RawSize))
+ {
+ if (RawHash == AttachmentHash.AsHash())
+ {
+ Package.AddAttachment(CbAttachment(Chunk, RawHash));
+ }
+ }
+ }
+ });
+ }
+ }
+ }
+
+ OnComplete(
+ {.Request = *Request, .Record = Record, .Package = Package, .ElapsedSeconds = ElapsedSeconds, .Source = &m_Info});
+ }
+
+ return Result;
+ }
+
+ virtual GetUpstreamCacheSingleResult GetCacheChunk(std::string_view Namespace,
+ const CacheKey&,
+ const IoHash& ValueContentId) override
+ {
+ ZEN_TRACE_CPU("Upstream::Horde::GetSingleCacheChunk");
+
+ try
+ {
+ CloudCacheSession Session(m_Client);
+ std::string_view BlobStoreNamespace = GetActualBlobStoreNamespace(Session, Namespace);
+ const CloudCacheResult Result = Session.GetCompressedBlob(BlobStoreNamespace, ValueContentId);
+
+ m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason);
+
+ if (Result.ErrorCode == 0)
+ {
+ return {.Status = {.Bytes = Result.Bytes, .ElapsedSeconds = Result.ElapsedSeconds, .Success = Result.Success},
+ .Value = Result.Response,
+ .Source = &m_Info};
+ }
+ else
+ {
+ return {.Status = {.Error{.ErrorCode = Result.ErrorCode, .Reason = std::move(Result.Reason)}}};
+ }
+ }
+ catch (std::exception& Err)
+ {
+ m_Status.Set(UpstreamEndpointState::kError, Err.what());
+
+ return {.Status = {.Error{.ErrorCode = -1, .Reason = Err.what()}}};
+ }
+ }
+
+ virtual GetUpstreamCacheResult GetCacheChunks(std::string_view Namespace,
+ std::span<CacheChunkRequest*> CacheChunkRequests,
+ OnCacheChunksGetComplete&& OnComplete) override final
+ {
+ ZEN_TRACE_CPU("Upstream::Horde::GetCacheChunks");
+
+ CloudCacheSession Session(m_Client);
+ GetUpstreamCacheResult Result;
+
+ for (CacheChunkRequest* RequestPtr : CacheChunkRequests)
+ {
+ CacheChunkRequest& Request = *RequestPtr;
+ IoBuffer Payload;
+ IoHash RawHash = IoHash::Zero;
+ uint64_t RawSize = 0;
+
+ double ElapsedSeconds = 0.0;
+ bool IsCompressed = false;
+ if (!Result.Error)
+ {
+ std::string_view BlobStoreNamespace = GetActualBlobStoreNamespace(Session, Namespace);
+ const CloudCacheResult BlobResult =
+ Request.ChunkId == IoHash::Zero
+ ? Session.GetInlineBlob(BlobStoreNamespace, Request.Key.Bucket, Request.Key.Hash, Request.ChunkId)
+ : Session.GetCompressedBlob(BlobStoreNamespace, Request.ChunkId);
+ ElapsedSeconds = BlobResult.ElapsedSeconds;
+ Payload = BlobResult.Response;
+
+ AppendResult(BlobResult, Result);
+
+ m_Status.SetFromErrorCode(BlobResult.ErrorCode, BlobResult.Reason);
+ if (Payload && IsCompressedBinary(Payload.GetContentType()))
+ {
+ IsCompressed = CompressedBuffer::ValidateCompressedHeader(Payload, RawHash, RawSize);
+ }
+ }
+
+ if (IsCompressed)
+ {
+ OnComplete({.Request = Request,
+ .RawHash = RawHash,
+ .RawSize = RawSize,
+ .Value = Payload,
+ .ElapsedSeconds = ElapsedSeconds,
+ .Source = &m_Info});
+ }
+ else
+ {
+ OnComplete({.Request = Request, .RawHash = IoHash::Zero, .RawSize = 0, .Value = IoBuffer()});
+ }
+ }
+
+ return Result;
+ }
+
+ virtual GetUpstreamCacheResult GetCacheValues(std::string_view Namespace,
+ std::span<CacheValueRequest*> CacheValueRequests,
+ OnCacheValueGetComplete&& OnComplete) override final
+ {
+ ZEN_TRACE_CPU("Upstream::Horde::GetCacheValues");
+
+ CloudCacheSession Session(m_Client);
+ GetUpstreamCacheResult Result;
+
+ for (CacheValueRequest* RequestPtr : CacheValueRequests)
+ {
+ CacheValueRequest& Request = *RequestPtr;
+ IoBuffer Payload;
+ IoHash RawHash = IoHash::Zero;
+ uint64_t RawSize = 0;
+
+ double ElapsedSeconds = 0.0;
+ bool IsCompressed = false;
+ if (!Result.Error)
+ {
+ std::string_view BlobStoreNamespace = GetActualBlobStoreNamespace(Session, Namespace);
+ IoHash PayloadHash;
+ const CloudCacheResult BlobResult =
+ Session.GetInlineBlob(BlobStoreNamespace, Request.Key.Bucket, Request.Key.Hash, PayloadHash);
+ ElapsedSeconds = BlobResult.ElapsedSeconds;
+ Payload = BlobResult.Response;
+
+ AppendResult(BlobResult, Result);
+
+ m_Status.SetFromErrorCode(BlobResult.ErrorCode, BlobResult.Reason);
+ if (Payload)
+ {
+ if (IsCompressedBinary(Payload.GetContentType()))
+ {
+ IsCompressed = CompressedBuffer::ValidateCompressedHeader(Payload, RawHash, RawSize) && RawHash != PayloadHash;
+ }
+ else
+ {
+ CompressedBuffer Compressed = CompressedBuffer::Compress(SharedBuffer(Payload));
+ RawHash = Compressed.DecodeRawHash();
+ if (RawHash == PayloadHash)
+ {
+ IsCompressed = true;
+ }
+ else
+ {
+ ZEN_WARN("Horde request for inline payload of {}/{}/{} has hash {}, expected hash {} from header",
+ Namespace,
+ Request.Key.Bucket,
+ Request.Key.Hash.ToHexString(),
+ RawHash.ToHexString(),
+ PayloadHash.ToHexString());
+ }
+ }
+ }
+ }
+
+ if (IsCompressed)
+ {
+ OnComplete({.Request = Request,
+ .RawHash = RawHash,
+ .RawSize = RawSize,
+ .Value = Payload,
+ .ElapsedSeconds = ElapsedSeconds,
+ .Source = &m_Info});
+ }
+ else
+ {
+ OnComplete({.Request = Request, .RawHash = IoHash::Zero, .RawSize = 0, .Value = IoBuffer()});
+ }
+ }
+
+ return Result;
+ }
+
+ virtual PutUpstreamCacheResult PutCacheRecord(const UpstreamCacheRecord& CacheRecord,
+ IoBuffer RecordValue,
+ std::span<IoBuffer const> Values) override
+ {
+ ZEN_TRACE_CPU("Upstream::Horde::PutCacheRecord");
+
+ ZEN_ASSERT(CacheRecord.ValueContentIds.size() == Values.size());
+ const int32_t MaxAttempts = 3;
+
+ try
+ {
+ CloudCacheSession Session(m_Client);
+
+ if (CacheRecord.Type == ZenContentType::kBinary)
+ {
+ CloudCacheResult Result;
+ for (uint32_t Attempt = 0; Attempt < MaxAttempts && !Result.Success; Attempt++)
+ {
+ std::string_view BlobStoreNamespace = GetActualBlobStoreNamespace(Session, CacheRecord.Namespace);
+ Result = Session.PutRef(BlobStoreNamespace,
+ CacheRecord.Key.Bucket,
+ CacheRecord.Key.Hash,
+ RecordValue,
+ ZenContentType::kBinary);
+ }
+
+ m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason);
+
+ return {.Reason = std::move(Result.Reason),
+ .Bytes = Result.Bytes,
+ .ElapsedSeconds = Result.ElapsedSeconds,
+ .Success = Result.Success};
+ }
+ else if (CacheRecord.Type == ZenContentType::kCompressedBinary)
+ {
+ IoHash RawHash;
+ uint64_t RawSize;
+ if (!CompressedBuffer::ValidateCompressedHeader(RecordValue, RawHash, RawSize))
+ {
+ return {.Reason = std::string("Invalid compressed value buffer"), .Success = false};
+ }
+
+ CbObjectWriter ReferencingObject;
+ ReferencingObject.AddBinaryAttachment("RawHash", RawHash);
+ ReferencingObject.AddInteger("RawSize", RawSize);
+
+ return PerformStructuredPut(
+ Session,
+ CacheRecord.Namespace,
+ CacheRecord.Key,
+ ReferencingObject.Save().GetBuffer().AsIoBuffer(),
+ MaxAttempts,
+ [&](const IoHash& ValueContentId, IoBuffer& OutBuffer, std::string& OutReason) {
+ if (ValueContentId != RawHash)
+ {
+ OutReason =
+ fmt::format("Value '{}' MISMATCHED from compressed buffer raw hash {}", ValueContentId, RawHash);
+ return false;
+ }
+
+ OutBuffer = RecordValue;
+ return true;
+ });
+ }
+ else
+ {
+ return PerformStructuredPut(
+ Session,
+ CacheRecord.Namespace,
+ CacheRecord.Key,
+ RecordValue,
+ MaxAttempts,
+ [&](const IoHash& ValueContentId, IoBuffer& OutBuffer, std::string& OutReason) {
+ const auto It =
+ std::find(std::begin(CacheRecord.ValueContentIds), std::end(CacheRecord.ValueContentIds), ValueContentId);
+
+ if (It == std::end(CacheRecord.ValueContentIds))
+ {
+ OutReason = fmt::format("value '{}' MISSING from local cache", ValueContentId);
+ return false;
+ }
+
+ const size_t Idx = std::distance(std::begin(CacheRecord.ValueContentIds), It);
+
+ OutBuffer = Values[Idx];
+ return true;
+ });
+ }
+ }
+ catch (std::exception& Err)
+ {
+ m_Status.Set(UpstreamEndpointState::kError, Err.what());
+
+ return {.Reason = std::string(Err.what()), .Success = false};
+ }
+ }
+
+ virtual UpstreamEndpointStats& Stats() override { return m_Stats; }
+
+ private:
+ static void AppendResult(const CloudCacheResult& Result, GetUpstreamCacheResult& Out)
+ {
+ Out.Success &= Result.Success;
+ Out.Bytes += Result.Bytes;
+ Out.ElapsedSeconds += Result.ElapsedSeconds;
+
+ if (Result.ErrorCode)
+ {
+ Out.Error = {.ErrorCode = Result.ErrorCode, .Reason = std::move(Result.Reason)};
+ }
+ };
+
+ PutUpstreamCacheResult PerformStructuredPut(
+ CloudCacheSession& Session,
+ std::string_view Namespace,
+ const CacheKey& Key,
+ IoBuffer ObjectBuffer,
+ const int32_t MaxAttempts,
+ std::function<bool(const IoHash& ValueContentId, IoBuffer& OutBuffer, std::string& OutReason)>&& BlobFetchFn)
+ {
+ int64_t TotalBytes = 0ull;
+ double TotalElapsedSeconds = 0.0;
+
+ std::string_view BlobStoreNamespace = GetActualBlobStoreNamespace(Session, Namespace);
+ const auto PutBlobs = [&](std::span<IoHash> ValueContentIds, std::string& OutReason) -> bool {
+ for (const IoHash& ValueContentId : ValueContentIds)
+ {
+ IoBuffer BlobBuffer;
+ if (!BlobFetchFn(ValueContentId, BlobBuffer, OutReason))
+ {
+ return false;
+ }
+
+ CloudCacheResult BlobResult;
+ for (int32_t Attempt = 0; Attempt < MaxAttempts && !BlobResult.Success; Attempt++)
+ {
+ BlobResult = Session.PutCompressedBlob(BlobStoreNamespace, ValueContentId, BlobBuffer);
+ }
+
+ m_Status.SetFromErrorCode(BlobResult.ErrorCode, BlobResult.Reason);
+
+ if (!BlobResult.Success)
+ {
+ OutReason = fmt::format("upload value '{}' FAILED, reason '{}'", ValueContentId, BlobResult.Reason);
+ return false;
+ }
+
+ TotalBytes += BlobResult.Bytes;
+ TotalElapsedSeconds += BlobResult.ElapsedSeconds;
+ }
+
+ return true;
+ };
+
+ PutRefResult RefResult;
+ for (int32_t Attempt = 0; Attempt < MaxAttempts && !RefResult.Success; Attempt++)
+ {
+ RefResult = Session.PutRef(BlobStoreNamespace, Key.Bucket, Key.Hash, ObjectBuffer, ZenContentType::kCbObject);
+ }
+
+ m_Status.SetFromErrorCode(RefResult.ErrorCode, RefResult.Reason);
+
+ if (!RefResult.Success)
+ {
+ return {.Reason = fmt::format("upload cache record '{}/{}' FAILED, reason '{}'", Key.Bucket, Key.Hash, RefResult.Reason),
+ .Success = false};
+ }
+
+ TotalBytes += RefResult.Bytes;
+ TotalElapsedSeconds += RefResult.ElapsedSeconds;
+
+ std::string Reason;
+ if (!PutBlobs(RefResult.Needs, Reason))
+ {
+ return {.Reason = std::move(Reason), .Success = false};
+ }
+
+ const IoHash RefHash = IoHash::HashBuffer(ObjectBuffer);
+ FinalizeRefResult FinalizeResult = Session.FinalizeRef(BlobStoreNamespace, Key.Bucket, Key.Hash, RefHash);
+
+ m_Status.SetFromErrorCode(FinalizeResult.ErrorCode, FinalizeResult.Reason);
+
+ if (!FinalizeResult.Success)
+ {
+ return {
+ .Reason = fmt::format("finalize cache record '{}/{}' FAILED, reason '{}'", Key.Bucket, Key.Hash, FinalizeResult.Reason),
+ .Success = false};
+ }
+
+ if (!FinalizeResult.Needs.empty())
+ {
+ if (!PutBlobs(FinalizeResult.Needs, Reason))
+ {
+ return {.Reason = std::move(Reason), .Success = false};
+ }
+
+ FinalizeResult = Session.FinalizeRef(BlobStoreNamespace, Key.Bucket, Key.Hash, RefHash);
+
+ m_Status.SetFromErrorCode(FinalizeResult.ErrorCode, FinalizeResult.Reason);
+
+ if (!FinalizeResult.Success)
+ {
+ return {.Reason = fmt::format("finalize '{}/{}' FAILED, reason '{}'", Key.Bucket, Key.Hash, FinalizeResult.Reason),
+ .Success = false};
+ }
+
+ if (!FinalizeResult.Needs.empty())
+ {
+ ExtendableStringBuilder<256> Sb;
+ for (const IoHash& MissingHash : FinalizeResult.Needs)
+ {
+ Sb << MissingHash.ToHexString() << ",";
+ }
+
+ return {
+ .Reason = fmt::format("finalize '{}/{}' FAILED, still needs value(s) '{}'", Key.Bucket, Key.Hash, Sb.ToString()),
+ .Success = false};
+ }
+ }
+
+ TotalBytes += FinalizeResult.Bytes;
+ TotalElapsedSeconds += FinalizeResult.ElapsedSeconds;
+
+ return {.Bytes = TotalBytes, .ElapsedSeconds = TotalElapsedSeconds, .Success = true};
+ }
+
+ spdlog::logger& Log() { return m_Log; }
+
+ AuthMgr& m_AuthMgr;
+ spdlog::logger& m_Log;
+ UpstreamEndpointInfo m_Info;
+ UpstreamStatus m_Status;
+ UpstreamEndpointStats m_Stats;
+ RefPtr<CloudCacheClient> m_Client;
+ };
+
+ class ZenUpstreamEndpoint final : public UpstreamEndpoint
+ {
+ struct ZenEndpoint
+ {
+ std::string Url;
+ std::string Reason;
+ double Latency{};
+ bool Ok = false;
+
+ bool operator<(const ZenEndpoint& RHS) const { return Ok && RHS.Ok ? Latency < RHS.Latency : Ok; }
+ };
+
+ public:
+ ZenUpstreamEndpoint(const ZenStructuredCacheClientOptions& Options)
+ : m_Log(zen::logging::Get("upstream"))
+ , m_ConnectTimeout(Options.ConnectTimeout)
+ , m_Timeout(Options.Timeout)
+ {
+ ZEN_ASSERT(!Options.Name.empty());
+ m_Info.Name = Options.Name;
+
+ for (const auto& Url : Options.Urls)
+ {
+ m_Endpoints.push_back({.Url = Url});
+ }
+ }
+
+ ~ZenUpstreamEndpoint() = default;
+
+ virtual const UpstreamEndpointInfo& GetEndpointInfo() const override { return m_Info; }
+
+ virtual UpstreamEndpointStatus Initialize() override
+ {
+ try
+ {
+ if (m_Status.EndpointState() == UpstreamEndpointState::kOk)
+ {
+ return {.State = UpstreamEndpointState::kOk};
+ }
+
+ const ZenEndpoint& Ep = GetEndpoint();
+
+ if (m_Info.Url != Ep.Url)
+ {
+ ZEN_INFO("Setting Zen upstream URL to '{}'", Ep.Url);
+ m_Info.Url = Ep.Url;
+ }
+
+ if (Ep.Ok)
+ {
+ RwLock::ExclusiveLockScope _(m_ClientLock);
+ m_Client = new ZenStructuredCacheClient({.Url = m_Info.Url, .ConnectTimeout = m_ConnectTimeout, .Timeout = m_Timeout});
+ m_Status.Set(UpstreamEndpointState::kOk);
+ }
+ else
+ {
+ m_Status.Set(UpstreamEndpointState::kError, Ep.Reason);
+ }
+
+ return m_Status.EndpointStatus();
+ }
+ catch (std::exception& Err)
+ {
+ m_Status.Set(UpstreamEndpointState::kError, Err.what());
+
+ return {.Reason = Err.what(), .State = GetState()};
+ }
+ }
+
+ virtual UpstreamEndpointState GetState() override { return m_Status.EndpointState(); }
+
+ virtual UpstreamEndpointStatus GetStatus() override { return m_Status.EndpointStatus(); }
+
+ virtual GetUpstreamCacheSingleResult GetCacheRecord(std::string_view Namespace,
+ const CacheKey& CacheKey,
+ ZenContentType Type) override
+ {
+ ZEN_TRACE_CPU("Upstream::Zen::GetSingleCacheRecord");
+
+ try
+ {
+ ZenStructuredCacheSession Session(GetClientRef());
+ const ZenCacheResult Result = Session.GetCacheRecord(Namespace, CacheKey.Bucket, CacheKey.Hash, Type);
+
+ m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason);
+
+ if (Result.ErrorCode == 0)
+ {
+ return {.Status = {.Bytes = Result.Bytes, .ElapsedSeconds = Result.ElapsedSeconds, .Success = Result.Success},
+ .Value = Result.Response,
+ .Source = &m_Info};
+ }
+ else
+ {
+ return {.Status = {.Error{.ErrorCode = Result.ErrorCode, .Reason = std::move(Result.Reason)}}};
+ }
+ }
+ catch (std::exception& Err)
+ {
+ m_Status.Set(UpstreamEndpointState::kError, Err.what());
+
+ return {.Status = {.Error{.ErrorCode = -1, .Reason = Err.what()}}};
+ }
+ }
+
+ virtual GetUpstreamCacheResult GetCacheRecords(std::string_view Namespace,
+ std::span<CacheKeyRequest*> Requests,
+ OnCacheRecordGetComplete&& OnComplete) override
+ {
+ ZEN_TRACE_CPU("Upstream::Zen::GetCacheRecords");
+ ZEN_ASSERT(Requests.size() > 0);
+
+ CbObjectWriter BatchRequest;
+ BatchRequest << "Method"sv
+ << "GetCacheRecords"sv;
+ BatchRequest << "Accept"sv << kCbPkgMagic;
+
+ BatchRequest.BeginObject("Params"sv);
+ {
+ CachePolicy DefaultPolicy = Requests[0]->Policy.GetRecordPolicy();
+ BatchRequest << "DefaultPolicy"sv << WriteToString<128>(DefaultPolicy);
+
+ BatchRequest << "Namespace"sv << Namespace;
+
+ BatchRequest.BeginArray("Requests"sv);
+ for (CacheKeyRequest* Request : Requests)
+ {
+ BatchRequest.BeginObject();
+ {
+ const CacheKey& Key = Request->Key;
+ BatchRequest.BeginObject("Key"sv);
+ {
+ BatchRequest << "Bucket"sv << Key.Bucket;
+ BatchRequest << "Hash"sv << Key.Hash;
+ }
+ BatchRequest.EndObject();
+ if (!Request->Policy.IsUniform() || Request->Policy.GetRecordPolicy() != DefaultPolicy)
+ {
+ BatchRequest.SetName("Policy"sv);
+ Request->Policy.Save(BatchRequest);
+ }
+ }
+ BatchRequest.EndObject();
+ }
+ BatchRequest.EndArray();
+ }
+ BatchRequest.EndObject();
+
+ ZenCacheResult Result;
+
+ {
+ ZenStructuredCacheSession Session(GetClientRef());
+ Result = Session.InvokeRpc(BatchRequest.Save());
+ }
+
+ m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason);
+
+ if (Result.Success)
+ {
+ CbPackage BatchResponse;
+ if (ParsePackageMessageWithLegacyFallback(Result.Response, BatchResponse))
+ {
+ CbArrayView Results = BatchResponse.GetObject()["Result"sv].AsArrayView();
+ if (Results.Num() != Requests.size())
+ {
+ ZEN_WARN("Upstream::Zen::GetCacheRecords invalid number of Response results from Upstream.");
+ }
+ else
+ {
+ for (size_t Index = 0; CbFieldView Record : Results)
+ {
+ CacheKeyRequest* Request = Requests[Index++];
+ OnComplete({.Request = *Request,
+ .Record = Record.AsObjectView(),
+ .Package = BatchResponse,
+ .ElapsedSeconds = Result.ElapsedSeconds,
+ .Source = &m_Info});
+ }
+
+ return {.Bytes = Result.Bytes, .ElapsedSeconds = Result.ElapsedSeconds, .Success = true};
+ }
+ }
+ else
+ {
+ ZEN_WARN("Upstream::Zen::GetCacheRecords invalid Response from Upstream.");
+ }
+ }
+
+ for (CacheKeyRequest* Request : Requests)
+ {
+ OnComplete({.Request = *Request, .Record = CbObjectView(), .Package = CbPackage()});
+ }
+
+ return {.Error{.ErrorCode = Result.ErrorCode, .Reason = std::move(Result.Reason)}};
+ }
+
+ virtual GetUpstreamCacheSingleResult GetCacheChunk(std::string_view Namespace,
+ const CacheKey& CacheKey,
+ const IoHash& ValueContentId) override
+ {
+ ZEN_TRACE_CPU("Upstream::Zen::GetCacheChunk");
+
+ try
+ {
+ ZenStructuredCacheSession Session(GetClientRef());
+ const ZenCacheResult Result = Session.GetCacheChunk(Namespace, CacheKey.Bucket, CacheKey.Hash, ValueContentId);
+
+ m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason);
+
+ if (Result.ErrorCode == 0)
+ {
+ return {.Status = {.Bytes = Result.Bytes, .ElapsedSeconds = Result.ElapsedSeconds, .Success = Result.Success},
+ .Value = Result.Response,
+ .Source = &m_Info};
+ }
+ else
+ {
+ return {.Status = {.Error{.ErrorCode = Result.ErrorCode, .Reason = std::move(Result.Reason)}}};
+ }
+ }
+ catch (std::exception& Err)
+ {
+ m_Status.Set(UpstreamEndpointState::kError, Err.what());
+
+ return {.Status = {.Error{.ErrorCode = -1, .Reason = Err.what()}}};
+ }
+ }
+
+ virtual GetUpstreamCacheResult GetCacheValues(std::string_view Namespace,
+ std::span<CacheValueRequest*> CacheValueRequests,
+ OnCacheValueGetComplete&& OnComplete) override final
+ {
+ ZEN_TRACE_CPU("Upstream::Zen::GetCacheValues");
+ ZEN_ASSERT(!CacheValueRequests.empty());
+
+ CbObjectWriter BatchRequest;
+ BatchRequest << "Method"sv
+ << "GetCacheValues"sv;
+ BatchRequest << "Accept"sv << kCbPkgMagic;
+
+ BatchRequest.BeginObject("Params"sv);
+ {
+ CachePolicy DefaultPolicy = CacheValueRequests[0]->Policy;
+ BatchRequest << "DefaultPolicy"sv << WriteToString<128>(DefaultPolicy).ToView();
+ BatchRequest << "Namespace"sv << Namespace;
+
+ BatchRequest.BeginArray("Requests"sv);
+ {
+ for (CacheValueRequest* RequestPtr : CacheValueRequests)
+ {
+ const CacheValueRequest& Request = *RequestPtr;
+
+ BatchRequest.BeginObject();
+ {
+ BatchRequest.BeginObject("Key"sv);
+ BatchRequest << "Bucket"sv << Request.Key.Bucket;
+ BatchRequest << "Hash"sv << Request.Key.Hash;
+ BatchRequest.EndObject();
+ if (Request.Policy != DefaultPolicy)
+ {
+ BatchRequest << "Policy"sv << WriteToString<128>(Request.Policy).ToView();
+ }
+ }
+ BatchRequest.EndObject();
+ }
+ }
+ BatchRequest.EndArray();
+ }
+ BatchRequest.EndObject();
+
+ ZenCacheResult Result;
+
+ {
+ ZenStructuredCacheSession Session(GetClientRef());
+ Result = Session.InvokeRpc(BatchRequest.Save());
+ }
+
+ m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason);
+
+ if (Result.Success)
+ {
+ CbPackage BatchResponse;
+ if (ParsePackageMessageWithLegacyFallback(Result.Response, BatchResponse))
+ {
+ CbArrayView Results = BatchResponse.GetObject()["Result"sv].AsArrayView();
+ if (CacheValueRequests.size() != Results.Num())
+ {
+ ZEN_WARN("Upstream::Zen::GetCacheValues invalid number of Response results from Upstream.");
+ }
+ else
+ {
+ for (size_t RequestIndex = 0; CbFieldView ChunkField : Results)
+ {
+ CacheValueRequest& Request = *CacheValueRequests[RequestIndex++];
+ CbObjectView ChunkObject = ChunkField.AsObjectView();
+ IoHash RawHash = ChunkObject["RawHash"sv].AsHash();
+ IoBuffer Payload;
+ uint64_t RawSize = 0;
+ if (RawHash != IoHash::Zero)
+ {
+ bool Success = false;
+ const CbAttachment* Attachment = BatchResponse.FindAttachment(RawHash);
+ if (Attachment)
+ {
+ if (const CompressedBuffer& Compressed = Attachment->AsCompressedBinary())
+ {
+ Payload = Compressed.GetCompressed().Flatten().AsIoBuffer();
+ Payload.SetContentType(ZenContentType::kCompressedBinary);
+ RawSize = Compressed.DecodeRawSize();
+ Success = true;
+ }
+ }
+ if (!Success)
+ {
+ CbFieldView RawSizeField = ChunkObject["RawSize"sv];
+ RawSize = RawSizeField.AsUInt64();
+ Success = !RawSizeField.HasError();
+ }
+ if (!Success)
+ {
+ RawHash = IoHash::Zero;
+ }
+ }
+ OnComplete({.Request = Request,
+ .RawHash = RawHash,
+ .RawSize = RawSize,
+ .Value = std::move(Payload),
+ .ElapsedSeconds = Result.ElapsedSeconds,
+ .Source = &m_Info});
+ }
+
+ return {.Bytes = Result.Bytes, .ElapsedSeconds = Result.ElapsedSeconds, .Success = true};
+ }
+ }
+ else
+ {
+ ZEN_WARN("Upstream::Zen::GetCacheValues invalid Response from Upstream.");
+ }
+ }
+
+ for (CacheValueRequest* RequestPtr : CacheValueRequests)
+ {
+ OnComplete({.Request = *RequestPtr, .RawHash = IoHash::Zero, .RawSize = 0, .Value = IoBuffer()});
+ }
+
+ return {.Error{.ErrorCode = Result.ErrorCode, .Reason = std::move(Result.Reason)}};
+ }
+
+ virtual GetUpstreamCacheResult GetCacheChunks(std::string_view Namespace,
+ std::span<CacheChunkRequest*> CacheChunkRequests,
+ OnCacheChunksGetComplete&& OnComplete) override final
+ {
+ ZEN_TRACE_CPU("Upstream::Zen::GetCacheChunks");
+ ZEN_ASSERT(!CacheChunkRequests.empty());
+
+ CbObjectWriter BatchRequest;
+ BatchRequest << "Method"sv
+ << "GetCacheChunks"sv;
+ BatchRequest << "Accept"sv << kCbPkgMagic;
+
+ BatchRequest.BeginObject("Params"sv);
+ {
+ CachePolicy DefaultPolicy = CacheChunkRequests[0]->Policy;
+ BatchRequest << "DefaultPolicy"sv << WriteToString<128>(DefaultPolicy).ToView();
+ BatchRequest << "Namespace"sv << Namespace;
+
+ BatchRequest.BeginArray("ChunkRequests"sv);
+ {
+ for (CacheChunkRequest* RequestPtr : CacheChunkRequests)
+ {
+ const CacheChunkRequest& Request = *RequestPtr;
+
+ BatchRequest.BeginObject();
+ {
+ BatchRequest.BeginObject("Key"sv);
+ BatchRequest << "Bucket"sv << Request.Key.Bucket;
+ BatchRequest << "Hash"sv << Request.Key.Hash;
+ BatchRequest.EndObject();
+ if (Request.ValueId)
+ {
+ BatchRequest.AddObjectId("ValueId"sv, Request.ValueId);
+ }
+ if (Request.ChunkId != Request.ChunkId.Zero)
+ {
+ BatchRequest << "ChunkId"sv << Request.ChunkId;
+ }
+ if (Request.RawOffset != 0)
+ {
+ BatchRequest << "RawOffset"sv << Request.RawOffset;
+ }
+ if (Request.RawSize != UINT64_MAX)
+ {
+ BatchRequest << "RawSize"sv << Request.RawSize;
+ }
+ if (Request.Policy != DefaultPolicy)
+ {
+ BatchRequest << "Policy"sv << WriteToString<128>(Request.Policy).ToView();
+ }
+ }
+ BatchRequest.EndObject();
+ }
+ }
+ BatchRequest.EndArray();
+ }
+ BatchRequest.EndObject();
+
+ ZenCacheResult Result;
+
+ {
+ ZenStructuredCacheSession Session(GetClientRef());
+ Result = Session.InvokeRpc(BatchRequest.Save());
+ }
+
+ m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason);
+
+ if (Result.Success)
+ {
+ CbPackage BatchResponse;
+ if (ParsePackageMessageWithLegacyFallback(Result.Response, BatchResponse))
+ {
+ CbArrayView Results = BatchResponse.GetObject()["Result"sv].AsArrayView();
+ if (CacheChunkRequests.size() != Results.Num())
+ {
+ ZEN_WARN("Upstream::Zen::GetCacheChunks invalid number of Response results from Upstream.");
+ }
+ else
+ {
+ for (size_t RequestIndex = 0; CbFieldView ChunkField : Results)
+ {
+ CacheChunkRequest& Request = *CacheChunkRequests[RequestIndex++];
+ CbObjectView ChunkObject = ChunkField.AsObjectView();
+ IoHash RawHash = ChunkObject["RawHash"sv].AsHash();
+ IoBuffer Payload;
+ uint64_t RawSize = 0;
+ if (RawHash != IoHash::Zero)
+ {
+ bool Success = false;
+ const CbAttachment* Attachment = BatchResponse.FindAttachment(RawHash);
+ if (Attachment)
+ {
+ if (const CompressedBuffer& Compressed = Attachment->AsCompressedBinary())
+ {
+ Payload = Compressed.GetCompressed().Flatten().AsIoBuffer();
+ Payload.SetContentType(ZenContentType::kCompressedBinary);
+ RawSize = Compressed.DecodeRawSize();
+ Success = true;
+ }
+ }
+ if (!Success)
+ {
+ CbFieldView RawSizeField = ChunkObject["RawSize"sv];
+ RawSize = RawSizeField.AsUInt64();
+ Success = !RawSizeField.HasError();
+ }
+ if (!Success)
+ {
+ RawHash = IoHash::Zero;
+ }
+ }
+ OnComplete({.Request = Request,
+ .RawHash = RawHash,
+ .RawSize = RawSize,
+ .Value = std::move(Payload),
+ .ElapsedSeconds = Result.ElapsedSeconds,
+ .Source = &m_Info});
+ }
+
+ return {.Bytes = Result.Bytes, .ElapsedSeconds = Result.ElapsedSeconds, .Success = true};
+ }
+ }
+ else
+ {
+ ZEN_WARN("Upstream::Zen::GetCacheChunks invalid Response from Upstream.");
+ }
+ }
+
+ for (CacheChunkRequest* RequestPtr : CacheChunkRequests)
+ {
+ OnComplete({.Request = *RequestPtr, .RawHash = IoHash::Zero, .RawSize = 0, .Value = IoBuffer()});
+ }
+
+ return {.Error{.ErrorCode = Result.ErrorCode, .Reason = std::move(Result.Reason)}};
+ }
+
+ virtual PutUpstreamCacheResult PutCacheRecord(const UpstreamCacheRecord& CacheRecord,
+ IoBuffer RecordValue,
+ std::span<IoBuffer const> Values) override
+ {
+ ZEN_TRACE_CPU("Upstream::Zen::PutCacheRecord");
+
+ ZEN_ASSERT(CacheRecord.ValueContentIds.size() == Values.size());
+ const int32_t MaxAttempts = 3;
+
+ try
+ {
+ ZenStructuredCacheSession Session(GetClientRef());
+ ZenCacheResult Result;
+ int64_t TotalBytes = 0ull;
+ double TotalElapsedSeconds = 0.0;
+
+ if (CacheRecord.Type == ZenContentType::kCbPackage)
+ {
+ CbPackage Package;
+ Package.SetObject(CbObject(SharedBuffer(RecordValue)));
+
+ for (const IoBuffer& Value : Values)
+ {
+ IoHash RawHash;
+ uint64_t RawSize;
+ if (CompressedBuffer AttachmentBuffer = CompressedBuffer::FromCompressed(SharedBuffer(Value), RawHash, RawSize))
+ {
+ Package.AddAttachment(CbAttachment(AttachmentBuffer, RawHash));
+ }
+ else
+ {
+ return {.Reason = std::string("Invalid value buffer"), .Success = false};
+ }
+ }
+
+ BinaryWriter MemStream;
+ Package.Save(MemStream);
+ IoBuffer PackagePayload(IoBuffer::Wrap, MemStream.Data(), MemStream.Size());
+
+ for (uint32_t Attempt = 0; Attempt < MaxAttempts && !Result.Success; Attempt++)
+ {
+ Result = Session.PutCacheRecord(CacheRecord.Namespace,
+ CacheRecord.Key.Bucket,
+ CacheRecord.Key.Hash,
+ PackagePayload,
+ CacheRecord.Type);
+ }
+
+ m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason);
+
+ TotalBytes = Result.Bytes;
+ TotalElapsedSeconds = Result.ElapsedSeconds;
+ }
+ else if (CacheRecord.Type == ZenContentType::kCompressedBinary)
+ {
+ IoHash RawHash;
+ uint64_t RawSize;
+ CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(RecordValue), RawHash, RawSize);
+ if (!Compressed)
+ {
+ return {.Reason = std::string("Invalid value compressed buffer"), .Success = false};
+ }
+
+ CbPackage BatchPackage;
+ CbObjectWriter BatchWriter;
+ BatchWriter << "Method"sv
+ << "PutCacheValues"sv;
+ BatchWriter << "Accept"sv << kCbPkgMagic;
+
+ BatchWriter.BeginObject("Params"sv);
+ {
+ // DefaultPolicy unspecified and expected to be Default
+
+ BatchWriter << "Namespace"sv << CacheRecord.Namespace;
+
+ BatchWriter.BeginArray("Requests"sv);
+ {
+ BatchWriter.BeginObject();
+ {
+ const CacheKey& Key = CacheRecord.Key;
+ BatchWriter.BeginObject("Key"sv);
+ {
+ BatchWriter << "Bucket"sv << Key.Bucket;
+ BatchWriter << "Hash"sv << Key.Hash;
+ }
+ BatchWriter.EndObject();
+ // Policy unspecified and expected to be Default
+ BatchWriter.AddBinaryAttachment("RawHash"sv, RawHash);
+ BatchPackage.AddAttachment(CbAttachment(Compressed, RawHash));
+ }
+ BatchWriter.EndObject();
+ }
+ BatchWriter.EndArray();
+ }
+ BatchWriter.EndObject();
+ BatchPackage.SetObject(BatchWriter.Save());
+
+ Result.Success = false;
+ for (uint32_t Attempt = 0; Attempt < MaxAttempts && !Result.Success; Attempt++)
+ {
+ Result = Session.InvokeRpc(BatchPackage);
+ }
+
+ m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason);
+
+ TotalBytes += Result.Bytes;
+ TotalElapsedSeconds += Result.ElapsedSeconds;
+ }
+ else
+ {
+ for (size_t Idx = 0, Count = Values.size(); Idx < Count; Idx++)
+ {
+ Result.Success = false;
+ for (uint32_t Attempt = 0; Attempt < MaxAttempts && !Result.Success; Attempt++)
+ {
+ Result = Session.PutCacheValue(CacheRecord.Namespace,
+ CacheRecord.Key.Bucket,
+ CacheRecord.Key.Hash,
+ CacheRecord.ValueContentIds[Idx],
+ Values[Idx]);
+ }
+
+ m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason);
+
+ TotalBytes += Result.Bytes;
+ TotalElapsedSeconds += Result.ElapsedSeconds;
+
+ if (!Result.Success)
+ {
+ return {.Reason = "Failed to upload value",
+ .Bytes = TotalBytes,
+ .ElapsedSeconds = TotalElapsedSeconds,
+ .Success = false};
+ }
+ }
+
+ Result.Success = false;
+ for (uint32_t Attempt = 0; Attempt < MaxAttempts && !Result.Success; Attempt++)
+ {
+ Result = Session.PutCacheRecord(CacheRecord.Namespace,
+ CacheRecord.Key.Bucket,
+ CacheRecord.Key.Hash,
+ RecordValue,
+ CacheRecord.Type);
+ }
+
+ m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason);
+
+ TotalBytes += Result.Bytes;
+ TotalElapsedSeconds += Result.ElapsedSeconds;
+ }
+
+ return {.Reason = std::move(Result.Reason),
+ .Bytes = TotalBytes,
+ .ElapsedSeconds = TotalElapsedSeconds,
+ .Success = Result.Success};
+ }
+ catch (std::exception& Err)
+ {
+ m_Status.Set(UpstreamEndpointState::kError, Err.what());
+
+ return {.Reason = std::string(Err.what()), .Success = false};
+ }
+ }
+
+ virtual UpstreamEndpointStats& Stats() override { return m_Stats; }
+
+ private:
+ Ref<ZenStructuredCacheClient> GetClientRef()
+ {
+ // m_Client can be modified at any time by a different thread.
+ // Make sure we safely bump the refcount inside a scope lock
+ RwLock::SharedLockScope _(m_ClientLock);
+ ZEN_ASSERT(m_Client);
+ Ref<ZenStructuredCacheClient> ClientRef(m_Client);
+ _.ReleaseNow();
+ return ClientRef;
+ }
+
+ const ZenEndpoint& GetEndpoint()
+ {
+ for (ZenEndpoint& Ep : m_Endpoints)
+ {
+ Ref<ZenStructuredCacheClient> Client(
+ new ZenStructuredCacheClient({.Url = Ep.Url, .ConnectTimeout = std::chrono::milliseconds(1000)}));
+ ZenStructuredCacheSession Session(std::move(Client));
+ const int32_t SampleCount = 2;
+
+ Ep.Ok = false;
+ Ep.Latency = {};
+
+ for (int32_t Sample = 0; Sample < SampleCount; ++Sample)
+ {
+ ZenCacheResult Result = Session.CheckHealth();
+ Ep.Ok = Result.Success;
+ Ep.Reason = std::move(Result.Reason);
+ Ep.Latency += Result.ElapsedSeconds;
+ }
+ Ep.Latency /= double(SampleCount);
+ }
+
+ std::sort(std::begin(m_Endpoints), std::end(m_Endpoints));
+
+ for (const auto& Ep : m_Endpoints)
+ {
+ ZEN_INFO("ping 'Zen' endpoint '{}' latency '{:.3}s' {}", Ep.Url, Ep.Latency, Ep.Ok ? "OK" : Ep.Reason);
+ }
+
+ return m_Endpoints.front();
+ }
+
+ spdlog::logger& Log() { return m_Log; }
+
+ spdlog::logger& m_Log;
+ UpstreamEndpointInfo m_Info;
+ UpstreamStatus m_Status;
+ UpstreamEndpointStats m_Stats;
+ std::vector<ZenEndpoint> m_Endpoints;
+ std::chrono::milliseconds m_ConnectTimeout;
+ std::chrono::milliseconds m_Timeout;
+ RwLock m_ClientLock;
+ RefPtr<ZenStructuredCacheClient> m_Client;
+ };
+
+} // namespace detail
+
+//////////////////////////////////////////////////////////////////////////
+
+class UpstreamCacheImpl final : public UpstreamCache
+{
+public:
+ UpstreamCacheImpl(const UpstreamCacheOptions& Options, ZenCacheStore& CacheStore, CidStore& CidStore)
+ : m_Log(logging::Get("upstream"))
+ , m_Options(Options)
+ , m_CacheStore(CacheStore)
+ , m_CidStore(CidStore)
+ {
+ }
+
+ virtual ~UpstreamCacheImpl() { Shutdown(); }
+
+ virtual void Initialize() override
+ {
+ for (uint32_t Idx = 0; Idx < m_Options.ThreadCount; Idx++)
+ {
+ m_UpstreamThreads.emplace_back(&UpstreamCacheImpl::ProcessUpstreamQueue, this);
+ }
+
+ m_EndpointMonitorThread = std::thread(&UpstreamCacheImpl::MonitorEndpoints, this);
+ m_RunState.IsRunning = true;
+ }
+
+ virtual void RegisterEndpoint(std::unique_ptr<UpstreamEndpoint> Endpoint) override
+ {
+ const UpstreamEndpointStatus Status = Endpoint->Initialize();
+ const UpstreamEndpointInfo& Info = Endpoint->GetEndpointInfo();
+
+ if (Status.State == UpstreamEndpointState::kOk)
+ {
+ ZEN_INFO("register endpoint '{} - {}' {}", Info.Name, Info.Url, ToString(Status.State));
+ }
+ else
+ {
+ ZEN_WARN("register endpoint '{} - {}' {}", Info.Name, Info.Url, ToString(Status.State));
+ }
+
+ // Register endpoint even if it fails, the health monitor thread will probe failing endpoint(s)
+ std::unique_lock<std::shared_mutex> _(m_EndpointsMutex);
+ m_Endpoints.emplace_back(std::move(Endpoint));
+ }
+
+ virtual void IterateEndpoints(std::function<bool(UpstreamEndpoint&)>&& Fn) override
+ {
+ std::shared_lock<std::shared_mutex> _(m_EndpointsMutex);
+
+ for (auto& Ep : m_Endpoints)
+ {
+ if (!Fn(*Ep))
+ {
+ break;
+ }
+ }
+ }
+
+ virtual GetUpstreamCacheSingleResult GetCacheRecord(std::string_view Namespace, const CacheKey& CacheKey, ZenContentType Type) override
+ {
+ ZEN_TRACE_CPU("Upstream::GetCacheRecord");
+
+ std::shared_lock<std::shared_mutex> _(m_EndpointsMutex);
+
+ if (m_Options.ReadUpstream)
+ {
+ for (auto& Endpoint : m_Endpoints)
+ {
+ if (Endpoint->GetState() != UpstreamEndpointState::kOk)
+ {
+ continue;
+ }
+
+ UpstreamEndpointStats& Stats = Endpoint->Stats();
+ metrics::OperationTiming::Scope Scope(Stats.CacheGetRequestTiming);
+ GetUpstreamCacheSingleResult Result = Endpoint->GetCacheRecord(Namespace, CacheKey, Type);
+ Scope.Stop();
+
+ Stats.CacheGetCount.Increment(1);
+ Stats.CacheGetTotalBytes.Increment(Result.Status.Bytes);
+
+ if (Result.Status.Success)
+ {
+ Stats.CacheHitCount.Increment(1);
+
+ return Result;
+ }
+
+ if (Result.Status.Error)
+ {
+ Stats.CacheErrorCount.Increment(1);
+
+ ZEN_WARN("get cache record FAILED, endpoint '{}', reason '{}', error code '{}'",
+ Endpoint->GetEndpointInfo().Url,
+ Result.Status.Error.Reason,
+ Result.Status.Error.ErrorCode);
+ }
+ }
+ }
+
+ return {};
+ }
+
+ virtual void GetCacheRecords(std::string_view Namespace,
+ std::span<CacheKeyRequest*> Requests,
+ OnCacheRecordGetComplete&& OnComplete) override final
+ {
+ ZEN_TRACE_CPU("Upstream::GetCacheRecords");
+
+ std::shared_lock<std::shared_mutex> _(m_EndpointsMutex);
+
+ std::vector<CacheKeyRequest*> RemainingKeys(Requests.begin(), Requests.end());
+
+ if (m_Options.ReadUpstream)
+ {
+ for (auto& Endpoint : m_Endpoints)
+ {
+ if (RemainingKeys.empty())
+ {
+ break;
+ }
+
+ if (Endpoint->GetState() != UpstreamEndpointState::kOk)
+ {
+ continue;
+ }
+
+ UpstreamEndpointStats& Stats = Endpoint->Stats();
+ std::vector<CacheKeyRequest*> Missing;
+ GetUpstreamCacheResult Result;
+ {
+ metrics::OperationTiming::Scope Scope(Stats.CacheGetRequestTiming);
+
+ Result = Endpoint->GetCacheRecords(Namespace, RemainingKeys, [&](CacheRecordGetCompleteParams&& Params) {
+ if (Params.Record)
+ {
+ OnComplete(std::forward<CacheRecordGetCompleteParams>(Params));
+
+ Stats.CacheHitCount.Increment(1);
+ }
+ else
+ {
+ Missing.push_back(&Params.Request);
+ }
+ });
+ }
+
+ Stats.CacheGetCount.Increment(int64_t(RemainingKeys.size()));
+ Stats.CacheGetTotalBytes.Increment(Result.Bytes);
+
+ if (Result.Error)
+ {
+ Stats.CacheErrorCount.Increment(1);
+
+ ZEN_WARN("get cache record(s) (rpc) FAILED, endpoint '{}', reason '{}', error code '{}'",
+ Endpoint->GetEndpointInfo().Url,
+ Result.Error.Reason,
+ Result.Error.ErrorCode);
+ }
+
+ RemainingKeys = std::move(Missing);
+ }
+ }
+
+ const UpstreamEndpointInfo Info;
+ for (CacheKeyRequest* Request : RemainingKeys)
+ {
+ OnComplete({.Request = *Request, .Record = CbObjectView(), .Package = CbPackage()});
+ }
+ }
+
+ virtual void GetCacheChunks(std::string_view Namespace,
+ std::span<CacheChunkRequest*> CacheChunkRequests,
+ OnCacheChunksGetComplete&& OnComplete) override final
+ {
+ ZEN_TRACE_CPU("Upstream::GetCacheChunks");
+
+ std::shared_lock<std::shared_mutex> _(m_EndpointsMutex);
+
+ std::vector<CacheChunkRequest*> RemainingKeys(CacheChunkRequests.begin(), CacheChunkRequests.end());
+
+ if (m_Options.ReadUpstream)
+ {
+ for (auto& Endpoint : m_Endpoints)
+ {
+ if (RemainingKeys.empty())
+ {
+ break;
+ }
+
+ if (Endpoint->GetState() != UpstreamEndpointState::kOk)
+ {
+ continue;
+ }
+
+ UpstreamEndpointStats& Stats = Endpoint->Stats();
+ std::vector<CacheChunkRequest*> Missing;
+ GetUpstreamCacheResult Result;
+ {
+ metrics::OperationTiming::Scope Scope(Endpoint->Stats().CacheGetRequestTiming);
+
+ Result = Endpoint->GetCacheChunks(Namespace, RemainingKeys, [&](CacheChunkGetCompleteParams&& Params) {
+ if (Params.RawHash != Params.RawHash.Zero)
+ {
+ OnComplete(std::forward<CacheChunkGetCompleteParams>(Params));
+
+ Stats.CacheHitCount.Increment(1);
+ }
+ else
+ {
+ Missing.push_back(&Params.Request);
+ }
+ });
+ }
+
+ Stats.CacheGetCount.Increment(int64_t(RemainingKeys.size()));
+ Stats.CacheGetTotalBytes.Increment(Result.Bytes);
+
+ if (Result.Error)
+ {
+ Stats.CacheErrorCount.Increment(1);
+
+ ZEN_WARN("get cache chunks(s) (rpc) FAILED, endpoint '{}', reason '{}', error code '{}'",
+ Endpoint->GetEndpointInfo().Url,
+ Result.Error.Reason,
+ Result.Error.ErrorCode);
+ }
+
+ RemainingKeys = std::move(Missing);
+ }
+ }
+
+ const UpstreamEndpointInfo Info;
+ for (CacheChunkRequest* RequestPtr : RemainingKeys)
+ {
+ OnComplete({.Request = *RequestPtr, .RawHash = IoHash::Zero, .RawSize = 0, .Value = IoBuffer()});
+ }
+ }
+
+ virtual GetUpstreamCacheSingleResult GetCacheChunk(std::string_view Namespace,
+ const CacheKey& CacheKey,
+ const IoHash& ValueContentId) override
+ {
+ ZEN_TRACE_CPU("Upstream::GetCacheChunk");
+
+ if (m_Options.ReadUpstream)
+ {
+ for (auto& Endpoint : m_Endpoints)
+ {
+ if (Endpoint->GetState() != UpstreamEndpointState::kOk)
+ {
+ continue;
+ }
+
+ UpstreamEndpointStats& Stats = Endpoint->Stats();
+ metrics::OperationTiming::Scope Scope(Stats.CacheGetRequestTiming);
+ GetUpstreamCacheSingleResult Result = Endpoint->GetCacheChunk(Namespace, CacheKey, ValueContentId);
+ Scope.Stop();
+
+ Stats.CacheGetCount.Increment(1);
+ Stats.CacheGetTotalBytes.Increment(Result.Status.Bytes);
+
+ if (Result.Status.Success)
+ {
+ Stats.CacheHitCount.Increment(1);
+
+ return Result;
+ }
+
+ if (Result.Status.Error)
+ {
+ Stats.CacheErrorCount.Increment(1);
+
+ ZEN_WARN("get cache chunk FAILED, endpoint '{}', reason '{}', error code '{}'",
+ Endpoint->GetEndpointInfo().Url,
+ Result.Status.Error.Reason,
+ Result.Status.Error.ErrorCode);
+ }
+ }
+ }
+
+ return {};
+ }
+
+ virtual void GetCacheValues(std::string_view Namespace,
+ std::span<CacheValueRequest*> CacheValueRequests,
+ OnCacheValueGetComplete&& OnComplete) override final
+ {
+ ZEN_TRACE_CPU("Upstream::GetCacheValues");
+
+ std::shared_lock<std::shared_mutex> _(m_EndpointsMutex);
+
+ std::vector<CacheValueRequest*> RemainingKeys(CacheValueRequests.begin(), CacheValueRequests.end());
+
+ if (m_Options.ReadUpstream)
+ {
+ for (auto& Endpoint : m_Endpoints)
+ {
+ if (RemainingKeys.empty())
+ {
+ break;
+ }
+
+ if (Endpoint->GetState() != UpstreamEndpointState::kOk)
+ {
+ continue;
+ }
+
+ UpstreamEndpointStats& Stats = Endpoint->Stats();
+ std::vector<CacheValueRequest*> Missing;
+ GetUpstreamCacheResult Result;
+ {
+ metrics::OperationTiming::Scope Scope(Endpoint->Stats().CacheGetRequestTiming);
+
+ Result = Endpoint->GetCacheValues(Namespace, RemainingKeys, [&](CacheValueGetCompleteParams&& Params) {
+ if (Params.RawHash != Params.RawHash.Zero)
+ {
+ OnComplete(std::forward<CacheValueGetCompleteParams>(Params));
+
+ Stats.CacheHitCount.Increment(1);
+ }
+ else
+ {
+ Missing.push_back(&Params.Request);
+ }
+ });
+ }
+
+ Stats.CacheGetCount.Increment(int64_t(RemainingKeys.size()));
+ Stats.CacheGetTotalBytes.Increment(Result.Bytes);
+
+ if (Result.Error)
+ {
+ Stats.CacheErrorCount.Increment(1);
+
+ ZEN_WARN("get cache values(s) (rpc) FAILED, endpoint '{}', reason '{}', error code '{}'",
+ Endpoint->GetEndpointInfo().Url,
+ Result.Error.Reason,
+ Result.Error.ErrorCode);
+ }
+
+ RemainingKeys = std::move(Missing);
+ }
+ }
+
+ const UpstreamEndpointInfo Info;
+ for (CacheValueRequest* RequestPtr : RemainingKeys)
+ {
+ OnComplete({.Request = *RequestPtr, .RawHash = IoHash::Zero, .RawSize = 0, .Value = IoBuffer()});
+ }
+ }
+
+ virtual void EnqueueUpstream(UpstreamCacheRecord CacheRecord) override
+ {
+ if (m_RunState.IsRunning && m_Options.WriteUpstream && m_Endpoints.size() > 0)
+ {
+ if (!m_UpstreamThreads.empty())
+ {
+ m_UpstreamQueue.Enqueue(std::move(CacheRecord));
+ }
+ else
+ {
+ ProcessCacheRecord(std::move(CacheRecord));
+ }
+ }
+ }
+
+ virtual void GetStatus(CbObjectWriter& Status) override
+ {
+ Status << "reading" << m_Options.ReadUpstream;
+ Status << "writing" << m_Options.WriteUpstream;
+ Status << "worker_threads" << m_Options.ThreadCount;
+ Status << "queue_count" << m_UpstreamQueue.Size();
+
+ Status.BeginArray("endpoints");
+ for (const auto& Ep : m_Endpoints)
+ {
+ const UpstreamEndpointInfo& EpInfo = Ep->GetEndpointInfo();
+ const UpstreamEndpointStatus EpStatus = Ep->GetStatus();
+ UpstreamEndpointStats& EpStats = Ep->Stats();
+
+ Status.BeginObject();
+ Status << "name" << EpInfo.Name;
+ Status << "url" << EpInfo.Url;
+ Status << "state" << ToString(EpStatus.State);
+ Status << "reason" << EpStatus.Reason;
+
+ Status.BeginObject("cache"sv);
+ {
+ const int64_t GetCount = EpStats.CacheGetCount.Value();
+ const int64_t HitCount = EpStats.CacheHitCount.Value();
+ const int64_t ErrorCount = EpStats.CacheErrorCount.Value();
+ const double HitRatio = GetCount > 0 ? double(HitCount) / double(GetCount) : 0.0;
+ const double ErrorRatio = GetCount > 0 ? double(ErrorCount) / double(GetCount) : 0.0;
+
+ metrics::EmitSnapshot("get_requests"sv, EpStats.CacheGetRequestTiming, Status);
+ Status << "get_bytes" << EpStats.CacheGetTotalBytes.Value();
+ Status << "get_count" << GetCount;
+ Status << "hit_count" << HitCount;
+ Status << "hit_ratio" << HitRatio;
+ Status << "error_count" << ErrorCount;
+ Status << "error_ratio" << ErrorRatio;
+ metrics::EmitSnapshot("put_requests"sv, EpStats.CachePutRequestTiming, Status);
+ Status << "put_bytes" << EpStats.CachePutTotalBytes.Value();
+ }
+ Status.EndObject();
+
+ Status.EndObject();
+ }
+ Status.EndArray();
+ }
+
+private:
+ void ProcessCacheRecord(UpstreamCacheRecord CacheRecord)
+ {
+ ZEN_TRACE_CPU("Upstream::ProcessCacheRecord");
+
+ ZenCacheValue CacheValue;
+ std::vector<IoBuffer> Payloads;
+
+ if (!m_CacheStore.Get(CacheRecord.Namespace, CacheRecord.Key.Bucket, CacheRecord.Key.Hash, CacheValue))
+ {
+ ZEN_WARN("process upstream FAILED, '{}/{}/{}', cache record doesn't exist",
+ CacheRecord.Namespace,
+ CacheRecord.Key.Bucket,
+ CacheRecord.Key.Hash);
+ return;
+ }
+
+ for (const IoHash& ValueContentId : CacheRecord.ValueContentIds)
+ {
+ if (IoBuffer Payload = m_CidStore.FindChunkByCid(ValueContentId))
+ {
+ Payloads.push_back(Payload);
+ }
+ else
+ {
+ ZEN_WARN("process upstream FAILED, '{}/{}/{}/{}', ValueContentId doesn't exist in CAS",
+ CacheRecord.Namespace,
+ CacheRecord.Key.Bucket,
+ CacheRecord.Key.Hash,
+ ValueContentId);
+ return;
+ }
+ }
+
+ std::shared_lock<std::shared_mutex> _(m_EndpointsMutex);
+
+ for (auto& Endpoint : m_Endpoints)
+ {
+ if (Endpoint->GetState() != UpstreamEndpointState::kOk)
+ {
+ continue;
+ }
+
+ UpstreamEndpointStats& Stats = Endpoint->Stats();
+ PutUpstreamCacheResult Result;
+ {
+ metrics::OperationTiming::Scope Scope(Stats.CachePutRequestTiming);
+ Result = Endpoint->PutCacheRecord(CacheRecord, CacheValue.Value, std::span(Payloads));
+ }
+
+ Stats.CachePutTotalBytes.Increment(Result.Bytes);
+
+ if (!Result.Success)
+ {
+ ZEN_WARN("upload cache record '{}/{}/{}' FAILED, endpoint '{}', reason '{}'",
+ CacheRecord.Namespace,
+ CacheRecord.Key.Bucket,
+ CacheRecord.Key.Hash,
+ Endpoint->GetEndpointInfo().Url,
+ Result.Reason);
+ }
+ }
+ }
+
+ void ProcessUpstreamQueue()
+ {
+ for (;;)
+ {
+ UpstreamCacheRecord CacheRecord;
+ if (m_UpstreamQueue.WaitAndDequeue(CacheRecord))
+ {
+ try
+ {
+ ProcessCacheRecord(std::move(CacheRecord));
+ }
+ catch (std::exception& Err)
+ {
+ ZEN_ERROR("upload cache record '{}/{}/{}' FAILED, reason '{}'",
+ CacheRecord.Namespace,
+ CacheRecord.Key.Bucket,
+ CacheRecord.Key.Hash,
+ Err.what());
+ }
+ }
+
+ if (!m_RunState.IsRunning)
+ {
+ break;
+ }
+ }
+ }
+
+ void MonitorEndpoints()
+ {
+ for (;;)
+ {
+ {
+ std::unique_lock lk(m_RunState.Mutex);
+ if (m_RunState.ExitSignal.wait_for(lk, m_Options.HealthCheckInterval, [this]() { return !m_RunState.IsRunning.load(); }))
+ {
+ break;
+ }
+ }
+
+ try
+ {
+ std::vector<UpstreamEndpoint*> Endpoints;
+
+ {
+ std::shared_lock<std::shared_mutex> _(m_EndpointsMutex);
+
+ for (auto& Endpoint : m_Endpoints)
+ {
+ UpstreamEndpointState State = Endpoint->GetState();
+ if (State == UpstreamEndpointState::kError)
+ {
+ Endpoints.push_back(Endpoint.get());
+ ZEN_WARN("HEALTH - endpoint '{} - {}' is in error state '{}'",
+ Endpoint->GetEndpointInfo().Name,
+ Endpoint->GetEndpointInfo().Url,
+ Endpoint->GetStatus().Reason);
+ }
+ if (State == UpstreamEndpointState::kUnauthorized)
+ {
+ Endpoints.push_back(Endpoint.get());
+ }
+ }
+ }
+
+ for (auto& Endpoint : Endpoints)
+ {
+ const UpstreamEndpointInfo& Info = Endpoint->GetEndpointInfo();
+ const UpstreamEndpointStatus Status = Endpoint->Initialize();
+
+ if (Status.State == UpstreamEndpointState::kOk)
+ {
+ ZEN_INFO("HEALTH - endpoint '{} - {}' Ok", Info.Name, Info.Url);
+ }
+ else
+ {
+ const std::string Reason = Status.Reason.empty() ? "" : fmt::format(", reason '{}'", Status.Reason);
+ ZEN_WARN("HEALTH - endpoint '{} - {}' {} {}", Info.Name, Info.Url, ToString(Status.State), Reason);
+ }
+ }
+ }
+ catch (std::exception& Err)
+ {
+ ZEN_ERROR("check endpoint(s) health FAILED, reason '{}'", Err.what());
+ }
+ }
+ }
+
+ void Shutdown()
+ {
+ if (m_RunState.Stop())
+ {
+ m_UpstreamQueue.CompleteAdding();
+ for (std::thread& Thread : m_UpstreamThreads)
+ {
+ Thread.join();
+ }
+
+ m_EndpointMonitorThread.join();
+ m_UpstreamThreads.clear();
+ m_Endpoints.clear();
+ }
+ }
+
+ spdlog::logger& Log() { return m_Log; }
+
+ using UpstreamQueue = BlockingQueue<UpstreamCacheRecord>;
+
+ struct RunState
+ {
+ std::mutex Mutex;
+ std::condition_variable ExitSignal;
+ std::atomic_bool IsRunning{false};
+
+ bool Stop()
+ {
+ bool Stopped = false;
+ {
+ std::lock_guard _(Mutex);
+ Stopped = IsRunning.exchange(false);
+ }
+ if (Stopped)
+ {
+ ExitSignal.notify_all();
+ }
+ return Stopped;
+ }
+ };
+
+ spdlog::logger& m_Log;
+ UpstreamCacheOptions m_Options;
+ ZenCacheStore& m_CacheStore;
+ CidStore& m_CidStore;
+ UpstreamQueue m_UpstreamQueue;
+ std::shared_mutex m_EndpointsMutex;
+ std::vector<std::unique_ptr<UpstreamEndpoint>> m_Endpoints;
+ std::vector<std::thread> m_UpstreamThreads;
+ std::thread m_EndpointMonitorThread;
+ RunState m_RunState;
+};
+
+//////////////////////////////////////////////////////////////////////////
+
+std::unique_ptr<UpstreamEndpoint>
+UpstreamEndpoint::CreateZenEndpoint(const ZenStructuredCacheClientOptions& Options)
+{
+ return std::make_unique<detail::ZenUpstreamEndpoint>(Options);
+}
+
+std::unique_ptr<UpstreamEndpoint>
+UpstreamEndpoint::CreateJupiterEndpoint(const CloudCacheClientOptions& Options, const UpstreamAuthConfig& AuthConfig, AuthMgr& Mgr)
+{
+ return std::make_unique<detail::JupiterUpstreamEndpoint>(Options, AuthConfig, Mgr);
+}
+
+std::unique_ptr<UpstreamCache>
+UpstreamCache::Create(const UpstreamCacheOptions& Options, ZenCacheStore& CacheStore, CidStore& CidStore)
+{
+ return std::make_unique<UpstreamCacheImpl>(Options, CacheStore, CidStore);
+}
+
+} // namespace zen
diff --git a/src/zenserver/upstream/upstreamcache.h b/src/zenserver/upstream/upstreamcache.h
new file mode 100644
index 000000000..695c06b32
--- /dev/null
+++ b/src/zenserver/upstream/upstreamcache.h
@@ -0,0 +1,252 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/compactbinary.h>
+#include <zencore/compress.h>
+#include <zencore/iobuffer.h>
+#include <zencore/iohash.h>
+#include <zencore/stats.h>
+#include <zencore/zencore.h>
+#include <zenutil/cache/cache.h>
+
+#include <atomic>
+#include <chrono>
+#include <functional>
+#include <memory>
+#include <vector>
+
+namespace zen {
+
+class CbObjectView;
+class AuthMgr;
+class CbObjectView;
+class CbPackage;
+class CbObjectWriter;
+class CidStore;
+class ZenCacheStore;
+struct CloudCacheClientOptions;
+class CloudCacheTokenProvider;
+struct ZenStructuredCacheClientOptions;
+
+struct UpstreamCacheRecord
+{
+ ZenContentType Type = ZenContentType::kBinary;
+ std::string Namespace;
+ CacheKey Key;
+ std::vector<IoHash> ValueContentIds;
+};
+
+struct UpstreamCacheOptions
+{
+ std::chrono::seconds HealthCheckInterval{5};
+ uint32_t ThreadCount = 4;
+ bool ReadUpstream = true;
+ bool WriteUpstream = true;
+};
+
+struct UpstreamError
+{
+ int32_t ErrorCode{};
+ std::string Reason{};
+
+ explicit operator bool() const { return ErrorCode != 0; }
+};
+
+struct UpstreamEndpointInfo
+{
+ std::string Name;
+ std::string Url;
+};
+
+struct GetUpstreamCacheResult
+{
+ UpstreamError Error{};
+ int64_t Bytes{};
+ double ElapsedSeconds{};
+ bool Success = false;
+};
+
+struct GetUpstreamCacheSingleResult
+{
+ GetUpstreamCacheResult Status;
+ IoBuffer Value;
+ const UpstreamEndpointInfo* Source = nullptr;
+};
+
+struct PutUpstreamCacheResult
+{
+ std::string Reason;
+ int64_t Bytes{};
+ double ElapsedSeconds{};
+ bool Success = false;
+};
+
+struct CacheRecordGetCompleteParams
+{
+ CacheKeyRequest& Request;
+ const CbObjectView& Record;
+ const CbPackage& Package;
+ double ElapsedSeconds{};
+ const UpstreamEndpointInfo* Source = nullptr;
+};
+
+using OnCacheRecordGetComplete = std::function<void(CacheRecordGetCompleteParams&&)>;
+
+struct CacheValueGetCompleteParams
+{
+ CacheValueRequest& Request;
+ IoHash RawHash;
+ uint64_t RawSize;
+ IoBuffer Value;
+ double ElapsedSeconds{};
+ const UpstreamEndpointInfo* Source = nullptr;
+};
+
+using OnCacheValueGetComplete = std::function<void(CacheValueGetCompleteParams&&)>;
+
+struct CacheChunkGetCompleteParams
+{
+ CacheChunkRequest& Request;
+ IoHash RawHash;
+ uint64_t RawSize;
+ IoBuffer Value;
+ double ElapsedSeconds{};
+ const UpstreamEndpointInfo* Source = nullptr;
+};
+
+using OnCacheChunksGetComplete = std::function<void(CacheChunkGetCompleteParams&&)>;
+
+struct UpstreamEndpointStats
+{
+ metrics::OperationTiming CacheGetRequestTiming;
+ metrics::OperationTiming CachePutRequestTiming;
+ metrics::Counter CacheGetTotalBytes;
+ metrics::Counter CachePutTotalBytes;
+ metrics::Counter CacheGetCount;
+ metrics::Counter CacheHitCount;
+ metrics::Counter CacheErrorCount;
+};
+
+enum class UpstreamEndpointState : uint32_t
+{
+ kDisabled,
+ kUnauthorized,
+ kError,
+ kOk
+};
+
+inline std::string_view
+ToString(UpstreamEndpointState State)
+{
+ using namespace std::literals;
+
+ switch (State)
+ {
+ case UpstreamEndpointState::kDisabled:
+ return "Disabled"sv;
+ case UpstreamEndpointState::kUnauthorized:
+ return "Unauthorized"sv;
+ case UpstreamEndpointState::kError:
+ return "Error"sv;
+ case UpstreamEndpointState::kOk:
+ return "Ok"sv;
+ default:
+ return "Unknown"sv;
+ }
+}
+
+struct UpstreamAuthConfig
+{
+ std::string_view OAuthUrl;
+ std::string_view OAuthClientId;
+ std::string_view OAuthClientSecret;
+ std::string_view OpenIdProvider;
+ std::string_view AccessToken;
+};
+
+struct UpstreamEndpointStatus
+{
+ std::string Reason;
+ UpstreamEndpointState State;
+};
+
+/**
+ * The upstream endpoint is responsible for handling upload/downloading of cache records.
+ */
+class UpstreamEndpoint
+{
+public:
+ virtual ~UpstreamEndpoint() = default;
+
+ virtual UpstreamEndpointStatus Initialize() = 0;
+
+ virtual const UpstreamEndpointInfo& GetEndpointInfo() const = 0;
+
+ virtual UpstreamEndpointState GetState() = 0;
+ virtual UpstreamEndpointStatus GetStatus() = 0;
+
+ virtual GetUpstreamCacheSingleResult GetCacheRecord(std::string_view Namespace, const CacheKey& CacheKey, ZenContentType Type) = 0;
+ virtual GetUpstreamCacheResult GetCacheRecords(std::string_view Namespace,
+ std::span<CacheKeyRequest*> Requests,
+ OnCacheRecordGetComplete&& OnComplete) = 0;
+
+ virtual GetUpstreamCacheResult GetCacheValues(std::string_view Namespace,
+ std::span<CacheValueRequest*> CacheValueRequests,
+ OnCacheValueGetComplete&& OnComplete) = 0;
+
+ virtual GetUpstreamCacheSingleResult GetCacheChunk(std::string_view Namespace, const CacheKey& CacheKey, const IoHash& PayloadId) = 0;
+ virtual GetUpstreamCacheResult GetCacheChunks(std::string_view Namespace,
+ std::span<CacheChunkRequest*> CacheChunkRequests,
+ OnCacheChunksGetComplete&& OnComplete) = 0;
+
+ virtual PutUpstreamCacheResult PutCacheRecord(const UpstreamCacheRecord& CacheRecord,
+ IoBuffer RecordValue,
+ std::span<IoBuffer const> Payloads) = 0;
+
+ virtual UpstreamEndpointStats& Stats() = 0;
+
+ static std::unique_ptr<UpstreamEndpoint> CreateZenEndpoint(const ZenStructuredCacheClientOptions& Options);
+
+ static std::unique_ptr<UpstreamEndpoint> CreateJupiterEndpoint(const CloudCacheClientOptions& Options,
+ const UpstreamAuthConfig& AuthConfig,
+ AuthMgr& Mgr);
+};
+
+/**
+ * Manages one or more upstream cache endpoints.
+ */
+class UpstreamCache
+{
+public:
+ virtual ~UpstreamCache() = default;
+
+ virtual void Initialize() = 0;
+
+ virtual void RegisterEndpoint(std::unique_ptr<UpstreamEndpoint> Endpoint) = 0;
+ virtual void IterateEndpoints(std::function<bool(UpstreamEndpoint&)>&& Fn) = 0;
+
+ virtual GetUpstreamCacheSingleResult GetCacheRecord(std::string_view Namespace, const CacheKey& CacheKey, ZenContentType Type) = 0;
+ virtual void GetCacheRecords(std::string_view Namespace,
+ std::span<CacheKeyRequest*> Requests,
+ OnCacheRecordGetComplete&& OnComplete) = 0;
+
+ virtual void GetCacheValues(std::string_view Namespace,
+ std::span<CacheValueRequest*> CacheValueRequests,
+ OnCacheValueGetComplete&& OnComplete) = 0;
+
+ virtual GetUpstreamCacheSingleResult GetCacheChunk(std::string_view Namespace,
+ const CacheKey& CacheKey,
+ const IoHash& ValueContentId) = 0;
+ virtual void GetCacheChunks(std::string_view Namespace,
+ std::span<CacheChunkRequest*> CacheChunkRequests,
+ OnCacheChunksGetComplete&& OnComplete) = 0;
+
+ virtual void EnqueueUpstream(UpstreamCacheRecord CacheRecord) = 0;
+
+ virtual void GetStatus(CbObjectWriter& CbO) = 0;
+
+ static std::unique_ptr<UpstreamCache> Create(const UpstreamCacheOptions& Options, ZenCacheStore& CacheStore, CidStore& CidStore);
+};
+
+} // namespace zen
diff --git a/src/zenserver/upstream/upstreamservice.cpp b/src/zenserver/upstream/upstreamservice.cpp
new file mode 100644
index 000000000..6db1357c5
--- /dev/null
+++ b/src/zenserver/upstream/upstreamservice.cpp
@@ -0,0 +1,56 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+#include <upstream/upstreamservice.h>
+
+#include <auth/authmgr.h>
+#include <upstream/upstreamcache.h>
+
+#include <zencore/compactbinarybuilder.h>
+#include <zencore/string.h>
+
+namespace zen {
+
+using namespace std::literals;
+
+HttpUpstreamService::HttpUpstreamService(UpstreamCache& Upstream, AuthMgr& Mgr) : m_Upstream(Upstream), m_AuthMgr(Mgr)
+{
+ m_Router.RegisterRoute(
+ "endpoints",
+ [this](HttpRouterRequest& Req) {
+ CbObjectWriter Writer;
+ Writer.BeginArray("Endpoints"sv);
+ m_Upstream.IterateEndpoints([&Writer](UpstreamEndpoint& Ep) {
+ UpstreamEndpointInfo Info = Ep.GetEndpointInfo();
+ UpstreamEndpointStatus Status = Ep.GetStatus();
+
+ Writer.BeginObject();
+ Writer << "Name"sv << Info.Name;
+ Writer << "Url"sv << Info.Url;
+ Writer << "State"sv << ToString(Status.State);
+ Writer << "Reason"sv << Status.Reason;
+ Writer.EndObject();
+
+ return true;
+ });
+ Writer.EndArray();
+ Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Writer.Save());
+ },
+ HttpVerb::kGet);
+}
+
+HttpUpstreamService::~HttpUpstreamService()
+{
+}
+
+const char*
+HttpUpstreamService::BaseUri() const
+{
+ return "/upstream/";
+}
+
+void
+HttpUpstreamService::HandleRequest(zen::HttpServerRequest& Request)
+{
+ m_Router.HandleRequest(Request);
+}
+
+} // namespace zen
diff --git a/src/zenserver/upstream/upstreamservice.h b/src/zenserver/upstream/upstreamservice.h
new file mode 100644
index 000000000..f1da03c8c
--- /dev/null
+++ b/src/zenserver/upstream/upstreamservice.h
@@ -0,0 +1,27 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zenhttp/httpserver.h>
+
+namespace zen {
+
+class AuthMgr;
+class UpstreamCache;
+
+class HttpUpstreamService final : public zen::HttpService
+{
+public:
+ HttpUpstreamService(UpstreamCache& Upstream, AuthMgr& Mgr);
+ virtual ~HttpUpstreamService();
+
+ virtual const char* BaseUri() const override;
+ virtual void HandleRequest(zen::HttpServerRequest& Request) override;
+
+private:
+ UpstreamCache& m_Upstream;
+ AuthMgr& m_AuthMgr;
+ HttpRequestRouter m_Router;
+};
+
+} // namespace zen
diff --git a/src/zenserver/upstream/zen.cpp b/src/zenserver/upstream/zen.cpp
new file mode 100644
index 000000000..9e1212834
--- /dev/null
+++ b/src/zenserver/upstream/zen.cpp
@@ -0,0 +1,326 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "zen.h"
+
+#include <zencore/compactbinarybuilder.h>
+#include <zencore/compactbinarypackage.h>
+#include <zencore/compactbinaryvalidation.h>
+#include <zencore/fmtutils.h>
+#include <zencore/session.h>
+#include <zencore/stream.h>
+#include <zenhttp/httpcommon.h>
+#include <zenhttp/httpshared.h>
+
+#include "cache/structuredcachestore.h"
+#include "diag/formatters.h"
+#include "diag/logging.h"
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <cpr/cpr.h>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+#include <xxhash.h>
+#include <gsl/gsl-lite.hpp>
+
+namespace zen {
+
+namespace detail {
+ struct ZenCacheSessionState
+ {
+ ZenCacheSessionState(ZenStructuredCacheClient& Client) : OwnerClient(Client) {}
+ ~ZenCacheSessionState() {}
+
+ void Reset(std::chrono::milliseconds ConnectTimeout, std::chrono::milliseconds Timeout)
+ {
+ Session.SetBody({});
+ Session.SetHeader({});
+ Session.SetConnectTimeout(ConnectTimeout);
+ Session.SetTimeout(Timeout);
+ }
+
+ cpr::Session& GetSession() { return Session; }
+
+ private:
+ ZenStructuredCacheClient& OwnerClient;
+ cpr::Session Session;
+ };
+
+} // namespace detail
+
+//////////////////////////////////////////////////////////////////////////
+
+ZenStructuredCacheClient::ZenStructuredCacheClient(const ZenStructuredCacheClientOptions& Options)
+: m_Log(logging::Get(std::string_view("zenclient")))
+, m_ServiceUrl(Options.Url)
+, m_ConnectTimeout(Options.ConnectTimeout)
+, m_Timeout(Options.Timeout)
+{
+}
+
+ZenStructuredCacheClient::~ZenStructuredCacheClient()
+{
+}
+
+detail::ZenCacheSessionState*
+ZenStructuredCacheClient::AllocSessionState()
+{
+ detail::ZenCacheSessionState* State = nullptr;
+
+ if (RwLock::ExclusiveLockScope _(m_SessionStateLock); !m_SessionStateCache.empty())
+ {
+ State = m_SessionStateCache.front();
+ m_SessionStateCache.pop_front();
+ }
+
+ if (State == nullptr)
+ {
+ State = new detail::ZenCacheSessionState(*this);
+ }
+
+ State->Reset(m_ConnectTimeout, m_Timeout);
+
+ return State;
+}
+
+void
+ZenStructuredCacheClient::FreeSessionState(detail::ZenCacheSessionState* State)
+{
+ RwLock::ExclusiveLockScope _(m_SessionStateLock);
+ m_SessionStateCache.push_front(State);
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+using namespace std::literals;
+
+ZenStructuredCacheSession::ZenStructuredCacheSession(Ref<ZenStructuredCacheClient>&& OuterClient)
+: m_Log(OuterClient->Log())
+, m_Client(std::move(OuterClient))
+{
+ m_SessionState = m_Client->AllocSessionState();
+}
+
+ZenStructuredCacheSession::~ZenStructuredCacheSession()
+{
+ m_Client->FreeSessionState(m_SessionState);
+}
+
+ZenCacheResult
+ZenStructuredCacheSession::CheckHealth()
+{
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_Client->ServiceUrl() << "/health/check";
+
+ cpr::Session& Session = m_SessionState->GetSession();
+ Session.SetOption(cpr::Url{Uri.c_str()});
+ cpr::Response Response = Session.Get();
+
+ if (Response.error)
+ {
+ return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = std::move(Response.error.message)};
+ }
+
+ return {.Bytes = Response.downloaded_bytes, .ElapsedSeconds = Response.elapsed, .Success = Response.status_code == 200};
+}
+
+ZenCacheResult
+ZenStructuredCacheSession::GetCacheRecord(std::string_view Namespace, std::string_view BucketId, const IoHash& Key, ZenContentType Type)
+{
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_Client->ServiceUrl() << "/z$/";
+ if (Namespace != ZenCacheStore::DefaultNamespace)
+ {
+ Uri << Namespace << "/";
+ }
+ Uri << BucketId << "/" << Key.ToHexString();
+
+ cpr::Session& Session = m_SessionState->GetSession();
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+ Session.SetHeader(cpr::Header{{"Accept", std::string{MapContentTypeToString(Type)}}});
+ cpr::Response Response = Session.Get();
+ ZEN_DEBUG("GET {}", Response);
+
+ if (Response.error)
+ {
+ return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = std::move(Response.error.message)};
+ }
+
+ const bool Success = Response.status_code == 200;
+ const IoBuffer Buffer = Success ? IoBufferBuilder::MakeCloneFromMemory(Response.text.data(), Response.text.size()) : IoBuffer();
+
+ return {.Response = Buffer, .Bytes = Response.downloaded_bytes, .ElapsedSeconds = Response.elapsed, .Success = Success};
+}
+
+ZenCacheResult
+ZenStructuredCacheSession::GetCacheChunk(std::string_view Namespace,
+ std::string_view BucketId,
+ const IoHash& Key,
+ const IoHash& ValueContentId)
+{
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_Client->ServiceUrl() << "/z$/";
+ if (Namespace != ZenCacheStore::DefaultNamespace)
+ {
+ Uri << Namespace << "/";
+ }
+ Uri << BucketId << "/" << Key.ToHexString() << "/" << ValueContentId.ToHexString();
+
+ cpr::Session& Session = m_SessionState->GetSession();
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+ Session.SetHeader(cpr::Header{{"Accept", "application/x-ue-comp"}});
+
+ cpr::Response Response = Session.Get();
+ ZEN_DEBUG("GET {}", Response);
+
+ if (Response.error)
+ {
+ return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = std::move(Response.error.message)};
+ }
+
+ const bool Success = Response.status_code == 200;
+ const IoBuffer Buffer = Success ? IoBufferBuilder::MakeCloneFromMemory(Response.text.data(), Response.text.size()) : IoBuffer();
+
+ return {.Response = Buffer,
+ .Bytes = Response.downloaded_bytes,
+ .ElapsedSeconds = Response.elapsed,
+ .Reason = Response.reason,
+ .Success = Success};
+}
+
+ZenCacheResult
+ZenStructuredCacheSession::PutCacheRecord(std::string_view Namespace,
+ std::string_view BucketId,
+ const IoHash& Key,
+ IoBuffer Value,
+ ZenContentType Type)
+{
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_Client->ServiceUrl() << "/z$/";
+ if (Namespace != ZenCacheStore::DefaultNamespace)
+ {
+ Uri << Namespace << "/";
+ }
+ Uri << BucketId << "/" << Key.ToHexString();
+
+ cpr::Session& Session = m_SessionState->GetSession();
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+ Session.SetHeader(cpr::Header{{"Content-Type",
+ Type == ZenContentType::kCbPackage ? "application/x-ue-cbpkg"
+ : Type == ZenContentType::kCbObject ? "application/x-ue-cb"
+ : "application/octet-stream"}});
+ Session.SetBody(cpr::Body{static_cast<const char*>(Value.Data()), Value.Size()});
+
+ cpr::Response Response = Session.Put();
+ ZEN_DEBUG("PUT {}", Response);
+
+ if (Response.error)
+ {
+ return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = std::move(Response.error.message)};
+ }
+
+ const bool Success = Response.status_code == 200 || Response.status_code == 201;
+ return {.Bytes = Response.uploaded_bytes, .ElapsedSeconds = Response.elapsed, .Reason = Response.reason, .Success = Success};
+}
+
+ZenCacheResult
+ZenStructuredCacheSession::PutCacheValue(std::string_view Namespace,
+ std::string_view BucketId,
+ const IoHash& Key,
+ const IoHash& ValueContentId,
+ IoBuffer Payload)
+{
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_Client->ServiceUrl() << "/z$/";
+ if (Namespace != ZenCacheStore::DefaultNamespace)
+ {
+ Uri << Namespace << "/";
+ }
+ Uri << BucketId << "/" << Key.ToHexString() << "/" << ValueContentId.ToHexString();
+
+ cpr::Session& Session = m_SessionState->GetSession();
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+ Session.SetHeader(cpr::Header{{"Content-Type", "application/x-ue-comp"}});
+ Session.SetBody(cpr::Body{static_cast<const char*>(Payload.Data()), Payload.Size()});
+
+ cpr::Response Response = Session.Put();
+ ZEN_DEBUG("PUT {}", Response);
+
+ if (Response.error)
+ {
+ return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = std::move(Response.error.message)};
+ }
+
+ const bool Success = Response.status_code == 200 || Response.status_code == 201;
+ return {.Bytes = Response.uploaded_bytes, .ElapsedSeconds = Response.elapsed, .Reason = Response.reason, .Success = Success};
+}
+
+ZenCacheResult
+ZenStructuredCacheSession::InvokeRpc(const CbObjectView& Request)
+{
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_Client->ServiceUrl() << "/z$/$rpc";
+
+ BinaryWriter Body;
+ Request.CopyTo(Body);
+
+ cpr::Session& Session = m_SessionState->GetSession();
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+ Session.SetHeader(cpr::Header{{"Content-Type", "application/x-ue-cb"}, {"Accept", "application/x-ue-cbpkg"}});
+ Session.SetBody(cpr::Body{reinterpret_cast<const char*>(Body.GetData()), Body.GetSize()});
+
+ cpr::Response Response = Session.Post();
+ ZEN_DEBUG("POST {}", Response);
+
+ if (Response.error)
+ {
+ return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = std::move(Response.error.message)};
+ }
+
+ const bool Success = Response.status_code == 200;
+ const IoBuffer Buffer = Success ? IoBufferBuilder::MakeCloneFromMemory(Response.text.data(), Response.text.size()) : IoBuffer();
+
+ return {.Response = std::move(Buffer),
+ .Bytes = Response.uploaded_bytes,
+ .ElapsedSeconds = Response.elapsed,
+ .Reason = Response.reason,
+ .Success = Success};
+}
+
+ZenCacheResult
+ZenStructuredCacheSession::InvokeRpc(const CbPackage& Request)
+{
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_Client->ServiceUrl() << "/z$/$rpc";
+
+ SharedBuffer Message = FormatPackageMessageBuffer(Request).Flatten();
+
+ cpr::Session& Session = m_SessionState->GetSession();
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+ Session.SetHeader(cpr::Header{{"Content-Type", "application/x-ue-cbpkg"}, {"Accept", "application/x-ue-cbpkg"}});
+ Session.SetBody(cpr::Body{reinterpret_cast<const char*>(Message.GetData()), Message.GetSize()});
+
+ cpr::Response Response = Session.Post();
+ ZEN_DEBUG("POST {}", Response);
+
+ if (Response.error)
+ {
+ return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = std::move(Response.error.message)};
+ }
+
+ const bool Success = Response.status_code == 200;
+ const IoBuffer Buffer = Success ? IoBufferBuilder::MakeCloneFromMemory(Response.text.data(), Response.text.size()) : IoBuffer();
+
+ return {.Response = std::move(Buffer),
+ .Bytes = Response.uploaded_bytes,
+ .ElapsedSeconds = Response.elapsed,
+ .Reason = Response.reason,
+ .Success = Success};
+}
+
+} // namespace zen
diff --git a/src/zenserver/upstream/zen.h b/src/zenserver/upstream/zen.h
new file mode 100644
index 000000000..bfba8fa98
--- /dev/null
+++ b/src/zenserver/upstream/zen.h
@@ -0,0 +1,125 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/iobuffer.h>
+#include <zencore/iohash.h>
+#include <zencore/logging.h>
+#include <zencore/memory.h>
+#include <zencore/thread.h>
+#include <zencore/uid.h>
+#include <zencore/zencore.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <tsl/robin_map.h>
+#include <asio.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+#include <chrono>
+#include <list>
+
+struct ZenCacheValue;
+
+namespace spdlog {
+class logger;
+}
+
+namespace zen {
+
+class CbObjectWriter;
+class CbObjectView;
+class CbPackage;
+class ZenStructuredCacheClient;
+
+//////////////////////////////////////////////////////////////////////////
+
+namespace detail {
+ struct ZenCacheSessionState;
+}
+
+struct ZenCacheResult
+{
+ IoBuffer Response;
+ int64_t Bytes = {};
+ double ElapsedSeconds = {};
+ int32_t ErrorCode = {};
+ std::string Reason;
+ bool Success = false;
+};
+
+struct ZenStructuredCacheClientOptions
+{
+ std::string_view Name;
+ std::string_view Url;
+ std::span<std::string const> Urls;
+ std::chrono::milliseconds ConnectTimeout{};
+ std::chrono::milliseconds Timeout{};
+};
+
+/** Zen Structured Cache session
+ *
+ * This provides a context in which cache queries can be performed
+ *
+ * These are currently all synchronous. Will need to be made asynchronous
+ */
+class ZenStructuredCacheSession
+{
+public:
+ ZenStructuredCacheSession(Ref<ZenStructuredCacheClient>&& OuterClient);
+ ~ZenStructuredCacheSession();
+
+ ZenCacheResult CheckHealth();
+ ZenCacheResult GetCacheRecord(std::string_view Namespace, std::string_view BucketId, const IoHash& Key, ZenContentType Type);
+ ZenCacheResult GetCacheChunk(std::string_view Namespace, std::string_view BucketId, const IoHash& Key, const IoHash& ValueContentId);
+ ZenCacheResult PutCacheRecord(std::string_view Namespace,
+ std::string_view BucketId,
+ const IoHash& Key,
+ IoBuffer Value,
+ ZenContentType Type);
+ ZenCacheResult PutCacheValue(std::string_view Namespace,
+ std::string_view BucketId,
+ const IoHash& Key,
+ const IoHash& ValueContentId,
+ IoBuffer Payload);
+ ZenCacheResult InvokeRpc(const CbObjectView& Request);
+ ZenCacheResult InvokeRpc(const CbPackage& Package);
+
+private:
+ inline spdlog::logger& Log() { return m_Log; }
+
+ spdlog::logger& m_Log;
+ Ref<ZenStructuredCacheClient> m_Client;
+ detail::ZenCacheSessionState* m_SessionState;
+};
+
+/** Zen Structured Cache client
+ *
+ * This represents an endpoint to query -- actual queries should be done via
+ * ZenStructuredCacheSession
+ */
+class ZenStructuredCacheClient : public RefCounted
+{
+public:
+ ZenStructuredCacheClient(const ZenStructuredCacheClientOptions& Options);
+ ~ZenStructuredCacheClient();
+
+ std::string_view ServiceUrl() const { return m_ServiceUrl; }
+
+ inline spdlog::logger& Log() { return m_Log; }
+
+private:
+ spdlog::logger& m_Log;
+ std::string m_ServiceUrl;
+ std::chrono::milliseconds m_ConnectTimeout;
+ std::chrono::milliseconds m_Timeout;
+
+ RwLock m_SessionStateLock;
+ std::list<detail::ZenCacheSessionState*> m_SessionStateCache;
+
+ detail::ZenCacheSessionState* AllocSessionState();
+ void FreeSessionState(detail::ZenCacheSessionState*);
+
+ friend class ZenStructuredCacheSession;
+};
+
+} // namespace zen