aboutsummaryrefslogtreecommitdiff
path: root/zenserver/cache/structuredcache.cpp
diff options
context:
space:
mode:
authorPer Larsson <[email protected]>2022-03-15 16:03:21 +0100
committerPer Larsson <[email protected]>2022-03-15 16:03:21 +0100
commitf81bdabcfcefa04a7e9b9e9ab7289a7bce41938e (patch)
tree6ec1001bcc3e62748affb4ef38391baa9b302e05 /zenserver/cache/structuredcache.cpp
parentCombine last stream response with stream complete message. (diff)
downloadzen-f81bdabcfcefa04a7e9b9e9ab7289a7bce41938e.tar.xz
zen-f81bdabcfcefa04a7e9b9e9ab7289a7bce41938e.zip
Stream response in batches.
Diffstat (limited to 'zenserver/cache/structuredcache.cpp')
-rw-r--r--zenserver/cache/structuredcache.cpp394
1 files changed, 229 insertions, 165 deletions
diff --git a/zenserver/cache/structuredcache.cpp b/zenserver/cache/structuredcache.cpp
index e3204ac9d..8df6074b7 100644
--- a/zenserver/cache/structuredcache.cpp
+++ b/zenserver/cache/structuredcache.cpp
@@ -170,9 +170,108 @@ HttpStructuredCacheService::RegisterHandlers(WebSocketServer& Server)
Server.RegisterRequestHandler("GetCacheChunks"sv, *this);
}
+class ResponseStreamWriter
+{
+public:
+ ResponseStreamWriter(WebSocketServer& Server,
+ const WebSocketMessage& Request,
+ uint32_t RequestCount,
+ uint32_t MaxBatchCount = ~uint32_t(0));
+ ~ResponseStreamWriter();
+
+ template<typename Fn>
+ void Append(Fn&& Append)
+ {
+ ZEN_ASSERT(m_ResponseCount < m_RequestCount);
+
+ if (m_CurrentBatchCount == 0)
+ {
+ m_Writer.BeginObject();
+ m_Writer.BeginArray("Result"sv);
+ }
+
+ Append(m_StreamResponse, m_Writer);
+
+ if (++m_CurrentBatchCount >= m_MaxBatchCount)
+ {
+ SendCurrentBatch();
+ }
+ }
+
+ bool Flush();
+
+private:
+ bool SendCurrentBatch();
+
+ WebSocketServer& m_Server;
+ CbWriter m_Writer;
+ CbPackage m_StreamResponse;
+ WebSocketId m_SocketId;
+ uint32_t m_CorrelationId;
+ uint32_t m_RequestCount;
+ uint32_t m_MaxBatchCount;
+ uint32_t m_CurrentBatchCount{0};
+ uint32_t m_ResponseCount{0};
+};
+
+ResponseStreamWriter::ResponseStreamWriter(WebSocketServer& Server,
+ const WebSocketMessage& Request,
+ uint32_t RequestCount,
+ uint32_t MaxBatchCount)
+: m_Server(Server)
+, m_SocketId(Request.SocketId())
+, m_CorrelationId(Request.CorrelationId())
+, m_RequestCount(RequestCount)
+, m_MaxBatchCount(MaxBatchCount)
+{
+}
+
+ResponseStreamWriter::~ResponseStreamWriter()
+{
+ Flush();
+}
+
+bool
+ResponseStreamWriter::SendCurrentBatch()
+{
+ if (m_CurrentBatchCount > 0)
+ {
+ m_Writer.EndArray();
+ m_Writer.EndObject();
+
+ m_ResponseCount += m_CurrentBatchCount;
+ m_CurrentBatchCount = 0;
+
+ CbPackage StreamResponse = std::move(m_StreamResponse);
+ StreamResponse.SetObject(m_Writer.Save().AsObject());
+
+ m_StreamResponse.Reset();
+ m_Writer.Reset();
+
+ WebSocketMessage Message;
+ Message.SetMessageType(m_ResponseCount == m_RequestCount ? WebSocketMessageType::kStreamCompleteResponse
+ : WebSocketMessageType::kStreamResponse);
+ Message.SetCorrelationId(m_CorrelationId);
+ Message.SetSocketId(m_SocketId);
+ Message.SetBody(std::move(StreamResponse));
+
+ m_Server.SendResponse(std::move(Message));
+ }
+
+ return m_ResponseCount == m_RequestCount;
+}
+
+bool
+ResponseStreamWriter::Flush()
+{
+ return SendCurrentBatch();
+}
+
bool
HttpStructuredCacheService::HandleRequest(const WebSocketMessage& RequestMessage)
{
+ const uint32_t kMaxBatchCount = 16;
+
CbObjectView Request = RequestMessage.Body().GetObject();
const auto Method = Request["Method"].AsString();
@@ -236,9 +335,10 @@ HttpStructuredCacheService::HandleRequest(const WebSocketMessage& RequestMessage
const std::string_view DefaultPolicyText = Params["DefaultPolicy"sv].AsString();
CachePolicy DefaultPolicy = DefaultPolicyText.empty() ? CachePolicy::Default : ParseCachePolicy(DefaultPolicyText);
- CbArrayView Requests = Params["Requests"sv].AsArrayView();
- const uint64_t RequestCount = Requests.Num();
- uint64_t ResponseCount = 0;
+ CbArrayView Requests = Params["Requests"sv].AsArrayView();
+ const uint64_t RequestCount = Requests.Num();
+
+ ResponseStreamWriter ResponseStream(SocketServer(), RequestMessage, uint32_t(RequestCount), kMaxBatchCount);
for (int32_t Idx = 0; CbFieldView RequestField : Params["Requests"sv])
{
@@ -279,36 +379,29 @@ HttpStructuredCacheService::HandleRequest(const WebSocketMessage& RequestMessage
}
}
- CbPackage Response;
- CbObjectWriter ResponseObject;
-
- ResponseObject.BeginObject("Result"sv);
- ResponseObject.AddInteger("RequestIndex"sv, RequestIndex);
+ const uint64_t RawSize = Compressed.IsNull() ? 0 : Compressed.GetRawSize();
- const IoHash RawHash = IoHash::FromBLAKE3(Compressed.GetRawHash());
- const uint64_t RawSize = Compressed.GetRawSize();
-
- if (Compressed)
- {
- ResponseObject.AddHash("RawHash"sv, RawHash);
+ ResponseStream.Append([&RequestIndex, &Policy, &Compressed, &RawSize](CbPackage& Response, CbWriter& ResponseObject) {
+ ResponseObject.BeginObject();
+ ResponseObject.AddInteger("RequestIndex"sv, RequestIndex);
- if (EnumHasAllFlags(Policy, CachePolicy::SkipData))
- {
- ResponseObject.AddInteger("RawSize"sv, RawSize);
- }
- else
+ if (Compressed)
{
- Response.AddAttachment(CbAttachment(std::move(Compressed)));
- }
- }
+ const IoHash RawHash = IoHash::FromBLAKE3(Compressed.GetRawHash());
+ ResponseObject.AddHash("RawHash"sv, RawHash);
- ResponseObject.EndObject();
- Response.SetObject(ResponseObject.Save());
+ if (EnumHasAllFlags(Policy, CachePolicy::SkipData))
+ {
+ ResponseObject.AddInteger("RawSize"sv, RawSize);
+ }
+ else
+ {
+ Response.AddAttachment(CbAttachment(std::move(Compressed)));
+ }
+ }
- SendStreamResponse(RequestMessage.SocketId(),
- RequestMessage.CorrelationId(),
- std::move(Response),
- ++ResponseCount == RequestCount);
+ ResponseObject.EndObject();
+ });
if (RawSize > 0)
{
@@ -320,7 +413,8 @@ HttpStructuredCacheService::HandleRequest(const WebSocketMessage& RequestMessage
}
}
- ZEN_ASSERT(ResponseCount == RequestCount);
+ const bool IsStreamComplete = ResponseStream.Flush();
+ ZEN_ASSERT(IsStreamComplete);
return true;
}
@@ -333,9 +427,10 @@ HttpStructuredCacheService::HandleRequest(const WebSocketMessage& RequestMessage
CachePolicy DefaultPolicy = DefaultPolicyText.empty() ? CachePolicy::Default : ParseCachePolicy(DefaultPolicyText);
std::vector<CacheKeyRequest> UpstreamRequests;
- CbArrayView Requests = Params["Requests"sv].AsArrayView();
- const uint64_t RequestCount = Requests.Num();
- uint64_t ResponseCount = 0;
+ CbArrayView Requests = Params["Requests"sv].AsArrayView();
+ const uint64_t RequestCount = Requests.Num();
+
+ ResponseStreamWriter ResponseStream(SocketServer(), RequestMessage, uint32_t(RequestCount), kMaxBatchCount);
for (int32_t Idx = 0; CbFieldView RequestField : Requests)
{
@@ -381,25 +476,17 @@ HttpStructuredCacheService::HandleRequest(const WebSocketMessage& RequestMessage
if (IsComplete)
{
- CbPackage Response;
- CbObjectWriter ResponseObject;
-
- ResponseObject.BeginObject("Result"sv);
- ResponseObject.AddInteger("RequestIndex"sv, RequestIndex);
- ResponseObject.AddObject("Record"sv, CbObject::Clone(RecordObject));
- ResponseObject.EndObject();
-
- Response.SetObject(ResponseObject.Save());
-
- for (const IoBuffer& Value : Values)
- {
- Response.AddAttachment(CbAttachment(CompressedBuffer::FromCompressed(SharedBuffer(Value))));
- }
+ ResponseStream.Append([&](CbPackage& Response, CbWriter& ResponseObject) {
+ ResponseObject.BeginObject();
+ ResponseObject.AddInteger("RequestIndex"sv, RequestIndex);
+ ResponseObject.AddObject("Record"sv, CbObject::Clone(RecordObject));
+ ResponseObject.EndObject();
- SendStreamResponse(RequestMessage.SocketId(),
- RequestMessage.CorrelationId(),
- std::move(Response),
- ++ResponseCount == RequestCount);
+ for (const IoBuffer& Value : Values)
+ {
+ Response.AddAttachment(CbAttachment(CompressedBuffer::FromCompressed(SharedBuffer(Value))));
+ }
+ });
ZEN_DEBUG("HIT - '{}/{}' {} '{}' (LOCAL)",
Key.Bucket,
@@ -412,125 +499,109 @@ HttpStructuredCacheService::HandleRequest(const WebSocketMessage& RequestMessage
}
}
- if (EnumHasAllFlags(Policy.GetRecordPolicy(), CachePolicy::QueryRemote) == false)
+ // if (EnumHasAllFlags(Policy.GetRecordPolicy(), CachePolicy::QueryRemote) == false)
{
- CbPackage Response;
- CbObjectWriter ResponseObject;
-
- ResponseObject.BeginObject("Error"sv);
- ResponseObject.AddInteger("RequestIndex"sv, RequestIndex);
- ResponseObject.AddString("Reason"sv, "Not Found"sv);
- ResponseObject.EndObject();
-
- Response.SetObject(ResponseObject.Save());
-
- SendStreamResponse(RequestMessage.SocketId(),
- RequestMessage.CorrelationId(),
- std::move(Response),
- ++ResponseCount == RequestCount);
+ ResponseStream.Append([&](CbPackage&, CbWriter& ResponseObject) {
+ ResponseObject.BeginObject();
+ ResponseObject.AddInteger("RequestIndex"sv, RequestIndex);
+ ResponseObject.EndObject();
+ });
ZEN_DEBUG("MISS - '{}/{}'", Key.Bucket, Key.Hash);
continue;
}
- UpstreamRequests.push_back({.Key = Key, .Policy = Policy, .UserData = uint64_t(RequestIndex)});
+ // UpstreamRequests.push_back({.Key = Key, .Policy = Policy, .UserData = uint64_t(RequestIndex)});
}
+ bool IsStreamComplete = ResponseStream.Flush();
+
if (UpstreamRequests.empty())
{
+ ZEN_ASSERT(IsStreamComplete);
return true;
}
- auto OnCacheRecordGetComplete =
- [this, RequestCount, &ResponseCount, &RequestMessage](CacheRecordGetCompleteParams&& Params) mutable {
- if (Params.Record)
+ auto OnCacheRecordGetComplete = [this, &ResponseStream](CacheRecordGetCompleteParams&& Params) mutable {
+ if (Params.Record)
+ {
+ CbArrayView RecordValuesView = Params.Record["Values"sv].AsArrayView();
+ uint32_t AttachmentCount{};
+ uint64_t TotalSize = Params.Record.GetSize();
+ std::vector<CbAttachment> Attachments;
+
+ for (CbFieldView ValueField : RecordValuesView)
{
- CbPackage Response;
- CbObjectWriter ResponseObject;
+ CbObjectView ValueObject = ValueField.AsObjectView();
- CbArrayView RecordValuesView = Params.Record["Values"sv].AsArrayView();
- uint32_t AttachmentCount{};
- uint64_t TotalSize = Params.Record.GetSize();
+ Oid ValueId = ValueObject["Id"sv].AsObjectId();
+ IoHash RawHash = ValueObject["RawHash"sv].AsHash();
+ CachePolicy ValuePolicy = Params.Request.Policy.GetValuePolicy(ValueId);
- for (CbFieldView ValueField : RecordValuesView)
+ if (const CbAttachment* Attachment = Params.Package.FindAttachment(RawHash))
{
- CbObjectView ValueObject = ValueField.AsObjectView();
-
- Oid ValueId = ValueObject["Id"sv].AsObjectId();
- IoHash RawHash = ValueObject["RawHash"sv].AsHash();
- CachePolicy ValuePolicy = Params.Request.Policy.GetValuePolicy(ValueId);
-
- if (const CbAttachment* Attachment = Params.Package.FindAttachment(RawHash))
+ if (CompressedBuffer Compressed = Attachment->AsCompressedBinary())
{
- if (CompressedBuffer Compressed = Attachment->AsCompressedBinary())
- {
- Response.AddAttachment(CbAttachment(Compressed));
- AttachmentCount++;
- TotalSize += Compressed.GetCompressedSize();
+ Attachments.emplace_back(Compressed);
+ // Response.AddAttachment(CbAttachment(Compressed));
+ AttachmentCount++;
+ TotalSize += Compressed.GetCompressedSize();
- if (EnumHasAllFlags(ValuePolicy, CachePolicy::StoreLocal))
- {
- IoBuffer Value = Compressed.GetCompressed().Flatten().AsIoBuffer();
- Value.SetContentType(ZenContentType::kCompressedBinary);
- m_CacheStore.Put(Params.Request.Key.Bucket, Params.Request.Key.Hash, {.Value = Value});
- }
+ if (EnumHasAllFlags(ValuePolicy, CachePolicy::StoreLocal))
+ {
+ IoBuffer Value = Compressed.GetCompressed().Flatten().AsIoBuffer();
+ Value.SetContentType(ZenContentType::kCompressedBinary);
+ m_CacheStore.Put(Params.Request.Key.Bucket, Params.Request.Key.Hash, {.Value = Value});
}
}
- else if (EnumHasAllFlags(Params.Request.Policy.GetRecordPolicy(), CachePolicy::QueryLocal))
+ }
+ else if (EnumHasAllFlags(Params.Request.Policy.GetRecordPolicy(), CachePolicy::QueryLocal))
+ {
+ if (IoBuffer Chunk = m_CidStore.FindChunkByCid(RawHash))
{
- if (IoBuffer Chunk = m_CidStore.FindChunkByCid(RawHash))
- {
- Response.AddAttachment(CbAttachment(CompressedBuffer::FromCompressed(SharedBuffer(Chunk))));
- AttachmentCount++;
- }
+ // Response.AddAttachment(CbAttachment(CompressedBuffer::FromCompressed(SharedBuffer(Chunk))));
+ Attachments.emplace_back(CompressedBuffer::FromCompressed(SharedBuffer(Chunk)));
+ AttachmentCount++;
}
}
+ }
- const bool IsComplete = AttachmentCount == RecordValuesView.Num();
- const bool AllowPartial = EnumHasAllFlags(Params.Request.Policy.GetRecordPolicy(), CachePolicy::PartialRecord);
+ const bool IsComplete = AttachmentCount == RecordValuesView.Num();
+ const bool AllowPartial = EnumHasAllFlags(Params.Request.Policy.GetRecordPolicy(), CachePolicy::PartialRecord);
- if (IsComplete || AllowPartial)
- {
- ResponseObject.BeginObject("Result"sv);
+ if (IsComplete || AllowPartial)
+ {
+ ResponseStream.Append([&](CbPackage& Response, CbWriter& ResponseObject) {
+ for (CbAttachment& Attachment : Attachments)
+ {
+ Response.AddAttachment(std::move(Attachment));
+ }
+
+ ResponseObject.BeginObject();
ResponseObject.AddInteger("RequestIndex"sv, int32_t(Params.Request.UserData));
ResponseObject.AddObject("Record"sv, CbObject::Clone(Params.Record));
ResponseObject.EndObject();
+ });
- Response.SetObject(ResponseObject.Save());
-
- SendStreamResponse(RequestMessage.SocketId(),
- RequestMessage.CorrelationId(),
- std::move(Response),
- ++ResponseCount == RequestCount);
+ ZEN_DEBUG("HIT - '{}/{}' {} '{}' (UPSTREAM)",
+ Params.Request.Key.Bucket,
+ Params.Request.Key.Hash,
+ NiceBytes(TotalSize),
+ ToString(ZenContentType::kCbObject));
- ZEN_DEBUG("HIT - '{}/{}' {} '{}' (UPSTREAM)",
- Params.Request.Key.Bucket,
- Params.Request.Key.Hash,
- NiceBytes(TotalSize),
- ToString(ZenContentType::kCbObject));
-
- return;
- }
+ return;
}
+ }
- CbPackage Response;
- CbObjectWriter ResponseObject;
-
- ResponseObject.BeginObject("Error"sv);
+ ResponseStream.Append([&](CbPackage&, CbWriter& ResponseObject) {
+ ResponseObject.BeginObject();
ResponseObject.AddInteger("RequestIndex"sv, int32_t(Params.Request.UserData));
- ResponseObject.AddString("Reason"sv, "Not Found"sv);
ResponseObject.EndObject();
+ });
- Response.SetObject(ResponseObject.Save());
-
- SendStreamResponse(RequestMessage.SocketId(),
- RequestMessage.CorrelationId(),
- std::move(Response),
- ++ResponseCount == RequestCount);
-
- ZEN_DEBUG("MISS - '{}/{}'", Params.Request.Key.Bucket, Params.Request.Key.Hash);
- };
+ ZEN_DEBUG("MISS - '{}/{}'", Params.Request.Key.Bucket, Params.Request.Key.Hash);
+ };
// TODO: Fix this
std::vector<CacheKeyRequest*> RequestPtrs;
@@ -543,7 +614,8 @@ HttpStructuredCacheService::HandleRequest(const WebSocketMessage& RequestMessage
m_UpstreamCache.GetCacheRecords(RequestPtrs, std::move(OnCacheRecordGetComplete));
- ZEN_ASSERT(ResponseCount == RequestCount);
+ IsStreamComplete = ResponseStream.Flush();
+ ZEN_ASSERT(IsStreamComplete);
return true;
}
@@ -572,14 +644,15 @@ HttpStructuredCacheService::HandleRequest(const WebSocketMessage& RequestMessage
IoBuffer CurrentRecordValue;
CompressedBuffer Compressed;
- CbArrayView Requests = Params["Requests"sv].AsArrayView();
- const uint64_t RequestCount = Requests.Num();
- uint64_t ResponseCount = 0;
+ CbArrayView Requests = Params["Requests"sv].AsArrayView();
+ const uint64_t RequestCount = Requests.Num();
+
+ ResponseStreamWriter ResponseStream(SocketServer(), RequestMessage, uint32_t(RequestCount), kMaxBatchCount);
for (int32_t Idx = 0; CbFieldView RequestField : Requests)
{
- CbObjectView RequestObject = RequestField.AsObjectView();
- const int32_t RequestIndex = Idx++;
+ CbObjectView RequestObject = RequestField.AsObjectView();
+ const int32_t RequestIndex = Idx++;
CbObjectView KeyObject = RequestObject["Key"sv].AsObjectView();
CacheKey Key = CacheKey::Create(KeyObject["Bucket"sv].AsString(), KeyObject["Hash"sv].AsHash());
@@ -617,23 +690,18 @@ HttpStructuredCacheService::HandleRequest(const WebSocketMessage& RequestMessage
if (Compressed || RawSize > 0)
{
- CbPackage Response;
- CbObjectWriter ResponseObject;
-
- ResponseObject.BeginObject("Result"sv);
- ResponseObject.AddInteger("RequestIndex"sv, RequestIndex);
- ResponseObject.AddHash("RawHash"sv, Cid);
- ResponseObject.AddInteger("RawSize"sv, RawSize);
- ResponseObject.EndObject();
-
- Response.SetObject(ResponseObject.Save());
-
- if (Compressed)
- {
- Response.AddAttachment(CbAttachment(std::move(Compressed)));
- }
+ ResponseStream.Append([&Compressed, RequestIndex, &Cid, RawSize](CbPackage& Response, CbWriter& ResponseObject) {
+ if (Compressed)
+ {
+ Response.AddAttachment(CbAttachment(std::move(Compressed)));
+ }
- SendStreamResponse(RequestMessage.SocketId(), RequestMessage.CorrelationId(), std::move(Response), ++ResponseCount == RequestCount);
+ ResponseObject.BeginObject();
+ ResponseObject.AddInteger("RequestIndex"sv, RequestIndex);
+ ResponseObject.AddHash("RawHash"sv, Cid);
+ ResponseObject.AddInteger("RawSize"sv, RawSize);
+ ResponseObject.EndObject();
+ });
ZEN_DEBUG("HIT - '{}/{}' {} '{}' (LOCAL)",
Key.Bucket,
@@ -643,22 +711,18 @@ HttpStructuredCacheService::HandleRequest(const WebSocketMessage& RequestMessage
}
else
{
- CbObjectWriter ResponseObject;
-
- ResponseObject.BeginObject("Error"sv);
- ResponseObject.AddInteger("RequestIndex"sv, RequestIndex);
- ResponseObject.EndObject();
-
- SendStreamResponse(RequestMessage.SocketId(),
- RequestMessage.CorrelationId(),
- std::move(ResponseObject.Save()),
- ++ResponseCount == RequestCount);
+ ResponseStream.Append([RequestIndex](CbPackage&, CbWriter& ResponseObject) {
+ ResponseObject.BeginObject();
+ ResponseObject.AddInteger("RequestIndex"sv, RequestIndex);
+ ResponseObject.EndObject();
+ });
ZEN_DEBUG("MISS - '{}/{}' '{}'", Key.Bucket, Key.Hash, ToString(ZenContentType::kCompressedBinary));
}
}
- ZEN_ASSERT(ResponseCount == RequestCount);
+ const bool Complete = ResponseStream.Flush();
+ ZEN_ASSERT(Complete);
return true;
}