diff options
| author | Stefan Boberg <[email protected]> | 2023-05-02 10:01:47 +0200 |
|---|---|---|
| committer | GitHub <[email protected]> | 2023-05-02 10:01:47 +0200 |
| commit | 075d17f8ada47e990fe94606c3d21df409223465 (patch) | |
| tree | e50549b766a2f3c354798a54ff73404217b4c9af /src/zenserver/upstream | |
| parent | fix: bundle shouldn't append content zip to zen (diff) | |
| download | zen-075d17f8ada47e990fe94606c3d21df409223465.tar.xz zen-075d17f8ada47e990fe94606c3d21df409223465.zip | |
moved source directories into `/src` (#264)
* moved source directories into `/src`
* updated bundle.lua for new `src` path
* moved some docs, icon
* removed old test trees
Diffstat (limited to 'src/zenserver/upstream')
| -rw-r--r-- | src/zenserver/upstream/hordecompute.cpp | 1457 | ||||
| -rw-r--r-- | src/zenserver/upstream/jupiter.cpp | 965 | ||||
| -rw-r--r-- | src/zenserver/upstream/jupiter.h | 217 | ||||
| -rw-r--r-- | src/zenserver/upstream/upstream.h | 8 | ||||
| -rw-r--r-- | src/zenserver/upstream/upstreamapply.cpp | 459 | ||||
| -rw-r--r-- | src/zenserver/upstream/upstreamapply.h | 192 | ||||
| -rw-r--r-- | src/zenserver/upstream/upstreamcache.cpp | 2112 | ||||
| -rw-r--r-- | src/zenserver/upstream/upstreamcache.h | 252 | ||||
| -rw-r--r-- | src/zenserver/upstream/upstreamservice.cpp | 56 | ||||
| -rw-r--r-- | src/zenserver/upstream/upstreamservice.h | 27 | ||||
| -rw-r--r-- | src/zenserver/upstream/zen.cpp | 326 | ||||
| -rw-r--r-- | src/zenserver/upstream/zen.h | 125 |
12 files changed, 6196 insertions, 0 deletions
diff --git a/src/zenserver/upstream/hordecompute.cpp b/src/zenserver/upstream/hordecompute.cpp new file mode 100644 index 000000000..64d9fff72 --- /dev/null +++ b/src/zenserver/upstream/hordecompute.cpp @@ -0,0 +1,1457 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "upstreamapply.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include "jupiter.h" + +# include <zencore/compactbinary.h> +# include <zencore/compactbinarybuilder.h> +# include <zencore/compactbinarypackage.h> +# include <zencore/compactbinaryvalidation.h> +# include <zencore/fmtutils.h> +# include <zencore/session.h> +# include <zencore/stream.h> +# include <zencore/thread.h> +# include <zencore/timer.h> +# include <zencore/workthreadpool.h> + +# include <zenstore/cidstore.h> + +# include <auth/authmgr.h> +# include <upstream/upstreamcache.h> + +# include "cache/structuredcachestore.h" +# include "diag/logging.h" + +# include <fmt/format.h> + +# include <algorithm> +# include <atomic> +# include <set> +# include <stack> + +namespace zen { + +using namespace std::literals; + +static const IoBuffer EmptyBuffer; +static const IoHash EmptyBufferId = IoHash::HashBuffer(EmptyBuffer); + +namespace detail { + + class HordeUpstreamApplyEndpoint final : public UpstreamApplyEndpoint + { + public: + HordeUpstreamApplyEndpoint(const CloudCacheClientOptions& ComputeOptions, + const UpstreamAuthConfig& ComputeAuthConfig, + const CloudCacheClientOptions& StorageOptions, + const UpstreamAuthConfig& StorageAuthConfig, + CidStore& CidStore, + AuthMgr& Mgr) + : m_Log(logging::Get("upstream-apply")) + , m_CidStore(CidStore) + , m_AuthMgr(Mgr) + { + m_DisplayName = fmt::format("{} - '{}'+'{}'", ComputeOptions.Name, ComputeOptions.ServiceUrl, StorageOptions.ServiceUrl); + m_ChannelId = fmt::format("zen-{}", zen::GetSessionIdString()); + + { + std::unique_ptr<CloudCacheTokenProvider> TokenProvider; + + if (ComputeAuthConfig.OAuthUrl.empty() == false) + { + TokenProvider = + CloudCacheTokenProvider::CreateFromOAuthClientCredentials({.Url = ComputeAuthConfig.OAuthUrl, + .ClientId = ComputeAuthConfig.OAuthClientId, + .ClientSecret = ComputeAuthConfig.OAuthClientSecret}); + } + else if (ComputeAuthConfig.OpenIdProvider.empty() == false) + { + TokenProvider = + CloudCacheTokenProvider::CreateFromCallback([this, ProviderName = std::string(ComputeAuthConfig.OpenIdProvider)]() { + AuthMgr::OpenIdAccessToken Token = m_AuthMgr.GetOpenIdAccessToken(ProviderName); + return CloudCacheAccessToken{.Value = Token.AccessToken, .ExpireTime = Token.ExpireTime}; + }); + } + else + { + CloudCacheAccessToken AccessToken{.Value = std::string(ComputeAuthConfig.AccessToken), + .ExpireTime = CloudCacheAccessToken::TimePoint::max()}; + TokenProvider = CloudCacheTokenProvider::CreateFromStaticToken(AccessToken); + } + + m_Client = new CloudCacheClient(ComputeOptions, std::move(TokenProvider)); + } + + { + std::unique_ptr<CloudCacheTokenProvider> TokenProvider; + + if (StorageAuthConfig.OAuthUrl.empty() == false) + { + TokenProvider = + CloudCacheTokenProvider::CreateFromOAuthClientCredentials({.Url = StorageAuthConfig.OAuthUrl, + .ClientId = StorageAuthConfig.OAuthClientId, + .ClientSecret = StorageAuthConfig.OAuthClientSecret}); + } + else if (StorageAuthConfig.OpenIdProvider.empty() == false) + { + TokenProvider = + CloudCacheTokenProvider::CreateFromCallback([this, ProviderName = std::string(StorageAuthConfig.OpenIdProvider)]() { + AuthMgr::OpenIdAccessToken Token = m_AuthMgr.GetOpenIdAccessToken(ProviderName); + return CloudCacheAccessToken{.Value = Token.AccessToken, .ExpireTime = Token.ExpireTime}; + }); + } + else + { + CloudCacheAccessToken AccessToken{.Value = std::string(StorageAuthConfig.AccessToken), + .ExpireTime = CloudCacheAccessToken::TimePoint::max()}; + TokenProvider = CloudCacheTokenProvider::CreateFromStaticToken(AccessToken); + } + + m_StorageClient = new CloudCacheClient(StorageOptions, std::move(TokenProvider)); + } + } + + virtual ~HordeUpstreamApplyEndpoint() = default; + + virtual UpstreamEndpointHealth Initialize() override { return CheckHealth(); } + + virtual bool IsHealthy() const override { return m_HealthOk.load(); } + + virtual UpstreamEndpointHealth CheckHealth() override + { + try + { + CloudCacheSession Session(m_Client); + CloudCacheResult Result = Session.Authenticate(); + + m_HealthOk = Result.ErrorCode == 0; + + return {.Reason = std::move(Result.Reason), .Ok = Result.Success}; + } + catch (std::exception& Err) + { + return {.Reason = Err.what(), .Ok = false}; + } + } + + virtual std::string_view DisplayName() const override { return m_DisplayName; } + + virtual PostUpstreamApplyResult PostApply(UpstreamApplyRecord ApplyRecord) override + { + PostUpstreamApplyResult ApplyResult{}; + ApplyResult.Timepoints.merge(ApplyRecord.Timepoints); + + try + { + UpstreamData UpstreamData; + if (!ProcessApplyKey(ApplyRecord, UpstreamData)) + { + return {.Error{.ErrorCode = -1, .Reason = "Failed to generate task data"}}; + } + + { + ApplyResult.Timepoints["zen-storage-build-ref"] = DateTime::NowTicks(); + + bool AlreadyQueued; + { + std::scoped_lock Lock(m_TaskMutex); + AlreadyQueued = m_PendingTasks.contains(UpstreamData.TaskId); + } + if (AlreadyQueued) + { + // Pending task is already queued, return success + ApplyResult.Success = true; + return ApplyResult; + } + m_PendingTasks[UpstreamData.TaskId] = std::move(ApplyRecord); + } + + CloudCacheSession ComputeSession(m_Client); + CloudCacheSession StorageSession(m_StorageClient); + + { + CloudCacheResult Result = BatchPutBlobsIfMissing(StorageSession, UpstreamData.Blobs, UpstreamData.CasIds); + ApplyResult.Bytes += Result.Bytes; + ApplyResult.ElapsedSeconds += Result.ElapsedSeconds; + ApplyResult.Timepoints["zen-storage-upload-blobs"] = DateTime::NowTicks(); + if (!Result.Success) + { + ApplyResult.Error = {.ErrorCode = Result.ErrorCode, + .Reason = !Result.Reason.empty() ? std::move(Result.Reason) : "Failed to upload blobs"}; + return ApplyResult; + } + UpstreamData.Blobs.clear(); + UpstreamData.CasIds.clear(); + } + + { + CloudCacheResult Result = BatchPutCompressedBlobsIfMissing(StorageSession, UpstreamData.Cids); + ApplyResult.Bytes += Result.Bytes; + ApplyResult.ElapsedSeconds += Result.ElapsedSeconds; + ApplyResult.Timepoints["zen-storage-upload-compressed-blobs"] = DateTime::NowTicks(); + if (!Result.Success) + { + ApplyResult.Error = { + .ErrorCode = Result.ErrorCode, + .Reason = !Result.Reason.empty() ? std::move(Result.Reason) : "Failed to upload compressed blobs"}; + return ApplyResult; + } + UpstreamData.Cids.clear(); + } + + { + CloudCacheResult Result = BatchPutObjectsIfMissing(StorageSession, UpstreamData.Objects); + ApplyResult.Bytes += Result.Bytes; + ApplyResult.ElapsedSeconds += Result.ElapsedSeconds; + ApplyResult.Timepoints["zen-storage-upload-objects"] = DateTime::NowTicks(); + if (!Result.Success) + { + ApplyResult.Error = {.ErrorCode = Result.ErrorCode, + .Reason = !Result.Reason.empty() ? std::move(Result.Reason) : "Failed to upload objects"}; + return ApplyResult; + } + } + + { + PutRefResult RefResult = StorageSession.PutRef(StorageSession.Client().DefaultBlobStoreNamespace(), + "requests"sv, + UpstreamData.TaskId, + UpstreamData.Objects[UpstreamData.TaskId].GetBuffer().AsIoBuffer(), + ZenContentType::kCbObject); + Log().debug("Put ref {} Need={} Bytes={} Duration={}s Result={}", + UpstreamData.TaskId, + RefResult.Needs.size(), + RefResult.Bytes, + RefResult.ElapsedSeconds, + RefResult.Success); + ApplyResult.Bytes += RefResult.Bytes; + ApplyResult.ElapsedSeconds += RefResult.ElapsedSeconds; + ApplyResult.Timepoints["zen-storage-put-ref"] = DateTime::NowTicks(); + + if (RefResult.Needs.size() > 0) + { + Log().error("Failed to add task ref {} due to {} missing blobs", UpstreamData.TaskId, RefResult.Needs.size()); + for (const auto& Hash : RefResult.Needs) + { + Log().debug("Task ref {} missing blob {}", UpstreamData.TaskId, Hash); + } + + ApplyResult.Error = {.ErrorCode = RefResult.ErrorCode, + .Reason = !RefResult.Reason.empty() ? std::move(RefResult.Reason) + : "Failed to add task ref due to missing blob"}; + return ApplyResult; + } + + if (!RefResult.Success) + { + ApplyResult.Error = {.ErrorCode = RefResult.ErrorCode, + .Reason = !RefResult.Reason.empty() ? std::move(RefResult.Reason) : "Failed to add task ref"}; + return ApplyResult; + } + UpstreamData.Objects.clear(); + } + + { + CbObjectWriter Writer; + Writer.AddString("c"sv, m_ChannelId); + Writer.AddObjectAttachment("r"sv, UpstreamData.RequirementsId); + Writer.BeginArray("t"sv); + Writer.AddObjectAttachment(UpstreamData.TaskId); + Writer.EndArray(); + CbObject TasksObject = Writer.Save(); + IoBuffer TasksData = TasksObject.GetBuffer().AsIoBuffer(); + + CloudCacheResult Result = ComputeSession.PostComputeTasks(TasksData); + Log().debug("Post compute task {} Bytes={} Duration={}s Result={}", + TasksObject.GetHash(), + Result.Bytes, + Result.ElapsedSeconds, + Result.Success); + ApplyResult.Bytes += Result.Bytes; + ApplyResult.ElapsedSeconds += Result.ElapsedSeconds; + ApplyResult.Timepoints["zen-horde-post-task"] = DateTime::NowTicks(); + if (!Result.Success) + { + { + std::scoped_lock Lock(m_TaskMutex); + m_PendingTasks.erase(UpstreamData.TaskId); + } + + ApplyResult.Error = {.ErrorCode = Result.ErrorCode, + .Reason = !Result.Reason.empty() ? std::move(Result.Reason) : "Failed to post compute task"}; + return ApplyResult; + } + } + + Log().info("Task posted {}", UpstreamData.TaskId); + ApplyResult.Success = true; + return ApplyResult; + } + catch (std::exception& Err) + { + m_HealthOk = false; + return {.Error{.ErrorCode = -1, .Reason = Err.what()}}; + } + } + + [[nodiscard]] CloudCacheResult BatchPutBlobsIfMissing(CloudCacheSession& Session, + const std::map<IoHash, IoBuffer>& Blobs, + const std::set<IoHash>& CasIds) + { + if (Blobs.size() == 0 && CasIds.size() == 0) + { + return {.Success = true}; + } + + int64_t Bytes{}; + double ElapsedSeconds{}; + + // Batch check for missing blobs + std::set<IoHash> Keys; + std::transform(Blobs.begin(), Blobs.end(), std::inserter(Keys, Keys.end()), [](const auto& It) { return It.first; }); + Keys.insert(CasIds.begin(), CasIds.end()); + + CloudCacheExistsResult ExistsResult = Session.BlobExists(Session.Client().DefaultBlobStoreNamespace(), Keys); + Log().debug("Queried {} missing blobs Need={} Duration={}s Result={}", + Keys.size(), + ExistsResult.Needs.size(), + ExistsResult.ElapsedSeconds, + ExistsResult.Success); + ElapsedSeconds += ExistsResult.ElapsedSeconds; + if (!ExistsResult.Success) + { + return {.Bytes = Bytes, + .ElapsedSeconds = ElapsedSeconds, + .ErrorCode = ExistsResult.ErrorCode ? ExistsResult.ErrorCode : -1, + .Reason = !ExistsResult.Reason.empty() ? std::move(ExistsResult.Reason) : "Failed to check if blobs exist"}; + } + + for (const auto& Hash : ExistsResult.Needs) + { + IoBuffer DataBuffer; + if (Blobs.contains(Hash)) + { + DataBuffer = Blobs.at(Hash); + } + else + { + DataBuffer = m_CidStore.FindChunkByCid(Hash); + if (!DataBuffer) + { + Log().warn("Put blob FAILED, input chunk '{}' missing", Hash); + return {.Bytes = Bytes, .ElapsedSeconds = ElapsedSeconds, .ErrorCode = -1, .Reason = "Failed to put blobs"}; + } + } + + CloudCacheResult Result = Session.PutBlob(Session.Client().DefaultBlobStoreNamespace(), Hash, DataBuffer); + Log().debug("Put blob {} Bytes={} Duration={}s Result={}", Hash, Result.Bytes, Result.ElapsedSeconds, Result.Success); + Bytes += Result.Bytes; + ElapsedSeconds += Result.ElapsedSeconds; + if (!Result.Success) + { + return {.Bytes = Bytes, + .ElapsedSeconds = ElapsedSeconds, + .ErrorCode = Result.ErrorCode ? Result.ErrorCode : -1, + .Reason = !Result.Reason.empty() ? std::move(Result.Reason) : "Failed to put blobs"}; + } + } + + return {.Bytes = Bytes, .ElapsedSeconds = ElapsedSeconds, .Success = true}; + } + + [[nodiscard]] CloudCacheResult BatchPutCompressedBlobsIfMissing(CloudCacheSession& Session, const std::set<IoHash>& Cids) + { + if (Cids.size() == 0) + { + return {.Success = true}; + } + + int64_t Bytes{}; + double ElapsedSeconds{}; + + // Batch check for missing compressed blobs + CloudCacheExistsResult ExistsResult = Session.CompressedBlobExists(Session.Client().DefaultBlobStoreNamespace(), Cids); + Log().debug("Queried {} missing compressed blobs Need={} Duration={}s Result={}", + Cids.size(), + ExistsResult.Needs.size(), + ExistsResult.ElapsedSeconds, + ExistsResult.Success); + ElapsedSeconds += ExistsResult.ElapsedSeconds; + if (!ExistsResult.Success) + { + return { + .Bytes = Bytes, + .ElapsedSeconds = ElapsedSeconds, + .ErrorCode = ExistsResult.ErrorCode ? ExistsResult.ErrorCode : -1, + .Reason = !ExistsResult.Reason.empty() ? std::move(ExistsResult.Reason) : "Failed to check if compressed blobs exist"}; + } + + for (const auto& Hash : ExistsResult.Needs) + { + IoBuffer DataBuffer = m_CidStore.FindChunkByCid(Hash); + if (!DataBuffer) + { + Log().warn("Put compressed blob FAILED, input CID chunk '{}' missing", Hash); + return {.Bytes = Bytes, .ElapsedSeconds = ElapsedSeconds, .ErrorCode = -1, .Reason = "Failed to put compressed blobs"}; + } + + CloudCacheResult Result = Session.PutCompressedBlob(Session.Client().DefaultBlobStoreNamespace(), Hash, DataBuffer); + Log().debug("Put compressed blob {} Bytes={} Duration={}s Result={}", + Hash, + Result.Bytes, + Result.ElapsedSeconds, + Result.Success); + Bytes += Result.Bytes; + ElapsedSeconds += Result.ElapsedSeconds; + if (!Result.Success) + { + return {.Bytes = Bytes, + .ElapsedSeconds = ElapsedSeconds, + .ErrorCode = Result.ErrorCode ? Result.ErrorCode : -1, + .Reason = !Result.Reason.empty() ? std::move(Result.Reason) : "Failed to put compressed blobs"}; + } + } + + return {.Bytes = Bytes, .ElapsedSeconds = ElapsedSeconds, .Success = true}; + } + + [[nodiscard]] CloudCacheResult BatchPutObjectsIfMissing(CloudCacheSession& Session, const std::map<IoHash, CbObject>& Objects) + { + if (Objects.size() == 0) + { + return {.Success = true}; + } + + int64_t Bytes{}; + double ElapsedSeconds{}; + + // Batch check for missing objects + std::set<IoHash> Keys; + std::transform(Objects.begin(), Objects.end(), std::inserter(Keys, Keys.end()), [](const auto& It) { return It.first; }); + + CloudCacheExistsResult ExistsResult = Session.ObjectExists(Session.Client().DefaultBlobStoreNamespace(), Keys); + Log().debug("Queried {} missing objects Need={} Duration={}s Result={}", + Keys.size(), + ExistsResult.Needs.size(), + ExistsResult.ElapsedSeconds, + ExistsResult.Success); + ElapsedSeconds += ExistsResult.ElapsedSeconds; + if (!ExistsResult.Success) + { + return {.Bytes = Bytes, + .ElapsedSeconds = ElapsedSeconds, + .ErrorCode = ExistsResult.ErrorCode ? ExistsResult.ErrorCode : -1, + .Reason = !ExistsResult.Reason.empty() ? std::move(ExistsResult.Reason) : "Failed to check if objects exist"}; + } + + for (const auto& Hash : ExistsResult.Needs) + { + CloudCacheResult Result = + Session.PutObject(Session.Client().DefaultBlobStoreNamespace(), Hash, Objects.at(Hash).GetBuffer().AsIoBuffer()); + Log().debug("Put object {} Bytes={} Duration={}s Result={}", Hash, Result.Bytes, Result.ElapsedSeconds, Result.Success); + Bytes += Result.Bytes; + ElapsedSeconds += Result.ElapsedSeconds; + if (!Result.Success) + { + return {.Bytes = Bytes, + .ElapsedSeconds = ElapsedSeconds, + .ErrorCode = Result.ErrorCode ? Result.ErrorCode : -1, + .Reason = !Result.Reason.empty() ? std::move(Result.Reason) : "Failed to put objects"}; + } + } + + return {.Bytes = Bytes, .ElapsedSeconds = ElapsedSeconds, .Success = true}; + } + + enum class ComputeTaskState : int32_t + { + Queued = 0, + Executing = 1, + Complete = 2, + }; + + enum class ComputeTaskOutcome : int32_t + { + Success = 0, + Failed = 1, + Cancelled = 2, + NoResult = 3, + Exipred = 4, + BlobNotFound = 5, + Exception = 6, + }; + + [[nodiscard]] static std::string_view ComputeTaskStateToString(const ComputeTaskState Outcome) + { + switch (Outcome) + { + case ComputeTaskState::Queued: + return "Queued"sv; + case ComputeTaskState::Executing: + return "Executing"sv; + case ComputeTaskState::Complete: + return "Complete"sv; + }; + return "Unknown"sv; + } + + [[nodiscard]] static std::string_view ComputeTaskOutcomeToString(const ComputeTaskOutcome Outcome) + { + switch (Outcome) + { + case ComputeTaskOutcome::Success: + return "Success"sv; + case ComputeTaskOutcome::Failed: + return "Failed"sv; + case ComputeTaskOutcome::Cancelled: + return "Cancelled"sv; + case ComputeTaskOutcome::NoResult: + return "NoResult"sv; + case ComputeTaskOutcome::Exipred: + return "Exipred"sv; + case ComputeTaskOutcome::BlobNotFound: + return "BlobNotFound"sv; + case ComputeTaskOutcome::Exception: + return "Exception"sv; + }; + return "Unknown"sv; + } + + virtual GetUpstreamApplyUpdatesResult GetUpdates(WorkerThreadPool& ThreadPool) override + { + int64_t Bytes{}; + double ElapsedSeconds{}; + + { + std::scoped_lock Lock(m_TaskMutex); + if (m_PendingTasks.empty()) + { + if (m_CompletedTasks.empty()) + { + // Nothing to do. + return {.Success = true}; + } + + UpstreamApplyCompleted CompletedTasks; + std::swap(CompletedTasks, m_CompletedTasks); + return {.Bytes = Bytes, .ElapsedSeconds = ElapsedSeconds, .Completed = std::move(CompletedTasks), .Success = true}; + } + } + + try + { + CloudCacheSession ComputeSession(m_Client); + + CloudCacheResult UpdatesResult = ComputeSession.GetComputeUpdates(m_ChannelId); + Log().debug("Get compute updates Bytes={} Duration={}s Result={}", + UpdatesResult.Bytes, + UpdatesResult.ElapsedSeconds, + UpdatesResult.Success); + Bytes += UpdatesResult.Bytes; + ElapsedSeconds += UpdatesResult.ElapsedSeconds; + if (!UpdatesResult.Success) + { + return {.Error{.ErrorCode = UpdatesResult.ErrorCode, .Reason = std::move(UpdatesResult.Reason)}, + .Bytes = Bytes, + .ElapsedSeconds = ElapsedSeconds}; + } + + if (!UpdatesResult.Success) + { + return {.Error{.ErrorCode = -1, .Reason = "Failed get task updates"}, .Bytes = Bytes, .ElapsedSeconds = ElapsedSeconds}; + } + + CbObject TaskStatus = LoadCompactBinaryObject(std::move(UpdatesResult.Response)); + + for (auto& It : TaskStatus["u"sv]) + { + CbObjectView Status = It.AsObjectView(); + IoHash TaskId = Status["h"sv].AsHash(); + const ComputeTaskState State = (ComputeTaskState)Status["s"sv].AsInt32(); + const ComputeTaskOutcome Outcome = (ComputeTaskOutcome)Status["o"sv].AsInt32(); + + Log().info("Task {} State={}", TaskId, ComputeTaskStateToString(State)); + + // Only completed tasks need to be processed + if (State != ComputeTaskState::Complete) + { + continue; + } + + IoHash WorkerId{}; + IoHash ActionId{}; + UpstreamApplyType ApplyType{}; + + { + std::scoped_lock Lock(m_TaskMutex); + auto TaskIt = m_PendingTasks.find(TaskId); + if (TaskIt != m_PendingTasks.end()) + { + WorkerId = TaskIt->second.WorkerDescriptor.GetHash(); + ActionId = TaskIt->second.Action.GetHash(); + ApplyType = TaskIt->second.Type; + m_PendingTasks.erase(TaskIt); + } + } + + if (WorkerId == IoHash::Zero) + { + Log().warn("Task {} missing from pending tasks", TaskId); + continue; + } + + std::map<std::string, uint64_t> Timepoints; + ProcessQueueTimings(Status["qs"sv].AsObjectView(), Timepoints); + ProcessExecuteTimings(Status["es"sv].AsObjectView(), Timepoints); + + if (Outcome != ComputeTaskOutcome::Success) + { + const std::string_view Detail = Status["d"sv].AsString(); + { + std::scoped_lock Lock(m_TaskMutex); + m_CompletedTasks[WorkerId][ActionId] = { + .Error{.ErrorCode = -1, .Reason = fmt::format("Task {} {}", ComputeTaskOutcomeToString(Outcome), Detail)}, + .Timepoints = std::move(Timepoints)}; + } + continue; + } + + Timepoints["zen-complete-queue-added"] = DateTime::NowTicks(); + ThreadPool.ScheduleWork([this, + ApplyType, + ResultHash = Status["r"sv].AsHash(), + Timepoints = std::move(Timepoints), + TaskId = std::move(TaskId), + WorkerId = std::move(WorkerId), + ActionId = std::move(ActionId)]() mutable { + Timepoints["zen-complete-queue-dispatched"] = DateTime::NowTicks(); + GetUpstreamApplyResult Result = ProcessTaskStatus(ApplyType, ResultHash); + Timepoints["zen-complete-queue-complete"] = DateTime::NowTicks(); + Result.Timepoints.merge(Timepoints); + + Log().debug("Task Processed {} Files={} Attachments={} ExitCode={}", + TaskId, + Result.OutputFiles.size(), + Result.OutputPackage.GetAttachments().size(), + Result.Error.ErrorCode); + { + std::scoped_lock Lock(m_TaskMutex); + m_CompletedTasks[WorkerId][ActionId] = std::move(Result); + } + }); + } + + { + std::scoped_lock Lock(m_TaskMutex); + if (m_CompletedTasks.empty()) + { + // Nothing to do. + return {.Bytes = Bytes, .ElapsedSeconds = ElapsedSeconds, .Success = true}; + } + UpstreamApplyCompleted CompletedTasks; + std::swap(CompletedTasks, m_CompletedTasks); + return {.Bytes = Bytes, .ElapsedSeconds = ElapsedSeconds, .Completed = std::move(CompletedTasks), .Success = true}; + } + } + catch (std::exception& Err) + { + m_HealthOk = false; + return { + .Error{.ErrorCode = -1, .Reason = Err.what()}, + .Bytes = Bytes, + .ElapsedSeconds = ElapsedSeconds, + }; + } + } + + virtual UpstreamApplyEndpointStats& Stats() override { return m_Stats; } + + private: + spdlog::logger& Log() { return m_Log; } + + spdlog::logger& m_Log; + CidStore& m_CidStore; + AuthMgr& m_AuthMgr; + std::string m_DisplayName; + RefPtr<CloudCacheClient> m_Client; + RefPtr<CloudCacheClient> m_StorageClient; + UpstreamApplyEndpointStats m_Stats; + std::atomic_bool m_HealthOk{false}; + std::string m_ChannelId; + + std::mutex m_TaskMutex; + std::unordered_map<IoHash, UpstreamApplyRecord> m_PendingTasks; + UpstreamApplyCompleted m_CompletedTasks; + + struct UpstreamData + { + std::map<IoHash, IoBuffer> Blobs; + std::map<IoHash, CbObject> Objects; + std::set<IoHash> CasIds; + std::set<IoHash> Cids; + IoHash TaskId; + IoHash RequirementsId; + }; + + struct UpstreamDirectory + { + std::filesystem::path Path; + std::map<std::string, UpstreamDirectory> Directories; + std::set<std::string> Files; + }; + + static void ProcessQueueTimings(CbObjectView QueueStats, std::map<std::string, uint64_t>& Timepoints) + { + uint64_t Ticks = QueueStats["t"sv].AsDateTimeTicks(); + if (Ticks == 0) + { + return; + } + + // Scope is an array of miliseconds after start time + // TODO: cleanup + Timepoints["horde-queue-added"] = Ticks; + int Index = 0; + for (auto& Item : QueueStats["s"sv].AsArrayView()) + { + Ticks += Item.AsInt32() * TimeSpan::TicksPerMillisecond; + switch (Index) + { + case 0: + Timepoints["horde-queue-dispatched"] = Ticks; + break; + case 1: + Timepoints["horde-queue-complete"] = Ticks; + break; + } + Index++; + } + } + + static void ProcessExecuteTimings(CbObjectView ExecutionStats, std::map<std::string, uint64_t>& Timepoints) + { + uint64_t Ticks = ExecutionStats["t"sv].AsDateTimeTicks(); + if (Ticks == 0) + { + return; + } + + // Scope is an array of miliseconds after start time + // TODO: cleanup + Timepoints["horde-execution-start"] = Ticks; + int Index = 0; + for (auto& Item : ExecutionStats["s"sv].AsArrayView()) + { + Ticks += Item.AsInt32() * TimeSpan::TicksPerMillisecond; + switch (Index) + { + case 0: + Timepoints["horde-execution-download-ref"] = Ticks; + break; + case 1: + Timepoints["horde-execution-download-input"] = Ticks; + break; + case 2: + Timepoints["horde-execution-execute"] = Ticks; + break; + case 3: + Timepoints["horde-execution-upload-log"] = Ticks; + break; + case 4: + Timepoints["horde-execution-upload-output"] = Ticks; + break; + case 5: + Timepoints["horde-execution-upload-ref"] = Ticks; + break; + } + Index++; + } + } + + [[nodiscard]] GetUpstreamApplyResult ProcessTaskStatus(const UpstreamApplyType ApplyType, const IoHash& ResultHash) + { + try + { + CloudCacheSession Session(m_StorageClient); + + GetUpstreamApplyResult ApplyResult{}; + + IoHash StdOutHash; + IoHash StdErrHash; + IoHash OutputHash; + + std::map<IoHash, IoBuffer> BinaryData; + + { + CloudCacheResult ObjectRefResult = + Session.GetRef(Session.Client().DefaultBlobStoreNamespace(), "responses"sv, ResultHash, ZenContentType::kCbObject); + Log().debug("Get ref {} Bytes={} Duration={}s Result={}", + ResultHash, + ObjectRefResult.Bytes, + ObjectRefResult.ElapsedSeconds, + ObjectRefResult.Success); + ApplyResult.Bytes += ObjectRefResult.Bytes; + ApplyResult.ElapsedSeconds += ObjectRefResult.ElapsedSeconds; + ApplyResult.Timepoints["zen-storage-get-ref"] = DateTime::NowTicks(); + + if (!ObjectRefResult.Success) + { + ApplyResult.Error.Reason = "Failed to get result object data"; + return ApplyResult; + } + + CbObject ResultObject = LoadCompactBinaryObject(ObjectRefResult.Response); + ApplyResult.Error.ErrorCode = ResultObject["e"sv].AsInt32(); + StdOutHash = ResultObject["so"sv].AsBinaryAttachment(); + StdErrHash = ResultObject["se"sv].AsBinaryAttachment(); + OutputHash = ResultObject["o"sv].AsObjectAttachment(); + } + + { + std::set<IoHash> NeededData; + if (OutputHash != IoHash::Zero) + { + GetObjectReferencesResult ObjectReferenceResult = + Session.GetObjectReferences(Session.Client().DefaultBlobStoreNamespace(), OutputHash); + Log().debug("Get object references {} References={} Bytes={} Duration={}s Result={}", + ResultHash, + ObjectReferenceResult.References.size(), + ObjectReferenceResult.Bytes, + ObjectReferenceResult.ElapsedSeconds, + ObjectReferenceResult.Success); + ApplyResult.Bytes += ObjectReferenceResult.Bytes; + ApplyResult.ElapsedSeconds += ObjectReferenceResult.ElapsedSeconds; + ApplyResult.Timepoints["zen-storage-get-object-references"] = DateTime::NowTicks(); + + if (!ObjectReferenceResult.Success) + { + ApplyResult.Error.Reason = "Failed to get result object references"; + return ApplyResult; + } + + NeededData = std::move(ObjectReferenceResult.References); + } + + NeededData.insert(OutputHash); + NeededData.insert(StdOutHash); + NeededData.insert(StdErrHash); + + for (const auto& Hash : NeededData) + { + if (Hash == IoHash::Zero) + { + continue; + } + CloudCacheResult BlobResult = Session.GetBlob(Session.Client().DefaultBlobStoreNamespace(), Hash); + Log().debug("Get blob {} Bytes={} Duration={}s Result={}", + Hash, + BlobResult.Bytes, + BlobResult.ElapsedSeconds, + BlobResult.Success); + ApplyResult.Bytes += BlobResult.Bytes; + ApplyResult.ElapsedSeconds += BlobResult.ElapsedSeconds; + if (!BlobResult.Success) + { + ApplyResult.Error.Reason = "Failed to get blob"; + return ApplyResult; + } + BinaryData[Hash] = std::move(BlobResult.Response); + } + ApplyResult.Timepoints["zen-storage-get-blobs"] = DateTime::NowTicks(); + } + + ApplyResult.StdOut = StdOutHash != IoHash::Zero + ? std::string((const char*)BinaryData[StdOutHash].GetData(), BinaryData[StdOutHash].GetSize()) + : ""; + ApplyResult.StdErr = StdErrHash != IoHash::Zero + ? std::string((const char*)BinaryData[StdErrHash].GetData(), BinaryData[StdErrHash].GetSize()) + : ""; + + if (OutputHash == IoHash::Zero) + { + ApplyResult.Error.Reason = "Task completed with no output object"; + return ApplyResult; + } + + CbObject OutputObject = LoadCompactBinaryObject(BinaryData[OutputHash]); + + switch (ApplyType) + { + case UpstreamApplyType::Simple: + { + ResolveMerkleTreeDirectory(""sv, OutputHash, BinaryData, ApplyResult.OutputFiles); + for (const auto& Pair : BinaryData) + { + ApplyResult.FileData[Pair.first] = std::move(BinaryData.at(Pair.first)); + } + + ApplyResult.Success = ApplyResult.Error.ErrorCode == 0; + return ApplyResult; + } + break; + case UpstreamApplyType::Asset: + { + if (ApplyResult.Error.ErrorCode != 0) + { + ApplyResult.Error.Reason = "Task completed with errors"; + return ApplyResult; + } + + // Get build.output + IoHash BuildOutputId; + IoBuffer BuildOutput; + for (auto& It : OutputObject["f"sv]) + { + const CbObjectView FileObject = It.AsObjectView(); + if (FileObject["n"sv].AsString() == "Build.output"sv) + { + BuildOutputId = FileObject["h"sv].AsBinaryAttachment(); + BuildOutput = BinaryData[BuildOutputId]; + break; + } + } + + if (BuildOutput.GetSize() == 0) + { + ApplyResult.Error.Reason = "Build.output file not found in task results"; + return ApplyResult; + } + + // Get Output directory node + IoBuffer OutputDirectoryTree; + for (auto& It : OutputObject["d"sv]) + { + const CbObjectView DirectoryObject = It.AsObjectView(); + if (DirectoryObject["n"sv].AsString() == "Outputs"sv) + { + OutputDirectoryTree = BinaryData[DirectoryObject["h"sv].AsObjectAttachment()]; + break; + } + } + + if (OutputDirectoryTree.GetSize() == 0) + { + ApplyResult.Error.Reason = "Outputs directory not found in task results"; + return ApplyResult; + } + + // load build.output as CbObject + + // Move Outputs from Horde to CbPackage + + std::unordered_map<IoHash, IoHash> CidToCompressedId; + CbPackage OutputPackage; + CbObject OutputDirectoryTreeObject = LoadCompactBinaryObject(OutputDirectoryTree); + + for (auto& It : OutputDirectoryTreeObject["f"sv]) + { + CbObjectView FileObject = It.AsObjectView(); + // Name is the uncompressed hash + IoHash DecompressedId = IoHash::FromHexString(FileObject["n"sv].AsString()); + // Hash is the compressed data hash, and how it is stored in Horde + IoHash CompressedId = FileObject["h"sv].AsBinaryAttachment(); + + if (!BinaryData.contains(CompressedId)) + { + Log().warn("Object attachment chunk not retrieved from Horde {}", CompressedId); + ApplyResult.Error.Reason = "Object attachment chunk not retrieved from Horde"; + return ApplyResult; + } + CidToCompressedId[DecompressedId] = CompressedId; + } + + // Iterate attachments, verify all chunks exist, and add to CbPackage + bool AnyErrors = false; + CbObject BuildOutputObject = LoadCompactBinaryObject(BuildOutput); + BuildOutputObject.IterateAttachments([&](CbFieldView Field) { + const IoHash DecompressedId = Field.AsHash(); + if (!CidToCompressedId.contains(DecompressedId)) + { + Log().warn("Attachment not found {}", DecompressedId); + AnyErrors = true; + return; + } + const IoHash& CompressedId = CidToCompressedId.at(DecompressedId); + + if (!BinaryData.contains(CompressedId)) + { + Log().warn("Missing output {} compressed {} uncompressed", CompressedId, DecompressedId); + AnyErrors = true; + return; + } + + IoHash RawHash; + uint64_t RawSize; + CompressedBuffer AttachmentBuffer = + CompressedBuffer::FromCompressed(SharedBuffer(BinaryData[CompressedId]), RawHash, RawSize); + + if (!AttachmentBuffer || RawHash != DecompressedId) + { + Log().warn( + "Invalid output encountered (not valid CompressedBuffer format) {} compressed {} uncompressed", + CompressedId, + DecompressedId); + AnyErrors = true; + return; + } + + ApplyResult.TotalAttachmentBytes += AttachmentBuffer.GetCompressedSize(); + ApplyResult.TotalRawAttachmentBytes += RawSize; + + CbAttachment Attachment(AttachmentBuffer, DecompressedId); + OutputPackage.AddAttachment(Attachment); + }); + + if (AnyErrors) + { + ApplyResult.Error.Reason = "Failed to get result object attachment data"; + return ApplyResult; + } + + OutputPackage.SetObject(BuildOutputObject); + ApplyResult.OutputPackage = std::move(OutputPackage); + + ApplyResult.Success = ApplyResult.Error.ErrorCode == 0; + return ApplyResult; + } + break; + } + + ApplyResult.Error.Reason = "Unknown apply type"; + return ApplyResult; + } + catch (std::exception& Err) + { + return {.Error{.ErrorCode = -1, .Reason = Err.what()}}; + } + } + + [[nodiscard]] bool ProcessApplyKey(const UpstreamApplyRecord& ApplyRecord, UpstreamData& Data) + { + std::string ExecutablePath; + std::string WorkingDirectory; + std::vector<std::string> Arguments; + std::map<std::string, std::string> Environment; + std::set<std::filesystem::path> InputFiles; + std::set<std::string> Outputs; + std::map<std::filesystem::path, IoHash> InputFileHashes; + + ExecutablePath = ApplyRecord.WorkerDescriptor["path"sv].AsString(); + if (ExecutablePath.empty()) + { + Log().warn("process apply upstream FAILED, '{}', path missing from worker descriptor", + ApplyRecord.WorkerDescriptor.GetHash()); + return false; + } + + WorkingDirectory = ApplyRecord.WorkerDescriptor["workdir"sv].AsString(); + + for (auto& It : ApplyRecord.WorkerDescriptor["executables"sv]) + { + CbObjectView FileEntry = It.AsObjectView(); + if (!ProcessFileEntry(FileEntry, InputFiles, InputFileHashes, Data.CasIds)) + { + return false; + } + } + + for (auto& It : ApplyRecord.WorkerDescriptor["files"sv]) + { + CbObjectView FileEntry = It.AsObjectView(); + if (!ProcessFileEntry(FileEntry, InputFiles, InputFileHashes, Data.CasIds)) + { + return false; + } + } + + for (auto& It : ApplyRecord.WorkerDescriptor["dirs"sv]) + { + std::string_view Directory = It.AsString(); + std::string DummyFile = fmt::format("{}/.zen_empty_file", Directory); + InputFiles.insert(DummyFile); + Data.Blobs[EmptyBufferId] = EmptyBuffer; + InputFileHashes[DummyFile] = EmptyBufferId; + } + + if (!WorkingDirectory.empty()) + { + std::string DummyFile = fmt::format("{}/.zen_empty_file", WorkingDirectory); + InputFiles.insert(DummyFile); + Data.Blobs[EmptyBufferId] = EmptyBuffer; + InputFileHashes[DummyFile] = EmptyBufferId; + } + + for (auto& It : ApplyRecord.WorkerDescriptor["environment"sv]) + { + std::string_view Env = It.AsString(); + auto Index = Env.find('='); + if (Index == std::string_view::npos) + { + Log().warn("process apply upstream FAILED, environment '{}' malformed", Env); + return false; + } + + Environment[std::string(Env.substr(0, Index))] = Env.substr(Index + 1); + } + + switch (ApplyRecord.Type) + { + case UpstreamApplyType::Simple: + { + for (auto& It : ApplyRecord.WorkerDescriptor["arguments"sv]) + { + Arguments.push_back(std::string(It.AsString())); + } + + for (auto& It : ApplyRecord.WorkerDescriptor["outputs"sv]) + { + Outputs.insert(std::string(It.AsString())); + } + } + break; + case UpstreamApplyType::Asset: + { + static const std::filesystem::path BuildActionPath = "Build.action"sv; + static const std::filesystem::path InputPath = "Inputs"sv; + const IoHash ActionId = ApplyRecord.Action.GetHash(); + + Arguments.push_back("-Build=build.action"); + Outputs.insert("Build.output"); + Outputs.insert("Outputs"); + + InputFiles.insert(BuildActionPath); + InputFileHashes[BuildActionPath] = ActionId; + Data.Blobs[ActionId] = IoBufferBuilder::MakeCloneFromMemory(ApplyRecord.Action.GetBuffer().GetData(), + ApplyRecord.Action.GetBuffer().GetSize()); + + bool AnyErrors = false; + ApplyRecord.Action.IterateAttachments([&](CbFieldView Field) { + const IoHash Cid = Field.AsHash(); + const std::filesystem::path FilePath = {InputPath / Cid.ToHexString()}; + + if (!m_CidStore.ContainsChunk(Cid)) + { + Log().warn("process apply upstream FAILED, input CID chunk '{}' missing", Cid); + AnyErrors = true; + return; + } + + if (InputFiles.contains(FilePath)) + { + return; + } + + InputFiles.insert(FilePath); + InputFileHashes[FilePath] = Cid; + Data.Cids.insert(Cid); + }); + + if (AnyErrors) + { + return false; + } + } + break; + } + + const UpstreamDirectory RootDirectory = BuildDirectoryTree(InputFiles); + + CbObject Sandbox = BuildMerkleTreeDirectory(RootDirectory, InputFileHashes, Data.Cids, Data.Objects); + const IoHash SandboxHash = Sandbox.GetHash(); + Data.Objects[SandboxHash] = std::move(Sandbox); + + { + std::string_view HostPlatform = ApplyRecord.WorkerDescriptor["host"sv].AsString(); + if (HostPlatform.empty()) + { + Log().warn("process apply upstream FAILED, 'host' platform not provided"); + return false; + } + + int32_t LogicalCores = ApplyRecord.WorkerDescriptor["cores"sv].AsInt32(); + int64_t Memory = ApplyRecord.WorkerDescriptor["memory"sv].AsInt64(); + bool Exclusive = ApplyRecord.WorkerDescriptor["exclusive"sv].AsBool(); + + std::string Condition = fmt::format("Platform == '{}'", HostPlatform); + if (HostPlatform == "Win64") + { + // TODO + // Condition += " && Pool == 'Win-RemoteExec'"; + } + + std::map<std::string_view, int64_t> Resources; + if (LogicalCores > 0) + { + Resources["LogicalCores"sv] = LogicalCores; + } + if (Memory > 0) + { + Resources["RAM"sv] = std::max(Memory / 1024LL / 1024LL / 1024LL, 1LL); + } + + CbObject Requirements = BuildRequirements(Condition, Resources, Exclusive); + const IoHash RequirementsId = Requirements.GetHash(); + Data.Objects[RequirementsId] = std::move(Requirements); + Data.RequirementsId = RequirementsId; + } + + CbObject Task = BuildTask(ExecutablePath, Arguments, Environment, WorkingDirectory, SandboxHash, Data.RequirementsId, Outputs); + + const IoHash TaskId = Task.GetHash(); + Data.Objects[TaskId] = std::move(Task); + Data.TaskId = TaskId; + + return true; + } + + [[nodiscard]] bool ProcessFileEntry(const CbObjectView& FileEntry, + std::set<std::filesystem::path>& InputFiles, + std::map<std::filesystem::path, IoHash>& InputFileHashes, + std::set<IoHash>& CasIds) + { + const std::filesystem::path FilePath = FileEntry["name"sv].AsString(); + const IoHash ChunkId = FileEntry["hash"sv].AsHash(); + const uint64_t Size = FileEntry["size"sv].AsUInt64(); + + if (!m_CidStore.ContainsChunk(ChunkId)) + { + Log().warn("process apply upstream FAILED, worker CAS chunk '{}' missing", ChunkId); + return false; + } + + if (InputFiles.contains(FilePath)) + { + Log().warn("process apply upstream FAILED, worker CAS chunk '{}' size: {} duplicate filename {}", ChunkId, Size, FilePath); + return false; + } + + InputFiles.insert(FilePath); + InputFileHashes[FilePath] = ChunkId; + CasIds.insert(ChunkId); + return true; + } + + [[nodiscard]] UpstreamDirectory BuildDirectoryTree(const std::set<std::filesystem::path>& InputFiles) + { + static const std::filesystem::path RootPath; + std::map<std::filesystem::path, UpstreamDirectory*> AllDirectories; + UpstreamDirectory RootDirectory = {.Path = RootPath}; + + AllDirectories[RootPath] = &RootDirectory; + + // Build tree from flat list + for (const auto& Path : InputFiles) + { + if (Path.has_parent_path()) + { + if (!AllDirectories.contains(Path.parent_path())) + { + std::stack<std::string> PathSplit; + { + std::filesystem::path ParentPath = Path.parent_path(); + PathSplit.push(ParentPath.filename().string()); + while (ParentPath.has_parent_path()) + { + ParentPath = ParentPath.parent_path(); + PathSplit.push(ParentPath.filename().string()); + } + } + UpstreamDirectory* ParentPtr = &RootDirectory; + while (!PathSplit.empty()) + { + if (!ParentPtr->Directories.contains(PathSplit.top())) + { + std::filesystem::path NewParentPath = {ParentPtr->Path / PathSplit.top()}; + ParentPtr->Directories[PathSplit.top()] = {.Path = NewParentPath}; + AllDirectories[NewParentPath] = &ParentPtr->Directories[PathSplit.top()]; + } + ParentPtr = &ParentPtr->Directories[PathSplit.top()]; + PathSplit.pop(); + } + } + + AllDirectories[Path.parent_path()]->Files.insert(Path.filename().string()); + } + else + { + RootDirectory.Files.insert(Path.filename().string()); + } + } + + return RootDirectory; + } + + [[nodiscard]] CbObject BuildMerkleTreeDirectory(const UpstreamDirectory& RootDirectory, + const std::map<std::filesystem::path, IoHash>& InputFileHashes, + const std::set<IoHash>& Cids, + std::map<IoHash, CbObject>& Objects) + { + CbObjectWriter DirectoryTreeWriter; + + if (!RootDirectory.Files.empty()) + { + DirectoryTreeWriter.BeginArray("f"sv); + for (const auto& File : RootDirectory.Files) + { + const std::filesystem::path FilePath = {RootDirectory.Path / File}; + const IoHash& FileHash = InputFileHashes.at(FilePath); + const bool Compressed = Cids.contains(FileHash); + DirectoryTreeWriter.BeginObject(); + DirectoryTreeWriter.AddString("n"sv, File); + DirectoryTreeWriter.AddBinaryAttachment("h"sv, FileHash); + DirectoryTreeWriter.AddBool("c"sv, Compressed); + DirectoryTreeWriter.EndObject(); + } + DirectoryTreeWriter.EndArray(); + } + + if (!RootDirectory.Directories.empty()) + { + DirectoryTreeWriter.BeginArray("d"sv); + for (const auto& Item : RootDirectory.Directories) + { + CbObject Directory = BuildMerkleTreeDirectory(Item.second, InputFileHashes, Cids, Objects); + const IoHash DirectoryHash = Directory.GetHash(); + Objects[DirectoryHash] = std::move(Directory); + + DirectoryTreeWriter.BeginObject(); + DirectoryTreeWriter.AddString("n"sv, Item.first); + DirectoryTreeWriter.AddObjectAttachment("h"sv, DirectoryHash); + DirectoryTreeWriter.EndObject(); + } + DirectoryTreeWriter.EndArray(); + } + + return DirectoryTreeWriter.Save(); + } + + void ResolveMerkleTreeDirectory(const std::filesystem::path& ParentDirectory, + const IoHash& DirectoryHash, + const std::map<IoHash, IoBuffer>& Objects, + std::map<std::filesystem::path, IoHash>& OutputFiles) + { + CbObject Directory = LoadCompactBinaryObject(Objects.at(DirectoryHash)); + + for (auto& It : Directory["f"sv]) + { + const CbObjectView FileObject = It.AsObjectView(); + const std::filesystem::path Path = ParentDirectory / FileObject["n"sv].AsString(); + + OutputFiles[Path] = FileObject["h"sv].AsBinaryAttachment(); + } + + for (auto& It : Directory["d"sv]) + { + const CbObjectView DirectoryObject = It.AsObjectView(); + + ResolveMerkleTreeDirectory(ParentDirectory / DirectoryObject["n"sv].AsString(), + DirectoryObject["h"sv].AsObjectAttachment(), + Objects, + OutputFiles); + } + } + + [[nodiscard]] CbObject BuildRequirements(const std::string_view Condition, + const std::map<std::string_view, int64_t>& Resources, + const bool Exclusive) + { + CbObjectWriter Writer; + Writer.AddString("c", Condition); + if (!Resources.empty()) + { + Writer.BeginArray("r"); + for (const auto& Resource : Resources) + { + Writer.BeginArray(); + Writer.AddString(Resource.first); + Writer.AddInteger(Resource.second); + Writer.EndArray(); + } + Writer.EndArray(); + } + Writer.AddBool("e", Exclusive); + return Writer.Save(); + } + + [[nodiscard]] CbObject BuildTask(const std::string_view Executable, + const std::vector<std::string>& Arguments, + const std::map<std::string, std::string>& Environment, + const std::string_view WorkingDirectory, + const IoHash& SandboxHash, + const IoHash& RequirementsId, + const std::set<std::string>& Outputs) + { + CbObjectWriter TaskWriter; + TaskWriter.AddString("e"sv, Executable); + + if (!Arguments.empty()) + { + TaskWriter.BeginArray("a"sv); + for (const auto& Argument : Arguments) + { + TaskWriter.AddString(Argument); + } + TaskWriter.EndArray(); + } + + if (!Environment.empty()) + { + TaskWriter.BeginArray("v"sv); + for (const auto& Env : Environment) + { + TaskWriter.BeginArray(); + TaskWriter.AddString(Env.first); + TaskWriter.AddString(Env.second); + TaskWriter.EndArray(); + } + TaskWriter.EndArray(); + } + + if (!WorkingDirectory.empty()) + { + TaskWriter.AddString("w"sv, WorkingDirectory); + } + + TaskWriter.AddObjectAttachment("s"sv, SandboxHash); + TaskWriter.AddObjectAttachment("r"sv, RequirementsId); + + // Outputs + if (!Outputs.empty()) + { + TaskWriter.BeginArray("o"sv); + for (const auto& Output : Outputs) + { + TaskWriter.AddString(Output); + } + TaskWriter.EndArray(); + } + + return TaskWriter.Save(); + } + }; +} // namespace detail + +////////////////////////////////////////////////////////////////////////// + +std::unique_ptr<UpstreamApplyEndpoint> +UpstreamApplyEndpoint::CreateHordeEndpoint(const CloudCacheClientOptions& ComputeOptions, + const UpstreamAuthConfig& ComputeAuthConfig, + const CloudCacheClientOptions& StorageOptions, + const UpstreamAuthConfig& StorageAuthConfig, + CidStore& CidStore, + AuthMgr& Mgr) +{ + return std::make_unique<detail::HordeUpstreamApplyEndpoint>(ComputeOptions, + ComputeAuthConfig, + StorageOptions, + StorageAuthConfig, + CidStore, + Mgr); +} + +} // namespace zen + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zenserver/upstream/jupiter.cpp b/src/zenserver/upstream/jupiter.cpp new file mode 100644 index 000000000..dbb185bec --- /dev/null +++ b/src/zenserver/upstream/jupiter.cpp @@ -0,0 +1,965 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "jupiter.h" + +#include "diag/formatters.h" +#include "diag/logging.h" + +#include <zencore/compactbinary.h> +#include <zencore/compositebuffer.h> +#include <zencore/iobuffer.h> +#include <zencore/iohash.h> +#include <zencore/string.h> +#include <zencore/thread.h> +#include <zencore/trace.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <cpr/cpr.h> +#include <fmt/format.h> +ZEN_THIRD_PARTY_INCLUDES_END + +#if ZEN_PLATFORM_WINDOWS +# pragma comment(lib, "Crypt32.lib") +# pragma comment(lib, "Wldap32.lib") +#endif + +#include <json11.hpp> + +using namespace std::literals; + +namespace zen { + +namespace detail { + struct CloudCacheSessionState + { + CloudCacheSessionState(CloudCacheClient& Client) : m_Client(Client) {} + + const CloudCacheAccessToken& GetAccessToken(bool RefreshToken) + { + if (RefreshToken) + { + m_AccessToken = m_Client.AcquireAccessToken(); + } + + return m_AccessToken; + } + + cpr::Session& GetSession() { return m_Session; } + + void Reset(std::chrono::milliseconds ConnectTimeout, std::chrono::milliseconds Timeout) + { + m_Session.SetBody({}); + m_Session.SetHeader({}); + m_Session.SetConnectTimeout(ConnectTimeout); + m_Session.SetTimeout(Timeout); + } + + private: + friend class zen::CloudCacheClient; + + CloudCacheClient& m_Client; + CloudCacheAccessToken m_AccessToken; + cpr::Session m_Session; + }; + +} // namespace detail + +CloudCacheSession::CloudCacheSession(CloudCacheClient* CacheClient) : m_Log(CacheClient->Logger()), m_CacheClient(CacheClient) +{ + m_SessionState = m_CacheClient->AllocSessionState(); +} + +CloudCacheSession::~CloudCacheSession() +{ + m_CacheClient->FreeSessionState(m_SessionState); +} + +CloudCacheResult +CloudCacheSession::Authenticate() +{ + const bool RefreshToken = true; + const CloudCacheAccessToken& AccessToken = GetAccessToken(RefreshToken); + + return {.Success = AccessToken.IsValid()}; +} + +CloudCacheResult +CloudCacheSession::GetRef(std::string_view Namespace, std::string_view BucketId, const IoHash& Key, ZenContentType RefType) +{ + const std::string ContentType = RefType == ZenContentType::kCbObject ? "application/x-ue-cb" : "application/octet-stream"; + + ExtendableStringBuilder<256> Uri; + Uri << m_CacheClient->ServiceUrl() << "/api/v1/refs/" << Namespace << "/" << BucketId << "/" << Key.ToHexString(); + + cpr::Session& Session = GetSession(); + const CloudCacheAccessToken& AccessToken = GetAccessToken(); + + Session.SetOption(cpr::Url{Uri.c_str()}); + Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, {"Accept", ContentType}}); + Session.SetOption(cpr::Body{}); + + cpr::Response Response = Session.Get(); + ZEN_DEBUG("GET {}", Response); + + if (Response.error) + { + return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message}; + } + else if (!VerifyAccessToken(Response.status_code)) + { + return {.ErrorCode = 401, .Reason = std::string("Invalid access token")}; + } + + const bool Success = Response.status_code == 200; + const IoBuffer Buffer = Success ? IoBufferBuilder::MakeCloneFromMemory(Response.text.data(), Response.text.size()) : IoBuffer(); + + return {.Response = Buffer, .Bytes = Response.downloaded_bytes, .ElapsedSeconds = Response.elapsed, .Success = Success}; +} + +CloudCacheResult +CloudCacheSession::GetBlob(std::string_view Namespace, const IoHash& Key) +{ + ExtendableStringBuilder<256> Uri; + Uri << m_CacheClient->ServiceUrl() << "/api/v1/blobs/" << Namespace << "/" << Key.ToHexString(); + + cpr::Session& Session = GetSession(); + const CloudCacheAccessToken& AccessToken = GetAccessToken(); + + Session.SetOption(cpr::Url{Uri.c_str()}); + Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, {"Accept", "application/octet-stream"}}); + Session.SetOption(cpr::Body{}); + + cpr::Response Response = Session.Get(); + ZEN_DEBUG("GET {}", Response); + + if (Response.error) + { + return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message}; + } + else if (!VerifyAccessToken(Response.status_code)) + { + return {.ErrorCode = 401, .Reason = std::string("Invalid access token")}; + } + + const bool Success = Response.status_code == 200; + const IoBuffer Buffer = + Success && Response.text.size() > 0 ? IoBufferBuilder::MakeCloneFromMemory(Response.text.data(), Response.text.size()) : IoBuffer(); + + return {.Response = Buffer, .Bytes = Response.downloaded_bytes, .ElapsedSeconds = Response.elapsed, .Success = Success}; +} + +CloudCacheResult +CloudCacheSession::GetCompressedBlob(std::string_view Namespace, const IoHash& Key) +{ + ZEN_TRACE_CPU("HordeClient::GetCompressedBlob"); + + ExtendableStringBuilder<256> Uri; + Uri << m_CacheClient->ServiceUrl() << "/api/v1/compressed-blobs/" << Namespace << "/" << Key.ToHexString(); + + cpr::Session& Session = GetSession(); + const CloudCacheAccessToken& AccessToken = GetAccessToken(); + + Session.SetOption(cpr::Url{Uri.c_str()}); + Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, {"Accept", "application/x-ue-comp"}}); + Session.SetOption(cpr::Body{}); + + cpr::Response Response = Session.Get(); + ZEN_DEBUG("GET {}", Response); + + if (Response.error) + { + return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message}; + } + else if (!VerifyAccessToken(Response.status_code)) + { + return {.ErrorCode = 401, .Reason = std::string("Invalid access token")}; + } + + const bool Success = Response.status_code == 200; + const IoBuffer Buffer = Success ? IoBufferBuilder::MakeCloneFromMemory(Response.text.data(), Response.text.size()) : IoBuffer(); + + return {.Response = Buffer, .Bytes = Response.downloaded_bytes, .ElapsedSeconds = Response.elapsed, .Success = Success}; +} + +CloudCacheResult +CloudCacheSession::GetInlineBlob(std::string_view Namespace, std::string_view BucketId, const IoHash& Key, IoHash& OutPayloadHash) +{ + ZEN_TRACE_CPU("HordeClient::GetInlineBlob"); + + ExtendableStringBuilder<256> Uri; + Uri << m_CacheClient->ServiceUrl() << "/api/v1/refs/" << Namespace << "/" << BucketId << "/" << Key.ToHexString(); + + cpr::Session& Session = GetSession(); + const CloudCacheAccessToken& AccessToken = GetAccessToken(); + + Session.SetOption(cpr::Url{Uri.c_str()}); + Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, {"Accept", "application/x-jupiter-inline"}}); + Session.SetOption(cpr::Body{}); + + cpr::Response Response = Session.Get(); + ZEN_DEBUG("GET {}", Response); + + if (Response.error) + { + return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message}; + } + else if (!VerifyAccessToken(Response.status_code)) + { + return {.ErrorCode = 401, .Reason = std::string("Invalid access token")}; + } + + const bool Success = Response.status_code == 200; + const IoBuffer Buffer = Success ? IoBufferBuilder::MakeCloneFromMemory(Response.text.data(), Response.text.size()) : IoBuffer(); + if (auto It = Response.header.find("X-Jupiter-InlinePayloadHash"); It != Response.header.end()) + { + const std::string& PayloadHashHeader = It->second; + if (PayloadHashHeader.length() == IoHash::StringLength) + { + OutPayloadHash = IoHash::FromHexString(PayloadHashHeader); + } + } + + return {.Response = Buffer, .Bytes = Response.downloaded_bytes, .ElapsedSeconds = Response.elapsed, .Success = Success}; +} + +CloudCacheResult +CloudCacheSession::GetObject(std::string_view Namespace, const IoHash& Key) +{ + ZEN_TRACE_CPU("HordeClient::GetObject"); + + ExtendableStringBuilder<256> Uri; + Uri << m_CacheClient->ServiceUrl() << "/api/v1/objects/" << Namespace << "/" << Key.ToHexString(); + + cpr::Session& Session = GetSession(); + const CloudCacheAccessToken& AccessToken = GetAccessToken(); + + Session.SetOption(cpr::Url{Uri.c_str()}); + Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, {"Accept", "application/x-ue-cb"}}); + Session.SetOption(cpr::Body{}); + + cpr::Response Response = Session.Get(); + ZEN_DEBUG("GET {}", Response); + + if (Response.error) + { + return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message}; + } + else if (!VerifyAccessToken(Response.status_code)) + { + return {.ErrorCode = 401, .Reason = std::string("Invalid access token")}; + } + + const bool Success = Response.status_code == 200; + const IoBuffer Buffer = Success ? IoBufferBuilder::MakeCloneFromMemory(Response.text.data(), Response.text.size()) : IoBuffer(); + + return {.Response = Buffer, .Bytes = Response.downloaded_bytes, .ElapsedSeconds = Response.elapsed, .Success = Success}; +} + +PutRefResult +CloudCacheSession::PutRef(std::string_view Namespace, std::string_view BucketId, const IoHash& Key, IoBuffer Ref, ZenContentType RefType) +{ + ZEN_TRACE_CPU("HordeClient::PutRef"); + + IoHash Hash = IoHash::HashBuffer(Ref.Data(), Ref.Size()); + + const std::string ContentType = RefType == ZenContentType::kCbObject ? "application/x-ue-cb" : "application/octet-stream"; + + ExtendableStringBuilder<256> Uri; + Uri << m_CacheClient->ServiceUrl() << "/api/v1/refs/" << Namespace << "/" << BucketId << "/" << Key.ToHexString(); + + cpr::Session& Session = GetSession(); + const CloudCacheAccessToken& AccessToken = GetAccessToken(); + + Session.SetOption(cpr::Url{Uri.c_str()}); + Session.SetOption( + cpr::Header{{"Authorization", AccessToken.Value}, {"X-Jupiter-IoHash", Hash.ToHexString()}, {"Content-Type", ContentType}}); + Session.SetBody(cpr::Body{(const char*)Ref.Data(), Ref.Size()}); + + cpr::Response Response = Session.Put(); + ZEN_DEBUG("PUT {}", Response); + + if (Response.error) + { + PutRefResult Result; + Result.ErrorCode = static_cast<int32_t>(Response.error.code); + Result.Reason = std::move(Response.error.message); + return Result; + } + else if (!VerifyAccessToken(Response.status_code)) + { + PutRefResult Result; + Result.ErrorCode = 401; + Result.Reason = "Invalid access token"sv; + return Result; + } + + PutRefResult Result; + Result.Success = (Response.status_code == 200 || Response.status_code == 201); + Result.Bytes = Response.uploaded_bytes; + Result.ElapsedSeconds = Response.elapsed; + + if (Result.Success) + { + std::string JsonError; + json11::Json Json = json11::Json::parse(Response.text, JsonError); + if (JsonError.empty()) + { + json11::Json::array Needs = Json["needs"].array_items(); + for (const auto& Need : Needs) + { + Result.Needs.emplace_back(IoHash::FromHexString(Need.string_value())); + } + } + } + + return Result; +} + +FinalizeRefResult +CloudCacheSession::FinalizeRef(std::string_view Namespace, std::string_view BucketId, const IoHash& Key, const IoHash& RefHash) +{ + ZEN_TRACE_CPU("HordeClient::FinalizeRef"); + + ExtendableStringBuilder<256> Uri; + Uri << m_CacheClient->ServiceUrl() << "/api/v1/refs/" << Namespace << "/" << BucketId << "/" << Key.ToHexString() << "/finalize/" + << RefHash.ToHexString(); + + cpr::Session& Session = GetSession(); + const CloudCacheAccessToken& AccessToken = GetAccessToken(); + + Session.SetOption(cpr::Url{Uri.c_str()}); + Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, + {"X-Jupiter-IoHash", RefHash.ToHexString()}, + {"Content-Type", "application/x-ue-cb"}}); + Session.SetBody(cpr::Body{}); + + cpr::Response Response = Session.Post(); + ZEN_DEBUG("POST {}", Response); + + if (Response.error) + { + FinalizeRefResult Result; + Result.ErrorCode = static_cast<int32_t>(Response.error.code); + Result.Reason = std::move(Response.error.message); + return Result; + } + else if (!VerifyAccessToken(Response.status_code)) + { + FinalizeRefResult Result; + Result.ErrorCode = 401; + Result.Reason = "Invalid access token"sv; + return Result; + } + + FinalizeRefResult Result; + Result.Success = (Response.status_code == 200 || Response.status_code == 201); + Result.Bytes = Response.uploaded_bytes; + Result.ElapsedSeconds = Response.elapsed; + + if (Result.Success) + { + std::string JsonError; + json11::Json Json = json11::Json::parse(Response.text, JsonError); + if (JsonError.empty()) + { + json11::Json::array Needs = Json["needs"].array_items(); + for (const auto& Need : Needs) + { + Result.Needs.emplace_back(IoHash::FromHexString(Need.string_value())); + } + } + } + + return Result; +} + +CloudCacheResult +CloudCacheSession::PutBlob(std::string_view Namespace, const IoHash& Key, IoBuffer Blob) +{ + ZEN_TRACE_CPU("HordeClient::PutBlob"); + + ExtendableStringBuilder<256> Uri; + Uri << m_CacheClient->ServiceUrl() << "/api/v1/blobs/" << Namespace << "/" << Key.ToHexString(); + + cpr::Session& Session = GetSession(); + const CloudCacheAccessToken& AccessToken = GetAccessToken(); + + Session.SetOption(cpr::Url{Uri.c_str()}); + Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, {"Content-Type", "application/octet-stream"}}); + Session.SetBody(cpr::Body{(const char*)Blob.Data(), Blob.Size()}); + + cpr::Response Response = Session.Put(); + ZEN_DEBUG("PUT {}", Response); + + if (Response.error) + { + return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message}; + } + else if (!VerifyAccessToken(Response.status_code)) + { + return {.ErrorCode = 401, .Reason = std::string("Invalid access token")}; + } + + return {.Bytes = Response.uploaded_bytes, + .ElapsedSeconds = Response.elapsed, + .Success = (Response.status_code == 200 || Response.status_code == 201)}; +} + +CloudCacheResult +CloudCacheSession::PutCompressedBlob(std::string_view Namespace, const IoHash& Key, IoBuffer Blob) +{ + ZEN_TRACE_CPU("HordeClient::PutCompressedBlob"); + + ExtendableStringBuilder<256> Uri; + Uri << m_CacheClient->ServiceUrl() << "/api/v1/compressed-blobs/" << Namespace << "/" << Key.ToHexString(); + + cpr::Session& Session = GetSession(); + const CloudCacheAccessToken& AccessToken = GetAccessToken(); + + Session.SetOption(cpr::Url{Uri.c_str()}); + Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, {"Content-Type", "application/x-ue-comp"}}); + Session.SetBody(cpr::Body{(const char*)Blob.Data(), Blob.Size()}); + + cpr::Response Response = Session.Put(); + ZEN_DEBUG("PUT {}", Response); + + if (Response.error) + { + return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message}; + } + else if (!VerifyAccessToken(Response.status_code)) + { + return {.ErrorCode = 401, .Reason = std::string("Invalid access token")}; + } + + return {.Bytes = Response.uploaded_bytes, + .ElapsedSeconds = Response.elapsed, + .Success = (Response.status_code == 200 || Response.status_code == 201)}; +} + +CloudCacheResult +CloudCacheSession::PutCompressedBlob(std::string_view Namespace, const IoHash& Key, const CompositeBuffer& Payload) +{ + ZEN_TRACE_CPU("HordeClient::PutCompressedBlob"); + + ExtendableStringBuilder<256> Uri; + Uri << m_CacheClient->ServiceUrl() << "/api/v1/compressed-blobs/" << Namespace << "/" << Key.ToHexString(); + + cpr::Session& Session = GetSession(); + const CloudCacheAccessToken& AccessToken = GetAccessToken(); + + Session.SetOption(cpr::Url{Uri.c_str()}); + Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, {"Content-Type", "application/x-ue-comp"}}); + uint64_t SizeLeft = Payload.GetSize(); + CompositeBuffer::Iterator BufferIt = Payload.GetIterator(0); + auto ReadCallback = [&Payload, &BufferIt, &SizeLeft](char* buffer, size_t& size, intptr_t) { + size = Min<size_t>(size, SizeLeft); + MutableMemoryView Data(buffer, size); + Payload.CopyTo(Data, BufferIt); + SizeLeft -= size; + return true; + }; + Session.SetReadCallback(cpr::ReadCallback(gsl::narrow<cpr::cpr_off_t>(SizeLeft), ReadCallback)); + + cpr::Response Response = Session.Put(); + ZEN_DEBUG("PUT {}", Response); + + if (Response.error) + { + return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message}; + } + else if (!VerifyAccessToken(Response.status_code)) + { + return {.ErrorCode = 401, .Reason = std::string("Invalid access token")}; + } + + return {.Bytes = Response.uploaded_bytes, + .ElapsedSeconds = Response.elapsed, + .Success = (Response.status_code == 200 || Response.status_code == 201)}; +} + +CloudCacheResult +CloudCacheSession::PutObject(std::string_view Namespace, const IoHash& Key, IoBuffer Object) +{ + ZEN_TRACE_CPU("HordeClient::PutObject"); + + ExtendableStringBuilder<256> Uri; + Uri << m_CacheClient->ServiceUrl() << "/api/v1/objects/" << Namespace << "/" << Key.ToHexString(); + + cpr::Session& Session = GetSession(); + const CloudCacheAccessToken& AccessToken = GetAccessToken(); + + Session.SetOption(cpr::Url{Uri.c_str()}); + Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, {"Content-Type", "application/x-ue-cb"}}); + Session.SetBody(cpr::Body{(const char*)Object.Data(), Object.Size()}); + + cpr::Response Response = Session.Put(); + ZEN_DEBUG("PUT {}", Response); + + if (Response.error) + { + return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message}; + } + else if (!VerifyAccessToken(Response.status_code)) + { + return {.ErrorCode = 401, .Reason = std::string("Invalid access token")}; + } + + return {.Bytes = Response.uploaded_bytes, + .ElapsedSeconds = Response.elapsed, + .Success = (Response.status_code == 200 || Response.status_code == 201)}; +} + +CloudCacheResult +CloudCacheSession::RefExists(std::string_view Namespace, std::string_view BucketId, const IoHash& Key) +{ + ZEN_TRACE_CPU("HordeClient::RefExists"); + + ExtendableStringBuilder<256> Uri; + Uri << m_CacheClient->ServiceUrl() << "/api/v1/refs/" << Namespace << "/" << BucketId << "/" << Key.ToHexString(); + + cpr::Session& Session = GetSession(); + const CloudCacheAccessToken& AccessToken = GetAccessToken(); + + Session.SetOption(cpr::Url{Uri.c_str()}); + Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}}); + Session.SetOption(cpr::Body{}); + + cpr::Response Response = Session.Head(); + ZEN_DEBUG("HEAD {}", Response); + + if (Response.error) + { + return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message}; + } + else if (!VerifyAccessToken(Response.status_code)) + { + return {.ErrorCode = 401, .Reason = std::string("Invalid access token")}; + } + + return {.ElapsedSeconds = Response.elapsed, .Success = Response.status_code == 200}; +} + +GetObjectReferencesResult +CloudCacheSession::GetObjectReferences(std::string_view Namespace, const IoHash& Key) +{ + ZEN_TRACE_CPU("HordeClient::GetObjectReferences"); + + ExtendableStringBuilder<256> Uri; + Uri << m_CacheClient->ServiceUrl() << "/api/v1/objects/" << Namespace << "/" << Key.ToHexString() << "/references"; + + cpr::Session& Session = GetSession(); + const CloudCacheAccessToken& AccessToken = GetAccessToken(); + + Session.SetOption(cpr::Url{Uri.c_str()}); + Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, {"Accept", "application/x-ue-cb"}}); + Session.SetOption(cpr::Body{}); + + cpr::Response Response = Session.Get(); + ZEN_DEBUG("GET {}", Response); + + if (Response.error) + { + return {CloudCacheResult{.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message}}; + } + else if (!VerifyAccessToken(Response.status_code)) + { + return {CloudCacheResult{.ErrorCode = 401, .Reason = std::string("Invalid access token")}}; + } + + GetObjectReferencesResult Result{ + CloudCacheResult{.Bytes = Response.downloaded_bytes, .ElapsedSeconds = Response.elapsed, .Success = Response.status_code == 200}}; + + if (Result.Success) + { + IoBuffer Buffer = IoBuffer(zen::IoBuffer::Wrap, Response.text.data(), Response.text.size()); + const CbObject ReferencesResponse = LoadCompactBinaryObject(Buffer); + for (auto& Item : ReferencesResponse["references"sv]) + { + Result.References.insert(Item.AsHash()); + } + } + + return Result; +} + +CloudCacheResult +CloudCacheSession::BlobExists(std::string_view Namespace, const IoHash& Key) +{ + return CacheTypeExists(Namespace, "blobs"sv, Key); +} + +CloudCacheResult +CloudCacheSession::CompressedBlobExists(std::string_view Namespace, const IoHash& Key) +{ + return CacheTypeExists(Namespace, "compressed-blobs"sv, Key); +} + +CloudCacheResult +CloudCacheSession::ObjectExists(std::string_view Namespace, const IoHash& Key) +{ + return CacheTypeExists(Namespace, "objects"sv, Key); +} + +CloudCacheExistsResult +CloudCacheSession::BlobExists(std::string_view Namespace, const std::set<IoHash>& Keys) +{ + return CacheTypeExists(Namespace, "blobs"sv, Keys); +} + +CloudCacheExistsResult +CloudCacheSession::CompressedBlobExists(std::string_view Namespace, const std::set<IoHash>& Keys) +{ + return CacheTypeExists(Namespace, "compressed-blobs"sv, Keys); +} + +CloudCacheExistsResult +CloudCacheSession::ObjectExists(std::string_view Namespace, const std::set<IoHash>& Keys) +{ + return CacheTypeExists(Namespace, "objects"sv, Keys); +} + +CloudCacheResult +CloudCacheSession::PostComputeTasks(IoBuffer TasksData) +{ + ZEN_TRACE_CPU("HordeClient::PostComputeTasks"); + + ExtendableStringBuilder<256> Uri; + Uri << m_CacheClient->ServiceUrl() << "/api/v1/compute/" << m_CacheClient->ComputeCluster(); + + cpr::Session& Session = GetSession(); + const CloudCacheAccessToken& AccessToken = GetAccessToken(); + + Session.SetOption(cpr::Url{Uri.c_str()}); + Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, {"Content-Type", "application/x-ue-cb"}}); + Session.SetBody(cpr::Body{(const char*)TasksData.Data(), TasksData.Size()}); + + cpr::Response Response = Session.Post(); + ZEN_DEBUG("POST {}", Response); + + if (Response.error) + { + return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message}; + } + else if (!VerifyAccessToken(Response.status_code)) + { + return {.ErrorCode = 401, .Reason = std::string("Invalid access token")}; + } + + return {.Bytes = Response.uploaded_bytes, .ElapsedSeconds = Response.elapsed, .Success = Response.status_code == 200}; +} + +CloudCacheResult +CloudCacheSession::GetComputeUpdates(std::string_view ChannelId, const uint32_t WaitSeconds) +{ + ZEN_TRACE_CPU("HordeClient::GetComputeUpdates"); + + ExtendableStringBuilder<256> Uri; + Uri << m_CacheClient->ServiceUrl() << "/api/v1/compute/" << m_CacheClient->ComputeCluster() << "/updates/" << ChannelId + << "?wait=" << WaitSeconds; + + cpr::Session& Session = GetSession(); + const CloudCacheAccessToken& AccessToken = GetAccessToken(); + + Session.SetOption(cpr::Url{Uri.c_str()}); + Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, {"Accept", "application/x-ue-cb"}}); + Session.SetOption(cpr::Body{}); + + cpr::Response Response = Session.Post(); + ZEN_DEBUG("POST {}", Response); + + if (Response.error) + { + return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message}; + } + else if (!VerifyAccessToken(Response.status_code)) + { + return {.ErrorCode = 401, .Reason = std::string("Invalid access token")}; + } + + const bool Success = Response.status_code == 200; + const IoBuffer Buffer = Success ? IoBufferBuilder::MakeCloneFromMemory(Response.text.data(), Response.text.size()) : IoBuffer(); + + return {.Response = Buffer, .Bytes = Response.downloaded_bytes, .ElapsedSeconds = Response.elapsed, .Success = Success}; +} + +std::vector<IoHash> +CloudCacheSession::Filter(std::string_view Namespace, std::string_view BucketId, const std::vector<IoHash>& ChunkHashes) +{ + ExtendableStringBuilder<256> Uri; + Uri << m_CacheClient->ServiceUrl(); + Uri << "/api/v1/s/" << Namespace; + + ZEN_UNUSED(BucketId, ChunkHashes); + + return {}; +} + +cpr::Session& +CloudCacheSession::GetSession() +{ + return m_SessionState->GetSession(); +} + +CloudCacheAccessToken +CloudCacheSession::GetAccessToken(bool RefreshToken) +{ + return m_SessionState->GetAccessToken(RefreshToken); +} + +bool +CloudCacheSession::VerifyAccessToken(long StatusCode) +{ + return StatusCode != 401; +} + +CloudCacheResult +CloudCacheSession::CacheTypeExists(std::string_view Namespace, std::string_view TypeId, const IoHash& Key) +{ + ZEN_TRACE_CPU("HordeClient::CacheTypeExists"); + + ExtendableStringBuilder<256> Uri; + Uri << m_CacheClient->ServiceUrl() << "/api/v1/" << TypeId << "/" << Namespace << "/" << Key.ToHexString(); + + cpr::Session& Session = GetSession(); + const CloudCacheAccessToken& AccessToken = GetAccessToken(); + + Session.SetOption(cpr::Url{Uri.c_str()}); + Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}}); + Session.SetOption(cpr::Body{}); + + cpr::Response Response = Session.Head(); + ZEN_DEBUG("HEAD {}", Response); + + if (Response.error) + { + return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message}; + } + else if (!VerifyAccessToken(Response.status_code)) + { + return {.ErrorCode = 401, .Reason = std::string("Invalid access token")}; + } + + return {.ElapsedSeconds = Response.elapsed, .Success = Response.status_code == 200}; +} + +CloudCacheExistsResult +CloudCacheSession::CacheTypeExists(std::string_view Namespace, std::string_view TypeId, const std::set<IoHash>& Keys) +{ + ZEN_TRACE_CPU("HordeClient::CacheTypeExists"); + + ExtendableStringBuilder<256> Body; + Body << "["; + for (const auto& Key : Keys) + { + Body << (Body.Size() != 1 ? ",\"" : "\"") << Key.ToHexString() << "\""; + } + Body << "]"; + + ExtendableStringBuilder<256> Uri; + Uri << m_CacheClient->ServiceUrl() << "/api/v1/" << TypeId << "/" << Namespace << "/exist"; + + cpr::Session& Session = GetSession(); + const CloudCacheAccessToken& AccessToken = GetAccessToken(); + + Session.SetOption(cpr::Url{Uri.c_str()}); + Session.SetOption( + cpr::Header{{"Authorization", AccessToken.Value}, {"Accept", "application/x-ue-cb"}, {"Content-Type", "application/json"}}); + Session.SetOption(cpr::Body(Body.ToString())); + + cpr::Response Response = Session.Post(); + ZEN_DEBUG("POST {}", Response); + + if (Response.error) + { + return {CloudCacheResult{.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message}}; + } + else if (!VerifyAccessToken(Response.status_code)) + { + return {CloudCacheResult{.ErrorCode = 401, .Reason = std::string("Invalid access token")}}; + } + + CloudCacheExistsResult Result{ + CloudCacheResult{.Bytes = Response.downloaded_bytes, .ElapsedSeconds = Response.elapsed, .Success = Response.status_code == 200}}; + + if (Result.Success) + { + IoBuffer Buffer = IoBuffer(zen::IoBuffer::Wrap, Response.text.data(), Response.text.size()); + const CbObject ExistsResponse = LoadCompactBinaryObject(Buffer); + for (auto& Item : ExistsResponse["needs"sv]) + { + Result.Needs.insert(Item.AsHash()); + } + } + + return Result; +} + +/** + * An access token provider that holds a token that will never change. + */ +class StaticTokenProvider final : public CloudCacheTokenProvider +{ +public: + StaticTokenProvider(CloudCacheAccessToken Token) : m_Token(std::move(Token)) {} + + virtual ~StaticTokenProvider() = default; + + virtual CloudCacheAccessToken AcquireAccessToken() final override { return m_Token; } + +private: + CloudCacheAccessToken m_Token; +}; + +std::unique_ptr<CloudCacheTokenProvider> +CloudCacheTokenProvider::CreateFromStaticToken(CloudCacheAccessToken Token) +{ + return std::make_unique<StaticTokenProvider>(std::move(Token)); +} + +class OAuthClientCredentialsTokenProvider final : public CloudCacheTokenProvider +{ +public: + OAuthClientCredentialsTokenProvider(const CloudCacheTokenProvider::OAuthClientCredentialsParams& Params) + { + m_Url = std::string(Params.Url); + m_ClientId = std::string(Params.ClientId); + m_ClientSecret = std::string(Params.ClientSecret); + } + + virtual ~OAuthClientCredentialsTokenProvider() = default; + + virtual CloudCacheAccessToken AcquireAccessToken() final override + { + using namespace std::chrono; + + std::string Body = + fmt::format("client_id={}&scope=cache_access&grant_type=client_credentials&client_secret={}", m_ClientId, m_ClientSecret); + + cpr::Response Response = + cpr::Post(cpr::Url{m_Url}, cpr::Header{{"Content-Type", "application/x-www-form-urlencoded"}}, cpr::Body{std::move(Body)}); + + if (Response.error || Response.status_code != 200) + { + return {}; + } + + std::string JsonError; + json11::Json Json = json11::Json::parse(Response.text, JsonError); + + if (JsonError.empty() == false) + { + return {}; + } + + std::string Token = Json["access_token"].string_value(); + int64_t ExpiresInSeconds = static_cast<int64_t>(Json["expires_in"].int_value()); + CloudCacheAccessToken::TimePoint ExpireTime = CloudCacheAccessToken::Clock::now() + seconds(ExpiresInSeconds); + + return {.Value = fmt::format("Bearer {}", Token), .ExpireTime = ExpireTime}; + } + +private: + std::string m_Url; + std::string m_ClientId; + std::string m_ClientSecret; +}; + +std::unique_ptr<CloudCacheTokenProvider> +CloudCacheTokenProvider::CreateFromOAuthClientCredentials(const OAuthClientCredentialsParams& Params) +{ + return std::make_unique<OAuthClientCredentialsTokenProvider>(Params); +} + +class CallbackTokenProvider final : public CloudCacheTokenProvider +{ +public: + CallbackTokenProvider(std::function<CloudCacheAccessToken()>&& Callback) : m_Callback(std::move(Callback)) {} + + virtual ~CallbackTokenProvider() = default; + + virtual CloudCacheAccessToken AcquireAccessToken() final override { return m_Callback(); } + +private: + std::function<CloudCacheAccessToken()> m_Callback; +}; + +std::unique_ptr<CloudCacheTokenProvider> +CloudCacheTokenProvider::CreateFromCallback(std::function<CloudCacheAccessToken()>&& Callback) +{ + return std::make_unique<CallbackTokenProvider>(std::move(Callback)); +} + +CloudCacheClient::CloudCacheClient(const CloudCacheClientOptions& Options, std::unique_ptr<CloudCacheTokenProvider> TokenProvider) +: m_Log(zen::logging::Get("jupiter")) +, m_ServiceUrl(Options.ServiceUrl) +, m_DefaultDdcNamespace(Options.DdcNamespace) +, m_DefaultBlobStoreNamespace(Options.BlobStoreNamespace) +, m_ComputeCluster(Options.ComputeCluster) +, m_ConnectTimeout(Options.ConnectTimeout) +, m_Timeout(Options.Timeout) +, m_TokenProvider(std::move(TokenProvider)) +{ + ZEN_ASSERT(m_TokenProvider.get() != nullptr); +} + +CloudCacheClient::~CloudCacheClient() +{ + RwLock::ExclusiveLockScope _(m_SessionStateLock); + + for (auto State : m_SessionStateCache) + { + delete State; + } +} + +CloudCacheAccessToken +CloudCacheClient::AcquireAccessToken() +{ + ZEN_TRACE_CPU("HordeClient::AcquireAccessToken"); + + return m_TokenProvider->AcquireAccessToken(); +} + +detail::CloudCacheSessionState* +CloudCacheClient::AllocSessionState() +{ + detail::CloudCacheSessionState* State = nullptr; + + bool IsTokenValid = false; + + { + RwLock::ExclusiveLockScope _(m_SessionStateLock); + + if (m_SessionStateCache.empty() == false) + { + State = m_SessionStateCache.front(); + IsTokenValid = State->m_AccessToken.IsValid(); + + m_SessionStateCache.pop_front(); + } + } + + if (State == nullptr) + { + State = new detail::CloudCacheSessionState(*this); + } + + State->Reset(m_ConnectTimeout, m_Timeout); + + if (IsTokenValid == false) + { + State->m_AccessToken = m_TokenProvider->AcquireAccessToken(); + } + + return State; +} + +void +CloudCacheClient::FreeSessionState(detail::CloudCacheSessionState* State) +{ + RwLock::ExclusiveLockScope _(m_SessionStateLock); + m_SessionStateCache.push_front(State); +} + +} // namespace zen diff --git a/src/zenserver/upstream/jupiter.h b/src/zenserver/upstream/jupiter.h new file mode 100644 index 000000000..99e5c530f --- /dev/null +++ b/src/zenserver/upstream/jupiter.h @@ -0,0 +1,217 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/iohash.h> +#include <zencore/logging.h> +#include <zencore/refcount.h> +#include <zencore/thread.h> +#include <zenhttp/httpserver.h> + +#include <atomic> +#include <chrono> +#include <list> +#include <memory> +#include <set> +#include <vector> + +struct ZenCacheValue; + +namespace cpr { +class Session; +} + +namespace zen { +namespace detail { + struct CloudCacheSessionState; +} + +class CbObjectView; +class CloudCacheClient; +class IoBuffer; +struct IoHash; + +/** + * Cached access token, for use with `Authorization:` header + */ +struct CloudCacheAccessToken +{ + using Clock = std::chrono::system_clock; + using TimePoint = Clock::time_point; + + static constexpr int64_t ExpireMarginInSeconds = 30; + + std::string Value; + TimePoint ExpireTime; + + bool IsValid() const + { + return Value.empty() == false && + ExpireMarginInSeconds < std::chrono::duration_cast<std::chrono::seconds>(ExpireTime - Clock::now()).count(); + } +}; + +struct CloudCacheResult +{ + IoBuffer Response; + int64_t Bytes{}; + double ElapsedSeconds{}; + int32_t ErrorCode{}; + std::string Reason; + bool Success = false; +}; + +struct PutRefResult : CloudCacheResult +{ + std::vector<IoHash> Needs; +}; + +struct FinalizeRefResult : CloudCacheResult +{ + std::vector<IoHash> Needs; +}; + +struct CloudCacheExistsResult : CloudCacheResult +{ + std::set<IoHash> Needs; +}; + +struct GetObjectReferencesResult : CloudCacheResult +{ + std::set<IoHash> References; +}; + +/** + * Context for performing Jupiter operations + * + * Maintains an HTTP connection so that subsequent operations don't need to go + * through the whole connection setup process + * + */ +class CloudCacheSession +{ +public: + CloudCacheSession(CloudCacheClient* CacheClient); + ~CloudCacheSession(); + + CloudCacheResult Authenticate(); + CloudCacheResult GetRef(std::string_view Namespace, std::string_view BucketId, const IoHash& Key, ZenContentType RefType); + CloudCacheResult GetBlob(std::string_view Namespace, const IoHash& Key); + CloudCacheResult GetCompressedBlob(std::string_view Namespace, const IoHash& Key); + CloudCacheResult GetObject(std::string_view Namespace, const IoHash& Key); + CloudCacheResult GetInlineBlob(std::string_view Namespace, std::string_view BucketId, const IoHash& Key, IoHash& OutPayloadHash); + + PutRefResult PutRef(std::string_view Namespace, std::string_view BucketId, const IoHash& Key, IoBuffer Ref, ZenContentType RefType); + CloudCacheResult PutBlob(std::string_view Namespace, const IoHash& Key, IoBuffer Blob); + CloudCacheResult PutCompressedBlob(std::string_view Namespace, const IoHash& Key, IoBuffer Blob); + CloudCacheResult PutCompressedBlob(std::string_view Namespace, const IoHash& Key, const CompositeBuffer& Blob); + CloudCacheResult PutObject(std::string_view Namespace, const IoHash& Key, IoBuffer Object); + + FinalizeRefResult FinalizeRef(std::string_view Namespace, std::string_view BucketId, const IoHash& Key, const IoHash& RefHah); + + CloudCacheResult RefExists(std::string_view Namespace, std::string_view BucketId, const IoHash& Key); + + GetObjectReferencesResult GetObjectReferences(std::string_view Namespace, const IoHash& Key); + + CloudCacheResult BlobExists(std::string_view Namespace, const IoHash& Key); + CloudCacheResult CompressedBlobExists(std::string_view Namespace, const IoHash& Key); + CloudCacheResult ObjectExists(std::string_view Namespace, const IoHash& Key); + + CloudCacheExistsResult BlobExists(std::string_view Namespace, const std::set<IoHash>& Keys); + CloudCacheExistsResult CompressedBlobExists(std::string_view Namespace, const std::set<IoHash>& Keys); + CloudCacheExistsResult ObjectExists(std::string_view Namespace, const std::set<IoHash>& Keys); + + CloudCacheResult PostComputeTasks(IoBuffer TasksData); + CloudCacheResult GetComputeUpdates(std::string_view ChannelId, const uint32_t WaitSeconds = 0); + + std::vector<IoHash> Filter(std::string_view Namespace, std::string_view BucketId, const std::vector<IoHash>& ChunkHashes); + + CloudCacheClient& Client() { return *m_CacheClient; }; + +private: + inline spdlog::logger& Log() { return m_Log; } + cpr::Session& GetSession(); + CloudCacheAccessToken GetAccessToken(bool RefreshToken = false); + bool VerifyAccessToken(long StatusCode); + + CloudCacheResult CacheTypeExists(std::string_view Namespace, std::string_view TypeId, const IoHash& Key); + + CloudCacheExistsResult CacheTypeExists(std::string_view Namespace, std::string_view TypeId, const std::set<IoHash>& Keys); + + spdlog::logger& m_Log; + RefPtr<CloudCacheClient> m_CacheClient; + detail::CloudCacheSessionState* m_SessionState; +}; + +/** + * Access token provider interface + */ +class CloudCacheTokenProvider +{ +public: + virtual ~CloudCacheTokenProvider() = default; + + virtual CloudCacheAccessToken AcquireAccessToken() = 0; + + static std::unique_ptr<CloudCacheTokenProvider> CreateFromStaticToken(CloudCacheAccessToken Token); + + struct OAuthClientCredentialsParams + { + std::string_view Url; + std::string_view ClientId; + std::string_view ClientSecret; + }; + + static std::unique_ptr<CloudCacheTokenProvider> CreateFromOAuthClientCredentials(const OAuthClientCredentialsParams& Params); + + static std::unique_ptr<CloudCacheTokenProvider> CreateFromCallback(std::function<CloudCacheAccessToken()>&& Callback); +}; + +struct CloudCacheClientOptions +{ + std::string_view Name; + std::string_view ServiceUrl; + std::string_view DdcNamespace; + std::string_view BlobStoreNamespace; + std::string_view ComputeCluster; + std::chrono::milliseconds ConnectTimeout{5000}; + std::chrono::milliseconds Timeout{}; +}; + +/** + * Jupiter upstream cache client + */ +class CloudCacheClient : public RefCounted +{ +public: + CloudCacheClient(const CloudCacheClientOptions& Options, std::unique_ptr<CloudCacheTokenProvider> TokenProvider); + ~CloudCacheClient(); + + CloudCacheAccessToken AcquireAccessToken(); + std::string_view DefaultDdcNamespace() const { return m_DefaultDdcNamespace; } + std::string_view DefaultBlobStoreNamespace() const { return m_DefaultBlobStoreNamespace; } + std::string_view ComputeCluster() const { return m_ComputeCluster; } + std::string_view ServiceUrl() const { return m_ServiceUrl; } + + spdlog::logger& Logger() { return m_Log; } + +private: + spdlog::logger& m_Log; + std::string m_ServiceUrl; + std::string m_DefaultDdcNamespace; + std::string m_DefaultBlobStoreNamespace; + std::string m_ComputeCluster; + std::chrono::milliseconds m_ConnectTimeout{}; + std::chrono::milliseconds m_Timeout{}; + std::unique_ptr<CloudCacheTokenProvider> m_TokenProvider; + + RwLock m_SessionStateLock; + std::list<detail::CloudCacheSessionState*> m_SessionStateCache; + + detail::CloudCacheSessionState* AllocSessionState(); + void FreeSessionState(detail::CloudCacheSessionState*); + + friend class CloudCacheSession; +}; + +} // namespace zen diff --git a/src/zenserver/upstream/upstream.h b/src/zenserver/upstream/upstream.h new file mode 100644 index 000000000..a57301206 --- /dev/null +++ b/src/zenserver/upstream/upstream.h @@ -0,0 +1,8 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <upstream/jupiter.h> +#include <upstream/upstreamcache.h> +#include <upstream/upstreamservice.h> +#include <upstream/zen.h> diff --git a/src/zenserver/upstream/upstreamapply.cpp b/src/zenserver/upstream/upstreamapply.cpp new file mode 100644 index 000000000..c719b225d --- /dev/null +++ b/src/zenserver/upstream/upstreamapply.cpp @@ -0,0 +1,459 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "upstreamapply.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include <zencore/compactbinary.h> +# include <zencore/compactbinarybuilder.h> +# include <zencore/fmtutils.h> +# include <zencore/stream.h> +# include <zencore/timer.h> +# include <zencore/workthreadpool.h> + +# include <zenstore/cidstore.h> + +# include "diag/logging.h" + +# include <fmt/format.h> + +# include <atomic> + +namespace zen { + +using namespace std::literals; + +struct UpstreamApplyStats +{ + static constexpr uint64_t MaxSampleCount = 1000ull; + + UpstreamApplyStats(bool Enabled) : m_Enabled(Enabled) {} + + void Add(UpstreamApplyEndpoint& Endpoint, const PostUpstreamApplyResult& Result) + { + UpstreamApplyEndpointStats& Stats = Endpoint.Stats(); + + if (Result.Error) + { + Stats.ErrorCount.Increment(1); + } + else if (Result.Success) + { + Stats.PostCount.Increment(1); + Stats.UpBytes.Increment(Result.Bytes / 1024 / 1024); + } + } + + void Add(UpstreamApplyEndpoint& Endpoint, const GetUpstreamApplyUpdatesResult& Result) + { + UpstreamApplyEndpointStats& Stats = Endpoint.Stats(); + + if (Result.Error) + { + Stats.ErrorCount.Increment(1); + } + else if (Result.Success) + { + Stats.UpdateCount.Increment(1); + Stats.DownBytes.Increment(Result.Bytes / 1024 / 1024); + if (!Result.Completed.empty()) + { + uint64_t Completed = 0; + for (auto& It : Result.Completed) + { + Completed += It.second.size(); + } + Stats.CompleteCount.Increment(Completed); + } + } + } + + bool m_Enabled; +}; + +////////////////////////////////////////////////////////////////////////// + +class UpstreamApplyImpl final : public UpstreamApply +{ +public: + UpstreamApplyImpl(const UpstreamApplyOptions& Options, CidStore& CidStore) + : m_Log(logging::Get("upstream-apply")) + , m_Options(Options) + , m_CidStore(CidStore) + , m_Stats(Options.StatsEnabled) + , m_UpstreamAsyncWorkPool(Options.UpstreamThreadCount) + , m_DownstreamAsyncWorkPool(Options.DownstreamThreadCount) + { + } + + virtual ~UpstreamApplyImpl() { Shutdown(); } + + virtual bool Initialize() override + { + for (auto& Endpoint : m_Endpoints) + { + const UpstreamEndpointHealth Health = Endpoint->Initialize(); + if (Health.Ok) + { + Log().info("initialize endpoint '{}' OK", Endpoint->DisplayName()); + } + else + { + Log().warn("initialize endpoint '{}' FAILED, reason '{}'", Endpoint->DisplayName(), Health.Reason); + } + } + + m_RunState.IsRunning = !m_Endpoints.empty(); + + if (m_RunState.IsRunning) + { + m_ShutdownEvent.Reset(); + + m_UpstreamUpdatesThread = std::thread(&UpstreamApplyImpl::ProcessUpstreamUpdates, this); + + m_EndpointMonitorThread = std::thread(&UpstreamApplyImpl::MonitorEndpoints, this); + } + + return m_RunState.IsRunning; + } + + virtual bool IsHealthy() const override + { + if (m_RunState.IsRunning) + { + for (const auto& Endpoint : m_Endpoints) + { + if (Endpoint->IsHealthy()) + { + return true; + } + } + } + + return false; + } + + virtual void RegisterEndpoint(std::unique_ptr<UpstreamApplyEndpoint> Endpoint) override + { + m_Endpoints.emplace_back(std::move(Endpoint)); + } + + virtual EnqueueResult EnqueueUpstream(UpstreamApplyRecord ApplyRecord) override + { + if (m_RunState.IsRunning) + { + const IoHash WorkerId = ApplyRecord.WorkerDescriptor.GetHash(); + const IoHash ActionId = ApplyRecord.Action.GetHash(); + const uint32_t TimeoutSeconds = ApplyRecord.WorkerDescriptor["timeout"sv].AsInt32(300); + + { + std::scoped_lock Lock(m_ApplyTasksMutex); + if (auto Status = FindStatus(WorkerId, ActionId); Status != nullptr) + { + // Already in progress + return {.ApplyId = ActionId, .Success = true}; + } + + std::chrono::steady_clock::time_point ExpireTime = + TimeoutSeconds > 0 ? std::chrono::steady_clock::now() + std::chrono::seconds(TimeoutSeconds) + : std::chrono::steady_clock::time_point::max(); + + m_ApplyTasks[WorkerId][ActionId] = {.State = UpstreamApplyState::Queued, .Result{}, .ExpireTime = std::move(ExpireTime)}; + } + + ApplyRecord.Timepoints["zen-queue-added"] = DateTime::NowTicks(); + m_UpstreamAsyncWorkPool.ScheduleWork( + [this, ApplyRecord = std::move(ApplyRecord)]() { ProcessApplyRecord(std::move(ApplyRecord)); }); + + return {.ApplyId = ActionId, .Success = true}; + } + + return {}; + } + + virtual StatusResult GetStatus(const IoHash& WorkerId, const IoHash& ActionId) override + { + if (m_RunState.IsRunning) + { + std::scoped_lock Lock(m_ApplyTasksMutex); + if (auto Status = FindStatus(WorkerId, ActionId); Status != nullptr) + { + return {.Status = *Status, .Success = true}; + } + } + + return {}; + } + + virtual void GetStatus(CbObjectWriter& Status) override + { + Status << "upstream_worker_threads" << m_Options.UpstreamThreadCount; + Status << "upstream_queue_count" << m_UpstreamAsyncWorkPool.PendingWork(); + Status << "downstream_worker_threads" << m_Options.DownstreamThreadCount; + Status << "downstream_queue_count" << m_DownstreamAsyncWorkPool.PendingWork(); + + Status.BeginArray("endpoints"); + for (const auto& Ep : m_Endpoints) + { + Status.BeginObject(); + Status << "name" << Ep->DisplayName(); + Status << "health" << (Ep->IsHealthy() ? "ok"sv : "inactive"sv); + + UpstreamApplyEndpointStats& Stats = Ep->Stats(); + const uint64_t PostCount = Stats.PostCount.Value(); + const uint64_t CompleteCount = Stats.CompleteCount.Value(); + // const uint64_t UpdateCount = Stats.UpdateCount; + const double CompleteRate = CompleteCount > 0 ? (double(PostCount) / double(CompleteCount)) : 0.0; + + Status << "post_count" << PostCount; + Status << "complete_count" << PostCount; + Status << "update_count" << Stats.UpdateCount.Value(); + + Status << "complete_ratio" << CompleteRate; + Status << "downloaded_mb" << Stats.DownBytes.Value(); + Status << "uploaded_mb" << Stats.UpBytes.Value(); + Status << "error_count" << Stats.ErrorCount.Value(); + + Status.EndObject(); + } + Status.EndArray(); + } + +private: + // The caller is responsible for locking if required + UpstreamApplyStatus* FindStatus(const IoHash& WorkerId, const IoHash& ActionId) + { + if (auto It = m_ApplyTasks.find(WorkerId); It != m_ApplyTasks.end()) + { + if (auto It2 = It->second.find(ActionId); It2 != It->second.end()) + { + return &It2->second; + } + } + return nullptr; + } + + void ProcessApplyRecord(UpstreamApplyRecord ApplyRecord) + { + const IoHash WorkerId = ApplyRecord.WorkerDescriptor.GetHash(); + const IoHash ActionId = ApplyRecord.Action.GetHash(); + try + { + for (auto& Endpoint : m_Endpoints) + { + if (Endpoint->IsHealthy()) + { + ApplyRecord.Timepoints["zen-queue-dispatched"] = DateTime::NowTicks(); + PostUpstreamApplyResult Result = Endpoint->PostApply(std::move(ApplyRecord)); + { + std::scoped_lock Lock(m_ApplyTasksMutex); + if (auto Status = FindStatus(WorkerId, ActionId); Status != nullptr) + { + Status->Timepoints.merge(Result.Timepoints); + + if (Result.Success) + { + Status->State = UpstreamApplyState::Executing; + } + else + { + Status->State = UpstreamApplyState::Complete; + Status->Result = {.Error = std::move(Result.Error), + .Bytes = Result.Bytes, + .ElapsedSeconds = Result.ElapsedSeconds}; + } + } + } + m_Stats.Add(*Endpoint, Result); + return; + } + } + + Log().warn("process upstream apply ({}/{}) FAILED 'No available endpoint'", WorkerId, ActionId); + + { + std::scoped_lock Lock(m_ApplyTasksMutex); + if (auto Status = FindStatus(WorkerId, ActionId); Status != nullptr) + { + Status->State = UpstreamApplyState::Complete; + Status->Result = {.Error{.ErrorCode = -1, .Reason = "No available endpoint"}}; + } + } + } + catch (std::exception& e) + { + std::scoped_lock Lock(m_ApplyTasksMutex); + if (auto Status = FindStatus(WorkerId, ActionId); Status != nullptr) + { + Status->State = UpstreamApplyState::Complete; + Status->Result = {.Error{.ErrorCode = -1, .Reason = e.what()}}; + } + Log().warn("process upstream apply ({}/{}) FAILED '{}'", WorkerId, ActionId, e.what()); + } + } + + void ProcessApplyUpdates() + { + for (auto& Endpoint : m_Endpoints) + { + if (Endpoint->IsHealthy()) + { + GetUpstreamApplyUpdatesResult Result = Endpoint->GetUpdates(m_DownstreamAsyncWorkPool); + m_Stats.Add(*Endpoint, Result); + + if (!Result.Success) + { + Log().warn("process upstream apply updates FAILED '{}'", Result.Error.Reason); + } + + if (!Result.Completed.empty()) + { + for (auto& It : Result.Completed) + { + for (auto& It2 : It.second) + { + std::scoped_lock Lock(m_ApplyTasksMutex); + if (auto Status = FindStatus(It.first, It2.first); Status != nullptr) + { + Status->State = UpstreamApplyState::Complete; + Status->Result = std::move(It2.second); + Status->Result.Timepoints.merge(Status->Timepoints); + Status->Result.Timepoints["zen-queue-complete"] = DateTime::NowTicks(); + Status->Timepoints.clear(); + } + } + } + } + } + } + } + + void ProcessUpstreamUpdates() + { + const auto& UpdateSleep = std::chrono::milliseconds(m_Options.UpdatesInterval); + while (!m_ShutdownEvent.Wait(uint32_t(UpdateSleep.count()))) + { + if (!m_RunState.IsRunning) + { + break; + } + + ProcessApplyUpdates(); + + // Remove any expired tasks, regardless of state + { + std::scoped_lock Lock(m_ApplyTasksMutex); + for (auto& WorkerIt : m_ApplyTasks) + { + const auto Count = std::erase_if(WorkerIt.second, [](const auto& Item) { + return Item.second.ExpireTime < std::chrono::steady_clock::now(); + }); + if (Count > 0) + { + Log().debug("Removed '{}' expired tasks", Count); + } + } + const auto Count = std::erase_if(m_ApplyTasks, [](const auto& Item) { return Item.second.empty(); }); + if (Count > 0) + { + Log().debug("Removed '{}' empty task lists", Count); + } + } + } + } + + void MonitorEndpoints() + { + for (;;) + { + { + std::unique_lock Lock(m_RunState.Mutex); + if (m_RunState.ExitSignal.wait_for(Lock, m_Options.HealthCheckInterval, [this]() { return !m_RunState.IsRunning.load(); })) + { + break; + } + } + + for (auto& Endpoint : m_Endpoints) + { + if (!Endpoint->IsHealthy()) + { + if (const UpstreamEndpointHealth Health = Endpoint->CheckHealth(); Health.Ok) + { + Log().warn("health check endpoint '{}' OK", Endpoint->DisplayName(), Health.Reason); + } + else + { + Log().warn("health check endpoint '{}' FAILED, reason '{}'", Endpoint->DisplayName(), Health.Reason); + } + } + } + } + } + + void Shutdown() + { + if (m_RunState.Stop()) + { + m_ShutdownEvent.Set(); + m_EndpointMonitorThread.join(); + m_UpstreamUpdatesThread.join(); + m_Endpoints.clear(); + } + } + + spdlog::logger& Log() { return m_Log; } + + struct RunState + { + std::mutex Mutex; + std::condition_variable ExitSignal; + std::atomic_bool IsRunning{false}; + + bool Stop() + { + bool Stopped = false; + { + std::scoped_lock Lock(Mutex); + Stopped = IsRunning.exchange(false); + } + if (Stopped) + { + ExitSignal.notify_all(); + } + return Stopped; + } + }; + + spdlog::logger& m_Log; + UpstreamApplyOptions m_Options; + CidStore& m_CidStore; + UpstreamApplyStats m_Stats; + UpstreamApplyTasks m_ApplyTasks; + std::mutex m_ApplyTasksMutex; + std::vector<std::unique_ptr<UpstreamApplyEndpoint>> m_Endpoints; + Event m_ShutdownEvent; + WorkerThreadPool m_UpstreamAsyncWorkPool; + WorkerThreadPool m_DownstreamAsyncWorkPool; + std::thread m_UpstreamUpdatesThread; + std::thread m_EndpointMonitorThread; + RunState m_RunState; +}; + +////////////////////////////////////////////////////////////////////////// + +bool +UpstreamApply::IsHealthy() const +{ + return false; +} + +std::unique_ptr<UpstreamApply> +UpstreamApply::Create(const UpstreamApplyOptions& Options, CidStore& CidStore) +{ + return std::make_unique<UpstreamApplyImpl>(Options, CidStore); +} + +} // namespace zen + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zenserver/upstream/upstreamapply.h b/src/zenserver/upstream/upstreamapply.h new file mode 100644 index 000000000..4a095be6c --- /dev/null +++ b/src/zenserver/upstream/upstreamapply.h @@ -0,0 +1,192 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#if ZEN_WITH_COMPUTE_SERVICES + +# include <zencore/compactbinarypackage.h> +# include <zencore/iobuffer.h> +# include <zencore/iohash.h> +# include <zencore/stats.h> +# include <zencore/zencore.h> + +# include <chrono> +# include <map> +# include <unordered_map> +# include <unordered_set> + +namespace zen { + +class AuthMgr; +class CbObjectWriter; +class CidStore; +class CloudCacheTokenProvider; +class WorkerThreadPool; +class ZenCacheNamespace; +struct CloudCacheClientOptions; +struct UpstreamAuthConfig; + +enum class UpstreamApplyState : int32_t +{ + Queued = 0, + Executing = 1, + Complete = 2, +}; + +enum class UpstreamApplyType +{ + Simple = 0, + Asset = 1, +}; + +struct UpstreamApplyRecord +{ + CbObject WorkerDescriptor; + CbObject Action; + UpstreamApplyType Type; + std::map<std::string, uint64_t> Timepoints{}; +}; + +struct UpstreamApplyOptions +{ + std::chrono::seconds HealthCheckInterval{5}; + std::chrono::seconds UpdatesInterval{5}; + uint32_t UpstreamThreadCount = 4; + uint32_t DownstreamThreadCount = 4; + bool StatsEnabled = false; +}; + +struct UpstreamApplyError +{ + int32_t ErrorCode{}; + std::string Reason{}; + + explicit operator bool() const { return ErrorCode != 0; } +}; + +struct PostUpstreamApplyResult +{ + UpstreamApplyError Error{}; + int64_t Bytes{}; + double ElapsedSeconds{}; + std::map<std::string, uint64_t> Timepoints{}; + bool Success = false; +}; + +struct GetUpstreamApplyResult +{ + // UpstreamApplyType::Simple + std::map<std::filesystem::path, IoHash> OutputFiles{}; + std::map<IoHash, IoBuffer> FileData{}; + + // UpstreamApplyType::Asset + CbPackage OutputPackage{}; + int64_t TotalAttachmentBytes{}; + int64_t TotalRawAttachmentBytes{}; + + UpstreamApplyError Error{}; + int64_t Bytes{}; + double ElapsedSeconds{}; + std::string StdOut{}; + std::string StdErr{}; + std::string Agent{}; + std::string Detail{}; + std::map<std::string, uint64_t> Timepoints{}; + bool Success = false; +}; + +using UpstreamApplyCompleted = std::unordered_map<IoHash, std::unordered_map<IoHash, GetUpstreamApplyResult>>; + +struct GetUpstreamApplyUpdatesResult +{ + UpstreamApplyError Error{}; + int64_t Bytes{}; + double ElapsedSeconds{}; + UpstreamApplyCompleted Completed{}; + bool Success = false; +}; + +struct UpstreamApplyStatus +{ + UpstreamApplyState State{}; + GetUpstreamApplyResult Result{}; + std::chrono::steady_clock::time_point ExpireTime{}; + std::map<std::string, uint64_t> Timepoints{}; +}; + +using UpstreamApplyTasks = std::unordered_map<IoHash, std::unordered_map<IoHash, UpstreamApplyStatus>>; + +struct UpstreamEndpointHealth +{ + std::string Reason; + bool Ok = false; +}; + +struct UpstreamApplyEndpointStats +{ + metrics::Counter PostCount; + metrics::Counter CompleteCount; + metrics::Counter UpdateCount; + metrics::Counter ErrorCount; + metrics::Counter UpBytes; + metrics::Counter DownBytes; +}; + +/** + * The upstream apply endpoint is responsible for handling remote execution. + */ +class UpstreamApplyEndpoint +{ +public: + virtual ~UpstreamApplyEndpoint() = default; + + virtual UpstreamEndpointHealth Initialize() = 0; + virtual bool IsHealthy() const = 0; + virtual UpstreamEndpointHealth CheckHealth() = 0; + virtual std::string_view DisplayName() const = 0; + virtual PostUpstreamApplyResult PostApply(UpstreamApplyRecord ApplyRecord) = 0; + virtual GetUpstreamApplyUpdatesResult GetUpdates(WorkerThreadPool& ThreadPool) = 0; + virtual UpstreamApplyEndpointStats& Stats() = 0; + + static std::unique_ptr<UpstreamApplyEndpoint> CreateHordeEndpoint(const CloudCacheClientOptions& ComputeOptions, + const UpstreamAuthConfig& ComputeAuthConfig, + const CloudCacheClientOptions& StorageOptions, + const UpstreamAuthConfig& StorageAuthConfig, + CidStore& CidStore, + AuthMgr& Mgr); +}; + +/** + * Manages one or more upstream compute endpoints. + */ +class UpstreamApply +{ +public: + virtual ~UpstreamApply() = default; + + virtual bool Initialize() = 0; + virtual bool IsHealthy() const = 0; + virtual void RegisterEndpoint(std::unique_ptr<UpstreamApplyEndpoint> Endpoint) = 0; + + struct EnqueueResult + { + IoHash ApplyId{}; + bool Success = false; + }; + + struct StatusResult + { + UpstreamApplyStatus Status{}; + bool Success = false; + }; + + virtual EnqueueResult EnqueueUpstream(UpstreamApplyRecord ApplyRecord) = 0; + virtual StatusResult GetStatus(const IoHash& WorkerId, const IoHash& ActionId) = 0; + virtual void GetStatus(CbObjectWriter& CbO) = 0; + + static std::unique_ptr<UpstreamApply> Create(const UpstreamApplyOptions& Options, CidStore& CidStore); +}; + +} // namespace zen + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zenserver/upstream/upstreamcache.cpp b/src/zenserver/upstream/upstreamcache.cpp new file mode 100644 index 000000000..e838b5fe2 --- /dev/null +++ b/src/zenserver/upstream/upstreamcache.cpp @@ -0,0 +1,2112 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "upstreamcache.h" +#include "jupiter.h" +#include "zen.h" + +#include <zencore/blockingqueue.h> +#include <zencore/compactbinary.h> +#include <zencore/compactbinarybuilder.h> +#include <zencore/compactbinarypackage.h> +#include <zencore/compactbinaryvalidation.h> +#include <zencore/fmtutils.h> +#include <zencore/stats.h> +#include <zencore/stream.h> +#include <zencore/timer.h> +#include <zencore/trace.h> + +#include <zenhttp/httpshared.h> + +#include <zenstore/cidstore.h> + +#include <auth/authmgr.h> +#include "cache/structuredcache.h" +#include "cache/structuredcachestore.h" +#include "diag/logging.h" + +#include <fmt/format.h> + +#include <algorithm> +#include <atomic> +#include <shared_mutex> +#include <thread> +#include <unordered_map> + +namespace zen { + +using namespace std::literals; + +namespace detail { + + class UpstreamStatus + { + public: + UpstreamEndpointState EndpointState() const { return static_cast<UpstreamEndpointState>(m_State.load(std::memory_order_relaxed)); } + + UpstreamEndpointStatus EndpointStatus() const + { + const UpstreamEndpointState State = EndpointState(); + { + std::unique_lock _(m_Mutex); + return {.Reason = m_ErrorText, .State = State}; + } + } + + void Set(UpstreamEndpointState NewState) + { + m_State.store(static_cast<uint32_t>(NewState), std::memory_order_relaxed); + { + std::unique_lock _(m_Mutex); + m_ErrorText.clear(); + } + } + + void Set(UpstreamEndpointState NewState, std::string ErrorText) + { + m_State.store(static_cast<uint32_t>(NewState), std::memory_order_relaxed); + { + std::unique_lock _(m_Mutex); + m_ErrorText = std::move(ErrorText); + } + } + + void SetFromErrorCode(int32_t ErrorCode, std::string_view ErrorText) + { + if (ErrorCode != 0) + { + Set(ErrorCode == 401 ? UpstreamEndpointState::kUnauthorized : UpstreamEndpointState::kError, std::string(ErrorText)); + } + } + + private: + mutable std::mutex m_Mutex; + std::string m_ErrorText; + std::atomic_uint32_t m_State; + }; + + class JupiterUpstreamEndpoint final : public UpstreamEndpoint + { + public: + JupiterUpstreamEndpoint(const CloudCacheClientOptions& Options, const UpstreamAuthConfig& AuthConfig, AuthMgr& Mgr) + : m_AuthMgr(Mgr) + , m_Log(zen::logging::Get("upstream")) + { + ZEN_ASSERT(!Options.Name.empty()); + m_Info.Name = Options.Name; + m_Info.Url = Options.ServiceUrl; + + std::unique_ptr<CloudCacheTokenProvider> TokenProvider; + + if (AuthConfig.OAuthUrl.empty() == false) + { + TokenProvider = CloudCacheTokenProvider::CreateFromOAuthClientCredentials( + {.Url = AuthConfig.OAuthUrl, .ClientId = AuthConfig.OAuthClientId, .ClientSecret = AuthConfig.OAuthClientSecret}); + } + else if (AuthConfig.OpenIdProvider.empty() == false) + { + TokenProvider = + CloudCacheTokenProvider::CreateFromCallback([this, ProviderName = std::string(AuthConfig.OpenIdProvider)]() { + AuthMgr::OpenIdAccessToken Token = m_AuthMgr.GetOpenIdAccessToken(ProviderName); + return CloudCacheAccessToken{.Value = Token.AccessToken, .ExpireTime = Token.ExpireTime}; + }); + } + else + { + CloudCacheAccessToken AccessToken{.Value = std::string(AuthConfig.AccessToken), + .ExpireTime = CloudCacheAccessToken::TimePoint::max()}; + + TokenProvider = CloudCacheTokenProvider::CreateFromStaticToken(AccessToken); + } + + m_Client = new CloudCacheClient(Options, std::move(TokenProvider)); + } + + virtual ~JupiterUpstreamEndpoint() = default; + + virtual const UpstreamEndpointInfo& GetEndpointInfo() const override { return m_Info; } + + virtual UpstreamEndpointStatus Initialize() override + { + try + { + if (m_Status.EndpointState() == UpstreamEndpointState::kOk) + { + return {.State = UpstreamEndpointState::kOk}; + } + + CloudCacheSession Session(m_Client); + const CloudCacheResult Result = Session.Authenticate(); + + if (Result.Success) + { + m_Status.Set(UpstreamEndpointState::kOk); + } + else if (Result.ErrorCode != 0) + { + m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason); + } + else + { + m_Status.Set(UpstreamEndpointState::kUnauthorized); + } + + return m_Status.EndpointStatus(); + } + catch (std::exception& Err) + { + m_Status.Set(UpstreamEndpointState::kError, Err.what()); + + return {.Reason = Err.what(), .State = GetState()}; + } + } + + std::string_view GetActualDdcNamespace(CloudCacheSession& Session, std::string_view Namespace) + { + if (Namespace == ZenCacheStore::DefaultNamespace) + { + return Session.Client().DefaultDdcNamespace(); + } + return Namespace; + } + + std::string_view GetActualBlobStoreNamespace(CloudCacheSession& Session, std::string_view Namespace) + { + if (Namespace == ZenCacheStore::DefaultNamespace) + { + return Session.Client().DefaultBlobStoreNamespace(); + } + return Namespace; + } + + virtual UpstreamEndpointState GetState() override { return m_Status.EndpointState(); } + + virtual UpstreamEndpointStatus GetStatus() override { return m_Status.EndpointStatus(); } + + virtual GetUpstreamCacheSingleResult GetCacheRecord(std::string_view Namespace, + const CacheKey& CacheKey, + ZenContentType Type) override + { + ZEN_TRACE_CPU("Upstream::Horde::GetSingleCacheRecord"); + + try + { + CloudCacheSession Session(m_Client); + CloudCacheResult Result; + + std::string_view BlobStoreNamespace = GetActualBlobStoreNamespace(Session, Namespace); + + if (Type == ZenContentType::kCompressedBinary) + { + Result = Session.GetRef(BlobStoreNamespace, CacheKey.Bucket, CacheKey.Hash, ZenContentType::kCbObject); + + if (Result.Success) + { + const CbValidateError ValidationResult = ValidateCompactBinary(Result.Response, CbValidateMode::All); + if (Result.Success = ValidationResult == CbValidateError::None; Result.Success) + { + CbObject CacheRecord = LoadCompactBinaryObject(Result.Response); + IoBuffer ContentBuffer; + int NumAttachments = 0; + + CacheRecord.IterateAttachments([&](CbFieldView AttachmentHash) { + CloudCacheResult AttachmentResult = Session.GetCompressedBlob(BlobStoreNamespace, AttachmentHash.AsHash()); + Result.Bytes += AttachmentResult.Bytes; + Result.ElapsedSeconds += AttachmentResult.ElapsedSeconds; + Result.ErrorCode = AttachmentResult.ErrorCode; + + IoHash RawHash; + uint64_t RawSize; + if (CompressedBuffer::ValidateCompressedHeader(AttachmentResult.Response, RawHash, RawSize)) + { + Result.Response = AttachmentResult.Response; + ++NumAttachments; + } + else + { + Result.Success = false; + } + }); + if (NumAttachments != 1) + { + Result.Success = false; + } + } + } + } + else + { + const ZenContentType AcceptType = Type == ZenContentType::kCbPackage ? ZenContentType::kCbObject : Type; + Result = Session.GetRef(BlobStoreNamespace, CacheKey.Bucket, CacheKey.Hash, AcceptType); + + if (Result.Success && Type == ZenContentType::kCbPackage) + { + CbPackage Package; + + const CbValidateError ValidationResult = ValidateCompactBinary(Result.Response, CbValidateMode::All); + if (Result.Success = ValidationResult == CbValidateError::None; Result.Success) + { + CbObject CacheRecord = LoadCompactBinaryObject(Result.Response); + + CacheRecord.IterateAttachments([&](CbFieldView AttachmentHash) { + CloudCacheResult AttachmentResult = Session.GetCompressedBlob(BlobStoreNamespace, AttachmentHash.AsHash()); + Result.Bytes += AttachmentResult.Bytes; + Result.ElapsedSeconds += AttachmentResult.ElapsedSeconds; + Result.ErrorCode = AttachmentResult.ErrorCode; + + IoHash RawHash; + uint64_t RawSize; + if (CompressedBuffer Chunk = + CompressedBuffer::FromCompressed(SharedBuffer(AttachmentResult.Response), RawHash, RawSize)) + { + Package.AddAttachment(CbAttachment(Chunk, AttachmentHash.AsHash())); + } + else + { + Result.Success = false; + } + }); + + Package.SetObject(CacheRecord); + } + + if (Result.Success) + { + BinaryWriter MemStream; + Package.Save(MemStream); + + Result.Response = IoBuffer(IoBuffer::Clone, MemStream.Data(), MemStream.Size()); + } + } + } + + m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason); + + if (Result.ErrorCode == 0) + { + return {.Status = {.Bytes = Result.Bytes, .ElapsedSeconds = Result.ElapsedSeconds, .Success = Result.Success}, + .Value = Result.Response, + .Source = &m_Info}; + } + else + { + return {.Status = {.Error{.ErrorCode = Result.ErrorCode, .Reason = std::move(Result.Reason)}}}; + } + } + catch (std::exception& Err) + { + m_Status.Set(UpstreamEndpointState::kError, Err.what()); + + return {.Status = {.Error{.ErrorCode = -1, .Reason = Err.what()}}}; + } + } + + virtual GetUpstreamCacheResult GetCacheRecords(std::string_view Namespace, + std::span<CacheKeyRequest*> Requests, + OnCacheRecordGetComplete&& OnComplete) override + { + ZEN_TRACE_CPU("Upstream::Horde::GetCacheRecords"); + + CloudCacheSession Session(m_Client); + GetUpstreamCacheResult Result; + + for (CacheKeyRequest* Request : Requests) + { + const CacheKey& CacheKey = Request->Key; + CbPackage Package; + CbObject Record; + + double ElapsedSeconds = 0.0; + if (!Result.Error) + { + std::string_view BlobStoreNamespace = GetActualBlobStoreNamespace(Session, Namespace); + CloudCacheResult RefResult = + Session.GetRef(BlobStoreNamespace, CacheKey.Bucket, CacheKey.Hash, ZenContentType::kCbObject); + AppendResult(RefResult, Result); + ElapsedSeconds = RefResult.ElapsedSeconds; + + m_Status.SetFromErrorCode(RefResult.ErrorCode, RefResult.Reason); + + if (RefResult.ErrorCode == 0) + { + const CbValidateError ValidationResult = ValidateCompactBinary(RefResult.Response, CbValidateMode::All); + if (ValidationResult == CbValidateError::None) + { + Record = LoadCompactBinaryObject(RefResult.Response); + Record.IterateAttachments([&](CbFieldView AttachmentHash) { + CloudCacheResult BlobResult = Session.GetCompressedBlob(BlobStoreNamespace, AttachmentHash.AsHash()); + AppendResult(BlobResult, Result); + + m_Status.SetFromErrorCode(BlobResult.ErrorCode, BlobResult.Reason); + + if (BlobResult.ErrorCode == 0) + { + IoHash RawHash; + uint64_t RawSize; + if (CompressedBuffer Chunk = + CompressedBuffer::FromCompressed(SharedBuffer(BlobResult.Response), RawHash, RawSize)) + { + if (RawHash == AttachmentHash.AsHash()) + { + Package.AddAttachment(CbAttachment(Chunk, RawHash)); + } + } + } + }); + } + } + } + + OnComplete( + {.Request = *Request, .Record = Record, .Package = Package, .ElapsedSeconds = ElapsedSeconds, .Source = &m_Info}); + } + + return Result; + } + + virtual GetUpstreamCacheSingleResult GetCacheChunk(std::string_view Namespace, + const CacheKey&, + const IoHash& ValueContentId) override + { + ZEN_TRACE_CPU("Upstream::Horde::GetSingleCacheChunk"); + + try + { + CloudCacheSession Session(m_Client); + std::string_view BlobStoreNamespace = GetActualBlobStoreNamespace(Session, Namespace); + const CloudCacheResult Result = Session.GetCompressedBlob(BlobStoreNamespace, ValueContentId); + + m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason); + + if (Result.ErrorCode == 0) + { + return {.Status = {.Bytes = Result.Bytes, .ElapsedSeconds = Result.ElapsedSeconds, .Success = Result.Success}, + .Value = Result.Response, + .Source = &m_Info}; + } + else + { + return {.Status = {.Error{.ErrorCode = Result.ErrorCode, .Reason = std::move(Result.Reason)}}}; + } + } + catch (std::exception& Err) + { + m_Status.Set(UpstreamEndpointState::kError, Err.what()); + + return {.Status = {.Error{.ErrorCode = -1, .Reason = Err.what()}}}; + } + } + + virtual GetUpstreamCacheResult GetCacheChunks(std::string_view Namespace, + std::span<CacheChunkRequest*> CacheChunkRequests, + OnCacheChunksGetComplete&& OnComplete) override final + { + ZEN_TRACE_CPU("Upstream::Horde::GetCacheChunks"); + + CloudCacheSession Session(m_Client); + GetUpstreamCacheResult Result; + + for (CacheChunkRequest* RequestPtr : CacheChunkRequests) + { + CacheChunkRequest& Request = *RequestPtr; + IoBuffer Payload; + IoHash RawHash = IoHash::Zero; + uint64_t RawSize = 0; + + double ElapsedSeconds = 0.0; + bool IsCompressed = false; + if (!Result.Error) + { + std::string_view BlobStoreNamespace = GetActualBlobStoreNamespace(Session, Namespace); + const CloudCacheResult BlobResult = + Request.ChunkId == IoHash::Zero + ? Session.GetInlineBlob(BlobStoreNamespace, Request.Key.Bucket, Request.Key.Hash, Request.ChunkId) + : Session.GetCompressedBlob(BlobStoreNamespace, Request.ChunkId); + ElapsedSeconds = BlobResult.ElapsedSeconds; + Payload = BlobResult.Response; + + AppendResult(BlobResult, Result); + + m_Status.SetFromErrorCode(BlobResult.ErrorCode, BlobResult.Reason); + if (Payload && IsCompressedBinary(Payload.GetContentType())) + { + IsCompressed = CompressedBuffer::ValidateCompressedHeader(Payload, RawHash, RawSize); + } + } + + if (IsCompressed) + { + OnComplete({.Request = Request, + .RawHash = RawHash, + .RawSize = RawSize, + .Value = Payload, + .ElapsedSeconds = ElapsedSeconds, + .Source = &m_Info}); + } + else + { + OnComplete({.Request = Request, .RawHash = IoHash::Zero, .RawSize = 0, .Value = IoBuffer()}); + } + } + + return Result; + } + + virtual GetUpstreamCacheResult GetCacheValues(std::string_view Namespace, + std::span<CacheValueRequest*> CacheValueRequests, + OnCacheValueGetComplete&& OnComplete) override final + { + ZEN_TRACE_CPU("Upstream::Horde::GetCacheValues"); + + CloudCacheSession Session(m_Client); + GetUpstreamCacheResult Result; + + for (CacheValueRequest* RequestPtr : CacheValueRequests) + { + CacheValueRequest& Request = *RequestPtr; + IoBuffer Payload; + IoHash RawHash = IoHash::Zero; + uint64_t RawSize = 0; + + double ElapsedSeconds = 0.0; + bool IsCompressed = false; + if (!Result.Error) + { + std::string_view BlobStoreNamespace = GetActualBlobStoreNamespace(Session, Namespace); + IoHash PayloadHash; + const CloudCacheResult BlobResult = + Session.GetInlineBlob(BlobStoreNamespace, Request.Key.Bucket, Request.Key.Hash, PayloadHash); + ElapsedSeconds = BlobResult.ElapsedSeconds; + Payload = BlobResult.Response; + + AppendResult(BlobResult, Result); + + m_Status.SetFromErrorCode(BlobResult.ErrorCode, BlobResult.Reason); + if (Payload) + { + if (IsCompressedBinary(Payload.GetContentType())) + { + IsCompressed = CompressedBuffer::ValidateCompressedHeader(Payload, RawHash, RawSize) && RawHash != PayloadHash; + } + else + { + CompressedBuffer Compressed = CompressedBuffer::Compress(SharedBuffer(Payload)); + RawHash = Compressed.DecodeRawHash(); + if (RawHash == PayloadHash) + { + IsCompressed = true; + } + else + { + ZEN_WARN("Horde request for inline payload of {}/{}/{} has hash {}, expected hash {} from header", + Namespace, + Request.Key.Bucket, + Request.Key.Hash.ToHexString(), + RawHash.ToHexString(), + PayloadHash.ToHexString()); + } + } + } + } + + if (IsCompressed) + { + OnComplete({.Request = Request, + .RawHash = RawHash, + .RawSize = RawSize, + .Value = Payload, + .ElapsedSeconds = ElapsedSeconds, + .Source = &m_Info}); + } + else + { + OnComplete({.Request = Request, .RawHash = IoHash::Zero, .RawSize = 0, .Value = IoBuffer()}); + } + } + + return Result; + } + + virtual PutUpstreamCacheResult PutCacheRecord(const UpstreamCacheRecord& CacheRecord, + IoBuffer RecordValue, + std::span<IoBuffer const> Values) override + { + ZEN_TRACE_CPU("Upstream::Horde::PutCacheRecord"); + + ZEN_ASSERT(CacheRecord.ValueContentIds.size() == Values.size()); + const int32_t MaxAttempts = 3; + + try + { + CloudCacheSession Session(m_Client); + + if (CacheRecord.Type == ZenContentType::kBinary) + { + CloudCacheResult Result; + for (uint32_t Attempt = 0; Attempt < MaxAttempts && !Result.Success; Attempt++) + { + std::string_view BlobStoreNamespace = GetActualBlobStoreNamespace(Session, CacheRecord.Namespace); + Result = Session.PutRef(BlobStoreNamespace, + CacheRecord.Key.Bucket, + CacheRecord.Key.Hash, + RecordValue, + ZenContentType::kBinary); + } + + m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason); + + return {.Reason = std::move(Result.Reason), + .Bytes = Result.Bytes, + .ElapsedSeconds = Result.ElapsedSeconds, + .Success = Result.Success}; + } + else if (CacheRecord.Type == ZenContentType::kCompressedBinary) + { + IoHash RawHash; + uint64_t RawSize; + if (!CompressedBuffer::ValidateCompressedHeader(RecordValue, RawHash, RawSize)) + { + return {.Reason = std::string("Invalid compressed value buffer"), .Success = false}; + } + + CbObjectWriter ReferencingObject; + ReferencingObject.AddBinaryAttachment("RawHash", RawHash); + ReferencingObject.AddInteger("RawSize", RawSize); + + return PerformStructuredPut( + Session, + CacheRecord.Namespace, + CacheRecord.Key, + ReferencingObject.Save().GetBuffer().AsIoBuffer(), + MaxAttempts, + [&](const IoHash& ValueContentId, IoBuffer& OutBuffer, std::string& OutReason) { + if (ValueContentId != RawHash) + { + OutReason = + fmt::format("Value '{}' MISMATCHED from compressed buffer raw hash {}", ValueContentId, RawHash); + return false; + } + + OutBuffer = RecordValue; + return true; + }); + } + else + { + return PerformStructuredPut( + Session, + CacheRecord.Namespace, + CacheRecord.Key, + RecordValue, + MaxAttempts, + [&](const IoHash& ValueContentId, IoBuffer& OutBuffer, std::string& OutReason) { + const auto It = + std::find(std::begin(CacheRecord.ValueContentIds), std::end(CacheRecord.ValueContentIds), ValueContentId); + + if (It == std::end(CacheRecord.ValueContentIds)) + { + OutReason = fmt::format("value '{}' MISSING from local cache", ValueContentId); + return false; + } + + const size_t Idx = std::distance(std::begin(CacheRecord.ValueContentIds), It); + + OutBuffer = Values[Idx]; + return true; + }); + } + } + catch (std::exception& Err) + { + m_Status.Set(UpstreamEndpointState::kError, Err.what()); + + return {.Reason = std::string(Err.what()), .Success = false}; + } + } + + virtual UpstreamEndpointStats& Stats() override { return m_Stats; } + + private: + static void AppendResult(const CloudCacheResult& Result, GetUpstreamCacheResult& Out) + { + Out.Success &= Result.Success; + Out.Bytes += Result.Bytes; + Out.ElapsedSeconds += Result.ElapsedSeconds; + + if (Result.ErrorCode) + { + Out.Error = {.ErrorCode = Result.ErrorCode, .Reason = std::move(Result.Reason)}; + } + }; + + PutUpstreamCacheResult PerformStructuredPut( + CloudCacheSession& Session, + std::string_view Namespace, + const CacheKey& Key, + IoBuffer ObjectBuffer, + const int32_t MaxAttempts, + std::function<bool(const IoHash& ValueContentId, IoBuffer& OutBuffer, std::string& OutReason)>&& BlobFetchFn) + { + int64_t TotalBytes = 0ull; + double TotalElapsedSeconds = 0.0; + + std::string_view BlobStoreNamespace = GetActualBlobStoreNamespace(Session, Namespace); + const auto PutBlobs = [&](std::span<IoHash> ValueContentIds, std::string& OutReason) -> bool { + for (const IoHash& ValueContentId : ValueContentIds) + { + IoBuffer BlobBuffer; + if (!BlobFetchFn(ValueContentId, BlobBuffer, OutReason)) + { + return false; + } + + CloudCacheResult BlobResult; + for (int32_t Attempt = 0; Attempt < MaxAttempts && !BlobResult.Success; Attempt++) + { + BlobResult = Session.PutCompressedBlob(BlobStoreNamespace, ValueContentId, BlobBuffer); + } + + m_Status.SetFromErrorCode(BlobResult.ErrorCode, BlobResult.Reason); + + if (!BlobResult.Success) + { + OutReason = fmt::format("upload value '{}' FAILED, reason '{}'", ValueContentId, BlobResult.Reason); + return false; + } + + TotalBytes += BlobResult.Bytes; + TotalElapsedSeconds += BlobResult.ElapsedSeconds; + } + + return true; + }; + + PutRefResult RefResult; + for (int32_t Attempt = 0; Attempt < MaxAttempts && !RefResult.Success; Attempt++) + { + RefResult = Session.PutRef(BlobStoreNamespace, Key.Bucket, Key.Hash, ObjectBuffer, ZenContentType::kCbObject); + } + + m_Status.SetFromErrorCode(RefResult.ErrorCode, RefResult.Reason); + + if (!RefResult.Success) + { + return {.Reason = fmt::format("upload cache record '{}/{}' FAILED, reason '{}'", Key.Bucket, Key.Hash, RefResult.Reason), + .Success = false}; + } + + TotalBytes += RefResult.Bytes; + TotalElapsedSeconds += RefResult.ElapsedSeconds; + + std::string Reason; + if (!PutBlobs(RefResult.Needs, Reason)) + { + return {.Reason = std::move(Reason), .Success = false}; + } + + const IoHash RefHash = IoHash::HashBuffer(ObjectBuffer); + FinalizeRefResult FinalizeResult = Session.FinalizeRef(BlobStoreNamespace, Key.Bucket, Key.Hash, RefHash); + + m_Status.SetFromErrorCode(FinalizeResult.ErrorCode, FinalizeResult.Reason); + + if (!FinalizeResult.Success) + { + return { + .Reason = fmt::format("finalize cache record '{}/{}' FAILED, reason '{}'", Key.Bucket, Key.Hash, FinalizeResult.Reason), + .Success = false}; + } + + if (!FinalizeResult.Needs.empty()) + { + if (!PutBlobs(FinalizeResult.Needs, Reason)) + { + return {.Reason = std::move(Reason), .Success = false}; + } + + FinalizeResult = Session.FinalizeRef(BlobStoreNamespace, Key.Bucket, Key.Hash, RefHash); + + m_Status.SetFromErrorCode(FinalizeResult.ErrorCode, FinalizeResult.Reason); + + if (!FinalizeResult.Success) + { + return {.Reason = fmt::format("finalize '{}/{}' FAILED, reason '{}'", Key.Bucket, Key.Hash, FinalizeResult.Reason), + .Success = false}; + } + + if (!FinalizeResult.Needs.empty()) + { + ExtendableStringBuilder<256> Sb; + for (const IoHash& MissingHash : FinalizeResult.Needs) + { + Sb << MissingHash.ToHexString() << ","; + } + + return { + .Reason = fmt::format("finalize '{}/{}' FAILED, still needs value(s) '{}'", Key.Bucket, Key.Hash, Sb.ToString()), + .Success = false}; + } + } + + TotalBytes += FinalizeResult.Bytes; + TotalElapsedSeconds += FinalizeResult.ElapsedSeconds; + + return {.Bytes = TotalBytes, .ElapsedSeconds = TotalElapsedSeconds, .Success = true}; + } + + spdlog::logger& Log() { return m_Log; } + + AuthMgr& m_AuthMgr; + spdlog::logger& m_Log; + UpstreamEndpointInfo m_Info; + UpstreamStatus m_Status; + UpstreamEndpointStats m_Stats; + RefPtr<CloudCacheClient> m_Client; + }; + + class ZenUpstreamEndpoint final : public UpstreamEndpoint + { + struct ZenEndpoint + { + std::string Url; + std::string Reason; + double Latency{}; + bool Ok = false; + + bool operator<(const ZenEndpoint& RHS) const { return Ok && RHS.Ok ? Latency < RHS.Latency : Ok; } + }; + + public: + ZenUpstreamEndpoint(const ZenStructuredCacheClientOptions& Options) + : m_Log(zen::logging::Get("upstream")) + , m_ConnectTimeout(Options.ConnectTimeout) + , m_Timeout(Options.Timeout) + { + ZEN_ASSERT(!Options.Name.empty()); + m_Info.Name = Options.Name; + + for (const auto& Url : Options.Urls) + { + m_Endpoints.push_back({.Url = Url}); + } + } + + ~ZenUpstreamEndpoint() = default; + + virtual const UpstreamEndpointInfo& GetEndpointInfo() const override { return m_Info; } + + virtual UpstreamEndpointStatus Initialize() override + { + try + { + if (m_Status.EndpointState() == UpstreamEndpointState::kOk) + { + return {.State = UpstreamEndpointState::kOk}; + } + + const ZenEndpoint& Ep = GetEndpoint(); + + if (m_Info.Url != Ep.Url) + { + ZEN_INFO("Setting Zen upstream URL to '{}'", Ep.Url); + m_Info.Url = Ep.Url; + } + + if (Ep.Ok) + { + RwLock::ExclusiveLockScope _(m_ClientLock); + m_Client = new ZenStructuredCacheClient({.Url = m_Info.Url, .ConnectTimeout = m_ConnectTimeout, .Timeout = m_Timeout}); + m_Status.Set(UpstreamEndpointState::kOk); + } + else + { + m_Status.Set(UpstreamEndpointState::kError, Ep.Reason); + } + + return m_Status.EndpointStatus(); + } + catch (std::exception& Err) + { + m_Status.Set(UpstreamEndpointState::kError, Err.what()); + + return {.Reason = Err.what(), .State = GetState()}; + } + } + + virtual UpstreamEndpointState GetState() override { return m_Status.EndpointState(); } + + virtual UpstreamEndpointStatus GetStatus() override { return m_Status.EndpointStatus(); } + + virtual GetUpstreamCacheSingleResult GetCacheRecord(std::string_view Namespace, + const CacheKey& CacheKey, + ZenContentType Type) override + { + ZEN_TRACE_CPU("Upstream::Zen::GetSingleCacheRecord"); + + try + { + ZenStructuredCacheSession Session(GetClientRef()); + const ZenCacheResult Result = Session.GetCacheRecord(Namespace, CacheKey.Bucket, CacheKey.Hash, Type); + + m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason); + + if (Result.ErrorCode == 0) + { + return {.Status = {.Bytes = Result.Bytes, .ElapsedSeconds = Result.ElapsedSeconds, .Success = Result.Success}, + .Value = Result.Response, + .Source = &m_Info}; + } + else + { + return {.Status = {.Error{.ErrorCode = Result.ErrorCode, .Reason = std::move(Result.Reason)}}}; + } + } + catch (std::exception& Err) + { + m_Status.Set(UpstreamEndpointState::kError, Err.what()); + + return {.Status = {.Error{.ErrorCode = -1, .Reason = Err.what()}}}; + } + } + + virtual GetUpstreamCacheResult GetCacheRecords(std::string_view Namespace, + std::span<CacheKeyRequest*> Requests, + OnCacheRecordGetComplete&& OnComplete) override + { + ZEN_TRACE_CPU("Upstream::Zen::GetCacheRecords"); + ZEN_ASSERT(Requests.size() > 0); + + CbObjectWriter BatchRequest; + BatchRequest << "Method"sv + << "GetCacheRecords"sv; + BatchRequest << "Accept"sv << kCbPkgMagic; + + BatchRequest.BeginObject("Params"sv); + { + CachePolicy DefaultPolicy = Requests[0]->Policy.GetRecordPolicy(); + BatchRequest << "DefaultPolicy"sv << WriteToString<128>(DefaultPolicy); + + BatchRequest << "Namespace"sv << Namespace; + + BatchRequest.BeginArray("Requests"sv); + for (CacheKeyRequest* Request : Requests) + { + BatchRequest.BeginObject(); + { + const CacheKey& Key = Request->Key; + BatchRequest.BeginObject("Key"sv); + { + BatchRequest << "Bucket"sv << Key.Bucket; + BatchRequest << "Hash"sv << Key.Hash; + } + BatchRequest.EndObject(); + if (!Request->Policy.IsUniform() || Request->Policy.GetRecordPolicy() != DefaultPolicy) + { + BatchRequest.SetName("Policy"sv); + Request->Policy.Save(BatchRequest); + } + } + BatchRequest.EndObject(); + } + BatchRequest.EndArray(); + } + BatchRequest.EndObject(); + + ZenCacheResult Result; + + { + ZenStructuredCacheSession Session(GetClientRef()); + Result = Session.InvokeRpc(BatchRequest.Save()); + } + + m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason); + + if (Result.Success) + { + CbPackage BatchResponse; + if (ParsePackageMessageWithLegacyFallback(Result.Response, BatchResponse)) + { + CbArrayView Results = BatchResponse.GetObject()["Result"sv].AsArrayView(); + if (Results.Num() != Requests.size()) + { + ZEN_WARN("Upstream::Zen::GetCacheRecords invalid number of Response results from Upstream."); + } + else + { + for (size_t Index = 0; CbFieldView Record : Results) + { + CacheKeyRequest* Request = Requests[Index++]; + OnComplete({.Request = *Request, + .Record = Record.AsObjectView(), + .Package = BatchResponse, + .ElapsedSeconds = Result.ElapsedSeconds, + .Source = &m_Info}); + } + + return {.Bytes = Result.Bytes, .ElapsedSeconds = Result.ElapsedSeconds, .Success = true}; + } + } + else + { + ZEN_WARN("Upstream::Zen::GetCacheRecords invalid Response from Upstream."); + } + } + + for (CacheKeyRequest* Request : Requests) + { + OnComplete({.Request = *Request, .Record = CbObjectView(), .Package = CbPackage()}); + } + + return {.Error{.ErrorCode = Result.ErrorCode, .Reason = std::move(Result.Reason)}}; + } + + virtual GetUpstreamCacheSingleResult GetCacheChunk(std::string_view Namespace, + const CacheKey& CacheKey, + const IoHash& ValueContentId) override + { + ZEN_TRACE_CPU("Upstream::Zen::GetCacheChunk"); + + try + { + ZenStructuredCacheSession Session(GetClientRef()); + const ZenCacheResult Result = Session.GetCacheChunk(Namespace, CacheKey.Bucket, CacheKey.Hash, ValueContentId); + + m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason); + + if (Result.ErrorCode == 0) + { + return {.Status = {.Bytes = Result.Bytes, .ElapsedSeconds = Result.ElapsedSeconds, .Success = Result.Success}, + .Value = Result.Response, + .Source = &m_Info}; + } + else + { + return {.Status = {.Error{.ErrorCode = Result.ErrorCode, .Reason = std::move(Result.Reason)}}}; + } + } + catch (std::exception& Err) + { + m_Status.Set(UpstreamEndpointState::kError, Err.what()); + + return {.Status = {.Error{.ErrorCode = -1, .Reason = Err.what()}}}; + } + } + + virtual GetUpstreamCacheResult GetCacheValues(std::string_view Namespace, + std::span<CacheValueRequest*> CacheValueRequests, + OnCacheValueGetComplete&& OnComplete) override final + { + ZEN_TRACE_CPU("Upstream::Zen::GetCacheValues"); + ZEN_ASSERT(!CacheValueRequests.empty()); + + CbObjectWriter BatchRequest; + BatchRequest << "Method"sv + << "GetCacheValues"sv; + BatchRequest << "Accept"sv << kCbPkgMagic; + + BatchRequest.BeginObject("Params"sv); + { + CachePolicy DefaultPolicy = CacheValueRequests[0]->Policy; + BatchRequest << "DefaultPolicy"sv << WriteToString<128>(DefaultPolicy).ToView(); + BatchRequest << "Namespace"sv << Namespace; + + BatchRequest.BeginArray("Requests"sv); + { + for (CacheValueRequest* RequestPtr : CacheValueRequests) + { + const CacheValueRequest& Request = *RequestPtr; + + BatchRequest.BeginObject(); + { + BatchRequest.BeginObject("Key"sv); + BatchRequest << "Bucket"sv << Request.Key.Bucket; + BatchRequest << "Hash"sv << Request.Key.Hash; + BatchRequest.EndObject(); + if (Request.Policy != DefaultPolicy) + { + BatchRequest << "Policy"sv << WriteToString<128>(Request.Policy).ToView(); + } + } + BatchRequest.EndObject(); + } + } + BatchRequest.EndArray(); + } + BatchRequest.EndObject(); + + ZenCacheResult Result; + + { + ZenStructuredCacheSession Session(GetClientRef()); + Result = Session.InvokeRpc(BatchRequest.Save()); + } + + m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason); + + if (Result.Success) + { + CbPackage BatchResponse; + if (ParsePackageMessageWithLegacyFallback(Result.Response, BatchResponse)) + { + CbArrayView Results = BatchResponse.GetObject()["Result"sv].AsArrayView(); + if (CacheValueRequests.size() != Results.Num()) + { + ZEN_WARN("Upstream::Zen::GetCacheValues invalid number of Response results from Upstream."); + } + else + { + for (size_t RequestIndex = 0; CbFieldView ChunkField : Results) + { + CacheValueRequest& Request = *CacheValueRequests[RequestIndex++]; + CbObjectView ChunkObject = ChunkField.AsObjectView(); + IoHash RawHash = ChunkObject["RawHash"sv].AsHash(); + IoBuffer Payload; + uint64_t RawSize = 0; + if (RawHash != IoHash::Zero) + { + bool Success = false; + const CbAttachment* Attachment = BatchResponse.FindAttachment(RawHash); + if (Attachment) + { + if (const CompressedBuffer& Compressed = Attachment->AsCompressedBinary()) + { + Payload = Compressed.GetCompressed().Flatten().AsIoBuffer(); + Payload.SetContentType(ZenContentType::kCompressedBinary); + RawSize = Compressed.DecodeRawSize(); + Success = true; + } + } + if (!Success) + { + CbFieldView RawSizeField = ChunkObject["RawSize"sv]; + RawSize = RawSizeField.AsUInt64(); + Success = !RawSizeField.HasError(); + } + if (!Success) + { + RawHash = IoHash::Zero; + } + } + OnComplete({.Request = Request, + .RawHash = RawHash, + .RawSize = RawSize, + .Value = std::move(Payload), + .ElapsedSeconds = Result.ElapsedSeconds, + .Source = &m_Info}); + } + + return {.Bytes = Result.Bytes, .ElapsedSeconds = Result.ElapsedSeconds, .Success = true}; + } + } + else + { + ZEN_WARN("Upstream::Zen::GetCacheValues invalid Response from Upstream."); + } + } + + for (CacheValueRequest* RequestPtr : CacheValueRequests) + { + OnComplete({.Request = *RequestPtr, .RawHash = IoHash::Zero, .RawSize = 0, .Value = IoBuffer()}); + } + + return {.Error{.ErrorCode = Result.ErrorCode, .Reason = std::move(Result.Reason)}}; + } + + virtual GetUpstreamCacheResult GetCacheChunks(std::string_view Namespace, + std::span<CacheChunkRequest*> CacheChunkRequests, + OnCacheChunksGetComplete&& OnComplete) override final + { + ZEN_TRACE_CPU("Upstream::Zen::GetCacheChunks"); + ZEN_ASSERT(!CacheChunkRequests.empty()); + + CbObjectWriter BatchRequest; + BatchRequest << "Method"sv + << "GetCacheChunks"sv; + BatchRequest << "Accept"sv << kCbPkgMagic; + + BatchRequest.BeginObject("Params"sv); + { + CachePolicy DefaultPolicy = CacheChunkRequests[0]->Policy; + BatchRequest << "DefaultPolicy"sv << WriteToString<128>(DefaultPolicy).ToView(); + BatchRequest << "Namespace"sv << Namespace; + + BatchRequest.BeginArray("ChunkRequests"sv); + { + for (CacheChunkRequest* RequestPtr : CacheChunkRequests) + { + const CacheChunkRequest& Request = *RequestPtr; + + BatchRequest.BeginObject(); + { + BatchRequest.BeginObject("Key"sv); + BatchRequest << "Bucket"sv << Request.Key.Bucket; + BatchRequest << "Hash"sv << Request.Key.Hash; + BatchRequest.EndObject(); + if (Request.ValueId) + { + BatchRequest.AddObjectId("ValueId"sv, Request.ValueId); + } + if (Request.ChunkId != Request.ChunkId.Zero) + { + BatchRequest << "ChunkId"sv << Request.ChunkId; + } + if (Request.RawOffset != 0) + { + BatchRequest << "RawOffset"sv << Request.RawOffset; + } + if (Request.RawSize != UINT64_MAX) + { + BatchRequest << "RawSize"sv << Request.RawSize; + } + if (Request.Policy != DefaultPolicy) + { + BatchRequest << "Policy"sv << WriteToString<128>(Request.Policy).ToView(); + } + } + BatchRequest.EndObject(); + } + } + BatchRequest.EndArray(); + } + BatchRequest.EndObject(); + + ZenCacheResult Result; + + { + ZenStructuredCacheSession Session(GetClientRef()); + Result = Session.InvokeRpc(BatchRequest.Save()); + } + + m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason); + + if (Result.Success) + { + CbPackage BatchResponse; + if (ParsePackageMessageWithLegacyFallback(Result.Response, BatchResponse)) + { + CbArrayView Results = BatchResponse.GetObject()["Result"sv].AsArrayView(); + if (CacheChunkRequests.size() != Results.Num()) + { + ZEN_WARN("Upstream::Zen::GetCacheChunks invalid number of Response results from Upstream."); + } + else + { + for (size_t RequestIndex = 0; CbFieldView ChunkField : Results) + { + CacheChunkRequest& Request = *CacheChunkRequests[RequestIndex++]; + CbObjectView ChunkObject = ChunkField.AsObjectView(); + IoHash RawHash = ChunkObject["RawHash"sv].AsHash(); + IoBuffer Payload; + uint64_t RawSize = 0; + if (RawHash != IoHash::Zero) + { + bool Success = false; + const CbAttachment* Attachment = BatchResponse.FindAttachment(RawHash); + if (Attachment) + { + if (const CompressedBuffer& Compressed = Attachment->AsCompressedBinary()) + { + Payload = Compressed.GetCompressed().Flatten().AsIoBuffer(); + Payload.SetContentType(ZenContentType::kCompressedBinary); + RawSize = Compressed.DecodeRawSize(); + Success = true; + } + } + if (!Success) + { + CbFieldView RawSizeField = ChunkObject["RawSize"sv]; + RawSize = RawSizeField.AsUInt64(); + Success = !RawSizeField.HasError(); + } + if (!Success) + { + RawHash = IoHash::Zero; + } + } + OnComplete({.Request = Request, + .RawHash = RawHash, + .RawSize = RawSize, + .Value = std::move(Payload), + .ElapsedSeconds = Result.ElapsedSeconds, + .Source = &m_Info}); + } + + return {.Bytes = Result.Bytes, .ElapsedSeconds = Result.ElapsedSeconds, .Success = true}; + } + } + else + { + ZEN_WARN("Upstream::Zen::GetCacheChunks invalid Response from Upstream."); + } + } + + for (CacheChunkRequest* RequestPtr : CacheChunkRequests) + { + OnComplete({.Request = *RequestPtr, .RawHash = IoHash::Zero, .RawSize = 0, .Value = IoBuffer()}); + } + + return {.Error{.ErrorCode = Result.ErrorCode, .Reason = std::move(Result.Reason)}}; + } + + virtual PutUpstreamCacheResult PutCacheRecord(const UpstreamCacheRecord& CacheRecord, + IoBuffer RecordValue, + std::span<IoBuffer const> Values) override + { + ZEN_TRACE_CPU("Upstream::Zen::PutCacheRecord"); + + ZEN_ASSERT(CacheRecord.ValueContentIds.size() == Values.size()); + const int32_t MaxAttempts = 3; + + try + { + ZenStructuredCacheSession Session(GetClientRef()); + ZenCacheResult Result; + int64_t TotalBytes = 0ull; + double TotalElapsedSeconds = 0.0; + + if (CacheRecord.Type == ZenContentType::kCbPackage) + { + CbPackage Package; + Package.SetObject(CbObject(SharedBuffer(RecordValue))); + + for (const IoBuffer& Value : Values) + { + IoHash RawHash; + uint64_t RawSize; + if (CompressedBuffer AttachmentBuffer = CompressedBuffer::FromCompressed(SharedBuffer(Value), RawHash, RawSize)) + { + Package.AddAttachment(CbAttachment(AttachmentBuffer, RawHash)); + } + else + { + return {.Reason = std::string("Invalid value buffer"), .Success = false}; + } + } + + BinaryWriter MemStream; + Package.Save(MemStream); + IoBuffer PackagePayload(IoBuffer::Wrap, MemStream.Data(), MemStream.Size()); + + for (uint32_t Attempt = 0; Attempt < MaxAttempts && !Result.Success; Attempt++) + { + Result = Session.PutCacheRecord(CacheRecord.Namespace, + CacheRecord.Key.Bucket, + CacheRecord.Key.Hash, + PackagePayload, + CacheRecord.Type); + } + + m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason); + + TotalBytes = Result.Bytes; + TotalElapsedSeconds = Result.ElapsedSeconds; + } + else if (CacheRecord.Type == ZenContentType::kCompressedBinary) + { + IoHash RawHash; + uint64_t RawSize; + CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(RecordValue), RawHash, RawSize); + if (!Compressed) + { + return {.Reason = std::string("Invalid value compressed buffer"), .Success = false}; + } + + CbPackage BatchPackage; + CbObjectWriter BatchWriter; + BatchWriter << "Method"sv + << "PutCacheValues"sv; + BatchWriter << "Accept"sv << kCbPkgMagic; + + BatchWriter.BeginObject("Params"sv); + { + // DefaultPolicy unspecified and expected to be Default + + BatchWriter << "Namespace"sv << CacheRecord.Namespace; + + BatchWriter.BeginArray("Requests"sv); + { + BatchWriter.BeginObject(); + { + const CacheKey& Key = CacheRecord.Key; + BatchWriter.BeginObject("Key"sv); + { + BatchWriter << "Bucket"sv << Key.Bucket; + BatchWriter << "Hash"sv << Key.Hash; + } + BatchWriter.EndObject(); + // Policy unspecified and expected to be Default + BatchWriter.AddBinaryAttachment("RawHash"sv, RawHash); + BatchPackage.AddAttachment(CbAttachment(Compressed, RawHash)); + } + BatchWriter.EndObject(); + } + BatchWriter.EndArray(); + } + BatchWriter.EndObject(); + BatchPackage.SetObject(BatchWriter.Save()); + + Result.Success = false; + for (uint32_t Attempt = 0; Attempt < MaxAttempts && !Result.Success; Attempt++) + { + Result = Session.InvokeRpc(BatchPackage); + } + + m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason); + + TotalBytes += Result.Bytes; + TotalElapsedSeconds += Result.ElapsedSeconds; + } + else + { + for (size_t Idx = 0, Count = Values.size(); Idx < Count; Idx++) + { + Result.Success = false; + for (uint32_t Attempt = 0; Attempt < MaxAttempts && !Result.Success; Attempt++) + { + Result = Session.PutCacheValue(CacheRecord.Namespace, + CacheRecord.Key.Bucket, + CacheRecord.Key.Hash, + CacheRecord.ValueContentIds[Idx], + Values[Idx]); + } + + m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason); + + TotalBytes += Result.Bytes; + TotalElapsedSeconds += Result.ElapsedSeconds; + + if (!Result.Success) + { + return {.Reason = "Failed to upload value", + .Bytes = TotalBytes, + .ElapsedSeconds = TotalElapsedSeconds, + .Success = false}; + } + } + + Result.Success = false; + for (uint32_t Attempt = 0; Attempt < MaxAttempts && !Result.Success; Attempt++) + { + Result = Session.PutCacheRecord(CacheRecord.Namespace, + CacheRecord.Key.Bucket, + CacheRecord.Key.Hash, + RecordValue, + CacheRecord.Type); + } + + m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason); + + TotalBytes += Result.Bytes; + TotalElapsedSeconds += Result.ElapsedSeconds; + } + + return {.Reason = std::move(Result.Reason), + .Bytes = TotalBytes, + .ElapsedSeconds = TotalElapsedSeconds, + .Success = Result.Success}; + } + catch (std::exception& Err) + { + m_Status.Set(UpstreamEndpointState::kError, Err.what()); + + return {.Reason = std::string(Err.what()), .Success = false}; + } + } + + virtual UpstreamEndpointStats& Stats() override { return m_Stats; } + + private: + Ref<ZenStructuredCacheClient> GetClientRef() + { + // m_Client can be modified at any time by a different thread. + // Make sure we safely bump the refcount inside a scope lock + RwLock::SharedLockScope _(m_ClientLock); + ZEN_ASSERT(m_Client); + Ref<ZenStructuredCacheClient> ClientRef(m_Client); + _.ReleaseNow(); + return ClientRef; + } + + const ZenEndpoint& GetEndpoint() + { + for (ZenEndpoint& Ep : m_Endpoints) + { + Ref<ZenStructuredCacheClient> Client( + new ZenStructuredCacheClient({.Url = Ep.Url, .ConnectTimeout = std::chrono::milliseconds(1000)})); + ZenStructuredCacheSession Session(std::move(Client)); + const int32_t SampleCount = 2; + + Ep.Ok = false; + Ep.Latency = {}; + + for (int32_t Sample = 0; Sample < SampleCount; ++Sample) + { + ZenCacheResult Result = Session.CheckHealth(); + Ep.Ok = Result.Success; + Ep.Reason = std::move(Result.Reason); + Ep.Latency += Result.ElapsedSeconds; + } + Ep.Latency /= double(SampleCount); + } + + std::sort(std::begin(m_Endpoints), std::end(m_Endpoints)); + + for (const auto& Ep : m_Endpoints) + { + ZEN_INFO("ping 'Zen' endpoint '{}' latency '{:.3}s' {}", Ep.Url, Ep.Latency, Ep.Ok ? "OK" : Ep.Reason); + } + + return m_Endpoints.front(); + } + + spdlog::logger& Log() { return m_Log; } + + spdlog::logger& m_Log; + UpstreamEndpointInfo m_Info; + UpstreamStatus m_Status; + UpstreamEndpointStats m_Stats; + std::vector<ZenEndpoint> m_Endpoints; + std::chrono::milliseconds m_ConnectTimeout; + std::chrono::milliseconds m_Timeout; + RwLock m_ClientLock; + RefPtr<ZenStructuredCacheClient> m_Client; + }; + +} // namespace detail + +////////////////////////////////////////////////////////////////////////// + +class UpstreamCacheImpl final : public UpstreamCache +{ +public: + UpstreamCacheImpl(const UpstreamCacheOptions& Options, ZenCacheStore& CacheStore, CidStore& CidStore) + : m_Log(logging::Get("upstream")) + , m_Options(Options) + , m_CacheStore(CacheStore) + , m_CidStore(CidStore) + { + } + + virtual ~UpstreamCacheImpl() { Shutdown(); } + + virtual void Initialize() override + { + for (uint32_t Idx = 0; Idx < m_Options.ThreadCount; Idx++) + { + m_UpstreamThreads.emplace_back(&UpstreamCacheImpl::ProcessUpstreamQueue, this); + } + + m_EndpointMonitorThread = std::thread(&UpstreamCacheImpl::MonitorEndpoints, this); + m_RunState.IsRunning = true; + } + + virtual void RegisterEndpoint(std::unique_ptr<UpstreamEndpoint> Endpoint) override + { + const UpstreamEndpointStatus Status = Endpoint->Initialize(); + const UpstreamEndpointInfo& Info = Endpoint->GetEndpointInfo(); + + if (Status.State == UpstreamEndpointState::kOk) + { + ZEN_INFO("register endpoint '{} - {}' {}", Info.Name, Info.Url, ToString(Status.State)); + } + else + { + ZEN_WARN("register endpoint '{} - {}' {}", Info.Name, Info.Url, ToString(Status.State)); + } + + // Register endpoint even if it fails, the health monitor thread will probe failing endpoint(s) + std::unique_lock<std::shared_mutex> _(m_EndpointsMutex); + m_Endpoints.emplace_back(std::move(Endpoint)); + } + + virtual void IterateEndpoints(std::function<bool(UpstreamEndpoint&)>&& Fn) override + { + std::shared_lock<std::shared_mutex> _(m_EndpointsMutex); + + for (auto& Ep : m_Endpoints) + { + if (!Fn(*Ep)) + { + break; + } + } + } + + virtual GetUpstreamCacheSingleResult GetCacheRecord(std::string_view Namespace, const CacheKey& CacheKey, ZenContentType Type) override + { + ZEN_TRACE_CPU("Upstream::GetCacheRecord"); + + std::shared_lock<std::shared_mutex> _(m_EndpointsMutex); + + if (m_Options.ReadUpstream) + { + for (auto& Endpoint : m_Endpoints) + { + if (Endpoint->GetState() != UpstreamEndpointState::kOk) + { + continue; + } + + UpstreamEndpointStats& Stats = Endpoint->Stats(); + metrics::OperationTiming::Scope Scope(Stats.CacheGetRequestTiming); + GetUpstreamCacheSingleResult Result = Endpoint->GetCacheRecord(Namespace, CacheKey, Type); + Scope.Stop(); + + Stats.CacheGetCount.Increment(1); + Stats.CacheGetTotalBytes.Increment(Result.Status.Bytes); + + if (Result.Status.Success) + { + Stats.CacheHitCount.Increment(1); + + return Result; + } + + if (Result.Status.Error) + { + Stats.CacheErrorCount.Increment(1); + + ZEN_WARN("get cache record FAILED, endpoint '{}', reason '{}', error code '{}'", + Endpoint->GetEndpointInfo().Url, + Result.Status.Error.Reason, + Result.Status.Error.ErrorCode); + } + } + } + + return {}; + } + + virtual void GetCacheRecords(std::string_view Namespace, + std::span<CacheKeyRequest*> Requests, + OnCacheRecordGetComplete&& OnComplete) override final + { + ZEN_TRACE_CPU("Upstream::GetCacheRecords"); + + std::shared_lock<std::shared_mutex> _(m_EndpointsMutex); + + std::vector<CacheKeyRequest*> RemainingKeys(Requests.begin(), Requests.end()); + + if (m_Options.ReadUpstream) + { + for (auto& Endpoint : m_Endpoints) + { + if (RemainingKeys.empty()) + { + break; + } + + if (Endpoint->GetState() != UpstreamEndpointState::kOk) + { + continue; + } + + UpstreamEndpointStats& Stats = Endpoint->Stats(); + std::vector<CacheKeyRequest*> Missing; + GetUpstreamCacheResult Result; + { + metrics::OperationTiming::Scope Scope(Stats.CacheGetRequestTiming); + + Result = Endpoint->GetCacheRecords(Namespace, RemainingKeys, [&](CacheRecordGetCompleteParams&& Params) { + if (Params.Record) + { + OnComplete(std::forward<CacheRecordGetCompleteParams>(Params)); + + Stats.CacheHitCount.Increment(1); + } + else + { + Missing.push_back(&Params.Request); + } + }); + } + + Stats.CacheGetCount.Increment(int64_t(RemainingKeys.size())); + Stats.CacheGetTotalBytes.Increment(Result.Bytes); + + if (Result.Error) + { + Stats.CacheErrorCount.Increment(1); + + ZEN_WARN("get cache record(s) (rpc) FAILED, endpoint '{}', reason '{}', error code '{}'", + Endpoint->GetEndpointInfo().Url, + Result.Error.Reason, + Result.Error.ErrorCode); + } + + RemainingKeys = std::move(Missing); + } + } + + const UpstreamEndpointInfo Info; + for (CacheKeyRequest* Request : RemainingKeys) + { + OnComplete({.Request = *Request, .Record = CbObjectView(), .Package = CbPackage()}); + } + } + + virtual void GetCacheChunks(std::string_view Namespace, + std::span<CacheChunkRequest*> CacheChunkRequests, + OnCacheChunksGetComplete&& OnComplete) override final + { + ZEN_TRACE_CPU("Upstream::GetCacheChunks"); + + std::shared_lock<std::shared_mutex> _(m_EndpointsMutex); + + std::vector<CacheChunkRequest*> RemainingKeys(CacheChunkRequests.begin(), CacheChunkRequests.end()); + + if (m_Options.ReadUpstream) + { + for (auto& Endpoint : m_Endpoints) + { + if (RemainingKeys.empty()) + { + break; + } + + if (Endpoint->GetState() != UpstreamEndpointState::kOk) + { + continue; + } + + UpstreamEndpointStats& Stats = Endpoint->Stats(); + std::vector<CacheChunkRequest*> Missing; + GetUpstreamCacheResult Result; + { + metrics::OperationTiming::Scope Scope(Endpoint->Stats().CacheGetRequestTiming); + + Result = Endpoint->GetCacheChunks(Namespace, RemainingKeys, [&](CacheChunkGetCompleteParams&& Params) { + if (Params.RawHash != Params.RawHash.Zero) + { + OnComplete(std::forward<CacheChunkGetCompleteParams>(Params)); + + Stats.CacheHitCount.Increment(1); + } + else + { + Missing.push_back(&Params.Request); + } + }); + } + + Stats.CacheGetCount.Increment(int64_t(RemainingKeys.size())); + Stats.CacheGetTotalBytes.Increment(Result.Bytes); + + if (Result.Error) + { + Stats.CacheErrorCount.Increment(1); + + ZEN_WARN("get cache chunks(s) (rpc) FAILED, endpoint '{}', reason '{}', error code '{}'", + Endpoint->GetEndpointInfo().Url, + Result.Error.Reason, + Result.Error.ErrorCode); + } + + RemainingKeys = std::move(Missing); + } + } + + const UpstreamEndpointInfo Info; + for (CacheChunkRequest* RequestPtr : RemainingKeys) + { + OnComplete({.Request = *RequestPtr, .RawHash = IoHash::Zero, .RawSize = 0, .Value = IoBuffer()}); + } + } + + virtual GetUpstreamCacheSingleResult GetCacheChunk(std::string_view Namespace, + const CacheKey& CacheKey, + const IoHash& ValueContentId) override + { + ZEN_TRACE_CPU("Upstream::GetCacheChunk"); + + if (m_Options.ReadUpstream) + { + for (auto& Endpoint : m_Endpoints) + { + if (Endpoint->GetState() != UpstreamEndpointState::kOk) + { + continue; + } + + UpstreamEndpointStats& Stats = Endpoint->Stats(); + metrics::OperationTiming::Scope Scope(Stats.CacheGetRequestTiming); + GetUpstreamCacheSingleResult Result = Endpoint->GetCacheChunk(Namespace, CacheKey, ValueContentId); + Scope.Stop(); + + Stats.CacheGetCount.Increment(1); + Stats.CacheGetTotalBytes.Increment(Result.Status.Bytes); + + if (Result.Status.Success) + { + Stats.CacheHitCount.Increment(1); + + return Result; + } + + if (Result.Status.Error) + { + Stats.CacheErrorCount.Increment(1); + + ZEN_WARN("get cache chunk FAILED, endpoint '{}', reason '{}', error code '{}'", + Endpoint->GetEndpointInfo().Url, + Result.Status.Error.Reason, + Result.Status.Error.ErrorCode); + } + } + } + + return {}; + } + + virtual void GetCacheValues(std::string_view Namespace, + std::span<CacheValueRequest*> CacheValueRequests, + OnCacheValueGetComplete&& OnComplete) override final + { + ZEN_TRACE_CPU("Upstream::GetCacheValues"); + + std::shared_lock<std::shared_mutex> _(m_EndpointsMutex); + + std::vector<CacheValueRequest*> RemainingKeys(CacheValueRequests.begin(), CacheValueRequests.end()); + + if (m_Options.ReadUpstream) + { + for (auto& Endpoint : m_Endpoints) + { + if (RemainingKeys.empty()) + { + break; + } + + if (Endpoint->GetState() != UpstreamEndpointState::kOk) + { + continue; + } + + UpstreamEndpointStats& Stats = Endpoint->Stats(); + std::vector<CacheValueRequest*> Missing; + GetUpstreamCacheResult Result; + { + metrics::OperationTiming::Scope Scope(Endpoint->Stats().CacheGetRequestTiming); + + Result = Endpoint->GetCacheValues(Namespace, RemainingKeys, [&](CacheValueGetCompleteParams&& Params) { + if (Params.RawHash != Params.RawHash.Zero) + { + OnComplete(std::forward<CacheValueGetCompleteParams>(Params)); + + Stats.CacheHitCount.Increment(1); + } + else + { + Missing.push_back(&Params.Request); + } + }); + } + + Stats.CacheGetCount.Increment(int64_t(RemainingKeys.size())); + Stats.CacheGetTotalBytes.Increment(Result.Bytes); + + if (Result.Error) + { + Stats.CacheErrorCount.Increment(1); + + ZEN_WARN("get cache values(s) (rpc) FAILED, endpoint '{}', reason '{}', error code '{}'", + Endpoint->GetEndpointInfo().Url, + Result.Error.Reason, + Result.Error.ErrorCode); + } + + RemainingKeys = std::move(Missing); + } + } + + const UpstreamEndpointInfo Info; + for (CacheValueRequest* RequestPtr : RemainingKeys) + { + OnComplete({.Request = *RequestPtr, .RawHash = IoHash::Zero, .RawSize = 0, .Value = IoBuffer()}); + } + } + + virtual void EnqueueUpstream(UpstreamCacheRecord CacheRecord) override + { + if (m_RunState.IsRunning && m_Options.WriteUpstream && m_Endpoints.size() > 0) + { + if (!m_UpstreamThreads.empty()) + { + m_UpstreamQueue.Enqueue(std::move(CacheRecord)); + } + else + { + ProcessCacheRecord(std::move(CacheRecord)); + } + } + } + + virtual void GetStatus(CbObjectWriter& Status) override + { + Status << "reading" << m_Options.ReadUpstream; + Status << "writing" << m_Options.WriteUpstream; + Status << "worker_threads" << m_Options.ThreadCount; + Status << "queue_count" << m_UpstreamQueue.Size(); + + Status.BeginArray("endpoints"); + for (const auto& Ep : m_Endpoints) + { + const UpstreamEndpointInfo& EpInfo = Ep->GetEndpointInfo(); + const UpstreamEndpointStatus EpStatus = Ep->GetStatus(); + UpstreamEndpointStats& EpStats = Ep->Stats(); + + Status.BeginObject(); + Status << "name" << EpInfo.Name; + Status << "url" << EpInfo.Url; + Status << "state" << ToString(EpStatus.State); + Status << "reason" << EpStatus.Reason; + + Status.BeginObject("cache"sv); + { + const int64_t GetCount = EpStats.CacheGetCount.Value(); + const int64_t HitCount = EpStats.CacheHitCount.Value(); + const int64_t ErrorCount = EpStats.CacheErrorCount.Value(); + const double HitRatio = GetCount > 0 ? double(HitCount) / double(GetCount) : 0.0; + const double ErrorRatio = GetCount > 0 ? double(ErrorCount) / double(GetCount) : 0.0; + + metrics::EmitSnapshot("get_requests"sv, EpStats.CacheGetRequestTiming, Status); + Status << "get_bytes" << EpStats.CacheGetTotalBytes.Value(); + Status << "get_count" << GetCount; + Status << "hit_count" << HitCount; + Status << "hit_ratio" << HitRatio; + Status << "error_count" << ErrorCount; + Status << "error_ratio" << ErrorRatio; + metrics::EmitSnapshot("put_requests"sv, EpStats.CachePutRequestTiming, Status); + Status << "put_bytes" << EpStats.CachePutTotalBytes.Value(); + } + Status.EndObject(); + + Status.EndObject(); + } + Status.EndArray(); + } + +private: + void ProcessCacheRecord(UpstreamCacheRecord CacheRecord) + { + ZEN_TRACE_CPU("Upstream::ProcessCacheRecord"); + + ZenCacheValue CacheValue; + std::vector<IoBuffer> Payloads; + + if (!m_CacheStore.Get(CacheRecord.Namespace, CacheRecord.Key.Bucket, CacheRecord.Key.Hash, CacheValue)) + { + ZEN_WARN("process upstream FAILED, '{}/{}/{}', cache record doesn't exist", + CacheRecord.Namespace, + CacheRecord.Key.Bucket, + CacheRecord.Key.Hash); + return; + } + + for (const IoHash& ValueContentId : CacheRecord.ValueContentIds) + { + if (IoBuffer Payload = m_CidStore.FindChunkByCid(ValueContentId)) + { + Payloads.push_back(Payload); + } + else + { + ZEN_WARN("process upstream FAILED, '{}/{}/{}/{}', ValueContentId doesn't exist in CAS", + CacheRecord.Namespace, + CacheRecord.Key.Bucket, + CacheRecord.Key.Hash, + ValueContentId); + return; + } + } + + std::shared_lock<std::shared_mutex> _(m_EndpointsMutex); + + for (auto& Endpoint : m_Endpoints) + { + if (Endpoint->GetState() != UpstreamEndpointState::kOk) + { + continue; + } + + UpstreamEndpointStats& Stats = Endpoint->Stats(); + PutUpstreamCacheResult Result; + { + metrics::OperationTiming::Scope Scope(Stats.CachePutRequestTiming); + Result = Endpoint->PutCacheRecord(CacheRecord, CacheValue.Value, std::span(Payloads)); + } + + Stats.CachePutTotalBytes.Increment(Result.Bytes); + + if (!Result.Success) + { + ZEN_WARN("upload cache record '{}/{}/{}' FAILED, endpoint '{}', reason '{}'", + CacheRecord.Namespace, + CacheRecord.Key.Bucket, + CacheRecord.Key.Hash, + Endpoint->GetEndpointInfo().Url, + Result.Reason); + } + } + } + + void ProcessUpstreamQueue() + { + for (;;) + { + UpstreamCacheRecord CacheRecord; + if (m_UpstreamQueue.WaitAndDequeue(CacheRecord)) + { + try + { + ProcessCacheRecord(std::move(CacheRecord)); + } + catch (std::exception& Err) + { + ZEN_ERROR("upload cache record '{}/{}/{}' FAILED, reason '{}'", + CacheRecord.Namespace, + CacheRecord.Key.Bucket, + CacheRecord.Key.Hash, + Err.what()); + } + } + + if (!m_RunState.IsRunning) + { + break; + } + } + } + + void MonitorEndpoints() + { + for (;;) + { + { + std::unique_lock lk(m_RunState.Mutex); + if (m_RunState.ExitSignal.wait_for(lk, m_Options.HealthCheckInterval, [this]() { return !m_RunState.IsRunning.load(); })) + { + break; + } + } + + try + { + std::vector<UpstreamEndpoint*> Endpoints; + + { + std::shared_lock<std::shared_mutex> _(m_EndpointsMutex); + + for (auto& Endpoint : m_Endpoints) + { + UpstreamEndpointState State = Endpoint->GetState(); + if (State == UpstreamEndpointState::kError) + { + Endpoints.push_back(Endpoint.get()); + ZEN_WARN("HEALTH - endpoint '{} - {}' is in error state '{}'", + Endpoint->GetEndpointInfo().Name, + Endpoint->GetEndpointInfo().Url, + Endpoint->GetStatus().Reason); + } + if (State == UpstreamEndpointState::kUnauthorized) + { + Endpoints.push_back(Endpoint.get()); + } + } + } + + for (auto& Endpoint : Endpoints) + { + const UpstreamEndpointInfo& Info = Endpoint->GetEndpointInfo(); + const UpstreamEndpointStatus Status = Endpoint->Initialize(); + + if (Status.State == UpstreamEndpointState::kOk) + { + ZEN_INFO("HEALTH - endpoint '{} - {}' Ok", Info.Name, Info.Url); + } + else + { + const std::string Reason = Status.Reason.empty() ? "" : fmt::format(", reason '{}'", Status.Reason); + ZEN_WARN("HEALTH - endpoint '{} - {}' {} {}", Info.Name, Info.Url, ToString(Status.State), Reason); + } + } + } + catch (std::exception& Err) + { + ZEN_ERROR("check endpoint(s) health FAILED, reason '{}'", Err.what()); + } + } + } + + void Shutdown() + { + if (m_RunState.Stop()) + { + m_UpstreamQueue.CompleteAdding(); + for (std::thread& Thread : m_UpstreamThreads) + { + Thread.join(); + } + + m_EndpointMonitorThread.join(); + m_UpstreamThreads.clear(); + m_Endpoints.clear(); + } + } + + spdlog::logger& Log() { return m_Log; } + + using UpstreamQueue = BlockingQueue<UpstreamCacheRecord>; + + struct RunState + { + std::mutex Mutex; + std::condition_variable ExitSignal; + std::atomic_bool IsRunning{false}; + + bool Stop() + { + bool Stopped = false; + { + std::lock_guard _(Mutex); + Stopped = IsRunning.exchange(false); + } + if (Stopped) + { + ExitSignal.notify_all(); + } + return Stopped; + } + }; + + spdlog::logger& m_Log; + UpstreamCacheOptions m_Options; + ZenCacheStore& m_CacheStore; + CidStore& m_CidStore; + UpstreamQueue m_UpstreamQueue; + std::shared_mutex m_EndpointsMutex; + std::vector<std::unique_ptr<UpstreamEndpoint>> m_Endpoints; + std::vector<std::thread> m_UpstreamThreads; + std::thread m_EndpointMonitorThread; + RunState m_RunState; +}; + +////////////////////////////////////////////////////////////////////////// + +std::unique_ptr<UpstreamEndpoint> +UpstreamEndpoint::CreateZenEndpoint(const ZenStructuredCacheClientOptions& Options) +{ + return std::make_unique<detail::ZenUpstreamEndpoint>(Options); +} + +std::unique_ptr<UpstreamEndpoint> +UpstreamEndpoint::CreateJupiterEndpoint(const CloudCacheClientOptions& Options, const UpstreamAuthConfig& AuthConfig, AuthMgr& Mgr) +{ + return std::make_unique<detail::JupiterUpstreamEndpoint>(Options, AuthConfig, Mgr); +} + +std::unique_ptr<UpstreamCache> +UpstreamCache::Create(const UpstreamCacheOptions& Options, ZenCacheStore& CacheStore, CidStore& CidStore) +{ + return std::make_unique<UpstreamCacheImpl>(Options, CacheStore, CidStore); +} + +} // namespace zen diff --git a/src/zenserver/upstream/upstreamcache.h b/src/zenserver/upstream/upstreamcache.h new file mode 100644 index 000000000..695c06b32 --- /dev/null +++ b/src/zenserver/upstream/upstreamcache.h @@ -0,0 +1,252 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/compactbinary.h> +#include <zencore/compress.h> +#include <zencore/iobuffer.h> +#include <zencore/iohash.h> +#include <zencore/stats.h> +#include <zencore/zencore.h> +#include <zenutil/cache/cache.h> + +#include <atomic> +#include <chrono> +#include <functional> +#include <memory> +#include <vector> + +namespace zen { + +class CbObjectView; +class AuthMgr; +class CbObjectView; +class CbPackage; +class CbObjectWriter; +class CidStore; +class ZenCacheStore; +struct CloudCacheClientOptions; +class CloudCacheTokenProvider; +struct ZenStructuredCacheClientOptions; + +struct UpstreamCacheRecord +{ + ZenContentType Type = ZenContentType::kBinary; + std::string Namespace; + CacheKey Key; + std::vector<IoHash> ValueContentIds; +}; + +struct UpstreamCacheOptions +{ + std::chrono::seconds HealthCheckInterval{5}; + uint32_t ThreadCount = 4; + bool ReadUpstream = true; + bool WriteUpstream = true; +}; + +struct UpstreamError +{ + int32_t ErrorCode{}; + std::string Reason{}; + + explicit operator bool() const { return ErrorCode != 0; } +}; + +struct UpstreamEndpointInfo +{ + std::string Name; + std::string Url; +}; + +struct GetUpstreamCacheResult +{ + UpstreamError Error{}; + int64_t Bytes{}; + double ElapsedSeconds{}; + bool Success = false; +}; + +struct GetUpstreamCacheSingleResult +{ + GetUpstreamCacheResult Status; + IoBuffer Value; + const UpstreamEndpointInfo* Source = nullptr; +}; + +struct PutUpstreamCacheResult +{ + std::string Reason; + int64_t Bytes{}; + double ElapsedSeconds{}; + bool Success = false; +}; + +struct CacheRecordGetCompleteParams +{ + CacheKeyRequest& Request; + const CbObjectView& Record; + const CbPackage& Package; + double ElapsedSeconds{}; + const UpstreamEndpointInfo* Source = nullptr; +}; + +using OnCacheRecordGetComplete = std::function<void(CacheRecordGetCompleteParams&&)>; + +struct CacheValueGetCompleteParams +{ + CacheValueRequest& Request; + IoHash RawHash; + uint64_t RawSize; + IoBuffer Value; + double ElapsedSeconds{}; + const UpstreamEndpointInfo* Source = nullptr; +}; + +using OnCacheValueGetComplete = std::function<void(CacheValueGetCompleteParams&&)>; + +struct CacheChunkGetCompleteParams +{ + CacheChunkRequest& Request; + IoHash RawHash; + uint64_t RawSize; + IoBuffer Value; + double ElapsedSeconds{}; + const UpstreamEndpointInfo* Source = nullptr; +}; + +using OnCacheChunksGetComplete = std::function<void(CacheChunkGetCompleteParams&&)>; + +struct UpstreamEndpointStats +{ + metrics::OperationTiming CacheGetRequestTiming; + metrics::OperationTiming CachePutRequestTiming; + metrics::Counter CacheGetTotalBytes; + metrics::Counter CachePutTotalBytes; + metrics::Counter CacheGetCount; + metrics::Counter CacheHitCount; + metrics::Counter CacheErrorCount; +}; + +enum class UpstreamEndpointState : uint32_t +{ + kDisabled, + kUnauthorized, + kError, + kOk +}; + +inline std::string_view +ToString(UpstreamEndpointState State) +{ + using namespace std::literals; + + switch (State) + { + case UpstreamEndpointState::kDisabled: + return "Disabled"sv; + case UpstreamEndpointState::kUnauthorized: + return "Unauthorized"sv; + case UpstreamEndpointState::kError: + return "Error"sv; + case UpstreamEndpointState::kOk: + return "Ok"sv; + default: + return "Unknown"sv; + } +} + +struct UpstreamAuthConfig +{ + std::string_view OAuthUrl; + std::string_view OAuthClientId; + std::string_view OAuthClientSecret; + std::string_view OpenIdProvider; + std::string_view AccessToken; +}; + +struct UpstreamEndpointStatus +{ + std::string Reason; + UpstreamEndpointState State; +}; + +/** + * The upstream endpoint is responsible for handling upload/downloading of cache records. + */ +class UpstreamEndpoint +{ +public: + virtual ~UpstreamEndpoint() = default; + + virtual UpstreamEndpointStatus Initialize() = 0; + + virtual const UpstreamEndpointInfo& GetEndpointInfo() const = 0; + + virtual UpstreamEndpointState GetState() = 0; + virtual UpstreamEndpointStatus GetStatus() = 0; + + virtual GetUpstreamCacheSingleResult GetCacheRecord(std::string_view Namespace, const CacheKey& CacheKey, ZenContentType Type) = 0; + virtual GetUpstreamCacheResult GetCacheRecords(std::string_view Namespace, + std::span<CacheKeyRequest*> Requests, + OnCacheRecordGetComplete&& OnComplete) = 0; + + virtual GetUpstreamCacheResult GetCacheValues(std::string_view Namespace, + std::span<CacheValueRequest*> CacheValueRequests, + OnCacheValueGetComplete&& OnComplete) = 0; + + virtual GetUpstreamCacheSingleResult GetCacheChunk(std::string_view Namespace, const CacheKey& CacheKey, const IoHash& PayloadId) = 0; + virtual GetUpstreamCacheResult GetCacheChunks(std::string_view Namespace, + std::span<CacheChunkRequest*> CacheChunkRequests, + OnCacheChunksGetComplete&& OnComplete) = 0; + + virtual PutUpstreamCacheResult PutCacheRecord(const UpstreamCacheRecord& CacheRecord, + IoBuffer RecordValue, + std::span<IoBuffer const> Payloads) = 0; + + virtual UpstreamEndpointStats& Stats() = 0; + + static std::unique_ptr<UpstreamEndpoint> CreateZenEndpoint(const ZenStructuredCacheClientOptions& Options); + + static std::unique_ptr<UpstreamEndpoint> CreateJupiterEndpoint(const CloudCacheClientOptions& Options, + const UpstreamAuthConfig& AuthConfig, + AuthMgr& Mgr); +}; + +/** + * Manages one or more upstream cache endpoints. + */ +class UpstreamCache +{ +public: + virtual ~UpstreamCache() = default; + + virtual void Initialize() = 0; + + virtual void RegisterEndpoint(std::unique_ptr<UpstreamEndpoint> Endpoint) = 0; + virtual void IterateEndpoints(std::function<bool(UpstreamEndpoint&)>&& Fn) = 0; + + virtual GetUpstreamCacheSingleResult GetCacheRecord(std::string_view Namespace, const CacheKey& CacheKey, ZenContentType Type) = 0; + virtual void GetCacheRecords(std::string_view Namespace, + std::span<CacheKeyRequest*> Requests, + OnCacheRecordGetComplete&& OnComplete) = 0; + + virtual void GetCacheValues(std::string_view Namespace, + std::span<CacheValueRequest*> CacheValueRequests, + OnCacheValueGetComplete&& OnComplete) = 0; + + virtual GetUpstreamCacheSingleResult GetCacheChunk(std::string_view Namespace, + const CacheKey& CacheKey, + const IoHash& ValueContentId) = 0; + virtual void GetCacheChunks(std::string_view Namespace, + std::span<CacheChunkRequest*> CacheChunkRequests, + OnCacheChunksGetComplete&& OnComplete) = 0; + + virtual void EnqueueUpstream(UpstreamCacheRecord CacheRecord) = 0; + + virtual void GetStatus(CbObjectWriter& CbO) = 0; + + static std::unique_ptr<UpstreamCache> Create(const UpstreamCacheOptions& Options, ZenCacheStore& CacheStore, CidStore& CidStore); +}; + +} // namespace zen diff --git a/src/zenserver/upstream/upstreamservice.cpp b/src/zenserver/upstream/upstreamservice.cpp new file mode 100644 index 000000000..6db1357c5 --- /dev/null +++ b/src/zenserver/upstream/upstreamservice.cpp @@ -0,0 +1,56 @@ +// Copyright Epic Games, Inc. All Rights Reserved. +#include <upstream/upstreamservice.h> + +#include <auth/authmgr.h> +#include <upstream/upstreamcache.h> + +#include <zencore/compactbinarybuilder.h> +#include <zencore/string.h> + +namespace zen { + +using namespace std::literals; + +HttpUpstreamService::HttpUpstreamService(UpstreamCache& Upstream, AuthMgr& Mgr) : m_Upstream(Upstream), m_AuthMgr(Mgr) +{ + m_Router.RegisterRoute( + "endpoints", + [this](HttpRouterRequest& Req) { + CbObjectWriter Writer; + Writer.BeginArray("Endpoints"sv); + m_Upstream.IterateEndpoints([&Writer](UpstreamEndpoint& Ep) { + UpstreamEndpointInfo Info = Ep.GetEndpointInfo(); + UpstreamEndpointStatus Status = Ep.GetStatus(); + + Writer.BeginObject(); + Writer << "Name"sv << Info.Name; + Writer << "Url"sv << Info.Url; + Writer << "State"sv << ToString(Status.State); + Writer << "Reason"sv << Status.Reason; + Writer.EndObject(); + + return true; + }); + Writer.EndArray(); + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Writer.Save()); + }, + HttpVerb::kGet); +} + +HttpUpstreamService::~HttpUpstreamService() +{ +} + +const char* +HttpUpstreamService::BaseUri() const +{ + return "/upstream/"; +} + +void +HttpUpstreamService::HandleRequest(zen::HttpServerRequest& Request) +{ + m_Router.HandleRequest(Request); +} + +} // namespace zen diff --git a/src/zenserver/upstream/upstreamservice.h b/src/zenserver/upstream/upstreamservice.h new file mode 100644 index 000000000..f1da03c8c --- /dev/null +++ b/src/zenserver/upstream/upstreamservice.h @@ -0,0 +1,27 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenhttp/httpserver.h> + +namespace zen { + +class AuthMgr; +class UpstreamCache; + +class HttpUpstreamService final : public zen::HttpService +{ +public: + HttpUpstreamService(UpstreamCache& Upstream, AuthMgr& Mgr); + virtual ~HttpUpstreamService(); + + virtual const char* BaseUri() const override; + virtual void HandleRequest(zen::HttpServerRequest& Request) override; + +private: + UpstreamCache& m_Upstream; + AuthMgr& m_AuthMgr; + HttpRequestRouter m_Router; +}; + +} // namespace zen diff --git a/src/zenserver/upstream/zen.cpp b/src/zenserver/upstream/zen.cpp new file mode 100644 index 000000000..9e1212834 --- /dev/null +++ b/src/zenserver/upstream/zen.cpp @@ -0,0 +1,326 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "zen.h" + +#include <zencore/compactbinarybuilder.h> +#include <zencore/compactbinarypackage.h> +#include <zencore/compactbinaryvalidation.h> +#include <zencore/fmtutils.h> +#include <zencore/session.h> +#include <zencore/stream.h> +#include <zenhttp/httpcommon.h> +#include <zenhttp/httpshared.h> + +#include "cache/structuredcachestore.h" +#include "diag/formatters.h" +#include "diag/logging.h" + +ZEN_THIRD_PARTY_INCLUDES_START +#include <cpr/cpr.h> +ZEN_THIRD_PARTY_INCLUDES_END + +#include <xxhash.h> +#include <gsl/gsl-lite.hpp> + +namespace zen { + +namespace detail { + struct ZenCacheSessionState + { + ZenCacheSessionState(ZenStructuredCacheClient& Client) : OwnerClient(Client) {} + ~ZenCacheSessionState() {} + + void Reset(std::chrono::milliseconds ConnectTimeout, std::chrono::milliseconds Timeout) + { + Session.SetBody({}); + Session.SetHeader({}); + Session.SetConnectTimeout(ConnectTimeout); + Session.SetTimeout(Timeout); + } + + cpr::Session& GetSession() { return Session; } + + private: + ZenStructuredCacheClient& OwnerClient; + cpr::Session Session; + }; + +} // namespace detail + +////////////////////////////////////////////////////////////////////////// + +ZenStructuredCacheClient::ZenStructuredCacheClient(const ZenStructuredCacheClientOptions& Options) +: m_Log(logging::Get(std::string_view("zenclient"))) +, m_ServiceUrl(Options.Url) +, m_ConnectTimeout(Options.ConnectTimeout) +, m_Timeout(Options.Timeout) +{ +} + +ZenStructuredCacheClient::~ZenStructuredCacheClient() +{ +} + +detail::ZenCacheSessionState* +ZenStructuredCacheClient::AllocSessionState() +{ + detail::ZenCacheSessionState* State = nullptr; + + if (RwLock::ExclusiveLockScope _(m_SessionStateLock); !m_SessionStateCache.empty()) + { + State = m_SessionStateCache.front(); + m_SessionStateCache.pop_front(); + } + + if (State == nullptr) + { + State = new detail::ZenCacheSessionState(*this); + } + + State->Reset(m_ConnectTimeout, m_Timeout); + + return State; +} + +void +ZenStructuredCacheClient::FreeSessionState(detail::ZenCacheSessionState* State) +{ + RwLock::ExclusiveLockScope _(m_SessionStateLock); + m_SessionStateCache.push_front(State); +} + +////////////////////////////////////////////////////////////////////////// + +using namespace std::literals; + +ZenStructuredCacheSession::ZenStructuredCacheSession(Ref<ZenStructuredCacheClient>&& OuterClient) +: m_Log(OuterClient->Log()) +, m_Client(std::move(OuterClient)) +{ + m_SessionState = m_Client->AllocSessionState(); +} + +ZenStructuredCacheSession::~ZenStructuredCacheSession() +{ + m_Client->FreeSessionState(m_SessionState); +} + +ZenCacheResult +ZenStructuredCacheSession::CheckHealth() +{ + ExtendableStringBuilder<256> Uri; + Uri << m_Client->ServiceUrl() << "/health/check"; + + cpr::Session& Session = m_SessionState->GetSession(); + Session.SetOption(cpr::Url{Uri.c_str()}); + cpr::Response Response = Session.Get(); + + if (Response.error) + { + return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = std::move(Response.error.message)}; + } + + return {.Bytes = Response.downloaded_bytes, .ElapsedSeconds = Response.elapsed, .Success = Response.status_code == 200}; +} + +ZenCacheResult +ZenStructuredCacheSession::GetCacheRecord(std::string_view Namespace, std::string_view BucketId, const IoHash& Key, ZenContentType Type) +{ + ExtendableStringBuilder<256> Uri; + Uri << m_Client->ServiceUrl() << "/z$/"; + if (Namespace != ZenCacheStore::DefaultNamespace) + { + Uri << Namespace << "/"; + } + Uri << BucketId << "/" << Key.ToHexString(); + + cpr::Session& Session = m_SessionState->GetSession(); + + Session.SetOption(cpr::Url{Uri.c_str()}); + Session.SetHeader(cpr::Header{{"Accept", std::string{MapContentTypeToString(Type)}}}); + cpr::Response Response = Session.Get(); + ZEN_DEBUG("GET {}", Response); + + if (Response.error) + { + return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = std::move(Response.error.message)}; + } + + const bool Success = Response.status_code == 200; + const IoBuffer Buffer = Success ? IoBufferBuilder::MakeCloneFromMemory(Response.text.data(), Response.text.size()) : IoBuffer(); + + return {.Response = Buffer, .Bytes = Response.downloaded_bytes, .ElapsedSeconds = Response.elapsed, .Success = Success}; +} + +ZenCacheResult +ZenStructuredCacheSession::GetCacheChunk(std::string_view Namespace, + std::string_view BucketId, + const IoHash& Key, + const IoHash& ValueContentId) +{ + ExtendableStringBuilder<256> Uri; + Uri << m_Client->ServiceUrl() << "/z$/"; + if (Namespace != ZenCacheStore::DefaultNamespace) + { + Uri << Namespace << "/"; + } + Uri << BucketId << "/" << Key.ToHexString() << "/" << ValueContentId.ToHexString(); + + cpr::Session& Session = m_SessionState->GetSession(); + + Session.SetOption(cpr::Url{Uri.c_str()}); + Session.SetHeader(cpr::Header{{"Accept", "application/x-ue-comp"}}); + + cpr::Response Response = Session.Get(); + ZEN_DEBUG("GET {}", Response); + + if (Response.error) + { + return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = std::move(Response.error.message)}; + } + + const bool Success = Response.status_code == 200; + const IoBuffer Buffer = Success ? IoBufferBuilder::MakeCloneFromMemory(Response.text.data(), Response.text.size()) : IoBuffer(); + + return {.Response = Buffer, + .Bytes = Response.downloaded_bytes, + .ElapsedSeconds = Response.elapsed, + .Reason = Response.reason, + .Success = Success}; +} + +ZenCacheResult +ZenStructuredCacheSession::PutCacheRecord(std::string_view Namespace, + std::string_view BucketId, + const IoHash& Key, + IoBuffer Value, + ZenContentType Type) +{ + ExtendableStringBuilder<256> Uri; + Uri << m_Client->ServiceUrl() << "/z$/"; + if (Namespace != ZenCacheStore::DefaultNamespace) + { + Uri << Namespace << "/"; + } + Uri << BucketId << "/" << Key.ToHexString(); + + cpr::Session& Session = m_SessionState->GetSession(); + + Session.SetOption(cpr::Url{Uri.c_str()}); + Session.SetHeader(cpr::Header{{"Content-Type", + Type == ZenContentType::kCbPackage ? "application/x-ue-cbpkg" + : Type == ZenContentType::kCbObject ? "application/x-ue-cb" + : "application/octet-stream"}}); + Session.SetBody(cpr::Body{static_cast<const char*>(Value.Data()), Value.Size()}); + + cpr::Response Response = Session.Put(); + ZEN_DEBUG("PUT {}", Response); + + if (Response.error) + { + return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = std::move(Response.error.message)}; + } + + const bool Success = Response.status_code == 200 || Response.status_code == 201; + return {.Bytes = Response.uploaded_bytes, .ElapsedSeconds = Response.elapsed, .Reason = Response.reason, .Success = Success}; +} + +ZenCacheResult +ZenStructuredCacheSession::PutCacheValue(std::string_view Namespace, + std::string_view BucketId, + const IoHash& Key, + const IoHash& ValueContentId, + IoBuffer Payload) +{ + ExtendableStringBuilder<256> Uri; + Uri << m_Client->ServiceUrl() << "/z$/"; + if (Namespace != ZenCacheStore::DefaultNamespace) + { + Uri << Namespace << "/"; + } + Uri << BucketId << "/" << Key.ToHexString() << "/" << ValueContentId.ToHexString(); + + cpr::Session& Session = m_SessionState->GetSession(); + + Session.SetOption(cpr::Url{Uri.c_str()}); + Session.SetHeader(cpr::Header{{"Content-Type", "application/x-ue-comp"}}); + Session.SetBody(cpr::Body{static_cast<const char*>(Payload.Data()), Payload.Size()}); + + cpr::Response Response = Session.Put(); + ZEN_DEBUG("PUT {}", Response); + + if (Response.error) + { + return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = std::move(Response.error.message)}; + } + + const bool Success = Response.status_code == 200 || Response.status_code == 201; + return {.Bytes = Response.uploaded_bytes, .ElapsedSeconds = Response.elapsed, .Reason = Response.reason, .Success = Success}; +} + +ZenCacheResult +ZenStructuredCacheSession::InvokeRpc(const CbObjectView& Request) +{ + ExtendableStringBuilder<256> Uri; + Uri << m_Client->ServiceUrl() << "/z$/$rpc"; + + BinaryWriter Body; + Request.CopyTo(Body); + + cpr::Session& Session = m_SessionState->GetSession(); + + Session.SetOption(cpr::Url{Uri.c_str()}); + Session.SetHeader(cpr::Header{{"Content-Type", "application/x-ue-cb"}, {"Accept", "application/x-ue-cbpkg"}}); + Session.SetBody(cpr::Body{reinterpret_cast<const char*>(Body.GetData()), Body.GetSize()}); + + cpr::Response Response = Session.Post(); + ZEN_DEBUG("POST {}", Response); + + if (Response.error) + { + return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = std::move(Response.error.message)}; + } + + const bool Success = Response.status_code == 200; + const IoBuffer Buffer = Success ? IoBufferBuilder::MakeCloneFromMemory(Response.text.data(), Response.text.size()) : IoBuffer(); + + return {.Response = std::move(Buffer), + .Bytes = Response.uploaded_bytes, + .ElapsedSeconds = Response.elapsed, + .Reason = Response.reason, + .Success = Success}; +} + +ZenCacheResult +ZenStructuredCacheSession::InvokeRpc(const CbPackage& Request) +{ + ExtendableStringBuilder<256> Uri; + Uri << m_Client->ServiceUrl() << "/z$/$rpc"; + + SharedBuffer Message = FormatPackageMessageBuffer(Request).Flatten(); + + cpr::Session& Session = m_SessionState->GetSession(); + + Session.SetOption(cpr::Url{Uri.c_str()}); + Session.SetHeader(cpr::Header{{"Content-Type", "application/x-ue-cbpkg"}, {"Accept", "application/x-ue-cbpkg"}}); + Session.SetBody(cpr::Body{reinterpret_cast<const char*>(Message.GetData()), Message.GetSize()}); + + cpr::Response Response = Session.Post(); + ZEN_DEBUG("POST {}", Response); + + if (Response.error) + { + return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = std::move(Response.error.message)}; + } + + const bool Success = Response.status_code == 200; + const IoBuffer Buffer = Success ? IoBufferBuilder::MakeCloneFromMemory(Response.text.data(), Response.text.size()) : IoBuffer(); + + return {.Response = std::move(Buffer), + .Bytes = Response.uploaded_bytes, + .ElapsedSeconds = Response.elapsed, + .Reason = Response.reason, + .Success = Success}; +} + +} // namespace zen diff --git a/src/zenserver/upstream/zen.h b/src/zenserver/upstream/zen.h new file mode 100644 index 000000000..bfba8fa98 --- /dev/null +++ b/src/zenserver/upstream/zen.h @@ -0,0 +1,125 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/iobuffer.h> +#include <zencore/iohash.h> +#include <zencore/logging.h> +#include <zencore/memory.h> +#include <zencore/thread.h> +#include <zencore/uid.h> +#include <zencore/zencore.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <tsl/robin_map.h> +#include <asio.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +#include <chrono> +#include <list> + +struct ZenCacheValue; + +namespace spdlog { +class logger; +} + +namespace zen { + +class CbObjectWriter; +class CbObjectView; +class CbPackage; +class ZenStructuredCacheClient; + +////////////////////////////////////////////////////////////////////////// + +namespace detail { + struct ZenCacheSessionState; +} + +struct ZenCacheResult +{ + IoBuffer Response; + int64_t Bytes = {}; + double ElapsedSeconds = {}; + int32_t ErrorCode = {}; + std::string Reason; + bool Success = false; +}; + +struct ZenStructuredCacheClientOptions +{ + std::string_view Name; + std::string_view Url; + std::span<std::string const> Urls; + std::chrono::milliseconds ConnectTimeout{}; + std::chrono::milliseconds Timeout{}; +}; + +/** Zen Structured Cache session + * + * This provides a context in which cache queries can be performed + * + * These are currently all synchronous. Will need to be made asynchronous + */ +class ZenStructuredCacheSession +{ +public: + ZenStructuredCacheSession(Ref<ZenStructuredCacheClient>&& OuterClient); + ~ZenStructuredCacheSession(); + + ZenCacheResult CheckHealth(); + ZenCacheResult GetCacheRecord(std::string_view Namespace, std::string_view BucketId, const IoHash& Key, ZenContentType Type); + ZenCacheResult GetCacheChunk(std::string_view Namespace, std::string_view BucketId, const IoHash& Key, const IoHash& ValueContentId); + ZenCacheResult PutCacheRecord(std::string_view Namespace, + std::string_view BucketId, + const IoHash& Key, + IoBuffer Value, + ZenContentType Type); + ZenCacheResult PutCacheValue(std::string_view Namespace, + std::string_view BucketId, + const IoHash& Key, + const IoHash& ValueContentId, + IoBuffer Payload); + ZenCacheResult InvokeRpc(const CbObjectView& Request); + ZenCacheResult InvokeRpc(const CbPackage& Package); + +private: + inline spdlog::logger& Log() { return m_Log; } + + spdlog::logger& m_Log; + Ref<ZenStructuredCacheClient> m_Client; + detail::ZenCacheSessionState* m_SessionState; +}; + +/** Zen Structured Cache client + * + * This represents an endpoint to query -- actual queries should be done via + * ZenStructuredCacheSession + */ +class ZenStructuredCacheClient : public RefCounted +{ +public: + ZenStructuredCacheClient(const ZenStructuredCacheClientOptions& Options); + ~ZenStructuredCacheClient(); + + std::string_view ServiceUrl() const { return m_ServiceUrl; } + + inline spdlog::logger& Log() { return m_Log; } + +private: + spdlog::logger& m_Log; + std::string m_ServiceUrl; + std::chrono::milliseconds m_ConnectTimeout; + std::chrono::milliseconds m_Timeout; + + RwLock m_SessionStateLock; + std::list<detail::ZenCacheSessionState*> m_SessionStateCache; + + detail::ZenCacheSessionState* AllocSessionState(); + void FreeSessionState(detail::ZenCacheSessionState*); + + friend class ZenStructuredCacheSession; +}; + +} // namespace zen |