diff options
| author | zousar <[email protected]> | 2025-02-26 11:49:05 -0700 |
|---|---|---|
| committer | zousar <[email protected]> | 2025-02-26 11:49:05 -0700 |
| commit | 0bbd1fb43bbd9f878a2aa326ef06f2dc503a3b3f (patch) | |
| tree | b399f6bf1b3c2fdb50596cfae71c598cd57b6f40 /src/zenstore/cache/cacherpc.cpp | |
| parent | Expand and fix unit tests for overwrite behavior (diff) | |
| download | zen-0bbd1fb43bbd9f878a2aa326ef06f2dc503a3b3f.tar.xz zen-0bbd1fb43bbd9f878a2aa326ef06f2dc503a3b3f.zip | |
Enforce Overwrite Prevention According To Cache Policy
Overwrite with differing value should be denied if QueryLocal is not present and StoreLocal is present. Overwrite with equal value should succeed regardless of policy flags.
Diffstat (limited to 'src/zenstore/cache/cacherpc.cpp')
| -rw-r--r-- | src/zenstore/cache/cacherpc.cpp | 170 |
1 files changed, 106 insertions, 64 deletions
diff --git a/src/zenstore/cache/cacherpc.cpp b/src/zenstore/cache/cacherpc.cpp index cca51e63e..94072d22d 100644 --- a/src/zenstore/cache/cacherpc.cpp +++ b/src/zenstore/cache/cacherpc.cpp @@ -422,7 +422,19 @@ CacheRpcHandler::PutCacheRecord(PutRequestData& Request, const CbPackage* Packag CacheValue.Value = IoBuffer(Record.GetSize()); Record.CopyTo(MutableMemoryView(CacheValue.Value.MutableData(), CacheValue.Value.GetSize())); CacheValue.Value.SetContentType(ZenContentType::kCbObject); - m_CacheStore.Put(Request.Context, Request.Namespace, Request.Key.Bucket, Request.Key.Hash, CacheValue, ReferencedAttachments, nullptr); + bool Overwrite = EnumHasAllFlags(Request.Policy.GetRecordPolicy(), CachePolicy::StoreLocal) && + !EnumHasAllFlags(Request.Policy.GetRecordPolicy(), CachePolicy::QueryLocal); + if (!m_CacheStore.Put(Request.Context, + Request.Namespace, + Request.Key.Bucket, + Request.Key.Hash, + CacheValue, + ReferencedAttachments, + Overwrite, + nullptr)) + { + return PutResult::Conflict; + } m_CacheStats.WriteCount++; if (!WriteAttachmentBuffers.empty()) @@ -753,18 +765,23 @@ CacheRpcHandler::HandleRpcGetCacheRecords(const CacheRequestContext& Context, Cb EnumHasAllFlags(Request.DownstreamPolicy.GetRecordPolicy(), CachePolicy::StoreLocal) && AreDiskWritesAllowed(); if (StoreLocal) { + bool Overwrite = !EnumHasAllFlags(Request.DownstreamPolicy.GetRecordPolicy(), CachePolicy::QueryLocal); std::vector<IoHash> ReferencedAttachments; ObjectBuffer.IterateAttachments([&ReferencedAttachments](CbFieldView HashView) { const IoHash ValueHash = HashView.AsHash(); ReferencedAttachments.push_back(ValueHash); }); - m_CacheStore.Put(Context, - *Namespace, - Key.Bucket, - Key.Hash, - {.Value = {Request.RecordCacheValue}}, - ReferencedAttachments, - nullptr); + if (!m_CacheStore.Put(Context, + *Namespace, + Key.Bucket, + Key.Hash, + {.Value = {Request.RecordCacheValue}}, + ReferencedAttachments, + Overwrite, + nullptr)) + { + return; + } m_CacheStats.WriteCount++; } ParseValues(Request); @@ -962,20 +979,25 @@ CacheRpcHandler::HandleRpcPutCacheValues(const CacheRequestContext& Context, con if (EnumHasAllFlags(Policy, CachePolicy::StoreLocal)) { - IoBuffer Value = Chunk.GetCompressed().Flatten().AsIoBuffer(); + bool Overwrite = !EnumHasAllFlags(Policy, CachePolicy::QueryLocal); + IoBuffer Value = Chunk.GetCompressed().Flatten().AsIoBuffer(); Value.SetContentType(ZenContentType::kCompressedBinary); if (RawSize == 0) { RawSize = Chunk.DecodeRawSize(); } - m_CacheStore.Put(Context, - *Namespace, - Key.Bucket, - Key.Hash, - {.Value = Value, .RawSize = RawSize, .RawHash = RawHash}, - {}, - Batch.get()); - m_CacheStats.WriteCount++; + bool PutSucceeded = m_CacheStore.Put(Context, + *Namespace, + Key.Bucket, + Key.Hash, + {.Value = Value, .RawSize = RawSize, .RawHash = RawHash}, + {}, + Overwrite, + Batch.get()); + if (PutSucceeded) + { + m_CacheStats.WriteCount++; + } if (Batch) { BatchResultIndexes.push_back(Results.size()); @@ -983,7 +1005,7 @@ CacheRpcHandler::HandleRpcPutCacheValues(const CacheRequestContext& Context, con } else { - Results.push_back(true); + Results.push_back(PutSucceeded); } TransferredSize = Chunk.GetCompressedSize(); } @@ -1225,6 +1247,7 @@ CacheRpcHandler::HandleRpcGetCacheValues(const CacheRequestContext& Context, CbO const bool HasData = IsCompressedBinary(Params.Value.GetContentType()); const bool SkipData = EnumHasAllFlags(Request.Policy, CachePolicy::SkipData); const bool StoreData = EnumHasAllFlags(Request.Policy, CachePolicy::StoreLocal) && AreDiskWritesAllowed(); + const bool Overwrite = StoreData && !EnumHasAllFlags(Request.Policy, CachePolicy::QueryLocal); const bool IsHit = SkipData || HasData; if (IsHit) { @@ -1235,14 +1258,18 @@ CacheRpcHandler::HandleRpcGetCacheValues(const CacheRequestContext& Context, CbO if (HasData && StoreData) { - m_CacheStore.Put(Context, - *Namespace, - Request.Key.Bucket, - Request.Key.Hash, - ZenCacheValue{.Value = Params.Value, .RawSize = Request.RawSize, .RawHash = Request.RawHash}, - {}, - nullptr); - m_CacheStats.WriteCount++; + if (m_CacheStore.Put( + Context, + *Namespace, + Request.Key.Bucket, + Request.Key.Hash, + ZenCacheValue{.Value = Params.Value, .RawSize = Request.RawSize, .RawHash = Request.RawHash}, + {}, + Overwrite, + nullptr)) + { + m_CacheStats.WriteCount++; + } } ZEN_DEBUG("GETCACHEVALUES HIT - '{}/{}/{}' {} ({}) in {}", @@ -1510,36 +1537,47 @@ CacheRpcHandler::GetLocalCacheRecords(const CacheRequestContext& Context, if (!UpstreamRecordRequests.empty()) { - const auto OnCacheRecordGetComplete = [this, Namespace, &RecordKeys, &Records, &RecordRequests, Context]( - CacheRecordGetCompleteParams&& Params) { - if (!Params.Record) - { - return; - } - CacheKeyRequest& RecordKey = Params.Request; - size_t RecordIndex = std::distance(RecordKeys.data(), &RecordKey); - RecordRequests[RecordIndex]->ElapsedTimeUs += static_cast<uint64_t>(Params.ElapsedSeconds * 1000000.0); - RecordBody& Record = Records[RecordIndex]; - - const CacheKey& Key = RecordKey.Key; - Record.Exists = true; - CbObject ObjectBuffer = CbObject::Clone(Params.Record); - Record.CacheValue = ObjectBuffer.GetBuffer().AsIoBuffer(); - Record.CacheValue.SetContentType(ZenContentType::kCbObject); - Record.Source = Params.Source; - - bool StoreLocal = EnumHasAllFlags(Record.DownstreamPolicy, CachePolicy::StoreLocal) && AreDiskWritesAllowed(); - if (StoreLocal) - { - std::vector<IoHash> ReferencedAttachments; - ObjectBuffer.IterateAttachments([&ReferencedAttachments](CbFieldView HashView) { - const IoHash ValueHash = HashView.AsHash(); - ReferencedAttachments.push_back(ValueHash); - }); - m_CacheStore.Put(Context, Namespace, Key.Bucket, Key.Hash, {.Value = Record.CacheValue}, ReferencedAttachments, nullptr); - m_CacheStats.WriteCount++; - } - }; + const auto OnCacheRecordGetComplete = + [this, Namespace, &RecordKeys, &Records, &RecordRequests, Context](CacheRecordGetCompleteParams&& Params) { + if (!Params.Record) + { + return; + } + CacheKeyRequest& RecordKey = Params.Request; + size_t RecordIndex = std::distance(RecordKeys.data(), &RecordKey); + RecordRequests[RecordIndex]->ElapsedTimeUs += static_cast<uint64_t>(Params.ElapsedSeconds * 1000000.0); + RecordBody& Record = Records[RecordIndex]; + + const CacheKey& Key = RecordKey.Key; + Record.Exists = true; + CbObject ObjectBuffer = CbObject::Clone(Params.Record); + Record.CacheValue = ObjectBuffer.GetBuffer().AsIoBuffer(); + Record.CacheValue.SetContentType(ZenContentType::kCbObject); + Record.Source = Params.Source; + + bool StoreLocal = EnumHasAllFlags(Record.DownstreamPolicy, CachePolicy::StoreLocal) && AreDiskWritesAllowed(); + if (StoreLocal) + { + bool Overwrite = !EnumHasAllFlags(Record.DownstreamPolicy, CachePolicy::QueryLocal); + std::vector<IoHash> ReferencedAttachments; + ObjectBuffer.IterateAttachments([&ReferencedAttachments](CbFieldView HashView) { + const IoHash ValueHash = HashView.AsHash(); + ReferencedAttachments.push_back(ValueHash); + }); + if (!m_CacheStore.Put(Context, + Namespace, + Key.Bucket, + Key.Hash, + {.Value = Record.CacheValue}, + ReferencedAttachments, + Overwrite, + nullptr)) + { + return; + } + m_CacheStats.WriteCount++; + } + }; m_UpstreamCache.GetCacheRecords(Namespace, UpstreamRecordRequests, std::move(OnCacheRecordGetComplete)); } @@ -1748,20 +1786,24 @@ CacheRpcHandler::GetUpstreamCacheChunks(const CacheRequestContext& Context, bool StoreLocal = EnumHasAllFlags(Request.DownstreamPolicy, CachePolicy::StoreLocal) && AreDiskWritesAllowed(); if (StoreLocal) { + bool Overwrite = !EnumHasAllFlags(Request.DownstreamPolicy, CachePolicy::QueryLocal); if (Request.IsRecordRequest) { m_CidStore.AddChunk(Params.Value, Params.RawHash); } else { - m_CacheStore.Put(Context, - Namespace, - Key.Key.Bucket, - Key.Key.Hash, - {.Value = Params.Value, .RawSize = Params.RawSize, .RawHash = Params.RawHash}, - {}, - nullptr); - m_CacheStats.WriteCount++; + if (m_CacheStore.Put(Context, + Namespace, + Key.Key.Bucket, + Key.Key.Hash, + {.Value = Params.Value, .RawSize = Params.RawSize, .RawHash = Params.RawHash}, + {}, + Overwrite, + nullptr)) + { + m_CacheStats.WriteCount++; + } } } if (!EnumHasAllFlags(Request.DownstreamPolicy, CachePolicy::SkipData)) |