diff options
| author | Dan Engelbrecht <[email protected]> | 2025-10-03 15:57:42 +0200 |
|---|---|---|
| committer | GitHub Enterprise <[email protected]> | 2025-10-03 15:57:42 +0200 |
| commit | 42a2c2582b10a598ce5ef50f7feb4bab394b8fc1 (patch) | |
| tree | 267816281dcdbeda9900a38e6863265ecc257f15 /src/zenstore/cache/cachepolicy.cpp | |
| parent | 5.7.5-pre0 (diff) | |
| download | zen-42a2c2582b10a598ce5ef50f7feb4bab394b8fc1.tar.xz zen-42a2c2582b10a598ce5ef50f7feb4bab394b8fc1.zip | |
cacherequests helpers test only (#551)
* don't use cacherequests utils in cache_cmd.cpp
* make zenutil/cacherequests code into test code helpers only
Diffstat (limited to 'src/zenstore/cache/cachepolicy.cpp')
| -rw-r--r-- | src/zenstore/cache/cachepolicy.cpp | 426 |
1 files changed, 426 insertions, 0 deletions
diff --git a/src/zenstore/cache/cachepolicy.cpp b/src/zenstore/cache/cachepolicy.cpp new file mode 100644 index 000000000..ca8a95ca1 --- /dev/null +++ b/src/zenstore/cache/cachepolicy.cpp @@ -0,0 +1,426 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenstore/cache/cachepolicy.h> + +#include <zencore/compactbinary.h> +#include <zencore/compactbinarybuilder.h> +#include <zencore/enumflags.h> +#include <zencore/string.h> + +#include <algorithm> +#include <unordered_map> + +#if ZEN_WITH_TESTS +# include <zencore/testing.h> +#endif + +namespace zen::Private { +class CacheRecordPolicyShared; +} + +namespace zen { + +using namespace std::literals; + +namespace DerivedData::Private { + + constinit char CachePolicyDelimiter = ','; + + struct CachePolicyToTextData + { + CachePolicy Policy; + std::string_view Text; + }; + + constinit CachePolicyToTextData CachePolicyToText[]{ + // Flags with multiple bits are ordered by bit count to minimize token count in the text format. + {CachePolicy::Default, "Default"sv}, + {CachePolicy::Remote, "Remote"sv}, + {CachePolicy::Local, "Local"sv}, + {CachePolicy::Store, "Store"sv}, + {CachePolicy::Query, "Query"sv}, + // Flags with only one bit can be in any order. Match the order in CachePolicy. + {CachePolicy::QueryLocal, "QueryLocal"sv}, + {CachePolicy::QueryRemote, "QueryRemote"sv}, + {CachePolicy::StoreLocal, "StoreLocal"sv}, + {CachePolicy::StoreRemote, "StoreRemote"sv}, + {CachePolicy::SkipMeta, "SkipMeta"sv}, + {CachePolicy::SkipData, "SkipData"sv}, + {CachePolicy::PartialRecord, "PartialRecord"sv}, + {CachePolicy::KeepAlive, "KeepAlive"sv}, + // None must be last because it matches every policy. + {CachePolicy::None, "None"sv}, + }; + + constinit CachePolicy CachePolicyKnownFlags = + CachePolicy::Default | CachePolicy::SkipMeta | CachePolicy::SkipData | CachePolicy::PartialRecord | CachePolicy::KeepAlive; + + StringBuilderBase& CachePolicyToString(StringBuilderBase& Builder, CachePolicy Policy) + { + // Mask out unknown flags. None will be written if no flags are known. + Policy &= CachePolicyKnownFlags; + for (const CachePolicyToTextData& Pair : CachePolicyToText) + { + if (EnumHasAllFlags(Policy, Pair.Policy)) + { + EnumRemoveFlags(Policy, Pair.Policy); + Builder << Pair.Text << CachePolicyDelimiter; + if (Policy == CachePolicy::None) + { + break; + } + } + } + Builder.RemoveSuffix(1); + return Builder; + } + + CachePolicy ParseCachePolicy(const std::string_view Text) + { + ZEN_ASSERT(!Text.empty()); // ParseCachePolicy requires a non-empty string + CachePolicy Policy = CachePolicy::None; + ForEachStrTok(Text, CachePolicyDelimiter, [&Policy, Index = int32_t(0)](const std::string_view& Token) mutable { + const int32_t EndIndex = Index; + for (; size_t(Index) < sizeof(CachePolicyToText) / sizeof(CachePolicyToText[0]); ++Index) + { + if (CachePolicyToText[Index].Text == Token) + { + Policy |= CachePolicyToText[Index].Policy; + ++Index; + return true; + } + } + for (Index = 0; Index < EndIndex; ++Index) + { + if (CachePolicyToText[Index].Text == Token) + { + Policy |= CachePolicyToText[Index].Policy; + ++Index; + return true; + } + } + return true; + }); + return Policy; + } + +} // namespace DerivedData::Private + +StringBuilderBase& +operator<<(StringBuilderBase& Builder, CachePolicy Policy) +{ + return DerivedData::Private::CachePolicyToString(Builder, Policy); +} + +CachePolicy +ParseCachePolicy(std::string_view Text) +{ + return DerivedData::Private::ParseCachePolicy(Text); +} + +CachePolicy +ConvertToUpstream(CachePolicy Policy) +{ + // Set Local flags equal to downstream's Remote flags. + // Delete Skip flags if StoreLocal is true, otherwise use the downstream value. + // Use the downstream value for all other flags. + + CachePolicy UpstreamPolicy = CachePolicy::None; + + if (EnumHasAllFlags(Policy, CachePolicy::QueryRemote)) + { + UpstreamPolicy |= CachePolicy::QueryLocal; + } + + if (EnumHasAllFlags(Policy, CachePolicy::StoreRemote)) + { + UpstreamPolicy |= CachePolicy::StoreLocal; + } + + if (!EnumHasAllFlags(Policy, CachePolicy::StoreLocal)) + { + UpstreamPolicy |= (Policy & (CachePolicy::SkipData | CachePolicy::SkipMeta)); + } + + UpstreamPolicy |= Policy & ~(CachePolicy::Local | CachePolicy::SkipData | CachePolicy::SkipMeta); + + return UpstreamPolicy; +} + +class Private::CacheRecordPolicyShared final : public Private::ICacheRecordPolicyShared +{ +public: + inline void AddValuePolicy(const CacheValuePolicy& Value) final + { + ZEN_ASSERT(Value.Id); // Failed to add value policy because the ID is null. + const auto Insert = + std::lower_bound(Values.begin(), Values.end(), Value, [](const CacheValuePolicy& Existing, const CacheValuePolicy& New) { + return Existing.Id < New.Id; + }); + ZEN_ASSERT( + !(Insert < Values.end() && + Insert->Id == Value.Id)); // Failed to add value policy with ID %s because it has an existing value policy with that ID. ") + Values.insert(Insert, Value); + } + + inline std::span<const CacheValuePolicy> GetValuePolicies() const final { return Values; } + +private: + std::vector<CacheValuePolicy> Values; +}; + +CachePolicy +CacheRecordPolicy::GetValuePolicy(const Oid& Id) const +{ + if (Shared) + { + const std::span<const CacheValuePolicy> Values = Shared->GetValuePolicies(); + const auto Iter = + std::lower_bound(Values.begin(), Values.end(), Id, [](const CacheValuePolicy& A, const Oid& B) { return A.Id < B; }); + if (Iter != Values.end() && Iter->Id == Id) + { + return Iter->Policy; + } + } + return DefaultValuePolicy; +} + +void +CacheRecordPolicy::Save(CbWriter& Writer) const +{ + Writer.BeginObject(); + // The RecordPolicy is calculated from the ValuePolicies and does not need to be saved separately. + Writer.AddString("BasePolicy"sv, WriteToString<128>(GetBasePolicy())); + if (!IsUniform()) + { + Writer.BeginArray("ValuePolicies"sv); + for (const CacheValuePolicy& Value : GetValuePolicies()) + { + Writer.BeginObject(); + Writer.AddObjectId("Id"sv, Value.Id); + Writer.AddString("Policy"sv, WriteToString<128>(Value.Policy)); + Writer.EndObject(); + } + Writer.EndArray(); + } + Writer.EndObject(); +} + +OptionalCacheRecordPolicy +CacheRecordPolicy::Load(const CbObjectView Object) +{ + std::string_view BasePolicyText = Object["BasePolicy"sv].AsString(); + if (BasePolicyText.empty()) + { + return {}; + } + + CacheRecordPolicyBuilder Builder(ParseCachePolicy(BasePolicyText)); + for (CbFieldView ValueField : Object["ValuePolicies"sv]) + { + const CbObjectView Value = ValueField.AsObjectView(); + const Oid Id = Value["Id"sv].AsObjectId(); + const std::string_view PolicyText = Value["Policy"sv].AsString(); + if (!Id || PolicyText.empty()) + { + return {}; + } + CachePolicy Policy = ParseCachePolicy(PolicyText); + if (EnumHasAnyFlags(Policy, ~CacheValuePolicy::PolicyMask)) + { + return {}; + } + Builder.AddValuePolicy(Id, Policy); + } + + return Builder.Build(); +} + +CacheRecordPolicy +CacheRecordPolicy::ConvertToUpstream() const +{ + CacheRecordPolicyBuilder Builder(zen::ConvertToUpstream(GetBasePolicy())); + for (const CacheValuePolicy& ValuePolicy : GetValuePolicies()) + { + Builder.AddValuePolicy(ValuePolicy.Id, zen::ConvertToUpstream(ValuePolicy.Policy)); + } + return Builder.Build(); +} + +void +CacheRecordPolicyBuilder::AddValuePolicy(const CacheValuePolicy& Value) +{ + ZEN_ASSERT(!EnumHasAnyFlags(Value.Policy, + ~Value.PolicyMask)); // Value policy contains flags that only make sense on the record policy. Policy: %s + if (Value.Policy == (BasePolicy & Value.PolicyMask)) + { + return; + } + if (!Shared) + { + Shared = new Private::CacheRecordPolicyShared; + } + Shared->AddValuePolicy(Value); +} + +CacheRecordPolicy +CacheRecordPolicyBuilder::Build() +{ + CacheRecordPolicy Policy(BasePolicy); + if (Shared) + { + const auto Add = [](const CachePolicy A, const CachePolicy B) { + return ((A | B) & ~CachePolicy::SkipData) | ((A & B) & CachePolicy::SkipData); + }; + const std::span<const CacheValuePolicy> Values = Shared->GetValuePolicies(); + Policy.RecordPolicy = BasePolicy; + for (const CacheValuePolicy& ValuePolicy : Values) + { + Policy.RecordPolicy = Add(Policy.RecordPolicy, ValuePolicy.Policy); + } + Policy.Shared = std::move(Shared); + } + return Policy; +} + +#if ZEN_WITH_TESTS +TEST_CASE("cachepolicy") +{ + SUBCASE("atomics serialization") + { + CachePolicy SomeAtomics[] = {CachePolicy::None, + CachePolicy::QueryLocal, + CachePolicy::StoreRemote, + CachePolicy::SkipData, + CachePolicy::KeepAlive}; + for (CachePolicy Atomic : SomeAtomics) + { + CHECK(ParseCachePolicy(WriteToString<128>(Atomic)) == Atomic); + } + // Also verify that we ignore unrecognized bits + for (CachePolicy Atomic : SomeAtomics) + { + CHECK(ParseCachePolicy(WriteToString<128>(Atomic | (CachePolicy)0x10000000)) == Atomic); + } + } + SUBCASE("aliases serialization") + { + CachePolicy SomeAliases[] = {CachePolicy::Query, CachePolicy::Local}; + for (CachePolicy Alias : SomeAliases) + { + CHECK(ParseCachePolicy(WriteToString<128>(Alias)) == Alias); + } + // Also verify that we ignore unrecognized bits + for (CachePolicy Alias : SomeAliases) + { + CHECK(ParseCachePolicy(WriteToString<128>(Alias | (CachePolicy)0x10000000)) == Alias); + } + } + SUBCASE("aliases take priority over atomics") + { + CHECK(WriteToString<128>(CachePolicy::Default).ToView() == "Default"sv); + CHECK(WriteToString<128>(CachePolicy::Query).ToView() == "Query"sv); + CHECK(WriteToString<128>(CachePolicy::Local).ToView() == "Local"sv); + } + SUBCASE("policies requiring multiple strings work") + { + char Delimiter = ','; + CachePolicy Combination = CachePolicy::SkipData | CachePolicy::QueryLocal; + CHECK(WriteToString<128>(Combination).ToView().find(Delimiter) != std::string_view::npos); + CHECK(ParseCachePolicy(WriteToString<128>(Combination)) == Combination); + } + SUBCASE("parsing invalid text") + { + CHECK(ParseCachePolicy(",,,") == CachePolicy::None); + CHECK(ParseCachePolicy("fee,fie,foo,fum") == CachePolicy::None); + CHECK(ParseCachePolicy("fee,KeepAlive,foo,fum") == CachePolicy::KeepAlive); + } +} + +TEST_CASE("cacherecordpolicy") +{ + SUBCASE("policy with no values") + { + CachePolicy Policy = CachePolicy::SkipData | CachePolicy::QueryLocal | CachePolicy::PartialRecord; + CachePolicy ValuePolicy = Policy & CacheValuePolicy::PolicyMask; + CacheRecordPolicy RecordPolicy; + CacheRecordPolicyBuilder Builder(Policy); + RecordPolicy = Builder.Build(); + SUBCASE("construct") + { + CHECK(RecordPolicy.IsUniform()); + CHECK(RecordPolicy.GetRecordPolicy() == Policy); + CHECK(RecordPolicy.GetBasePolicy() == Policy); + CHECK(RecordPolicy.GetValuePolicy(Oid::NewOid()) == ValuePolicy); + CHECK(RecordPolicy.GetValuePolicies().size() == 0); + } + SUBCASE("saveload") + { + CbWriter Writer; + RecordPolicy.Save(Writer); + CbObject Saved = Writer.Save()->AsObject(); + CacheRecordPolicy Loaded = CacheRecordPolicy::Load(Saved).Get(); + CHECK(Loaded.IsUniform()); + CHECK(Loaded.GetRecordPolicy() == Policy); + CHECK(Loaded.GetBasePolicy() == Policy); + CHECK(Loaded.GetValuePolicy(Oid::NewOid()) == ValuePolicy); + CHECK(Loaded.GetValuePolicies().size() == 0); + } + } + + SUBCASE("policy with values") + { + CachePolicy DefaultPolicy = CachePolicy::StoreRemote | CachePolicy::QueryLocal | CachePolicy::PartialRecord; + CachePolicy DefaultValuePolicy = DefaultPolicy & CacheValuePolicy::PolicyMask; + CachePolicy PartialOverlap = CachePolicy::StoreRemote; + CachePolicy NoOverlap = CachePolicy::QueryRemote; + CachePolicy UnionPolicy = DefaultPolicy | PartialOverlap | NoOverlap | CachePolicy::PartialRecord; + + CacheRecordPolicy RecordPolicy; + CacheRecordPolicyBuilder Builder(DefaultPolicy); + Oid PartialOid = Oid::NewOid(); + Oid NoOverlapOid = Oid::NewOid(); + Oid OtherOid = Oid::NewOid(); + Builder.AddValuePolicy(PartialOid, PartialOverlap); + Builder.AddValuePolicy(NoOverlapOid, NoOverlap); + RecordPolicy = Builder.Build(); + SUBCASE("construct") + { + CHECK(!RecordPolicy.IsUniform()); + CHECK(RecordPolicy.GetRecordPolicy() == UnionPolicy); + CHECK(RecordPolicy.GetBasePolicy() == DefaultPolicy); + CHECK(RecordPolicy.GetValuePolicy(PartialOid) == PartialOverlap); + CHECK(RecordPolicy.GetValuePolicy(NoOverlapOid) == NoOverlap); + CHECK(RecordPolicy.GetValuePolicy(OtherOid) == DefaultValuePolicy); + CHECK(RecordPolicy.GetValuePolicies().size() == 2); + } + SUBCASE("saveload") + { + CbWriter Writer; + RecordPolicy.Save(Writer); + CbObject Saved = Writer.Save()->AsObject(); + CacheRecordPolicy Loaded = CacheRecordPolicy::Load(Saved).Get(); + CHECK(!RecordPolicy.IsUniform()); + CHECK(RecordPolicy.GetRecordPolicy() == UnionPolicy); + CHECK(RecordPolicy.GetBasePolicy() == DefaultPolicy); + CHECK(RecordPolicy.GetValuePolicy(PartialOid) == PartialOverlap); + CHECK(RecordPolicy.GetValuePolicy(NoOverlapOid) == NoOverlap); + CHECK(RecordPolicy.GetValuePolicy(OtherOid) == DefaultValuePolicy); + CHECK(RecordPolicy.GetValuePolicies().size() == 2); + } + } + + SUBCASE("parsing invalid text") + { + OptionalCacheRecordPolicy Loaded = CacheRecordPolicy::Load(CbObject()); + CHECK(Loaded.IsNull()); + } +} +#endif + +void +cachepolicy_forcelink() +{ +} + +} // namespace zen |