aboutsummaryrefslogtreecommitdiff
path: root/zenserver/cache/structuredcache.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'zenserver/cache/structuredcache.cpp')
-rw-r--r--zenserver/cache/structuredcache.cpp212
1 files changed, 212 insertions, 0 deletions
diff --git a/zenserver/cache/structuredcache.cpp b/zenserver/cache/structuredcache.cpp
index 499329e94..bb4b67797 100644
--- a/zenserver/cache/structuredcache.cpp
+++ b/zenserver/cache/structuredcache.cpp
@@ -166,6 +166,7 @@ HttpStructuredCacheService::RegisterHandlers(WebSocketServer& Server)
{
Server.RegisterRequestHandler("GetBinaryCacheValue"sv, *this);
Server.RegisterRequestHandler("GetCacheValues"sv, *this);
+ Server.RegisterRequestHandler("GetCacheRecords"sv, *this);
}
bool
@@ -314,6 +315,217 @@ HttpStructuredCacheService::HandleRequest(const WebSocketMessage& RequestMessage
return true;
}
+ if (Method == "GetCacheRecords"sv)
+ {
+ ZEN_TRACE_CPU("Z$::WS_GetCacheRecords");
+
+ const std::string_view DefaultPolicyText = Params["DefaultPolicy"sv].AsString();
+ CachePolicy DefaultPolicy = DefaultPolicyText.empty() ? CachePolicy::Default : ParseCachePolicy(DefaultPolicyText);
+ std::vector<CacheKeyRequest> UpstreamRequests;
+
+ CbArrayView RequestsView = Params["Requests"sv].AsArrayView();
+ const uint64_t RequestCount = RequestsView.Num();
+
+ for (int32_t Idx = 0; CbFieldView RequestField : RequestsView)
+ {
+ const int32_t RequestIndex = Idx++;
+
+ CbObjectView RequestObject = RequestField.AsObjectView();
+
+ CbObjectView KeyObject = RequestObject["Key"sv].AsObjectView();
+ CacheKey Key = CacheKey::Create(KeyObject["Bucket"sv].AsString(), KeyObject["Hash"sv].AsHash());
+ CacheRecordPolicy Policy = LoadCacheRecordPolicy(RequestObject["Policy"sv].AsObjectView(), DefaultPolicy);
+
+ if (EnumHasAllFlags(Policy.GetRecordPolicy(), CachePolicy::QueryLocal))
+ {
+ ZenCacheValue RecordCacheValue;
+ if (m_CacheStore.Get(Key.Bucket, Key.Hash, RecordCacheValue))
+ {
+ CbObjectView RecordObject = CbObjectView(RecordCacheValue.Value.GetData());
+ CbArrayView RecordValuesView = RecordObject["Values"sv].AsArrayView();
+ uint64_t TotalSize = RecordCacheValue.Value.GetSize();
+
+ std::vector<IoBuffer> Values;
+ Values.reserve(RecordValuesView.Num());
+
+ for (CbFieldView ValueField : RecordValuesView)
+ {
+ CbObjectView ValueObject = ValueField.AsObjectView();
+
+ Oid ValueId = ValueObject["Id"sv].AsObjectId();
+ IoHash RawHash = ValueObject["RawHash"sv].AsHash();
+ CachePolicy ValuePolicy = Policy.GetValuePolicy(ValueId);
+
+ if (EnumHasAllFlags(ValuePolicy, CachePolicy::QueryLocal))
+ {
+ if (IoBuffer Value = m_CidStore.FindChunkByCid(RawHash))
+ {
+ Values.push_back(Value);
+ TotalSize += Value.GetSize();
+ }
+ }
+ }
+
+ const bool IsComplete = Values.size() == RecordValuesView.Num();
+
+ if (IsComplete)
+ {
+ CbPackage Response;
+ CbObjectWriter ResponseObject;
+
+ ResponseObject.BeginObject("Result"sv);
+ ResponseObject.AddInteger("RequestIndex"sv, RequestIndex);
+ ResponseObject.AddObject("Record"sv, CbObject::Clone(RecordObject));
+ ResponseObject.EndObject();
+
+ Response.SetObject(ResponseObject.Save());
+
+ for (const IoBuffer& Value : Values)
+ {
+ Response.AddAttachment(CbAttachment(CompressedBuffer::FromCompressed(SharedBuffer(Value))));
+ }
+
+ SendStreamResponse(RequestMessage.SocketId(), RequestMessage.CorrelationId(), std::move(Response));
+
+ ZEN_DEBUG("HIT - '{}/{}' {} '{}' (LOCAL)",
+ Key.Bucket,
+ Key.Hash,
+ NiceBytes(TotalSize),
+ ToString(ZenContentType::kCbObject));
+
+ continue;
+ }
+ }
+ }
+
+ if (EnumHasAllFlags(Policy.GetRecordPolicy(), CachePolicy::QueryRemote) == false)
+ {
+ CbPackage Response;
+ CbObjectWriter ResponseObject;
+
+ ResponseObject.BeginObject("Error"sv);
+ ResponseObject.AddInteger("RequestIndex"sv, RequestIndex);
+ ResponseObject.AddString("Reason"sv, "Not Found"sv);
+ ResponseObject.EndObject();
+
+ Response.SetObject(ResponseObject.Save());
+
+ SendStreamResponse(RequestMessage.SocketId(), RequestMessage.CorrelationId(), std::move(Response));
+
+ ZEN_DEBUG("MISS - '{}/{}'", Key.Bucket, Key.Hash);
+
+ continue;
+ }
+
+ UpstreamRequests.push_back({.Key = Key, .Policy = Policy, .UserData = uint64_t(RequestIndex)});
+ }
+
+ if (UpstreamRequests.empty())
+ {
+ SendStreamCompleteResponse(RequestMessage.SocketId(), RequestMessage.CorrelationId());
+
+ return true;
+ }
+
+ // TODO: Fix this
+ std::vector<CacheKeyRequest*> RequestPtrs;
+ RequestPtrs.reserve(UpstreamRequests.size());
+
+ for (CacheKeyRequest& Req : UpstreamRequests)
+ {
+ RequestPtrs.push_back(&Req);
+ }
+
+ auto OnCacheRecordGetComplete = [this, &RequestMessage](CacheRecordGetCompleteParams&& Params) {
+ if (Params.Record)
+ {
+ CbPackage Response;
+ CbObjectWriter ResponseObject;
+
+ CbArrayView RecordValuesView = Params.Record["Values"sv].AsArrayView();
+ uint32_t AttachmentCount{};
+ uint64_t TotalSize = Params.Record.GetSize();
+
+ for (CbFieldView ValueField : RecordValuesView)
+ {
+ CbObjectView ValueObject = ValueField.AsObjectView();
+
+ Oid ValueId = ValueObject["Id"sv].AsObjectId();
+ IoHash RawHash = ValueObject["RawHash"sv].AsHash();
+ CachePolicy ValuePolicy = Params.Request.Policy.GetValuePolicy(ValueId);
+
+ if (const CbAttachment* Attachment = Params.Package.FindAttachment(RawHash))
+ {
+ if (CompressedBuffer Compressed = Attachment->AsCompressedBinary())
+ {
+ Response.AddAttachment(CbAttachment(Compressed));
+ AttachmentCount++;
+ TotalSize += Compressed.GetCompressedSize();
+
+ if (EnumHasAllFlags(ValuePolicy, CachePolicy::StoreLocal))
+ {
+ IoBuffer Value = Compressed.GetCompressed().Flatten().AsIoBuffer();
+ Value.SetContentType(ZenContentType::kCompressedBinary);
+ m_CacheStore.Put(Params.Request.Key.Bucket, Params.Request.Key.Hash, {.Value = Value});
+ }
+ }
+ }
+ else if (EnumHasAllFlags(Params.Request.Policy.GetRecordPolicy(), CachePolicy::QueryLocal))
+ {
+ if (IoBuffer Chunk = m_CidStore.FindChunkByCid(RawHash))
+ {
+ Response.AddAttachment(CbAttachment(CompressedBuffer::FromCompressed(SharedBuffer(Chunk))));
+ AttachmentCount++;
+ }
+ }
+ }
+
+ const bool IsComplete = AttachmentCount == RecordValuesView.Num();
+ const bool AllowPartial = EnumHasAllFlags(Params.Request.Policy.GetRecordPolicy(), CachePolicy::PartialRecord);
+
+ if (IsComplete || AllowPartial)
+ {
+ ResponseObject.BeginObject("Result"sv);
+ ResponseObject.AddInteger("RequestIndex"sv, int32_t(Params.Request.UserData));
+ ResponseObject.AddObject("Record"sv, CbObject::Clone(Params.Record));
+ ResponseObject.EndObject();
+
+ Response.SetObject(ResponseObject.Save());
+
+ SendStreamResponse(RequestMessage.SocketId(), RequestMessage.CorrelationId(), std::move(Response));
+
+ ZEN_DEBUG("HIT - '{}/{}' {} '{}' (UPSTREAM)",
+ Params.Request.Key.Bucket,
+ Params.Request.Key.Hash,
+ NiceBytes(TotalSize),
+ ToString(ZenContentType::kCbObject));
+
+ return;
+ }
+ }
+
+ CbPackage Response;
+ CbObjectWriter ResponseObject;
+
+ ResponseObject.BeginObject("Error"sv);
+ ResponseObject.AddInteger("RequestIndex"sv, int32_t(Params.Request.UserData));
+ ResponseObject.AddString("Reason"sv, "Not Found"sv);
+ ResponseObject.EndObject();
+
+ Response.SetObject(ResponseObject.Save());
+
+ SendStreamResponse(RequestMessage.SocketId(), RequestMessage.CorrelationId(), std::move(Response));
+
+ ZEN_DEBUG("MISS - '{}/{}'", Params.Request.Key.Bucket, Params.Request.Key.Hash);
+ };
+
+ m_UpstreamCache.GetCacheRecords(RequestPtrs, std::move(OnCacheRecordGetComplete));
+
+ SendStreamCompleteResponse(RequestMessage.SocketId(), RequestMessage.CorrelationId());
+
+ return true;
+ }
+
return false;
}