aboutsummaryrefslogtreecommitdiff
path: root/src/zenutil
diff options
context:
space:
mode:
authorStefan Boberg <[email protected]>2023-05-02 10:01:47 +0200
committerGitHub <[email protected]>2023-05-02 10:01:47 +0200
commit075d17f8ada47e990fe94606c3d21df409223465 (patch)
treee50549b766a2f3c354798a54ff73404217b4c9af /src/zenutil
parentfix: bundle shouldn't append content zip to zen (diff)
downloadzen-075d17f8ada47e990fe94606c3d21df409223465.tar.xz
zen-075d17f8ada47e990fe94606c3d21df409223465.zip
moved source directories into `/src` (#264)
* moved source directories into `/src` * updated bundle.lua for new `src` path * moved some docs, icon * removed old test trees
Diffstat (limited to 'src/zenutil')
-rw-r--r--src/zenutil/basicfile.cpp575
-rw-r--r--src/zenutil/cache/cachekey.cpp9
-rw-r--r--src/zenutil/cache/cachepolicy.cpp282
-rw-r--r--src/zenutil/cache/cacherequests.cpp1643
-rw-r--r--src/zenutil/cache/rpcrecording.cpp210
-rw-r--r--src/zenutil/include/zenutil/basicfile.h113
-rw-r--r--src/zenutil/include/zenutil/cache/cache.h6
-rw-r--r--src/zenutil/include/zenutil/cache/cachekey.h86
-rw-r--r--src/zenutil/include/zenutil/cache/cachepolicy.h227
-rw-r--r--src/zenutil/include/zenutil/cache/cacherequests.h279
-rw-r--r--src/zenutil/include/zenutil/cache/rpcrecording.h29
-rw-r--r--src/zenutil/include/zenutil/zenserverprocess.h141
-rw-r--r--src/zenutil/xmake.lua9
-rw-r--r--src/zenutil/zenserverprocess.cpp677
14 files changed, 4286 insertions, 0 deletions
diff --git a/src/zenutil/basicfile.cpp b/src/zenutil/basicfile.cpp
new file mode 100644
index 000000000..1e6043d7e
--- /dev/null
+++ b/src/zenutil/basicfile.cpp
@@ -0,0 +1,575 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "zenutil/basicfile.h"
+
+#include <zencore/compactbinary.h>
+#include <zencore/except.h>
+#include <zencore/filesystem.h>
+#include <zencore/fmtutils.h>
+#include <zencore/testing.h>
+#include <zencore/testutils.h>
+
+#if ZEN_PLATFORM_WINDOWS
+# include <zencore/windows.h>
+#else
+# include <fcntl.h>
+# include <sys/file.h>
+# include <sys/stat.h>
+# include <unistd.h>
+#endif
+
+#include <fmt/format.h>
+#include <gsl/gsl-lite.hpp>
+
+namespace zen {
+
+BasicFile::~BasicFile()
+{
+ Close();
+}
+
+void
+BasicFile::Open(const std::filesystem::path& FileName, Mode Mode)
+{
+ std::error_code Ec;
+ Open(FileName, Mode, Ec);
+
+ if (Ec)
+ {
+ throw std::system_error(Ec, fmt::format("failed to open file '{}'", FileName));
+ }
+}
+
+void
+BasicFile::Open(const std::filesystem::path& FileName, Mode Mode, std::error_code& Ec)
+{
+ Ec.clear();
+
+#if ZEN_PLATFORM_WINDOWS
+ DWORD dwCreationDisposition = 0;
+ DWORD dwDesiredAccess = 0;
+ switch (Mode)
+ {
+ case Mode::kRead:
+ dwCreationDisposition |= OPEN_EXISTING;
+ dwDesiredAccess |= GENERIC_READ;
+ break;
+ case Mode::kWrite:
+ dwCreationDisposition |= OPEN_ALWAYS;
+ dwDesiredAccess |= (GENERIC_READ | GENERIC_WRITE);
+ break;
+ case Mode::kDelete:
+ dwCreationDisposition |= OPEN_ALWAYS;
+ dwDesiredAccess |= (GENERIC_READ | GENERIC_WRITE | DELETE);
+ break;
+ case Mode::kTruncate:
+ dwCreationDisposition |= CREATE_ALWAYS;
+ dwDesiredAccess |= (GENERIC_READ | GENERIC_WRITE);
+ break;
+ case Mode::kTruncateDelete:
+ dwCreationDisposition |= CREATE_ALWAYS;
+ dwDesiredAccess |= (GENERIC_READ | GENERIC_WRITE | DELETE);
+ break;
+ }
+
+ const DWORD dwShareMode = FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE;
+ const DWORD dwFlagsAndAttributes = FILE_ATTRIBUTE_NORMAL;
+ HANDLE hTemplateFile = nullptr;
+
+ HANDLE FileHandle = CreateFile(FileName.c_str(),
+ dwDesiredAccess,
+ dwShareMode,
+ /* lpSecurityAttributes */ nullptr,
+ dwCreationDisposition,
+ dwFlagsAndAttributes,
+ hTemplateFile);
+
+ if (FileHandle == INVALID_HANDLE_VALUE)
+ {
+ Ec = MakeErrorCodeFromLastError();
+
+ return;
+ }
+#else
+ int OpenFlags = O_CLOEXEC;
+ switch (Mode)
+ {
+ case Mode::kRead:
+ OpenFlags |= O_RDONLY;
+ break;
+ case Mode::kWrite:
+ case Mode::kDelete:
+ OpenFlags |= (O_RDWR | O_CREAT);
+ break;
+ case Mode::kTruncate:
+ case Mode::kTruncateDelete:
+ OpenFlags |= (O_RDWR | O_CREAT | O_TRUNC);
+ break;
+ }
+
+ int Fd = open(FileName.c_str(), OpenFlags, 0666);
+ if (Fd < 0)
+ {
+ Ec = MakeErrorCodeFromLastError();
+ return;
+ }
+ if (Mode != Mode::kRead)
+ {
+ fchmod(Fd, 0666);
+ }
+
+ void* FileHandle = (void*)(uintptr_t(Fd));
+#endif
+
+ m_FileHandle = FileHandle;
+}
+
+void
+BasicFile::Close()
+{
+ if (m_FileHandle)
+ {
+#if ZEN_PLATFORM_WINDOWS
+ ::CloseHandle(m_FileHandle);
+#else
+ int Fd = int(uintptr_t(m_FileHandle));
+ close(Fd);
+#endif
+ m_FileHandle = nullptr;
+ }
+}
+
+void
+BasicFile::Read(void* Data, uint64_t BytesToRead, uint64_t FileOffset)
+{
+ const uint64_t MaxChunkSize = 2u * 1024 * 1024 * 1024;
+
+ while (BytesToRead)
+ {
+ const uint64_t NumberOfBytesToRead = Min(BytesToRead, MaxChunkSize);
+
+#if ZEN_PLATFORM_WINDOWS
+ OVERLAPPED Ovl{};
+
+ Ovl.Offset = DWORD(FileOffset & 0xffff'ffffu);
+ Ovl.OffsetHigh = DWORD(FileOffset >> 32);
+
+ DWORD dwNumberOfBytesRead = 0;
+ BOOL Success = ::ReadFile(m_FileHandle, Data, DWORD(NumberOfBytesToRead), &dwNumberOfBytesRead, &Ovl);
+
+ ZEN_ASSERT(dwNumberOfBytesRead == NumberOfBytesToRead);
+#else
+ static_assert(sizeof(off_t) >= sizeof(uint64_t), "sizeof(off_t) does not support large files");
+ int Fd = int(uintptr_t(m_FileHandle));
+ int BytesRead = pread(Fd, Data, NumberOfBytesToRead, FileOffset);
+ bool Success = (BytesRead > 0);
+#endif
+
+ if (!Success)
+ {
+ ThrowLastError(fmt::format("Failed to read from file '{}'", zen::PathFromHandle(m_FileHandle)));
+ }
+
+ BytesToRead -= NumberOfBytesToRead;
+ FileOffset += NumberOfBytesToRead;
+ Data = reinterpret_cast<uint8_t*>(Data) + NumberOfBytesToRead;
+ }
+}
+
+IoBuffer
+BasicFile::ReadAll()
+{
+ IoBuffer Buffer(FileSize());
+ Read(Buffer.MutableData(), Buffer.Size(), 0);
+ return Buffer;
+}
+
+void
+BasicFile::StreamFile(std::function<void(const void* Data, uint64_t Size)>&& ChunkFun)
+{
+ StreamByteRange(0, FileSize(), std::move(ChunkFun));
+}
+
+void
+BasicFile::StreamByteRange(uint64_t FileOffset, uint64_t Size, std::function<void(const void* Data, uint64_t Size)>&& ChunkFun)
+{
+ const uint64_t ChunkSize = 128 * 1024;
+ IoBuffer ReadBuffer{ChunkSize};
+ void* BufferPtr = ReadBuffer.MutableData();
+
+ uint64_t RemainBytes = Size;
+ uint64_t CurrentOffset = FileOffset;
+
+ while (RemainBytes)
+ {
+ const uint64_t ThisChunkBytes = zen::Min(ChunkSize, RemainBytes);
+
+ Read(BufferPtr, ThisChunkBytes, CurrentOffset);
+
+ ChunkFun(BufferPtr, ThisChunkBytes);
+
+ CurrentOffset += ThisChunkBytes;
+ RemainBytes -= ThisChunkBytes;
+ }
+}
+
+void
+BasicFile::Write(MemoryView Data, uint64_t FileOffset, std::error_code& Ec)
+{
+ Write(Data.GetData(), Data.GetSize(), FileOffset, Ec);
+}
+
+void
+BasicFile::Write(const void* Data, uint64_t Size, uint64_t FileOffset, std::error_code& Ec)
+{
+ Ec.clear();
+
+ const uint64_t MaxChunkSize = 2u * 1024 * 1024 * 1024;
+
+ while (Size)
+ {
+ const uint64_t NumberOfBytesToWrite = Min(Size, MaxChunkSize);
+
+#if ZEN_PLATFORM_WINDOWS
+ OVERLAPPED Ovl{};
+
+ Ovl.Offset = DWORD(FileOffset & 0xffff'ffffu);
+ Ovl.OffsetHigh = DWORD(FileOffset >> 32);
+
+ DWORD dwNumberOfBytesWritten = 0;
+
+ BOOL Success = ::WriteFile(m_FileHandle, Data, DWORD(NumberOfBytesToWrite), &dwNumberOfBytesWritten, &Ovl);
+#else
+ static_assert(sizeof(off_t) >= sizeof(uint64_t), "sizeof(off_t) does not support large files");
+ int Fd = int(uintptr_t(m_FileHandle));
+ int BytesWritten = pwrite(Fd, Data, NumberOfBytesToWrite, FileOffset);
+ bool Success = (BytesWritten > 0);
+#endif
+
+ if (!Success)
+ {
+ Ec = MakeErrorCodeFromLastError();
+
+ return;
+ }
+
+ Size -= NumberOfBytesToWrite;
+ FileOffset += NumberOfBytesToWrite;
+ Data = reinterpret_cast<const uint8_t*>(Data) + NumberOfBytesToWrite;
+ }
+}
+
+void
+BasicFile::Write(MemoryView Data, uint64_t FileOffset)
+{
+ Write(Data.GetData(), Data.GetSize(), FileOffset);
+}
+
+void
+BasicFile::Write(const void* Data, uint64_t Size, uint64_t Offset)
+{
+ std::error_code Ec;
+ Write(Data, Size, Offset, Ec);
+
+ if (Ec)
+ {
+ throw std::system_error(Ec, fmt::format("Failed to write to file '{}'", zen::PathFromHandle(m_FileHandle)));
+ }
+}
+
+void
+BasicFile::WriteAll(IoBuffer Data, std::error_code& Ec)
+{
+ Write(Data.Data(), Data.Size(), 0, Ec);
+}
+
+void
+BasicFile::Flush()
+{
+#if ZEN_PLATFORM_WINDOWS
+ FlushFileBuffers(m_FileHandle);
+#else
+ int Fd = int(uintptr_t(m_FileHandle));
+ fsync(Fd);
+#endif
+}
+
+uint64_t
+BasicFile::FileSize()
+{
+#if ZEN_PLATFORM_WINDOWS
+ ULARGE_INTEGER liFileSize;
+ liFileSize.LowPart = ::GetFileSize(m_FileHandle, &liFileSize.HighPart);
+ if (liFileSize.LowPart == INVALID_FILE_SIZE)
+ {
+ int Error = zen::GetLastError();
+ if (Error)
+ {
+ ThrowSystemError(Error, fmt::format("Failed to get file size from file '{}'", PathFromHandle(m_FileHandle)));
+ }
+ }
+ return uint64_t(liFileSize.QuadPart);
+#else
+ int Fd = int(uintptr_t(m_FileHandle));
+ static_assert(sizeof(decltype(stat::st_size)) == sizeof(uint64_t), "fstat() doesn't support large files");
+ struct stat Stat;
+ fstat(Fd, &Stat);
+ return uint64_t(Stat.st_size);
+#endif
+}
+
+void
+BasicFile::SetFileSize(uint64_t FileSize)
+{
+#if ZEN_PLATFORM_WINDOWS
+ LARGE_INTEGER liFileSize;
+ liFileSize.QuadPart = FileSize;
+ BOOL OK = ::SetFilePointerEx(m_FileHandle, liFileSize, 0, FILE_BEGIN);
+ if (OK == FALSE)
+ {
+ int Error = zen::GetLastError();
+ if (Error)
+ {
+ ThrowSystemError(Error, fmt::format("Failed to set file pointer to {} for file {}", FileSize, PathFromHandle(m_FileHandle)));
+ }
+ }
+ OK = ::SetEndOfFile(m_FileHandle);
+ if (OK == FALSE)
+ {
+ int Error = zen::GetLastError();
+ if (Error)
+ {
+ ThrowSystemError(Error, fmt::format("Failed to set end of file to {} for file {}", FileSize, PathFromHandle(m_FileHandle)));
+ }
+ }
+#elif ZEN_PLATFORM_MAC
+ int Fd = int(intptr_t(m_FileHandle));
+ if (ftruncate(Fd, (off_t)FileSize) < 0)
+ {
+ int Error = zen::GetLastError();
+ if (Error)
+ {
+ ThrowSystemError(Error, fmt::format("Failed to set truncate file to {} for file {}", FileSize, PathFromHandle(m_FileHandle)));
+ }
+ }
+#else
+ int Fd = int(intptr_t(m_FileHandle));
+ if (ftruncate64(Fd, (off64_t)FileSize) < 0)
+ {
+ int Error = zen::GetLastError();
+ if (Error)
+ {
+ ThrowSystemError(Error, fmt::format("Failed to set truncate file to {} for file {}", FileSize, PathFromHandle(m_FileHandle)));
+ }
+ }
+ if (FileSize > 0)
+ {
+ int Error = posix_fallocate64(Fd, 0, (off64_t)FileSize);
+ if (Error)
+ {
+ ThrowSystemError(Error, fmt::format("Failed to allocate space of {} for file {}", FileSize, PathFromHandle(m_FileHandle)));
+ }
+ }
+#endif
+}
+
+void*
+BasicFile::Detach()
+{
+ void* FileHandle = m_FileHandle;
+ m_FileHandle = 0;
+ return FileHandle;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+TemporaryFile::~TemporaryFile()
+{
+ Close();
+}
+
+void
+TemporaryFile::Close()
+{
+ if (m_FileHandle)
+ {
+#if ZEN_PLATFORM_WINDOWS
+ // Mark file for deletion when final handle is closed
+
+ FILE_DISPOSITION_INFO Fdi{.DeleteFile = TRUE};
+
+ SetFileInformationByHandle(m_FileHandle, FileDispositionInfo, &Fdi, sizeof Fdi);
+#else
+ std::filesystem::path FilePath = zen::PathFromHandle(m_FileHandle);
+ unlink(FilePath.c_str());
+#endif
+
+ BasicFile::Close();
+ }
+}
+
+void
+TemporaryFile::CreateTemporary(std::filesystem::path TempDirName, std::error_code& Ec)
+{
+ StringBuilder<64> TempName;
+ Oid::NewOid().ToString(TempName);
+
+ m_TempPath = TempDirName / TempName.c_str();
+
+ Open(m_TempPath, BasicFile::Mode::kTruncateDelete, Ec);
+}
+
+void
+TemporaryFile::MoveTemporaryIntoPlace(std::filesystem::path FinalFileName, std::error_code& Ec)
+{
+ // We intentionally call the base class Close() since otherwise we'll end up
+ // deleting the temporary file
+ BasicFile::Close();
+
+ std::filesystem::rename(m_TempPath, FinalFileName, Ec);
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+LockFile::LockFile()
+{
+}
+
+LockFile::~LockFile()
+{
+#if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC
+ int Fd = int(intptr_t(m_FileHandle));
+ flock(Fd, LOCK_UN | LOCK_NB);
+#endif
+}
+
+void
+LockFile::Create(std::filesystem::path FileName, CbObject Payload, std::error_code& Ec)
+{
+#if ZEN_PLATFORM_WINDOWS
+ Ec.clear();
+
+ const DWORD dwCreationDisposition = CREATE_ALWAYS;
+ DWORD dwDesiredAccess = GENERIC_READ | GENERIC_WRITE | DELETE;
+ const DWORD dwShareMode = FILE_SHARE_READ;
+ const DWORD dwFlagsAndAttributes = FILE_ATTRIBUTE_NORMAL | FILE_FLAG_DELETE_ON_CLOSE;
+ HANDLE hTemplateFile = nullptr;
+
+ HANDLE FileHandle = CreateFile(FileName.c_str(),
+ dwDesiredAccess,
+ dwShareMode,
+ /* lpSecurityAttributes */ nullptr,
+ dwCreationDisposition,
+ dwFlagsAndAttributes,
+ hTemplateFile);
+
+ if (FileHandle == INVALID_HANDLE_VALUE)
+ {
+ Ec = zen::MakeErrorCodeFromLastError();
+
+ return;
+ }
+#elif ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC
+ int Fd = open(FileName.c_str(), O_RDWR | O_CREAT | O_CLOEXEC, 0666);
+ if (Fd < 0)
+ {
+ Ec = zen::MakeErrorCodeFromLastError();
+ return;
+ }
+ fchmod(Fd, 0666);
+
+ int LockRet = flock(Fd, LOCK_EX | LOCK_NB);
+ if (LockRet < 0)
+ {
+ Ec = zen::MakeErrorCodeFromLastError();
+ close(Fd);
+ return;
+ }
+
+ void* FileHandle = (void*)uintptr_t(Fd);
+#endif
+
+ m_FileHandle = FileHandle;
+
+ BasicFile::Write(Payload.GetBuffer(), 0, Ec);
+}
+
+void
+LockFile::Update(CbObject Payload, std::error_code& Ec)
+{
+ BasicFile::Write(Payload.GetBuffer(), 0, Ec);
+}
+
+/*
+ ___________ __
+ \__ ___/___ _______/ |_ ______
+ | |_/ __ \ / ___/\ __\/ ___/
+ | |\ ___/ \___ \ | | \___ \
+ |____| \___ >____ > |__| /____ >
+ \/ \/ \/
+*/
+
+#if ZEN_WITH_TESTS
+
+TEST_CASE("BasicFile")
+{
+ ScopedCurrentDirectoryChange _;
+
+ BasicFile File1;
+ CHECK_THROWS(File1.Open("zonk", BasicFile::Mode::kRead));
+ CHECK_NOTHROW(File1.Open("zonk", BasicFile::Mode::kTruncate));
+ CHECK_NOTHROW(File1.Write("abcd", 4, 0));
+ CHECK(File1.FileSize() == 4);
+ {
+ IoBuffer Data = File1.ReadAll();
+ CHECK(Data.Size() == 4);
+ CHECK_EQ(memcmp(Data.Data(), "abcd", 4), 0);
+ }
+ CHECK_NOTHROW(File1.Write("efgh", 4, 2));
+ CHECK(File1.FileSize() == 6);
+ {
+ IoBuffer Data = File1.ReadAll();
+ CHECK(Data.Size() == 6);
+ CHECK_EQ(memcmp(Data.Data(), "abefgh", 6), 0);
+ }
+}
+
+TEST_CASE("TemporaryFile")
+{
+ ScopedCurrentDirectoryChange _;
+
+ SUBCASE("DeleteOnClose")
+ {
+ TemporaryFile TmpFile;
+ std::error_code Ec;
+ TmpFile.CreateTemporary(std::filesystem::current_path(), Ec);
+ CHECK(!Ec);
+ CHECK(std::filesystem::exists(TmpFile.GetPath()));
+ TmpFile.Close();
+ CHECK(std::filesystem::exists(TmpFile.GetPath()) == false);
+ }
+
+ SUBCASE("MoveIntoPlace")
+ {
+ TemporaryFile TmpFile;
+ std::error_code Ec;
+ TmpFile.CreateTemporary(std::filesystem::current_path(), Ec);
+ CHECK(!Ec);
+ std::filesystem::path TempPath = TmpFile.GetPath();
+ std::filesystem::path FinalPath = std::filesystem::current_path() / "final";
+ CHECK(std::filesystem::exists(TempPath));
+ TmpFile.MoveTemporaryIntoPlace(FinalPath, Ec);
+ CHECK(!Ec);
+ CHECK(std::filesystem::exists(TempPath) == false);
+ CHECK(std::filesystem::exists(FinalPath));
+ }
+}
+
+void
+basicfile_forcelink()
+{
+}
+
+#endif
+
+} // namespace zen
diff --git a/src/zenutil/cache/cachekey.cpp b/src/zenutil/cache/cachekey.cpp
new file mode 100644
index 000000000..545b47f11
--- /dev/null
+++ b/src/zenutil/cache/cachekey.cpp
@@ -0,0 +1,9 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zenutil/cache/cachekey.h>
+
+namespace zen {
+
+const CacheKey CacheKey::Empty = CacheKey{.Bucket = std::string(), .Hash = IoHash()};
+
+} // namespace zen
diff --git a/src/zenutil/cache/cachepolicy.cpp b/src/zenutil/cache/cachepolicy.cpp
new file mode 100644
index 000000000..3bca363bb
--- /dev/null
+++ b/src/zenutil/cache/cachepolicy.cpp
@@ -0,0 +1,282 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zenutil/cache/cachepolicy.h>
+
+#include <zencore/compactbinary.h>
+#include <zencore/compactbinarybuilder.h>
+#include <zencore/enumflags.h>
+#include <zencore/string.h>
+
+#include <algorithm>
+#include <unordered_map>
+
+namespace zen::Private {
+class CacheRecordPolicyShared;
+}
+
+namespace zen {
+
+using namespace std::literals;
+
+namespace DerivedData::Private {
+
+ constinit char CachePolicyDelimiter = ',';
+
+ struct CachePolicyToTextData
+ {
+ CachePolicy Policy;
+ std::string_view Text;
+ };
+
+ constinit CachePolicyToTextData CachePolicyToText[]{
+ // Flags with multiple bits are ordered by bit count to minimize token count in the text format.
+ {CachePolicy::Default, "Default"sv},
+ {CachePolicy::Remote, "Remote"sv},
+ {CachePolicy::Local, "Local"sv},
+ {CachePolicy::Store, "Store"sv},
+ {CachePolicy::Query, "Query"sv},
+ // Flags with only one bit can be in any order. Match the order in CachePolicy.
+ {CachePolicy::QueryLocal, "QueryLocal"sv},
+ {CachePolicy::QueryRemote, "QueryRemote"sv},
+ {CachePolicy::StoreLocal, "StoreLocal"sv},
+ {CachePolicy::StoreRemote, "StoreRemote"sv},
+ {CachePolicy::SkipMeta, "SkipMeta"sv},
+ {CachePolicy::SkipData, "SkipData"sv},
+ {CachePolicy::PartialRecord, "PartialRecord"sv},
+ {CachePolicy::KeepAlive, "KeepAlive"sv},
+ // None must be last because it matches every policy.
+ {CachePolicy::None, "None"sv},
+ };
+
+ constinit CachePolicy CachePolicyKnownFlags =
+ CachePolicy::Default | CachePolicy::SkipMeta | CachePolicy::SkipData | CachePolicy::PartialRecord | CachePolicy::KeepAlive;
+
+ StringBuilderBase& CachePolicyToString(StringBuilderBase& Builder, CachePolicy Policy)
+ {
+ // Mask out unknown flags. None will be written if no flags are known.
+ Policy &= CachePolicyKnownFlags;
+ for (const CachePolicyToTextData& Pair : CachePolicyToText)
+ {
+ if (EnumHasAllFlags(Policy, Pair.Policy))
+ {
+ EnumRemoveFlags(Policy, Pair.Policy);
+ Builder << Pair.Text << CachePolicyDelimiter;
+ if (Policy == CachePolicy::None)
+ {
+ break;
+ }
+ }
+ }
+ Builder.RemoveSuffix(1);
+ return Builder;
+ }
+
+ CachePolicy ParseCachePolicy(const std::string_view Text)
+ {
+ ZEN_ASSERT(!Text.empty()); // ParseCachePolicy requires a non-empty string
+ CachePolicy Policy = CachePolicy::None;
+ ForEachStrTok(Text, CachePolicyDelimiter, [&Policy, Index = int32_t(0)](const std::string_view& Token) mutable {
+ const int32_t EndIndex = Index;
+ for (; size_t(Index) < sizeof(CachePolicyToText) / sizeof(CachePolicyToText[0]); ++Index)
+ {
+ if (CachePolicyToText[Index].Text == Token)
+ {
+ Policy |= CachePolicyToText[Index].Policy;
+ ++Index;
+ return true;
+ }
+ }
+ for (Index = 0; Index < EndIndex; ++Index)
+ {
+ if (CachePolicyToText[Index].Text == Token)
+ {
+ Policy |= CachePolicyToText[Index].Policy;
+ ++Index;
+ return true;
+ }
+ }
+ return true;
+ });
+ return Policy;
+ }
+
+} // namespace DerivedData::Private
+
+StringBuilderBase&
+operator<<(StringBuilderBase& Builder, CachePolicy Policy)
+{
+ return DerivedData::Private::CachePolicyToString(Builder, Policy);
+}
+
+CachePolicy
+ParseCachePolicy(std::string_view Text)
+{
+ return DerivedData::Private::ParseCachePolicy(Text);
+}
+
+CachePolicy
+ConvertToUpstream(CachePolicy Policy)
+{
+ // Set Local flags equal to downstream's Remote flags.
+ // Delete Skip flags if StoreLocal is true, otherwise use the downstream value.
+ // Use the downstream value for all other flags.
+
+ CachePolicy UpstreamPolicy = CachePolicy::None;
+
+ if (EnumHasAllFlags(Policy, CachePolicy::QueryRemote))
+ {
+ UpstreamPolicy |= CachePolicy::QueryLocal;
+ }
+
+ if (EnumHasAllFlags(Policy, CachePolicy::StoreRemote))
+ {
+ UpstreamPolicy |= CachePolicy::StoreLocal;
+ }
+
+ if (!EnumHasAllFlags(Policy, CachePolicy::StoreLocal))
+ {
+ UpstreamPolicy |= (Policy & (CachePolicy::SkipData | CachePolicy::SkipMeta));
+ }
+
+ UpstreamPolicy |= Policy & ~(CachePolicy::Local | CachePolicy::SkipData | CachePolicy::SkipMeta);
+
+ return UpstreamPolicy;
+}
+
+class Private::CacheRecordPolicyShared final : public Private::ICacheRecordPolicyShared
+{
+public:
+ inline void AddValuePolicy(const CacheValuePolicy& Value) final
+ {
+ ZEN_ASSERT(Value.Id); // Failed to add value policy because the ID is null.
+ const auto Insert =
+ std::lower_bound(Values.begin(), Values.end(), Value, [](const CacheValuePolicy& Existing, const CacheValuePolicy& New) {
+ return Existing.Id < New.Id;
+ });
+ ZEN_ASSERT(
+ !(Insert < Values.end() &&
+ Insert->Id == Value.Id)); // Failed to add value policy with ID %s because it has an existing value policy with that ID. ")
+ Values.insert(Insert, Value);
+ }
+
+ inline std::span<const CacheValuePolicy> GetValuePolicies() const final { return Values; }
+
+private:
+ std::vector<CacheValuePolicy> Values;
+};
+
+CachePolicy
+CacheRecordPolicy::GetValuePolicy(const Oid& Id) const
+{
+ if (Shared)
+ {
+ const std::span<const CacheValuePolicy> Values = Shared->GetValuePolicies();
+ const auto Iter =
+ std::lower_bound(Values.begin(), Values.end(), Id, [](const CacheValuePolicy& A, const Oid& B) { return A.Id < B; });
+ if (Iter != Values.end() && Iter->Id == Id)
+ {
+ return Iter->Policy;
+ }
+ }
+ return DefaultValuePolicy;
+}
+
+void
+CacheRecordPolicy::Save(CbWriter& Writer) const
+{
+ Writer.BeginObject();
+ // The RecordPolicy is calculated from the ValuePolicies and does not need to be saved separately.
+ Writer.AddString("BasePolicy"sv, WriteToString<128>(GetBasePolicy()));
+ if (!IsUniform())
+ {
+ Writer.BeginArray("ValuePolicies"sv);
+ for (const CacheValuePolicy& Value : GetValuePolicies())
+ {
+ Writer.BeginObject();
+ Writer.AddObjectId("Id"sv, Value.Id);
+ Writer.AddString("Policy"sv, WriteToString<128>(Value.Policy));
+ Writer.EndObject();
+ }
+ Writer.EndArray();
+ }
+ Writer.EndObject();
+}
+
+OptionalCacheRecordPolicy
+CacheRecordPolicy::Load(const CbObjectView Object)
+{
+ std::string_view BasePolicyText = Object["BasePolicy"sv].AsString();
+ if (BasePolicyText.empty())
+ {
+ return {};
+ }
+
+ CacheRecordPolicyBuilder Builder(ParseCachePolicy(BasePolicyText));
+ for (CbFieldView ValueField : Object["ValuePolicies"sv])
+ {
+ const CbObjectView Value = ValueField.AsObjectView();
+ const Oid Id = Value["Id"sv].AsObjectId();
+ const std::string_view PolicyText = Value["Policy"sv].AsString();
+ if (!Id || PolicyText.empty())
+ {
+ return {};
+ }
+ CachePolicy Policy = ParseCachePolicy(PolicyText);
+ if (EnumHasAnyFlags(Policy, ~CacheValuePolicy::PolicyMask))
+ {
+ return {};
+ }
+ Builder.AddValuePolicy(Id, Policy);
+ }
+
+ return Builder.Build();
+}
+
+CacheRecordPolicy
+CacheRecordPolicy::ConvertToUpstream() const
+{
+ CacheRecordPolicyBuilder Builder(zen::ConvertToUpstream(GetBasePolicy()));
+ for (const CacheValuePolicy& ValuePolicy : GetValuePolicies())
+ {
+ Builder.AddValuePolicy(ValuePolicy.Id, zen::ConvertToUpstream(ValuePolicy.Policy));
+ }
+ return Builder.Build();
+}
+
+void
+CacheRecordPolicyBuilder::AddValuePolicy(const CacheValuePolicy& Value)
+{
+ ZEN_ASSERT(!EnumHasAnyFlags(Value.Policy,
+ ~Value.PolicyMask)); // Value policy contains flags that only make sense on the record policy. Policy: %s
+ if (Value.Policy == (BasePolicy & Value.PolicyMask))
+ {
+ return;
+ }
+ if (!Shared)
+ {
+ Shared = new Private::CacheRecordPolicyShared;
+ }
+ Shared->AddValuePolicy(Value);
+}
+
+CacheRecordPolicy
+CacheRecordPolicyBuilder::Build()
+{
+ CacheRecordPolicy Policy(BasePolicy);
+ if (Shared)
+ {
+ const auto Add = [](const CachePolicy A, const CachePolicy B) {
+ return ((A | B) & ~CachePolicy::SkipData) | ((A & B) & CachePolicy::SkipData);
+ };
+ const std::span<const CacheValuePolicy> Values = Shared->GetValuePolicies();
+ Policy.RecordPolicy = BasePolicy;
+ for (const CacheValuePolicy& ValuePolicy : Values)
+ {
+ Policy.RecordPolicy = Add(Policy.RecordPolicy, ValuePolicy.Policy);
+ }
+ Policy.Shared = std::move(Shared);
+ }
+ return Policy;
+}
+
+} // namespace zen
diff --git a/src/zenutil/cache/cacherequests.cpp b/src/zenutil/cache/cacherequests.cpp
new file mode 100644
index 000000000..4c865ec22
--- /dev/null
+++ b/src/zenutil/cache/cacherequests.cpp
@@ -0,0 +1,1643 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zenutil/cache/cacherequests.h>
+
+#include <zencore/compactbinary.h>
+#include <zencore/compactbinarybuilder.h>
+#include <zencore/compactbinarypackage.h>
+#include <zencore/fmtutils.h>
+#include <zencore/logging.h>
+#include <zencore/zencore.h>
+
+#include <string>
+#include <string_view>
+
+#if ZEN_WITH_TESTS
+# include <zencore/testing.h>
+#endif
+
+namespace zen {
+
+namespace cacherequests {
+
+ namespace {
+ constinit AsciiSet ValidNamespaceNameCharactersSet{"abcdefghijklmnopqrstuvwxyz0123456789-_.ABCDEFGHIJKLMNOPQRSTUVWXYZ"};
+ constinit AsciiSet ValidBucketNameCharactersSet{"abcdefghijklmnopqrstuvwxyz0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"};
+
+ std::optional<std::string> GetValidNamespaceName(std::string_view Name)
+ {
+ if (Name.empty())
+ {
+ ZEN_WARN("Namespace is invalid, empty namespace is not allowed");
+ return {};
+ }
+
+ if (Name.length() > 64)
+ {
+ ZEN_WARN("Namespace '{}' is invalid, length exceeds 64 characters", Name);
+ return {};
+ }
+
+ if (!AsciiSet::HasOnly(Name, ValidNamespaceNameCharactersSet))
+ {
+ ZEN_WARN("Namespace '{}' is invalid, invalid characters detected", Name);
+ return {};
+ }
+
+ return ToLower(Name);
+ }
+
+ std::optional<std::string> GetValidBucketName(std::string_view Name)
+ {
+ if (Name.empty())
+ {
+ ZEN_WARN("Bucket name is invalid, empty bucket name is not allowed");
+ return {};
+ }
+
+ if (!AsciiSet::HasOnly(Name, ValidBucketNameCharactersSet))
+ {
+ ZEN_WARN("Bucket name '{}' is invalid, invalid characters detected", Name);
+ return {};
+ }
+
+ return ToLower(Name);
+ }
+
+ std::optional<IoHash> GetValidIoHash(std::string_view Hash)
+ {
+ if (Hash.length() != IoHash::StringLength)
+ {
+ return {};
+ }
+
+ IoHash KeyHash;
+ if (!ParseHexBytes(Hash.data(), Hash.size(), KeyHash.Hash))
+ {
+ return {};
+ }
+ return KeyHash;
+ }
+
+ std::optional<CacheRecordPolicy> Convert(const OptionalCacheRecordPolicy& Policy)
+ {
+ return Policy.IsValid() ? Policy.Get() : std::optional<CacheRecordPolicy>{};
+ };
+ } // namespace
+
+ std::optional<std::string> GetRequestNamespace(const CbObjectView& Params)
+ {
+ CbFieldView NamespaceField = Params["Namespace"];
+ if (!NamespaceField)
+ {
+ return std::string("!default!"); // ZenCacheStore::DefaultNamespace);
+ }
+
+ if (NamespaceField.HasError())
+ {
+ return {};
+ }
+ if (!NamespaceField.IsString())
+ {
+ return {};
+ }
+ return GetValidNamespaceName(NamespaceField.AsString());
+ }
+
+ bool GetRequestCacheKey(const CbObjectView& KeyView, CacheKey& Key)
+ {
+ CbFieldView BucketField = KeyView["Bucket"];
+ if (BucketField.HasError())
+ {
+ return false;
+ }
+ if (!BucketField.IsString())
+ {
+ return false;
+ }
+ std::optional<std::string> Bucket = GetValidBucketName(BucketField.AsString());
+ if (!Bucket.has_value())
+ {
+ return false;
+ }
+ CbFieldView HashField = KeyView["Hash"];
+ if (HashField.HasError())
+ {
+ return false;
+ }
+ if (!HashField.IsHash())
+ {
+ return false;
+ }
+ Key.Bucket = *Bucket;
+ Key.Hash = HashField.AsHash();
+ return true;
+ }
+
+ void WriteCacheRequestKey(CbObjectWriter& Writer, const CacheKey& Value)
+ {
+ Writer.BeginObject("Key");
+ {
+ Writer << "Bucket" << Value.Bucket;
+ Writer << "Hash" << Value.Hash;
+ }
+ Writer.EndObject();
+ }
+
+ void WriteOptionalCacheRequestPolicy(CbObjectWriter& Writer, std::string_view FieldName, const std::optional<CacheRecordPolicy>& Policy)
+ {
+ if (Policy)
+ {
+ Writer.SetName(FieldName);
+ Policy->Save(Writer);
+ }
+ }
+
+ std::optional<CachePolicy> GetCachePolicy(CbObjectView ObjectView, std::string_view FieldName)
+ {
+ std::string_view DefaultPolicyText = ObjectView[FieldName].AsString();
+ if (DefaultPolicyText.empty())
+ {
+ return {};
+ }
+ return ParseCachePolicy(DefaultPolicyText);
+ }
+
+ void WriteCachePolicy(CbObjectWriter& Writer, std::string_view FieldName, const std::optional<CachePolicy>& Policy)
+ {
+ if (Policy)
+ {
+ Writer << FieldName << WriteToString<128>(*Policy);
+ }
+ }
+
+ bool PutCacheRecordsRequest::Parse(const CbPackage& Package)
+ {
+ CbObjectView BatchObject = Package.GetObject();
+ ZEN_ASSERT(BatchObject["Method"].AsString() == "PutCacheRecords");
+ AcceptMagic = BatchObject["AcceptType"].AsUInt32(0);
+
+ CbObjectView Params = BatchObject["Params"].AsObjectView();
+ std::optional<std::string> RequestNamespace = GetRequestNamespace(Params);
+ if (!RequestNamespace)
+ {
+ return false;
+ }
+ Namespace = *RequestNamespace;
+ DefaultPolicy = GetCachePolicy(Params, "DefaultPolicy").value_or(CachePolicy::Default);
+
+ CbArrayView RequestFieldArray = Params["Requests"].AsArrayView();
+ Requests.resize(RequestFieldArray.Num());
+ for (size_t RequestIndex = 0; CbFieldView RequestField : RequestFieldArray)
+ {
+ CbObjectView RequestObject = RequestField.AsObjectView();
+ CbObjectView RecordObject = RequestObject["Record"].AsObjectView();
+ CbObjectView KeyView = RecordObject["Key"].AsObjectView();
+
+ PutCacheRecordRequest& Request = Requests[RequestIndex++];
+
+ if (!GetRequestCacheKey(KeyView, Request.Key))
+ {
+ return false;
+ }
+
+ Request.Policy = Convert(CacheRecordPolicy::Load(RequestObject["Policy"].AsObjectView()));
+
+ std::unordered_map<IoHash, size_t, IoHash::Hasher> RawHashToAttachmentIndex;
+
+ CbArrayView ValuesArray = RecordObject["Values"].AsArrayView();
+ Request.Values.resize(ValuesArray.Num());
+ RawHashToAttachmentIndex.reserve(ValuesArray.Num());
+ for (size_t Index = 0; CbFieldView Value : ValuesArray)
+ {
+ CbObjectView ObjectView = Value.AsObjectView();
+ IoHash AttachmentHash = ObjectView["RawHash"].AsHash();
+ RawHashToAttachmentIndex[AttachmentHash] = Index;
+ Request.Values[Index++] = {.Id = ObjectView["Id"].AsObjectId(), .RawHash = AttachmentHash};
+ }
+
+ RecordObject.IterateAttachments([&](CbFieldView HashView) {
+ const IoHash ValueHash = HashView.AsHash();
+ if (const CbAttachment* Attachment = Package.FindAttachment(ValueHash))
+ {
+ if (Attachment->IsCompressedBinary())
+ {
+ auto It = RawHashToAttachmentIndex.find(ValueHash);
+ ZEN_ASSERT(It != RawHashToAttachmentIndex.end());
+ PutCacheRecordRequestValue& Value = Request.Values[It->second];
+ ZEN_ASSERT(Value.RawHash == ValueHash);
+ Value.Body = Attachment->AsCompressedBinary();
+ ZEN_ASSERT_SLOW(Value.Body.DecodeRawHash() == Value.RawHash);
+ }
+ }
+ });
+ }
+
+ return true;
+ }
+
+ bool PutCacheRecordsRequest::Format(CbPackage& OutPackage) const
+ {
+ CbObjectWriter Writer;
+ Writer << "Method"
+ << "PutCacheRecords";
+ if (AcceptMagic != 0)
+ {
+ Writer << "Accept" << AcceptMagic;
+ }
+
+ Writer.BeginObject("Params");
+ {
+ Writer << "DefaultPolicy" << WriteToString<128>(DefaultPolicy);
+ Writer << "Namespace" << Namespace;
+
+ Writer.BeginArray("Requests");
+ for (const PutCacheRecordRequest& RecordRequest : Requests)
+ {
+ Writer.BeginObject();
+ {
+ Writer.BeginObject("Record");
+ {
+ WriteCacheRequestKey(Writer, RecordRequest.Key);
+ Writer.BeginArray("Values");
+ for (const PutCacheRecordRequestValue& Value : RecordRequest.Values)
+ {
+ Writer.BeginObject();
+ {
+ Writer.AddObjectId("Id", Value.Id);
+ const CompressedBuffer& Buffer = Value.Body;
+ if (Buffer)
+ {
+ IoHash AttachmentHash = Buffer.DecodeRawHash(); // TODO: Slow!
+ Writer.AddBinaryAttachment("RawHash", AttachmentHash);
+ OutPackage.AddAttachment(CbAttachment(Buffer, AttachmentHash));
+ Writer.AddInteger("RawSize", Buffer.DecodeRawSize()); // TODO: Slow!
+ }
+ else
+ {
+ if (Value.RawHash == IoHash::Zero)
+ {
+ return false;
+ }
+ Writer.AddBinaryAttachment("RawHash", Value.RawHash);
+ }
+ }
+ Writer.EndObject();
+ }
+ Writer.EndArray();
+ }
+ Writer.EndObject();
+ WriteOptionalCacheRequestPolicy(Writer, "Policy", RecordRequest.Policy);
+ }
+ Writer.EndObject();
+ }
+ Writer.EndArray();
+ }
+ Writer.EndObject();
+ OutPackage.SetObject(Writer.Save());
+
+ return true;
+ }
+
+ bool PutCacheRecordsResult::Parse(const CbPackage& Package)
+ {
+ CbArrayView ResultsArray = Package.GetObject()["Result"].AsArrayView();
+ if (!ResultsArray)
+ {
+ return false;
+ }
+ CbFieldViewIterator It = ResultsArray.CreateViewIterator();
+ while (It.HasValue())
+ {
+ Success.push_back(It.AsBool());
+ It++;
+ }
+ return true;
+ }
+
+ bool PutCacheRecordsResult::Format(CbPackage& OutPackage) const
+ {
+ CbObjectWriter ResponseObject;
+ ResponseObject.BeginArray("Result");
+ for (bool Value : Success)
+ {
+ ResponseObject.AddBool(Value);
+ }
+ ResponseObject.EndArray();
+
+ OutPackage.SetObject(ResponseObject.Save());
+ return true;
+ }
+
+ bool GetCacheRecordsRequest::Parse(const CbObjectView& RpcRequest)
+ {
+ ZEN_ASSERT(RpcRequest["Method"].AsString() == "GetCacheRecords");
+ AcceptMagic = RpcRequest["AcceptType"].AsUInt32(0);
+ AcceptOptions = RpcRequest["AcceptFlags"].AsUInt16(0);
+ ProcessPid = RpcRequest["Pid"].AsInt32(0);
+
+ CbObjectView Params = RpcRequest["Params"].AsObjectView();
+ std::optional<std::string> RequestNamespace = GetRequestNamespace(Params);
+ if (!RequestNamespace)
+ {
+ return false;
+ }
+
+ Namespace = *RequestNamespace;
+ DefaultPolicy = GetCachePolicy(Params, "DefaultPolicy").value_or(CachePolicy::Default);
+
+ CbArrayView RequestsArray = Params["Requests"].AsArrayView();
+ Requests.reserve(RequestsArray.Num());
+ for (CbFieldView RequestField : RequestsArray)
+ {
+ CbObjectView RequestObject = RequestField.AsObjectView();
+ CbObjectView KeyObject = RequestObject["Key"].AsObjectView();
+
+ GetCacheRecordRequest& Request = Requests.emplace_back();
+
+ if (!GetRequestCacheKey(KeyObject, Request.Key))
+ {
+ return false;
+ }
+
+ Request.Policy = Convert(CacheRecordPolicy::Load(RequestObject["Policy"].AsObjectView()));
+ }
+ return true;
+ }
+
+ bool GetCacheRecordsRequest::Parse(const CbPackage& RpcRequest) { return Parse(RpcRequest.GetObject()); }
+
+ bool GetCacheRecordsRequest::Format(CbObjectWriter& Writer, const std::span<const size_t> OptionalRecordFilter) const
+ {
+ Writer << "Method"
+ << "GetCacheRecords";
+ if (AcceptMagic != 0)
+ {
+ Writer << "Accept" << AcceptMagic;
+ }
+ if (AcceptOptions != 0)
+ {
+ Writer << "AcceptFlags" << AcceptOptions;
+ }
+ if (ProcessPid != 0)
+ {
+ Writer << "Pid" << ProcessPid;
+ }
+
+ Writer.BeginObject("Params");
+ {
+ Writer << "DefaultPolicy" << WriteToString<128>(DefaultPolicy);
+ Writer << "Namespace" << Namespace;
+ Writer.BeginArray("Requests");
+ if (OptionalRecordFilter.empty())
+ {
+ for (const GetCacheRecordRequest& RecordRequest : Requests)
+ {
+ Writer.BeginObject();
+ {
+ WriteCacheRequestKey(Writer, RecordRequest.Key);
+ WriteOptionalCacheRequestPolicy(Writer, "Policy", RecordRequest.Policy);
+ }
+ Writer.EndObject();
+ }
+ }
+ else
+ {
+ for (size_t Index : OptionalRecordFilter)
+ {
+ const GetCacheRecordRequest& RecordRequest = Requests[Index];
+ Writer.BeginObject();
+ {
+ WriteCacheRequestKey(Writer, RecordRequest.Key);
+ WriteOptionalCacheRequestPolicy(Writer, "Policy", RecordRequest.Policy);
+ }
+ Writer.EndObject();
+ }
+ }
+ Writer.EndArray();
+ }
+ Writer.EndObject();
+
+ return true;
+ }
+
+ bool GetCacheRecordsRequest::Format(CbPackage& OutPackage, const std::span<const size_t> OptionalRecordFilter) const
+ {
+ CbObjectWriter Writer;
+ if (!Format(Writer, OptionalRecordFilter))
+ {
+ return false;
+ }
+ OutPackage.SetObject(Writer.Save());
+ return true;
+ }
+
+ bool GetCacheRecordsResult::Parse(const CbPackage& Package, const std::span<const size_t> OptionalRecordResultIndexes)
+ {
+ CbObject ResponseObject = Package.GetObject();
+ CbArrayView ResultsArray = ResponseObject["Result"].AsArrayView();
+ if (!ResultsArray)
+ {
+ return false;
+ }
+
+ Results.reserve(ResultsArray.Num());
+ if (!OptionalRecordResultIndexes.empty() && ResultsArray.Num() != OptionalRecordResultIndexes.size())
+ {
+ return false;
+ }
+ for (size_t Index = 0; CbFieldView RecordView : ResultsArray)
+ {
+ size_t ResultIndex = OptionalRecordResultIndexes.empty() ? Index : OptionalRecordResultIndexes[Index];
+ Index++;
+
+ if (Results.size() <= ResultIndex)
+ {
+ Results.resize(ResultIndex + 1);
+ }
+ if (RecordView.IsNull())
+ {
+ continue;
+ }
+ Results[ResultIndex] = GetCacheRecordResult{};
+ GetCacheRecordResult& Request = Results[ResultIndex].value();
+ CbObjectView RecordObject = RecordView.AsObjectView();
+ CbObjectView KeyObject = RecordObject["Key"].AsObjectView();
+ if (!GetRequestCacheKey(KeyObject, Request.Key))
+ {
+ return false;
+ }
+
+ CbArrayView ValuesArray = RecordObject["Values"].AsArrayView();
+ Request.Values.reserve(ValuesArray.Num());
+ for (CbFieldView Value : ValuesArray)
+ {
+ CbObjectView ValueObject = Value.AsObjectView();
+ IoHash RawHash = ValueObject["RawHash"].AsHash();
+ uint64_t RawSize = ValueObject["RawSize"].AsUInt64();
+ Oid Id = ValueObject["Id"].AsObjectId();
+ const CbAttachment* Attachment = Package.FindAttachment(RawHash);
+ if (!Attachment)
+ {
+ Request.Values.push_back({.Id = Id, .RawHash = RawHash, .RawSize = RawSize, .Body = {}});
+ continue;
+ }
+ if (!Attachment->IsCompressedBinary())
+ {
+ return false;
+ }
+ Request.Values.push_back({.Id = Id, .RawHash = RawHash, .RawSize = RawSize, .Body = Attachment->AsCompressedBinary()});
+ }
+ }
+ return true;
+ }
+
+ bool GetCacheRecordsResult::Format(CbPackage& OutPackage) const
+ {
+ CbObjectWriter Writer;
+
+ Writer.BeginArray("Result");
+ for (const std::optional<GetCacheRecordResult>& RecordResult : Results)
+ {
+ if (!RecordResult.has_value())
+ {
+ Writer.AddNull();
+ continue;
+ }
+ Writer.BeginObject();
+ WriteCacheRequestKey(Writer, RecordResult->Key);
+
+ Writer.BeginArray("Values");
+ for (const GetCacheRecordResultValue& Value : RecordResult->Values)
+ {
+ IoHash AttachmentHash = Value.Body ? Value.Body.DecodeRawHash() : Value.RawHash;
+ Writer.BeginObject();
+ {
+ Writer.AddObjectId("Id", Value.Id);
+ Writer.AddHash("RawHash", AttachmentHash);
+ Writer.AddInteger("RawSize", Value.Body ? Value.Body.DecodeRawSize() : Value.RawSize);
+ }
+ Writer.EndObject();
+ if (Value.Body)
+ {
+ OutPackage.AddAttachment(CbAttachment(Value.Body, AttachmentHash));
+ }
+ }
+
+ Writer.EndArray();
+ Writer.EndObject();
+ }
+ Writer.EndArray();
+
+ OutPackage.SetObject(Writer.Save());
+ return true;
+ }
+
+ bool PutCacheValuesRequest::Parse(const CbPackage& Package)
+ {
+ CbObjectView BatchObject = Package.GetObject();
+ ZEN_ASSERT(BatchObject["Method"].AsString() == "PutCacheValues");
+ AcceptMagic = BatchObject["AcceptType"].AsUInt32(0);
+
+ CbObjectView Params = BatchObject["Params"].AsObjectView();
+ std::optional<std::string> RequestNamespace = cacherequests::GetRequestNamespace(Params);
+ if (!RequestNamespace)
+ {
+ return false;
+ }
+
+ Namespace = *RequestNamespace;
+ DefaultPolicy = GetCachePolicy(Params, "DefaultPolicy").value_or(CachePolicy::Default);
+
+ CbArrayView RequestsArray = Params["Requests"].AsArrayView();
+ Requests.reserve(RequestsArray.Num());
+ for (CbFieldView RequestField : RequestsArray)
+ {
+ CbObjectView RequestObject = RequestField.AsObjectView();
+ CbObjectView KeyObject = RequestObject["Key"].AsObjectView();
+
+ PutCacheValueRequest& Request = Requests.emplace_back();
+
+ if (!GetRequestCacheKey(KeyObject, Request.Key))
+ {
+ return false;
+ }
+
+ Request.RawHash = RequestObject["RawHash"].AsBinaryAttachment();
+ Request.Policy = GetCachePolicy(RequestObject, "Policy");
+
+ if (const CbAttachment* Attachment = Package.FindAttachment(Request.RawHash))
+ {
+ if (!Attachment->IsCompressedBinary())
+ {
+ return false;
+ }
+ Request.Body = Attachment->AsCompressedBinary();
+ }
+ }
+ return true;
+ }
+
+ bool PutCacheValuesRequest::Format(CbPackage& OutPackage) const
+ {
+ CbObjectWriter Writer;
+ Writer << "Method"
+ << "PutCacheValues";
+ if (AcceptMagic != 0)
+ {
+ Writer << "Accept" << AcceptMagic;
+ }
+
+ Writer.BeginObject("Params");
+ {
+ Writer << "DefaultPolicy" << WriteToString<128>(DefaultPolicy);
+ Writer << "Namespace" << Namespace;
+
+ Writer.BeginArray("Requests");
+ for (const PutCacheValueRequest& ValueRequest : Requests)
+ {
+ Writer.BeginObject();
+ {
+ WriteCacheRequestKey(Writer, ValueRequest.Key);
+ if (ValueRequest.Body)
+ {
+ IoHash AttachmentHash = ValueRequest.Body.DecodeRawHash();
+ if (ValueRequest.RawHash != IoHash::Zero && AttachmentHash != ValueRequest.RawHash)
+ {
+ return false;
+ }
+ Writer.AddBinaryAttachment("RawHash", AttachmentHash);
+ OutPackage.AddAttachment(CbAttachment(ValueRequest.Body, AttachmentHash));
+ }
+ else if (ValueRequest.RawHash != IoHash::Zero)
+ {
+ Writer.AddBinaryAttachment("RawHash", ValueRequest.RawHash);
+ }
+ else
+ {
+ return false;
+ }
+ WriteCachePolicy(Writer, "Policy", ValueRequest.Policy);
+ }
+ Writer.EndObject();
+ }
+ Writer.EndArray();
+ }
+ Writer.EndObject();
+
+ OutPackage.SetObject(Writer.Save());
+ return true;
+ }
+
+ bool PutCacheValuesResult::Parse(const CbPackage& Package)
+ {
+ CbArrayView ResultsArray = Package.GetObject()["Result"].AsArrayView();
+ if (!ResultsArray)
+ {
+ return false;
+ }
+ CbFieldViewIterator It = ResultsArray.CreateViewIterator();
+ while (It.HasValue())
+ {
+ Success.push_back(It.AsBool());
+ It++;
+ }
+ return true;
+ }
+
+ bool PutCacheValuesResult::Format(CbPackage& OutPackage) const
+ {
+ if (Success.empty())
+ {
+ return false;
+ }
+ CbObjectWriter ResponseObject;
+ ResponseObject.BeginArray("Result");
+ for (bool Value : Success)
+ {
+ ResponseObject.AddBool(Value);
+ }
+ ResponseObject.EndArray();
+
+ OutPackage.SetObject(ResponseObject.Save());
+ return true;
+ }
+
+ bool GetCacheValuesRequest::Parse(const CbObjectView& BatchObject)
+ {
+ ZEN_ASSERT(BatchObject["Method"].AsString() == "GetCacheValues");
+ AcceptMagic = BatchObject["AcceptType"].AsUInt32(0);
+ AcceptOptions = BatchObject["AcceptFlags"].AsUInt16(0);
+ ProcessPid = BatchObject["Pid"].AsInt32(0);
+
+ CbObjectView Params = BatchObject["Params"].AsObjectView();
+ std::optional<std::string> RequestNamespace = cacherequests::GetRequestNamespace(Params);
+ if (!RequestNamespace)
+ {
+ return false;
+ }
+
+ Namespace = *RequestNamespace;
+ DefaultPolicy = GetCachePolicy(Params, "DefaultPolicy").value_or(CachePolicy::Default);
+
+ CbArrayView RequestsArray = Params["Requests"].AsArrayView();
+ Requests.reserve(RequestsArray.Num());
+ for (CbFieldView RequestField : RequestsArray)
+ {
+ CbObjectView RequestObject = RequestField.AsObjectView();
+ CbObjectView KeyObject = RequestObject["Key"].AsObjectView();
+
+ GetCacheValueRequest& Request = Requests.emplace_back();
+
+ if (!GetRequestCacheKey(KeyObject, Request.Key))
+ {
+ return false;
+ }
+
+ Request.Policy = GetCachePolicy(RequestObject, "Policy");
+ }
+ return true;
+ }
+
+ bool GetCacheValuesRequest::Format(CbPackage& OutPackage, const std::span<const size_t> OptionalValueFilter) const
+ {
+ CbObjectWriter Writer;
+ Writer << "Method"
+ << "GetCacheValues";
+ if (AcceptMagic != 0)
+ {
+ Writer << "Accept" << AcceptMagic;
+ }
+ if (AcceptOptions != 0)
+ {
+ Writer << "AcceptFlags" << AcceptOptions;
+ }
+ if (ProcessPid != 0)
+ {
+ Writer << "Pid" << ProcessPid;
+ }
+
+ Writer.BeginObject("Params");
+ {
+ Writer << "DefaultPolicy" << WriteToString<128>(DefaultPolicy);
+ Writer << "Namespace" << Namespace;
+
+ Writer.BeginArray("Requests");
+ if (OptionalValueFilter.empty())
+ {
+ for (const GetCacheValueRequest& ValueRequest : Requests)
+ {
+ Writer.BeginObject();
+ {
+ WriteCacheRequestKey(Writer, ValueRequest.Key);
+ WriteCachePolicy(Writer, "Policy", ValueRequest.Policy);
+ }
+ Writer.EndObject();
+ }
+ }
+ else
+ {
+ for (size_t Index : OptionalValueFilter)
+ {
+ const GetCacheValueRequest& ValueRequest = Requests[Index];
+ Writer.BeginObject();
+ {
+ WriteCacheRequestKey(Writer, ValueRequest.Key);
+ WriteCachePolicy(Writer, "Policy", ValueRequest.Policy);
+ }
+ Writer.EndObject();
+ }
+ }
+ Writer.EndArray();
+ }
+ Writer.EndObject();
+
+ OutPackage.SetObject(Writer.Save());
+ return true;
+ }
+
+ bool CacheValuesResult::Parse(const CbPackage& Package, const std::span<const size_t> OptionalValueResultIndexes)
+ {
+ CbObject ResponseObject = Package.GetObject();
+ CbArrayView ResultsArray = ResponseObject["Result"].AsArrayView();
+ if (!ResultsArray)
+ {
+ return false;
+ }
+ Results.reserve(ResultsArray.Num());
+ if (!OptionalValueResultIndexes.empty() && ResultsArray.Num() != OptionalValueResultIndexes.size())
+ {
+ return false;
+ }
+ for (size_t Index = 0; CbFieldView RecordView : ResultsArray)
+ {
+ size_t ResultIndex = OptionalValueResultIndexes.empty() ? Index : OptionalValueResultIndexes[Index];
+ Index++;
+
+ if (Results.size() <= ResultIndex)
+ {
+ Results.resize(ResultIndex + 1);
+ }
+ if (RecordView.IsNull())
+ {
+ continue;
+ }
+
+ CacheValueResult& ValueResult = Results[ResultIndex];
+ CbObjectView RecordObject = RecordView.AsObjectView();
+
+ CbFieldView RawHashField = RecordObject["RawHash"];
+ ValueResult.RawHash = RawHashField.AsHash();
+ bool Succeeded = !RawHashField.HasError();
+ if (Succeeded)
+ {
+ const CbAttachment* Attachment = Package.FindAttachment(ValueResult.RawHash);
+ ValueResult.Body = Attachment ? Attachment->AsCompressedBinary() : CompressedBuffer();
+ if (ValueResult.Body)
+ {
+ ValueResult.RawSize = ValueResult.Body.DecodeRawSize();
+ }
+ else
+ {
+ ValueResult.RawSize = RecordObject["RawSize"].AsUInt64(UINT64_MAX);
+ }
+ }
+ }
+ return true;
+ }
+
+ bool CacheValuesResult::Format(CbPackage& OutPackage) const
+ {
+ CbObjectWriter ResponseObject;
+
+ ResponseObject.BeginArray("Result");
+ for (const CacheValueResult& ValueResult : Results)
+ {
+ ResponseObject.BeginObject();
+ if (ValueResult.RawHash != IoHash::Zero)
+ {
+ ResponseObject.AddHash("RawHash", ValueResult.RawHash);
+ if (ValueResult.Body)
+ {
+ OutPackage.AddAttachment(CbAttachment(ValueResult.Body, ValueResult.RawHash));
+ }
+ else
+ {
+ ResponseObject.AddInteger("RawSize", ValueResult.RawSize);
+ }
+ }
+ ResponseObject.EndObject();
+ }
+ ResponseObject.EndArray();
+
+ OutPackage.SetObject(ResponseObject.Save());
+ return true;
+ }
+
+ bool GetCacheChunksRequest::Parse(const CbObjectView& BatchObject)
+ {
+ ZEN_ASSERT(BatchObject["Method"].AsString() == "GetCacheChunks");
+ AcceptMagic = BatchObject["AcceptType"].AsUInt32(0);
+ AcceptOptions = BatchObject["AcceptFlags"].AsUInt16(0);
+ ProcessPid = BatchObject["Pid"].AsInt32(0);
+
+ CbObjectView Params = BatchObject["Params"].AsObjectView();
+ std::optional<std::string> RequestNamespace = cacherequests::GetRequestNamespace(Params);
+ if (!RequestNamespace)
+ {
+ return false;
+ }
+
+ Namespace = *RequestNamespace;
+ DefaultPolicy = GetCachePolicy(Params, "DefaultPolicy").value_or(CachePolicy::Default);
+
+ CbArrayView RequestsArray = Params["ChunkRequests"].AsArrayView();
+ Requests.reserve(RequestsArray.Num());
+ for (CbFieldView RequestField : RequestsArray)
+ {
+ CbObjectView RequestObject = RequestField.AsObjectView();
+ CbObjectView KeyObject = RequestObject["Key"].AsObjectView();
+
+ GetCacheChunkRequest& Request = Requests.emplace_back();
+
+ if (!GetRequestCacheKey(KeyObject, Request.Key))
+ {
+ return false;
+ }
+
+ Request.ValueId = RequestObject["ValueId"].AsObjectId();
+ Request.ChunkId = RequestObject["ChunkId"].AsHash();
+ Request.RawOffset = RequestObject["RawOffset"].AsUInt64();
+ Request.RawSize = RequestObject["RawSize"].AsUInt64(UINT64_MAX);
+
+ Request.Policy = GetCachePolicy(RequestObject, "Policy");
+ }
+ return true;
+ }
+
+ bool GetCacheChunksRequest::Format(CbPackage& OutPackage) const
+ {
+ CbObjectWriter Writer;
+ Writer << "Method"
+ << "GetCacheChunks";
+ if (AcceptMagic != 0)
+ {
+ Writer << "Accept" << AcceptMagic;
+ }
+ if (AcceptOptions != 0)
+ {
+ Writer << "AcceptFlags" << AcceptOptions;
+ }
+ if (ProcessPid != 0)
+ {
+ Writer << "Pid" << ProcessPid;
+ }
+
+ Writer.BeginObject("Params");
+ {
+ Writer << "DefaultPolicy" << WriteToString<128>(DefaultPolicy);
+ Writer << "Namespace" << Namespace;
+
+ Writer.BeginArray("ChunkRequests");
+ for (const GetCacheChunkRequest& ValueRequest : Requests)
+ {
+ Writer.BeginObject();
+ {
+ WriteCacheRequestKey(Writer, ValueRequest.Key);
+
+ Writer.AddObjectId("ValueId", ValueRequest.ValueId);
+ Writer.AddHash("ChunkId", ValueRequest.ChunkId);
+ Writer.AddInteger("RawOffset", ValueRequest.RawOffset);
+ Writer.AddInteger("RawSize", ValueRequest.RawSize);
+
+ WriteCachePolicy(Writer, "Policy", ValueRequest.Policy);
+ }
+ Writer.EndObject();
+ }
+ Writer.EndArray();
+ }
+ Writer.EndObject();
+
+ OutPackage.SetObject(Writer.Save());
+ return true;
+ }
+
+ bool HttpRequestParseRelativeUri(std::string_view Key, HttpRequestData& Data)
+ {
+ std::vector<std::string_view> Tokens;
+ uint32_t TokenCount = zen::ForEachStrTok(Key, '/', [&](const std::string_view& Token) {
+ Tokens.push_back(Token);
+ return true;
+ });
+
+ switch (TokenCount)
+ {
+ case 1:
+ Data.Namespace = GetValidNamespaceName(Tokens[0]);
+ return Data.Namespace.has_value();
+ case 2:
+ {
+ std::optional<IoHash> PossibleHashKey = GetValidIoHash(Tokens[1]);
+ if (PossibleHashKey.has_value())
+ {
+ // Legacy bucket/key request
+ Data.Bucket = GetValidBucketName(Tokens[0]);
+ if (!Data.Bucket.has_value())
+ {
+ return false;
+ }
+ Data.HashKey = PossibleHashKey;
+ return true;
+ }
+ Data.Namespace = GetValidNamespaceName(Tokens[0]);
+ if (!Data.Namespace.has_value())
+ {
+ return false;
+ }
+ Data.Bucket = GetValidBucketName(Tokens[1]);
+ if (!Data.Bucket.has_value())
+ {
+ return false;
+ }
+ return true;
+ }
+ case 3:
+ {
+ std::optional<IoHash> PossibleHashKey = GetValidIoHash(Tokens[1]);
+ if (PossibleHashKey.has_value())
+ {
+ // Legacy bucket/key/valueid request
+ Data.Bucket = GetValidBucketName(Tokens[0]);
+ if (!Data.Bucket.has_value())
+ {
+ return false;
+ }
+ Data.HashKey = PossibleHashKey;
+ Data.ValueContentId = GetValidIoHash(Tokens[2]);
+ if (!Data.ValueContentId.has_value())
+ {
+ return false;
+ }
+ return true;
+ }
+ Data.Namespace = GetValidNamespaceName(Tokens[0]);
+ if (!Data.Namespace.has_value())
+ {
+ return false;
+ }
+ Data.Bucket = GetValidBucketName(Tokens[1]);
+ if (!Data.Bucket.has_value())
+ {
+ return false;
+ }
+ Data.HashKey = GetValidIoHash(Tokens[2]);
+ if (!Data.HashKey)
+ {
+ return false;
+ }
+ return true;
+ }
+ case 4:
+ {
+ Data.Namespace = GetValidNamespaceName(Tokens[0]);
+ if (!Data.Namespace.has_value())
+ {
+ return false;
+ }
+
+ Data.Bucket = GetValidBucketName(Tokens[1]);
+ if (!Data.Bucket.has_value())
+ {
+ return false;
+ }
+
+ Data.HashKey = GetValidIoHash(Tokens[2]);
+ if (!Data.HashKey.has_value())
+ {
+ return false;
+ }
+
+ Data.ValueContentId = GetValidIoHash(Tokens[3]);
+ if (!Data.ValueContentId.has_value())
+ {
+ return false;
+ }
+ return true;
+ }
+ default:
+ return false;
+ }
+ }
+
+ // bool CacheRecord::Parse(CbObjectView& Reader)
+ // {
+ // CbObjectView KeyView = Reader["Key"].AsObjectView();
+ //
+ // if (!GetRequestCacheKey(KeyView, Key))
+ // {
+ // return false;
+ // }
+ // CbArrayView ValuesArray = Reader["Values"].AsArrayView();
+ // Values.reserve(ValuesArray.Num());
+ // for (CbFieldView Value : ValuesArray)
+ // {
+ // CbObjectView ObjectView = Value.AsObjectView();
+ // Values.push_back({.Id = ObjectView["Id"].AsObjectId(),
+ // .RawHash = ObjectView["RawHash"].AsHash(),
+ // .RawSize = ObjectView["RawSize"].AsUInt64()});
+ // }
+ // return true;
+ // }
+ //
+ // bool CacheRecord::Format(CbObjectWriter& Writer) const
+ // {
+ // WriteCacheRequestKey(Writer, Key);
+ // Writer.BeginArray("Values");
+ // for (const CacheRecordValue& Value : Values)
+ // {
+ // Writer.BeginObject();
+ // {
+ // Writer.AddObjectId("Id", Value.Id);
+ // Writer.AddHash("RawHash", Value.RawHash);
+ // Writer.AddInteger("RawSize", Value.RawSize);
+ // }
+ // Writer.EndObject();
+ // }
+ // Writer.EndArray();
+ // return true;
+ // }
+
+#if ZEN_WITH_TESTS
+
+ static bool operator==(const PutCacheRecordRequestValue& Lhs, const PutCacheRecordRequestValue& Rhs)
+ {
+ const IoHash LhsRawHash = Lhs.RawHash != IoHash::Zero ? Lhs.RawHash : Lhs.Body.DecodeRawHash();
+ const IoHash RhsRawHash = Rhs.RawHash != IoHash::Zero ? Rhs.RawHash : Rhs.Body.DecodeRawHash();
+ return Lhs.Id == Rhs.Id && LhsRawHash == RhsRawHash &&
+ Lhs.Body.GetCompressed().Flatten().GetView().EqualBytes(Rhs.Body.GetCompressed().Flatten().GetView());
+ }
+
+ static bool operator==(const zen::CacheValuePolicy& Lhs, const zen::CacheValuePolicy& Rhs)
+ {
+ return (Lhs.Id == Rhs.Id) && (Lhs.Policy == Rhs.Policy);
+ }
+
+ static bool operator==(const std::span<const zen::CacheValuePolicy>& Lhs, const std::span<const zen::CacheValuePolicy>& Rhs)
+ {
+ if (Lhs.size() != Lhs.size())
+ {
+ return false;
+ }
+ for (size_t Idx = 0; Idx < Lhs.size(); ++Idx)
+ {
+ if (Lhs[Idx] != Rhs[Idx])
+ {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ static bool operator==(const zen::CacheRecordPolicy& Lhs, const zen::CacheRecordPolicy& Rhs)
+ {
+ return (Lhs.GetRecordPolicy() == Rhs.GetRecordPolicy()) && (Lhs.GetBasePolicy() == Rhs.GetBasePolicy()) &&
+ (Lhs.GetValuePolicies() == Rhs.GetValuePolicies());
+ }
+
+ static bool operator==(const std::optional<CacheRecordPolicy>& Lhs, const std::optional<CacheRecordPolicy>& Rhs)
+ {
+ return (Lhs.has_value() == Rhs.has_value()) && (!Lhs || (*Lhs == *Rhs));
+ }
+
+ static bool operator==(const PutCacheRecordRequest& Lhs, const PutCacheRecordRequest& Rhs)
+ {
+ return (Lhs.Key == Rhs.Key) && (Lhs.Values == Rhs.Values) && (Lhs.Policy == Rhs.Policy);
+ }
+
+ static bool operator==(const PutCacheRecordsRequest& Lhs, const PutCacheRecordsRequest& Rhs)
+ {
+ return (Lhs.DefaultPolicy == Rhs.DefaultPolicy) && (Lhs.Namespace == Rhs.Namespace) && (Lhs.Requests == Rhs.Requests);
+ }
+
+ static bool operator==(const PutCacheRecordsResult& Lhs, const PutCacheRecordsResult& Rhs) { return (Lhs.Success == Rhs.Success); }
+
+ static bool operator==(const GetCacheRecordRequest& Lhs, const GetCacheRecordRequest& Rhs)
+ {
+ return (Lhs.Key == Rhs.Key) && (Lhs.Policy == Rhs.Policy);
+ }
+
+ static bool operator==(const GetCacheRecordsRequest& Lhs, const GetCacheRecordsRequest& Rhs)
+ {
+ return (Lhs.DefaultPolicy == Rhs.DefaultPolicy) && (Lhs.Namespace == Rhs.Namespace) && (Lhs.Requests == Rhs.Requests);
+ }
+
+ static bool operator==(const GetCacheRecordResultValue& Lhs, const GetCacheRecordResultValue& Rhs)
+ {
+ if ((Lhs.Id != Rhs.Id) || (Lhs.RawHash != Rhs.RawHash) || (Lhs.RawSize != Rhs.RawSize))
+ {
+ return false;
+ }
+ if (bool(Lhs.Body) != bool(Rhs.Body))
+ {
+ return false;
+ }
+ if (bool(Lhs.Body) && Lhs.Body.DecodeRawHash() != Rhs.Body.DecodeRawHash())
+ {
+ return false;
+ }
+ return true;
+ }
+
+ static bool operator==(const GetCacheRecordResult& Lhs, const GetCacheRecordResult& Rhs)
+ {
+ return Lhs.Key == Rhs.Key && Lhs.Values == Rhs.Values;
+ }
+
+ static bool operator==(const std::optional<GetCacheRecordResult>& Lhs, const std::optional<GetCacheRecordResult>& Rhs)
+ {
+ if (Lhs.has_value() != Rhs.has_value())
+ {
+ return false;
+ }
+ return *Lhs == Rhs;
+ }
+
+ static bool operator==(const GetCacheRecordsResult& Lhs, const GetCacheRecordsResult& Rhs) { return Lhs.Results == Rhs.Results; }
+
+ static bool operator==(const PutCacheValueRequest& Lhs, const PutCacheValueRequest& Rhs)
+ {
+ if ((Lhs.Key != Rhs.Key) || (Lhs.RawHash != Rhs.RawHash))
+ {
+ return false;
+ }
+
+ if (bool(Lhs.Body) != bool(Rhs.Body))
+ {
+ return false;
+ }
+ if (!Lhs.Body)
+ {
+ return true;
+ }
+ return Lhs.Body.GetCompressed().Flatten().GetView().EqualBytes(Rhs.Body.GetCompressed().Flatten().GetView());
+ }
+
+ static bool operator==(const PutCacheValuesRequest& Lhs, const PutCacheValuesRequest& Rhs)
+ {
+ return (Lhs.DefaultPolicy == Rhs.DefaultPolicy) && (Lhs.Namespace == Rhs.Namespace) && (Lhs.Requests == Rhs.Requests);
+ }
+
+ static bool operator==(const PutCacheValuesResult& Lhs, const PutCacheValuesResult& Rhs) { return (Lhs.Success == Rhs.Success); }
+
+ static bool operator==(const GetCacheValueRequest& Lhs, const GetCacheValueRequest& Rhs)
+ {
+ return Lhs.Key == Rhs.Key && Lhs.Policy == Rhs.Policy;
+ }
+
+ static bool operator==(const GetCacheValuesRequest& Lhs, const GetCacheValuesRequest& Rhs)
+ {
+ return Lhs.DefaultPolicy == Rhs.DefaultPolicy && Lhs.Namespace == Rhs.Namespace && Lhs.Requests == Rhs.Requests;
+ }
+
+ static bool operator==(const CacheValueResult& Lhs, const CacheValueResult& Rhs)
+ {
+ if (Lhs.RawHash != Rhs.RawHash)
+ {
+ return false;
+ };
+ if (Lhs.Body)
+ {
+ if (!Rhs.Body)
+ {
+ return false;
+ }
+ return Lhs.Body.GetCompressed().Flatten().GetView().EqualBytes(Rhs.Body.GetCompressed().Flatten().GetView());
+ }
+ return Lhs.RawSize == Rhs.RawSize;
+ }
+
+ static bool operator==(const CacheValuesResult& Lhs, const CacheValuesResult& Rhs) { return Lhs.Results == Rhs.Results; }
+
+ static bool operator==(const GetCacheChunkRequest& Lhs, const GetCacheChunkRequest& Rhs)
+ {
+ return Lhs.Key == Rhs.Key && Lhs.ValueId == Rhs.ValueId && Lhs.ChunkId == Rhs.ChunkId && Lhs.RawOffset == Rhs.RawOffset &&
+ Lhs.RawSize == Rhs.RawSize && Lhs.Policy == Rhs.Policy;
+ }
+
+ static bool operator==(const GetCacheChunksRequest& Lhs, const GetCacheChunksRequest& Rhs)
+ {
+ return Lhs.DefaultPolicy == Rhs.DefaultPolicy && Lhs.Namespace == Rhs.Namespace && Lhs.Requests == Rhs.Requests;
+ }
+
+ static CompressedBuffer MakeCompressedBuffer(size_t Size) { return CompressedBuffer::Compress(SharedBuffer(IoBuffer(Size))); };
+
+ TEST_CASE("cacherequests.put.cache.records")
+ {
+ PutCacheRecordsRequest EmptyRequest;
+ CbPackage EmptyRequestPackage;
+ CHECK(EmptyRequest.Format(EmptyRequestPackage));
+ PutCacheRecordsRequest EmptyRequestCopy;
+ CHECK(!EmptyRequestCopy.Parse(EmptyRequestPackage)); // Namespace is required
+
+ PutCacheRecordsRequest FullRequest = {
+ .DefaultPolicy = CachePolicy::Remote,
+ .Namespace = "the_namespace",
+ .Requests = {{.Key = {.Bucket = "thebucket", .Hash = IoHash::FromHexString("177030568fdd461bf4fe5ddbf4d463e514e8178e")},
+ .Values = {{.Id = Oid::NewOid(), .Body = MakeCompressedBuffer(2134)},
+ {.Id = Oid::NewOid(), .Body = MakeCompressedBuffer(213)},
+ {.Id = Oid::NewOid(), .Body = MakeCompressedBuffer(7)}},
+ .Policy = CachePolicy::StoreLocal},
+ {.Key = {.Bucket = "thebucket", .Hash = IoHash::FromHexString("d1df59fcab06793a5f2c372d795bb907a15cab15")},
+ .Values = {{.Id = Oid::NewOid(), .Body = MakeCompressedBuffer(1234)},
+ {.Id = Oid::NewOid(), .Body = MakeCompressedBuffer(99)},
+ {.Id = Oid::NewOid(), .Body = MakeCompressedBuffer(124)}},
+ .Policy = CachePolicy::Store},
+ {.Key = {.Bucket = "theotherbucket", .Hash = IoHash::FromHexString("e1ce9e1ac8a6f5953dc14c1fa9512b804ed689df")},
+ .Values = {{.Id = Oid::NewOid(), .Body = MakeCompressedBuffer(19)},
+ {.Id = Oid::NewOid(), .Body = MakeCompressedBuffer(1248)},
+ {.Id = Oid::NewOid(), .Body = MakeCompressedBuffer(823)}}}}};
+
+ CbPackage FullRequestPackage;
+ CHECK(FullRequest.Format(FullRequestPackage));
+ PutCacheRecordsRequest FullRequestCopy;
+ CHECK(FullRequestCopy.Parse(FullRequestPackage));
+ CHECK(FullRequest == FullRequestCopy);
+
+ PutCacheRecordsResult EmptyResult;
+ CbPackage EmptyResponsePackage;
+ CHECK(EmptyResult.Format(EmptyResponsePackage));
+ PutCacheRecordsResult EmptyResultCopy;
+ CHECK(!EmptyResultCopy.Parse(EmptyResponsePackage));
+ CHECK(EmptyResult == EmptyResultCopy);
+
+ PutCacheRecordsResult FullResult = {.Success = {true, false, true, true, false}};
+ CbPackage FullResponsePackage;
+ CHECK(FullResult.Format(FullResponsePackage));
+ PutCacheRecordsResult FullResultCopy;
+ CHECK(FullResultCopy.Parse(FullResponsePackage));
+ CHECK(FullResult == FullResultCopy);
+ }
+
+ TEST_CASE("cacherequests.get.cache.records")
+ {
+ GetCacheRecordsRequest EmptyRequest;
+ CbPackage EmptyRequestPackage;
+ CHECK(EmptyRequest.Format(EmptyRequestPackage));
+ GetCacheRecordsRequest EmptyRequestCopy;
+ CHECK(!EmptyRequestCopy.Parse(EmptyRequestPackage)); // Namespace is required
+
+ GetCacheRecordsRequest FullRequest = {
+ .DefaultPolicy = CachePolicy::StoreLocal,
+ .Namespace = "other_namespace",
+ .Requests = {{.Key = {.Bucket = "finebucket", .Hash = IoHash::FromHexString("d1df59fcab06793a5f2c372d795bb907a15cab15")},
+ .Policy = CachePolicy::Local},
+ {.Key = {.Bucket = "badbucket", .Hash = IoHash::FromHexString("177030568fdd461bf4fe5ddbf4d463e514e8178e")},
+ .Policy = CachePolicy::Remote},
+ {.Key = {.Bucket = "badbucket", .Hash = IoHash::FromHexString("e1ce9e1ac8a6f5953dc14c1fa9512b804ed689df")}}}};
+
+ CbPackage FullRequestPackage;
+ CHECK(FullRequest.Format(FullRequestPackage));
+ GetCacheRecordsRequest FullRequestCopy;
+ CHECK(FullRequestCopy.Parse(FullRequestPackage));
+ CHECK(FullRequest == FullRequestCopy);
+
+ CbPackage PartialRequestPackage;
+ CHECK(FullRequest.Format(PartialRequestPackage, std::initializer_list<size_t>{0, 2}));
+ GetCacheRecordsRequest PartialRequest = FullRequest;
+ PartialRequest.Requests.erase(PartialRequest.Requests.begin() + 1);
+ GetCacheRecordsRequest PartialRequestCopy;
+ CHECK(PartialRequestCopy.Parse(PartialRequestPackage));
+ CHECK(PartialRequest == PartialRequestCopy);
+
+ GetCacheRecordsResult EmptyResult;
+ CbPackage EmptyResponsePackage;
+ CHECK(EmptyResult.Format(EmptyResponsePackage));
+ GetCacheRecordsResult EmptyResultCopy;
+ CHECK(!EmptyResultCopy.Parse(EmptyResponsePackage));
+ CHECK(EmptyResult == EmptyResultCopy);
+
+ PutCacheRecordsRequest FullPutRequest = {
+ .DefaultPolicy = CachePolicy::Remote,
+ .Namespace = "the_namespace",
+ .Requests = {{.Key = {.Bucket = "thebucket", .Hash = IoHash::FromHexString("177030568fdd461bf4fe5ddbf4d463e514e8178e")},
+ .Values = {{.Id = Oid::NewOid(), .Body = MakeCompressedBuffer(2134)},
+ {.Id = Oid::NewOid(), .Body = MakeCompressedBuffer(213)},
+ {.Id = Oid::NewOid(), .Body = MakeCompressedBuffer(7)}},
+ .Policy = CachePolicy::StoreLocal},
+ {.Key = {.Bucket = "thebucket", .Hash = IoHash::FromHexString("d1df59fcab06793a5f2c372d795bb907a15cab15")},
+ .Values = {{.Id = Oid::NewOid(), .Body = MakeCompressedBuffer(1234)},
+ {.Id = Oid::NewOid(), .Body = MakeCompressedBuffer(99)},
+ {.Id = Oid::NewOid(), .Body = MakeCompressedBuffer(124)}},
+ .Policy = CachePolicy::Store},
+ {.Key = {.Bucket = "theotherbucket", .Hash = IoHash::FromHexString("e1ce9e1ac8a6f5953dc14c1fa9512b804ed689df")},
+ .Values = {{.Id = Oid::NewOid(), .Body = MakeCompressedBuffer(19)},
+ {.Id = Oid::NewOid(), .Body = MakeCompressedBuffer(1248)},
+ {.Id = Oid::NewOid(), .Body = MakeCompressedBuffer(823)}}}}};
+
+ CbPackage FullPutRequestPackage;
+ CHECK(FullPutRequest.Format(FullPutRequestPackage));
+ PutCacheRecordsRequest FullPutRequestCopy;
+ CHECK(FullPutRequestCopy.Parse(FullPutRequestPackage));
+
+ GetCacheRecordsResult FullResult = {
+ {GetCacheRecordResult{.Key = FullPutRequestCopy.Requests[0].Key,
+ .Values = {{.Id = FullPutRequestCopy.Requests[0].Values[0].Id,
+ .RawHash = FullPutRequestCopy.Requests[0].Values[0].Body.DecodeRawHash(),
+ .RawSize = FullPutRequestCopy.Requests[0].Values[0].Body.DecodeRawSize(),
+ .Body = FullPutRequestCopy.Requests[0].Values[0].Body},
+ {.Id = FullPutRequestCopy.Requests[0].Values[1].Id,
+
+ .RawHash = FullPutRequestCopy.Requests[0].Values[1].Body.DecodeRawHash(),
+ .RawSize = FullPutRequestCopy.Requests[0].Values[1].Body.DecodeRawSize(),
+ .Body = FullPutRequestCopy.Requests[0].Values[1].Body},
+ {.Id = FullPutRequestCopy.Requests[0].Values[2].Id,
+ .RawHash = FullPutRequestCopy.Requests[0].Values[2].Body.DecodeRawHash(),
+ .RawSize = FullPutRequestCopy.Requests[0].Values[2].Body.DecodeRawSize(),
+ .Body = FullPutRequestCopy.Requests[0].Values[2].Body}}},
+ {}, // Simulate not have!
+ GetCacheRecordResult{.Key = FullPutRequestCopy.Requests[2].Key,
+ .Values = {{.Id = FullPutRequestCopy.Requests[2].Values[0].Id,
+ .RawHash = FullPutRequestCopy.Requests[2].Values[0].Body.DecodeRawHash(),
+ .RawSize = FullPutRequestCopy.Requests[2].Values[0].Body.DecodeRawSize(),
+ .Body = FullPutRequestCopy.Requests[2].Values[0].Body},
+ {.Id = FullPutRequestCopy.Requests[2].Values[1].Id,
+ .RawHash = FullPutRequestCopy.Requests[2].Values[1].Body.DecodeRawHash(),
+ .RawSize = FullPutRequestCopy.Requests[2].Values[1].Body.DecodeRawSize(),
+ .Body = {}}, // Simulate not have
+ {.Id = FullPutRequestCopy.Requests[2].Values[2].Id,
+ .RawHash = FullPutRequestCopy.Requests[2].Values[2].Body.DecodeRawHash(),
+ .RawSize = FullPutRequestCopy.Requests[2].Values[2].Body.DecodeRawSize(),
+ .Body = FullPutRequestCopy.Requests[2].Values[2].Body}}}}};
+ CbPackage FullResponsePackage;
+ CHECK(FullResult.Format(FullResponsePackage));
+ GetCacheRecordsResult FullResultCopy;
+ CHECK(FullResultCopy.Parse(FullResponsePackage));
+ CHECK(FullResult.Results[0] == FullResultCopy.Results[0]);
+ CHECK(!FullResultCopy.Results[1]);
+ CHECK(FullResult.Results[2] == FullResultCopy.Results[2]);
+
+ GetCacheRecordsResult PartialResultCopy;
+ CHECK(PartialResultCopy.Parse(FullResponsePackage, std::initializer_list<size_t>{0, 3, 4}));
+ CHECK(FullResult.Results[0] == PartialResultCopy.Results[0]);
+ CHECK(!PartialResultCopy.Results[1]);
+ CHECK(!PartialResultCopy.Results[2]);
+ CHECK(!PartialResultCopy.Results[3]);
+ CHECK(FullResult.Results[2] == PartialResultCopy.Results[4]);
+ }
+
+ TEST_CASE("cacherequests.put.cache.values")
+ {
+ PutCacheValuesRequest EmptyRequest;
+ CbPackage EmptyRequestPackage;
+ CHECK(EmptyRequest.Format(EmptyRequestPackage));
+ PutCacheValuesRequest EmptyRequestCopy;
+ CHECK(!EmptyRequestCopy.Parse(EmptyRequestPackage)); // Namespace is required
+
+ CompressedBuffer Buffers[3] = {MakeCompressedBuffer(969), MakeCompressedBuffer(3469), MakeCompressedBuffer(9)};
+ PutCacheValuesRequest FullRequest = {
+ .DefaultPolicy = CachePolicy::StoreLocal,
+ .Namespace = "other_namespace",
+ .Requests = {{.Key = {.Bucket = "finebucket", .Hash = IoHash::FromHexString("d1df59fcab06793a5f2c372d795bb907a15cab15")},
+ .RawHash = Buffers[0].DecodeRawHash(),
+ .Body = Buffers[0],
+ .Policy = CachePolicy::Local},
+ {.Key = {.Bucket = "badbucket", .Hash = IoHash::FromHexString("177030568fdd461bf4fe5ddbf4d463e514e8178e")},
+ .RawHash = Buffers[1].DecodeRawHash(),
+ .Body = Buffers[1],
+ .Policy = CachePolicy::Remote},
+ {.Key = {.Bucket = "badbucket", .Hash = IoHash::FromHexString("e1ce9e1ac8a6f5953dc14c1fa9512b804ed689df")},
+ .RawHash = Buffers[2].DecodeRawHash()}}};
+
+ CbPackage FullRequestPackage;
+ CHECK(FullRequest.Format(FullRequestPackage));
+ PutCacheValuesRequest FullRequestCopy;
+ CHECK(FullRequestCopy.Parse(FullRequestPackage));
+ CHECK(FullRequest == FullRequestCopy);
+
+ PutCacheValuesResult EmptyResult;
+ CbPackage EmptyResponsePackage;
+ CHECK(!EmptyResult.Format(EmptyResponsePackage));
+
+ PutCacheValuesResult FullResult = {.Success = {true, false, true}};
+
+ CbPackage FullResponsePackage;
+ CHECK(FullResult.Format(FullResponsePackage));
+ PutCacheValuesResult FullResultCopy;
+ CHECK(FullResultCopy.Parse(FullResponsePackage));
+ CHECK(FullResult == FullResultCopy);
+ }
+
+ TEST_CASE("cacherequests.get.cache.values")
+ {
+ GetCacheValuesRequest EmptyRequest;
+ CbPackage EmptyRequestPackage;
+ CHECK(EmptyRequest.Format(EmptyRequestPackage));
+ GetCacheValuesRequest EmptyRequestCopy;
+ CHECK(!EmptyRequestCopy.Parse(EmptyRequestPackage.GetObject())); // Namespace is required
+
+ GetCacheValuesRequest FullRequest = {
+ .DefaultPolicy = CachePolicy::StoreLocal,
+ .Namespace = "other_namespace",
+ .Requests = {{.Key = {.Bucket = "finebucket", .Hash = IoHash::FromHexString("d1df59fcab06793a5f2c372d795bb907a15cab15")},
+ .Policy = CachePolicy::Local},
+ {.Key = {.Bucket = "badbucket", .Hash = IoHash::FromHexString("177030568fdd461bf4fe5ddbf4d463e514e8178e")},
+ .Policy = CachePolicy::Remote},
+ {.Key = {.Bucket = "badbucket", .Hash = IoHash::FromHexString("e1ce9e1ac8a6f5953dc14c1fa9512b804ed689df")}}}};
+
+ CbPackage FullRequestPackage;
+ CHECK(FullRequest.Format(FullRequestPackage));
+ GetCacheValuesRequest FullRequestCopy;
+ CHECK(FullRequestCopy.Parse(FullRequestPackage.GetObject()));
+ CHECK(FullRequest == FullRequestCopy);
+
+ CbPackage PartialRequestPackage;
+ CHECK(FullRequest.Format(PartialRequestPackage, std::initializer_list<size_t>{0, 2}));
+ GetCacheValuesRequest PartialRequest = FullRequest;
+ PartialRequest.Requests.erase(PartialRequest.Requests.begin() + 1);
+ GetCacheValuesRequest PartialRequestCopy;
+ CHECK(PartialRequestCopy.Parse(PartialRequestPackage.GetObject()));
+ CHECK(PartialRequest == PartialRequestCopy);
+
+ CacheValuesResult EmptyResult;
+ CbPackage EmptyResponsePackage;
+ CHECK(EmptyResult.Format(EmptyResponsePackage));
+ CacheValuesResult EmptyResultCopy;
+ CHECK(!EmptyResultCopy.Parse(EmptyResponsePackage));
+ CHECK(EmptyResult == EmptyResultCopy);
+
+ CompressedBuffer Buffers[3][3] = {{MakeCompressedBuffer(123), MakeCompressedBuffer(321), MakeCompressedBuffer(333)},
+ {MakeCompressedBuffer(6123), MakeCompressedBuffer(8321), MakeCompressedBuffer(7333)},
+ {MakeCompressedBuffer(5123), MakeCompressedBuffer(2321), MakeCompressedBuffer(2333)}};
+ CacheValuesResult FullResult = {
+ .Results = {CacheValueResult{.RawSize = 0, .RawHash = Buffers[0][0].DecodeRawHash(), .Body = Buffers[0][0]},
+ CacheValueResult{.RawSize = 0, .RawHash = Buffers[0][1].DecodeRawHash(), .Body = Buffers[0][1]},
+ CacheValueResult{.RawSize = 0, .RawHash = Buffers[0][2].DecodeRawHash(), .Body = Buffers[0][2]},
+ CacheValueResult{.RawSize = 0, .RawHash = Buffers[2][0].DecodeRawHash(), .Body = Buffers[2][0]},
+ CacheValueResult{.RawSize = 0, .RawHash = Buffers[2][1].DecodeRawHash(), .Body = Buffers[2][1]},
+ CacheValueResult{.RawSize = Buffers[2][2].DecodeRawSize(), .RawHash = Buffers[2][2].DecodeRawHash()}}};
+ CbPackage FullResponsePackage;
+ CHECK(FullResult.Format(FullResponsePackage));
+ CacheValuesResult FullResultCopy;
+ CHECK(FullResultCopy.Parse(FullResponsePackage));
+ CHECK(FullResult == FullResultCopy);
+
+ CacheValuesResult PartialResultCopy;
+ CHECK(PartialResultCopy.Parse(FullResponsePackage, std::initializer_list<size_t>{0, 3, 4, 5, 6, 9}));
+ CHECK(PartialResultCopy.Results[0] == FullResult.Results[0]);
+ CHECK(PartialResultCopy.Results[1].RawHash == IoHash::Zero);
+ CHECK(PartialResultCopy.Results[2].RawHash == IoHash::Zero);
+ CHECK(PartialResultCopy.Results[3] == FullResult.Results[1]);
+ CHECK(PartialResultCopy.Results[4] == FullResult.Results[2]);
+ CHECK(PartialResultCopy.Results[5] == FullResult.Results[3]);
+ CHECK(PartialResultCopy.Results[6] == FullResult.Results[4]);
+ CHECK(PartialResultCopy.Results[7].RawHash == IoHash::Zero);
+ CHECK(PartialResultCopy.Results[8].RawHash == IoHash::Zero);
+ CHECK(PartialResultCopy.Results[9] == FullResult.Results[5]);
+ }
+
+ TEST_CASE("cacherequests.get.cache.chunks")
+ {
+ GetCacheChunksRequest EmptyRequest;
+ CbPackage EmptyRequestPackage;
+ CHECK(EmptyRequest.Format(EmptyRequestPackage));
+ GetCacheChunksRequest EmptyRequestCopy;
+ CHECK(!EmptyRequestCopy.Parse(EmptyRequestPackage.GetObject())); // Namespace is required
+
+ GetCacheChunksRequest FullRequest = {
+ .DefaultPolicy = CachePolicy::StoreLocal,
+ .Namespace = "other_namespace",
+ .Requests = {{.Key = {.Bucket = "finebucket", .Hash = IoHash::FromHexString("d1df59fcab06793a5f2c372d795bb907a15cab15")},
+ .ValueId = Oid::NewOid(),
+ .ChunkId = IoHash::FromHexString("ab3917854bfef7e7af2c372d795bb907a15cab15"),
+ .RawOffset = 77,
+ .RawSize = 33,
+ .Policy = CachePolicy::Local},
+ {.Key = {.Bucket = "badbucket", .Hash = IoHash::FromHexString("177030568fdd461bf4fe5ddbf4d463e514e8178e")},
+ .ValueId = Oid::NewOid(),
+ .ChunkId = IoHash::FromHexString("372d795bb907a15cab15ab3917854bfef7e7af2c"),
+ .Policy = CachePolicy::Remote},
+ {
+ .Key = {.Bucket = "badbucket", .Hash = IoHash::FromHexString("e1ce9e1ac8a6f5953dc14c1fa9512b804ed689df")},
+ .ChunkId = IoHash::FromHexString("372d795bb907a15cab15ab3917854bfef7e7af2c"),
+ }}};
+
+ CbPackage FullRequestPackage;
+ CHECK(FullRequest.Format(FullRequestPackage));
+ GetCacheChunksRequest FullRequestCopy;
+ CHECK(FullRequestCopy.Parse(FullRequestPackage.GetObject()));
+ CHECK(FullRequest == FullRequestCopy);
+
+ GetCacheChunksResult EmptyResult;
+ CbPackage EmptyResponsePackage;
+ CHECK(EmptyResult.Format(EmptyResponsePackage));
+ GetCacheChunksResult EmptyResultCopy;
+ CHECK(!EmptyResultCopy.Parse(EmptyResponsePackage));
+ CHECK(EmptyResult == EmptyResultCopy);
+
+ CompressedBuffer Buffers[3][3] = {{MakeCompressedBuffer(123), MakeCompressedBuffer(321), MakeCompressedBuffer(333)},
+ {MakeCompressedBuffer(6123), MakeCompressedBuffer(8321), MakeCompressedBuffer(7333)},
+ {MakeCompressedBuffer(5123), MakeCompressedBuffer(2321), MakeCompressedBuffer(2333)}};
+ GetCacheChunksResult FullResult = {
+ .Results = {CacheValueResult{.RawSize = 0, .RawHash = Buffers[0][0].DecodeRawHash(), .Body = Buffers[0][0]},
+ CacheValueResult{.RawSize = 0, .RawHash = Buffers[0][1].DecodeRawHash(), .Body = Buffers[0][1]},
+ CacheValueResult{.RawSize = 0, .RawHash = Buffers[0][2].DecodeRawHash(), .Body = Buffers[0][2]},
+ CacheValueResult{.RawSize = 0, .RawHash = Buffers[2][0].DecodeRawHash(), .Body = Buffers[2][0]},
+ CacheValueResult{.RawSize = 0, .RawHash = Buffers[2][1].DecodeRawHash(), .Body = Buffers[2][1]},
+ CacheValueResult{.RawSize = Buffers[2][2].DecodeRawSize(), .RawHash = Buffers[2][2].DecodeRawHash()}}};
+ CbPackage FullResponsePackage;
+ CHECK(FullResult.Format(FullResponsePackage));
+ GetCacheChunksResult FullResultCopy;
+ CHECK(FullResultCopy.Parse(FullResponsePackage));
+ CHECK(FullResult == FullResultCopy);
+ }
+
+ TEST_CASE("z$service.parse.relative.Uri")
+ {
+ HttpRequestData LegacyBucketRequestBecomesNamespaceRequest;
+ CHECK(HttpRequestParseRelativeUri("test", LegacyBucketRequestBecomesNamespaceRequest));
+ CHECK(LegacyBucketRequestBecomesNamespaceRequest.Namespace == "test");
+ CHECK(!LegacyBucketRequestBecomesNamespaceRequest.Bucket.has_value());
+ CHECK(!LegacyBucketRequestBecomesNamespaceRequest.HashKey.has_value());
+ CHECK(!LegacyBucketRequestBecomesNamespaceRequest.ValueContentId.has_value());
+
+ HttpRequestData LegacyHashKeyRequest;
+ CHECK(HttpRequestParseRelativeUri("test/0123456789abcdef12340123456789abcdef1234", LegacyHashKeyRequest));
+ CHECK(!LegacyHashKeyRequest.Namespace);
+ CHECK(LegacyHashKeyRequest.Bucket == "test");
+ CHECK(LegacyHashKeyRequest.HashKey == IoHash::FromHexString("0123456789abcdef12340123456789abcdef1234"));
+ CHECK(!LegacyHashKeyRequest.ValueContentId.has_value());
+
+ HttpRequestData LegacyValueContentIdRequest;
+ CHECK(HttpRequestParseRelativeUri("test/0123456789abcdef12340123456789abcdef1234/56789abcdef12345678956789abcdef123456789",
+ LegacyValueContentIdRequest));
+ CHECK(!LegacyValueContentIdRequest.Namespace);
+ CHECK(LegacyValueContentIdRequest.Bucket == "test");
+ CHECK(LegacyValueContentIdRequest.HashKey == IoHash::FromHexString("0123456789abcdef12340123456789abcdef1234"));
+ CHECK(LegacyValueContentIdRequest.ValueContentId == IoHash::FromHexString("56789abcdef12345678956789abcdef123456789"));
+
+ HttpRequestData V2DefaultNamespaceRequest;
+ CHECK(HttpRequestParseRelativeUri("ue4.ddc", V2DefaultNamespaceRequest));
+ CHECK(V2DefaultNamespaceRequest.Namespace == "ue4.ddc");
+ CHECK(!V2DefaultNamespaceRequest.Bucket.has_value());
+ CHECK(!V2DefaultNamespaceRequest.HashKey.has_value());
+ CHECK(!V2DefaultNamespaceRequest.ValueContentId.has_value());
+
+ HttpRequestData V2NamespaceRequest;
+ CHECK(HttpRequestParseRelativeUri("nicenamespace", V2NamespaceRequest));
+ CHECK(V2NamespaceRequest.Namespace == "nicenamespace");
+ CHECK(!V2NamespaceRequest.Bucket.has_value());
+ CHECK(!V2NamespaceRequest.HashKey.has_value());
+ CHECK(!V2NamespaceRequest.ValueContentId.has_value());
+
+ HttpRequestData V2BucketRequestWithDefaultNamespace;
+ CHECK(HttpRequestParseRelativeUri("ue4.ddc/test", V2BucketRequestWithDefaultNamespace));
+ CHECK(V2BucketRequestWithDefaultNamespace.Namespace == "ue4.ddc");
+ CHECK(V2BucketRequestWithDefaultNamespace.Bucket == "test");
+ CHECK(!V2BucketRequestWithDefaultNamespace.HashKey.has_value());
+ CHECK(!V2BucketRequestWithDefaultNamespace.ValueContentId.has_value());
+
+ HttpRequestData V2BucketRequestWithNamespace;
+ CHECK(HttpRequestParseRelativeUri("nicenamespace/test", V2BucketRequestWithNamespace));
+ CHECK(V2BucketRequestWithNamespace.Namespace == "nicenamespace");
+ CHECK(V2BucketRequestWithNamespace.Bucket == "test");
+ CHECK(!V2BucketRequestWithNamespace.HashKey.has_value());
+ CHECK(!V2BucketRequestWithNamespace.ValueContentId.has_value());
+
+ HttpRequestData V2HashKeyRequest;
+ CHECK(HttpRequestParseRelativeUri("test/0123456789abcdef12340123456789abcdef1234", V2HashKeyRequest));
+ CHECK(!V2HashKeyRequest.Namespace);
+ CHECK(V2HashKeyRequest.Bucket == "test");
+ CHECK(V2HashKeyRequest.HashKey == IoHash::FromHexString("0123456789abcdef12340123456789abcdef1234"));
+ CHECK(!V2HashKeyRequest.ValueContentId.has_value());
+
+ HttpRequestData V2ValueContentIdRequest;
+ CHECK(HttpRequestParseRelativeUri(
+ "nicenamespace/test/0123456789abcdef12340123456789abcdef1234/56789abcdef12345678956789abcdef123456789",
+ V2ValueContentIdRequest));
+ CHECK(V2ValueContentIdRequest.Namespace == "nicenamespace");
+ CHECK(V2ValueContentIdRequest.Bucket == "test");
+ CHECK(V2ValueContentIdRequest.HashKey == IoHash::FromHexString("0123456789abcdef12340123456789abcdef1234"));
+ CHECK(V2ValueContentIdRequest.ValueContentId == IoHash::FromHexString("56789abcdef12345678956789abcdef123456789"));
+
+ HttpRequestData Invalid;
+ CHECK(!HttpRequestParseRelativeUri("", Invalid));
+ CHECK(!HttpRequestParseRelativeUri("/", Invalid));
+ CHECK(!HttpRequestParseRelativeUri("bad\2_namespace", Invalid));
+ CHECK(!HttpRequestParseRelativeUri("nice/\2\1bucket", Invalid));
+ CHECK(!HttpRequestParseRelativeUri("namespace/bucket/0123456789a", Invalid));
+ CHECK(!HttpRequestParseRelativeUri("namespace/bucket/0123456789abcdef12340123456789abcdef1234/56789abcdef1234", Invalid));
+ CHECK(!HttpRequestParseRelativeUri("namespace/bucket/pppppppp89abcdef12340123456789abcdef1234", Invalid));
+ CHECK(!HttpRequestParseRelativeUri("namespace/bucket/0123456789abcdef12340123456789abcdef1234/56789abcd", Invalid));
+ CHECK(!HttpRequestParseRelativeUri(
+ "namespace/bucket/0123456789abcdef12340123456789abcdef1234/ppppppppdef12345678956789abcdef123456789",
+ Invalid));
+ }
+#endif
+} // namespace cacherequests
+
+void
+cacherequests_forcelink()
+{
+}
+
+} // namespace zen
diff --git a/src/zenutil/cache/rpcrecording.cpp b/src/zenutil/cache/rpcrecording.cpp
new file mode 100644
index 000000000..4958a27f6
--- /dev/null
+++ b/src/zenutil/cache/rpcrecording.cpp
@@ -0,0 +1,210 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zenutil/basicfile.h>
+#include <zenutil/cache/rpcrecording.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <fmt/format.h>
+#include <gsl/gsl-lite.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+namespace zen::cache {
+struct RecordedRequest
+{
+ uint64_t Offset;
+ uint64_t Length;
+ ZenContentType ContentType;
+ ZenContentType AcceptType;
+};
+
+const uint64_t RecordedRequestBlockSize = 1ull << 31u;
+
+struct RecordedRequestsWriter
+{
+ void BeginWrite(const std::filesystem::path& BasePath)
+ {
+ m_BasePath = BasePath;
+ std::filesystem::create_directories(m_BasePath);
+ }
+
+ void EndWrite()
+ {
+ RwLock::ExclusiveLockScope _(m_Lock);
+ m_BlockFiles.clear();
+
+ IoBuffer IndexBuffer(IoBuffer::Wrap, m_Entries.data(), m_Entries.size() * sizeof(RecordedRequest));
+ BasicFile IndexFile;
+ IndexFile.Open(m_BasePath / "index.bin", BasicFile::Mode::kTruncate);
+ std::error_code Ec;
+ IndexFile.WriteAll(IndexBuffer, Ec);
+ IndexFile.Close();
+ m_Entries.clear();
+ }
+
+ uint64_t WriteRequest(ZenContentType ContentType, ZenContentType AcceptType, const IoBuffer& RequestBuffer)
+ {
+ RwLock::ExclusiveLockScope Lock(m_Lock);
+ uint64_t RequestIndex = m_Entries.size();
+ RecordedRequest& Entry = m_Entries.emplace_back(
+ RecordedRequest{.Offset = ~0ull, .Length = RequestBuffer.Size(), .ContentType = ContentType, .AcceptType = AcceptType});
+ if (Entry.Length < 1 * 1024 * 1024)
+ {
+ uint32_t BlockIndex = gsl::narrow<uint32_t>((m_ChunkOffset + Entry.Length) / RecordedRequestBlockSize);
+ if (BlockIndex == m_BlockFiles.size())
+ {
+ std::unique_ptr<BasicFile>& NewBlockFile = m_BlockFiles.emplace_back(std::make_unique<BasicFile>());
+ NewBlockFile->Open(m_BasePath / fmt::format("chunks{}.bin", BlockIndex), BasicFile::Mode::kTruncate);
+ m_ChunkOffset = BlockIndex * RecordedRequestBlockSize;
+ }
+ ZEN_ASSERT(BlockIndex < m_BlockFiles.size());
+ BasicFile* BlockFile = m_BlockFiles[BlockIndex].get();
+ ZEN_ASSERT(BlockFile != nullptr);
+
+ Entry.Offset = m_ChunkOffset;
+ m_ChunkOffset = RoundUp(m_ChunkOffset + Entry.Length, 1u << 4u);
+ Lock.ReleaseNow();
+
+ std::error_code Ec;
+ BlockFile->Write(RequestBuffer.Data(), RequestBuffer.Size(), Entry.Offset - BlockIndex * RecordedRequestBlockSize, Ec);
+ if (Ec)
+ {
+ Entry.Length = 0;
+ return ~0ull;
+ }
+ return RequestIndex;
+ }
+ Lock.ReleaseNow();
+
+ BasicFile RequestFile;
+ RequestFile.Open(m_BasePath / fmt::format("request{}.bin", RequestIndex), BasicFile::Mode::kTruncate);
+ std::error_code Ec;
+ RequestFile.WriteAll(RequestBuffer, Ec);
+ if (Ec)
+ {
+ Entry.Length = 0;
+ return ~0ull;
+ }
+ return RequestIndex;
+ }
+
+ std::filesystem::path m_BasePath;
+ mutable RwLock m_Lock;
+ std::vector<RecordedRequest> m_Entries;
+ std::vector<std::unique_ptr<BasicFile>> m_BlockFiles;
+ uint64_t m_ChunkOffset;
+};
+
+struct RecordedRequestsReader
+{
+ uint64_t BeginRead(const std::filesystem::path& BasePath, bool InMemory)
+ {
+ m_BasePath = BasePath;
+ BasicFile IndexFile;
+ IndexFile.Open(m_BasePath / "index.bin", BasicFile::Mode::kRead);
+ m_Entries.resize(IndexFile.FileSize() / sizeof(RecordedRequest));
+ IndexFile.Read(m_Entries.data(), IndexFile.FileSize(), 0);
+ uint64_t MaxChunkPosition = 0;
+ for (const RecordedRequest& R : m_Entries)
+ {
+ if (R.Offset != ~0ull)
+ {
+ MaxChunkPosition = Max(MaxChunkPosition, R.Offset + R.Length);
+ }
+ }
+ uint32_t BlockCount = gsl::narrow<uint32_t>(MaxChunkPosition / RecordedRequestBlockSize) + 1;
+ m_BlockFiles.resize(BlockCount);
+ for (uint32_t BlockIndex = 0; BlockIndex < BlockCount; ++BlockIndex)
+ {
+ if (InMemory)
+ {
+ BasicFile Chunk;
+ Chunk.Open(m_BasePath / fmt::format("chunks{}.bin", BlockIndex), BasicFile::Mode::kRead);
+ m_BlockFiles[BlockIndex] = Chunk.ReadAll();
+ continue;
+ }
+ m_BlockFiles[BlockIndex] = IoBufferBuilder::MakeFromFile(m_BasePath / fmt::format("chunks{}.bin", BlockIndex));
+ }
+ return m_Entries.size();
+ }
+ void EndRead() { m_BlockFiles.clear(); }
+
+ std::pair<ZenContentType, ZenContentType> ReadRequest(uint64_t RequestIndex, IoBuffer& OutBuffer) const
+ {
+ if (RequestIndex >= m_Entries.size())
+ {
+ return {ZenContentType::kUnknownContentType, ZenContentType::kUnknownContentType};
+ }
+ const RecordedRequest& Entry = m_Entries[RequestIndex];
+ if (Entry.Length == 0)
+ {
+ return {ZenContentType::kUnknownContentType, ZenContentType::kUnknownContentType};
+ }
+ if (Entry.Offset != ~0ull)
+ {
+ uint32_t BlockIndex = gsl::narrow<uint32_t>((Entry.Offset + Entry.Length) / RecordedRequestBlockSize);
+ uint64_t ChunkOffset = Entry.Offset - (BlockIndex * RecordedRequestBlockSize);
+ OutBuffer = IoBuffer(m_BlockFiles[BlockIndex], ChunkOffset, Entry.Length);
+ return {Entry.ContentType, Entry.AcceptType};
+ }
+ OutBuffer = IoBufferBuilder::MakeFromFile(m_BasePath / fmt::format("request{}.bin", RequestIndex));
+ return {Entry.ContentType, Entry.AcceptType};
+ }
+
+ std::filesystem::path m_BasePath;
+ std::vector<RecordedRequest> m_Entries;
+ std::vector<IoBuffer> m_BlockFiles;
+};
+
+class DiskRequestRecorder : public IRpcRequestRecorder
+{
+public:
+ DiskRequestRecorder(const std::filesystem::path& BasePath) { m_RecordedRequests.BeginWrite(BasePath); }
+ virtual ~DiskRequestRecorder() { m_RecordedRequests.EndWrite(); }
+
+private:
+ virtual uint64_t RecordRequest(const ZenContentType ContentType,
+ const ZenContentType AcceptType,
+ const IoBuffer& RequestBuffer) override
+ {
+ return m_RecordedRequests.WriteRequest(ContentType, AcceptType, RequestBuffer);
+ }
+ virtual void RecordResponse(uint64_t, const ZenContentType, const IoBuffer&) override {}
+ virtual void RecordResponse(uint64_t, const ZenContentType, const CompositeBuffer&) override {}
+ RecordedRequestsWriter m_RecordedRequests;
+};
+
+class DiskRequestReplayer : public IRpcRequestReplayer
+{
+public:
+ DiskRequestReplayer(const std::filesystem::path& BasePath, bool InMemory)
+ {
+ m_RequestCount = m_RequestBuffer.BeginRead(BasePath, InMemory);
+ }
+ virtual ~DiskRequestReplayer() { m_RequestBuffer.EndRead(); }
+
+private:
+ virtual uint64_t GetRequestCount() const override { return m_RequestCount; }
+
+ virtual std::pair<ZenContentType, ZenContentType> GetRequest(uint64_t RequestIndex, IoBuffer& OutBuffer) override
+ {
+ return m_RequestBuffer.ReadRequest(RequestIndex, OutBuffer);
+ }
+ virtual ZenContentType GetResponse(uint64_t, IoBuffer&) override { return ZenContentType::kUnknownContentType; }
+
+ std::uint64_t m_RequestCount;
+ RecordedRequestsReader m_RequestBuffer;
+};
+
+std::unique_ptr<cache::IRpcRequestRecorder>
+MakeDiskRequestRecorder(const std::filesystem::path& BasePath)
+{
+ return std::make_unique<DiskRequestRecorder>(BasePath);
+}
+
+std::unique_ptr<cache::IRpcRequestReplayer>
+MakeDiskRequestReplayer(const std::filesystem::path& BasePath, bool InMemory)
+{
+ return std::make_unique<DiskRequestReplayer>(BasePath, InMemory);
+}
+
+} // namespace zen::cache
diff --git a/src/zenutil/include/zenutil/basicfile.h b/src/zenutil/include/zenutil/basicfile.h
new file mode 100644
index 000000000..877df0f92
--- /dev/null
+++ b/src/zenutil/include/zenutil/basicfile.h
@@ -0,0 +1,113 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/iobuffer.h>
+
+#include <filesystem>
+#include <functional>
+
+namespace zen {
+
+class CbObject;
+
+/**
+ * Probably the most basic file abstraction in the universe
+ *
+ * One thing of note is that there is no notion of a "current file position"
+ * in this API -- all reads and writes are done from explicit offsets in
+ * the file. This avoids concurrency issues which can occur otherwise.
+ *
+ */
+
+class BasicFile
+{
+public:
+ BasicFile() = default;
+ ~BasicFile();
+
+ BasicFile(const BasicFile&) = delete;
+ BasicFile& operator=(const BasicFile&) = delete;
+
+ enum class Mode : uint32_t
+ {
+ kRead = 0, // Opens a existing file for read only
+ kWrite = 1, // Opens (or creates) a file for read and write
+ kTruncate = 2, // Opens (or creates) a file for read and write and sets the size to zero
+ kDelete = 3, // Opens (or creates) a file for read and write allowing .DeleteFile file disposition to be set
+ kTruncateDelete =
+ 4 // Opens (or creates) a file for read and write and sets the size to zero allowing .DeleteFile file disposition to be set
+ };
+
+ void Open(const std::filesystem::path& FileName, Mode Mode);
+ void Open(const std::filesystem::path& FileName, Mode Mode, std::error_code& Ec);
+ void Close();
+ void Read(void* Data, uint64_t Size, uint64_t FileOffset);
+ void StreamFile(std::function<void(const void* Data, uint64_t Size)>&& ChunkFun);
+ void StreamByteRange(uint64_t FileOffset, uint64_t Size, std::function<void(const void* Data, uint64_t Size)>&& ChunkFun);
+ void Write(MemoryView Data, uint64_t FileOffset);
+ void Write(MemoryView Data, uint64_t FileOffset, std::error_code& Ec);
+ void Write(const void* Data, uint64_t Size, uint64_t FileOffset);
+ void Write(const void* Data, uint64_t Size, uint64_t FileOffset, std::error_code& Ec);
+ void Flush();
+ uint64_t FileSize();
+ void SetFileSize(uint64_t FileSize);
+ IoBuffer ReadAll();
+ void WriteAll(IoBuffer Data, std::error_code& Ec);
+ void* Detach();
+
+ inline void* Handle() { return m_FileHandle; }
+
+protected:
+ void* m_FileHandle = nullptr; // This is either null or valid
+private:
+};
+
+/**
+ * Simple abstraction for a temporary file
+ *
+ * Works like a regular BasicFile but implements a simple mechanism to allow creating
+ * a temporary file for writing in a directory which may later be moved atomically
+ * into the intended location after it has been fully written to.
+ *
+ */
+
+class TemporaryFile : public BasicFile
+{
+public:
+ TemporaryFile() = default;
+ ~TemporaryFile();
+
+ TemporaryFile(const TemporaryFile&) = delete;
+ TemporaryFile& operator=(const TemporaryFile&) = delete;
+
+ void Close();
+ void CreateTemporary(std::filesystem::path TempDirName, std::error_code& Ec);
+ void MoveTemporaryIntoPlace(std::filesystem::path FinalFileName, std::error_code& Ec);
+ const std::filesystem::path& GetPath() const { return m_TempPath; }
+
+private:
+ std::filesystem::path m_TempPath;
+
+ using BasicFile::Open;
+};
+
+/** Lock file abstraction
+
+ */
+
+class LockFile : protected BasicFile
+{
+public:
+ LockFile();
+ ~LockFile();
+
+ void Create(std::filesystem::path FileName, CbObject Payload, std::error_code& Ec);
+ void Update(CbObject Payload, std::error_code& Ec);
+
+private:
+};
+
+ZENCORE_API void basicfile_forcelink();
+
+} // namespace zen
diff --git a/src/zenutil/include/zenutil/cache/cache.h b/src/zenutil/include/zenutil/cache/cache.h
new file mode 100644
index 000000000..1a1dd9386
--- /dev/null
+++ b/src/zenutil/include/zenutil/cache/cache.h
@@ -0,0 +1,6 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zenutil/cache/cachekey.h>
+#include <zenutil/cache/cachepolicy.h>
diff --git a/src/zenutil/include/zenutil/cache/cachekey.h b/src/zenutil/include/zenutil/cache/cachekey.h
new file mode 100644
index 000000000..741375946
--- /dev/null
+++ b/src/zenutil/include/zenutil/cache/cachekey.h
@@ -0,0 +1,86 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/iohash.h>
+#include <zencore/string.h>
+#include <zencore/uid.h>
+
+#include <zenutil/cache/cachepolicy.h>
+
+namespace zen {
+
+struct CacheKey
+{
+ std::string Bucket;
+ IoHash Hash;
+
+ static CacheKey Create(std::string_view Bucket, const IoHash& Hash) { return {.Bucket = ToLower(Bucket), .Hash = Hash}; }
+
+ auto operator<=>(const CacheKey& that) const
+ {
+ if (auto b = caseSensitiveCompareStrings(Bucket, that.Bucket); b != std::strong_ordering::equal)
+ {
+ return b;
+ }
+ return Hash <=> that.Hash;
+ }
+
+ auto operator==(const CacheKey& that) const { return (*this <=> that) == std::strong_ordering::equal; }
+
+ static const CacheKey Empty;
+};
+
+struct CacheChunkRequest
+{
+ CacheKey Key;
+ IoHash ChunkId;
+ Oid ValueId;
+ uint64_t RawOffset = 0ull;
+ uint64_t RawSize = ~uint64_t(0);
+ CachePolicy Policy = CachePolicy::Default;
+};
+
+struct CacheKeyRequest
+{
+ CacheKey Key;
+ CacheRecordPolicy Policy;
+};
+
+struct CacheValueRequest
+{
+ CacheKey Key;
+ CachePolicy Policy = CachePolicy::Default;
+};
+
+inline bool
+operator<(const CacheChunkRequest& A, const CacheChunkRequest& B)
+{
+ if (A.Key < B.Key)
+ {
+ return true;
+ }
+ if (B.Key < A.Key)
+ {
+ return false;
+ }
+ if (A.ChunkId < B.ChunkId)
+ {
+ return true;
+ }
+ if (B.ChunkId < A.ChunkId)
+ {
+ return false;
+ }
+ if (A.ValueId < B.ValueId)
+ {
+ return true;
+ }
+ if (B.ValueId < A.ValueId)
+ {
+ return false;
+ }
+ return A.RawOffset < B.RawOffset;
+}
+
+} // namespace zen
diff --git a/src/zenutil/include/zenutil/cache/cachepolicy.h b/src/zenutil/include/zenutil/cache/cachepolicy.h
new file mode 100644
index 000000000..9a745e42c
--- /dev/null
+++ b/src/zenutil/include/zenutil/cache/cachepolicy.h
@@ -0,0 +1,227 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/compactbinary.h>
+#include <zencore/enumflags.h>
+#include <zencore/refcount.h>
+#include <zencore/string.h>
+#include <zencore/uid.h>
+
+#include <gsl/gsl-lite.hpp>
+#include <span>
+namespace zen::Private {
+class ICacheRecordPolicyShared;
+}
+namespace zen {
+
+class CbObjectView;
+class CbWriter;
+
+class OptionalCacheRecordPolicy;
+
+enum class CachePolicy : uint32_t
+{
+ /** A value with no flags. Disables access to the cache unless combined with other flags. */
+ None = 0,
+
+ /** Allow a cache request to query local caches. */
+ QueryLocal = 1 << 0,
+ /** Allow a cache request to query remote caches. */
+ QueryRemote = 1 << 1,
+ /** Allow a cache request to query any caches. */
+ Query = QueryLocal | QueryRemote,
+
+ /** Allow cache requests to query and store records and values in local caches. */
+ StoreLocal = 1 << 2,
+ /** Allow cache records and values to be stored in remote caches. */
+ StoreRemote = 1 << 3,
+ /** Allow cache records and values to be stored in any caches. */
+ Store = StoreLocal | StoreRemote,
+
+ /** Allow cache requests to query and store records and values in local caches. */
+ Local = QueryLocal | StoreLocal,
+ /** Allow cache requests to query and store records and values in remote caches. */
+ Remote = QueryRemote | StoreRemote,
+
+ /** Allow cache requests to query and store records and values in any caches. */
+ Default = Query | Store,
+
+ /** Skip fetching the data for values. */
+ SkipData = 1 << 4,
+
+ /** Skip fetching the metadata for record requests. */
+ SkipMeta = 1 << 5,
+
+ /**
+ * Partial output will be provided with the error status when a required value is missing.
+ *
+ * This is meant for cases when the missing values can be individually recovered, or rebuilt,
+ * without rebuilding the whole record. The cache automatically adds this flag when there are
+ * other cache stores that it may be able to recover missing values from.
+ *
+ * Missing values will be returned in the records, but with only the hash and size.
+ *
+ * Applying this flag for a put of a record allows a partial record to be stored.
+ */
+ PartialRecord = 1 << 6,
+
+ /**
+ * Keep records in the cache for at least the duration of the session.
+ *
+ * This is a hint that the record may be accessed again in this session. This is mainly meant
+ * to be used when subsequent accesses will not tolerate a cache miss.
+ */
+ KeepAlive = 1 << 7,
+};
+
+gsl_DEFINE_ENUM_BITMASK_OPERATORS(CachePolicy);
+/** Append a non-empty text version of the policy to the builder. */
+StringBuilderBase& operator<<(StringBuilderBase& Builder, CachePolicy Policy);
+/** Parse non-empty text written by operator<< into a policy. */
+CachePolicy ParseCachePolicy(std::string_view Text);
+/** Return input converted into the equivalent policy that the upstream should use when forwarding a put or get to an upstream server. */
+CachePolicy ConvertToUpstream(CachePolicy Policy);
+
+inline CachePolicy
+Union(CachePolicy A, CachePolicy B)
+{
+ constexpr CachePolicy InvertedFlags = CachePolicy::SkipData | CachePolicy::SkipMeta;
+ return (A & ~(InvertedFlags)) | (B & ~(InvertedFlags)) | (A & B & InvertedFlags);
+}
+
+/** A value ID and the cache policy to use for that value. */
+struct CacheValuePolicy
+{
+ Oid Id;
+ CachePolicy Policy = CachePolicy::Default;
+
+ /** Flags that are valid on a value policy. */
+ static constexpr CachePolicy PolicyMask = CachePolicy::Default | CachePolicy::SkipData;
+};
+
+/** Interface for the private implementation of the cache record policy. */
+class Private::ICacheRecordPolicyShared : public RefCounted
+{
+public:
+ virtual ~ICacheRecordPolicyShared() = default;
+ virtual void AddValuePolicy(const CacheValuePolicy& Policy) = 0;
+ virtual std::span<const CacheValuePolicy> GetValuePolicies() const = 0;
+};
+
+/**
+ * Flags to control the behavior of cache record requests, with optional overrides by value.
+ *
+ * Examples:
+ * - A base policy of None with value policy overrides of Default will fetch those values if they
+ * exist in the record, and skip data for any other values.
+ * - A base policy of Default, with value policy overrides of (Query | SkipData), will skip those
+ * values, but still check if they exist, and will load any other values.
+ */
+class CacheRecordPolicy
+{
+public:
+ /** Construct a cache record policy that uses the default policy. */
+ CacheRecordPolicy() = default;
+
+ /** Construct a cache record policy with a uniform policy for the record and every value. */
+ inline CacheRecordPolicy(CachePolicy BasePolicy)
+ : RecordPolicy(BasePolicy)
+ , DefaultValuePolicy(BasePolicy & CacheValuePolicy::PolicyMask)
+ {
+ }
+
+ /** Returns true if the record and every value use the same cache policy. */
+ inline bool IsUniform() const { return !Shared; }
+
+ /** Returns the cache policy to use for the record. */
+ inline CachePolicy GetRecordPolicy() const { return RecordPolicy; }
+
+ /** Returns the base cache policy that this was constructed from. */
+ inline CachePolicy GetBasePolicy() const { return DefaultValuePolicy | (RecordPolicy & ~CacheValuePolicy::PolicyMask); }
+
+ /** Returns the cache policy to use for the value. */
+ CachePolicy GetValuePolicy(const Oid& Id) const;
+
+ /** Returns the array of cache policy overrides for values, sorted by ID. */
+ inline std::span<const CacheValuePolicy> GetValuePolicies() const
+ {
+ return Shared ? Shared->GetValuePolicies() : std::span<const CacheValuePolicy>();
+ }
+
+ /** Saves the cache record policy to a compact binary object. */
+ void Save(CbWriter& Writer) const;
+
+ /** Loads a cache record policy from an object. */
+ static OptionalCacheRecordPolicy Load(CbObjectView Object);
+
+ /** Return *this converted into the equivalent policy that the upstream should use when forwarding a put or get to an upstream server.
+ */
+ CacheRecordPolicy ConvertToUpstream() const;
+
+private:
+ friend class CacheRecordPolicyBuilder;
+ friend class OptionalCacheRecordPolicy;
+
+ CachePolicy RecordPolicy = CachePolicy::Default;
+ CachePolicy DefaultValuePolicy = CachePolicy::Default;
+ RefPtr<const Private::ICacheRecordPolicyShared> Shared;
+};
+
+/** A cache record policy builder is used to construct a cache record policy. */
+class CacheRecordPolicyBuilder
+{
+public:
+ /** Construct a policy builder that uses the default policy as its base policy. */
+ CacheRecordPolicyBuilder() = default;
+
+ /** Construct a policy builder that uses the provided policy for the record and values with no override. */
+ inline explicit CacheRecordPolicyBuilder(CachePolicy Policy) : BasePolicy(Policy) {}
+
+ /** Adds a cache policy override for a value. */
+ void AddValuePolicy(const CacheValuePolicy& Value);
+ inline void AddValuePolicy(const Oid& Id, CachePolicy Policy) { AddValuePolicy({Id, Policy}); }
+
+ /** Build a cache record policy, which makes this builder subsequently unusable. */
+ CacheRecordPolicy Build();
+
+private:
+ CachePolicy BasePolicy = CachePolicy::Default;
+ RefPtr<Private::ICacheRecordPolicyShared> Shared;
+};
+
+/**
+ * A cache record policy that can be null.
+ *
+ * @see CacheRecordPolicy
+ */
+class OptionalCacheRecordPolicy : private CacheRecordPolicy
+{
+public:
+ inline OptionalCacheRecordPolicy() : CacheRecordPolicy(~CachePolicy::None) {}
+
+ inline OptionalCacheRecordPolicy(CacheRecordPolicy&& InOutput) : CacheRecordPolicy(std::move(InOutput)) {}
+ inline OptionalCacheRecordPolicy(const CacheRecordPolicy& InOutput) : CacheRecordPolicy(InOutput) {}
+ inline OptionalCacheRecordPolicy& operator=(CacheRecordPolicy&& InOutput)
+ {
+ CacheRecordPolicy::operator=(std::move(InOutput));
+ return *this;
+ }
+ inline OptionalCacheRecordPolicy& operator=(const CacheRecordPolicy& InOutput)
+ {
+ CacheRecordPolicy::operator=(InOutput);
+ return *this;
+ }
+
+ /** Returns the cache record policy. The caller must check for null before using this accessor. */
+ inline const CacheRecordPolicy& Get() const& { return *this; }
+ inline CacheRecordPolicy Get() && { return std::move(*this); }
+
+ inline bool IsNull() const { return RecordPolicy == ~CachePolicy::None; }
+ inline bool IsValid() const { return !IsNull(); }
+ inline explicit operator bool() const { return !IsNull(); }
+
+ inline void Reset() { *this = OptionalCacheRecordPolicy(); }
+};
+
+} // namespace zen
diff --git a/src/zenutil/include/zenutil/cache/cacherequests.h b/src/zenutil/include/zenutil/cache/cacherequests.h
new file mode 100644
index 000000000..f1999ebfe
--- /dev/null
+++ b/src/zenutil/include/zenutil/cache/cacherequests.h
@@ -0,0 +1,279 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/compress.h>
+
+#include "cachekey.h"
+#include "cachepolicy.h"
+
+#include <functional>
+
+namespace zen {
+
+class CbPackage;
+class CbObjectWriter;
+class CbObjectView;
+
+namespace cacherequests {
+ // I'd really like to get rid of std::optional<CacheRecordPolicy> (or really the class CacheRecordPolicy)
+ //
+ // CacheRecordPolicy has a record level policy but it can also contain policies for individual
+ // values inside the record.
+ //
+ // However, when we do a "PutCacheRecords" we already list the individual Values with their Id
+ // so we can just as well use an optional plain CachePolicy for each value.
+ //
+ // In "GetCacheRecords" we do not currently as for the individual values but you can add
+ // a policy on a per-value level in the std::optional<CacheRecordPolicy> Policy for each record.
+ //
+ // But as we already need to know the Ids of the values we want to set the policy for
+ // it would be simpler to add an array of requested values which each has an optional policy.
+ //
+ // We could add:
+ // struct GetCacheRecordValueRequest
+ // {
+ // Oid Id;
+ // std::optional<CachePolicy> Policy;
+ // };
+ //
+ // and change GetCacheRecordRequest to
+ // struct GetCacheRecordRequest
+ // {
+ // CacheKey Key = CacheKey::Empty;
+ // std::vector<GetCacheRecordValueRequest> ValueRequests;
+ // std::optional<CachePolicy> Policy;
+ // };
+ //
+ // This way we don't need the complex CacheRecordPolicy class and the request becomes
+ // more uniform and easier to understand.
+ //
+ // Would need to decide what the ValueRequests actually mean:
+ // Do they dictate which values to fetch or just a change of the policy?
+ // If they dictate the values to fetch you need to know all the value ids to set them
+ // and that is unlikely what we want - we want to be able to get a cache record with
+ // all its values without knowing all the Ids, right?
+ //
+
+ //////////////////////////////////////////////////////////////////////////
+ // Put 1..n structured cache records with optional attachments
+
+ struct PutCacheRecordRequestValue
+ {
+ Oid Id = Oid::Zero;
+ IoHash RawHash = IoHash::Zero; // If Body is not set, this must be set and the value must already exist in cache
+ CompressedBuffer Body = CompressedBuffer::Null;
+ };
+
+ struct PutCacheRecordRequest
+ {
+ CacheKey Key = CacheKey::Empty;
+ std::vector<PutCacheRecordRequestValue> Values;
+ std::optional<CacheRecordPolicy> Policy;
+ };
+
+ struct PutCacheRecordsRequest
+ {
+ uint32_t AcceptMagic = 0;
+ CachePolicy DefaultPolicy = CachePolicy::Default;
+ std::string Namespace;
+ std::vector<PutCacheRecordRequest> Requests;
+
+ bool Parse(const CbPackage& Package);
+ bool Format(CbPackage& OutPackage) const;
+ };
+
+ struct PutCacheRecordsResult
+ {
+ std::vector<bool> Success;
+
+ bool Parse(const CbPackage& Package);
+ bool Format(CbPackage& OutPackage) const;
+ };
+
+ //////////////////////////////////////////////////////////////////////////
+ // Get 1..n structured cache records with optional attachments
+ // We can get requests for a cache record where we want care about a particular
+ // value id which we now of, but we don't know the ids of the other values and
+ // we still want them.
+ // Not sure if in that case we want different policies for the different attachemnts?
+
+ struct GetCacheRecordRequest
+ {
+ CacheKey Key = CacheKey::Empty;
+ std::optional<CacheRecordPolicy> Policy;
+ };
+
+ struct GetCacheRecordsRequest
+ {
+ uint32_t AcceptMagic = 0;
+ uint16_t AcceptOptions = 0;
+ int32_t ProcessPid = 0;
+ CachePolicy DefaultPolicy = CachePolicy::Default;
+ std::string Namespace;
+ std::vector<GetCacheRecordRequest> Requests;
+
+ bool Parse(const CbPackage& RpcRequest);
+ bool Parse(const CbObjectView& RpcRequest);
+ bool Format(CbPackage& OutPackage, const std::span<const size_t> OptionalRecordFilter = {}) const;
+ bool Format(CbObjectWriter& Writer, const std::span<const size_t> OptionalRecordFilter = {}) const;
+ };
+
+ struct GetCacheRecordResultValue
+ {
+ Oid Id = Oid::Zero;
+ IoHash RawHash = IoHash::Zero;
+ uint64_t RawSize = 0;
+ CompressedBuffer Body = CompressedBuffer::Null;
+ };
+
+ struct GetCacheRecordResult
+ {
+ CacheKey Key = CacheKey::Empty;
+ std::vector<GetCacheRecordResultValue> Values;
+ };
+
+ struct GetCacheRecordsResult
+ {
+ std::vector<std::optional<GetCacheRecordResult>> Results;
+
+ bool Parse(const CbPackage& Package, const std::span<const size_t> OptionalRecordResultIndexes = {});
+ bool Format(CbPackage& OutPackage) const;
+ };
+
+ //////////////////////////////////////////////////////////////////////////
+ // Put 1..n unstructured cache objects
+
+ struct PutCacheValueRequest
+ {
+ CacheKey Key = CacheKey::Empty;
+ IoHash RawHash = IoHash::Zero;
+ CompressedBuffer Body = CompressedBuffer::Null; // If not set the value is expected to already exist in cache store
+ std::optional<CachePolicy> Policy;
+ };
+
+ struct PutCacheValuesRequest
+ {
+ uint32_t AcceptMagic = 0;
+ CachePolicy DefaultPolicy = CachePolicy::Default;
+ std::string Namespace;
+ std::vector<PutCacheValueRequest> Requests;
+
+ bool Parse(const CbPackage& Package);
+ bool Format(CbPackage& OutPackage) const;
+ };
+
+ struct PutCacheValuesResult
+ {
+ std::vector<bool> Success;
+
+ bool Parse(const CbPackage& Package);
+ bool Format(CbPackage& OutPackage) const;
+ };
+
+ //////////////////////////////////////////////////////////////////////////
+ // Get 1..n unstructured cache objects (stored data may be structured or unstructured)
+
+ struct GetCacheValueRequest
+ {
+ CacheKey Key = CacheKey::Empty;
+ std::optional<CachePolicy> Policy;
+ };
+
+ struct GetCacheValuesRequest
+ {
+ uint32_t AcceptMagic = 0;
+ uint16_t AcceptOptions = 0;
+ int32_t ProcessPid = 0;
+ CachePolicy DefaultPolicy = CachePolicy::Default;
+ std::string Namespace;
+ std::vector<GetCacheValueRequest> Requests;
+
+ bool Parse(const CbObjectView& BatchObject);
+ bool Format(CbPackage& OutPackage, const std::span<const size_t> OptionalValueFilter = {}) const;
+ };
+
+ struct CacheValueResult
+ {
+ uint64_t RawSize = 0;
+ IoHash RawHash = IoHash::Zero;
+ CompressedBuffer Body = CompressedBuffer::Null;
+ };
+
+ struct CacheValuesResult
+ {
+ std::vector<CacheValueResult> Results;
+
+ bool Parse(const CbPackage& Package, const std::span<const size_t> OptionalValueResultIndexes = {});
+ bool Format(CbPackage& OutPackage) const;
+ };
+
+ typedef CacheValuesResult GetCacheValuesResult;
+
+ //////////////////////////////////////////////////////////////////////////
+ // Get 1..n cache record values (attachments) for 1..n records
+
+ struct GetCacheChunkRequest
+ {
+ CacheKey Key;
+ Oid ValueId = Oid::Zero; // Set if ChunkId is not known at request time
+ IoHash ChunkId = IoHash::Zero;
+ uint64_t RawOffset = 0ull;
+ uint64_t RawSize = ~uint64_t(0);
+ std::optional<CachePolicy> Policy;
+ };
+
+ struct GetCacheChunksRequest
+ {
+ uint32_t AcceptMagic = 0;
+ uint16_t AcceptOptions = 0;
+ int32_t ProcessPid = 0;
+ CachePolicy DefaultPolicy = CachePolicy::Default;
+ std::string Namespace;
+ std::vector<GetCacheChunkRequest> Requests;
+
+ bool Parse(const CbObjectView& BatchObject);
+ bool Format(CbPackage& OutPackage) const;
+ };
+
+ typedef CacheValuesResult GetCacheChunksResult;
+
+ //////////////////////////////////////////////////////////////////////////
+
+ struct HttpRequestData
+ {
+ std::optional<std::string> Namespace;
+ std::optional<std::string> Bucket;
+ std::optional<IoHash> HashKey;
+ std::optional<IoHash> ValueContentId;
+ };
+
+ bool HttpRequestParseRelativeUri(std::string_view Key, HttpRequestData& Data);
+
+ // Temporarily public
+ std::optional<std::string> GetRequestNamespace(const CbObjectView& Params);
+ bool GetRequestCacheKey(const CbObjectView& KeyView, CacheKey& Key);
+
+ //////////////////////////////////////////////////////////////////////////
+
+ // struct CacheRecordValue
+ // {
+ // Oid Id = Oid::Zero;
+ // IoHash RawHash = IoHash::Zero;
+ // uint64_t RawSize = 0;
+ // };
+ //
+ // struct CacheRecord
+ // {
+ // CacheKey Key = CacheKey::Empty;
+ // std::vector<CacheRecordValue> Values;
+ //
+ // bool Parse(CbObjectView& Reader);
+ // bool Format(CbObjectWriter& Writer) const;
+ // };
+
+} // namespace cacherequests
+
+void cacherequests_forcelink(); // internal
+
+} // namespace zen
diff --git a/src/zenutil/include/zenutil/cache/rpcrecording.h b/src/zenutil/include/zenutil/cache/rpcrecording.h
new file mode 100644
index 000000000..6d65a532a
--- /dev/null
+++ b/src/zenutil/include/zenutil/cache/rpcrecording.h
@@ -0,0 +1,29 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/compositebuffer.h>
+#include <zencore/iobuffer.h>
+
+namespace zen::cache {
+class IRpcRequestRecorder
+{
+public:
+ virtual ~IRpcRequestRecorder() {}
+ virtual uint64_t RecordRequest(const ZenContentType ContentType, const ZenContentType AcceptType, const IoBuffer& RequestBuffer) = 0;
+ virtual void RecordResponse(uint64_t RequestIndex, const ZenContentType ContentType, const IoBuffer& ResponseBuffer) = 0;
+ virtual void RecordResponse(uint64_t RequestIndex, const ZenContentType ContentType, const CompositeBuffer& ResponseBuffer) = 0;
+};
+class IRpcRequestReplayer
+{
+public:
+ virtual ~IRpcRequestReplayer() {}
+ virtual uint64_t GetRequestCount() const = 0;
+ virtual std::pair<ZenContentType, ZenContentType> GetRequest(uint64_t RequestIndex, IoBuffer& OutBuffer) = 0;
+ virtual ZenContentType GetResponse(uint64_t RequestIndex, IoBuffer& OutBuffer) = 0;
+};
+
+std::unique_ptr<cache::IRpcRequestRecorder> MakeDiskRequestRecorder(const std::filesystem::path& BasePath);
+std::unique_ptr<cache::IRpcRequestReplayer> MakeDiskRequestReplayer(const std::filesystem::path& BasePath, bool InMemory);
+
+} // namespace zen::cache
diff --git a/src/zenutil/include/zenutil/zenserverprocess.h b/src/zenutil/include/zenutil/zenserverprocess.h
new file mode 100644
index 000000000..1c204c144
--- /dev/null
+++ b/src/zenutil/include/zenutil/zenserverprocess.h
@@ -0,0 +1,141 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/enumflags.h>
+#include <zencore/logging.h>
+#include <zencore/thread.h>
+#include <zencore/uid.h>
+
+#include <atomic>
+#include <filesystem>
+#include <optional>
+
+namespace zen {
+
+class ZenServerEnvironment
+{
+public:
+ ZenServerEnvironment();
+ ~ZenServerEnvironment();
+
+ void Initialize(std::filesystem::path ProgramBaseDir);
+ void InitializeForTest(std::filesystem::path ProgramBaseDir, std::filesystem::path TestBaseDir, std::string_view ServerClass = "");
+
+ std::filesystem::path CreateNewTestDir();
+ std::filesystem::path ProgramBaseDir() const { return m_ProgramBaseDir; }
+ std::filesystem::path GetTestRootDir(std::string_view Path);
+ inline bool IsInitialized() const { return m_IsInitialized; }
+ inline bool IsTestEnvironment() const { return m_IsTestInstance; }
+ inline std::string_view GetServerClass() const { return m_ServerClass; }
+
+private:
+ std::filesystem::path m_ProgramBaseDir;
+ std::filesystem::path m_TestBaseDir;
+ bool m_IsInitialized = false;
+ bool m_IsTestInstance = false;
+ std::string m_ServerClass;
+};
+
+struct ZenServerInstance
+{
+ ZenServerInstance(ZenServerEnvironment& TestEnvironment);
+ ~ZenServerInstance();
+
+ void Shutdown();
+ void SignalShutdown();
+ void WaitUntilReady();
+ [[nodiscard]] bool WaitUntilReady(int Timeout);
+ void EnableTermination() { m_Terminate = true; }
+ void Detach();
+ inline int GetPid() { return m_Process.Pid(); }
+ inline void SetOwnerPid(int Pid) { m_OwnerPid = Pid; }
+
+ void SetTestDir(std::filesystem::path TestDir)
+ {
+ ZEN_ASSERT(!m_Process.IsValid());
+ m_TestDir = TestDir;
+ }
+
+ void SpawnServer(int BasePort = 0, std::string_view AdditionalServerArgs = std::string_view());
+
+ void AttachToRunningServer(int BasePort = 0);
+
+ std::string GetBaseUri() const;
+
+private:
+ ZenServerEnvironment& m_Env;
+ ProcessHandle m_Process;
+ NamedEvent m_ReadyEvent;
+ NamedEvent m_ShutdownEvent;
+ bool m_Terminate = false;
+ std::filesystem::path m_TestDir;
+ int m_BasePort = 0;
+ std::optional<int> m_OwnerPid;
+
+ void CreateShutdownEvent(int BasePort);
+};
+
+/** Shared system state
+ *
+ * Used as a scratchpad to identify running instances etc
+ *
+ * The state lives in a memory-mapped file backed by the swapfile
+ *
+ */
+
+class ZenServerState
+{
+public:
+ ZenServerState();
+ ~ZenServerState();
+
+ struct ZenServerEntry
+ {
+ // NOTE: any changes to this should consider backwards compatibility
+ // which means you should not rearrange members only potentially
+ // add something to the end or use a different mechanism for
+ // additional state. For example, you can use the session ID
+ // to introduce additional named objects
+ std::atomic<uint32_t> Pid;
+ std::atomic<uint16_t> DesiredListenPort;
+ std::atomic<uint16_t> Flags;
+ uint8_t SessionId[12];
+ std::atomic<uint32_t> SponsorPids[8];
+ std::atomic<uint16_t> EffectiveListenPort;
+ uint8_t Padding[10];
+
+ enum class FlagsEnum : uint16_t
+ {
+ kShutdownPlease = 1 << 0,
+ kIsReady = 1 << 1,
+ };
+
+ FRIEND_ENUM_CLASS_FLAGS(FlagsEnum);
+
+ Oid GetSessionId() const { return Oid::FromMemory(SessionId); }
+ void Reset();
+ void SignalShutdownRequest();
+ void SignalReady();
+ bool AddSponsorProcess(uint32_t Pid);
+ };
+
+ static_assert(sizeof(ZenServerEntry) == 64);
+
+ void Initialize();
+ [[nodiscard]] bool InitializeReadOnly();
+ [[nodiscard]] ZenServerEntry* Lookup(int DesiredListenPort);
+ ZenServerEntry* Register(int DesiredListenPort);
+ void Sweep();
+ void Snapshot(std::function<void(const ZenServerEntry&)>&& Callback);
+ inline bool IsReadOnly() const { return m_IsReadOnly; }
+
+private:
+ void* m_hMapFile = nullptr;
+ ZenServerEntry* m_Data = nullptr;
+ int m_MaxEntryCount = 65536 / sizeof(ZenServerEntry);
+ ZenServerEntry* m_OurEntry = nullptr;
+ bool m_IsReadOnly = true;
+};
+
+} // namespace zen
diff --git a/src/zenutil/xmake.lua b/src/zenutil/xmake.lua
new file mode 100644
index 000000000..e7d849bb2
--- /dev/null
+++ b/src/zenutil/xmake.lua
@@ -0,0 +1,9 @@
+-- Copyright Epic Games, Inc. All Rights Reserved.
+
+target('zenutil')
+ set_kind("static")
+ add_headerfiles("**.h")
+ add_files("**.cpp")
+ add_includedirs("include", {public=true})
+ add_deps("zencore")
+ add_packages("vcpkg::spdlog")
diff --git a/src/zenutil/zenserverprocess.cpp b/src/zenutil/zenserverprocess.cpp
new file mode 100644
index 000000000..5ecde343b
--- /dev/null
+++ b/src/zenutil/zenserverprocess.cpp
@@ -0,0 +1,677 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "zenutil/zenserverprocess.h"
+
+#include <zencore/except.h>
+#include <zencore/filesystem.h>
+#include <zencore/fmtutils.h>
+#include <zencore/logging.h>
+#include <zencore/session.h>
+#include <zencore/string.h>
+#include <zencore/thread.h>
+
+#include <atomic>
+
+#if ZEN_PLATFORM_WINDOWS
+# include <atlbase.h>
+# include <zencore/windows.h>
+#else
+# include <sys/mman.h>
+#endif
+
+//////////////////////////////////////////////////////////////////////////
+
+namespace zen {
+
+namespace zenutil {
+#if ZEN_PLATFORM_WINDOWS
+ class SecurityAttributes
+ {
+ public:
+ inline SECURITY_ATTRIBUTES* Attributes() { return &m_Attributes; }
+
+ protected:
+ SECURITY_ATTRIBUTES m_Attributes{};
+ SECURITY_DESCRIPTOR m_Sd{};
+ };
+
+ // Security attributes which allows any user access
+
+ class AnyUserSecurityAttributes : public SecurityAttributes
+ {
+ public:
+ AnyUserSecurityAttributes()
+ {
+ m_Attributes.nLength = sizeof m_Attributes;
+ m_Attributes.bInheritHandle = false; // Disable inheritance
+
+ const BOOL Success = InitializeSecurityDescriptor(&m_Sd, SECURITY_DESCRIPTOR_REVISION);
+
+ if (Success)
+ {
+ if (!SetSecurityDescriptorDacl(&m_Sd, TRUE, (PACL)NULL, FALSE))
+ {
+ ThrowLastError("SetSecurityDescriptorDacl failed");
+ }
+
+ m_Attributes.lpSecurityDescriptor = &m_Sd;
+ }
+ }
+ };
+#endif // ZEN_PLATFORM_WINDOWS
+
+} // namespace zenutil
+
+//////////////////////////////////////////////////////////////////////////
+
+ZenServerState::ZenServerState()
+{
+}
+
+ZenServerState::~ZenServerState()
+{
+ if (m_OurEntry)
+ {
+ // Clean up our entry now that we're leaving
+
+ m_OurEntry->Reset();
+ m_OurEntry = nullptr;
+ }
+
+#if ZEN_PLATFORM_WINDOWS
+ if (m_Data)
+ {
+ UnmapViewOfFile(m_Data);
+ }
+
+ if (m_hMapFile)
+ {
+ CloseHandle(m_hMapFile);
+ }
+#else
+ if (m_Data != nullptr)
+ {
+ munmap(m_Data, m_MaxEntryCount * sizeof(ZenServerEntry));
+ }
+
+ int Fd = int(intptr_t(m_hMapFile));
+ close(Fd);
+#endif
+
+ m_Data = nullptr;
+}
+
+void
+ZenServerState::Initialize()
+{
+ size_t MapSize = m_MaxEntryCount * sizeof(ZenServerEntry);
+
+#if ZEN_PLATFORM_WINDOWS
+ // TODO: there's a small chance of a race here, this logic could be tightened up with a mutex to
+ // ensure only a single process at a time creates the mapping
+ // TODO: the fallback to Local instead of Global has a flaw where if you start a non-elevated instance
+ // first then start an elevated instance second you'll have the first instance with a local
+ // mapping and the second instance with a global mapping. This kind of elevated/non-elevated
+ // shouldn't be common, but handling for it should be improved in the future.
+
+ HANDLE hMap = OpenFileMapping(FILE_MAP_ALL_ACCESS, FALSE, L"Global\\ZenMap");
+ if (hMap == NULL)
+ {
+ hMap = OpenFileMapping(FILE_MAP_ALL_ACCESS, FALSE, L"Local\\ZenMap");
+ }
+
+ if (hMap == NULL)
+ {
+ // Security attributes to enable any user to access state
+ zenutil::AnyUserSecurityAttributes Attrs;
+
+ hMap = CreateFileMapping(INVALID_HANDLE_VALUE, // use paging file
+ Attrs.Attributes(), // allow anyone to access
+ PAGE_READWRITE, // read/write access
+ 0, // maximum object size (high-order DWORD)
+ DWORD(MapSize), // maximum object size (low-order DWORD)
+ L"Global\\ZenMap"); // name of mapping object
+
+ if (hMap == NULL)
+ {
+ hMap = CreateFileMapping(INVALID_HANDLE_VALUE, // use paging file
+ Attrs.Attributes(), // allow anyone to access
+ PAGE_READWRITE, // read/write access
+ 0, // maximum object size (high-order DWORD)
+ m_MaxEntryCount * sizeof(ZenServerEntry), // maximum object size (low-order DWORD)
+ L"Local\\ZenMap"); // name of mapping object
+ }
+
+ if (hMap == NULL)
+ {
+ ThrowLastError("Could not open or create file mapping object for Zen server state");
+ }
+ }
+
+ void* pBuf = MapViewOfFile(hMap, // handle to map object
+ FILE_MAP_ALL_ACCESS, // read/write permission
+ 0, // offset high
+ 0, // offset low
+ DWORD(MapSize));
+
+ if (pBuf == NULL)
+ {
+ ThrowLastError("Could not map view of Zen server state");
+ }
+#else
+ int Fd = shm_open("/UnrealEngineZen", O_RDWR | O_CREAT | O_CLOEXEC, 0666);
+ if (Fd < 0)
+ {
+ ThrowLastError("Could not open a shared memory object");
+ }
+ fchmod(Fd, 0666);
+ void* hMap = (void*)intptr_t(Fd);
+
+ int Result = ftruncate(Fd, MapSize);
+ ZEN_UNUSED(Result);
+
+ void* pBuf = mmap(nullptr, MapSize, PROT_READ | PROT_WRITE, MAP_SHARED, Fd, 0);
+ if (pBuf == MAP_FAILED)
+ {
+ ThrowLastError("Could not map view of Zen server state");
+ }
+#endif
+
+ m_hMapFile = hMap;
+ m_Data = reinterpret_cast<ZenServerEntry*>(pBuf);
+ m_IsReadOnly = false;
+}
+
+bool
+ZenServerState::InitializeReadOnly()
+{
+ size_t MapSize = m_MaxEntryCount * sizeof(ZenServerEntry);
+
+#if ZEN_PLATFORM_WINDOWS
+ HANDLE hMap = OpenFileMapping(FILE_MAP_ALL_ACCESS, FALSE, L"Global\\ZenMap");
+ if (hMap == NULL)
+ {
+ hMap = OpenFileMapping(FILE_MAP_ALL_ACCESS, FALSE, L"Local\\ZenMap");
+ }
+
+ if (hMap == NULL)
+ {
+ return false;
+ }
+
+ void* pBuf = MapViewOfFile(hMap, // handle to map object
+ FILE_MAP_READ, // read permission
+ 0, // offset high
+ 0, // offset low
+ MapSize);
+
+ if (pBuf == NULL)
+ {
+ ThrowLastError("Could not map view of Zen server state");
+ }
+#else
+ int Fd = shm_open("/UnrealEngineZen", O_RDONLY | O_CLOEXEC, 0666);
+ if (Fd < 0)
+ {
+ return false;
+ }
+ void* hMap = (void*)intptr_t(Fd);
+
+ void* pBuf = mmap(nullptr, MapSize, PROT_READ, MAP_PRIVATE, Fd, 0);
+ if (pBuf == MAP_FAILED)
+ {
+ ThrowLastError("Could not map read-only view of Zen server state");
+ }
+#endif
+
+ m_hMapFile = hMap;
+ m_Data = reinterpret_cast<ZenServerEntry*>(pBuf);
+
+ return true;
+}
+
+ZenServerState::ZenServerEntry*
+ZenServerState::Lookup(int DesiredListenPort)
+{
+ for (int i = 0; i < m_MaxEntryCount; ++i)
+ {
+ if (m_Data[i].DesiredListenPort == DesiredListenPort)
+ {
+ return &m_Data[i];
+ }
+ }
+
+ return nullptr;
+}
+
+ZenServerState::ZenServerEntry*
+ZenServerState::Register(int DesiredListenPort)
+{
+ if (m_Data == nullptr)
+ {
+ return nullptr;
+ }
+
+ // Allocate an entry
+
+ int Pid = GetCurrentProcessId();
+
+ for (int i = 0; i < m_MaxEntryCount; ++i)
+ {
+ ZenServerEntry& Entry = m_Data[i];
+
+ if (Entry.DesiredListenPort.load(std::memory_order_relaxed) == 0)
+ {
+ uint16_t Expected = 0;
+ if (Entry.DesiredListenPort.compare_exchange_strong(Expected, uint16_t(DesiredListenPort)))
+ {
+ // Successfully allocated entry
+
+ m_OurEntry = &Entry;
+
+ Entry.Pid = Pid;
+ Entry.EffectiveListenPort = 0;
+ Entry.Flags = 0;
+
+ const Oid SesId = GetSessionId();
+ memcpy(Entry.SessionId, &SesId, sizeof SesId);
+
+ return &Entry;
+ }
+ }
+ }
+
+ return nullptr;
+}
+
+void
+ZenServerState::Sweep()
+{
+ if (m_Data == nullptr)
+ {
+ return;
+ }
+
+ ZEN_ASSERT(m_IsReadOnly == false);
+
+ for (int i = 0; i < m_MaxEntryCount; ++i)
+ {
+ ZenServerEntry& Entry = m_Data[i];
+
+ if (Entry.DesiredListenPort)
+ {
+ if (IsProcessRunning(Entry.Pid) == false)
+ {
+ ZEN_DEBUG("Sweep - pid {} not running, reclaiming entry (port {})", Entry.Pid, Entry.DesiredListenPort);
+
+ Entry.Reset();
+ }
+ }
+ }
+}
+
+void
+ZenServerState::Snapshot(std::function<void(const ZenServerEntry&)>&& Callback)
+{
+ if (m_Data == nullptr)
+ {
+ return;
+ }
+
+ for (int i = 0; i < m_MaxEntryCount; ++i)
+ {
+ ZenServerEntry& Entry = m_Data[i];
+
+ if (Entry.DesiredListenPort)
+ {
+ Callback(Entry);
+ }
+ }
+}
+
+void
+ZenServerState::ZenServerEntry::Reset()
+{
+ Pid = 0;
+ DesiredListenPort = 0;
+ Flags = 0;
+ EffectiveListenPort = 0;
+}
+
+void
+ZenServerState::ZenServerEntry::SignalShutdownRequest()
+{
+ Flags |= uint16_t(FlagsEnum::kShutdownPlease);
+}
+
+void
+ZenServerState::ZenServerEntry::SignalReady()
+{
+ Flags |= uint16_t(FlagsEnum::kIsReady);
+}
+
+bool
+ZenServerState::ZenServerEntry::AddSponsorProcess(uint32_t PidToAdd)
+{
+ for (std::atomic<uint32_t>& PidEntry : SponsorPids)
+ {
+ if (PidEntry.load(std::memory_order_relaxed) == 0)
+ {
+ uint32_t Expected = 0;
+ if (PidEntry.compare_exchange_strong(Expected, PidToAdd))
+ {
+ // Success!
+ return true;
+ }
+ }
+ else if (PidEntry.load(std::memory_order_relaxed) == PidToAdd)
+ {
+ // Success, the because pid is already in the list
+ return true;
+ }
+ }
+
+ return false;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+std::atomic<int> ZenServerTestCounter{0};
+
+ZenServerEnvironment::ZenServerEnvironment()
+{
+}
+
+ZenServerEnvironment::~ZenServerEnvironment()
+{
+}
+
+void
+ZenServerEnvironment::Initialize(std::filesystem::path ProgramBaseDir)
+{
+ m_ProgramBaseDir = ProgramBaseDir;
+
+ ZEN_DEBUG("Program base dir is '{}'", ProgramBaseDir);
+
+ m_IsInitialized = true;
+}
+
+void
+ZenServerEnvironment::InitializeForTest(std::filesystem::path ProgramBaseDir,
+ std::filesystem::path TestBaseDir,
+ std::string_view ServerClass)
+{
+ using namespace std::literals;
+
+ m_ProgramBaseDir = ProgramBaseDir;
+ m_TestBaseDir = TestBaseDir;
+
+ ZEN_INFO("Program base dir is '{}'", ProgramBaseDir);
+ ZEN_INFO("Cleaning test base dir '{}'", TestBaseDir);
+ DeleteDirectories(TestBaseDir.c_str());
+
+ m_IsTestInstance = true;
+ m_IsInitialized = true;
+
+ if (ServerClass.empty())
+ {
+#if ZEN_WITH_HTTPSYS
+ m_ServerClass = "httpsys"sv;
+#else
+ m_ServerClass = "asio"sv;
+#endif
+ }
+ else
+ {
+ m_ServerClass = ServerClass;
+ }
+}
+
+std::filesystem::path
+ZenServerEnvironment::CreateNewTestDir()
+{
+ using namespace std::literals;
+
+ ExtendableWideStringBuilder<256> TestDir;
+ TestDir << "test"sv << int64_t(++ZenServerTestCounter);
+
+ std::filesystem::path TestPath = m_TestBaseDir / TestDir.c_str();
+
+ ZEN_INFO("Creating new test dir @ '{}'", TestPath);
+
+ CreateDirectories(TestPath.c_str());
+
+ return TestPath;
+}
+
+std::filesystem::path
+ZenServerEnvironment::GetTestRootDir(std::string_view Path)
+{
+ std::filesystem::path Root = m_ProgramBaseDir.parent_path().parent_path();
+
+ std::filesystem::path Relative{Path};
+
+ return Root / Relative;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+std::atomic<int> ChildIdCounter{0};
+
+ZenServerInstance::ZenServerInstance(ZenServerEnvironment& TestEnvironment) : m_Env(TestEnvironment)
+{
+ ZEN_ASSERT(TestEnvironment.IsInitialized());
+}
+
+ZenServerInstance::~ZenServerInstance()
+{
+ Shutdown();
+}
+
+void
+ZenServerInstance::SignalShutdown()
+{
+ m_ShutdownEvent.Set();
+}
+
+void
+ZenServerInstance::Shutdown()
+{
+ if (m_Process.IsValid())
+ {
+ if (m_Terminate)
+ {
+ ZEN_INFO("Terminating zenserver process");
+ m_Process.Terminate(111);
+ m_Process.Reset();
+ }
+ else
+ {
+ SignalShutdown();
+ m_Process.Wait();
+ m_Process.Reset();
+ }
+ }
+}
+
+void
+ZenServerInstance::SpawnServer(int BasePort, std::string_view AdditionalServerArgs)
+{
+ ZEN_ASSERT(!m_Process.IsValid()); // Only spawn once
+
+ const int MyPid = zen::GetCurrentProcessId();
+ const int ChildId = ++ChildIdCounter;
+
+ ExtendableStringBuilder<32> ChildEventName;
+ ChildEventName << "Zen_Child_" << ChildId;
+ NamedEvent ChildEvent{ChildEventName};
+
+ CreateShutdownEvent(BasePort);
+
+ ExtendableStringBuilder<32> LogId;
+ LogId << "Zen" << ChildId;
+
+ ExtendableStringBuilder<512> CommandLine;
+ CommandLine << "zenserver" ZEN_EXE_SUFFIX_LITERAL; // see CreateProc() call for actual binary path
+
+ const bool IsTest = m_Env.IsTestEnvironment();
+
+ if (IsTest)
+ {
+ if (!m_OwnerPid.has_value())
+ {
+ m_OwnerPid = MyPid;
+ }
+
+ CommandLine << " --test --log-id " << LogId;
+ }
+
+ if (m_OwnerPid.has_value())
+ {
+ CommandLine << " --owner-pid " << m_OwnerPid.value();
+ }
+
+ CommandLine << " --child-id " << ChildEventName;
+
+ if (std::string_view ServerClass = m_Env.GetServerClass(); ServerClass.empty() == false)
+ {
+ CommandLine << " --http " << ServerClass;
+ }
+
+ if (BasePort)
+ {
+ CommandLine << " --port " << BasePort;
+ m_BasePort = BasePort;
+ }
+
+ if (!m_TestDir.empty())
+ {
+ CommandLine << " --data-dir ";
+ PathToUtf8(m_TestDir.c_str(), CommandLine);
+ }
+
+ if (!AdditionalServerArgs.empty())
+ {
+ CommandLine << " " << AdditionalServerArgs;
+ }
+
+ std::filesystem::path CurrentDirectory = std::filesystem::current_path();
+
+ ZEN_DEBUG("Spawning server '{}'", LogId);
+
+ uint32_t CreationFlags = 0;
+ if (!IsTest)
+ {
+ CreationFlags |= CreateProcOptions::Flag_NewConsole;
+ }
+
+ const std::filesystem::path BaseDir = m_Env.ProgramBaseDir();
+ const std::filesystem::path Executable = BaseDir / "zenserver" ZEN_EXE_SUFFIX_LITERAL;
+ CreateProcOptions CreateOptions = {
+ .WorkingDirectory = &CurrentDirectory,
+ .Flags = CreationFlags,
+ };
+ CreateProcResult ChildPid = CreateProc(Executable, CommandLine.ToView(), CreateOptions);
+#if ZEN_PLATFORM_WINDOWS
+ if (!ChildPid && ::GetLastError() == ERROR_ELEVATION_REQUIRED)
+ {
+ ZEN_DEBUG("Regular spawn failed - spawning elevated server");
+ CreateOptions.Flags |= CreateProcOptions::Flag_Elevated;
+ ChildPid = CreateProc(Executable, CommandLine.ToView(), CreateOptions);
+ }
+#endif
+
+ if (!ChildPid)
+ {
+ ThrowLastError("Server spawn failed");
+ }
+
+ ZEN_DEBUG("Server '{}' spawned OK", LogId);
+
+ if (IsTest)
+ {
+ m_Process.Initialize(ChildPid);
+ }
+
+ m_ReadyEvent = std::move(ChildEvent);
+}
+
+void
+ZenServerInstance::CreateShutdownEvent(int BasePort)
+{
+ ExtendableStringBuilder<32> ChildShutdownEventName;
+ ChildShutdownEventName << "Zen_" << BasePort;
+ ChildShutdownEventName << "_Shutdown";
+ NamedEvent ChildShutdownEvent{ChildShutdownEventName};
+ m_ShutdownEvent = std::move(ChildShutdownEvent);
+}
+
+void
+ZenServerInstance::AttachToRunningServer(int BasePort)
+{
+ ZenServerState State;
+ if (!State.InitializeReadOnly())
+ {
+ // TODO: return success/error code instead?
+ throw std::runtime_error("No zen state found");
+ }
+
+ const ZenServerState::ZenServerEntry* Entry = nullptr;
+
+ if (BasePort)
+ {
+ Entry = State.Lookup(BasePort);
+ }
+ else
+ {
+ State.Snapshot([&](const ZenServerState::ZenServerEntry& InEntry) { Entry = &InEntry; });
+ }
+
+ if (!Entry)
+ {
+ // TODO: return success/error code instead?
+ throw std::runtime_error("No server found");
+ }
+
+ m_Process.Initialize(Entry->Pid);
+ CreateShutdownEvent(Entry->EffectiveListenPort);
+}
+
+void
+ZenServerInstance::Detach()
+{
+ if (m_Process.IsValid())
+ {
+ m_Process.Reset();
+ m_ShutdownEvent.Close();
+ }
+}
+
+void
+ZenServerInstance::WaitUntilReady()
+{
+ while (m_ReadyEvent.Wait(100) == false)
+ {
+ if (!m_Process.IsRunning() || !m_Process.IsValid())
+ {
+ ZEN_INFO("Wait abandoned by invalid process (running={})", m_Process.IsRunning());
+ return;
+ }
+ }
+}
+
+bool
+ZenServerInstance::WaitUntilReady(int Timeout)
+{
+ return m_ReadyEvent.Wait(Timeout);
+}
+
+std::string
+ZenServerInstance::GetBaseUri() const
+{
+ ZEN_ASSERT(m_BasePort);
+
+ return fmt::format("http://localhost:{}", m_BasePort);
+}
+
+} // namespace zen