aboutsummaryrefslogtreecommitdiff
path: root/src/zenstore/cache/cacherpc.cpp
diff options
context:
space:
mode:
authorzousar <[email protected]>2025-02-26 11:49:05 -0700
committerzousar <[email protected]>2025-02-26 11:49:05 -0700
commit0bbd1fb43bbd9f878a2aa326ef06f2dc503a3b3f (patch)
treeb399f6bf1b3c2fdb50596cfae71c598cd57b6f40 /src/zenstore/cache/cacherpc.cpp
parentExpand and fix unit tests for overwrite behavior (diff)
downloadzen-0bbd1fb43bbd9f878a2aa326ef06f2dc503a3b3f.tar.xz
zen-0bbd1fb43bbd9f878a2aa326ef06f2dc503a3b3f.zip
Enforce Overwrite Prevention According To Cache Policy
Overwrite with differing value should be denied if QueryLocal is not present and StoreLocal is present. Overwrite with equal value should succeed regardless of policy flags.
Diffstat (limited to 'src/zenstore/cache/cacherpc.cpp')
-rw-r--r--src/zenstore/cache/cacherpc.cpp170
1 files changed, 106 insertions, 64 deletions
diff --git a/src/zenstore/cache/cacherpc.cpp b/src/zenstore/cache/cacherpc.cpp
index cca51e63e..94072d22d 100644
--- a/src/zenstore/cache/cacherpc.cpp
+++ b/src/zenstore/cache/cacherpc.cpp
@@ -422,7 +422,19 @@ CacheRpcHandler::PutCacheRecord(PutRequestData& Request, const CbPackage* Packag
CacheValue.Value = IoBuffer(Record.GetSize());
Record.CopyTo(MutableMemoryView(CacheValue.Value.MutableData(), CacheValue.Value.GetSize()));
CacheValue.Value.SetContentType(ZenContentType::kCbObject);
- m_CacheStore.Put(Request.Context, Request.Namespace, Request.Key.Bucket, Request.Key.Hash, CacheValue, ReferencedAttachments, nullptr);
+ bool Overwrite = EnumHasAllFlags(Request.Policy.GetRecordPolicy(), CachePolicy::StoreLocal) &&
+ !EnumHasAllFlags(Request.Policy.GetRecordPolicy(), CachePolicy::QueryLocal);
+ if (!m_CacheStore.Put(Request.Context,
+ Request.Namespace,
+ Request.Key.Bucket,
+ Request.Key.Hash,
+ CacheValue,
+ ReferencedAttachments,
+ Overwrite,
+ nullptr))
+ {
+ return PutResult::Conflict;
+ }
m_CacheStats.WriteCount++;
if (!WriteAttachmentBuffers.empty())
@@ -753,18 +765,23 @@ CacheRpcHandler::HandleRpcGetCacheRecords(const CacheRequestContext& Context, Cb
EnumHasAllFlags(Request.DownstreamPolicy.GetRecordPolicy(), CachePolicy::StoreLocal) && AreDiskWritesAllowed();
if (StoreLocal)
{
+ bool Overwrite = !EnumHasAllFlags(Request.DownstreamPolicy.GetRecordPolicy(), CachePolicy::QueryLocal);
std::vector<IoHash> ReferencedAttachments;
ObjectBuffer.IterateAttachments([&ReferencedAttachments](CbFieldView HashView) {
const IoHash ValueHash = HashView.AsHash();
ReferencedAttachments.push_back(ValueHash);
});
- m_CacheStore.Put(Context,
- *Namespace,
- Key.Bucket,
- Key.Hash,
- {.Value = {Request.RecordCacheValue}},
- ReferencedAttachments,
- nullptr);
+ if (!m_CacheStore.Put(Context,
+ *Namespace,
+ Key.Bucket,
+ Key.Hash,
+ {.Value = {Request.RecordCacheValue}},
+ ReferencedAttachments,
+ Overwrite,
+ nullptr))
+ {
+ return;
+ }
m_CacheStats.WriteCount++;
}
ParseValues(Request);
@@ -962,20 +979,25 @@ CacheRpcHandler::HandleRpcPutCacheValues(const CacheRequestContext& Context, con
if (EnumHasAllFlags(Policy, CachePolicy::StoreLocal))
{
- IoBuffer Value = Chunk.GetCompressed().Flatten().AsIoBuffer();
+ bool Overwrite = !EnumHasAllFlags(Policy, CachePolicy::QueryLocal);
+ IoBuffer Value = Chunk.GetCompressed().Flatten().AsIoBuffer();
Value.SetContentType(ZenContentType::kCompressedBinary);
if (RawSize == 0)
{
RawSize = Chunk.DecodeRawSize();
}
- m_CacheStore.Put(Context,
- *Namespace,
- Key.Bucket,
- Key.Hash,
- {.Value = Value, .RawSize = RawSize, .RawHash = RawHash},
- {},
- Batch.get());
- m_CacheStats.WriteCount++;
+ bool PutSucceeded = m_CacheStore.Put(Context,
+ *Namespace,
+ Key.Bucket,
+ Key.Hash,
+ {.Value = Value, .RawSize = RawSize, .RawHash = RawHash},
+ {},
+ Overwrite,
+ Batch.get());
+ if (PutSucceeded)
+ {
+ m_CacheStats.WriteCount++;
+ }
if (Batch)
{
BatchResultIndexes.push_back(Results.size());
@@ -983,7 +1005,7 @@ CacheRpcHandler::HandleRpcPutCacheValues(const CacheRequestContext& Context, con
}
else
{
- Results.push_back(true);
+ Results.push_back(PutSucceeded);
}
TransferredSize = Chunk.GetCompressedSize();
}
@@ -1225,6 +1247,7 @@ CacheRpcHandler::HandleRpcGetCacheValues(const CacheRequestContext& Context, CbO
const bool HasData = IsCompressedBinary(Params.Value.GetContentType());
const bool SkipData = EnumHasAllFlags(Request.Policy, CachePolicy::SkipData);
const bool StoreData = EnumHasAllFlags(Request.Policy, CachePolicy::StoreLocal) && AreDiskWritesAllowed();
+ const bool Overwrite = StoreData && !EnumHasAllFlags(Request.Policy, CachePolicy::QueryLocal);
const bool IsHit = SkipData || HasData;
if (IsHit)
{
@@ -1235,14 +1258,18 @@ CacheRpcHandler::HandleRpcGetCacheValues(const CacheRequestContext& Context, CbO
if (HasData && StoreData)
{
- m_CacheStore.Put(Context,
- *Namespace,
- Request.Key.Bucket,
- Request.Key.Hash,
- ZenCacheValue{.Value = Params.Value, .RawSize = Request.RawSize, .RawHash = Request.RawHash},
- {},
- nullptr);
- m_CacheStats.WriteCount++;
+ if (m_CacheStore.Put(
+ Context,
+ *Namespace,
+ Request.Key.Bucket,
+ Request.Key.Hash,
+ ZenCacheValue{.Value = Params.Value, .RawSize = Request.RawSize, .RawHash = Request.RawHash},
+ {},
+ Overwrite,
+ nullptr))
+ {
+ m_CacheStats.WriteCount++;
+ }
}
ZEN_DEBUG("GETCACHEVALUES HIT - '{}/{}/{}' {} ({}) in {}",
@@ -1510,36 +1537,47 @@ CacheRpcHandler::GetLocalCacheRecords(const CacheRequestContext& Context,
if (!UpstreamRecordRequests.empty())
{
- const auto OnCacheRecordGetComplete = [this, Namespace, &RecordKeys, &Records, &RecordRequests, Context](
- CacheRecordGetCompleteParams&& Params) {
- if (!Params.Record)
- {
- return;
- }
- CacheKeyRequest& RecordKey = Params.Request;
- size_t RecordIndex = std::distance(RecordKeys.data(), &RecordKey);
- RecordRequests[RecordIndex]->ElapsedTimeUs += static_cast<uint64_t>(Params.ElapsedSeconds * 1000000.0);
- RecordBody& Record = Records[RecordIndex];
-
- const CacheKey& Key = RecordKey.Key;
- Record.Exists = true;
- CbObject ObjectBuffer = CbObject::Clone(Params.Record);
- Record.CacheValue = ObjectBuffer.GetBuffer().AsIoBuffer();
- Record.CacheValue.SetContentType(ZenContentType::kCbObject);
- Record.Source = Params.Source;
-
- bool StoreLocal = EnumHasAllFlags(Record.DownstreamPolicy, CachePolicy::StoreLocal) && AreDiskWritesAllowed();
- if (StoreLocal)
- {
- std::vector<IoHash> ReferencedAttachments;
- ObjectBuffer.IterateAttachments([&ReferencedAttachments](CbFieldView HashView) {
- const IoHash ValueHash = HashView.AsHash();
- ReferencedAttachments.push_back(ValueHash);
- });
- m_CacheStore.Put(Context, Namespace, Key.Bucket, Key.Hash, {.Value = Record.CacheValue}, ReferencedAttachments, nullptr);
- m_CacheStats.WriteCount++;
- }
- };
+ const auto OnCacheRecordGetComplete =
+ [this, Namespace, &RecordKeys, &Records, &RecordRequests, Context](CacheRecordGetCompleteParams&& Params) {
+ if (!Params.Record)
+ {
+ return;
+ }
+ CacheKeyRequest& RecordKey = Params.Request;
+ size_t RecordIndex = std::distance(RecordKeys.data(), &RecordKey);
+ RecordRequests[RecordIndex]->ElapsedTimeUs += static_cast<uint64_t>(Params.ElapsedSeconds * 1000000.0);
+ RecordBody& Record = Records[RecordIndex];
+
+ const CacheKey& Key = RecordKey.Key;
+ Record.Exists = true;
+ CbObject ObjectBuffer = CbObject::Clone(Params.Record);
+ Record.CacheValue = ObjectBuffer.GetBuffer().AsIoBuffer();
+ Record.CacheValue.SetContentType(ZenContentType::kCbObject);
+ Record.Source = Params.Source;
+
+ bool StoreLocal = EnumHasAllFlags(Record.DownstreamPolicy, CachePolicy::StoreLocal) && AreDiskWritesAllowed();
+ if (StoreLocal)
+ {
+ bool Overwrite = !EnumHasAllFlags(Record.DownstreamPolicy, CachePolicy::QueryLocal);
+ std::vector<IoHash> ReferencedAttachments;
+ ObjectBuffer.IterateAttachments([&ReferencedAttachments](CbFieldView HashView) {
+ const IoHash ValueHash = HashView.AsHash();
+ ReferencedAttachments.push_back(ValueHash);
+ });
+ if (!m_CacheStore.Put(Context,
+ Namespace,
+ Key.Bucket,
+ Key.Hash,
+ {.Value = Record.CacheValue},
+ ReferencedAttachments,
+ Overwrite,
+ nullptr))
+ {
+ return;
+ }
+ m_CacheStats.WriteCount++;
+ }
+ };
m_UpstreamCache.GetCacheRecords(Namespace, UpstreamRecordRequests, std::move(OnCacheRecordGetComplete));
}
@@ -1748,20 +1786,24 @@ CacheRpcHandler::GetUpstreamCacheChunks(const CacheRequestContext& Context,
bool StoreLocal = EnumHasAllFlags(Request.DownstreamPolicy, CachePolicy::StoreLocal) && AreDiskWritesAllowed();
if (StoreLocal)
{
+ bool Overwrite = !EnumHasAllFlags(Request.DownstreamPolicy, CachePolicy::QueryLocal);
if (Request.IsRecordRequest)
{
m_CidStore.AddChunk(Params.Value, Params.RawHash);
}
else
{
- m_CacheStore.Put(Context,
- Namespace,
- Key.Key.Bucket,
- Key.Key.Hash,
- {.Value = Params.Value, .RawSize = Params.RawSize, .RawHash = Params.RawHash},
- {},
- nullptr);
- m_CacheStats.WriteCount++;
+ if (m_CacheStore.Put(Context,
+ Namespace,
+ Key.Key.Bucket,
+ Key.Key.Hash,
+ {.Value = Params.Value, .RawSize = Params.RawSize, .RawHash = Params.RawHash},
+ {},
+ Overwrite,
+ nullptr))
+ {
+ m_CacheStats.WriteCount++;
+ }
}
}
if (!EnumHasAllFlags(Request.DownstreamPolicy, CachePolicy::SkipData))