diff options
| author | Martin Ridgers <[email protected]> | 2021-11-15 09:10:39 +0100 |
|---|---|---|
| committer | Martin Ridgers <[email protected]> | 2021-11-15 09:10:39 +0100 |
| commit | b258c117aba04c6a672fb87d07d126449d961a73 (patch) | |
| tree | 174ccc6a674a173f417debd31a11d32348f042c6 | |
| parent | Fixed up FileSystemTranersal visitor to use std::fs::path (diff) | |
| parent | Updated cache policy according to UE. (diff) | |
| download | zen-b258c117aba04c6a672fb87d07d126449d961a73.tar.xz zen-b258c117aba04c6a672fb87d07d126449d961a73.zip | |
Merged main
| -rw-r--r-- | zencore/include/zencore/iobuffer.h | 1 | ||||
| -rw-r--r-- | zenserver-test/zenserver-test.cpp | 366 | ||||
| -rw-r--r-- | zenserver/cache/structuredcache.cpp | 622 | ||||
| -rw-r--r-- | zenserver/cache/structuredcache.h | 5 | ||||
| -rw-r--r-- | zenserver/upstream/upstreamapply.cpp | 9 | ||||
| -rw-r--r-- | zenserver/upstream/upstreamapply.h | 3 | ||||
| -rw-r--r-- | zenserver/upstream/upstreamcache.cpp | 343 | ||||
| -rw-r--r-- | zenserver/upstream/upstreamcache.h | 63 | ||||
| -rw-r--r-- | zenserver/upstream/zen.cpp | 29 | ||||
| -rw-r--r-- | zenserver/upstream/zen.h | 3 | ||||
| -rw-r--r-- | zenutil/cache/cachekey.cpp | 9 | ||||
| -rw-r--r-- | zenutil/cache/cachepolicy.cpp | 167 | ||||
| -rw-r--r-- | zenutil/include/zenutil/cache/cache.h | 6 | ||||
| -rw-r--r-- | zenutil/include/zenutil/cache/cachekey.h | 83 | ||||
| -rw-r--r-- | zenutil/include/zenutil/cache/cachepolicy.h | 109 | ||||
| -rw-r--r-- | zenutil/zenutil.vcxproj | 5 | ||||
| -rw-r--r-- | zenutil/zenutil.vcxproj.filters | 20 |
17 files changed, 1618 insertions, 225 deletions
diff --git a/zencore/include/zencore/iobuffer.h b/zencore/include/zencore/iobuffer.h index b1d13c58f..fee89a408 100644 --- a/zencore/include/zencore/iobuffer.h +++ b/zencore/include/zencore/iobuffer.h @@ -380,6 +380,7 @@ public: ZENCORE_API static IoBuffer MakeFromFileHandle(void* FileHandle, uint64_t Offset = 0, uint64_t Size = ~0ull); ZENCORE_API static IoBuffer ReadFromFileMaybe(IoBuffer& InBuffer); inline static IoBuffer MakeCloneFromMemory(const void* Ptr, size_t Sz) { return IoBuffer(IoBuffer::Clone, Ptr, Sz); } + inline static IoBuffer MakeCloneFromMemory(MemoryView Memory) { return IoBuffer(IoBuffer::Clone, Memory.GetData(), Memory.GetSize()); } }; IoHash HashBuffer(IoBuffer& Buffer); diff --git a/zenserver-test/zenserver-test.cpp b/zenserver-test/zenserver-test.cpp index ea234aaed..2fd8bdfcb 100644 --- a/zenserver-test/zenserver-test.cpp +++ b/zenserver-test/zenserver-test.cpp @@ -20,6 +20,7 @@ #include <zenhttp/httpclient.h> #include <zenhttp/httpshared.h> #include <zenhttp/zenhttp.h> +#include <zenutil/cache/cache.h> #include <zenutil/zenserverprocess.h> #if ZEN_USE_MIMALLOC @@ -1118,6 +1119,45 @@ TEST_CASE("project.pipe") } # endif +namespace utils { + + struct ZenConfig + { + std::filesystem::path DataDir; + uint16_t Port; + std::string BaseUri; + std::string Args; + + static ZenConfig New(uint16_t Port = 13337, std::string Args = "") + { + return ZenConfig{.DataDir = TestEnv.CreateNewTestDir(), + .Port = Port, + .BaseUri = "http://localhost:{}/z$"_format(Port), + .Args = std::move(Args)}; + } + + static ZenConfig NewWithUpstream(uint16_t UpstreamPort) + { + return New(13337, "--debug --upstream-thread-count=0 --upstream-zen-url=http://localhost:{}"_format(UpstreamPort)); + } + + void Spawn(ZenServerInstance& Inst) + { + Inst.SetTestDir(DataDir); + Inst.SpawnServer(Port, Args); + Inst.WaitUntilReady(); + } + }; + + void SpawnServer(ZenServerInstance& Server, ZenConfig& Cfg) + { + Server.SetTestDir(Cfg.DataDir); + Server.SpawnServer(Cfg.Port, Cfg.Args); + Server.WaitUntilReady(); + } + +} // namespace utils + TEST_CASE("zcache.basic") { using namespace std::literals; @@ -1424,34 +1464,7 @@ TEST_CASE("zcache.cbpackage") TEST_CASE("zcache.policy") { using namespace std::literals; - - struct ZenConfig - { - std::filesystem::path DataDir; - uint16_t Port; - std::string BaseUri; - std::string Args; - - static ZenConfig New(uint16_t Port = 13337, std::string Args = "") - { - return ZenConfig{.DataDir = TestEnv.CreateNewTestDir(), - .Port = Port, - .BaseUri = "http://localhost:{}/z$"_format(Port), - .Args = std::move(Args)}; - } - - static ZenConfig NewWithUpstream(uint16_t UpstreamPort) - { - return New(13337, "--upstream-thread-count=0 --upstream-zen-url=http://localhost:{}"_format(UpstreamPort)); - } - - void Spawn(ZenServerInstance& Inst) - { - Inst.SetTestDir(DataDir); - Inst.SpawnServer(Port, Args); - Inst.WaitUntilReady(); - } - }; + using namespace utils; auto GenerateData = [](uint64_t Size, zen::IoHash& OutHash) -> zen::UniqueBuffer { auto Buf = zen::UniqueBuffer::Alloc(Size); @@ -1901,6 +1914,303 @@ TEST_CASE("zcache.policy") } } +TEST_CASE("zcache.rpc") +{ + using namespace std::literals; + + auto CreateCacheRecord = [](const zen::CacheKey& CacheKey, size_t PayloadSize) -> zen::CbPackage { + std::vector<uint8_t> Data; + Data.resize(PayloadSize); + for (size_t Idx = 0; Idx < PayloadSize; ++Idx) + { + Data[Idx] = Idx % 255; + } + + zen::CbAttachment Attachment(zen::CompressedBuffer::Compress(SharedBuffer::MakeView(Data.data(), Data.size()))); + + zen::CbObjectWriter CacheRecord; + CacheRecord.BeginObject("CacheKey"sv); + CacheRecord << "Bucket"sv << CacheKey.Bucket << "Hash"sv << CacheKey.Hash; + CacheRecord.EndObject(); + CacheRecord << "Data"sv << Attachment; + + zen::CbPackage Package; + Package.SetObject(CacheRecord.Save()); + Package.AddAttachment(Attachment); + + return Package; + }; + + auto ToIoBuffer = [](zen::CbPackage Package) -> zen::IoBuffer { + zen::BinaryWriter MemStream; + Package.Save(MemStream); + return zen::IoBuffer(zen::IoBuffer::Clone, MemStream.Data(), MemStream.Size()); + }; + + auto PutCacheRecords = [&CreateCacheRecord, &ToIoBuffer](std::string_view BaseUri, + std::string_view Query, + std::string_view Bucket, + size_t Num, + size_t PayloadSize = 1024) -> std::vector<CacheKey> { + std::vector<zen::CacheKey> OutKeys; + + for (uint32_t Key = 1; Key <= Num; ++Key) + { + const zen::CacheKey CacheKey = zen::CacheKey::Create(Bucket, zen::IoHash::HashBuffer(&Key, sizeof uint32_t)); + CbPackage CacheRecord = CreateCacheRecord(CacheKey, PayloadSize); + + OutKeys.push_back(CacheKey); + + IoBuffer Payload = ToIoBuffer(CacheRecord); + + cpr::Response Result = cpr::Put(cpr::Url{"{}/{}/{}{}"_format(BaseUri, CacheKey.Bucket, CacheKey.Hash, Query)}, + cpr::Body{(const char*)Payload.Data(), Payload.Size()}, + cpr::Header{{"Content-Type", "application/x-ue-cbpkg"}}); + + CHECK(Result.status_code == 201); + } + + return OutKeys; + }; + + struct GetCacheRecordResult + { + zen::CbPackage Response; + std::vector<zen::CbFieldView> Records; + bool Success; + }; + + auto GetCacheRecords = + [](std::string_view BaseUri, std::span<zen::CacheKey> Keys, const zen::CacheRecordPolicy& Policy) -> GetCacheRecordResult { + using namespace zen; + + CbObjectWriter Request; + Request << "Method"sv + << "GetCacheRecords"sv; + Request.BeginObject("Params"sv); + + Request.BeginArray("CacheKeys"sv); + for (const CacheKey& Key : Keys) + { + Request.BeginObject(); + Request << "Bucket"sv << Key.Bucket << "Hash"sv << Key.Hash; + Request.EndObject(); + } + Request.EndArray(); + + Request.BeginObject("Policy"); + CacheRecordPolicy::Save(Policy, Request); + Request.EndObject(); + + Request.EndObject(); + + BinaryWriter Body; + Request.Save(Body); + + cpr::Response Result = cpr::Post(cpr::Url{"{}/$rpc"_format(BaseUri)}, + cpr::Header{{"Content-Type", "application/x-ue-cb"}, {"Accept", "application/x-ue-cbpkg"}}, + cpr::Body{(const char*)Body.GetData(), Body.GetSize()}); + + GetCacheRecordResult OutResult; + + if (Result.status_code == 200) + { + CbPackage Response; + if (Response.TryLoad(zen::IoBuffer(zen::IoBuffer::Wrap, Result.text.data(), Result.text.size()))) + { + OutResult.Response = std::move(Response); + CbObjectView ResponseObject = OutResult.Response.GetObject(); + + for (CbFieldView RecordView : ResponseObject["Result"]) + { + ExtendableStringBuilder<256> Tmp; + auto JSON = RecordView.AsObjectView().ToJson(Tmp).ToView(); + OutResult.Records.push_back(RecordView); + } + + OutResult.Success = true; + } + } + + return OutResult; + }; + + auto LoadKey = [](zen::CbFieldView KeyView) -> zen::CacheKey { + if (zen::CbObjectView KeyObj = KeyView.AsObjectView()) + { + return CacheKey::Create(KeyObj["Bucket"sv].AsString(), KeyObj["Hash"].AsHash()); + } + return CacheKey::Empty; + }; + + SUBCASE("get cache records") + { + std::filesystem::path TestDir = TestEnv.CreateNewTestDir(); + const uint16_t PortNumber = 13337; + const auto BaseUri = "http://localhost:{}/z$"_format(PortNumber); + + ZenServerInstance Inst(TestEnv); + Inst.SetTestDir(TestDir); + Inst.SpawnServer(PortNumber); + Inst.WaitUntilReady(); + + CacheRecordPolicy Policy; + std::vector<zen::CacheKey> Keys = PutCacheRecords(BaseUri, ""sv, "mastodon"sv, 128); + GetCacheRecordResult Result = GetCacheRecords(BaseUri, Keys, Policy); + + CHECK(Result.Records.size() == Keys.size()); + + for (size_t Index = 0; CbFieldView RecordView : Result.Records) + { + const CacheKey& ExpectedKey = Keys[Index++]; + + CbObjectView RecordObj = RecordView.AsObjectView(); + CbObjectView KeyObj = RecordObj["CacheKey"sv].AsObjectView(); + const CacheKey Key = CacheKey::Create(KeyObj["Bucket"sv].AsString(), KeyObj["Hash"].AsHash()); + const IoHash AttachmentHash = RecordObj["Data"sv].AsHash(); + const CbAttachment* Attachment = Result.Response.FindAttachment(AttachmentHash); + + CHECK(Key == ExpectedKey); + CHECK(Attachment != nullptr); + } + } + + SUBCASE("get missing cache records") + { + std::filesystem::path TestDir = TestEnv.CreateNewTestDir(); + const uint16_t PortNumber = 13337; + const auto BaseUri = "http://localhost:{}/z$"_format(PortNumber); + + ZenServerInstance Inst(TestEnv); + Inst.SetTestDir(TestDir); + Inst.SpawnServer(PortNumber); + Inst.WaitUntilReady(); + + CacheRecordPolicy Policy; + std::vector<zen::CacheKey> ExistingKeys = PutCacheRecords(BaseUri, ""sv, "mastodon"sv, 128); + std::vector<zen::CacheKey> Keys; + + for (const zen::CacheKey& Key : ExistingKeys) + { + Keys.push_back(Key); + Keys.push_back(CacheKey::Create("missing"sv, IoHash::Zero)); + } + + GetCacheRecordResult Result = GetCacheRecords(BaseUri, Keys, Policy); + + CHECK(Result.Records.size() == Keys.size()); + + size_t KeyIndex = 0; + for (size_t Index = 0; CbFieldView RecordView : Result.Records) + { + const bool Missing = Index++ % 2 != 0; + + if (Missing) + { + CHECK(RecordView.IsNull()); + } + else + { + const CacheKey& ExpectedKey = ExistingKeys[KeyIndex++]; + CbObjectView RecordObj = RecordView.AsObjectView(); + CbObjectView KeyObj = RecordObj["CacheKey"sv].AsObjectView(); + zen::CacheKey Key = LoadKey(RecordObj["CacheKey"sv]); + const IoHash AttachmentHash = RecordObj["Data"sv].AsHash(); + const CbAttachment* Attachment = Result.Response.FindAttachment(AttachmentHash); + + CHECK(Key == ExpectedKey); + CHECK(Attachment != nullptr); + } + } + } + + SUBCASE("policy - 'SkipAttachments' does not return any record attachments") + { + std::filesystem::path TestDir = TestEnv.CreateNewTestDir(); + const uint16_t PortNumber = 13337; + const auto BaseUri = "http://localhost:{}/z$"_format(PortNumber); + + ZenServerInstance Inst(TestEnv); + Inst.SetTestDir(TestDir); + Inst.SpawnServer(PortNumber); + Inst.WaitUntilReady(); + + CacheRecordPolicy Policy(CachePolicy::QueryLocal | CachePolicy::SkipAttachments); + std::vector<zen::CacheKey> Keys = PutCacheRecords(BaseUri, ""sv, "mastodon"sv, 4); + GetCacheRecordResult Result = GetCacheRecords(BaseUri, Keys, Policy); + + CHECK(Result.Records.size() == Keys.size()); + + std::span<const zen::CbAttachment> Attachments = Result.Response.GetAttachments(); + CHECK(Attachments.empty()); + + for (size_t Index = 0; CbFieldView RecordView : Result.Records) + { + const CacheKey& ExpectedKey = Keys[Index++]; + + CbObjectView RecordObj = RecordView.AsObjectView(); + CbObjectView KeyObj = RecordObj["CacheKey"sv].AsObjectView(); + const CacheKey Key = CacheKey::Create(KeyObj["Bucket"sv].AsString(), KeyObj["Hash"].AsHash()); + const IoHash AttachmentHash = RecordObj["Data"sv].AsHash(); + + CHECK(Key == ExpectedKey); + } + } + + SUBCASE("policy - 'QueryLocal' does not query upstream") + { + using namespace utils; + + ZenConfig UpstreamCfg = ZenConfig::New(13338); + ZenServerInstance UpstreamServer(TestEnv); + ZenConfig LocalCfg = ZenConfig::NewWithUpstream(13338); + ZenServerInstance LocalServer(TestEnv); + + SpawnServer(UpstreamServer, UpstreamCfg); + SpawnServer(LocalServer, LocalCfg); + + std::vector<zen::CacheKey> Keys = PutCacheRecords(UpstreamCfg.BaseUri, ""sv, "mastodon"sv, 4); + + CacheRecordPolicy Policy(CachePolicy::QueryLocal); + GetCacheRecordResult Result = GetCacheRecords(LocalCfg.BaseUri, Keys, Policy); + + CHECK(Result.Records.size() == Keys.size()); + + for (CbFieldView RecordView : Result.Records) + { + CHECK(RecordView.IsNull()); + } + } + + SUBCASE("policy - 'QueryRemote' does query upstream") + { + using namespace utils; + + ZenConfig UpstreamCfg = ZenConfig::New(13338); + ZenServerInstance UpstreamServer(TestEnv); + ZenConfig LocalCfg = ZenConfig::NewWithUpstream(13338); + ZenServerInstance LocalServer(TestEnv); + + SpawnServer(UpstreamServer, UpstreamCfg); + SpawnServer(LocalServer, LocalCfg); + + std::vector<zen::CacheKey> Keys = PutCacheRecords(UpstreamCfg.BaseUri, ""sv, "mastodon"sv, 4); + + CacheRecordPolicy Policy(CachePolicy::QueryLocal | CachePolicy::QueryRemote); + GetCacheRecordResult Result = GetCacheRecords(LocalCfg.BaseUri, Keys, Policy); + + CHECK(Result.Records.size() == Keys.size()); + + for (size_t Index = 0; CbFieldView RecordView : Result.Response.GetObject()["Result"sv]) + { + const zen::CacheKey& ExpectedKey = Keys[Index++]; + CbObjectView RecordObj = RecordView.AsObjectView(); + zen::CacheKey Key = LoadKey(RecordObj["CacheKey"sv]); + CHECK(Key == ExpectedKey); + } + } +} + struct RemoteExecutionRequest { RemoteExecutionRequest(std::string_view Host, int Port, std::filesystem::path& TreePath) diff --git a/zenserver/cache/structuredcache.cpp b/zenserver/cache/structuredcache.cpp index 4db8baa32..e1edfd161 100644 --- a/zenserver/cache/structuredcache.cpp +++ b/zenserver/cache/structuredcache.cpp @@ -11,7 +11,9 @@ #include <zencore/timer.h> #include <zenhttp/httpserver.h> #include <zenstore/cas.h> +#include <zenutil/cache/cache.h> +//#include "cachekey.h" #include "monitoring/httpstats.h" #include "structuredcache.h" #include "structuredcachestore.h" @@ -36,115 +38,24 @@ using namespace std::literals; ////////////////////////////////////////////////////////////////////////// -namespace detail { namespace cacheopt { - constexpr std::string_view Local = "local"sv; - constexpr std::string_view Remote = "remote"sv; - constexpr std::string_view Data = "data"sv; - constexpr std::string_view Meta = "meta"sv; - constexpr std::string_view Value = "value"sv; - constexpr std::string_view Attachments = "attachments"sv; -}} // namespace detail::cacheopt - -////////////////////////////////////////////////////////////////////////// - -enum class CachePolicy : uint8_t -{ - None = 0, - QueryLocal = 1 << 0, - QueryRemote = 1 << 1, - Query = QueryLocal | QueryRemote, - StoreLocal = 1 << 2, - StoreRemote = 1 << 3, - Store = StoreLocal | StoreRemote, - SkipMeta = 1 << 4, - SkipValue = 1 << 5, - SkipAttachments = 1 << 6, - SkipData = SkipMeta | SkipValue | SkipAttachments, - SkipLocalCopy = 1 << 7, - Local = QueryLocal | StoreLocal, - Remote = QueryRemote | StoreRemote, - Default = Query | Store, - Disable = None, -}; - -gsl_DEFINE_ENUM_BITMASK_OPERATORS(CachePolicy); - CachePolicy ParseCachePolicy(const HttpServerRequest::QueryParams& QueryParams) { - CachePolicy QueryPolicy = CachePolicy::Query; - - { - std::string_view Opts = QueryParams.GetValue("query"sv); - if (!Opts.empty()) - { - QueryPolicy = CachePolicy::None; - ForEachStrTok(Opts, ',', [&QueryPolicy](const std::string_view& Opt) { - if (Opt == detail::cacheopt::Local) - { - QueryPolicy |= CachePolicy::QueryLocal; - } - if (Opt == detail::cacheopt::Remote) - { - QueryPolicy |= CachePolicy::QueryRemote; - } - return true; - }); - } - } - - CachePolicy StorePolicy = CachePolicy::Store; - - { - std::string_view Opts = QueryParams.GetValue("store"sv); - if (!Opts.empty()) - { - StorePolicy = CachePolicy::None; - ForEachStrTok(Opts, ',', [&StorePolicy](const std::string_view& Opt) { - if (Opt == detail::cacheopt::Local) - { - StorePolicy |= CachePolicy::StoreLocal; - } - if (Opt == detail::cacheopt::Remote) - { - StorePolicy |= CachePolicy::StoreRemote; - } - return true; - }); - } - } - - CachePolicy SkipPolicy = CachePolicy::None; - - { - std::string_view Opts = QueryParams.GetValue("skip"sv); - if (!Opts.empty()) - { - ForEachStrTok(Opts, ',', [&SkipPolicy](const std::string_view& Opt) { - if (Opt == detail::cacheopt::Meta) - { - SkipPolicy |= CachePolicy::SkipMeta; - } - if (Opt == detail::cacheopt::Value) - { - SkipPolicy |= CachePolicy::SkipValue; - } - if (Opt == detail::cacheopt::Attachments) - { - SkipPolicy |= CachePolicy::SkipAttachments; - } - if (Opt == detail::cacheopt::Data) - { - SkipPolicy |= CachePolicy::SkipData; - } - return true; - }); - } - } + const CachePolicy QueryPolicy = zen::ParseQueryCachePolicy(QueryParams.GetValue("query"sv)); + const CachePolicy StorePolicy = zen::ParseStoreCachePolicy(QueryParams.GetValue("store"sv)); + const CachePolicy SkipPolicy = zen::ParseSkipCachePolicy(QueryParams.GetValue("skip"sv)); return QueryPolicy | StorePolicy | SkipPolicy; } +struct AttachmentCount +{ + uint32_t New = 0; + uint32_t Valid = 0; + uint32_t Invalid = 0; + uint32_t Total = 0; +}; + ////////////////////////////////////////////////////////////////////////// HttpStructuredCacheService::HttpStructuredCacheService(ZenCacheStore& InCacheStore, @@ -210,6 +121,11 @@ HttpStructuredCacheService::HandleRequest(HttpServerRequest& Request) { std::string_view Key = Request.RelativeUri(); + if (Key == "$rpc") + { + return HandleRpcRequest(Request); + } + if (std::all_of(begin(Key), end(Key), [](const char c) { return std::isalnum(c); })) { // Bucket reference @@ -325,8 +241,8 @@ HttpStructuredCacheService::HandleGetCacheRecord(zen::HttpServerRequest& Request if (ValidCount != AttachmentCount) { - Success = false; - ZEN_WARN("GET - '{}/{}' '{}' FAILED, found '{}' of '{}' attachments", + // Success = false; + ZEN_WARN("GET - '{}/{}' '{}' is partial, found '{}' of '{}' attachments", Ref.BucketSegment, Ref.HashKey, ToString(AcceptType), @@ -549,52 +465,39 @@ HttpStructuredCacheService::HandlePutCacheRecord(zen::HttpServerRequest& Request CbObjectView CacheRecord(Body.Data()); std::vector<IoHash> ValidAttachments; - uint32_t AttachmentCount = 0; + int32_t TotalCount = 0; - CacheRecord.IterateAttachments([this, &AttachmentCount, &ValidAttachments](CbFieldView AttachmentHash) { + CacheRecord.IterateAttachments([this, &TotalCount, &ValidAttachments](CbFieldView AttachmentHash) { const IoHash Hash = AttachmentHash.AsHash(); if (m_CidStore.ContainsChunk(Hash)) { ValidAttachments.emplace_back(Hash); } - AttachmentCount++; + TotalCount++; }); - const uint32_t ValidCount = static_cast<uint32_t>(ValidAttachments.size()); - const bool ValidCacheRecord = ValidCount == AttachmentCount; - - if (ValidCacheRecord) - { - ZEN_DEBUG("PUT - '{}/{}' {} '{}', {} attachments", - Ref.BucketSegment, - Ref.HashKey, - NiceBytes(Body.Size()), - ToString(ContentType), - ValidCount); + ZEN_DEBUG("PUT - '{}/{}' {} '{}' attachments '{}/{}' (Valid/Total)", + Ref.BucketSegment, + Ref.HashKey, + NiceBytes(Body.Size()), + ToString(ContentType), + TotalCount, + ValidAttachments.size()); - m_CacheStore.Put(Ref.BucketSegment, Ref.HashKey, {.Value = Body}); + Body.SetContentType(ZenContentType::kCbObject); + m_CacheStore.Put(Ref.BucketSegment, Ref.HashKey, {.Value = Body}); - if (StoreUpstream) - { - ZEN_ASSERT(m_UpstreamCache); - m_UpstreamCache->EnqueueUpstream({.Type = ZenContentType::kCbObject, - .CacheKey = {Ref.BucketSegment, Ref.HashKey}, - .PayloadIds = std::move(ValidAttachments)}); - } + const bool IsPartialRecord = TotalCount != static_cast<int32_t>(ValidAttachments.size()); - Request.WriteResponse(HttpResponseCode::Created); - } - else + if (StoreUpstream && !IsPartialRecord) { - ZEN_WARN("PUT - '{}/{}' '{}' FAILED, found {}/{} attachments", - Ref.BucketSegment, - Ref.HashKey, - ToString(ContentType), - ValidCount, - AttachmentCount); - - Request.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "Missing attachments"sv); + ZEN_ASSERT(m_UpstreamCache); + m_UpstreamCache->EnqueueUpstream({.Type = ZenContentType::kCbObject, + .CacheKey = {Ref.BucketSegment, Ref.HashKey}, + .PayloadIds = std::move(ValidAttachments)}); } + + Request.WriteResponse(HttpResponseCode::Created); } else if (ContentType == HttpContentType::kCbPackage) { @@ -606,16 +509,15 @@ HttpStructuredCacheService::HandlePutCacheRecord(zen::HttpServerRequest& Request return Request.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "Invalid package"sv); } - CbObject CacheRecord = Package.GetObject(); - - std::span<const CbAttachment> Attachments = Package.GetAttachments(); - std::vector<IoHash> ValidAttachments; - int32_t NewAttachmentCount = 0; + CbObject CacheRecord = Package.GetObject(); + AttachmentCount Count; + std::vector<IoHash> ValidAttachments; - ValidAttachments.reserve(Attachments.size()); + ValidAttachments.reserve(Package.GetAttachments().size()); - CacheRecord.IterateAttachments([this, &Ref, &Package, &ValidAttachments, &NewAttachmentCount](CbFieldView AttachmentHash) { - if (const CbAttachment* Attachment = Package.FindAttachment(AttachmentHash.AsHash())) + CacheRecord.IterateAttachments([this, &Ref, &Package, &ValidAttachments, &Count](CbFieldView HashView) { + const IoHash Hash = HashView.AsHash(); + if (const CbAttachment* Attachment = Package.FindAttachment(Hash)) { if (Attachment->IsCompressedBinary()) { @@ -626,8 +528,9 @@ HttpStructuredCacheService::HandlePutCacheRecord(zen::HttpServerRequest& Request if (InsertResult.New) { - NewAttachmentCount++; + Count.New++; } + Count.Valid++; } else { @@ -635,40 +538,40 @@ HttpStructuredCacheService::HandlePutCacheRecord(zen::HttpServerRequest& Request Ref.BucketSegment, Ref.HashKey, ToString(HttpContentType::kCbPackage), - AttachmentHash.AsHash()); + Hash); + Count.Invalid++; } } - else + else if (m_CidStore.ContainsChunk(Hash)) { - ZEN_WARN("PUT - '{}/{}' '{}' FAILED, missing attachment '{}'", - Ref.BucketSegment, - Ref.HashKey, - ToString(HttpContentType::kCbPackage), - AttachmentHash.AsHash()); + ValidAttachments.emplace_back(Hash); + Count.Valid++; } + Count.Total++; }); - const bool AttachmentsValid = ValidAttachments.size() == Attachments.size(); - - if (!AttachmentsValid) + if (Count.Invalid > 0) { - return Request.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "Invalid attachments"sv); + return Request.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "Invalid attachment(s)"sv); } - ZEN_DEBUG("PUT - '{}/{}' {} '{}', {}/{} new attachments", + ZEN_DEBUG("PUT - '{}/{}' {} '{}', attachments '{}/{}/{}' (New/Valid/Total)", Ref.BucketSegment, Ref.HashKey, NiceBytes(Body.GetSize()), ToString(ContentType), - NewAttachmentCount, - Attachments.size()); + Count.New, + Count.Valid, + Count.Total); IoBuffer CacheRecordValue = CacheRecord.GetBuffer().AsIoBuffer(); CacheRecordValue.SetContentType(ZenContentType::kCbObject); m_CacheStore.Put(Ref.BucketSegment, Ref.HashKey, {.Value = CacheRecord.GetBuffer().AsIoBuffer()}); - if (StoreUpstream) + const bool IsPartialRecord = Count.Valid != Count.Total; + + if (StoreUpstream && !IsPartialRecord) { ZEN_ASSERT(m_UpstreamCache); m_UpstreamCache->EnqueueUpstream({.Type = ZenContentType::kCbPackage, @@ -714,8 +617,7 @@ HttpStructuredCacheService::HandleGetCachePayload(zen::HttpServerRequest& Reques if (QueryUpstream) { - if (auto UpstreamResult = m_UpstreamCache->GetCachePayload({{Ref.BucketSegment, Ref.HashKey}, Ref.PayloadId}); - UpstreamResult.Success) + if (auto UpstreamResult = m_UpstreamCache->GetCachePayload({Ref.BucketSegment, Ref.HashKey}, Ref.PayloadId); UpstreamResult.Success) { if (CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(UpstreamResult.Value))) { @@ -876,6 +778,400 @@ HttpStructuredCacheService::ValidateKeyUri(HttpServerRequest& Request, CacheRef& } void +HttpStructuredCacheService::HandleRpcRequest(zen::HttpServerRequest& Request) +{ + switch (auto Verb = Request.RequestVerb()) + { + using enum HttpVerb; + + case kPost: + { + const HttpContentType ContentType = Request.RequestContentType(); + const HttpContentType AcceptType = Request.AcceptContentType(); + + if (ContentType != HttpContentType::kCbObject || AcceptType != HttpContentType::kCbPackage) + { + return Request.WriteResponse(HttpResponseCode::BadRequest); + } + + Request.WriteResponseAsync( + [this, RpcRequest = zen::LoadCompactBinaryObject(Request.ReadPayload())](HttpServerRequest& AsyncRequest) { + const std::string_view Method = RpcRequest["Method"sv].AsString(); + if (Method == "GetCacheRecords"sv) + { + HandleRpcGetCacheRecords(AsyncRequest, RpcRequest); + } + else if (Method == "GetCachePayloads"sv) + { + HandleRpcGetCachePayloads(AsyncRequest, RpcRequest); + } + else + { + AsyncRequest.WriteResponse(HttpResponseCode::BadRequest); + } + }); + } + break; + default: + Request.WriteResponse(HttpResponseCode::BadRequest); + break; + } +} + +void +HttpStructuredCacheService::HandleRpcGetCacheRecords(zen::HttpServerRequest& Request, CbObjectView RpcRequest) +{ + using namespace fmt::literals; + + CbPackage RpcResponse; + CacheRecordPolicy Policy; + CbObjectView Params = RpcRequest["Params"sv].AsObjectView(); + std::vector<CacheKey> CacheKeys; + std::vector<IoBuffer> CacheValues; + std::vector<size_t> UpstreamRequests; + + ZEN_ASSERT(RpcRequest["Method"sv].AsString() == "GetCacheRecords"sv); + + CacheRecordPolicy::Load(Params["Policy"sv].AsObjectView(), Policy); + + const bool SkipAttachments = (Policy.GetRecordPolicy() & CachePolicy::SkipAttachments) == CachePolicy::SkipAttachments; + const bool QueryRemote = m_UpstreamCache && ((Policy.GetRecordPolicy() & CachePolicy::QueryRemote) == CachePolicy::QueryRemote); + + for (CbFieldView KeyView : Params["CacheKeys"sv]) + { + CbObjectView KeyObject = KeyView.AsObjectView(); + CacheKeys.push_back(CacheKey::Create(KeyObject["Bucket"sv].AsString(), KeyObject["Hash"sv].AsHash())); + } + + if (CacheKeys.empty()) + { + return Request.WriteResponse(HttpResponseCode::BadRequest); + } + + CacheValues.resize(CacheKeys.size()); + + for (size_t KeyIndex = 0; const CacheKey& Key : CacheKeys) + { + ZenCacheValue CacheValue; + if (m_CacheStore.Get(Key.Bucket, Key.Hash, CacheValue)) + { + CbObjectView CacheRecord(CacheValue.Value.Data()); + + if (!SkipAttachments) + { + CacheRecord.IterateAttachments([this, &RpcResponse](CbFieldView AttachmentHash) { + if (IoBuffer Chunk = m_CidStore.FindChunkByCid(AttachmentHash.AsHash())) + { + RpcResponse.AddAttachment(CbAttachment(CompressedBuffer::FromCompressed(SharedBuffer(Chunk)))); + } + }); + } + + ZEN_DEBUG("HIT - '{}/{}' {} '{}' (LOCAL)", + Key.Bucket, + Key.Hash, + NiceBytes(CacheValue.Value.Size()), + ToString(CacheValue.Value.GetContentType())); + + CacheValues[KeyIndex] = CacheValue.Value; + m_CacheStats.HitCount++; + } + else if (QueryRemote) + { + UpstreamRequests.push_back(KeyIndex); + } + else + { + ZEN_DEBUG("MISS - '{}/{}'", Key.Bucket, Key.Hash); + m_CacheStats.MissCount++; + } + + ++KeyIndex; + } + + if (!UpstreamRequests.empty() && m_UpstreamCache) + { + const auto OnCacheRecordGetComplete = + [this, &CacheKeys, &CacheValues, &RpcResponse, SkipAttachments](CacheRecordGetCompleteParams&& Params) { + if (Params.Record) + { + AttachmentCount Count; + Params.Record.IterateAttachments([this, &RpcResponse, SkipAttachments, &Params, &Count](CbFieldView HashView) { + if (const CbAttachment* Attachment = Params.Package.FindAttachment(HashView.AsHash())) + { + if (CompressedBuffer Compressed = Attachment->AsCompressedBinary()) + { + auto InsertResult = m_CidStore.AddChunk(Compressed); + if (InsertResult.New) + { + Count.New++; + } + Count.Valid++; + + if (!SkipAttachments) + { + RpcResponse.AddAttachment(CbAttachment(Compressed)); + } + } + else + { + ZEN_DEBUG("Uncompressed payload '{}' from upstream cache record '{}/{}'", + HashView.AsHash(), + Params.CacheKey.Bucket, + Params.CacheKey.Hash); + Count.Invalid++; + } + } + else if (m_CidStore.ContainsChunk(HashView.AsHash())) + { + Count.Valid++; + } + Count.Total++; + }); + + ZEN_DEBUG("HIT - '{}/{}' {} '{}' attachments '{}/{}/{}' (New/Valid/Total) (UPSTREAM)", + Params.CacheKey.Bucket, + Params.CacheKey.Hash, + NiceBytes(Params.Record.GetView().GetSize()), + ToString(HttpContentType::kCbPackage), + Count.New, + Count.Valid, + Count.Total); + + ZEN_ASSERT(Params.KeyIndex < CacheValues.size()); + + IoBuffer CacheValue = CbObject::Clone(Params.Record).GetBuffer().AsIoBuffer(); + CacheValue.SetContentType(ZenContentType::kCbObject); + + CacheValues[Params.KeyIndex] = CacheValue; + m_CacheStore.Put(Params.CacheKey.Bucket, Params.CacheKey.Hash, {.Value = CacheValue}); + + m_CacheStats.HitCount++; + m_CacheStats.UpstreamHitCount++; + } + else + { + ZEN_DEBUG("MISS - '{}/{}'", Params.CacheKey.Bucket, Params.CacheKey.Hash); + m_CacheStats.MissCount++; + } + }; + + m_UpstreamCache->GetCacheRecords(CacheKeys, UpstreamRequests, Policy, std::move(OnCacheRecordGetComplete)); + } + + CbObjectWriter ResponseObject; + + ResponseObject.BeginArray("Result"sv); + for (const IoBuffer& Value : CacheValues) + { + if (Value) + { + CbObjectView Record(Value.Data()); + ResponseObject << Record; + } + else + { + ResponseObject.AddNull(); + } + } + ResponseObject.EndArray(); + + RpcResponse.SetObject(ResponseObject.Save()); + + BinaryWriter MemStream; + RpcResponse.Save(MemStream); + + Request.WriteResponse(HttpResponseCode::OK, + HttpContentType::kCbPackage, + IoBuffer(IoBuffer::Wrap, MemStream.GetData(), MemStream.GetSize())); +} + +void +HttpStructuredCacheService::HandleRpcGetCachePayloads(zen::HttpServerRequest& Request, CbObjectView RpcRequest) +{ + using namespace fmt::literals; + + ZEN_ASSERT(RpcRequest["Method"sv].AsString() == "GetCachePayloads"sv); + + std::vector<CacheChunkRequest> ChunkRequests; + std::vector<size_t> UpstreamRequests; + std::vector<IoBuffer> Chunks; + CbObjectView Params = RpcRequest["Params"sv].AsObjectView(); + + for (CbFieldView RequestView : Params["ChunkRequests"sv]) + { + CbObjectView RequestObject = RequestView.AsObjectView(); + CbObjectView KeyObject = RequestObject["Key"sv].AsObjectView(); + const CacheKey Key = CacheKey::Create(KeyObject["Bucket"sv].AsString(), KeyObject["Hash"sv].AsHash()); + const IoHash ChunkId = RequestObject["ChunkId"sv].AsHash(); + const Oid PayloadId = RequestObject["PayloadId"sv].AsObjectId(); + const uint64_t RawOffset = RequestObject["RawOffset"sv].AsUInt64(); + const uint64_t RawSize = RequestObject["RawSize"sv].AsUInt64(); + const uint32_t ChunkPolicy = RequestObject["Policy"sv].AsUInt32(); + + ChunkRequests.emplace_back(Key, ChunkId, PayloadId, RawOffset, RawSize, static_cast<CachePolicy>(ChunkPolicy)); + } + + if (ChunkRequests.empty()) + { + return Request.WriteResponse(HttpResponseCode::BadRequest); + } + + Chunks.resize(ChunkRequests.size()); + + // Unreal uses a 12 byte ID to address cache record payloads. When the uncompressed hash (ChunkId) + // is missing, load the cache record and try to find the raw hash from the payload ID. + { + const auto GetChunkIdFromPayloadId = [](CbObjectView Record, const Oid& PayloadId) -> IoHash { + if (CbObjectView ValueObject = Record["Value"sv].AsObjectView()) + { + const Oid Id = ValueObject["Id"sv].AsObjectId(); + if (Id == PayloadId) + { + return ValueObject["RawHash"sv].AsHash(); + } + } + + for (CbFieldView AttachmentView : Record["Attachments"sv]) + { + CbObjectView AttachmentObject = AttachmentView.AsObjectView(); + const Oid Id = AttachmentObject["Id"sv].AsObjectId(); + + if (Id == PayloadId) + { + return AttachmentObject["RawHash"sv].AsHash(); + } + } + + return IoHash::Zero; + }; + + CacheKey CurrentKey = CacheKey::Empty; + IoBuffer CurrentRecordBuffer; + + for (CacheChunkRequest& ChunkRequest : ChunkRequests) + { + if (ChunkRequest.ChunkId != IoHash::Zero) + { + continue; + } + + if (ChunkRequest.Key != CurrentKey) + { + CurrentKey = ChunkRequest.Key; + + ZenCacheValue CacheValue; + if (m_CacheStore.Get(CurrentKey.Bucket, CurrentKey.Hash, CacheValue)) + { + CurrentRecordBuffer = CacheValue.Value; + } + } + + if (CurrentRecordBuffer) + { + ChunkRequest.ChunkId = GetChunkIdFromPayloadId(CbObjectView(CurrentRecordBuffer.GetData()), ChunkRequest.PayloadId); + } + } + } + + for (size_t RequestIndex = 0; const CacheChunkRequest& ChunkRequest : ChunkRequests) + { + const bool QueryLocal = (ChunkRequest.Policy & CachePolicy::QueryLocal) == CachePolicy::QueryLocal; + const bool QueryRemote = (ChunkRequest.Policy & CachePolicy::QueryRemote) == CachePolicy::QueryRemote; + + if (QueryLocal) + { + if (IoBuffer Chunk = m_CidStore.FindChunkByCid(ChunkRequest.ChunkId)) + { + ZEN_DEBUG("HIT - '{}/{}/{}' {} '{}' ({})", + ChunkRequest.Key.Bucket, + ChunkRequest.Key.Hash, + ChunkRequest.ChunkId, + NiceBytes(Chunk.Size()), + ToString(Chunk.GetContentType()), + "LOCAL"); + + Chunks[RequestIndex] = Chunk; + m_CacheStats.HitCount++; + } + else if (QueryRemote) + { + UpstreamRequests.push_back(RequestIndex); + } + else + { + ZEN_DEBUG("MISS - '{}/{}/{}'", ChunkRequest.Key.Bucket, ChunkRequest.Key.Hash, ChunkRequest.ChunkId); + m_CacheStats.MissCount++; + } + } + else + { + ZEN_DEBUG("SKIP - '{}/{}/{}'", ChunkRequest.Key.Bucket, ChunkRequest.Key.Hash, ChunkRequest.ChunkId); + } + + ++RequestIndex; + } + + if (!UpstreamRequests.empty() && m_UpstreamCache) + { + const auto OnCachePayloadGetComplete = [this, &ChunkRequests, &Chunks](CachePayloadGetCompleteParams&& Params) { + if (CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(Params.Payload))) + { + auto InsertResult = m_CidStore.AddChunk(Compressed); + + ZEN_DEBUG("HIT - '{}/{}/{}' {} ({})", + Params.Request.Key.Bucket, + Params.Request.Key.Hash, + Params.Request.ChunkId, + NiceBytes(Params.Payload.GetSize()), + "UPSTREAM"); + + ZEN_ASSERT(Params.RequestIndex < Chunks.size()); + Chunks[Params.RequestIndex] = std::move(Params.Payload); + + m_CacheStats.HitCount++; + m_CacheStats.UpstreamHitCount++; + } + else + { + ZEN_DEBUG("MISS - '{}/{}/{}'", Params.Request.Key.Bucket, Params.Request.Key.Hash, Params.Request.ChunkId); + m_CacheStats.MissCount++; + } + }; + + m_UpstreamCache->GetCachePayloads(ChunkRequests, UpstreamRequests, std::move(OnCachePayloadGetComplete)); + } + + CbPackage RpcResponse; + CbObjectWriter ResponseObject; + + ResponseObject.BeginArray("Result"sv); + + for (size_t ChunkIndex = 0; ChunkIndex < Chunks.size(); ++ChunkIndex) + { + if (Chunks[ChunkIndex]) + { + ResponseObject << ChunkRequests[ChunkIndex].ChunkId; + RpcResponse.AddAttachment(CbAttachment(CompressedBuffer::FromCompressed(SharedBuffer(std::move(Chunks[ChunkIndex]))))); + } + else + { + ResponseObject << IoHash::Zero; + } + } + ResponseObject.EndArray(); + + RpcResponse.SetObject(ResponseObject.Save()); + + BinaryWriter MemStream; + RpcResponse.Save(MemStream); + + Request.WriteResponse(HttpResponseCode::OK, + HttpContentType::kCbPackage, + IoBuffer(IoBuffer::Wrap, MemStream.GetData(), MemStream.GetSize())); +} + +void HttpStructuredCacheService::HandleStatsRequest(zen::HttpServerRequest& Request) { CbObjectWriter Cbo; diff --git a/zenserver/cache/structuredcache.h b/zenserver/cache/structuredcache.h index ad7253f79..51073d05d 100644 --- a/zenserver/cache/structuredcache.h +++ b/zenserver/cache/structuredcache.h @@ -20,7 +20,7 @@ class CasStore; class CidStore; class UpstreamCache; class ZenCacheStore; -enum class CachePolicy : uint8_t; +enum class CachePolicy : uint32_t; /** * Structured cache service. Imposes constraints on keys, supports blobs and @@ -89,6 +89,9 @@ private: void HandleCachePayloadRequest(zen::HttpServerRequest& Request, const CacheRef& Ref, CachePolicy Policy); void HandleGetCachePayload(zen::HttpServerRequest& Request, const CacheRef& Ref, CachePolicy Policy); void HandlePutCachePayload(zen::HttpServerRequest& Request, const CacheRef& Ref, CachePolicy Policy); + void HandleRpcRequest(zen::HttpServerRequest& Request); + void HandleRpcGetCacheRecords(zen::HttpServerRequest& Request, CbObjectView BatchRequest); + void HandleRpcGetCachePayloads(zen::HttpServerRequest& Request, CbObjectView BatchRequest); void HandleCacheBucketRequest(zen::HttpServerRequest& Request, std::string_view Bucket); virtual void HandleStatsRequest(zen::HttpServerRequest& Request) override; virtual void HandleStatusRequest(zen::HttpServerRequest& Request) override; diff --git a/zenserver/upstream/upstreamapply.cpp b/zenserver/upstream/upstreamapply.cpp index 19d02f753..3c67779c4 100644 --- a/zenserver/upstream/upstreamapply.cpp +++ b/zenserver/upstream/upstreamapply.cpp @@ -1157,8 +1157,9 @@ public: { if (m_RunState.IsRunning) { - const IoHash WorkerId = ApplyRecord.WorkerDescriptor.GetHash(); - const IoHash ActionId = ApplyRecord.Action.GetHash(); + const IoHash WorkerId = ApplyRecord.WorkerDescriptor.GetHash(); + const IoHash ActionId = ApplyRecord.Action.GetHash(); + const uint32_t TimeoutSeconds = ApplyRecord.WorkerDescriptor["timeout"sv].AsInt32(300); { std::scoped_lock Lock(m_ApplyTasksMutex); @@ -1169,8 +1170,8 @@ public: } std::chrono::steady_clock::time_point ExpireTime = - ApplyRecord.ExpireSeconds > 0 ? std::chrono::steady_clock::now() + std::chrono::seconds(ApplyRecord.ExpireSeconds) - : std::chrono::steady_clock::time_point::max(); + TimeoutSeconds > 0 ? std::chrono::steady_clock::now() + std::chrono::seconds(TimeoutSeconds) + : std::chrono::steady_clock::time_point::max(); m_ApplyTasks[WorkerId][ActionId] = {.State = UpstreamApplyState::Queued, .Result{}, .ExpireTime = std::move(ExpireTime)}; } diff --git a/zenserver/upstream/upstreamapply.h b/zenserver/upstream/upstreamapply.h index e5f0e4faa..8196c3b40 100644 --- a/zenserver/upstream/upstreamapply.h +++ b/zenserver/upstream/upstreamapply.h @@ -36,7 +36,6 @@ struct UpstreamApplyRecord { CbObject WorkerDescriptor; CbObject Action; - uint32_t ExpireSeconds{}; }; struct UpstreamApplyOptions @@ -94,7 +93,7 @@ struct UpstreamApplyStatus std::chrono::steady_clock::time_point ExpireTime{}; }; -using UpstreamApplyTasks = std::unordered_map<IoHash, std::unordered_map<IoHash, UpstreamApplyStatus>>; +using UpstreamApplyTasks = std::unordered_map<IoHash, std::unordered_map<IoHash, UpstreamApplyStatus>>; struct UpstreamEndpointHealth { diff --git a/zenserver/upstream/upstreamcache.cpp b/zenserver/upstream/upstreamcache.cpp index 00555f2ce..ade71c5d2 100644 --- a/zenserver/upstream/upstreamcache.cpp +++ b/zenserver/upstream/upstreamcache.cpp @@ -70,7 +70,7 @@ namespace detail { virtual std::string_view DisplayName() const override { return m_DisplayName; } - virtual GetUpstreamCacheResult GetCacheRecord(UpstreamCacheKey CacheKey, ZenContentType Type) override + virtual GetUpstreamCacheResult GetCacheRecord(CacheKey CacheKey, ZenContentType Type) override { try { @@ -144,12 +144,69 @@ namespace detail { } } - virtual GetUpstreamCacheResult GetCachePayload(UpstreamPayloadKey PayloadKey) override + virtual GetUpstreamCacheResult GetCacheRecords(std::span<CacheKey> CacheKeys, + std::span<size_t> KeyIndex, + const CacheRecordPolicy& Policy, + OnCacheRecordGetComplete&& OnComplete) override + { + ZEN_UNUSED(Policy); + + CloudCacheSession Session(m_Client); + GetUpstreamCacheResult Result; + + for (size_t Index : KeyIndex) + { + const CacheKey& CacheKey = CacheKeys[Index]; + CbPackage Package; + CbObject Record; + + if (!Result.Error) + { + CloudCacheResult RefResult = Session.GetRef(CacheKey.Bucket, CacheKey.Hash, ZenContentType::kCbObject); + AppendResult(RefResult, Result); + + if (RefResult.ErrorCode == 0) + { + const CbValidateError ValidationResult = ValidateCompactBinary(RefResult.Response, CbValidateMode::All); + if (ValidationResult == CbValidateError::None) + { + Record = LoadCompactBinaryObject(RefResult.Response); + Record.IterateAttachments([this, &Session, &Result, &Package](CbFieldView AttachmentHash) { + CloudCacheResult BlobResult = Session.GetCompressedBlob(AttachmentHash.AsHash()); + AppendResult(BlobResult, Result); + + if (BlobResult.ErrorCode == 0) + { + if (CompressedBuffer Chunk = CompressedBuffer::FromCompressed(SharedBuffer(BlobResult.Response))) + { + Package.AddAttachment(CbAttachment(Chunk)); + } + } + else + { + m_HealthOk = false; + } + }); + } + } + else + { + m_HealthOk = false; + } + } + + OnComplete({.CacheKey = CacheKey, .KeyIndex = Index, .Record = Record, .Package = Package}); + } + + return Result; + } + + virtual GetUpstreamCacheResult GetCachePayload(const CacheKey&, const IoHash& PayloadId) override { try { CloudCacheSession Session(m_Client); - const CloudCacheResult Result = Session.GetCompressedBlob(PayloadKey.PayloadId); + const CloudCacheResult Result = Session.GetCompressedBlob(PayloadId); if (Result.ErrorCode == 0) { @@ -171,6 +228,33 @@ namespace detail { } } + virtual GetUpstreamCacheResult GetCachePayloads(std::span<CacheChunkRequest> CacheChunkRequests, + std::span<size_t> RequestIndex, + OnCachePayloadGetComplete&& OnComplete) override final + { + CloudCacheSession Session(m_Client); + GetUpstreamCacheResult Result; + + for (size_t Index : RequestIndex) + { + const CacheChunkRequest& Request = CacheChunkRequests[Index]; + IoBuffer Payload; + + if (!Result.Error) + { + const CloudCacheResult BlobResult = Session.GetCompressedBlob(Request.ChunkId); + Payload = BlobResult.Response; + + AppendResult(BlobResult, Result); + m_HealthOk = BlobResult.ErrorCode == 0; + } + + OnComplete({.Request = Request, .RequestIndex = Index, .Payload = Payload}); + } + + return Result; + } + virtual PutUpstreamCacheResult PutCacheRecord(const UpstreamCacheRecord& CacheRecord, IoBuffer RecordValue, std::span<IoBuffer const> Payloads) override @@ -323,6 +407,18 @@ namespace detail { virtual UpstreamEndpointStats& Stats() override { return m_Stats; } private: + static void AppendResult(const CloudCacheResult& Result, GetUpstreamCacheResult& Out) + { + Out.Success &= Result.Success; + Out.Bytes += Result.Bytes; + Out.ElapsedSeconds += Result.ElapsedSeconds; + + if (Result.ErrorCode) + { + Out.Error = {.ErrorCode = Result.ErrorCode, .Reason = std::move(Result.Reason)}; + } + }; + spdlog::logger& Log() { return m_Log; } spdlog::logger& m_Log; @@ -419,7 +515,7 @@ namespace detail { virtual std::string_view DisplayName() const override { return m_DisplayName; } - virtual GetUpstreamCacheResult GetCacheRecord(UpstreamCacheKey CacheKey, ZenContentType Type) override + virtual GetUpstreamCacheResult GetCacheRecord(CacheKey CacheKey, ZenContentType Type) override { try { @@ -446,13 +542,80 @@ namespace detail { } } - virtual GetUpstreamCacheResult GetCachePayload(UpstreamPayloadKey PayloadKey) override + virtual GetUpstreamCacheResult GetCacheRecords(std::span<CacheKey> CacheKeys, + std::span<size_t> KeyIndex, + const CacheRecordPolicy& Policy, + OnCacheRecordGetComplete&& OnComplete) override + { + std::vector<size_t> IndexMap; + IndexMap.reserve(KeyIndex.size()); + + CbObjectWriter BatchRequest; + BatchRequest << "Method"sv + << "GetCacheRecords"; + + BatchRequest.BeginObject("Params"sv); + { + BatchRequest.BeginArray("CacheKeys"sv); + for (size_t Index : KeyIndex) + { + const CacheKey& Key = CacheKeys[Index]; + IndexMap.push_back(Index); + + BatchRequest.BeginObject(); + BatchRequest << "Bucket"sv << Key.Bucket; + BatchRequest << "Hash"sv << Key.Hash; + BatchRequest.EndObject(); + } + BatchRequest.EndArray(); + + BatchRequest.BeginObject("Policy"sv); + CacheRecordPolicy::Save(Policy, BatchRequest); + BatchRequest.EndObject(); + } + BatchRequest.EndObject(); + + CbPackage BatchResponse; + ZenCacheResult Result; + + { + ZenStructuredCacheSession Session(*m_Client); + Result = Session.InvokeRpc(BatchRequest.Save()); + } + + if (Result.Success) + { + if (BatchResponse.TryLoad(Result.Response)) + { + for (size_t LocalIndex = 0; CbFieldView Record : BatchResponse.GetObject()["Result"sv]) + { + const size_t Index = IndexMap[LocalIndex++]; + OnComplete( + {.CacheKey = CacheKeys[Index], .KeyIndex = Index, .Record = Record.AsObjectView(), .Package = BatchResponse}); + } + + return {.Bytes = Result.Bytes, .ElapsedSeconds = Result.ElapsedSeconds, .Success = true}; + } + } + else if (Result.ErrorCode) + { + m_HealthOk = false; + } + + for (size_t Index : KeyIndex) + { + OnComplete({.CacheKey = CacheKeys[Index], .KeyIndex = Index, .Record = CbObjectView(), .Package = CbPackage()}); + } + + return {.Error{.ErrorCode = Result.ErrorCode, .Reason = std::move(Result.Reason)}}; + } + + virtual GetUpstreamCacheResult GetCachePayload(const CacheKey& CacheKey, const IoHash& PayloadId) override { try { ZenStructuredCacheSession Session(*m_Client); - const ZenCacheResult Result = - Session.GetCachePayload(PayloadKey.CacheKey.Bucket, PayloadKey.CacheKey.Hash, PayloadKey.PayloadId); + const ZenCacheResult Result = Session.GetCachePayload(CacheKey.Bucket, CacheKey.Hash, PayloadId); if (Result.ErrorCode == 0) { @@ -474,6 +637,90 @@ namespace detail { } } + virtual GetUpstreamCacheResult GetCachePayloads(std::span<CacheChunkRequest> CacheChunkRequests, + std::span<size_t> RequestIndex, + OnCachePayloadGetComplete&& OnComplete) override final + { + std::vector<size_t> IndexMap; + IndexMap.reserve(RequestIndex.size()); + + CbObjectWriter BatchRequest; + BatchRequest << "Method"sv + << "GetCachePayloads"; + + BatchRequest.BeginObject("Params"sv); + { + BatchRequest.BeginArray("ChunkRequests"sv); + { + for (size_t Index : RequestIndex) + { + const CacheChunkRequest& Request = CacheChunkRequests[Index]; + IndexMap.push_back(Index); + + BatchRequest.BeginObject(); + { + BatchRequest.BeginObject("Key"sv); + BatchRequest << "Bucket"sv << Request.Key.Bucket; + BatchRequest << "Hash"sv << Request.Key.Hash; + BatchRequest.EndObject(); + + BatchRequest.AddObjectId("PayloadId"sv, Request.PayloadId); + BatchRequest << "ChunkId"sv << Request.ChunkId; + BatchRequest << "RawOffset"sv << Request.RawOffset; + BatchRequest << "RawSize"sv << Request.RawSize; + BatchRequest << "Policy"sv << static_cast<uint32_t>(Request.Policy); + } + BatchRequest.EndObject(); + } + } + BatchRequest.EndArray(); + } + BatchRequest.EndObject(); + + CbPackage BatchResponse; + ZenCacheResult Result; + + { + ZenStructuredCacheSession Session(*m_Client); + Result = Session.InvokeRpc(BatchRequest.Save()); + } + + if (Result.Success) + { + if (BatchResponse.TryLoad(Result.Response)) + { + for (size_t LocalIndex = 0; CbFieldView AttachmentHash : BatchResponse.GetObject()["Result"sv]) + { + const size_t Index = IndexMap[LocalIndex++]; + IoBuffer Payload; + + if (const CbAttachment* Attachment = BatchResponse.FindAttachment(AttachmentHash.AsHash())) + { + if (const CompressedBuffer& Compressed = Attachment->AsCompressedBinary()) + { + Payload = Compressed.GetCompressed().Flatten().AsIoBuffer(); + } + } + + OnComplete({.Request = CacheChunkRequests[Index], .RequestIndex = Index, .Payload = std::move(Payload)}); + } + + return {.Bytes = Result.Bytes, .ElapsedSeconds = Result.ElapsedSeconds, .Success = true}; + } + } + else if (Result.ErrorCode) + { + m_HealthOk = false; + } + + for (size_t Index : RequestIndex) + { + OnComplete({.Request = CacheChunkRequests[Index], .RequestIndex = Index, .Payload = IoBuffer()}); + } + + return {.Error{.ErrorCode = Result.ErrorCode, .Reason = std::move(Result.Reason)}}; + } + virtual PutUpstreamCacheResult PutCacheRecord(const UpstreamCacheRecord& CacheRecord, IoBuffer RecordValue, std::span<IoBuffer const> Payloads) override @@ -758,7 +1005,7 @@ public: virtual void RegisterEndpoint(std::unique_ptr<UpstreamEndpoint> Endpoint) override { m_Endpoints.emplace_back(std::move(Endpoint)); } - virtual GetUpstreamCacheResult GetCacheRecord(UpstreamCacheKey CacheKey, ZenContentType Type) override + virtual GetUpstreamCacheResult GetCacheRecord(CacheKey CacheKey, ZenContentType Type) override { if (m_Options.ReadUpstream) { @@ -780,7 +1027,83 @@ public: return {}; } - virtual GetUpstreamCacheResult GetCachePayload(UpstreamPayloadKey PayloadKey) override + virtual void GetCacheRecords(std::span<CacheKey> CacheKeys, + std::span<size_t> KeyIndex, + const CacheRecordPolicy& Policy, + OnCacheRecordGetComplete&& OnComplete) override final + { + std::vector<size_t> MissingKeys(KeyIndex.begin(), KeyIndex.end()); + + if (m_Options.ReadUpstream) + { + for (auto& Endpoint : m_Endpoints) + { + if (Endpoint->IsHealthy() && !MissingKeys.empty()) + { + std::vector<size_t> Missing; + + auto Result = Endpoint->GetCacheRecords(CacheKeys, MissingKeys, Policy, [&](CacheRecordGetCompleteParams&& Params) { + if (Params.Record) + { + OnComplete(std::forward<CacheRecordGetCompleteParams>(Params)); + } + else + { + Missing.push_back(Params.KeyIndex); + } + }); + + m_Stats.Add(m_Log, *Endpoint, Result, m_Endpoints); + MissingKeys = std::move(Missing); + } + } + } + + for (size_t Index : MissingKeys) + { + OnComplete({.CacheKey = CacheKeys[Index], .KeyIndex = Index, .Record = CbObjectView(), .Package = CbPackage()}); + } + } + + virtual void GetCachePayloads(std::span<CacheChunkRequest> CacheChunkRequests, + std::span<size_t> RequestIndex, + OnCachePayloadGetComplete&& OnComplete) override final + { + std::vector<size_t> MissingPayloads(RequestIndex.begin(), RequestIndex.end()); + + if (m_Options.ReadUpstream) + { + for (auto& Endpoint : m_Endpoints) + { + if (Endpoint->IsHealthy() && !MissingPayloads.empty()) + { + std::vector<size_t> Missing; + + auto Result = + Endpoint->GetCachePayloads(CacheChunkRequests, MissingPayloads, [&](CachePayloadGetCompleteParams&& Params) { + if (Params.Payload) + { + OnComplete(std::forward<CachePayloadGetCompleteParams>(Params)); + } + else + { + Missing.push_back(Params.RequestIndex); + } + }); + + m_Stats.Add(m_Log, *Endpoint, Result, m_Endpoints); + MissingPayloads = std::move(Missing); + } + } + } + + for (size_t Index : MissingPayloads) + { + OnComplete({.Request = CacheChunkRequests[Index], .RequestIndex = Index, .Payload = IoBuffer()}); + } + } + + virtual GetUpstreamCacheResult GetCachePayload(const CacheKey& CacheKey, const IoHash& PayloadId) override { if (m_Options.ReadUpstream) { @@ -788,7 +1111,7 @@ public: { if (Endpoint->IsHealthy()) { - const GetUpstreamCacheResult Result = Endpoint->GetCachePayload(PayloadKey); + const GetUpstreamCacheResult Result = Endpoint->GetCachePayload(CacheKey, PayloadId); m_Stats.Add(m_Log, *Endpoint, Result, m_Endpoints); if (Result.Success) diff --git a/zenserver/upstream/upstreamcache.h b/zenserver/upstream/upstreamcache.h index edc995da6..e5c3521b9 100644 --- a/zenserver/upstream/upstreamcache.h +++ b/zenserver/upstream/upstreamcache.h @@ -5,34 +5,26 @@ #include <zencore/iobuffer.h> #include <zencore/iohash.h> #include <zencore/zencore.h> +#include <zenutil/cache/cache.h> #include <atomic> #include <chrono> +#include <functional> #include <memory> namespace zen { +class CbObjectView; +class CbPackage; class CbObjectWriter; class CidStore; class ZenCacheStore; struct CloudCacheClientOptions; -struct UpstreamCacheKey -{ - std::string Bucket; - IoHash Hash; -}; - -struct UpstreamPayloadKey -{ - UpstreamCacheKey CacheKey; - IoHash PayloadId; -}; - struct UpstreamCacheRecord { ZenContentType Type = ZenContentType::kBinary; - UpstreamCacheKey CacheKey; + CacheKey CacheKey; std::vector<IoHash> PayloadIds; }; @@ -88,6 +80,25 @@ struct UpstreamEndpointStats std::atomic<double> SecondsDown{}; }; +struct CacheRecordGetCompleteParams +{ + const CacheKey& CacheKey; + size_t KeyIndex = ~size_t(0); + const CbObjectView& Record; + const CbPackage& Package; +}; + +using OnCacheRecordGetComplete = std::function<void(CacheRecordGetCompleteParams&&)>; + +struct CachePayloadGetCompleteParams +{ + const CacheChunkRequest& Request; + size_t RequestIndex{~size_t(0)}; + IoBuffer Payload; +}; + +using OnCachePayloadGetComplete = std::function<void(CachePayloadGetCompleteParams&&)>; + /** * The upstream endpont is responsible for handling upload/downloading of cache records. */ @@ -104,9 +115,18 @@ public: virtual std::string_view DisplayName() const = 0; - virtual GetUpstreamCacheResult GetCacheRecord(UpstreamCacheKey CacheKey, ZenContentType Type) = 0; + virtual GetUpstreamCacheResult GetCacheRecord(CacheKey CacheKey, ZenContentType Type) = 0; - virtual GetUpstreamCacheResult GetCachePayload(UpstreamPayloadKey PayloadKey) = 0; + virtual GetUpstreamCacheResult GetCacheRecords(std::span<CacheKey> CacheKeys, + std::span<size_t> KeyIndex, + const CacheRecordPolicy& Policy, + OnCacheRecordGetComplete&& OnComplete) = 0; + + virtual GetUpstreamCacheResult GetCachePayload(const CacheKey& CacheKey, const IoHash& PayloadId) = 0; + + virtual GetUpstreamCacheResult GetCachePayloads(std::span<CacheChunkRequest> CacheChunkRequests, + std::span<size_t> RequestIndex, + OnCachePayloadGetComplete&& OnComplete) = 0; virtual PutUpstreamCacheResult PutCacheRecord(const UpstreamCacheRecord& CacheRecord, IoBuffer RecordValue, @@ -127,9 +147,18 @@ public: virtual void RegisterEndpoint(std::unique_ptr<UpstreamEndpoint> Endpoint) = 0; - virtual GetUpstreamCacheResult GetCacheRecord(UpstreamCacheKey CacheKey, ZenContentType Type) = 0; + virtual GetUpstreamCacheResult GetCacheRecord(CacheKey CacheKey, ZenContentType Type) = 0; + + virtual void GetCacheRecords(std::span<CacheKey> CacheKeys, + std::span<size_t> KeyIndex, + const CacheRecordPolicy& RecordPolicy, + OnCacheRecordGetComplete&& OnComplete) = 0; + + virtual GetUpstreamCacheResult GetCachePayload(const CacheKey& CacheKey, const IoHash& PayloadId) = 0; - virtual GetUpstreamCacheResult GetCachePayload(UpstreamPayloadKey PayloadKey) = 0; + virtual void GetCachePayloads(std::span<CacheChunkRequest> CacheChunkRequests, + std::span<size_t> RequestIndex, + OnCachePayloadGetComplete&& OnComplete) = 0; struct EnqueueResult { diff --git a/zenserver/upstream/zen.cpp b/zenserver/upstream/zen.cpp index 14333f45a..9ba767098 100644 --- a/zenserver/upstream/zen.cpp +++ b/zenserver/upstream/zen.cpp @@ -499,4 +499,33 @@ ZenStructuredCacheSession::PutCachePayload(std::string_view BucketId, const IoHa .Success = (Response.status_code == 200 || Response.status_code == 201)}; } +ZenCacheResult +ZenStructuredCacheSession::InvokeRpc(const CbObjectView& Request) +{ + ExtendableStringBuilder<256> Uri; + Uri << m_Client.ServiceUrl() << "/z$/$rpc"; + + BinaryWriter Body; + Request.CopyTo(Body); + + cpr::Session& Session = m_SessionState->Session; + + Session.SetOption(cpr::Url{Uri.c_str()}); + Session.SetHeader(cpr::Header{{"Content-Type", "application/x-ue-cb"}, {"Accept", "application/x-ue-cbpkg"}}); + Session.SetBody(cpr::Body{reinterpret_cast<const char*>(Body.GetData()), Body.GetSize()}); + + cpr::Response Response = Session.Post(); + ZEN_DEBUG("POST {}", Response); + + if (Response.error) + { + return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = std::move(Response.error.message)}; + } + + const bool Success = Response.status_code == 200; + const IoBuffer Buffer = Success ? IoBufferBuilder::MakeCloneFromMemory(Response.text.data(), Response.text.size()) : IoBuffer(); + + return {.Response = std::move(Buffer), .Bytes = Response.uploaded_bytes, .ElapsedSeconds = Response.elapsed, .Success = Success}; +} + } // namespace zen diff --git a/zenserver/upstream/zen.h b/zenserver/upstream/zen.h index 8e81d1cb6..1fbfed7dd 100644 --- a/zenserver/upstream/zen.h +++ b/zenserver/upstream/zen.h @@ -28,6 +28,8 @@ class logger; namespace zen { class CbObjectWriter; +class CbObjectView; +class CbPackage; class ZenStructuredCacheClient; /** Zen mesh tracker @@ -116,6 +118,7 @@ public: ZenCacheResult GetCachePayload(std::string_view BucketId, const IoHash& Key, const IoHash& PayloadId); ZenCacheResult PutCacheRecord(std::string_view BucketId, const IoHash& Key, IoBuffer Value, ZenContentType Type); ZenCacheResult PutCachePayload(std::string_view BucketId, const IoHash& Key, const IoHash& PayloadId, IoBuffer Payload); + ZenCacheResult InvokeRpc(const CbObjectView& Request); private: inline spdlog::logger& Log() { return m_Log; } diff --git a/zenutil/cache/cachekey.cpp b/zenutil/cache/cachekey.cpp new file mode 100644 index 000000000..545b47f11 --- /dev/null +++ b/zenutil/cache/cachekey.cpp @@ -0,0 +1,9 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenutil/cache/cachekey.h> + +namespace zen { + +const CacheKey CacheKey::Empty = CacheKey{.Bucket = std::string(), .Hash = IoHash()}; + +} // namespace zen diff --git a/zenutil/cache/cachepolicy.cpp b/zenutil/cache/cachepolicy.cpp new file mode 100644 index 000000000..f718bf841 --- /dev/null +++ b/zenutil/cache/cachepolicy.cpp @@ -0,0 +1,167 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenutil/cache/cachepolicy.h> + +#include <zencore/compactbinary.h> +#include <zencore/compactbinarybuilder.h> +#include <zencore/string.h> + +namespace zen { + +using namespace std::literals; + +namespace detail { namespace cacheopt { + constexpr std::string_view Local = "local"sv; + constexpr std::string_view Remote = "remote"sv; + constexpr std::string_view Data = "data"sv; + constexpr std::string_view Meta = "meta"sv; + constexpr std::string_view Value = "value"sv; + constexpr std::string_view Attachments = "attachments"sv; +}} // namespace detail::cacheopt + +CachePolicy +ParseQueryCachePolicy(std::string_view QueryPolicy, CachePolicy Default) +{ + if (QueryPolicy.empty()) + { + return Default; + } + + CachePolicy Result = CachePolicy::None; + + ForEachStrTok(QueryPolicy, ',', [&Result](const std::string_view& Token) { + if (Token == detail::cacheopt::Local) + { + Result |= CachePolicy::QueryLocal; + } + if (Token == detail::cacheopt::Remote) + { + Result |= CachePolicy::QueryRemote; + } + return true; + }); + + return Result; +} + +CachePolicy +ParseStoreCachePolicy(std::string_view StorePolicy, CachePolicy Default) +{ + if (StorePolicy.empty()) + { + return Default; + } + + CachePolicy Result = CachePolicy::None; + + ForEachStrTok(StorePolicy, ',', [&Result](const std::string_view& Token) { + if (Token == detail::cacheopt::Local) + { + Result |= CachePolicy::StoreLocal; + } + if (Token == detail::cacheopt::Remote) + { + Result |= CachePolicy::StoreRemote; + } + return true; + }); + + return Result; +} + +CachePolicy +ParseSkipCachePolicy(std::string_view SkipPolicy, CachePolicy Default) +{ + if (SkipPolicy.empty()) + { + return Default; + } + + CachePolicy Result = CachePolicy::None; + + ForEachStrTok(SkipPolicy, ',', [&Result](const std::string_view& Token) { + if (Token == detail::cacheopt::Meta) + { + Result |= CachePolicy::SkipMeta; + } + if (Token == detail::cacheopt::Value) + { + Result |= CachePolicy::SkipValue; + } + if (Token == detail::cacheopt::Attachments) + { + Result |= CachePolicy::SkipAttachments; + } + if (Token == detail::cacheopt::Data) + { + Result |= CachePolicy::SkipData; + } + return true; + }); + + return Result; +} + +CacheRecordPolicy::CacheRecordPolicy(const CachePolicy RecordPolicy, const CachePolicy PayloadPolicy) +: m_RecordPolicy(RecordPolicy) +, m_DefaultPayloadPolicy(PayloadPolicy) +{ +} + +CachePolicy +CacheRecordPolicy::GetPayloadPolicy(const Oid& PayloadId) const +{ + if (const auto It = m_PayloadPolicies.find(PayloadId); It != m_PayloadPolicies.end()) + { + return It->second; + } + + return m_DefaultPayloadPolicy; +} + +bool +CacheRecordPolicy::Load(CbObjectView RecordPolicyObject, CacheRecordPolicy& OutRecordPolicy) +{ + using namespace std::literals; + + const uint32_t RecordPolicy = RecordPolicyObject["RecordPolicy"sv].AsUInt32(static_cast<uint32_t>(CachePolicy::Default)); + const uint32_t DefaultPayloadPolicy = + RecordPolicyObject["DefaultPayloadPolicy"sv].AsUInt32(static_cast<uint32_t>(CachePolicy::Default)); + + OutRecordPolicy.m_RecordPolicy = static_cast<CachePolicy>(RecordPolicy); + OutRecordPolicy.m_DefaultPayloadPolicy = static_cast<CachePolicy>(DefaultPayloadPolicy); + + for (CbFieldView PayloadPolicyView : RecordPolicyObject["PayloadPolicies"sv]) + { + CbObjectView PayloadPolicyObject = PayloadPolicyView.AsObjectView(); + const Oid PayloadId = PayloadPolicyObject["Id"sv].AsObjectId(); + const uint32_t PayloadPolicy = PayloadPolicyObject["Policy"sv].AsUInt32(); + + if (PayloadId != Oid::Zero && PayloadPolicy != 0) + { + OutRecordPolicy.m_PayloadPolicies.emplace(PayloadId, static_cast<CachePolicy>(PayloadPolicy)); + } + } + + return true; +} + +void +CacheRecordPolicy::Save(const CacheRecordPolicy& Policy, CbWriter& Writer) +{ + Writer << "RecordPolicy"sv << static_cast<uint32_t>(Policy.GetRecordPolicy()); + Writer << "DefaultPayloadPolicy"sv << static_cast<uint32_t>(Policy.GetDefaultPayloadPolicy()); + + if (!Policy.m_PayloadPolicies.empty()) + { + Writer.BeginArray("PayloadPolicies"sv); + for (const auto& Kv : Policy.m_PayloadPolicies) + { + Writer.AddObjectId("Id"sv, Kv.first); + Writer << "Policy"sv << static_cast<uint32_t>(Kv.second); + } + Writer.EndArray(); + } +} + +} // namespace zen diff --git a/zenutil/include/zenutil/cache/cache.h b/zenutil/include/zenutil/cache/cache.h new file mode 100644 index 000000000..1a1dd9386 --- /dev/null +++ b/zenutil/include/zenutil/cache/cache.h @@ -0,0 +1,6 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenutil/cache/cachekey.h> +#include <zenutil/cache/cachepolicy.h> diff --git a/zenutil/include/zenutil/cache/cachekey.h b/zenutil/include/zenutil/cache/cachekey.h new file mode 100644 index 000000000..fb36c7759 --- /dev/null +++ b/zenutil/include/zenutil/cache/cachekey.h @@ -0,0 +1,83 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/iohash.h> +#include <zencore/string.h> +#include <zencore/uid.h> + +#include <zenutil/cache/cachepolicy.h> + +namespace zen { + +struct CacheKey +{ + std::string Bucket; + IoHash Hash; + + static CacheKey Create(std::string_view Bucket, const IoHash& Hash) { return {.Bucket = ToLower(Bucket), .Hash = Hash}; } + + static const CacheKey Empty; +}; + +inline bool +operator==(const CacheKey& A, const CacheKey& B) +{ + return A.Bucket == B.Bucket && A.Hash == B.Hash; +} + +inline bool +operator!=(const CacheKey& A, const CacheKey& B) +{ + return A.Bucket != B.Bucket || A.Hash != B.Hash; +} + +inline bool +operator<(const CacheKey& A, const CacheKey& B) +{ + const std::string& BucketA = A.Bucket; + const std::string& BucketB = B.Bucket; + return BucketA == BucketB ? A.Hash < B.Hash : BucketA < BucketB; +} + +struct CacheChunkRequest +{ + CacheKey Key; + IoHash ChunkId; + Oid PayloadId; + uint64_t RawOffset = 0ull; + uint64_t RawSize = ~uint64_t(0); + CachePolicy Policy = CachePolicy::Default; +}; + +inline bool +operator<(const CacheChunkRequest& A, const CacheChunkRequest& B) +{ + if (A.Key < B.Key) + { + return true; + } + if (B.Key < A.Key) + { + return false; + } + if (A.ChunkId < B.ChunkId) + { + return true; + } + if (B.ChunkId < A.ChunkId) + { + return false; + } + if (A.PayloadId < B.PayloadId) + { + return true; + } + if (B.PayloadId < A.PayloadId) + { + return false; + } + return A.RawOffset < B.RawOffset; +} + +} // namespace zen diff --git a/zenutil/include/zenutil/cache/cachepolicy.h b/zenutil/include/zenutil/cache/cachepolicy.h new file mode 100644 index 000000000..5cf19238e --- /dev/null +++ b/zenutil/include/zenutil/cache/cachepolicy.h @@ -0,0 +1,109 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/string.h> +#include <zencore/uid.h> + +#include <gsl/gsl-lite.hpp> +#include <unordered_map> + +namespace zen { + +class CbObjectView; +class CbWriter; + +enum class CachePolicy : uint32_t +{ + /** A value without any flags set. */ + None = 0, + + /** Allow a cache request to query local caches. */ + QueryLocal = 1 << 0, + /** Allow a cache request to query remote caches. */ + QueryRemote = 1 << 1, + /** Allow a cache request to query any caches. */ + Query = QueryLocal | QueryRemote, + + /** Allow cache records and values to be stored in local caches. */ + StoreLocal = 1 << 2, + /** Allow cache records and values to be stored in remote caches. */ + StoreRemote = 1 << 3, + /** Allow cache records and values to be stored in any caches. */ + Store = StoreLocal | StoreRemote, + + /** Skip fetching the metadata for record requests. */ + SkipMeta = 1 << 4, + /** Skip fetching the value for record, chunk, or value requests. */ + SkipValue = 1 << 5, + /** Skip fetching the attachments for record requests. */ + SkipAttachments = 1 << 6, + /** + * Skip fetching the data for any requests. + * + * Put requests with skip flags may assume that record existence implies payload existence. + */ + SkipData = SkipMeta | SkipValue | SkipAttachments, + + /** + * Keep records in the cache for at least the duration of the session. + * + * This is a hint that the record may be accessed again in this session. This is mainly meant + * to be used when subsequent accesses will not tolerate a cache miss. + */ + KeepAlive = 1 << 7, + + /** + * Partial output will be provided with the error status when a required payload is missing. + * + * This is meant for cases when the missing payloads can be individually recovered or rebuilt + * without rebuilding the whole record. The cache automatically adds this flag when there are + * other cache stores that it may be able to recover missing payloads from. + * + * Requests for records would return records where the missing payloads have a hash and size, + * but no data. Requests for chunks or values would return the hash and size, but no data. + */ + PartialOnError = 1 << 8, + + /** Allow cache requests to query and store records and values in local caches. */ + Local = QueryLocal | StoreLocal, + /** Allow cache requests to query and store records and values in remote caches. */ + Remote = QueryRemote | StoreRemote, + + /** Allow cache requests to query and store records and values in any caches. */ + Default = Query | Store, + + /** Do not allow cache requests to query or store records and values in any caches. */ + Disable = None, +}; + +gsl_DEFINE_ENUM_BITMASK_OPERATORS(CachePolicy); + +CachePolicy ParseQueryCachePolicy(std::string_view QueryPolicy, CachePolicy Default = CachePolicy::Query); + +CachePolicy ParseStoreCachePolicy(std::string_view StorePolicy, CachePolicy Default = CachePolicy::Store); + +CachePolicy ParseSkipCachePolicy(std::string_view SkipPolicy, CachePolicy Default = CachePolicy::None); + +class CacheRecordPolicy +{ +public: + CacheRecordPolicy() = default; + CacheRecordPolicy(const CachePolicy RecordPolicy, const CachePolicy DefaultPayloadPolicy = CachePolicy::Default); + + CachePolicy GetRecordPolicy() const { return m_RecordPolicy; } + CachePolicy GetPayloadPolicy(const Oid& PayloadId) const; + CachePolicy GetDefaultPayloadPolicy() const { return m_DefaultPayloadPolicy; } + + static bool Load(CbObjectView RecordPolicyObject, CacheRecordPolicy& OutRecordPolicy); + static void Save(const CacheRecordPolicy& Policy, CbWriter& Writer); + +private: + using PayloadPolicyMap = std::unordered_map<Oid, CachePolicy, Oid::Hasher>; + + CachePolicy m_RecordPolicy = CachePolicy::Default; + CachePolicy m_DefaultPayloadPolicy = CachePolicy::Default; + PayloadPolicyMap m_PayloadPolicies; +}; + +} // namespace zen diff --git a/zenutil/zenutil.vcxproj b/zenutil/zenutil.vcxproj index 3bf6111f7..f5db7c5b0 100644 --- a/zenutil/zenutil.vcxproj +++ b/zenutil/zenutil.vcxproj @@ -97,9 +97,14 @@ </Link> </ItemDefinitionGroup> <ItemGroup> + <ClCompile Include="cache\cachekey.cpp" /> + <ClCompile Include="cache\cachepolicy.cpp" /> <ClCompile Include="zenserverprocess.cpp" /> </ItemGroup> <ItemGroup> + <ClInclude Include="include\zenutil\cache\cache.h" /> + <ClInclude Include="include\zenutil\cache\cachekey.h" /> + <ClInclude Include="include\zenutil\cache\cachepolicy.h" /> <ClInclude Include="include\zenutil\zenserverprocess.h" /> </ItemGroup> <ItemGroup> diff --git a/zenutil/zenutil.vcxproj.filters b/zenutil/zenutil.vcxproj.filters index 9952e7159..368a827c2 100644 --- a/zenutil/zenutil.vcxproj.filters +++ b/zenutil/zenutil.vcxproj.filters @@ -2,11 +2,31 @@ <Project ToolsVersion="4.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003"> <ItemGroup> <ClCompile Include="zenserverprocess.cpp" /> + <ClCompile Include="cache\cachekey.cpp"> + <Filter>cache</Filter> + </ClCompile> + <ClCompile Include="cache\cachepolicy.cpp"> + <Filter>cache</Filter> + </ClCompile> </ItemGroup> <ItemGroup> <ClInclude Include="include\zenutil\zenserverprocess.h" /> + <ClInclude Include="include\zenutil\cache\cache.h"> + <Filter>cache</Filter> + </ClInclude> + <ClInclude Include="include\zenutil\cache\cachekey.h"> + <Filter>cache</Filter> + </ClInclude> + <ClInclude Include="include\zenutil\cache\cachepolicy.h"> + <Filter>cache</Filter> + </ClInclude> </ItemGroup> <ItemGroup> <None Include="xmake.lua" /> </ItemGroup> + <ItemGroup> + <Filter Include="cache"> + <UniqueIdentifier>{a441c536-6a01-4ac4-85a0-2667c95027d0}</UniqueIdentifier> + </Filter> + </ItemGroup> </Project>
\ No newline at end of file |