aboutsummaryrefslogtreecommitdiff
path: root/src/zenutil/chunkrequests.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/zenutil/chunkrequests.cpp')
-rw-r--r--src/zenutil/chunkrequests.cpp147
1 files changed, 147 insertions, 0 deletions
diff --git a/src/zenutil/chunkrequests.cpp b/src/zenutil/chunkrequests.cpp
new file mode 100644
index 000000000..745363668
--- /dev/null
+++ b/src/zenutil/chunkrequests.cpp
@@ -0,0 +1,147 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zenutil/chunkrequests.h>
+
+#include <zencore/blake3.h>
+#include <zencore/iobuffer.h>
+#include <zencore/sharedbuffer.h>
+#include <zencore/stream.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <gsl/gsl-lite.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+namespace zen {
+namespace {
+ struct RequestHeader
+ {
+ enum
+ {
+ kMagic = 0xAAAA'77AC
+ };
+ uint32_t Magic;
+ uint32_t ChunkCount;
+ uint32_t Reserved1;
+ uint32_t Reserved2;
+ };
+
+ struct ResponseHeader
+ {
+ uint32_t Magic = 0xbada'b00f;
+ uint32_t ChunkCount;
+ uint32_t Reserved1 = 0;
+ uint32_t Reserved2 = 0;
+ };
+
+ struct ResponseChunkEntry
+ {
+ uint32_t CorrelationId;
+ uint32_t Flags = 0;
+ uint64_t ChunkSize;
+ };
+} // namespace
+
+IoBuffer
+BuildChunkBatchRequest(const std::vector<RequestChunkEntry>& Entries)
+{
+ RequestHeader RequestHdr;
+ RequestHdr.Magic = (uint32_t)RequestHeader::kMagic;
+ RequestHdr.ChunkCount = gsl::narrow<uint32_t>(Entries.size());
+ UniqueBuffer Buffer = UniqueBuffer::Alloc(sizeof(RequestHeader) + sizeof(RequestChunkEntry) * RequestHdr.ChunkCount);
+ MutableMemoryView WriteBuffer = Buffer.GetMutableView();
+ WriteBuffer = WriteBuffer.CopyFrom(MemoryView(&RequestHdr, sizeof(RequestHeader)));
+ WriteBuffer.CopyFrom(MemoryView(Entries.data(), sizeof(RequestChunkEntry) * RequestHdr.ChunkCount));
+ return Buffer.MoveToShared().AsIoBuffer();
+}
+
+std::optional<std::vector<RequestChunkEntry>>
+ParseChunkBatchRequest(const IoBuffer& Payload)
+{
+ if (Payload.Size() <= sizeof(RequestHeader))
+ {
+ return {};
+ }
+
+ BinaryReader Reader(Payload);
+
+ RequestHeader RequestHdr;
+ Reader.Read(&RequestHdr, sizeof RequestHdr);
+
+ if (RequestHdr.Magic != RequestHeader::kMagic)
+ {
+ return {};
+ }
+
+ std::vector<RequestChunkEntry> RequestedChunks;
+ RequestedChunks.resize(RequestHdr.ChunkCount);
+ Reader.Read(RequestedChunks.data(), sizeof(RequestChunkEntry) * RequestHdr.ChunkCount);
+ return RequestedChunks;
+}
+
+std::vector<IoBuffer>
+BuildChunkBatchResponse(const std::vector<RequestChunkEntry>& Requests, std::span<IoBuffer> Chunks)
+{
+ ZEN_ASSERT(Requests.size() == Chunks.size());
+ size_t ChunkCount = Requests.size();
+
+ std::vector<IoBuffer> OutBlobs;
+ OutBlobs.reserve(1 + ChunkCount);
+ OutBlobs.emplace_back(sizeof(ResponseHeader) + ChunkCount * sizeof(ResponseChunkEntry));
+
+ uint8_t* ResponsePtr = reinterpret_cast<uint8_t*>(OutBlobs[0].MutableData());
+ ResponseHeader ResponseHdr;
+ ResponseHdr.ChunkCount = gsl::narrow<uint32_t>(Requests.size());
+ memcpy(ResponsePtr, &ResponseHdr, sizeof(ResponseHdr));
+ ResponsePtr += sizeof(ResponseHdr);
+ for (uint32_t ChunkIndex = 0; ChunkIndex < ChunkCount; ++ChunkIndex)
+ {
+ const IoBuffer& FoundChunk(Chunks[ChunkIndex]);
+ ResponseChunkEntry ResponseChunk;
+ ResponseChunk.CorrelationId = Requests[ChunkIndex].CorrelationId;
+ if (FoundChunk)
+ {
+ ResponseChunk.ChunkSize = FoundChunk.Size();
+ }
+ else
+ {
+ ResponseChunk.ChunkSize = uint64_t(-1);
+ }
+ memcpy(ResponsePtr, &ResponseChunk, sizeof(ResponseChunk));
+ ResponsePtr += sizeof(ResponseChunk);
+ }
+ OutBlobs.insert(OutBlobs.end(), Chunks.begin(), Chunks.end());
+ auto It = std::remove_if(OutBlobs.begin() + 1, OutBlobs.end(), [](const IoBuffer& B) { return B.GetSize() == 0; });
+ OutBlobs.erase(It, OutBlobs.end());
+ return OutBlobs;
+}
+
+std::vector<IoBuffer>
+ParseChunkBatchResponse(const IoBuffer& Buffer)
+{
+ MemoryView View = Buffer.GetView();
+ const ResponseHeader* Header = (const ResponseHeader*)View.GetData();
+ if (Header->Magic != 0xbada'b00f)
+ {
+ return {};
+ }
+ View.MidInline(sizeof(ResponseHeader));
+ const ResponseChunkEntry* Entries = (const ResponseChunkEntry*)View.GetData();
+ View.MidInline(sizeof(ResponseChunkEntry) * Header->ChunkCount);
+ std::vector<IoBuffer> Result(Header->ChunkCount);
+ for (uint32_t Index = 0; Index < Header->ChunkCount; Index++)
+ {
+ const ResponseChunkEntry& Entry = Entries[Index];
+ if (Result.size() < Entry.CorrelationId + 1)
+ {
+ Result.resize(Entry.CorrelationId + 1);
+ }
+ if (Entry.ChunkSize != uint64_t(-1))
+ {
+ Result[Entry.CorrelationId] = IoBuffer(IoBuffer::Wrap, View.GetData(), Entry.ChunkSize);
+ View.MidInline(Entry.ChunkSize);
+ }
+ }
+ return Result;
+}
+
+} // namespace zen