aboutsummaryrefslogtreecommitdiff
path: root/zenserver/cache/structuredcache.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'zenserver/cache/structuredcache.cpp')
-rw-r--r--zenserver/cache/structuredcache.cpp569
1 files changed, 569 insertions, 0 deletions
diff --git a/zenserver/cache/structuredcache.cpp b/zenserver/cache/structuredcache.cpp
index 8ae531720..4c93e8258 100644
--- a/zenserver/cache/structuredcache.cpp
+++ b/zenserver/cache/structuredcache.cpp
@@ -162,6 +162,575 @@ HttpStructuredCacheService::HandleRequest(HttpServerRequest& Request)
}
void
+HttpStructuredCacheService::RegisterHandlers(WebSocketServer& Server)
+{
+ Server.RegisterRequestHandler("GetBinaryCacheValue"sv, *this);
+ Server.RegisterRequestHandler("GetCacheValues"sv, *this);
+ Server.RegisterRequestHandler("GetCacheRecords"sv, *this);
+ Server.RegisterRequestHandler("GetCacheChunks"sv, *this);
+}
+
+class ResponseStreamWriter
+{
+public:
+ ResponseStreamWriter(WebSocketServer& Server,
+ const WebSocketMessage& Request,
+ uint32_t RequestCount,
+ uint32_t MaxBatchCount = ~uint32_t(0));
+ ~ResponseStreamWriter();
+
+ template<typename Fn>
+ void Append(Fn&& Append)
+ {
+ ZEN_ASSERT(m_ResponseCount < m_RequestCount);
+
+ if (m_CurrentBatchCount == 0)
+ {
+ m_Writer.BeginObject();
+ m_Writer.BeginArray("Result"sv);
+ }
+
+ Append(m_StreamResponse, m_Writer);
+
+ if (++m_CurrentBatchCount >= m_MaxBatchCount)
+ {
+ SendCurrentBatch();
+ }
+ }
+
+ bool Flush();
+
+private:
+ bool SendCurrentBatch();
+
+ WebSocketServer& m_Server;
+ CbWriter m_Writer;
+ CbPackage m_StreamResponse;
+ WebSocketId m_SocketId;
+ uint32_t m_CorrelationId;
+ uint32_t m_RequestCount;
+ uint32_t m_MaxBatchCount;
+ uint32_t m_CurrentBatchCount{0};
+ uint32_t m_ResponseCount{0};
+};
+
+ResponseStreamWriter::ResponseStreamWriter(WebSocketServer& Server,
+ const WebSocketMessage& Request,
+ uint32_t RequestCount,
+ uint32_t MaxBatchCount)
+: m_Server(Server)
+, m_SocketId(Request.SocketId())
+, m_CorrelationId(Request.CorrelationId())
+, m_RequestCount(RequestCount)
+, m_MaxBatchCount(MaxBatchCount)
+{
+}
+
+ResponseStreamWriter::~ResponseStreamWriter()
+{
+ Flush();
+}
+
+bool
+ResponseStreamWriter::SendCurrentBatch()
+{
+ if (m_CurrentBatchCount > 0)
+ {
+ m_Writer.EndArray();
+ m_Writer.EndObject();
+
+ m_ResponseCount += m_CurrentBatchCount;
+ m_CurrentBatchCount = 0;
+
+ CbPackage StreamResponse = std::move(m_StreamResponse);
+ StreamResponse.SetObject(m_Writer.Save().AsObject());
+
+ m_StreamResponse.Reset();
+ m_Writer.Reset();
+
+ WebSocketMessage Message;
+ Message.SetMessageType(m_ResponseCount == m_RequestCount ? WebSocketMessageType::kStreamCompleteResponse
+ : WebSocketMessageType::kStreamResponse);
+ Message.SetCorrelationId(m_CorrelationId);
+ Message.SetSocketId(m_SocketId);
+ Message.SetBody(std::move(StreamResponse));
+
+ m_Server.SendResponse(std::move(Message));
+ }
+
+ return m_ResponseCount == m_RequestCount;
+}
+
+bool
+ResponseStreamWriter::Flush()
+{
+ return SendCurrentBatch();
+}
+
+bool
+HttpStructuredCacheService::HandleRequest(const WebSocketMessage& RequestMessage)
+{
+ const uint32_t kMaxBatchCount = 16;
+
+ CbObjectView Request = RequestMessage.Body().GetObject();
+
+ const auto Method = Request["Method"].AsString();
+ CbObjectView Params = Request["Params"sv].AsObjectView();
+
+ if (Method == "GetBinaryCacheValue"sv)
+ {
+ ZEN_TRACE_CPU("Z$::WS_GetBinaryCacheValue");
+
+ // CachePolicy Policy;
+ CbObjectView KeyObject = Params["Key"sv].AsObjectView();
+ CacheKey Key = CacheKey::Create(KeyObject["Bucket"sv].AsString(), KeyObject["Hash"sv].AsHash());
+
+ ZenCacheValue CacheValue;
+ const bool InLocalCache = m_CacheStore.Get(Key.Bucket, Key.Hash, CacheValue);
+
+ CbPackage Response;
+
+ if (InLocalCache)
+ {
+ m_CacheStats.HitCount++;
+
+ CbAttachment Attachment(SharedBuffer(CacheValue.Value));
+
+ CbObjectWriter ResponseObject;
+ ResponseObject.AddAttachment("Result", Attachment);
+ Response.AddAttachment(std::move(Attachment));
+ Response.SetObject(ResponseObject.Save());
+
+ ZenContentType ContentType = CacheValue.Value.GetContentType();
+
+ ZEN_DEBUG("HIT - '{}/{}' {} '{}' (LOCAL)", Key.Bucket, Key.Hash, NiceBytes(CacheValue.Value.Size()), ToString(ContentType));
+ }
+ else
+ {
+ m_CacheStats.MissCount++;
+
+ CbObjectWriter ResponseObject;
+ ResponseObject << "Error"sv
+ << "Not Found"sv;
+ Response.SetObject(ResponseObject.Save());
+
+ ZEN_DEBUG("MISS - '{}/{}' '{}'", Key.Bucket, Key.Hash, ToString(ZenContentType::kBinary));
+ }
+
+ WebSocketMessage ResponseMessage;
+ ResponseMessage.SetMessageType(WebSocketMessageType::kResponse);
+ ResponseMessage.SetCorrelationId(RequestMessage.CorrelationId());
+ ResponseMessage.SetSocketId(RequestMessage.SocketId());
+ ResponseMessage.SetBody(std::move(Response));
+
+ SocketServer().SendResponse(std::move(ResponseMessage));
+
+ return true;
+ }
+
+ if (Method == "GetCacheValues"sv)
+ {
+ ZEN_TRACE_CPU("Z$::WS_GetCacheValues");
+
+ const std::string_view DefaultPolicyText = Params["DefaultPolicy"sv].AsString();
+ CachePolicy DefaultPolicy = DefaultPolicyText.empty() ? CachePolicy::Default : ParseCachePolicy(DefaultPolicyText);
+
+ CbArrayView Requests = Params["Requests"sv].AsArrayView();
+ const uint64_t RequestCount = Requests.Num();
+
+ ResponseStreamWriter ResponseStream(SocketServer(), RequestMessage, uint32_t(RequestCount), kMaxBatchCount);
+
+ for (int32_t Idx = 0; CbFieldView RequestField : Params["Requests"sv])
+ {
+ const int32_t RequestIndex = Idx++;
+
+ CbObjectView RequestObject = RequestField.AsObjectView();
+
+ CbObjectView KeyObject = RequestObject["Key"sv].AsObjectView();
+ CacheKey Key = CacheKey::Create(KeyObject["Bucket"sv].AsString(), KeyObject["Hash"sv].AsHash());
+ std::string_view PolicyText = RequestObject["Policy"sv].AsString();
+ CachePolicy Policy = PolicyText.empty() ? DefaultPolicy : ParseCachePolicy(PolicyText);
+
+ CompressedBuffer Compressed;
+ bool InLocalCache = false;
+
+ if (EnumHasAllFlags(Policy, CachePolicy::QueryLocal))
+ {
+ ZenCacheValue CacheValue;
+ if (m_CacheStore.Get(Key.Bucket, Key.Hash, CacheValue))
+ {
+ Compressed = CompressedBuffer::FromCompressed(SharedBuffer(CacheValue.Value));
+ InLocalCache = true;
+ }
+ }
+
+ if (Compressed.IsNull() && EnumHasAllFlags(Policy, CachePolicy::QueryRemote))
+ {
+ if (auto UpstreamResult = m_UpstreamCache.GetCacheRecord({Key.Bucket, Key.Hash}, ZenContentType::kCompressedBinary);
+ UpstreamResult.Success)
+ {
+ Compressed = CompressedBuffer::FromCompressed(SharedBuffer(UpstreamResult.Value));
+
+ if (Compressed)
+ {
+ UpstreamResult.Value.SetContentType(ZenContentType::kCompressedBinary);
+ m_CacheStore.Put(Key.Bucket, Key.Hash, ZenCacheValue{UpstreamResult.Value});
+ }
+ }
+ }
+
+ const uint64_t RawSize = Compressed.IsNull() ? 0 : Compressed.GetRawSize();
+
+ ResponseStream.Append([&RequestIndex, &Policy, &Compressed, &RawSize](CbPackage& Response, CbWriter& ResponseObject) {
+ ResponseObject.BeginObject();
+ ResponseObject.AddInteger("RequestIndex"sv, RequestIndex);
+
+ if (Compressed)
+ {
+ const IoHash RawHash = IoHash::FromBLAKE3(Compressed.GetRawHash());
+ ResponseObject.AddHash("RawHash"sv, RawHash);
+
+ if (EnumHasAllFlags(Policy, CachePolicy::SkipData))
+ {
+ ResponseObject.AddInteger("RawSize"sv, RawSize);
+ }
+ else
+ {
+ Response.AddAttachment(CbAttachment(std::move(Compressed)));
+ }
+ }
+
+ ResponseObject.EndObject();
+ });
+
+ if (RawSize > 0)
+ {
+ ZEN_DEBUG("HIT - '{}/{}' {} '{}'", Key.Bucket, Key.Hash, NiceBytes(RawSize), ToString(ZenContentType::kCompressedBinary));
+ }
+ else
+ {
+ ZEN_DEBUG("MISS - '{}/{}'", Key.Bucket, Key.Hash);
+ }
+ }
+
+ const bool IsStreamComplete = ResponseStream.Flush();
+ ZEN_ASSERT(IsStreamComplete);
+
+ return true;
+ }
+
+ if (Method == "GetCacheRecords"sv)
+ {
+ ZEN_TRACE_CPU("Z$::WS_GetCacheRecords");
+
+ const std::string_view DefaultPolicyText = Params["DefaultPolicy"sv].AsString();
+ CachePolicy DefaultPolicy = DefaultPolicyText.empty() ? CachePolicy::Default : ParseCachePolicy(DefaultPolicyText);
+ std::vector<CacheKeyRequest> UpstreamRequests;
+
+ CbArrayView Requests = Params["Requests"sv].AsArrayView();
+ const uint64_t RequestCount = Requests.Num();
+
+ ResponseStreamWriter ResponseStream(SocketServer(), RequestMessage, uint32_t(RequestCount), kMaxBatchCount);
+
+ for (int32_t Idx = 0; CbFieldView RequestField : Requests)
+ {
+ const int32_t RequestIndex = Idx++;
+
+ CbObjectView RequestObject = RequestField.AsObjectView();
+
+ CbObjectView KeyObject = RequestObject["Key"sv].AsObjectView();
+ CacheKey Key = CacheKey::Create(KeyObject["Bucket"sv].AsString(), KeyObject["Hash"sv].AsHash());
+ CacheRecordPolicy Policy = LoadCacheRecordPolicy(RequestObject["Policy"sv].AsObjectView(), DefaultPolicy);
+
+ if (EnumHasAllFlags(Policy.GetRecordPolicy(), CachePolicy::QueryLocal))
+ {
+ ZenCacheValue RecordCacheValue;
+ if (m_CacheStore.Get(Key.Bucket, Key.Hash, RecordCacheValue))
+ {
+ CbObjectView RecordObject = CbObjectView(RecordCacheValue.Value.GetData());
+ CbArrayView RecordValuesView = RecordObject["Values"sv].AsArrayView();
+ uint64_t TotalSize = RecordCacheValue.Value.GetSize();
+
+ std::vector<IoBuffer> Values;
+ Values.reserve(RecordValuesView.Num());
+
+ for (CbFieldView ValueField : RecordValuesView)
+ {
+ CbObjectView ValueObject = ValueField.AsObjectView();
+
+ Oid ValueId = ValueObject["Id"sv].AsObjectId();
+ IoHash RawHash = ValueObject["RawHash"sv].AsHash();
+ CachePolicy ValuePolicy = Policy.GetValuePolicy(ValueId);
+
+ if (EnumHasAllFlags(ValuePolicy, CachePolicy::QueryLocal))
+ {
+ if (IoBuffer Value = m_CidStore.FindChunkByCid(RawHash))
+ {
+ Values.push_back(Value);
+ TotalSize += Value.GetSize();
+ }
+ }
+ }
+
+ const bool IsComplete = Values.size() == RecordValuesView.Num();
+
+ if (IsComplete)
+ {
+ ResponseStream.Append([&](CbPackage& Response, CbWriter& ResponseObject) {
+ ResponseObject.BeginObject();
+ ResponseObject.AddInteger("RequestIndex"sv, RequestIndex);
+ ResponseObject.AddObject("Record"sv, CbObject::Clone(RecordObject));
+ ResponseObject.EndObject();
+
+ for (const IoBuffer& Value : Values)
+ {
+ Response.AddAttachment(CbAttachment(CompressedBuffer::FromCompressed(SharedBuffer(Value))));
+ }
+ });
+
+ ZEN_DEBUG("HIT - '{}/{}' {} '{}' (LOCAL)",
+ Key.Bucket,
+ Key.Hash,
+ NiceBytes(TotalSize),
+ ToString(ZenContentType::kCbObject));
+
+ continue;
+ }
+ }
+ }
+
+ // if (EnumHasAllFlags(Policy.GetRecordPolicy(), CachePolicy::QueryRemote) == false)
+ {
+ ResponseStream.Append([&](CbPackage&, CbWriter& ResponseObject) {
+ ResponseObject.BeginObject();
+ ResponseObject.AddInteger("RequestIndex"sv, RequestIndex);
+ ResponseObject.EndObject();
+ });
+
+ ZEN_DEBUG("MISS - '{}/{}'", Key.Bucket, Key.Hash);
+
+ continue;
+ }
+
+ // UpstreamRequests.push_back({.Key = Key, .Policy = Policy, .UserData = uint64_t(RequestIndex)});
+ }
+
+ bool IsStreamComplete = ResponseStream.Flush();
+
+ if (UpstreamRequests.empty())
+ {
+ ZEN_ASSERT(IsStreamComplete);
+ return true;
+ }
+
+ auto OnCacheRecordGetComplete = [this, &ResponseStream](CacheRecordGetCompleteParams&& Params) mutable {
+ if (Params.Record)
+ {
+ CbArrayView RecordValuesView = Params.Record["Values"sv].AsArrayView();
+ uint32_t AttachmentCount{};
+ uint64_t TotalSize = Params.Record.GetSize();
+ std::vector<CbAttachment> Attachments;
+
+ for (CbFieldView ValueField : RecordValuesView)
+ {
+ CbObjectView ValueObject = ValueField.AsObjectView();
+
+ Oid ValueId = ValueObject["Id"sv].AsObjectId();
+ IoHash RawHash = ValueObject["RawHash"sv].AsHash();
+ CachePolicy ValuePolicy = Params.Request.Policy.GetValuePolicy(ValueId);
+
+ if (const CbAttachment* Attachment = Params.Package.FindAttachment(RawHash))
+ {
+ if (CompressedBuffer Compressed = Attachment->AsCompressedBinary())
+ {
+ Attachments.emplace_back(Compressed);
+ // Response.AddAttachment(CbAttachment(Compressed));
+ AttachmentCount++;
+ TotalSize += Compressed.GetCompressedSize();
+
+ if (EnumHasAllFlags(ValuePolicy, CachePolicy::StoreLocal))
+ {
+ IoBuffer Value = Compressed.GetCompressed().Flatten().AsIoBuffer();
+ Value.SetContentType(ZenContentType::kCompressedBinary);
+ m_CacheStore.Put(Params.Request.Key.Bucket, Params.Request.Key.Hash, {.Value = Value});
+ }
+ }
+ }
+ else if (EnumHasAllFlags(Params.Request.Policy.GetRecordPolicy(), CachePolicy::QueryLocal))
+ {
+ if (IoBuffer Chunk = m_CidStore.FindChunkByCid(RawHash))
+ {
+ // Response.AddAttachment(CbAttachment(CompressedBuffer::FromCompressed(SharedBuffer(Chunk))));
+ Attachments.emplace_back(CompressedBuffer::FromCompressed(SharedBuffer(Chunk)));
+ AttachmentCount++;
+ }
+ }
+ }
+
+ const bool IsComplete = AttachmentCount == RecordValuesView.Num();
+ const bool AllowPartial = EnumHasAllFlags(Params.Request.Policy.GetRecordPolicy(), CachePolicy::PartialRecord);
+
+ if (IsComplete || AllowPartial)
+ {
+ ResponseStream.Append([&](CbPackage& Response, CbWriter& ResponseObject) {
+ for (CbAttachment& Attachment : Attachments)
+ {
+ Response.AddAttachment(std::move(Attachment));
+ }
+
+ ResponseObject.BeginObject();
+ ResponseObject.AddInteger("RequestIndex"sv, int32_t(Params.Request.UserData));
+ ResponseObject.AddObject("Record"sv, CbObject::Clone(Params.Record));
+ ResponseObject.EndObject();
+ });
+
+ ZEN_DEBUG("HIT - '{}/{}' {} '{}' (UPSTREAM)",
+ Params.Request.Key.Bucket,
+ Params.Request.Key.Hash,
+ NiceBytes(TotalSize),
+ ToString(ZenContentType::kCbObject));
+
+ return;
+ }
+ }
+
+ ResponseStream.Append([&](CbPackage&, CbWriter& ResponseObject) {
+ ResponseObject.BeginObject();
+ ResponseObject.AddInteger("RequestIndex"sv, int32_t(Params.Request.UserData));
+ ResponseObject.EndObject();
+ });
+
+ ZEN_DEBUG("MISS - '{}/{}'", Params.Request.Key.Bucket, Params.Request.Key.Hash);
+ };
+
+ // TODO: Fix this
+ std::vector<CacheKeyRequest*> RequestPtrs;
+ RequestPtrs.reserve(UpstreamRequests.size());
+
+ for (CacheKeyRequest& Req : UpstreamRequests)
+ {
+ RequestPtrs.push_back(&Req);
+ }
+
+ m_UpstreamCache.GetCacheRecords(RequestPtrs, std::move(OnCacheRecordGetComplete));
+
+ IsStreamComplete = ResponseStream.Flush();
+ ZEN_ASSERT(IsStreamComplete);
+
+ return true;
+ }
+
+ if (Method == "GetCacheChunks"sv)
+ {
+ ZEN_TRACE_CPU("Z$::WS_GetCacheChunks");
+
+ auto GetCidFromValueId = [](const Oid& ValueId, CbObjectView Record, uint64_t& OutRawSize) -> IoHash {
+ CbArrayView Values = Record["Values"sv].AsArrayView();
+
+ for (CbFieldView Value : Values)
+ {
+ CbObjectView ValueObject = Value.AsObjectView();
+ if (ValueObject["Id"sv].AsObjectId() == ValueId)
+ {
+ OutRawSize = ValueObject["RawSize"sv].AsUInt64();
+ return ValueObject["RawHash"sv].AsHash();
+ }
+ }
+
+ return IoHash::Zero;
+ };
+
+ CacheKey CurrentKey;
+ IoBuffer CurrentRecordValue;
+
+ CbArrayView Requests = Params["Requests"sv].AsArrayView();
+ const uint64_t RequestCount = Requests.Num();
+
+ ResponseStreamWriter ResponseStream(SocketServer(), RequestMessage, uint32_t(RequestCount), kMaxBatchCount);
+
+ for (int32_t Idx = 0; CbFieldView RequestField : Requests)
+ {
+ CbObjectView RequestObject = RequestField.AsObjectView();
+ const int32_t RequestIndex = Idx++;
+
+ CbObjectView KeyObject = RequestObject["Key"sv].AsObjectView();
+ CacheKey Key = CacheKey::Create(KeyObject["Bucket"sv].AsString(), KeyObject["Hash"sv].AsHash());
+ const IoHash RawHash = RequestObject["ChunkId"sv].AsHash();
+ const Oid ValueId = RequestObject["ValueId"sv].AsObjectId();
+ const uint64_t RequestedRawOffset = RequestObject["RawOffset"sv].AsUInt64();
+ const uint64_t RequestedRawSize = RequestObject["RawSize"sv].AsUInt64(UINT64_MAX);
+
+ IoHash Cid = RawHash;
+ uint64_t RawSize = 0;
+
+ if (RawHash == IoHash::Zero)
+ {
+ if (CurrentKey != Key || CurrentRecordValue.GetSize() == 0)
+ {
+ ZenCacheValue RecordCacheValue;
+ if (m_CacheStore.Get(Key.Bucket, Key.Hash, RecordCacheValue))
+ {
+ CurrentRecordValue = RecordCacheValue.Value;
+ CurrentKey = Key;
+ }
+ }
+
+ if (CurrentRecordValue)
+ {
+ Cid = GetCidFromValueId(ValueId, CbObjectView(CurrentRecordValue.GetData()), RawSize);
+ }
+ }
+
+ CompressedBuffer Compressed;
+ if (IoBuffer Chunk = m_CidStore.FindChunkByCid(Cid))
+ {
+ Compressed = CompressedBuffer::FromCompressed(SharedBuffer(Chunk));
+ RawSize = Compressed.GetRawSize();
+ }
+
+ if (Compressed || RawSize > 0)
+ {
+ ResponseStream.Append([&Compressed, RequestIndex, &Cid, RawSize](CbPackage& Response, CbWriter& ResponseObject) {
+ if (Compressed)
+ {
+ Response.AddAttachment(CbAttachment(std::move(Compressed)));
+ }
+
+ ResponseObject.BeginObject();
+ ResponseObject.AddInteger("RequestIndex"sv, RequestIndex);
+ ResponseObject.AddHash("RawHash"sv, Cid);
+ ResponseObject.AddInteger("RawSize"sv, RawSize);
+ ResponseObject.EndObject();
+ });
+
+ ZEN_DEBUG("HIT - '{}/{}' {} '{}' (LOCAL)",
+ Key.Bucket,
+ Key.Hash,
+ NiceBytes(RawSize),
+ ToString(ZenContentType::kCompressedBinary));
+ }
+ else
+ {
+ ResponseStream.Append([RequestIndex](CbPackage&, CbWriter& ResponseObject) {
+ ResponseObject.BeginObject();
+ ResponseObject.AddInteger("RequestIndex"sv, RequestIndex);
+ ResponseObject.EndObject();
+ });
+
+ ZEN_DEBUG("MISS - '{}/{}' '{}'", Key.Bucket, Key.Hash, ToString(ZenContentType::kCompressedBinary));
+ }
+ }
+
+ const bool Complete = ResponseStream.Flush();
+ ZEN_ASSERT(Complete);
+
+ return true;
+ }
+
+ return false;
+}
+
+void
HttpStructuredCacheService::HandleCacheBucketRequest(HttpServerRequest& Request, std::string_view Bucket)
{
switch (Request.RequestVerb())