diff options
| author | Stefan Boberg <[email protected]> | 2023-05-02 10:01:47 +0200 |
|---|---|---|
| committer | GitHub <[email protected]> | 2023-05-02 10:01:47 +0200 |
| commit | 075d17f8ada47e990fe94606c3d21df409223465 (patch) | |
| tree | e50549b766a2f3c354798a54ff73404217b4c9af /src/zenutil | |
| parent | fix: bundle shouldn't append content zip to zen (diff) | |
| download | zen-075d17f8ada47e990fe94606c3d21df409223465.tar.xz zen-075d17f8ada47e990fe94606c3d21df409223465.zip | |
moved source directories into `/src` (#264)
* moved source directories into `/src`
* updated bundle.lua for new `src` path
* moved some docs, icon
* removed old test trees
Diffstat (limited to 'src/zenutil')
| -rw-r--r-- | src/zenutil/basicfile.cpp | 575 | ||||
| -rw-r--r-- | src/zenutil/cache/cachekey.cpp | 9 | ||||
| -rw-r--r-- | src/zenutil/cache/cachepolicy.cpp | 282 | ||||
| -rw-r--r-- | src/zenutil/cache/cacherequests.cpp | 1643 | ||||
| -rw-r--r-- | src/zenutil/cache/rpcrecording.cpp | 210 | ||||
| -rw-r--r-- | src/zenutil/include/zenutil/basicfile.h | 113 | ||||
| -rw-r--r-- | src/zenutil/include/zenutil/cache/cache.h | 6 | ||||
| -rw-r--r-- | src/zenutil/include/zenutil/cache/cachekey.h | 86 | ||||
| -rw-r--r-- | src/zenutil/include/zenutil/cache/cachepolicy.h | 227 | ||||
| -rw-r--r-- | src/zenutil/include/zenutil/cache/cacherequests.h | 279 | ||||
| -rw-r--r-- | src/zenutil/include/zenutil/cache/rpcrecording.h | 29 | ||||
| -rw-r--r-- | src/zenutil/include/zenutil/zenserverprocess.h | 141 | ||||
| -rw-r--r-- | src/zenutil/xmake.lua | 9 | ||||
| -rw-r--r-- | src/zenutil/zenserverprocess.cpp | 677 |
14 files changed, 4286 insertions, 0 deletions
diff --git a/src/zenutil/basicfile.cpp b/src/zenutil/basicfile.cpp new file mode 100644 index 000000000..1e6043d7e --- /dev/null +++ b/src/zenutil/basicfile.cpp @@ -0,0 +1,575 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "zenutil/basicfile.h" + +#include <zencore/compactbinary.h> +#include <zencore/except.h> +#include <zencore/filesystem.h> +#include <zencore/fmtutils.h> +#include <zencore/testing.h> +#include <zencore/testutils.h> + +#if ZEN_PLATFORM_WINDOWS +# include <zencore/windows.h> +#else +# include <fcntl.h> +# include <sys/file.h> +# include <sys/stat.h> +# include <unistd.h> +#endif + +#include <fmt/format.h> +#include <gsl/gsl-lite.hpp> + +namespace zen { + +BasicFile::~BasicFile() +{ + Close(); +} + +void +BasicFile::Open(const std::filesystem::path& FileName, Mode Mode) +{ + std::error_code Ec; + Open(FileName, Mode, Ec); + + if (Ec) + { + throw std::system_error(Ec, fmt::format("failed to open file '{}'", FileName)); + } +} + +void +BasicFile::Open(const std::filesystem::path& FileName, Mode Mode, std::error_code& Ec) +{ + Ec.clear(); + +#if ZEN_PLATFORM_WINDOWS + DWORD dwCreationDisposition = 0; + DWORD dwDesiredAccess = 0; + switch (Mode) + { + case Mode::kRead: + dwCreationDisposition |= OPEN_EXISTING; + dwDesiredAccess |= GENERIC_READ; + break; + case Mode::kWrite: + dwCreationDisposition |= OPEN_ALWAYS; + dwDesiredAccess |= (GENERIC_READ | GENERIC_WRITE); + break; + case Mode::kDelete: + dwCreationDisposition |= OPEN_ALWAYS; + dwDesiredAccess |= (GENERIC_READ | GENERIC_WRITE | DELETE); + break; + case Mode::kTruncate: + dwCreationDisposition |= CREATE_ALWAYS; + dwDesiredAccess |= (GENERIC_READ | GENERIC_WRITE); + break; + case Mode::kTruncateDelete: + dwCreationDisposition |= CREATE_ALWAYS; + dwDesiredAccess |= (GENERIC_READ | GENERIC_WRITE | DELETE); + break; + } + + const DWORD dwShareMode = FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE; + const DWORD dwFlagsAndAttributes = FILE_ATTRIBUTE_NORMAL; + HANDLE hTemplateFile = nullptr; + + HANDLE FileHandle = CreateFile(FileName.c_str(), + dwDesiredAccess, + dwShareMode, + /* lpSecurityAttributes */ nullptr, + dwCreationDisposition, + dwFlagsAndAttributes, + hTemplateFile); + + if (FileHandle == INVALID_HANDLE_VALUE) + { + Ec = MakeErrorCodeFromLastError(); + + return; + } +#else + int OpenFlags = O_CLOEXEC; + switch (Mode) + { + case Mode::kRead: + OpenFlags |= O_RDONLY; + break; + case Mode::kWrite: + case Mode::kDelete: + OpenFlags |= (O_RDWR | O_CREAT); + break; + case Mode::kTruncate: + case Mode::kTruncateDelete: + OpenFlags |= (O_RDWR | O_CREAT | O_TRUNC); + break; + } + + int Fd = open(FileName.c_str(), OpenFlags, 0666); + if (Fd < 0) + { + Ec = MakeErrorCodeFromLastError(); + return; + } + if (Mode != Mode::kRead) + { + fchmod(Fd, 0666); + } + + void* FileHandle = (void*)(uintptr_t(Fd)); +#endif + + m_FileHandle = FileHandle; +} + +void +BasicFile::Close() +{ + if (m_FileHandle) + { +#if ZEN_PLATFORM_WINDOWS + ::CloseHandle(m_FileHandle); +#else + int Fd = int(uintptr_t(m_FileHandle)); + close(Fd); +#endif + m_FileHandle = nullptr; + } +} + +void +BasicFile::Read(void* Data, uint64_t BytesToRead, uint64_t FileOffset) +{ + const uint64_t MaxChunkSize = 2u * 1024 * 1024 * 1024; + + while (BytesToRead) + { + const uint64_t NumberOfBytesToRead = Min(BytesToRead, MaxChunkSize); + +#if ZEN_PLATFORM_WINDOWS + OVERLAPPED Ovl{}; + + Ovl.Offset = DWORD(FileOffset & 0xffff'ffffu); + Ovl.OffsetHigh = DWORD(FileOffset >> 32); + + DWORD dwNumberOfBytesRead = 0; + BOOL Success = ::ReadFile(m_FileHandle, Data, DWORD(NumberOfBytesToRead), &dwNumberOfBytesRead, &Ovl); + + ZEN_ASSERT(dwNumberOfBytesRead == NumberOfBytesToRead); +#else + static_assert(sizeof(off_t) >= sizeof(uint64_t), "sizeof(off_t) does not support large files"); + int Fd = int(uintptr_t(m_FileHandle)); + int BytesRead = pread(Fd, Data, NumberOfBytesToRead, FileOffset); + bool Success = (BytesRead > 0); +#endif + + if (!Success) + { + ThrowLastError(fmt::format("Failed to read from file '{}'", zen::PathFromHandle(m_FileHandle))); + } + + BytesToRead -= NumberOfBytesToRead; + FileOffset += NumberOfBytesToRead; + Data = reinterpret_cast<uint8_t*>(Data) + NumberOfBytesToRead; + } +} + +IoBuffer +BasicFile::ReadAll() +{ + IoBuffer Buffer(FileSize()); + Read(Buffer.MutableData(), Buffer.Size(), 0); + return Buffer; +} + +void +BasicFile::StreamFile(std::function<void(const void* Data, uint64_t Size)>&& ChunkFun) +{ + StreamByteRange(0, FileSize(), std::move(ChunkFun)); +} + +void +BasicFile::StreamByteRange(uint64_t FileOffset, uint64_t Size, std::function<void(const void* Data, uint64_t Size)>&& ChunkFun) +{ + const uint64_t ChunkSize = 128 * 1024; + IoBuffer ReadBuffer{ChunkSize}; + void* BufferPtr = ReadBuffer.MutableData(); + + uint64_t RemainBytes = Size; + uint64_t CurrentOffset = FileOffset; + + while (RemainBytes) + { + const uint64_t ThisChunkBytes = zen::Min(ChunkSize, RemainBytes); + + Read(BufferPtr, ThisChunkBytes, CurrentOffset); + + ChunkFun(BufferPtr, ThisChunkBytes); + + CurrentOffset += ThisChunkBytes; + RemainBytes -= ThisChunkBytes; + } +} + +void +BasicFile::Write(MemoryView Data, uint64_t FileOffset, std::error_code& Ec) +{ + Write(Data.GetData(), Data.GetSize(), FileOffset, Ec); +} + +void +BasicFile::Write(const void* Data, uint64_t Size, uint64_t FileOffset, std::error_code& Ec) +{ + Ec.clear(); + + const uint64_t MaxChunkSize = 2u * 1024 * 1024 * 1024; + + while (Size) + { + const uint64_t NumberOfBytesToWrite = Min(Size, MaxChunkSize); + +#if ZEN_PLATFORM_WINDOWS + OVERLAPPED Ovl{}; + + Ovl.Offset = DWORD(FileOffset & 0xffff'ffffu); + Ovl.OffsetHigh = DWORD(FileOffset >> 32); + + DWORD dwNumberOfBytesWritten = 0; + + BOOL Success = ::WriteFile(m_FileHandle, Data, DWORD(NumberOfBytesToWrite), &dwNumberOfBytesWritten, &Ovl); +#else + static_assert(sizeof(off_t) >= sizeof(uint64_t), "sizeof(off_t) does not support large files"); + int Fd = int(uintptr_t(m_FileHandle)); + int BytesWritten = pwrite(Fd, Data, NumberOfBytesToWrite, FileOffset); + bool Success = (BytesWritten > 0); +#endif + + if (!Success) + { + Ec = MakeErrorCodeFromLastError(); + + return; + } + + Size -= NumberOfBytesToWrite; + FileOffset += NumberOfBytesToWrite; + Data = reinterpret_cast<const uint8_t*>(Data) + NumberOfBytesToWrite; + } +} + +void +BasicFile::Write(MemoryView Data, uint64_t FileOffset) +{ + Write(Data.GetData(), Data.GetSize(), FileOffset); +} + +void +BasicFile::Write(const void* Data, uint64_t Size, uint64_t Offset) +{ + std::error_code Ec; + Write(Data, Size, Offset, Ec); + + if (Ec) + { + throw std::system_error(Ec, fmt::format("Failed to write to file '{}'", zen::PathFromHandle(m_FileHandle))); + } +} + +void +BasicFile::WriteAll(IoBuffer Data, std::error_code& Ec) +{ + Write(Data.Data(), Data.Size(), 0, Ec); +} + +void +BasicFile::Flush() +{ +#if ZEN_PLATFORM_WINDOWS + FlushFileBuffers(m_FileHandle); +#else + int Fd = int(uintptr_t(m_FileHandle)); + fsync(Fd); +#endif +} + +uint64_t +BasicFile::FileSize() +{ +#if ZEN_PLATFORM_WINDOWS + ULARGE_INTEGER liFileSize; + liFileSize.LowPart = ::GetFileSize(m_FileHandle, &liFileSize.HighPart); + if (liFileSize.LowPart == INVALID_FILE_SIZE) + { + int Error = zen::GetLastError(); + if (Error) + { + ThrowSystemError(Error, fmt::format("Failed to get file size from file '{}'", PathFromHandle(m_FileHandle))); + } + } + return uint64_t(liFileSize.QuadPart); +#else + int Fd = int(uintptr_t(m_FileHandle)); + static_assert(sizeof(decltype(stat::st_size)) == sizeof(uint64_t), "fstat() doesn't support large files"); + struct stat Stat; + fstat(Fd, &Stat); + return uint64_t(Stat.st_size); +#endif +} + +void +BasicFile::SetFileSize(uint64_t FileSize) +{ +#if ZEN_PLATFORM_WINDOWS + LARGE_INTEGER liFileSize; + liFileSize.QuadPart = FileSize; + BOOL OK = ::SetFilePointerEx(m_FileHandle, liFileSize, 0, FILE_BEGIN); + if (OK == FALSE) + { + int Error = zen::GetLastError(); + if (Error) + { + ThrowSystemError(Error, fmt::format("Failed to set file pointer to {} for file {}", FileSize, PathFromHandle(m_FileHandle))); + } + } + OK = ::SetEndOfFile(m_FileHandle); + if (OK == FALSE) + { + int Error = zen::GetLastError(); + if (Error) + { + ThrowSystemError(Error, fmt::format("Failed to set end of file to {} for file {}", FileSize, PathFromHandle(m_FileHandle))); + } + } +#elif ZEN_PLATFORM_MAC + int Fd = int(intptr_t(m_FileHandle)); + if (ftruncate(Fd, (off_t)FileSize) < 0) + { + int Error = zen::GetLastError(); + if (Error) + { + ThrowSystemError(Error, fmt::format("Failed to set truncate file to {} for file {}", FileSize, PathFromHandle(m_FileHandle))); + } + } +#else + int Fd = int(intptr_t(m_FileHandle)); + if (ftruncate64(Fd, (off64_t)FileSize) < 0) + { + int Error = zen::GetLastError(); + if (Error) + { + ThrowSystemError(Error, fmt::format("Failed to set truncate file to {} for file {}", FileSize, PathFromHandle(m_FileHandle))); + } + } + if (FileSize > 0) + { + int Error = posix_fallocate64(Fd, 0, (off64_t)FileSize); + if (Error) + { + ThrowSystemError(Error, fmt::format("Failed to allocate space of {} for file {}", FileSize, PathFromHandle(m_FileHandle))); + } + } +#endif +} + +void* +BasicFile::Detach() +{ + void* FileHandle = m_FileHandle; + m_FileHandle = 0; + return FileHandle; +} + +////////////////////////////////////////////////////////////////////////// + +TemporaryFile::~TemporaryFile() +{ + Close(); +} + +void +TemporaryFile::Close() +{ + if (m_FileHandle) + { +#if ZEN_PLATFORM_WINDOWS + // Mark file for deletion when final handle is closed + + FILE_DISPOSITION_INFO Fdi{.DeleteFile = TRUE}; + + SetFileInformationByHandle(m_FileHandle, FileDispositionInfo, &Fdi, sizeof Fdi); +#else + std::filesystem::path FilePath = zen::PathFromHandle(m_FileHandle); + unlink(FilePath.c_str()); +#endif + + BasicFile::Close(); + } +} + +void +TemporaryFile::CreateTemporary(std::filesystem::path TempDirName, std::error_code& Ec) +{ + StringBuilder<64> TempName; + Oid::NewOid().ToString(TempName); + + m_TempPath = TempDirName / TempName.c_str(); + + Open(m_TempPath, BasicFile::Mode::kTruncateDelete, Ec); +} + +void +TemporaryFile::MoveTemporaryIntoPlace(std::filesystem::path FinalFileName, std::error_code& Ec) +{ + // We intentionally call the base class Close() since otherwise we'll end up + // deleting the temporary file + BasicFile::Close(); + + std::filesystem::rename(m_TempPath, FinalFileName, Ec); +} + +////////////////////////////////////////////////////////////////////////// + +LockFile::LockFile() +{ +} + +LockFile::~LockFile() +{ +#if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC + int Fd = int(intptr_t(m_FileHandle)); + flock(Fd, LOCK_UN | LOCK_NB); +#endif +} + +void +LockFile::Create(std::filesystem::path FileName, CbObject Payload, std::error_code& Ec) +{ +#if ZEN_PLATFORM_WINDOWS + Ec.clear(); + + const DWORD dwCreationDisposition = CREATE_ALWAYS; + DWORD dwDesiredAccess = GENERIC_READ | GENERIC_WRITE | DELETE; + const DWORD dwShareMode = FILE_SHARE_READ; + const DWORD dwFlagsAndAttributes = FILE_ATTRIBUTE_NORMAL | FILE_FLAG_DELETE_ON_CLOSE; + HANDLE hTemplateFile = nullptr; + + HANDLE FileHandle = CreateFile(FileName.c_str(), + dwDesiredAccess, + dwShareMode, + /* lpSecurityAttributes */ nullptr, + dwCreationDisposition, + dwFlagsAndAttributes, + hTemplateFile); + + if (FileHandle == INVALID_HANDLE_VALUE) + { + Ec = zen::MakeErrorCodeFromLastError(); + + return; + } +#elif ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC + int Fd = open(FileName.c_str(), O_RDWR | O_CREAT | O_CLOEXEC, 0666); + if (Fd < 0) + { + Ec = zen::MakeErrorCodeFromLastError(); + return; + } + fchmod(Fd, 0666); + + int LockRet = flock(Fd, LOCK_EX | LOCK_NB); + if (LockRet < 0) + { + Ec = zen::MakeErrorCodeFromLastError(); + close(Fd); + return; + } + + void* FileHandle = (void*)uintptr_t(Fd); +#endif + + m_FileHandle = FileHandle; + + BasicFile::Write(Payload.GetBuffer(), 0, Ec); +} + +void +LockFile::Update(CbObject Payload, std::error_code& Ec) +{ + BasicFile::Write(Payload.GetBuffer(), 0, Ec); +} + +/* + ___________ __ + \__ ___/___ _______/ |_ ______ + | |_/ __ \ / ___/\ __\/ ___/ + | |\ ___/ \___ \ | | \___ \ + |____| \___ >____ > |__| /____ > + \/ \/ \/ +*/ + +#if ZEN_WITH_TESTS + +TEST_CASE("BasicFile") +{ + ScopedCurrentDirectoryChange _; + + BasicFile File1; + CHECK_THROWS(File1.Open("zonk", BasicFile::Mode::kRead)); + CHECK_NOTHROW(File1.Open("zonk", BasicFile::Mode::kTruncate)); + CHECK_NOTHROW(File1.Write("abcd", 4, 0)); + CHECK(File1.FileSize() == 4); + { + IoBuffer Data = File1.ReadAll(); + CHECK(Data.Size() == 4); + CHECK_EQ(memcmp(Data.Data(), "abcd", 4), 0); + } + CHECK_NOTHROW(File1.Write("efgh", 4, 2)); + CHECK(File1.FileSize() == 6); + { + IoBuffer Data = File1.ReadAll(); + CHECK(Data.Size() == 6); + CHECK_EQ(memcmp(Data.Data(), "abefgh", 6), 0); + } +} + +TEST_CASE("TemporaryFile") +{ + ScopedCurrentDirectoryChange _; + + SUBCASE("DeleteOnClose") + { + TemporaryFile TmpFile; + std::error_code Ec; + TmpFile.CreateTemporary(std::filesystem::current_path(), Ec); + CHECK(!Ec); + CHECK(std::filesystem::exists(TmpFile.GetPath())); + TmpFile.Close(); + CHECK(std::filesystem::exists(TmpFile.GetPath()) == false); + } + + SUBCASE("MoveIntoPlace") + { + TemporaryFile TmpFile; + std::error_code Ec; + TmpFile.CreateTemporary(std::filesystem::current_path(), Ec); + CHECK(!Ec); + std::filesystem::path TempPath = TmpFile.GetPath(); + std::filesystem::path FinalPath = std::filesystem::current_path() / "final"; + CHECK(std::filesystem::exists(TempPath)); + TmpFile.MoveTemporaryIntoPlace(FinalPath, Ec); + CHECK(!Ec); + CHECK(std::filesystem::exists(TempPath) == false); + CHECK(std::filesystem::exists(FinalPath)); + } +} + +void +basicfile_forcelink() +{ +} + +#endif + +} // namespace zen diff --git a/src/zenutil/cache/cachekey.cpp b/src/zenutil/cache/cachekey.cpp new file mode 100644 index 000000000..545b47f11 --- /dev/null +++ b/src/zenutil/cache/cachekey.cpp @@ -0,0 +1,9 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenutil/cache/cachekey.h> + +namespace zen { + +const CacheKey CacheKey::Empty = CacheKey{.Bucket = std::string(), .Hash = IoHash()}; + +} // namespace zen diff --git a/src/zenutil/cache/cachepolicy.cpp b/src/zenutil/cache/cachepolicy.cpp new file mode 100644 index 000000000..3bca363bb --- /dev/null +++ b/src/zenutil/cache/cachepolicy.cpp @@ -0,0 +1,282 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenutil/cache/cachepolicy.h> + +#include <zencore/compactbinary.h> +#include <zencore/compactbinarybuilder.h> +#include <zencore/enumflags.h> +#include <zencore/string.h> + +#include <algorithm> +#include <unordered_map> + +namespace zen::Private { +class CacheRecordPolicyShared; +} + +namespace zen { + +using namespace std::literals; + +namespace DerivedData::Private { + + constinit char CachePolicyDelimiter = ','; + + struct CachePolicyToTextData + { + CachePolicy Policy; + std::string_view Text; + }; + + constinit CachePolicyToTextData CachePolicyToText[]{ + // Flags with multiple bits are ordered by bit count to minimize token count in the text format. + {CachePolicy::Default, "Default"sv}, + {CachePolicy::Remote, "Remote"sv}, + {CachePolicy::Local, "Local"sv}, + {CachePolicy::Store, "Store"sv}, + {CachePolicy::Query, "Query"sv}, + // Flags with only one bit can be in any order. Match the order in CachePolicy. + {CachePolicy::QueryLocal, "QueryLocal"sv}, + {CachePolicy::QueryRemote, "QueryRemote"sv}, + {CachePolicy::StoreLocal, "StoreLocal"sv}, + {CachePolicy::StoreRemote, "StoreRemote"sv}, + {CachePolicy::SkipMeta, "SkipMeta"sv}, + {CachePolicy::SkipData, "SkipData"sv}, + {CachePolicy::PartialRecord, "PartialRecord"sv}, + {CachePolicy::KeepAlive, "KeepAlive"sv}, + // None must be last because it matches every policy. + {CachePolicy::None, "None"sv}, + }; + + constinit CachePolicy CachePolicyKnownFlags = + CachePolicy::Default | CachePolicy::SkipMeta | CachePolicy::SkipData | CachePolicy::PartialRecord | CachePolicy::KeepAlive; + + StringBuilderBase& CachePolicyToString(StringBuilderBase& Builder, CachePolicy Policy) + { + // Mask out unknown flags. None will be written if no flags are known. + Policy &= CachePolicyKnownFlags; + for (const CachePolicyToTextData& Pair : CachePolicyToText) + { + if (EnumHasAllFlags(Policy, Pair.Policy)) + { + EnumRemoveFlags(Policy, Pair.Policy); + Builder << Pair.Text << CachePolicyDelimiter; + if (Policy == CachePolicy::None) + { + break; + } + } + } + Builder.RemoveSuffix(1); + return Builder; + } + + CachePolicy ParseCachePolicy(const std::string_view Text) + { + ZEN_ASSERT(!Text.empty()); // ParseCachePolicy requires a non-empty string + CachePolicy Policy = CachePolicy::None; + ForEachStrTok(Text, CachePolicyDelimiter, [&Policy, Index = int32_t(0)](const std::string_view& Token) mutable { + const int32_t EndIndex = Index; + for (; size_t(Index) < sizeof(CachePolicyToText) / sizeof(CachePolicyToText[0]); ++Index) + { + if (CachePolicyToText[Index].Text == Token) + { + Policy |= CachePolicyToText[Index].Policy; + ++Index; + return true; + } + } + for (Index = 0; Index < EndIndex; ++Index) + { + if (CachePolicyToText[Index].Text == Token) + { + Policy |= CachePolicyToText[Index].Policy; + ++Index; + return true; + } + } + return true; + }); + return Policy; + } + +} // namespace DerivedData::Private + +StringBuilderBase& +operator<<(StringBuilderBase& Builder, CachePolicy Policy) +{ + return DerivedData::Private::CachePolicyToString(Builder, Policy); +} + +CachePolicy +ParseCachePolicy(std::string_view Text) +{ + return DerivedData::Private::ParseCachePolicy(Text); +} + +CachePolicy +ConvertToUpstream(CachePolicy Policy) +{ + // Set Local flags equal to downstream's Remote flags. + // Delete Skip flags if StoreLocal is true, otherwise use the downstream value. + // Use the downstream value for all other flags. + + CachePolicy UpstreamPolicy = CachePolicy::None; + + if (EnumHasAllFlags(Policy, CachePolicy::QueryRemote)) + { + UpstreamPolicy |= CachePolicy::QueryLocal; + } + + if (EnumHasAllFlags(Policy, CachePolicy::StoreRemote)) + { + UpstreamPolicy |= CachePolicy::StoreLocal; + } + + if (!EnumHasAllFlags(Policy, CachePolicy::StoreLocal)) + { + UpstreamPolicy |= (Policy & (CachePolicy::SkipData | CachePolicy::SkipMeta)); + } + + UpstreamPolicy |= Policy & ~(CachePolicy::Local | CachePolicy::SkipData | CachePolicy::SkipMeta); + + return UpstreamPolicy; +} + +class Private::CacheRecordPolicyShared final : public Private::ICacheRecordPolicyShared +{ +public: + inline void AddValuePolicy(const CacheValuePolicy& Value) final + { + ZEN_ASSERT(Value.Id); // Failed to add value policy because the ID is null. + const auto Insert = + std::lower_bound(Values.begin(), Values.end(), Value, [](const CacheValuePolicy& Existing, const CacheValuePolicy& New) { + return Existing.Id < New.Id; + }); + ZEN_ASSERT( + !(Insert < Values.end() && + Insert->Id == Value.Id)); // Failed to add value policy with ID %s because it has an existing value policy with that ID. ") + Values.insert(Insert, Value); + } + + inline std::span<const CacheValuePolicy> GetValuePolicies() const final { return Values; } + +private: + std::vector<CacheValuePolicy> Values; +}; + +CachePolicy +CacheRecordPolicy::GetValuePolicy(const Oid& Id) const +{ + if (Shared) + { + const std::span<const CacheValuePolicy> Values = Shared->GetValuePolicies(); + const auto Iter = + std::lower_bound(Values.begin(), Values.end(), Id, [](const CacheValuePolicy& A, const Oid& B) { return A.Id < B; }); + if (Iter != Values.end() && Iter->Id == Id) + { + return Iter->Policy; + } + } + return DefaultValuePolicy; +} + +void +CacheRecordPolicy::Save(CbWriter& Writer) const +{ + Writer.BeginObject(); + // The RecordPolicy is calculated from the ValuePolicies and does not need to be saved separately. + Writer.AddString("BasePolicy"sv, WriteToString<128>(GetBasePolicy())); + if (!IsUniform()) + { + Writer.BeginArray("ValuePolicies"sv); + for (const CacheValuePolicy& Value : GetValuePolicies()) + { + Writer.BeginObject(); + Writer.AddObjectId("Id"sv, Value.Id); + Writer.AddString("Policy"sv, WriteToString<128>(Value.Policy)); + Writer.EndObject(); + } + Writer.EndArray(); + } + Writer.EndObject(); +} + +OptionalCacheRecordPolicy +CacheRecordPolicy::Load(const CbObjectView Object) +{ + std::string_view BasePolicyText = Object["BasePolicy"sv].AsString(); + if (BasePolicyText.empty()) + { + return {}; + } + + CacheRecordPolicyBuilder Builder(ParseCachePolicy(BasePolicyText)); + for (CbFieldView ValueField : Object["ValuePolicies"sv]) + { + const CbObjectView Value = ValueField.AsObjectView(); + const Oid Id = Value["Id"sv].AsObjectId(); + const std::string_view PolicyText = Value["Policy"sv].AsString(); + if (!Id || PolicyText.empty()) + { + return {}; + } + CachePolicy Policy = ParseCachePolicy(PolicyText); + if (EnumHasAnyFlags(Policy, ~CacheValuePolicy::PolicyMask)) + { + return {}; + } + Builder.AddValuePolicy(Id, Policy); + } + + return Builder.Build(); +} + +CacheRecordPolicy +CacheRecordPolicy::ConvertToUpstream() const +{ + CacheRecordPolicyBuilder Builder(zen::ConvertToUpstream(GetBasePolicy())); + for (const CacheValuePolicy& ValuePolicy : GetValuePolicies()) + { + Builder.AddValuePolicy(ValuePolicy.Id, zen::ConvertToUpstream(ValuePolicy.Policy)); + } + return Builder.Build(); +} + +void +CacheRecordPolicyBuilder::AddValuePolicy(const CacheValuePolicy& Value) +{ + ZEN_ASSERT(!EnumHasAnyFlags(Value.Policy, + ~Value.PolicyMask)); // Value policy contains flags that only make sense on the record policy. Policy: %s + if (Value.Policy == (BasePolicy & Value.PolicyMask)) + { + return; + } + if (!Shared) + { + Shared = new Private::CacheRecordPolicyShared; + } + Shared->AddValuePolicy(Value); +} + +CacheRecordPolicy +CacheRecordPolicyBuilder::Build() +{ + CacheRecordPolicy Policy(BasePolicy); + if (Shared) + { + const auto Add = [](const CachePolicy A, const CachePolicy B) { + return ((A | B) & ~CachePolicy::SkipData) | ((A & B) & CachePolicy::SkipData); + }; + const std::span<const CacheValuePolicy> Values = Shared->GetValuePolicies(); + Policy.RecordPolicy = BasePolicy; + for (const CacheValuePolicy& ValuePolicy : Values) + { + Policy.RecordPolicy = Add(Policy.RecordPolicy, ValuePolicy.Policy); + } + Policy.Shared = std::move(Shared); + } + return Policy; +} + +} // namespace zen diff --git a/src/zenutil/cache/cacherequests.cpp b/src/zenutil/cache/cacherequests.cpp new file mode 100644 index 000000000..4c865ec22 --- /dev/null +++ b/src/zenutil/cache/cacherequests.cpp @@ -0,0 +1,1643 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenutil/cache/cacherequests.h> + +#include <zencore/compactbinary.h> +#include <zencore/compactbinarybuilder.h> +#include <zencore/compactbinarypackage.h> +#include <zencore/fmtutils.h> +#include <zencore/logging.h> +#include <zencore/zencore.h> + +#include <string> +#include <string_view> + +#if ZEN_WITH_TESTS +# include <zencore/testing.h> +#endif + +namespace zen { + +namespace cacherequests { + + namespace { + constinit AsciiSet ValidNamespaceNameCharactersSet{"abcdefghijklmnopqrstuvwxyz0123456789-_.ABCDEFGHIJKLMNOPQRSTUVWXYZ"}; + constinit AsciiSet ValidBucketNameCharactersSet{"abcdefghijklmnopqrstuvwxyz0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"}; + + std::optional<std::string> GetValidNamespaceName(std::string_view Name) + { + if (Name.empty()) + { + ZEN_WARN("Namespace is invalid, empty namespace is not allowed"); + return {}; + } + + if (Name.length() > 64) + { + ZEN_WARN("Namespace '{}' is invalid, length exceeds 64 characters", Name); + return {}; + } + + if (!AsciiSet::HasOnly(Name, ValidNamespaceNameCharactersSet)) + { + ZEN_WARN("Namespace '{}' is invalid, invalid characters detected", Name); + return {}; + } + + return ToLower(Name); + } + + std::optional<std::string> GetValidBucketName(std::string_view Name) + { + if (Name.empty()) + { + ZEN_WARN("Bucket name is invalid, empty bucket name is not allowed"); + return {}; + } + + if (!AsciiSet::HasOnly(Name, ValidBucketNameCharactersSet)) + { + ZEN_WARN("Bucket name '{}' is invalid, invalid characters detected", Name); + return {}; + } + + return ToLower(Name); + } + + std::optional<IoHash> GetValidIoHash(std::string_view Hash) + { + if (Hash.length() != IoHash::StringLength) + { + return {}; + } + + IoHash KeyHash; + if (!ParseHexBytes(Hash.data(), Hash.size(), KeyHash.Hash)) + { + return {}; + } + return KeyHash; + } + + std::optional<CacheRecordPolicy> Convert(const OptionalCacheRecordPolicy& Policy) + { + return Policy.IsValid() ? Policy.Get() : std::optional<CacheRecordPolicy>{}; + }; + } // namespace + + std::optional<std::string> GetRequestNamespace(const CbObjectView& Params) + { + CbFieldView NamespaceField = Params["Namespace"]; + if (!NamespaceField) + { + return std::string("!default!"); // ZenCacheStore::DefaultNamespace); + } + + if (NamespaceField.HasError()) + { + return {}; + } + if (!NamespaceField.IsString()) + { + return {}; + } + return GetValidNamespaceName(NamespaceField.AsString()); + } + + bool GetRequestCacheKey(const CbObjectView& KeyView, CacheKey& Key) + { + CbFieldView BucketField = KeyView["Bucket"]; + if (BucketField.HasError()) + { + return false; + } + if (!BucketField.IsString()) + { + return false; + } + std::optional<std::string> Bucket = GetValidBucketName(BucketField.AsString()); + if (!Bucket.has_value()) + { + return false; + } + CbFieldView HashField = KeyView["Hash"]; + if (HashField.HasError()) + { + return false; + } + if (!HashField.IsHash()) + { + return false; + } + Key.Bucket = *Bucket; + Key.Hash = HashField.AsHash(); + return true; + } + + void WriteCacheRequestKey(CbObjectWriter& Writer, const CacheKey& Value) + { + Writer.BeginObject("Key"); + { + Writer << "Bucket" << Value.Bucket; + Writer << "Hash" << Value.Hash; + } + Writer.EndObject(); + } + + void WriteOptionalCacheRequestPolicy(CbObjectWriter& Writer, std::string_view FieldName, const std::optional<CacheRecordPolicy>& Policy) + { + if (Policy) + { + Writer.SetName(FieldName); + Policy->Save(Writer); + } + } + + std::optional<CachePolicy> GetCachePolicy(CbObjectView ObjectView, std::string_view FieldName) + { + std::string_view DefaultPolicyText = ObjectView[FieldName].AsString(); + if (DefaultPolicyText.empty()) + { + return {}; + } + return ParseCachePolicy(DefaultPolicyText); + } + + void WriteCachePolicy(CbObjectWriter& Writer, std::string_view FieldName, const std::optional<CachePolicy>& Policy) + { + if (Policy) + { + Writer << FieldName << WriteToString<128>(*Policy); + } + } + + bool PutCacheRecordsRequest::Parse(const CbPackage& Package) + { + CbObjectView BatchObject = Package.GetObject(); + ZEN_ASSERT(BatchObject["Method"].AsString() == "PutCacheRecords"); + AcceptMagic = BatchObject["AcceptType"].AsUInt32(0); + + CbObjectView Params = BatchObject["Params"].AsObjectView(); + std::optional<std::string> RequestNamespace = GetRequestNamespace(Params); + if (!RequestNamespace) + { + return false; + } + Namespace = *RequestNamespace; + DefaultPolicy = GetCachePolicy(Params, "DefaultPolicy").value_or(CachePolicy::Default); + + CbArrayView RequestFieldArray = Params["Requests"].AsArrayView(); + Requests.resize(RequestFieldArray.Num()); + for (size_t RequestIndex = 0; CbFieldView RequestField : RequestFieldArray) + { + CbObjectView RequestObject = RequestField.AsObjectView(); + CbObjectView RecordObject = RequestObject["Record"].AsObjectView(); + CbObjectView KeyView = RecordObject["Key"].AsObjectView(); + + PutCacheRecordRequest& Request = Requests[RequestIndex++]; + + if (!GetRequestCacheKey(KeyView, Request.Key)) + { + return false; + } + + Request.Policy = Convert(CacheRecordPolicy::Load(RequestObject["Policy"].AsObjectView())); + + std::unordered_map<IoHash, size_t, IoHash::Hasher> RawHashToAttachmentIndex; + + CbArrayView ValuesArray = RecordObject["Values"].AsArrayView(); + Request.Values.resize(ValuesArray.Num()); + RawHashToAttachmentIndex.reserve(ValuesArray.Num()); + for (size_t Index = 0; CbFieldView Value : ValuesArray) + { + CbObjectView ObjectView = Value.AsObjectView(); + IoHash AttachmentHash = ObjectView["RawHash"].AsHash(); + RawHashToAttachmentIndex[AttachmentHash] = Index; + Request.Values[Index++] = {.Id = ObjectView["Id"].AsObjectId(), .RawHash = AttachmentHash}; + } + + RecordObject.IterateAttachments([&](CbFieldView HashView) { + const IoHash ValueHash = HashView.AsHash(); + if (const CbAttachment* Attachment = Package.FindAttachment(ValueHash)) + { + if (Attachment->IsCompressedBinary()) + { + auto It = RawHashToAttachmentIndex.find(ValueHash); + ZEN_ASSERT(It != RawHashToAttachmentIndex.end()); + PutCacheRecordRequestValue& Value = Request.Values[It->second]; + ZEN_ASSERT(Value.RawHash == ValueHash); + Value.Body = Attachment->AsCompressedBinary(); + ZEN_ASSERT_SLOW(Value.Body.DecodeRawHash() == Value.RawHash); + } + } + }); + } + + return true; + } + + bool PutCacheRecordsRequest::Format(CbPackage& OutPackage) const + { + CbObjectWriter Writer; + Writer << "Method" + << "PutCacheRecords"; + if (AcceptMagic != 0) + { + Writer << "Accept" << AcceptMagic; + } + + Writer.BeginObject("Params"); + { + Writer << "DefaultPolicy" << WriteToString<128>(DefaultPolicy); + Writer << "Namespace" << Namespace; + + Writer.BeginArray("Requests"); + for (const PutCacheRecordRequest& RecordRequest : Requests) + { + Writer.BeginObject(); + { + Writer.BeginObject("Record"); + { + WriteCacheRequestKey(Writer, RecordRequest.Key); + Writer.BeginArray("Values"); + for (const PutCacheRecordRequestValue& Value : RecordRequest.Values) + { + Writer.BeginObject(); + { + Writer.AddObjectId("Id", Value.Id); + const CompressedBuffer& Buffer = Value.Body; + if (Buffer) + { + IoHash AttachmentHash = Buffer.DecodeRawHash(); // TODO: Slow! + Writer.AddBinaryAttachment("RawHash", AttachmentHash); + OutPackage.AddAttachment(CbAttachment(Buffer, AttachmentHash)); + Writer.AddInteger("RawSize", Buffer.DecodeRawSize()); // TODO: Slow! + } + else + { + if (Value.RawHash == IoHash::Zero) + { + return false; + } + Writer.AddBinaryAttachment("RawHash", Value.RawHash); + } + } + Writer.EndObject(); + } + Writer.EndArray(); + } + Writer.EndObject(); + WriteOptionalCacheRequestPolicy(Writer, "Policy", RecordRequest.Policy); + } + Writer.EndObject(); + } + Writer.EndArray(); + } + Writer.EndObject(); + OutPackage.SetObject(Writer.Save()); + + return true; + } + + bool PutCacheRecordsResult::Parse(const CbPackage& Package) + { + CbArrayView ResultsArray = Package.GetObject()["Result"].AsArrayView(); + if (!ResultsArray) + { + return false; + } + CbFieldViewIterator It = ResultsArray.CreateViewIterator(); + while (It.HasValue()) + { + Success.push_back(It.AsBool()); + It++; + } + return true; + } + + bool PutCacheRecordsResult::Format(CbPackage& OutPackage) const + { + CbObjectWriter ResponseObject; + ResponseObject.BeginArray("Result"); + for (bool Value : Success) + { + ResponseObject.AddBool(Value); + } + ResponseObject.EndArray(); + + OutPackage.SetObject(ResponseObject.Save()); + return true; + } + + bool GetCacheRecordsRequest::Parse(const CbObjectView& RpcRequest) + { + ZEN_ASSERT(RpcRequest["Method"].AsString() == "GetCacheRecords"); + AcceptMagic = RpcRequest["AcceptType"].AsUInt32(0); + AcceptOptions = RpcRequest["AcceptFlags"].AsUInt16(0); + ProcessPid = RpcRequest["Pid"].AsInt32(0); + + CbObjectView Params = RpcRequest["Params"].AsObjectView(); + std::optional<std::string> RequestNamespace = GetRequestNamespace(Params); + if (!RequestNamespace) + { + return false; + } + + Namespace = *RequestNamespace; + DefaultPolicy = GetCachePolicy(Params, "DefaultPolicy").value_or(CachePolicy::Default); + + CbArrayView RequestsArray = Params["Requests"].AsArrayView(); + Requests.reserve(RequestsArray.Num()); + for (CbFieldView RequestField : RequestsArray) + { + CbObjectView RequestObject = RequestField.AsObjectView(); + CbObjectView KeyObject = RequestObject["Key"].AsObjectView(); + + GetCacheRecordRequest& Request = Requests.emplace_back(); + + if (!GetRequestCacheKey(KeyObject, Request.Key)) + { + return false; + } + + Request.Policy = Convert(CacheRecordPolicy::Load(RequestObject["Policy"].AsObjectView())); + } + return true; + } + + bool GetCacheRecordsRequest::Parse(const CbPackage& RpcRequest) { return Parse(RpcRequest.GetObject()); } + + bool GetCacheRecordsRequest::Format(CbObjectWriter& Writer, const std::span<const size_t> OptionalRecordFilter) const + { + Writer << "Method" + << "GetCacheRecords"; + if (AcceptMagic != 0) + { + Writer << "Accept" << AcceptMagic; + } + if (AcceptOptions != 0) + { + Writer << "AcceptFlags" << AcceptOptions; + } + if (ProcessPid != 0) + { + Writer << "Pid" << ProcessPid; + } + + Writer.BeginObject("Params"); + { + Writer << "DefaultPolicy" << WriteToString<128>(DefaultPolicy); + Writer << "Namespace" << Namespace; + Writer.BeginArray("Requests"); + if (OptionalRecordFilter.empty()) + { + for (const GetCacheRecordRequest& RecordRequest : Requests) + { + Writer.BeginObject(); + { + WriteCacheRequestKey(Writer, RecordRequest.Key); + WriteOptionalCacheRequestPolicy(Writer, "Policy", RecordRequest.Policy); + } + Writer.EndObject(); + } + } + else + { + for (size_t Index : OptionalRecordFilter) + { + const GetCacheRecordRequest& RecordRequest = Requests[Index]; + Writer.BeginObject(); + { + WriteCacheRequestKey(Writer, RecordRequest.Key); + WriteOptionalCacheRequestPolicy(Writer, "Policy", RecordRequest.Policy); + } + Writer.EndObject(); + } + } + Writer.EndArray(); + } + Writer.EndObject(); + + return true; + } + + bool GetCacheRecordsRequest::Format(CbPackage& OutPackage, const std::span<const size_t> OptionalRecordFilter) const + { + CbObjectWriter Writer; + if (!Format(Writer, OptionalRecordFilter)) + { + return false; + } + OutPackage.SetObject(Writer.Save()); + return true; + } + + bool GetCacheRecordsResult::Parse(const CbPackage& Package, const std::span<const size_t> OptionalRecordResultIndexes) + { + CbObject ResponseObject = Package.GetObject(); + CbArrayView ResultsArray = ResponseObject["Result"].AsArrayView(); + if (!ResultsArray) + { + return false; + } + + Results.reserve(ResultsArray.Num()); + if (!OptionalRecordResultIndexes.empty() && ResultsArray.Num() != OptionalRecordResultIndexes.size()) + { + return false; + } + for (size_t Index = 0; CbFieldView RecordView : ResultsArray) + { + size_t ResultIndex = OptionalRecordResultIndexes.empty() ? Index : OptionalRecordResultIndexes[Index]; + Index++; + + if (Results.size() <= ResultIndex) + { + Results.resize(ResultIndex + 1); + } + if (RecordView.IsNull()) + { + continue; + } + Results[ResultIndex] = GetCacheRecordResult{}; + GetCacheRecordResult& Request = Results[ResultIndex].value(); + CbObjectView RecordObject = RecordView.AsObjectView(); + CbObjectView KeyObject = RecordObject["Key"].AsObjectView(); + if (!GetRequestCacheKey(KeyObject, Request.Key)) + { + return false; + } + + CbArrayView ValuesArray = RecordObject["Values"].AsArrayView(); + Request.Values.reserve(ValuesArray.Num()); + for (CbFieldView Value : ValuesArray) + { + CbObjectView ValueObject = Value.AsObjectView(); + IoHash RawHash = ValueObject["RawHash"].AsHash(); + uint64_t RawSize = ValueObject["RawSize"].AsUInt64(); + Oid Id = ValueObject["Id"].AsObjectId(); + const CbAttachment* Attachment = Package.FindAttachment(RawHash); + if (!Attachment) + { + Request.Values.push_back({.Id = Id, .RawHash = RawHash, .RawSize = RawSize, .Body = {}}); + continue; + } + if (!Attachment->IsCompressedBinary()) + { + return false; + } + Request.Values.push_back({.Id = Id, .RawHash = RawHash, .RawSize = RawSize, .Body = Attachment->AsCompressedBinary()}); + } + } + return true; + } + + bool GetCacheRecordsResult::Format(CbPackage& OutPackage) const + { + CbObjectWriter Writer; + + Writer.BeginArray("Result"); + for (const std::optional<GetCacheRecordResult>& RecordResult : Results) + { + if (!RecordResult.has_value()) + { + Writer.AddNull(); + continue; + } + Writer.BeginObject(); + WriteCacheRequestKey(Writer, RecordResult->Key); + + Writer.BeginArray("Values"); + for (const GetCacheRecordResultValue& Value : RecordResult->Values) + { + IoHash AttachmentHash = Value.Body ? Value.Body.DecodeRawHash() : Value.RawHash; + Writer.BeginObject(); + { + Writer.AddObjectId("Id", Value.Id); + Writer.AddHash("RawHash", AttachmentHash); + Writer.AddInteger("RawSize", Value.Body ? Value.Body.DecodeRawSize() : Value.RawSize); + } + Writer.EndObject(); + if (Value.Body) + { + OutPackage.AddAttachment(CbAttachment(Value.Body, AttachmentHash)); + } + } + + Writer.EndArray(); + Writer.EndObject(); + } + Writer.EndArray(); + + OutPackage.SetObject(Writer.Save()); + return true; + } + + bool PutCacheValuesRequest::Parse(const CbPackage& Package) + { + CbObjectView BatchObject = Package.GetObject(); + ZEN_ASSERT(BatchObject["Method"].AsString() == "PutCacheValues"); + AcceptMagic = BatchObject["AcceptType"].AsUInt32(0); + + CbObjectView Params = BatchObject["Params"].AsObjectView(); + std::optional<std::string> RequestNamespace = cacherequests::GetRequestNamespace(Params); + if (!RequestNamespace) + { + return false; + } + + Namespace = *RequestNamespace; + DefaultPolicy = GetCachePolicy(Params, "DefaultPolicy").value_or(CachePolicy::Default); + + CbArrayView RequestsArray = Params["Requests"].AsArrayView(); + Requests.reserve(RequestsArray.Num()); + for (CbFieldView RequestField : RequestsArray) + { + CbObjectView RequestObject = RequestField.AsObjectView(); + CbObjectView KeyObject = RequestObject["Key"].AsObjectView(); + + PutCacheValueRequest& Request = Requests.emplace_back(); + + if (!GetRequestCacheKey(KeyObject, Request.Key)) + { + return false; + } + + Request.RawHash = RequestObject["RawHash"].AsBinaryAttachment(); + Request.Policy = GetCachePolicy(RequestObject, "Policy"); + + if (const CbAttachment* Attachment = Package.FindAttachment(Request.RawHash)) + { + if (!Attachment->IsCompressedBinary()) + { + return false; + } + Request.Body = Attachment->AsCompressedBinary(); + } + } + return true; + } + + bool PutCacheValuesRequest::Format(CbPackage& OutPackage) const + { + CbObjectWriter Writer; + Writer << "Method" + << "PutCacheValues"; + if (AcceptMagic != 0) + { + Writer << "Accept" << AcceptMagic; + } + + Writer.BeginObject("Params"); + { + Writer << "DefaultPolicy" << WriteToString<128>(DefaultPolicy); + Writer << "Namespace" << Namespace; + + Writer.BeginArray("Requests"); + for (const PutCacheValueRequest& ValueRequest : Requests) + { + Writer.BeginObject(); + { + WriteCacheRequestKey(Writer, ValueRequest.Key); + if (ValueRequest.Body) + { + IoHash AttachmentHash = ValueRequest.Body.DecodeRawHash(); + if (ValueRequest.RawHash != IoHash::Zero && AttachmentHash != ValueRequest.RawHash) + { + return false; + } + Writer.AddBinaryAttachment("RawHash", AttachmentHash); + OutPackage.AddAttachment(CbAttachment(ValueRequest.Body, AttachmentHash)); + } + else if (ValueRequest.RawHash != IoHash::Zero) + { + Writer.AddBinaryAttachment("RawHash", ValueRequest.RawHash); + } + else + { + return false; + } + WriteCachePolicy(Writer, "Policy", ValueRequest.Policy); + } + Writer.EndObject(); + } + Writer.EndArray(); + } + Writer.EndObject(); + + OutPackage.SetObject(Writer.Save()); + return true; + } + + bool PutCacheValuesResult::Parse(const CbPackage& Package) + { + CbArrayView ResultsArray = Package.GetObject()["Result"].AsArrayView(); + if (!ResultsArray) + { + return false; + } + CbFieldViewIterator It = ResultsArray.CreateViewIterator(); + while (It.HasValue()) + { + Success.push_back(It.AsBool()); + It++; + } + return true; + } + + bool PutCacheValuesResult::Format(CbPackage& OutPackage) const + { + if (Success.empty()) + { + return false; + } + CbObjectWriter ResponseObject; + ResponseObject.BeginArray("Result"); + for (bool Value : Success) + { + ResponseObject.AddBool(Value); + } + ResponseObject.EndArray(); + + OutPackage.SetObject(ResponseObject.Save()); + return true; + } + + bool GetCacheValuesRequest::Parse(const CbObjectView& BatchObject) + { + ZEN_ASSERT(BatchObject["Method"].AsString() == "GetCacheValues"); + AcceptMagic = BatchObject["AcceptType"].AsUInt32(0); + AcceptOptions = BatchObject["AcceptFlags"].AsUInt16(0); + ProcessPid = BatchObject["Pid"].AsInt32(0); + + CbObjectView Params = BatchObject["Params"].AsObjectView(); + std::optional<std::string> RequestNamespace = cacherequests::GetRequestNamespace(Params); + if (!RequestNamespace) + { + return false; + } + + Namespace = *RequestNamespace; + DefaultPolicy = GetCachePolicy(Params, "DefaultPolicy").value_or(CachePolicy::Default); + + CbArrayView RequestsArray = Params["Requests"].AsArrayView(); + Requests.reserve(RequestsArray.Num()); + for (CbFieldView RequestField : RequestsArray) + { + CbObjectView RequestObject = RequestField.AsObjectView(); + CbObjectView KeyObject = RequestObject["Key"].AsObjectView(); + + GetCacheValueRequest& Request = Requests.emplace_back(); + + if (!GetRequestCacheKey(KeyObject, Request.Key)) + { + return false; + } + + Request.Policy = GetCachePolicy(RequestObject, "Policy"); + } + return true; + } + + bool GetCacheValuesRequest::Format(CbPackage& OutPackage, const std::span<const size_t> OptionalValueFilter) const + { + CbObjectWriter Writer; + Writer << "Method" + << "GetCacheValues"; + if (AcceptMagic != 0) + { + Writer << "Accept" << AcceptMagic; + } + if (AcceptOptions != 0) + { + Writer << "AcceptFlags" << AcceptOptions; + } + if (ProcessPid != 0) + { + Writer << "Pid" << ProcessPid; + } + + Writer.BeginObject("Params"); + { + Writer << "DefaultPolicy" << WriteToString<128>(DefaultPolicy); + Writer << "Namespace" << Namespace; + + Writer.BeginArray("Requests"); + if (OptionalValueFilter.empty()) + { + for (const GetCacheValueRequest& ValueRequest : Requests) + { + Writer.BeginObject(); + { + WriteCacheRequestKey(Writer, ValueRequest.Key); + WriteCachePolicy(Writer, "Policy", ValueRequest.Policy); + } + Writer.EndObject(); + } + } + else + { + for (size_t Index : OptionalValueFilter) + { + const GetCacheValueRequest& ValueRequest = Requests[Index]; + Writer.BeginObject(); + { + WriteCacheRequestKey(Writer, ValueRequest.Key); + WriteCachePolicy(Writer, "Policy", ValueRequest.Policy); + } + Writer.EndObject(); + } + } + Writer.EndArray(); + } + Writer.EndObject(); + + OutPackage.SetObject(Writer.Save()); + return true; + } + + bool CacheValuesResult::Parse(const CbPackage& Package, const std::span<const size_t> OptionalValueResultIndexes) + { + CbObject ResponseObject = Package.GetObject(); + CbArrayView ResultsArray = ResponseObject["Result"].AsArrayView(); + if (!ResultsArray) + { + return false; + } + Results.reserve(ResultsArray.Num()); + if (!OptionalValueResultIndexes.empty() && ResultsArray.Num() != OptionalValueResultIndexes.size()) + { + return false; + } + for (size_t Index = 0; CbFieldView RecordView : ResultsArray) + { + size_t ResultIndex = OptionalValueResultIndexes.empty() ? Index : OptionalValueResultIndexes[Index]; + Index++; + + if (Results.size() <= ResultIndex) + { + Results.resize(ResultIndex + 1); + } + if (RecordView.IsNull()) + { + continue; + } + + CacheValueResult& ValueResult = Results[ResultIndex]; + CbObjectView RecordObject = RecordView.AsObjectView(); + + CbFieldView RawHashField = RecordObject["RawHash"]; + ValueResult.RawHash = RawHashField.AsHash(); + bool Succeeded = !RawHashField.HasError(); + if (Succeeded) + { + const CbAttachment* Attachment = Package.FindAttachment(ValueResult.RawHash); + ValueResult.Body = Attachment ? Attachment->AsCompressedBinary() : CompressedBuffer(); + if (ValueResult.Body) + { + ValueResult.RawSize = ValueResult.Body.DecodeRawSize(); + } + else + { + ValueResult.RawSize = RecordObject["RawSize"].AsUInt64(UINT64_MAX); + } + } + } + return true; + } + + bool CacheValuesResult::Format(CbPackage& OutPackage) const + { + CbObjectWriter ResponseObject; + + ResponseObject.BeginArray("Result"); + for (const CacheValueResult& ValueResult : Results) + { + ResponseObject.BeginObject(); + if (ValueResult.RawHash != IoHash::Zero) + { + ResponseObject.AddHash("RawHash", ValueResult.RawHash); + if (ValueResult.Body) + { + OutPackage.AddAttachment(CbAttachment(ValueResult.Body, ValueResult.RawHash)); + } + else + { + ResponseObject.AddInteger("RawSize", ValueResult.RawSize); + } + } + ResponseObject.EndObject(); + } + ResponseObject.EndArray(); + + OutPackage.SetObject(ResponseObject.Save()); + return true; + } + + bool GetCacheChunksRequest::Parse(const CbObjectView& BatchObject) + { + ZEN_ASSERT(BatchObject["Method"].AsString() == "GetCacheChunks"); + AcceptMagic = BatchObject["AcceptType"].AsUInt32(0); + AcceptOptions = BatchObject["AcceptFlags"].AsUInt16(0); + ProcessPid = BatchObject["Pid"].AsInt32(0); + + CbObjectView Params = BatchObject["Params"].AsObjectView(); + std::optional<std::string> RequestNamespace = cacherequests::GetRequestNamespace(Params); + if (!RequestNamespace) + { + return false; + } + + Namespace = *RequestNamespace; + DefaultPolicy = GetCachePolicy(Params, "DefaultPolicy").value_or(CachePolicy::Default); + + CbArrayView RequestsArray = Params["ChunkRequests"].AsArrayView(); + Requests.reserve(RequestsArray.Num()); + for (CbFieldView RequestField : RequestsArray) + { + CbObjectView RequestObject = RequestField.AsObjectView(); + CbObjectView KeyObject = RequestObject["Key"].AsObjectView(); + + GetCacheChunkRequest& Request = Requests.emplace_back(); + + if (!GetRequestCacheKey(KeyObject, Request.Key)) + { + return false; + } + + Request.ValueId = RequestObject["ValueId"].AsObjectId(); + Request.ChunkId = RequestObject["ChunkId"].AsHash(); + Request.RawOffset = RequestObject["RawOffset"].AsUInt64(); + Request.RawSize = RequestObject["RawSize"].AsUInt64(UINT64_MAX); + + Request.Policy = GetCachePolicy(RequestObject, "Policy"); + } + return true; + } + + bool GetCacheChunksRequest::Format(CbPackage& OutPackage) const + { + CbObjectWriter Writer; + Writer << "Method" + << "GetCacheChunks"; + if (AcceptMagic != 0) + { + Writer << "Accept" << AcceptMagic; + } + if (AcceptOptions != 0) + { + Writer << "AcceptFlags" << AcceptOptions; + } + if (ProcessPid != 0) + { + Writer << "Pid" << ProcessPid; + } + + Writer.BeginObject("Params"); + { + Writer << "DefaultPolicy" << WriteToString<128>(DefaultPolicy); + Writer << "Namespace" << Namespace; + + Writer.BeginArray("ChunkRequests"); + for (const GetCacheChunkRequest& ValueRequest : Requests) + { + Writer.BeginObject(); + { + WriteCacheRequestKey(Writer, ValueRequest.Key); + + Writer.AddObjectId("ValueId", ValueRequest.ValueId); + Writer.AddHash("ChunkId", ValueRequest.ChunkId); + Writer.AddInteger("RawOffset", ValueRequest.RawOffset); + Writer.AddInteger("RawSize", ValueRequest.RawSize); + + WriteCachePolicy(Writer, "Policy", ValueRequest.Policy); + } + Writer.EndObject(); + } + Writer.EndArray(); + } + Writer.EndObject(); + + OutPackage.SetObject(Writer.Save()); + return true; + } + + bool HttpRequestParseRelativeUri(std::string_view Key, HttpRequestData& Data) + { + std::vector<std::string_view> Tokens; + uint32_t TokenCount = zen::ForEachStrTok(Key, '/', [&](const std::string_view& Token) { + Tokens.push_back(Token); + return true; + }); + + switch (TokenCount) + { + case 1: + Data.Namespace = GetValidNamespaceName(Tokens[0]); + return Data.Namespace.has_value(); + case 2: + { + std::optional<IoHash> PossibleHashKey = GetValidIoHash(Tokens[1]); + if (PossibleHashKey.has_value()) + { + // Legacy bucket/key request + Data.Bucket = GetValidBucketName(Tokens[0]); + if (!Data.Bucket.has_value()) + { + return false; + } + Data.HashKey = PossibleHashKey; + return true; + } + Data.Namespace = GetValidNamespaceName(Tokens[0]); + if (!Data.Namespace.has_value()) + { + return false; + } + Data.Bucket = GetValidBucketName(Tokens[1]); + if (!Data.Bucket.has_value()) + { + return false; + } + return true; + } + case 3: + { + std::optional<IoHash> PossibleHashKey = GetValidIoHash(Tokens[1]); + if (PossibleHashKey.has_value()) + { + // Legacy bucket/key/valueid request + Data.Bucket = GetValidBucketName(Tokens[0]); + if (!Data.Bucket.has_value()) + { + return false; + } + Data.HashKey = PossibleHashKey; + Data.ValueContentId = GetValidIoHash(Tokens[2]); + if (!Data.ValueContentId.has_value()) + { + return false; + } + return true; + } + Data.Namespace = GetValidNamespaceName(Tokens[0]); + if (!Data.Namespace.has_value()) + { + return false; + } + Data.Bucket = GetValidBucketName(Tokens[1]); + if (!Data.Bucket.has_value()) + { + return false; + } + Data.HashKey = GetValidIoHash(Tokens[2]); + if (!Data.HashKey) + { + return false; + } + return true; + } + case 4: + { + Data.Namespace = GetValidNamespaceName(Tokens[0]); + if (!Data.Namespace.has_value()) + { + return false; + } + + Data.Bucket = GetValidBucketName(Tokens[1]); + if (!Data.Bucket.has_value()) + { + return false; + } + + Data.HashKey = GetValidIoHash(Tokens[2]); + if (!Data.HashKey.has_value()) + { + return false; + } + + Data.ValueContentId = GetValidIoHash(Tokens[3]); + if (!Data.ValueContentId.has_value()) + { + return false; + } + return true; + } + default: + return false; + } + } + + // bool CacheRecord::Parse(CbObjectView& Reader) + // { + // CbObjectView KeyView = Reader["Key"].AsObjectView(); + // + // if (!GetRequestCacheKey(KeyView, Key)) + // { + // return false; + // } + // CbArrayView ValuesArray = Reader["Values"].AsArrayView(); + // Values.reserve(ValuesArray.Num()); + // for (CbFieldView Value : ValuesArray) + // { + // CbObjectView ObjectView = Value.AsObjectView(); + // Values.push_back({.Id = ObjectView["Id"].AsObjectId(), + // .RawHash = ObjectView["RawHash"].AsHash(), + // .RawSize = ObjectView["RawSize"].AsUInt64()}); + // } + // return true; + // } + // + // bool CacheRecord::Format(CbObjectWriter& Writer) const + // { + // WriteCacheRequestKey(Writer, Key); + // Writer.BeginArray("Values"); + // for (const CacheRecordValue& Value : Values) + // { + // Writer.BeginObject(); + // { + // Writer.AddObjectId("Id", Value.Id); + // Writer.AddHash("RawHash", Value.RawHash); + // Writer.AddInteger("RawSize", Value.RawSize); + // } + // Writer.EndObject(); + // } + // Writer.EndArray(); + // return true; + // } + +#if ZEN_WITH_TESTS + + static bool operator==(const PutCacheRecordRequestValue& Lhs, const PutCacheRecordRequestValue& Rhs) + { + const IoHash LhsRawHash = Lhs.RawHash != IoHash::Zero ? Lhs.RawHash : Lhs.Body.DecodeRawHash(); + const IoHash RhsRawHash = Rhs.RawHash != IoHash::Zero ? Rhs.RawHash : Rhs.Body.DecodeRawHash(); + return Lhs.Id == Rhs.Id && LhsRawHash == RhsRawHash && + Lhs.Body.GetCompressed().Flatten().GetView().EqualBytes(Rhs.Body.GetCompressed().Flatten().GetView()); + } + + static bool operator==(const zen::CacheValuePolicy& Lhs, const zen::CacheValuePolicy& Rhs) + { + return (Lhs.Id == Rhs.Id) && (Lhs.Policy == Rhs.Policy); + } + + static bool operator==(const std::span<const zen::CacheValuePolicy>& Lhs, const std::span<const zen::CacheValuePolicy>& Rhs) + { + if (Lhs.size() != Lhs.size()) + { + return false; + } + for (size_t Idx = 0; Idx < Lhs.size(); ++Idx) + { + if (Lhs[Idx] != Rhs[Idx]) + { + return false; + } + } + return true; + } + + static bool operator==(const zen::CacheRecordPolicy& Lhs, const zen::CacheRecordPolicy& Rhs) + { + return (Lhs.GetRecordPolicy() == Rhs.GetRecordPolicy()) && (Lhs.GetBasePolicy() == Rhs.GetBasePolicy()) && + (Lhs.GetValuePolicies() == Rhs.GetValuePolicies()); + } + + static bool operator==(const std::optional<CacheRecordPolicy>& Lhs, const std::optional<CacheRecordPolicy>& Rhs) + { + return (Lhs.has_value() == Rhs.has_value()) && (!Lhs || (*Lhs == *Rhs)); + } + + static bool operator==(const PutCacheRecordRequest& Lhs, const PutCacheRecordRequest& Rhs) + { + return (Lhs.Key == Rhs.Key) && (Lhs.Values == Rhs.Values) && (Lhs.Policy == Rhs.Policy); + } + + static bool operator==(const PutCacheRecordsRequest& Lhs, const PutCacheRecordsRequest& Rhs) + { + return (Lhs.DefaultPolicy == Rhs.DefaultPolicy) && (Lhs.Namespace == Rhs.Namespace) && (Lhs.Requests == Rhs.Requests); + } + + static bool operator==(const PutCacheRecordsResult& Lhs, const PutCacheRecordsResult& Rhs) { return (Lhs.Success == Rhs.Success); } + + static bool operator==(const GetCacheRecordRequest& Lhs, const GetCacheRecordRequest& Rhs) + { + return (Lhs.Key == Rhs.Key) && (Lhs.Policy == Rhs.Policy); + } + + static bool operator==(const GetCacheRecordsRequest& Lhs, const GetCacheRecordsRequest& Rhs) + { + return (Lhs.DefaultPolicy == Rhs.DefaultPolicy) && (Lhs.Namespace == Rhs.Namespace) && (Lhs.Requests == Rhs.Requests); + } + + static bool operator==(const GetCacheRecordResultValue& Lhs, const GetCacheRecordResultValue& Rhs) + { + if ((Lhs.Id != Rhs.Id) || (Lhs.RawHash != Rhs.RawHash) || (Lhs.RawSize != Rhs.RawSize)) + { + return false; + } + if (bool(Lhs.Body) != bool(Rhs.Body)) + { + return false; + } + if (bool(Lhs.Body) && Lhs.Body.DecodeRawHash() != Rhs.Body.DecodeRawHash()) + { + return false; + } + return true; + } + + static bool operator==(const GetCacheRecordResult& Lhs, const GetCacheRecordResult& Rhs) + { + return Lhs.Key == Rhs.Key && Lhs.Values == Rhs.Values; + } + + static bool operator==(const std::optional<GetCacheRecordResult>& Lhs, const std::optional<GetCacheRecordResult>& Rhs) + { + if (Lhs.has_value() != Rhs.has_value()) + { + return false; + } + return *Lhs == Rhs; + } + + static bool operator==(const GetCacheRecordsResult& Lhs, const GetCacheRecordsResult& Rhs) { return Lhs.Results == Rhs.Results; } + + static bool operator==(const PutCacheValueRequest& Lhs, const PutCacheValueRequest& Rhs) + { + if ((Lhs.Key != Rhs.Key) || (Lhs.RawHash != Rhs.RawHash)) + { + return false; + } + + if (bool(Lhs.Body) != bool(Rhs.Body)) + { + return false; + } + if (!Lhs.Body) + { + return true; + } + return Lhs.Body.GetCompressed().Flatten().GetView().EqualBytes(Rhs.Body.GetCompressed().Flatten().GetView()); + } + + static bool operator==(const PutCacheValuesRequest& Lhs, const PutCacheValuesRequest& Rhs) + { + return (Lhs.DefaultPolicy == Rhs.DefaultPolicy) && (Lhs.Namespace == Rhs.Namespace) && (Lhs.Requests == Rhs.Requests); + } + + static bool operator==(const PutCacheValuesResult& Lhs, const PutCacheValuesResult& Rhs) { return (Lhs.Success == Rhs.Success); } + + static bool operator==(const GetCacheValueRequest& Lhs, const GetCacheValueRequest& Rhs) + { + return Lhs.Key == Rhs.Key && Lhs.Policy == Rhs.Policy; + } + + static bool operator==(const GetCacheValuesRequest& Lhs, const GetCacheValuesRequest& Rhs) + { + return Lhs.DefaultPolicy == Rhs.DefaultPolicy && Lhs.Namespace == Rhs.Namespace && Lhs.Requests == Rhs.Requests; + } + + static bool operator==(const CacheValueResult& Lhs, const CacheValueResult& Rhs) + { + if (Lhs.RawHash != Rhs.RawHash) + { + return false; + }; + if (Lhs.Body) + { + if (!Rhs.Body) + { + return false; + } + return Lhs.Body.GetCompressed().Flatten().GetView().EqualBytes(Rhs.Body.GetCompressed().Flatten().GetView()); + } + return Lhs.RawSize == Rhs.RawSize; + } + + static bool operator==(const CacheValuesResult& Lhs, const CacheValuesResult& Rhs) { return Lhs.Results == Rhs.Results; } + + static bool operator==(const GetCacheChunkRequest& Lhs, const GetCacheChunkRequest& Rhs) + { + return Lhs.Key == Rhs.Key && Lhs.ValueId == Rhs.ValueId && Lhs.ChunkId == Rhs.ChunkId && Lhs.RawOffset == Rhs.RawOffset && + Lhs.RawSize == Rhs.RawSize && Lhs.Policy == Rhs.Policy; + } + + static bool operator==(const GetCacheChunksRequest& Lhs, const GetCacheChunksRequest& Rhs) + { + return Lhs.DefaultPolicy == Rhs.DefaultPolicy && Lhs.Namespace == Rhs.Namespace && Lhs.Requests == Rhs.Requests; + } + + static CompressedBuffer MakeCompressedBuffer(size_t Size) { return CompressedBuffer::Compress(SharedBuffer(IoBuffer(Size))); }; + + TEST_CASE("cacherequests.put.cache.records") + { + PutCacheRecordsRequest EmptyRequest; + CbPackage EmptyRequestPackage; + CHECK(EmptyRequest.Format(EmptyRequestPackage)); + PutCacheRecordsRequest EmptyRequestCopy; + CHECK(!EmptyRequestCopy.Parse(EmptyRequestPackage)); // Namespace is required + + PutCacheRecordsRequest FullRequest = { + .DefaultPolicy = CachePolicy::Remote, + .Namespace = "the_namespace", + .Requests = {{.Key = {.Bucket = "thebucket", .Hash = IoHash::FromHexString("177030568fdd461bf4fe5ddbf4d463e514e8178e")}, + .Values = {{.Id = Oid::NewOid(), .Body = MakeCompressedBuffer(2134)}, + {.Id = Oid::NewOid(), .Body = MakeCompressedBuffer(213)}, + {.Id = Oid::NewOid(), .Body = MakeCompressedBuffer(7)}}, + .Policy = CachePolicy::StoreLocal}, + {.Key = {.Bucket = "thebucket", .Hash = IoHash::FromHexString("d1df59fcab06793a5f2c372d795bb907a15cab15")}, + .Values = {{.Id = Oid::NewOid(), .Body = MakeCompressedBuffer(1234)}, + {.Id = Oid::NewOid(), .Body = MakeCompressedBuffer(99)}, + {.Id = Oid::NewOid(), .Body = MakeCompressedBuffer(124)}}, + .Policy = CachePolicy::Store}, + {.Key = {.Bucket = "theotherbucket", .Hash = IoHash::FromHexString("e1ce9e1ac8a6f5953dc14c1fa9512b804ed689df")}, + .Values = {{.Id = Oid::NewOid(), .Body = MakeCompressedBuffer(19)}, + {.Id = Oid::NewOid(), .Body = MakeCompressedBuffer(1248)}, + {.Id = Oid::NewOid(), .Body = MakeCompressedBuffer(823)}}}}}; + + CbPackage FullRequestPackage; + CHECK(FullRequest.Format(FullRequestPackage)); + PutCacheRecordsRequest FullRequestCopy; + CHECK(FullRequestCopy.Parse(FullRequestPackage)); + CHECK(FullRequest == FullRequestCopy); + + PutCacheRecordsResult EmptyResult; + CbPackage EmptyResponsePackage; + CHECK(EmptyResult.Format(EmptyResponsePackage)); + PutCacheRecordsResult EmptyResultCopy; + CHECK(!EmptyResultCopy.Parse(EmptyResponsePackage)); + CHECK(EmptyResult == EmptyResultCopy); + + PutCacheRecordsResult FullResult = {.Success = {true, false, true, true, false}}; + CbPackage FullResponsePackage; + CHECK(FullResult.Format(FullResponsePackage)); + PutCacheRecordsResult FullResultCopy; + CHECK(FullResultCopy.Parse(FullResponsePackage)); + CHECK(FullResult == FullResultCopy); + } + + TEST_CASE("cacherequests.get.cache.records") + { + GetCacheRecordsRequest EmptyRequest; + CbPackage EmptyRequestPackage; + CHECK(EmptyRequest.Format(EmptyRequestPackage)); + GetCacheRecordsRequest EmptyRequestCopy; + CHECK(!EmptyRequestCopy.Parse(EmptyRequestPackage)); // Namespace is required + + GetCacheRecordsRequest FullRequest = { + .DefaultPolicy = CachePolicy::StoreLocal, + .Namespace = "other_namespace", + .Requests = {{.Key = {.Bucket = "finebucket", .Hash = IoHash::FromHexString("d1df59fcab06793a5f2c372d795bb907a15cab15")}, + .Policy = CachePolicy::Local}, + {.Key = {.Bucket = "badbucket", .Hash = IoHash::FromHexString("177030568fdd461bf4fe5ddbf4d463e514e8178e")}, + .Policy = CachePolicy::Remote}, + {.Key = {.Bucket = "badbucket", .Hash = IoHash::FromHexString("e1ce9e1ac8a6f5953dc14c1fa9512b804ed689df")}}}}; + + CbPackage FullRequestPackage; + CHECK(FullRequest.Format(FullRequestPackage)); + GetCacheRecordsRequest FullRequestCopy; + CHECK(FullRequestCopy.Parse(FullRequestPackage)); + CHECK(FullRequest == FullRequestCopy); + + CbPackage PartialRequestPackage; + CHECK(FullRequest.Format(PartialRequestPackage, std::initializer_list<size_t>{0, 2})); + GetCacheRecordsRequest PartialRequest = FullRequest; + PartialRequest.Requests.erase(PartialRequest.Requests.begin() + 1); + GetCacheRecordsRequest PartialRequestCopy; + CHECK(PartialRequestCopy.Parse(PartialRequestPackage)); + CHECK(PartialRequest == PartialRequestCopy); + + GetCacheRecordsResult EmptyResult; + CbPackage EmptyResponsePackage; + CHECK(EmptyResult.Format(EmptyResponsePackage)); + GetCacheRecordsResult EmptyResultCopy; + CHECK(!EmptyResultCopy.Parse(EmptyResponsePackage)); + CHECK(EmptyResult == EmptyResultCopy); + + PutCacheRecordsRequest FullPutRequest = { + .DefaultPolicy = CachePolicy::Remote, + .Namespace = "the_namespace", + .Requests = {{.Key = {.Bucket = "thebucket", .Hash = IoHash::FromHexString("177030568fdd461bf4fe5ddbf4d463e514e8178e")}, + .Values = {{.Id = Oid::NewOid(), .Body = MakeCompressedBuffer(2134)}, + {.Id = Oid::NewOid(), .Body = MakeCompressedBuffer(213)}, + {.Id = Oid::NewOid(), .Body = MakeCompressedBuffer(7)}}, + .Policy = CachePolicy::StoreLocal}, + {.Key = {.Bucket = "thebucket", .Hash = IoHash::FromHexString("d1df59fcab06793a5f2c372d795bb907a15cab15")}, + .Values = {{.Id = Oid::NewOid(), .Body = MakeCompressedBuffer(1234)}, + {.Id = Oid::NewOid(), .Body = MakeCompressedBuffer(99)}, + {.Id = Oid::NewOid(), .Body = MakeCompressedBuffer(124)}}, + .Policy = CachePolicy::Store}, + {.Key = {.Bucket = "theotherbucket", .Hash = IoHash::FromHexString("e1ce9e1ac8a6f5953dc14c1fa9512b804ed689df")}, + .Values = {{.Id = Oid::NewOid(), .Body = MakeCompressedBuffer(19)}, + {.Id = Oid::NewOid(), .Body = MakeCompressedBuffer(1248)}, + {.Id = Oid::NewOid(), .Body = MakeCompressedBuffer(823)}}}}}; + + CbPackage FullPutRequestPackage; + CHECK(FullPutRequest.Format(FullPutRequestPackage)); + PutCacheRecordsRequest FullPutRequestCopy; + CHECK(FullPutRequestCopy.Parse(FullPutRequestPackage)); + + GetCacheRecordsResult FullResult = { + {GetCacheRecordResult{.Key = FullPutRequestCopy.Requests[0].Key, + .Values = {{.Id = FullPutRequestCopy.Requests[0].Values[0].Id, + .RawHash = FullPutRequestCopy.Requests[0].Values[0].Body.DecodeRawHash(), + .RawSize = FullPutRequestCopy.Requests[0].Values[0].Body.DecodeRawSize(), + .Body = FullPutRequestCopy.Requests[0].Values[0].Body}, + {.Id = FullPutRequestCopy.Requests[0].Values[1].Id, + + .RawHash = FullPutRequestCopy.Requests[0].Values[1].Body.DecodeRawHash(), + .RawSize = FullPutRequestCopy.Requests[0].Values[1].Body.DecodeRawSize(), + .Body = FullPutRequestCopy.Requests[0].Values[1].Body}, + {.Id = FullPutRequestCopy.Requests[0].Values[2].Id, + .RawHash = FullPutRequestCopy.Requests[0].Values[2].Body.DecodeRawHash(), + .RawSize = FullPutRequestCopy.Requests[0].Values[2].Body.DecodeRawSize(), + .Body = FullPutRequestCopy.Requests[0].Values[2].Body}}}, + {}, // Simulate not have! + GetCacheRecordResult{.Key = FullPutRequestCopy.Requests[2].Key, + .Values = {{.Id = FullPutRequestCopy.Requests[2].Values[0].Id, + .RawHash = FullPutRequestCopy.Requests[2].Values[0].Body.DecodeRawHash(), + .RawSize = FullPutRequestCopy.Requests[2].Values[0].Body.DecodeRawSize(), + .Body = FullPutRequestCopy.Requests[2].Values[0].Body}, + {.Id = FullPutRequestCopy.Requests[2].Values[1].Id, + .RawHash = FullPutRequestCopy.Requests[2].Values[1].Body.DecodeRawHash(), + .RawSize = FullPutRequestCopy.Requests[2].Values[1].Body.DecodeRawSize(), + .Body = {}}, // Simulate not have + {.Id = FullPutRequestCopy.Requests[2].Values[2].Id, + .RawHash = FullPutRequestCopy.Requests[2].Values[2].Body.DecodeRawHash(), + .RawSize = FullPutRequestCopy.Requests[2].Values[2].Body.DecodeRawSize(), + .Body = FullPutRequestCopy.Requests[2].Values[2].Body}}}}}; + CbPackage FullResponsePackage; + CHECK(FullResult.Format(FullResponsePackage)); + GetCacheRecordsResult FullResultCopy; + CHECK(FullResultCopy.Parse(FullResponsePackage)); + CHECK(FullResult.Results[0] == FullResultCopy.Results[0]); + CHECK(!FullResultCopy.Results[1]); + CHECK(FullResult.Results[2] == FullResultCopy.Results[2]); + + GetCacheRecordsResult PartialResultCopy; + CHECK(PartialResultCopy.Parse(FullResponsePackage, std::initializer_list<size_t>{0, 3, 4})); + CHECK(FullResult.Results[0] == PartialResultCopy.Results[0]); + CHECK(!PartialResultCopy.Results[1]); + CHECK(!PartialResultCopy.Results[2]); + CHECK(!PartialResultCopy.Results[3]); + CHECK(FullResult.Results[2] == PartialResultCopy.Results[4]); + } + + TEST_CASE("cacherequests.put.cache.values") + { + PutCacheValuesRequest EmptyRequest; + CbPackage EmptyRequestPackage; + CHECK(EmptyRequest.Format(EmptyRequestPackage)); + PutCacheValuesRequest EmptyRequestCopy; + CHECK(!EmptyRequestCopy.Parse(EmptyRequestPackage)); // Namespace is required + + CompressedBuffer Buffers[3] = {MakeCompressedBuffer(969), MakeCompressedBuffer(3469), MakeCompressedBuffer(9)}; + PutCacheValuesRequest FullRequest = { + .DefaultPolicy = CachePolicy::StoreLocal, + .Namespace = "other_namespace", + .Requests = {{.Key = {.Bucket = "finebucket", .Hash = IoHash::FromHexString("d1df59fcab06793a5f2c372d795bb907a15cab15")}, + .RawHash = Buffers[0].DecodeRawHash(), + .Body = Buffers[0], + .Policy = CachePolicy::Local}, + {.Key = {.Bucket = "badbucket", .Hash = IoHash::FromHexString("177030568fdd461bf4fe5ddbf4d463e514e8178e")}, + .RawHash = Buffers[1].DecodeRawHash(), + .Body = Buffers[1], + .Policy = CachePolicy::Remote}, + {.Key = {.Bucket = "badbucket", .Hash = IoHash::FromHexString("e1ce9e1ac8a6f5953dc14c1fa9512b804ed689df")}, + .RawHash = Buffers[2].DecodeRawHash()}}}; + + CbPackage FullRequestPackage; + CHECK(FullRequest.Format(FullRequestPackage)); + PutCacheValuesRequest FullRequestCopy; + CHECK(FullRequestCopy.Parse(FullRequestPackage)); + CHECK(FullRequest == FullRequestCopy); + + PutCacheValuesResult EmptyResult; + CbPackage EmptyResponsePackage; + CHECK(!EmptyResult.Format(EmptyResponsePackage)); + + PutCacheValuesResult FullResult = {.Success = {true, false, true}}; + + CbPackage FullResponsePackage; + CHECK(FullResult.Format(FullResponsePackage)); + PutCacheValuesResult FullResultCopy; + CHECK(FullResultCopy.Parse(FullResponsePackage)); + CHECK(FullResult == FullResultCopy); + } + + TEST_CASE("cacherequests.get.cache.values") + { + GetCacheValuesRequest EmptyRequest; + CbPackage EmptyRequestPackage; + CHECK(EmptyRequest.Format(EmptyRequestPackage)); + GetCacheValuesRequest EmptyRequestCopy; + CHECK(!EmptyRequestCopy.Parse(EmptyRequestPackage.GetObject())); // Namespace is required + + GetCacheValuesRequest FullRequest = { + .DefaultPolicy = CachePolicy::StoreLocal, + .Namespace = "other_namespace", + .Requests = {{.Key = {.Bucket = "finebucket", .Hash = IoHash::FromHexString("d1df59fcab06793a5f2c372d795bb907a15cab15")}, + .Policy = CachePolicy::Local}, + {.Key = {.Bucket = "badbucket", .Hash = IoHash::FromHexString("177030568fdd461bf4fe5ddbf4d463e514e8178e")}, + .Policy = CachePolicy::Remote}, + {.Key = {.Bucket = "badbucket", .Hash = IoHash::FromHexString("e1ce9e1ac8a6f5953dc14c1fa9512b804ed689df")}}}}; + + CbPackage FullRequestPackage; + CHECK(FullRequest.Format(FullRequestPackage)); + GetCacheValuesRequest FullRequestCopy; + CHECK(FullRequestCopy.Parse(FullRequestPackage.GetObject())); + CHECK(FullRequest == FullRequestCopy); + + CbPackage PartialRequestPackage; + CHECK(FullRequest.Format(PartialRequestPackage, std::initializer_list<size_t>{0, 2})); + GetCacheValuesRequest PartialRequest = FullRequest; + PartialRequest.Requests.erase(PartialRequest.Requests.begin() + 1); + GetCacheValuesRequest PartialRequestCopy; + CHECK(PartialRequestCopy.Parse(PartialRequestPackage.GetObject())); + CHECK(PartialRequest == PartialRequestCopy); + + CacheValuesResult EmptyResult; + CbPackage EmptyResponsePackage; + CHECK(EmptyResult.Format(EmptyResponsePackage)); + CacheValuesResult EmptyResultCopy; + CHECK(!EmptyResultCopy.Parse(EmptyResponsePackage)); + CHECK(EmptyResult == EmptyResultCopy); + + CompressedBuffer Buffers[3][3] = {{MakeCompressedBuffer(123), MakeCompressedBuffer(321), MakeCompressedBuffer(333)}, + {MakeCompressedBuffer(6123), MakeCompressedBuffer(8321), MakeCompressedBuffer(7333)}, + {MakeCompressedBuffer(5123), MakeCompressedBuffer(2321), MakeCompressedBuffer(2333)}}; + CacheValuesResult FullResult = { + .Results = {CacheValueResult{.RawSize = 0, .RawHash = Buffers[0][0].DecodeRawHash(), .Body = Buffers[0][0]}, + CacheValueResult{.RawSize = 0, .RawHash = Buffers[0][1].DecodeRawHash(), .Body = Buffers[0][1]}, + CacheValueResult{.RawSize = 0, .RawHash = Buffers[0][2].DecodeRawHash(), .Body = Buffers[0][2]}, + CacheValueResult{.RawSize = 0, .RawHash = Buffers[2][0].DecodeRawHash(), .Body = Buffers[2][0]}, + CacheValueResult{.RawSize = 0, .RawHash = Buffers[2][1].DecodeRawHash(), .Body = Buffers[2][1]}, + CacheValueResult{.RawSize = Buffers[2][2].DecodeRawSize(), .RawHash = Buffers[2][2].DecodeRawHash()}}}; + CbPackage FullResponsePackage; + CHECK(FullResult.Format(FullResponsePackage)); + CacheValuesResult FullResultCopy; + CHECK(FullResultCopy.Parse(FullResponsePackage)); + CHECK(FullResult == FullResultCopy); + + CacheValuesResult PartialResultCopy; + CHECK(PartialResultCopy.Parse(FullResponsePackage, std::initializer_list<size_t>{0, 3, 4, 5, 6, 9})); + CHECK(PartialResultCopy.Results[0] == FullResult.Results[0]); + CHECK(PartialResultCopy.Results[1].RawHash == IoHash::Zero); + CHECK(PartialResultCopy.Results[2].RawHash == IoHash::Zero); + CHECK(PartialResultCopy.Results[3] == FullResult.Results[1]); + CHECK(PartialResultCopy.Results[4] == FullResult.Results[2]); + CHECK(PartialResultCopy.Results[5] == FullResult.Results[3]); + CHECK(PartialResultCopy.Results[6] == FullResult.Results[4]); + CHECK(PartialResultCopy.Results[7].RawHash == IoHash::Zero); + CHECK(PartialResultCopy.Results[8].RawHash == IoHash::Zero); + CHECK(PartialResultCopy.Results[9] == FullResult.Results[5]); + } + + TEST_CASE("cacherequests.get.cache.chunks") + { + GetCacheChunksRequest EmptyRequest; + CbPackage EmptyRequestPackage; + CHECK(EmptyRequest.Format(EmptyRequestPackage)); + GetCacheChunksRequest EmptyRequestCopy; + CHECK(!EmptyRequestCopy.Parse(EmptyRequestPackage.GetObject())); // Namespace is required + + GetCacheChunksRequest FullRequest = { + .DefaultPolicy = CachePolicy::StoreLocal, + .Namespace = "other_namespace", + .Requests = {{.Key = {.Bucket = "finebucket", .Hash = IoHash::FromHexString("d1df59fcab06793a5f2c372d795bb907a15cab15")}, + .ValueId = Oid::NewOid(), + .ChunkId = IoHash::FromHexString("ab3917854bfef7e7af2c372d795bb907a15cab15"), + .RawOffset = 77, + .RawSize = 33, + .Policy = CachePolicy::Local}, + {.Key = {.Bucket = "badbucket", .Hash = IoHash::FromHexString("177030568fdd461bf4fe5ddbf4d463e514e8178e")}, + .ValueId = Oid::NewOid(), + .ChunkId = IoHash::FromHexString("372d795bb907a15cab15ab3917854bfef7e7af2c"), + .Policy = CachePolicy::Remote}, + { + .Key = {.Bucket = "badbucket", .Hash = IoHash::FromHexString("e1ce9e1ac8a6f5953dc14c1fa9512b804ed689df")}, + .ChunkId = IoHash::FromHexString("372d795bb907a15cab15ab3917854bfef7e7af2c"), + }}}; + + CbPackage FullRequestPackage; + CHECK(FullRequest.Format(FullRequestPackage)); + GetCacheChunksRequest FullRequestCopy; + CHECK(FullRequestCopy.Parse(FullRequestPackage.GetObject())); + CHECK(FullRequest == FullRequestCopy); + + GetCacheChunksResult EmptyResult; + CbPackage EmptyResponsePackage; + CHECK(EmptyResult.Format(EmptyResponsePackage)); + GetCacheChunksResult EmptyResultCopy; + CHECK(!EmptyResultCopy.Parse(EmptyResponsePackage)); + CHECK(EmptyResult == EmptyResultCopy); + + CompressedBuffer Buffers[3][3] = {{MakeCompressedBuffer(123), MakeCompressedBuffer(321), MakeCompressedBuffer(333)}, + {MakeCompressedBuffer(6123), MakeCompressedBuffer(8321), MakeCompressedBuffer(7333)}, + {MakeCompressedBuffer(5123), MakeCompressedBuffer(2321), MakeCompressedBuffer(2333)}}; + GetCacheChunksResult FullResult = { + .Results = {CacheValueResult{.RawSize = 0, .RawHash = Buffers[0][0].DecodeRawHash(), .Body = Buffers[0][0]}, + CacheValueResult{.RawSize = 0, .RawHash = Buffers[0][1].DecodeRawHash(), .Body = Buffers[0][1]}, + CacheValueResult{.RawSize = 0, .RawHash = Buffers[0][2].DecodeRawHash(), .Body = Buffers[0][2]}, + CacheValueResult{.RawSize = 0, .RawHash = Buffers[2][0].DecodeRawHash(), .Body = Buffers[2][0]}, + CacheValueResult{.RawSize = 0, .RawHash = Buffers[2][1].DecodeRawHash(), .Body = Buffers[2][1]}, + CacheValueResult{.RawSize = Buffers[2][2].DecodeRawSize(), .RawHash = Buffers[2][2].DecodeRawHash()}}}; + CbPackage FullResponsePackage; + CHECK(FullResult.Format(FullResponsePackage)); + GetCacheChunksResult FullResultCopy; + CHECK(FullResultCopy.Parse(FullResponsePackage)); + CHECK(FullResult == FullResultCopy); + } + + TEST_CASE("z$service.parse.relative.Uri") + { + HttpRequestData LegacyBucketRequestBecomesNamespaceRequest; + CHECK(HttpRequestParseRelativeUri("test", LegacyBucketRequestBecomesNamespaceRequest)); + CHECK(LegacyBucketRequestBecomesNamespaceRequest.Namespace == "test"); + CHECK(!LegacyBucketRequestBecomesNamespaceRequest.Bucket.has_value()); + CHECK(!LegacyBucketRequestBecomesNamespaceRequest.HashKey.has_value()); + CHECK(!LegacyBucketRequestBecomesNamespaceRequest.ValueContentId.has_value()); + + HttpRequestData LegacyHashKeyRequest; + CHECK(HttpRequestParseRelativeUri("test/0123456789abcdef12340123456789abcdef1234", LegacyHashKeyRequest)); + CHECK(!LegacyHashKeyRequest.Namespace); + CHECK(LegacyHashKeyRequest.Bucket == "test"); + CHECK(LegacyHashKeyRequest.HashKey == IoHash::FromHexString("0123456789abcdef12340123456789abcdef1234")); + CHECK(!LegacyHashKeyRequest.ValueContentId.has_value()); + + HttpRequestData LegacyValueContentIdRequest; + CHECK(HttpRequestParseRelativeUri("test/0123456789abcdef12340123456789abcdef1234/56789abcdef12345678956789abcdef123456789", + LegacyValueContentIdRequest)); + CHECK(!LegacyValueContentIdRequest.Namespace); + CHECK(LegacyValueContentIdRequest.Bucket == "test"); + CHECK(LegacyValueContentIdRequest.HashKey == IoHash::FromHexString("0123456789abcdef12340123456789abcdef1234")); + CHECK(LegacyValueContentIdRequest.ValueContentId == IoHash::FromHexString("56789abcdef12345678956789abcdef123456789")); + + HttpRequestData V2DefaultNamespaceRequest; + CHECK(HttpRequestParseRelativeUri("ue4.ddc", V2DefaultNamespaceRequest)); + CHECK(V2DefaultNamespaceRequest.Namespace == "ue4.ddc"); + CHECK(!V2DefaultNamespaceRequest.Bucket.has_value()); + CHECK(!V2DefaultNamespaceRequest.HashKey.has_value()); + CHECK(!V2DefaultNamespaceRequest.ValueContentId.has_value()); + + HttpRequestData V2NamespaceRequest; + CHECK(HttpRequestParseRelativeUri("nicenamespace", V2NamespaceRequest)); + CHECK(V2NamespaceRequest.Namespace == "nicenamespace"); + CHECK(!V2NamespaceRequest.Bucket.has_value()); + CHECK(!V2NamespaceRequest.HashKey.has_value()); + CHECK(!V2NamespaceRequest.ValueContentId.has_value()); + + HttpRequestData V2BucketRequestWithDefaultNamespace; + CHECK(HttpRequestParseRelativeUri("ue4.ddc/test", V2BucketRequestWithDefaultNamespace)); + CHECK(V2BucketRequestWithDefaultNamespace.Namespace == "ue4.ddc"); + CHECK(V2BucketRequestWithDefaultNamespace.Bucket == "test"); + CHECK(!V2BucketRequestWithDefaultNamespace.HashKey.has_value()); + CHECK(!V2BucketRequestWithDefaultNamespace.ValueContentId.has_value()); + + HttpRequestData V2BucketRequestWithNamespace; + CHECK(HttpRequestParseRelativeUri("nicenamespace/test", V2BucketRequestWithNamespace)); + CHECK(V2BucketRequestWithNamespace.Namespace == "nicenamespace"); + CHECK(V2BucketRequestWithNamespace.Bucket == "test"); + CHECK(!V2BucketRequestWithNamespace.HashKey.has_value()); + CHECK(!V2BucketRequestWithNamespace.ValueContentId.has_value()); + + HttpRequestData V2HashKeyRequest; + CHECK(HttpRequestParseRelativeUri("test/0123456789abcdef12340123456789abcdef1234", V2HashKeyRequest)); + CHECK(!V2HashKeyRequest.Namespace); + CHECK(V2HashKeyRequest.Bucket == "test"); + CHECK(V2HashKeyRequest.HashKey == IoHash::FromHexString("0123456789abcdef12340123456789abcdef1234")); + CHECK(!V2HashKeyRequest.ValueContentId.has_value()); + + HttpRequestData V2ValueContentIdRequest; + CHECK(HttpRequestParseRelativeUri( + "nicenamespace/test/0123456789abcdef12340123456789abcdef1234/56789abcdef12345678956789abcdef123456789", + V2ValueContentIdRequest)); + CHECK(V2ValueContentIdRequest.Namespace == "nicenamespace"); + CHECK(V2ValueContentIdRequest.Bucket == "test"); + CHECK(V2ValueContentIdRequest.HashKey == IoHash::FromHexString("0123456789abcdef12340123456789abcdef1234")); + CHECK(V2ValueContentIdRequest.ValueContentId == IoHash::FromHexString("56789abcdef12345678956789abcdef123456789")); + + HttpRequestData Invalid; + CHECK(!HttpRequestParseRelativeUri("", Invalid)); + CHECK(!HttpRequestParseRelativeUri("/", Invalid)); + CHECK(!HttpRequestParseRelativeUri("bad\2_namespace", Invalid)); + CHECK(!HttpRequestParseRelativeUri("nice/\2\1bucket", Invalid)); + CHECK(!HttpRequestParseRelativeUri("namespace/bucket/0123456789a", Invalid)); + CHECK(!HttpRequestParseRelativeUri("namespace/bucket/0123456789abcdef12340123456789abcdef1234/56789abcdef1234", Invalid)); + CHECK(!HttpRequestParseRelativeUri("namespace/bucket/pppppppp89abcdef12340123456789abcdef1234", Invalid)); + CHECK(!HttpRequestParseRelativeUri("namespace/bucket/0123456789abcdef12340123456789abcdef1234/56789abcd", Invalid)); + CHECK(!HttpRequestParseRelativeUri( + "namespace/bucket/0123456789abcdef12340123456789abcdef1234/ppppppppdef12345678956789abcdef123456789", + Invalid)); + } +#endif +} // namespace cacherequests + +void +cacherequests_forcelink() +{ +} + +} // namespace zen diff --git a/src/zenutil/cache/rpcrecording.cpp b/src/zenutil/cache/rpcrecording.cpp new file mode 100644 index 000000000..4958a27f6 --- /dev/null +++ b/src/zenutil/cache/rpcrecording.cpp @@ -0,0 +1,210 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenutil/basicfile.h> +#include <zenutil/cache/rpcrecording.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <fmt/format.h> +#include <gsl/gsl-lite.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen::cache { +struct RecordedRequest +{ + uint64_t Offset; + uint64_t Length; + ZenContentType ContentType; + ZenContentType AcceptType; +}; + +const uint64_t RecordedRequestBlockSize = 1ull << 31u; + +struct RecordedRequestsWriter +{ + void BeginWrite(const std::filesystem::path& BasePath) + { + m_BasePath = BasePath; + std::filesystem::create_directories(m_BasePath); + } + + void EndWrite() + { + RwLock::ExclusiveLockScope _(m_Lock); + m_BlockFiles.clear(); + + IoBuffer IndexBuffer(IoBuffer::Wrap, m_Entries.data(), m_Entries.size() * sizeof(RecordedRequest)); + BasicFile IndexFile; + IndexFile.Open(m_BasePath / "index.bin", BasicFile::Mode::kTruncate); + std::error_code Ec; + IndexFile.WriteAll(IndexBuffer, Ec); + IndexFile.Close(); + m_Entries.clear(); + } + + uint64_t WriteRequest(ZenContentType ContentType, ZenContentType AcceptType, const IoBuffer& RequestBuffer) + { + RwLock::ExclusiveLockScope Lock(m_Lock); + uint64_t RequestIndex = m_Entries.size(); + RecordedRequest& Entry = m_Entries.emplace_back( + RecordedRequest{.Offset = ~0ull, .Length = RequestBuffer.Size(), .ContentType = ContentType, .AcceptType = AcceptType}); + if (Entry.Length < 1 * 1024 * 1024) + { + uint32_t BlockIndex = gsl::narrow<uint32_t>((m_ChunkOffset + Entry.Length) / RecordedRequestBlockSize); + if (BlockIndex == m_BlockFiles.size()) + { + std::unique_ptr<BasicFile>& NewBlockFile = m_BlockFiles.emplace_back(std::make_unique<BasicFile>()); + NewBlockFile->Open(m_BasePath / fmt::format("chunks{}.bin", BlockIndex), BasicFile::Mode::kTruncate); + m_ChunkOffset = BlockIndex * RecordedRequestBlockSize; + } + ZEN_ASSERT(BlockIndex < m_BlockFiles.size()); + BasicFile* BlockFile = m_BlockFiles[BlockIndex].get(); + ZEN_ASSERT(BlockFile != nullptr); + + Entry.Offset = m_ChunkOffset; + m_ChunkOffset = RoundUp(m_ChunkOffset + Entry.Length, 1u << 4u); + Lock.ReleaseNow(); + + std::error_code Ec; + BlockFile->Write(RequestBuffer.Data(), RequestBuffer.Size(), Entry.Offset - BlockIndex * RecordedRequestBlockSize, Ec); + if (Ec) + { + Entry.Length = 0; + return ~0ull; + } + return RequestIndex; + } + Lock.ReleaseNow(); + + BasicFile RequestFile; + RequestFile.Open(m_BasePath / fmt::format("request{}.bin", RequestIndex), BasicFile::Mode::kTruncate); + std::error_code Ec; + RequestFile.WriteAll(RequestBuffer, Ec); + if (Ec) + { + Entry.Length = 0; + return ~0ull; + } + return RequestIndex; + } + + std::filesystem::path m_BasePath; + mutable RwLock m_Lock; + std::vector<RecordedRequest> m_Entries; + std::vector<std::unique_ptr<BasicFile>> m_BlockFiles; + uint64_t m_ChunkOffset; +}; + +struct RecordedRequestsReader +{ + uint64_t BeginRead(const std::filesystem::path& BasePath, bool InMemory) + { + m_BasePath = BasePath; + BasicFile IndexFile; + IndexFile.Open(m_BasePath / "index.bin", BasicFile::Mode::kRead); + m_Entries.resize(IndexFile.FileSize() / sizeof(RecordedRequest)); + IndexFile.Read(m_Entries.data(), IndexFile.FileSize(), 0); + uint64_t MaxChunkPosition = 0; + for (const RecordedRequest& R : m_Entries) + { + if (R.Offset != ~0ull) + { + MaxChunkPosition = Max(MaxChunkPosition, R.Offset + R.Length); + } + } + uint32_t BlockCount = gsl::narrow<uint32_t>(MaxChunkPosition / RecordedRequestBlockSize) + 1; + m_BlockFiles.resize(BlockCount); + for (uint32_t BlockIndex = 0; BlockIndex < BlockCount; ++BlockIndex) + { + if (InMemory) + { + BasicFile Chunk; + Chunk.Open(m_BasePath / fmt::format("chunks{}.bin", BlockIndex), BasicFile::Mode::kRead); + m_BlockFiles[BlockIndex] = Chunk.ReadAll(); + continue; + } + m_BlockFiles[BlockIndex] = IoBufferBuilder::MakeFromFile(m_BasePath / fmt::format("chunks{}.bin", BlockIndex)); + } + return m_Entries.size(); + } + void EndRead() { m_BlockFiles.clear(); } + + std::pair<ZenContentType, ZenContentType> ReadRequest(uint64_t RequestIndex, IoBuffer& OutBuffer) const + { + if (RequestIndex >= m_Entries.size()) + { + return {ZenContentType::kUnknownContentType, ZenContentType::kUnknownContentType}; + } + const RecordedRequest& Entry = m_Entries[RequestIndex]; + if (Entry.Length == 0) + { + return {ZenContentType::kUnknownContentType, ZenContentType::kUnknownContentType}; + } + if (Entry.Offset != ~0ull) + { + uint32_t BlockIndex = gsl::narrow<uint32_t>((Entry.Offset + Entry.Length) / RecordedRequestBlockSize); + uint64_t ChunkOffset = Entry.Offset - (BlockIndex * RecordedRequestBlockSize); + OutBuffer = IoBuffer(m_BlockFiles[BlockIndex], ChunkOffset, Entry.Length); + return {Entry.ContentType, Entry.AcceptType}; + } + OutBuffer = IoBufferBuilder::MakeFromFile(m_BasePath / fmt::format("request{}.bin", RequestIndex)); + return {Entry.ContentType, Entry.AcceptType}; + } + + std::filesystem::path m_BasePath; + std::vector<RecordedRequest> m_Entries; + std::vector<IoBuffer> m_BlockFiles; +}; + +class DiskRequestRecorder : public IRpcRequestRecorder +{ +public: + DiskRequestRecorder(const std::filesystem::path& BasePath) { m_RecordedRequests.BeginWrite(BasePath); } + virtual ~DiskRequestRecorder() { m_RecordedRequests.EndWrite(); } + +private: + virtual uint64_t RecordRequest(const ZenContentType ContentType, + const ZenContentType AcceptType, + const IoBuffer& RequestBuffer) override + { + return m_RecordedRequests.WriteRequest(ContentType, AcceptType, RequestBuffer); + } + virtual void RecordResponse(uint64_t, const ZenContentType, const IoBuffer&) override {} + virtual void RecordResponse(uint64_t, const ZenContentType, const CompositeBuffer&) override {} + RecordedRequestsWriter m_RecordedRequests; +}; + +class DiskRequestReplayer : public IRpcRequestReplayer +{ +public: + DiskRequestReplayer(const std::filesystem::path& BasePath, bool InMemory) + { + m_RequestCount = m_RequestBuffer.BeginRead(BasePath, InMemory); + } + virtual ~DiskRequestReplayer() { m_RequestBuffer.EndRead(); } + +private: + virtual uint64_t GetRequestCount() const override { return m_RequestCount; } + + virtual std::pair<ZenContentType, ZenContentType> GetRequest(uint64_t RequestIndex, IoBuffer& OutBuffer) override + { + return m_RequestBuffer.ReadRequest(RequestIndex, OutBuffer); + } + virtual ZenContentType GetResponse(uint64_t, IoBuffer&) override { return ZenContentType::kUnknownContentType; } + + std::uint64_t m_RequestCount; + RecordedRequestsReader m_RequestBuffer; +}; + +std::unique_ptr<cache::IRpcRequestRecorder> +MakeDiskRequestRecorder(const std::filesystem::path& BasePath) +{ + return std::make_unique<DiskRequestRecorder>(BasePath); +} + +std::unique_ptr<cache::IRpcRequestReplayer> +MakeDiskRequestReplayer(const std::filesystem::path& BasePath, bool InMemory) +{ + return std::make_unique<DiskRequestReplayer>(BasePath, InMemory); +} + +} // namespace zen::cache diff --git a/src/zenutil/include/zenutil/basicfile.h b/src/zenutil/include/zenutil/basicfile.h new file mode 100644 index 000000000..877df0f92 --- /dev/null +++ b/src/zenutil/include/zenutil/basicfile.h @@ -0,0 +1,113 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/iobuffer.h> + +#include <filesystem> +#include <functional> + +namespace zen { + +class CbObject; + +/** + * Probably the most basic file abstraction in the universe + * + * One thing of note is that there is no notion of a "current file position" + * in this API -- all reads and writes are done from explicit offsets in + * the file. This avoids concurrency issues which can occur otherwise. + * + */ + +class BasicFile +{ +public: + BasicFile() = default; + ~BasicFile(); + + BasicFile(const BasicFile&) = delete; + BasicFile& operator=(const BasicFile&) = delete; + + enum class Mode : uint32_t + { + kRead = 0, // Opens a existing file for read only + kWrite = 1, // Opens (or creates) a file for read and write + kTruncate = 2, // Opens (or creates) a file for read and write and sets the size to zero + kDelete = 3, // Opens (or creates) a file for read and write allowing .DeleteFile file disposition to be set + kTruncateDelete = + 4 // Opens (or creates) a file for read and write and sets the size to zero allowing .DeleteFile file disposition to be set + }; + + void Open(const std::filesystem::path& FileName, Mode Mode); + void Open(const std::filesystem::path& FileName, Mode Mode, std::error_code& Ec); + void Close(); + void Read(void* Data, uint64_t Size, uint64_t FileOffset); + void StreamFile(std::function<void(const void* Data, uint64_t Size)>&& ChunkFun); + void StreamByteRange(uint64_t FileOffset, uint64_t Size, std::function<void(const void* Data, uint64_t Size)>&& ChunkFun); + void Write(MemoryView Data, uint64_t FileOffset); + void Write(MemoryView Data, uint64_t FileOffset, std::error_code& Ec); + void Write(const void* Data, uint64_t Size, uint64_t FileOffset); + void Write(const void* Data, uint64_t Size, uint64_t FileOffset, std::error_code& Ec); + void Flush(); + uint64_t FileSize(); + void SetFileSize(uint64_t FileSize); + IoBuffer ReadAll(); + void WriteAll(IoBuffer Data, std::error_code& Ec); + void* Detach(); + + inline void* Handle() { return m_FileHandle; } + +protected: + void* m_FileHandle = nullptr; // This is either null or valid +private: +}; + +/** + * Simple abstraction for a temporary file + * + * Works like a regular BasicFile but implements a simple mechanism to allow creating + * a temporary file for writing in a directory which may later be moved atomically + * into the intended location after it has been fully written to. + * + */ + +class TemporaryFile : public BasicFile +{ +public: + TemporaryFile() = default; + ~TemporaryFile(); + + TemporaryFile(const TemporaryFile&) = delete; + TemporaryFile& operator=(const TemporaryFile&) = delete; + + void Close(); + void CreateTemporary(std::filesystem::path TempDirName, std::error_code& Ec); + void MoveTemporaryIntoPlace(std::filesystem::path FinalFileName, std::error_code& Ec); + const std::filesystem::path& GetPath() const { return m_TempPath; } + +private: + std::filesystem::path m_TempPath; + + using BasicFile::Open; +}; + +/** Lock file abstraction + + */ + +class LockFile : protected BasicFile +{ +public: + LockFile(); + ~LockFile(); + + void Create(std::filesystem::path FileName, CbObject Payload, std::error_code& Ec); + void Update(CbObject Payload, std::error_code& Ec); + +private: +}; + +ZENCORE_API void basicfile_forcelink(); + +} // namespace zen diff --git a/src/zenutil/include/zenutil/cache/cache.h b/src/zenutil/include/zenutil/cache/cache.h new file mode 100644 index 000000000..1a1dd9386 --- /dev/null +++ b/src/zenutil/include/zenutil/cache/cache.h @@ -0,0 +1,6 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenutil/cache/cachekey.h> +#include <zenutil/cache/cachepolicy.h> diff --git a/src/zenutil/include/zenutil/cache/cachekey.h b/src/zenutil/include/zenutil/cache/cachekey.h new file mode 100644 index 000000000..741375946 --- /dev/null +++ b/src/zenutil/include/zenutil/cache/cachekey.h @@ -0,0 +1,86 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/iohash.h> +#include <zencore/string.h> +#include <zencore/uid.h> + +#include <zenutil/cache/cachepolicy.h> + +namespace zen { + +struct CacheKey +{ + std::string Bucket; + IoHash Hash; + + static CacheKey Create(std::string_view Bucket, const IoHash& Hash) { return {.Bucket = ToLower(Bucket), .Hash = Hash}; } + + auto operator<=>(const CacheKey& that) const + { + if (auto b = caseSensitiveCompareStrings(Bucket, that.Bucket); b != std::strong_ordering::equal) + { + return b; + } + return Hash <=> that.Hash; + } + + auto operator==(const CacheKey& that) const { return (*this <=> that) == std::strong_ordering::equal; } + + static const CacheKey Empty; +}; + +struct CacheChunkRequest +{ + CacheKey Key; + IoHash ChunkId; + Oid ValueId; + uint64_t RawOffset = 0ull; + uint64_t RawSize = ~uint64_t(0); + CachePolicy Policy = CachePolicy::Default; +}; + +struct CacheKeyRequest +{ + CacheKey Key; + CacheRecordPolicy Policy; +}; + +struct CacheValueRequest +{ + CacheKey Key; + CachePolicy Policy = CachePolicy::Default; +}; + +inline bool +operator<(const CacheChunkRequest& A, const CacheChunkRequest& B) +{ + if (A.Key < B.Key) + { + return true; + } + if (B.Key < A.Key) + { + return false; + } + if (A.ChunkId < B.ChunkId) + { + return true; + } + if (B.ChunkId < A.ChunkId) + { + return false; + } + if (A.ValueId < B.ValueId) + { + return true; + } + if (B.ValueId < A.ValueId) + { + return false; + } + return A.RawOffset < B.RawOffset; +} + +} // namespace zen diff --git a/src/zenutil/include/zenutil/cache/cachepolicy.h b/src/zenutil/include/zenutil/cache/cachepolicy.h new file mode 100644 index 000000000..9a745e42c --- /dev/null +++ b/src/zenutil/include/zenutil/cache/cachepolicy.h @@ -0,0 +1,227 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/compactbinary.h> +#include <zencore/enumflags.h> +#include <zencore/refcount.h> +#include <zencore/string.h> +#include <zencore/uid.h> + +#include <gsl/gsl-lite.hpp> +#include <span> +namespace zen::Private { +class ICacheRecordPolicyShared; +} +namespace zen { + +class CbObjectView; +class CbWriter; + +class OptionalCacheRecordPolicy; + +enum class CachePolicy : uint32_t +{ + /** A value with no flags. Disables access to the cache unless combined with other flags. */ + None = 0, + + /** Allow a cache request to query local caches. */ + QueryLocal = 1 << 0, + /** Allow a cache request to query remote caches. */ + QueryRemote = 1 << 1, + /** Allow a cache request to query any caches. */ + Query = QueryLocal | QueryRemote, + + /** Allow cache requests to query and store records and values in local caches. */ + StoreLocal = 1 << 2, + /** Allow cache records and values to be stored in remote caches. */ + StoreRemote = 1 << 3, + /** Allow cache records and values to be stored in any caches. */ + Store = StoreLocal | StoreRemote, + + /** Allow cache requests to query and store records and values in local caches. */ + Local = QueryLocal | StoreLocal, + /** Allow cache requests to query and store records and values in remote caches. */ + Remote = QueryRemote | StoreRemote, + + /** Allow cache requests to query and store records and values in any caches. */ + Default = Query | Store, + + /** Skip fetching the data for values. */ + SkipData = 1 << 4, + + /** Skip fetching the metadata for record requests. */ + SkipMeta = 1 << 5, + + /** + * Partial output will be provided with the error status when a required value is missing. + * + * This is meant for cases when the missing values can be individually recovered, or rebuilt, + * without rebuilding the whole record. The cache automatically adds this flag when there are + * other cache stores that it may be able to recover missing values from. + * + * Missing values will be returned in the records, but with only the hash and size. + * + * Applying this flag for a put of a record allows a partial record to be stored. + */ + PartialRecord = 1 << 6, + + /** + * Keep records in the cache for at least the duration of the session. + * + * This is a hint that the record may be accessed again in this session. This is mainly meant + * to be used when subsequent accesses will not tolerate a cache miss. + */ + KeepAlive = 1 << 7, +}; + +gsl_DEFINE_ENUM_BITMASK_OPERATORS(CachePolicy); +/** Append a non-empty text version of the policy to the builder. */ +StringBuilderBase& operator<<(StringBuilderBase& Builder, CachePolicy Policy); +/** Parse non-empty text written by operator<< into a policy. */ +CachePolicy ParseCachePolicy(std::string_view Text); +/** Return input converted into the equivalent policy that the upstream should use when forwarding a put or get to an upstream server. */ +CachePolicy ConvertToUpstream(CachePolicy Policy); + +inline CachePolicy +Union(CachePolicy A, CachePolicy B) +{ + constexpr CachePolicy InvertedFlags = CachePolicy::SkipData | CachePolicy::SkipMeta; + return (A & ~(InvertedFlags)) | (B & ~(InvertedFlags)) | (A & B & InvertedFlags); +} + +/** A value ID and the cache policy to use for that value. */ +struct CacheValuePolicy +{ + Oid Id; + CachePolicy Policy = CachePolicy::Default; + + /** Flags that are valid on a value policy. */ + static constexpr CachePolicy PolicyMask = CachePolicy::Default | CachePolicy::SkipData; +}; + +/** Interface for the private implementation of the cache record policy. */ +class Private::ICacheRecordPolicyShared : public RefCounted +{ +public: + virtual ~ICacheRecordPolicyShared() = default; + virtual void AddValuePolicy(const CacheValuePolicy& Policy) = 0; + virtual std::span<const CacheValuePolicy> GetValuePolicies() const = 0; +}; + +/** + * Flags to control the behavior of cache record requests, with optional overrides by value. + * + * Examples: + * - A base policy of None with value policy overrides of Default will fetch those values if they + * exist in the record, and skip data for any other values. + * - A base policy of Default, with value policy overrides of (Query | SkipData), will skip those + * values, but still check if they exist, and will load any other values. + */ +class CacheRecordPolicy +{ +public: + /** Construct a cache record policy that uses the default policy. */ + CacheRecordPolicy() = default; + + /** Construct a cache record policy with a uniform policy for the record and every value. */ + inline CacheRecordPolicy(CachePolicy BasePolicy) + : RecordPolicy(BasePolicy) + , DefaultValuePolicy(BasePolicy & CacheValuePolicy::PolicyMask) + { + } + + /** Returns true if the record and every value use the same cache policy. */ + inline bool IsUniform() const { return !Shared; } + + /** Returns the cache policy to use for the record. */ + inline CachePolicy GetRecordPolicy() const { return RecordPolicy; } + + /** Returns the base cache policy that this was constructed from. */ + inline CachePolicy GetBasePolicy() const { return DefaultValuePolicy | (RecordPolicy & ~CacheValuePolicy::PolicyMask); } + + /** Returns the cache policy to use for the value. */ + CachePolicy GetValuePolicy(const Oid& Id) const; + + /** Returns the array of cache policy overrides for values, sorted by ID. */ + inline std::span<const CacheValuePolicy> GetValuePolicies() const + { + return Shared ? Shared->GetValuePolicies() : std::span<const CacheValuePolicy>(); + } + + /** Saves the cache record policy to a compact binary object. */ + void Save(CbWriter& Writer) const; + + /** Loads a cache record policy from an object. */ + static OptionalCacheRecordPolicy Load(CbObjectView Object); + + /** Return *this converted into the equivalent policy that the upstream should use when forwarding a put or get to an upstream server. + */ + CacheRecordPolicy ConvertToUpstream() const; + +private: + friend class CacheRecordPolicyBuilder; + friend class OptionalCacheRecordPolicy; + + CachePolicy RecordPolicy = CachePolicy::Default; + CachePolicy DefaultValuePolicy = CachePolicy::Default; + RefPtr<const Private::ICacheRecordPolicyShared> Shared; +}; + +/** A cache record policy builder is used to construct a cache record policy. */ +class CacheRecordPolicyBuilder +{ +public: + /** Construct a policy builder that uses the default policy as its base policy. */ + CacheRecordPolicyBuilder() = default; + + /** Construct a policy builder that uses the provided policy for the record and values with no override. */ + inline explicit CacheRecordPolicyBuilder(CachePolicy Policy) : BasePolicy(Policy) {} + + /** Adds a cache policy override for a value. */ + void AddValuePolicy(const CacheValuePolicy& Value); + inline void AddValuePolicy(const Oid& Id, CachePolicy Policy) { AddValuePolicy({Id, Policy}); } + + /** Build a cache record policy, which makes this builder subsequently unusable. */ + CacheRecordPolicy Build(); + +private: + CachePolicy BasePolicy = CachePolicy::Default; + RefPtr<Private::ICacheRecordPolicyShared> Shared; +}; + +/** + * A cache record policy that can be null. + * + * @see CacheRecordPolicy + */ +class OptionalCacheRecordPolicy : private CacheRecordPolicy +{ +public: + inline OptionalCacheRecordPolicy() : CacheRecordPolicy(~CachePolicy::None) {} + + inline OptionalCacheRecordPolicy(CacheRecordPolicy&& InOutput) : CacheRecordPolicy(std::move(InOutput)) {} + inline OptionalCacheRecordPolicy(const CacheRecordPolicy& InOutput) : CacheRecordPolicy(InOutput) {} + inline OptionalCacheRecordPolicy& operator=(CacheRecordPolicy&& InOutput) + { + CacheRecordPolicy::operator=(std::move(InOutput)); + return *this; + } + inline OptionalCacheRecordPolicy& operator=(const CacheRecordPolicy& InOutput) + { + CacheRecordPolicy::operator=(InOutput); + return *this; + } + + /** Returns the cache record policy. The caller must check for null before using this accessor. */ + inline const CacheRecordPolicy& Get() const& { return *this; } + inline CacheRecordPolicy Get() && { return std::move(*this); } + + inline bool IsNull() const { return RecordPolicy == ~CachePolicy::None; } + inline bool IsValid() const { return !IsNull(); } + inline explicit operator bool() const { return !IsNull(); } + + inline void Reset() { *this = OptionalCacheRecordPolicy(); } +}; + +} // namespace zen diff --git a/src/zenutil/include/zenutil/cache/cacherequests.h b/src/zenutil/include/zenutil/cache/cacherequests.h new file mode 100644 index 000000000..f1999ebfe --- /dev/null +++ b/src/zenutil/include/zenutil/cache/cacherequests.h @@ -0,0 +1,279 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/compress.h> + +#include "cachekey.h" +#include "cachepolicy.h" + +#include <functional> + +namespace zen { + +class CbPackage; +class CbObjectWriter; +class CbObjectView; + +namespace cacherequests { + // I'd really like to get rid of std::optional<CacheRecordPolicy> (or really the class CacheRecordPolicy) + // + // CacheRecordPolicy has a record level policy but it can also contain policies for individual + // values inside the record. + // + // However, when we do a "PutCacheRecords" we already list the individual Values with their Id + // so we can just as well use an optional plain CachePolicy for each value. + // + // In "GetCacheRecords" we do not currently as for the individual values but you can add + // a policy on a per-value level in the std::optional<CacheRecordPolicy> Policy for each record. + // + // But as we already need to know the Ids of the values we want to set the policy for + // it would be simpler to add an array of requested values which each has an optional policy. + // + // We could add: + // struct GetCacheRecordValueRequest + // { + // Oid Id; + // std::optional<CachePolicy> Policy; + // }; + // + // and change GetCacheRecordRequest to + // struct GetCacheRecordRequest + // { + // CacheKey Key = CacheKey::Empty; + // std::vector<GetCacheRecordValueRequest> ValueRequests; + // std::optional<CachePolicy> Policy; + // }; + // + // This way we don't need the complex CacheRecordPolicy class and the request becomes + // more uniform and easier to understand. + // + // Would need to decide what the ValueRequests actually mean: + // Do they dictate which values to fetch or just a change of the policy? + // If they dictate the values to fetch you need to know all the value ids to set them + // and that is unlikely what we want - we want to be able to get a cache record with + // all its values without knowing all the Ids, right? + // + + ////////////////////////////////////////////////////////////////////////// + // Put 1..n structured cache records with optional attachments + + struct PutCacheRecordRequestValue + { + Oid Id = Oid::Zero; + IoHash RawHash = IoHash::Zero; // If Body is not set, this must be set and the value must already exist in cache + CompressedBuffer Body = CompressedBuffer::Null; + }; + + struct PutCacheRecordRequest + { + CacheKey Key = CacheKey::Empty; + std::vector<PutCacheRecordRequestValue> Values; + std::optional<CacheRecordPolicy> Policy; + }; + + struct PutCacheRecordsRequest + { + uint32_t AcceptMagic = 0; + CachePolicy DefaultPolicy = CachePolicy::Default; + std::string Namespace; + std::vector<PutCacheRecordRequest> Requests; + + bool Parse(const CbPackage& Package); + bool Format(CbPackage& OutPackage) const; + }; + + struct PutCacheRecordsResult + { + std::vector<bool> Success; + + bool Parse(const CbPackage& Package); + bool Format(CbPackage& OutPackage) const; + }; + + ////////////////////////////////////////////////////////////////////////// + // Get 1..n structured cache records with optional attachments + // We can get requests for a cache record where we want care about a particular + // value id which we now of, but we don't know the ids of the other values and + // we still want them. + // Not sure if in that case we want different policies for the different attachemnts? + + struct GetCacheRecordRequest + { + CacheKey Key = CacheKey::Empty; + std::optional<CacheRecordPolicy> Policy; + }; + + struct GetCacheRecordsRequest + { + uint32_t AcceptMagic = 0; + uint16_t AcceptOptions = 0; + int32_t ProcessPid = 0; + CachePolicy DefaultPolicy = CachePolicy::Default; + std::string Namespace; + std::vector<GetCacheRecordRequest> Requests; + + bool Parse(const CbPackage& RpcRequest); + bool Parse(const CbObjectView& RpcRequest); + bool Format(CbPackage& OutPackage, const std::span<const size_t> OptionalRecordFilter = {}) const; + bool Format(CbObjectWriter& Writer, const std::span<const size_t> OptionalRecordFilter = {}) const; + }; + + struct GetCacheRecordResultValue + { + Oid Id = Oid::Zero; + IoHash RawHash = IoHash::Zero; + uint64_t RawSize = 0; + CompressedBuffer Body = CompressedBuffer::Null; + }; + + struct GetCacheRecordResult + { + CacheKey Key = CacheKey::Empty; + std::vector<GetCacheRecordResultValue> Values; + }; + + struct GetCacheRecordsResult + { + std::vector<std::optional<GetCacheRecordResult>> Results; + + bool Parse(const CbPackage& Package, const std::span<const size_t> OptionalRecordResultIndexes = {}); + bool Format(CbPackage& OutPackage) const; + }; + + ////////////////////////////////////////////////////////////////////////// + // Put 1..n unstructured cache objects + + struct PutCacheValueRequest + { + CacheKey Key = CacheKey::Empty; + IoHash RawHash = IoHash::Zero; + CompressedBuffer Body = CompressedBuffer::Null; // If not set the value is expected to already exist in cache store + std::optional<CachePolicy> Policy; + }; + + struct PutCacheValuesRequest + { + uint32_t AcceptMagic = 0; + CachePolicy DefaultPolicy = CachePolicy::Default; + std::string Namespace; + std::vector<PutCacheValueRequest> Requests; + + bool Parse(const CbPackage& Package); + bool Format(CbPackage& OutPackage) const; + }; + + struct PutCacheValuesResult + { + std::vector<bool> Success; + + bool Parse(const CbPackage& Package); + bool Format(CbPackage& OutPackage) const; + }; + + ////////////////////////////////////////////////////////////////////////// + // Get 1..n unstructured cache objects (stored data may be structured or unstructured) + + struct GetCacheValueRequest + { + CacheKey Key = CacheKey::Empty; + std::optional<CachePolicy> Policy; + }; + + struct GetCacheValuesRequest + { + uint32_t AcceptMagic = 0; + uint16_t AcceptOptions = 0; + int32_t ProcessPid = 0; + CachePolicy DefaultPolicy = CachePolicy::Default; + std::string Namespace; + std::vector<GetCacheValueRequest> Requests; + + bool Parse(const CbObjectView& BatchObject); + bool Format(CbPackage& OutPackage, const std::span<const size_t> OptionalValueFilter = {}) const; + }; + + struct CacheValueResult + { + uint64_t RawSize = 0; + IoHash RawHash = IoHash::Zero; + CompressedBuffer Body = CompressedBuffer::Null; + }; + + struct CacheValuesResult + { + std::vector<CacheValueResult> Results; + + bool Parse(const CbPackage& Package, const std::span<const size_t> OptionalValueResultIndexes = {}); + bool Format(CbPackage& OutPackage) const; + }; + + typedef CacheValuesResult GetCacheValuesResult; + + ////////////////////////////////////////////////////////////////////////// + // Get 1..n cache record values (attachments) for 1..n records + + struct GetCacheChunkRequest + { + CacheKey Key; + Oid ValueId = Oid::Zero; // Set if ChunkId is not known at request time + IoHash ChunkId = IoHash::Zero; + uint64_t RawOffset = 0ull; + uint64_t RawSize = ~uint64_t(0); + std::optional<CachePolicy> Policy; + }; + + struct GetCacheChunksRequest + { + uint32_t AcceptMagic = 0; + uint16_t AcceptOptions = 0; + int32_t ProcessPid = 0; + CachePolicy DefaultPolicy = CachePolicy::Default; + std::string Namespace; + std::vector<GetCacheChunkRequest> Requests; + + bool Parse(const CbObjectView& BatchObject); + bool Format(CbPackage& OutPackage) const; + }; + + typedef CacheValuesResult GetCacheChunksResult; + + ////////////////////////////////////////////////////////////////////////// + + struct HttpRequestData + { + std::optional<std::string> Namespace; + std::optional<std::string> Bucket; + std::optional<IoHash> HashKey; + std::optional<IoHash> ValueContentId; + }; + + bool HttpRequestParseRelativeUri(std::string_view Key, HttpRequestData& Data); + + // Temporarily public + std::optional<std::string> GetRequestNamespace(const CbObjectView& Params); + bool GetRequestCacheKey(const CbObjectView& KeyView, CacheKey& Key); + + ////////////////////////////////////////////////////////////////////////// + + // struct CacheRecordValue + // { + // Oid Id = Oid::Zero; + // IoHash RawHash = IoHash::Zero; + // uint64_t RawSize = 0; + // }; + // + // struct CacheRecord + // { + // CacheKey Key = CacheKey::Empty; + // std::vector<CacheRecordValue> Values; + // + // bool Parse(CbObjectView& Reader); + // bool Format(CbObjectWriter& Writer) const; + // }; + +} // namespace cacherequests + +void cacherequests_forcelink(); // internal + +} // namespace zen diff --git a/src/zenutil/include/zenutil/cache/rpcrecording.h b/src/zenutil/include/zenutil/cache/rpcrecording.h new file mode 100644 index 000000000..6d65a532a --- /dev/null +++ b/src/zenutil/include/zenutil/cache/rpcrecording.h @@ -0,0 +1,29 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/compositebuffer.h> +#include <zencore/iobuffer.h> + +namespace zen::cache { +class IRpcRequestRecorder +{ +public: + virtual ~IRpcRequestRecorder() {} + virtual uint64_t RecordRequest(const ZenContentType ContentType, const ZenContentType AcceptType, const IoBuffer& RequestBuffer) = 0; + virtual void RecordResponse(uint64_t RequestIndex, const ZenContentType ContentType, const IoBuffer& ResponseBuffer) = 0; + virtual void RecordResponse(uint64_t RequestIndex, const ZenContentType ContentType, const CompositeBuffer& ResponseBuffer) = 0; +}; +class IRpcRequestReplayer +{ +public: + virtual ~IRpcRequestReplayer() {} + virtual uint64_t GetRequestCount() const = 0; + virtual std::pair<ZenContentType, ZenContentType> GetRequest(uint64_t RequestIndex, IoBuffer& OutBuffer) = 0; + virtual ZenContentType GetResponse(uint64_t RequestIndex, IoBuffer& OutBuffer) = 0; +}; + +std::unique_ptr<cache::IRpcRequestRecorder> MakeDiskRequestRecorder(const std::filesystem::path& BasePath); +std::unique_ptr<cache::IRpcRequestReplayer> MakeDiskRequestReplayer(const std::filesystem::path& BasePath, bool InMemory); + +} // namespace zen::cache diff --git a/src/zenutil/include/zenutil/zenserverprocess.h b/src/zenutil/include/zenutil/zenserverprocess.h new file mode 100644 index 000000000..1c204c144 --- /dev/null +++ b/src/zenutil/include/zenutil/zenserverprocess.h @@ -0,0 +1,141 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/enumflags.h> +#include <zencore/logging.h> +#include <zencore/thread.h> +#include <zencore/uid.h> + +#include <atomic> +#include <filesystem> +#include <optional> + +namespace zen { + +class ZenServerEnvironment +{ +public: + ZenServerEnvironment(); + ~ZenServerEnvironment(); + + void Initialize(std::filesystem::path ProgramBaseDir); + void InitializeForTest(std::filesystem::path ProgramBaseDir, std::filesystem::path TestBaseDir, std::string_view ServerClass = ""); + + std::filesystem::path CreateNewTestDir(); + std::filesystem::path ProgramBaseDir() const { return m_ProgramBaseDir; } + std::filesystem::path GetTestRootDir(std::string_view Path); + inline bool IsInitialized() const { return m_IsInitialized; } + inline bool IsTestEnvironment() const { return m_IsTestInstance; } + inline std::string_view GetServerClass() const { return m_ServerClass; } + +private: + std::filesystem::path m_ProgramBaseDir; + std::filesystem::path m_TestBaseDir; + bool m_IsInitialized = false; + bool m_IsTestInstance = false; + std::string m_ServerClass; +}; + +struct ZenServerInstance +{ + ZenServerInstance(ZenServerEnvironment& TestEnvironment); + ~ZenServerInstance(); + + void Shutdown(); + void SignalShutdown(); + void WaitUntilReady(); + [[nodiscard]] bool WaitUntilReady(int Timeout); + void EnableTermination() { m_Terminate = true; } + void Detach(); + inline int GetPid() { return m_Process.Pid(); } + inline void SetOwnerPid(int Pid) { m_OwnerPid = Pid; } + + void SetTestDir(std::filesystem::path TestDir) + { + ZEN_ASSERT(!m_Process.IsValid()); + m_TestDir = TestDir; + } + + void SpawnServer(int BasePort = 0, std::string_view AdditionalServerArgs = std::string_view()); + + void AttachToRunningServer(int BasePort = 0); + + std::string GetBaseUri() const; + +private: + ZenServerEnvironment& m_Env; + ProcessHandle m_Process; + NamedEvent m_ReadyEvent; + NamedEvent m_ShutdownEvent; + bool m_Terminate = false; + std::filesystem::path m_TestDir; + int m_BasePort = 0; + std::optional<int> m_OwnerPid; + + void CreateShutdownEvent(int BasePort); +}; + +/** Shared system state + * + * Used as a scratchpad to identify running instances etc + * + * The state lives in a memory-mapped file backed by the swapfile + * + */ + +class ZenServerState +{ +public: + ZenServerState(); + ~ZenServerState(); + + struct ZenServerEntry + { + // NOTE: any changes to this should consider backwards compatibility + // which means you should not rearrange members only potentially + // add something to the end or use a different mechanism for + // additional state. For example, you can use the session ID + // to introduce additional named objects + std::atomic<uint32_t> Pid; + std::atomic<uint16_t> DesiredListenPort; + std::atomic<uint16_t> Flags; + uint8_t SessionId[12]; + std::atomic<uint32_t> SponsorPids[8]; + std::atomic<uint16_t> EffectiveListenPort; + uint8_t Padding[10]; + + enum class FlagsEnum : uint16_t + { + kShutdownPlease = 1 << 0, + kIsReady = 1 << 1, + }; + + FRIEND_ENUM_CLASS_FLAGS(FlagsEnum); + + Oid GetSessionId() const { return Oid::FromMemory(SessionId); } + void Reset(); + void SignalShutdownRequest(); + void SignalReady(); + bool AddSponsorProcess(uint32_t Pid); + }; + + static_assert(sizeof(ZenServerEntry) == 64); + + void Initialize(); + [[nodiscard]] bool InitializeReadOnly(); + [[nodiscard]] ZenServerEntry* Lookup(int DesiredListenPort); + ZenServerEntry* Register(int DesiredListenPort); + void Sweep(); + void Snapshot(std::function<void(const ZenServerEntry&)>&& Callback); + inline bool IsReadOnly() const { return m_IsReadOnly; } + +private: + void* m_hMapFile = nullptr; + ZenServerEntry* m_Data = nullptr; + int m_MaxEntryCount = 65536 / sizeof(ZenServerEntry); + ZenServerEntry* m_OurEntry = nullptr; + bool m_IsReadOnly = true; +}; + +} // namespace zen diff --git a/src/zenutil/xmake.lua b/src/zenutil/xmake.lua new file mode 100644 index 000000000..e7d849bb2 --- /dev/null +++ b/src/zenutil/xmake.lua @@ -0,0 +1,9 @@ +-- Copyright Epic Games, Inc. All Rights Reserved. + +target('zenutil') + set_kind("static") + add_headerfiles("**.h") + add_files("**.cpp") + add_includedirs("include", {public=true}) + add_deps("zencore") + add_packages("vcpkg::spdlog") diff --git a/src/zenutil/zenserverprocess.cpp b/src/zenutil/zenserverprocess.cpp new file mode 100644 index 000000000..5ecde343b --- /dev/null +++ b/src/zenutil/zenserverprocess.cpp @@ -0,0 +1,677 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "zenutil/zenserverprocess.h" + +#include <zencore/except.h> +#include <zencore/filesystem.h> +#include <zencore/fmtutils.h> +#include <zencore/logging.h> +#include <zencore/session.h> +#include <zencore/string.h> +#include <zencore/thread.h> + +#include <atomic> + +#if ZEN_PLATFORM_WINDOWS +# include <atlbase.h> +# include <zencore/windows.h> +#else +# include <sys/mman.h> +#endif + +////////////////////////////////////////////////////////////////////////// + +namespace zen { + +namespace zenutil { +#if ZEN_PLATFORM_WINDOWS + class SecurityAttributes + { + public: + inline SECURITY_ATTRIBUTES* Attributes() { return &m_Attributes; } + + protected: + SECURITY_ATTRIBUTES m_Attributes{}; + SECURITY_DESCRIPTOR m_Sd{}; + }; + + // Security attributes which allows any user access + + class AnyUserSecurityAttributes : public SecurityAttributes + { + public: + AnyUserSecurityAttributes() + { + m_Attributes.nLength = sizeof m_Attributes; + m_Attributes.bInheritHandle = false; // Disable inheritance + + const BOOL Success = InitializeSecurityDescriptor(&m_Sd, SECURITY_DESCRIPTOR_REVISION); + + if (Success) + { + if (!SetSecurityDescriptorDacl(&m_Sd, TRUE, (PACL)NULL, FALSE)) + { + ThrowLastError("SetSecurityDescriptorDacl failed"); + } + + m_Attributes.lpSecurityDescriptor = &m_Sd; + } + } + }; +#endif // ZEN_PLATFORM_WINDOWS + +} // namespace zenutil + +////////////////////////////////////////////////////////////////////////// + +ZenServerState::ZenServerState() +{ +} + +ZenServerState::~ZenServerState() +{ + if (m_OurEntry) + { + // Clean up our entry now that we're leaving + + m_OurEntry->Reset(); + m_OurEntry = nullptr; + } + +#if ZEN_PLATFORM_WINDOWS + if (m_Data) + { + UnmapViewOfFile(m_Data); + } + + if (m_hMapFile) + { + CloseHandle(m_hMapFile); + } +#else + if (m_Data != nullptr) + { + munmap(m_Data, m_MaxEntryCount * sizeof(ZenServerEntry)); + } + + int Fd = int(intptr_t(m_hMapFile)); + close(Fd); +#endif + + m_Data = nullptr; +} + +void +ZenServerState::Initialize() +{ + size_t MapSize = m_MaxEntryCount * sizeof(ZenServerEntry); + +#if ZEN_PLATFORM_WINDOWS + // TODO: there's a small chance of a race here, this logic could be tightened up with a mutex to + // ensure only a single process at a time creates the mapping + // TODO: the fallback to Local instead of Global has a flaw where if you start a non-elevated instance + // first then start an elevated instance second you'll have the first instance with a local + // mapping and the second instance with a global mapping. This kind of elevated/non-elevated + // shouldn't be common, but handling for it should be improved in the future. + + HANDLE hMap = OpenFileMapping(FILE_MAP_ALL_ACCESS, FALSE, L"Global\\ZenMap"); + if (hMap == NULL) + { + hMap = OpenFileMapping(FILE_MAP_ALL_ACCESS, FALSE, L"Local\\ZenMap"); + } + + if (hMap == NULL) + { + // Security attributes to enable any user to access state + zenutil::AnyUserSecurityAttributes Attrs; + + hMap = CreateFileMapping(INVALID_HANDLE_VALUE, // use paging file + Attrs.Attributes(), // allow anyone to access + PAGE_READWRITE, // read/write access + 0, // maximum object size (high-order DWORD) + DWORD(MapSize), // maximum object size (low-order DWORD) + L"Global\\ZenMap"); // name of mapping object + + if (hMap == NULL) + { + hMap = CreateFileMapping(INVALID_HANDLE_VALUE, // use paging file + Attrs.Attributes(), // allow anyone to access + PAGE_READWRITE, // read/write access + 0, // maximum object size (high-order DWORD) + m_MaxEntryCount * sizeof(ZenServerEntry), // maximum object size (low-order DWORD) + L"Local\\ZenMap"); // name of mapping object + } + + if (hMap == NULL) + { + ThrowLastError("Could not open or create file mapping object for Zen server state"); + } + } + + void* pBuf = MapViewOfFile(hMap, // handle to map object + FILE_MAP_ALL_ACCESS, // read/write permission + 0, // offset high + 0, // offset low + DWORD(MapSize)); + + if (pBuf == NULL) + { + ThrowLastError("Could not map view of Zen server state"); + } +#else + int Fd = shm_open("/UnrealEngineZen", O_RDWR | O_CREAT | O_CLOEXEC, 0666); + if (Fd < 0) + { + ThrowLastError("Could not open a shared memory object"); + } + fchmod(Fd, 0666); + void* hMap = (void*)intptr_t(Fd); + + int Result = ftruncate(Fd, MapSize); + ZEN_UNUSED(Result); + + void* pBuf = mmap(nullptr, MapSize, PROT_READ | PROT_WRITE, MAP_SHARED, Fd, 0); + if (pBuf == MAP_FAILED) + { + ThrowLastError("Could not map view of Zen server state"); + } +#endif + + m_hMapFile = hMap; + m_Data = reinterpret_cast<ZenServerEntry*>(pBuf); + m_IsReadOnly = false; +} + +bool +ZenServerState::InitializeReadOnly() +{ + size_t MapSize = m_MaxEntryCount * sizeof(ZenServerEntry); + +#if ZEN_PLATFORM_WINDOWS + HANDLE hMap = OpenFileMapping(FILE_MAP_ALL_ACCESS, FALSE, L"Global\\ZenMap"); + if (hMap == NULL) + { + hMap = OpenFileMapping(FILE_MAP_ALL_ACCESS, FALSE, L"Local\\ZenMap"); + } + + if (hMap == NULL) + { + return false; + } + + void* pBuf = MapViewOfFile(hMap, // handle to map object + FILE_MAP_READ, // read permission + 0, // offset high + 0, // offset low + MapSize); + + if (pBuf == NULL) + { + ThrowLastError("Could not map view of Zen server state"); + } +#else + int Fd = shm_open("/UnrealEngineZen", O_RDONLY | O_CLOEXEC, 0666); + if (Fd < 0) + { + return false; + } + void* hMap = (void*)intptr_t(Fd); + + void* pBuf = mmap(nullptr, MapSize, PROT_READ, MAP_PRIVATE, Fd, 0); + if (pBuf == MAP_FAILED) + { + ThrowLastError("Could not map read-only view of Zen server state"); + } +#endif + + m_hMapFile = hMap; + m_Data = reinterpret_cast<ZenServerEntry*>(pBuf); + + return true; +} + +ZenServerState::ZenServerEntry* +ZenServerState::Lookup(int DesiredListenPort) +{ + for (int i = 0; i < m_MaxEntryCount; ++i) + { + if (m_Data[i].DesiredListenPort == DesiredListenPort) + { + return &m_Data[i]; + } + } + + return nullptr; +} + +ZenServerState::ZenServerEntry* +ZenServerState::Register(int DesiredListenPort) +{ + if (m_Data == nullptr) + { + return nullptr; + } + + // Allocate an entry + + int Pid = GetCurrentProcessId(); + + for (int i = 0; i < m_MaxEntryCount; ++i) + { + ZenServerEntry& Entry = m_Data[i]; + + if (Entry.DesiredListenPort.load(std::memory_order_relaxed) == 0) + { + uint16_t Expected = 0; + if (Entry.DesiredListenPort.compare_exchange_strong(Expected, uint16_t(DesiredListenPort))) + { + // Successfully allocated entry + + m_OurEntry = &Entry; + + Entry.Pid = Pid; + Entry.EffectiveListenPort = 0; + Entry.Flags = 0; + + const Oid SesId = GetSessionId(); + memcpy(Entry.SessionId, &SesId, sizeof SesId); + + return &Entry; + } + } + } + + return nullptr; +} + +void +ZenServerState::Sweep() +{ + if (m_Data == nullptr) + { + return; + } + + ZEN_ASSERT(m_IsReadOnly == false); + + for (int i = 0; i < m_MaxEntryCount; ++i) + { + ZenServerEntry& Entry = m_Data[i]; + + if (Entry.DesiredListenPort) + { + if (IsProcessRunning(Entry.Pid) == false) + { + ZEN_DEBUG("Sweep - pid {} not running, reclaiming entry (port {})", Entry.Pid, Entry.DesiredListenPort); + + Entry.Reset(); + } + } + } +} + +void +ZenServerState::Snapshot(std::function<void(const ZenServerEntry&)>&& Callback) +{ + if (m_Data == nullptr) + { + return; + } + + for (int i = 0; i < m_MaxEntryCount; ++i) + { + ZenServerEntry& Entry = m_Data[i]; + + if (Entry.DesiredListenPort) + { + Callback(Entry); + } + } +} + +void +ZenServerState::ZenServerEntry::Reset() +{ + Pid = 0; + DesiredListenPort = 0; + Flags = 0; + EffectiveListenPort = 0; +} + +void +ZenServerState::ZenServerEntry::SignalShutdownRequest() +{ + Flags |= uint16_t(FlagsEnum::kShutdownPlease); +} + +void +ZenServerState::ZenServerEntry::SignalReady() +{ + Flags |= uint16_t(FlagsEnum::kIsReady); +} + +bool +ZenServerState::ZenServerEntry::AddSponsorProcess(uint32_t PidToAdd) +{ + for (std::atomic<uint32_t>& PidEntry : SponsorPids) + { + if (PidEntry.load(std::memory_order_relaxed) == 0) + { + uint32_t Expected = 0; + if (PidEntry.compare_exchange_strong(Expected, PidToAdd)) + { + // Success! + return true; + } + } + else if (PidEntry.load(std::memory_order_relaxed) == PidToAdd) + { + // Success, the because pid is already in the list + return true; + } + } + + return false; +} + +////////////////////////////////////////////////////////////////////////// + +std::atomic<int> ZenServerTestCounter{0}; + +ZenServerEnvironment::ZenServerEnvironment() +{ +} + +ZenServerEnvironment::~ZenServerEnvironment() +{ +} + +void +ZenServerEnvironment::Initialize(std::filesystem::path ProgramBaseDir) +{ + m_ProgramBaseDir = ProgramBaseDir; + + ZEN_DEBUG("Program base dir is '{}'", ProgramBaseDir); + + m_IsInitialized = true; +} + +void +ZenServerEnvironment::InitializeForTest(std::filesystem::path ProgramBaseDir, + std::filesystem::path TestBaseDir, + std::string_view ServerClass) +{ + using namespace std::literals; + + m_ProgramBaseDir = ProgramBaseDir; + m_TestBaseDir = TestBaseDir; + + ZEN_INFO("Program base dir is '{}'", ProgramBaseDir); + ZEN_INFO("Cleaning test base dir '{}'", TestBaseDir); + DeleteDirectories(TestBaseDir.c_str()); + + m_IsTestInstance = true; + m_IsInitialized = true; + + if (ServerClass.empty()) + { +#if ZEN_WITH_HTTPSYS + m_ServerClass = "httpsys"sv; +#else + m_ServerClass = "asio"sv; +#endif + } + else + { + m_ServerClass = ServerClass; + } +} + +std::filesystem::path +ZenServerEnvironment::CreateNewTestDir() +{ + using namespace std::literals; + + ExtendableWideStringBuilder<256> TestDir; + TestDir << "test"sv << int64_t(++ZenServerTestCounter); + + std::filesystem::path TestPath = m_TestBaseDir / TestDir.c_str(); + + ZEN_INFO("Creating new test dir @ '{}'", TestPath); + + CreateDirectories(TestPath.c_str()); + + return TestPath; +} + +std::filesystem::path +ZenServerEnvironment::GetTestRootDir(std::string_view Path) +{ + std::filesystem::path Root = m_ProgramBaseDir.parent_path().parent_path(); + + std::filesystem::path Relative{Path}; + + return Root / Relative; +} + +////////////////////////////////////////////////////////////////////////// + +std::atomic<int> ChildIdCounter{0}; + +ZenServerInstance::ZenServerInstance(ZenServerEnvironment& TestEnvironment) : m_Env(TestEnvironment) +{ + ZEN_ASSERT(TestEnvironment.IsInitialized()); +} + +ZenServerInstance::~ZenServerInstance() +{ + Shutdown(); +} + +void +ZenServerInstance::SignalShutdown() +{ + m_ShutdownEvent.Set(); +} + +void +ZenServerInstance::Shutdown() +{ + if (m_Process.IsValid()) + { + if (m_Terminate) + { + ZEN_INFO("Terminating zenserver process"); + m_Process.Terminate(111); + m_Process.Reset(); + } + else + { + SignalShutdown(); + m_Process.Wait(); + m_Process.Reset(); + } + } +} + +void +ZenServerInstance::SpawnServer(int BasePort, std::string_view AdditionalServerArgs) +{ + ZEN_ASSERT(!m_Process.IsValid()); // Only spawn once + + const int MyPid = zen::GetCurrentProcessId(); + const int ChildId = ++ChildIdCounter; + + ExtendableStringBuilder<32> ChildEventName; + ChildEventName << "Zen_Child_" << ChildId; + NamedEvent ChildEvent{ChildEventName}; + + CreateShutdownEvent(BasePort); + + ExtendableStringBuilder<32> LogId; + LogId << "Zen" << ChildId; + + ExtendableStringBuilder<512> CommandLine; + CommandLine << "zenserver" ZEN_EXE_SUFFIX_LITERAL; // see CreateProc() call for actual binary path + + const bool IsTest = m_Env.IsTestEnvironment(); + + if (IsTest) + { + if (!m_OwnerPid.has_value()) + { + m_OwnerPid = MyPid; + } + + CommandLine << " --test --log-id " << LogId; + } + + if (m_OwnerPid.has_value()) + { + CommandLine << " --owner-pid " << m_OwnerPid.value(); + } + + CommandLine << " --child-id " << ChildEventName; + + if (std::string_view ServerClass = m_Env.GetServerClass(); ServerClass.empty() == false) + { + CommandLine << " --http " << ServerClass; + } + + if (BasePort) + { + CommandLine << " --port " << BasePort; + m_BasePort = BasePort; + } + + if (!m_TestDir.empty()) + { + CommandLine << " --data-dir "; + PathToUtf8(m_TestDir.c_str(), CommandLine); + } + + if (!AdditionalServerArgs.empty()) + { + CommandLine << " " << AdditionalServerArgs; + } + + std::filesystem::path CurrentDirectory = std::filesystem::current_path(); + + ZEN_DEBUG("Spawning server '{}'", LogId); + + uint32_t CreationFlags = 0; + if (!IsTest) + { + CreationFlags |= CreateProcOptions::Flag_NewConsole; + } + + const std::filesystem::path BaseDir = m_Env.ProgramBaseDir(); + const std::filesystem::path Executable = BaseDir / "zenserver" ZEN_EXE_SUFFIX_LITERAL; + CreateProcOptions CreateOptions = { + .WorkingDirectory = &CurrentDirectory, + .Flags = CreationFlags, + }; + CreateProcResult ChildPid = CreateProc(Executable, CommandLine.ToView(), CreateOptions); +#if ZEN_PLATFORM_WINDOWS + if (!ChildPid && ::GetLastError() == ERROR_ELEVATION_REQUIRED) + { + ZEN_DEBUG("Regular spawn failed - spawning elevated server"); + CreateOptions.Flags |= CreateProcOptions::Flag_Elevated; + ChildPid = CreateProc(Executable, CommandLine.ToView(), CreateOptions); + } +#endif + + if (!ChildPid) + { + ThrowLastError("Server spawn failed"); + } + + ZEN_DEBUG("Server '{}' spawned OK", LogId); + + if (IsTest) + { + m_Process.Initialize(ChildPid); + } + + m_ReadyEvent = std::move(ChildEvent); +} + +void +ZenServerInstance::CreateShutdownEvent(int BasePort) +{ + ExtendableStringBuilder<32> ChildShutdownEventName; + ChildShutdownEventName << "Zen_" << BasePort; + ChildShutdownEventName << "_Shutdown"; + NamedEvent ChildShutdownEvent{ChildShutdownEventName}; + m_ShutdownEvent = std::move(ChildShutdownEvent); +} + +void +ZenServerInstance::AttachToRunningServer(int BasePort) +{ + ZenServerState State; + if (!State.InitializeReadOnly()) + { + // TODO: return success/error code instead? + throw std::runtime_error("No zen state found"); + } + + const ZenServerState::ZenServerEntry* Entry = nullptr; + + if (BasePort) + { + Entry = State.Lookup(BasePort); + } + else + { + State.Snapshot([&](const ZenServerState::ZenServerEntry& InEntry) { Entry = &InEntry; }); + } + + if (!Entry) + { + // TODO: return success/error code instead? + throw std::runtime_error("No server found"); + } + + m_Process.Initialize(Entry->Pid); + CreateShutdownEvent(Entry->EffectiveListenPort); +} + +void +ZenServerInstance::Detach() +{ + if (m_Process.IsValid()) + { + m_Process.Reset(); + m_ShutdownEvent.Close(); + } +} + +void +ZenServerInstance::WaitUntilReady() +{ + while (m_ReadyEvent.Wait(100) == false) + { + if (!m_Process.IsRunning() || !m_Process.IsValid()) + { + ZEN_INFO("Wait abandoned by invalid process (running={})", m_Process.IsRunning()); + return; + } + } +} + +bool +ZenServerInstance::WaitUntilReady(int Timeout) +{ + return m_ReadyEvent.Wait(Timeout); +} + +std::string +ZenServerInstance::GetBaseUri() const +{ + ZEN_ASSERT(m_BasePort); + + return fmt::format("http://localhost:{}", m_BasePort); +} + +} // namespace zen |