diff options
| author | zousar <[email protected]> | 2022-02-01 21:33:57 -0700 |
|---|---|---|
| committer | zousar <[email protected]> | 2022-02-01 21:33:57 -0700 |
| commit | 7f12ac47b01639e8e67d04388a6f971271d25c52 (patch) | |
| tree | 53ac2f2a844785bf5543cd62a692985f0b0ff385 | |
| parent | Merge branch 'main' into non-elevated-asio (diff) | |
| parent | CacheRecordPolicy: Fix inverted PolicyMask expression that caused parsing Val... (diff) | |
| download | zen-7f12ac47b01639e8e67d04388a6f971271d25c52.tar.xz zen-7f12ac47b01639e8e67d04388a6f971271d25c52.zip | |
Merge branch 'main' into non-elevated-asio
| -rw-r--r-- | generate_projects.bat | 1 | ||||
| -rw-r--r-- | scripts/remote_build.py | 270 | ||||
| -rw-r--r-- | thirdparty/trace/trace.h | 2 | ||||
| -rw-r--r-- | xmake.lua | 2 | ||||
| -rw-r--r-- | zencore/crypto.cpp | 235 | ||||
| -rw-r--r-- | zencore/include/zencore/crypto.h | 50 | ||||
| -rw-r--r-- | zencore/include/zencore/trace.h | 1 | ||||
| -rw-r--r-- | zencore/trace.cpp | 11 | ||||
| -rw-r--r-- | zencore/zencore.cpp | 2 | ||||
| -rw-r--r-- | zenserver-test/cachepolicy-tests.cpp | 47 | ||||
| -rw-r--r-- | zenserver-test/zenserver-test.cpp | 172 | ||||
| -rw-r--r-- | zenserver/cache/structuredcache.cpp | 1431 | ||||
| -rw-r--r-- | zenserver/cache/structuredcache.h | 57 | ||||
| -rw-r--r-- | zenserver/upstream/upstreamcache.cpp | 340 | ||||
| -rw-r--r-- | zenserver/upstream/upstreamcache.h | 43 | ||||
| -rw-r--r-- | zenserver/upstream/zen.cpp | 8 | ||||
| -rw-r--r-- | zenserver/upstream/zen.h | 4 | ||||
| -rw-r--r-- | zenserver/zenserver.cpp | 4 | ||||
| -rw-r--r-- | zenutil/cache/cachepolicy.cpp | 326 | ||||
| -rw-r--r-- | zenutil/include/zenutil/cache/cachekey.h | 12 | ||||
| -rw-r--r-- | zenutil/include/zenutil/cache/cachepolicy.h | 143 |
21 files changed, 2425 insertions, 736 deletions
diff --git a/generate_projects.bat b/generate_projects.bat new file mode 100644 index 000000000..92a9a33ae --- /dev/null +++ b/generate_projects.bat @@ -0,0 +1 @@ +@xmake project --yes --kind=vsxmake2022 -m release,debug -a x64 diff --git a/scripts/remote_build.py b/scripts/remote_build.py new file mode 100644 index 000000000..c5787f635 --- /dev/null +++ b/scripts/remote_build.py @@ -0,0 +1,270 @@ +import os +import sys +import argparse +import subprocess +from pathlib import Path + +# {{{1 misc -------------------------------------------------------------------- + +# Disables output of ANSI codes if the terminal doesn't support them +if os.name == "nt": + from ctypes import windll, c_int, byref + stdout_handle = windll.kernel32.GetStdHandle(c_int(-11)) + mode = c_int(0) + windll.kernel32.GetConsoleMode(c_int(stdout_handle), byref(mode)) + ansi_on = (mode.value & 4) != 0 +else: + ansi_on = True + +#------------------------------------------------------------------------------- +def _header(*args, ansi=96): + if ansi_on: + print(f"\x1b[{ansi}m##", *args, end="") + print("\x1b[0m") + else: + print("\n##", *args) + +#------------------------------------------------------------------------------- +def _run_checked(cmd, *args, **kwargs): + _header(cmd, *args, ansi=97) + ret = subprocess.run((cmd, *args), **kwargs, bufsize=0) + if ret.returncode: + raise RuntimeError("Failed running " + str(cmd)) + +#------------------------------------------------------------------------------- +def _get_ip(): + import socket + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + try: + s.connect(("172.31.255.255", 1)) + return s.getsockname()[0] + except: + return "127.0.0.1" + finally: + s.close() + +#------------------------------------------------------------------------------- +def _find_binary(name): + name += ".exe" if os.name == "nt" else "" + for prefix in os.getenv("PATH", "").split(os.pathsep): + path = Path(prefix) / name + if path.is_file(): + return path + raise EnvironmentError(f"Unable to find '{name}' in the path") + +#------------------------------------------------------------------------------- +class _AutoKill(object): + def __init__(self, proc): + self._proc = proc + + def __del__(self): + self._proc.kill() + self._proc.wait() + pass + + + +# {{{1 local ------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +def _local(args): + # Parse arguments + desc = "Build Zen on a remote host" + parser = argparse.ArgumentParser(description=desc) + parser.add_argument("remotehost", help="") + parser.add_argument("action", default="build", nargs="?", help="") + parser.add_argument("--keyfile", default=None, help="SSH key file") + args = parser.parse_args(args) + + # Find the binaries we'll need + _header("Finding tools") + git_bin = _find_binary("git") + print(f"Using git from '{git_bin.name}' from '{git_bin.parent}'") + + def find_git_tool(git_bin, tool_name): + print(f"Locating {tool_name}...") + tool_suffix = "usr/bin/" + tool_name + tool_suffix += ".exe" if os.name == "nt" else "" + for parent in git_bin.parents: + tool_path = parent / tool_suffix + if tool_path.is_file(): + return tool_path + return _find_binary(tool_name) + + ssh_bin = find_git_tool(git_bin, "ssh") + scp_bin = find_git_tool(git_bin, "scp") + print(f"Using '{ssh_bin.name}' from '{ssh_bin.parent}'") + print(f"Using '{scp_bin.name}' from '{scp_bin.parent}'") + + # Find the Zen repository root + for parent in Path(__file__).resolve().parents: + if (parent / ".git").is_dir(): + zen_dir = parent + break; + else: + raise EnvironmentError("Unable to find '.git/' directory") + + _header("Validating remote host and credentials") + + # Validate key file. OpenSSL needs a trailing EOL, LibreSSL doesn't + if args.keyfile: + with open(args.keyfile, "rt") as key_file: + lines = [x for x in key_file] + if not lines[-1].endswith("\n"): + print("!! ERROR: key file must end with a new line") + return 1 + identity = ("-i", args.keyfile) + else: + identity = () + + # Validate remote host + host = args.remotehost + if host == "linux": host = os.getenv("ZEN_REMOTE_HOST_LINUX", "arn-lin-12345") + if host == "mac": host = os.getenv("ZEN_REMOTE_HOST_MAC", "imacpro-arn.local") + """ + keygen_bin = find_git_tool(git_bin, "ssh-keygen") + print(f"Using '{keygen_bin.name}' from '{keygen_bin.parent}'") + known_host = subprocess.run((keygen_bin, "-F", host)).returncode + if not known_host: + print("Adding", host, "as a known host") + print("ANSWER 'yes'") + known_host = subprocess.run((ssh_bin, *identity, "zenbuild@" + host, "uname -a")).returncode + raise IndexError + """ + host = "zenbuild@" + host + print(f"Using host '{host}'") + + # Start a git daemon to use as a transfer mechanism + _header("Starting a git daemon") + print("Port: 4493") + print("Base-path: ", zen_dir) + print("Host: ", _get_ip()) + daemon = subprocess.Popen( + ( git_bin, + "daemon", + "--port=4493", + "--export-all", + "--reuseaddr", + "--verbose", + "--informative-errors", + "--base-path=" + str(zen_dir) ), + #stdout = daemon_log, + stderr = subprocess.STDOUT + ) + daemon_killer = _AutoKill(daemon) + + # Run this script on the remote machine + _header("Running SSH") + + remote_zen_dir = "%s_%s" % (os.getlogin(), _get_ip()) + print(f"Using zen '~/{remote_zen_dir}'") + + print(f"Running {__file__} remotely") + with open(__file__, "rt") as self_file: + _run_checked( + ssh_bin, + *identity, + "-tA", + host, + f"python3 -u - !remote {_get_ip()} '{remote_zen_dir}' main '{args.action}'", + stdin=self_file) + + # If we're bundling, collect zip files from the remote machine + if args.action == "bundle": + build_dir = zen_dir / "build" + build_dir.mkdir(exist_ok=True) + scp_args = (*identity, host + f":zen/{remote_zen_dir}/build/*.zip", build_dir) + _run_checked("scp", *scp_args) + + + +# {{{1 remote ------------------------------------------------------------------ + +#------------------------------------------------------------------------------- +def _remote(args): + # Parse arguments + desc = "Build Zen on a remote host" + parser = argparse.ArgumentParser(description=desc) + parser.add_argument("ip", help="Host's IP address") + parser.add_argument("reponame", help="Repository name clone into and work in") + parser.add_argument("branch", help="Zen branch to operate on") + parser.add_argument("action", help="The action to do") + args = parser.parse_args(args) + + # Homeward bound and out + zen_dir = Path().home() / "zen" + os.chdir(zen_dir) + + # Mutual exclusion + """ + lock_path = zen_dir / "../.remote_lock" + try: lock_file = open(lock_path, "xb") + except: raise RuntimeError("Failed to lock", lock_path) + """ + + # Check for a clone, create it, chdir to it + _header("REMOTE:", f"Clone/pull from {args.ip}") + clone_dir = zen_dir / args.reponame + if not clone_dir.is_dir(): + _run_checked("git", "clone", f"git://{args.ip}:4493/", clone_dir) + os.chdir(clone_dir) + + _run_checked("git", "checkout", args.branch) + _run_checked("git", "pull", "-r") + + _header("REMOTE:", f"Performing action '{args.action}'") + + # Find xmake + xmake_bin = max(x for x in (zen_dir / "xmake").glob("*")) + xmake_bin /= "usr/local/bin/xmake" + + # Run xmake + xmake_env = {} + xmake_env["VCPKG_ROOT"] = zen_dir / "vcpkg" + if sys.platform == "linux": + xmake_env["CXX"] = "g++-11" + print("xmake environment:") + for key, value in xmake_env.items(): + print(" ", key, "=", value) + xmake_env.update(os.environ) + + def run_xmake(*args): + print("starting xmake...", end="\r") + _run_checked(xmake_bin, args[0], "--yes", *args[1:], env=xmake_env) + + if args.action.startswith("build"): + mode = "debug" if args.action == "build.debug" else "release" + run_xmake("config", "--mode=" + mode) + run_xmake("build") + + elif args.action == "bundle": + run_xmake("bundle") + + elif args.action == "test": + run_xmake("config", "--mode=debug") + run_xmake("test") + + elif args.action == "clean": + _run_checked("git", "reset") + _run_checked("git", "checkout", ".") + _run_checked("git", "clean", "-xdf") + + + +# {{{1 entry ------------------------------------------------------------------- +if __name__ == "__main__": + if "!remote" in sys.argv[1:2]: + ret = _remote(sys.argv[2:]) + raise SystemExit(ret) + + try: + ret = _local(sys.argv[1:]) + raise SystemExit(ret) + except: + raise + finally: + # Roundabout way to avoid orphaned git-daemon processes + if os.name == "nt": + os.system("taskkill /f /im git-daemon.exe") + +# vim: expandtab foldlevel=1 foldmethod=marker diff --git a/thirdparty/trace/trace.h b/thirdparty/trace/trace.h index caa862ffe..d7fbbb71f 100644 --- a/thirdparty/trace/trace.h +++ b/thirdparty/trace/trace.h @@ -263,7 +263,7 @@ class FChannel; EventProps_Meta const EventProps_Private = {}; \
typedef std::conditional<bIsImportant, UE::Trace::Private::FImportantLogScope, UE::Trace::Private::FLogScope>::type LogScopeType; \
explicit operator bool () const { return true; } \
- enum { EventFlags = PartialEventFlags|(EventProps_Meta::NumAuxFields ? UE::Trace::Private::FEventInfo::Flag_MaybeHasAux : 0), }; \
+ enum { EventFlags = PartialEventFlags|((EventProps_Meta::NumAuxFields != 0) ? UE::Trace::Private::FEventInfo::Flag_MaybeHasAux : 0), }; \
static_assert( \
!bIsImportant || (uint32(EventFlags) & uint32(UE::Trace::Private::FEventInfo::Flag_NoSync)), \
"Trace events flagged as Important events must be marked NoSync" \
@@ -130,7 +130,7 @@ option_end() add_define_by_config("ZEN_ENABLE_MESH", "zenmesh") option("zentrace") - set_default(false) + set_default(true) set_showmenu(true) set_description("Enable UE's Trace support") option_end() diff --git a/zencore/crypto.cpp b/zencore/crypto.cpp new file mode 100644 index 000000000..880d7b495 --- /dev/null +++ b/zencore/crypto.cpp @@ -0,0 +1,235 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencore/crypto.h> +#include <zencore/intmath.h> +#include <zencore/testing.h> + +#include <openssl/conf.h> +#include <openssl/err.h> +#include <openssl/evp.h> + +#include <string> +#include <string_view> + +#if ZEN_PLATFORM_WINDOWS +# pragma comment(lib, "crypt32.lib") +# pragma comment(lib, "ws2_32.lib") +#endif + +namespace zen { + +class NullCipher final : public SymmetricCipher +{ +public: + NullCipher() = default; + virtual ~NullCipher() = default; + + virtual bool Initialize(MemoryView, MemoryView) override final { return true; } + + virtual CipherSettings Settings() override final { return {}; } + + virtual MemoryView Encrypt(MemoryView Data, MutableMemoryView) override final { return Data; } + + virtual MemoryView Decrypt(MemoryView Data, MutableMemoryView) override final { return Data; } +}; + +std::unique_ptr<SymmetricCipher> +MakeNullCipher() +{ + return std::make_unique<NullCipher>(); +} + +#if ZEN_PLATFORM_WINDOWS +class Aes final : public SymmetricCipher +{ +public: + Aes(const EVP_CIPHER* Cipher = EVP_aes_256_cbc()) : m_Cipher(Cipher) + { + ZEN_ASSERT(Cipher); + m_KeySize = static_cast<size_t>(EVP_CIPHER_key_length(m_Cipher)); + m_InitVectorSize = static_cast<size_t>(EVP_CIPHER_iv_length(m_Cipher)); + m_BlockSize = static_cast<size_t>(EVP_CIPHER_block_size(m_Cipher)); + } + + virtual ~Aes() + { + if (m_EncryptionCtx) + { + EVP_CIPHER_CTX_free(m_EncryptionCtx); + } + + if (m_DecryptionCtx) + { + EVP_CIPHER_CTX_free(m_DecryptionCtx); + } + } + + virtual bool Initialize(MemoryView Key, MemoryView InitVector) override final + { + ZEN_ASSERT(m_EncryptionCtx == nullptr && m_DecryptionCtx == nullptr); + ZEN_ASSERT(Key.GetSize() == m_KeySize); + ZEN_ASSERT(InitVector.GetSize() == m_InitVectorSize); + + m_EncryptionCtx = EVP_CIPHER_CTX_new(); + m_DecryptionCtx = EVP_CIPHER_CTX_new(); + + if (int ErrorCode = EVP_EncryptInit_ex(m_EncryptionCtx, + m_Cipher, + nullptr, + reinterpret_cast<const unsigned char*>(Key.GetData()), + reinterpret_cast<const unsigned char*>(InitVector.GetData())); + ErrorCode != 1) + { + return false; + } + + if (int ErrorCode = EVP_DecryptInit_ex(m_DecryptionCtx, + m_Cipher, + nullptr, + reinterpret_cast<const unsigned char*>(Key.GetData()), + reinterpret_cast<const unsigned char*>(InitVector.GetData())); + ErrorCode != 1) + { + return false; + } + + return true; + } + + virtual CipherSettings Settings() override final + { + return {.KeySize = m_KeySize, .InitVectorSize = m_InitVectorSize, .BlockSize = m_BlockSize}; + } + + virtual MemoryView Encrypt(MemoryView Data, MutableMemoryView EncryptionBuffer) + { + ZEN_ASSERT(m_EncryptionCtx); + + const uint64_t InputSize = Data.GetSize(); + const uint64_t NeededSize = RoundUp(InputSize, m_BlockSize); + + if (NeededSize > EncryptionBuffer.GetSize()) + { + return MemoryView(); + } + + int TotalSize = 0; + int EncryptedSize = 0; + int ErrorCode = EVP_EncryptUpdate(m_EncryptionCtx, + reinterpret_cast<unsigned char*>(EncryptionBuffer.GetData()), + &EncryptedSize, + reinterpret_cast<const unsigned char*>(Data.GetData()), + static_cast<int>(Data.GetSize())); + + if (ErrorCode != 1) + { + return MemoryView(); + } + + TotalSize = EncryptedSize; + MutableMemoryView Remaining = EncryptionBuffer.RightChop(uint64_t(EncryptedSize)); + + ErrorCode = EVP_EncryptFinal_ex(m_EncryptionCtx, reinterpret_cast<unsigned char*>(Remaining.GetData()), &EncryptedSize); + + if (ErrorCode != 1) + { + return MemoryView(); + } + + TotalSize += EncryptedSize; + + return EncryptionBuffer.Left(uint64_t(TotalSize)); + } + + virtual MemoryView Decrypt(MemoryView Data, MutableMemoryView DecryptionBuffer) override final + { + ZEN_ASSERT(m_DecryptionCtx); + + int TotalSize = 0; + int DecryptedSize = 0; + int ErrorCode = EVP_DecryptUpdate(m_DecryptionCtx, + reinterpret_cast<unsigned char*>(DecryptionBuffer.GetData()), + &DecryptedSize, + reinterpret_cast<const unsigned char*>(Data.GetData()), + static_cast<int>(Data.GetSize())); + + if (ErrorCode != 1) + { + return MemoryView(); + } + + TotalSize = DecryptedSize; + MutableMemoryView Remaining = DecryptionBuffer.RightChop(uint64_t(DecryptedSize)); + + ErrorCode = EVP_DecryptFinal_ex(m_DecryptionCtx, reinterpret_cast<unsigned char*>(Remaining.GetData()), &DecryptedSize); + + TotalSize += DecryptedSize; + + return DecryptionBuffer.Left(uint64_t(TotalSize)); + } + +private: + const EVP_CIPHER* m_Cipher = nullptr; + EVP_CIPHER_CTX* m_EncryptionCtx = nullptr; + EVP_CIPHER_CTX* m_DecryptionCtx = nullptr; + size_t m_BlockSize = 0; + size_t m_KeySize = 0; + size_t m_InitVectorSize = 0; +}; + +std::unique_ptr<SymmetricCipher> +MakeAesCipher() +{ + return std::make_unique<Aes>(); +} + +#endif // ZEN_PLATFORM_WINDOWS + +#if ZEN_WITH_TESTS + +using namespace std::literals; + +void +crypto_forcelink() +{ +} + +TEST_CASE("crypto.aes") +{ + SUBCASE("basic") + { +# if ZEN_PLATFORM_WINDOWS + auto Cipher = std::make_unique<Aes>(); + + std::string_view PlainText = "The quick brown fox jumps over the lazy dog"sv; + + std::vector<uint8_t> Key = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}; + std::vector<uint8_t> Seed = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + + std::vector<uint8_t> EncryptionBuffer; + std::vector<uint8_t> DecryptionBuffer; + + bool Ok = Cipher->Initialize(MakeMemoryView(Key), MakeMemoryView(Seed)); + CHECK(Ok); + + EncryptionBuffer.resize(PlainText.size() + Cipher->Settings().BlockSize); + DecryptionBuffer.resize(PlainText.size() + Cipher->Settings().BlockSize); + + MemoryView EncryptedView = Cipher->Encrypt(MakeMemoryView(PlainText), MakeMutableMemoryView(EncryptionBuffer)); + CHECK(EncryptedView.IsEmpty() == false); + + MemoryView DecryptedView = Cipher->Decrypt(EncryptedView, MakeMutableMemoryView(DecryptionBuffer)); + CHECK(DecryptedView.IsEmpty() == false); + + std::string_view EncryptedDecryptedText = + std::string_view(reinterpret_cast<const char*>(DecryptedView.GetData()), DecryptedView.GetSize()); + + CHECK(EncryptedDecryptedText == PlainText); + } +# endif +} + +#endif + +} // namespace zen diff --git a/zencore/include/zencore/crypto.h b/zencore/include/zencore/crypto.h new file mode 100644 index 000000000..4d6ddba47 --- /dev/null +++ b/zencore/include/zencore/crypto.h @@ -0,0 +1,50 @@ + +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/memory.h> +#include <zencore/zencore.h> + +#include <memory> + +namespace zen { + +/** + * Experimental interface for a symmetric encryption/decryption algorithm. + * Currenlty only AES 256 bit CBC is supported using OpenSSL. + */ +class SymmetricCipher +{ +public: + virtual ~SymmetricCipher() = default; + + virtual bool Initialize(MemoryView Key, MemoryView InitVector) = 0; + + struct CipherSettings + { + size_t KeySize = 0; + size_t InitVectorSize = 0; + size_t BlockSize = 0; + }; + + virtual CipherSettings Settings() = 0; + + virtual MemoryView Encrypt(MemoryView Data, MutableMemoryView EncryptionBuffer) = 0; + + virtual MemoryView Decrypt(MemoryView Data, MutableMemoryView DecryptionBuffer) = 0; +}; + +std::unique_ptr<SymmetricCipher> MakeNullCipher(); + +#if ZEN_PLATFORM_WINDOWS +/** + * Create a new instance of a 256 bit AES CBC symmetric cipher. + * NOTE: Currenlty only tested on Windows + */ +std::unique_ptr<SymmetricCipher> MakeAesCipher(); +#endif + +void crypto_forcelink(); + +} // namespace zen diff --git a/zencore/include/zencore/trace.h b/zencore/include/zencore/trace.h index f28fdeeaf..0af490f23 100644 --- a/zencore/include/zencore/trace.h +++ b/zencore/include/zencore/trace.h @@ -22,6 +22,7 @@ enum class TraceType { File, Network, + None }; void TraceInit(const char* HostOrPath, TraceType Type); diff --git a/zencore/trace.cpp b/zencore/trace.cpp index 6a35571e6..788dcec07 100644 --- a/zencore/trace.cpp +++ b/zencore/trace.cpp @@ -12,6 +12,8 @@ void TraceInit(const char* HostOrPath, TraceType Type) { + bool EnableEvents = true; + switch (Type) { case TraceType::Network: @@ -21,6 +23,10 @@ TraceInit(const char* HostOrPath, TraceType Type) case TraceType::File: trace::WriteTo(HostOrPath); break; + + case TraceType::None: + EnableEvents = false; + break; } trace::FInitializeDesc Desc = { @@ -28,7 +34,10 @@ TraceInit(const char* HostOrPath, TraceType Type) }; trace::Initialize(Desc); - trace::ToggleChannel("cpu", true); + if (EnableEvents) + { + trace::ToggleChannel("cpu", true); + } } #endif // ZEN_WITH_TRACE diff --git a/zencore/zencore.cpp b/zencore/zencore.cpp index 19acdd1f5..8b45d273d 100644 --- a/zencore/zencore.cpp +++ b/zencore/zencore.cpp @@ -16,6 +16,7 @@ #include <zencore/compactbinarypackage.h> #include <zencore/compositebuffer.h> #include <zencore/compress.h> +#include <zencore/crypto.h> #include <zencore/filesystem.h> #include <zencore/intmath.h> #include <zencore/iobuffer.h> @@ -117,6 +118,7 @@ zencore_forcelinktests() zen::uson_forcelink(); zen::usonbuilder_forcelink(); zen::usonpackage_forcelink(); + zen::crypto_forcelink(); } #endif diff --git a/zenserver-test/cachepolicy-tests.cpp b/zenserver-test/cachepolicy-tests.cpp index 686ff818c..d3135439c 100644 --- a/zenserver-test/cachepolicy-tests.cpp +++ b/zenserver-test/cachepolicy-tests.cpp @@ -23,8 +23,7 @@ TEST_CASE("cachepolicy") CachePolicy::QueryLocal, CachePolicy::StoreRemote, CachePolicy::SkipData, - CachePolicy::KeepAlive, - CachePolicy::Disable}; + CachePolicy::KeepAlive}; for (CachePolicy Atomic : SomeAtomics) { CHECK(ParseCachePolicy(WriteToString<128>(Atomic)) == Atomic); @@ -73,7 +72,8 @@ TEST_CASE("cacherecordpolicy") { SUBCASE("policy with no values") { - CachePolicy Policy = CachePolicy::SkipData | CachePolicy::QueryLocal; + CachePolicy Policy = CachePolicy::SkipData | CachePolicy::QueryLocal | CachePolicy::PartialRecord; + CachePolicy ValuePolicy = Policy & CacheValuePolicy::PolicyMask; CacheRecordPolicy RecordPolicy; CacheRecordPolicyBuilder Builder(Policy); RecordPolicy = Builder.Build(); @@ -81,8 +81,8 @@ TEST_CASE("cacherecordpolicy") { CHECK(RecordPolicy.IsUniform()); CHECK(RecordPolicy.GetRecordPolicy() == Policy); - CHECK(RecordPolicy.GetDefaultValuePolicy() == Policy); - CHECK(RecordPolicy.GetValuePolicy(Oid::NewOid()) == Policy); + CHECK(RecordPolicy.GetBasePolicy() == Policy); + CHECK(RecordPolicy.GetValuePolicy(Oid::NewOid()) == ValuePolicy); CHECK(RecordPolicy.GetValuePolicies().size() == 0); } SUBCASE("saveload") @@ -90,21 +90,22 @@ TEST_CASE("cacherecordpolicy") CbWriter Writer; RecordPolicy.Save(Writer); CbObject Saved = Writer.Save()->AsObject(); - CacheRecordPolicy Loaded = CacheRecordPolicy::Load(Saved); + CacheRecordPolicy Loaded = CacheRecordPolicy::Load(Saved).Get(); CHECK(Loaded.IsUniform()); CHECK(Loaded.GetRecordPolicy() == Policy); - CHECK(Loaded.GetDefaultValuePolicy() == Policy); - CHECK(Loaded.GetValuePolicy(Oid::NewOid()) == Policy); + CHECK(Loaded.GetBasePolicy() == Policy); + CHECK(Loaded.GetValuePolicy(Oid::NewOid()) == ValuePolicy); CHECK(Loaded.GetValuePolicies().size() == 0); } } SUBCASE("policy with values") { - CachePolicy DefaultPolicy = CachePolicy::StoreRemote | CachePolicy::QueryLocal; + CachePolicy DefaultPolicy = CachePolicy::StoreRemote | CachePolicy::QueryLocal | CachePolicy::PartialRecord; + CachePolicy DefaultValuePolicy = DefaultPolicy & CacheValuePolicy::PolicyMask; CachePolicy PartialOverlap = CachePolicy::StoreRemote; CachePolicy NoOverlap = CachePolicy::QueryRemote; - CachePolicy UnionPolicy = DefaultPolicy | PartialOverlap | NoOverlap; + CachePolicy UnionPolicy = DefaultPolicy | PartialOverlap | NoOverlap | CachePolicy::PartialRecord; CacheRecordPolicy RecordPolicy; CacheRecordPolicyBuilder Builder(DefaultPolicy); @@ -118,10 +119,10 @@ TEST_CASE("cacherecordpolicy") { CHECK(!RecordPolicy.IsUniform()); CHECK(RecordPolicy.GetRecordPolicy() == UnionPolicy); - CHECK(RecordPolicy.GetDefaultValuePolicy() == DefaultPolicy); + CHECK(RecordPolicy.GetBasePolicy() == DefaultPolicy); CHECK(RecordPolicy.GetValuePolicy(PartialOid) == PartialOverlap); CHECK(RecordPolicy.GetValuePolicy(NoOverlapOid) == NoOverlap); - CHECK(RecordPolicy.GetValuePolicy(OtherOid) == DefaultPolicy); + CHECK(RecordPolicy.GetValuePolicy(OtherOid) == DefaultValuePolicy); CHECK(RecordPolicy.GetValuePolicies().size() == 2); } SUBCASE("saveload") @@ -129,33 +130,21 @@ TEST_CASE("cacherecordpolicy") CbWriter Writer; RecordPolicy.Save(Writer); CbObject Saved = Writer.Save()->AsObject(); - CacheRecordPolicy Loaded = CacheRecordPolicy::Load(Saved); + CacheRecordPolicy Loaded = CacheRecordPolicy::Load(Saved).Get(); CHECK(!RecordPolicy.IsUniform()); CHECK(RecordPolicy.GetRecordPolicy() == UnionPolicy); - CHECK(RecordPolicy.GetDefaultValuePolicy() == DefaultPolicy); + CHECK(RecordPolicy.GetBasePolicy() == DefaultPolicy); CHECK(RecordPolicy.GetValuePolicy(PartialOid) == PartialOverlap); CHECK(RecordPolicy.GetValuePolicy(NoOverlapOid) == NoOverlap); - CHECK(RecordPolicy.GetValuePolicy(OtherOid) == DefaultPolicy); + CHECK(RecordPolicy.GetValuePolicy(OtherOid) == DefaultValuePolicy); CHECK(RecordPolicy.GetValuePolicies().size() == 2); } } SUBCASE("parsing invalid text") { - CacheRecordPolicy Loaded = CacheRecordPolicy::Load(CbObject()); - CHECK(Loaded.IsUniform()); - CHECK(Loaded.GetRecordPolicy() == CachePolicy::Default); - CHECK(Loaded.GetDefaultValuePolicy() == CachePolicy::Default); - CHECK(Loaded.GetValuePolicy(Oid::NewOid()) == CachePolicy::Default); - CHECK(Loaded.GetValuePolicies().size() == 0); - - CachePolicy Policy = CachePolicy::SkipData; - Loaded = CacheRecordPolicy::Load(CbObject(), Policy); - CHECK(Loaded.IsUniform()); - CHECK(Loaded.GetRecordPolicy() == Policy); - CHECK(Loaded.GetDefaultValuePolicy() == Policy); - CHECK(Loaded.GetValuePolicy(Oid::NewOid()) == Policy); - CHECK(Loaded.GetValuePolicies().size() == 0); + OptionalCacheRecordPolicy Loaded = CacheRecordPolicy::Load(CbObject()); + CHECK(Loaded.IsNull()); } } diff --git a/zenserver-test/zenserver-test.cpp b/zenserver-test/zenserver-test.cpp index f51fd1504..425f43946 100644 --- a/zenserver-test/zenserver-test.cpp +++ b/zenserver-test/zenserver-test.cpp @@ -1814,7 +1814,12 @@ TEST_CASE("zcache.rpc") { using namespace std::literals; - auto CreateCacheRecord = [](const zen::CacheKey& CacheKey, size_t PayloadSize) -> zen::CbPackage { + auto AppendCacheRecord = [](CbPackage& Package, + CbWriter& Writer, + const zen::CacheKey& CacheKey, + size_t PayloadSize, + CachePolicy /* BatchDefaultPolicy */, + CachePolicy RecordPolicy) { std::vector<uint8_t> Data; Data.resize(PayloadSize); for (size_t Idx = 0; Idx < PayloadSize; ++Idx) @@ -1822,19 +1827,37 @@ TEST_CASE("zcache.rpc") Data[Idx] = Idx % 255; } - zen::CbAttachment Attachment(zen::CompressedBuffer::Compress(SharedBuffer::MakeView(Data.data(), Data.size()))); + CompressedBuffer Value = zen::CompressedBuffer::Compress(SharedBuffer::MakeView(Data.data(), Data.size())); + zen::CbAttachment Attachment(Value); - zen::CbObjectWriter CacheRecord; - CacheRecord.BeginObject("CacheKey"sv); - CacheRecord << "Bucket"sv << CacheKey.Bucket << "Hash"sv << CacheKey.Hash; - CacheRecord.EndObject(); - CacheRecord << "Data"sv << Attachment; + Writer.BeginObject(); + { + Writer.BeginObject("Record"sv); + { + Writer.BeginObject("Key"sv); + { + Writer << "Bucket"sv << CacheKey.Bucket << "Hash"sv << CacheKey.Hash; + } + Writer.EndObject(); + Writer.BeginArray("Values"sv); + { + Writer.BeginObject(); + { + Writer.AddObjectId("Id"sv, Oid::NewOid()); + Writer.AddBinaryAttachment("RawHash"sv, IoHash::FromBLAKE3(Value.GetRawHash())); + Writer.AddInteger("RawSize"sv, Value.GetRawSize()); + } + Writer.EndObject(); + } + Writer.EndArray(); + } + Writer.EndObject(); + Writer.SetName("Policy"sv); + CacheRecordPolicy(RecordPolicy).Save(Writer); + } + Writer.EndObject(); - zen::CbPackage Package; - Package.SetObject(CacheRecord.Save()); Package.AddAttachment(Attachment); - - return Package; }; auto ToIoBuffer = [](zen::CbPackage Package) -> zen::IoBuffer { @@ -1843,27 +1866,46 @@ TEST_CASE("zcache.rpc") return zen::IoBuffer(zen::IoBuffer::Clone, MemStream.Data(), MemStream.Size()); }; - auto PutCacheRecords = [&CreateCacheRecord, &ToIoBuffer](std::string_view BaseUri, - std::string_view Query, - std::string_view Bucket, - size_t Num, - size_t PayloadSize = 1024) -> std::vector<CacheKey> { + auto PutCacheRecords = + [&AppendCacheRecord, + &ToIoBuffer](std::string_view BaseUri, std::string_view Bucket, size_t Num, size_t PayloadSize = 1024) -> std::vector<CacheKey> { std::vector<zen::CacheKey> OutKeys; for (uint32_t Key = 1; Key <= Num; ++Key) { - const zen::CacheKey CacheKey = zen::CacheKey::Create(Bucket, zen::IoHash::HashBuffer(&Key, sizeof(uint32_t))); - CbPackage CacheRecord = CreateCacheRecord(CacheKey, PayloadSize); + zen::IoHash KeyHash; + ((uint32_t*)(KeyHash.Hash))[0] = Key; + const zen::CacheKey CacheKey = zen::CacheKey::Create(Bucket, KeyHash); + CbPackage Package; + CbWriter Writer; - OutKeys.push_back(CacheKey); + Writer.BeginObject(); + { + Writer << "Method"sv + << "PutCacheRecords"sv; + Writer.BeginObject("Params"sv); + { + CachePolicy BatchDefaultPolicy = CachePolicy::Default; + Writer << "DefaultPolicy"sv << WriteToString<128>(BatchDefaultPolicy); + Writer.BeginArray("Requests"sv); + { + AppendCacheRecord(Package, Writer, CacheKey, PayloadSize, BatchDefaultPolicy, CachePolicy::Default); + } + Writer.EndArray(); + } + Writer.EndObject(); + } + Writer.EndObject(); + Package.SetObject(Writer.Save().AsObject()); - IoBuffer Payload = ToIoBuffer(CacheRecord); + OutKeys.push_back(CacheKey); - cpr::Response Result = cpr::Put(cpr::Url{fmt::format("{}/{}/{}{}", BaseUri, CacheKey.Bucket, CacheKey.Hash, Query)}, - cpr::Body{(const char*)Payload.Data(), Payload.Size()}, - cpr::Header{{"Content-Type", "application/x-ue-cbpkg"}}); + IoBuffer Body = FormatPackageMessageBuffer(Package).Flatten().AsIoBuffer(); + cpr::Response Result = cpr::Post(cpr::Url{fmt::format("{}/$rpc", BaseUri)}, + cpr::Header{{"Content-Type", "application/x-ue-cbpkg"}, {"Accept", "application/x-ue-cbpkg"}}, + cpr::Body{(const char*)Body.GetData(), Body.GetSize()}); - CHECK(Result.status_code == 201); + CHECK(Result.status_code == 200); } return OutKeys; @@ -1876,27 +1918,30 @@ TEST_CASE("zcache.rpc") bool Success; }; - auto GetCacheRecords = - [](std::string_view BaseUri, std::span<zen::CacheKey> Keys, const zen::CacheRecordPolicy& Policy) -> GetCacheRecordResult { + auto GetCacheRecords = [](std::string_view BaseUri, std::span<zen::CacheKey> Keys, zen::CachePolicy Policy) -> GetCacheRecordResult { using namespace zen; CbObjectWriter Request; Request << "Method"sv << "GetCacheRecords"sv; Request.BeginObject("Params"sv); - - Request.BeginArray("CacheKeys"sv); - for (const CacheKey& Key : Keys) { - Request.BeginObject(); - Request << "Bucket"sv << Key.Bucket << "Hash"sv << Key.Hash; - Request.EndObject(); + Request << "DefaultPolicy"sv << WriteToString<128>(Policy); + Request.BeginArray("Requests"sv); + for (const CacheKey& Key : Keys) + { + Request.BeginObject(); + { + Request.BeginObject("Key"sv); + { + Request << "Bucket"sv << Key.Bucket << "Hash"sv << Key.Hash; + } + Request.EndObject(); + } + Request.EndObject(); + } + Request.EndArray(); } - Request.EndArray(); - - Request.SetName("Policy"sv); - Policy.Save(Request); - Request.EndObject(); BinaryWriter Body; @@ -1947,8 +1992,8 @@ TEST_CASE("zcache.rpc") Inst.SpawnServer(PortNumber); Inst.WaitUntilReady(); - CacheRecordPolicy Policy; - std::vector<zen::CacheKey> Keys = PutCacheRecords(BaseUri, ""sv, "mastodon"sv, 128); + CachePolicy Policy = CachePolicy::Default; + std::vector<zen::CacheKey> Keys = PutCacheRecords(BaseUri, "mastodon"sv, 128); GetCacheRecordResult Result = GetCacheRecords(BaseUri, Keys, Policy); CHECK(Result.Records.size() == Keys.size()); @@ -1957,11 +2002,18 @@ TEST_CASE("zcache.rpc") { const CacheKey& ExpectedKey = Keys[Index++]; - CbObjectView RecordObj = RecordView.AsObjectView(); - CbObjectView KeyObj = RecordObj["CacheKey"sv].AsObjectView(); - const CacheKey Key = CacheKey::Create(KeyObj["Bucket"sv].AsString(), KeyObj["Hash"].AsHash()); - const IoHash AttachmentHash = RecordObj["Data"sv].AsHash(); - const CbAttachment* Attachment = Result.Response.FindAttachment(AttachmentHash); + CbObjectView RecordObj = RecordView.AsObjectView(); + CbObjectView KeyObj = RecordObj["Key"sv].AsObjectView(); + const CacheKey Key = CacheKey::Create(KeyObj["Bucket"sv].AsString(), KeyObj["Hash"].AsHash()); + IoHash AttachmentHash; + size_t NumValues = 0; + for (CbFieldView Value : RecordObj["Values"sv]) + { + AttachmentHash = Value.AsObjectView()["RawHash"sv].AsHash(); + ++NumValues; + } + CHECK(NumValues == 1); + const CbAttachment* Attachment = Result.Response.FindAttachment(AttachmentHash); CHECK(Key == ExpectedKey); CHECK(Attachment != nullptr); @@ -1979,8 +2031,8 @@ TEST_CASE("zcache.rpc") Inst.SpawnServer(PortNumber); Inst.WaitUntilReady(); - CacheRecordPolicy Policy; - std::vector<zen::CacheKey> ExistingKeys = PutCacheRecords(BaseUri, ""sv, "mastodon"sv, 128); + CachePolicy Policy = CachePolicy::Default; + std::vector<zen::CacheKey> ExistingKeys = PutCacheRecords(BaseUri, "mastodon"sv, 128); std::vector<zen::CacheKey> Keys; for (const zen::CacheKey& Key : ExistingKeys) @@ -2004,12 +2056,18 @@ TEST_CASE("zcache.rpc") } else { - const CacheKey& ExpectedKey = ExistingKeys[KeyIndex++]; - CbObjectView RecordObj = RecordView.AsObjectView(); - zen::CacheKey Key = LoadKey(RecordObj["CacheKey"sv]); - const IoHash AttachmentHash = RecordObj["Data"sv].AsHash(); - const CbAttachment* Attachment = Result.Response.FindAttachment(AttachmentHash); - + const CacheKey& ExpectedKey = ExistingKeys[KeyIndex++]; + CbObjectView RecordObj = RecordView.AsObjectView(); + zen::CacheKey Key = LoadKey(RecordObj["Key"sv]); + IoHash AttachmentHash; + size_t NumValues = 0; + for (CbFieldView Value : RecordObj["Values"sv]) + { + AttachmentHash = Value.AsObjectView()["RawHash"sv].AsHash(); + ++NumValues; + } + CHECK(NumValues == 1); + const CbAttachment* Attachment = Result.Response.FindAttachment(AttachmentHash); CHECK(Key == ExpectedKey); CHECK(Attachment != nullptr); } @@ -2028,9 +2086,9 @@ TEST_CASE("zcache.rpc") SpawnServer(UpstreamServer, UpstreamCfg); SpawnServer(LocalServer, LocalCfg); - std::vector<zen::CacheKey> Keys = PutCacheRecords(UpstreamCfg.BaseUri, ""sv, "mastodon"sv, 4); + std::vector<zen::CacheKey> Keys = PutCacheRecords(UpstreamCfg.BaseUri, "mastodon"sv, 4); - CacheRecordPolicy Policy(CachePolicy::QueryLocal); + CachePolicy Policy = CachePolicy::QueryLocal; GetCacheRecordResult Result = GetCacheRecords(LocalCfg.BaseUri, Keys, Policy); CHECK(Result.Records.size() == Keys.size()); @@ -2053,9 +2111,9 @@ TEST_CASE("zcache.rpc") SpawnServer(UpstreamServer, UpstreamCfg); SpawnServer(LocalServer, LocalCfg); - std::vector<zen::CacheKey> Keys = PutCacheRecords(UpstreamCfg.BaseUri, ""sv, "mastodon"sv, 4); + std::vector<zen::CacheKey> Keys = PutCacheRecords(UpstreamCfg.BaseUri, "mastodon"sv, 4); - CacheRecordPolicy Policy(CachePolicy::QueryLocal | CachePolicy::QueryRemote); + CachePolicy Policy = (CachePolicy::QueryLocal | CachePolicy::QueryRemote); GetCacheRecordResult Result = GetCacheRecords(LocalCfg.BaseUri, Keys, Policy); CHECK(Result.Records.size() == Keys.size()); @@ -2064,7 +2122,7 @@ TEST_CASE("zcache.rpc") { const zen::CacheKey& ExpectedKey = Keys[Index++]; CbObjectView RecordObj = RecordView.AsObjectView(); - zen::CacheKey Key = LoadKey(RecordObj["CacheKey"sv]); + zen::CacheKey Key = LoadKey(RecordObj["Key"sv]); CHECK(Key == ExpectedKey); } } diff --git a/zenserver/cache/structuredcache.cpp b/zenserver/cache/structuredcache.cpp index 0f385116b..49e5896d1 100644 --- a/zenserver/cache/structuredcache.cpp +++ b/zenserver/cache/structuredcache.cpp @@ -15,6 +15,7 @@ #include <zencore/timer.h> #include <zencore/trace.h> #include <zenhttp/httpserver.h> +#include <zenhttp/httpshared.h> #include <zenstore/cas.h> #include <zenutil/cache/cache.h> @@ -47,6 +48,13 @@ ParseCachePolicy(const HttpServerRequest::QueryParams& QueryParams) return !PolicyText.empty() ? zen::ParseCachePolicy(PolicyText) : CachePolicy::Default; } +CacheRecordPolicy +LoadCacheRecordPolicy(CbObjectView Object, CachePolicy DefaultPolicy = CachePolicy::Default) +{ + OptionalCacheRecordPolicy Policy = CacheRecordPolicy::Load(Object); + return Policy ? std::move(Policy).Get() : CacheRecordPolicy(DefaultPolicy); +} + struct AttachmentCount { uint32_t New = 0; @@ -55,6 +63,13 @@ struct AttachmentCount uint32_t Total = 0; }; +struct PutRequestData +{ + CacheKey Key; + CbObjectView RecordObject; + CacheRecordPolicy Policy; +}; + ////////////////////////////////////////////////////////////////////////// HttpStructuredCacheService::HttpStructuredCacheService(ZenCacheStore& InCacheStore, @@ -134,13 +149,13 @@ HttpStructuredCacheService::HandleRequest(HttpServerRequest& Request) CachePolicy PolicyFromURL = ParseCachePolicy(Request.GetQueryParams()); - if (Ref.PayloadId == IoHash::Zero) + if (Ref.ValueContentId == IoHash::Zero) { return HandleCacheRecordRequest(Request, Ref, PolicyFromURL); } else { - return HandleCachePayloadRequest(Request, Ref, PolicyFromURL); + return HandleCacheValueRequest(Request, Ref, PolicyFromURL); } return; @@ -452,14 +467,14 @@ HttpStructuredCacheService::HandlePutCacheRecord(zen::HttpServerRequest& Request Body.SetContentType(ContentType); - if (ContentType == HttpContentType::kBinary) + if (ContentType == HttpContentType::kBinary || ContentType == HttpContentType::kCompressedBinary) { ZEN_DEBUG("PUT - '{}/{}' {} '{}'", Ref.BucketSegment, Ref.HashKey, NiceBytes(Body.Size()), ToString(ContentType)); m_CacheStore.Put(Ref.BucketSegment, Ref.HashKey, {.Value = Body}); if (EnumHasAllFlags(PolicyFromURL, CachePolicy::StoreRemote)) { - m_UpstreamCache.EnqueueUpstream({.Type = ZenContentType::kBinary, .Key = {Ref.BucketSegment, Ref.HashKey}}); + m_UpstreamCache.EnqueueUpstream({.Type = ContentType, .Key = {Ref.BucketSegment, Ref.HashKey}}); } Request.WriteResponse(HttpResponseCode::Created); @@ -503,8 +518,9 @@ HttpStructuredCacheService::HandlePutCacheRecord(zen::HttpServerRequest& Request if (EnumHasAllFlags(Policy, CachePolicy::StoreRemote) && !IsPartialRecord) { - m_UpstreamCache.EnqueueUpstream( - {.Type = ZenContentType::kCbObject, .Key = {Ref.BucketSegment, Ref.HashKey}, .PayloadIds = std::move(ValidAttachments)}); + m_UpstreamCache.EnqueueUpstream({.Type = ZenContentType::kCbObject, + .Key = {Ref.BucketSegment, Ref.HashKey}, + .ValueContentIds = std::move(ValidAttachments)}); } Request.WriteResponse(HttpResponseCode::Created); @@ -585,8 +601,9 @@ HttpStructuredCacheService::HandlePutCacheRecord(zen::HttpServerRequest& Request if (EnumHasAllFlags(Policy, CachePolicy::StoreRemote) && !IsPartialRecord) { - m_UpstreamCache.EnqueueUpstream( - {.Type = ZenContentType::kCbPackage, .Key = {Ref.BucketSegment, Ref.HashKey}, .PayloadIds = std::move(ValidAttachments)}); + m_UpstreamCache.EnqueueUpstream({.Type = ZenContentType::kCbPackage, + .Key = {Ref.BucketSegment, Ref.HashKey}, + .ValueContentIds = std::move(ValidAttachments)}); } Request.WriteResponse(HttpResponseCode::Created); @@ -598,16 +615,16 @@ HttpStructuredCacheService::HandlePutCacheRecord(zen::HttpServerRequest& Request } void -HttpStructuredCacheService::HandleCachePayloadRequest(HttpServerRequest& Request, const CacheRef& Ref, CachePolicy PolicyFromURL) +HttpStructuredCacheService::HandleCacheValueRequest(HttpServerRequest& Request, const CacheRef& Ref, CachePolicy PolicyFromURL) { switch (Request.RequestVerb()) { case HttpVerb::kHead: case HttpVerb::kGet: - HandleGetCachePayload(Request, Ref, PolicyFromURL); + HandleGetCacheValue(Request, Ref, PolicyFromURL); break; case HttpVerb::kPut: - HandlePutCachePayload(Request, Ref, PolicyFromURL); + HandlePutCacheValue(Request, Ref, PolicyFromURL); break; default: break; @@ -615,16 +632,17 @@ HttpStructuredCacheService::HandleCachePayloadRequest(HttpServerRequest& Request } void -HttpStructuredCacheService::HandleGetCachePayload(zen::HttpServerRequest& Request, const CacheRef& Ref, CachePolicy PolicyFromURL) +HttpStructuredCacheService::HandleGetCacheValue(zen::HttpServerRequest& Request, const CacheRef& Ref, CachePolicy PolicyFromURL) { - IoBuffer Payload = m_CidStore.FindChunkByCid(Ref.PayloadId); + IoBuffer Value = m_CidStore.FindChunkByCid(Ref.ValueContentId); bool InUpstreamCache = false; CachePolicy Policy = PolicyFromURL; - const bool QueryUpstream = !Payload && EnumHasAllFlags(Policy, CachePolicy::QueryRemote); + const bool QueryUpstream = !Value && EnumHasAllFlags(Policy, CachePolicy::QueryRemote); if (QueryUpstream) { - if (auto UpstreamResult = m_UpstreamCache.GetCachePayload({Ref.BucketSegment, Ref.HashKey}, Ref.PayloadId); UpstreamResult.Success) + if (auto UpstreamResult = m_UpstreamCache.GetCacheValue({Ref.BucketSegment, Ref.HashKey}, Ref.ValueContentId); + UpstreamResult.Success) { if (CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(UpstreamResult.Value))) { @@ -638,9 +656,9 @@ HttpStructuredCacheService::HandleGetCachePayload(zen::HttpServerRequest& Reques } } - if (!Payload) + if (!Value) { - ZEN_DEBUG("MISS - '{}/{}/{}' '{}'", Ref.BucketSegment, Ref.HashKey, Ref.PayloadId, ToString(Request.AcceptContentType())); + ZEN_DEBUG("MISS - '{}/{}/{}' '{}'", Ref.BucketSegment, Ref.HashKey, Ref.ValueContentId, ToString(Request.AcceptContentType())); m_CacheStats.MissCount++; return Request.WriteResponse(HttpResponseCode::NotFound); } @@ -648,9 +666,9 @@ HttpStructuredCacheService::HandleGetCachePayload(zen::HttpServerRequest& Reques ZEN_DEBUG("HIT - '{}/{}/{}' {} '{}' ({})", Ref.BucketSegment, Ref.HashKey, - Ref.PayloadId, - NiceBytes(Payload.Size()), - ToString(Payload.GetContentType()), + Ref.ValueContentId, + NiceBytes(Value.Size()), + ToString(Value.GetContentType()), InUpstreamCache ? "UPSTREAM" : "LOCAL"); m_CacheStats.HitCount++; @@ -665,12 +683,12 @@ HttpStructuredCacheService::HandleGetCachePayload(zen::HttpServerRequest& Reques } else { - Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, Payload); + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, Value); } } void -HttpStructuredCacheService::HandlePutCachePayload(zen::HttpServerRequest& Request, const CacheRef& Ref, CachePolicy PolicyFromURL) +HttpStructuredCacheService::HandlePutCacheValue(zen::HttpServerRequest& Request, const CacheRef& Ref, CachePolicy PolicyFromURL) { // Note: Individual cacherecord values are not propagated upstream until a valid cache record has been stored ZEN_UNUSED(PolicyFromURL); @@ -691,9 +709,11 @@ HttpStructuredCacheService::HandlePutCachePayload(zen::HttpServerRequest& Reques return Request.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "Attachments must be compressed"sv); } - if (IoHash::FromBLAKE3(Compressed.GetRawHash()) != Ref.PayloadId) + if (IoHash::FromBLAKE3(Compressed.GetRawHash()) != Ref.ValueContentId) { - return Request.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "ValueId does not match attachment hash"sv); + return Request.WriteResponse(HttpResponseCode::BadRequest, + HttpContentType::kText, + "ValueContentId does not match attachment hash"sv); } CidStore::InsertResult Result = m_CidStore.AddChunk(Compressed); @@ -701,7 +721,7 @@ HttpStructuredCacheService::HandlePutCachePayload(zen::HttpServerRequest& Reques ZEN_DEBUG("PUT - '{}/{}/{}' {} '{}' ({})", Ref.BucketSegment, Ref.HashKey, - Ref.PayloadId, + Ref.ValueContentId, NiceBytes(Body.Size()), ToString(Body.GetContentType()), Result.New ? "NEW" : "OLD"); @@ -730,13 +750,13 @@ HttpStructuredCacheService::ValidateKeyUri(HttpServerRequest& Request, CacheRef& } std::string_view HashSegment; - std::string_view PayloadSegment; + std::string_view ValueSegment; - std::string_view::size_type PayloadSplitOffset = Key.find_last_of('/'); + std::string_view::size_type ValueSplitOffset = Key.find_last_of('/'); // We know there is a slash so no need to check for npos return - if (PayloadSplitOffset == BucketSplitOffset) + if (ValueSplitOffset == BucketSplitOffset) { // Basic cache record lookup HashSegment = Key.substr(BucketSplitOffset + 1); @@ -744,8 +764,8 @@ HttpStructuredCacheService::ValidateKeyUri(HttpServerRequest& Request, CacheRef& else { // Cache record + valueid lookup - HashSegment = Key.substr(BucketSplitOffset + 1, PayloadSplitOffset - BucketSplitOffset - 1); - PayloadSegment = Key.substr(PayloadSplitOffset + 1); + HashSegment = Key.substr(BucketSplitOffset + 1, ValueSplitOffset - BucketSplitOffset - 1); + ValueSegment = Key.substr(ValueSplitOffset + 1); } if (HashSegment.size() != IoHash::StringLength) @@ -753,9 +773,9 @@ HttpStructuredCacheService::ValidateKeyUri(HttpServerRequest& Request, CacheRef& return false; } - if (!PayloadSegment.empty() && PayloadSegment.size() == IoHash::StringLength) + if (!ValueSegment.empty() && ValueSegment.size() == IoHash::StringLength) { - const bool IsOk = ParseHexBytes(PayloadSegment.data(), PayloadSegment.size(), OutRef.PayloadId.Hash); + const bool IsOk = ParseHexBytes(ValueSegment.data(), ValueSegment.size(), OutRef.ValueContentId.Hash); if (!IsOk) { @@ -764,7 +784,7 @@ HttpStructuredCacheService::ValidateKeyUri(HttpServerRequest& Request, CacheRef& } else { - OutRef.PayloadId = IoHash::Zero; + OutRef.ValueContentId = IoHash::Zero; } const bool IsOk = ParseHexBytes(HashSegment.data(), HashSegment.size(), OutRef.HashKey.Hash); @@ -787,27 +807,52 @@ HttpStructuredCacheService::HandleRpcRequest(zen::HttpServerRequest& Request) const HttpContentType ContentType = Request.RequestContentType(); const HttpContentType AcceptType = Request.AcceptContentType(); - if (ContentType != HttpContentType::kCbObject || AcceptType != HttpContentType::kCbPackage) + if ((ContentType != HttpContentType::kCbObject && ContentType != HttpContentType::kCbPackage) || + AcceptType != HttpContentType::kCbPackage) { return Request.WriteResponse(HttpResponseCode::BadRequest); } - Request.WriteResponseAsync( - [this, RpcRequest = zen::LoadCompactBinaryObject(Request.ReadPayload())](HttpServerRequest& AsyncRequest) { - const std::string_view Method = RpcRequest["Method"sv].AsString(); - if (Method == "GetCacheRecords"sv) - { - HandleRpcGetCacheRecords(AsyncRequest, RpcRequest); - } - else if (Method == "GetCacheValues"sv) - { - HandleRpcGetCachePayloads(AsyncRequest, RpcRequest); - } - else - { - AsyncRequest.WriteResponse(HttpResponseCode::BadRequest); - } - }); + Request.WriteResponseAsync([this, Body = Request.ReadPayload(), ContentType](HttpServerRequest& AsyncRequest) mutable { + CbPackage Package; + CbObjectView Object; + CbObject ObjectBuffer; + if (ContentType == HttpContentType::kCbObject) + { + ObjectBuffer = zen::LoadCompactBinaryObject(std::move(Body)); + Object = ObjectBuffer; + } + else + { + Package = ParsePackageMessage(Body); + Object = Package.GetObject(); + } + const std::string_view Method = Object["Method"sv].AsString(); + if (Method == "PutCacheRecords"sv) + { + HandleRpcPutCacheRecords(AsyncRequest, Package); + } + else if (Method == "GetCacheRecords"sv) + { + HandleRpcGetCacheRecords(AsyncRequest, Object); + } + else if (Method == "PutCacheValues"sv) + { + HandleRpcPutCacheValues(AsyncRequest, Package); + } + else if (Method == "GetCacheValues"sv) + { + HandleRpcGetCacheValues(AsyncRequest, Object); + } + else if (Method == "GetCacheChunks"sv) + { + HandleRpcGetCacheChunks(AsyncRequest, Object); + } + else + { + AsyncRequest.WriteResponse(HttpResponseCode::BadRequest); + } + }); } break; default: @@ -817,13 +862,149 @@ HttpStructuredCacheService::HandleRpcRequest(zen::HttpServerRequest& Request) } void -HttpStructuredCacheService::HandleRpcGetCacheRecords(zen::HttpServerRequest& Request, CbObjectView RpcRequest) +HttpStructuredCacheService::HandleRpcPutCacheRecords(zen::HttpServerRequest& Request, const CbPackage& BatchRequest) +{ + ZEN_TRACE_CPU("Z$::RpcPutCacheRecords"); + CbObjectView BatchObject = BatchRequest.GetObject(); + + CbObjectView Params = BatchObject["Params"sv].AsObjectView(); + CachePolicy DefaultPolicy; + + ZEN_ASSERT(BatchObject["Method"sv].AsString() == "PutCacheRecords"sv); + + std::string_view PolicyText = Params["DefaultPolicy"].AsString(); + DefaultPolicy = !PolicyText.empty() ? ParseCachePolicy(PolicyText) : CachePolicy::Default; + std::vector<bool> Results; + for (CbFieldView RequestField : Params["Requests"sv]) + { + CbObjectView RequestObject = RequestField.AsObjectView(); + CbObjectView RecordObject = RequestObject["Record"sv].AsObjectView(); + CbObjectView KeyView = RecordObject["Key"sv].AsObjectView(); + CbFieldView BucketField = KeyView["Bucket"sv]; + CbFieldView HashField = KeyView["Hash"sv]; + CacheKey Key = CacheKey::Create(BucketField.AsString(), HashField.AsHash()); + if (BucketField.HasError() || HashField.HasError() || Key.Bucket.empty()) + { + return Request.WriteResponse(HttpResponseCode::BadRequest); + } + CacheRecordPolicy Policy = LoadCacheRecordPolicy(RequestObject["Policy"sv].AsObjectView(), DefaultPolicy); + PutRequestData PutRequest{std::move(Key), RecordObject, std::move(Policy)}; + + PutResult Result = PutCacheRecord(PutRequest, &BatchRequest); + + if (Result == PutResult::Invalid) + { + return Request.WriteResponse(HttpResponseCode::BadRequest); + } + Results.push_back(Result == PutResult::Success); + } + if (Results.empty()) + { + return Request.WriteResponse(HttpResponseCode::BadRequest); + } + + CbObjectWriter ResponseObject; + ResponseObject.BeginArray("Result"sv); + for (bool Value : Results) + { + ResponseObject.AddBool(Value); + } + ResponseObject.EndArray(); + + CbPackage RpcResponse; + RpcResponse.SetObject(ResponseObject.Save()); + + BinaryWriter MemStream; + RpcResponse.Save(MemStream); + + Request.WriteResponse(HttpResponseCode::OK, + HttpContentType::kCbPackage, + IoBuffer(IoBuffer::Wrap, MemStream.GetData(), MemStream.GetSize())); +} + +HttpStructuredCacheService::PutResult +HttpStructuredCacheService::PutCacheRecord(PutRequestData& Request, const CbPackage* Package) +{ + std::vector<IoHash> ValidAttachments; + AttachmentCount Count; + CbObjectView Record = Request.RecordObject; + uint64_t RecordObjectSize = Record.GetSize(); + uint64_t TransferredSize = RecordObjectSize; + + Request.RecordObject.IterateAttachments([this, &Request, Package, &ValidAttachments, &Count, &TransferredSize](CbFieldView HashView) { + const IoHash ValueHash = HashView.AsHash(); + if (const CbAttachment* Attachment = Package ? Package->FindAttachment(ValueHash) : nullptr) + { + if (Attachment->IsCompressedBinary()) + { + CompressedBuffer Chunk = Attachment->AsCompressedBinary(); + CidStore::InsertResult InsertResult = m_CidStore.AddChunk(Chunk); + + ValidAttachments.emplace_back(InsertResult.DecompressedId); + + if (InsertResult.New) + { + Count.New++; + } + Count.Valid++; + TransferredSize += Chunk.GetCompressedSize(); + } + else + { + ZEN_WARN("PUT - '{}/{}' '{}' FAILED, attachment '{}' is not compressed", + Request.Key.Bucket, + Request.Key.Hash, + ToString(HttpContentType::kCbPackage), + ValueHash); + Count.Invalid++; + } + } + else if (m_CidStore.ContainsChunk(ValueHash)) + { + ValidAttachments.emplace_back(ValueHash); + Count.Valid++; + } + Count.Total++; + }); + + if (Count.Invalid > 0) + { + return PutResult::Invalid; + } + + ZEN_DEBUG("PUT - '{}/{}' {}, attachments '{}/{}/{}' (new/valid/total)", + Request.Key.Bucket, + Request.Key.Hash, + NiceBytes(TransferredSize), + Count.New, + Count.Valid, + Count.Total); + + ZenCacheValue CacheValue; + CacheValue.Value = IoBuffer(Record.GetSize()); + Record.CopyTo(MutableMemoryView(CacheValue.Value.MutableData(), CacheValue.Value.GetSize())); + CacheValue.Value.SetContentType(ZenContentType::kCbObject); + m_CacheStore.Put(Request.Key.Bucket, Request.Key.Hash, CacheValue); + + const bool IsPartialRecord = Count.Valid != Count.Total; + + if (EnumHasAllFlags(Request.Policy.GetRecordPolicy(), CachePolicy::StoreRemote) && !IsPartialRecord) + { + m_UpstreamCache.EnqueueUpstream( + {.Type = ZenContentType::kCbPackage, .Key = Request.Key, .ValueContentIds = std::move(ValidAttachments)}); + } + return PutResult::Success; +} + +#if BACKWARDS_COMPATABILITY_JAN2022 +void +HttpStructuredCacheService::HandleRpcGetCacheRecordsLegacy(zen::HttpServerRequest& Request, CbObjectView RpcRequest) { ZEN_TRACE_CPU("Z$::RpcGetCacheRecords"); CbPackage RpcResponse; CbObjectView Params = RpcRequest["Params"sv].AsObjectView(); - CacheRecordPolicy BatchPolicy = CacheRecordPolicy::Load(Params["Policy"sv].AsObjectView()); + CacheRecordPolicy BatchPolicy = LoadCacheRecordPolicy(Params["Policy"sv].AsObjectView()); std::vector<CacheKey> CacheKeys; std::vector<IoBuffer> CacheValues; std::vector<size_t> UpstreamRequests; @@ -849,7 +1030,8 @@ HttpStructuredCacheService::HandleRpcGetCacheRecords(zen::HttpServerRequest& Req uint32_t MissingCount = 0; uint32_t MissingReadFromUpstreamCount = 0; - if (EnumHasAllFlags(BatchPolicy.GetRecordPolicy(), CachePolicy::QueryLocal) && m_CacheStore.Get(Key.Bucket, Key.Hash, CacheValue)) + if (EnumHasAllFlags(BatchPolicy.GetRecordPolicy(), CachePolicy::QueryLocal) && m_CacheStore.Get(Key.Bucket, Key.Hash, CacheValue) && + CacheValue.Value.GetContentType() == ZenContentType::kCbObject) { CbObjectView CacheRecord(CacheValue.Value.Data()); CacheRecord.IterateAttachments( @@ -895,12 +1077,8 @@ HttpStructuredCacheService::HandleRpcGetCacheRecords(zen::HttpServerRequest& Req }); } - if ((!CacheValue.Value && EnumHasAllFlags(BatchPolicy.GetRecordPolicy(), CachePolicy::QueryRemote)) || - MissingReadFromUpstreamCount != 0) - { - UpstreamRequests.push_back(KeyIndex); - } - else if (CacheValue.Value && (MissingCount == 0 || EnumHasAllFlags(BatchPolicy.GetRecordPolicy(), CachePolicy::PartialRecord))) + // Searching upstream is not implemented in this legacy support function + if (CacheValue.Value && (MissingCount == 0 || EnumHasAllFlags(BatchPolicy.GetRecordPolicy(), CachePolicy::PartialRecord))) { ZEN_DEBUG("HIT - '{}/{}' {} '{}' (LOCAL) {}", Key.Bucket, @@ -929,116 +1107,445 @@ HttpStructuredCacheService::HandleRpcGetCacheRecords(zen::HttpServerRequest& Req ++KeyIndex; } - if (!UpstreamRequests.empty()) + CbObjectWriter ResponseObject; + + ResponseObject.BeginArray("Result"sv); + for (const IoBuffer& Value : CacheValues) + { + if (Value) + { + CbObjectView Record(Value.Data()); + ResponseObject << Record; + } + else + { + ResponseObject.AddNull(); + } + } + ResponseObject.EndArray(); + + RpcResponse.SetObject(ResponseObject.Save()); + + BinaryWriter MemStream; + RpcResponse.Save(MemStream); + + Request.WriteResponse(HttpResponseCode::OK, + HttpContentType::kCbPackage, + IoBuffer(IoBuffer::Wrap, MemStream.GetData(), MemStream.GetSize())); +} +#endif + +void +HttpStructuredCacheService::HandleRpcGetCacheRecords(zen::HttpServerRequest& HttpRequest, CbObjectView RpcRequest) +{ +#if BACKWARDS_COMPATABILITY_JAN2022 + // Backwards compatability; + if (RpcRequest["Params"sv].AsObjectView()["CacheKeys"sv]) + { + return HandleRpcGetCacheRecordsLegacy(HttpRequest, RpcRequest); + } +#endif + ZEN_TRACE_CPU("Z$::RpcGetCacheRecords"); + + CbObjectView Params = RpcRequest["Params"sv].AsObjectView(); + ZEN_ASSERT(RpcRequest["Method"sv].AsString() == "GetCacheRecords"sv); + + struct ValueRequestData + { + Oid ValueId; + IoHash ContentId; + CompressedBuffer Payload; + CachePolicy DownstreamPolicy; + bool Exists = false; + bool ReadFromUpstream = false; + }; + struct RecordRequestData { - const auto OnCacheRecordGetComplete = [this, &CacheValues, &RpcResponse, &BatchPolicy](CacheRecordGetCompleteParams&& Params) { - ZEN_ASSERT(Params.KeyIndex < CacheValues.size()); + CacheKeyRequest Upstream; + CbObjectView RecordObject; + IoBuffer RecordCacheValue; + CacheRecordPolicy DownstreamPolicy; + std::vector<ValueRequestData> Values; + bool Complete = false; + bool UsedUpstream = false; + }; + + std::string_view PolicyText = Params["DefaultPolicy"sv].AsString(); + CachePolicy DefaultPolicy = !PolicyText.empty() ? ParseCachePolicy(PolicyText) : CachePolicy::Default; + std::vector<RecordRequestData> Requests; + std::vector<size_t> UpstreamIndexes; + CbArrayView RequestsArray = Params["Requests"sv].AsArrayView(); + Requests.reserve(RequestsArray.Num()); + + auto ParseValues = [](RecordRequestData& Request) { + CbArrayView ValuesArray = Request.RecordObject["Values"sv].AsArrayView(); + Request.Values.reserve(ValuesArray.Num()); + for (CbFieldView ValueField : ValuesArray) + { + CbObjectView ValueObject = ValueField.AsObjectView(); + Oid ValueId = ValueObject["Id"sv].AsObjectId(); + CbFieldView RawHashField = ValueObject["RawHash"sv]; + IoHash RawHash = RawHashField.AsBinaryAttachment(); + if (ValueId && !RawHashField.HasError()) + { + Request.Values.push_back({ValueId, RawHash}); + Request.Values.back().DownstreamPolicy = Request.DownstreamPolicy.GetValuePolicy(ValueId); + } + } + }; - IoBuffer CacheValue; - AttachmentCount Count; + for (CbFieldView RequestField : RequestsArray) + { + RecordRequestData& Request = Requests.emplace_back(); + CbObjectView RequestObject = RequestField.AsObjectView(); + CbObjectView KeyObject = RequestObject["Key"sv].AsObjectView(); + CbFieldView BucketField = KeyObject["Bucket"sv]; + CbFieldView HashField = KeyObject["Hash"sv]; + CacheKey& Key = Request.Upstream.Key; + Key = CacheKey::Create(BucketField.AsString(), HashField.AsHash()); + if (HashField.HasError() || Key.Bucket.empty()) + { + return HttpRequest.WriteResponse(HttpResponseCode::BadRequest); + } + Request.DownstreamPolicy = LoadCacheRecordPolicy(RequestObject["Policy"sv].AsObjectView(), DefaultPolicy); + const CacheRecordPolicy& Policy = Request.DownstreamPolicy; - if (Params.Record) + ZenCacheValue CacheValue; + bool NeedUpstreamAttachment = false; + bool FoundLocalInvalid = false; + ZenCacheValue RecordCacheValue; + + if (EnumHasAllFlags(Policy.GetRecordPolicy(), CachePolicy::QueryLocal) && m_CacheStore.Get(Key.Bucket, Key.Hash, RecordCacheValue)) + { + Request.RecordCacheValue = std::move(RecordCacheValue.Value); + if (Request.RecordCacheValue.GetContentType() != ZenContentType::kCbObject) + { + FoundLocalInvalid = true; + } + else { - Params.Record.IterateAttachments([this, &RpcResponse, &Params, &Count, &BatchPolicy](CbFieldView HashView) { - CachePolicy ValuePolicy = BatchPolicy.GetRecordPolicy(); - bool FoundInUpstream = false; - if (EnumHasAllFlags(ValuePolicy, CachePolicy::QueryRemote)) + Request.RecordObject = CbObjectView(Request.RecordCacheValue.GetData()); + ParseValues(Request); + + Request.Complete = true; + for (ValueRequestData& Value : Request.Values) + { + CachePolicy ValuePolicy = Value.DownstreamPolicy; + if (!EnumHasAllFlags(ValuePolicy, CachePolicy::QueryLocal)) { - if (const CbAttachment* Attachment = Params.Package.FindAttachment(HashView.AsHash())) + // A value that is requested without the Query flag (such as None/Disable) counts as existing, because we + // didn't ask for it and thus the record is complete in its absence. + if (!EnumHasAllFlags(ValuePolicy, CachePolicy::QueryRemote)) { - FoundInUpstream = true; - if (CompressedBuffer Compressed = Attachment->AsCompressedBinary()) - { - FoundInUpstream = true; - if (EnumHasAllFlags(ValuePolicy, CachePolicy::StoreLocal)) - { - auto InsertResult = m_CidStore.AddChunk(Compressed); - if (InsertResult.New) - { - Count.New++; - } - } - Count.Valid++; - - if (!EnumHasAllFlags(ValuePolicy, CachePolicy::SkipData)) - { - RpcResponse.AddAttachment(CbAttachment(Compressed)); - } - } - else + Value.Exists = true; + } + else + { + NeedUpstreamAttachment = true; + Value.ReadFromUpstream = true; + Request.Complete = false; + } + } + else if (EnumHasAllFlags(ValuePolicy, CachePolicy::SkipData)) + { + if (m_CidStore.ContainsChunk(Value.ContentId)) + { + Value.Exists = true; + } + else + { + if (EnumHasAllFlags(ValuePolicy, CachePolicy::QueryRemote)) { - ZEN_DEBUG("Uncompressed value '{}' from upstream cache record '{}/{}'", - HashView.AsHash(), - Params.Key.Bucket, - Params.Key.Hash); - Count.Invalid++; + NeedUpstreamAttachment = true; + Value.ReadFromUpstream = true; } + Request.Complete = false; } } - if (!FoundInUpstream && EnumHasAllFlags(ValuePolicy, CachePolicy::QueryLocal) && - m_CidStore.ContainsChunk(HashView.AsHash())) + else { - // We added the attachment for this Value in the local loop before calling m_UpstreamCache - Count.Valid++; + if (IoBuffer Chunk = m_CidStore.FindChunkByCid(Value.ContentId)) + { + ZEN_ASSERT(Chunk.GetSize() > 0); + Value.Payload = CompressedBuffer::FromCompressed(SharedBuffer(Chunk)); + Value.Exists = true; + } + else + { + if (EnumHasAllFlags(ValuePolicy, CachePolicy::QueryRemote)) + { + NeedUpstreamAttachment = true; + Value.ReadFromUpstream = true; + } + Request.Complete = false; + } } - Count.Total++; - }); + } + } + } + if (!Request.Complete) + { + bool NeedUpstreamRecord = + !Request.RecordObject && !FoundLocalInvalid && EnumHasAllFlags(Policy.GetRecordPolicy(), CachePolicy::QueryRemote); + if (NeedUpstreamRecord || NeedUpstreamAttachment) + { + UpstreamIndexes.push_back(Requests.size() - 1); + } + } + } + if (Requests.empty()) + { + return HttpRequest.WriteResponse(HttpResponseCode::BadRequest); + } - if ((Count.Valid == Count.Total) || EnumHasAllFlags(BatchPolicy.GetRecordPolicy(), CachePolicy::PartialRecord)) + if (!UpstreamIndexes.empty()) + { + std::vector<CacheKeyRequest*> UpstreamRequests; + UpstreamRequests.reserve(UpstreamIndexes.size()); + for (size_t Index : UpstreamIndexes) + { + RecordRequestData& Request = Requests[Index]; + UpstreamRequests.push_back(&Request.Upstream); + + if (Request.Values.size()) + { + // We will be returning the local object and know all the value Ids that exist in it + // Convert all their Downstream Values to upstream values, and add SkipData to any ones that we already have. + CachePolicy UpstreamBasePolicy = ConvertToUpstream(Request.DownstreamPolicy.GetBasePolicy()) | CachePolicy::SkipMeta; + CacheRecordPolicyBuilder Builder(UpstreamBasePolicy); + for (ValueRequestData& Value : Request.Values) { - CacheValue = CbObject::Clone(Params.Record).GetBuffer().AsIoBuffer(); + CachePolicy UpstreamPolicy = ConvertToUpstream(Value.DownstreamPolicy); + UpstreamPolicy |= !Value.ReadFromUpstream ? CachePolicy::SkipData : CachePolicy::None; + Builder.AddValuePolicy(Value.ValueId, UpstreamPolicy); } + Request.Upstream.Policy = Builder.Build(); + } + else + { + // We don't know which Values exist in the Record; ask the upstrem for all values that the client wants, + // and convert the CacheRecordPolicy to an upstream policy + Request.Upstream.Policy = Request.DownstreamPolicy.ConvertToUpstream(); + } + } + + const auto OnCacheRecordGetComplete = [this, &ParseValues](CacheRecordGetCompleteParams&& Params) { + if (!Params.Record) + { + return; } - if (CacheValue) + RecordRequestData& Request = + *reinterpret_cast<RecordRequestData*>(reinterpret_cast<char*>(&Params.Request) - offsetof(RecordRequestData, Upstream)); + const CacheKey& Key = Request.Upstream.Key; + if (!Request.RecordObject) { - ZEN_DEBUG("HIT - '{}/{}' {} '{}' attachments '{}/{}/{}' (new/valid/total) (UPSTREAM)", - Params.Key.Bucket, - Params.Key.Hash, - NiceBytes(CacheValue.GetSize()), - ToString(HttpContentType::kCbPackage), - Count.New, - Count.Valid, - Count.Total); - - CacheValue.SetContentType(ZenContentType::kCbObject); - CacheValues[Params.KeyIndex] = CacheValue; - if (EnumHasAllFlags(BatchPolicy.GetRecordPolicy(), CachePolicy::StoreLocal)) + CbObject ObjectBuffer = CbObject::Clone(Params.Record); + Request.RecordCacheValue = ObjectBuffer.GetBuffer().AsIoBuffer(); + Request.RecordCacheValue.SetContentType(ZenContentType::kCbObject); + Request.RecordObject = ObjectBuffer; + if (EnumHasAllFlags(Request.DownstreamPolicy.GetRecordPolicy(), CachePolicy::StoreLocal)) { - m_CacheStore.Put(Params.Key.Bucket, Params.Key.Hash, {.Value = CacheValue}); + m_CacheStore.Put(Key.Bucket, Key.Hash, {.Value = {Request.RecordCacheValue}}); } - - m_CacheStats.HitCount++; - m_CacheStats.UpstreamHitCount++; + ParseValues(Request); + Request.UsedUpstream = true; } - else + + Request.Complete = true; + for (ValueRequestData& Value : Request.Values) { - const bool IsPartial = Count.Valid != Count.Total; - ZEN_DEBUG("MISS - '{}/{}' {}", Params.Key.Bucket, Params.Key.Hash, IsPartial ? "(partial)"sv : ""sv); - m_CacheStats.MissCount++; + if (Value.Exists) + { + continue; + } + CachePolicy ValuePolicy = Value.DownstreamPolicy; + if (!EnumHasAllFlags(ValuePolicy, CachePolicy::QueryRemote)) + { + Request.Complete = false; + continue; + } + if (!EnumHasAllFlags(ValuePolicy, CachePolicy::SkipData) || EnumHasAllFlags(ValuePolicy, CachePolicy::StoreLocal)) + { + if (const CbAttachment* Attachment = Params.Package.FindAttachment(Value.ContentId)) + { + if (CompressedBuffer Compressed = Attachment->AsCompressedBinary()) + { + Request.UsedUpstream = true; + Value.Exists = true; + if (EnumHasAllFlags(ValuePolicy, CachePolicy::StoreLocal)) + { + m_CidStore.AddChunk(Compressed); + } + if (!EnumHasAllFlags(ValuePolicy, CachePolicy::SkipData)) + { + Value.Payload = Compressed; + } + } + else + { + ZEN_DEBUG("Uncompressed value '{}' from upstream cache record '{}/{}'", Value.ContentId, Key.Bucket, Key.Hash); + } + } + if (!Value.Exists && !EnumHasAllFlags(ValuePolicy, CachePolicy::SkipData)) + { + Request.Complete = false; + } + // Request.Complete does not need to be set to false for upstream SkipData attachments. + // In the PartialRecord==false case, the upstream will have failed the entire record if any SkipData attachment + // didn't exist and we will not get here. In the PartialRecord==true case, we do not need to inform the client of + // any missing SkipData attachments. + } } }; - m_UpstreamCache.GetCacheRecords(CacheKeys, UpstreamRequests, BatchPolicy, std::move(OnCacheRecordGetComplete)); + m_UpstreamCache.GetCacheRecords(UpstreamRequests, std::move(OnCacheRecordGetComplete)); } + CbPackage ResponsePackage; CbObjectWriter ResponseObject; ResponseObject.BeginArray("Result"sv); - for (const IoBuffer& Value : CacheValues) + for (RecordRequestData& Request : Requests) { - if (Value) + const CacheKey& Key = Request.Upstream.Key; + if (Request.Complete || + (Request.RecordObject && EnumHasAllFlags(Request.DownstreamPolicy.GetRecordPolicy(), CachePolicy::PartialRecord))) { - CbObjectView Record(Value.Data()); - ResponseObject << Record; + ResponseObject << Request.RecordObject; + for (ValueRequestData& Value : Request.Values) + { + if (!EnumHasAllFlags(Value.DownstreamPolicy, CachePolicy::SkipData) && Value.Payload) + { + ResponsePackage.AddAttachment(CbAttachment(Value.Payload)); + } + } + + ZEN_DEBUG("HIT - '{}/{}' {}{}{}", + Key.Bucket, + Key.Hash, + NiceBytes(Request.RecordCacheValue.Size()), + Request.Complete ? ""sv : " (PARTIAL)"sv, + Request.UsedUpstream ? " (UPSTREAM)"sv : ""sv); + m_CacheStats.HitCount++; + m_CacheStats.UpstreamHitCount += Request.UsedUpstream ? 1 : 0; } else { ResponseObject.AddNull(); + + if (!EnumHasAnyFlags(Request.DownstreamPolicy.GetRecordPolicy(), CachePolicy::Query)) + { + // If they requested no query, do not record this as a miss + ZEN_DEBUG("DISABLEDQUERY - '{}/{}'", Key.Bucket, Key.Hash); + } + else + { + ZEN_DEBUG("MISS - '{}/{}' {}", Key.Bucket, Key.Hash, Request.RecordObject ? ""sv : "(PARTIAL)"sv); + m_CacheStats.MissCount++; + } + } + } + ResponseObject.EndArray(); + ResponsePackage.SetObject(ResponseObject.Save()); + + BinaryWriter MemStream; + ResponsePackage.Save(MemStream); + + HttpRequest.WriteResponse(HttpResponseCode::OK, + HttpContentType::kCbPackage, + IoBuffer(IoBuffer::Wrap, MemStream.GetData(), MemStream.GetSize())); +} + +void +HttpStructuredCacheService::HandleRpcPutCacheValues(zen::HttpServerRequest& Request, const CbPackage& BatchRequest) +{ + ZEN_TRACE_CPU("Z$::RpcPutCacheValues"); + CbObjectView BatchObject = BatchRequest.GetObject(); + + CbObjectView Params = BatchObject["Params"sv].AsObjectView(); + + ZEN_ASSERT(BatchObject["Method"sv].AsString() == "PutCacheValues"sv); + + std::string_view PolicyText = Params["DefaultPolicy"].AsString(); + CachePolicy DefaultPolicy = !PolicyText.empty() ? ParseCachePolicy(PolicyText) : CachePolicy::Default; + std::vector<bool> Results; + for (CbFieldView RequestField : Params["Requests"sv]) + { + CbObjectView RequestObject = RequestField.AsObjectView(); + CbObjectView KeyView = RequestObject["Key"sv].AsObjectView(); + CbFieldView BucketField = KeyView["Bucket"sv]; + CbFieldView HashField = KeyView["Hash"sv]; + CacheKey Key = CacheKey::Create(BucketField.AsString(), HashField.AsHash()); + if (BucketField.HasError() || HashField.HasError() || Key.Bucket.empty()) + { + return Request.WriteResponse(HttpResponseCode::BadRequest); } + PolicyText = RequestObject["Policy"sv].AsString(); + CachePolicy Policy = !PolicyText.empty() ? ParseCachePolicy(PolicyText) : DefaultPolicy; + IoHash RawHash = RequestObject["RawHash"sv].AsBinaryAttachment(); + bool Succeeded = false; + uint64_t TransferredSize = 0; + + if (const CbAttachment* Attachment = BatchRequest.FindAttachment(RawHash)) + { + if (Attachment->IsCompressedBinary()) + { + CompressedBuffer Chunk = Attachment->AsCompressedBinary(); + if (EnumHasAllFlags(Policy, CachePolicy::StoreRemote)) + { + // TODO: Implement upstream puts of CacheValues with StoreLocal == false. + // Currently ProcessCacheRecord requires that the value exist in the local cache to put it upstream. + Policy |= CachePolicy::StoreLocal; + } + + if (EnumHasAllFlags(Policy, CachePolicy::StoreLocal)) + { + IoBuffer Value = Chunk.GetCompressed().Flatten().AsIoBuffer(); + Value.SetContentType(ZenContentType::kCompressedBinary); + m_CacheStore.Put(Key.Bucket, Key.Hash, {.Value = Value}); + TransferredSize = Chunk.GetCompressedSize(); + } + Succeeded = true; + } + else + { + ZEN_WARN("PUTCACHEVALUES - '{}/{}/{}' FAILED, value is not compressed", Key.Bucket, Key.Hash, RawHash); + return Request.WriteResponse(HttpResponseCode::BadRequest); + } + } + else if (EnumHasAllFlags(Policy, CachePolicy::QueryLocal)) + { + ZenCacheValue ExistingValue; + if (m_CacheStore.Get(Key.Bucket, Key.Hash, ExistingValue) && IsCompressedBinary(ExistingValue.Value.GetContentType())) + { + Succeeded = true; + } + } + // We do not search the Upstream. No data in a put means the caller is probing for whether they need to do a heavy put. + // If it doesn't exist locally they should do the heavy put rather than having us fetch it from upstream. + + if (Succeeded && EnumHasAllFlags(Policy, CachePolicy::StoreRemote)) + { + m_UpstreamCache.EnqueueUpstream({.Type = ZenContentType::kBinary, .Key = Key}); + } + Results.push_back(Succeeded); + ZEN_DEBUG("PUTCACHEVALUES - '{}/{}' {}, '{}'", Key.Bucket, Key.Hash, NiceBytes(TransferredSize), Succeeded ? "Added"sv : "Invalid"); + } + if (Results.empty()) + { + return Request.WriteResponse(HttpResponseCode::BadRequest); + } + + CbObjectWriter ResponseObject; + ResponseObject.BeginArray("Result"sv); + for (bool Value : Results) + { + ResponseObject.AddBool(Value); } ResponseObject.EndArray(); + CbPackage RpcResponse; RpcResponse.SetObject(ResponseObject.Save()); BinaryWriter MemStream; @@ -1050,216 +1557,610 @@ HttpStructuredCacheService::HandleRpcGetCacheRecords(zen::HttpServerRequest& Req } void -HttpStructuredCacheService::HandleRpcGetCachePayloads(zen::HttpServerRequest& Request, CbObjectView RpcRequest) +HttpStructuredCacheService::HandleRpcGetCacheValues(zen::HttpServerRequest& HttpRequest, CbObjectView RpcRequest) { - ZEN_TRACE_CPU("Z$::RpcGetCachePayloads"); - - ZEN_ASSERT(RpcRequest["Method"sv].AsString() == "GetCacheValues"sv); +#if BACKWARDS_COMPATABILITY_JAN2022 + if (RpcRequest["Params"sv].AsObjectView()["ChunkRequests"]) + { + return HandleRpcGetCacheChunks(HttpRequest, RpcRequest); + } +#endif - std::vector<CacheChunkRequest> ChunkRequests; - std::vector<size_t> UpstreamRequests; - std::vector<IoBuffer> Chunks; - CbObjectView Params = RpcRequest["Params"sv].AsObjectView(); + ZEN_TRACE_CPU("Z$::RpcGetCacheValues"); - for (CbFieldView RequestView : Params["ChunkRequests"sv]) + CbObjectView Params = RpcRequest["Params"sv].AsObjectView(); + std::string_view PolicyText = Params["DefaultPolicy"sv].AsString(); + CachePolicy DefaultPolicy = !PolicyText.empty() ? ParseCachePolicy(PolicyText) : CachePolicy::Default; + struct RequestData { - CbObjectView RequestObject = RequestView.AsObjectView(); - CbObjectView KeyObject = RequestObject["Key"sv].AsObjectView(); - const CacheKey Key = CacheKey::Create(KeyObject["Bucket"sv].AsString(), KeyObject["Hash"sv].AsHash()); - const IoHash ChunkId = RequestObject["ChunkId"sv].AsHash(); - const Oid PayloadId = RequestObject["ValueId"sv].AsObjectId(); - const uint64_t RawOffset = RequestObject["RawOffset"sv].AsUInt64(); - const uint64_t RawSize = RequestObject["RawSize"sv].AsUInt64(); - std::string_view PolicyText = RequestObject["Policy"sv].AsString(); - const CachePolicy ChunkPolicy = !PolicyText.empty() ? ParseCachePolicy(PolicyText) : CachePolicy::Default; + CacheKey Key; + CachePolicy Policy; + CompressedBuffer Result; + }; + std::vector<RequestData> Requests; - // Note we could use emplace_back here but [Apple] LLVM-12's C++ library - // can't infer a constructor like other platforms (or can't handle an - // initializer list like others do). - ChunkRequests.push_back({Key, ChunkId, PayloadId, RawOffset, RawSize, ChunkPolicy}); - } + ZEN_ASSERT(RpcRequest["Method"sv].AsString() == "GetCacheValues"sv); - if (ChunkRequests.empty()) + for (CbFieldView RequestField : Params["Requests"sv]) { - return Request.WriteResponse(HttpResponseCode::BadRequest); - } + RequestData& Request = Requests.emplace_back(); + CbObjectView RequestObject = RequestField.AsObjectView(); + CbObjectView KeyObject = RequestObject["Key"sv].AsObjectView(); + CbFieldView BucketField = KeyObject["Bucket"sv]; + CbFieldView HashField = KeyObject["Hash"sv]; + Request.Key = CacheKey::Create(BucketField.AsString(), HashField.AsHash()); + if (BucketField.HasError() || HashField.HasError() || Request.Key.Bucket.empty()) + { + return HttpRequest.WriteResponse(HttpResponseCode::BadRequest); + } + PolicyText = RequestObject["Policy"sv].AsString(); + Request.Policy = !PolicyText.empty() ? ParseCachePolicy(PolicyText) : DefaultPolicy; - Chunks.resize(ChunkRequests.size()); + CacheKey& Key = Request.Key; + CachePolicy Policy = Request.Policy; + CompressedBuffer& Result = Request.Result; - // Unreal uses a 12 byte ID to address cache record values. When the uncompressed hash (ChunkId) - // is missing, load the cache record and try to find the raw hash from the ValueId. - { - const auto GetChunkIdFromPayloadId = [](CbObjectView Record, const Oid& PayloadId) -> IoHash { - if (PayloadId) + ZenCacheValue CacheValue; + std::string_view Source; + if (EnumHasAllFlags(Policy, CachePolicy::QueryLocal)) + { + if (m_CacheStore.Get(Key.Bucket, Key.Hash, CacheValue) && IsCompressedBinary(CacheValue.Value.GetContentType())) { - // A valid ValueId indicates that the caller is searching for a Value in a Record - // that was Put with ICacheStore::Put - for (CbFieldView ValueView : Record["Values"sv]) + Result = CompressedBuffer::FromCompressed(SharedBuffer(CacheValue.Value)); + if (Result) { - CbObjectView ValueObject = ValueView.AsObjectView(); - const Oid Id = ValueObject["Id"sv].AsObjectId(); - - if (Id == PayloadId) - { - return ValueObject["RawHash"sv].AsHash(); - } + Source = "LOCAL"sv; } - - // Legacy fields from previous version of CacheRecord serialization: - if (CbObjectView ValueObject = Record["Value"sv].AsObjectView()) + } + } + if (!Result && EnumHasAllFlags(Policy, CachePolicy::QueryRemote)) + { + GetUpstreamCacheResult UpstreamResult = m_UpstreamCache.GetCacheRecord({Key.Bucket, Key.Hash}, ZenContentType::kBinary); + if (UpstreamResult.Success && IsCompressedBinary(UpstreamResult.Value.GetContentType())) + { + Result = CompressedBuffer::FromCompressed(SharedBuffer(UpstreamResult.Value)); + if (Result) { - const Oid Id = ValueObject["Id"sv].AsObjectId(); - if (Id == PayloadId) + UpstreamResult.Value.SetContentType(ZenContentType::kCompressedBinary); + Source = "UPSTREAM"sv; + // TODO: Respect the StoreLocal flag once we have upstream existence-only checks. For now the requirement + // that we copy data from upstream even when SkipData and !StoreLocal are true means that it is too expensive + // for us to keep the data only on the upstream server. + // if (EnumHasAllFlags(Policy, CachePolicy::StoreLocal)) { - return ValueObject["RawHash"sv].AsHash(); + m_CacheStore.Put(Key.Bucket, Key.Hash, ZenCacheValue{UpstreamResult.Value}); } } + } + } - for (CbFieldView AttachmentView : Record["Attachments"sv]) - { - CbObjectView AttachmentObject = AttachmentView.AsObjectView(); - const Oid Id = AttachmentObject["Id"sv].AsObjectId(); + if (Result) + { + ZEN_DEBUG("GETCACHEVALUES HIT - '{}/{}' {} ({})", Key.Bucket, Key.Hash, NiceBytes(Result.GetCompressed().GetSize()), Source); + m_CacheStats.HitCount++; + } + else if (!EnumHasAnyFlags(Policy, CachePolicy::Query)) + { + // If they requested no query, do not record this as a miss + ZEN_DEBUG("GETCACHEVALUES DISABLEDQUERY - '{}/{}'", Key.Bucket, Key.Hash); + } + else + { + ZEN_DEBUG("GETCACHEVALUES MISS - '{}/{}'", Key.Bucket, Key.Hash); + m_CacheStats.MissCount++; + } + } + if (Requests.empty()) + { + return HttpRequest.WriteResponse(HttpResponseCode::BadRequest); + } - if (Id == PayloadId) - { - return AttachmentObject["RawHash"sv].AsHash(); - } - } - return IoHash::Zero; - } - else + CbPackage RpcResponse; + CbObjectWriter ResponseObject; + ResponseObject.BeginArray("Result"sv); + for (const RequestData& Request : Requests) + { + ResponseObject.BeginObject(); + { + const CompressedBuffer& Result = Request.Result; + if (Result) { - // An invalid ValueId indicates that the caller is requesting a Value that - // was Put with ICacheStore::PutValue - return Record["RawHash"sv].AsHash(); + ResponseObject.AddHash("RawHash"sv, IoHash::FromBLAKE3(Result.GetRawHash())); + if (!EnumHasAllFlags(Request.Policy, CachePolicy::SkipData)) + { + RpcResponse.AddAttachment(CbAttachment(Result)); + } + else + { + ResponseObject.AddInteger("RawSize"sv, Result.GetRawSize()); + } } - }; + } + ResponseObject.EndObject(); + } + ResponseObject.EndArray(); + + RpcResponse.SetObject(ResponseObject.Save()); + + BinaryWriter MemStream; + RpcResponse.Save(MemStream); + + HttpRequest.WriteResponse(HttpResponseCode::OK, + HttpContentType::kCbPackage, + IoBuffer(IoBuffer::Wrap, MemStream.GetData(), MemStream.GetSize())); +} + +namespace GetCacheChunks::detail { + + struct ValueData + { + Oid ValueId; + IoHash ContentId; + uint64_t RawSize; + }; + struct KeyRequestData + { + CacheKeyRequest Upstream; + IoBuffer CacheValue; + std::vector<ValueData> Values; + CachePolicy DownstreamRecordPolicy; + CachePolicy DownstreamPolicy; + std::string_view Source; + bool Exists = false; + bool HasRequest = false; + bool HasRecordRequest = false; + bool HasValueRequest = false; + bool ValuesRead = false; + }; + struct ChunkRequestData + { + CacheChunkRequest Upstream; + KeyRequestData* KeyRequest; + size_t KeyRequestIndex; + CachePolicy DownstreamPolicy; + CompressedBuffer Value; + std::string_view Source; + uint64_t TotalSize = 0; + bool Exists = false; + bool IsRecordRequest = false; + bool TotalSizeKnown = false; + }; + +} // namespace GetCacheChunks::detail + +void +HttpStructuredCacheService::HandleRpcGetCacheChunks(zen::HttpServerRequest& HttpRequest, CbObjectView RpcRequest) +{ + using namespace GetCacheChunks::detail; + + ZEN_TRACE_CPU("Z$::RpcGetCacheChunks"); + + std::vector<KeyRequestData> KeyRequests; + std::vector<ChunkRequestData> Chunks; + BACKWARDS_COMPATABILITY_JAN2022_CODE(bool SendValueOnly = false;) + if (!TryGetCacheChunks_Parse(KeyRequests, Chunks BACKWARDS_COMPATABILITY_JAN2022_CODE(, SendValueOnly), RpcRequest)) + { + return HttpRequest.WriteResponse(HttpResponseCode::BadRequest); + } + GetCacheChunks_LoadKeys(KeyRequests); + GetCacheChunks_LoadChunks(Chunks); + GetCacheChunks_SendResults(Chunks, HttpRequest BACKWARDS_COMPATABILITY_JAN2022_CODE(, SendValueOnly)); +} + +bool +HttpStructuredCacheService::TryGetCacheChunks_Parse(std::vector<GetCacheChunks::detail::KeyRequestData>& KeyRequests, + std::vector<GetCacheChunks::detail::ChunkRequestData>& Chunks, + BACKWARDS_COMPATABILITY_JAN2022_CODE(bool& SendValueOnly, ) CbObjectView RpcRequest) +{ + using namespace GetCacheChunks::detail; + +#if BACKWARDS_COMPATABILITY_JAN2022 + SendValueOnly = RpcRequest["MethodVersion"sv].AsInt32() < 1; +#else + ZEN_ASSERT(RpcRequest["Method"sv].AsString() == "GetCacheChunks"sv); +#endif + + CbObjectView Params = RpcRequest["Params"sv].AsObjectView(); + std::string_view DefaultPolicyText = Params["DefaultPolicy"sv].AsString(); + CachePolicy DefaultPolicy = !DefaultPolicyText.empty() ? ParseCachePolicy(DefaultPolicyText) : CachePolicy::Default; + + KeyRequestData* PreviousKeyRequest = nullptr; + CbArrayView ChunkRequestsArray = Params["ChunkRequests"sv].AsArrayView(); + Chunks.reserve(ChunkRequestsArray.Num()); + for (CbFieldView RequestView : ChunkRequestsArray) + { + ChunkRequestData& Chunk = Chunks.emplace_back(); + CbObjectView RequestObject = RequestView.AsObjectView(); + + CbObjectView KeyObject = RequestObject["Key"sv].AsObjectView(); + CbFieldView HashField = KeyObject["Hash"sv]; + Chunk.Upstream.Key = CacheKey::Create(KeyObject["Bucket"sv].AsString(), HashField.AsHash()); + if (Chunk.Upstream.Key.Bucket.empty() || HashField.HasError()) + { + ZEN_WARN("GetCacheChunks: Invalid key in ChunkRequest."); + return false; + } + + KeyRequestData* KeyRequest = nullptr; + if (!PreviousKeyRequest || PreviousKeyRequest->Upstream.Key < Chunk.Upstream.Key) + { + KeyRequest = &KeyRequests.emplace_back(); + KeyRequest->Upstream.Key = Chunk.Upstream.Key; + PreviousKeyRequest = KeyRequest; + } + else if (!(Chunk.Upstream.Key < PreviousKeyRequest->Upstream.Key)) + { + KeyRequest = PreviousKeyRequest; + } + else + { + ZEN_WARN("GetCacheChunks: Keys in ChunkRequest are not sorted: {}/{} came after {}/{}.", + Chunk.Upstream.Key.Bucket, + Chunk.Upstream.Key.Hash, + PreviousKeyRequest->Upstream.Key.Bucket, + PreviousKeyRequest->Upstream.Key.Hash); + return false; + } + Chunk.KeyRequestIndex = std::distance(KeyRequests.data(), KeyRequest); + + Chunk.Upstream.ChunkId = RequestObject["ChunkId"sv].AsHash(); + Chunk.Upstream.ValueId = RequestObject["ValueId"sv].AsObjectId(); + Chunk.Upstream.RawOffset = RequestObject["RawOffset"sv].AsUInt64(); + Chunk.Upstream.RawSize = RequestObject["RawSize"sv].AsUInt64(UINT64_MAX); + std::string_view PolicyText = RequestObject["Policy"sv].AsString(); + Chunk.DownstreamPolicy = !PolicyText.empty() ? ParseCachePolicy(PolicyText) : DefaultPolicy; +#if BACKWARDS_COMPATABILITY_JAN2022 + if (SendValueOnly) + { + Chunk.DownstreamPolicy = Chunk.DownstreamPolicy & (~CachePolicy::SkipData); + } +#endif + Chunk.IsRecordRequest = (bool)Chunk.Upstream.ValueId; - CacheKey CurrentKey = CacheKey::Empty; - IoBuffer CurrentRecordBuffer; + if (!Chunk.IsRecordRequest || Chunk.Upstream.ChunkId == IoHash::Zero) + { + KeyRequest->DownstreamPolicy = + KeyRequest->HasRequest ? Union(KeyRequest->DownstreamPolicy, Chunk.DownstreamPolicy) : Chunk.DownstreamPolicy; + KeyRequest->HasRequest = true; + (Chunk.IsRecordRequest ? KeyRequest->HasRecordRequest : KeyRequest->HasValueRequest) = true; + } + } + if (Chunks.empty()) + { + return false; + } + for (ChunkRequestData& Chunk : Chunks) + { + Chunk.KeyRequest = &KeyRequests[Chunk.KeyRequestIndex]; + } + return true; +} + +void +HttpStructuredCacheService::GetCacheChunks_LoadKeys(std::vector<GetCacheChunks::detail::KeyRequestData>& KeyRequests) +{ + using namespace GetCacheChunks::detail; - for (CacheChunkRequest& ChunkRequest : ChunkRequests) + std::vector<CacheKeyRequest*> UpstreamRecordRequests; + std::vector<KeyRequestData*> UpstreamValueRequests; + for (KeyRequestData& KeyRequest : KeyRequests) + { + if (KeyRequest.HasRequest) { - if (ChunkRequest.ChunkId != IoHash::Zero) + if (KeyRequest.HasRecordRequest) { - continue; + KeyRequest.DownstreamRecordPolicy = KeyRequest.DownstreamPolicy | CachePolicy::SkipData | CachePolicy::SkipMeta; } - if (ChunkRequest.Key != CurrentKey) + if (!KeyRequest.Exists && EnumHasAllFlags(KeyRequest.DownstreamPolicy, CachePolicy::QueryLocal)) { - CurrentKey = ChunkRequest.Key; - + // There's currently no interface for checking only whether a CacheValue exists without loading it, + // so we load it here even if SkipData is true and its a CacheValue request. ZenCacheValue CacheValue; - if (m_CacheStore.Get(CurrentKey.Bucket, CurrentKey.Hash, CacheValue)) + if (m_CacheStore.Get(KeyRequest.Upstream.Key.Bucket, KeyRequest.Upstream.Key.Hash, CacheValue)) { - CurrentRecordBuffer = CacheValue.Value; + KeyRequest.Exists = true; + KeyRequest.CacheValue = std::move(CacheValue.Value); + KeyRequest.Source = "LOCAL"sv; } } - - if (CurrentRecordBuffer) + if (!KeyRequest.Exists) { - ChunkRequest.ChunkId = GetChunkIdFromPayloadId(CbObjectView(CurrentRecordBuffer.GetData()), ChunkRequest.PayloadId); + // At most one of RecordRequest or ValueRequest will succeed for the upstream request of the key a given key, but we don't + // know which, + // and if the requests (from arbitrary Unreal Class code) includes both types of request for a key, we want to ask for both + // kinds and pass the request that uses the one that succeeds. + if (KeyRequest.HasRecordRequest && EnumHasAllFlags(KeyRequest.DownstreamRecordPolicy, CachePolicy::QueryRemote)) + { + KeyRequest.Upstream.Policy = CacheRecordPolicy(ConvertToUpstream(KeyRequest.DownstreamRecordPolicy)); + UpstreamRecordRequests.push_back(&KeyRequest.Upstream); + } + if (KeyRequest.HasValueRequest && EnumHasAllFlags(KeyRequest.DownstreamPolicy, CachePolicy::QueryRemote)) + { + UpstreamValueRequests.push_back(&KeyRequest); + } } } } - for (size_t RequestIndex = 0; const CacheChunkRequest& ChunkRequest : ChunkRequests) + if (!UpstreamRecordRequests.empty()) { - const bool QueryLocal = EnumHasAllFlags(ChunkRequest.Policy, CachePolicy::QueryLocal); - const bool QueryRemote = EnumHasAllFlags(ChunkRequest.Policy, CachePolicy::QueryRemote); + const auto OnCacheRecordGetComplete = [this](CacheRecordGetCompleteParams&& Params) { + if (!Params.Record) + { + return; + } + + KeyRequestData& KeyRequest = + *reinterpret_cast<KeyRequestData*>(reinterpret_cast<char*>(&Params.Request) - offsetof(KeyRequestData, Upstream)); + const CacheKey& Key = KeyRequest.Upstream.Key; + KeyRequest.Exists = true; + CbObject ObjectBuffer = CbObject::Clone(Params.Record); + KeyRequest.CacheValue = ObjectBuffer.GetBuffer().AsIoBuffer(); + KeyRequest.CacheValue.SetContentType(ZenContentType::kCbObject); + KeyRequest.Source = "UPSTREAM"sv; + + if (EnumHasAllFlags(KeyRequest.DownstreamPolicy, CachePolicy::StoreLocal)) + { + m_CacheStore.Put(Key.Bucket, Key.Hash, {.Value = KeyRequest.CacheValue}); + } + }; + m_UpstreamCache.GetCacheRecords(UpstreamRecordRequests, std::move(OnCacheRecordGetComplete)); + } - if (QueryLocal) + if (!UpstreamValueRequests.empty()) + { + for (KeyRequestData* KeyRequestPtr : UpstreamValueRequests) { - if (IoBuffer Chunk = m_CidStore.FindChunkByCid(ChunkRequest.ChunkId)) + KeyRequestData& KeyRequest = *KeyRequestPtr; + CacheKey& Key = KeyRequest.Upstream.Key; + GetUpstreamCacheResult UpstreamResult = m_UpstreamCache.GetCacheRecord({Key.Bucket, Key.Hash}, ZenContentType::kBinary); + if (UpstreamResult.Success && IsCompressedBinary(UpstreamResult.Value.GetContentType())) { - ZEN_ASSERT(Chunk.GetSize() > 0); - - ZEN_DEBUG("HIT - '{}/{}/{}' {} '{}' ({})", - ChunkRequest.Key.Bucket, - ChunkRequest.Key.Hash, - ChunkRequest.ChunkId, - NiceBytes(Chunk.Size()), - ToString(Chunk.GetContentType()), - "LOCAL"); - - Chunks[RequestIndex] = Chunk; - m_CacheStats.HitCount++; + CompressedBuffer Result = CompressedBuffer::FromCompressed(SharedBuffer(UpstreamResult.Value)); + if (Result) + { + KeyRequest.CacheValue = std::move(UpstreamResult.Value); + KeyRequest.CacheValue.SetContentType(ZenContentType::kCompressedBinary); + KeyRequest.Exists = true; + KeyRequest.Source = "UPSTREAM"sv; + // TODO: Respect the StoreLocal flag once we have upstream existence-only checks. For now the requirement + // that we copy data from upstream even when SkipData and !StoreLocal are true means that it is too expensive + // for us to keep the data only on the upstream server. + // if (EnumHasAllFlags(KeyRequest->DownstreamValuePolicy, CachePolicy::StoreLocal)) + { + m_CacheStore.Put(Key.Bucket, Key.Hash, {.Value = KeyRequest.CacheValue}); + } + } } - else if (QueryRemote) + } + } +} + +void +HttpStructuredCacheService::GetCacheChunks_LoadChunks(std::vector<GetCacheChunks::detail::ChunkRequestData>& Chunks) +{ + using namespace GetCacheChunks::detail; + + std::vector<CacheChunkRequest*> UpstreamPayloadRequests; + for (ChunkRequestData& Chunk : Chunks) + { + if (Chunk.IsRecordRequest) + { + if (Chunk.Upstream.ChunkId == IoHash::Zero) { - UpstreamRequests.push_back(RequestIndex); + // Unreal uses a 12 byte ID to address cache record values. When the uncompressed hash (ChunkId) + // is missing, parse the cache record and try to find the raw hash from the ValueId. + KeyRequestData& KeyRequest = *Chunk.KeyRequest; + if (!KeyRequest.ValuesRead) + { + KeyRequest.ValuesRead = true; + if (KeyRequest.CacheValue && KeyRequest.CacheValue.GetContentType() == ZenContentType::kCbObject) + { + CbObjectView RecordObject = CbObjectView(KeyRequest.CacheValue.GetData()); + CbArrayView ValuesArray = RecordObject["Values"sv].AsArrayView(); + KeyRequest.Values.reserve(ValuesArray.Num()); + for (CbFieldView ValueField : ValuesArray) + { + CbObjectView ValueObject = ValueField.AsObjectView(); + Oid ValueId = ValueObject["Id"sv].AsObjectId(); + CbFieldView RawHashField = ValueObject["RawHash"sv]; + IoHash RawHash = RawHashField.AsBinaryAttachment(); + if (ValueId && !RawHashField.HasError()) + { + KeyRequest.Values.push_back({ValueId, RawHash, ValueObject["RawSize"sv].AsUInt64()}); + } + } + } + } + + for (const ValueData& Value : KeyRequest.Values) + { + if (Value.ValueId == Chunk.Upstream.ValueId) + { + Chunk.Upstream.ChunkId = Value.ContentId; + Chunk.TotalSize = Value.RawSize; + Chunk.TotalSizeKnown = true; + break; + } + } } - else + + // Now load the ContentId from the local ContentIdStore or from the upstream + if (Chunk.Upstream.ChunkId != IoHash::Zero) { - ZEN_DEBUG("MISS - '{}/{}/{}'", ChunkRequest.Key.Bucket, ChunkRequest.Key.Hash, ChunkRequest.ChunkId); - m_CacheStats.MissCount++; + if (EnumHasAllFlags(Chunk.DownstreamPolicy, CachePolicy::QueryLocal)) + { + if (EnumHasAllFlags(Chunk.DownstreamPolicy, CachePolicy::SkipData) && Chunk.TotalSizeKnown) + { + if (m_CidStore.ContainsChunk(Chunk.Upstream.ChunkId)) + { + Chunk.Exists = true; + Chunk.Source = "LOCAL"sv; + } + } + else if (IoBuffer Payload = m_CidStore.FindChunkByCid(Chunk.Upstream.ChunkId)) + { + CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(Payload)); + if (Compressed) + { + if (!EnumHasAllFlags(Chunk.DownstreamPolicy, CachePolicy::SkipData)) + { + Chunk.Value = Compressed; + } + Chunk.Exists = true; + Chunk.TotalSize = Compressed.GetRawSize(); + Chunk.TotalSizeKnown = true; + Chunk.Source = "LOCAL"sv; + } + } + } + if (!Chunk.Exists && EnumHasAllFlags(Chunk.DownstreamPolicy, CachePolicy::QueryRemote)) + { + Chunk.Upstream.Policy = ConvertToUpstream(Chunk.DownstreamPolicy); + UpstreamPayloadRequests.push_back(&Chunk.Upstream); + } } } else { - ZEN_DEBUG("SKIP - '{}/{}/{}'", ChunkRequest.Key.Bucket, ChunkRequest.Key.Hash, ChunkRequest.ChunkId); + if (Chunk.KeyRequest->Exists) + { + if (Chunk.KeyRequest->CacheValue && IsCompressedBinary(Chunk.KeyRequest->CacheValue.GetContentType())) + { + CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(Chunk.KeyRequest->CacheValue)); + if (Compressed) + { + if (!EnumHasAllFlags(Chunk.DownstreamPolicy, CachePolicy::SkipData)) + { + Chunk.Value = Compressed; + } + Chunk.Exists = true; + Chunk.TotalSize = Compressed.GetRawSize(); + Chunk.TotalSizeKnown = true; + Chunk.Source = Chunk.KeyRequest->Source; + Chunk.Upstream.ChunkId = IoHash::FromBLAKE3(Compressed.GetRawHash()); + } + } + } } - - ++RequestIndex; } - if (!UpstreamRequests.empty()) + if (!UpstreamPayloadRequests.empty()) { - const auto OnCachePayloadGetComplete = [this, &Chunks](CachePayloadGetCompleteParams&& Params) { - if (CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(Params.Payload))) + const auto OnCacheValueGetComplete = [this](CacheValueGetCompleteParams&& Params) { + if (Params.RawHash == Params.RawHash.Zero) { - m_CidStore.AddChunk(Compressed); - - ZEN_DEBUG("HIT - '{}/{}/{}' {} ({})", - Params.Request.Key.Bucket, - Params.Request.Key.Hash, - Params.Request.ChunkId, - NiceBytes(Params.Payload.GetSize()), - "UPSTREAM"); - - ZEN_ASSERT(Params.RequestIndex < Chunks.size()); - Chunks[Params.RequestIndex] = std::move(Params.Payload); - - m_CacheStats.HitCount++; - m_CacheStats.UpstreamHitCount++; + return; } - else + + ChunkRequestData& Chunk = + *reinterpret_cast<ChunkRequestData*>(reinterpret_cast<char*>(&Params.Request) - offsetof(ChunkRequestData, Upstream)); + if (EnumHasAllFlags(Chunk.DownstreamPolicy, CachePolicy::StoreLocal) || + !EnumHasAllFlags(Chunk.DownstreamPolicy, CachePolicy::SkipData)) { - ZEN_DEBUG("MISS - '{}/{}/{}'", Params.Request.Key.Bucket, Params.Request.Key.Hash, Params.Request.ChunkId); - m_CacheStats.MissCount++; + CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(Params.Value)); + if (!Compressed || Compressed.GetRawSize() != Params.RawSize) + { + return; + } + + if (EnumHasAllFlags(Chunk.DownstreamPolicy, CachePolicy::StoreLocal)) + { + m_CidStore.AddChunk(Compressed); + } + if (!EnumHasAllFlags(Chunk.DownstreamPolicy, CachePolicy::SkipData)) + { + Chunk.Value = std::move(Compressed); + } } + Chunk.Exists = true; + Chunk.TotalSize = Params.RawSize; + Chunk.TotalSizeKnown = true; + Chunk.Source = "UPSTREAM"sv; + + m_CacheStats.UpstreamHitCount++; }; - m_UpstreamCache.GetCachePayloads(ChunkRequests, UpstreamRequests, std::move(OnCachePayloadGetComplete)); + m_UpstreamCache.GetCacheValues(UpstreamPayloadRequests, std::move(OnCacheValueGetComplete)); } +} - CbPackage RpcResponse; - CbObjectWriter ResponseObject; +void +HttpStructuredCacheService::GetCacheChunks_SendResults(std::vector<GetCacheChunks::detail::ChunkRequestData>& Chunks, + zen::HttpServerRequest& HttpRequest + BACKWARDS_COMPATABILITY_JAN2022_CODE(, bool SendValueOnly)) +{ + using namespace GetCacheChunks::detail; - ResponseObject.BeginArray("Result"sv); + CbPackage RpcResponse; + CbObjectWriter Writer; - for (size_t ChunkIndex = 0; ChunkIndex < Chunks.size(); ++ChunkIndex) + Writer.BeginArray("Result"sv); + for (ChunkRequestData& Chunk : Chunks) { - if (Chunks[ChunkIndex]) +#if BACKWARDS_COMPATABILITY_JAN2022 + if (SendValueOnly) { - ResponseObject << ChunkRequests[ChunkIndex].ChunkId; - RpcResponse.AddAttachment(CbAttachment(CompressedBuffer::FromCompressed(SharedBuffer(std::move(Chunks[ChunkIndex]))))); + if (Chunk.Value) + { + Writer << Chunk.Upstream.ChunkId; + RpcResponse.AddAttachment(CbAttachment(Chunk.Value)); + } + else + { + Writer << IoHash::Zero; + } } else +#endif { - ResponseObject << IoHash::Zero; + Writer.BeginObject(); + { + if (Chunk.Exists) + { + Writer.AddHash("RawHash"sv, Chunk.Upstream.ChunkId); + if (Chunk.Value && !EnumHasAllFlags(Chunk.DownstreamPolicy, CachePolicy::SkipData)) + { + RpcResponse.AddAttachment(CbAttachment(Chunk.Value)); + } + else + { + Writer.AddInteger("RawSize"sv, Chunk.TotalSize); + } + + ZEN_DEBUG("CHUNKHIT - '{}/{}/{}' {} '{}' ({})", + Chunk.Upstream.Key.Bucket, + Chunk.Upstream.Key.Hash, + Chunk.Upstream.ValueId, + NiceBytes(Chunk.TotalSize), + Chunk.IsRecordRequest ? "Record"sv : "Value"sv, + Chunk.Source); + m_CacheStats.HitCount++; + } + else if (!EnumHasAnyFlags(Chunk.DownstreamPolicy, CachePolicy::Query)) + { + ZEN_DEBUG("CHUNKSKIP - '{}/{}/{}'", Chunk.Upstream.Key.Bucket, Chunk.Upstream.Key.Hash, Chunk.Upstream.ValueId); + } + else + { + ZEN_DEBUG("MISS - '{}/{}/{}'", Chunk.Upstream.Key.Bucket, Chunk.Upstream.Key.Hash, Chunk.Upstream.ValueId); + m_CacheStats.MissCount++; + } + } + Writer.EndObject(); } } - ResponseObject.EndArray(); + Writer.EndArray(); - RpcResponse.SetObject(ResponseObject.Save()); + RpcResponse.SetObject(Writer.Save()); BinaryWriter MemStream; RpcResponse.Save(MemStream); - Request.WriteResponse(HttpResponseCode::OK, - HttpContentType::kCbPackage, - IoBuffer(IoBuffer::Wrap, MemStream.GetData(), MemStream.GetSize())); + HttpRequest.WriteResponse(HttpResponseCode::OK, + HttpContentType::kCbPackage, + IoBuffer(IoBuffer::Wrap, MemStream.GetData(), MemStream.GetSize())); } void diff --git a/zenserver/cache/structuredcache.h b/zenserver/cache/structuredcache.h index a7ecba845..14b001e48 100644 --- a/zenserver/cache/structuredcache.h +++ b/zenserver/cache/structuredcache.h @@ -9,6 +9,10 @@ #include "monitoring/httpstatus.h" #include <memory> +#include <vector> + +// Include the define for BACKWARDS_COMPATABILITY_JAN2022 +#include <zenutil/cache/cachepolicy.h> namespace spdlog { class logger; @@ -19,11 +23,17 @@ namespace zen { class CasStore; class CidStore; class CbObjectView; +struct PutRequestData; class ScrubContext; class UpstreamCache; class ZenCacheStore; enum class CachePolicy : uint32_t; +namespace GetCacheChunks::detail { + struct KeyRequestData; + struct ChunkRequestData; +} // namespace GetCacheChunks::detail + /** * Structured cache service. Imposes constraints on keys, supports blobs and * structured values @@ -73,7 +83,7 @@ private: { std::string BucketSegment; IoHash HashKey; - IoHash PayloadId; + IoHash ValueContentId; }; struct CacheStats @@ -82,20 +92,41 @@ private: std::atomic_uint64_t UpstreamHitCount{}; std::atomic_uint64_t MissCount{}; }; + enum class PutResult + { + Success, + Fail, + Invalid, + }; [[nodiscard]] bool ValidateKeyUri(zen::HttpServerRequest& Request, CacheRef& OutRef); void HandleCacheRecordRequest(zen::HttpServerRequest& Request, const CacheRef& Ref, CachePolicy PolicyFromURL); void HandleGetCacheRecord(zen::HttpServerRequest& Request, const CacheRef& Ref, CachePolicy PolicyFromUrl); void HandlePutCacheRecord(zen::HttpServerRequest& Request, const CacheRef& Ref, CachePolicy PolicyFromURL); - void HandleCachePayloadRequest(zen::HttpServerRequest& Request, const CacheRef& Ref, CachePolicy PolicyFromURL); - void HandleGetCachePayload(zen::HttpServerRequest& Request, const CacheRef& Ref, CachePolicy PolicyFromURL); - void HandlePutCachePayload(zen::HttpServerRequest& Request, const CacheRef& Ref, CachePolicy PolicyFromURL); + void HandleCacheValueRequest(zen::HttpServerRequest& Request, const CacheRef& Ref, CachePolicy PolicyFromURL); + void HandleGetCacheValue(zen::HttpServerRequest& Request, const CacheRef& Ref, CachePolicy PolicyFromURL); + void HandlePutCacheValue(zen::HttpServerRequest& Request, const CacheRef& Ref, CachePolicy PolicyFromURL); void HandleRpcRequest(zen::HttpServerRequest& Request); - void HandleRpcGetCacheRecords(zen::HttpServerRequest& Request, CbObjectView BatchRequest); - void HandleRpcGetCachePayloads(zen::HttpServerRequest& Request, CbObjectView BatchRequest); - void HandleCacheBucketRequest(zen::HttpServerRequest& Request, std::string_view Bucket); - virtual void HandleStatsRequest(zen::HttpServerRequest& Request) override; - virtual void HandleStatusRequest(zen::HttpServerRequest& Request) override; + void HandleRpcPutCacheRecords(zen::HttpServerRequest& Request, const CbPackage& BatchRequest); +#if BACKWARDS_COMPATABILITY_JAN2022 + void HandleRpcGetCacheRecordsLegacy(zen::HttpServerRequest& Request, CbObjectView BatchRequest); +#endif + void HandleRpcGetCacheRecords(zen::HttpServerRequest& Request, CbObjectView BatchRequest); + void HandleRpcPutCacheValues(zen::HttpServerRequest& Request, const CbPackage& BatchRequest); + void HandleRpcGetCacheValues(zen::HttpServerRequest& Request, CbObjectView BatchRequest); + void HandleRpcGetCacheChunks(zen::HttpServerRequest& Request, CbObjectView BatchRequest); + void HandleCacheBucketRequest(zen::HttpServerRequest& Request, std::string_view Bucket); + virtual void HandleStatsRequest(zen::HttpServerRequest& Request) override; + virtual void HandleStatusRequest(zen::HttpServerRequest& Request) override; + PutResult PutCacheRecord(PutRequestData& Request, const CbPackage* Package); + + bool TryGetCacheChunks_Parse(std::vector<GetCacheChunks::detail::KeyRequestData>& KeyRequests, + std::vector<GetCacheChunks::detail::ChunkRequestData>& Chunks, + BACKWARDS_COMPATABILITY_JAN2022_CODE(bool& SendValueOnly, ) CbObjectView RpcRequest); + void GetCacheChunks_LoadKeys(std::vector<GetCacheChunks::detail::KeyRequestData>& KeyRequests); + void GetCacheChunks_LoadChunks(std::vector<GetCacheChunks::detail::ChunkRequestData>& Chunks); + void GetCacheChunks_SendResults(std::vector<GetCacheChunks::detail::ChunkRequestData>& Chunks, + zen::HttpServerRequest& HttpRequest BACKWARDS_COMPATABILITY_JAN2022_CODE(, bool SendValueOnly)); spdlog::logger& Log() { return m_Log; } spdlog::logger& m_Log; @@ -110,4 +141,12 @@ private: CacheStats m_CacheStats; }; +/** Recognize both kBinary and kCompressedBinary as kCompressedBinary for structured cache value keys. + * We need this until the content type is preserved for kCompressedBinary when passing to and from upstream servers. */ +inline bool +IsCompressedBinary(ZenContentType Type) +{ + return Type == ZenContentType::kBinary || Type == ZenContentType::kCompressedBinary; +} + } // namespace zen diff --git a/zenserver/upstream/upstreamcache.cpp b/zenserver/upstream/upstreamcache.cpp index 091406db3..9d3ed2f94 100644 --- a/zenserver/upstream/upstreamcache.cpp +++ b/zenserver/upstream/upstreamcache.cpp @@ -18,6 +18,7 @@ #include <zenstore/cas.h> #include <zenstore/cidstore.h> +#include "cache/structuredcache.h" #include "cache/structuredcachestore.h" #include "diag/logging.h" @@ -215,21 +216,16 @@ namespace detail { } } - virtual GetUpstreamCacheResult GetCacheRecords(std::span<CacheKey> CacheKeys, - std::span<size_t> KeyIndex, - const CacheRecordPolicy& Policy, - OnCacheRecordGetComplete&& OnComplete) override + virtual GetUpstreamCacheResult GetCacheRecords(std::span<CacheKeyRequest*> Requests, OnCacheRecordGetComplete&& OnComplete) override { ZEN_TRACE_CPU("Upstream::Horde::GetCacheRecords"); - ZEN_UNUSED(Policy); - CloudCacheSession Session(m_Client); GetUpstreamCacheResult Result; - for (size_t Index : KeyIndex) + for (CacheKeyRequest* Request : Requests) { - const CacheKey& CacheKey = CacheKeys[Index]; + const CacheKey& CacheKey = Request->Key; CbPackage Package; CbObject Record; @@ -264,20 +260,20 @@ namespace detail { } } - OnComplete({.Key = CacheKey, .KeyIndex = Index, .Record = Record, .Package = Package}); + OnComplete({.Request = *Request, .Record = Record, .Package = Package}); } return Result; } - virtual GetUpstreamCacheResult GetCachePayload(const CacheKey&, const IoHash& PayloadId) override + virtual GetUpstreamCacheResult GetCacheValue(const CacheKey&, const IoHash& ValueContentId) override { - ZEN_TRACE_CPU("Upstream::Horde::GetSingleCachePayload"); + ZEN_TRACE_CPU("Upstream::Horde::GetSingleCacheValue"); try { CloudCacheSession Session(m_Client); - const CloudCacheResult Result = Session.GetCompressedBlob(PayloadId); + const CloudCacheResult Result = Session.GetCompressedBlob(ValueContentId); m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason); @@ -301,20 +297,20 @@ namespace detail { } } - virtual GetUpstreamCacheResult GetCachePayloads(std::span<CacheChunkRequest> CacheChunkRequests, - std::span<size_t> RequestIndex, - OnCachePayloadGetComplete&& OnComplete) override final + virtual GetUpstreamCacheResult GetCacheValues(std::span<CacheChunkRequest*> CacheChunkRequests, + OnCacheValueGetComplete&& OnComplete) override final { - ZEN_TRACE_CPU("Upstream::Horde::GetCachePayloads"); + ZEN_TRACE_CPU("Upstream::Horde::GetCacheValues"); CloudCacheSession Session(m_Client); GetUpstreamCacheResult Result; - for (size_t Index : RequestIndex) + for (CacheChunkRequest* RequestPtr : CacheChunkRequests) { - const CacheChunkRequest& Request = CacheChunkRequests[Index]; - IoBuffer Payload; + CacheChunkRequest& Request = *RequestPtr; + IoBuffer Payload; + CompressedBuffer Compressed; if (!Result.Error) { const CloudCacheResult BlobResult = Session.GetCompressedBlob(Request.ChunkId); @@ -323,9 +319,23 @@ namespace detail { AppendResult(BlobResult, Result); m_Status.SetFromErrorCode(BlobResult.ErrorCode, BlobResult.Reason); + if (Payload && IsCompressedBinary(Payload.GetContentType())) + { + Compressed = CompressedBuffer::FromCompressed(SharedBuffer(Payload)); + } } - OnComplete({.Request = Request, .RequestIndex = Index, .Payload = Payload}); + if (Compressed) + { + OnComplete({.Request = Request, + .RawHash = IoHash::FromBLAKE3(Compressed.GetRawHash()), + .RawSize = Compressed.GetRawSize(), + .Value = Payload}); + } + else + { + OnComplete({.Request = Request, .RawHash = IoHash::Zero, .RawSize = 0, .Value = IoBuffer()}); + } } return Result; @@ -333,11 +343,11 @@ namespace detail { virtual PutUpstreamCacheResult PutCacheRecord(const UpstreamCacheRecord& CacheRecord, IoBuffer RecordValue, - std::span<IoBuffer const> Payloads) override + std::span<IoBuffer const> Values) override { ZEN_TRACE_CPU("Upstream::Horde::PutCacheRecord"); - ZEN_ASSERT(CacheRecord.PayloadIds.size() == Payloads.size()); + ZEN_ASSERT(CacheRecord.ValueContentIds.size() == Values.size()); const int32_t MaxAttempts = 3; try @@ -371,30 +381,31 @@ namespace detail { int64_t TotalBytes = 0ull; double TotalElapsedSeconds = 0.0; - const auto PutBlobs = [&](std::span<IoHash> PayloadIds, std::string& OutReason) -> bool { - for (const IoHash& PayloadId : PayloadIds) + const auto PutBlobs = [&](std::span<IoHash> ValueContentIds, std::string& OutReason) -> bool { + for (const IoHash& ValueContentId : ValueContentIds) { - const auto It = std::find(std::begin(CacheRecord.PayloadIds), std::end(CacheRecord.PayloadIds), PayloadId); + const auto It = + std::find(std::begin(CacheRecord.ValueContentIds), std::end(CacheRecord.ValueContentIds), ValueContentId); - if (It == std::end(CacheRecord.PayloadIds)) + if (It == std::end(CacheRecord.ValueContentIds)) { - OutReason = fmt::format("value '{}' MISSING from local cache", PayloadId); + OutReason = fmt::format("value '{}' MISSING from local cache", ValueContentId); return false; } - const size_t Idx = std::distance(std::begin(CacheRecord.PayloadIds), It); + const size_t Idx = std::distance(std::begin(CacheRecord.ValueContentIds), It); CloudCacheResult BlobResult; for (int32_t Attempt = 0; Attempt < MaxAttempts && !BlobResult.Success; Attempt++) { - BlobResult = Session.PutCompressedBlob(CacheRecord.PayloadIds[Idx], Payloads[Idx]); + BlobResult = Session.PutCompressedBlob(CacheRecord.ValueContentIds[Idx], Values[Idx]); } m_Status.SetFromErrorCode(BlobResult.ErrorCode, BlobResult.Reason); if (!BlobResult.Success) { - OutReason = fmt::format("upload value '{}' FAILED, reason '{}'", PayloadId, BlobResult.Reason); + OutReason = fmt::format("upload value '{}' FAILED, reason '{}'", ValueContentId, BlobResult.Reason); return false; } @@ -619,15 +630,10 @@ namespace detail { } } - virtual GetUpstreamCacheResult GetCacheRecords(std::span<CacheKey> CacheKeys, - std::span<size_t> KeyIndex, - const CacheRecordPolicy& Policy, - OnCacheRecordGetComplete&& OnComplete) override + virtual GetUpstreamCacheResult GetCacheRecords(std::span<CacheKeyRequest*> Requests, OnCacheRecordGetComplete&& OnComplete) override { ZEN_TRACE_CPU("Upstream::Zen::GetCacheRecords"); - - std::vector<size_t> IndexMap; - IndexMap.reserve(KeyIndex.size()); + ZEN_ASSERT(Requests.size() > 0); CbObjectWriter BatchRequest; BatchRequest << "Method"sv @@ -635,21 +641,30 @@ namespace detail { BatchRequest.BeginObject("Params"sv); { - BatchRequest.BeginArray("CacheKeys"sv); - for (size_t Index : KeyIndex) - { - const CacheKey& Key = CacheKeys[Index]; - IndexMap.push_back(Index); + CachePolicy DefaultPolicy = Requests[0]->Policy.GetRecordPolicy(); + BatchRequest << "DefaultPolicy"sv << WriteToString<128>(DefaultPolicy); + BatchRequest.BeginArray("Requests"sv); + for (CacheKeyRequest* Request : Requests) + { BatchRequest.BeginObject(); - BatchRequest << "Bucket"sv << Key.Bucket; - BatchRequest << "Hash"sv << Key.Hash; + { + const CacheKey& Key = Request->Key; + BatchRequest.BeginObject("Key"sv); + { + BatchRequest << "Bucket"sv << Key.Bucket; + BatchRequest << "Hash"sv << Key.Hash; + } + BatchRequest.EndObject(); + if (!Request->Policy.IsUniform() || Request->Policy.GetRecordPolicy() != DefaultPolicy) + { + BatchRequest.SetName("Policy"sv); + Request->Policy.Save(BatchRequest); + } + } BatchRequest.EndObject(); } BatchRequest.EndArray(); - - BatchRequest.SetName("Policy"sv); - Policy.Save(BatchRequest); } BatchRequest.EndObject(); @@ -667,32 +682,40 @@ namespace detail { { if (BatchResponse.TryLoad(Result.Response)) { - for (size_t LocalIndex = 0; CbFieldView Record : BatchResponse.GetObject()["Result"sv]) + CbArrayView Results = BatchResponse.GetObject()["Result"sv].AsArrayView(); + if (Results.Num() != Requests.size()) { - const size_t Index = IndexMap[LocalIndex++]; - OnComplete({.Key = CacheKeys[Index], .KeyIndex = Index, .Record = Record.AsObjectView(), .Package = BatchResponse}); + ZEN_WARN("Upstream::Zen::GetCacheRecords invalid number of Requests from Upstream."); } + else + { + for (size_t Index = 0; CbFieldView Record : Results) + { + CacheKeyRequest* Request = Requests[Index++]; + OnComplete({.Request = *Request, .Record = Record.AsObjectView(), .Package = BatchResponse}); + } - return {.Bytes = Result.Bytes, .ElapsedSeconds = Result.ElapsedSeconds, .Success = true}; + return {.Bytes = Result.Bytes, .ElapsedSeconds = Result.ElapsedSeconds, .Success = true}; + } } } - for (size_t Index : KeyIndex) + for (CacheKeyRequest* Request : Requests) { - OnComplete({.Key = CacheKeys[Index], .KeyIndex = Index, .Record = CbObjectView(), .Package = CbPackage()}); + OnComplete({.Request = *Request, .Record = CbObjectView(), .Package = CbPackage()}); } return {.Error{.ErrorCode = Result.ErrorCode, .Reason = std::move(Result.Reason)}}; } - virtual GetUpstreamCacheResult GetCachePayload(const CacheKey& CacheKey, const IoHash& PayloadId) override + virtual GetUpstreamCacheResult GetCacheValue(const CacheKey& CacheKey, const IoHash& ValueContentId) override { - ZEN_TRACE_CPU("Upstream::Zen::GetSingleCachePayload"); + ZEN_TRACE_CPU("Upstream::Zen::GetSingleCacheValue"); try { ZenStructuredCacheSession Session(*m_Client); - const ZenCacheResult Result = Session.GetCachePayload(CacheKey.Bucket, CacheKey.Hash, PayloadId); + const ZenCacheResult Result = Session.GetCacheValue(CacheKey.Bucket, CacheKey.Hash, ValueContentId); m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason); @@ -716,27 +739,28 @@ namespace detail { } } - virtual GetUpstreamCacheResult GetCachePayloads(std::span<CacheChunkRequest> CacheChunkRequests, - std::span<size_t> RequestIndex, - OnCachePayloadGetComplete&& OnComplete) override final + virtual GetUpstreamCacheResult GetCacheValues(std::span<CacheChunkRequest*> CacheChunkRequests, + OnCacheValueGetComplete&& OnComplete) override final { - ZEN_TRACE_CPU("Upstream::Zen::GetCachePayloads"); - - std::vector<size_t> IndexMap; - IndexMap.reserve(RequestIndex.size()); + ZEN_TRACE_CPU("Upstream::Zen::GetCacheValues"); + ZEN_ASSERT(!CacheChunkRequests.empty()); CbObjectWriter BatchRequest; BatchRequest << "Method"sv - << "GetCacheValues"; + << "GetCacheChunks"; +#if BACKWARDS_COMPATABILITY_JAN2022 + BatchRequest.AddInteger("MethodVersion"sv, 1); +#endif BatchRequest.BeginObject("Params"sv); { + CachePolicy DefaultPolicy = CacheChunkRequests[0]->Policy; + BatchRequest << "DefaultPolicy"sv << WriteToString<128>(DefaultPolicy).ToView(); BatchRequest.BeginArray("ChunkRequests"sv); { - for (size_t Index : RequestIndex) + for (CacheChunkRequest* RequestPtr : CacheChunkRequests) { - const CacheChunkRequest& Request = CacheChunkRequests[Index]; - IndexMap.push_back(Index); + const CacheChunkRequest& Request = *RequestPtr; BatchRequest.BeginObject(); { @@ -744,11 +768,26 @@ namespace detail { BatchRequest << "Bucket"sv << Request.Key.Bucket; BatchRequest << "Hash"sv << Request.Key.Hash; BatchRequest.EndObject(); - BatchRequest.AddObjectId("ValueId"sv, Request.PayloadId); - BatchRequest << "ChunkId"sv << Request.ChunkId; - BatchRequest << "RawOffset"sv << Request.RawOffset; - BatchRequest << "RawSize"sv << Request.RawSize; - BatchRequest << "Policy"sv << WriteToString<128>(Request.Policy).ToView(); + if (Request.ValueId) + { + BatchRequest.AddObjectId("ValueId"sv, Request.ValueId); + } + if (Request.ChunkId != Request.ChunkId.Zero) + { + BatchRequest << "ChunkId"sv << Request.ChunkId; + } + if (Request.RawOffset != 0) + { + BatchRequest << "RawOffset"sv << Request.RawOffset; + } + if (Request.RawSize != UINT64_MAX) + { + BatchRequest << "RawSize"sv << Request.RawSize; + } + if (Request.Policy != DefaultPolicy) + { + BatchRequest << "Policy"sv << WriteToString<128>(Request.Policy).ToView(); + } } BatchRequest.EndObject(); } @@ -771,29 +810,56 @@ namespace detail { { if (BatchResponse.TryLoad(Result.Response)) { - for (size_t LocalIndex = 0; CbFieldView AttachmentHash : BatchResponse.GetObject()["Result"sv]) + CbArrayView Results = BatchResponse.GetObject()["Result"sv].AsArrayView(); + if (CacheChunkRequests.size() != Results.Num()) { - const size_t Index = IndexMap[LocalIndex++]; - IoBuffer Payload; - - if (const CbAttachment* Attachment = BatchResponse.FindAttachment(AttachmentHash.AsHash())) + ZEN_WARN("Upstream::Zen::GetCacheValues invalid number of Requests from Upstream."); + } + else + { + for (size_t RequestIndex = 0; CbFieldView ChunkField : Results) { - if (const CompressedBuffer& Compressed = Attachment->AsCompressedBinary()) + CacheChunkRequest& Request = *CacheChunkRequests[RequestIndex++]; + CbObjectView ChunkObject = ChunkField.AsObjectView(); + IoHash RawHash = ChunkObject["RawHash"sv].AsHash(); + IoBuffer Payload; + uint64_t RawSize = 0; + if (RawHash != IoHash::Zero) { - Payload = Compressed.GetCompressed().Flatten().AsIoBuffer(); + bool Success = false; + const CbAttachment* Attachment = BatchResponse.FindAttachment(RawHash); + if (Attachment) + { + if (const CompressedBuffer& Compressed = Attachment->AsCompressedBinary()) + { + Payload = Compressed.GetCompressed().Flatten().AsIoBuffer(); + Payload.SetContentType(ZenContentType::kCompressedBinary); + RawSize = Compressed.GetRawSize(); + Success = true; + } + } + if (!Success) + { + CbFieldView RawSizeField = ChunkObject["RawSize"sv]; + RawSize = RawSizeField.AsUInt64(); + Success = !RawSizeField.HasError(); + } + if (!Success) + { + RawHash = IoHash::Zero; + } } + OnComplete({.Request = Request, .RawHash = RawHash, .RawSize = RawSize, .Value = std::move(Payload)}); } - OnComplete({.Request = CacheChunkRequests[Index], .RequestIndex = Index, .Payload = std::move(Payload)}); + return {.Bytes = Result.Bytes, .ElapsedSeconds = Result.ElapsedSeconds, .Success = true}; } - - return {.Bytes = Result.Bytes, .ElapsedSeconds = Result.ElapsedSeconds, .Success = true}; } } - for (size_t Index : RequestIndex) + for (CacheChunkRequest* RequestPtr : CacheChunkRequests) { - OnComplete({.Request = CacheChunkRequests[Index], .RequestIndex = Index, .Payload = IoBuffer()}); + OnComplete({.Request = *RequestPtr, .RawHash = IoHash::Zero, .RawSize = 0, .Value = IoBuffer()}); } return {.Error{.ErrorCode = Result.ErrorCode, .Reason = std::move(Result.Reason)}}; @@ -801,11 +867,11 @@ namespace detail { virtual PutUpstreamCacheResult PutCacheRecord(const UpstreamCacheRecord& CacheRecord, IoBuffer RecordValue, - std::span<IoBuffer const> Payloads) override + std::span<IoBuffer const> Values) override { ZEN_TRACE_CPU("Upstream::Zen::PutCacheRecord"); - ZEN_ASSERT(CacheRecord.PayloadIds.size() == Payloads.size()); + ZEN_ASSERT(CacheRecord.ValueContentIds.size() == Values.size()); const int32_t MaxAttempts = 3; try @@ -820,9 +886,9 @@ namespace detail { CbPackage Package; Package.SetObject(CbObject(SharedBuffer(RecordValue))); - for (const IoBuffer& Payload : Payloads) + for (const IoBuffer& Value : Values) { - if (CompressedBuffer AttachmentBuffer = CompressedBuffer::FromCompressed(SharedBuffer(Payload))) + if (CompressedBuffer AttachmentBuffer = CompressedBuffer::FromCompressed(SharedBuffer(Value))) { Package.AddAttachment(CbAttachment(AttachmentBuffer)); } @@ -848,15 +914,15 @@ namespace detail { } else { - for (size_t Idx = 0, Count = Payloads.size(); Idx < Count; Idx++) + for (size_t Idx = 0, Count = Values.size(); Idx < Count; Idx++) { Result.Success = false; for (uint32_t Attempt = 0; Attempt < MaxAttempts && !Result.Success; Attempt++) { - Result = Session.PutCachePayload(CacheRecord.Key.Bucket, - CacheRecord.Key.Hash, - CacheRecord.PayloadIds[Idx], - Payloads[Idx]); + Result = Session.PutCacheValue(CacheRecord.Key.Bucket, + CacheRecord.Key.Hash, + CacheRecord.ValueContentIds[Idx], + Values[Idx]); } m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason); @@ -866,7 +932,7 @@ namespace detail { if (!Result.Success) { - return {.Reason = "Failed to upload payload", + return {.Reason = "Failed to upload value", .Bytes = TotalBytes, .ElapsedSeconds = TotalElapsedSeconds, .Success = false}; @@ -1044,21 +1110,16 @@ public: return {}; } - virtual void GetCacheRecords(std::span<CacheKey> CacheKeys, - std::span<size_t> KeyIndex, - const CacheRecordPolicy& DownstreamPolicy, - OnCacheRecordGetComplete&& OnComplete) override final + virtual void GetCacheRecords(std::span<CacheKeyRequest*> Requests, OnCacheRecordGetComplete&& OnComplete) override final { ZEN_TRACE_CPU("Upstream::GetCacheRecords"); std::shared_lock<std::shared_mutex> _(m_EndpointsMutex); - std::vector<size_t> RemainingKeys(KeyIndex.begin(), KeyIndex.end()); + std::vector<CacheKeyRequest*> RemainingKeys(Requests.begin(), Requests.end()); if (m_Options.ReadUpstream) { - CacheRecordPolicy UpstreamPolicy = DownstreamPolicy.ConvertToUpstream(); - for (auto& Endpoint : m_Endpoints) { if (RemainingKeys.empty()) @@ -1071,25 +1132,24 @@ public: continue; } - UpstreamEndpointStats& Stats = Endpoint->Stats(); - std::vector<size_t> Missing; - GetUpstreamCacheResult Result; + UpstreamEndpointStats& Stats = Endpoint->Stats(); + std::vector<CacheKeyRequest*> Missing; + GetUpstreamCacheResult Result; { metrics::OperationTiming::Scope Scope(Stats.CacheGetRequestTiming); - Result = - Endpoint->GetCacheRecords(CacheKeys, RemainingKeys, UpstreamPolicy, [&](CacheRecordGetCompleteParams&& Params) { - if (Params.Record) - { - OnComplete(std::forward<CacheRecordGetCompleteParams>(Params)); + Result = Endpoint->GetCacheRecords(RemainingKeys, [&](CacheRecordGetCompleteParams&& Params) { + if (Params.Record) + { + OnComplete(std::forward<CacheRecordGetCompleteParams>(Params)); - Stats.CacheHitCount.Increment(1); - } - else - { - Missing.push_back(Params.KeyIndex); - } - }); + Stats.CacheHitCount.Increment(1); + } + else + { + Missing.push_back(&Params.Request); + } + }); } Stats.CacheGetCount.Increment(int64_t(RemainingKeys.size())); @@ -1109,21 +1169,19 @@ public: } } - for (size_t Index : RemainingKeys) + for (CacheKeyRequest* Request : RemainingKeys) { - OnComplete({.Key = CacheKeys[Index], .KeyIndex = Index, .Record = CbObjectView(), .Package = CbPackage()}); + OnComplete({.Request = *Request, .Record = CbObjectView(), .Package = CbPackage()}); } } - virtual void GetCachePayloads(std::span<CacheChunkRequest> CacheChunkRequests, - std::span<size_t> RequestIndex, - OnCachePayloadGetComplete&& OnComplete) override final + virtual void GetCacheValues(std::span<CacheChunkRequest*> CacheChunkRequests, OnCacheValueGetComplete&& OnComplete) override final { - ZEN_TRACE_CPU("Upstream::GetCachePayloads"); + ZEN_TRACE_CPU("Upstream::GetCacheValues"); std::shared_lock<std::shared_mutex> _(m_EndpointsMutex); - std::vector<size_t> RemainingKeys(RequestIndex.begin(), RequestIndex.end()); + std::vector<CacheChunkRequest*> RemainingKeys(CacheChunkRequests.begin(), CacheChunkRequests.end()); if (m_Options.ReadUpstream) { @@ -1139,22 +1197,22 @@ public: continue; } - UpstreamEndpointStats& Stats = Endpoint->Stats(); - std::vector<size_t> Missing; - GetUpstreamCacheResult Result; + UpstreamEndpointStats& Stats = Endpoint->Stats(); + std::vector<CacheChunkRequest*> Missing; + GetUpstreamCacheResult Result; { metrics::OperationTiming::Scope Scope(Endpoint->Stats().CacheGetRequestTiming); - Result = Endpoint->GetCachePayloads(CacheChunkRequests, RemainingKeys, [&](CachePayloadGetCompleteParams&& Params) { - if (Params.Payload) + Result = Endpoint->GetCacheValues(RemainingKeys, [&](CacheValueGetCompleteParams&& Params) { + if (Params.RawHash != Params.RawHash.Zero) { - OnComplete(std::forward<CachePayloadGetCompleteParams>(Params)); + OnComplete(std::forward<CacheValueGetCompleteParams>(Params)); Stats.CacheHitCount.Increment(1); } else { - Missing.push_back(Params.RequestIndex); + Missing.push_back(&Params.Request); } }); } @@ -1166,7 +1224,7 @@ public: { Stats.CacheErrorCount.Increment(1); - ZEN_ERROR("get cache payloads(s) (rpc) FAILED, endpoint '{}', reason '{}', error code '{}'", + ZEN_ERROR("get cache values(s) (rpc) FAILED, endpoint '{}', reason '{}', error code '{}'", Endpoint->GetEndpointInfo().Url, Result.Error.Reason, Result.Error.ErrorCode); @@ -1176,15 +1234,15 @@ public: } } - for (size_t Index : RemainingKeys) + for (CacheChunkRequest* RequestPtr : CacheChunkRequests) { - OnComplete({.Request = CacheChunkRequests[Index], .RequestIndex = Index, .Payload = IoBuffer()}); + OnComplete({.Request = *RequestPtr, .RawHash = IoHash::Zero, .RawSize = 0, .Value = IoBuffer()}); } } - virtual GetUpstreamCacheResult GetCachePayload(const CacheKey& CacheKey, const IoHash& PayloadId) override + virtual GetUpstreamCacheResult GetCacheValue(const CacheKey& CacheKey, const IoHash& ValueContentId) override { - ZEN_TRACE_CPU("Upstream::GetCachePayload"); + ZEN_TRACE_CPU("Upstream::GetCacheValue"); if (m_Options.ReadUpstream) { @@ -1200,7 +1258,7 @@ public: { metrics::OperationTiming::Scope Scope(Stats.CacheGetRequestTiming); - Result = Endpoint->GetCachePayload(CacheKey, PayloadId); + Result = Endpoint->GetCacheValue(CacheKey, ValueContentId); } Stats.CacheGetCount.Increment(1); @@ -1217,7 +1275,7 @@ public: { Stats.CacheErrorCount.Increment(1); - ZEN_ERROR("get cache payload FAILED, endpoint '{}', reason '{}', error code '{}'", + ZEN_ERROR("get cache value FAILED, endpoint '{}', reason '{}', error code '{}'", Endpoint->GetEndpointInfo().Url, Result.Error.Reason, Result.Error.ErrorCode); @@ -1302,18 +1360,18 @@ private: return; } - for (const IoHash& PayloadId : CacheRecord.PayloadIds) + for (const IoHash& ValueContentId : CacheRecord.ValueContentIds) { - if (IoBuffer Payload = m_CidStore.FindChunkByCid(PayloadId)) + if (IoBuffer Payload = m_CidStore.FindChunkByCid(ValueContentId)) { Payloads.push_back(Payload); } else { - ZEN_WARN("process upstream FAILED, '{}/{}/{}', payload doesn't exist in CAS", + ZEN_WARN("process upstream FAILED, '{}/{}/{}', ValueContentId doesn't exist in CAS", CacheRecord.Key.Bucket, CacheRecord.Key.Hash, - PayloadId); + ValueContentId); return; } } diff --git a/zenserver/upstream/upstreamcache.h b/zenserver/upstream/upstreamcache.h index 16d8c7929..994129fc4 100644 --- a/zenserver/upstream/upstreamcache.h +++ b/zenserver/upstream/upstreamcache.h @@ -2,6 +2,8 @@ #pragma once +#include <zencore/compactbinary.h> +#include <zencore/compress.h> #include <zencore/iobuffer.h> #include <zencore/iohash.h> #include <zencore/stats.h> @@ -12,10 +14,10 @@ #include <chrono> #include <functional> #include <memory> +#include <vector> namespace zen { -class CbObjectView; class CbPackage; class CbObjectWriter; class CidStore; @@ -27,7 +29,7 @@ struct UpstreamCacheRecord { ZenContentType Type = ZenContentType::kBinary; CacheKey Key; - std::vector<IoHash> PayloadIds; + std::vector<IoHash> ValueContentIds; }; struct UpstreamCacheOptions @@ -65,22 +67,22 @@ struct PutUpstreamCacheResult struct CacheRecordGetCompleteParams { - const CacheKey& Key; - size_t KeyIndex = ~size_t(0); + CacheKeyRequest& Request; const CbObjectView& Record; const CbPackage& Package; }; using OnCacheRecordGetComplete = std::function<void(CacheRecordGetCompleteParams&&)>; -struct CachePayloadGetCompleteParams +struct CacheValueGetCompleteParams { - const CacheChunkRequest& Request; - size_t RequestIndex{~size_t(0)}; - IoBuffer Payload; + CacheChunkRequest& Request; + IoHash RawHash; + uint64_t RawSize; + IoBuffer Value; }; -using OnCachePayloadGetComplete = std::function<void(CachePayloadGetCompleteParams&&)>; +using OnCacheValueGetComplete = std::function<void(CacheValueGetCompleteParams&&)>; struct UpstreamEndpointStats { @@ -151,16 +153,12 @@ public: virtual GetUpstreamCacheResult GetCacheRecord(CacheKey CacheKey, ZenContentType Type) = 0; - virtual GetUpstreamCacheResult GetCacheRecords(std::span<CacheKey> CacheKeys, - std::span<size_t> KeyIndex, - const CacheRecordPolicy& Policy, - OnCacheRecordGetComplete&& OnComplete) = 0; + virtual GetUpstreamCacheResult GetCacheRecords(std::span<CacheKeyRequest*> Requests, OnCacheRecordGetComplete&& OnComplete) = 0; - virtual GetUpstreamCacheResult GetCachePayload(const CacheKey& CacheKey, const IoHash& PayloadId) = 0; + virtual GetUpstreamCacheResult GetCacheValue(const CacheKey& CacheKey, const IoHash& PayloadId) = 0; - virtual GetUpstreamCacheResult GetCachePayloads(std::span<CacheChunkRequest> CacheChunkRequests, - std::span<size_t> RequestIndex, - OnCachePayloadGetComplete&& OnComplete) = 0; + virtual GetUpstreamCacheResult GetCacheValues(std::span<CacheChunkRequest*> CacheChunkRequests, + OnCacheValueGetComplete&& OnComplete) = 0; virtual PutUpstreamCacheResult PutCacheRecord(const UpstreamCacheRecord& CacheRecord, IoBuffer RecordValue, @@ -185,16 +183,11 @@ public: virtual GetUpstreamCacheResult GetCacheRecord(CacheKey CacheKey, ZenContentType Type) = 0; - virtual void GetCacheRecords(std::span<CacheKey> CacheKeys, - std::span<size_t> KeyIndex, - const CacheRecordPolicy& RecordPolicy, - OnCacheRecordGetComplete&& OnComplete) = 0; + virtual void GetCacheRecords(std::span<CacheKeyRequest*> Requests, OnCacheRecordGetComplete&& OnComplete) = 0; - virtual GetUpstreamCacheResult GetCachePayload(const CacheKey& CacheKey, const IoHash& PayloadId) = 0; + virtual GetUpstreamCacheResult GetCacheValue(const CacheKey& CacheKey, const IoHash& ValueContentId) = 0; - virtual void GetCachePayloads(std::span<CacheChunkRequest> CacheChunkRequests, - std::span<size_t> RequestIndex, - OnCachePayloadGetComplete&& OnComplete) = 0; + virtual void GetCacheValues(std::span<CacheChunkRequest*> CacheChunkRequests, OnCacheValueGetComplete&& OnComplete) = 0; virtual void EnqueueUpstream(UpstreamCacheRecord CacheRecord) = 0; diff --git a/zenserver/upstream/zen.cpp b/zenserver/upstream/zen.cpp index cd7f48334..a2666ac02 100644 --- a/zenserver/upstream/zen.cpp +++ b/zenserver/upstream/zen.cpp @@ -433,10 +433,10 @@ ZenStructuredCacheSession::GetCacheRecord(std::string_view BucketId, const IoHas } ZenCacheResult -ZenStructuredCacheSession::GetCachePayload(std::string_view BucketId, const IoHash& Key, const IoHash& PayloadId) +ZenStructuredCacheSession::GetCacheValue(std::string_view BucketId, const IoHash& Key, const IoHash& ValueContentId) { ExtendableStringBuilder<256> Uri; - Uri << m_Client.ServiceUrl() << "/z$/" << BucketId << "/" << Key.ToHexString() << "/" << PayloadId.ToHexString(); + Uri << m_Client.ServiceUrl() << "/z$/" << BucketId << "/" << Key.ToHexString() << "/" << ValueContentId.ToHexString(); cpr::Session& Session = m_SessionState->GetSession(); @@ -486,10 +486,10 @@ ZenStructuredCacheSession::PutCacheRecord(std::string_view BucketId, const IoHas } ZenCacheResult -ZenStructuredCacheSession::PutCachePayload(std::string_view BucketId, const IoHash& Key, const IoHash& PayloadId, IoBuffer Payload) +ZenStructuredCacheSession::PutCacheValue(std::string_view BucketId, const IoHash& Key, const IoHash& ValueContentId, IoBuffer Payload) { ExtendableStringBuilder<256> Uri; - Uri << m_Client.ServiceUrl() << "/z$/" << BucketId << "/" << Key.ToHexString() << "/" << PayloadId.ToHexString(); + Uri << m_Client.ServiceUrl() << "/z$/" << BucketId << "/" << Key.ToHexString() << "/" << ValueContentId.ToHexString(); cpr::Session& Session = m_SessionState->GetSession(); diff --git a/zenserver/upstream/zen.h b/zenserver/upstream/zen.h index c2be2165a..8cc4c121d 100644 --- a/zenserver/upstream/zen.h +++ b/zenserver/upstream/zen.h @@ -123,9 +123,9 @@ public: ZenCacheResult CheckHealth(); ZenCacheResult GetCacheRecord(std::string_view BucketId, const IoHash& Key, ZenContentType Type); - ZenCacheResult GetCachePayload(std::string_view BucketId, const IoHash& Key, const IoHash& PayloadId); + ZenCacheResult GetCacheValue(std::string_view BucketId, const IoHash& Key, const IoHash& ValueContentId); ZenCacheResult PutCacheRecord(std::string_view BucketId, const IoHash& Key, IoBuffer Value, ZenContentType Type); - ZenCacheResult PutCachePayload(std::string_view BucketId, const IoHash& Key, const IoHash& PayloadId, IoBuffer Payload); + ZenCacheResult PutCacheValue(std::string_view BucketId, const IoHash& Key, const IoHash& ValueContentId, IoBuffer Payload); ZenCacheResult InvokeRpc(const CbObjectView& Request); private: diff --git a/zenserver/zenserver.cpp b/zenserver/zenserver.cpp index c6a27ec44..e246afc34 100644 --- a/zenserver/zenserver.cpp +++ b/zenserver/zenserver.cpp @@ -1044,6 +1044,10 @@ main(int argc, char* argv[]) { TraceInit(ServerOptions.TraceFile.c_str(), TraceType::File); } + else + { + TraceInit(nullptr, TraceType::None); + } #endif // ZEN_WITH_TRACE #if ZEN_PLATFORM_WINDOWS diff --git a/zenutil/cache/cachepolicy.cpp b/zenutil/cache/cachepolicy.cpp index 3bf7a0c67..f17c54aa2 100644 --- a/zenutil/cache/cachepolicy.cpp +++ b/zenutil/cache/cachepolicy.cpp @@ -10,149 +10,156 @@ #include <algorithm> #include <unordered_map> +namespace zen::Private { +class CacheRecordPolicyShared; +} + namespace zen { using namespace std::literals; -namespace detail::CachePolicyImpl { - constexpr char DelimiterChar = ','; - constexpr std::string_view None = "None"sv; - constexpr std::string_view QueryLocal = "QueryLocal"sv; - constexpr std::string_view QueryRemote = "QueryRemote"sv; - constexpr std::string_view Query = "Query"sv; - constexpr std::string_view StoreLocal = "StoreLocal"sv; - constexpr std::string_view StoreRemote = "StoreRemote"sv; - constexpr std::string_view Store = "Store"sv; - constexpr std::string_view SkipMeta = "SkipMeta"sv; - constexpr std::string_view SkipData = "SkipData"sv; - constexpr std::string_view PartialRecord = "PartialRecord"sv; - constexpr std::string_view KeepAlive = "KeepAlive"sv; - constexpr std::string_view Local = "Local"sv; - constexpr std::string_view Remote = "Remote"sv; - constexpr std::string_view Default = "Default"sv; - constexpr std::string_view Disable = "Disable"sv; +namespace DerivedData::Private { - using TextToPolicyMap = std::unordered_map<std::string_view, CachePolicy>; - const TextToPolicyMap TextToPolicy = {{None, CachePolicy::None}, - {QueryLocal, CachePolicy::QueryLocal}, - {QueryRemote, CachePolicy::QueryRemote}, - {Query, CachePolicy::Query}, - {StoreLocal, CachePolicy::StoreLocal}, - {StoreRemote, CachePolicy::StoreRemote}, - {Store, CachePolicy::Store}, - {SkipMeta, CachePolicy::SkipMeta}, - {SkipData, CachePolicy::SkipData}, - {PartialRecord, CachePolicy::PartialRecord}, - {KeepAlive, CachePolicy::KeepAlive}, - {Local, CachePolicy::Local}, - {Remote, CachePolicy::Remote}, - {Default, CachePolicy::Default}, - {Disable, CachePolicy::Disable}}; + constexpr char CachePolicyDelimiter = ','; - using PolicyTextPair = std::pair<CachePolicy, std::string_view>; - const PolicyTextPair FlagsToString[]{ - // Order of these Flags is important: we want the aliases before the atomic values, - // and the bigger aliases first, to reduce the number of tokens we add - {CachePolicy::Default, Default}, - {CachePolicy::Remote, Remote}, - {CachePolicy::Local, Local}, - {CachePolicy::Store, Store}, - {CachePolicy::Query, Query}, - - // Order of Atomics doesn't matter, so arbitrarily we list them in enum order - {CachePolicy::QueryLocal, QueryLocal}, - {CachePolicy::QueryRemote, QueryRemote}, - {CachePolicy::StoreLocal, StoreLocal}, - {CachePolicy::StoreRemote, StoreRemote}, - {CachePolicy::SkipMeta, SkipMeta}, - {CachePolicy::SkipData, SkipData}, - {CachePolicy::PartialRecord, PartialRecord}, - {CachePolicy::KeepAlive, KeepAlive}, + struct CachePolicyToTextData + { + CachePolicy Policy; + std::string_view Text; + }; - // None must come at the end of the array, to write out only if no others exist - {CachePolicy::None, None}, + const CachePolicyToTextData CachePolicyToText[]{ + // Flags with multiple bits are ordered by bit count to minimize token count in the text format. + {CachePolicy::Default, "Default"sv}, + {CachePolicy::Remote, "Remote"sv}, + {CachePolicy::Local, "Local"sv}, + {CachePolicy::Store, "Store"sv}, + {CachePolicy::Query, "Query"sv}, + // Flags with only one bit can be in any order. Match the order in CachePolicy. + {CachePolicy::QueryLocal, "QueryLocal"sv}, + {CachePolicy::QueryRemote, "QueryRemote"sv}, + {CachePolicy::StoreLocal, "StoreLocal"sv}, + {CachePolicy::StoreRemote, "StoreRemote"sv}, + {CachePolicy::SkipMeta, "SkipMeta"sv}, + {CachePolicy::SkipData, "SkipData"sv}, + {CachePolicy::PartialRecord, "PartialRecord"sv}, + {CachePolicy::KeepAlive, "KeepAlive"sv}, + // None must be last because it matches every policy. + {CachePolicy::None, "None"sv}, }; - constexpr CachePolicy KnownFlags = - CachePolicy::Default | CachePolicy::SkipMeta | CachePolicy::SkipData | CachePolicy::KeepAlive | CachePolicy::PartialRecord; -} // namespace detail::CachePolicyImpl -StringBuilderBase& -AppendToBuilderImpl(StringBuilderBase& Builder, CachePolicy Policy) -{ - // Remove any bits we don't recognize; write None if there are not any bits we recognize - Policy = Policy & detail::CachePolicyImpl::KnownFlags; - for (const detail::CachePolicyImpl::PolicyTextPair& Pair : detail::CachePolicyImpl::FlagsToString) + constexpr CachePolicy CachePolicyKnownFlags = + CachePolicy::Default | CachePolicy::SkipMeta | CachePolicy::SkipData | CachePolicy::PartialRecord | CachePolicy::KeepAlive; + + StringBuilderBase& CachePolicyToString(StringBuilderBase& Builder, CachePolicy Policy) { - if (EnumHasAllFlags(Policy, Pair.first)) + // Mask out unknown flags. None will be written if no flags are known. + Policy &= CachePolicyKnownFlags; + for (const CachePolicyToTextData& Pair : CachePolicyToText) { - EnumRemoveFlags(Policy, Pair.first); - Builder << Pair.second << detail::CachePolicyImpl::DelimiterChar; - if (Policy == CachePolicy::None) + if (EnumHasAllFlags(Policy, Pair.Policy)) { - break; + EnumRemoveFlags(Policy, Pair.Policy); + Builder << Pair.Text << CachePolicyDelimiter; + if (Policy == CachePolicy::None) + { + break; + } } } + Builder.RemoveSuffix(1); + return Builder; } - Builder.RemoveSuffix(1); // Text will have been added by CachePolicy::None if not by anything else - return Builder; -} + + CachePolicy ParseCachePolicy(const std::string_view Text) + { + ZEN_ASSERT(!Text.empty()); // ParseCachePolicy requires a non-empty string + CachePolicy Policy = CachePolicy::None; + ForEachStrTok(Text, CachePolicyDelimiter, [&Policy, Index = int32_t(0)](const std::string_view& Token) mutable { + const int32_t EndIndex = Index; + for (; Index < sizeof(CachePolicyToText) / sizeof(CachePolicyToText[0]); ++Index) + { + if (CachePolicyToText[Index].Text == Token) + { + Policy |= CachePolicyToText[Index].Policy; + ++Index; + return true; + } + } + for (Index = 0; Index < EndIndex; ++Index) + { + if (CachePolicyToText[Index].Text == Token) + { + Policy |= CachePolicyToText[Index].Policy; + ++Index; + return true; + } + } + return true; + }); + return Policy; + } + +} // namespace DerivedData::Private + StringBuilderBase& operator<<(StringBuilderBase& Builder, CachePolicy Policy) { - return AppendToBuilderImpl(Builder, Policy); + return DerivedData::Private::CachePolicyToString(Builder, Policy); } CachePolicy ParseCachePolicy(std::string_view Text) { - ZEN_ASSERT(!Text.empty()); // Empty string is not valid input to ParseCachePolicy - - CachePolicy Result = CachePolicy::None; - ForEachStrTok(Text, detail::CachePolicyImpl::DelimiterChar, [&Result](const std::string_view& Token) { - auto it = detail::CachePolicyImpl::TextToPolicy.find(Token); - if (it != detail::CachePolicyImpl::TextToPolicy.end()) - { - Result |= it->second; - } - return true; - }); - - return Result; + return DerivedData::Private::ParseCachePolicy(Text); } -namespace Private { +CachePolicy +ConvertToUpstream(CachePolicy Policy) +{ + // Set Local flags equal to downstream's Remote flags. + // Delete Skip flags if StoreLocal is true, otherwise use the downstream value. + // Use the downstream value for all other flags. + return (EnumHasAllFlags(Policy, CachePolicy::QueryRemote) ? CachePolicy::QueryLocal : CachePolicy::None) | + (EnumHasAllFlags(Policy, CachePolicy::StoreRemote) ? CachePolicy::StoreLocal : CachePolicy::None) | + (!EnumHasAllFlags(Policy, CachePolicy::StoreLocal) ? (Policy & (CachePolicy::SkipData | CachePolicy::SkipMeta)) + : CachePolicy::None) | + (Policy & ~(CachePolicy::Local | CachePolicy::SkipData | CachePolicy::SkipMeta)); +} - class CacheRecordPolicyShared final : public ICacheRecordPolicyShared +class Private::CacheRecordPolicyShared final : public Private::ICacheRecordPolicyShared +{ +public: + inline void AddValuePolicy(const CacheValuePolicy& Value) final { - public: - inline std::span<const CacheValuePolicy> GetValuePolicies() const final { return Values; } - - inline void AddValuePolicy(const CacheValuePolicy& Policy) final { Values.push_back(Policy); } - - inline void Build() final - { - std::sort(Values.begin(), Values.end(), [](const CacheValuePolicy& A, const CacheValuePolicy& B) { return A.Id < B.Id; }); - } + ZEN_ASSERT(Value.Id); // Failed to add value policy because the ID is null. + const auto Insert = + std::lower_bound(Values.begin(), Values.end(), Value, [](const CacheValuePolicy& Existing, const CacheValuePolicy& New) { + return Existing.Id < New.Id; + }); + ZEN_ASSERT( + !(Insert < Values.end() && + Insert->Id == Value.Id)); // Failed to add value policy with ID %s because it has an existing value policy with that ID. ") + Values.insert(Insert, Value); + } - private: - std::vector<CacheValuePolicy> Values; - }; + inline std::span<const CacheValuePolicy> GetValuePolicies() const final { return Values; } -} // namespace Private +private: + std::vector<CacheValuePolicy> Values; +}; CachePolicy CacheRecordPolicy::GetValuePolicy(const Oid& Id) const { if (Shared) { - if (std::span<const CacheValuePolicy> Values = Shared->GetValuePolicies(); !Values.empty()) + const std::span<const CacheValuePolicy> Values = Shared->GetValuePolicies(); + const auto Iter = + std::lower_bound(Values.begin(), Values.end(), Id, [](const CacheValuePolicy& A, const Oid& B) { return A.Id < B; }); + if (Iter != Values.end() && Iter->Id == Id) { - auto Iter = - std::lower_bound(Values.begin(), Values.end(), Id, [](const CacheValuePolicy& A, const Oid& B) { return A.Id < B; }); - if (Iter != Values.end() && Iter->Id == Id) - { - return Iter->Policy; - } + return Iter->Policy; } } return DefaultValuePolicy; @@ -162,46 +169,58 @@ void CacheRecordPolicy::Save(CbWriter& Writer) const { Writer.BeginObject(); + // The RecordPolicy is calculated from the ValuePolicies and does not need to be saved separately. + Writer.AddString("BasePolicy"sv, WriteToString<128>(GetBasePolicy())); + if (!IsUniform()) { - // The RecordPolicy is calculated from the ValuePolicies and does not need to be saved separately. - Writer << "DefaultValuePolicy"sv << WriteToString<128>(GetDefaultValuePolicy()); - if (!IsUniform()) + Writer.BeginArray("ValuePolicies"sv); + for (const CacheValuePolicy& Value : GetValuePolicies()) { - // FCacheRecordPolicyBuilder guarantees IsUniform -> non-empty GetValuePolicies. Small size penalty here if not. - Writer.BeginArray("ValuePolicies"sv); - { - for (const CacheValuePolicy& ValuePolicy : GetValuePolicies()) - { - // FCacheRecordPolicyBuilder is responsible for ensuring that each ValuePolicy != DefaultValuePolicy - // If it lets any duplicates through we will incur a small serialization size penalty here - Writer.BeginObject(); - Writer << "Id"sv << ValuePolicy.Id; - Writer << "Policy"sv << WriteToString<128>(ValuePolicy.Policy); - Writer.EndObject(); - } - } - Writer.EndArray(); + Writer.BeginObject(); + Writer.AddObjectId("Id"sv, Value.Id); + Writer.AddString("Policy"sv, WriteToString<128>(Value.Policy)); + Writer.EndObject(); } + Writer.EndArray(); } Writer.EndObject(); } -CacheRecordPolicy -CacheRecordPolicy::Load(CbObjectView Object, CachePolicy DefaultPolicy) +OptionalCacheRecordPolicy +CacheRecordPolicy::Load(const CbObjectView Object) { - std::string_view PolicyText = Object["DefaultValuePolicy"sv].AsString(); - CachePolicy DefaultValuePolicy = !PolicyText.empty() ? ParseCachePolicy(PolicyText) : DefaultPolicy; + std::string_view BasePolicyText = Object["BasePolicy"sv].AsString(); +#if BACKWARDS_COMPATABILITY_JAN2022 + if (BasePolicyText.empty()) + { + BasePolicyText = Object["DefaultValuePolicy"sv].AsString(); + } +#endif + if (BasePolicyText.empty()) + { + return {}; + } - CacheRecordPolicyBuilder Builder(DefaultValuePolicy); - for (CbFieldView ValueObjectField : Object["ValuePolicies"sv]) + CacheRecordPolicyBuilder Builder(ParseCachePolicy(BasePolicyText)); + for (CbFieldView ValueField : Object["ValuePolicies"sv]) { - CbObjectView ValueObject = ValueObjectField.AsObjectView(); - const Oid ValueId = ValueObject["Id"sv].AsObjectId(); - PolicyText = ValueObject["Policy"sv].AsString(); - CachePolicy ValuePolicy = !PolicyText.empty() ? ParseCachePolicy(PolicyText) : DefaultValuePolicy; - // FCacheRecordPolicyBuilder should guarantee that FValueId(ValueId).IsValid and ValuePolicy != DefaultValuePolicy - // If it lets any through we will have unused data in the record we create. - Builder.AddValuePolicy(ValueId, ValuePolicy); + const CbObjectView Value = ValueField.AsObjectView(); + const Oid Id = Value["Id"sv].AsObjectId(); + const std::string_view PolicyText = Value["Policy"sv].AsString(); + if (!Id || PolicyText.empty()) + { + return {}; + } + CachePolicy Policy = ParseCachePolicy(PolicyText); +#if BACKWARDS_COMPATABILITY_JAN2022 + Policy = Policy & CacheValuePolicy::PolicyMask; +#else + if (EnumHasAnyFlags(Policy, ~CacheValuePolicy::PolicyMask)) + { + return {}; + } +#endif + Builder.AddValuePolicy(Id, Policy); } return Builder.Build(); @@ -210,30 +229,28 @@ CacheRecordPolicy::Load(CbObjectView Object, CachePolicy DefaultPolicy) CacheRecordPolicy CacheRecordPolicy::ConvertToUpstream() const { - auto DownstreamToUpstream = [](CachePolicy P) { - // Remote|Local -> Set Remote - // Delete Skip Flags - // Maintain Remaining Flags - return (EnumHasAllFlags(P, CachePolicy::QueryRemote) ? CachePolicy::QueryLocal : CachePolicy::None) | - (EnumHasAllFlags(P, CachePolicy::StoreRemote) ? CachePolicy::StoreLocal : CachePolicy::None) | - (P & ~(CachePolicy::SkipData | CachePolicy::SkipMeta)); - }; - CacheRecordPolicyBuilder Builder(DownstreamToUpstream(GetDefaultValuePolicy())); + CacheRecordPolicyBuilder Builder(zen::ConvertToUpstream(GetBasePolicy())); for (const CacheValuePolicy& ValuePolicy : GetValuePolicies()) { - Builder.AddValuePolicy(ValuePolicy.Id, DownstreamToUpstream(ValuePolicy.Policy)); + Builder.AddValuePolicy(ValuePolicy.Id, zen::ConvertToUpstream(ValuePolicy.Policy)); } return Builder.Build(); } void -CacheRecordPolicyBuilder::AddValuePolicy(const CacheValuePolicy& Policy) +CacheRecordPolicyBuilder::AddValuePolicy(const CacheValuePolicy& Value) { + ZEN_ASSERT(!EnumHasAnyFlags(Value.Policy, + ~Value.PolicyMask)); // Value policy contains flags that only make sense on the record policy. Policy: %s + if (Value.Policy == (BasePolicy & Value.PolicyMask)) + { + return; + } if (!Shared) { Shared = new Private::CacheRecordPolicyShared; } - Shared->AddValuePolicy(Policy); + Shared->AddValuePolicy(Value); } CacheRecordPolicy @@ -242,13 +259,14 @@ CacheRecordPolicyBuilder::Build() CacheRecordPolicy Policy(BasePolicy); if (Shared) { - Shared->Build(); - const auto PolicyOr = [](CachePolicy A, CachePolicy B) { return A | (B & ~CachePolicy::SkipData); }; - const std::span<const CacheValuePolicy> Values = Shared->GetValuePolicies(); - Policy.RecordPolicy = BasePolicy; + const auto Add = [](const CachePolicy A, const CachePolicy B) { + return ((A | B) & ~CachePolicy::SkipData) | ((A & B) & CachePolicy::SkipData); + }; + const std::span<const CacheValuePolicy> Values = Shared->GetValuePolicies(); + Policy.RecordPolicy = BasePolicy; for (const CacheValuePolicy& ValuePolicy : Values) { - Policy.RecordPolicy = PolicyOr(Policy.RecordPolicy, ValuePolicy.Policy); + Policy.RecordPolicy = Add(Policy.RecordPolicy, ValuePolicy.Policy); } Policy.Shared = std::move(Shared); } diff --git a/zenutil/include/zenutil/cache/cachekey.h b/zenutil/include/zenutil/cache/cachekey.h index fb36c7759..aa649b4dc 100644 --- a/zenutil/include/zenutil/cache/cachekey.h +++ b/zenutil/include/zenutil/cache/cachekey.h @@ -44,12 +44,18 @@ struct CacheChunkRequest { CacheKey Key; IoHash ChunkId; - Oid PayloadId; + Oid ValueId; uint64_t RawOffset = 0ull; uint64_t RawSize = ~uint64_t(0); CachePolicy Policy = CachePolicy::Default; }; +struct CacheKeyRequest +{ + CacheKey Key; + CacheRecordPolicy Policy; +}; + inline bool operator<(const CacheChunkRequest& A, const CacheChunkRequest& B) { @@ -69,11 +75,11 @@ operator<(const CacheChunkRequest& A, const CacheChunkRequest& B) { return false; } - if (A.PayloadId < B.PayloadId) + if (A.ValueId < B.ValueId) { return true; } - if (B.PayloadId < A.PayloadId) + if (B.ValueId < A.ValueId) { return false; } diff --git a/zenutil/include/zenutil/cache/cachepolicy.h b/zenutil/include/zenutil/cache/cachepolicy.h index b3602edbd..3eb0fda66 100644 --- a/zenutil/include/zenutil/cache/cachepolicy.h +++ b/zenutil/include/zenutil/cache/cachepolicy.h @@ -12,14 +12,26 @@ #include <span> #include <unordered_map> +#define BACKWARDS_COMPATABILITY_JAN2022 1 +#if BACKWARDS_COMPATABILITY_JAN2022 +# define BACKWARDS_COMPATABILITY_JAN2022_CODE(...) __VA_ARGS__ +#else +# define BACKWARDS_COMPATABILITY_JAN2022_CODE(...) +#endif + +namespace zen::Private { +class ICacheRecordPolicyShared; +} namespace zen { class CbObjectView; class CbWriter; +class OptionalCacheRecordPolicy; + enum class CachePolicy : uint32_t { - /** A value without any flags set. */ + /** A value with no flags. Disables access to the cache unless combined with other flags. */ None = 0, /** Allow a cache request to query local caches. */ @@ -29,17 +41,26 @@ enum class CachePolicy : uint32_t /** Allow a cache request to query any caches. */ Query = QueryLocal | QueryRemote, - /** Allow cache records and values to be stored in local caches. */ + /** Allow cache requests to query and store records and values in local caches. */ StoreLocal = 1 << 2, /** Allow cache records and values to be stored in remote caches. */ StoreRemote = 1 << 3, /** Allow cache records and values to be stored in any caches. */ Store = StoreLocal | StoreRemote, - /** Skip fetching the metadata for record requests. */ - SkipMeta = 1 << 4, + /** Allow cache requests to query and store records and values in local caches. */ + Local = QueryLocal | StoreLocal, + /** Allow cache requests to query and store records and values in remote caches. */ + Remote = QueryRemote | StoreRemote, + + /** Allow cache requests to query and store records and values in any caches. */ + Default = Query | Store, + /** Skip fetching the data for values. */ - SkipData = 1 << 5, + SkipData = 1 << 4, + + /** Skip fetching the metadata for record requests. */ + SkipMeta = 1 << 5, /** * Partial output will be provided with the error status when a required value is missing. @@ -48,7 +69,7 @@ enum class CachePolicy : uint32_t * without rebuilding the whole record. The cache automatically adds this flag when there are * other cache stores that it may be able to recover missing values from. * - * Missing values will be returned in the records or chunks, but with only the hash and size. + * Missing values will be returned in the records, but with only the hash and size. * * Applying this flag for a put of a record allows a partial record to be stored. */ @@ -61,50 +82,48 @@ enum class CachePolicy : uint32_t * to be used when subsequent accesses will not tolerate a cache miss. */ KeepAlive = 1 << 7, - - /** Allow cache requests to query and store records and values in local caches. */ - Local = QueryLocal | StoreLocal, - /** Allow cache requests to query and store records and values in remote caches. */ - Remote = QueryRemote | StoreRemote, - - /** Allow cache requests to query and store records and values in any caches. */ - Default = Query | Store, - - /** Do not allow cache requests to query or store records and values in any caches. */ - Disable = None, }; gsl_DEFINE_ENUM_BITMASK_OPERATORS(CachePolicy); -/** Serialize Policy to text and append to Builder. Appended text will not be empty. */ +/** Append a non-empty text version of the policy to the builder. */ StringBuilderBase& operator<<(StringBuilderBase& Builder, CachePolicy Policy); -/** Parse text written by operator<< back into an ECachePolicy. Text must not be empty. */ +/** Parse non-empty text written by operator<< into a policy. */ CachePolicy ParseCachePolicy(std::string_view Text); +/** Return input converted into the equivalent policy that the upstream should use when forwarding a put or get to an upstream server. */ +CachePolicy ConvertToUpstream(CachePolicy Policy); + +inline CachePolicy +Union(CachePolicy A, CachePolicy B) +{ + constexpr CachePolicy InvertedFlags = CachePolicy::SkipData | CachePolicy::SkipMeta; + return (A & ~(InvertedFlags)) | (B & ~(InvertedFlags)) | (A & B & InvertedFlags); +} /** A value ID and the cache policy to use for that value. */ struct CacheValuePolicy { Oid Id; CachePolicy Policy = CachePolicy::Default; + + /** Flags that are valid on a value policy. */ + static constexpr CachePolicy PolicyMask = CachePolicy::Default | CachePolicy::SkipData; }; -namespace Private { - /** Interface for the private implementation of the cache record policy. */ - class ICacheRecordPolicyShared : public RefCounted - { - public: - virtual ~ICacheRecordPolicyShared() = default; - virtual std::span<const CacheValuePolicy> GetValuePolicies() const = 0; - virtual void AddValuePolicy(const CacheValuePolicy& Policy) = 0; - virtual void Build() = 0; - }; -} // namespace Private +/** Interface for the private implementation of the cache record policy. */ +class Private::ICacheRecordPolicyShared : public RefCounted +{ +public: + virtual ~ICacheRecordPolicyShared() = default; + virtual void AddValuePolicy(const CacheValuePolicy& Policy) = 0; + virtual std::span<const CacheValuePolicy> GetValuePolicies() const = 0; +}; /** * Flags to control the behavior of cache record requests, with optional overrides by value. * * Examples: - * - A base policy of Disable, with value policy overrides of Default, will fetch those values if - * they exist in the record, and skip data for any other values. + * - A base policy of None with value policy overrides of Default will fetch those values if they + * exist in the record, and skip data for any other values. * - A base policy of Default, with value policy overrides of (Query | SkipData), will skip those * values, but still check if they exist, and will load any other values. */ @@ -115,34 +134,35 @@ public: CacheRecordPolicy() = default; /** Construct a cache record policy with a uniform policy for the record and every value. */ - inline CacheRecordPolicy(CachePolicy Policy) : RecordPolicy(Policy), DefaultValuePolicy(Policy) {} + inline CacheRecordPolicy(CachePolicy BasePolicy) + : RecordPolicy(BasePolicy) + , DefaultValuePolicy(BasePolicy & CacheValuePolicy::PolicyMask) + { + } /** Returns true if the record and every value use the same cache policy. */ - inline bool IsUniform() const { return !Shared && RecordPolicy == DefaultValuePolicy; } + inline bool IsUniform() const { return !Shared; } /** Returns the cache policy to use for the record. */ inline CachePolicy GetRecordPolicy() const { return RecordPolicy; } + /** Returns the base cache policy that this was constructed from. */ + inline CachePolicy GetBasePolicy() const { return DefaultValuePolicy | (RecordPolicy & ~CacheValuePolicy::PolicyMask); } + /** Returns the cache policy to use for the value. */ CachePolicy GetValuePolicy(const Oid& Id) const; - /** Returns the cache policy to use for values with no override. */ - inline CachePolicy GetDefaultValuePolicy() const { return DefaultValuePolicy; } - /** Returns the array of cache policy overrides for values, sorted by ID. */ inline std::span<const CacheValuePolicy> GetValuePolicies() const { return Shared ? Shared->GetValuePolicies() : std::span<const CacheValuePolicy>(); } - /** Save the values from *this into the given writer. */ + /** Saves the cache record policy to a compact binary object. */ void Save(CbWriter& Writer) const; - /** - * Returns a policy loaded from values on Object. - * Invalid data will result in a uniform CacheRecordPolicy with defaultValuePolicy == DefaultPolicy. - */ - static CacheRecordPolicy Load(CbObjectView Object, CachePolicy DefaultPolicy = CachePolicy::Default); + /** Loads a cache record policy from an object. */ + static OptionalCacheRecordPolicy Load(CbObjectView Object); /** Return *this converted into the equivalent policy that the upstream should use when forwarding a put or get to an upstream server. */ @@ -150,6 +170,7 @@ public: private: friend class CacheRecordPolicyBuilder; + friend class OptionalCacheRecordPolicy; CachePolicy RecordPolicy = CachePolicy::Default; CachePolicy DefaultValuePolicy = CachePolicy::Default; @@ -167,7 +188,7 @@ public: inline explicit CacheRecordPolicyBuilder(CachePolicy Policy) : BasePolicy(Policy) {} /** Adds a cache policy override for a value. */ - void AddValuePolicy(const CacheValuePolicy& Policy); + void AddValuePolicy(const CacheValuePolicy& Value); inline void AddValuePolicy(const Oid& Id, CachePolicy Policy) { AddValuePolicy({Id, Policy}); } /** Build a cache record policy, which makes this builder subsequently unusable. */ @@ -178,4 +199,38 @@ private: RefPtr<Private::ICacheRecordPolicyShared> Shared; }; +/** + * A cache record policy that can be null. + * + * @see CacheRecordPolicy + */ +class OptionalCacheRecordPolicy : private CacheRecordPolicy +{ +public: + inline OptionalCacheRecordPolicy() : CacheRecordPolicy(~CachePolicy::None) {} + + inline OptionalCacheRecordPolicy(CacheRecordPolicy&& InOutput) : CacheRecordPolicy(std::move(InOutput)) {} + inline OptionalCacheRecordPolicy(const CacheRecordPolicy& InOutput) : CacheRecordPolicy(InOutput) {} + inline OptionalCacheRecordPolicy& operator=(CacheRecordPolicy&& InOutput) + { + CacheRecordPolicy::operator=(std::move(InOutput)); + return *this; + } + inline OptionalCacheRecordPolicy& operator=(const CacheRecordPolicy& InOutput) + { + CacheRecordPolicy::operator=(InOutput); + return *this; + } + + /** Returns the cache record policy. The caller must check for null before using this accessor. */ + inline const CacheRecordPolicy& Get() const& { return *this; } + inline CacheRecordPolicy Get() && { return std::move(*this); } + + inline bool IsNull() const { return RecordPolicy == ~CachePolicy::None; } + inline bool IsValid() const { return !IsNull(); } + inline explicit operator bool() const { return !IsNull(); } + + inline void Reset() { *this = OptionalCacheRecordPolicy(); } +}; + } // namespace zen |