diff options
| author | Stefan Boberg <[email protected]> | 2023-12-20 16:03:35 +0100 |
|---|---|---|
| committer | GitHub <[email protected]> | 2023-12-20 16:03:35 +0100 |
| commit | 086558fd15f884cd29d1e6941a8576190c0b650d (patch) | |
| tree | 71e5e729be82d1825a228931d9c03376c659b5ca /src/zenstore/cache/cacherpc.cpp | |
| parent | move cachedisklayer and structuredcachestore into zenstore (#624) (diff) | |
| download | zen-086558fd15f884cd29d1e6941a8576190c0b650d.tar.xz zen-086558fd15f884cd29d1e6941a8576190c0b650d.zip | |
separate RPC processing from HTTP processing (#626)
* moved all RPC processing from HttpStructuredCacheService into separate CacheRpcHandler class in zenstore
* move package marshaling to zenutil. was previously in zenhttp/httpshared but it's useful in other contexts as well where we don't want to depend on zenhttp
* introduced UpstreamCacheClient, this provides a subset of functions on UpstreamCache and lives in zenstore
Diffstat (limited to 'src/zenstore/cache/cacherpc.cpp')
| -rw-r--r-- | src/zenstore/cache/cacherpc.cpp | 1640 |
1 files changed, 1640 insertions, 0 deletions
diff --git a/src/zenstore/cache/cacherpc.cpp b/src/zenstore/cache/cacherpc.cpp new file mode 100644 index 000000000..96b344ee9 --- /dev/null +++ b/src/zenstore/cache/cacherpc.cpp @@ -0,0 +1,1640 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "zenstore/cache/cacherpc.h" + +#include <zencore/compactbinary.h> +#include <zencore/compactbinarybuilder.h> +#include <zencore/compactbinarypackage.h> +#include <zencore/compactbinaryvalidation.h> +#include <zencore/fmtutils.h> +#include <zencore/scopeguard.h> +#include <zencore/timer.h> +#include <zencore/trace.h> +#include <zenstore/cache/cacheshared.h> +#include <zenstore/cache/structuredcachestore.h> +#include <zenstore/cache/upstreamcacheclient.h> +#include <zenstore/cidstore.h> +#include <zenutil/packageformat.h> + +namespace zen { namespace { + + constinit AsciiSet ValidNamespaceNameCharactersSet{"abcdefghijklmnopqrstuvwxyz0123456789-_.ABCDEFGHIJKLMNOPQRSTUVWXYZ"}; + constinit AsciiSet ValidBucketNameCharactersSet{"abcdefghijklmnopqrstuvwxyz0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"}; + + std::optional<std::string> GetValidNamespaceName(std::string_view Name) + { + if (Name.empty()) + { + ZEN_WARN("Namespace is invalid, empty namespace is not allowed"); + return {}; + } + + if (Name.length() > 64) + { + ZEN_WARN("Namespace '{}' is invalid, length exceeds 64 characters", Name); + return {}; + } + + if (!AsciiSet::HasOnly(Name, ValidNamespaceNameCharactersSet)) + { + ZEN_WARN("Namespace '{}' is invalid, invalid characters detected", Name); + return {}; + } + + return ToLower(Name); + } + + std::optional<std::string> GetValidBucketName(std::string_view Name) + { + if (Name.empty()) + { + ZEN_WARN("Bucket name is invalid, empty bucket name is not allowed"); + return {}; + } + + if (!AsciiSet::HasOnly(Name, ValidBucketNameCharactersSet)) + { + ZEN_WARN("Bucket name '{}' is invalid, invalid characters detected", Name); + return {}; + } + + return ToLower(Name); + } + +}} // namespace zen:: + +////////////////////////////////////////////////////////////////////////// + +namespace zen { + +using namespace std::literals; + +std::optional<std::string> +GetRpcRequestNamespace(const CbObjectView Params) +{ + CbFieldView NamespaceField = Params["Namespace"sv]; + if (!NamespaceField) + { + return std::string(ZenCacheStore::DefaultNamespace); + } + + if (NamespaceField.HasError()) + { + return {}; + } + if (!NamespaceField.IsString()) + { + return {}; + } + return GetValidNamespaceName(NamespaceField.AsString()); +} + +bool +GetRpcRequestCacheKey(const CbObjectView& KeyView, CacheKey& Key) +{ + CbFieldView BucketField = KeyView["Bucket"sv]; + if (BucketField.HasError()) + { + return false; + } + if (!BucketField.IsString()) + { + return false; + } + std::optional<std::string> Bucket = GetValidBucketName(BucketField.AsString()); + if (!Bucket.has_value()) + { + return false; + } + CbFieldView HashField = KeyView["Hash"sv]; + if (HashField.HasError()) + { + return false; + } + if (!HashField.IsHash()) + { + return false; + } + IoHash Hash = HashField.AsHash(); + Key = CacheKey::Create(*Bucket, Hash); + return true; +} + +namespace cache::detail { + + struct RecordValue + { + Oid ValueId; + IoHash ContentId; + uint64_t RawSize; + }; + + struct RecordBody + { + IoBuffer CacheValue; + std::vector<RecordValue> Values; + const UpstreamEndpointInfo* Source = nullptr; + CachePolicy DownstreamPolicy; + bool Exists = false; + bool HasRequest = false; + bool ValuesRead = false; + }; + + struct ChunkRequest + { + CacheChunkRequest* Key = nullptr; + RecordBody* Record = nullptr; + CompressedBuffer Value; + const UpstreamEndpointInfo* Source = nullptr; + uint64_t RawSize = 0; + uint64_t RequestedSize = 0; + uint64_t RequestedOffset = 0; + CachePolicy DownstreamPolicy; + bool Exists = false; + bool RawSizeKnown = false; + bool IsRecordRequest = false; + uint64_t ElapsedTimeUs = 0; + }; + +} // namespace cache::detail + +struct PutRequestData +{ + std::string Namespace; + CacheKey Key; + CbObjectView RecordObject; + CacheRecordPolicy Policy; + CacheRequestContext Context; +}; + +CacheRecordPolicy +LoadCacheRecordPolicy(CbObjectView Object, CachePolicy DefaultPolicy = CachePolicy::Default) +{ + OptionalCacheRecordPolicy Policy = CacheRecordPolicy::Load(Object); + return Policy ? std::move(Policy).Get() : CacheRecordPolicy(DefaultPolicy); +} + +CacheRpcHandler::CacheRpcHandler(LoggerRef InLog, + CacheStats& InCacheStats, + UpstreamCacheClient& InUpstreamCache, + ZenCacheStore& InCacheStore, + CidStore& InCidStore, + const DiskWriteBlocker* InDiskWriteBlocker) +: m_Log(InLog) +, m_CacheStats(InCacheStats) +, m_UpstreamCache(InUpstreamCache) +, m_CacheStore(InCacheStore) +, m_CidStore(InCidStore) +, m_DiskWriteBlocker(InDiskWriteBlocker) +{ +} + +CacheRpcHandler::~CacheRpcHandler() +{ +} + +bool +CacheRpcHandler::AreDiskWritesAllowed() const +{ + return (m_DiskWriteBlocker == nullptr || m_DiskWriteBlocker->AreDiskWritesAllowed()); +} + +CacheRpcHandler::RpcResponseCode +CacheRpcHandler::HandleRpcRequest(const CacheRequestContext& Context, + const ZenContentType ContentType, + IoBuffer&& Body, + uint32_t& OutAcceptMagic, + RpcAcceptOptions& OutAcceptFlags, + int& OutTargetProcessId, + CbPackage& OutResultPackage) +{ + ZEN_TRACE_CPU("Z$::HandleRpcRequest"); + + m_CacheStats.RpcRequests.fetch_add(1); + + CbPackage Package; + CbObjectView Object; + CbObject ObjectBuffer; + if (ContentType == ZenContentType::kCbObject) + { + ObjectBuffer = LoadCompactBinaryObject(std::move(Body)); + Object = ObjectBuffer; + } + else + { + Package = ParsePackageMessage(Body); + Object = Package.GetObject(); + } + OutAcceptMagic = Object["Accept"sv].AsUInt32(); + OutAcceptFlags = static_cast<RpcAcceptOptions>(Object["AcceptFlags"sv].AsUInt16(0u)); + OutTargetProcessId = Object["Pid"sv].AsInt32(0); + + const std::string_view Method = Object["Method"sv].AsString(); + + if (Method == "PutCacheRecords"sv) + { + if (!AreDiskWritesAllowed()) + { + return RpcResponseCode::InsufficientStorage; + } + OutResultPackage = HandleRpcPutCacheRecords(Context, Package); + } + else if (Method == "GetCacheRecords"sv) + { + OutResultPackage = HandleRpcGetCacheRecords(Context, Object); + } + else if (Method == "PutCacheValues"sv) + { + if (!AreDiskWritesAllowed()) + { + return RpcResponseCode::InsufficientStorage; + } + OutResultPackage = HandleRpcPutCacheValues(Context, Package); + } + else if (Method == "GetCacheValues"sv) + { + OutResultPackage = HandleRpcGetCacheValues(Context, Object); + } + else if (Method == "GetCacheChunks"sv) + { + OutResultPackage = HandleRpcGetCacheChunks(Context, Object); + } + else + { + m_CacheStats.BadRequestCount++; + return RpcResponseCode::BadRequest; + } + return RpcResponseCode::OK; +} + +CbPackage +CacheRpcHandler::HandleRpcPutCacheRecords(const CacheRequestContext& Context, const CbPackage& BatchRequest) +{ + ZEN_TRACE_CPU("Z$::RpcPutCacheRecords"); + + CbObjectView BatchObject = BatchRequest.GetObject(); + ZEN_ASSERT(BatchObject["Method"sv].AsString() == "PutCacheRecords"sv); + + CbObjectView Params = BatchObject["Params"sv].AsObjectView(); + CachePolicy DefaultPolicy; + + m_CacheStats.RpcRecordRequests.fetch_add(1); + + std::string_view PolicyText = Params["DefaultPolicy"].AsString(); + std::optional<std::string> Namespace = GetRpcRequestNamespace(Params); + if (!Namespace) + { + return CbPackage{}; + } + DefaultPolicy = !PolicyText.empty() ? ParseCachePolicy(PolicyText) : CachePolicy::Default; + + std::vector<bool> Results; + + CbArrayView RequestsArray = Params["Requests"sv].AsArrayView(); + for (CbFieldView RequestField : RequestsArray) + { + m_CacheStats.RpcRecordBatchRequests.fetch_add(1); + CbObjectView RequestObject = RequestField.AsObjectView(); + CbObjectView RecordObject = RequestObject["Record"sv].AsObjectView(); + CbObjectView KeyView = RecordObject["Key"sv].AsObjectView(); + + CacheKey Key; + if (!GetRpcRequestCacheKey(KeyView, Key)) + { + return CbPackage{}; + } + CacheRecordPolicy Policy = LoadCacheRecordPolicy(RequestObject["Policy"sv].AsObjectView(), DefaultPolicy); + PutRequestData PutRequest{.Namespace = *Namespace, + .Key = std::move(Key), + .RecordObject = RecordObject, + .Policy = std::move(Policy), + .Context = Context}; + + PutResult Result = PutCacheRecord(PutRequest, &BatchRequest); + + if (Result == PutResult::Invalid) + { + return CbPackage{}; + } + Results.push_back(Result == PutResult::Success); + } + if (Results.empty()) + { + return CbPackage{}; + } + + CbObjectWriter ResponseObject; + ResponseObject.BeginArray("Result"sv); + for (bool Value : Results) + { + ResponseObject.AddBool(Value); + } + ResponseObject.EndArray(); + + CbPackage RpcResponse; + RpcResponse.SetObject(ResponseObject.Save()); + return RpcResponse; +} + +PutResult +CacheRpcHandler::PutCacheRecord(PutRequestData& Request, const CbPackage* Package) +{ + CbObjectView Record = Request.RecordObject; + uint64_t RecordObjectSize = Record.GetSize(); + uint64_t TransferredSize = RecordObjectSize; + + AttachmentCount Count; + size_t NumAttachments = Package->GetAttachments().size(); + std::vector<IoHash> ValidAttachments; + std::vector<IoHash> ReferencedAttachments; + std::vector<const CbAttachment*> AttachmentsToStoreLocally; + ValidAttachments.reserve(NumAttachments); + AttachmentsToStoreLocally.reserve(NumAttachments); + + const bool HasUpstream = m_UpstreamCache.IsActive(); + + Stopwatch Timer; + + Request.RecordObject.IterateAttachments( + [this, &Request, Package, &AttachmentsToStoreLocally, &ValidAttachments, &ReferencedAttachments, &Count, &TransferredSize]( + CbFieldView HashView) { + const IoHash ValueHash = HashView.AsHash(); + ReferencedAttachments.push_back(ValueHash); + if (const CbAttachment* Attachment = Package ? Package->FindAttachment(ValueHash) : nullptr) + { + if (Attachment->IsCompressedBinary()) + { + AttachmentsToStoreLocally.emplace_back(Attachment); + ValidAttachments.emplace_back(ValueHash); + Count.Valid++; + } + else + { + ZEN_WARN("PUTCACHERECORD - '{}/{}/{}' '{}' FAILED, attachment '{}' is not compressed", + Request.Namespace, + Request.Key.Bucket, + Request.Key.Hash, + ToString(ZenContentType::kCbPackage), + ValueHash); + Count.Invalid++; + } + } + else if (m_CidStore.ContainsChunk(ValueHash)) + { + ValidAttachments.emplace_back(ValueHash); + Count.Valid++; + } + Count.Total++; + }); + + if (Count.Invalid > 0) + { + return PutResult::Invalid; + } + + ZenCacheValue CacheValue; + CacheValue.Value = IoBuffer(Record.GetSize()); + Record.CopyTo(MutableMemoryView(CacheValue.Value.MutableData(), CacheValue.Value.GetSize())); + CacheValue.Value.SetContentType(ZenContentType::kCbObject); + m_CacheStore.Put(Request.Context, Request.Namespace, Request.Key.Bucket, Request.Key.Hash, CacheValue, ReferencedAttachments); + m_CacheStats.WriteCount++; + + for (const CbAttachment* Attachment : AttachmentsToStoreLocally) + { + CompressedBuffer Chunk = Attachment->AsCompressedBinary(); + CidStore::InsertResult InsertResult = m_CidStore.AddChunk(Chunk.GetCompressed().Flatten().AsIoBuffer(), Attachment->GetHash()); + if (InsertResult.New) + { + Count.New++; + } + TransferredSize += Chunk.GetCompressedSize(); + } + + ZEN_DEBUG("PUTCACHERECORD - '{}/{}/{}' {}, attachments '{}/{}/{}' (new/valid/total) in {}", + Request.Namespace, + Request.Key.Bucket, + Request.Key.Hash, + NiceBytes(TransferredSize), + Count.New, + Count.Valid, + Count.Total, + NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000)); + + const bool IsPartialRecord = Count.Valid != Count.Total; + + if (HasUpstream && EnumHasAllFlags(Request.Policy.GetRecordPolicy(), CachePolicy::StoreRemote) && !IsPartialRecord) + { + m_UpstreamCache.EnqueueUpstream({.Type = ZenContentType::kCbPackage, + .Namespace = Request.Namespace, + .Key = Request.Key, + .ValueContentIds = std::move(ValidAttachments)}); + } + return PutResult::Success; +} + +CbPackage +CacheRpcHandler::HandleRpcGetCacheRecords(const CacheRequestContext& Context, CbObjectView RpcRequest) +{ + ZEN_TRACE_CPU("Z$::RpcGetCacheRecords"); + + ZEN_ASSERT(RpcRequest["Method"sv].AsString() == "GetCacheRecords"sv); + + CbObjectView Params = RpcRequest["Params"sv].AsObjectView(); + m_CacheStats.RpcRecordRequests.fetch_add(1); + + struct ValueRequestData + { + Oid ValueId; + IoHash ContentId; + CompressedBuffer Payload; + CachePolicy DownstreamPolicy; + bool Exists = false; + bool ReadFromUpstream = false; + }; + struct RecordRequestData + { + CacheKeyRequest Upstream; + CbObjectView RecordObject; + IoBuffer RecordCacheValue; + CacheRecordPolicy DownstreamPolicy; + std::vector<ValueRequestData> Values; + bool Complete = false; + const UpstreamEndpointInfo* Source = nullptr; + uint64_t ElapsedTimeUs; + }; + + std::string_view PolicyText = Params["DefaultPolicy"sv].AsString(); + CachePolicy DefaultPolicy = !PolicyText.empty() ? ParseCachePolicy(PolicyText) : CachePolicy::Default; + std::optional<std::string> Namespace = GetRpcRequestNamespace(Params); + if (!Namespace) + { + return CbPackage{}; + } + + const bool HasUpstream = m_UpstreamCache.IsActive(); + + std::vector<RecordRequestData> Requests; + std::vector<size_t> UpstreamIndexes; + + auto ParseValues = [](RecordRequestData& Request) { + CbArrayView ValuesArray = Request.RecordObject["Values"sv].AsArrayView(); + Request.Values.reserve(ValuesArray.Num()); + for (CbFieldView ValueField : ValuesArray) + { + CbObjectView ValueObject = ValueField.AsObjectView(); + Oid ValueId = ValueObject["Id"sv].AsObjectId(); + CbFieldView RawHashField = ValueObject["RawHash"sv]; + IoHash RawHash = RawHashField.AsBinaryAttachment(); + if (ValueId && !RawHashField.HasError()) + { + Request.Values.push_back({ValueId, RawHash}); + Request.Values.back().DownstreamPolicy = Request.DownstreamPolicy.GetValuePolicy(ValueId); + } + } + }; + + CbArrayView RequestsArray = Params["Requests"sv].AsArrayView(); + for (CbFieldView RequestField : RequestsArray) + { + ZEN_TRACE_CPU("Z$::RpcGetCacheRecords::Request"); + + m_CacheStats.RpcRecordBatchRequests.fetch_add(1); + + Stopwatch Timer; + RecordRequestData& Request = Requests.emplace_back(); + CbObjectView RequestObject = RequestField.AsObjectView(); + CbObjectView KeyObject = RequestObject["Key"sv].AsObjectView(); + + CacheKey& Key = Request.Upstream.Key; + if (!GetRpcRequestCacheKey(KeyObject, Key)) + { + return CbPackage{}; + } + + Request.DownstreamPolicy = LoadCacheRecordPolicy(RequestObject["Policy"sv].AsObjectView(), DefaultPolicy); + const CacheRecordPolicy& Policy = Request.DownstreamPolicy; + + ZenCacheValue CacheValue; + bool NeedUpstreamAttachment = false; + bool FoundLocalInvalid = false; + ZenCacheValue RecordCacheValue; + + if (EnumHasAllFlags(Policy.GetRecordPolicy(), CachePolicy::QueryLocal) && + m_CacheStore.Get(Context, *Namespace, Key.Bucket, Key.Hash, RecordCacheValue)) + { + Request.RecordCacheValue = std::move(RecordCacheValue.Value); + if (Request.RecordCacheValue.GetContentType() != ZenContentType::kCbObject) + { + FoundLocalInvalid = true; + } + else + { + Request.RecordObject = CbObjectView(Request.RecordCacheValue.GetData()); + ParseValues(Request); + + Request.Complete = true; + for (ValueRequestData& Value : Request.Values) + { + CachePolicy ValuePolicy = Value.DownstreamPolicy; + if (!EnumHasAllFlags(ValuePolicy, CachePolicy::QueryLocal)) + { + // A value that is requested without the Query flag (such as None/Disable) counts as existing, because we + // didn't ask for it and thus the record is complete in its absence. + if (!EnumHasAllFlags(ValuePolicy, CachePolicy::QueryRemote)) + { + Value.Exists = true; + } + else + { + NeedUpstreamAttachment = true; + Value.ReadFromUpstream = true; + Request.Complete = false; + } + } + else if (EnumHasAllFlags(ValuePolicy, CachePolicy::SkipData)) + { + if (m_CidStore.ContainsChunk(Value.ContentId)) + { + Value.Exists = true; + } + else + { + if (EnumHasAllFlags(ValuePolicy, CachePolicy::QueryRemote)) + { + NeedUpstreamAttachment = true; + Value.ReadFromUpstream = true; + } + Request.Complete = false; + } + } + else + { + if (IoBuffer Chunk = m_CidStore.FindChunkByCid(Value.ContentId)) + { + if (Chunk.GetSize() > 0) + { + Value.Payload = CompressedBuffer::FromCompressedNoValidate(std::move(Chunk)); + Value.Exists = true; + continue; + } + else + { + ZEN_WARN("Skipping invalid chunk in local cache '{}'", Value.ContentId); + } + } + + if (EnumHasAllFlags(ValuePolicy, CachePolicy::QueryRemote)) + { + NeedUpstreamAttachment = true; + Value.ReadFromUpstream = true; + } + Request.Complete = false; + } + } + } + } + if (!Request.Complete) + { + bool NeedUpstreamRecord = HasUpstream && !Request.RecordObject && !FoundLocalInvalid && + EnumHasAllFlags(Policy.GetRecordPolicy(), CachePolicy::QueryRemote); + if (NeedUpstreamRecord || NeedUpstreamAttachment) + { + UpstreamIndexes.push_back(Requests.size() - 1); + } + } + Request.ElapsedTimeUs = Timer.GetElapsedTimeUs(); + } + if (Requests.empty()) + { + return CbPackage{}; + } + + if (!UpstreamIndexes.empty()) + { + std::vector<CacheKeyRequest*> UpstreamRequests; + UpstreamRequests.reserve(UpstreamIndexes.size()); + for (size_t Index : UpstreamIndexes) + { + RecordRequestData& Request = Requests[Index]; + UpstreamRequests.push_back(&Request.Upstream); + + if (Request.Values.size()) + { + // We will be returning the local object and know all the value Ids that exist in it + // Convert all their Downstream Values to upstream values, and add SkipData to any ones that we already have. + CachePolicy UpstreamBasePolicy = ConvertToUpstream(Request.DownstreamPolicy.GetBasePolicy()) | CachePolicy::SkipMeta; + CacheRecordPolicyBuilder Builder(UpstreamBasePolicy); + for (ValueRequestData& Value : Request.Values) + { + CachePolicy UpstreamPolicy = ConvertToUpstream(Value.DownstreamPolicy); + UpstreamPolicy |= !Value.ReadFromUpstream ? CachePolicy::SkipData : CachePolicy::None; + Builder.AddValuePolicy(Value.ValueId, UpstreamPolicy); + } + Request.Upstream.Policy = Builder.Build(); + } + else + { + // We don't know which Values exist in the Record; ask the upstrem for all values that the client wants, + // and convert the CacheRecordPolicy to an upstream policy + Request.Upstream.Policy = Request.DownstreamPolicy.ConvertToUpstream(); + } + } + + const auto OnCacheRecordGetComplete = [this, Namespace, &ParseValues, Context](CacheRecordGetCompleteParams&& Params) { + if (!Params.Record) + { + return; + } + + RecordRequestData& Request = + *reinterpret_cast<RecordRequestData*>(reinterpret_cast<char*>(&Params.Request) - offsetof(RecordRequestData, Upstream)); + Request.ElapsedTimeUs += static_cast<uint64_t>(Params.ElapsedSeconds * 1000000.0); + const CacheKey& Key = Request.Upstream.Key; + Stopwatch Timer; + auto TimeGuard = MakeGuard([&Timer, &Request]() { Request.ElapsedTimeUs += Timer.GetElapsedTimeUs(); }); + if (!Request.RecordObject) + { + CbObject ObjectBuffer = CbObject::Clone(Params.Record); + Request.RecordCacheValue = ObjectBuffer.GetBuffer().AsIoBuffer(); + Request.RecordCacheValue.SetContentType(ZenContentType::kCbObject); + Request.RecordObject = ObjectBuffer; + bool StoreLocal = + EnumHasAllFlags(Request.DownstreamPolicy.GetRecordPolicy(), CachePolicy::StoreLocal) && AreDiskWritesAllowed(); + if (StoreLocal) + { + std::vector<IoHash> ReferencedAttachments; + ObjectBuffer.IterateAttachments([&ReferencedAttachments](CbFieldView HashView) { + const IoHash ValueHash = HashView.AsHash(); + ReferencedAttachments.push_back(ValueHash); + }); + m_CacheStore + .Put(Context, *Namespace, Key.Bucket, Key.Hash, {.Value = {Request.RecordCacheValue}}, ReferencedAttachments); + m_CacheStats.WriteCount++; + } + ParseValues(Request); + Request.Source = Params.Source; + } + + Request.Complete = true; + for (ValueRequestData& Value : Request.Values) + { + if (Value.Exists) + { + continue; + } + CachePolicy ValuePolicy = Value.DownstreamPolicy; + if (!EnumHasAllFlags(ValuePolicy, CachePolicy::QueryRemote)) + { + Request.Complete = false; + continue; + } + if (!EnumHasAllFlags(ValuePolicy, CachePolicy::SkipData) || EnumHasAllFlags(ValuePolicy, CachePolicy::StoreLocal)) + { + bool StoreLocal = EnumHasAllFlags(ValuePolicy, CachePolicy::StoreLocal) && AreDiskWritesAllowed(); + if (const CbAttachment* Attachment = Params.Package.FindAttachment(Value.ContentId)) + { + if (CompressedBuffer Compressed = Attachment->AsCompressedBinary()) + { + Request.Source = Params.Source; + Value.Exists = true; + if (StoreLocal) + { + m_CidStore.AddChunk(Compressed.GetCompressed().Flatten().AsIoBuffer(), Attachment->GetHash()); + } + if (!EnumHasAllFlags(ValuePolicy, CachePolicy::SkipData)) + { + Value.Payload = Compressed; + } + } + else + { + ZEN_DEBUG("Uncompressed value '{}' from upstream cache record '{}/{}/{}'", + Value.ContentId, + *Namespace, + Key.Bucket, + Key.Hash); + } + } + if (!Value.Exists && !EnumHasAllFlags(ValuePolicy, CachePolicy::SkipData)) + { + Request.Complete = false; + } + // Request.Complete does not need to be set to false for upstream SkipData attachments. + // In the PartialRecord==false case, the upstream will have failed the entire record if any SkipData attachment + // didn't exist and we will not get here. In the PartialRecord==true case, we do not need to inform the client of + // any missing SkipData attachments. + } + Request.ElapsedTimeUs += Timer.GetElapsedTimeUs(); + } + }; + + m_UpstreamCache.GetCacheRecords(*Namespace, UpstreamRequests, std::move(OnCacheRecordGetComplete)); + } + + { + ZEN_TRACE_CPU("Z$::RpcGetCacheRecords::Response"); + CbPackage ResponsePackage; + CbObjectWriter ResponseObject; + + ResponseObject.BeginArray("Result"sv); + for (RecordRequestData& Request : Requests) + { + const CacheKey& Key = Request.Upstream.Key; + if (Request.Complete || + (Request.RecordObject && EnumHasAllFlags(Request.DownstreamPolicy.GetRecordPolicy(), CachePolicy::PartialRecord))) + { + ResponseObject << Request.RecordObject; + for (ValueRequestData& Value : Request.Values) + { + if (!EnumHasAllFlags(Value.DownstreamPolicy, CachePolicy::SkipData) && Value.Payload) + { + ResponsePackage.AddAttachment(CbAttachment(Value.Payload, Value.ContentId)); + } + } + + ZEN_DEBUG("GETCACHERECORD HIT - '{}/{}/{}' {}{} ({}) in {}", + *Namespace, + Key.Bucket, + Key.Hash, + NiceBytes(Request.RecordCacheValue.Size()), + Request.Complete ? ""sv : " (PARTIAL)"sv, + Request.Source ? Request.Source->Url : "LOCAL"sv, + NiceLatencyNs(Request.ElapsedTimeUs * 1000)); + m_CacheStats.HitCount++; + m_CacheStats.UpstreamHitCount += Request.Source ? 1 : 0; + } + else + { + ResponseObject.AddNull(); + + if (!EnumHasAnyFlags(Request.DownstreamPolicy.GetRecordPolicy(), CachePolicy::Query)) + { + // If they requested no query, do not record this as a miss + ZEN_DEBUG("GETCACHERECORD DISABLEDQUERY - '{}/{}/{}' in {}", + *Namespace, + Key.Bucket, + Key.Hash, + NiceLatencyNs(Request.ElapsedTimeUs * 1000)); + } + else + { + ZEN_DEBUG("GETCACHERECORD MISS - '{}/{}/{}'{} ({}) in {}", + *Namespace, + Key.Bucket, + Key.Hash, + Request.RecordObject ? ""sv : " (PARTIAL)"sv, + Request.Source ? Request.Source->Url : "LOCAL"sv, + NiceLatencyNs(Request.ElapsedTimeUs * 1000)); + m_CacheStats.MissCount++; + } + } + } + ResponseObject.EndArray(); + ResponsePackage.SetObject(ResponseObject.Save()); + return ResponsePackage; + } +} + +CbPackage +CacheRpcHandler::HandleRpcPutCacheValues(const CacheRequestContext& Context, const CbPackage& BatchRequest) +{ + ZEN_TRACE_CPU("Z$::RpcPutCacheValues"); + CbObjectView BatchObject = BatchRequest.GetObject(); + CbObjectView Params = BatchObject["Params"sv].AsObjectView(); + + m_CacheStats.RpcValueRequests.fetch_add(1); + + std::string_view PolicyText = Params["DefaultPolicy"].AsString(); + CachePolicy DefaultPolicy = !PolicyText.empty() ? ParseCachePolicy(PolicyText) : CachePolicy::Default; + std::optional<std::string> Namespace = GetRpcRequestNamespace(Params); + if (!Namespace) + { + return CbPackage{}; + } + const bool HasUpstream = m_UpstreamCache.IsActive(); + CbArrayView RequestsArray = Params["Requests"sv].AsArrayView(); + + std::vector<bool> Results; + for (CbFieldView RequestField : RequestsArray) + { + ZEN_TRACE_CPU("Z$::RpcPutCacheValues::Request"); + + m_CacheStats.RpcValueBatchRequests.fetch_add(1); + + Stopwatch Timer; + + CbObjectView RequestObject = RequestField.AsObjectView(); + CbObjectView KeyView = RequestObject["Key"sv].AsObjectView(); + + CacheKey Key; + if (!GetRpcRequestCacheKey(KeyView, Key)) + { + return CbPackage{}; + } + + PolicyText = RequestObject["Policy"sv].AsString(); + CachePolicy Policy = !PolicyText.empty() ? ParseCachePolicy(PolicyText) : DefaultPolicy; + IoHash RawHash = RequestObject["RawHash"sv].AsBinaryAttachment(); + uint64_t RawSize = RequestObject["RawSize"sv].AsUInt64(); + bool Succeeded = false; + uint64_t TransferredSize = 0; + + if (const CbAttachment* Attachment = BatchRequest.FindAttachment(RawHash)) + { + if (Attachment->IsCompressedBinary()) + { + CompressedBuffer Chunk = Attachment->AsCompressedBinary(); + if (EnumHasAllFlags(Policy, CachePolicy::StoreRemote)) + { + // TODO: Implement upstream puts of CacheValues with StoreLocal == false. + // Currently ProcessCacheRecord requires that the value exist in the local cache to put it upstream. + Policy |= CachePolicy::StoreLocal; + } + + if (EnumHasAllFlags(Policy, CachePolicy::StoreLocal)) + { + IoBuffer Value = Chunk.GetCompressed().Flatten().AsIoBuffer(); + Value.SetContentType(ZenContentType::kCompressedBinary); + if (RawSize == 0) + { + RawSize = Chunk.DecodeRawSize(); + } + m_CacheStore + .Put(Context, *Namespace, Key.Bucket, Key.Hash, {.Value = Value, .RawSize = RawSize, .RawHash = RawHash}, {}); + m_CacheStats.WriteCount++; + TransferredSize = Chunk.GetCompressedSize(); + } + Succeeded = true; + } + else + { + ZEN_WARN("PUTCACHEVALUES - '{}/{}/{}/{}' FAILED, value is not compressed", *Namespace, Key.Bucket, Key.Hash, RawHash); + return CbPackage{}; + } + } + else if (EnumHasAllFlags(Policy, CachePolicy::QueryLocal)) + { + ZenCacheValue ExistingValue; + if (m_CacheStore.Get(Context, *Namespace, Key.Bucket, Key.Hash, ExistingValue) && + IsCompressedBinary(ExistingValue.Value.GetContentType())) + { + Succeeded = true; + } + } + // We do not search the Upstream. No data in a put means the caller is probing for whether they need to do a heavy put. + // If it doesn't exist locally they should do the heavy put rather than having us fetch it from upstream. + + if (HasUpstream && Succeeded && EnumHasAllFlags(Policy, CachePolicy::StoreRemote)) + { + m_UpstreamCache.EnqueueUpstream({.Type = ZenContentType::kCompressedBinary, .Namespace = *Namespace, .Key = Key}); + } + Results.push_back(Succeeded); + ZEN_DEBUG("PUTCACHEVALUES - '{}/{}/{}' {}, '{}' in {}", + *Namespace, + Key.Bucket, + Key.Hash, + NiceBytes(TransferredSize), + Succeeded ? "Added"sv : "Invalid", + NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000)); + } + if (Results.empty()) + { + return CbPackage{}; + } + + { + ZEN_TRACE_CPU("Z$::RpcPutCacheValues::Response"); + CbObjectWriter ResponseObject; + ResponseObject.BeginArray("Result"sv); + for (bool Value : Results) + { + ResponseObject.AddBool(Value); + } + ResponseObject.EndArray(); + + CbPackage RpcResponse; + RpcResponse.SetObject(ResponseObject.Save()); + + return RpcResponse; + } +} + +CbPackage +CacheRpcHandler::HandleRpcGetCacheValues(const CacheRequestContext& Context, CbObjectView RpcRequest) +{ + ZEN_TRACE_CPU("Z$::RpcGetCacheValues"); + ZEN_ASSERT(RpcRequest["Method"sv].AsString() == "GetCacheValues"sv); + + CbObjectView Params = RpcRequest["Params"sv].AsObjectView(); + m_CacheStats.RpcValueRequests.fetch_add(1); + + std::string_view PolicyText = Params["DefaultPolicy"sv].AsString(); + CachePolicy DefaultPolicy = !PolicyText.empty() ? ParseCachePolicy(PolicyText) : CachePolicy::Default; + std::optional<std::string> Namespace = GetRpcRequestNamespace(Params); + if (!Namespace) + { + return CbPackage{}; + } + + struct RequestData + { + CacheKey Key; + CachePolicy Policy; + IoHash RawHash = IoHash::Zero; + uint64_t RawSize = 0; + CompressedBuffer Result; + }; + std::vector<RequestData> Requests; + + std::vector<size_t> RemoteRequestIndexes; + + const bool HasUpstream = m_UpstreamCache.IsActive(); + + CbArrayView RequestsArray = Params["Requests"sv].AsArrayView(); + for (CbFieldView RequestField : RequestsArray) + { + ZEN_TRACE_CPU("Z$::RpcGetCacheValues::Request"); + + m_CacheStats.RpcValueBatchRequests.fetch_add(1); + + Stopwatch Timer; + + RequestData& Request = Requests.emplace_back(); + CbObjectView RequestObject = RequestField.AsObjectView(); + CbObjectView KeyObject = RequestObject["Key"sv].AsObjectView(); + + if (!GetRpcRequestCacheKey(KeyObject, Request.Key)) + { + return CbPackage{}; + } + + PolicyText = RequestObject["Policy"sv].AsString(); + Request.Policy = !PolicyText.empty() ? ParseCachePolicy(PolicyText) : DefaultPolicy; + + CacheKey& Key = Request.Key; + CachePolicy Policy = Request.Policy; + + ZenCacheValue CacheValue; + if (EnumHasAllFlags(Policy, CachePolicy::QueryLocal)) + { + if (m_CacheStore.Get(Context, *Namespace, Key.Bucket, Key.Hash, CacheValue) && + IsCompressedBinary(CacheValue.Value.GetContentType())) + { + Request.RawHash = CacheValue.RawHash; + Request.RawSize = CacheValue.RawSize; + Request.Result = CompressedBuffer::FromCompressedNoValidate(std::move(CacheValue.Value)); + } + } + if (Request.Result) + { + ZEN_DEBUG("GETCACHEVALUES HIT - '{}/{}/{}' {} ({}) in {}", + *Namespace, + Key.Bucket, + Key.Hash, + NiceBytes(Request.Result.GetCompressed().GetSize()), + "LOCAL"sv, + NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000)); + m_CacheStats.HitCount++; + } + else if (HasUpstream && EnumHasAllFlags(Policy, CachePolicy::QueryRemote)) + { + RemoteRequestIndexes.push_back(Requests.size() - 1); + } + else if (!EnumHasAnyFlags(Policy, CachePolicy::Query)) + { + // If they requested no query, do not record this as a miss + ZEN_DEBUG("GETCACHEVALUES DISABLEDQUERY - '{}/{}/{}'", *Namespace, Key.Bucket, Key.Hash); + } + else + { + ZEN_DEBUG("GETCACHEVALUES MISS - '{}/{}/{}' ({}) in {}", + *Namespace, + Key.Bucket, + Key.Hash, + "LOCAL"sv, + NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000)); + m_CacheStats.MissCount++; + } + } + + if (!RemoteRequestIndexes.empty()) + { + std::vector<CacheValueRequest> RequestedRecordsData; + std::vector<CacheValueRequest*> CacheValueRequests; + RequestedRecordsData.reserve(RemoteRequestIndexes.size()); + CacheValueRequests.reserve(RemoteRequestIndexes.size()); + for (size_t Index : RemoteRequestIndexes) + { + RequestData& Request = Requests[Index]; + RequestedRecordsData.push_back({.Key = {Request.Key.Bucket, Request.Key.Hash}, .Policy = ConvertToUpstream(Request.Policy)}); + CacheValueRequests.push_back(&RequestedRecordsData.back()); + } + Stopwatch Timer; + m_UpstreamCache.GetCacheValues( + *Namespace, + CacheValueRequests, + [this, Namespace, &RequestedRecordsData, &Requests, &RemoteRequestIndexes, &Timer, Context]( + CacheValueGetCompleteParams&& Params) { + CacheValueRequest& ChunkRequest = Params.Request; + if (Params.RawHash != IoHash::Zero) + { + size_t RequestOffset = std::distance(RequestedRecordsData.data(), &ChunkRequest); + size_t RequestIndex = RemoteRequestIndexes[RequestOffset]; + RequestData& Request = Requests[RequestIndex]; + Request.RawHash = Params.RawHash; + Request.RawSize = Params.RawSize; + const bool HasData = IsCompressedBinary(Params.Value.GetContentType()); + const bool SkipData = EnumHasAllFlags(Request.Policy, CachePolicy::SkipData); + const bool StoreData = EnumHasAllFlags(Request.Policy, CachePolicy::StoreLocal) && AreDiskWritesAllowed(); + const bool IsHit = SkipData || HasData; + if (IsHit) + { + if (HasData && !SkipData) + { + Request.Result = CompressedBuffer::FromCompressedNoValidate(IoBuffer(Params.Value)); + } + + if (HasData && StoreData) + { + m_CacheStore.Put(Context, + *Namespace, + Request.Key.Bucket, + Request.Key.Hash, + ZenCacheValue{.Value = Params.Value, .RawSize = Request.RawSize, .RawHash = Request.RawHash}, + {}); + m_CacheStats.WriteCount++; + } + + ZEN_DEBUG("GETCACHEVALUES HIT - '{}/{}/{}' {} ({}) in {}", + *Namespace, + ChunkRequest.Key.Bucket, + ChunkRequest.Key.Hash, + NiceBytes(Request.Result.GetCompressed().GetSize()), + Params.Source ? Params.Source->Url : "UPSTREAM", + NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000)); + m_CacheStats.HitCount++; + m_CacheStats.UpstreamHitCount++; + return; + } + } + ZEN_DEBUG("GETCACHEVALUES MISS - '{}/{}/{}' ({}) in {}", + *Namespace, + ChunkRequest.Key.Bucket, + ChunkRequest.Key.Hash, + Params.Source ? Params.Source->Url : "UPSTREAM", + NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000)); + m_CacheStats.MissCount++; + }); + } + + if (Requests.empty()) + { + return CbPackage{}; + } + + { + ZEN_TRACE_CPU("Z$::RpcGetCacheValues::Response"); + CbPackage RpcResponse; + CbObjectWriter ResponseObject; + ResponseObject.BeginArray("Result"sv); + for (const RequestData& Request : Requests) + { + ResponseObject.BeginObject(); + { + const CompressedBuffer& Result = Request.Result; + if (Result) + { + ResponseObject.AddHash("RawHash"sv, Request.RawHash); + if (!EnumHasAllFlags(Request.Policy, CachePolicy::SkipData)) + { + RpcResponse.AddAttachment(CbAttachment(Result, Request.RawHash)); + } + else + { + ResponseObject.AddInteger("RawSize"sv, Request.RawSize); + } + } + else if (Request.RawHash != IoHash::Zero) + { + ResponseObject.AddHash("RawHash"sv, Request.RawHash); + ResponseObject.AddInteger("RawSize"sv, Request.RawSize); + } + } + ResponseObject.EndObject(); + } + ResponseObject.EndArray(); + + RpcResponse.SetObject(ResponseObject.Save()); + return RpcResponse; + } +} + +CbPackage +CacheRpcHandler::HandleRpcGetCacheChunks(const CacheRequestContext& Context, CbObjectView RpcRequest) +{ + ZEN_TRACE_CPU("Z$::RpcGetCacheChunks"); + using namespace cache::detail; + + std::string Namespace; + std::vector<CacheKeyRequest> RecordKeys; // Data about a Record necessary to identify it to the upstream + std::vector<RecordBody> Records; // Scratch-space data about a Record when fulfilling RecordRequests + std::vector<CacheChunkRequest> RequestKeys; // Data about a ChunkRequest necessary to identify it to the upstream + std::vector<ChunkRequest> Requests; // Intermediate and result data about a ChunkRequest + std::vector<ChunkRequest*> RecordRequests; // The ChunkRequests that are requesting a subvalue from a Record Key + std::vector<ChunkRequest*> ValueRequests; // The ChunkRequests that are requesting a Value Key + std::vector<CacheChunkRequest*> UpstreamChunks; // ChunkRequests that we need to send to the upstream + + // Parse requests from the CompactBinary body of the RpcRequest and divide it into RecordRequests and ValueRequests + if (!ParseGetCacheChunksRequest(Namespace, RecordKeys, Records, RequestKeys, Requests, RecordRequests, ValueRequests, RpcRequest)) + { + return CbPackage{}; + } + + // For each Record request, load the Record if necessary to find the Chunk's ContentId, load its Payloads if we + // have it locally, and otherwise append a request for the payload to UpstreamChunks + GetLocalCacheRecords(Context, Namespace, RecordKeys, Records, RecordRequests, UpstreamChunks); + + // For each Value request, load the Value if we have it locally and otherwise append a request for the payload to UpstreamChunks + GetLocalCacheValues(Context, Namespace, ValueRequests, UpstreamChunks); + + // Call GetCacheChunks on the upstream for any payloads we do not have locally + GetUpstreamCacheChunks(Context, Namespace, UpstreamChunks, RequestKeys, Requests); + + // Send the payload and descriptive data about each chunk to the client + return WriteGetCacheChunksResponse(Context, Namespace, Requests); +} + +bool +CacheRpcHandler::ParseGetCacheChunksRequest(std::string& Namespace, + std::vector<CacheKeyRequest>& RecordKeys, + std::vector<cache::detail::RecordBody>& Records, + std::vector<CacheChunkRequest>& RequestKeys, + std::vector<cache::detail::ChunkRequest>& Requests, + std::vector<cache::detail::ChunkRequest*>& RecordRequests, + std::vector<cache::detail::ChunkRequest*>& ValueRequests, + CbObjectView RpcRequest) +{ + ZEN_TRACE_CPU("Z$::ParseGetCacheChunksRequest"); + + using namespace cache::detail; + + ZEN_ASSERT(RpcRequest["Method"sv].AsString() == "GetCacheChunks"sv); + + CbObjectView Params = RpcRequest["Params"sv].AsObjectView(); + m_CacheStats.RpcChunkRequests.fetch_add(1); + + std::string_view DefaultPolicyText = Params["DefaultPolicy"sv].AsString(); + CachePolicy DefaultPolicy = !DefaultPolicyText.empty() ? ParseCachePolicy(DefaultPolicyText) : CachePolicy::Default; + + std::optional<std::string> NamespaceText = GetRpcRequestNamespace(Params); + if (!NamespaceText) + { + ZEN_WARN("GetCacheChunks: Invalid namespace in ChunkRequest."); + return false; + } + Namespace = *NamespaceText; + + CbArrayView ChunkRequestsArray = Params["ChunkRequests"sv].AsArrayView(); + size_t NumRequests = static_cast<size_t>(ChunkRequestsArray.Num()); + // Note that these reservations allow us to take pointers to the elements while populating them. If the reservation is removed, + // we will need to change the pointers to indexes to handle reallocations. + RecordKeys.reserve(NumRequests); + Records.reserve(NumRequests); + RequestKeys.reserve(NumRequests); + Requests.reserve(NumRequests); + RecordRequests.reserve(NumRequests); + ValueRequests.reserve(NumRequests); + + CacheKeyRequest* PreviousRecordKey = nullptr; + RecordBody* PreviousRecord = nullptr; + + for (CbFieldView RequestView : ChunkRequestsArray) + { + ZEN_TRACE_CPU("Z$::ParseGetCacheChunksRequest::Request"); + + m_CacheStats.RpcChunkBatchRequests.fetch_add(1); + + CbObjectView RequestObject = RequestView.AsObjectView(); + CacheChunkRequest& RequestKey = RequestKeys.emplace_back(); + ChunkRequest& Request = Requests.emplace_back(); + CbObjectView KeyObject = RequestObject["Key"sv].AsObjectView(); + + Request.Key = &RequestKey; + if (!GetRpcRequestCacheKey(KeyObject, Request.Key->Key)) + { + ZEN_WARN("GetCacheChunks: Invalid key in ChunkRequest."); + return false; + } + + RequestKey.ChunkId = RequestObject["ChunkId"sv].AsHash(); + RequestKey.ValueId = RequestObject["ValueId"sv].AsObjectId(); + RequestKey.RawOffset = RequestObject["RawOffset"sv].AsUInt64(); + RequestKey.RawSize = RequestObject["RawSize"sv].AsUInt64(UINT64_MAX); + Request.RequestedSize = RequestKey.RawSize; + Request.RequestedOffset = RequestKey.RawOffset; + std::string_view PolicyText = RequestObject["Policy"sv].AsString(); + Request.DownstreamPolicy = !PolicyText.empty() ? ParseCachePolicy(PolicyText) : DefaultPolicy; + Request.IsRecordRequest = (bool)RequestKey.ValueId; + + if (!Request.IsRecordRequest) + { + ValueRequests.push_back(&Request); + } + else + { + RecordRequests.push_back(&Request); + CacheKeyRequest* RecordKey = nullptr; + RecordBody* Record = nullptr; + + if (!PreviousRecordKey || PreviousRecordKey->Key < RequestKey.Key) + { + RecordKey = &RecordKeys.emplace_back(); + PreviousRecordKey = RecordKey; + Record = &Records.emplace_back(); + PreviousRecord = Record; + RecordKey->Key = RequestKey.Key; + } + else if (RequestKey.Key == PreviousRecordKey->Key) + { + RecordKey = PreviousRecordKey; + Record = PreviousRecord; + } + else + { + ZEN_WARN("GetCacheChunks: Keys in ChunkRequest are not sorted: {}/{} came after {}/{}.", + RequestKey.Key.Bucket, + RequestKey.Key.Hash, + PreviousRecordKey->Key.Bucket, + PreviousRecordKey->Key.Hash); + return false; + } + Request.Record = Record; + if (RequestKey.ChunkId == RequestKey.ChunkId.Zero) + { + Record->DownstreamPolicy = + Record->HasRequest ? Union(Record->DownstreamPolicy, Request.DownstreamPolicy) : Request.DownstreamPolicy; + Record->HasRequest = true; + } + } + } + if (Requests.empty()) + { + return false; + } + return true; +} + +void +CacheRpcHandler::GetLocalCacheRecords(const CacheRequestContext& Context, + std::string_view Namespace, + std::vector<CacheKeyRequest>& RecordKeys, + std::vector<cache::detail::RecordBody>& Records, + std::vector<cache::detail::ChunkRequest*>& RecordRequests, + std::vector<CacheChunkRequest*>& OutUpstreamChunks) +{ + ZEN_TRACE_CPU("Z$::GetLocalCacheRecords"); + + using namespace cache::detail; + const bool HasUpstream = m_UpstreamCache.IsActive(); + + std::vector<CacheKeyRequest*> UpstreamRecordRequests; + for (size_t RecordIndex = 0; RecordIndex < Records.size(); ++RecordIndex) + { + ZEN_TRACE_CPU("Z$::GetLocalCacheRecords::Record"); + + Stopwatch Timer; + CacheKeyRequest& RecordKey = RecordKeys[RecordIndex]; + RecordBody& Record = Records[RecordIndex]; + if (Record.HasRequest) + { + Record.DownstreamPolicy |= CachePolicy::SkipData | CachePolicy::SkipMeta; + + if (!Record.Exists && EnumHasAllFlags(Record.DownstreamPolicy, CachePolicy::QueryLocal)) + { + ZenCacheValue CacheValue; + if (m_CacheStore.Get(Context, Namespace, RecordKey.Key.Bucket, RecordKey.Key.Hash, CacheValue)) + { + Record.Exists = true; + Record.CacheValue = std::move(CacheValue.Value); + } + } + if (HasUpstream && !Record.Exists && EnumHasAllFlags(Record.DownstreamPolicy, CachePolicy::QueryRemote)) + { + RecordKey.Policy = CacheRecordPolicy(ConvertToUpstream(Record.DownstreamPolicy)); + UpstreamRecordRequests.push_back(&RecordKey); + } + RecordRequests[RecordIndex]->ElapsedTimeUs += Timer.GetElapsedTimeUs(); + } + } + + if (!UpstreamRecordRequests.empty()) + { + const auto OnCacheRecordGetComplete = + [this, Namespace, &RecordKeys, &Records, &RecordRequests, Context](CacheRecordGetCompleteParams&& Params) { + if (!Params.Record) + { + return; + } + CacheKeyRequest& RecordKey = Params.Request; + size_t RecordIndex = std::distance(RecordKeys.data(), &RecordKey); + RecordRequests[RecordIndex]->ElapsedTimeUs += static_cast<uint64_t>(Params.ElapsedSeconds * 1000000.0); + RecordBody& Record = Records[RecordIndex]; + + const CacheKey& Key = RecordKey.Key; + Record.Exists = true; + CbObject ObjectBuffer = CbObject::Clone(Params.Record); + Record.CacheValue = ObjectBuffer.GetBuffer().AsIoBuffer(); + Record.CacheValue.SetContentType(ZenContentType::kCbObject); + Record.Source = Params.Source; + + bool StoreLocal = EnumHasAllFlags(Record.DownstreamPolicy, CachePolicy::StoreLocal) && AreDiskWritesAllowed(); + if (StoreLocal) + { + std::vector<IoHash> ReferencedAttachments; + ObjectBuffer.IterateAttachments([&ReferencedAttachments](CbFieldView HashView) { + const IoHash ValueHash = HashView.AsHash(); + ReferencedAttachments.push_back(ValueHash); + }); + m_CacheStore.Put(Context, Namespace, Key.Bucket, Key.Hash, {.Value = Record.CacheValue}, ReferencedAttachments); + m_CacheStats.WriteCount++; + } + }; + m_UpstreamCache.GetCacheRecords(Namespace, UpstreamRecordRequests, std::move(OnCacheRecordGetComplete)); + } + + for (ChunkRequest* Request : RecordRequests) + { + ZEN_TRACE_CPU("Z$::GetLocalCacheRecords::Chunk"); + + Stopwatch Timer; + if (Request->Key->ChunkId == IoHash::Zero) + { + // Unreal uses a 12 byte ID to address cache record values. When the uncompressed hash (ChunkId) + // is missing, parse the cache record and try to find the raw hash from the ValueId. + RecordBody& Record = *Request->Record; + if (!Record.ValuesRead) + { + Record.ValuesRead = true; + if (Record.CacheValue && Record.CacheValue.GetContentType() == ZenContentType::kCbObject) + { + CbObjectView RecordObject = CbObjectView(Record.CacheValue.GetData()); + CbArrayView ValuesArray = RecordObject["Values"sv].AsArrayView(); + Record.Values.reserve(ValuesArray.Num()); + for (CbFieldView ValueField : ValuesArray) + { + CbObjectView ValueObject = ValueField.AsObjectView(); + Oid ValueId = ValueObject["Id"sv].AsObjectId(); + CbFieldView RawHashField = ValueObject["RawHash"sv]; + IoHash RawHash = RawHashField.AsBinaryAttachment(); + if (ValueId && !RawHashField.HasError()) + { + Record.Values.push_back({ValueId, RawHash, ValueObject["RawSize"sv].AsUInt64()}); + } + } + } + } + + for (const RecordValue& Value : Record.Values) + { + if (Value.ValueId == Request->Key->ValueId) + { + Request->Key->ChunkId = Value.ContentId; + Request->RawSize = Value.RawSize; + Request->RawSizeKnown = true; + break; + } + } + } + + // Now load the ContentId from the local ContentIdStore or from the upstream + if (Request->Key->ChunkId != IoHash::Zero) + { + if (EnumHasAllFlags(Request->DownstreamPolicy, CachePolicy::QueryLocal)) + { + if (EnumHasAllFlags(Request->DownstreamPolicy, CachePolicy::SkipData) && Request->RawSizeKnown) + { + if (m_CidStore.ContainsChunk(Request->Key->ChunkId)) + { + Request->Exists = true; + } + } + else if (IoBuffer Payload = m_CidStore.FindChunkByCid(Request->Key->ChunkId)) + { + if (!EnumHasAllFlags(Request->DownstreamPolicy, CachePolicy::SkipData)) + { + Request->Value = CompressedBuffer::FromCompressedNoValidate(std::move(Payload)); + if (Request->Value) + { + Request->Exists = true; + Request->RawSizeKnown = false; + } + } + else + { + IoHash _; + if (CompressedBuffer::ValidateCompressedHeader(Payload, _, Request->RawSize)) + { + Request->Exists = true; + Request->RawSizeKnown = true; + } + } + } + } + if (HasUpstream && !Request->Exists && EnumHasAllFlags(Request->DownstreamPolicy, CachePolicy::QueryRemote)) + { + Request->Key->Policy = ConvertToUpstream(Request->DownstreamPolicy); + OutUpstreamChunks.push_back(Request->Key); + } + } + Request->ElapsedTimeUs += Timer.GetElapsedTimeUs(); + } +} + +void +CacheRpcHandler::GetLocalCacheValues(const CacheRequestContext& Context, + std::string_view Namespace, + std::vector<cache::detail::ChunkRequest*>& ValueRequests, + std::vector<CacheChunkRequest*>& OutUpstreamChunks) +{ + ZEN_TRACE_CPU("Z$::GetLocalCacheValues"); + + using namespace cache::detail; + const bool HasUpstream = m_UpstreamCache.IsActive(); + + for (ChunkRequest* Request : ValueRequests) + { + ZEN_TRACE_CPU("Z$::GetLocalCacheValues::Value"); + + Stopwatch Timer; + if (!Request->Exists && EnumHasAllFlags(Request->DownstreamPolicy, CachePolicy::QueryLocal)) + { + ZenCacheValue CacheValue; + if (m_CacheStore.Get(Context, Namespace, Request->Key->Key.Bucket, Request->Key->Key.Hash, CacheValue)) + { + if (IsCompressedBinary(CacheValue.Value.GetContentType())) + { + Request->Key->ChunkId = CacheValue.RawHash; + Request->Exists = true; + Request->RawSize = CacheValue.RawSize; + Request->RawSizeKnown = true; + if (!EnumHasAllFlags(Request->DownstreamPolicy, CachePolicy::SkipData)) + { + Request->Value = CompressedBuffer::FromCompressedNoValidate(std::move(CacheValue.Value)); + } + } + } + } + if (HasUpstream && !Request->Exists && EnumHasAllFlags(Request->DownstreamPolicy, CachePolicy::QueryRemote)) + { + if (EnumHasAllFlags(Request->DownstreamPolicy, CachePolicy::StoreLocal)) + { + // Convert the Offset,Size request into a request for the entire value; we will need it all to be able to store it locally + Request->Key->RawOffset = 0; + Request->Key->RawSize = UINT64_MAX; + } + OutUpstreamChunks.push_back(Request->Key); + } + Request->ElapsedTimeUs += Timer.GetElapsedTimeUs(); + } +} + +void +CacheRpcHandler::GetUpstreamCacheChunks(const CacheRequestContext& Context, + std::string_view Namespace, + std::vector<CacheChunkRequest*>& UpstreamChunks, + std::vector<CacheChunkRequest>& RequestKeys, + std::vector<cache::detail::ChunkRequest>& Requests) +{ + if (UpstreamChunks.empty()) + { + return; + } + ZEN_TRACE_CPU("Z$::GetUpstreamCacheChunks"); + + using namespace cache::detail; + + const auto OnCacheChunksGetComplete = [this, Namespace, &RequestKeys, &Requests, Context](CacheChunkGetCompleteParams&& Params) { + if (Params.RawHash == Params.RawHash.Zero) + { + return; + } + + CacheChunkRequest& Key = Params.Request; + size_t RequestIndex = std::distance(RequestKeys.data(), &Key); + ChunkRequest& Request = Requests[RequestIndex]; + Request.ElapsedTimeUs += static_cast<uint64_t>(Params.ElapsedSeconds * 1000000.0); + if (EnumHasAllFlags(Request.DownstreamPolicy, CachePolicy::StoreLocal) || + !EnumHasAllFlags(Request.DownstreamPolicy, CachePolicy::SkipData)) + { + CompressedBuffer Compressed = CompressedBuffer::FromCompressedNoValidate(IoBuffer(Params.Value)); + if (!Compressed) + { + return; + } + + bool StoreLocal = EnumHasAllFlags(Request.DownstreamPolicy, CachePolicy::StoreLocal) && AreDiskWritesAllowed(); + if (StoreLocal) + { + if (Request.IsRecordRequest) + { + m_CidStore.AddChunk(Params.Value, Params.RawHash); + } + else + { + m_CacheStore.Put(Context, + Namespace, + Key.Key.Bucket, + Key.Key.Hash, + {.Value = Params.Value, .RawSize = Params.RawSize, .RawHash = Params.RawHash}, + {}); + m_CacheStats.WriteCount++; + } + } + if (!EnumHasAllFlags(Request.DownstreamPolicy, CachePolicy::SkipData)) + { + Request.Value = std::move(Compressed); + } + } + Key.ChunkId = Params.RawHash; + Request.Exists = true; + Request.RawSize = Params.RawSize; + Request.RawSizeKnown = true; + Request.Source = Params.Source; + + m_CacheStats.UpstreamHitCount++; + }; + + m_UpstreamCache.GetCacheChunks(Namespace, UpstreamChunks, std::move(OnCacheChunksGetComplete)); +} + +CbPackage +CacheRpcHandler::WriteGetCacheChunksResponse([[maybe_unused]] const CacheRequestContext& Context, + std::string_view Namespace, + std::vector<cache::detail::ChunkRequest>& Requests) +{ + ZEN_TRACE_CPU("Z$::WriteGetCacheChunksResponse"); + + using namespace cache::detail; + + CbPackage RpcResponse; + CbObjectWriter Writer; + + Writer.BeginArray("Result"sv); + for (ChunkRequest& Request : Requests) + { + ZEN_TRACE_CPU("Z$::WriteGetCacheChunksResponse::Request"); + + Writer.BeginObject(); + { + if (Request.Exists) + { + Writer.AddHash("RawHash"sv, Request.Key->ChunkId); + if (Request.Value && !EnumHasAllFlags(Request.DownstreamPolicy, CachePolicy::SkipData)) + { + RpcResponse.AddAttachment(CbAttachment(Request.Value, Request.Key->ChunkId)); + } + else + { + Writer.AddInteger("RawSize"sv, Request.RawSize); + } + + ZEN_DEBUG("GETCACHECHUNKS HIT - '{}/{}/{}/{}' {} '{}' ({}) in {}", + Namespace, + Request.Key->Key.Bucket, + Request.Key->Key.Hash, + Request.Key->ValueId, + NiceBytes(Request.RawSize), + Request.IsRecordRequest ? "Record"sv : "Value"sv, + Request.Source ? Request.Source->Url : "LOCAL"sv, + NiceLatencyNs(Request.ElapsedTimeUs * 1000)); + m_CacheStats.HitCount++; + } + else if (!EnumHasAnyFlags(Request.DownstreamPolicy, CachePolicy::Query)) + { + ZEN_DEBUG("GETCACHECHUNKS DISABLEDQUERY - '{}/{}/{}/{}' in {}", + Namespace, + Request.Key->Key.Bucket, + Request.Key->Key.Hash, + Request.Key->ValueId, + NiceLatencyNs(Request.ElapsedTimeUs * 1000)); + } + else + { + ZEN_DEBUG("GETCACHECHUNKS MISS - '{}/{}/{}/{}' in {}", + Namespace, + Request.Key->Key.Bucket, + Request.Key->Key.Hash, + Request.Key->ValueId, + NiceLatencyNs(Request.ElapsedTimeUs * 1000)); + m_CacheStats.MissCount++; + } + } + Writer.EndObject(); + } + Writer.EndArray(); + + RpcResponse.SetObject(Writer.Save()); + return RpcResponse; +} + +} // namespace zen
\ No newline at end of file |