diff options
Diffstat (limited to 'zenserver/cache/structuredcache.cpp')
| -rw-r--r-- | zenserver/cache/structuredcache.cpp | 569 |
1 files changed, 569 insertions, 0 deletions
diff --git a/zenserver/cache/structuredcache.cpp b/zenserver/cache/structuredcache.cpp index 8ae531720..4c93e8258 100644 --- a/zenserver/cache/structuredcache.cpp +++ b/zenserver/cache/structuredcache.cpp @@ -162,6 +162,575 @@ HttpStructuredCacheService::HandleRequest(HttpServerRequest& Request) } void +HttpStructuredCacheService::RegisterHandlers(WebSocketServer& Server) +{ + Server.RegisterRequestHandler("GetBinaryCacheValue"sv, *this); + Server.RegisterRequestHandler("GetCacheValues"sv, *this); + Server.RegisterRequestHandler("GetCacheRecords"sv, *this); + Server.RegisterRequestHandler("GetCacheChunks"sv, *this); +} + +class ResponseStreamWriter +{ +public: + ResponseStreamWriter(WebSocketServer& Server, + const WebSocketMessage& Request, + uint32_t RequestCount, + uint32_t MaxBatchCount = ~uint32_t(0)); + ~ResponseStreamWriter(); + + template<typename Fn> + void Append(Fn&& Append) + { + ZEN_ASSERT(m_ResponseCount < m_RequestCount); + + if (m_CurrentBatchCount == 0) + { + m_Writer.BeginObject(); + m_Writer.BeginArray("Result"sv); + } + + Append(m_StreamResponse, m_Writer); + + if (++m_CurrentBatchCount >= m_MaxBatchCount) + { + SendCurrentBatch(); + } + } + + bool Flush(); + +private: + bool SendCurrentBatch(); + + WebSocketServer& m_Server; + CbWriter m_Writer; + CbPackage m_StreamResponse; + WebSocketId m_SocketId; + uint32_t m_CorrelationId; + uint32_t m_RequestCount; + uint32_t m_MaxBatchCount; + uint32_t m_CurrentBatchCount{0}; + uint32_t m_ResponseCount{0}; +}; + +ResponseStreamWriter::ResponseStreamWriter(WebSocketServer& Server, + const WebSocketMessage& Request, + uint32_t RequestCount, + uint32_t MaxBatchCount) +: m_Server(Server) +, m_SocketId(Request.SocketId()) +, m_CorrelationId(Request.CorrelationId()) +, m_RequestCount(RequestCount) +, m_MaxBatchCount(MaxBatchCount) +{ +} + +ResponseStreamWriter::~ResponseStreamWriter() +{ + Flush(); +} + +bool +ResponseStreamWriter::SendCurrentBatch() +{ + if (m_CurrentBatchCount > 0) + { + m_Writer.EndArray(); + m_Writer.EndObject(); + + m_ResponseCount += m_CurrentBatchCount; + m_CurrentBatchCount = 0; + + CbPackage StreamResponse = std::move(m_StreamResponse); + StreamResponse.SetObject(m_Writer.Save().AsObject()); + + m_StreamResponse.Reset(); + m_Writer.Reset(); + + WebSocketMessage Message; + Message.SetMessageType(m_ResponseCount == m_RequestCount ? WebSocketMessageType::kStreamCompleteResponse + : WebSocketMessageType::kStreamResponse); + Message.SetCorrelationId(m_CorrelationId); + Message.SetSocketId(m_SocketId); + Message.SetBody(std::move(StreamResponse)); + + m_Server.SendResponse(std::move(Message)); + } + + return m_ResponseCount == m_RequestCount; +} + +bool +ResponseStreamWriter::Flush() +{ + return SendCurrentBatch(); +} + +bool +HttpStructuredCacheService::HandleRequest(const WebSocketMessage& RequestMessage) +{ + const uint32_t kMaxBatchCount = 16; + + CbObjectView Request = RequestMessage.Body().GetObject(); + + const auto Method = Request["Method"].AsString(); + CbObjectView Params = Request["Params"sv].AsObjectView(); + + if (Method == "GetBinaryCacheValue"sv) + { + ZEN_TRACE_CPU("Z$::WS_GetBinaryCacheValue"); + + // CachePolicy Policy; + CbObjectView KeyObject = Params["Key"sv].AsObjectView(); + CacheKey Key = CacheKey::Create(KeyObject["Bucket"sv].AsString(), KeyObject["Hash"sv].AsHash()); + + ZenCacheValue CacheValue; + const bool InLocalCache = m_CacheStore.Get(Key.Bucket, Key.Hash, CacheValue); + + CbPackage Response; + + if (InLocalCache) + { + m_CacheStats.HitCount++; + + CbAttachment Attachment(SharedBuffer(CacheValue.Value)); + + CbObjectWriter ResponseObject; + ResponseObject.AddAttachment("Result", Attachment); + Response.AddAttachment(std::move(Attachment)); + Response.SetObject(ResponseObject.Save()); + + ZenContentType ContentType = CacheValue.Value.GetContentType(); + + ZEN_DEBUG("HIT - '{}/{}' {} '{}' (LOCAL)", Key.Bucket, Key.Hash, NiceBytes(CacheValue.Value.Size()), ToString(ContentType)); + } + else + { + m_CacheStats.MissCount++; + + CbObjectWriter ResponseObject; + ResponseObject << "Error"sv + << "Not Found"sv; + Response.SetObject(ResponseObject.Save()); + + ZEN_DEBUG("MISS - '{}/{}' '{}'", Key.Bucket, Key.Hash, ToString(ZenContentType::kBinary)); + } + + WebSocketMessage ResponseMessage; + ResponseMessage.SetMessageType(WebSocketMessageType::kResponse); + ResponseMessage.SetCorrelationId(RequestMessage.CorrelationId()); + ResponseMessage.SetSocketId(RequestMessage.SocketId()); + ResponseMessage.SetBody(std::move(Response)); + + SocketServer().SendResponse(std::move(ResponseMessage)); + + return true; + } + + if (Method == "GetCacheValues"sv) + { + ZEN_TRACE_CPU("Z$::WS_GetCacheValues"); + + const std::string_view DefaultPolicyText = Params["DefaultPolicy"sv].AsString(); + CachePolicy DefaultPolicy = DefaultPolicyText.empty() ? CachePolicy::Default : ParseCachePolicy(DefaultPolicyText); + + CbArrayView Requests = Params["Requests"sv].AsArrayView(); + const uint64_t RequestCount = Requests.Num(); + + ResponseStreamWriter ResponseStream(SocketServer(), RequestMessage, uint32_t(RequestCount), kMaxBatchCount); + + for (int32_t Idx = 0; CbFieldView RequestField : Params["Requests"sv]) + { + const int32_t RequestIndex = Idx++; + + CbObjectView RequestObject = RequestField.AsObjectView(); + + CbObjectView KeyObject = RequestObject["Key"sv].AsObjectView(); + CacheKey Key = CacheKey::Create(KeyObject["Bucket"sv].AsString(), KeyObject["Hash"sv].AsHash()); + std::string_view PolicyText = RequestObject["Policy"sv].AsString(); + CachePolicy Policy = PolicyText.empty() ? DefaultPolicy : ParseCachePolicy(PolicyText); + + CompressedBuffer Compressed; + bool InLocalCache = false; + + if (EnumHasAllFlags(Policy, CachePolicy::QueryLocal)) + { + ZenCacheValue CacheValue; + if (m_CacheStore.Get(Key.Bucket, Key.Hash, CacheValue)) + { + Compressed = CompressedBuffer::FromCompressed(SharedBuffer(CacheValue.Value)); + InLocalCache = true; + } + } + + if (Compressed.IsNull() && EnumHasAllFlags(Policy, CachePolicy::QueryRemote)) + { + if (auto UpstreamResult = m_UpstreamCache.GetCacheRecord({Key.Bucket, Key.Hash}, ZenContentType::kCompressedBinary); + UpstreamResult.Success) + { + Compressed = CompressedBuffer::FromCompressed(SharedBuffer(UpstreamResult.Value)); + + if (Compressed) + { + UpstreamResult.Value.SetContentType(ZenContentType::kCompressedBinary); + m_CacheStore.Put(Key.Bucket, Key.Hash, ZenCacheValue{UpstreamResult.Value}); + } + } + } + + const uint64_t RawSize = Compressed.IsNull() ? 0 : Compressed.GetRawSize(); + + ResponseStream.Append([&RequestIndex, &Policy, &Compressed, &RawSize](CbPackage& Response, CbWriter& ResponseObject) { + ResponseObject.BeginObject(); + ResponseObject.AddInteger("RequestIndex"sv, RequestIndex); + + if (Compressed) + { + const IoHash RawHash = IoHash::FromBLAKE3(Compressed.GetRawHash()); + ResponseObject.AddHash("RawHash"sv, RawHash); + + if (EnumHasAllFlags(Policy, CachePolicy::SkipData)) + { + ResponseObject.AddInteger("RawSize"sv, RawSize); + } + else + { + Response.AddAttachment(CbAttachment(std::move(Compressed))); + } + } + + ResponseObject.EndObject(); + }); + + if (RawSize > 0) + { + ZEN_DEBUG("HIT - '{}/{}' {} '{}'", Key.Bucket, Key.Hash, NiceBytes(RawSize), ToString(ZenContentType::kCompressedBinary)); + } + else + { + ZEN_DEBUG("MISS - '{}/{}'", Key.Bucket, Key.Hash); + } + } + + const bool IsStreamComplete = ResponseStream.Flush(); + ZEN_ASSERT(IsStreamComplete); + + return true; + } + + if (Method == "GetCacheRecords"sv) + { + ZEN_TRACE_CPU("Z$::WS_GetCacheRecords"); + + const std::string_view DefaultPolicyText = Params["DefaultPolicy"sv].AsString(); + CachePolicy DefaultPolicy = DefaultPolicyText.empty() ? CachePolicy::Default : ParseCachePolicy(DefaultPolicyText); + std::vector<CacheKeyRequest> UpstreamRequests; + + CbArrayView Requests = Params["Requests"sv].AsArrayView(); + const uint64_t RequestCount = Requests.Num(); + + ResponseStreamWriter ResponseStream(SocketServer(), RequestMessage, uint32_t(RequestCount), kMaxBatchCount); + + for (int32_t Idx = 0; CbFieldView RequestField : Requests) + { + const int32_t RequestIndex = Idx++; + + CbObjectView RequestObject = RequestField.AsObjectView(); + + CbObjectView KeyObject = RequestObject["Key"sv].AsObjectView(); + CacheKey Key = CacheKey::Create(KeyObject["Bucket"sv].AsString(), KeyObject["Hash"sv].AsHash()); + CacheRecordPolicy Policy = LoadCacheRecordPolicy(RequestObject["Policy"sv].AsObjectView(), DefaultPolicy); + + if (EnumHasAllFlags(Policy.GetRecordPolicy(), CachePolicy::QueryLocal)) + { + ZenCacheValue RecordCacheValue; + if (m_CacheStore.Get(Key.Bucket, Key.Hash, RecordCacheValue)) + { + CbObjectView RecordObject = CbObjectView(RecordCacheValue.Value.GetData()); + CbArrayView RecordValuesView = RecordObject["Values"sv].AsArrayView(); + uint64_t TotalSize = RecordCacheValue.Value.GetSize(); + + std::vector<IoBuffer> Values; + Values.reserve(RecordValuesView.Num()); + + for (CbFieldView ValueField : RecordValuesView) + { + CbObjectView ValueObject = ValueField.AsObjectView(); + + Oid ValueId = ValueObject["Id"sv].AsObjectId(); + IoHash RawHash = ValueObject["RawHash"sv].AsHash(); + CachePolicy ValuePolicy = Policy.GetValuePolicy(ValueId); + + if (EnumHasAllFlags(ValuePolicy, CachePolicy::QueryLocal)) + { + if (IoBuffer Value = m_CidStore.FindChunkByCid(RawHash)) + { + Values.push_back(Value); + TotalSize += Value.GetSize(); + } + } + } + + const bool IsComplete = Values.size() == RecordValuesView.Num(); + + if (IsComplete) + { + ResponseStream.Append([&](CbPackage& Response, CbWriter& ResponseObject) { + ResponseObject.BeginObject(); + ResponseObject.AddInteger("RequestIndex"sv, RequestIndex); + ResponseObject.AddObject("Record"sv, CbObject::Clone(RecordObject)); + ResponseObject.EndObject(); + + for (const IoBuffer& Value : Values) + { + Response.AddAttachment(CbAttachment(CompressedBuffer::FromCompressed(SharedBuffer(Value)))); + } + }); + + ZEN_DEBUG("HIT - '{}/{}' {} '{}' (LOCAL)", + Key.Bucket, + Key.Hash, + NiceBytes(TotalSize), + ToString(ZenContentType::kCbObject)); + + continue; + } + } + } + + // if (EnumHasAllFlags(Policy.GetRecordPolicy(), CachePolicy::QueryRemote) == false) + { + ResponseStream.Append([&](CbPackage&, CbWriter& ResponseObject) { + ResponseObject.BeginObject(); + ResponseObject.AddInteger("RequestIndex"sv, RequestIndex); + ResponseObject.EndObject(); + }); + + ZEN_DEBUG("MISS - '{}/{}'", Key.Bucket, Key.Hash); + + continue; + } + + // UpstreamRequests.push_back({.Key = Key, .Policy = Policy, .UserData = uint64_t(RequestIndex)}); + } + + bool IsStreamComplete = ResponseStream.Flush(); + + if (UpstreamRequests.empty()) + { + ZEN_ASSERT(IsStreamComplete); + return true; + } + + auto OnCacheRecordGetComplete = [this, &ResponseStream](CacheRecordGetCompleteParams&& Params) mutable { + if (Params.Record) + { + CbArrayView RecordValuesView = Params.Record["Values"sv].AsArrayView(); + uint32_t AttachmentCount{}; + uint64_t TotalSize = Params.Record.GetSize(); + std::vector<CbAttachment> Attachments; + + for (CbFieldView ValueField : RecordValuesView) + { + CbObjectView ValueObject = ValueField.AsObjectView(); + + Oid ValueId = ValueObject["Id"sv].AsObjectId(); + IoHash RawHash = ValueObject["RawHash"sv].AsHash(); + CachePolicy ValuePolicy = Params.Request.Policy.GetValuePolicy(ValueId); + + if (const CbAttachment* Attachment = Params.Package.FindAttachment(RawHash)) + { + if (CompressedBuffer Compressed = Attachment->AsCompressedBinary()) + { + Attachments.emplace_back(Compressed); + // Response.AddAttachment(CbAttachment(Compressed)); + AttachmentCount++; + TotalSize += Compressed.GetCompressedSize(); + + if (EnumHasAllFlags(ValuePolicy, CachePolicy::StoreLocal)) + { + IoBuffer Value = Compressed.GetCompressed().Flatten().AsIoBuffer(); + Value.SetContentType(ZenContentType::kCompressedBinary); + m_CacheStore.Put(Params.Request.Key.Bucket, Params.Request.Key.Hash, {.Value = Value}); + } + } + } + else if (EnumHasAllFlags(Params.Request.Policy.GetRecordPolicy(), CachePolicy::QueryLocal)) + { + if (IoBuffer Chunk = m_CidStore.FindChunkByCid(RawHash)) + { + // Response.AddAttachment(CbAttachment(CompressedBuffer::FromCompressed(SharedBuffer(Chunk)))); + Attachments.emplace_back(CompressedBuffer::FromCompressed(SharedBuffer(Chunk))); + AttachmentCount++; + } + } + } + + const bool IsComplete = AttachmentCount == RecordValuesView.Num(); + const bool AllowPartial = EnumHasAllFlags(Params.Request.Policy.GetRecordPolicy(), CachePolicy::PartialRecord); + + if (IsComplete || AllowPartial) + { + ResponseStream.Append([&](CbPackage& Response, CbWriter& ResponseObject) { + for (CbAttachment& Attachment : Attachments) + { + Response.AddAttachment(std::move(Attachment)); + } + + ResponseObject.BeginObject(); + ResponseObject.AddInteger("RequestIndex"sv, int32_t(Params.Request.UserData)); + ResponseObject.AddObject("Record"sv, CbObject::Clone(Params.Record)); + ResponseObject.EndObject(); + }); + + ZEN_DEBUG("HIT - '{}/{}' {} '{}' (UPSTREAM)", + Params.Request.Key.Bucket, + Params.Request.Key.Hash, + NiceBytes(TotalSize), + ToString(ZenContentType::kCbObject)); + + return; + } + } + + ResponseStream.Append([&](CbPackage&, CbWriter& ResponseObject) { + ResponseObject.BeginObject(); + ResponseObject.AddInteger("RequestIndex"sv, int32_t(Params.Request.UserData)); + ResponseObject.EndObject(); + }); + + ZEN_DEBUG("MISS - '{}/{}'", Params.Request.Key.Bucket, Params.Request.Key.Hash); + }; + + // TODO: Fix this + std::vector<CacheKeyRequest*> RequestPtrs; + RequestPtrs.reserve(UpstreamRequests.size()); + + for (CacheKeyRequest& Req : UpstreamRequests) + { + RequestPtrs.push_back(&Req); + } + + m_UpstreamCache.GetCacheRecords(RequestPtrs, std::move(OnCacheRecordGetComplete)); + + IsStreamComplete = ResponseStream.Flush(); + ZEN_ASSERT(IsStreamComplete); + + return true; + } + + if (Method == "GetCacheChunks"sv) + { + ZEN_TRACE_CPU("Z$::WS_GetCacheChunks"); + + auto GetCidFromValueId = [](const Oid& ValueId, CbObjectView Record, uint64_t& OutRawSize) -> IoHash { + CbArrayView Values = Record["Values"sv].AsArrayView(); + + for (CbFieldView Value : Values) + { + CbObjectView ValueObject = Value.AsObjectView(); + if (ValueObject["Id"sv].AsObjectId() == ValueId) + { + OutRawSize = ValueObject["RawSize"sv].AsUInt64(); + return ValueObject["RawHash"sv].AsHash(); + } + } + + return IoHash::Zero; + }; + + CacheKey CurrentKey; + IoBuffer CurrentRecordValue; + + CbArrayView Requests = Params["Requests"sv].AsArrayView(); + const uint64_t RequestCount = Requests.Num(); + + ResponseStreamWriter ResponseStream(SocketServer(), RequestMessage, uint32_t(RequestCount), kMaxBatchCount); + + for (int32_t Idx = 0; CbFieldView RequestField : Requests) + { + CbObjectView RequestObject = RequestField.AsObjectView(); + const int32_t RequestIndex = Idx++; + + CbObjectView KeyObject = RequestObject["Key"sv].AsObjectView(); + CacheKey Key = CacheKey::Create(KeyObject["Bucket"sv].AsString(), KeyObject["Hash"sv].AsHash()); + const IoHash RawHash = RequestObject["ChunkId"sv].AsHash(); + const Oid ValueId = RequestObject["ValueId"sv].AsObjectId(); + const uint64_t RequestedRawOffset = RequestObject["RawOffset"sv].AsUInt64(); + const uint64_t RequestedRawSize = RequestObject["RawSize"sv].AsUInt64(UINT64_MAX); + + IoHash Cid = RawHash; + uint64_t RawSize = 0; + + if (RawHash == IoHash::Zero) + { + if (CurrentKey != Key || CurrentRecordValue.GetSize() == 0) + { + ZenCacheValue RecordCacheValue; + if (m_CacheStore.Get(Key.Bucket, Key.Hash, RecordCacheValue)) + { + CurrentRecordValue = RecordCacheValue.Value; + CurrentKey = Key; + } + } + + if (CurrentRecordValue) + { + Cid = GetCidFromValueId(ValueId, CbObjectView(CurrentRecordValue.GetData()), RawSize); + } + } + + CompressedBuffer Compressed; + if (IoBuffer Chunk = m_CidStore.FindChunkByCid(Cid)) + { + Compressed = CompressedBuffer::FromCompressed(SharedBuffer(Chunk)); + RawSize = Compressed.GetRawSize(); + } + + if (Compressed || RawSize > 0) + { + ResponseStream.Append([&Compressed, RequestIndex, &Cid, RawSize](CbPackage& Response, CbWriter& ResponseObject) { + if (Compressed) + { + Response.AddAttachment(CbAttachment(std::move(Compressed))); + } + + ResponseObject.BeginObject(); + ResponseObject.AddInteger("RequestIndex"sv, RequestIndex); + ResponseObject.AddHash("RawHash"sv, Cid); + ResponseObject.AddInteger("RawSize"sv, RawSize); + ResponseObject.EndObject(); + }); + + ZEN_DEBUG("HIT - '{}/{}' {} '{}' (LOCAL)", + Key.Bucket, + Key.Hash, + NiceBytes(RawSize), + ToString(ZenContentType::kCompressedBinary)); + } + else + { + ResponseStream.Append([RequestIndex](CbPackage&, CbWriter& ResponseObject) { + ResponseObject.BeginObject(); + ResponseObject.AddInteger("RequestIndex"sv, RequestIndex); + ResponseObject.EndObject(); + }); + + ZEN_DEBUG("MISS - '{}/{}' '{}'", Key.Bucket, Key.Hash, ToString(ZenContentType::kCompressedBinary)); + } + } + + const bool Complete = ResponseStream.Flush(); + ZEN_ASSERT(Complete); + + return true; + } + + return false; +} + +void HttpStructuredCacheService::HandleCacheBucketRequest(HttpServerRequest& Request, std::string_view Bucket) { switch (Request.RequestVerb()) |