diff options
Diffstat (limited to 'zenserver/cache/structuredcache.cpp')
| -rw-r--r-- | zenserver/cache/structuredcache.cpp | 212 |
1 files changed, 212 insertions, 0 deletions
diff --git a/zenserver/cache/structuredcache.cpp b/zenserver/cache/structuredcache.cpp index 499329e94..bb4b67797 100644 --- a/zenserver/cache/structuredcache.cpp +++ b/zenserver/cache/structuredcache.cpp @@ -166,6 +166,7 @@ HttpStructuredCacheService::RegisterHandlers(WebSocketServer& Server) { Server.RegisterRequestHandler("GetBinaryCacheValue"sv, *this); Server.RegisterRequestHandler("GetCacheValues"sv, *this); + Server.RegisterRequestHandler("GetCacheRecords"sv, *this); } bool @@ -314,6 +315,217 @@ HttpStructuredCacheService::HandleRequest(const WebSocketMessage& RequestMessage return true; } + if (Method == "GetCacheRecords"sv) + { + ZEN_TRACE_CPU("Z$::WS_GetCacheRecords"); + + const std::string_view DefaultPolicyText = Params["DefaultPolicy"sv].AsString(); + CachePolicy DefaultPolicy = DefaultPolicyText.empty() ? CachePolicy::Default : ParseCachePolicy(DefaultPolicyText); + std::vector<CacheKeyRequest> UpstreamRequests; + + CbArrayView RequestsView = Params["Requests"sv].AsArrayView(); + const uint64_t RequestCount = RequestsView.Num(); + + for (int32_t Idx = 0; CbFieldView RequestField : RequestsView) + { + const int32_t RequestIndex = Idx++; + + CbObjectView RequestObject = RequestField.AsObjectView(); + + CbObjectView KeyObject = RequestObject["Key"sv].AsObjectView(); + CacheKey Key = CacheKey::Create(KeyObject["Bucket"sv].AsString(), KeyObject["Hash"sv].AsHash()); + CacheRecordPolicy Policy = LoadCacheRecordPolicy(RequestObject["Policy"sv].AsObjectView(), DefaultPolicy); + + if (EnumHasAllFlags(Policy.GetRecordPolicy(), CachePolicy::QueryLocal)) + { + ZenCacheValue RecordCacheValue; + if (m_CacheStore.Get(Key.Bucket, Key.Hash, RecordCacheValue)) + { + CbObjectView RecordObject = CbObjectView(RecordCacheValue.Value.GetData()); + CbArrayView RecordValuesView = RecordObject["Values"sv].AsArrayView(); + uint64_t TotalSize = RecordCacheValue.Value.GetSize(); + + std::vector<IoBuffer> Values; + Values.reserve(RecordValuesView.Num()); + + for (CbFieldView ValueField : RecordValuesView) + { + CbObjectView ValueObject = ValueField.AsObjectView(); + + Oid ValueId = ValueObject["Id"sv].AsObjectId(); + IoHash RawHash = ValueObject["RawHash"sv].AsHash(); + CachePolicy ValuePolicy = Policy.GetValuePolicy(ValueId); + + if (EnumHasAllFlags(ValuePolicy, CachePolicy::QueryLocal)) + { + if (IoBuffer Value = m_CidStore.FindChunkByCid(RawHash)) + { + Values.push_back(Value); + TotalSize += Value.GetSize(); + } + } + } + + const bool IsComplete = Values.size() == RecordValuesView.Num(); + + if (IsComplete) + { + CbPackage Response; + CbObjectWriter ResponseObject; + + ResponseObject.BeginObject("Result"sv); + ResponseObject.AddInteger("RequestIndex"sv, RequestIndex); + ResponseObject.AddObject("Record"sv, CbObject::Clone(RecordObject)); + ResponseObject.EndObject(); + + Response.SetObject(ResponseObject.Save()); + + for (const IoBuffer& Value : Values) + { + Response.AddAttachment(CbAttachment(CompressedBuffer::FromCompressed(SharedBuffer(Value)))); + } + + SendStreamResponse(RequestMessage.SocketId(), RequestMessage.CorrelationId(), std::move(Response)); + + ZEN_DEBUG("HIT - '{}/{}' {} '{}' (LOCAL)", + Key.Bucket, + Key.Hash, + NiceBytes(TotalSize), + ToString(ZenContentType::kCbObject)); + + continue; + } + } + } + + if (EnumHasAllFlags(Policy.GetRecordPolicy(), CachePolicy::QueryRemote) == false) + { + CbPackage Response; + CbObjectWriter ResponseObject; + + ResponseObject.BeginObject("Error"sv); + ResponseObject.AddInteger("RequestIndex"sv, RequestIndex); + ResponseObject.AddString("Reason"sv, "Not Found"sv); + ResponseObject.EndObject(); + + Response.SetObject(ResponseObject.Save()); + + SendStreamResponse(RequestMessage.SocketId(), RequestMessage.CorrelationId(), std::move(Response)); + + ZEN_DEBUG("MISS - '{}/{}'", Key.Bucket, Key.Hash); + + continue; + } + + UpstreamRequests.push_back({.Key = Key, .Policy = Policy, .UserData = uint64_t(RequestIndex)}); + } + + if (UpstreamRequests.empty()) + { + SendStreamCompleteResponse(RequestMessage.SocketId(), RequestMessage.CorrelationId()); + + return true; + } + + // TODO: Fix this + std::vector<CacheKeyRequest*> RequestPtrs; + RequestPtrs.reserve(UpstreamRequests.size()); + + for (CacheKeyRequest& Req : UpstreamRequests) + { + RequestPtrs.push_back(&Req); + } + + auto OnCacheRecordGetComplete = [this, &RequestMessage](CacheRecordGetCompleteParams&& Params) { + if (Params.Record) + { + CbPackage Response; + CbObjectWriter ResponseObject; + + CbArrayView RecordValuesView = Params.Record["Values"sv].AsArrayView(); + uint32_t AttachmentCount{}; + uint64_t TotalSize = Params.Record.GetSize(); + + for (CbFieldView ValueField : RecordValuesView) + { + CbObjectView ValueObject = ValueField.AsObjectView(); + + Oid ValueId = ValueObject["Id"sv].AsObjectId(); + IoHash RawHash = ValueObject["RawHash"sv].AsHash(); + CachePolicy ValuePolicy = Params.Request.Policy.GetValuePolicy(ValueId); + + if (const CbAttachment* Attachment = Params.Package.FindAttachment(RawHash)) + { + if (CompressedBuffer Compressed = Attachment->AsCompressedBinary()) + { + Response.AddAttachment(CbAttachment(Compressed)); + AttachmentCount++; + TotalSize += Compressed.GetCompressedSize(); + + if (EnumHasAllFlags(ValuePolicy, CachePolicy::StoreLocal)) + { + IoBuffer Value = Compressed.GetCompressed().Flatten().AsIoBuffer(); + Value.SetContentType(ZenContentType::kCompressedBinary); + m_CacheStore.Put(Params.Request.Key.Bucket, Params.Request.Key.Hash, {.Value = Value}); + } + } + } + else if (EnumHasAllFlags(Params.Request.Policy.GetRecordPolicy(), CachePolicy::QueryLocal)) + { + if (IoBuffer Chunk = m_CidStore.FindChunkByCid(RawHash)) + { + Response.AddAttachment(CbAttachment(CompressedBuffer::FromCompressed(SharedBuffer(Chunk)))); + AttachmentCount++; + } + } + } + + const bool IsComplete = AttachmentCount == RecordValuesView.Num(); + const bool AllowPartial = EnumHasAllFlags(Params.Request.Policy.GetRecordPolicy(), CachePolicy::PartialRecord); + + if (IsComplete || AllowPartial) + { + ResponseObject.BeginObject("Result"sv); + ResponseObject.AddInteger("RequestIndex"sv, int32_t(Params.Request.UserData)); + ResponseObject.AddObject("Record"sv, CbObject::Clone(Params.Record)); + ResponseObject.EndObject(); + + Response.SetObject(ResponseObject.Save()); + + SendStreamResponse(RequestMessage.SocketId(), RequestMessage.CorrelationId(), std::move(Response)); + + ZEN_DEBUG("HIT - '{}/{}' {} '{}' (UPSTREAM)", + Params.Request.Key.Bucket, + Params.Request.Key.Hash, + NiceBytes(TotalSize), + ToString(ZenContentType::kCbObject)); + + return; + } + } + + CbPackage Response; + CbObjectWriter ResponseObject; + + ResponseObject.BeginObject("Error"sv); + ResponseObject.AddInteger("RequestIndex"sv, int32_t(Params.Request.UserData)); + ResponseObject.AddString("Reason"sv, "Not Found"sv); + ResponseObject.EndObject(); + + Response.SetObject(ResponseObject.Save()); + + SendStreamResponse(RequestMessage.SocketId(), RequestMessage.CorrelationId(), std::move(Response)); + + ZEN_DEBUG("MISS - '{}/{}'", Params.Request.Key.Bucket, Params.Request.Key.Hash); + }; + + m_UpstreamCache.GetCacheRecords(RequestPtrs, std::move(OnCacheRecordGetComplete)); + + SendStreamCompleteResponse(RequestMessage.SocketId(), RequestMessage.CorrelationId()); + + return true; + } + return false; } |