aboutsummaryrefslogtreecommitdiff
path: root/src/zenstore/cache/cacherpc.cpp
diff options
context:
space:
mode:
authorStefan Boberg <[email protected]>2023-12-20 16:03:35 +0100
committerGitHub <[email protected]>2023-12-20 16:03:35 +0100
commit086558fd15f884cd29d1e6941a8576190c0b650d (patch)
tree71e5e729be82d1825a228931d9c03376c659b5ca /src/zenstore/cache/cacherpc.cpp
parentmove cachedisklayer and structuredcachestore into zenstore (#624) (diff)
downloadzen-086558fd15f884cd29d1e6941a8576190c0b650d.tar.xz
zen-086558fd15f884cd29d1e6941a8576190c0b650d.zip
separate RPC processing from HTTP processing (#626)
* moved all RPC processing from HttpStructuredCacheService into separate CacheRpcHandler class in zenstore * move package marshaling to zenutil. was previously in zenhttp/httpshared but it's useful in other contexts as well where we don't want to depend on zenhttp * introduced UpstreamCacheClient, this provides a subset of functions on UpstreamCache and lives in zenstore
Diffstat (limited to 'src/zenstore/cache/cacherpc.cpp')
-rw-r--r--src/zenstore/cache/cacherpc.cpp1640
1 files changed, 1640 insertions, 0 deletions
diff --git a/src/zenstore/cache/cacherpc.cpp b/src/zenstore/cache/cacherpc.cpp
new file mode 100644
index 000000000..96b344ee9
--- /dev/null
+++ b/src/zenstore/cache/cacherpc.cpp
@@ -0,0 +1,1640 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "zenstore/cache/cacherpc.h"
+
+#include <zencore/compactbinary.h>
+#include <zencore/compactbinarybuilder.h>
+#include <zencore/compactbinarypackage.h>
+#include <zencore/compactbinaryvalidation.h>
+#include <zencore/fmtutils.h>
+#include <zencore/scopeguard.h>
+#include <zencore/timer.h>
+#include <zencore/trace.h>
+#include <zenstore/cache/cacheshared.h>
+#include <zenstore/cache/structuredcachestore.h>
+#include <zenstore/cache/upstreamcacheclient.h>
+#include <zenstore/cidstore.h>
+#include <zenutil/packageformat.h>
+
+namespace zen { namespace {
+
+ constinit AsciiSet ValidNamespaceNameCharactersSet{"abcdefghijklmnopqrstuvwxyz0123456789-_.ABCDEFGHIJKLMNOPQRSTUVWXYZ"};
+ constinit AsciiSet ValidBucketNameCharactersSet{"abcdefghijklmnopqrstuvwxyz0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"};
+
+ std::optional<std::string> GetValidNamespaceName(std::string_view Name)
+ {
+ if (Name.empty())
+ {
+ ZEN_WARN("Namespace is invalid, empty namespace is not allowed");
+ return {};
+ }
+
+ if (Name.length() > 64)
+ {
+ ZEN_WARN("Namespace '{}' is invalid, length exceeds 64 characters", Name);
+ return {};
+ }
+
+ if (!AsciiSet::HasOnly(Name, ValidNamespaceNameCharactersSet))
+ {
+ ZEN_WARN("Namespace '{}' is invalid, invalid characters detected", Name);
+ return {};
+ }
+
+ return ToLower(Name);
+ }
+
+ std::optional<std::string> GetValidBucketName(std::string_view Name)
+ {
+ if (Name.empty())
+ {
+ ZEN_WARN("Bucket name is invalid, empty bucket name is not allowed");
+ return {};
+ }
+
+ if (!AsciiSet::HasOnly(Name, ValidBucketNameCharactersSet))
+ {
+ ZEN_WARN("Bucket name '{}' is invalid, invalid characters detected", Name);
+ return {};
+ }
+
+ return ToLower(Name);
+ }
+
+}} // namespace zen::
+
+//////////////////////////////////////////////////////////////////////////
+
+namespace zen {
+
+using namespace std::literals;
+
+std::optional<std::string>
+GetRpcRequestNamespace(const CbObjectView Params)
+{
+ CbFieldView NamespaceField = Params["Namespace"sv];
+ if (!NamespaceField)
+ {
+ return std::string(ZenCacheStore::DefaultNamespace);
+ }
+
+ if (NamespaceField.HasError())
+ {
+ return {};
+ }
+ if (!NamespaceField.IsString())
+ {
+ return {};
+ }
+ return GetValidNamespaceName(NamespaceField.AsString());
+}
+
+bool
+GetRpcRequestCacheKey(const CbObjectView& KeyView, CacheKey& Key)
+{
+ CbFieldView BucketField = KeyView["Bucket"sv];
+ if (BucketField.HasError())
+ {
+ return false;
+ }
+ if (!BucketField.IsString())
+ {
+ return false;
+ }
+ std::optional<std::string> Bucket = GetValidBucketName(BucketField.AsString());
+ if (!Bucket.has_value())
+ {
+ return false;
+ }
+ CbFieldView HashField = KeyView["Hash"sv];
+ if (HashField.HasError())
+ {
+ return false;
+ }
+ if (!HashField.IsHash())
+ {
+ return false;
+ }
+ IoHash Hash = HashField.AsHash();
+ Key = CacheKey::Create(*Bucket, Hash);
+ return true;
+}
+
+namespace cache::detail {
+
+ struct RecordValue
+ {
+ Oid ValueId;
+ IoHash ContentId;
+ uint64_t RawSize;
+ };
+
+ struct RecordBody
+ {
+ IoBuffer CacheValue;
+ std::vector<RecordValue> Values;
+ const UpstreamEndpointInfo* Source = nullptr;
+ CachePolicy DownstreamPolicy;
+ bool Exists = false;
+ bool HasRequest = false;
+ bool ValuesRead = false;
+ };
+
+ struct ChunkRequest
+ {
+ CacheChunkRequest* Key = nullptr;
+ RecordBody* Record = nullptr;
+ CompressedBuffer Value;
+ const UpstreamEndpointInfo* Source = nullptr;
+ uint64_t RawSize = 0;
+ uint64_t RequestedSize = 0;
+ uint64_t RequestedOffset = 0;
+ CachePolicy DownstreamPolicy;
+ bool Exists = false;
+ bool RawSizeKnown = false;
+ bool IsRecordRequest = false;
+ uint64_t ElapsedTimeUs = 0;
+ };
+
+} // namespace cache::detail
+
+struct PutRequestData
+{
+ std::string Namespace;
+ CacheKey Key;
+ CbObjectView RecordObject;
+ CacheRecordPolicy Policy;
+ CacheRequestContext Context;
+};
+
+CacheRecordPolicy
+LoadCacheRecordPolicy(CbObjectView Object, CachePolicy DefaultPolicy = CachePolicy::Default)
+{
+ OptionalCacheRecordPolicy Policy = CacheRecordPolicy::Load(Object);
+ return Policy ? std::move(Policy).Get() : CacheRecordPolicy(DefaultPolicy);
+}
+
+CacheRpcHandler::CacheRpcHandler(LoggerRef InLog,
+ CacheStats& InCacheStats,
+ UpstreamCacheClient& InUpstreamCache,
+ ZenCacheStore& InCacheStore,
+ CidStore& InCidStore,
+ const DiskWriteBlocker* InDiskWriteBlocker)
+: m_Log(InLog)
+, m_CacheStats(InCacheStats)
+, m_UpstreamCache(InUpstreamCache)
+, m_CacheStore(InCacheStore)
+, m_CidStore(InCidStore)
+, m_DiskWriteBlocker(InDiskWriteBlocker)
+{
+}
+
+CacheRpcHandler::~CacheRpcHandler()
+{
+}
+
+bool
+CacheRpcHandler::AreDiskWritesAllowed() const
+{
+ return (m_DiskWriteBlocker == nullptr || m_DiskWriteBlocker->AreDiskWritesAllowed());
+}
+
+CacheRpcHandler::RpcResponseCode
+CacheRpcHandler::HandleRpcRequest(const CacheRequestContext& Context,
+ const ZenContentType ContentType,
+ IoBuffer&& Body,
+ uint32_t& OutAcceptMagic,
+ RpcAcceptOptions& OutAcceptFlags,
+ int& OutTargetProcessId,
+ CbPackage& OutResultPackage)
+{
+ ZEN_TRACE_CPU("Z$::HandleRpcRequest");
+
+ m_CacheStats.RpcRequests.fetch_add(1);
+
+ CbPackage Package;
+ CbObjectView Object;
+ CbObject ObjectBuffer;
+ if (ContentType == ZenContentType::kCbObject)
+ {
+ ObjectBuffer = LoadCompactBinaryObject(std::move(Body));
+ Object = ObjectBuffer;
+ }
+ else
+ {
+ Package = ParsePackageMessage(Body);
+ Object = Package.GetObject();
+ }
+ OutAcceptMagic = Object["Accept"sv].AsUInt32();
+ OutAcceptFlags = static_cast<RpcAcceptOptions>(Object["AcceptFlags"sv].AsUInt16(0u));
+ OutTargetProcessId = Object["Pid"sv].AsInt32(0);
+
+ const std::string_view Method = Object["Method"sv].AsString();
+
+ if (Method == "PutCacheRecords"sv)
+ {
+ if (!AreDiskWritesAllowed())
+ {
+ return RpcResponseCode::InsufficientStorage;
+ }
+ OutResultPackage = HandleRpcPutCacheRecords(Context, Package);
+ }
+ else if (Method == "GetCacheRecords"sv)
+ {
+ OutResultPackage = HandleRpcGetCacheRecords(Context, Object);
+ }
+ else if (Method == "PutCacheValues"sv)
+ {
+ if (!AreDiskWritesAllowed())
+ {
+ return RpcResponseCode::InsufficientStorage;
+ }
+ OutResultPackage = HandleRpcPutCacheValues(Context, Package);
+ }
+ else if (Method == "GetCacheValues"sv)
+ {
+ OutResultPackage = HandleRpcGetCacheValues(Context, Object);
+ }
+ else if (Method == "GetCacheChunks"sv)
+ {
+ OutResultPackage = HandleRpcGetCacheChunks(Context, Object);
+ }
+ else
+ {
+ m_CacheStats.BadRequestCount++;
+ return RpcResponseCode::BadRequest;
+ }
+ return RpcResponseCode::OK;
+}
+
+CbPackage
+CacheRpcHandler::HandleRpcPutCacheRecords(const CacheRequestContext& Context, const CbPackage& BatchRequest)
+{
+ ZEN_TRACE_CPU("Z$::RpcPutCacheRecords");
+
+ CbObjectView BatchObject = BatchRequest.GetObject();
+ ZEN_ASSERT(BatchObject["Method"sv].AsString() == "PutCacheRecords"sv);
+
+ CbObjectView Params = BatchObject["Params"sv].AsObjectView();
+ CachePolicy DefaultPolicy;
+
+ m_CacheStats.RpcRecordRequests.fetch_add(1);
+
+ std::string_view PolicyText = Params["DefaultPolicy"].AsString();
+ std::optional<std::string> Namespace = GetRpcRequestNamespace(Params);
+ if (!Namespace)
+ {
+ return CbPackage{};
+ }
+ DefaultPolicy = !PolicyText.empty() ? ParseCachePolicy(PolicyText) : CachePolicy::Default;
+
+ std::vector<bool> Results;
+
+ CbArrayView RequestsArray = Params["Requests"sv].AsArrayView();
+ for (CbFieldView RequestField : RequestsArray)
+ {
+ m_CacheStats.RpcRecordBatchRequests.fetch_add(1);
+ CbObjectView RequestObject = RequestField.AsObjectView();
+ CbObjectView RecordObject = RequestObject["Record"sv].AsObjectView();
+ CbObjectView KeyView = RecordObject["Key"sv].AsObjectView();
+
+ CacheKey Key;
+ if (!GetRpcRequestCacheKey(KeyView, Key))
+ {
+ return CbPackage{};
+ }
+ CacheRecordPolicy Policy = LoadCacheRecordPolicy(RequestObject["Policy"sv].AsObjectView(), DefaultPolicy);
+ PutRequestData PutRequest{.Namespace = *Namespace,
+ .Key = std::move(Key),
+ .RecordObject = RecordObject,
+ .Policy = std::move(Policy),
+ .Context = Context};
+
+ PutResult Result = PutCacheRecord(PutRequest, &BatchRequest);
+
+ if (Result == PutResult::Invalid)
+ {
+ return CbPackage{};
+ }
+ Results.push_back(Result == PutResult::Success);
+ }
+ if (Results.empty())
+ {
+ return CbPackage{};
+ }
+
+ CbObjectWriter ResponseObject;
+ ResponseObject.BeginArray("Result"sv);
+ for (bool Value : Results)
+ {
+ ResponseObject.AddBool(Value);
+ }
+ ResponseObject.EndArray();
+
+ CbPackage RpcResponse;
+ RpcResponse.SetObject(ResponseObject.Save());
+ return RpcResponse;
+}
+
+PutResult
+CacheRpcHandler::PutCacheRecord(PutRequestData& Request, const CbPackage* Package)
+{
+ CbObjectView Record = Request.RecordObject;
+ uint64_t RecordObjectSize = Record.GetSize();
+ uint64_t TransferredSize = RecordObjectSize;
+
+ AttachmentCount Count;
+ size_t NumAttachments = Package->GetAttachments().size();
+ std::vector<IoHash> ValidAttachments;
+ std::vector<IoHash> ReferencedAttachments;
+ std::vector<const CbAttachment*> AttachmentsToStoreLocally;
+ ValidAttachments.reserve(NumAttachments);
+ AttachmentsToStoreLocally.reserve(NumAttachments);
+
+ const bool HasUpstream = m_UpstreamCache.IsActive();
+
+ Stopwatch Timer;
+
+ Request.RecordObject.IterateAttachments(
+ [this, &Request, Package, &AttachmentsToStoreLocally, &ValidAttachments, &ReferencedAttachments, &Count, &TransferredSize](
+ CbFieldView HashView) {
+ const IoHash ValueHash = HashView.AsHash();
+ ReferencedAttachments.push_back(ValueHash);
+ if (const CbAttachment* Attachment = Package ? Package->FindAttachment(ValueHash) : nullptr)
+ {
+ if (Attachment->IsCompressedBinary())
+ {
+ AttachmentsToStoreLocally.emplace_back(Attachment);
+ ValidAttachments.emplace_back(ValueHash);
+ Count.Valid++;
+ }
+ else
+ {
+ ZEN_WARN("PUTCACHERECORD - '{}/{}/{}' '{}' FAILED, attachment '{}' is not compressed",
+ Request.Namespace,
+ Request.Key.Bucket,
+ Request.Key.Hash,
+ ToString(ZenContentType::kCbPackage),
+ ValueHash);
+ Count.Invalid++;
+ }
+ }
+ else if (m_CidStore.ContainsChunk(ValueHash))
+ {
+ ValidAttachments.emplace_back(ValueHash);
+ Count.Valid++;
+ }
+ Count.Total++;
+ });
+
+ if (Count.Invalid > 0)
+ {
+ return PutResult::Invalid;
+ }
+
+ ZenCacheValue CacheValue;
+ CacheValue.Value = IoBuffer(Record.GetSize());
+ Record.CopyTo(MutableMemoryView(CacheValue.Value.MutableData(), CacheValue.Value.GetSize()));
+ CacheValue.Value.SetContentType(ZenContentType::kCbObject);
+ m_CacheStore.Put(Request.Context, Request.Namespace, Request.Key.Bucket, Request.Key.Hash, CacheValue, ReferencedAttachments);
+ m_CacheStats.WriteCount++;
+
+ for (const CbAttachment* Attachment : AttachmentsToStoreLocally)
+ {
+ CompressedBuffer Chunk = Attachment->AsCompressedBinary();
+ CidStore::InsertResult InsertResult = m_CidStore.AddChunk(Chunk.GetCompressed().Flatten().AsIoBuffer(), Attachment->GetHash());
+ if (InsertResult.New)
+ {
+ Count.New++;
+ }
+ TransferredSize += Chunk.GetCompressedSize();
+ }
+
+ ZEN_DEBUG("PUTCACHERECORD - '{}/{}/{}' {}, attachments '{}/{}/{}' (new/valid/total) in {}",
+ Request.Namespace,
+ Request.Key.Bucket,
+ Request.Key.Hash,
+ NiceBytes(TransferredSize),
+ Count.New,
+ Count.Valid,
+ Count.Total,
+ NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000));
+
+ const bool IsPartialRecord = Count.Valid != Count.Total;
+
+ if (HasUpstream && EnumHasAllFlags(Request.Policy.GetRecordPolicy(), CachePolicy::StoreRemote) && !IsPartialRecord)
+ {
+ m_UpstreamCache.EnqueueUpstream({.Type = ZenContentType::kCbPackage,
+ .Namespace = Request.Namespace,
+ .Key = Request.Key,
+ .ValueContentIds = std::move(ValidAttachments)});
+ }
+ return PutResult::Success;
+}
+
+CbPackage
+CacheRpcHandler::HandleRpcGetCacheRecords(const CacheRequestContext& Context, CbObjectView RpcRequest)
+{
+ ZEN_TRACE_CPU("Z$::RpcGetCacheRecords");
+
+ ZEN_ASSERT(RpcRequest["Method"sv].AsString() == "GetCacheRecords"sv);
+
+ CbObjectView Params = RpcRequest["Params"sv].AsObjectView();
+ m_CacheStats.RpcRecordRequests.fetch_add(1);
+
+ struct ValueRequestData
+ {
+ Oid ValueId;
+ IoHash ContentId;
+ CompressedBuffer Payload;
+ CachePolicy DownstreamPolicy;
+ bool Exists = false;
+ bool ReadFromUpstream = false;
+ };
+ struct RecordRequestData
+ {
+ CacheKeyRequest Upstream;
+ CbObjectView RecordObject;
+ IoBuffer RecordCacheValue;
+ CacheRecordPolicy DownstreamPolicy;
+ std::vector<ValueRequestData> Values;
+ bool Complete = false;
+ const UpstreamEndpointInfo* Source = nullptr;
+ uint64_t ElapsedTimeUs;
+ };
+
+ std::string_view PolicyText = Params["DefaultPolicy"sv].AsString();
+ CachePolicy DefaultPolicy = !PolicyText.empty() ? ParseCachePolicy(PolicyText) : CachePolicy::Default;
+ std::optional<std::string> Namespace = GetRpcRequestNamespace(Params);
+ if (!Namespace)
+ {
+ return CbPackage{};
+ }
+
+ const bool HasUpstream = m_UpstreamCache.IsActive();
+
+ std::vector<RecordRequestData> Requests;
+ std::vector<size_t> UpstreamIndexes;
+
+ auto ParseValues = [](RecordRequestData& Request) {
+ CbArrayView ValuesArray = Request.RecordObject["Values"sv].AsArrayView();
+ Request.Values.reserve(ValuesArray.Num());
+ for (CbFieldView ValueField : ValuesArray)
+ {
+ CbObjectView ValueObject = ValueField.AsObjectView();
+ Oid ValueId = ValueObject["Id"sv].AsObjectId();
+ CbFieldView RawHashField = ValueObject["RawHash"sv];
+ IoHash RawHash = RawHashField.AsBinaryAttachment();
+ if (ValueId && !RawHashField.HasError())
+ {
+ Request.Values.push_back({ValueId, RawHash});
+ Request.Values.back().DownstreamPolicy = Request.DownstreamPolicy.GetValuePolicy(ValueId);
+ }
+ }
+ };
+
+ CbArrayView RequestsArray = Params["Requests"sv].AsArrayView();
+ for (CbFieldView RequestField : RequestsArray)
+ {
+ ZEN_TRACE_CPU("Z$::RpcGetCacheRecords::Request");
+
+ m_CacheStats.RpcRecordBatchRequests.fetch_add(1);
+
+ Stopwatch Timer;
+ RecordRequestData& Request = Requests.emplace_back();
+ CbObjectView RequestObject = RequestField.AsObjectView();
+ CbObjectView KeyObject = RequestObject["Key"sv].AsObjectView();
+
+ CacheKey& Key = Request.Upstream.Key;
+ if (!GetRpcRequestCacheKey(KeyObject, Key))
+ {
+ return CbPackage{};
+ }
+
+ Request.DownstreamPolicy = LoadCacheRecordPolicy(RequestObject["Policy"sv].AsObjectView(), DefaultPolicy);
+ const CacheRecordPolicy& Policy = Request.DownstreamPolicy;
+
+ ZenCacheValue CacheValue;
+ bool NeedUpstreamAttachment = false;
+ bool FoundLocalInvalid = false;
+ ZenCacheValue RecordCacheValue;
+
+ if (EnumHasAllFlags(Policy.GetRecordPolicy(), CachePolicy::QueryLocal) &&
+ m_CacheStore.Get(Context, *Namespace, Key.Bucket, Key.Hash, RecordCacheValue))
+ {
+ Request.RecordCacheValue = std::move(RecordCacheValue.Value);
+ if (Request.RecordCacheValue.GetContentType() != ZenContentType::kCbObject)
+ {
+ FoundLocalInvalid = true;
+ }
+ else
+ {
+ Request.RecordObject = CbObjectView(Request.RecordCacheValue.GetData());
+ ParseValues(Request);
+
+ Request.Complete = true;
+ for (ValueRequestData& Value : Request.Values)
+ {
+ CachePolicy ValuePolicy = Value.DownstreamPolicy;
+ if (!EnumHasAllFlags(ValuePolicy, CachePolicy::QueryLocal))
+ {
+ // A value that is requested without the Query flag (such as None/Disable) counts as existing, because we
+ // didn't ask for it and thus the record is complete in its absence.
+ if (!EnumHasAllFlags(ValuePolicy, CachePolicy::QueryRemote))
+ {
+ Value.Exists = true;
+ }
+ else
+ {
+ NeedUpstreamAttachment = true;
+ Value.ReadFromUpstream = true;
+ Request.Complete = false;
+ }
+ }
+ else if (EnumHasAllFlags(ValuePolicy, CachePolicy::SkipData))
+ {
+ if (m_CidStore.ContainsChunk(Value.ContentId))
+ {
+ Value.Exists = true;
+ }
+ else
+ {
+ if (EnumHasAllFlags(ValuePolicy, CachePolicy::QueryRemote))
+ {
+ NeedUpstreamAttachment = true;
+ Value.ReadFromUpstream = true;
+ }
+ Request.Complete = false;
+ }
+ }
+ else
+ {
+ if (IoBuffer Chunk = m_CidStore.FindChunkByCid(Value.ContentId))
+ {
+ if (Chunk.GetSize() > 0)
+ {
+ Value.Payload = CompressedBuffer::FromCompressedNoValidate(std::move(Chunk));
+ Value.Exists = true;
+ continue;
+ }
+ else
+ {
+ ZEN_WARN("Skipping invalid chunk in local cache '{}'", Value.ContentId);
+ }
+ }
+
+ if (EnumHasAllFlags(ValuePolicy, CachePolicy::QueryRemote))
+ {
+ NeedUpstreamAttachment = true;
+ Value.ReadFromUpstream = true;
+ }
+ Request.Complete = false;
+ }
+ }
+ }
+ }
+ if (!Request.Complete)
+ {
+ bool NeedUpstreamRecord = HasUpstream && !Request.RecordObject && !FoundLocalInvalid &&
+ EnumHasAllFlags(Policy.GetRecordPolicy(), CachePolicy::QueryRemote);
+ if (NeedUpstreamRecord || NeedUpstreamAttachment)
+ {
+ UpstreamIndexes.push_back(Requests.size() - 1);
+ }
+ }
+ Request.ElapsedTimeUs = Timer.GetElapsedTimeUs();
+ }
+ if (Requests.empty())
+ {
+ return CbPackage{};
+ }
+
+ if (!UpstreamIndexes.empty())
+ {
+ std::vector<CacheKeyRequest*> UpstreamRequests;
+ UpstreamRequests.reserve(UpstreamIndexes.size());
+ for (size_t Index : UpstreamIndexes)
+ {
+ RecordRequestData& Request = Requests[Index];
+ UpstreamRequests.push_back(&Request.Upstream);
+
+ if (Request.Values.size())
+ {
+ // We will be returning the local object and know all the value Ids that exist in it
+ // Convert all their Downstream Values to upstream values, and add SkipData to any ones that we already have.
+ CachePolicy UpstreamBasePolicy = ConvertToUpstream(Request.DownstreamPolicy.GetBasePolicy()) | CachePolicy::SkipMeta;
+ CacheRecordPolicyBuilder Builder(UpstreamBasePolicy);
+ for (ValueRequestData& Value : Request.Values)
+ {
+ CachePolicy UpstreamPolicy = ConvertToUpstream(Value.DownstreamPolicy);
+ UpstreamPolicy |= !Value.ReadFromUpstream ? CachePolicy::SkipData : CachePolicy::None;
+ Builder.AddValuePolicy(Value.ValueId, UpstreamPolicy);
+ }
+ Request.Upstream.Policy = Builder.Build();
+ }
+ else
+ {
+ // We don't know which Values exist in the Record; ask the upstrem for all values that the client wants,
+ // and convert the CacheRecordPolicy to an upstream policy
+ Request.Upstream.Policy = Request.DownstreamPolicy.ConvertToUpstream();
+ }
+ }
+
+ const auto OnCacheRecordGetComplete = [this, Namespace, &ParseValues, Context](CacheRecordGetCompleteParams&& Params) {
+ if (!Params.Record)
+ {
+ return;
+ }
+
+ RecordRequestData& Request =
+ *reinterpret_cast<RecordRequestData*>(reinterpret_cast<char*>(&Params.Request) - offsetof(RecordRequestData, Upstream));
+ Request.ElapsedTimeUs += static_cast<uint64_t>(Params.ElapsedSeconds * 1000000.0);
+ const CacheKey& Key = Request.Upstream.Key;
+ Stopwatch Timer;
+ auto TimeGuard = MakeGuard([&Timer, &Request]() { Request.ElapsedTimeUs += Timer.GetElapsedTimeUs(); });
+ if (!Request.RecordObject)
+ {
+ CbObject ObjectBuffer = CbObject::Clone(Params.Record);
+ Request.RecordCacheValue = ObjectBuffer.GetBuffer().AsIoBuffer();
+ Request.RecordCacheValue.SetContentType(ZenContentType::kCbObject);
+ Request.RecordObject = ObjectBuffer;
+ bool StoreLocal =
+ EnumHasAllFlags(Request.DownstreamPolicy.GetRecordPolicy(), CachePolicy::StoreLocal) && AreDiskWritesAllowed();
+ if (StoreLocal)
+ {
+ std::vector<IoHash> ReferencedAttachments;
+ ObjectBuffer.IterateAttachments([&ReferencedAttachments](CbFieldView HashView) {
+ const IoHash ValueHash = HashView.AsHash();
+ ReferencedAttachments.push_back(ValueHash);
+ });
+ m_CacheStore
+ .Put(Context, *Namespace, Key.Bucket, Key.Hash, {.Value = {Request.RecordCacheValue}}, ReferencedAttachments);
+ m_CacheStats.WriteCount++;
+ }
+ ParseValues(Request);
+ Request.Source = Params.Source;
+ }
+
+ Request.Complete = true;
+ for (ValueRequestData& Value : Request.Values)
+ {
+ if (Value.Exists)
+ {
+ continue;
+ }
+ CachePolicy ValuePolicy = Value.DownstreamPolicy;
+ if (!EnumHasAllFlags(ValuePolicy, CachePolicy::QueryRemote))
+ {
+ Request.Complete = false;
+ continue;
+ }
+ if (!EnumHasAllFlags(ValuePolicy, CachePolicy::SkipData) || EnumHasAllFlags(ValuePolicy, CachePolicy::StoreLocal))
+ {
+ bool StoreLocal = EnumHasAllFlags(ValuePolicy, CachePolicy::StoreLocal) && AreDiskWritesAllowed();
+ if (const CbAttachment* Attachment = Params.Package.FindAttachment(Value.ContentId))
+ {
+ if (CompressedBuffer Compressed = Attachment->AsCompressedBinary())
+ {
+ Request.Source = Params.Source;
+ Value.Exists = true;
+ if (StoreLocal)
+ {
+ m_CidStore.AddChunk(Compressed.GetCompressed().Flatten().AsIoBuffer(), Attachment->GetHash());
+ }
+ if (!EnumHasAllFlags(ValuePolicy, CachePolicy::SkipData))
+ {
+ Value.Payload = Compressed;
+ }
+ }
+ else
+ {
+ ZEN_DEBUG("Uncompressed value '{}' from upstream cache record '{}/{}/{}'",
+ Value.ContentId,
+ *Namespace,
+ Key.Bucket,
+ Key.Hash);
+ }
+ }
+ if (!Value.Exists && !EnumHasAllFlags(ValuePolicy, CachePolicy::SkipData))
+ {
+ Request.Complete = false;
+ }
+ // Request.Complete does not need to be set to false for upstream SkipData attachments.
+ // In the PartialRecord==false case, the upstream will have failed the entire record if any SkipData attachment
+ // didn't exist and we will not get here. In the PartialRecord==true case, we do not need to inform the client of
+ // any missing SkipData attachments.
+ }
+ Request.ElapsedTimeUs += Timer.GetElapsedTimeUs();
+ }
+ };
+
+ m_UpstreamCache.GetCacheRecords(*Namespace, UpstreamRequests, std::move(OnCacheRecordGetComplete));
+ }
+
+ {
+ ZEN_TRACE_CPU("Z$::RpcGetCacheRecords::Response");
+ CbPackage ResponsePackage;
+ CbObjectWriter ResponseObject;
+
+ ResponseObject.BeginArray("Result"sv);
+ for (RecordRequestData& Request : Requests)
+ {
+ const CacheKey& Key = Request.Upstream.Key;
+ if (Request.Complete ||
+ (Request.RecordObject && EnumHasAllFlags(Request.DownstreamPolicy.GetRecordPolicy(), CachePolicy::PartialRecord)))
+ {
+ ResponseObject << Request.RecordObject;
+ for (ValueRequestData& Value : Request.Values)
+ {
+ if (!EnumHasAllFlags(Value.DownstreamPolicy, CachePolicy::SkipData) && Value.Payload)
+ {
+ ResponsePackage.AddAttachment(CbAttachment(Value.Payload, Value.ContentId));
+ }
+ }
+
+ ZEN_DEBUG("GETCACHERECORD HIT - '{}/{}/{}' {}{} ({}) in {}",
+ *Namespace,
+ Key.Bucket,
+ Key.Hash,
+ NiceBytes(Request.RecordCacheValue.Size()),
+ Request.Complete ? ""sv : " (PARTIAL)"sv,
+ Request.Source ? Request.Source->Url : "LOCAL"sv,
+ NiceLatencyNs(Request.ElapsedTimeUs * 1000));
+ m_CacheStats.HitCount++;
+ m_CacheStats.UpstreamHitCount += Request.Source ? 1 : 0;
+ }
+ else
+ {
+ ResponseObject.AddNull();
+
+ if (!EnumHasAnyFlags(Request.DownstreamPolicy.GetRecordPolicy(), CachePolicy::Query))
+ {
+ // If they requested no query, do not record this as a miss
+ ZEN_DEBUG("GETCACHERECORD DISABLEDQUERY - '{}/{}/{}' in {}",
+ *Namespace,
+ Key.Bucket,
+ Key.Hash,
+ NiceLatencyNs(Request.ElapsedTimeUs * 1000));
+ }
+ else
+ {
+ ZEN_DEBUG("GETCACHERECORD MISS - '{}/{}/{}'{} ({}) in {}",
+ *Namespace,
+ Key.Bucket,
+ Key.Hash,
+ Request.RecordObject ? ""sv : " (PARTIAL)"sv,
+ Request.Source ? Request.Source->Url : "LOCAL"sv,
+ NiceLatencyNs(Request.ElapsedTimeUs * 1000));
+ m_CacheStats.MissCount++;
+ }
+ }
+ }
+ ResponseObject.EndArray();
+ ResponsePackage.SetObject(ResponseObject.Save());
+ return ResponsePackage;
+ }
+}
+
+CbPackage
+CacheRpcHandler::HandleRpcPutCacheValues(const CacheRequestContext& Context, const CbPackage& BatchRequest)
+{
+ ZEN_TRACE_CPU("Z$::RpcPutCacheValues");
+ CbObjectView BatchObject = BatchRequest.GetObject();
+ CbObjectView Params = BatchObject["Params"sv].AsObjectView();
+
+ m_CacheStats.RpcValueRequests.fetch_add(1);
+
+ std::string_view PolicyText = Params["DefaultPolicy"].AsString();
+ CachePolicy DefaultPolicy = !PolicyText.empty() ? ParseCachePolicy(PolicyText) : CachePolicy::Default;
+ std::optional<std::string> Namespace = GetRpcRequestNamespace(Params);
+ if (!Namespace)
+ {
+ return CbPackage{};
+ }
+ const bool HasUpstream = m_UpstreamCache.IsActive();
+ CbArrayView RequestsArray = Params["Requests"sv].AsArrayView();
+
+ std::vector<bool> Results;
+ for (CbFieldView RequestField : RequestsArray)
+ {
+ ZEN_TRACE_CPU("Z$::RpcPutCacheValues::Request");
+
+ m_CacheStats.RpcValueBatchRequests.fetch_add(1);
+
+ Stopwatch Timer;
+
+ CbObjectView RequestObject = RequestField.AsObjectView();
+ CbObjectView KeyView = RequestObject["Key"sv].AsObjectView();
+
+ CacheKey Key;
+ if (!GetRpcRequestCacheKey(KeyView, Key))
+ {
+ return CbPackage{};
+ }
+
+ PolicyText = RequestObject["Policy"sv].AsString();
+ CachePolicy Policy = !PolicyText.empty() ? ParseCachePolicy(PolicyText) : DefaultPolicy;
+ IoHash RawHash = RequestObject["RawHash"sv].AsBinaryAttachment();
+ uint64_t RawSize = RequestObject["RawSize"sv].AsUInt64();
+ bool Succeeded = false;
+ uint64_t TransferredSize = 0;
+
+ if (const CbAttachment* Attachment = BatchRequest.FindAttachment(RawHash))
+ {
+ if (Attachment->IsCompressedBinary())
+ {
+ CompressedBuffer Chunk = Attachment->AsCompressedBinary();
+ if (EnumHasAllFlags(Policy, CachePolicy::StoreRemote))
+ {
+ // TODO: Implement upstream puts of CacheValues with StoreLocal == false.
+ // Currently ProcessCacheRecord requires that the value exist in the local cache to put it upstream.
+ Policy |= CachePolicy::StoreLocal;
+ }
+
+ if (EnumHasAllFlags(Policy, CachePolicy::StoreLocal))
+ {
+ IoBuffer Value = Chunk.GetCompressed().Flatten().AsIoBuffer();
+ Value.SetContentType(ZenContentType::kCompressedBinary);
+ if (RawSize == 0)
+ {
+ RawSize = Chunk.DecodeRawSize();
+ }
+ m_CacheStore
+ .Put(Context, *Namespace, Key.Bucket, Key.Hash, {.Value = Value, .RawSize = RawSize, .RawHash = RawHash}, {});
+ m_CacheStats.WriteCount++;
+ TransferredSize = Chunk.GetCompressedSize();
+ }
+ Succeeded = true;
+ }
+ else
+ {
+ ZEN_WARN("PUTCACHEVALUES - '{}/{}/{}/{}' FAILED, value is not compressed", *Namespace, Key.Bucket, Key.Hash, RawHash);
+ return CbPackage{};
+ }
+ }
+ else if (EnumHasAllFlags(Policy, CachePolicy::QueryLocal))
+ {
+ ZenCacheValue ExistingValue;
+ if (m_CacheStore.Get(Context, *Namespace, Key.Bucket, Key.Hash, ExistingValue) &&
+ IsCompressedBinary(ExistingValue.Value.GetContentType()))
+ {
+ Succeeded = true;
+ }
+ }
+ // We do not search the Upstream. No data in a put means the caller is probing for whether they need to do a heavy put.
+ // If it doesn't exist locally they should do the heavy put rather than having us fetch it from upstream.
+
+ if (HasUpstream && Succeeded && EnumHasAllFlags(Policy, CachePolicy::StoreRemote))
+ {
+ m_UpstreamCache.EnqueueUpstream({.Type = ZenContentType::kCompressedBinary, .Namespace = *Namespace, .Key = Key});
+ }
+ Results.push_back(Succeeded);
+ ZEN_DEBUG("PUTCACHEVALUES - '{}/{}/{}' {}, '{}' in {}",
+ *Namespace,
+ Key.Bucket,
+ Key.Hash,
+ NiceBytes(TransferredSize),
+ Succeeded ? "Added"sv : "Invalid",
+ NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000));
+ }
+ if (Results.empty())
+ {
+ return CbPackage{};
+ }
+
+ {
+ ZEN_TRACE_CPU("Z$::RpcPutCacheValues::Response");
+ CbObjectWriter ResponseObject;
+ ResponseObject.BeginArray("Result"sv);
+ for (bool Value : Results)
+ {
+ ResponseObject.AddBool(Value);
+ }
+ ResponseObject.EndArray();
+
+ CbPackage RpcResponse;
+ RpcResponse.SetObject(ResponseObject.Save());
+
+ return RpcResponse;
+ }
+}
+
+CbPackage
+CacheRpcHandler::HandleRpcGetCacheValues(const CacheRequestContext& Context, CbObjectView RpcRequest)
+{
+ ZEN_TRACE_CPU("Z$::RpcGetCacheValues");
+ ZEN_ASSERT(RpcRequest["Method"sv].AsString() == "GetCacheValues"sv);
+
+ CbObjectView Params = RpcRequest["Params"sv].AsObjectView();
+ m_CacheStats.RpcValueRequests.fetch_add(1);
+
+ std::string_view PolicyText = Params["DefaultPolicy"sv].AsString();
+ CachePolicy DefaultPolicy = !PolicyText.empty() ? ParseCachePolicy(PolicyText) : CachePolicy::Default;
+ std::optional<std::string> Namespace = GetRpcRequestNamespace(Params);
+ if (!Namespace)
+ {
+ return CbPackage{};
+ }
+
+ struct RequestData
+ {
+ CacheKey Key;
+ CachePolicy Policy;
+ IoHash RawHash = IoHash::Zero;
+ uint64_t RawSize = 0;
+ CompressedBuffer Result;
+ };
+ std::vector<RequestData> Requests;
+
+ std::vector<size_t> RemoteRequestIndexes;
+
+ const bool HasUpstream = m_UpstreamCache.IsActive();
+
+ CbArrayView RequestsArray = Params["Requests"sv].AsArrayView();
+ for (CbFieldView RequestField : RequestsArray)
+ {
+ ZEN_TRACE_CPU("Z$::RpcGetCacheValues::Request");
+
+ m_CacheStats.RpcValueBatchRequests.fetch_add(1);
+
+ Stopwatch Timer;
+
+ RequestData& Request = Requests.emplace_back();
+ CbObjectView RequestObject = RequestField.AsObjectView();
+ CbObjectView KeyObject = RequestObject["Key"sv].AsObjectView();
+
+ if (!GetRpcRequestCacheKey(KeyObject, Request.Key))
+ {
+ return CbPackage{};
+ }
+
+ PolicyText = RequestObject["Policy"sv].AsString();
+ Request.Policy = !PolicyText.empty() ? ParseCachePolicy(PolicyText) : DefaultPolicy;
+
+ CacheKey& Key = Request.Key;
+ CachePolicy Policy = Request.Policy;
+
+ ZenCacheValue CacheValue;
+ if (EnumHasAllFlags(Policy, CachePolicy::QueryLocal))
+ {
+ if (m_CacheStore.Get(Context, *Namespace, Key.Bucket, Key.Hash, CacheValue) &&
+ IsCompressedBinary(CacheValue.Value.GetContentType()))
+ {
+ Request.RawHash = CacheValue.RawHash;
+ Request.RawSize = CacheValue.RawSize;
+ Request.Result = CompressedBuffer::FromCompressedNoValidate(std::move(CacheValue.Value));
+ }
+ }
+ if (Request.Result)
+ {
+ ZEN_DEBUG("GETCACHEVALUES HIT - '{}/{}/{}' {} ({}) in {}",
+ *Namespace,
+ Key.Bucket,
+ Key.Hash,
+ NiceBytes(Request.Result.GetCompressed().GetSize()),
+ "LOCAL"sv,
+ NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000));
+ m_CacheStats.HitCount++;
+ }
+ else if (HasUpstream && EnumHasAllFlags(Policy, CachePolicy::QueryRemote))
+ {
+ RemoteRequestIndexes.push_back(Requests.size() - 1);
+ }
+ else if (!EnumHasAnyFlags(Policy, CachePolicy::Query))
+ {
+ // If they requested no query, do not record this as a miss
+ ZEN_DEBUG("GETCACHEVALUES DISABLEDQUERY - '{}/{}/{}'", *Namespace, Key.Bucket, Key.Hash);
+ }
+ else
+ {
+ ZEN_DEBUG("GETCACHEVALUES MISS - '{}/{}/{}' ({}) in {}",
+ *Namespace,
+ Key.Bucket,
+ Key.Hash,
+ "LOCAL"sv,
+ NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000));
+ m_CacheStats.MissCount++;
+ }
+ }
+
+ if (!RemoteRequestIndexes.empty())
+ {
+ std::vector<CacheValueRequest> RequestedRecordsData;
+ std::vector<CacheValueRequest*> CacheValueRequests;
+ RequestedRecordsData.reserve(RemoteRequestIndexes.size());
+ CacheValueRequests.reserve(RemoteRequestIndexes.size());
+ for (size_t Index : RemoteRequestIndexes)
+ {
+ RequestData& Request = Requests[Index];
+ RequestedRecordsData.push_back({.Key = {Request.Key.Bucket, Request.Key.Hash}, .Policy = ConvertToUpstream(Request.Policy)});
+ CacheValueRequests.push_back(&RequestedRecordsData.back());
+ }
+ Stopwatch Timer;
+ m_UpstreamCache.GetCacheValues(
+ *Namespace,
+ CacheValueRequests,
+ [this, Namespace, &RequestedRecordsData, &Requests, &RemoteRequestIndexes, &Timer, Context](
+ CacheValueGetCompleteParams&& Params) {
+ CacheValueRequest& ChunkRequest = Params.Request;
+ if (Params.RawHash != IoHash::Zero)
+ {
+ size_t RequestOffset = std::distance(RequestedRecordsData.data(), &ChunkRequest);
+ size_t RequestIndex = RemoteRequestIndexes[RequestOffset];
+ RequestData& Request = Requests[RequestIndex];
+ Request.RawHash = Params.RawHash;
+ Request.RawSize = Params.RawSize;
+ const bool HasData = IsCompressedBinary(Params.Value.GetContentType());
+ const bool SkipData = EnumHasAllFlags(Request.Policy, CachePolicy::SkipData);
+ const bool StoreData = EnumHasAllFlags(Request.Policy, CachePolicy::StoreLocal) && AreDiskWritesAllowed();
+ const bool IsHit = SkipData || HasData;
+ if (IsHit)
+ {
+ if (HasData && !SkipData)
+ {
+ Request.Result = CompressedBuffer::FromCompressedNoValidate(IoBuffer(Params.Value));
+ }
+
+ if (HasData && StoreData)
+ {
+ m_CacheStore.Put(Context,
+ *Namespace,
+ Request.Key.Bucket,
+ Request.Key.Hash,
+ ZenCacheValue{.Value = Params.Value, .RawSize = Request.RawSize, .RawHash = Request.RawHash},
+ {});
+ m_CacheStats.WriteCount++;
+ }
+
+ ZEN_DEBUG("GETCACHEVALUES HIT - '{}/{}/{}' {} ({}) in {}",
+ *Namespace,
+ ChunkRequest.Key.Bucket,
+ ChunkRequest.Key.Hash,
+ NiceBytes(Request.Result.GetCompressed().GetSize()),
+ Params.Source ? Params.Source->Url : "UPSTREAM",
+ NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000));
+ m_CacheStats.HitCount++;
+ m_CacheStats.UpstreamHitCount++;
+ return;
+ }
+ }
+ ZEN_DEBUG("GETCACHEVALUES MISS - '{}/{}/{}' ({}) in {}",
+ *Namespace,
+ ChunkRequest.Key.Bucket,
+ ChunkRequest.Key.Hash,
+ Params.Source ? Params.Source->Url : "UPSTREAM",
+ NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000));
+ m_CacheStats.MissCount++;
+ });
+ }
+
+ if (Requests.empty())
+ {
+ return CbPackage{};
+ }
+
+ {
+ ZEN_TRACE_CPU("Z$::RpcGetCacheValues::Response");
+ CbPackage RpcResponse;
+ CbObjectWriter ResponseObject;
+ ResponseObject.BeginArray("Result"sv);
+ for (const RequestData& Request : Requests)
+ {
+ ResponseObject.BeginObject();
+ {
+ const CompressedBuffer& Result = Request.Result;
+ if (Result)
+ {
+ ResponseObject.AddHash("RawHash"sv, Request.RawHash);
+ if (!EnumHasAllFlags(Request.Policy, CachePolicy::SkipData))
+ {
+ RpcResponse.AddAttachment(CbAttachment(Result, Request.RawHash));
+ }
+ else
+ {
+ ResponseObject.AddInteger("RawSize"sv, Request.RawSize);
+ }
+ }
+ else if (Request.RawHash != IoHash::Zero)
+ {
+ ResponseObject.AddHash("RawHash"sv, Request.RawHash);
+ ResponseObject.AddInteger("RawSize"sv, Request.RawSize);
+ }
+ }
+ ResponseObject.EndObject();
+ }
+ ResponseObject.EndArray();
+
+ RpcResponse.SetObject(ResponseObject.Save());
+ return RpcResponse;
+ }
+}
+
+CbPackage
+CacheRpcHandler::HandleRpcGetCacheChunks(const CacheRequestContext& Context, CbObjectView RpcRequest)
+{
+ ZEN_TRACE_CPU("Z$::RpcGetCacheChunks");
+ using namespace cache::detail;
+
+ std::string Namespace;
+ std::vector<CacheKeyRequest> RecordKeys; // Data about a Record necessary to identify it to the upstream
+ std::vector<RecordBody> Records; // Scratch-space data about a Record when fulfilling RecordRequests
+ std::vector<CacheChunkRequest> RequestKeys; // Data about a ChunkRequest necessary to identify it to the upstream
+ std::vector<ChunkRequest> Requests; // Intermediate and result data about a ChunkRequest
+ std::vector<ChunkRequest*> RecordRequests; // The ChunkRequests that are requesting a subvalue from a Record Key
+ std::vector<ChunkRequest*> ValueRequests; // The ChunkRequests that are requesting a Value Key
+ std::vector<CacheChunkRequest*> UpstreamChunks; // ChunkRequests that we need to send to the upstream
+
+ // Parse requests from the CompactBinary body of the RpcRequest and divide it into RecordRequests and ValueRequests
+ if (!ParseGetCacheChunksRequest(Namespace, RecordKeys, Records, RequestKeys, Requests, RecordRequests, ValueRequests, RpcRequest))
+ {
+ return CbPackage{};
+ }
+
+ // For each Record request, load the Record if necessary to find the Chunk's ContentId, load its Payloads if we
+ // have it locally, and otherwise append a request for the payload to UpstreamChunks
+ GetLocalCacheRecords(Context, Namespace, RecordKeys, Records, RecordRequests, UpstreamChunks);
+
+ // For each Value request, load the Value if we have it locally and otherwise append a request for the payload to UpstreamChunks
+ GetLocalCacheValues(Context, Namespace, ValueRequests, UpstreamChunks);
+
+ // Call GetCacheChunks on the upstream for any payloads we do not have locally
+ GetUpstreamCacheChunks(Context, Namespace, UpstreamChunks, RequestKeys, Requests);
+
+ // Send the payload and descriptive data about each chunk to the client
+ return WriteGetCacheChunksResponse(Context, Namespace, Requests);
+}
+
+bool
+CacheRpcHandler::ParseGetCacheChunksRequest(std::string& Namespace,
+ std::vector<CacheKeyRequest>& RecordKeys,
+ std::vector<cache::detail::RecordBody>& Records,
+ std::vector<CacheChunkRequest>& RequestKeys,
+ std::vector<cache::detail::ChunkRequest>& Requests,
+ std::vector<cache::detail::ChunkRequest*>& RecordRequests,
+ std::vector<cache::detail::ChunkRequest*>& ValueRequests,
+ CbObjectView RpcRequest)
+{
+ ZEN_TRACE_CPU("Z$::ParseGetCacheChunksRequest");
+
+ using namespace cache::detail;
+
+ ZEN_ASSERT(RpcRequest["Method"sv].AsString() == "GetCacheChunks"sv);
+
+ CbObjectView Params = RpcRequest["Params"sv].AsObjectView();
+ m_CacheStats.RpcChunkRequests.fetch_add(1);
+
+ std::string_view DefaultPolicyText = Params["DefaultPolicy"sv].AsString();
+ CachePolicy DefaultPolicy = !DefaultPolicyText.empty() ? ParseCachePolicy(DefaultPolicyText) : CachePolicy::Default;
+
+ std::optional<std::string> NamespaceText = GetRpcRequestNamespace(Params);
+ if (!NamespaceText)
+ {
+ ZEN_WARN("GetCacheChunks: Invalid namespace in ChunkRequest.");
+ return false;
+ }
+ Namespace = *NamespaceText;
+
+ CbArrayView ChunkRequestsArray = Params["ChunkRequests"sv].AsArrayView();
+ size_t NumRequests = static_cast<size_t>(ChunkRequestsArray.Num());
+ // Note that these reservations allow us to take pointers to the elements while populating them. If the reservation is removed,
+ // we will need to change the pointers to indexes to handle reallocations.
+ RecordKeys.reserve(NumRequests);
+ Records.reserve(NumRequests);
+ RequestKeys.reserve(NumRequests);
+ Requests.reserve(NumRequests);
+ RecordRequests.reserve(NumRequests);
+ ValueRequests.reserve(NumRequests);
+
+ CacheKeyRequest* PreviousRecordKey = nullptr;
+ RecordBody* PreviousRecord = nullptr;
+
+ for (CbFieldView RequestView : ChunkRequestsArray)
+ {
+ ZEN_TRACE_CPU("Z$::ParseGetCacheChunksRequest::Request");
+
+ m_CacheStats.RpcChunkBatchRequests.fetch_add(1);
+
+ CbObjectView RequestObject = RequestView.AsObjectView();
+ CacheChunkRequest& RequestKey = RequestKeys.emplace_back();
+ ChunkRequest& Request = Requests.emplace_back();
+ CbObjectView KeyObject = RequestObject["Key"sv].AsObjectView();
+
+ Request.Key = &RequestKey;
+ if (!GetRpcRequestCacheKey(KeyObject, Request.Key->Key))
+ {
+ ZEN_WARN("GetCacheChunks: Invalid key in ChunkRequest.");
+ return false;
+ }
+
+ RequestKey.ChunkId = RequestObject["ChunkId"sv].AsHash();
+ RequestKey.ValueId = RequestObject["ValueId"sv].AsObjectId();
+ RequestKey.RawOffset = RequestObject["RawOffset"sv].AsUInt64();
+ RequestKey.RawSize = RequestObject["RawSize"sv].AsUInt64(UINT64_MAX);
+ Request.RequestedSize = RequestKey.RawSize;
+ Request.RequestedOffset = RequestKey.RawOffset;
+ std::string_view PolicyText = RequestObject["Policy"sv].AsString();
+ Request.DownstreamPolicy = !PolicyText.empty() ? ParseCachePolicy(PolicyText) : DefaultPolicy;
+ Request.IsRecordRequest = (bool)RequestKey.ValueId;
+
+ if (!Request.IsRecordRequest)
+ {
+ ValueRequests.push_back(&Request);
+ }
+ else
+ {
+ RecordRequests.push_back(&Request);
+ CacheKeyRequest* RecordKey = nullptr;
+ RecordBody* Record = nullptr;
+
+ if (!PreviousRecordKey || PreviousRecordKey->Key < RequestKey.Key)
+ {
+ RecordKey = &RecordKeys.emplace_back();
+ PreviousRecordKey = RecordKey;
+ Record = &Records.emplace_back();
+ PreviousRecord = Record;
+ RecordKey->Key = RequestKey.Key;
+ }
+ else if (RequestKey.Key == PreviousRecordKey->Key)
+ {
+ RecordKey = PreviousRecordKey;
+ Record = PreviousRecord;
+ }
+ else
+ {
+ ZEN_WARN("GetCacheChunks: Keys in ChunkRequest are not sorted: {}/{} came after {}/{}.",
+ RequestKey.Key.Bucket,
+ RequestKey.Key.Hash,
+ PreviousRecordKey->Key.Bucket,
+ PreviousRecordKey->Key.Hash);
+ return false;
+ }
+ Request.Record = Record;
+ if (RequestKey.ChunkId == RequestKey.ChunkId.Zero)
+ {
+ Record->DownstreamPolicy =
+ Record->HasRequest ? Union(Record->DownstreamPolicy, Request.DownstreamPolicy) : Request.DownstreamPolicy;
+ Record->HasRequest = true;
+ }
+ }
+ }
+ if (Requests.empty())
+ {
+ return false;
+ }
+ return true;
+}
+
+void
+CacheRpcHandler::GetLocalCacheRecords(const CacheRequestContext& Context,
+ std::string_view Namespace,
+ std::vector<CacheKeyRequest>& RecordKeys,
+ std::vector<cache::detail::RecordBody>& Records,
+ std::vector<cache::detail::ChunkRequest*>& RecordRequests,
+ std::vector<CacheChunkRequest*>& OutUpstreamChunks)
+{
+ ZEN_TRACE_CPU("Z$::GetLocalCacheRecords");
+
+ using namespace cache::detail;
+ const bool HasUpstream = m_UpstreamCache.IsActive();
+
+ std::vector<CacheKeyRequest*> UpstreamRecordRequests;
+ for (size_t RecordIndex = 0; RecordIndex < Records.size(); ++RecordIndex)
+ {
+ ZEN_TRACE_CPU("Z$::GetLocalCacheRecords::Record");
+
+ Stopwatch Timer;
+ CacheKeyRequest& RecordKey = RecordKeys[RecordIndex];
+ RecordBody& Record = Records[RecordIndex];
+ if (Record.HasRequest)
+ {
+ Record.DownstreamPolicy |= CachePolicy::SkipData | CachePolicy::SkipMeta;
+
+ if (!Record.Exists && EnumHasAllFlags(Record.DownstreamPolicy, CachePolicy::QueryLocal))
+ {
+ ZenCacheValue CacheValue;
+ if (m_CacheStore.Get(Context, Namespace, RecordKey.Key.Bucket, RecordKey.Key.Hash, CacheValue))
+ {
+ Record.Exists = true;
+ Record.CacheValue = std::move(CacheValue.Value);
+ }
+ }
+ if (HasUpstream && !Record.Exists && EnumHasAllFlags(Record.DownstreamPolicy, CachePolicy::QueryRemote))
+ {
+ RecordKey.Policy = CacheRecordPolicy(ConvertToUpstream(Record.DownstreamPolicy));
+ UpstreamRecordRequests.push_back(&RecordKey);
+ }
+ RecordRequests[RecordIndex]->ElapsedTimeUs += Timer.GetElapsedTimeUs();
+ }
+ }
+
+ if (!UpstreamRecordRequests.empty())
+ {
+ const auto OnCacheRecordGetComplete =
+ [this, Namespace, &RecordKeys, &Records, &RecordRequests, Context](CacheRecordGetCompleteParams&& Params) {
+ if (!Params.Record)
+ {
+ return;
+ }
+ CacheKeyRequest& RecordKey = Params.Request;
+ size_t RecordIndex = std::distance(RecordKeys.data(), &RecordKey);
+ RecordRequests[RecordIndex]->ElapsedTimeUs += static_cast<uint64_t>(Params.ElapsedSeconds * 1000000.0);
+ RecordBody& Record = Records[RecordIndex];
+
+ const CacheKey& Key = RecordKey.Key;
+ Record.Exists = true;
+ CbObject ObjectBuffer = CbObject::Clone(Params.Record);
+ Record.CacheValue = ObjectBuffer.GetBuffer().AsIoBuffer();
+ Record.CacheValue.SetContentType(ZenContentType::kCbObject);
+ Record.Source = Params.Source;
+
+ bool StoreLocal = EnumHasAllFlags(Record.DownstreamPolicy, CachePolicy::StoreLocal) && AreDiskWritesAllowed();
+ if (StoreLocal)
+ {
+ std::vector<IoHash> ReferencedAttachments;
+ ObjectBuffer.IterateAttachments([&ReferencedAttachments](CbFieldView HashView) {
+ const IoHash ValueHash = HashView.AsHash();
+ ReferencedAttachments.push_back(ValueHash);
+ });
+ m_CacheStore.Put(Context, Namespace, Key.Bucket, Key.Hash, {.Value = Record.CacheValue}, ReferencedAttachments);
+ m_CacheStats.WriteCount++;
+ }
+ };
+ m_UpstreamCache.GetCacheRecords(Namespace, UpstreamRecordRequests, std::move(OnCacheRecordGetComplete));
+ }
+
+ for (ChunkRequest* Request : RecordRequests)
+ {
+ ZEN_TRACE_CPU("Z$::GetLocalCacheRecords::Chunk");
+
+ Stopwatch Timer;
+ if (Request->Key->ChunkId == IoHash::Zero)
+ {
+ // Unreal uses a 12 byte ID to address cache record values. When the uncompressed hash (ChunkId)
+ // is missing, parse the cache record and try to find the raw hash from the ValueId.
+ RecordBody& Record = *Request->Record;
+ if (!Record.ValuesRead)
+ {
+ Record.ValuesRead = true;
+ if (Record.CacheValue && Record.CacheValue.GetContentType() == ZenContentType::kCbObject)
+ {
+ CbObjectView RecordObject = CbObjectView(Record.CacheValue.GetData());
+ CbArrayView ValuesArray = RecordObject["Values"sv].AsArrayView();
+ Record.Values.reserve(ValuesArray.Num());
+ for (CbFieldView ValueField : ValuesArray)
+ {
+ CbObjectView ValueObject = ValueField.AsObjectView();
+ Oid ValueId = ValueObject["Id"sv].AsObjectId();
+ CbFieldView RawHashField = ValueObject["RawHash"sv];
+ IoHash RawHash = RawHashField.AsBinaryAttachment();
+ if (ValueId && !RawHashField.HasError())
+ {
+ Record.Values.push_back({ValueId, RawHash, ValueObject["RawSize"sv].AsUInt64()});
+ }
+ }
+ }
+ }
+
+ for (const RecordValue& Value : Record.Values)
+ {
+ if (Value.ValueId == Request->Key->ValueId)
+ {
+ Request->Key->ChunkId = Value.ContentId;
+ Request->RawSize = Value.RawSize;
+ Request->RawSizeKnown = true;
+ break;
+ }
+ }
+ }
+
+ // Now load the ContentId from the local ContentIdStore or from the upstream
+ if (Request->Key->ChunkId != IoHash::Zero)
+ {
+ if (EnumHasAllFlags(Request->DownstreamPolicy, CachePolicy::QueryLocal))
+ {
+ if (EnumHasAllFlags(Request->DownstreamPolicy, CachePolicy::SkipData) && Request->RawSizeKnown)
+ {
+ if (m_CidStore.ContainsChunk(Request->Key->ChunkId))
+ {
+ Request->Exists = true;
+ }
+ }
+ else if (IoBuffer Payload = m_CidStore.FindChunkByCid(Request->Key->ChunkId))
+ {
+ if (!EnumHasAllFlags(Request->DownstreamPolicy, CachePolicy::SkipData))
+ {
+ Request->Value = CompressedBuffer::FromCompressedNoValidate(std::move(Payload));
+ if (Request->Value)
+ {
+ Request->Exists = true;
+ Request->RawSizeKnown = false;
+ }
+ }
+ else
+ {
+ IoHash _;
+ if (CompressedBuffer::ValidateCompressedHeader(Payload, _, Request->RawSize))
+ {
+ Request->Exists = true;
+ Request->RawSizeKnown = true;
+ }
+ }
+ }
+ }
+ if (HasUpstream && !Request->Exists && EnumHasAllFlags(Request->DownstreamPolicy, CachePolicy::QueryRemote))
+ {
+ Request->Key->Policy = ConvertToUpstream(Request->DownstreamPolicy);
+ OutUpstreamChunks.push_back(Request->Key);
+ }
+ }
+ Request->ElapsedTimeUs += Timer.GetElapsedTimeUs();
+ }
+}
+
+void
+CacheRpcHandler::GetLocalCacheValues(const CacheRequestContext& Context,
+ std::string_view Namespace,
+ std::vector<cache::detail::ChunkRequest*>& ValueRequests,
+ std::vector<CacheChunkRequest*>& OutUpstreamChunks)
+{
+ ZEN_TRACE_CPU("Z$::GetLocalCacheValues");
+
+ using namespace cache::detail;
+ const bool HasUpstream = m_UpstreamCache.IsActive();
+
+ for (ChunkRequest* Request : ValueRequests)
+ {
+ ZEN_TRACE_CPU("Z$::GetLocalCacheValues::Value");
+
+ Stopwatch Timer;
+ if (!Request->Exists && EnumHasAllFlags(Request->DownstreamPolicy, CachePolicy::QueryLocal))
+ {
+ ZenCacheValue CacheValue;
+ if (m_CacheStore.Get(Context, Namespace, Request->Key->Key.Bucket, Request->Key->Key.Hash, CacheValue))
+ {
+ if (IsCompressedBinary(CacheValue.Value.GetContentType()))
+ {
+ Request->Key->ChunkId = CacheValue.RawHash;
+ Request->Exists = true;
+ Request->RawSize = CacheValue.RawSize;
+ Request->RawSizeKnown = true;
+ if (!EnumHasAllFlags(Request->DownstreamPolicy, CachePolicy::SkipData))
+ {
+ Request->Value = CompressedBuffer::FromCompressedNoValidate(std::move(CacheValue.Value));
+ }
+ }
+ }
+ }
+ if (HasUpstream && !Request->Exists && EnumHasAllFlags(Request->DownstreamPolicy, CachePolicy::QueryRemote))
+ {
+ if (EnumHasAllFlags(Request->DownstreamPolicy, CachePolicy::StoreLocal))
+ {
+ // Convert the Offset,Size request into a request for the entire value; we will need it all to be able to store it locally
+ Request->Key->RawOffset = 0;
+ Request->Key->RawSize = UINT64_MAX;
+ }
+ OutUpstreamChunks.push_back(Request->Key);
+ }
+ Request->ElapsedTimeUs += Timer.GetElapsedTimeUs();
+ }
+}
+
+void
+CacheRpcHandler::GetUpstreamCacheChunks(const CacheRequestContext& Context,
+ std::string_view Namespace,
+ std::vector<CacheChunkRequest*>& UpstreamChunks,
+ std::vector<CacheChunkRequest>& RequestKeys,
+ std::vector<cache::detail::ChunkRequest>& Requests)
+{
+ if (UpstreamChunks.empty())
+ {
+ return;
+ }
+ ZEN_TRACE_CPU("Z$::GetUpstreamCacheChunks");
+
+ using namespace cache::detail;
+
+ const auto OnCacheChunksGetComplete = [this, Namespace, &RequestKeys, &Requests, Context](CacheChunkGetCompleteParams&& Params) {
+ if (Params.RawHash == Params.RawHash.Zero)
+ {
+ return;
+ }
+
+ CacheChunkRequest& Key = Params.Request;
+ size_t RequestIndex = std::distance(RequestKeys.data(), &Key);
+ ChunkRequest& Request = Requests[RequestIndex];
+ Request.ElapsedTimeUs += static_cast<uint64_t>(Params.ElapsedSeconds * 1000000.0);
+ if (EnumHasAllFlags(Request.DownstreamPolicy, CachePolicy::StoreLocal) ||
+ !EnumHasAllFlags(Request.DownstreamPolicy, CachePolicy::SkipData))
+ {
+ CompressedBuffer Compressed = CompressedBuffer::FromCompressedNoValidate(IoBuffer(Params.Value));
+ if (!Compressed)
+ {
+ return;
+ }
+
+ bool StoreLocal = EnumHasAllFlags(Request.DownstreamPolicy, CachePolicy::StoreLocal) && AreDiskWritesAllowed();
+ if (StoreLocal)
+ {
+ if (Request.IsRecordRequest)
+ {
+ m_CidStore.AddChunk(Params.Value, Params.RawHash);
+ }
+ else
+ {
+ m_CacheStore.Put(Context,
+ Namespace,
+ Key.Key.Bucket,
+ Key.Key.Hash,
+ {.Value = Params.Value, .RawSize = Params.RawSize, .RawHash = Params.RawHash},
+ {});
+ m_CacheStats.WriteCount++;
+ }
+ }
+ if (!EnumHasAllFlags(Request.DownstreamPolicy, CachePolicy::SkipData))
+ {
+ Request.Value = std::move(Compressed);
+ }
+ }
+ Key.ChunkId = Params.RawHash;
+ Request.Exists = true;
+ Request.RawSize = Params.RawSize;
+ Request.RawSizeKnown = true;
+ Request.Source = Params.Source;
+
+ m_CacheStats.UpstreamHitCount++;
+ };
+
+ m_UpstreamCache.GetCacheChunks(Namespace, UpstreamChunks, std::move(OnCacheChunksGetComplete));
+}
+
+CbPackage
+CacheRpcHandler::WriteGetCacheChunksResponse([[maybe_unused]] const CacheRequestContext& Context,
+ std::string_view Namespace,
+ std::vector<cache::detail::ChunkRequest>& Requests)
+{
+ ZEN_TRACE_CPU("Z$::WriteGetCacheChunksResponse");
+
+ using namespace cache::detail;
+
+ CbPackage RpcResponse;
+ CbObjectWriter Writer;
+
+ Writer.BeginArray("Result"sv);
+ for (ChunkRequest& Request : Requests)
+ {
+ ZEN_TRACE_CPU("Z$::WriteGetCacheChunksResponse::Request");
+
+ Writer.BeginObject();
+ {
+ if (Request.Exists)
+ {
+ Writer.AddHash("RawHash"sv, Request.Key->ChunkId);
+ if (Request.Value && !EnumHasAllFlags(Request.DownstreamPolicy, CachePolicy::SkipData))
+ {
+ RpcResponse.AddAttachment(CbAttachment(Request.Value, Request.Key->ChunkId));
+ }
+ else
+ {
+ Writer.AddInteger("RawSize"sv, Request.RawSize);
+ }
+
+ ZEN_DEBUG("GETCACHECHUNKS HIT - '{}/{}/{}/{}' {} '{}' ({}) in {}",
+ Namespace,
+ Request.Key->Key.Bucket,
+ Request.Key->Key.Hash,
+ Request.Key->ValueId,
+ NiceBytes(Request.RawSize),
+ Request.IsRecordRequest ? "Record"sv : "Value"sv,
+ Request.Source ? Request.Source->Url : "LOCAL"sv,
+ NiceLatencyNs(Request.ElapsedTimeUs * 1000));
+ m_CacheStats.HitCount++;
+ }
+ else if (!EnumHasAnyFlags(Request.DownstreamPolicy, CachePolicy::Query))
+ {
+ ZEN_DEBUG("GETCACHECHUNKS DISABLEDQUERY - '{}/{}/{}/{}' in {}",
+ Namespace,
+ Request.Key->Key.Bucket,
+ Request.Key->Key.Hash,
+ Request.Key->ValueId,
+ NiceLatencyNs(Request.ElapsedTimeUs * 1000));
+ }
+ else
+ {
+ ZEN_DEBUG("GETCACHECHUNKS MISS - '{}/{}/{}/{}' in {}",
+ Namespace,
+ Request.Key->Key.Bucket,
+ Request.Key->Key.Hash,
+ Request.Key->ValueId,
+ NiceLatencyNs(Request.ElapsedTimeUs * 1000));
+ m_CacheStats.MissCount++;
+ }
+ }
+ Writer.EndObject();
+ }
+ Writer.EndArray();
+
+ RpcResponse.SetObject(Writer.Save());
+ return RpcResponse;
+}
+
+} // namespace zen \ No newline at end of file