diff options
Diffstat (limited to 'zenserver/cache/structuredcache.cpp')
| -rw-r--r-- | zenserver/cache/structuredcache.cpp | 199 |
1 files changed, 113 insertions, 86 deletions
diff --git a/zenserver/cache/structuredcache.cpp b/zenserver/cache/structuredcache.cpp index d469e3c68..e3204ac9d 100644 --- a/zenserver/cache/structuredcache.cpp +++ b/zenserver/cache/structuredcache.cpp @@ -236,8 +236,14 @@ HttpStructuredCacheService::HandleRequest(const WebSocketMessage& RequestMessage const std::string_view DefaultPolicyText = Params["DefaultPolicy"sv].AsString(); CachePolicy DefaultPolicy = DefaultPolicyText.empty() ? CachePolicy::Default : ParseCachePolicy(DefaultPolicyText); - for (uint32_t RequestIdx = 0; CbFieldView RequestField : Params["Requests"sv]) + CbArrayView Requests = Params["Requests"sv].AsArrayView(); + const uint64_t RequestCount = Requests.Num(); + uint64_t ResponseCount = 0; + + for (int32_t Idx = 0; CbFieldView RequestField : Params["Requests"sv]) { + const int32_t RequestIndex = Idx++; + CbObjectView RequestObject = RequestField.AsObjectView(); CbObjectView KeyObject = RequestObject["Key"sv].AsObjectView(); @@ -277,7 +283,7 @@ HttpStructuredCacheService::HandleRequest(const WebSocketMessage& RequestMessage CbObjectWriter ResponseObject; ResponseObject.BeginObject("Result"sv); - ResponseObject.AddInteger("RequestIndex"sv, RequestIdx++); + ResponseObject.AddInteger("RequestIndex"sv, RequestIndex); const IoHash RawHash = IoHash::FromBLAKE3(Compressed.GetRawHash()); const uint64_t RawSize = Compressed.GetRawSize(); @@ -299,7 +305,10 @@ HttpStructuredCacheService::HandleRequest(const WebSocketMessage& RequestMessage ResponseObject.EndObject(); Response.SetObject(ResponseObject.Save()); - SendStreamResponse(RequestMessage.SocketId(), RequestMessage.CorrelationId(), std::move(Response)); + SendStreamResponse(RequestMessage.SocketId(), + RequestMessage.CorrelationId(), + std::move(Response), + ++ResponseCount == RequestCount); if (RawSize > 0) { @@ -311,7 +320,7 @@ HttpStructuredCacheService::HandleRequest(const WebSocketMessage& RequestMessage } } - SendStreamCompleteResponse(RequestMessage.SocketId(), RequestMessage.CorrelationId()); + ZEN_ASSERT(ResponseCount == RequestCount); return true; } @@ -324,9 +333,11 @@ HttpStructuredCacheService::HandleRequest(const WebSocketMessage& RequestMessage CachePolicy DefaultPolicy = DefaultPolicyText.empty() ? CachePolicy::Default : ParseCachePolicy(DefaultPolicyText); std::vector<CacheKeyRequest> UpstreamRequests; - CbArrayView RequestsView = Params["Requests"sv].AsArrayView(); + CbArrayView Requests = Params["Requests"sv].AsArrayView(); + const uint64_t RequestCount = Requests.Num(); + uint64_t ResponseCount = 0; - for (int32_t Idx = 0; CbFieldView RequestField : RequestsView) + for (int32_t Idx = 0; CbFieldView RequestField : Requests) { const int32_t RequestIndex = Idx++; @@ -385,7 +396,10 @@ HttpStructuredCacheService::HandleRequest(const WebSocketMessage& RequestMessage Response.AddAttachment(CbAttachment(CompressedBuffer::FromCompressed(SharedBuffer(Value)))); } - SendStreamResponse(RequestMessage.SocketId(), RequestMessage.CorrelationId(), std::move(Response)); + SendStreamResponse(RequestMessage.SocketId(), + RequestMessage.CorrelationId(), + std::move(Response), + ++ResponseCount == RequestCount); ZEN_DEBUG("HIT - '{}/{}' {} '{}' (LOCAL)", Key.Bucket, @@ -410,7 +424,10 @@ HttpStructuredCacheService::HandleRequest(const WebSocketMessage& RequestMessage Response.SetObject(ResponseObject.Save()); - SendStreamResponse(RequestMessage.SocketId(), RequestMessage.CorrelationId(), std::move(Response)); + SendStreamResponse(RequestMessage.SocketId(), + RequestMessage.CorrelationId(), + std::move(Response), + ++ResponseCount == RequestCount); ZEN_DEBUG("MISS - '{}/{}'", Key.Bucket, Key.Hash); @@ -422,106 +439,111 @@ HttpStructuredCacheService::HandleRequest(const WebSocketMessage& RequestMessage if (UpstreamRequests.empty()) { - SendStreamCompleteResponse(RequestMessage.SocketId(), RequestMessage.CorrelationId()); - return true; } - // TODO: Fix this - std::vector<CacheKeyRequest*> RequestPtrs; - RequestPtrs.reserve(UpstreamRequests.size()); - - for (CacheKeyRequest& Req : UpstreamRequests) - { - RequestPtrs.push_back(&Req); - } - - auto OnCacheRecordGetComplete = [this, &RequestMessage](CacheRecordGetCompleteParams&& Params) { - if (Params.Record) - { - CbPackage Response; - CbObjectWriter ResponseObject; - - CbArrayView RecordValuesView = Params.Record["Values"sv].AsArrayView(); - uint32_t AttachmentCount{}; - uint64_t TotalSize = Params.Record.GetSize(); - - for (CbFieldView ValueField : RecordValuesView) + auto OnCacheRecordGetComplete = + [this, RequestCount, &ResponseCount, &RequestMessage](CacheRecordGetCompleteParams&& Params) mutable { + if (Params.Record) { - CbObjectView ValueObject = ValueField.AsObjectView(); + CbPackage Response; + CbObjectWriter ResponseObject; - Oid ValueId = ValueObject["Id"sv].AsObjectId(); - IoHash RawHash = ValueObject["RawHash"sv].AsHash(); - CachePolicy ValuePolicy = Params.Request.Policy.GetValuePolicy(ValueId); + CbArrayView RecordValuesView = Params.Record["Values"sv].AsArrayView(); + uint32_t AttachmentCount{}; + uint64_t TotalSize = Params.Record.GetSize(); - if (const CbAttachment* Attachment = Params.Package.FindAttachment(RawHash)) + for (CbFieldView ValueField : RecordValuesView) { - if (CompressedBuffer Compressed = Attachment->AsCompressedBinary()) - { - Response.AddAttachment(CbAttachment(Compressed)); - AttachmentCount++; - TotalSize += Compressed.GetCompressedSize(); + CbObjectView ValueObject = ValueField.AsObjectView(); - if (EnumHasAllFlags(ValuePolicy, CachePolicy::StoreLocal)) + Oid ValueId = ValueObject["Id"sv].AsObjectId(); + IoHash RawHash = ValueObject["RawHash"sv].AsHash(); + CachePolicy ValuePolicy = Params.Request.Policy.GetValuePolicy(ValueId); + + if (const CbAttachment* Attachment = Params.Package.FindAttachment(RawHash)) + { + if (CompressedBuffer Compressed = Attachment->AsCompressedBinary()) { - IoBuffer Value = Compressed.GetCompressed().Flatten().AsIoBuffer(); - Value.SetContentType(ZenContentType::kCompressedBinary); - m_CacheStore.Put(Params.Request.Key.Bucket, Params.Request.Key.Hash, {.Value = Value}); + Response.AddAttachment(CbAttachment(Compressed)); + AttachmentCount++; + TotalSize += Compressed.GetCompressedSize(); + + if (EnumHasAllFlags(ValuePolicy, CachePolicy::StoreLocal)) + { + IoBuffer Value = Compressed.GetCompressed().Flatten().AsIoBuffer(); + Value.SetContentType(ZenContentType::kCompressedBinary); + m_CacheStore.Put(Params.Request.Key.Bucket, Params.Request.Key.Hash, {.Value = Value}); + } } } - } - else if (EnumHasAllFlags(Params.Request.Policy.GetRecordPolicy(), CachePolicy::QueryLocal)) - { - if (IoBuffer Chunk = m_CidStore.FindChunkByCid(RawHash)) + else if (EnumHasAllFlags(Params.Request.Policy.GetRecordPolicy(), CachePolicy::QueryLocal)) { - Response.AddAttachment(CbAttachment(CompressedBuffer::FromCompressed(SharedBuffer(Chunk)))); - AttachmentCount++; + if (IoBuffer Chunk = m_CidStore.FindChunkByCid(RawHash)) + { + Response.AddAttachment(CbAttachment(CompressedBuffer::FromCompressed(SharedBuffer(Chunk)))); + AttachmentCount++; + } } } - } - const bool IsComplete = AttachmentCount == RecordValuesView.Num(); - const bool AllowPartial = EnumHasAllFlags(Params.Request.Policy.GetRecordPolicy(), CachePolicy::PartialRecord); + const bool IsComplete = AttachmentCount == RecordValuesView.Num(); + const bool AllowPartial = EnumHasAllFlags(Params.Request.Policy.GetRecordPolicy(), CachePolicy::PartialRecord); - if (IsComplete || AllowPartial) - { - ResponseObject.BeginObject("Result"sv); - ResponseObject.AddInteger("RequestIndex"sv, int32_t(Params.Request.UserData)); - ResponseObject.AddObject("Record"sv, CbObject::Clone(Params.Record)); - ResponseObject.EndObject(); + if (IsComplete || AllowPartial) + { + ResponseObject.BeginObject("Result"sv); + ResponseObject.AddInteger("RequestIndex"sv, int32_t(Params.Request.UserData)); + ResponseObject.AddObject("Record"sv, CbObject::Clone(Params.Record)); + ResponseObject.EndObject(); - Response.SetObject(ResponseObject.Save()); + Response.SetObject(ResponseObject.Save()); - SendStreamResponse(RequestMessage.SocketId(), RequestMessage.CorrelationId(), std::move(Response)); + SendStreamResponse(RequestMessage.SocketId(), + RequestMessage.CorrelationId(), + std::move(Response), + ++ResponseCount == RequestCount); - ZEN_DEBUG("HIT - '{}/{}' {} '{}' (UPSTREAM)", - Params.Request.Key.Bucket, - Params.Request.Key.Hash, - NiceBytes(TotalSize), - ToString(ZenContentType::kCbObject)); + ZEN_DEBUG("HIT - '{}/{}' {} '{}' (UPSTREAM)", + Params.Request.Key.Bucket, + Params.Request.Key.Hash, + NiceBytes(TotalSize), + ToString(ZenContentType::kCbObject)); - return; + return; + } } - } - CbPackage Response; - CbObjectWriter ResponseObject; + CbPackage Response; + CbObjectWriter ResponseObject; - ResponseObject.BeginObject("Error"sv); - ResponseObject.AddInteger("RequestIndex"sv, int32_t(Params.Request.UserData)); - ResponseObject.AddString("Reason"sv, "Not Found"sv); - ResponseObject.EndObject(); + ResponseObject.BeginObject("Error"sv); + ResponseObject.AddInteger("RequestIndex"sv, int32_t(Params.Request.UserData)); + ResponseObject.AddString("Reason"sv, "Not Found"sv); + ResponseObject.EndObject(); - Response.SetObject(ResponseObject.Save()); + Response.SetObject(ResponseObject.Save()); - SendStreamResponse(RequestMessage.SocketId(), RequestMessage.CorrelationId(), std::move(Response)); + SendStreamResponse(RequestMessage.SocketId(), + RequestMessage.CorrelationId(), + std::move(Response), + ++ResponseCount == RequestCount); - ZEN_DEBUG("MISS - '{}/{}'", Params.Request.Key.Bucket, Params.Request.Key.Hash); - }; + ZEN_DEBUG("MISS - '{}/{}'", Params.Request.Key.Bucket, Params.Request.Key.Hash); + }; + + // TODO: Fix this + std::vector<CacheKeyRequest*> RequestPtrs; + RequestPtrs.reserve(UpstreamRequests.size()); + + for (CacheKeyRequest& Req : UpstreamRequests) + { + RequestPtrs.push_back(&Req); + } m_UpstreamCache.GetCacheRecords(RequestPtrs, std::move(OnCacheRecordGetComplete)); - SendStreamCompleteResponse(RequestMessage.SocketId(), RequestMessage.CorrelationId()); + ZEN_ASSERT(ResponseCount == RequestCount); return true; } @@ -550,12 +572,14 @@ HttpStructuredCacheService::HandleRequest(const WebSocketMessage& RequestMessage IoBuffer CurrentRecordValue; CompressedBuffer Compressed; - CbArrayView RequestsView = Params["Requests"sv].AsArrayView(); + CbArrayView Requests = Params["Requests"sv].AsArrayView(); + const uint64_t RequestCount = Requests.Num(); + uint64_t ResponseCount = 0; - for (int32_t Idx = 0; CbFieldView RequestField : RequestsView) + for (int32_t Idx = 0; CbFieldView RequestField : Requests) { - CbObjectView RequestObject = RequestField.AsObjectView(); - const int32_t RequestIndex = Idx++; + CbObjectView RequestObject = RequestField.AsObjectView(); + const int32_t RequestIndex = Idx++; CbObjectView KeyObject = RequestObject["Key"sv].AsObjectView(); CacheKey Key = CacheKey::Create(KeyObject["Bucket"sv].AsString(), KeyObject["Hash"sv].AsHash()); @@ -609,7 +633,7 @@ HttpStructuredCacheService::HandleRequest(const WebSocketMessage& RequestMessage Response.AddAttachment(CbAttachment(std::move(Compressed))); } - SendStreamResponse(RequestMessage.SocketId(), RequestMessage.CorrelationId(), std::move(Response)); + SendStreamResponse(RequestMessage.SocketId(), RequestMessage.CorrelationId(), std::move(Response), ++ResponseCount == RequestCount); ZEN_DEBUG("HIT - '{}/{}' {} '{}' (LOCAL)", Key.Bucket, @@ -625,13 +649,16 @@ HttpStructuredCacheService::HandleRequest(const WebSocketMessage& RequestMessage ResponseObject.AddInteger("RequestIndex"sv, RequestIndex); ResponseObject.EndObject(); - SendStreamResponse(RequestMessage.SocketId(), RequestMessage.CorrelationId(), std::move(ResponseObject.Save())); + SendStreamResponse(RequestMessage.SocketId(), + RequestMessage.CorrelationId(), + std::move(ResponseObject.Save()), + ++ResponseCount == RequestCount); ZEN_DEBUG("MISS - '{}/{}' '{}'", Key.Bucket, Key.Hash, ToString(ZenContentType::kCompressedBinary)); } } - SendStreamCompleteResponse(RequestMessage.SocketId(), RequestMessage.CorrelationId()); + ZEN_ASSERT(ResponseCount == RequestCount); return true; } |