aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPer Larsson <[email protected]>2022-02-02 19:08:10 +0100
committerPer Larsson <[email protected]>2022-02-02 19:08:10 +0100
commit10ab6d8c768b54dfcd085ec94aa959dc9d1103ce (patch)
treec48a96d8a1ea8d4267906af76e72e1e95f70fc78
parentChanged OIDC token endpoint. (diff)
parentMerge branch 'main' of https://github.com/EpicGames/zen (diff)
downloadzen-10ab6d8c768b54dfcd085ec94aa959dc9d1103ce.tar.xz
zen-10ab6d8c768b54dfcd085ec94aa959dc9d1103ce.zip
Merged main.
-rw-r--r--CODING.md16
-rw-r--r--README.md19
-rw-r--r--TODO.md57
-rw-r--r--generate_projects.bat1
-rw-r--r--scripts/formatcode.py18
-rw-r--r--scripts/remote_build.py270
-rw-r--r--thirdparty/trace/trace.h2
-rw-r--r--xmake.lua2
-rw-r--r--zen/cmds/top.cpp3
-rw-r--r--zen/internalfile.cpp6
-rw-r--r--zencore/compactbinary.cpp4
-rw-r--r--zencore/crypto.cpp235
-rw-r--r--zencore/filesystem.cpp2
-rw-r--r--zencore/include/zencore/crypto.h50
-rw-r--r--zencore/include/zencore/refcount.h7
-rw-r--r--zencore/include/zencore/trace.h1
-rw-r--r--zencore/include/zencore/zencore.h15
-rw-r--r--zencore/iobuffer.cpp2
-rw-r--r--zencore/thread.cpp6
-rw-r--r--zencore/trace.cpp11
-rw-r--r--zencore/zencore.cpp2
-rw-r--r--zenhttp/httpsys.cpp4
-rw-r--r--zenserver-test/cachepolicy-tests.cpp51
-rw-r--r--zenserver-test/zenserver-test.cpp89
-rw-r--r--zenserver/cache/structuredcache.cpp1174
-rw-r--r--zenserver/cache/structuredcache.h42
-rw-r--r--zenserver/cache/structuredcachestore.h14
-rw-r--r--zenserver/upstream/jupiter.h4
-rw-r--r--zenserver/upstream/upstreamapply.h33
-rw-r--r--zenserver/upstream/upstreamcache.cpp261
-rw-r--r--zenserver/upstream/upstreamcache.h56
-rw-r--r--zenserver/upstream/zen.cpp18
-rw-r--r--zenserver/upstream/zen.h12
-rw-r--r--zenserver/zenserver.cpp6
-rw-r--r--zenutil/cache/cachepolicy.cpp326
-rw-r--r--zenutil/include/zenutil/cache/cachekey.h6
-rw-r--r--zenutil/include/zenutil/cache/cachepolicy.h143
37 files changed, 2228 insertions, 740 deletions
diff --git a/CODING.md b/CODING.md
index 3b37ca368..d94d9d665 100644
--- a/CODING.md
+++ b/CODING.md
@@ -1,5 +1,15 @@
# Naming Conventions
-* Classes/Structs - PascalCase
-* Functions - CamelCase
-* Class member variables - m_PascalCase
+The naming conventions for Zen are intended to resemble the Unreal Engine coding style, with some minor exceptions.
+
+* Classes/Structs - `PascalCase`
+* Functions - `PascalCase()`
+* Class member variables - `m_PascalCase`
+
+Those who are familiar with the UE coding standards will note that we do not require or encourage `F` prefixes on struct or classes, and we expect class members to have a `m_` member prefix.
+
+# Code formatting
+
+To ensure consistent formatting we rely on `clang-format` to automatically format source code. This leads to consistent formatting which should lead to less surprises and more straightforward merging.
+
+Formatting is triggered via `prepare_commit` which should be used ahead of commit. We do not currently reject commits which have not been formatted, but we probably should at some point in the future.
diff --git a/README.md b/README.md
index 288776636..f26f23150 100644
--- a/README.md
+++ b/README.md
@@ -39,8 +39,11 @@ currently building with the VS2022 toolchain has not been tested (please leave t
* clone the `zen` repository if you haven't already
* run `git clone https://github.com/EpicGames/zen.git`
* run `xmake project -k vsxmake2019 -a x64 -y`
-* open the `vsxmake2019\zen.sln` VS solution (NOTE: you currently need to run Visual Studio in ADMIN mode since
- http.sys requires elevation)
+* open the `vsxmake2019\zen.sln` VS solution
+ * Note: if you want full connectivity with the http.sys server implementation you currently need to run
+ Visual Studio in ADMIN mode since http.sys requires elevation to be able to listen on a non-local network socket.
+ You can start Visual Studio in admin mode by holding CTRL-SHIFT when launching Visual Studio. Alternatively
+ you can add an URL reservation (see below)
* you can now build and run `zenserver` as usual from Visual Studio
* third-party dependencies will be built the first time via the `vcpkg` integration. This is not as
fast as it could be (it does not go wide) but should only happen on the first build and will leverage
@@ -193,9 +196,19 @@ is incredibly handy. When that is installed you may enable auto-attach to child
The tests are implemented using [doctest](https://github.com/onqtam/doctest), which is similar to Catch in usage.
+# Adding a http.sys URL reservation
+
+Registering a handler for an HTTP endpoint requires either process elevation (i.e running Zen as admin) or a one-time URL reservation. An URL reservation can be added by issuing a command like
+
+`netsh http add urlacl url=http://*:1337/ user=stefan.boberg` (enable for a specific user)
+
+or
+
+`netsh http add urlacl url=http://*:1337/ sddl=D:(A;;GX;;;S-1-1-0)` (enable for any authenticated user)
+
# Coding Standards
-See [Coding.md](Coding.md)
+See [CODING.md](CODING.md)
Run `prepare_commit.bat` before committing code. It ensures all source files are formatted with
clang-format which you will need to install.
diff --git a/TODO.md b/TODO.md
index 9ea5862f2..a582bd904 100644
--- a/TODO.md
+++ b/TODO.md
@@ -1,68 +1,34 @@
-# Use-cases
-
-* Mirage cache
-* Editor Domain
-* COTF2
-* Target Domain / Build Store
-
# General
* Switch to CMake projects for cross-platform builds?
-* Should get rid of stack-dependent RefCount initialization
-* Upgrade to CPR 1.6.0 for more efficient downloads
-* Implement support for `CbFieldType::CustomById` / `CbFieldType::CustomByName`
-# Upstream Connectivity
+# Connectivity
## Jupiter
* High-performance/concurrency HTTP client (on asio)
-# Peer Connectivity
+## Peer Connectivity
-* Beacon
+* Beacon / tracker implementation
+* CID store peer fetching
-# Downstream Connectivity
+## Downstream Connectivity
-## Runtime
-* High performance HTTP client (layered on asio or UE sockets)
-## Cooker
+# Runtime
-## Editor
-
-## Mirage
+* High performance (HTTP?) client (layered on asio or UE sockets)
+ * Do we have asio in the engine yet?
# Local Features
* VFS for surfacing debugging information
-# TPS
-
-* nodejs/http_parser
-* all the rest (do we need TPS for vcpkg packages?)
-
-
-
-# Productization
-
-* Incremental cook
-* Windows feature complete
-* Mac / Linux support
-* Non-elevated execution
-* State management strategy
-
# Cache
-* M7
- * Cleanup
- * Jupiter upstream configuration
-
-# Editor Domain
-
-
-
+* Full support for content-type on simple (unstructured) values
# Daemon Notes
@@ -76,3 +42,8 @@ additional things to consider:
number to keep things simple. There should not be any tight coupling with
different engine branches (new releases should be backwards compatible) etc
and any service development should take place on a single stream
+
+# Random things
+
+* We currently have too many different paths for marshaling Compact Binary packages. They
+ need to be unified to use `FormatPackageMessageBuffer` everywhere
diff --git a/generate_projects.bat b/generate_projects.bat
new file mode 100644
index 000000000..92a9a33ae
--- /dev/null
+++ b/generate_projects.bat
@@ -0,0 +1 @@
+@xmake project --yes --kind=vsxmake2022 -m release,debug -a x64
diff --git a/scripts/formatcode.py b/scripts/formatcode.py
index 423d2b4e7..dc13ae117 100644
--- a/scripts/formatcode.py
+++ b/scripts/formatcode.py
@@ -7,6 +7,7 @@ import re
match_expressions = []
valid_extensions = []
root_dir = ''
+use_batching = True
def is_header_missing(f):
with open(f) as reader:
@@ -37,7 +38,11 @@ def scan_tree(root):
header_files.append(full_path)
args = ""
if files:
- os.system("clang-format -i " + " ".join(files))
+ if use_batching:
+ os.system("clang-format -i " + " ".join(files))
+ else:
+ for file in files:
+ os.system("clang-format -i " + file)
if header_files:
add_headers(header_files, "// Copyright Epic Games, Inc. All Rights Reserved.\n\n")
@@ -75,26 +80,31 @@ def parse_match_expressions(wildcards, matches):
try:
match_expressions.append(re.compile(regex, re.IGNORECASE))
except Exception as ex:
- print('Could not parse input filename expression \'{}\': {}'.format(wildcard, str(ex)))
+ print(f'Could not parse input filename expression \'{wildcard}\': {str(ex)}')
quit()
for regex in matches:
try:
match_expressions.append(re.compile(regex, re.IGNORECASE))
except Exception as ex:
- print('Could not parse input --match expression \'{}\': {}'.format(regex, str(ex)))
+ print(f'Could not parse input --match expression \'{regex}\': {str(ex)}')
quit()
def _main():
- global root_dir
+ global root_dir, use_batching
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*', help="Match text for filenames. If fullpath contains text it is a match, " +\
"* is a wildcard. Directory separators are matched by either / or \\. Case insensitive.")
parser.add_argument('--match', action='append', default=[], help="Match regular expression for filenames. " +\
"Relative path from the root zen directory must be a complete match. Directory separators are matched only by /. Case insensitive.")
+ parser.add_argument('--batch', dest='use_batching', action='store_true', help="Enable batching calls to clang-format.")
+ parser.add_argument('--no-batch', dest='use_batching', action='store_false', help="Disable batching calls to clang-format.")
+ parser.set_defaults(use_batching=True)
options = parser.parse_args()
+
parse_match_expressions(options.filenames, options.match)
root_dir = pathlib.Path(__file__).parent.parent.resolve()
+ use_batching = options.use_batching
while True:
if (os.path.isfile(".clang-format")):
diff --git a/scripts/remote_build.py b/scripts/remote_build.py
new file mode 100644
index 000000000..c5787f635
--- /dev/null
+++ b/scripts/remote_build.py
@@ -0,0 +1,270 @@
+import os
+import sys
+import argparse
+import subprocess
+from pathlib import Path
+
+# {{{1 misc --------------------------------------------------------------------
+
+# Disables output of ANSI codes if the terminal doesn't support them
+if os.name == "nt":
+ from ctypes import windll, c_int, byref
+ stdout_handle = windll.kernel32.GetStdHandle(c_int(-11))
+ mode = c_int(0)
+ windll.kernel32.GetConsoleMode(c_int(stdout_handle), byref(mode))
+ ansi_on = (mode.value & 4) != 0
+else:
+ ansi_on = True
+
+#-------------------------------------------------------------------------------
+def _header(*args, ansi=96):
+ if ansi_on:
+ print(f"\x1b[{ansi}m##", *args, end="")
+ print("\x1b[0m")
+ else:
+ print("\n##", *args)
+
+#-------------------------------------------------------------------------------
+def _run_checked(cmd, *args, **kwargs):
+ _header(cmd, *args, ansi=97)
+ ret = subprocess.run((cmd, *args), **kwargs, bufsize=0)
+ if ret.returncode:
+ raise RuntimeError("Failed running " + str(cmd))
+
+#-------------------------------------------------------------------------------
+def _get_ip():
+ import socket
+ s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+ try:
+ s.connect(("172.31.255.255", 1))
+ return s.getsockname()[0]
+ except:
+ return "127.0.0.1"
+ finally:
+ s.close()
+
+#-------------------------------------------------------------------------------
+def _find_binary(name):
+ name += ".exe" if os.name == "nt" else ""
+ for prefix in os.getenv("PATH", "").split(os.pathsep):
+ path = Path(prefix) / name
+ if path.is_file():
+ return path
+ raise EnvironmentError(f"Unable to find '{name}' in the path")
+
+#-------------------------------------------------------------------------------
+class _AutoKill(object):
+ def __init__(self, proc):
+ self._proc = proc
+
+ def __del__(self):
+ self._proc.kill()
+ self._proc.wait()
+ pass
+
+
+
+# {{{1 local -------------------------------------------------------------------
+
+#-------------------------------------------------------------------------------
+def _local(args):
+ # Parse arguments
+ desc = "Build Zen on a remote host"
+ parser = argparse.ArgumentParser(description=desc)
+ parser.add_argument("remotehost", help="")
+ parser.add_argument("action", default="build", nargs="?", help="")
+ parser.add_argument("--keyfile", default=None, help="SSH key file")
+ args = parser.parse_args(args)
+
+ # Find the binaries we'll need
+ _header("Finding tools")
+ git_bin = _find_binary("git")
+ print(f"Using git from '{git_bin.name}' from '{git_bin.parent}'")
+
+ def find_git_tool(git_bin, tool_name):
+ print(f"Locating {tool_name}...")
+ tool_suffix = "usr/bin/" + tool_name
+ tool_suffix += ".exe" if os.name == "nt" else ""
+ for parent in git_bin.parents:
+ tool_path = parent / tool_suffix
+ if tool_path.is_file():
+ return tool_path
+ return _find_binary(tool_name)
+
+ ssh_bin = find_git_tool(git_bin, "ssh")
+ scp_bin = find_git_tool(git_bin, "scp")
+ print(f"Using '{ssh_bin.name}' from '{ssh_bin.parent}'")
+ print(f"Using '{scp_bin.name}' from '{scp_bin.parent}'")
+
+ # Find the Zen repository root
+ for parent in Path(__file__).resolve().parents:
+ if (parent / ".git").is_dir():
+ zen_dir = parent
+ break;
+ else:
+ raise EnvironmentError("Unable to find '.git/' directory")
+
+ _header("Validating remote host and credentials")
+
+ # Validate key file. OpenSSL needs a trailing EOL, LibreSSL doesn't
+ if args.keyfile:
+ with open(args.keyfile, "rt") as key_file:
+ lines = [x for x in key_file]
+ if not lines[-1].endswith("\n"):
+ print("!! ERROR: key file must end with a new line")
+ return 1
+ identity = ("-i", args.keyfile)
+ else:
+ identity = ()
+
+ # Validate remote host
+ host = args.remotehost
+ if host == "linux": host = os.getenv("ZEN_REMOTE_HOST_LINUX", "arn-lin-12345")
+ if host == "mac": host = os.getenv("ZEN_REMOTE_HOST_MAC", "imacpro-arn.local")
+ """
+ keygen_bin = find_git_tool(git_bin, "ssh-keygen")
+ print(f"Using '{keygen_bin.name}' from '{keygen_bin.parent}'")
+ known_host = subprocess.run((keygen_bin, "-F", host)).returncode
+ if not known_host:
+ print("Adding", host, "as a known host")
+ print("ANSWER 'yes'")
+ known_host = subprocess.run((ssh_bin, *identity, "zenbuild@" + host, "uname -a")).returncode
+ raise IndexError
+ """
+ host = "zenbuild@" + host
+ print(f"Using host '{host}'")
+
+ # Start a git daemon to use as a transfer mechanism
+ _header("Starting a git daemon")
+ print("Port: 4493")
+ print("Base-path: ", zen_dir)
+ print("Host: ", _get_ip())
+ daemon = subprocess.Popen(
+ ( git_bin,
+ "daemon",
+ "--port=4493",
+ "--export-all",
+ "--reuseaddr",
+ "--verbose",
+ "--informative-errors",
+ "--base-path=" + str(zen_dir) ),
+ #stdout = daemon_log,
+ stderr = subprocess.STDOUT
+ )
+ daemon_killer = _AutoKill(daemon)
+
+ # Run this script on the remote machine
+ _header("Running SSH")
+
+ remote_zen_dir = "%s_%s" % (os.getlogin(), _get_ip())
+ print(f"Using zen '~/{remote_zen_dir}'")
+
+ print(f"Running {__file__} remotely")
+ with open(__file__, "rt") as self_file:
+ _run_checked(
+ ssh_bin,
+ *identity,
+ "-tA",
+ host,
+ f"python3 -u - !remote {_get_ip()} '{remote_zen_dir}' main '{args.action}'",
+ stdin=self_file)
+
+ # If we're bundling, collect zip files from the remote machine
+ if args.action == "bundle":
+ build_dir = zen_dir / "build"
+ build_dir.mkdir(exist_ok=True)
+ scp_args = (*identity, host + f":zen/{remote_zen_dir}/build/*.zip", build_dir)
+ _run_checked("scp", *scp_args)
+
+
+
+# {{{1 remote ------------------------------------------------------------------
+
+#-------------------------------------------------------------------------------
+def _remote(args):
+ # Parse arguments
+ desc = "Build Zen on a remote host"
+ parser = argparse.ArgumentParser(description=desc)
+ parser.add_argument("ip", help="Host's IP address")
+ parser.add_argument("reponame", help="Repository name clone into and work in")
+ parser.add_argument("branch", help="Zen branch to operate on")
+ parser.add_argument("action", help="The action to do")
+ args = parser.parse_args(args)
+
+ # Homeward bound and out
+ zen_dir = Path().home() / "zen"
+ os.chdir(zen_dir)
+
+ # Mutual exclusion
+ """
+ lock_path = zen_dir / "../.remote_lock"
+ try: lock_file = open(lock_path, "xb")
+ except: raise RuntimeError("Failed to lock", lock_path)
+ """
+
+ # Check for a clone, create it, chdir to it
+ _header("REMOTE:", f"Clone/pull from {args.ip}")
+ clone_dir = zen_dir / args.reponame
+ if not clone_dir.is_dir():
+ _run_checked("git", "clone", f"git://{args.ip}:4493/", clone_dir)
+ os.chdir(clone_dir)
+
+ _run_checked("git", "checkout", args.branch)
+ _run_checked("git", "pull", "-r")
+
+ _header("REMOTE:", f"Performing action '{args.action}'")
+
+ # Find xmake
+ xmake_bin = max(x for x in (zen_dir / "xmake").glob("*"))
+ xmake_bin /= "usr/local/bin/xmake"
+
+ # Run xmake
+ xmake_env = {}
+ xmake_env["VCPKG_ROOT"] = zen_dir / "vcpkg"
+ if sys.platform == "linux":
+ xmake_env["CXX"] = "g++-11"
+ print("xmake environment:")
+ for key, value in xmake_env.items():
+ print(" ", key, "=", value)
+ xmake_env.update(os.environ)
+
+ def run_xmake(*args):
+ print("starting xmake...", end="\r")
+ _run_checked(xmake_bin, args[0], "--yes", *args[1:], env=xmake_env)
+
+ if args.action.startswith("build"):
+ mode = "debug" if args.action == "build.debug" else "release"
+ run_xmake("config", "--mode=" + mode)
+ run_xmake("build")
+
+ elif args.action == "bundle":
+ run_xmake("bundle")
+
+ elif args.action == "test":
+ run_xmake("config", "--mode=debug")
+ run_xmake("test")
+
+ elif args.action == "clean":
+ _run_checked("git", "reset")
+ _run_checked("git", "checkout", ".")
+ _run_checked("git", "clean", "-xdf")
+
+
+
+# {{{1 entry -------------------------------------------------------------------
+if __name__ == "__main__":
+ if "!remote" in sys.argv[1:2]:
+ ret = _remote(sys.argv[2:])
+ raise SystemExit(ret)
+
+ try:
+ ret = _local(sys.argv[1:])
+ raise SystemExit(ret)
+ except:
+ raise
+ finally:
+ # Roundabout way to avoid orphaned git-daemon processes
+ if os.name == "nt":
+ os.system("taskkill /f /im git-daemon.exe")
+
+# vim: expandtab foldlevel=1 foldmethod=marker
diff --git a/thirdparty/trace/trace.h b/thirdparty/trace/trace.h
index caa862ffe..d7fbbb71f 100644
--- a/thirdparty/trace/trace.h
+++ b/thirdparty/trace/trace.h
@@ -263,7 +263,7 @@ class FChannel;
EventProps_Meta const EventProps_Private = {}; \
typedef std::conditional<bIsImportant, UE::Trace::Private::FImportantLogScope, UE::Trace::Private::FLogScope>::type LogScopeType; \
explicit operator bool () const { return true; } \
- enum { EventFlags = PartialEventFlags|(EventProps_Meta::NumAuxFields ? UE::Trace::Private::FEventInfo::Flag_MaybeHasAux : 0), }; \
+ enum { EventFlags = PartialEventFlags|((EventProps_Meta::NumAuxFields != 0) ? UE::Trace::Private::FEventInfo::Flag_MaybeHasAux : 0), }; \
static_assert( \
!bIsImportant || (uint32(EventFlags) & uint32(UE::Trace::Private::FEventInfo::Flag_NoSync)), \
"Trace events flagged as Important events must be marked NoSync" \
diff --git a/xmake.lua b/xmake.lua
index c5cee1bec..be01b9fb6 100644
--- a/xmake.lua
+++ b/xmake.lua
@@ -130,7 +130,7 @@ option_end()
add_define_by_config("ZEN_ENABLE_MESH", "zenmesh")
option("zentrace")
- set_default(false)
+ set_default(true)
set_showmenu(true)
set_description("Enable UE's Trace support")
option_end()
diff --git a/zen/cmds/top.cpp b/zen/cmds/top.cpp
index f5b9d654a..21e4dc60e 100644
--- a/zen/cmds/top.cpp
+++ b/zen/cmds/top.cpp
@@ -78,7 +78,8 @@ PsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
return 0;
}
- State.Snapshot([&](const ZenServerState::ZenServerEntry& Entry) { ZEN_CONSOLE("Port {} : pid {}", Entry.EffectiveListenPort, Entry.Pid); });
+ State.Snapshot(
+ [&](const ZenServerState::ZenServerEntry& Entry) { ZEN_CONSOLE("Port {} : pid {}", Entry.EffectiveListenPort, Entry.Pid); });
return 0;
}
diff --git a/zen/internalfile.cpp b/zen/internalfile.cpp
index c1ffe6f5f..804375ce2 100644
--- a/zen/internalfile.cpp
+++ b/zen/internalfile.cpp
@@ -258,9 +258,9 @@ InternalFile::Read(void* Data, uint64_t Size, uint64_t Offset)
HRESULT hRes = m_File.Read(Data, gsl::narrow<DWORD>(Size), &ovl);
Success = SUCCEEDED(hRes);
#else
- int Fd = int(intptr_t(m_File));
- int BytesRead = pread(Fd, Data, Size, Offset);
- Success = (BytesRead > 0);
+ int Fd = int(intptr_t(m_File));
+ int BytesRead = pread(Fd, Data, Size, Offset);
+ Success = (BytesRead > 0);
#endif
if (Success)
diff --git a/zencore/compactbinary.cpp b/zencore/compactbinary.cpp
index cded378a1..902ec26c8 100644
--- a/zencore/compactbinary.cpp
+++ b/zencore/compactbinary.cpp
@@ -54,7 +54,7 @@ GetPlatformToDateTimeBiasInSeconds()
#if ZEN_PLATFORM_WINDOWS
const uint64_t PlatformEpochYear = 1601;
#else
- const uint64_t PlatformEpochYear = 1970;
+ const uint64_t PlatformEpochYear = 1970;
#endif
const uint64_t DateTimeEpochYear = 1;
return uint64_t(double(PlatformEpochYear - DateTimeEpochYear) * 365.2425) * 86400;
@@ -71,7 +71,7 @@ DateTime::Now()
GetSystemTimeAsFileTime(&SysTime);
return DateTime{(EpochBias * SecsTo100nsTicks) + (uint64_t(SysTime.dwHighDateTime) << 32) | SysTime.dwLowDateTime};
#else
- int64_t SecondsSinceUnixEpoch = time(nullptr);
+ int64_t SecondsSinceUnixEpoch = time(nullptr);
return DateTime{(EpochBias + SecondsSinceUnixEpoch) * SecsTo100nsTicks};
#endif
}
diff --git a/zencore/crypto.cpp b/zencore/crypto.cpp
new file mode 100644
index 000000000..880d7b495
--- /dev/null
+++ b/zencore/crypto.cpp
@@ -0,0 +1,235 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zencore/crypto.h>
+#include <zencore/intmath.h>
+#include <zencore/testing.h>
+
+#include <openssl/conf.h>
+#include <openssl/err.h>
+#include <openssl/evp.h>
+
+#include <string>
+#include <string_view>
+
+#if ZEN_PLATFORM_WINDOWS
+# pragma comment(lib, "crypt32.lib")
+# pragma comment(lib, "ws2_32.lib")
+#endif
+
+namespace zen {
+
+class NullCipher final : public SymmetricCipher
+{
+public:
+ NullCipher() = default;
+ virtual ~NullCipher() = default;
+
+ virtual bool Initialize(MemoryView, MemoryView) override final { return true; }
+
+ virtual CipherSettings Settings() override final { return {}; }
+
+ virtual MemoryView Encrypt(MemoryView Data, MutableMemoryView) override final { return Data; }
+
+ virtual MemoryView Decrypt(MemoryView Data, MutableMemoryView) override final { return Data; }
+};
+
+std::unique_ptr<SymmetricCipher>
+MakeNullCipher()
+{
+ return std::make_unique<NullCipher>();
+}
+
+#if ZEN_PLATFORM_WINDOWS
+class Aes final : public SymmetricCipher
+{
+public:
+ Aes(const EVP_CIPHER* Cipher = EVP_aes_256_cbc()) : m_Cipher(Cipher)
+ {
+ ZEN_ASSERT(Cipher);
+ m_KeySize = static_cast<size_t>(EVP_CIPHER_key_length(m_Cipher));
+ m_InitVectorSize = static_cast<size_t>(EVP_CIPHER_iv_length(m_Cipher));
+ m_BlockSize = static_cast<size_t>(EVP_CIPHER_block_size(m_Cipher));
+ }
+
+ virtual ~Aes()
+ {
+ if (m_EncryptionCtx)
+ {
+ EVP_CIPHER_CTX_free(m_EncryptionCtx);
+ }
+
+ if (m_DecryptionCtx)
+ {
+ EVP_CIPHER_CTX_free(m_DecryptionCtx);
+ }
+ }
+
+ virtual bool Initialize(MemoryView Key, MemoryView InitVector) override final
+ {
+ ZEN_ASSERT(m_EncryptionCtx == nullptr && m_DecryptionCtx == nullptr);
+ ZEN_ASSERT(Key.GetSize() == m_KeySize);
+ ZEN_ASSERT(InitVector.GetSize() == m_InitVectorSize);
+
+ m_EncryptionCtx = EVP_CIPHER_CTX_new();
+ m_DecryptionCtx = EVP_CIPHER_CTX_new();
+
+ if (int ErrorCode = EVP_EncryptInit_ex(m_EncryptionCtx,
+ m_Cipher,
+ nullptr,
+ reinterpret_cast<const unsigned char*>(Key.GetData()),
+ reinterpret_cast<const unsigned char*>(InitVector.GetData()));
+ ErrorCode != 1)
+ {
+ return false;
+ }
+
+ if (int ErrorCode = EVP_DecryptInit_ex(m_DecryptionCtx,
+ m_Cipher,
+ nullptr,
+ reinterpret_cast<const unsigned char*>(Key.GetData()),
+ reinterpret_cast<const unsigned char*>(InitVector.GetData()));
+ ErrorCode != 1)
+ {
+ return false;
+ }
+
+ return true;
+ }
+
+ virtual CipherSettings Settings() override final
+ {
+ return {.KeySize = m_KeySize, .InitVectorSize = m_InitVectorSize, .BlockSize = m_BlockSize};
+ }
+
+ virtual MemoryView Encrypt(MemoryView Data, MutableMemoryView EncryptionBuffer)
+ {
+ ZEN_ASSERT(m_EncryptionCtx);
+
+ const uint64_t InputSize = Data.GetSize();
+ const uint64_t NeededSize = RoundUp(InputSize, m_BlockSize);
+
+ if (NeededSize > EncryptionBuffer.GetSize())
+ {
+ return MemoryView();
+ }
+
+ int TotalSize = 0;
+ int EncryptedSize = 0;
+ int ErrorCode = EVP_EncryptUpdate(m_EncryptionCtx,
+ reinterpret_cast<unsigned char*>(EncryptionBuffer.GetData()),
+ &EncryptedSize,
+ reinterpret_cast<const unsigned char*>(Data.GetData()),
+ static_cast<int>(Data.GetSize()));
+
+ if (ErrorCode != 1)
+ {
+ return MemoryView();
+ }
+
+ TotalSize = EncryptedSize;
+ MutableMemoryView Remaining = EncryptionBuffer.RightChop(uint64_t(EncryptedSize));
+
+ ErrorCode = EVP_EncryptFinal_ex(m_EncryptionCtx, reinterpret_cast<unsigned char*>(Remaining.GetData()), &EncryptedSize);
+
+ if (ErrorCode != 1)
+ {
+ return MemoryView();
+ }
+
+ TotalSize += EncryptedSize;
+
+ return EncryptionBuffer.Left(uint64_t(TotalSize));
+ }
+
+ virtual MemoryView Decrypt(MemoryView Data, MutableMemoryView DecryptionBuffer) override final
+ {
+ ZEN_ASSERT(m_DecryptionCtx);
+
+ int TotalSize = 0;
+ int DecryptedSize = 0;
+ int ErrorCode = EVP_DecryptUpdate(m_DecryptionCtx,
+ reinterpret_cast<unsigned char*>(DecryptionBuffer.GetData()),
+ &DecryptedSize,
+ reinterpret_cast<const unsigned char*>(Data.GetData()),
+ static_cast<int>(Data.GetSize()));
+
+ if (ErrorCode != 1)
+ {
+ return MemoryView();
+ }
+
+ TotalSize = DecryptedSize;
+ MutableMemoryView Remaining = DecryptionBuffer.RightChop(uint64_t(DecryptedSize));
+
+ ErrorCode = EVP_DecryptFinal_ex(m_DecryptionCtx, reinterpret_cast<unsigned char*>(Remaining.GetData()), &DecryptedSize);
+
+ TotalSize += DecryptedSize;
+
+ return DecryptionBuffer.Left(uint64_t(TotalSize));
+ }
+
+private:
+ const EVP_CIPHER* m_Cipher = nullptr;
+ EVP_CIPHER_CTX* m_EncryptionCtx = nullptr;
+ EVP_CIPHER_CTX* m_DecryptionCtx = nullptr;
+ size_t m_BlockSize = 0;
+ size_t m_KeySize = 0;
+ size_t m_InitVectorSize = 0;
+};
+
+std::unique_ptr<SymmetricCipher>
+MakeAesCipher()
+{
+ return std::make_unique<Aes>();
+}
+
+#endif // ZEN_PLATFORM_WINDOWS
+
+#if ZEN_WITH_TESTS
+
+using namespace std::literals;
+
+void
+crypto_forcelink()
+{
+}
+
+TEST_CASE("crypto.aes")
+{
+ SUBCASE("basic")
+ {
+# if ZEN_PLATFORM_WINDOWS
+ auto Cipher = std::make_unique<Aes>();
+
+ std::string_view PlainText = "The quick brown fox jumps over the lazy dog"sv;
+
+ std::vector<uint8_t> Key = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
+ 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31};
+ std::vector<uint8_t> Seed = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
+
+ std::vector<uint8_t> EncryptionBuffer;
+ std::vector<uint8_t> DecryptionBuffer;
+
+ bool Ok = Cipher->Initialize(MakeMemoryView(Key), MakeMemoryView(Seed));
+ CHECK(Ok);
+
+ EncryptionBuffer.resize(PlainText.size() + Cipher->Settings().BlockSize);
+ DecryptionBuffer.resize(PlainText.size() + Cipher->Settings().BlockSize);
+
+ MemoryView EncryptedView = Cipher->Encrypt(MakeMemoryView(PlainText), MakeMutableMemoryView(EncryptionBuffer));
+ CHECK(EncryptedView.IsEmpty() == false);
+
+ MemoryView DecryptedView = Cipher->Decrypt(EncryptedView, MakeMutableMemoryView(DecryptionBuffer));
+ CHECK(DecryptedView.IsEmpty() == false);
+
+ std::string_view EncryptedDecryptedText =
+ std::string_view(reinterpret_cast<const char*>(DecryptedView.GetData()), DecryptedView.GetSize());
+
+ CHECK(EncryptedDecryptedText == PlainText);
+ }
+# endif
+}
+
+#endif
+
+} // namespace zen
diff --git a/zencore/filesystem.cpp b/zencore/filesystem.cpp
index ee49aa474..ab606301c 100644
--- a/zencore/filesystem.cpp
+++ b/zencore/filesystem.cpp
@@ -703,7 +703,7 @@ ScanFile(std::filesystem::path Path, const uint64_t ChunkSize, std::function<voi
ProcessFunc(ReadBuffer.data(), dwBytesRead);
}
#else
- int Fd = open(Path.c_str(), O_RDONLY | O_CLOEXEC);
+ int Fd = open(Path.c_str(), O_RDONLY | O_CLOEXEC);
if (Fd < 0)
{
return false;
diff --git a/zencore/include/zencore/crypto.h b/zencore/include/zencore/crypto.h
new file mode 100644
index 000000000..4d6ddba47
--- /dev/null
+++ b/zencore/include/zencore/crypto.h
@@ -0,0 +1,50 @@
+
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/memory.h>
+#include <zencore/zencore.h>
+
+#include <memory>
+
+namespace zen {
+
+/**
+ * Experimental interface for a symmetric encryption/decryption algorithm.
+ * Currenlty only AES 256 bit CBC is supported using OpenSSL.
+ */
+class SymmetricCipher
+{
+public:
+ virtual ~SymmetricCipher() = default;
+
+ virtual bool Initialize(MemoryView Key, MemoryView InitVector) = 0;
+
+ struct CipherSettings
+ {
+ size_t KeySize = 0;
+ size_t InitVectorSize = 0;
+ size_t BlockSize = 0;
+ };
+
+ virtual CipherSettings Settings() = 0;
+
+ virtual MemoryView Encrypt(MemoryView Data, MutableMemoryView EncryptionBuffer) = 0;
+
+ virtual MemoryView Decrypt(MemoryView Data, MutableMemoryView DecryptionBuffer) = 0;
+};
+
+std::unique_ptr<SymmetricCipher> MakeNullCipher();
+
+#if ZEN_PLATFORM_WINDOWS
+/**
+ * Create a new instance of a 256 bit AES CBC symmetric cipher.
+ * NOTE: Currenlty only tested on Windows
+ */
+std::unique_ptr<SymmetricCipher> MakeAesCipher();
+#endif
+
+void crypto_forcelink();
+
+} // namespace zen
diff --git a/zencore/include/zencore/refcount.h b/zencore/include/zencore/refcount.h
index 254a22db5..7befbb338 100644
--- a/zencore/include/zencore/refcount.h
+++ b/zencore/include/zencore/refcount.h
@@ -114,7 +114,7 @@ public:
private:
T* m_Ref = nullptr;
- template <typename U>
+ template<typename U>
friend class RefPtr;
};
@@ -135,8 +135,9 @@ public:
inline ~Ref() { m_Ref && m_Ref->Release(); }
template<typename DerivedType>
- requires DerivedFrom<DerivedType, T>
- inline Ref(const Ref<DerivedType>& Rhs) : Ref(Rhs.m_Ref) {}
+ requires DerivedFrom<DerivedType, T> inline Ref(const Ref<DerivedType>& Rhs) : Ref(Rhs.m_Ref)
+ {
+ }
[[nodiscard]] inline bool IsNull() const { return m_Ref == nullptr; }
inline explicit operator bool() const { return m_Ref != nullptr; }
diff --git a/zencore/include/zencore/trace.h b/zencore/include/zencore/trace.h
index f28fdeeaf..0af490f23 100644
--- a/zencore/include/zencore/trace.h
+++ b/zencore/include/zencore/trace.h
@@ -22,6 +22,7 @@ enum class TraceType
{
File,
Network,
+ None
};
void TraceInit(const char* HostOrPath, TraceType Type);
diff --git a/zencore/include/zencore/zencore.h b/zencore/include/zencore/zencore.h
index 023b237cd..bd5e5a531 100644
--- a/zencore/include/zencore/zencore.h
+++ b/zencore/include/zencore/zencore.h
@@ -153,7 +153,7 @@ concept DerivedFrom = std::derived_from<D, B>;
template<class T>
concept Integral = std::is_integral_v<T>;
template<class T>
-concept SignedIntegral = Integral<T> && std::is_signed_v<T>;
+concept SignedIntegral = Integral<T>&& std::is_signed_v<T>;
template<class T>
concept UnsignedIntegral = Integral<T> && !std::is_signed_v<T>;
template<class F, class... A>
@@ -162,7 +162,7 @@ concept Invocable = requires(F&& f, A&&... a)
std::invoke(std::forward<F>(f), std::forward<A>(a)...);
};
template<class D, class B>
-concept DerivedFrom = std::is_base_of_v<B, D> && std::is_convertible_v<const volatile D*, const volatile B*>;
+concept DerivedFrom = std::is_base_of_v<B, D>&& std::is_convertible_v<const volatile D*, const volatile B*>;
#endif
#if defined(__cpp_lib_ranges)
@@ -239,13 +239,12 @@ static_assert(sizeof(wchar_t) == 2, "wchar_t is expected to be two bytes in size
# define ZEN_DEBUG_SECTION ZEN_CODE_SECTION(".zcold")
#endif
-namespace zen
+namespace zen {
+class AssertException : public std::logic_error
{
- class AssertException : public std::logic_error
- {
- public:
- AssertException(const char* Msg) : std::logic_error(Msg) {}
- };
+public:
+ AssertException(const char* Msg) : std::logic_error(Msg) {}
+};
} // namespace zen
diff --git a/zencore/iobuffer.cpp b/zencore/iobuffer.cpp
index e2aaa3169..57abbfb48 100644
--- a/zencore/iobuffer.cpp
+++ b/zencore/iobuffer.cpp
@@ -513,7 +513,7 @@ IoBufferBuilder::MakeFromTemporaryFile(const std::filesystem::path& FileName)
Handle = DataFile.Detach();
#else
- int Fd = open(FileName.native().c_str(), O_RDONLY);
+ int Fd = open(FileName.native().c_str(), O_RDONLY);
if (Fd < 0)
{
return {};
diff --git a/zencore/thread.cpp b/zencore/thread.cpp
index 2cc4d8a96..a123eec82 100644
--- a/zencore/thread.cpp
+++ b/zencore/thread.cpp
@@ -163,7 +163,7 @@ Event::Set()
#if ZEN_PLATFORM_WINDOWS
SetEvent(m_EventHandle);
#else
- auto* Inner = (EventInner*)m_EventHandle;
+ auto* Inner = (EventInner*)m_EventHandle;
{
std::unique_lock Lock(Inner->Mutex);
Inner->bSet = true;
@@ -316,7 +316,7 @@ NamedEvent::Close()
#if ZEN_PLATFORM_WINDOWS
CloseHandle(m_EventHandle);
#elif ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC
- int Fd = int(intptr_t(m_EventHandle) & 0xffff'ffff);
+ int Fd = int(intptr_t(m_EventHandle) & 0xffff'ffff);
if (flock(Fd, LOCK_EX | LOCK_NB) == 0)
{
@@ -580,7 +580,7 @@ ProcessHandle::Terminate(int ExitCode)
bSuccess = (WaitResult != WAIT_OBJECT_0);
#elif ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC
ZEN_UNUSED(ExitCode);
- bSuccess = (kill(m_Pid, SIGKILL) == 0);
+ bSuccess = (kill(m_Pid, SIGKILL) == 0);
#endif
if (!bSuccess)
diff --git a/zencore/trace.cpp b/zencore/trace.cpp
index 6a35571e6..788dcec07 100644
--- a/zencore/trace.cpp
+++ b/zencore/trace.cpp
@@ -12,6 +12,8 @@
void
TraceInit(const char* HostOrPath, TraceType Type)
{
+ bool EnableEvents = true;
+
switch (Type)
{
case TraceType::Network:
@@ -21,6 +23,10 @@ TraceInit(const char* HostOrPath, TraceType Type)
case TraceType::File:
trace::WriteTo(HostOrPath);
break;
+
+ case TraceType::None:
+ EnableEvents = false;
+ break;
}
trace::FInitializeDesc Desc = {
@@ -28,7 +34,10 @@ TraceInit(const char* HostOrPath, TraceType Type)
};
trace::Initialize(Desc);
- trace::ToggleChannel("cpu", true);
+ if (EnableEvents)
+ {
+ trace::ToggleChannel("cpu", true);
+ }
}
#endif // ZEN_WITH_TRACE
diff --git a/zencore/zencore.cpp b/zencore/zencore.cpp
index 19acdd1f5..8b45d273d 100644
--- a/zencore/zencore.cpp
+++ b/zencore/zencore.cpp
@@ -16,6 +16,7 @@
#include <zencore/compactbinarypackage.h>
#include <zencore/compositebuffer.h>
#include <zencore/compress.h>
+#include <zencore/crypto.h>
#include <zencore/filesystem.h>
#include <zencore/intmath.h>
#include <zencore/iobuffer.h>
@@ -117,6 +118,7 @@ zencore_forcelinktests()
zen::uson_forcelink();
zen::usonbuilder_forcelink();
zen::usonpackage_forcelink();
+ zen::crypto_forcelink();
}
#endif
diff --git a/zenhttp/httpsys.cpp b/zenhttp/httpsys.cpp
index 3c57f7ce3..8e898eb18 100644
--- a/zenhttp/httpsys.cpp
+++ b/zenhttp/httpsys.cpp
@@ -802,6 +802,8 @@ HttpSysServer::InitializeServer(int BasePort)
// port for the current user. eg:
// netsh http add urlacl url=http://*:1337/ user=<some_user>
+ ZEN_WARN("Unable to register handler using '{}' - falling back to local-only", WideToUtf8(WildcardUrlPath));
+
const std::u8string_view Hosts[] = {u8"[::1]"sv, u8"localhost"sv, u8"127.0.0.1"sv};
ULONG InternalResult = ERROR_SHARING_VIOLATION;
@@ -818,6 +820,8 @@ HttpSysServer::InitializeServer(int BasePort)
if (InternalResult == NO_ERROR)
{
+ ZEN_INFO("Registered local handler '{}'", WideToUtf8(LocalUrlPath));
+
m_BaseUris.push_back(LocalUrlPath.c_str());
}
else
diff --git a/zenserver-test/cachepolicy-tests.cpp b/zenserver-test/cachepolicy-tests.cpp
index 686ff818c..79d78e522 100644
--- a/zenserver-test/cachepolicy-tests.cpp
+++ b/zenserver-test/cachepolicy-tests.cpp
@@ -23,8 +23,7 @@ TEST_CASE("cachepolicy")
CachePolicy::QueryLocal,
CachePolicy::StoreRemote,
CachePolicy::SkipData,
- CachePolicy::KeepAlive,
- CachePolicy::Disable};
+ CachePolicy::KeepAlive};
for (CachePolicy Atomic : SomeAtomics)
{
CHECK(ParseCachePolicy(WriteToString<128>(Atomic)) == Atomic);
@@ -73,7 +72,8 @@ TEST_CASE("cacherecordpolicy")
{
SUBCASE("policy with no values")
{
- CachePolicy Policy = CachePolicy::SkipData | CachePolicy::QueryLocal;
+ CachePolicy Policy = CachePolicy::SkipData | CachePolicy::QueryLocal | CachePolicy::PartialRecord;
+ CachePolicy ValuePolicy = Policy & CacheValuePolicy::PolicyMask;
CacheRecordPolicy RecordPolicy;
CacheRecordPolicyBuilder Builder(Policy);
RecordPolicy = Builder.Build();
@@ -81,8 +81,8 @@ TEST_CASE("cacherecordpolicy")
{
CHECK(RecordPolicy.IsUniform());
CHECK(RecordPolicy.GetRecordPolicy() == Policy);
- CHECK(RecordPolicy.GetDefaultValuePolicy() == Policy);
- CHECK(RecordPolicy.GetValuePolicy(Oid::NewOid()) == Policy);
+ CHECK(RecordPolicy.GetBasePolicy() == Policy);
+ CHECK(RecordPolicy.GetValuePolicy(Oid::NewOid()) == ValuePolicy);
CHECK(RecordPolicy.GetValuePolicies().size() == 0);
}
SUBCASE("saveload")
@@ -90,21 +90,22 @@ TEST_CASE("cacherecordpolicy")
CbWriter Writer;
RecordPolicy.Save(Writer);
CbObject Saved = Writer.Save()->AsObject();
- CacheRecordPolicy Loaded = CacheRecordPolicy::Load(Saved);
+ CacheRecordPolicy Loaded = CacheRecordPolicy::Load(Saved).Get();
CHECK(Loaded.IsUniform());
CHECK(Loaded.GetRecordPolicy() == Policy);
- CHECK(Loaded.GetDefaultValuePolicy() == Policy);
- CHECK(Loaded.GetValuePolicy(Oid::NewOid()) == Policy);
+ CHECK(Loaded.GetBasePolicy() == Policy);
+ CHECK(Loaded.GetValuePolicy(Oid::NewOid()) == ValuePolicy);
CHECK(Loaded.GetValuePolicies().size() == 0);
}
}
SUBCASE("policy with values")
{
- CachePolicy DefaultPolicy = CachePolicy::StoreRemote | CachePolicy::QueryLocal;
- CachePolicy PartialOverlap = CachePolicy::StoreRemote;
- CachePolicy NoOverlap = CachePolicy::QueryRemote;
- CachePolicy UnionPolicy = DefaultPolicy | PartialOverlap | NoOverlap;
+ CachePolicy DefaultPolicy = CachePolicy::StoreRemote | CachePolicy::QueryLocal | CachePolicy::PartialRecord;
+ CachePolicy DefaultValuePolicy = DefaultPolicy & CacheValuePolicy::PolicyMask;
+ CachePolicy PartialOverlap = CachePolicy::StoreRemote;
+ CachePolicy NoOverlap = CachePolicy::QueryRemote;
+ CachePolicy UnionPolicy = DefaultPolicy | PartialOverlap | NoOverlap | CachePolicy::PartialRecord;
CacheRecordPolicy RecordPolicy;
CacheRecordPolicyBuilder Builder(DefaultPolicy);
@@ -118,10 +119,10 @@ TEST_CASE("cacherecordpolicy")
{
CHECK(!RecordPolicy.IsUniform());
CHECK(RecordPolicy.GetRecordPolicy() == UnionPolicy);
- CHECK(RecordPolicy.GetDefaultValuePolicy() == DefaultPolicy);
+ CHECK(RecordPolicy.GetBasePolicy() == DefaultPolicy);
CHECK(RecordPolicy.GetValuePolicy(PartialOid) == PartialOverlap);
CHECK(RecordPolicy.GetValuePolicy(NoOverlapOid) == NoOverlap);
- CHECK(RecordPolicy.GetValuePolicy(OtherOid) == DefaultPolicy);
+ CHECK(RecordPolicy.GetValuePolicy(OtherOid) == DefaultValuePolicy);
CHECK(RecordPolicy.GetValuePolicies().size() == 2);
}
SUBCASE("saveload")
@@ -129,33 +130,21 @@ TEST_CASE("cacherecordpolicy")
CbWriter Writer;
RecordPolicy.Save(Writer);
CbObject Saved = Writer.Save()->AsObject();
- CacheRecordPolicy Loaded = CacheRecordPolicy::Load(Saved);
+ CacheRecordPolicy Loaded = CacheRecordPolicy::Load(Saved).Get();
CHECK(!RecordPolicy.IsUniform());
CHECK(RecordPolicy.GetRecordPolicy() == UnionPolicy);
- CHECK(RecordPolicy.GetDefaultValuePolicy() == DefaultPolicy);
+ CHECK(RecordPolicy.GetBasePolicy() == DefaultPolicy);
CHECK(RecordPolicy.GetValuePolicy(PartialOid) == PartialOverlap);
CHECK(RecordPolicy.GetValuePolicy(NoOverlapOid) == NoOverlap);
- CHECK(RecordPolicy.GetValuePolicy(OtherOid) == DefaultPolicy);
+ CHECK(RecordPolicy.GetValuePolicy(OtherOid) == DefaultValuePolicy);
CHECK(RecordPolicy.GetValuePolicies().size() == 2);
}
}
SUBCASE("parsing invalid text")
{
- CacheRecordPolicy Loaded = CacheRecordPolicy::Load(CbObject());
- CHECK(Loaded.IsUniform());
- CHECK(Loaded.GetRecordPolicy() == CachePolicy::Default);
- CHECK(Loaded.GetDefaultValuePolicy() == CachePolicy::Default);
- CHECK(Loaded.GetValuePolicy(Oid::NewOid()) == CachePolicy::Default);
- CHECK(Loaded.GetValuePolicies().size() == 0);
-
- CachePolicy Policy = CachePolicy::SkipData;
- Loaded = CacheRecordPolicy::Load(CbObject(), Policy);
- CHECK(Loaded.IsUniform());
- CHECK(Loaded.GetRecordPolicy() == Policy);
- CHECK(Loaded.GetDefaultValuePolicy() == Policy);
- CHECK(Loaded.GetValuePolicy(Oid::NewOid()) == Policy);
- CHECK(Loaded.GetValuePolicies().size() == 0);
+ OptionalCacheRecordPolicy Loaded = CacheRecordPolicy::Load(CbObject());
+ CHECK(Loaded.IsNull());
}
}
diff --git a/zenserver-test/zenserver-test.cpp b/zenserver-test/zenserver-test.cpp
index 85393aed2..425f43946 100644
--- a/zenserver-test/zenserver-test.cpp
+++ b/zenserver-test/zenserver-test.cpp
@@ -1827,7 +1827,8 @@ TEST_CASE("zcache.rpc")
Data[Idx] = Idx % 255;
}
- zen::CbAttachment Attachment(zen::CompressedBuffer::Compress(SharedBuffer::MakeView(Data.data(), Data.size())));
+ CompressedBuffer Value = zen::CompressedBuffer::Compress(SharedBuffer::MakeView(Data.data(), Data.size()));
+ zen::CbAttachment Attachment(Value);
Writer.BeginObject();
{
@@ -1838,7 +1839,17 @@ TEST_CASE("zcache.rpc")
Writer << "Bucket"sv << CacheKey.Bucket << "Hash"sv << CacheKey.Hash;
}
Writer.EndObject();
- Writer << "Data"sv << Attachment;
+ Writer.BeginArray("Values"sv);
+ {
+ Writer.BeginObject();
+ {
+ Writer.AddObjectId("Id"sv, Oid::NewOid());
+ Writer.AddBinaryAttachment("RawHash"sv, IoHash::FromBLAKE3(Value.GetRawHash()));
+ Writer.AddInteger("RawSize"sv, Value.GetRawSize());
+ }
+ Writer.EndObject();
+ }
+ Writer.EndArray();
}
Writer.EndObject();
Writer.SetName("Policy"sv);
@@ -1862,7 +1873,7 @@ TEST_CASE("zcache.rpc")
for (uint32_t Key = 1; Key <= Num; ++Key)
{
- zen::IoHash KeyHash;
+ zen::IoHash KeyHash;
((uint32_t*)(KeyHash.Hash))[0] = Key;
const zen::CacheKey CacheKey = zen::CacheKey::Create(Bucket, KeyHash);
CbPackage Package;
@@ -1907,27 +1918,30 @@ TEST_CASE("zcache.rpc")
bool Success;
};
- auto GetCacheRecords =
- [](std::string_view BaseUri, std::span<zen::CacheKey> Keys, const zen::CacheRecordPolicy& Policy) -> GetCacheRecordResult {
+ auto GetCacheRecords = [](std::string_view BaseUri, std::span<zen::CacheKey> Keys, zen::CachePolicy Policy) -> GetCacheRecordResult {
using namespace zen;
CbObjectWriter Request;
Request << "Method"sv
<< "GetCacheRecords"sv;
Request.BeginObject("Params"sv);
-
- Request.BeginArray("CacheKeys"sv);
- for (const CacheKey& Key : Keys)
{
- Request.BeginObject();
- Request << "Bucket"sv << Key.Bucket << "Hash"sv << Key.Hash;
- Request.EndObject();
+ Request << "DefaultPolicy"sv << WriteToString<128>(Policy);
+ Request.BeginArray("Requests"sv);
+ for (const CacheKey& Key : Keys)
+ {
+ Request.BeginObject();
+ {
+ Request.BeginObject("Key"sv);
+ {
+ Request << "Bucket"sv << Key.Bucket << "Hash"sv << Key.Hash;
+ }
+ Request.EndObject();
+ }
+ Request.EndObject();
+ }
+ Request.EndArray();
}
- Request.EndArray();
-
- Request.SetName("Policy"sv);
- Policy.Save(Request);
-
Request.EndObject();
BinaryWriter Body;
@@ -1978,7 +1992,7 @@ TEST_CASE("zcache.rpc")
Inst.SpawnServer(PortNumber);
Inst.WaitUntilReady();
- CacheRecordPolicy Policy;
+ CachePolicy Policy = CachePolicy::Default;
std::vector<zen::CacheKey> Keys = PutCacheRecords(BaseUri, "mastodon"sv, 128);
GetCacheRecordResult Result = GetCacheRecords(BaseUri, Keys, Policy);
@@ -1988,11 +2002,18 @@ TEST_CASE("zcache.rpc")
{
const CacheKey& ExpectedKey = Keys[Index++];
- CbObjectView RecordObj = RecordView.AsObjectView();
- CbObjectView KeyObj = RecordObj["Key"sv].AsObjectView();
- const CacheKey Key = CacheKey::Create(KeyObj["Bucket"sv].AsString(), KeyObj["Hash"].AsHash());
- const IoHash AttachmentHash = RecordObj["Data"sv].AsHash();
- const CbAttachment* Attachment = Result.Response.FindAttachment(AttachmentHash);
+ CbObjectView RecordObj = RecordView.AsObjectView();
+ CbObjectView KeyObj = RecordObj["Key"sv].AsObjectView();
+ const CacheKey Key = CacheKey::Create(KeyObj["Bucket"sv].AsString(), KeyObj["Hash"].AsHash());
+ IoHash AttachmentHash;
+ size_t NumValues = 0;
+ for (CbFieldView Value : RecordObj["Values"sv])
+ {
+ AttachmentHash = Value.AsObjectView()["RawHash"sv].AsHash();
+ ++NumValues;
+ }
+ CHECK(NumValues == 1);
+ const CbAttachment* Attachment = Result.Response.FindAttachment(AttachmentHash);
CHECK(Key == ExpectedKey);
CHECK(Attachment != nullptr);
@@ -2010,7 +2031,7 @@ TEST_CASE("zcache.rpc")
Inst.SpawnServer(PortNumber);
Inst.WaitUntilReady();
- CacheRecordPolicy Policy;
+ CachePolicy Policy = CachePolicy::Default;
std::vector<zen::CacheKey> ExistingKeys = PutCacheRecords(BaseUri, "mastodon"sv, 128);
std::vector<zen::CacheKey> Keys;
@@ -2035,12 +2056,18 @@ TEST_CASE("zcache.rpc")
}
else
{
- const CacheKey& ExpectedKey = ExistingKeys[KeyIndex++];
- CbObjectView RecordObj = RecordView.AsObjectView();
- zen::CacheKey Key = LoadKey(RecordObj["Key"sv]);
- const IoHash AttachmentHash = RecordObj["Data"sv].AsHash();
- const CbAttachment* Attachment = Result.Response.FindAttachment(AttachmentHash);
-
+ const CacheKey& ExpectedKey = ExistingKeys[KeyIndex++];
+ CbObjectView RecordObj = RecordView.AsObjectView();
+ zen::CacheKey Key = LoadKey(RecordObj["Key"sv]);
+ IoHash AttachmentHash;
+ size_t NumValues = 0;
+ for (CbFieldView Value : RecordObj["Values"sv])
+ {
+ AttachmentHash = Value.AsObjectView()["RawHash"sv].AsHash();
+ ++NumValues;
+ }
+ CHECK(NumValues == 1);
+ const CbAttachment* Attachment = Result.Response.FindAttachment(AttachmentHash);
CHECK(Key == ExpectedKey);
CHECK(Attachment != nullptr);
}
@@ -2061,7 +2088,7 @@ TEST_CASE("zcache.rpc")
std::vector<zen::CacheKey> Keys = PutCacheRecords(UpstreamCfg.BaseUri, "mastodon"sv, 4);
- CacheRecordPolicy Policy(CachePolicy::QueryLocal);
+ CachePolicy Policy = CachePolicy::QueryLocal;
GetCacheRecordResult Result = GetCacheRecords(LocalCfg.BaseUri, Keys, Policy);
CHECK(Result.Records.size() == Keys.size());
@@ -2086,7 +2113,7 @@ TEST_CASE("zcache.rpc")
std::vector<zen::CacheKey> Keys = PutCacheRecords(UpstreamCfg.BaseUri, "mastodon"sv, 4);
- CacheRecordPolicy Policy(CachePolicy::QueryLocal | CachePolicy::QueryRemote);
+ CachePolicy Policy = (CachePolicy::QueryLocal | CachePolicy::QueryRemote);
GetCacheRecordResult Result = GetCacheRecords(LocalCfg.BaseUri, Keys, Policy);
CHECK(Result.Records.size() == Keys.size());
diff --git a/zenserver/cache/structuredcache.cpp b/zenserver/cache/structuredcache.cpp
index d39b95a1e..49e5896d1 100644
--- a/zenserver/cache/structuredcache.cpp
+++ b/zenserver/cache/structuredcache.cpp
@@ -48,6 +48,13 @@ ParseCachePolicy(const HttpServerRequest::QueryParams& QueryParams)
return !PolicyText.empty() ? zen::ParseCachePolicy(PolicyText) : CachePolicy::Default;
}
+CacheRecordPolicy
+LoadCacheRecordPolicy(CbObjectView Object, CachePolicy DefaultPolicy = CachePolicy::Default)
+{
+ OptionalCacheRecordPolicy Policy = CacheRecordPolicy::Load(Object);
+ return Policy ? std::move(Policy).Get() : CacheRecordPolicy(DefaultPolicy);
+}
+
struct AttachmentCount
{
uint32_t New = 0;
@@ -460,14 +467,14 @@ HttpStructuredCacheService::HandlePutCacheRecord(zen::HttpServerRequest& Request
Body.SetContentType(ContentType);
- if (ContentType == HttpContentType::kBinary)
+ if (ContentType == HttpContentType::kBinary || ContentType == HttpContentType::kCompressedBinary)
{
ZEN_DEBUG("PUT - '{}/{}' {} '{}'", Ref.BucketSegment, Ref.HashKey, NiceBytes(Body.Size()), ToString(ContentType));
m_CacheStore.Put(Ref.BucketSegment, Ref.HashKey, {.Value = Body});
if (EnumHasAllFlags(PolicyFromURL, CachePolicy::StoreRemote))
{
- m_UpstreamCache.EnqueueUpstream({.Type = ZenContentType::kBinary, .Key = {Ref.BucketSegment, Ref.HashKey}});
+ m_UpstreamCache.EnqueueUpstream({.Type = ContentType, .Key = {Ref.BucketSegment, Ref.HashKey}});
}
Request.WriteResponse(HttpResponseCode::Created);
@@ -829,10 +836,18 @@ HttpStructuredCacheService::HandleRpcRequest(zen::HttpServerRequest& Request)
{
HandleRpcGetCacheRecords(AsyncRequest, Object);
}
+ else if (Method == "PutCacheValues"sv)
+ {
+ HandleRpcPutCacheValues(AsyncRequest, Package);
+ }
else if (Method == "GetCacheValues"sv)
{
HandleRpcGetCacheValues(AsyncRequest, Object);
}
+ else if (Method == "GetCacheChunks"sv)
+ {
+ HandleRpcGetCacheChunks(AsyncRequest, Object);
+ }
else
{
AsyncRequest.WriteResponse(HttpResponseCode::BadRequest);
@@ -872,7 +887,7 @@ HttpStructuredCacheService::HandleRpcPutCacheRecords(zen::HttpServerRequest& Req
{
return Request.WriteResponse(HttpResponseCode::BadRequest);
}
- CacheRecordPolicy Policy = CacheRecordPolicy::Load(RequestObject["Policy"sv].AsObjectView(), DefaultPolicy);
+ CacheRecordPolicy Policy = LoadCacheRecordPolicy(RequestObject["Policy"sv].AsObjectView(), DefaultPolicy);
PutRequestData PutRequest{std::move(Key), RecordObject, std::move(Policy)};
PutResult Result = PutCacheRecord(PutRequest, &BatchRequest);
@@ -966,7 +981,7 @@ HttpStructuredCacheService::PutCacheRecord(PutRequestData& Request, const CbPack
Count.Total);
ZenCacheValue CacheValue;
- CacheValue.Value = IoBuffer(Record.GetSize());
+ CacheValue.Value = IoBuffer(Record.GetSize());
Record.CopyTo(MutableMemoryView(CacheValue.Value.MutableData(), CacheValue.Value.GetSize()));
CacheValue.Value.SetContentType(ZenContentType::kCbObject);
m_CacheStore.Put(Request.Key.Bucket, Request.Key.Hash, CacheValue);
@@ -981,14 +996,15 @@ HttpStructuredCacheService::PutCacheRecord(PutRequestData& Request, const CbPack
return PutResult::Success;
}
+#if BACKWARDS_COMPATABILITY_JAN2022
void
-HttpStructuredCacheService::HandleRpcGetCacheRecords(zen::HttpServerRequest& Request, CbObjectView RpcRequest)
+HttpStructuredCacheService::HandleRpcGetCacheRecordsLegacy(zen::HttpServerRequest& Request, CbObjectView RpcRequest)
{
ZEN_TRACE_CPU("Z$::RpcGetCacheRecords");
CbPackage RpcResponse;
CbObjectView Params = RpcRequest["Params"sv].AsObjectView();
- CacheRecordPolicy BatchPolicy = CacheRecordPolicy::Load(Params["Policy"sv].AsObjectView());
+ CacheRecordPolicy BatchPolicy = LoadCacheRecordPolicy(Params["Policy"sv].AsObjectView());
std::vector<CacheKey> CacheKeys;
std::vector<IoBuffer> CacheValues;
std::vector<size_t> UpstreamRequests;
@@ -1014,7 +1030,8 @@ HttpStructuredCacheService::HandleRpcGetCacheRecords(zen::HttpServerRequest& Req
uint32_t MissingCount = 0;
uint32_t MissingReadFromUpstreamCount = 0;
- if (EnumHasAllFlags(BatchPolicy.GetRecordPolicy(), CachePolicy::QueryLocal) && m_CacheStore.Get(Key.Bucket, Key.Hash, CacheValue))
+ if (EnumHasAllFlags(BatchPolicy.GetRecordPolicy(), CachePolicy::QueryLocal) && m_CacheStore.Get(Key.Bucket, Key.Hash, CacheValue) &&
+ CacheValue.Value.GetContentType() == ZenContentType::kCbObject)
{
CbObjectView CacheRecord(CacheValue.Value.Data());
CacheRecord.IterateAttachments(
@@ -1060,12 +1077,8 @@ HttpStructuredCacheService::HandleRpcGetCacheRecords(zen::HttpServerRequest& Req
});
}
- if ((!CacheValue.Value && EnumHasAllFlags(BatchPolicy.GetRecordPolicy(), CachePolicy::QueryRemote)) ||
- MissingReadFromUpstreamCount != 0)
- {
- UpstreamRequests.push_back(KeyIndex);
- }
- else if (CacheValue.Value && (MissingCount == 0 || EnumHasAllFlags(BatchPolicy.GetRecordPolicy(), CachePolicy::PartialRecord)))
+ // Searching upstream is not implemented in this legacy support function
+ if (CacheValue.Value && (MissingCount == 0 || EnumHasAllFlags(BatchPolicy.GetRecordPolicy(), CachePolicy::PartialRecord)))
{
ZEN_DEBUG("HIT - '{}/{}' {} '{}' (LOCAL) {}",
Key.Bucket,
@@ -1094,116 +1107,445 @@ HttpStructuredCacheService::HandleRpcGetCacheRecords(zen::HttpServerRequest& Req
++KeyIndex;
}
- if (!UpstreamRequests.empty())
+ CbObjectWriter ResponseObject;
+
+ ResponseObject.BeginArray("Result"sv);
+ for (const IoBuffer& Value : CacheValues)
{
- const auto OnCacheRecordGetComplete = [this, &CacheValues, &RpcResponse, &BatchPolicy](CacheRecordGetCompleteParams&& Params) {
- ZEN_ASSERT(Params.KeyIndex < CacheValues.size());
+ if (Value)
+ {
+ CbObjectView Record(Value.Data());
+ ResponseObject << Record;
+ }
+ else
+ {
+ ResponseObject.AddNull();
+ }
+ }
+ ResponseObject.EndArray();
- IoBuffer CacheValue;
- AttachmentCount Count;
+ RpcResponse.SetObject(ResponseObject.Save());
- if (Params.Record)
+ BinaryWriter MemStream;
+ RpcResponse.Save(MemStream);
+
+ Request.WriteResponse(HttpResponseCode::OK,
+ HttpContentType::kCbPackage,
+ IoBuffer(IoBuffer::Wrap, MemStream.GetData(), MemStream.GetSize()));
+}
+#endif
+
+void
+HttpStructuredCacheService::HandleRpcGetCacheRecords(zen::HttpServerRequest& HttpRequest, CbObjectView RpcRequest)
+{
+#if BACKWARDS_COMPATABILITY_JAN2022
+ // Backwards compatability;
+ if (RpcRequest["Params"sv].AsObjectView()["CacheKeys"sv])
+ {
+ return HandleRpcGetCacheRecordsLegacy(HttpRequest, RpcRequest);
+ }
+#endif
+ ZEN_TRACE_CPU("Z$::RpcGetCacheRecords");
+
+ CbObjectView Params = RpcRequest["Params"sv].AsObjectView();
+ ZEN_ASSERT(RpcRequest["Method"sv].AsString() == "GetCacheRecords"sv);
+
+ struct ValueRequestData
+ {
+ Oid ValueId;
+ IoHash ContentId;
+ CompressedBuffer Payload;
+ CachePolicy DownstreamPolicy;
+ bool Exists = false;
+ bool ReadFromUpstream = false;
+ };
+ struct RecordRequestData
+ {
+ CacheKeyRequest Upstream;
+ CbObjectView RecordObject;
+ IoBuffer RecordCacheValue;
+ CacheRecordPolicy DownstreamPolicy;
+ std::vector<ValueRequestData> Values;
+ bool Complete = false;
+ bool UsedUpstream = false;
+ };
+
+ std::string_view PolicyText = Params["DefaultPolicy"sv].AsString();
+ CachePolicy DefaultPolicy = !PolicyText.empty() ? ParseCachePolicy(PolicyText) : CachePolicy::Default;
+ std::vector<RecordRequestData> Requests;
+ std::vector<size_t> UpstreamIndexes;
+ CbArrayView RequestsArray = Params["Requests"sv].AsArrayView();
+ Requests.reserve(RequestsArray.Num());
+
+ auto ParseValues = [](RecordRequestData& Request) {
+ CbArrayView ValuesArray = Request.RecordObject["Values"sv].AsArrayView();
+ Request.Values.reserve(ValuesArray.Num());
+ for (CbFieldView ValueField : ValuesArray)
+ {
+ CbObjectView ValueObject = ValueField.AsObjectView();
+ Oid ValueId = ValueObject["Id"sv].AsObjectId();
+ CbFieldView RawHashField = ValueObject["RawHash"sv];
+ IoHash RawHash = RawHashField.AsBinaryAttachment();
+ if (ValueId && !RawHashField.HasError())
+ {
+ Request.Values.push_back({ValueId, RawHash});
+ Request.Values.back().DownstreamPolicy = Request.DownstreamPolicy.GetValuePolicy(ValueId);
+ }
+ }
+ };
+
+ for (CbFieldView RequestField : RequestsArray)
+ {
+ RecordRequestData& Request = Requests.emplace_back();
+ CbObjectView RequestObject = RequestField.AsObjectView();
+ CbObjectView KeyObject = RequestObject["Key"sv].AsObjectView();
+ CbFieldView BucketField = KeyObject["Bucket"sv];
+ CbFieldView HashField = KeyObject["Hash"sv];
+ CacheKey& Key = Request.Upstream.Key;
+ Key = CacheKey::Create(BucketField.AsString(), HashField.AsHash());
+ if (HashField.HasError() || Key.Bucket.empty())
+ {
+ return HttpRequest.WriteResponse(HttpResponseCode::BadRequest);
+ }
+ Request.DownstreamPolicy = LoadCacheRecordPolicy(RequestObject["Policy"sv].AsObjectView(), DefaultPolicy);
+ const CacheRecordPolicy& Policy = Request.DownstreamPolicy;
+
+ ZenCacheValue CacheValue;
+ bool NeedUpstreamAttachment = false;
+ bool FoundLocalInvalid = false;
+ ZenCacheValue RecordCacheValue;
+
+ if (EnumHasAllFlags(Policy.GetRecordPolicy(), CachePolicy::QueryLocal) && m_CacheStore.Get(Key.Bucket, Key.Hash, RecordCacheValue))
+ {
+ Request.RecordCacheValue = std::move(RecordCacheValue.Value);
+ if (Request.RecordCacheValue.GetContentType() != ZenContentType::kCbObject)
{
- Params.Record.IterateAttachments([this, &RpcResponse, &Params, &Count, &BatchPolicy](CbFieldView HashView) {
- CachePolicy ValuePolicy = BatchPolicy.GetRecordPolicy();
- bool FoundInUpstream = false;
- if (EnumHasAllFlags(ValuePolicy, CachePolicy::QueryRemote))
+ FoundLocalInvalid = true;
+ }
+ else
+ {
+ Request.RecordObject = CbObjectView(Request.RecordCacheValue.GetData());
+ ParseValues(Request);
+
+ Request.Complete = true;
+ for (ValueRequestData& Value : Request.Values)
+ {
+ CachePolicy ValuePolicy = Value.DownstreamPolicy;
+ if (!EnumHasAllFlags(ValuePolicy, CachePolicy::QueryLocal))
{
- if (const CbAttachment* Attachment = Params.Package.FindAttachment(HashView.AsHash()))
+ // A value that is requested without the Query flag (such as None/Disable) counts as existing, because we
+ // didn't ask for it and thus the record is complete in its absence.
+ if (!EnumHasAllFlags(ValuePolicy, CachePolicy::QueryRemote))
{
- FoundInUpstream = true;
- if (CompressedBuffer Compressed = Attachment->AsCompressedBinary())
- {
- FoundInUpstream = true;
- if (EnumHasAllFlags(ValuePolicy, CachePolicy::StoreLocal))
- {
- auto InsertResult = m_CidStore.AddChunk(Compressed);
- if (InsertResult.New)
- {
- Count.New++;
- }
- }
- Count.Valid++;
-
- if (!EnumHasAllFlags(ValuePolicy, CachePolicy::SkipData))
- {
- RpcResponse.AddAttachment(CbAttachment(Compressed));
- }
- }
- else
+ Value.Exists = true;
+ }
+ else
+ {
+ NeedUpstreamAttachment = true;
+ Value.ReadFromUpstream = true;
+ Request.Complete = false;
+ }
+ }
+ else if (EnumHasAllFlags(ValuePolicy, CachePolicy::SkipData))
+ {
+ if (m_CidStore.ContainsChunk(Value.ContentId))
+ {
+ Value.Exists = true;
+ }
+ else
+ {
+ if (EnumHasAllFlags(ValuePolicy, CachePolicy::QueryRemote))
{
- ZEN_DEBUG("Uncompressed value '{}' from upstream cache record '{}/{}'",
- HashView.AsHash(),
- Params.Key.Bucket,
- Params.Key.Hash);
- Count.Invalid++;
+ NeedUpstreamAttachment = true;
+ Value.ReadFromUpstream = true;
}
+ Request.Complete = false;
}
}
- if (!FoundInUpstream && EnumHasAllFlags(ValuePolicy, CachePolicy::QueryLocal) &&
- m_CidStore.ContainsChunk(HashView.AsHash()))
+ else
{
- // We added the attachment for this Value in the local loop before calling m_UpstreamCache
- Count.Valid++;
+ if (IoBuffer Chunk = m_CidStore.FindChunkByCid(Value.ContentId))
+ {
+ ZEN_ASSERT(Chunk.GetSize() > 0);
+ Value.Payload = CompressedBuffer::FromCompressed(SharedBuffer(Chunk));
+ Value.Exists = true;
+ }
+ else
+ {
+ if (EnumHasAllFlags(ValuePolicy, CachePolicy::QueryRemote))
+ {
+ NeedUpstreamAttachment = true;
+ Value.ReadFromUpstream = true;
+ }
+ Request.Complete = false;
+ }
}
- Count.Total++;
- });
+ }
+ }
+ }
+ if (!Request.Complete)
+ {
+ bool NeedUpstreamRecord =
+ !Request.RecordObject && !FoundLocalInvalid && EnumHasAllFlags(Policy.GetRecordPolicy(), CachePolicy::QueryRemote);
+ if (NeedUpstreamRecord || NeedUpstreamAttachment)
+ {
+ UpstreamIndexes.push_back(Requests.size() - 1);
+ }
+ }
+ }
+ if (Requests.empty())
+ {
+ return HttpRequest.WriteResponse(HttpResponseCode::BadRequest);
+ }
+
+ if (!UpstreamIndexes.empty())
+ {
+ std::vector<CacheKeyRequest*> UpstreamRequests;
+ UpstreamRequests.reserve(UpstreamIndexes.size());
+ for (size_t Index : UpstreamIndexes)
+ {
+ RecordRequestData& Request = Requests[Index];
+ UpstreamRequests.push_back(&Request.Upstream);
- if ((Count.Valid == Count.Total) || EnumHasAllFlags(BatchPolicy.GetRecordPolicy(), CachePolicy::PartialRecord))
+ if (Request.Values.size())
+ {
+ // We will be returning the local object and know all the value Ids that exist in it
+ // Convert all their Downstream Values to upstream values, and add SkipData to any ones that we already have.
+ CachePolicy UpstreamBasePolicy = ConvertToUpstream(Request.DownstreamPolicy.GetBasePolicy()) | CachePolicy::SkipMeta;
+ CacheRecordPolicyBuilder Builder(UpstreamBasePolicy);
+ for (ValueRequestData& Value : Request.Values)
{
- CacheValue = CbObject::Clone(Params.Record).GetBuffer().AsIoBuffer();
+ CachePolicy UpstreamPolicy = ConvertToUpstream(Value.DownstreamPolicy);
+ UpstreamPolicy |= !Value.ReadFromUpstream ? CachePolicy::SkipData : CachePolicy::None;
+ Builder.AddValuePolicy(Value.ValueId, UpstreamPolicy);
}
+ Request.Upstream.Policy = Builder.Build();
+ }
+ else
+ {
+ // We don't know which Values exist in the Record; ask the upstrem for all values that the client wants,
+ // and convert the CacheRecordPolicy to an upstream policy
+ Request.Upstream.Policy = Request.DownstreamPolicy.ConvertToUpstream();
}
+ }
- if (CacheValue)
+ const auto OnCacheRecordGetComplete = [this, &ParseValues](CacheRecordGetCompleteParams&& Params) {
+ if (!Params.Record)
{
- ZEN_DEBUG("HIT - '{}/{}' {} '{}' attachments '{}/{}/{}' (new/valid/total) (UPSTREAM)",
- Params.Key.Bucket,
- Params.Key.Hash,
- NiceBytes(CacheValue.GetSize()),
- ToString(HttpContentType::kCbPackage),
- Count.New,
- Count.Valid,
- Count.Total);
-
- CacheValue.SetContentType(ZenContentType::kCbObject);
- CacheValues[Params.KeyIndex] = CacheValue;
- if (EnumHasAllFlags(BatchPolicy.GetRecordPolicy(), CachePolicy::StoreLocal))
+ return;
+ }
+
+ RecordRequestData& Request =
+ *reinterpret_cast<RecordRequestData*>(reinterpret_cast<char*>(&Params.Request) - offsetof(RecordRequestData, Upstream));
+ const CacheKey& Key = Request.Upstream.Key;
+ if (!Request.RecordObject)
+ {
+ CbObject ObjectBuffer = CbObject::Clone(Params.Record);
+ Request.RecordCacheValue = ObjectBuffer.GetBuffer().AsIoBuffer();
+ Request.RecordCacheValue.SetContentType(ZenContentType::kCbObject);
+ Request.RecordObject = ObjectBuffer;
+ if (EnumHasAllFlags(Request.DownstreamPolicy.GetRecordPolicy(), CachePolicy::StoreLocal))
{
- m_CacheStore.Put(Params.Key.Bucket, Params.Key.Hash, {.Value = CacheValue});
+ m_CacheStore.Put(Key.Bucket, Key.Hash, {.Value = {Request.RecordCacheValue}});
}
-
- m_CacheStats.HitCount++;
- m_CacheStats.UpstreamHitCount++;
+ ParseValues(Request);
+ Request.UsedUpstream = true;
}
- else
+
+ Request.Complete = true;
+ for (ValueRequestData& Value : Request.Values)
{
- const bool IsPartial = Count.Valid != Count.Total;
- ZEN_DEBUG("MISS - '{}/{}' {}", Params.Key.Bucket, Params.Key.Hash, IsPartial ? "(partial)"sv : ""sv);
- m_CacheStats.MissCount++;
+ if (Value.Exists)
+ {
+ continue;
+ }
+ CachePolicy ValuePolicy = Value.DownstreamPolicy;
+ if (!EnumHasAllFlags(ValuePolicy, CachePolicy::QueryRemote))
+ {
+ Request.Complete = false;
+ continue;
+ }
+ if (!EnumHasAllFlags(ValuePolicy, CachePolicy::SkipData) || EnumHasAllFlags(ValuePolicy, CachePolicy::StoreLocal))
+ {
+ if (const CbAttachment* Attachment = Params.Package.FindAttachment(Value.ContentId))
+ {
+ if (CompressedBuffer Compressed = Attachment->AsCompressedBinary())
+ {
+ Request.UsedUpstream = true;
+ Value.Exists = true;
+ if (EnumHasAllFlags(ValuePolicy, CachePolicy::StoreLocal))
+ {
+ m_CidStore.AddChunk(Compressed);
+ }
+ if (!EnumHasAllFlags(ValuePolicy, CachePolicy::SkipData))
+ {
+ Value.Payload = Compressed;
+ }
+ }
+ else
+ {
+ ZEN_DEBUG("Uncompressed value '{}' from upstream cache record '{}/{}'", Value.ContentId, Key.Bucket, Key.Hash);
+ }
+ }
+ if (!Value.Exists && !EnumHasAllFlags(ValuePolicy, CachePolicy::SkipData))
+ {
+ Request.Complete = false;
+ }
+ // Request.Complete does not need to be set to false for upstream SkipData attachments.
+ // In the PartialRecord==false case, the upstream will have failed the entire record if any SkipData attachment
+ // didn't exist and we will not get here. In the PartialRecord==true case, we do not need to inform the client of
+ // any missing SkipData attachments.
+ }
}
};
- m_UpstreamCache.GetCacheRecords(CacheKeys, UpstreamRequests, BatchPolicy, std::move(OnCacheRecordGetComplete));
+ m_UpstreamCache.GetCacheRecords(UpstreamRequests, std::move(OnCacheRecordGetComplete));
}
+ CbPackage ResponsePackage;
CbObjectWriter ResponseObject;
ResponseObject.BeginArray("Result"sv);
- for (const IoBuffer& Value : CacheValues)
+ for (RecordRequestData& Request : Requests)
{
- if (Value)
+ const CacheKey& Key = Request.Upstream.Key;
+ if (Request.Complete ||
+ (Request.RecordObject && EnumHasAllFlags(Request.DownstreamPolicy.GetRecordPolicy(), CachePolicy::PartialRecord)))
{
- CbObjectView Record(Value.Data());
- ResponseObject << Record;
+ ResponseObject << Request.RecordObject;
+ for (ValueRequestData& Value : Request.Values)
+ {
+ if (!EnumHasAllFlags(Value.DownstreamPolicy, CachePolicy::SkipData) && Value.Payload)
+ {
+ ResponsePackage.AddAttachment(CbAttachment(Value.Payload));
+ }
+ }
+
+ ZEN_DEBUG("HIT - '{}/{}' {}{}{}",
+ Key.Bucket,
+ Key.Hash,
+ NiceBytes(Request.RecordCacheValue.Size()),
+ Request.Complete ? ""sv : " (PARTIAL)"sv,
+ Request.UsedUpstream ? " (UPSTREAM)"sv : ""sv);
+ m_CacheStats.HitCount++;
+ m_CacheStats.UpstreamHitCount += Request.UsedUpstream ? 1 : 0;
}
else
{
ResponseObject.AddNull();
+
+ if (!EnumHasAnyFlags(Request.DownstreamPolicy.GetRecordPolicy(), CachePolicy::Query))
+ {
+ // If they requested no query, do not record this as a miss
+ ZEN_DEBUG("DISABLEDQUERY - '{}/{}'", Key.Bucket, Key.Hash);
+ }
+ else
+ {
+ ZEN_DEBUG("MISS - '{}/{}' {}", Key.Bucket, Key.Hash, Request.RecordObject ? ""sv : "(PARTIAL)"sv);
+ m_CacheStats.MissCount++;
+ }
}
}
ResponseObject.EndArray();
+ ResponsePackage.SetObject(ResponseObject.Save());
+
+ BinaryWriter MemStream;
+ ResponsePackage.Save(MemStream);
+
+ HttpRequest.WriteResponse(HttpResponseCode::OK,
+ HttpContentType::kCbPackage,
+ IoBuffer(IoBuffer::Wrap, MemStream.GetData(), MemStream.GetSize()));
+}
+void
+HttpStructuredCacheService::HandleRpcPutCacheValues(zen::HttpServerRequest& Request, const CbPackage& BatchRequest)
+{
+ ZEN_TRACE_CPU("Z$::RpcPutCacheValues");
+ CbObjectView BatchObject = BatchRequest.GetObject();
+
+ CbObjectView Params = BatchObject["Params"sv].AsObjectView();
+
+ ZEN_ASSERT(BatchObject["Method"sv].AsString() == "PutCacheValues"sv);
+
+ std::string_view PolicyText = Params["DefaultPolicy"].AsString();
+ CachePolicy DefaultPolicy = !PolicyText.empty() ? ParseCachePolicy(PolicyText) : CachePolicy::Default;
+ std::vector<bool> Results;
+ for (CbFieldView RequestField : Params["Requests"sv])
+ {
+ CbObjectView RequestObject = RequestField.AsObjectView();
+ CbObjectView KeyView = RequestObject["Key"sv].AsObjectView();
+ CbFieldView BucketField = KeyView["Bucket"sv];
+ CbFieldView HashField = KeyView["Hash"sv];
+ CacheKey Key = CacheKey::Create(BucketField.AsString(), HashField.AsHash());
+ if (BucketField.HasError() || HashField.HasError() || Key.Bucket.empty())
+ {
+ return Request.WriteResponse(HttpResponseCode::BadRequest);
+ }
+ PolicyText = RequestObject["Policy"sv].AsString();
+ CachePolicy Policy = !PolicyText.empty() ? ParseCachePolicy(PolicyText) : DefaultPolicy;
+ IoHash RawHash = RequestObject["RawHash"sv].AsBinaryAttachment();
+ bool Succeeded = false;
+ uint64_t TransferredSize = 0;
+
+ if (const CbAttachment* Attachment = BatchRequest.FindAttachment(RawHash))
+ {
+ if (Attachment->IsCompressedBinary())
+ {
+ CompressedBuffer Chunk = Attachment->AsCompressedBinary();
+ if (EnumHasAllFlags(Policy, CachePolicy::StoreRemote))
+ {
+ // TODO: Implement upstream puts of CacheValues with StoreLocal == false.
+ // Currently ProcessCacheRecord requires that the value exist in the local cache to put it upstream.
+ Policy |= CachePolicy::StoreLocal;
+ }
+
+ if (EnumHasAllFlags(Policy, CachePolicy::StoreLocal))
+ {
+ IoBuffer Value = Chunk.GetCompressed().Flatten().AsIoBuffer();
+ Value.SetContentType(ZenContentType::kCompressedBinary);
+ m_CacheStore.Put(Key.Bucket, Key.Hash, {.Value = Value});
+ TransferredSize = Chunk.GetCompressedSize();
+ }
+ Succeeded = true;
+ }
+ else
+ {
+ ZEN_WARN("PUTCACHEVALUES - '{}/{}/{}' FAILED, value is not compressed", Key.Bucket, Key.Hash, RawHash);
+ return Request.WriteResponse(HttpResponseCode::BadRequest);
+ }
+ }
+ else if (EnumHasAllFlags(Policy, CachePolicy::QueryLocal))
+ {
+ ZenCacheValue ExistingValue;
+ if (m_CacheStore.Get(Key.Bucket, Key.Hash, ExistingValue) && IsCompressedBinary(ExistingValue.Value.GetContentType()))
+ {
+ Succeeded = true;
+ }
+ }
+ // We do not search the Upstream. No data in a put means the caller is probing for whether they need to do a heavy put.
+ // If it doesn't exist locally they should do the heavy put rather than having us fetch it from upstream.
+
+ if (Succeeded && EnumHasAllFlags(Policy, CachePolicy::StoreRemote))
+ {
+ m_UpstreamCache.EnqueueUpstream({.Type = ZenContentType::kBinary, .Key = Key});
+ }
+ Results.push_back(Succeeded);
+ ZEN_DEBUG("PUTCACHEVALUES - '{}/{}' {}, '{}'", Key.Bucket, Key.Hash, NiceBytes(TransferredSize), Succeeded ? "Added"sv : "Invalid");
+ }
+ if (Results.empty())
+ {
+ return Request.WriteResponse(HttpResponseCode::BadRequest);
+ }
+
+ CbObjectWriter ResponseObject;
+ ResponseObject.BeginArray("Result"sv);
+ for (bool Value : Results)
+ {
+ ResponseObject.AddBool(Value);
+ }
+ ResponseObject.EndArray();
+
+ CbPackage RpcResponse;
RpcResponse.SetObject(ResponseObject.Save());
BinaryWriter MemStream;
@@ -1215,216 +1557,610 @@ HttpStructuredCacheService::HandleRpcGetCacheRecords(zen::HttpServerRequest& Req
}
void
-HttpStructuredCacheService::HandleRpcGetCacheValues(zen::HttpServerRequest& Request, CbObjectView RpcRequest)
+HttpStructuredCacheService::HandleRpcGetCacheValues(zen::HttpServerRequest& HttpRequest, CbObjectView RpcRequest)
{
- ZEN_TRACE_CPU("Z$::RpcGetCacheValues");
-
- ZEN_ASSERT(RpcRequest["Method"sv].AsString() == "GetCacheValues"sv);
+#if BACKWARDS_COMPATABILITY_JAN2022
+ if (RpcRequest["Params"sv].AsObjectView()["ChunkRequests"])
+ {
+ return HandleRpcGetCacheChunks(HttpRequest, RpcRequest);
+ }
+#endif
- std::vector<CacheChunkRequest> ChunkRequests;
- std::vector<size_t> UpstreamRequests;
- std::vector<IoBuffer> Chunks;
- CbObjectView Params = RpcRequest["Params"sv].AsObjectView();
+ ZEN_TRACE_CPU("Z$::RpcGetCacheValues");
- for (CbFieldView RequestView : Params["ChunkRequests"sv])
+ CbObjectView Params = RpcRequest["Params"sv].AsObjectView();
+ std::string_view PolicyText = Params["DefaultPolicy"sv].AsString();
+ CachePolicy DefaultPolicy = !PolicyText.empty() ? ParseCachePolicy(PolicyText) : CachePolicy::Default;
+ struct RequestData
{
- CbObjectView RequestObject = RequestView.AsObjectView();
- CbObjectView KeyObject = RequestObject["Key"sv].AsObjectView();
- const CacheKey Key = CacheKey::Create(KeyObject["Bucket"sv].AsString(), KeyObject["Hash"sv].AsHash());
- const IoHash ChunkId = RequestObject["ChunkId"sv].AsHash();
- const Oid ValueId = RequestObject["ValueId"sv].AsObjectId();
- const uint64_t RawOffset = RequestObject["RawOffset"sv].AsUInt64();
- const uint64_t RawSize = RequestObject["RawSize"sv].AsUInt64();
- std::string_view PolicyText = RequestObject["Policy"sv].AsString();
- const CachePolicy ChunkPolicy = !PolicyText.empty() ? ParseCachePolicy(PolicyText) : CachePolicy::Default;
+ CacheKey Key;
+ CachePolicy Policy;
+ CompressedBuffer Result;
+ };
+ std::vector<RequestData> Requests;
- // Note we could use emplace_back here but [Apple] LLVM-12's C++ library
- // can't infer a constructor like other platforms (or can't handle an
- // initializer list like others do).
- ChunkRequests.push_back({Key, ChunkId, ValueId, RawOffset, RawSize, ChunkPolicy});
- }
+ ZEN_ASSERT(RpcRequest["Method"sv].AsString() == "GetCacheValues"sv);
- if (ChunkRequests.empty())
+ for (CbFieldView RequestField : Params["Requests"sv])
{
- return Request.WriteResponse(HttpResponseCode::BadRequest);
- }
+ RequestData& Request = Requests.emplace_back();
+ CbObjectView RequestObject = RequestField.AsObjectView();
+ CbObjectView KeyObject = RequestObject["Key"sv].AsObjectView();
+ CbFieldView BucketField = KeyObject["Bucket"sv];
+ CbFieldView HashField = KeyObject["Hash"sv];
+ Request.Key = CacheKey::Create(BucketField.AsString(), HashField.AsHash());
+ if (BucketField.HasError() || HashField.HasError() || Request.Key.Bucket.empty())
+ {
+ return HttpRequest.WriteResponse(HttpResponseCode::BadRequest);
+ }
+ PolicyText = RequestObject["Policy"sv].AsString();
+ Request.Policy = !PolicyText.empty() ? ParseCachePolicy(PolicyText) : DefaultPolicy;
- Chunks.resize(ChunkRequests.size());
+ CacheKey& Key = Request.Key;
+ CachePolicy Policy = Request.Policy;
+ CompressedBuffer& Result = Request.Result;
- // Unreal uses a 12 byte ID to address cache record values. When the uncompressed hash (ChunkId)
- // is missing, load the cache record and try to find the raw hash from the ValueId.
- {
- const auto GetChunkIdFromValueId = [](CbObjectView Record, const Oid& ValueId) -> IoHash {
- if (ValueId)
+ ZenCacheValue CacheValue;
+ std::string_view Source;
+ if (EnumHasAllFlags(Policy, CachePolicy::QueryLocal))
+ {
+ if (m_CacheStore.Get(Key.Bucket, Key.Hash, CacheValue) && IsCompressedBinary(CacheValue.Value.GetContentType()))
{
- // A valid ValueId indicates that the caller is searching for a Value in a Record
- // that was Put with ICacheStore::Put
- for (CbFieldView ValueView : Record["Values"sv])
+ Result = CompressedBuffer::FromCompressed(SharedBuffer(CacheValue.Value));
+ if (Result)
{
- CbObjectView ValueObject = ValueView.AsObjectView();
- const Oid Id = ValueObject["Id"sv].AsObjectId();
-
- if (Id == ValueId)
- {
- return ValueObject["RawHash"sv].AsHash();
- }
+ Source = "LOCAL"sv;
}
-
- // Legacy fields from previous version of CacheRecord serialization:
- if (CbObjectView ValueObject = Record["Value"sv].AsObjectView())
+ }
+ }
+ if (!Result && EnumHasAllFlags(Policy, CachePolicy::QueryRemote))
+ {
+ GetUpstreamCacheResult UpstreamResult = m_UpstreamCache.GetCacheRecord({Key.Bucket, Key.Hash}, ZenContentType::kBinary);
+ if (UpstreamResult.Success && IsCompressedBinary(UpstreamResult.Value.GetContentType()))
+ {
+ Result = CompressedBuffer::FromCompressed(SharedBuffer(UpstreamResult.Value));
+ if (Result)
{
- const Oid Id = ValueObject["Id"sv].AsObjectId();
- if (Id == ValueId)
+ UpstreamResult.Value.SetContentType(ZenContentType::kCompressedBinary);
+ Source = "UPSTREAM"sv;
+ // TODO: Respect the StoreLocal flag once we have upstream existence-only checks. For now the requirement
+ // that we copy data from upstream even when SkipData and !StoreLocal are true means that it is too expensive
+ // for us to keep the data only on the upstream server.
+ // if (EnumHasAllFlags(Policy, CachePolicy::StoreLocal))
{
- return ValueObject["RawHash"sv].AsHash();
+ m_CacheStore.Put(Key.Bucket, Key.Hash, ZenCacheValue{UpstreamResult.Value});
}
}
+ }
+ }
- for (CbFieldView AttachmentView : Record["Attachments"sv])
- {
- CbObjectView AttachmentObject = AttachmentView.AsObjectView();
- const Oid Id = AttachmentObject["Id"sv].AsObjectId();
+ if (Result)
+ {
+ ZEN_DEBUG("GETCACHEVALUES HIT - '{}/{}' {} ({})", Key.Bucket, Key.Hash, NiceBytes(Result.GetCompressed().GetSize()), Source);
+ m_CacheStats.HitCount++;
+ }
+ else if (!EnumHasAnyFlags(Policy, CachePolicy::Query))
+ {
+ // If they requested no query, do not record this as a miss
+ ZEN_DEBUG("GETCACHEVALUES DISABLEDQUERY - '{}/{}'", Key.Bucket, Key.Hash);
+ }
+ else
+ {
+ ZEN_DEBUG("GETCACHEVALUES MISS - '{}/{}'", Key.Bucket, Key.Hash);
+ m_CacheStats.MissCount++;
+ }
+ }
+ if (Requests.empty())
+ {
+ return HttpRequest.WriteResponse(HttpResponseCode::BadRequest);
+ }
- if (Id == ValueId)
- {
- return AttachmentObject["RawHash"sv].AsHash();
- }
- }
- return IoHash::Zero;
- }
- else
+ CbPackage RpcResponse;
+ CbObjectWriter ResponseObject;
+ ResponseObject.BeginArray("Result"sv);
+ for (const RequestData& Request : Requests)
+ {
+ ResponseObject.BeginObject();
+ {
+ const CompressedBuffer& Result = Request.Result;
+ if (Result)
{
- // An invalid ValueId indicates that the caller is requesting a Value that
- // was Put with ICacheStore::PutValue
- return Record["RawHash"sv].AsHash();
+ ResponseObject.AddHash("RawHash"sv, IoHash::FromBLAKE3(Result.GetRawHash()));
+ if (!EnumHasAllFlags(Request.Policy, CachePolicy::SkipData))
+ {
+ RpcResponse.AddAttachment(CbAttachment(Result));
+ }
+ else
+ {
+ ResponseObject.AddInteger("RawSize"sv, Result.GetRawSize());
+ }
}
- };
+ }
+ ResponseObject.EndObject();
+ }
+ ResponseObject.EndArray();
- CacheKey CurrentKey = CacheKey::Empty;
- IoBuffer CurrentRecordBuffer;
+ RpcResponse.SetObject(ResponseObject.Save());
+
+ BinaryWriter MemStream;
+ RpcResponse.Save(MemStream);
+
+ HttpRequest.WriteResponse(HttpResponseCode::OK,
+ HttpContentType::kCbPackage,
+ IoBuffer(IoBuffer::Wrap, MemStream.GetData(), MemStream.GetSize()));
+}
+
+namespace GetCacheChunks::detail {
+
+ struct ValueData
+ {
+ Oid ValueId;
+ IoHash ContentId;
+ uint64_t RawSize;
+ };
+ struct KeyRequestData
+ {
+ CacheKeyRequest Upstream;
+ IoBuffer CacheValue;
+ std::vector<ValueData> Values;
+ CachePolicy DownstreamRecordPolicy;
+ CachePolicy DownstreamPolicy;
+ std::string_view Source;
+ bool Exists = false;
+ bool HasRequest = false;
+ bool HasRecordRequest = false;
+ bool HasValueRequest = false;
+ bool ValuesRead = false;
+ };
+ struct ChunkRequestData
+ {
+ CacheChunkRequest Upstream;
+ KeyRequestData* KeyRequest;
+ size_t KeyRequestIndex;
+ CachePolicy DownstreamPolicy;
+ CompressedBuffer Value;
+ std::string_view Source;
+ uint64_t TotalSize = 0;
+ bool Exists = false;
+ bool IsRecordRequest = false;
+ bool TotalSizeKnown = false;
+ };
+
+} // namespace GetCacheChunks::detail
+
+void
+HttpStructuredCacheService::HandleRpcGetCacheChunks(zen::HttpServerRequest& HttpRequest, CbObjectView RpcRequest)
+{
+ using namespace GetCacheChunks::detail;
+
+ ZEN_TRACE_CPU("Z$::RpcGetCacheChunks");
+
+ std::vector<KeyRequestData> KeyRequests;
+ std::vector<ChunkRequestData> Chunks;
+ BACKWARDS_COMPATABILITY_JAN2022_CODE(bool SendValueOnly = false;)
+ if (!TryGetCacheChunks_Parse(KeyRequests, Chunks BACKWARDS_COMPATABILITY_JAN2022_CODE(, SendValueOnly), RpcRequest))
+ {
+ return HttpRequest.WriteResponse(HttpResponseCode::BadRequest);
+ }
+ GetCacheChunks_LoadKeys(KeyRequests);
+ GetCacheChunks_LoadChunks(Chunks);
+ GetCacheChunks_SendResults(Chunks, HttpRequest BACKWARDS_COMPATABILITY_JAN2022_CODE(, SendValueOnly));
+}
+
+bool
+HttpStructuredCacheService::TryGetCacheChunks_Parse(std::vector<GetCacheChunks::detail::KeyRequestData>& KeyRequests,
+ std::vector<GetCacheChunks::detail::ChunkRequestData>& Chunks,
+ BACKWARDS_COMPATABILITY_JAN2022_CODE(bool& SendValueOnly, ) CbObjectView RpcRequest)
+{
+ using namespace GetCacheChunks::detail;
+
+#if BACKWARDS_COMPATABILITY_JAN2022
+ SendValueOnly = RpcRequest["MethodVersion"sv].AsInt32() < 1;
+#else
+ ZEN_ASSERT(RpcRequest["Method"sv].AsString() == "GetCacheChunks"sv);
+#endif
+
+ CbObjectView Params = RpcRequest["Params"sv].AsObjectView();
+ std::string_view DefaultPolicyText = Params["DefaultPolicy"sv].AsString();
+ CachePolicy DefaultPolicy = !DefaultPolicyText.empty() ? ParseCachePolicy(DefaultPolicyText) : CachePolicy::Default;
+
+ KeyRequestData* PreviousKeyRequest = nullptr;
+ CbArrayView ChunkRequestsArray = Params["ChunkRequests"sv].AsArrayView();
+ Chunks.reserve(ChunkRequestsArray.Num());
+ for (CbFieldView RequestView : ChunkRequestsArray)
+ {
+ ChunkRequestData& Chunk = Chunks.emplace_back();
+ CbObjectView RequestObject = RequestView.AsObjectView();
+
+ CbObjectView KeyObject = RequestObject["Key"sv].AsObjectView();
+ CbFieldView HashField = KeyObject["Hash"sv];
+ Chunk.Upstream.Key = CacheKey::Create(KeyObject["Bucket"sv].AsString(), HashField.AsHash());
+ if (Chunk.Upstream.Key.Bucket.empty() || HashField.HasError())
+ {
+ ZEN_WARN("GetCacheChunks: Invalid key in ChunkRequest.");
+ return false;
+ }
- for (CacheChunkRequest& ChunkRequest : ChunkRequests)
+ KeyRequestData* KeyRequest = nullptr;
+ if (!PreviousKeyRequest || PreviousKeyRequest->Upstream.Key < Chunk.Upstream.Key)
{
- if (ChunkRequest.ChunkId != IoHash::Zero)
+ KeyRequest = &KeyRequests.emplace_back();
+ KeyRequest->Upstream.Key = Chunk.Upstream.Key;
+ PreviousKeyRequest = KeyRequest;
+ }
+ else if (!(Chunk.Upstream.Key < PreviousKeyRequest->Upstream.Key))
+ {
+ KeyRequest = PreviousKeyRequest;
+ }
+ else
+ {
+ ZEN_WARN("GetCacheChunks: Keys in ChunkRequest are not sorted: {}/{} came after {}/{}.",
+ Chunk.Upstream.Key.Bucket,
+ Chunk.Upstream.Key.Hash,
+ PreviousKeyRequest->Upstream.Key.Bucket,
+ PreviousKeyRequest->Upstream.Key.Hash);
+ return false;
+ }
+ Chunk.KeyRequestIndex = std::distance(KeyRequests.data(), KeyRequest);
+
+ Chunk.Upstream.ChunkId = RequestObject["ChunkId"sv].AsHash();
+ Chunk.Upstream.ValueId = RequestObject["ValueId"sv].AsObjectId();
+ Chunk.Upstream.RawOffset = RequestObject["RawOffset"sv].AsUInt64();
+ Chunk.Upstream.RawSize = RequestObject["RawSize"sv].AsUInt64(UINT64_MAX);
+ std::string_view PolicyText = RequestObject["Policy"sv].AsString();
+ Chunk.DownstreamPolicy = !PolicyText.empty() ? ParseCachePolicy(PolicyText) : DefaultPolicy;
+#if BACKWARDS_COMPATABILITY_JAN2022
+ if (SendValueOnly)
+ {
+ Chunk.DownstreamPolicy = Chunk.DownstreamPolicy & (~CachePolicy::SkipData);
+ }
+#endif
+ Chunk.IsRecordRequest = (bool)Chunk.Upstream.ValueId;
+
+ if (!Chunk.IsRecordRequest || Chunk.Upstream.ChunkId == IoHash::Zero)
+ {
+ KeyRequest->DownstreamPolicy =
+ KeyRequest->HasRequest ? Union(KeyRequest->DownstreamPolicy, Chunk.DownstreamPolicy) : Chunk.DownstreamPolicy;
+ KeyRequest->HasRequest = true;
+ (Chunk.IsRecordRequest ? KeyRequest->HasRecordRequest : KeyRequest->HasValueRequest) = true;
+ }
+ }
+ if (Chunks.empty())
+ {
+ return false;
+ }
+ for (ChunkRequestData& Chunk : Chunks)
+ {
+ Chunk.KeyRequest = &KeyRequests[Chunk.KeyRequestIndex];
+ }
+ return true;
+}
+
+void
+HttpStructuredCacheService::GetCacheChunks_LoadKeys(std::vector<GetCacheChunks::detail::KeyRequestData>& KeyRequests)
+{
+ using namespace GetCacheChunks::detail;
+
+ std::vector<CacheKeyRequest*> UpstreamRecordRequests;
+ std::vector<KeyRequestData*> UpstreamValueRequests;
+ for (KeyRequestData& KeyRequest : KeyRequests)
+ {
+ if (KeyRequest.HasRequest)
+ {
+ if (KeyRequest.HasRecordRequest)
{
- continue;
+ KeyRequest.DownstreamRecordPolicy = KeyRequest.DownstreamPolicy | CachePolicy::SkipData | CachePolicy::SkipMeta;
}
- if (ChunkRequest.Key != CurrentKey)
+ if (!KeyRequest.Exists && EnumHasAllFlags(KeyRequest.DownstreamPolicy, CachePolicy::QueryLocal))
{
- CurrentKey = ChunkRequest.Key;
-
+ // There's currently no interface for checking only whether a CacheValue exists without loading it,
+ // so we load it here even if SkipData is true and its a CacheValue request.
ZenCacheValue CacheValue;
- if (m_CacheStore.Get(CurrentKey.Bucket, CurrentKey.Hash, CacheValue))
+ if (m_CacheStore.Get(KeyRequest.Upstream.Key.Bucket, KeyRequest.Upstream.Key.Hash, CacheValue))
{
- CurrentRecordBuffer = CacheValue.Value;
+ KeyRequest.Exists = true;
+ KeyRequest.CacheValue = std::move(CacheValue.Value);
+ KeyRequest.Source = "LOCAL"sv;
}
}
-
- if (CurrentRecordBuffer)
+ if (!KeyRequest.Exists)
{
- ChunkRequest.ChunkId = GetChunkIdFromValueId(CbObjectView(CurrentRecordBuffer.GetData()), ChunkRequest.ValueId);
+ // At most one of RecordRequest or ValueRequest will succeed for the upstream request of the key a given key, but we don't
+ // know which,
+ // and if the requests (from arbitrary Unreal Class code) includes both types of request for a key, we want to ask for both
+ // kinds and pass the request that uses the one that succeeds.
+ if (KeyRequest.HasRecordRequest && EnumHasAllFlags(KeyRequest.DownstreamRecordPolicy, CachePolicy::QueryRemote))
+ {
+ KeyRequest.Upstream.Policy = CacheRecordPolicy(ConvertToUpstream(KeyRequest.DownstreamRecordPolicy));
+ UpstreamRecordRequests.push_back(&KeyRequest.Upstream);
+ }
+ if (KeyRequest.HasValueRequest && EnumHasAllFlags(KeyRequest.DownstreamPolicy, CachePolicy::QueryRemote))
+ {
+ UpstreamValueRequests.push_back(&KeyRequest);
+ }
}
}
}
- for (size_t RequestIndex = 0; const CacheChunkRequest& ChunkRequest : ChunkRequests)
+ if (!UpstreamRecordRequests.empty())
{
- const bool QueryLocal = EnumHasAllFlags(ChunkRequest.Policy, CachePolicy::QueryLocal);
- const bool QueryRemote = EnumHasAllFlags(ChunkRequest.Policy, CachePolicy::QueryRemote);
+ const auto OnCacheRecordGetComplete = [this](CacheRecordGetCompleteParams&& Params) {
+ if (!Params.Record)
+ {
+ return;
+ }
- if (QueryLocal)
+ KeyRequestData& KeyRequest =
+ *reinterpret_cast<KeyRequestData*>(reinterpret_cast<char*>(&Params.Request) - offsetof(KeyRequestData, Upstream));
+ const CacheKey& Key = KeyRequest.Upstream.Key;
+ KeyRequest.Exists = true;
+ CbObject ObjectBuffer = CbObject::Clone(Params.Record);
+ KeyRequest.CacheValue = ObjectBuffer.GetBuffer().AsIoBuffer();
+ KeyRequest.CacheValue.SetContentType(ZenContentType::kCbObject);
+ KeyRequest.Source = "UPSTREAM"sv;
+
+ if (EnumHasAllFlags(KeyRequest.DownstreamPolicy, CachePolicy::StoreLocal))
+ {
+ m_CacheStore.Put(Key.Bucket, Key.Hash, {.Value = KeyRequest.CacheValue});
+ }
+ };
+ m_UpstreamCache.GetCacheRecords(UpstreamRecordRequests, std::move(OnCacheRecordGetComplete));
+ }
+
+ if (!UpstreamValueRequests.empty())
+ {
+ for (KeyRequestData* KeyRequestPtr : UpstreamValueRequests)
{
- if (IoBuffer Chunk = m_CidStore.FindChunkByCid(ChunkRequest.ChunkId))
+ KeyRequestData& KeyRequest = *KeyRequestPtr;
+ CacheKey& Key = KeyRequest.Upstream.Key;
+ GetUpstreamCacheResult UpstreamResult = m_UpstreamCache.GetCacheRecord({Key.Bucket, Key.Hash}, ZenContentType::kBinary);
+ if (UpstreamResult.Success && IsCompressedBinary(UpstreamResult.Value.GetContentType()))
{
- ZEN_ASSERT(Chunk.GetSize() > 0);
-
- ZEN_DEBUG("HIT - '{}/{}/{}' {} '{}' ({})",
- ChunkRequest.Key.Bucket,
- ChunkRequest.Key.Hash,
- ChunkRequest.ChunkId,
- NiceBytes(Chunk.Size()),
- ToString(Chunk.GetContentType()),
- "LOCAL");
-
- Chunks[RequestIndex] = Chunk;
- m_CacheStats.HitCount++;
+ CompressedBuffer Result = CompressedBuffer::FromCompressed(SharedBuffer(UpstreamResult.Value));
+ if (Result)
+ {
+ KeyRequest.CacheValue = std::move(UpstreamResult.Value);
+ KeyRequest.CacheValue.SetContentType(ZenContentType::kCompressedBinary);
+ KeyRequest.Exists = true;
+ KeyRequest.Source = "UPSTREAM"sv;
+ // TODO: Respect the StoreLocal flag once we have upstream existence-only checks. For now the requirement
+ // that we copy data from upstream even when SkipData and !StoreLocal are true means that it is too expensive
+ // for us to keep the data only on the upstream server.
+ // if (EnumHasAllFlags(KeyRequest->DownstreamValuePolicy, CachePolicy::StoreLocal))
+ {
+ m_CacheStore.Put(Key.Bucket, Key.Hash, {.Value = KeyRequest.CacheValue});
+ }
+ }
}
- else if (QueryRemote)
+ }
+ }
+}
+
+void
+HttpStructuredCacheService::GetCacheChunks_LoadChunks(std::vector<GetCacheChunks::detail::ChunkRequestData>& Chunks)
+{
+ using namespace GetCacheChunks::detail;
+
+ std::vector<CacheChunkRequest*> UpstreamPayloadRequests;
+ for (ChunkRequestData& Chunk : Chunks)
+ {
+ if (Chunk.IsRecordRequest)
+ {
+ if (Chunk.Upstream.ChunkId == IoHash::Zero)
{
- UpstreamRequests.push_back(RequestIndex);
+ // Unreal uses a 12 byte ID to address cache record values. When the uncompressed hash (ChunkId)
+ // is missing, parse the cache record and try to find the raw hash from the ValueId.
+ KeyRequestData& KeyRequest = *Chunk.KeyRequest;
+ if (!KeyRequest.ValuesRead)
+ {
+ KeyRequest.ValuesRead = true;
+ if (KeyRequest.CacheValue && KeyRequest.CacheValue.GetContentType() == ZenContentType::kCbObject)
+ {
+ CbObjectView RecordObject = CbObjectView(KeyRequest.CacheValue.GetData());
+ CbArrayView ValuesArray = RecordObject["Values"sv].AsArrayView();
+ KeyRequest.Values.reserve(ValuesArray.Num());
+ for (CbFieldView ValueField : ValuesArray)
+ {
+ CbObjectView ValueObject = ValueField.AsObjectView();
+ Oid ValueId = ValueObject["Id"sv].AsObjectId();
+ CbFieldView RawHashField = ValueObject["RawHash"sv];
+ IoHash RawHash = RawHashField.AsBinaryAttachment();
+ if (ValueId && !RawHashField.HasError())
+ {
+ KeyRequest.Values.push_back({ValueId, RawHash, ValueObject["RawSize"sv].AsUInt64()});
+ }
+ }
+ }
+ }
+
+ for (const ValueData& Value : KeyRequest.Values)
+ {
+ if (Value.ValueId == Chunk.Upstream.ValueId)
+ {
+ Chunk.Upstream.ChunkId = Value.ContentId;
+ Chunk.TotalSize = Value.RawSize;
+ Chunk.TotalSizeKnown = true;
+ break;
+ }
+ }
}
- else
+
+ // Now load the ContentId from the local ContentIdStore or from the upstream
+ if (Chunk.Upstream.ChunkId != IoHash::Zero)
{
- ZEN_DEBUG("MISS - '{}/{}/{}'", ChunkRequest.Key.Bucket, ChunkRequest.Key.Hash, ChunkRequest.ChunkId);
- m_CacheStats.MissCount++;
+ if (EnumHasAllFlags(Chunk.DownstreamPolicy, CachePolicy::QueryLocal))
+ {
+ if (EnumHasAllFlags(Chunk.DownstreamPolicy, CachePolicy::SkipData) && Chunk.TotalSizeKnown)
+ {
+ if (m_CidStore.ContainsChunk(Chunk.Upstream.ChunkId))
+ {
+ Chunk.Exists = true;
+ Chunk.Source = "LOCAL"sv;
+ }
+ }
+ else if (IoBuffer Payload = m_CidStore.FindChunkByCid(Chunk.Upstream.ChunkId))
+ {
+ CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(Payload));
+ if (Compressed)
+ {
+ if (!EnumHasAllFlags(Chunk.DownstreamPolicy, CachePolicy::SkipData))
+ {
+ Chunk.Value = Compressed;
+ }
+ Chunk.Exists = true;
+ Chunk.TotalSize = Compressed.GetRawSize();
+ Chunk.TotalSizeKnown = true;
+ Chunk.Source = "LOCAL"sv;
+ }
+ }
+ }
+ if (!Chunk.Exists && EnumHasAllFlags(Chunk.DownstreamPolicy, CachePolicy::QueryRemote))
+ {
+ Chunk.Upstream.Policy = ConvertToUpstream(Chunk.DownstreamPolicy);
+ UpstreamPayloadRequests.push_back(&Chunk.Upstream);
+ }
}
}
else
{
- ZEN_DEBUG("SKIP - '{}/{}/{}'", ChunkRequest.Key.Bucket, ChunkRequest.Key.Hash, ChunkRequest.ChunkId);
+ if (Chunk.KeyRequest->Exists)
+ {
+ if (Chunk.KeyRequest->CacheValue && IsCompressedBinary(Chunk.KeyRequest->CacheValue.GetContentType()))
+ {
+ CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(Chunk.KeyRequest->CacheValue));
+ if (Compressed)
+ {
+ if (!EnumHasAllFlags(Chunk.DownstreamPolicy, CachePolicy::SkipData))
+ {
+ Chunk.Value = Compressed;
+ }
+ Chunk.Exists = true;
+ Chunk.TotalSize = Compressed.GetRawSize();
+ Chunk.TotalSizeKnown = true;
+ Chunk.Source = Chunk.KeyRequest->Source;
+ Chunk.Upstream.ChunkId = IoHash::FromBLAKE3(Compressed.GetRawHash());
+ }
+ }
+ }
}
-
- ++RequestIndex;
}
- if (!UpstreamRequests.empty())
+ if (!UpstreamPayloadRequests.empty())
{
- const auto OnCacheValueGetComplete = [this, &Chunks](CacheValueGetCompleteParams&& Params) {
- if (CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(Params.Value)))
+ const auto OnCacheValueGetComplete = [this](CacheValueGetCompleteParams&& Params) {
+ if (Params.RawHash == Params.RawHash.Zero)
{
- m_CidStore.AddChunk(Compressed);
-
- ZEN_DEBUG("HIT - '{}/{}/{}' {} ({})",
- Params.Request.Key.Bucket,
- Params.Request.Key.Hash,
- Params.Request.ChunkId,
- NiceBytes(Params.Value.GetSize()),
- "UPSTREAM");
-
- ZEN_ASSERT(Params.RequestIndex < Chunks.size());
- Chunks[Params.RequestIndex] = std::move(Params.Value);
-
- m_CacheStats.HitCount++;
- m_CacheStats.UpstreamHitCount++;
+ return;
}
- else
+
+ ChunkRequestData& Chunk =
+ *reinterpret_cast<ChunkRequestData*>(reinterpret_cast<char*>(&Params.Request) - offsetof(ChunkRequestData, Upstream));
+ if (EnumHasAllFlags(Chunk.DownstreamPolicy, CachePolicy::StoreLocal) ||
+ !EnumHasAllFlags(Chunk.DownstreamPolicy, CachePolicy::SkipData))
{
- ZEN_DEBUG("MISS - '{}/{}/{}'", Params.Request.Key.Bucket, Params.Request.Key.Hash, Params.Request.ChunkId);
- m_CacheStats.MissCount++;
+ CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(Params.Value));
+ if (!Compressed || Compressed.GetRawSize() != Params.RawSize)
+ {
+ return;
+ }
+
+ if (EnumHasAllFlags(Chunk.DownstreamPolicy, CachePolicy::StoreLocal))
+ {
+ m_CidStore.AddChunk(Compressed);
+ }
+ if (!EnumHasAllFlags(Chunk.DownstreamPolicy, CachePolicy::SkipData))
+ {
+ Chunk.Value = std::move(Compressed);
+ }
}
+ Chunk.Exists = true;
+ Chunk.TotalSize = Params.RawSize;
+ Chunk.TotalSizeKnown = true;
+ Chunk.Source = "UPSTREAM"sv;
+
+ m_CacheStats.UpstreamHitCount++;
};
- m_UpstreamCache.GetCacheValues(ChunkRequests, UpstreamRequests, std::move(OnCacheValueGetComplete));
+ m_UpstreamCache.GetCacheValues(UpstreamPayloadRequests, std::move(OnCacheValueGetComplete));
}
+}
- CbPackage RpcResponse;
- CbObjectWriter ResponseObject;
+void
+HttpStructuredCacheService::GetCacheChunks_SendResults(std::vector<GetCacheChunks::detail::ChunkRequestData>& Chunks,
+ zen::HttpServerRequest& HttpRequest
+ BACKWARDS_COMPATABILITY_JAN2022_CODE(, bool SendValueOnly))
+{
+ using namespace GetCacheChunks::detail;
- ResponseObject.BeginArray("Result"sv);
+ CbPackage RpcResponse;
+ CbObjectWriter Writer;
- for (size_t ChunkIndex = 0; ChunkIndex < Chunks.size(); ++ChunkIndex)
+ Writer.BeginArray("Result"sv);
+ for (ChunkRequestData& Chunk : Chunks)
{
- if (Chunks[ChunkIndex])
+#if BACKWARDS_COMPATABILITY_JAN2022
+ if (SendValueOnly)
{
- ResponseObject << ChunkRequests[ChunkIndex].ChunkId;
- RpcResponse.AddAttachment(CbAttachment(CompressedBuffer::FromCompressed(SharedBuffer(std::move(Chunks[ChunkIndex])))));
+ if (Chunk.Value)
+ {
+ Writer << Chunk.Upstream.ChunkId;
+ RpcResponse.AddAttachment(CbAttachment(Chunk.Value));
+ }
+ else
+ {
+ Writer << IoHash::Zero;
+ }
}
else
+#endif
{
- ResponseObject << IoHash::Zero;
+ Writer.BeginObject();
+ {
+ if (Chunk.Exists)
+ {
+ Writer.AddHash("RawHash"sv, Chunk.Upstream.ChunkId);
+ if (Chunk.Value && !EnumHasAllFlags(Chunk.DownstreamPolicy, CachePolicy::SkipData))
+ {
+ RpcResponse.AddAttachment(CbAttachment(Chunk.Value));
+ }
+ else
+ {
+ Writer.AddInteger("RawSize"sv, Chunk.TotalSize);
+ }
+
+ ZEN_DEBUG("CHUNKHIT - '{}/{}/{}' {} '{}' ({})",
+ Chunk.Upstream.Key.Bucket,
+ Chunk.Upstream.Key.Hash,
+ Chunk.Upstream.ValueId,
+ NiceBytes(Chunk.TotalSize),
+ Chunk.IsRecordRequest ? "Record"sv : "Value"sv,
+ Chunk.Source);
+ m_CacheStats.HitCount++;
+ }
+ else if (!EnumHasAnyFlags(Chunk.DownstreamPolicy, CachePolicy::Query))
+ {
+ ZEN_DEBUG("CHUNKSKIP - '{}/{}/{}'", Chunk.Upstream.Key.Bucket, Chunk.Upstream.Key.Hash, Chunk.Upstream.ValueId);
+ }
+ else
+ {
+ ZEN_DEBUG("MISS - '{}/{}/{}'", Chunk.Upstream.Key.Bucket, Chunk.Upstream.Key.Hash, Chunk.Upstream.ValueId);
+ m_CacheStats.MissCount++;
+ }
+ }
+ Writer.EndObject();
}
}
- ResponseObject.EndArray();
+ Writer.EndArray();
- RpcResponse.SetObject(ResponseObject.Save());
+ RpcResponse.SetObject(Writer.Save());
BinaryWriter MemStream;
RpcResponse.Save(MemStream);
- Request.WriteResponse(HttpResponseCode::OK,
- HttpContentType::kCbPackage,
- IoBuffer(IoBuffer::Wrap, MemStream.GetData(), MemStream.GetSize()));
+ HttpRequest.WriteResponse(HttpResponseCode::OK,
+ HttpContentType::kCbPackage,
+ IoBuffer(IoBuffer::Wrap, MemStream.GetData(), MemStream.GetSize()));
}
void
diff --git a/zenserver/cache/structuredcache.h b/zenserver/cache/structuredcache.h
index 88bf6cda1..14b001e48 100644
--- a/zenserver/cache/structuredcache.h
+++ b/zenserver/cache/structuredcache.h
@@ -9,6 +9,10 @@
#include "monitoring/httpstatus.h"
#include <memory>
+#include <vector>
+
+// Include the define for BACKWARDS_COMPATABILITY_JAN2022
+#include <zenutil/cache/cachepolicy.h>
namespace spdlog {
class logger;
@@ -25,6 +29,11 @@ class UpstreamCache;
class ZenCacheStore;
enum class CachePolicy : uint32_t;
+namespace GetCacheChunks::detail {
+ struct KeyRequestData;
+ struct ChunkRequestData;
+} // namespace GetCacheChunks::detail
+
/**
* Structured cache service. Imposes constraints on keys, supports blobs and
* structured values
@@ -99,12 +108,25 @@ private:
void HandlePutCacheValue(zen::HttpServerRequest& Request, const CacheRef& Ref, CachePolicy PolicyFromURL);
void HandleRpcRequest(zen::HttpServerRequest& Request);
void HandleRpcPutCacheRecords(zen::HttpServerRequest& Request, const CbPackage& BatchRequest);
- void HandleRpcGetCacheRecords(zen::HttpServerRequest& Request, CbObjectView BatchRequest);
- void HandleRpcGetCacheValues(zen::HttpServerRequest& Request, CbObjectView BatchRequest);
- void HandleCacheBucketRequest(zen::HttpServerRequest& Request, std::string_view Bucket);
- virtual void HandleStatsRequest(zen::HttpServerRequest& Request) override;
- virtual void HandleStatusRequest(zen::HttpServerRequest& Request) override;
- PutResult PutCacheRecord(PutRequestData& Request, const CbPackage* Package);
+#if BACKWARDS_COMPATABILITY_JAN2022
+ void HandleRpcGetCacheRecordsLegacy(zen::HttpServerRequest& Request, CbObjectView BatchRequest);
+#endif
+ void HandleRpcGetCacheRecords(zen::HttpServerRequest& Request, CbObjectView BatchRequest);
+ void HandleRpcPutCacheValues(zen::HttpServerRequest& Request, const CbPackage& BatchRequest);
+ void HandleRpcGetCacheValues(zen::HttpServerRequest& Request, CbObjectView BatchRequest);
+ void HandleRpcGetCacheChunks(zen::HttpServerRequest& Request, CbObjectView BatchRequest);
+ void HandleCacheBucketRequest(zen::HttpServerRequest& Request, std::string_view Bucket);
+ virtual void HandleStatsRequest(zen::HttpServerRequest& Request) override;
+ virtual void HandleStatusRequest(zen::HttpServerRequest& Request) override;
+ PutResult PutCacheRecord(PutRequestData& Request, const CbPackage* Package);
+
+ bool TryGetCacheChunks_Parse(std::vector<GetCacheChunks::detail::KeyRequestData>& KeyRequests,
+ std::vector<GetCacheChunks::detail::ChunkRequestData>& Chunks,
+ BACKWARDS_COMPATABILITY_JAN2022_CODE(bool& SendValueOnly, ) CbObjectView RpcRequest);
+ void GetCacheChunks_LoadKeys(std::vector<GetCacheChunks::detail::KeyRequestData>& KeyRequests);
+ void GetCacheChunks_LoadChunks(std::vector<GetCacheChunks::detail::ChunkRequestData>& Chunks);
+ void GetCacheChunks_SendResults(std::vector<GetCacheChunks::detail::ChunkRequestData>& Chunks,
+ zen::HttpServerRequest& HttpRequest BACKWARDS_COMPATABILITY_JAN2022_CODE(, bool SendValueOnly));
spdlog::logger& Log() { return m_Log; }
spdlog::logger& m_Log;
@@ -119,4 +141,12 @@ private:
CacheStats m_CacheStats;
};
+/** Recognize both kBinary and kCompressedBinary as kCompressedBinary for structured cache value keys.
+ * We need this until the content type is preserved for kCompressedBinary when passing to and from upstream servers. */
+inline bool
+IsCompressedBinary(ZenContentType Type)
+{
+ return Type == ZenContentType::kBinary || Type == ZenContentType::kCompressedBinary;
+}
+
} // namespace zen
diff --git a/zenserver/cache/structuredcachestore.h b/zenserver/cache/structuredcachestore.h
index 4f162c0ba..9ffc06b28 100644
--- a/zenserver/cache/structuredcachestore.h
+++ b/zenserver/cache/structuredcachestore.h
@@ -84,12 +84,12 @@ struct DiskLocation
}
static const uint64_t kOffsetMask = 0x0000'ffFF'ffFF'ffFFull;
- static const uint64_t kSizeMask = 0x00FF'0000'0000'0000ull; // Most significant bits of value size (lower 32 bits in LowerSize)
+ static const uint64_t kSizeMask = 0x00FF'0000'0000'0000ull; // Most significant bits of value size (lower 32 bits in LowerSize)
static const uint64_t kFlagsMask = 0xff00'0000'0000'0000ull;
- static const uint64_t kStandaloneFile = 0x8000'0000'0000'0000ull; // Stored as a separate file
- static const uint64_t kStructured = 0x4000'0000'0000'0000ull; // Serialized as compact binary
- static const uint64_t kTombStone = 0x2000'0000'0000'0000ull; // Represents a deleted key/value
- static const uint64_t kCompressed = 0x1000'0000'0000'0000ull; // Stored in compressed buffer format
+ static const uint64_t kStandaloneFile = 0x8000'0000'0000'0000ull; // Stored as a separate file
+ static const uint64_t kStructured = 0x4000'0000'0000'0000ull; // Serialized as compact binary
+ static const uint64_t kTombStone = 0x2000'0000'0000'0000ull; // Represents a deleted key/value
+ static const uint64_t kCompressed = 0x1000'0000'0000'0000ull; // Stored in compressed buffer format
static uint64_t CombineOffsetAndFlags(uint64_t Offset, uint64_t Flags) { return Offset | Flags; }
@@ -104,11 +104,11 @@ struct DiskLocation
{
ContentType = ZenContentType::kCbObject;
}
-
+
if (IsFlagSet(DiskLocation::kCompressed))
{
ContentType = ZenContentType::kCompressedBinary;
- }
+ }
return ContentType;
}
diff --git a/zenserver/upstream/jupiter.h b/zenserver/upstream/jupiter.h
index 47fdc4e17..f90ad26ed 100644
--- a/zenserver/upstream/jupiter.h
+++ b/zenserver/upstream/jupiter.h
@@ -26,10 +26,10 @@ namespace detail {
struct CloudCacheSessionState;
}
-class IoBuffer;
+class CbObjectView;
class CloudCacheClient;
+class IoBuffer;
struct IoHash;
-class CbObjectView;
/**
* Cached access token, for use with `Authorization:` header
diff --git a/zenserver/upstream/upstreamapply.h b/zenserver/upstream/upstreamapply.h
index e48b67c61..c56a22ac3 100644
--- a/zenserver/upstream/upstreamapply.h
+++ b/zenserver/upstream/upstreamapply.h
@@ -115,26 +115,20 @@ struct UpstreamApplyEndpointStats
};
/**
- * The upstream apply endpont is responsible for handling remote execution.
+ * The upstream apply endpoint is responsible for handling remote execution.
*/
class UpstreamApplyEndpoint
{
public:
virtual ~UpstreamApplyEndpoint() = default;
- virtual UpstreamEndpointHealth Initialize() = 0;
-
- virtual bool IsHealthy() const = 0;
-
- virtual UpstreamEndpointHealth CheckHealth() = 0;
-
- virtual std::string_view DisplayName() const = 0;
-
- virtual PostUpstreamApplyResult PostApply(const UpstreamApplyRecord& ApplyRecord) = 0;
-
- virtual GetUpstreamApplyUpdatesResult GetUpdates() = 0;
-
- virtual UpstreamApplyEndpointStats& Stats() = 0;
+ virtual UpstreamEndpointHealth Initialize() = 0;
+ virtual bool IsHealthy() const = 0;
+ virtual UpstreamEndpointHealth CheckHealth() = 0;
+ virtual std::string_view DisplayName() const = 0;
+ virtual PostUpstreamApplyResult PostApply(const UpstreamApplyRecord& ApplyRecord) = 0;
+ virtual GetUpstreamApplyUpdatesResult GetUpdates() = 0;
+ virtual UpstreamApplyEndpointStats& Stats() = 0;
};
/**
@@ -145,8 +139,7 @@ class UpstreamApply
public:
virtual ~UpstreamApply() = default;
- virtual bool Initialize() = 0;
-
+ virtual bool Initialize() = 0;
virtual void RegisterEndpoint(std::unique_ptr<UpstreamApplyEndpoint> Endpoint) = 0;
struct EnqueueResult
@@ -161,11 +154,9 @@ public:
bool Success = false;
};
- virtual EnqueueResult EnqueueUpstream(UpstreamApplyRecord ApplyRecord) = 0;
-
- virtual StatusResult GetStatus(const IoHash& WorkerId, const IoHash& ActionId) = 0;
-
- virtual void GetStatus(CbObjectWriter& CbO) = 0;
+ virtual EnqueueResult EnqueueUpstream(UpstreamApplyRecord ApplyRecord) = 0;
+ virtual StatusResult GetStatus(const IoHash& WorkerId, const IoHash& ActionId) = 0;
+ virtual void GetStatus(CbObjectWriter& CbO) = 0;
};
std::unique_ptr<UpstreamApply> MakeUpstreamApply(const UpstreamApplyOptions& Options, CasStore& CasStore, CidStore& CidStore);
diff --git a/zenserver/upstream/upstreamcache.cpp b/zenserver/upstream/upstreamcache.cpp
index d83542701..7466af1d2 100644
--- a/zenserver/upstream/upstreamcache.cpp
+++ b/zenserver/upstream/upstreamcache.cpp
@@ -18,6 +18,7 @@
#include <zenstore/cas.h>
#include <zenstore/cidstore.h>
+#include "cache/structuredcache.h"
#include <auth/authmgr.h>
#include "cache/structuredcachestore.h"
#include "diag/logging.h"
@@ -240,21 +241,16 @@ namespace detail {
}
}
- virtual GetUpstreamCacheResult GetCacheRecords(std::span<CacheKey> CacheKeys,
- std::span<size_t> KeyIndex,
- const CacheRecordPolicy& Policy,
- OnCacheRecordGetComplete&& OnComplete) override
+ virtual GetUpstreamCacheResult GetCacheRecords(std::span<CacheKeyRequest*> Requests, OnCacheRecordGetComplete&& OnComplete) override
{
ZEN_TRACE_CPU("Upstream::Horde::GetCacheRecords");
- ZEN_UNUSED(Policy);
-
CloudCacheSession Session(m_Client);
GetUpstreamCacheResult Result;
- for (size_t Index : KeyIndex)
+ for (CacheKeyRequest* Request : Requests)
{
- const CacheKey& CacheKey = CacheKeys[Index];
+ const CacheKey& CacheKey = Request->Key;
CbPackage Package;
CbObject Record;
@@ -289,7 +285,7 @@ namespace detail {
}
}
- OnComplete({.Key = CacheKey, .KeyIndex = Index, .Record = Record, .Package = Package});
+ OnComplete({.Request = *Request, .Record = Record, .Package = Package});
}
return Result;
@@ -326,20 +322,20 @@ namespace detail {
}
}
- virtual GetUpstreamCacheResult GetCacheValues(std::span<CacheChunkRequest> CacheChunkRequests,
- std::span<size_t> RequestIndex,
- OnCacheValueGetComplete&& OnComplete) override final
+ virtual GetUpstreamCacheResult GetCacheValues(std::span<CacheChunkRequest*> CacheChunkRequests,
+ OnCacheValueGetComplete&& OnComplete) override final
{
ZEN_TRACE_CPU("Upstream::Horde::GetCacheValues");
CloudCacheSession Session(m_Client);
GetUpstreamCacheResult Result;
- for (size_t Index : RequestIndex)
+ for (CacheChunkRequest* RequestPtr : CacheChunkRequests)
{
- const CacheChunkRequest& Request = CacheChunkRequests[Index];
- IoBuffer Payload;
+ CacheChunkRequest& Request = *RequestPtr;
+ IoBuffer Payload;
+ CompressedBuffer Compressed;
if (!Result.Error)
{
const CloudCacheResult BlobResult = Session.GetCompressedBlob(Request.ChunkId);
@@ -348,9 +344,23 @@ namespace detail {
AppendResult(BlobResult, Result);
m_Status.SetFromErrorCode(BlobResult.ErrorCode, BlobResult.Reason);
+ if (Payload && IsCompressedBinary(Payload.GetContentType()))
+ {
+ Compressed = CompressedBuffer::FromCompressed(SharedBuffer(Payload));
+ }
}
- OnComplete({.Request = Request, .RequestIndex = Index, .Value = Payload});
+ if (Compressed)
+ {
+ OnComplete({.Request = Request,
+ .RawHash = IoHash::FromBLAKE3(Compressed.GetRawHash()),
+ .RawSize = Compressed.GetRawSize(),
+ .Value = Payload});
+ }
+ else
+ {
+ OnComplete({.Request = Request, .RawHash = IoHash::Zero, .RawSize = 0, .Value = IoBuffer()});
+ }
}
return Result;
@@ -646,15 +656,10 @@ namespace detail {
}
}
- virtual GetUpstreamCacheResult GetCacheRecords(std::span<CacheKey> CacheKeys,
- std::span<size_t> KeyIndex,
- const CacheRecordPolicy& Policy,
- OnCacheRecordGetComplete&& OnComplete) override
+ virtual GetUpstreamCacheResult GetCacheRecords(std::span<CacheKeyRequest*> Requests, OnCacheRecordGetComplete&& OnComplete) override
{
ZEN_TRACE_CPU("Upstream::Zen::GetCacheRecords");
-
- std::vector<size_t> IndexMap;
- IndexMap.reserve(KeyIndex.size());
+ ZEN_ASSERT(Requests.size() > 0);
CbObjectWriter BatchRequest;
BatchRequest << "Method"sv
@@ -662,21 +667,30 @@ namespace detail {
BatchRequest.BeginObject("Params"sv);
{
- BatchRequest.BeginArray("CacheKeys"sv);
- for (size_t Index : KeyIndex)
- {
- const CacheKey& Key = CacheKeys[Index];
- IndexMap.push_back(Index);
+ CachePolicy DefaultPolicy = Requests[0]->Policy.GetRecordPolicy();
+ BatchRequest << "DefaultPolicy"sv << WriteToString<128>(DefaultPolicy);
+ BatchRequest.BeginArray("Requests"sv);
+ for (CacheKeyRequest* Request : Requests)
+ {
BatchRequest.BeginObject();
- BatchRequest << "Bucket"sv << Key.Bucket;
- BatchRequest << "Hash"sv << Key.Hash;
+ {
+ const CacheKey& Key = Request->Key;
+ BatchRequest.BeginObject("Key"sv);
+ {
+ BatchRequest << "Bucket"sv << Key.Bucket;
+ BatchRequest << "Hash"sv << Key.Hash;
+ }
+ BatchRequest.EndObject();
+ if (!Request->Policy.IsUniform() || Request->Policy.GetRecordPolicy() != DefaultPolicy)
+ {
+ BatchRequest.SetName("Policy"sv);
+ Request->Policy.Save(BatchRequest);
+ }
+ }
BatchRequest.EndObject();
}
BatchRequest.EndArray();
-
- BatchRequest.SetName("Policy"sv);
- Policy.Save(BatchRequest);
}
BatchRequest.EndObject();
@@ -694,19 +708,27 @@ namespace detail {
{
if (BatchResponse.TryLoad(Result.Response))
{
- for (size_t LocalIndex = 0; CbFieldView Record : BatchResponse.GetObject()["Result"sv])
+ CbArrayView Results = BatchResponse.GetObject()["Result"sv].AsArrayView();
+ if (Results.Num() != Requests.size())
{
- const size_t Index = IndexMap[LocalIndex++];
- OnComplete({.Key = CacheKeys[Index], .KeyIndex = Index, .Record = Record.AsObjectView(), .Package = BatchResponse});
+ ZEN_WARN("Upstream::Zen::GetCacheRecords invalid number of Requests from Upstream.");
}
+ else
+ {
+ for (size_t Index = 0; CbFieldView Record : Results)
+ {
+ CacheKeyRequest* Request = Requests[Index++];
+ OnComplete({.Request = *Request, .Record = Record.AsObjectView(), .Package = BatchResponse});
+ }
- return {.Bytes = Result.Bytes, .ElapsedSeconds = Result.ElapsedSeconds, .Success = true};
+ return {.Bytes = Result.Bytes, .ElapsedSeconds = Result.ElapsedSeconds, .Success = true};
+ }
}
}
- for (size_t Index : KeyIndex)
+ for (CacheKeyRequest* Request : Requests)
{
- OnComplete({.Key = CacheKeys[Index], .KeyIndex = Index, .Record = CbObjectView(), .Package = CbPackage()});
+ OnComplete({.Request = *Request, .Record = CbObjectView(), .Package = CbPackage()});
}
return {.Error{.ErrorCode = Result.ErrorCode, .Reason = std::move(Result.Reason)}};
@@ -743,27 +765,28 @@ namespace detail {
}
}
- virtual GetUpstreamCacheResult GetCacheValues(std::span<CacheChunkRequest> CacheChunkRequests,
- std::span<size_t> RequestIndex,
- OnCacheValueGetComplete&& OnComplete) override final
+ virtual GetUpstreamCacheResult GetCacheValues(std::span<CacheChunkRequest*> CacheChunkRequests,
+ OnCacheValueGetComplete&& OnComplete) override final
{
ZEN_TRACE_CPU("Upstream::Zen::GetCacheValues");
-
- std::vector<size_t> IndexMap;
- IndexMap.reserve(RequestIndex.size());
+ ZEN_ASSERT(!CacheChunkRequests.empty());
CbObjectWriter BatchRequest;
BatchRequest << "Method"sv
- << "GetCacheValues";
+ << "GetCacheChunks";
+#if BACKWARDS_COMPATABILITY_JAN2022
+ BatchRequest.AddInteger("MethodVersion"sv, 1);
+#endif
BatchRequest.BeginObject("Params"sv);
{
+ CachePolicy DefaultPolicy = CacheChunkRequests[0]->Policy;
+ BatchRequest << "DefaultPolicy"sv << WriteToString<128>(DefaultPolicy).ToView();
BatchRequest.BeginArray("ChunkRequests"sv);
{
- for (size_t Index : RequestIndex)
+ for (CacheChunkRequest* RequestPtr : CacheChunkRequests)
{
- const CacheChunkRequest& Request = CacheChunkRequests[Index];
- IndexMap.push_back(Index);
+ const CacheChunkRequest& Request = *RequestPtr;
BatchRequest.BeginObject();
{
@@ -771,11 +794,26 @@ namespace detail {
BatchRequest << "Bucket"sv << Request.Key.Bucket;
BatchRequest << "Hash"sv << Request.Key.Hash;
BatchRequest.EndObject();
- BatchRequest.AddObjectId("ValueId"sv, Request.ValueId);
- BatchRequest << "ChunkId"sv << Request.ChunkId;
- BatchRequest << "RawOffset"sv << Request.RawOffset;
- BatchRequest << "RawSize"sv << Request.RawSize;
- BatchRequest << "Policy"sv << WriteToString<128>(Request.Policy).ToView();
+ if (Request.ValueId)
+ {
+ BatchRequest.AddObjectId("ValueId"sv, Request.ValueId);
+ }
+ if (Request.ChunkId != Request.ChunkId.Zero)
+ {
+ BatchRequest << "ChunkId"sv << Request.ChunkId;
+ }
+ if (Request.RawOffset != 0)
+ {
+ BatchRequest << "RawOffset"sv << Request.RawOffset;
+ }
+ if (Request.RawSize != UINT64_MAX)
+ {
+ BatchRequest << "RawSize"sv << Request.RawSize;
+ }
+ if (Request.Policy != DefaultPolicy)
+ {
+ BatchRequest << "Policy"sv << WriteToString<128>(Request.Policy).ToView();
+ }
}
BatchRequest.EndObject();
}
@@ -798,29 +836,56 @@ namespace detail {
{
if (BatchResponse.TryLoad(Result.Response))
{
- for (size_t LocalIndex = 0; CbFieldView AttachmentHash : BatchResponse.GetObject()["Result"sv])
+ CbArrayView Results = BatchResponse.GetObject()["Result"sv].AsArrayView();
+ if (CacheChunkRequests.size() != Results.Num())
{
- const size_t Index = IndexMap[LocalIndex++];
- IoBuffer Payload;
-
- if (const CbAttachment* Attachment = BatchResponse.FindAttachment(AttachmentHash.AsHash()))
+ ZEN_WARN("Upstream::Zen::GetCacheValues invalid number of Requests from Upstream.");
+ }
+ else
+ {
+ for (size_t RequestIndex = 0; CbFieldView ChunkField : Results)
{
- if (const CompressedBuffer& Compressed = Attachment->AsCompressedBinary())
+ CacheChunkRequest& Request = *CacheChunkRequests[RequestIndex++];
+ CbObjectView ChunkObject = ChunkField.AsObjectView();
+ IoHash RawHash = ChunkObject["RawHash"sv].AsHash();
+ IoBuffer Payload;
+ uint64_t RawSize = 0;
+ if (RawHash != IoHash::Zero)
{
- Payload = Compressed.GetCompressed().Flatten().AsIoBuffer();
+ bool Success = false;
+ const CbAttachment* Attachment = BatchResponse.FindAttachment(RawHash);
+ if (Attachment)
+ {
+ if (const CompressedBuffer& Compressed = Attachment->AsCompressedBinary())
+ {
+ Payload = Compressed.GetCompressed().Flatten().AsIoBuffer();
+ Payload.SetContentType(ZenContentType::kCompressedBinary);
+ RawSize = Compressed.GetRawSize();
+ Success = true;
+ }
+ }
+ if (!Success)
+ {
+ CbFieldView RawSizeField = ChunkObject["RawSize"sv];
+ RawSize = RawSizeField.AsUInt64();
+ Success = !RawSizeField.HasError();
+ }
+ if (!Success)
+ {
+ RawHash = IoHash::Zero;
+ }
}
+ OnComplete({.Request = Request, .RawHash = RawHash, .RawSize = RawSize, .Value = std::move(Payload)});
}
- OnComplete({.Request = CacheChunkRequests[Index], .RequestIndex = Index, .Value = std::move(Payload)});
+ return {.Bytes = Result.Bytes, .ElapsedSeconds = Result.ElapsedSeconds, .Success = true};
}
-
- return {.Bytes = Result.Bytes, .ElapsedSeconds = Result.ElapsedSeconds, .Success = true};
}
}
- for (size_t Index : RequestIndex)
+ for (CacheChunkRequest* RequestPtr : CacheChunkRequests)
{
- OnComplete({.Request = CacheChunkRequests[Index], .RequestIndex = Index, .Value = IoBuffer()});
+ OnComplete({.Request = *RequestPtr, .RawHash = IoHash::Zero, .RawSize = 0, .Value = IoBuffer()});
}
return {.Error{.ErrorCode = Result.ErrorCode, .Reason = std::move(Result.Reason)}};
@@ -1071,21 +1136,16 @@ public:
return {};
}
- virtual void GetCacheRecords(std::span<CacheKey> CacheKeys,
- std::span<size_t> KeyIndex,
- const CacheRecordPolicy& DownstreamPolicy,
- OnCacheRecordGetComplete&& OnComplete) override final
+ virtual void GetCacheRecords(std::span<CacheKeyRequest*> Requests, OnCacheRecordGetComplete&& OnComplete) override final
{
ZEN_TRACE_CPU("Upstream::GetCacheRecords");
std::shared_lock<std::shared_mutex> _(m_EndpointsMutex);
- std::vector<size_t> RemainingKeys(KeyIndex.begin(), KeyIndex.end());
+ std::vector<CacheKeyRequest*> RemainingKeys(Requests.begin(), Requests.end());
if (m_Options.ReadUpstream)
{
- CacheRecordPolicy UpstreamPolicy = DownstreamPolicy.ConvertToUpstream();
-
for (auto& Endpoint : m_Endpoints)
{
if (RemainingKeys.empty())
@@ -1098,25 +1158,24 @@ public:
continue;
}
- UpstreamEndpointStats& Stats = Endpoint->Stats();
- std::vector<size_t> Missing;
- GetUpstreamCacheResult Result;
+ UpstreamEndpointStats& Stats = Endpoint->Stats();
+ std::vector<CacheKeyRequest*> Missing;
+ GetUpstreamCacheResult Result;
{
metrics::OperationTiming::Scope Scope(Stats.CacheGetRequestTiming);
- Result =
- Endpoint->GetCacheRecords(CacheKeys, RemainingKeys, UpstreamPolicy, [&](CacheRecordGetCompleteParams&& Params) {
- if (Params.Record)
- {
- OnComplete(std::forward<CacheRecordGetCompleteParams>(Params));
+ Result = Endpoint->GetCacheRecords(RemainingKeys, [&](CacheRecordGetCompleteParams&& Params) {
+ if (Params.Record)
+ {
+ OnComplete(std::forward<CacheRecordGetCompleteParams>(Params));
- Stats.CacheHitCount.Increment(1);
- }
- else
- {
- Missing.push_back(Params.KeyIndex);
- }
- });
+ Stats.CacheHitCount.Increment(1);
+ }
+ else
+ {
+ Missing.push_back(&Params.Request);
+ }
+ });
}
Stats.CacheGetCount.Increment(int64_t(RemainingKeys.size()));
@@ -1136,21 +1195,19 @@ public:
}
}
- for (size_t Index : RemainingKeys)
+ for (CacheKeyRequest* Request : RemainingKeys)
{
- OnComplete({.Key = CacheKeys[Index], .KeyIndex = Index, .Record = CbObjectView(), .Package = CbPackage()});
+ OnComplete({.Request = *Request, .Record = CbObjectView(), .Package = CbPackage()});
}
}
- virtual void GetCacheValues(std::span<CacheChunkRequest> CacheChunkRequests,
- std::span<size_t> RequestIndex,
- OnCacheValueGetComplete&& OnComplete) override final
+ virtual void GetCacheValues(std::span<CacheChunkRequest*> CacheChunkRequests, OnCacheValueGetComplete&& OnComplete) override final
{
ZEN_TRACE_CPU("Upstream::GetCacheValues");
std::shared_lock<std::shared_mutex> _(m_EndpointsMutex);
- std::vector<size_t> RemainingKeys(RequestIndex.begin(), RequestIndex.end());
+ std::vector<CacheChunkRequest*> RemainingKeys(CacheChunkRequests.begin(), CacheChunkRequests.end());
if (m_Options.ReadUpstream)
{
@@ -1166,14 +1223,14 @@ public:
continue;
}
- UpstreamEndpointStats& Stats = Endpoint->Stats();
- std::vector<size_t> Missing;
- GetUpstreamCacheResult Result;
+ UpstreamEndpointStats& Stats = Endpoint->Stats();
+ std::vector<CacheChunkRequest*> Missing;
+ GetUpstreamCacheResult Result;
{
metrics::OperationTiming::Scope Scope(Endpoint->Stats().CacheGetRequestTiming);
- Result = Endpoint->GetCacheValues(CacheChunkRequests, RemainingKeys, [&](CacheValueGetCompleteParams&& Params) {
- if (Params.Value)
+ Result = Endpoint->GetCacheValues(RemainingKeys, [&](CacheValueGetCompleteParams&& Params) {
+ if (Params.RawHash != Params.RawHash.Zero)
{
OnComplete(std::forward<CacheValueGetCompleteParams>(Params));
@@ -1181,7 +1238,7 @@ public:
}
else
{
- Missing.push_back(Params.RequestIndex);
+ Missing.push_back(&Params.Request);
}
});
}
@@ -1203,9 +1260,9 @@ public:
}
}
- for (size_t Index : RemainingKeys)
+ for (CacheChunkRequest* RequestPtr : CacheChunkRequests)
{
- OnComplete({.Request = CacheChunkRequests[Index], .RequestIndex = Index, .Value = IoBuffer()});
+ OnComplete({.Request = *RequestPtr, .RawHash = IoHash::Zero, .RawSize = 0, .Value = IoBuffer()});
}
}
diff --git a/zenserver/upstream/upstreamcache.h b/zenserver/upstream/upstreamcache.h
index 48601c879..4ccc56f79 100644
--- a/zenserver/upstream/upstreamcache.h
+++ b/zenserver/upstream/upstreamcache.h
@@ -2,6 +2,8 @@
#pragma once
+#include <zencore/compactbinary.h>
+#include <zencore/compress.h>
#include <zencore/iobuffer.h>
#include <zencore/iohash.h>
#include <zencore/stats.h>
@@ -12,9 +14,11 @@
#include <chrono>
#include <functional>
#include <memory>
+#include <vector>
namespace zen {
+class CbObjectView;
class AuthMgr;
class CbObjectView;
class CbPackage;
@@ -67,8 +71,7 @@ struct PutUpstreamCacheResult
struct CacheRecordGetCompleteParams
{
- const CacheKey& Key;
- size_t KeyIndex = ~size_t(0);
+ CacheKeyRequest& Request;
const CbObjectView& Record;
const CbPackage& Package;
};
@@ -77,9 +80,10 @@ using OnCacheRecordGetComplete = std::function<void(CacheRecordGetCompleteParams
struct CacheValueGetCompleteParams
{
- const CacheChunkRequest& Request;
- size_t RequestIndex{~size_t(0)};
- IoBuffer Value;
+ CacheChunkRequest& Request;
+ IoHash RawHash;
+ uint64_t RawSize;
+ IoBuffer Value;
};
using OnCacheValueGetComplete = std::function<void(CacheValueGetCompleteParams&&)>;
@@ -145,33 +149,26 @@ struct UpstreamEndpointInfo
};
/**
- * The upstream endpont is responsible for handling upload/downloading of cache records.
+ * The upstream endpoint is responsible for handling upload/downloading of cache records.
*/
class UpstreamEndpoint
{
public:
virtual ~UpstreamEndpoint() = default;
- virtual const UpstreamEndpointInfo& GetEndpointInfo() const = 0;
-
virtual UpstreamEndpointStatus Initialize() = 0;
- virtual UpstreamEndpointState GetState() = 0;
+ virtual const UpstreamEndpointInfo& GetEndpointInfo() const = 0;
+ virtual UpstreamEndpointState GetState() = 0;
virtual UpstreamEndpointStatus GetStatus() = 0;
- virtual GetUpstreamCacheResult GetCacheRecord(CacheKey CacheKey, ZenContentType Type) = 0;
-
- virtual GetUpstreamCacheResult GetCacheRecords(std::span<CacheKey> CacheKeys,
- std::span<size_t> KeyIndex,
- const CacheRecordPolicy& Policy,
- OnCacheRecordGetComplete&& OnComplete) = 0;
+ virtual GetUpstreamCacheResult GetCacheRecord(CacheKey CacheKey, ZenContentType Type) = 0;
+ virtual GetUpstreamCacheResult GetCacheRecords(std::span<CacheKeyRequest*> Requests, OnCacheRecordGetComplete&& OnComplete) = 0;
virtual GetUpstreamCacheResult GetCacheValue(const CacheKey& CacheKey, const IoHash& PayloadId) = 0;
-
- virtual GetUpstreamCacheResult GetCacheValues(std::span<CacheChunkRequest> CacheChunkRequests,
- std::span<size_t> RequestIndex,
- OnCacheValueGetComplete&& OnComplete) = 0;
+ virtual GetUpstreamCacheResult GetCacheValues(std::span<CacheChunkRequest*> CacheChunkRequests,
+ OnCacheValueGetComplete&& OnComplete) = 0;
virtual PutUpstreamCacheResult PutCacheRecord(const UpstreamCacheRecord& CacheRecord,
IoBuffer RecordValue,
@@ -190,22 +187,14 @@ public:
virtual void Initialize() = 0;
- virtual void RegisterEndpoint(std::unique_ptr<UpstreamEndpoint> Endpoint) = 0;
-
+ virtual void RegisterEndpoint(std::unique_ptr<UpstreamEndpoint> Endpoint) = 0;
virtual void IterateEndpoints(std::function<bool(UpstreamEndpoint&)>&& Fn) = 0;
- virtual GetUpstreamCacheResult GetCacheRecord(CacheKey CacheKey, ZenContentType Type) = 0;
-
- virtual void GetCacheRecords(std::span<CacheKey> CacheKeys,
- std::span<size_t> KeyIndex,
- const CacheRecordPolicy& RecordPolicy,
- OnCacheRecordGetComplete&& OnComplete) = 0;
+ virtual GetUpstreamCacheResult GetCacheRecord(CacheKey CacheKey, ZenContentType Type) = 0;
+ virtual void GetCacheRecords(std::span<CacheKeyRequest*> Requests, OnCacheRecordGetComplete&& OnComplete) = 0;
- virtual GetUpstreamCacheResult GetCacheValue(const CacheKey& CacheKey, const IoHash& ValueContentId) = 0;
-
- virtual void GetCacheValues(std::span<CacheChunkRequest> CacheChunkRequests,
- std::span<size_t> RequestIndex,
- OnCacheValueGetComplete&& OnComplete) = 0;
+ virtual GetUpstreamCacheResult GetCacheValue(const CacheKey& CacheKey, const IoHash& ValueContentId) = 0;
+ virtual void GetCacheValues(std::span<CacheChunkRequest*> CacheChunkRequests, OnCacheValueGetComplete&& OnComplete) = 0;
virtual void EnqueueUpstream(UpstreamCacheRecord CacheRecord) = 0;
@@ -214,6 +203,9 @@ public:
std::unique_ptr<UpstreamCache> MakeUpstreamCache(const UpstreamCacheOptions& Options, ZenCacheStore& CacheStore, CidStore& CidStore);
+std::unique_ptr<UpstreamEndpoint> MakeJupiterUpstreamEndpoint(const CloudCacheClientOptions& Options);
+
+std::unique_ptr<UpstreamEndpoint> MakeJupiterUpstreamEndpoint(const CloudCacheClientOptions& Options);
std::unique_ptr<UpstreamEndpoint> MakeJupiterUpstreamEndpoint(const CloudCacheClientOptions& Options,
const UpstreamAuthConfig& AuthConfig,
AuthMgr& Mgr);
diff --git a/zenserver/upstream/zen.cpp b/zenserver/upstream/zen.cpp
index a2666ac02..0570dd316 100644
--- a/zenserver/upstream/zen.cpp
+++ b/zenserver/upstream/zen.cpp
@@ -66,17 +66,17 @@ namespace detail {
// Note that currently this just implements an UDP echo service for testing purposes
-Mesh::Mesh(asio::io_context& IoContext) : m_Log(logging::Get("mesh")), m_IoContext(IoContext), m_SessionId(GetSessionId())
+MeshTracker::MeshTracker(asio::io_context& IoContext) : m_Log(logging::Get("mesh")), m_IoContext(IoContext), m_SessionId(GetSessionId())
{
}
-Mesh::~Mesh()
+MeshTracker::~MeshTracker()
{
Stop();
}
void
-Mesh::Start(uint16_t Port)
+MeshTracker::Start(uint16_t Port)
{
ZEN_ASSERT(Port);
ZEN_ASSERT(m_Port == 0);
@@ -87,7 +87,7 @@ Mesh::Start(uint16_t Port)
};
void
-Mesh::Stop()
+MeshTracker::Stop()
{
using namespace std::literals;
@@ -118,7 +118,7 @@ Mesh::Stop()
}
void
-Mesh::EnqueueTick()
+MeshTracker::EnqueueTick()
{
m_Timer.expires_after(std::chrono::seconds(10));
@@ -138,7 +138,7 @@ Mesh::EnqueueTick()
}
void
-Mesh::OnTick()
+MeshTracker::OnTick()
{
using namespace std::literals;
@@ -156,7 +156,7 @@ Mesh::OnTick()
}
void
-Mesh::BroadcastPacket(CbObjectWriter& Obj)
+MeshTracker::BroadcastPacket(CbObjectWriter& Obj)
{
std::error_code ErrorCode;
@@ -201,7 +201,7 @@ Mesh::BroadcastPacket(CbObjectWriter& Obj)
}
void
-Mesh::Run()
+MeshTracker::Run()
{
m_State = kRunning;
@@ -212,7 +212,7 @@ Mesh::Run()
}
void
-Mesh::IssueReceive()
+MeshTracker::IssueReceive()
{
using namespace std::literals;
diff --git a/zenserver/upstream/zen.h b/zenserver/upstream/zen.h
index 8cc4c121d..bc8fd3c56 100644
--- a/zenserver/upstream/zen.h
+++ b/zenserver/upstream/zen.h
@@ -34,12 +34,16 @@ class ZenStructuredCacheClient;
/** Zen mesh tracker
*
* Discovers and tracks local peers
+ *
+ * NOTE: This is currently experimental, and not very useful yet
+ *
*/
-class Mesh
+
+class MeshTracker
{
public:
- Mesh(asio::io_context& IoContext);
- ~Mesh();
+ MeshTracker(asio::io_context& IoContext);
+ ~MeshTracker();
void Start(uint16_t Port);
void Stop();
@@ -86,6 +90,8 @@ private:
tsl::robin_map<Oid, PeerInfo, Oid::Hasher> m_KnownPeers;
};
+//////////////////////////////////////////////////////////////////////////
+
namespace detail {
struct ZenCacheSessionState;
}
diff --git a/zenserver/zenserver.cpp b/zenserver/zenserver.cpp
index c8ee024fb..6b178ee0c 100644
--- a/zenserver/zenserver.cpp
+++ b/zenserver/zenserver.cpp
@@ -555,7 +555,7 @@ private:
std::unique_ptr<zen::HttpStructuredCacheService> m_StructuredCacheService;
zen::HttpAdminService m_AdminService{m_GcScheduler};
zen::HttpHealthService m_HealthService;
- zen::Mesh m_ZenMesh{m_IoContext};
+ zen::MeshTracker m_ZenMesh{m_IoContext};
#if ZEN_WITH_COMPUTE_SERVICES
std::unique_ptr<zen::HttpLaunchService> m_HttpLaunchService;
std::unique_ptr<zen::HttpFunctionService> m_HttpFunctionService;
@@ -1018,6 +1018,10 @@ main(int argc, char* argv[])
{
TraceInit(ServerOptions.TraceFile.c_str(), TraceType::File);
}
+ else
+ {
+ TraceInit(nullptr, TraceType::None);
+ }
#endif // ZEN_WITH_TRACE
#if ZEN_PLATFORM_WINDOWS
diff --git a/zenutil/cache/cachepolicy.cpp b/zenutil/cache/cachepolicy.cpp
index 3bf7a0c67..f17c54aa2 100644
--- a/zenutil/cache/cachepolicy.cpp
+++ b/zenutil/cache/cachepolicy.cpp
@@ -10,149 +10,156 @@
#include <algorithm>
#include <unordered_map>
+namespace zen::Private {
+class CacheRecordPolicyShared;
+}
+
namespace zen {
using namespace std::literals;
-namespace detail::CachePolicyImpl {
- constexpr char DelimiterChar = ',';
- constexpr std::string_view None = "None"sv;
- constexpr std::string_view QueryLocal = "QueryLocal"sv;
- constexpr std::string_view QueryRemote = "QueryRemote"sv;
- constexpr std::string_view Query = "Query"sv;
- constexpr std::string_view StoreLocal = "StoreLocal"sv;
- constexpr std::string_view StoreRemote = "StoreRemote"sv;
- constexpr std::string_view Store = "Store"sv;
- constexpr std::string_view SkipMeta = "SkipMeta"sv;
- constexpr std::string_view SkipData = "SkipData"sv;
- constexpr std::string_view PartialRecord = "PartialRecord"sv;
- constexpr std::string_view KeepAlive = "KeepAlive"sv;
- constexpr std::string_view Local = "Local"sv;
- constexpr std::string_view Remote = "Remote"sv;
- constexpr std::string_view Default = "Default"sv;
- constexpr std::string_view Disable = "Disable"sv;
+namespace DerivedData::Private {
- using TextToPolicyMap = std::unordered_map<std::string_view, CachePolicy>;
- const TextToPolicyMap TextToPolicy = {{None, CachePolicy::None},
- {QueryLocal, CachePolicy::QueryLocal},
- {QueryRemote, CachePolicy::QueryRemote},
- {Query, CachePolicy::Query},
- {StoreLocal, CachePolicy::StoreLocal},
- {StoreRemote, CachePolicy::StoreRemote},
- {Store, CachePolicy::Store},
- {SkipMeta, CachePolicy::SkipMeta},
- {SkipData, CachePolicy::SkipData},
- {PartialRecord, CachePolicy::PartialRecord},
- {KeepAlive, CachePolicy::KeepAlive},
- {Local, CachePolicy::Local},
- {Remote, CachePolicy::Remote},
- {Default, CachePolicy::Default},
- {Disable, CachePolicy::Disable}};
+ constexpr char CachePolicyDelimiter = ',';
- using PolicyTextPair = std::pair<CachePolicy, std::string_view>;
- const PolicyTextPair FlagsToString[]{
- // Order of these Flags is important: we want the aliases before the atomic values,
- // and the bigger aliases first, to reduce the number of tokens we add
- {CachePolicy::Default, Default},
- {CachePolicy::Remote, Remote},
- {CachePolicy::Local, Local},
- {CachePolicy::Store, Store},
- {CachePolicy::Query, Query},
-
- // Order of Atomics doesn't matter, so arbitrarily we list them in enum order
- {CachePolicy::QueryLocal, QueryLocal},
- {CachePolicy::QueryRemote, QueryRemote},
- {CachePolicy::StoreLocal, StoreLocal},
- {CachePolicy::StoreRemote, StoreRemote},
- {CachePolicy::SkipMeta, SkipMeta},
- {CachePolicy::SkipData, SkipData},
- {CachePolicy::PartialRecord, PartialRecord},
- {CachePolicy::KeepAlive, KeepAlive},
+ struct CachePolicyToTextData
+ {
+ CachePolicy Policy;
+ std::string_view Text;
+ };
- // None must come at the end of the array, to write out only if no others exist
- {CachePolicy::None, None},
+ const CachePolicyToTextData CachePolicyToText[]{
+ // Flags with multiple bits are ordered by bit count to minimize token count in the text format.
+ {CachePolicy::Default, "Default"sv},
+ {CachePolicy::Remote, "Remote"sv},
+ {CachePolicy::Local, "Local"sv},
+ {CachePolicy::Store, "Store"sv},
+ {CachePolicy::Query, "Query"sv},
+ // Flags with only one bit can be in any order. Match the order in CachePolicy.
+ {CachePolicy::QueryLocal, "QueryLocal"sv},
+ {CachePolicy::QueryRemote, "QueryRemote"sv},
+ {CachePolicy::StoreLocal, "StoreLocal"sv},
+ {CachePolicy::StoreRemote, "StoreRemote"sv},
+ {CachePolicy::SkipMeta, "SkipMeta"sv},
+ {CachePolicy::SkipData, "SkipData"sv},
+ {CachePolicy::PartialRecord, "PartialRecord"sv},
+ {CachePolicy::KeepAlive, "KeepAlive"sv},
+ // None must be last because it matches every policy.
+ {CachePolicy::None, "None"sv},
};
- constexpr CachePolicy KnownFlags =
- CachePolicy::Default | CachePolicy::SkipMeta | CachePolicy::SkipData | CachePolicy::KeepAlive | CachePolicy::PartialRecord;
-} // namespace detail::CachePolicyImpl
-StringBuilderBase&
-AppendToBuilderImpl(StringBuilderBase& Builder, CachePolicy Policy)
-{
- // Remove any bits we don't recognize; write None if there are not any bits we recognize
- Policy = Policy & detail::CachePolicyImpl::KnownFlags;
- for (const detail::CachePolicyImpl::PolicyTextPair& Pair : detail::CachePolicyImpl::FlagsToString)
+ constexpr CachePolicy CachePolicyKnownFlags =
+ CachePolicy::Default | CachePolicy::SkipMeta | CachePolicy::SkipData | CachePolicy::PartialRecord | CachePolicy::KeepAlive;
+
+ StringBuilderBase& CachePolicyToString(StringBuilderBase& Builder, CachePolicy Policy)
{
- if (EnumHasAllFlags(Policy, Pair.first))
+ // Mask out unknown flags. None will be written if no flags are known.
+ Policy &= CachePolicyKnownFlags;
+ for (const CachePolicyToTextData& Pair : CachePolicyToText)
{
- EnumRemoveFlags(Policy, Pair.first);
- Builder << Pair.second << detail::CachePolicyImpl::DelimiterChar;
- if (Policy == CachePolicy::None)
+ if (EnumHasAllFlags(Policy, Pair.Policy))
{
- break;
+ EnumRemoveFlags(Policy, Pair.Policy);
+ Builder << Pair.Text << CachePolicyDelimiter;
+ if (Policy == CachePolicy::None)
+ {
+ break;
+ }
}
}
+ Builder.RemoveSuffix(1);
+ return Builder;
}
- Builder.RemoveSuffix(1); // Text will have been added by CachePolicy::None if not by anything else
- return Builder;
-}
+
+ CachePolicy ParseCachePolicy(const std::string_view Text)
+ {
+ ZEN_ASSERT(!Text.empty()); // ParseCachePolicy requires a non-empty string
+ CachePolicy Policy = CachePolicy::None;
+ ForEachStrTok(Text, CachePolicyDelimiter, [&Policy, Index = int32_t(0)](const std::string_view& Token) mutable {
+ const int32_t EndIndex = Index;
+ for (; Index < sizeof(CachePolicyToText) / sizeof(CachePolicyToText[0]); ++Index)
+ {
+ if (CachePolicyToText[Index].Text == Token)
+ {
+ Policy |= CachePolicyToText[Index].Policy;
+ ++Index;
+ return true;
+ }
+ }
+ for (Index = 0; Index < EndIndex; ++Index)
+ {
+ if (CachePolicyToText[Index].Text == Token)
+ {
+ Policy |= CachePolicyToText[Index].Policy;
+ ++Index;
+ return true;
+ }
+ }
+ return true;
+ });
+ return Policy;
+ }
+
+} // namespace DerivedData::Private
+
StringBuilderBase&
operator<<(StringBuilderBase& Builder, CachePolicy Policy)
{
- return AppendToBuilderImpl(Builder, Policy);
+ return DerivedData::Private::CachePolicyToString(Builder, Policy);
}
CachePolicy
ParseCachePolicy(std::string_view Text)
{
- ZEN_ASSERT(!Text.empty()); // Empty string is not valid input to ParseCachePolicy
-
- CachePolicy Result = CachePolicy::None;
- ForEachStrTok(Text, detail::CachePolicyImpl::DelimiterChar, [&Result](const std::string_view& Token) {
- auto it = detail::CachePolicyImpl::TextToPolicy.find(Token);
- if (it != detail::CachePolicyImpl::TextToPolicy.end())
- {
- Result |= it->second;
- }
- return true;
- });
-
- return Result;
+ return DerivedData::Private::ParseCachePolicy(Text);
}
-namespace Private {
+CachePolicy
+ConvertToUpstream(CachePolicy Policy)
+{
+ // Set Local flags equal to downstream's Remote flags.
+ // Delete Skip flags if StoreLocal is true, otherwise use the downstream value.
+ // Use the downstream value for all other flags.
+ return (EnumHasAllFlags(Policy, CachePolicy::QueryRemote) ? CachePolicy::QueryLocal : CachePolicy::None) |
+ (EnumHasAllFlags(Policy, CachePolicy::StoreRemote) ? CachePolicy::StoreLocal : CachePolicy::None) |
+ (!EnumHasAllFlags(Policy, CachePolicy::StoreLocal) ? (Policy & (CachePolicy::SkipData | CachePolicy::SkipMeta))
+ : CachePolicy::None) |
+ (Policy & ~(CachePolicy::Local | CachePolicy::SkipData | CachePolicy::SkipMeta));
+}
- class CacheRecordPolicyShared final : public ICacheRecordPolicyShared
+class Private::CacheRecordPolicyShared final : public Private::ICacheRecordPolicyShared
+{
+public:
+ inline void AddValuePolicy(const CacheValuePolicy& Value) final
{
- public:
- inline std::span<const CacheValuePolicy> GetValuePolicies() const final { return Values; }
-
- inline void AddValuePolicy(const CacheValuePolicy& Policy) final { Values.push_back(Policy); }
-
- inline void Build() final
- {
- std::sort(Values.begin(), Values.end(), [](const CacheValuePolicy& A, const CacheValuePolicy& B) { return A.Id < B.Id; });
- }
+ ZEN_ASSERT(Value.Id); // Failed to add value policy because the ID is null.
+ const auto Insert =
+ std::lower_bound(Values.begin(), Values.end(), Value, [](const CacheValuePolicy& Existing, const CacheValuePolicy& New) {
+ return Existing.Id < New.Id;
+ });
+ ZEN_ASSERT(
+ !(Insert < Values.end() &&
+ Insert->Id == Value.Id)); // Failed to add value policy with ID %s because it has an existing value policy with that ID. ")
+ Values.insert(Insert, Value);
+ }
- private:
- std::vector<CacheValuePolicy> Values;
- };
+ inline std::span<const CacheValuePolicy> GetValuePolicies() const final { return Values; }
-} // namespace Private
+private:
+ std::vector<CacheValuePolicy> Values;
+};
CachePolicy
CacheRecordPolicy::GetValuePolicy(const Oid& Id) const
{
if (Shared)
{
- if (std::span<const CacheValuePolicy> Values = Shared->GetValuePolicies(); !Values.empty())
+ const std::span<const CacheValuePolicy> Values = Shared->GetValuePolicies();
+ const auto Iter =
+ std::lower_bound(Values.begin(), Values.end(), Id, [](const CacheValuePolicy& A, const Oid& B) { return A.Id < B; });
+ if (Iter != Values.end() && Iter->Id == Id)
{
- auto Iter =
- std::lower_bound(Values.begin(), Values.end(), Id, [](const CacheValuePolicy& A, const Oid& B) { return A.Id < B; });
- if (Iter != Values.end() && Iter->Id == Id)
- {
- return Iter->Policy;
- }
+ return Iter->Policy;
}
}
return DefaultValuePolicy;
@@ -162,46 +169,58 @@ void
CacheRecordPolicy::Save(CbWriter& Writer) const
{
Writer.BeginObject();
+ // The RecordPolicy is calculated from the ValuePolicies and does not need to be saved separately.
+ Writer.AddString("BasePolicy"sv, WriteToString<128>(GetBasePolicy()));
+ if (!IsUniform())
{
- // The RecordPolicy is calculated from the ValuePolicies and does not need to be saved separately.
- Writer << "DefaultValuePolicy"sv << WriteToString<128>(GetDefaultValuePolicy());
- if (!IsUniform())
+ Writer.BeginArray("ValuePolicies"sv);
+ for (const CacheValuePolicy& Value : GetValuePolicies())
{
- // FCacheRecordPolicyBuilder guarantees IsUniform -> non-empty GetValuePolicies. Small size penalty here if not.
- Writer.BeginArray("ValuePolicies"sv);
- {
- for (const CacheValuePolicy& ValuePolicy : GetValuePolicies())
- {
- // FCacheRecordPolicyBuilder is responsible for ensuring that each ValuePolicy != DefaultValuePolicy
- // If it lets any duplicates through we will incur a small serialization size penalty here
- Writer.BeginObject();
- Writer << "Id"sv << ValuePolicy.Id;
- Writer << "Policy"sv << WriteToString<128>(ValuePolicy.Policy);
- Writer.EndObject();
- }
- }
- Writer.EndArray();
+ Writer.BeginObject();
+ Writer.AddObjectId("Id"sv, Value.Id);
+ Writer.AddString("Policy"sv, WriteToString<128>(Value.Policy));
+ Writer.EndObject();
}
+ Writer.EndArray();
}
Writer.EndObject();
}
-CacheRecordPolicy
-CacheRecordPolicy::Load(CbObjectView Object, CachePolicy DefaultPolicy)
+OptionalCacheRecordPolicy
+CacheRecordPolicy::Load(const CbObjectView Object)
{
- std::string_view PolicyText = Object["DefaultValuePolicy"sv].AsString();
- CachePolicy DefaultValuePolicy = !PolicyText.empty() ? ParseCachePolicy(PolicyText) : DefaultPolicy;
+ std::string_view BasePolicyText = Object["BasePolicy"sv].AsString();
+#if BACKWARDS_COMPATABILITY_JAN2022
+ if (BasePolicyText.empty())
+ {
+ BasePolicyText = Object["DefaultValuePolicy"sv].AsString();
+ }
+#endif
+ if (BasePolicyText.empty())
+ {
+ return {};
+ }
- CacheRecordPolicyBuilder Builder(DefaultValuePolicy);
- for (CbFieldView ValueObjectField : Object["ValuePolicies"sv])
+ CacheRecordPolicyBuilder Builder(ParseCachePolicy(BasePolicyText));
+ for (CbFieldView ValueField : Object["ValuePolicies"sv])
{
- CbObjectView ValueObject = ValueObjectField.AsObjectView();
- const Oid ValueId = ValueObject["Id"sv].AsObjectId();
- PolicyText = ValueObject["Policy"sv].AsString();
- CachePolicy ValuePolicy = !PolicyText.empty() ? ParseCachePolicy(PolicyText) : DefaultValuePolicy;
- // FCacheRecordPolicyBuilder should guarantee that FValueId(ValueId).IsValid and ValuePolicy != DefaultValuePolicy
- // If it lets any through we will have unused data in the record we create.
- Builder.AddValuePolicy(ValueId, ValuePolicy);
+ const CbObjectView Value = ValueField.AsObjectView();
+ const Oid Id = Value["Id"sv].AsObjectId();
+ const std::string_view PolicyText = Value["Policy"sv].AsString();
+ if (!Id || PolicyText.empty())
+ {
+ return {};
+ }
+ CachePolicy Policy = ParseCachePolicy(PolicyText);
+#if BACKWARDS_COMPATABILITY_JAN2022
+ Policy = Policy & CacheValuePolicy::PolicyMask;
+#else
+ if (EnumHasAnyFlags(Policy, ~CacheValuePolicy::PolicyMask))
+ {
+ return {};
+ }
+#endif
+ Builder.AddValuePolicy(Id, Policy);
}
return Builder.Build();
@@ -210,30 +229,28 @@ CacheRecordPolicy::Load(CbObjectView Object, CachePolicy DefaultPolicy)
CacheRecordPolicy
CacheRecordPolicy::ConvertToUpstream() const
{
- auto DownstreamToUpstream = [](CachePolicy P) {
- // Remote|Local -> Set Remote
- // Delete Skip Flags
- // Maintain Remaining Flags
- return (EnumHasAllFlags(P, CachePolicy::QueryRemote) ? CachePolicy::QueryLocal : CachePolicy::None) |
- (EnumHasAllFlags(P, CachePolicy::StoreRemote) ? CachePolicy::StoreLocal : CachePolicy::None) |
- (P & ~(CachePolicy::SkipData | CachePolicy::SkipMeta));
- };
- CacheRecordPolicyBuilder Builder(DownstreamToUpstream(GetDefaultValuePolicy()));
+ CacheRecordPolicyBuilder Builder(zen::ConvertToUpstream(GetBasePolicy()));
for (const CacheValuePolicy& ValuePolicy : GetValuePolicies())
{
- Builder.AddValuePolicy(ValuePolicy.Id, DownstreamToUpstream(ValuePolicy.Policy));
+ Builder.AddValuePolicy(ValuePolicy.Id, zen::ConvertToUpstream(ValuePolicy.Policy));
}
return Builder.Build();
}
void
-CacheRecordPolicyBuilder::AddValuePolicy(const CacheValuePolicy& Policy)
+CacheRecordPolicyBuilder::AddValuePolicy(const CacheValuePolicy& Value)
{
+ ZEN_ASSERT(!EnumHasAnyFlags(Value.Policy,
+ ~Value.PolicyMask)); // Value policy contains flags that only make sense on the record policy. Policy: %s
+ if (Value.Policy == (BasePolicy & Value.PolicyMask))
+ {
+ return;
+ }
if (!Shared)
{
Shared = new Private::CacheRecordPolicyShared;
}
- Shared->AddValuePolicy(Policy);
+ Shared->AddValuePolicy(Value);
}
CacheRecordPolicy
@@ -242,13 +259,14 @@ CacheRecordPolicyBuilder::Build()
CacheRecordPolicy Policy(BasePolicy);
if (Shared)
{
- Shared->Build();
- const auto PolicyOr = [](CachePolicy A, CachePolicy B) { return A | (B & ~CachePolicy::SkipData); };
- const std::span<const CacheValuePolicy> Values = Shared->GetValuePolicies();
- Policy.RecordPolicy = BasePolicy;
+ const auto Add = [](const CachePolicy A, const CachePolicy B) {
+ return ((A | B) & ~CachePolicy::SkipData) | ((A & B) & CachePolicy::SkipData);
+ };
+ const std::span<const CacheValuePolicy> Values = Shared->GetValuePolicies();
+ Policy.RecordPolicy = BasePolicy;
for (const CacheValuePolicy& ValuePolicy : Values)
{
- Policy.RecordPolicy = PolicyOr(Policy.RecordPolicy, ValuePolicy.Policy);
+ Policy.RecordPolicy = Add(Policy.RecordPolicy, ValuePolicy.Policy);
}
Policy.Shared = std::move(Shared);
}
diff --git a/zenutil/include/zenutil/cache/cachekey.h b/zenutil/include/zenutil/cache/cachekey.h
index a0a83a883..aa649b4dc 100644
--- a/zenutil/include/zenutil/cache/cachekey.h
+++ b/zenutil/include/zenutil/cache/cachekey.h
@@ -50,6 +50,12 @@ struct CacheChunkRequest
CachePolicy Policy = CachePolicy::Default;
};
+struct CacheKeyRequest
+{
+ CacheKey Key;
+ CacheRecordPolicy Policy;
+};
+
inline bool
operator<(const CacheChunkRequest& A, const CacheChunkRequest& B)
{
diff --git a/zenutil/include/zenutil/cache/cachepolicy.h b/zenutil/include/zenutil/cache/cachepolicy.h
index b3602edbd..3eb0fda66 100644
--- a/zenutil/include/zenutil/cache/cachepolicy.h
+++ b/zenutil/include/zenutil/cache/cachepolicy.h
@@ -12,14 +12,26 @@
#include <span>
#include <unordered_map>
+#define BACKWARDS_COMPATABILITY_JAN2022 1
+#if BACKWARDS_COMPATABILITY_JAN2022
+# define BACKWARDS_COMPATABILITY_JAN2022_CODE(...) __VA_ARGS__
+#else
+# define BACKWARDS_COMPATABILITY_JAN2022_CODE(...)
+#endif
+
+namespace zen::Private {
+class ICacheRecordPolicyShared;
+}
namespace zen {
class CbObjectView;
class CbWriter;
+class OptionalCacheRecordPolicy;
+
enum class CachePolicy : uint32_t
{
- /** A value without any flags set. */
+ /** A value with no flags. Disables access to the cache unless combined with other flags. */
None = 0,
/** Allow a cache request to query local caches. */
@@ -29,17 +41,26 @@ enum class CachePolicy : uint32_t
/** Allow a cache request to query any caches. */
Query = QueryLocal | QueryRemote,
- /** Allow cache records and values to be stored in local caches. */
+ /** Allow cache requests to query and store records and values in local caches. */
StoreLocal = 1 << 2,
/** Allow cache records and values to be stored in remote caches. */
StoreRemote = 1 << 3,
/** Allow cache records and values to be stored in any caches. */
Store = StoreLocal | StoreRemote,
- /** Skip fetching the metadata for record requests. */
- SkipMeta = 1 << 4,
+ /** Allow cache requests to query and store records and values in local caches. */
+ Local = QueryLocal | StoreLocal,
+ /** Allow cache requests to query and store records and values in remote caches. */
+ Remote = QueryRemote | StoreRemote,
+
+ /** Allow cache requests to query and store records and values in any caches. */
+ Default = Query | Store,
+
/** Skip fetching the data for values. */
- SkipData = 1 << 5,
+ SkipData = 1 << 4,
+
+ /** Skip fetching the metadata for record requests. */
+ SkipMeta = 1 << 5,
/**
* Partial output will be provided with the error status when a required value is missing.
@@ -48,7 +69,7 @@ enum class CachePolicy : uint32_t
* without rebuilding the whole record. The cache automatically adds this flag when there are
* other cache stores that it may be able to recover missing values from.
*
- * Missing values will be returned in the records or chunks, but with only the hash and size.
+ * Missing values will be returned in the records, but with only the hash and size.
*
* Applying this flag for a put of a record allows a partial record to be stored.
*/
@@ -61,50 +82,48 @@ enum class CachePolicy : uint32_t
* to be used when subsequent accesses will not tolerate a cache miss.
*/
KeepAlive = 1 << 7,
-
- /** Allow cache requests to query and store records and values in local caches. */
- Local = QueryLocal | StoreLocal,
- /** Allow cache requests to query and store records and values in remote caches. */
- Remote = QueryRemote | StoreRemote,
-
- /** Allow cache requests to query and store records and values in any caches. */
- Default = Query | Store,
-
- /** Do not allow cache requests to query or store records and values in any caches. */
- Disable = None,
};
gsl_DEFINE_ENUM_BITMASK_OPERATORS(CachePolicy);
-/** Serialize Policy to text and append to Builder. Appended text will not be empty. */
+/** Append a non-empty text version of the policy to the builder. */
StringBuilderBase& operator<<(StringBuilderBase& Builder, CachePolicy Policy);
-/** Parse text written by operator<< back into an ECachePolicy. Text must not be empty. */
+/** Parse non-empty text written by operator<< into a policy. */
CachePolicy ParseCachePolicy(std::string_view Text);
+/** Return input converted into the equivalent policy that the upstream should use when forwarding a put or get to an upstream server. */
+CachePolicy ConvertToUpstream(CachePolicy Policy);
+
+inline CachePolicy
+Union(CachePolicy A, CachePolicy B)
+{
+ constexpr CachePolicy InvertedFlags = CachePolicy::SkipData | CachePolicy::SkipMeta;
+ return (A & ~(InvertedFlags)) | (B & ~(InvertedFlags)) | (A & B & InvertedFlags);
+}
/** A value ID and the cache policy to use for that value. */
struct CacheValuePolicy
{
Oid Id;
CachePolicy Policy = CachePolicy::Default;
+
+ /** Flags that are valid on a value policy. */
+ static constexpr CachePolicy PolicyMask = CachePolicy::Default | CachePolicy::SkipData;
};
-namespace Private {
- /** Interface for the private implementation of the cache record policy. */
- class ICacheRecordPolicyShared : public RefCounted
- {
- public:
- virtual ~ICacheRecordPolicyShared() = default;
- virtual std::span<const CacheValuePolicy> GetValuePolicies() const = 0;
- virtual void AddValuePolicy(const CacheValuePolicy& Policy) = 0;
- virtual void Build() = 0;
- };
-} // namespace Private
+/** Interface for the private implementation of the cache record policy. */
+class Private::ICacheRecordPolicyShared : public RefCounted
+{
+public:
+ virtual ~ICacheRecordPolicyShared() = default;
+ virtual void AddValuePolicy(const CacheValuePolicy& Policy) = 0;
+ virtual std::span<const CacheValuePolicy> GetValuePolicies() const = 0;
+};
/**
* Flags to control the behavior of cache record requests, with optional overrides by value.
*
* Examples:
- * - A base policy of Disable, with value policy overrides of Default, will fetch those values if
- * they exist in the record, and skip data for any other values.
+ * - A base policy of None with value policy overrides of Default will fetch those values if they
+ * exist in the record, and skip data for any other values.
* - A base policy of Default, with value policy overrides of (Query | SkipData), will skip those
* values, but still check if they exist, and will load any other values.
*/
@@ -115,34 +134,35 @@ public:
CacheRecordPolicy() = default;
/** Construct a cache record policy with a uniform policy for the record and every value. */
- inline CacheRecordPolicy(CachePolicy Policy) : RecordPolicy(Policy), DefaultValuePolicy(Policy) {}
+ inline CacheRecordPolicy(CachePolicy BasePolicy)
+ : RecordPolicy(BasePolicy)
+ , DefaultValuePolicy(BasePolicy & CacheValuePolicy::PolicyMask)
+ {
+ }
/** Returns true if the record and every value use the same cache policy. */
- inline bool IsUniform() const { return !Shared && RecordPolicy == DefaultValuePolicy; }
+ inline bool IsUniform() const { return !Shared; }
/** Returns the cache policy to use for the record. */
inline CachePolicy GetRecordPolicy() const { return RecordPolicy; }
+ /** Returns the base cache policy that this was constructed from. */
+ inline CachePolicy GetBasePolicy() const { return DefaultValuePolicy | (RecordPolicy & ~CacheValuePolicy::PolicyMask); }
+
/** Returns the cache policy to use for the value. */
CachePolicy GetValuePolicy(const Oid& Id) const;
- /** Returns the cache policy to use for values with no override. */
- inline CachePolicy GetDefaultValuePolicy() const { return DefaultValuePolicy; }
-
/** Returns the array of cache policy overrides for values, sorted by ID. */
inline std::span<const CacheValuePolicy> GetValuePolicies() const
{
return Shared ? Shared->GetValuePolicies() : std::span<const CacheValuePolicy>();
}
- /** Save the values from *this into the given writer. */
+ /** Saves the cache record policy to a compact binary object. */
void Save(CbWriter& Writer) const;
- /**
- * Returns a policy loaded from values on Object.
- * Invalid data will result in a uniform CacheRecordPolicy with defaultValuePolicy == DefaultPolicy.
- */
- static CacheRecordPolicy Load(CbObjectView Object, CachePolicy DefaultPolicy = CachePolicy::Default);
+ /** Loads a cache record policy from an object. */
+ static OptionalCacheRecordPolicy Load(CbObjectView Object);
/** Return *this converted into the equivalent policy that the upstream should use when forwarding a put or get to an upstream server.
*/
@@ -150,6 +170,7 @@ public:
private:
friend class CacheRecordPolicyBuilder;
+ friend class OptionalCacheRecordPolicy;
CachePolicy RecordPolicy = CachePolicy::Default;
CachePolicy DefaultValuePolicy = CachePolicy::Default;
@@ -167,7 +188,7 @@ public:
inline explicit CacheRecordPolicyBuilder(CachePolicy Policy) : BasePolicy(Policy) {}
/** Adds a cache policy override for a value. */
- void AddValuePolicy(const CacheValuePolicy& Policy);
+ void AddValuePolicy(const CacheValuePolicy& Value);
inline void AddValuePolicy(const Oid& Id, CachePolicy Policy) { AddValuePolicy({Id, Policy}); }
/** Build a cache record policy, which makes this builder subsequently unusable. */
@@ -178,4 +199,38 @@ private:
RefPtr<Private::ICacheRecordPolicyShared> Shared;
};
+/**
+ * A cache record policy that can be null.
+ *
+ * @see CacheRecordPolicy
+ */
+class OptionalCacheRecordPolicy : private CacheRecordPolicy
+{
+public:
+ inline OptionalCacheRecordPolicy() : CacheRecordPolicy(~CachePolicy::None) {}
+
+ inline OptionalCacheRecordPolicy(CacheRecordPolicy&& InOutput) : CacheRecordPolicy(std::move(InOutput)) {}
+ inline OptionalCacheRecordPolicy(const CacheRecordPolicy& InOutput) : CacheRecordPolicy(InOutput) {}
+ inline OptionalCacheRecordPolicy& operator=(CacheRecordPolicy&& InOutput)
+ {
+ CacheRecordPolicy::operator=(std::move(InOutput));
+ return *this;
+ }
+ inline OptionalCacheRecordPolicy& operator=(const CacheRecordPolicy& InOutput)
+ {
+ CacheRecordPolicy::operator=(InOutput);
+ return *this;
+ }
+
+ /** Returns the cache record policy. The caller must check for null before using this accessor. */
+ inline const CacheRecordPolicy& Get() const& { return *this; }
+ inline CacheRecordPolicy Get() && { return std::move(*this); }
+
+ inline bool IsNull() const { return RecordPolicy == ~CachePolicy::None; }
+ inline bool IsValid() const { return !IsNull(); }
+ inline explicit operator bool() const { return !IsNull(); }
+
+ inline void Reset() { *this = OptionalCacheRecordPolicy(); }
+};
+
} // namespace zen