diff options
Diffstat (limited to 'zenserver/cache/structuredcache.cpp')
| -rw-r--r-- | zenserver/cache/structuredcache.cpp | 394 |
1 files changed, 229 insertions, 165 deletions
diff --git a/zenserver/cache/structuredcache.cpp b/zenserver/cache/structuredcache.cpp index e3204ac9d..8df6074b7 100644 --- a/zenserver/cache/structuredcache.cpp +++ b/zenserver/cache/structuredcache.cpp @@ -170,9 +170,108 @@ HttpStructuredCacheService::RegisterHandlers(WebSocketServer& Server) Server.RegisterRequestHandler("GetCacheChunks"sv, *this); } +class ResponseStreamWriter +{ +public: + ResponseStreamWriter(WebSocketServer& Server, + const WebSocketMessage& Request, + uint32_t RequestCount, + uint32_t MaxBatchCount = ~uint32_t(0)); + ~ResponseStreamWriter(); + + template<typename Fn> + void Append(Fn&& Append) + { + ZEN_ASSERT(m_ResponseCount < m_RequestCount); + + if (m_CurrentBatchCount == 0) + { + m_Writer.BeginObject(); + m_Writer.BeginArray("Result"sv); + } + + Append(m_StreamResponse, m_Writer); + + if (++m_CurrentBatchCount >= m_MaxBatchCount) + { + SendCurrentBatch(); + } + } + + bool Flush(); + +private: + bool SendCurrentBatch(); + + WebSocketServer& m_Server; + CbWriter m_Writer; + CbPackage m_StreamResponse; + WebSocketId m_SocketId; + uint32_t m_CorrelationId; + uint32_t m_RequestCount; + uint32_t m_MaxBatchCount; + uint32_t m_CurrentBatchCount{0}; + uint32_t m_ResponseCount{0}; +}; + +ResponseStreamWriter::ResponseStreamWriter(WebSocketServer& Server, + const WebSocketMessage& Request, + uint32_t RequestCount, + uint32_t MaxBatchCount) +: m_Server(Server) +, m_SocketId(Request.SocketId()) +, m_CorrelationId(Request.CorrelationId()) +, m_RequestCount(RequestCount) +, m_MaxBatchCount(MaxBatchCount) +{ +} + +ResponseStreamWriter::~ResponseStreamWriter() +{ + Flush(); +} + +bool +ResponseStreamWriter::SendCurrentBatch() +{ + if (m_CurrentBatchCount > 0) + { + m_Writer.EndArray(); + m_Writer.EndObject(); + + m_ResponseCount += m_CurrentBatchCount; + m_CurrentBatchCount = 0; + + CbPackage StreamResponse = std::move(m_StreamResponse); + StreamResponse.SetObject(m_Writer.Save().AsObject()); + + m_StreamResponse.Reset(); + m_Writer.Reset(); + + WebSocketMessage Message; + Message.SetMessageType(m_ResponseCount == m_RequestCount ? WebSocketMessageType::kStreamCompleteResponse + : WebSocketMessageType::kStreamResponse); + Message.SetCorrelationId(m_CorrelationId); + Message.SetSocketId(m_SocketId); + Message.SetBody(std::move(StreamResponse)); + + m_Server.SendResponse(std::move(Message)); + } + + return m_ResponseCount == m_RequestCount; +} + +bool +ResponseStreamWriter::Flush() +{ + return SendCurrentBatch(); +} + bool HttpStructuredCacheService::HandleRequest(const WebSocketMessage& RequestMessage) { + const uint32_t kMaxBatchCount = 16; + CbObjectView Request = RequestMessage.Body().GetObject(); const auto Method = Request["Method"].AsString(); @@ -236,9 +335,10 @@ HttpStructuredCacheService::HandleRequest(const WebSocketMessage& RequestMessage const std::string_view DefaultPolicyText = Params["DefaultPolicy"sv].AsString(); CachePolicy DefaultPolicy = DefaultPolicyText.empty() ? CachePolicy::Default : ParseCachePolicy(DefaultPolicyText); - CbArrayView Requests = Params["Requests"sv].AsArrayView(); - const uint64_t RequestCount = Requests.Num(); - uint64_t ResponseCount = 0; + CbArrayView Requests = Params["Requests"sv].AsArrayView(); + const uint64_t RequestCount = Requests.Num(); + + ResponseStreamWriter ResponseStream(SocketServer(), RequestMessage, uint32_t(RequestCount), kMaxBatchCount); for (int32_t Idx = 0; CbFieldView RequestField : Params["Requests"sv]) { @@ -279,36 +379,29 @@ HttpStructuredCacheService::HandleRequest(const WebSocketMessage& RequestMessage } } - CbPackage Response; - CbObjectWriter ResponseObject; - - ResponseObject.BeginObject("Result"sv); - ResponseObject.AddInteger("RequestIndex"sv, RequestIndex); + const uint64_t RawSize = Compressed.IsNull() ? 0 : Compressed.GetRawSize(); - const IoHash RawHash = IoHash::FromBLAKE3(Compressed.GetRawHash()); - const uint64_t RawSize = Compressed.GetRawSize(); - - if (Compressed) - { - ResponseObject.AddHash("RawHash"sv, RawHash); + ResponseStream.Append([&RequestIndex, &Policy, &Compressed, &RawSize](CbPackage& Response, CbWriter& ResponseObject) { + ResponseObject.BeginObject(); + ResponseObject.AddInteger("RequestIndex"sv, RequestIndex); - if (EnumHasAllFlags(Policy, CachePolicy::SkipData)) - { - ResponseObject.AddInteger("RawSize"sv, RawSize); - } - else + if (Compressed) { - Response.AddAttachment(CbAttachment(std::move(Compressed))); - } - } + const IoHash RawHash = IoHash::FromBLAKE3(Compressed.GetRawHash()); + ResponseObject.AddHash("RawHash"sv, RawHash); - ResponseObject.EndObject(); - Response.SetObject(ResponseObject.Save()); + if (EnumHasAllFlags(Policy, CachePolicy::SkipData)) + { + ResponseObject.AddInteger("RawSize"sv, RawSize); + } + else + { + Response.AddAttachment(CbAttachment(std::move(Compressed))); + } + } - SendStreamResponse(RequestMessage.SocketId(), - RequestMessage.CorrelationId(), - std::move(Response), - ++ResponseCount == RequestCount); + ResponseObject.EndObject(); + }); if (RawSize > 0) { @@ -320,7 +413,8 @@ HttpStructuredCacheService::HandleRequest(const WebSocketMessage& RequestMessage } } - ZEN_ASSERT(ResponseCount == RequestCount); + const bool IsStreamComplete = ResponseStream.Flush(); + ZEN_ASSERT(IsStreamComplete); return true; } @@ -333,9 +427,10 @@ HttpStructuredCacheService::HandleRequest(const WebSocketMessage& RequestMessage CachePolicy DefaultPolicy = DefaultPolicyText.empty() ? CachePolicy::Default : ParseCachePolicy(DefaultPolicyText); std::vector<CacheKeyRequest> UpstreamRequests; - CbArrayView Requests = Params["Requests"sv].AsArrayView(); - const uint64_t RequestCount = Requests.Num(); - uint64_t ResponseCount = 0; + CbArrayView Requests = Params["Requests"sv].AsArrayView(); + const uint64_t RequestCount = Requests.Num(); + + ResponseStreamWriter ResponseStream(SocketServer(), RequestMessage, uint32_t(RequestCount), kMaxBatchCount); for (int32_t Idx = 0; CbFieldView RequestField : Requests) { @@ -381,25 +476,17 @@ HttpStructuredCacheService::HandleRequest(const WebSocketMessage& RequestMessage if (IsComplete) { - CbPackage Response; - CbObjectWriter ResponseObject; - - ResponseObject.BeginObject("Result"sv); - ResponseObject.AddInteger("RequestIndex"sv, RequestIndex); - ResponseObject.AddObject("Record"sv, CbObject::Clone(RecordObject)); - ResponseObject.EndObject(); - - Response.SetObject(ResponseObject.Save()); - - for (const IoBuffer& Value : Values) - { - Response.AddAttachment(CbAttachment(CompressedBuffer::FromCompressed(SharedBuffer(Value)))); - } + ResponseStream.Append([&](CbPackage& Response, CbWriter& ResponseObject) { + ResponseObject.BeginObject(); + ResponseObject.AddInteger("RequestIndex"sv, RequestIndex); + ResponseObject.AddObject("Record"sv, CbObject::Clone(RecordObject)); + ResponseObject.EndObject(); - SendStreamResponse(RequestMessage.SocketId(), - RequestMessage.CorrelationId(), - std::move(Response), - ++ResponseCount == RequestCount); + for (const IoBuffer& Value : Values) + { + Response.AddAttachment(CbAttachment(CompressedBuffer::FromCompressed(SharedBuffer(Value)))); + } + }); ZEN_DEBUG("HIT - '{}/{}' {} '{}' (LOCAL)", Key.Bucket, @@ -412,125 +499,109 @@ HttpStructuredCacheService::HandleRequest(const WebSocketMessage& RequestMessage } } - if (EnumHasAllFlags(Policy.GetRecordPolicy(), CachePolicy::QueryRemote) == false) + // if (EnumHasAllFlags(Policy.GetRecordPolicy(), CachePolicy::QueryRemote) == false) { - CbPackage Response; - CbObjectWriter ResponseObject; - - ResponseObject.BeginObject("Error"sv); - ResponseObject.AddInteger("RequestIndex"sv, RequestIndex); - ResponseObject.AddString("Reason"sv, "Not Found"sv); - ResponseObject.EndObject(); - - Response.SetObject(ResponseObject.Save()); - - SendStreamResponse(RequestMessage.SocketId(), - RequestMessage.CorrelationId(), - std::move(Response), - ++ResponseCount == RequestCount); + ResponseStream.Append([&](CbPackage&, CbWriter& ResponseObject) { + ResponseObject.BeginObject(); + ResponseObject.AddInteger("RequestIndex"sv, RequestIndex); + ResponseObject.EndObject(); + }); ZEN_DEBUG("MISS - '{}/{}'", Key.Bucket, Key.Hash); continue; } - UpstreamRequests.push_back({.Key = Key, .Policy = Policy, .UserData = uint64_t(RequestIndex)}); + // UpstreamRequests.push_back({.Key = Key, .Policy = Policy, .UserData = uint64_t(RequestIndex)}); } + bool IsStreamComplete = ResponseStream.Flush(); + if (UpstreamRequests.empty()) { + ZEN_ASSERT(IsStreamComplete); return true; } - auto OnCacheRecordGetComplete = - [this, RequestCount, &ResponseCount, &RequestMessage](CacheRecordGetCompleteParams&& Params) mutable { - if (Params.Record) + auto OnCacheRecordGetComplete = [this, &ResponseStream](CacheRecordGetCompleteParams&& Params) mutable { + if (Params.Record) + { + CbArrayView RecordValuesView = Params.Record["Values"sv].AsArrayView(); + uint32_t AttachmentCount{}; + uint64_t TotalSize = Params.Record.GetSize(); + std::vector<CbAttachment> Attachments; + + for (CbFieldView ValueField : RecordValuesView) { - CbPackage Response; - CbObjectWriter ResponseObject; + CbObjectView ValueObject = ValueField.AsObjectView(); - CbArrayView RecordValuesView = Params.Record["Values"sv].AsArrayView(); - uint32_t AttachmentCount{}; - uint64_t TotalSize = Params.Record.GetSize(); + Oid ValueId = ValueObject["Id"sv].AsObjectId(); + IoHash RawHash = ValueObject["RawHash"sv].AsHash(); + CachePolicy ValuePolicy = Params.Request.Policy.GetValuePolicy(ValueId); - for (CbFieldView ValueField : RecordValuesView) + if (const CbAttachment* Attachment = Params.Package.FindAttachment(RawHash)) { - CbObjectView ValueObject = ValueField.AsObjectView(); - - Oid ValueId = ValueObject["Id"sv].AsObjectId(); - IoHash RawHash = ValueObject["RawHash"sv].AsHash(); - CachePolicy ValuePolicy = Params.Request.Policy.GetValuePolicy(ValueId); - - if (const CbAttachment* Attachment = Params.Package.FindAttachment(RawHash)) + if (CompressedBuffer Compressed = Attachment->AsCompressedBinary()) { - if (CompressedBuffer Compressed = Attachment->AsCompressedBinary()) - { - Response.AddAttachment(CbAttachment(Compressed)); - AttachmentCount++; - TotalSize += Compressed.GetCompressedSize(); + Attachments.emplace_back(Compressed); + // Response.AddAttachment(CbAttachment(Compressed)); + AttachmentCount++; + TotalSize += Compressed.GetCompressedSize(); - if (EnumHasAllFlags(ValuePolicy, CachePolicy::StoreLocal)) - { - IoBuffer Value = Compressed.GetCompressed().Flatten().AsIoBuffer(); - Value.SetContentType(ZenContentType::kCompressedBinary); - m_CacheStore.Put(Params.Request.Key.Bucket, Params.Request.Key.Hash, {.Value = Value}); - } + if (EnumHasAllFlags(ValuePolicy, CachePolicy::StoreLocal)) + { + IoBuffer Value = Compressed.GetCompressed().Flatten().AsIoBuffer(); + Value.SetContentType(ZenContentType::kCompressedBinary); + m_CacheStore.Put(Params.Request.Key.Bucket, Params.Request.Key.Hash, {.Value = Value}); } } - else if (EnumHasAllFlags(Params.Request.Policy.GetRecordPolicy(), CachePolicy::QueryLocal)) + } + else if (EnumHasAllFlags(Params.Request.Policy.GetRecordPolicy(), CachePolicy::QueryLocal)) + { + if (IoBuffer Chunk = m_CidStore.FindChunkByCid(RawHash)) { - if (IoBuffer Chunk = m_CidStore.FindChunkByCid(RawHash)) - { - Response.AddAttachment(CbAttachment(CompressedBuffer::FromCompressed(SharedBuffer(Chunk)))); - AttachmentCount++; - } + // Response.AddAttachment(CbAttachment(CompressedBuffer::FromCompressed(SharedBuffer(Chunk)))); + Attachments.emplace_back(CompressedBuffer::FromCompressed(SharedBuffer(Chunk))); + AttachmentCount++; } } + } - const bool IsComplete = AttachmentCount == RecordValuesView.Num(); - const bool AllowPartial = EnumHasAllFlags(Params.Request.Policy.GetRecordPolicy(), CachePolicy::PartialRecord); + const bool IsComplete = AttachmentCount == RecordValuesView.Num(); + const bool AllowPartial = EnumHasAllFlags(Params.Request.Policy.GetRecordPolicy(), CachePolicy::PartialRecord); - if (IsComplete || AllowPartial) - { - ResponseObject.BeginObject("Result"sv); + if (IsComplete || AllowPartial) + { + ResponseStream.Append([&](CbPackage& Response, CbWriter& ResponseObject) { + for (CbAttachment& Attachment : Attachments) + { + Response.AddAttachment(std::move(Attachment)); + } + + ResponseObject.BeginObject(); ResponseObject.AddInteger("RequestIndex"sv, int32_t(Params.Request.UserData)); ResponseObject.AddObject("Record"sv, CbObject::Clone(Params.Record)); ResponseObject.EndObject(); + }); - Response.SetObject(ResponseObject.Save()); - - SendStreamResponse(RequestMessage.SocketId(), - RequestMessage.CorrelationId(), - std::move(Response), - ++ResponseCount == RequestCount); + ZEN_DEBUG("HIT - '{}/{}' {} '{}' (UPSTREAM)", + Params.Request.Key.Bucket, + Params.Request.Key.Hash, + NiceBytes(TotalSize), + ToString(ZenContentType::kCbObject)); - ZEN_DEBUG("HIT - '{}/{}' {} '{}' (UPSTREAM)", - Params.Request.Key.Bucket, - Params.Request.Key.Hash, - NiceBytes(TotalSize), - ToString(ZenContentType::kCbObject)); - - return; - } + return; } + } - CbPackage Response; - CbObjectWriter ResponseObject; - - ResponseObject.BeginObject("Error"sv); + ResponseStream.Append([&](CbPackage&, CbWriter& ResponseObject) { + ResponseObject.BeginObject(); ResponseObject.AddInteger("RequestIndex"sv, int32_t(Params.Request.UserData)); - ResponseObject.AddString("Reason"sv, "Not Found"sv); ResponseObject.EndObject(); + }); - Response.SetObject(ResponseObject.Save()); - - SendStreamResponse(RequestMessage.SocketId(), - RequestMessage.CorrelationId(), - std::move(Response), - ++ResponseCount == RequestCount); - - ZEN_DEBUG("MISS - '{}/{}'", Params.Request.Key.Bucket, Params.Request.Key.Hash); - }; + ZEN_DEBUG("MISS - '{}/{}'", Params.Request.Key.Bucket, Params.Request.Key.Hash); + }; // TODO: Fix this std::vector<CacheKeyRequest*> RequestPtrs; @@ -543,7 +614,8 @@ HttpStructuredCacheService::HandleRequest(const WebSocketMessage& RequestMessage m_UpstreamCache.GetCacheRecords(RequestPtrs, std::move(OnCacheRecordGetComplete)); - ZEN_ASSERT(ResponseCount == RequestCount); + IsStreamComplete = ResponseStream.Flush(); + ZEN_ASSERT(IsStreamComplete); return true; } @@ -572,14 +644,15 @@ HttpStructuredCacheService::HandleRequest(const WebSocketMessage& RequestMessage IoBuffer CurrentRecordValue; CompressedBuffer Compressed; - CbArrayView Requests = Params["Requests"sv].AsArrayView(); - const uint64_t RequestCount = Requests.Num(); - uint64_t ResponseCount = 0; + CbArrayView Requests = Params["Requests"sv].AsArrayView(); + const uint64_t RequestCount = Requests.Num(); + + ResponseStreamWriter ResponseStream(SocketServer(), RequestMessage, uint32_t(RequestCount), kMaxBatchCount); for (int32_t Idx = 0; CbFieldView RequestField : Requests) { - CbObjectView RequestObject = RequestField.AsObjectView(); - const int32_t RequestIndex = Idx++; + CbObjectView RequestObject = RequestField.AsObjectView(); + const int32_t RequestIndex = Idx++; CbObjectView KeyObject = RequestObject["Key"sv].AsObjectView(); CacheKey Key = CacheKey::Create(KeyObject["Bucket"sv].AsString(), KeyObject["Hash"sv].AsHash()); @@ -617,23 +690,18 @@ HttpStructuredCacheService::HandleRequest(const WebSocketMessage& RequestMessage if (Compressed || RawSize > 0) { - CbPackage Response; - CbObjectWriter ResponseObject; - - ResponseObject.BeginObject("Result"sv); - ResponseObject.AddInteger("RequestIndex"sv, RequestIndex); - ResponseObject.AddHash("RawHash"sv, Cid); - ResponseObject.AddInteger("RawSize"sv, RawSize); - ResponseObject.EndObject(); - - Response.SetObject(ResponseObject.Save()); - - if (Compressed) - { - Response.AddAttachment(CbAttachment(std::move(Compressed))); - } + ResponseStream.Append([&Compressed, RequestIndex, &Cid, RawSize](CbPackage& Response, CbWriter& ResponseObject) { + if (Compressed) + { + Response.AddAttachment(CbAttachment(std::move(Compressed))); + } - SendStreamResponse(RequestMessage.SocketId(), RequestMessage.CorrelationId(), std::move(Response), ++ResponseCount == RequestCount); + ResponseObject.BeginObject(); + ResponseObject.AddInteger("RequestIndex"sv, RequestIndex); + ResponseObject.AddHash("RawHash"sv, Cid); + ResponseObject.AddInteger("RawSize"sv, RawSize); + ResponseObject.EndObject(); + }); ZEN_DEBUG("HIT - '{}/{}' {} '{}' (LOCAL)", Key.Bucket, @@ -643,22 +711,18 @@ HttpStructuredCacheService::HandleRequest(const WebSocketMessage& RequestMessage } else { - CbObjectWriter ResponseObject; - - ResponseObject.BeginObject("Error"sv); - ResponseObject.AddInteger("RequestIndex"sv, RequestIndex); - ResponseObject.EndObject(); - - SendStreamResponse(RequestMessage.SocketId(), - RequestMessage.CorrelationId(), - std::move(ResponseObject.Save()), - ++ResponseCount == RequestCount); + ResponseStream.Append([RequestIndex](CbPackage&, CbWriter& ResponseObject) { + ResponseObject.BeginObject(); + ResponseObject.AddInteger("RequestIndex"sv, RequestIndex); + ResponseObject.EndObject(); + }); ZEN_DEBUG("MISS - '{}/{}' '{}'", Key.Bucket, Key.Hash, ToString(ZenContentType::kCompressedBinary)); } } - ZEN_ASSERT(ResponseCount == RequestCount); + const bool Complete = ResponseStream.Flush(); + ZEN_ASSERT(Complete); return true; } |